#!/usr/bin/env python3 # coding=utf-8 import math class LinearLr: def __init__(self, param_group, learning_rate: float, total_steps: int, delay: bool, multiplier: int): self.total_steps = total_steps self.delay_steps = total_steps / 20 if delay else 0 self.max_lr = learning_rate self.steps = 0 self.param_group = param_group self.decay_multiplier = multiplier def __call__(self, _): self.steps += 1 if self.steps < self.delay_steps: lr = 0.0 elif self.steps < self.total_steps / 10: lr = self.max_lr * (self.steps - self.delay_steps) / (self.total_steps / 10 - self.delay_steps) else: max_lr = self.max_lr - self.max_lr / self.decay_multiplier min_lr = self.max_lr / self.decay_multiplier lr = max_lr * (math.cos(math.pi * (self.steps - self.total_steps / 10) / (self.total_steps * 9 / 10)) + 1) / 2 + min_lr #lr = self.max_lr * (self.total_steps - self.steps) / (self.total_steps * 9 / 10) # Safety first! if lr < 0.0: lr = 0.0 self.param_group["lr"] = lr def lr(self) -> float: return self.param_group["lr"]