1.6 KiB
1.6 KiB
按照图示完成多任务学习的网络结构,部分网络层共享参数
In [1]:
import torch from torch import nn
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
In [ ]:
class MTLNN(nn.Module): def __init__(self): super(MTLNN, self).__init__() self.main_branch = nn.Sequential( nn. )