mmagic.models.base_models.average_model
¶
Module Contents¶
Classes¶
Implements the exponential moving average (EMA) of the model. |
|
Implements the exponential moving average with ramping up momentum. |
- class mmagic.models.base_models.average_model.ExponentialMovingAverage(model: torch.nn.Module, momentum: float = 0.0002, interval: int = 1, device: Optional[torch.device] = None, update_buffers: bool = False)[source]¶
Bases:
mmengine.model.BaseAveragedModel
Implements the exponential moving average (EMA) of the model.
All parameters are updated by the formula as below:
\[Xema_{t+1} = (1 - momentum) * Xema_{t} + momentum * X_t\]- Parameters
model (nn.Module) – The model to be averaged.
momentum (float) – The momentum used for updating ema parameter. Defaults to 0.0002. Ema’s parameter are updated with the formula \(averaged\_param = (1-momentum) * averaged\_param + momentum * source\_param\).
interval (int) – Interval between two updates. Defaults to 1.
device (torch.device, optional) – If provided, the averaged model will be stored on the
device
. Defaults to None.update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.
- avg_func(averaged_param: torch.Tensor, source_param: torch.Tensor, steps: int) None [source]¶
Compute the moving average of the parameters using exponential moving average.
- Parameters
averaged_param (Tensor) – The averaged parameters.
source_param (Tensor) – The source parameters.
steps (int) – The number of times the parameters have been updated.
- _load_from_state_dict(state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: list, unexpected_keys: list, error_msgs: List[str]) None [source]¶
Overrides
nn.Module._load_from_state_dict
to support loadingstate_dict
without wrap ema module withBaseAveragedModel
.In OpenMMLab 1.0, model will not wrap ema submodule with
BaseAveragedModel
, and the ema weight key in state_dict will miss module prefix. Therefore,BaseAveragedModel
need to automatically add themodule
prefix if the corresponding key instate_dict
misses it.- Parameters
state_dict (dict) – A dict containing parameters and persistent buffers.
prefix (str) – The prefix for parameters and buffers used in this module
local_metadata (dict) – a dict containing the metadata for this module.
strict (bool) – Whether to strictly enforce that the keys in
state_dict
withprefix
match the names of parameters and buffers in this modulemissing_keys (List[str]) – if
strict=True
, add missing keys to this listunexpected_keys (List[str]) – if
strict=True
, add unexpected keys to this listerror_msgs (List[str]) – error messages should be added to this list, and will be reported together in
load_state_dict()
.
- class mmagic.models.base_models.average_model.RampUpEMA(model: torch.nn.Module, interval: int = 1, ema_kimg: int = 10, ema_rampup: float = 0.05, batch_size: int = 32, eps: float = 1e-08, start_iter: int = 0, device: Optional[torch.device] = None, update_buffers: bool = False)[source]¶
Bases:
mmengine.model.BaseAveragedModel
Implements the exponential moving average with ramping up momentum.
Ref: https://github.com/NVlabs/stylegan3/blob/master/training/training_loop.py # noqa
- Parameters
model (nn.Module) – The model to be averaged.
interval (int) – Interval between two updates. Defaults to 1.
ema_kimg (int, optional) – EMA kimgs. Defaults to 10.
ema_rampup (float, optional) – Ramp up rate. Defaults to 0.05.
batch_size (int, optional) – Global batch size. Defaults to 32.
eps (float, optional) – Ramp up epsilon. Defaults to 1e-8.
start_iter (int, optional) – EMA start iter. Defaults to 0.
device (torch.device, optional) – If provided, the averaged model will be stored on the
device
. Defaults to None.update_buffers (bool) – if True, it will compute running averages for both the parameters and the buffers of the model. Defaults to False.
- static rampup(steps, ema_kimg=10, ema_rampup=0.05, batch_size=4, eps=1e-08)[source]¶
Ramp up ema momentum.
- Parameters
steps –
ema_kimg (int, optional) – Half-life of the exponential moving average of generator weights. Defaults to 10.
ema_rampup (float, optional) – EMA ramp-up coefficient.If set to None, then rampup will be disabled. Defaults to 0.05.
batch_size (int, optional) – Total batch size for one training iteration. Defaults to 4.
eps (float, optional) – Epsiolon to avoid
batch_size
divided by zero. Defaults to 1e-8.
- Returns
Updated momentum.
- Return type
dict
- avg_func(averaged_param: torch.Tensor, source_param: torch.Tensor, steps: int) None [source]¶
Compute the moving average of the parameters using exponential moving average.
- Parameters
averaged_param (Tensor) – The averaged parameters.
source_param (Tensor) – The source parameters.
steps (int) – The number of times the parameters have been updated.
- _load_from_state_dict(state_dict: dict, prefix: str, local_metadata: dict, strict: bool, missing_keys: list, unexpected_keys: list, error_msgs: List[str]) None [source]¶
Overrides
nn.Module._load_from_state_dict
to support loadingstate_dict
without wrap ema module withBaseAveragedModel
.In OpenMMLab 1.0, model will not wrap ema submodule with
BaseAveragedModel
, and the ema weight key in state_dict will miss module prefix. Therefore,BaseAveragedModel
need to automatically add themodule
prefix if the corresponding key instate_dict
misses it.- Parameters
state_dict (dict) – A dict containing parameters and persistent buffers.
prefix (str) – The prefix for parameters and buffers used in this module
local_metadata (dict) – a dict containing the metadata for this module.
strict (bool) – Whether to strictly enforce that the keys in
state_dict
withprefix
match the names of parameters and buffers in this modulemissing_keys (List[str]) – if
strict=True
, add missing keys to this listunexpected_keys (List[str]) – if
strict=True
, add unexpected keys to this listerror_msgs (List[str]) – error messages should be added to this list, and will be reported together in
load_state_dict()
.