update service
This commit is contained in:
parent
44a247390a
commit
b8fe3320a9
|
@ -0,0 +1,134 @@
|
|||
from asyncio.log import logger
|
||||
from cmath import log
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
import numpy as np
|
||||
|
||||
batch_size = 256
|
||||
random_seed = 1884
|
||||
DEVICE = torch.device('cuda')
|
||||
torch.manual_seed(random_seed)
|
||||
|
||||
|
||||
def load_data():
|
||||
train_loader = DataLoader(
|
||||
torchvision.datasets.MNIST('./data/', train=True, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize(
|
||||
(0.1307,), (0.3081,))
|
||||
])),
|
||||
batch_size=batch_size, shuffle=True)
|
||||
test_loader = DataLoader(
|
||||
torchvision.datasets.MNIST('./data/', train=False, download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize(
|
||||
(0.1307,), (0.3081,))
|
||||
])),
|
||||
batch_size=batch_size, shuffle=True)
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self):
|
||||
super(CNN, self).__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Flatten(),
|
||||
nn.Linear(in_features=7*7*64, out_features=128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(in_features=128, out_features=10)
|
||||
# nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
def train(epochs: int, lr: float, train_loader, test_loader):
|
||||
cnn = CNN()
|
||||
cnn.to(DEVICE)
|
||||
optimizer = optim.Adam(cnn.parameters(), lr=lr)
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
history = {'Test Loss': [], 'Test Accuracy': []}
|
||||
for epoch in range(1, epochs + 1):
|
||||
process_bar = tqdm.tqdm(train_loader, unit='step')
|
||||
cnn.train(True)
|
||||
for step, (img, label) in enumerate(process_bar):
|
||||
img = img.to(DEVICE)
|
||||
label = label.to(DEVICE)
|
||||
cnn.zero_grad()
|
||||
outputs = cnn(img)
|
||||
loss = loss_func(outputs, label)
|
||||
predictions = torch.argmax(outputs, dim=1)
|
||||
accuracy = torch.sum(predictions==label) / label.shape[0]
|
||||
|
||||
# 进行反向传播求出模型参数的梯度
|
||||
loss.backward()
|
||||
# 使用迭代器更新模型权重
|
||||
optimizer.step()
|
||||
|
||||
# 将本step结果进行可视化处理
|
||||
process_bar.set_description(f"[{epoch}/{epochs}] Loss: {round(loss.item(), 4)}, Acc: {round(accuracy.item(), 4)}")
|
||||
if step == len(process_bar) - 1:
|
||||
correct, total_loss = 0, 0
|
||||
cnn.train(False)
|
||||
with torch.no_grad():
|
||||
for img, labels in test_loader:
|
||||
img = img.to(DEVICE)
|
||||
labels = labels.to(DEVICE)
|
||||
outputs = cnn(img)
|
||||
loss = loss_func(outputs, labels)
|
||||
predictions = torch.argmax(outputs, dim=1)
|
||||
total_loss += loss
|
||||
correct += torch.sum(predictions == labels)
|
||||
test_acc = correct / (batch_size * len(test_loader))
|
||||
test_loss = total_loss / len(test_loader)
|
||||
history['Test Loss'].append(test_loss.item())
|
||||
history['Test Accuracy'].append(test_acc.item())
|
||||
process_bar.set_description(f"[{epoch}/{epochs}] Loss: {round(loss.item(), 4)}, Acc: {round(accuracy.item(), 4)}, test_loss: {round(test_loss.item(), 4)}, test_acc: {round(test_acc.item(), 4)}")
|
||||
process_bar.close()
|
||||
return cnn
|
||||
|
||||
|
||||
def load_model(model_path):
|
||||
model = torch.load(model_path, map_location='cpu')
|
||||
return model
|
||||
|
||||
def run_mnist_infer(img, model:CNN):
|
||||
trans = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.1307,), (0.3081,))])
|
||||
data = trans(img)
|
||||
data = data.unsqueeze(0)
|
||||
pred = model(data)
|
||||
logger.info(pred)
|
||||
rst = pred.detach().numpy().argmax(axis=1)[0]
|
||||
return rst
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# model=load_model('./models/MNIST_torch.pth')
|
||||
# a, b = load_data()
|
||||
# for i, l in a:
|
||||
# print(model(i)[0])
|
||||
# print(np.argmax(model(i).detach().numpy()[0]), l)
|
||||
# break
|
||||
# trl, tel = load_data()
|
||||
# cnn = train(10, 0.01, trl, tel)
|
||||
# torch.save(cnn, './models/MNIST_torch.pth')
|
||||
pass
|
Loading…
Reference in New Issue