#!/usr/bin/env python # -*- coding: utf-8 -*- """ @project: @File : test1 @Author : qiqq @create_time : 2022/11/4 21:22 """ import torch from torch import nn from schedulers.polyscheduler import LinoPolyScheduler class ddd(nn.Module): def __init__(self): super(ddd, self).__init__() self.conv=nn.Conv2d(3,5,kernel_size=3) def forward(self,x): return self.conv(x) model= ddd() optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) lisss=[1,2,3,5,4,78,7,8,] scheduler = LinoPolyScheduler(optimizer, min_lr=0.01, epochs=10) for epoch in range(10): lee =optimizer.param_groups[0]['lr'] for i in lisss: optimizer.step() print("当前学习率:", lee) scheduler.step()