34 lines
787 B
Python
34 lines
787 B
Python
|
|
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
@project:
|
|
@File : test1
|
|
@Author : qiqq
|
|
@create_time : 2022/11/4 21:22
|
|
"""
|
|
import torch
|
|
from torch import nn
|
|
from schedulers.polyscheduler import LinoPolyScheduler
|
|
class ddd(nn.Module):
|
|
def __init__(self):
|
|
super(ddd, self).__init__()
|
|
self.conv=nn.Conv2d(3,5,kernel_size=3)
|
|
def forward(self,x):
|
|
return self.conv(x)
|
|
|
|
|
|
model= ddd()
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
lisss=[1,2,3,5,4,78,7,8,]
|
|
scheduler = LinoPolyScheduler(optimizer, min_lr=0.01, epochs=10)
|
|
for epoch in range(10):
|
|
lee =optimizer.param_groups[0]['lr']
|
|
for i in lisss:
|
|
optimizer.step()
|
|
print("当前学习率:", lee)
|
|
|
|
scheduler.step()
|
|
|
|
|