mmagic.engine.hooks.ema
¶
Module Contents¶
Classes¶
Exponential Moving Average Hook. |
Attributes¶
- class mmagic.engine.hooks.ema.ExponentialMovingAverageHook(module_keys, interp_mode='lerp', interp_cfg=None, interval=- 1, start_iter=0)[source]¶
Bases:
mmengine.hooks.Hook
Exponential Moving Average Hook.
Exponential moving average is a trick that widely used in current GAN literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is maintaining a model with the same architecture, but its parameters are updated as a moving average of the trained weights in the original model. In general, the model with moving averaged weights achieves better performance.
- Parameters
module_keys (str | tuple[str]) – The name of the ema model. Note that we require these keys are followed by ‘_ema’ so that we can easily find the original model by discarding the last four characters.
interp_mode (str, optional) – Mode of the interpolation method. Defaults to ‘lerp’.
interp_cfg (dict | None, optional) – Set arguments of the interpolation function. Defaults to None.
interval (int, optional) – Evaluation interval (by iterations). Default: -1.
start_iter (int, optional) – Start iteration for ema. If the start iteration is not reached, the weights of ema model will maintain the same as the original one. Otherwise, its parameters are updated as a moving average of the trained weights in the original model. Default: 0.
- static lerp(a, b, momentum=0.001, momentum_nontrainable=1.0, trainable=True)[source]¶
Does a linear interpolation of two parameters/ buffers.
- Parameters
a (torch.Tensor) – Interpolation start point, refer to orig state.
b (torch.Tensor) – Interpolation end point, refer to ema state.
momentum (float, optional) – The weight for the interpolation formula. Defaults to 0.001.
momentum_nontrainable (float, optional) – The weight for the interpolation formula used for nontrainable parameters. Defaults to 1..
trainable (bool, optional) – Whether input parameters is trainable. If set to False, momentum_nontrainable will be used. Defaults to True.
- Returns
Interpolation result.
- Return type
torch.Tensor
- every_n_iters(runner: mmengine.runner.Runner, n: int)[source]¶
This is the function to perform every n iterations.
- Parameters
runner (Runner) – runner used to drive the whole pipeline
n (int) – the number of iterations
- Returns
the latest iterations
- Return type
int
- after_train_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) None [source]¶
This is the function to perform after each training iteration.
- Parameters
runner (Runner) – runner to drive the pipeline
batch_idx (int) – the id of batch
data_batch (DATA_BATCH, optional) – data batch. Defaults to None.
outputs (Optional[dict], optional) – output. Defaults to None.