From b8fe3320a9dce33c0bd66ae8f5fd96e1d6b2aaab Mon Sep 17 00:00:00 2001 From: zhaojinghao Date: Thu, 4 Aug 2022 08:43:24 +0800 Subject: [PATCH] update service --- mnist/mnist_torch.py | 134 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 mnist/mnist_torch.py diff --git a/mnist/mnist_torch.py b/mnist/mnist_torch.py new file mode 100644 index 0000000..9aa9522 --- /dev/null +++ b/mnist/mnist_torch.py @@ -0,0 +1,134 @@ +from asyncio.log import logger +from cmath import log +import torch +import torch.nn as nn +import torchvision +from torch.utils.data import DataLoader +import torch.optim as optim +import tqdm +import numpy as np + +batch_size = 256 +random_seed = 1884 +DEVICE = torch.device('cuda') +torch.manual_seed(random_seed) + + +def load_data(): + train_loader = DataLoader( + torchvision.datasets.MNIST('./data/', train=True, download=True, + transform=torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.1307,), (0.3081,)) + ])), + batch_size=batch_size, shuffle=True) + test_loader = DataLoader( + torchvision.datasets.MNIST('./data/', train=False, download=True, + transform=torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.1307,), (0.3081,)) + ])), + batch_size=batch_size, shuffle=True) + return train_loader, test_loader + + +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + self.model = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + + nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), + + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), + nn.ReLU(), + + nn.Flatten(), + nn.Linear(in_features=7*7*64, out_features=128), + nn.ReLU(), + nn.Linear(in_features=128, out_features=10) + # nn.Softmax(dim=1) + ) + + def forward(self, x): + return self.model(x) + + +def train(epochs: int, lr: float, train_loader, test_loader): + cnn = CNN() + cnn.to(DEVICE) + optimizer = optim.Adam(cnn.parameters(), lr=lr) + loss_func = nn.CrossEntropyLoss() + history = {'Test Loss': [], 'Test Accuracy': []} + for epoch in range(1, epochs + 1): + process_bar = tqdm.tqdm(train_loader, unit='step') + cnn.train(True) + for step, (img, label) in enumerate(process_bar): + img = img.to(DEVICE) + label = label.to(DEVICE) + cnn.zero_grad() + outputs = cnn(img) + loss = loss_func(outputs, label) + predictions = torch.argmax(outputs, dim=1) + accuracy = torch.sum(predictions==label) / label.shape[0] + + # 进行反向传播求出模型参数的梯度 + loss.backward() + # 使用迭代器更新模型权重 + optimizer.step() + + # 将本step结果进行可视化处理 + process_bar.set_description(f"[{epoch}/{epochs}] Loss: {round(loss.item(), 4)}, Acc: {round(accuracy.item(), 4)}") + if step == len(process_bar) - 1: + correct, total_loss = 0, 0 + cnn.train(False) + with torch.no_grad(): + for img, labels in test_loader: + img = img.to(DEVICE) + labels = labels.to(DEVICE) + outputs = cnn(img) + loss = loss_func(outputs, labels) + predictions = torch.argmax(outputs, dim=1) + total_loss += loss + correct += torch.sum(predictions == labels) + test_acc = correct / (batch_size * len(test_loader)) + test_loss = total_loss / len(test_loader) + history['Test Loss'].append(test_loss.item()) + history['Test Accuracy'].append(test_acc.item()) + process_bar.set_description(f"[{epoch}/{epochs}] Loss: {round(loss.item(), 4)}, Acc: {round(accuracy.item(), 4)}, test_loss: {round(test_loss.item(), 4)}, test_acc: {round(test_acc.item(), 4)}") + process_bar.close() + return cnn + + +def load_model(model_path): + model = torch.load(model_path, map_location='cpu') + return model + +def run_mnist_infer(img, model:CNN): + trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.1307,), (0.3081,))]) + data = trans(img) + data = data.unsqueeze(0) + pred = model(data) + logger.info(pred) + rst = pred.detach().numpy().argmax(axis=1)[0] + return rst + + +if __name__ == '__main__': + # model=load_model('./models/MNIST_torch.pth') + # a, b = load_data() + # for i, l in a: + # print(model(i)[0]) + # print(np.argmax(model(i).detach().numpy()[0]), l) + # break + # trl, tel = load_data() + # cnn = train(10, 0.01, trl, tel) + # torch.save(cnn, './models/MNIST_torch.pth') + pass \ No newline at end of file