2025.3assignment/24ALexNet_02.py

116 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torchvision.datasets
from matplotlib import pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
batch_size=10
train_data=torchvision.datasets.MNIST(root="./data",train=True,download=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.MNIST(root="./data",train=False,download=True,transform=torchvision.transforms.ToTensor())
train_dataloader=DataLoader(dataset=train_data,batch_size=batch_size)
test_dataloader=DataLoader(dataset=test_data,batch_size=batch_size)
# #输出数据集中的第一个图片
# plt.imshow(train_data.data[0].numpy(),cmap='gray')
# plt.title('%i' % train_data.targets[0])
# plt.show()
class AlexNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Sequential(
nn.Conv2d(in_channels=1,out_channels=96,kernel_size=3,stride=1,padding=1),#28*28
nn.ReLU(),
#2.改 3通道
#3. 写net 传函数 和 新模型写类区别
#4.用batch
nn.MaxPool2d(kernel_size=3,stride=1)#26*26
)
self.conv2=nn.Sequential(
nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5,padding=2),#26*26
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=1)#24*24
)
self.conv3=nn.Sequential(
nn.Conv2d(in_channels=256,out_channels=384,kernel_size=3,padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=384,out_channels=384,kernel_size=3,padding=1),
nn.ReLU(),
nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3,stride=1)#22*22
)
self.flatten=nn.Flatten()
self.end=nn.Sequential(
nn.Linear(in_features=256*22*22,out_features=256),
nn.ReLU(),
nn.Linear(in_features=256,out_features=10)
)
def forward(self,x):
x=self.conv1(x)
x=self.conv2(x)
x=self.conv3(x)
x=self.flatten(x)
x=self.end(x)
return x
net=AlexNet()
loss_cross=nn.CrossEntropyLoss()
optim=torch.optim.Adam(net.parameters(),lr=0.001)
def train(net,train_data):
for data,targets in train_data:
#print("targets:",targets)
#print("data.shape:",data.shape)
#print("targets.shape:",targets.shape)
#print("targets.dty24ALexNet.pype",targets.dtype)
#data=data.reshape(10,1,28,28) #1,28,28->1,1,28,28 ?
data=net(data)
#计算每次训练的准确率
acc=(data.argmax(1)==targets).sum() #用sum() 和 sum(0)都可以 || 加()的是函数 不加()的是属性
#print(acc.dtype)
accury=acc/batch_size
#print("正确数:",acc,"正确率",accury)
print("正确数:{},正确率:{}".format(acc,accury))#跑的时候除10 有 0.0000几的误差????
loss= loss_cross(data,targets)
#print("loss:",loss) #loss值代表??
print("loss:{}".format(loss))
loss.backward()
optim.step()
optim.zero_grad()
return net
def yanzheng(net,test_data):
net.eval()
acc=0
sum=batch_size
with torch.no_grad():
for data,targets in test_data:
data=net(data)
acc=(data.argmax(1)==targets).sum()
accury = acc / batch_size
print("正确数:",acc," ","正确率",accury)
#训练数据
#net=train(net,train_dataloader) #不返回也没事 训练之后的参数会保留,(只要是在这一次程序中运行的)
#保存模型 此方法需要保证可以找到自己定义的class AlexNet模型 (直接加载模型应该也可以吧 eg:VGG ||我认为小土堆的第二种方法 也需要保证运行的程序中有模型,无论是加载别人定义好的 还是 自己引入from P26_model_save import */加上代码 class 模型
#torch.save(net,"ALexNet_02.path")
#加载模型
net_loading=torch.load("ALexNet_02.path")
#测试数据
yanzheng(net_loading,test_dataloader)