students_git_repo/周家林/TCN.py

134 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm # 用于权重归一化的工具
# 扩张因果卷积模块
class DilatedCausalConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation):
"""
in_channels: 输入通道数
out_channels: 输出通道数
kernel_size: 卷积核大小
dilation: 扩张因子(控制感受野大小)
"""
super().__init__()
# 计算因果卷积需要的左侧padding量(kernel_size-1)*dilation
self.padding = (kernel_size - 1) * dilation # 保证时序因果关系(不泄露未来信息)
# 创建带权重归一化的1D卷积层
self.conv = weight_norm(
nn.Conv1d(in_channels,
out_channels,
kernel_size,
padding=self.padding, # 只在左侧填充
dilation=dilation) # 设置扩张率
)
def forward(self, x):
"""
输入形状: (batch_size, in_channels, seq_len)
输出形状: (batch_size, out_channels, seq_len)
"""
x = self.conv(x)
# 裁剪右侧多余的padding保持输出长度与输入一致
return x[:, :, :-self.padding] # 切片操作去除右侧padding
# 残差块模块
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.2):
super().__init__()
# 第一个卷积层(包含所有规范化操作)
self.conv1 = DilatedCausalConv1d(in_channels, out_channels, kernel_size, dilation)
# 第二个卷积层
self.conv2 = DilatedCausalConv1d(out_channels, out_channels, kernel_size, dilation)
# 公共组件初始化
self.dropout = nn.Dropout(dropout) # 随机失活层
self.relu = nn.ReLU() # 激活函数
# 当输入输出通道数不同时使用1x1卷积调整通道数
self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
def forward(self, x):
residual = x # 保存原始输入用于残差连接
# 第一层处理流程
x = self.dropout(x) # 应用Dropout
x = self.relu(x) # 非线性激活
x = self.conv1(x) # 扩张因果卷积
# 第二层处理流程
x = self.dropout(x) # 再次应用Dropout
x = self.relu(x) # 非线性激活
x = self.conv2(x) # 扩张因果卷积
# 处理残差连接
if self.downsample is not None:
residual = self.downsample(residual) # 通过1x1卷积调整通道数
return residual + x # 残差相加
# 完整TCN模型
class TCN(nn.Module):
def __init__(self, input_size, num_channels, kernel_size=3, dropout=0.2):
"""
input_size: 输入特征维度(通道数)
num_channels: 每层的输出通道数列表(决定网络深度)
kernel_size: 卷积核尺寸
"""
super().__init__()
layers = [] # 存储所有残差块
num_levels = len(num_channels) # 网络层数
# 逐层构建网络
for i in range(num_levels):
dilation = 2 ** i # 扩张因子指数增长2^0, 2^1, 2^2...
in_channels = input_size if i == 0 else num_channels[i - 1] # 确定输入通道
out_channels = num_channels[i] # 当前层输出通道
# 添加残差块
layers += [
ResidualBlock(
in_channels,
out_channels,
kernel_size=kernel_size,
dilation=dilation,
dropout=dropout
)
]
# 将所有残差块组合成序列
self.network = nn.Sequential(*layers)
def forward(self, x):
"""
输入形状: (batch_size, input_size, seq_len)
输出形状: (batch_size, num_channels[-1], seq_len)
"""
return self.network(x)
# 示例用法
if __name__ == "__main__":
# 配置参数
batch_size = 32 # 批大小
seq_len = 100 # 序列长度
input_size = 64 # 输入特征维度
num_channels = [64, 64, 64] # 各层输出通道配置这里3层每层64通道
kernel_size = 3 # 卷积核尺寸
# 初始化模型
model = TCN(input_size, num_channels, kernel_size)
# 生成测试数据
x = torch.randn(batch_size, input_size, seq_len) # 随机输入数据
# 前向传播
output = model(x)
# 验证输出形状(应与输入序列长度相同)
print(f"Input shape: {x.shape}") # (32, 64, 100)
print(f"Output shape: {output.shape}") # (32, 64, 100)