ai_platform_cv/mnist/mnist_torch.py

124 lines
4.7 KiB
Python

# -*-coding:utf-8-*-
from logzero import logger
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
import tqdm
batch_size = 256
random_seed = 1884
DEVICE = torch.device('cpu')
torch.manual_seed(random_seed)
def load_data():
train_loader = DataLoader(
torchvision.datasets.MNIST('./data/', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
torchvision.datasets.MNIST('./data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
return train_loader, test_loader
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(in_features=7*7*64, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=10)
# nn.Softmax(dim=1)
)
def forward(self, x):
return self.model(x)
def train(epochs: int, lr: float, train_loader, test_loader):
cnn = CNN()
cnn.to(DEVICE)
optimizer = optim.Adam(cnn.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss()
history = {'Test Loss': [], 'Test Accuracy': []}
for epoch in range(1, epochs + 1):
process_bar = tqdm.tqdm(train_loader, unit='step')
cnn.train(True)
for step, (img, label) in enumerate(process_bar):
img = img.to(DEVICE)
label = label.to(DEVICE)
cnn.zero_grad()
outputs = cnn(img)
loss = loss_func(outputs, label)
predictions = torch.argmax(outputs, dim=1)
accuracy = torch.sum(predictions==label) / label.shape[0]
# 进行反向传播求出模型参数的梯度
loss.backward()
# 使用迭代器更新模型权重
optimizer.step()
# 将本step结果进行可视化处理
process_bar.set_description(f"[{epoch}/{epochs}] Loss: {round(loss.item(), 4)}, Acc: {round(accuracy.item(), 4)}")
if step == len(process_bar) - 1:
correct, total_loss = 0, 0
cnn.train(False)
with torch.no_grad():
for img, labels in test_loader:
img = img.to(DEVICE)
labels = labels.to(DEVICE)
outputs = cnn(img)
loss = loss_func(outputs, labels)
predictions = torch.argmax(outputs, dim=1)
total_loss += loss
correct += torch.sum(predictions == labels)
test_acc = correct / (batch_size * len(test_loader))
test_loss = total_loss / len(test_loader)
history['Test Loss'].append(test_loss.item())
history['Test Accuracy'].append(test_acc.item())
process_bar.set_description(f"[{epoch}/{epochs}] Loss: {round(loss.item(), 4)}, Acc: {round(accuracy.item(), 4)}, test_loss: {round(test_loss.item(), 4)}, test_acc: {round(test_acc.item(), 4)}")
process_bar.close()
return cnn
def load_model(model_path):
model = torch.load(model_path, map_location='cpu')
return model
def run_mnist_infer(img, model:CNN):
trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])
data = trans(img)
data = data.unsqueeze(0)
pred = model(data)
logger.info(pred)
rst = pred.detach().numpy().argmax(axis=1)[0]
return rst
if __name__ == '__main__':
pass