|
import sys |
|
import warnings |
|
from bisect import bisect_right |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.optim import lr_scheduler |
|
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_debug |
|
|
|
|
|
class ChainedScheduler(lr_scheduler._LRScheduler): |
|
"""Chains list of learning rate schedulers. It takes a list of chainable learning |
|
rate schedulers and performs consecutive step() functions belong to them by just |
|
one call. |
|
|
|
Args: |
|
schedulers (list): List of chained schedulers. |
|
|
|
Example: |
|
>>> # Assuming optimizer uses lr = 1. for all groups |
|
>>> # lr = 0.09 if epoch == 0 |
|
>>> # lr = 0.081 if epoch == 1 |
|
>>> # lr = 0.729 if epoch == 2 |
|
>>> # lr = 0.6561 if epoch == 3 |
|
>>> # lr = 0.59049 if epoch >= 4 |
|
>>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) |
|
>>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) |
|
>>> scheduler = ChainedScheduler([scheduler1, scheduler2]) |
|
>>> for epoch in range(100): |
|
>>> train(...) |
|
>>> validate(...) |
|
>>> scheduler.step() |
|
""" |
|
|
|
def __init__(self, optimizer, schedulers): |
|
for scheduler_idx in range(1, len(schedulers)): |
|
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): |
|
raise ValueError( |
|
"ChainedScheduler expects all schedulers to belong to the same optimizer, but " |
|
"got schedulers at index {} and {} to be different".format(0, scheduler_idx) |
|
) |
|
self._schedulers = list(schedulers) |
|
self.optimizer = optimizer |
|
|
|
def step(self): |
|
for scheduler in self._schedulers: |
|
scheduler.step() |
|
|
|
def state_dict(self): |
|
"""Returns the state of the scheduler as a :class:`dict`. |
|
|
|
It contains an entry for every variable in self.__dict__ which |
|
is not the optimizer. |
|
The wrapped scheduler states will also be saved. |
|
""" |
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} |
|
state_dict['_schedulers'] = [None] * len(self._schedulers) |
|
|
|
for idx, s in enumerate(self._schedulers): |
|
state_dict['_schedulers'][idx] = s.state_dict() |
|
|
|
return state_dict |
|
|
|
def load_state_dict(self, state_dict): |
|
"""Loads the schedulers state. |
|
|
|
Args: |
|
state_dict (dict): scheduler state. Should be an object returned |
|
from a call to :meth:`state_dict`. |
|
""" |
|
_schedulers = state_dict.pop('_schedulers') |
|
self.__dict__.update(state_dict) |
|
|
|
|
|
state_dict['_schedulers'] = _schedulers |
|
|
|
for idx, s in enumerate(_schedulers): |
|
self._schedulers[idx].load_state_dict(s) |
|
|
|
|
|
class SequentialLR(lr_scheduler._LRScheduler): |
|
"""Receives the list of schedulers that is expected to be called sequentially during |
|
optimization process and milestone points that provides exact intervals to reflect |
|
which scheduler is supposed to be called at a given epoch. |
|
|
|
Args: |
|
schedulers (list): List of chained schedulers. |
|
milestones (list): List of integers that reflects milestone points. |
|
|
|
Example: |
|
>>> # Assuming optimizer uses lr = 1. for all groups |
|
>>> # lr = 0.1 if epoch == 0 |
|
>>> # lr = 0.1 if epoch == 1 |
|
>>> # lr = 0.9 if epoch == 2 |
|
>>> # lr = 0.81 if epoch == 3 |
|
>>> # lr = 0.729 if epoch == 4 |
|
>>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) |
|
>>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) |
|
>>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) |
|
>>> for epoch in range(100): |
|
>>> train(...) |
|
>>> validate(...) |
|
>>> scheduler.step() |
|
""" |
|
|
|
def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): |
|
for scheduler_idx in range(1, len(schedulers)): |
|
if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): |
|
raise ValueError( |
|
"Sequential Schedulers expects all schedulers to belong to the same optimizer, but " |
|
"got schedulers at index {} and {} to be different".format(0, scheduler_idx) |
|
) |
|
if (len(milestones) != len(schedulers) - 1): |
|
raise ValueError( |
|
"Sequential Schedulers expects number of schedulers provided to be one more " |
|
"than the number of milestone points, but got number of schedulers {} and the " |
|
"number of milestones to be equal to {}".format(len(schedulers), len(milestones)) |
|
) |
|
self._schedulers = schedulers |
|
self._milestones = milestones |
|
self.last_epoch = last_epoch + 1 |
|
self.optimizer = optimizer |
|
|
|
def step(self): |
|
self.last_epoch += 1 |
|
idx = bisect_right(self._milestones, self.last_epoch) |
|
if idx > 0 and self._milestones[idx - 1] == self.last_epoch: |
|
self._schedulers[idx].step(0) |
|
else: |
|
self._schedulers[idx].step() |
|
|
|
def state_dict(self): |
|
"""Returns the state of the scheduler as a :class:`dict`. |
|
|
|
It contains an entry for every variable in self.__dict__ which |
|
is not the optimizer. |
|
The wrapped scheduler states will also be saved. |
|
""" |
|
state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} |
|
state_dict['_schedulers'] = [None] * len(self._schedulers) |
|
|
|
for idx, s in enumerate(self._schedulers): |
|
state_dict['_schedulers'][idx] = s.state_dict() |
|
|
|
return state_dict |
|
|
|
def load_state_dict(self, state_dict): |
|
"""Loads the schedulers state. |
|
|
|
Args: |
|
state_dict (dict): scheduler state. Should be an object returned |
|
from a call to :meth:`state_dict`. |
|
""" |
|
_schedulers = state_dict.pop('_schedulers') |
|
self.__dict__.update(state_dict) |
|
|
|
|
|
state_dict['_schedulers'] = _schedulers |
|
|
|
for idx, s in enumerate(_schedulers): |
|
self._schedulers[idx].load_state_dict(s) |
|
|
|
|
|
class ConstantLR(lr_scheduler._LRScheduler): |
|
"""Decays the learning rate of each parameter group by a small constant factor until the |
|
number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can |
|
happen simultaneously with other changes to the learning rate from outside this scheduler. |
|
When last_epoch=-1, sets initial lr as lr. |
|
|
|
Args: |
|
optimizer (Optimizer): Wrapped optimizer. |
|
factor (float): The number we multiply learning rate until the milestone. Default: 1./3. |
|
total_iters (int): The number of steps that the scheduler decays the learning rate. |
|
Default: 5. |
|
last_epoch (int): The index of the last epoch. Default: -1. |
|
verbose (bool): If ``True``, prints a message to stdout for |
|
each update. Default: ``False``. |
|
|
|
Example: |
|
>>> # Assuming optimizer uses lr = 0.05 for all groups |
|
>>> # lr = 0.025 if epoch == 0 |
|
>>> # lr = 0.025 if epoch == 1 |
|
>>> # lr = 0.025 if epoch == 2 |
|
>>> # lr = 0.025 if epoch == 3 |
|
>>> # lr = 0.05 if epoch >= 4 |
|
>>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) |
|
>>> for epoch in range(100): |
|
>>> train(...) |
|
>>> validate(...) |
|
>>> scheduler.step() |
|
""" |
|
|
|
def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): |
|
if factor > 1.0 or factor < 0: |
|
raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') |
|
|
|
self.factor = factor |
|
self.total_iters = total_iters |
|
super(ConstantLR, self).__init__(optimizer, last_epoch, verbose) |
|
|
|
def get_lr(self): |
|
if not self._get_lr_called_within_step: |
|
warnings.warn("To get the last learning rate computed by the scheduler, " |
|
"please use `get_last_lr()`.", UserWarning) |
|
|
|
if self.last_epoch == 0: |
|
return [group['lr'] * self.factor for group in self.optimizer.param_groups] |
|
|
|
if (self.last_epoch > self.total_iters or |
|
(self.last_epoch != self.total_iters)): |
|
return [group['lr'] for group in self.optimizer.param_groups] |
|
|
|
if (self.last_epoch == self.total_iters): |
|
return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] |
|
|
|
def _get_closed_form_lr(self): |
|
return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) |
|
for base_lr in self.base_lrs] |
|
|
|
|
|
class LinearLR(lr_scheduler._LRScheduler): |
|
"""Decays the learning rate of each parameter group by linearly changing small |
|
multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. |
|
Notice that such decay can happen simultaneously with other changes to the learning rate |
|
from outside this scheduler. When last_epoch=-1, sets initial lr as lr. |
|
|
|
Args: |
|
optimizer (Optimizer): Wrapped optimizer. |
|
start_factor (float): The number we multiply learning rate in the first epoch. |
|
The multiplication factor changes towards end_factor in the following epochs. |
|
Default: 1./3. |
|
end_factor (float): The number we multiply learning rate at the end of linear changing |
|
process. Default: 1.0. |
|
total_iters (int): The number of iterations that multiplicative factor reaches to 1. |
|
Default: 5. |
|
last_epoch (int): The index of the last epoch. Default: -1. |
|
verbose (bool): If ``True``, prints a message to stdout for |
|
each update. Default: ``False``. |
|
|
|
Example: |
|
>>> # Assuming optimizer uses lr = 0.05 for all groups |
|
>>> # lr = 0.025 if epoch == 0 |
|
>>> # lr = 0.03125 if epoch == 1 |
|
>>> # lr = 0.0375 if epoch == 2 |
|
>>> # lr = 0.04375 if epoch == 3 |
|
>>> # lr = 0.05 if epoch >= 4 |
|
>>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) |
|
>>> for epoch in range(100): |
|
>>> train(...) |
|
>>> validate(...) |
|
>>> scheduler.step() |
|
""" |
|
|
|
def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, |
|
verbose=False): |
|
if start_factor > 1.0 or start_factor < 0: |
|
raise ValueError('Starting multiplicative factor expected to be between 0 and 1.') |
|
|
|
if end_factor > 1.0 or end_factor < 0: |
|
raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') |
|
|
|
self.start_factor = start_factor |
|
self.end_factor = end_factor |
|
self.total_iters = total_iters |
|
super(LinearLR, self).__init__(optimizer, last_epoch, verbose) |
|
|
|
def get_lr(self): |
|
if not self._get_lr_called_within_step: |
|
warnings.warn("To get the last learning rate computed by the scheduler, " |
|
"please use `get_last_lr()`.", UserWarning) |
|
|
|
if self.last_epoch == 0: |
|
return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] |
|
|
|
if (self.last_epoch > self.total_iters): |
|
return [group['lr'] for group in self.optimizer.param_groups] |
|
|
|
return [group['lr'] * (1. + (self.end_factor - self.start_factor) / |
|
(self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) |
|
for group in self.optimizer.param_groups] |
|
|
|
def _get_closed_form_lr(self): |
|
return [base_lr * (self.start_factor + |
|
(self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) |
|
for base_lr in self.base_lrs] |
|
|
|
|
|
custom_schedulers = ['ConstantLR', 'LinearLR'] |
|
def get_scheduler(name): |
|
if hasattr(lr_scheduler, name): |
|
return getattr(lr_scheduler, name) |
|
elif name in custom_schedulers: |
|
return getattr(sys.modules[__name__], name) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def getattr_recursive(m, attr): |
|
for name in attr.split('.'): |
|
m = getattr(m, name) |
|
return m |
|
|
|
|
|
def get_parameters(model, name): |
|
module = getattr_recursive(model, name) |
|
if isinstance(module, nn.Module): |
|
return module.parameters() |
|
elif isinstance(module, nn.Parameter): |
|
return module |
|
return [] |
|
|
|
|
|
def parse_optimizer(config, model): |
|
if hasattr(config, 'params'): |
|
params = [{'params': get_parameters(model, name), 'name': name, **args} for name, args in config.params.items()] |
|
rank_zero_debug('Specify optimizer params:', config.params) |
|
else: |
|
params = model.parameters() |
|
if config.name in ['FusedAdam']: |
|
import apex |
|
optim = getattr(apex.optimizers, config.name)(params, **config.args) |
|
else: |
|
optim = getattr(torch.optim, config.name)(params, **config.args) |
|
return optim |
|
|
|
|
|
def parse_scheduler(config, optimizer): |
|
interval = config.get('interval', 'epoch') |
|
assert interval in ['epoch', 'step'] |
|
if config.name == 'SequentialLR': |
|
scheduler = { |
|
'scheduler': SequentialLR(optimizer, [parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers], milestones=config.milestones), |
|
'interval': interval |
|
} |
|
elif config.name == 'Chained': |
|
scheduler = { |
|
'scheduler': ChainedScheduler([parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers]), |
|
'interval': interval |
|
} |
|
else: |
|
scheduler = { |
|
'scheduler': get_scheduler(config.name)(optimizer, **config.args), |
|
'interval': interval |
|
} |
|
return scheduler |
|
|
|
|
|
def update_module_step(m, epoch, global_step): |
|
if hasattr(m, 'update_step'): |
|
m.update_step(epoch, global_step) |
|
|