TCN框架代码提交
This commit is contained in:
commit
eb1987cf9e
|
@ -0,0 +1,134 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm # 用于权重归一化的工具
|
||||
|
||||
|
||||
# 扩张因果卷积模块
|
||||
class DilatedCausalConv1d(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, dilation):
|
||||
"""
|
||||
in_channels: 输入通道数
|
||||
out_channels: 输出通道数
|
||||
kernel_size: 卷积核大小
|
||||
dilation: 扩张因子(控制感受野大小)
|
||||
"""
|
||||
super().__init__()
|
||||
# 计算因果卷积需要的左侧padding量:(kernel_size-1)*dilation
|
||||
self.padding = (kernel_size - 1) * dilation # 保证时序因果关系(不泄露未来信息)
|
||||
|
||||
# 创建带权重归一化的1D卷积层
|
||||
self.conv = weight_norm(
|
||||
nn.Conv1d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=self.padding, # 只在左侧填充
|
||||
dilation=dilation) # 设置扩张率
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
输入形状: (batch_size, in_channels, seq_len)
|
||||
输出形状: (batch_size, out_channels, seq_len)
|
||||
"""
|
||||
x = self.conv(x)
|
||||
# 裁剪右侧多余的padding,保持输出长度与输入一致
|
||||
return x[:, :, :-self.padding] # 切片操作去除右侧padding
|
||||
|
||||
|
||||
# 残差块模块
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.2):
|
||||
super().__init__()
|
||||
# 第一个卷积层(包含所有规范化操作)
|
||||
self.conv1 = DilatedCausalConv1d(in_channels, out_channels, kernel_size, dilation)
|
||||
# 第二个卷积层
|
||||
self.conv2 = DilatedCausalConv1d(out_channels, out_channels, kernel_size, dilation)
|
||||
|
||||
# 公共组件初始化
|
||||
self.dropout = nn.Dropout(dropout) # 随机失活层
|
||||
self.relu = nn.ReLU() # 激活函数
|
||||
|
||||
# 当输入输出通道数不同时,使用1x1卷积调整通道数
|
||||
self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
|
||||
|
||||
def forward(self, x):
|
||||
residual = x # 保存原始输入用于残差连接
|
||||
|
||||
# 第一层处理流程
|
||||
x = self.dropout(x) # 应用Dropout
|
||||
x = self.relu(x) # 非线性激活
|
||||
x = self.conv1(x) # 扩张因果卷积
|
||||
|
||||
# 第二层处理流程
|
||||
x = self.dropout(x) # 再次应用Dropout
|
||||
x = self.relu(x) # 非线性激活
|
||||
x = self.conv2(x) # 扩张因果卷积
|
||||
|
||||
# 处理残差连接
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual) # 通过1x1卷积调整通道数
|
||||
return residual + x # 残差相加
|
||||
|
||||
|
||||
# 完整TCN模型
|
||||
class TCN(nn.Module):
|
||||
def __init__(self, input_size, num_channels, kernel_size=3, dropout=0.2):
|
||||
"""
|
||||
input_size: 输入特征维度(通道数)
|
||||
num_channels: 每层的输出通道数列表(决定网络深度)
|
||||
kernel_size: 卷积核尺寸
|
||||
"""
|
||||
super().__init__()
|
||||
layers = [] # 存储所有残差块
|
||||
num_levels = len(num_channels) # 网络层数
|
||||
|
||||
# 逐层构建网络
|
||||
for i in range(num_levels):
|
||||
dilation = 2 ** i # 扩张因子指数增长(2^0, 2^1, 2^2...)
|
||||
in_channels = input_size if i == 0 else num_channels[i - 1] # 确定输入通道
|
||||
out_channels = num_channels[i] # 当前层输出通道
|
||||
|
||||
# 添加残差块
|
||||
layers += [
|
||||
ResidualBlock(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=dilation,
|
||||
dropout=dropout
|
||||
)
|
||||
]
|
||||
|
||||
# 将所有残差块组合成序列
|
||||
self.network = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
输入形状: (batch_size, input_size, seq_len)
|
||||
输出形状: (batch_size, num_channels[-1], seq_len)
|
||||
"""
|
||||
return self.network(x)
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
batch_size = 32 # 批大小
|
||||
seq_len = 100 # 序列长度
|
||||
input_size = 64 # 输入特征维度
|
||||
num_channels = [64, 64, 64] # 各层输出通道配置(这里3层,每层64通道)
|
||||
kernel_size = 3 # 卷积核尺寸
|
||||
|
||||
# 初始化模型
|
||||
model = TCN(input_size, num_channels, kernel_size)
|
||||
|
||||
# 生成测试数据
|
||||
x = torch.randn(batch_size, input_size, seq_len) # 随机输入数据
|
||||
|
||||
# 前向传播
|
||||
output = model(x)
|
||||
|
||||
# 验证输出形状(应与输入序列长度相同)
|
||||
print(f"Input shape: {x.shape}") # (32, 64, 100)
|
||||
print(f"Output shape: {output.shape}") # (32, 64, 100)
|
Loading…
Reference in New Issue