22-T67/multi-task-learning.ipynb

1.6 KiB

按照图示完成多任务学习的网络结构,部分网络层共享参数

In [1]:
import torch
from torch import nn
/home/zhaojh/miniconda3/envs/py37/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
In [ ]:
class MTLNN(nn.Module):
    def __init__(self):
        super(MTLNN, self).__init__()
        self.main_branch = nn.Sequential(
            nn.
        )