34 lines
787 B
Python
34 lines
787 B
Python
|
|
||
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
@project:
|
||
|
@File : test1
|
||
|
@Author : qiqq
|
||
|
@create_time : 2022/11/4 21:22
|
||
|
"""
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from schedulers.polyscheduler import LinoPolyScheduler
|
||
|
class ddd(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(ddd, self).__init__()
|
||
|
self.conv=nn.Conv2d(3,5,kernel_size=3)
|
||
|
def forward(self,x):
|
||
|
return self.conv(x)
|
||
|
|
||
|
|
||
|
model= ddd()
|
||
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||
|
lisss=[1,2,3,5,4,78,7,8,]
|
||
|
scheduler = LinoPolyScheduler(optimizer, min_lr=0.01, epochs=10)
|
||
|
for epoch in range(10):
|
||
|
lee =optimizer.param_groups[0]['lr']
|
||
|
for i in lisss:
|
||
|
optimizer.step()
|
||
|
print("当前学习率:", lee)
|
||
|
|
||
|
scheduler.step()
|
||
|
|
||
|
|