124 lines
4.7 KiB
Python
124 lines
4.7 KiB
Python
# -*-coding:utf-8-*-
|
|
from logzero import logger
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
from torch.utils.data import DataLoader
|
|
import torch.optim as optim
|
|
import tqdm
|
|
|
|
batch_size = 256
|
|
random_seed = 1884
|
|
DEVICE = torch.device('cpu')
|
|
torch.manual_seed(random_seed)
|
|
|
|
|
|
def load_data():
|
|
train_loader = DataLoader(
|
|
torchvision.datasets.MNIST('./data/', train=True, download=True,
|
|
transform=torchvision.transforms.Compose([
|
|
torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize(
|
|
(0.1307,), (0.3081,))
|
|
])),
|
|
batch_size=batch_size, shuffle=True)
|
|
test_loader = DataLoader(
|
|
torchvision.datasets.MNIST('./data/', train=False, download=True,
|
|
transform=torchvision.transforms.Compose([
|
|
torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize(
|
|
(0.1307,), (0.3081,))
|
|
])),
|
|
batch_size=batch_size, shuffle=True)
|
|
return train_loader, test_loader
|
|
|
|
|
|
class CNN(nn.Module):
|
|
def __init__(self):
|
|
super(CNN, self).__init__()
|
|
self.model = nn.Sequential(
|
|
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
|
|
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(kernel_size=2, stride=2),
|
|
|
|
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
|
|
nn.Flatten(),
|
|
nn.Linear(in_features=7*7*64, out_features=128),
|
|
nn.ReLU(),
|
|
nn.Linear(in_features=128, out_features=10)
|
|
# nn.Softmax(dim=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
|
|
def train(epochs: int, lr: float, train_loader, test_loader):
|
|
cnn = CNN()
|
|
cnn.to(DEVICE)
|
|
optimizer = optim.Adam(cnn.parameters(), lr=lr)
|
|
loss_func = nn.CrossEntropyLoss()
|
|
history = {'Test Loss': [], 'Test Accuracy': []}
|
|
for epoch in range(1, epochs + 1):
|
|
process_bar = tqdm.tqdm(train_loader, unit='step')
|
|
cnn.train(True)
|
|
for step, (img, label) in enumerate(process_bar):
|
|
img = img.to(DEVICE)
|
|
label = label.to(DEVICE)
|
|
cnn.zero_grad()
|
|
outputs = cnn(img)
|
|
loss = loss_func(outputs, label)
|
|
predictions = torch.argmax(outputs, dim=1)
|
|
accuracy = torch.sum(predictions==label) / label.shape[0]
|
|
|
|
# 进行反向传播求出模型参数的梯度
|
|
loss.backward()
|
|
# 使用迭代器更新模型权重
|
|
optimizer.step()
|
|
|
|
# 将本step结果进行可视化处理
|
|
process_bar.set_description(f"[{epoch}/{epochs}] Loss: {round(loss.item(), 4)}, Acc: {round(accuracy.item(), 4)}")
|
|
if step == len(process_bar) - 1:
|
|
correct, total_loss = 0, 0
|
|
cnn.train(False)
|
|
with torch.no_grad():
|
|
for img, labels in test_loader:
|
|
img = img.to(DEVICE)
|
|
labels = labels.to(DEVICE)
|
|
outputs = cnn(img)
|
|
loss = loss_func(outputs, labels)
|
|
predictions = torch.argmax(outputs, dim=1)
|
|
total_loss += loss
|
|
correct += torch.sum(predictions == labels)
|
|
test_acc = correct / (batch_size * len(test_loader))
|
|
test_loss = total_loss / len(test_loader)
|
|
history['Test Loss'].append(test_loss.item())
|
|
history['Test Accuracy'].append(test_acc.item())
|
|
process_bar.set_description(f"[{epoch}/{epochs}] Loss: {round(loss.item(), 4)}, Acc: {round(accuracy.item(), 4)}, test_loss: {round(test_loss.item(), 4)}, test_acc: {round(test_acc.item(), 4)}")
|
|
process_bar.close()
|
|
return cnn
|
|
|
|
|
|
def load_model(model_path):
|
|
model = torch.load(model_path, map_location='cpu')
|
|
return model
|
|
|
|
def run_mnist_infer(img, model:CNN):
|
|
trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
|
|
torchvision.transforms.Normalize((0.1307,), (0.3081,))])
|
|
data = trans(img)
|
|
data = data.unsqueeze(0)
|
|
pred = model(data)
|
|
logger.info(pred)
|
|
rst = pred.detach().numpy().argmax(axis=1)[0]
|
|
return rst
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pass |