mmagic.engine.hooks
¶
Package Contents¶
Classes¶
Exponential Moving Average Hook. |
|
IterTimerHooks inherits from |
|
PGGAN Fetch Data Hook. |
|
Pickle Useful Data Hook. |
|
A hook to update learning rate. |
|
Basic hook that invoke visualizers during validation and test. |
|
MMagic Visualization Hook. Used to visual output samples in training, |
- class mmagic.engine.hooks.ExponentialMovingAverageHook(module_keys, interp_mode='lerp', interp_cfg=None, interval=- 1, start_iter=0)[source]¶
Bases:
mmengine.hooks.Hook
Exponential Moving Average Hook.
Exponential moving average is a trick that widely used in current GAN literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is maintaining a model with the same architecture, but its parameters are updated as a moving average of the trained weights in the original model. In general, the model with moving averaged weights achieves better performance.
- Parameters
module_keys (str | tuple[str]) – The name of the ema model. Note that we require these keys are followed by ‘_ema’ so that we can easily find the original model by discarding the last four characters.
interp_mode (str, optional) – Mode of the interpolation method. Defaults to ‘lerp’.
interp_cfg (dict | None, optional) – Set arguments of the interpolation function. Defaults to None.
interval (int, optional) – Evaluation interval (by iterations). Default: -1.
start_iter (int, optional) – Start iteration for ema. If the start iteration is not reached, the weights of ema model will maintain the same as the original one. Otherwise, its parameters are updated as a moving average of the trained weights in the original model. Default: 0.
- static lerp(a, b, momentum=0.001, momentum_nontrainable=1.0, trainable=True)[source]¶
Does a linear interpolation of two parameters/ buffers.
- Parameters
a (torch.Tensor) – Interpolation start point, refer to orig state.
b (torch.Tensor) – Interpolation end point, refer to ema state.
momentum (float, optional) – The weight for the interpolation formula. Defaults to 0.001.
momentum_nontrainable (float, optional) – The weight for the interpolation formula used for nontrainable parameters. Defaults to 1..
trainable (bool, optional) – Whether input parameters is trainable. If set to False, momentum_nontrainable will be used. Defaults to True.
- Returns
Interpolation result.
- Return type
torch.Tensor
- every_n_iters(runner: mmengine.runner.Runner, n: int)[source]¶
This is the function to perform every n iterations.
- Parameters
runner (Runner) – runner used to drive the whole pipeline
n (int) – the number of iterations
- Returns
the latest iterations
- Return type
int
- after_train_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) None [source]¶
This is the function to perform after each training iteration.
- Parameters
runner (Runner) – runner to drive the pipeline
batch_idx (int) – the id of batch
data_batch (DATA_BATCH, optional) – data batch. Defaults to None.
outputs (Optional[dict], optional) – output. Defaults to None.
- class mmagic.engine.hooks.IterTimerHook[source]¶
Bases:
mmengine.hooks.IterTimerHook
IterTimerHooks inherits from
mmengine.hooks.IterTimerHook
and overwritesself._after_iter()
.This hooks should be used along with
mmagic.engine.runner.MultiValLoop
andmmagic.engine.runner.MultiTestLoop
.- _after_iter(runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Union[dict, Sequence[mmengine.structures.BaseDataElement]]] = None, mode: str = 'train') None [source]¶
Calculating time for an iteration and updating “time”
HistoryBuffer
ofrunner.message_hub
. If mode is ‘train’, we take runner.max_iters as the total iterations and calculate the rest time. If mode in val or test, we use runner.val_loop.total_length or runner.test_loop.total_length as total number of iterations. If you want to know how total_length is calculated, please refers tommagic.engine.runner.MultiValLoop.run()
andmmagic.engine.runner.MultiTestLoop.run()
.- Parameters
runner (Runner) – The runner of the training validation and testing process.
batch_idx (int) – The index of the current batch in the loop.
data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.
outputs (dict or sequence, optional) – Outputs from model. Defaults to None.
mode (str) – Current mode of runner. Defaults to ‘train’.
- class mmagic.engine.hooks.PGGANFetchDataHook[source]¶
Bases:
mmengine.hooks.Hook
PGGAN Fetch Data Hook.
- Parameters
interval (int, optional) – The interval of calling this hook. If set to -1, the visualization hook will not be called. Defaults to 1.
- before_train_iter(runner, batch_idx: int, data_batch: DATA_BATCH = None) None [source]¶
All subclasses should override this method, if they need any operations before each training iteration.
- Parameters
runner (Runner) – The runner of the training process.
batch_idx (int) – The index of the current batch in the train loop.
data_batch (dict or tuple or list, optional) – Data from dataloader.
- update_dataloader(dataloader: torch.utils.data.dataloader.DataLoader, curr_scale: int) Optional[torch.utils.data.dataloader.DataLoader] [source]¶
Update the data loader.
- Parameters
dataloader (DataLoader) – The dataloader to be updated.
curr_scale (int) – The current scale of the generated image.
- Returns
- The updated dataloader. If the dataloader do
not need to update, return None.
- Return type
Optional[DataLoader]
- class mmagic.engine.hooks.PickleDataHook(output_dir, data_name_list, interval=- 1, before_run=False, after_run=False, filename_tmpl='iter_{}.pkl')[source]¶
Bases:
mmengine.hooks.Hook
Pickle Useful Data Hook.
This hook will be used in SinGAN training for saving some important data that will be used in testing or inference.
- Parameters
output_dir (str) – The output path for saving pickled data.
data_name_list (list[str]) – The list contains the name of results in outputs dict.
interval (int) – The interval of calling this hook. If set to -1, the PickleDataHook will not be called during training. Default: -1.
before_run (bool, optional) – Whether to save before running. Defaults to False.
after_run (bool, optional) – Whether to save after running. Defaults to False.
filename_tmpl (str, optional) – Format string used to save images. The output file name will be formatted as this args. Defaults to ‘iter_{}.pkl’.
- after_run(runner)[source]¶
The behavior after each train iteration.
- Parameters
runner (object) – The runner.
- before_run(runner)[source]¶
The behavior after each train iteration.
- Parameters
runner (object) – The runner.
- after_train_iter(runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None)[source]¶
The behavior after each train iteration.
- Parameters
runner (Runner) – The runner of the training process.
batch_idx (int) – The index of the current batch in the train loop.
data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.
outputs (dict, optional) – Outputs from model. Defaults to None.
- _pickle_data(runner: mmengine.runner.Runner)[source]¶
Save target data to pickle file.
- Parameters
runner (Runner) – The runner of the training process.
- _get_numpy_data(data: Tuple[List[torch.Tensor], torch.Tensor, int]) Tuple[List[numpy.ndarray], numpy.ndarray, int] [source]¶
Convert tensor or list of tensor to numpy or list of numpy.
- Parameters
data (Tuple[List[Tensor], Tensor, int]) – Data to be converted.
- Returns
Converted data.
- Return type
Tuple[List[np.ndarray], np.ndarray, int]
- class mmagic.engine.hooks.ReduceLRSchedulerHook(val_metric: str = None, by_epoch=True, interval=1)[source]¶
Bases:
mmengine.hooks.ParamSchedulerHook
A hook to update learning rate.
- Parameters
val_metric (str) – The metric of validation. If val_metric is not None, we check val_metric to reduce learning. Default: None.
by_epoch (bool) – Whether to update by epoch. Default: True.
interval (int) – The interval of iterations to update. Default: 1.
- after_train_epoch(runner: mmengine.runner.Runner)[source]¶
Call step function for each scheduler after each train epoch.
- Parameters
runner (Runner) – The runner of the training process.
- after_train_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) None [source]¶
Call step function for each scheduler after each iteration.
- Parameters
runner (Runner) – The runner of the training process.
batch_idx (int) – The index of the current batch in the train loop.
data_batch (Sequence[dict], optional) – Data from dataloader. In order to keep this interface consistent with other hooks, we keep
data_batch
here. Defaults to None.outputs (dict, optional) – Outputs from model. In order to keep this interface consistent with other hooks, we keep
data_batch
here. Defaults to None.
- class mmagic.engine.hooks.BasicVisualizationHook(interval: dict = {}, on_train=False, on_val=True, on_test=True)[source]¶
Bases:
mmengine.hooks.Hook
Basic hook that invoke visualizers during validation and test.
- Parameters
interval (int | dict) – Visualization interval. Default: {}.
on_train (bool) – Whether to call hook during train. Default to False.
on_val (bool) – Whether to call hook during validation. Default to True.
on_test (bool) – Whether to call hook during test. Default to True.
- priority = NORMAL¶
- _after_iter(runner, batch_idx: int, data_batch: Optional[Sequence[dict]], outputs: Optional[Sequence[mmengine.structures.BaseDataElement]], mode=None) None [source]¶
Show or Write the predicted results.
- Parameters
runner (Runner) – The runner of the training process.
batch_idx (int) – The index of the current batch in the test loop.
data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataElement], optional) – Outputs from model. Defaults to None.
- class mmagic.engine.hooks.VisualizationHook(interval: int = 1000, vis_kwargs_list: Tuple[List[dict], dict] = None, fixed_input: bool = True, n_samples: Optional[int] = 64, n_row: Optional[int] = None, message_hub_vis_kwargs: Optional[Tuple[str, dict, List[str], List[Dict]]] = None, save_at_test: bool = True, max_save_at_test: int = 100, test_vis_keys: Optional[Union[str, List[str]]] = None, show: bool = False, wait_time: float = 0)[source]¶
Bases:
mmengine.hooks.Hook
MMagic Visualization Hook. Used to visual output samples in training, validation and testing. In this hook, we use a list called sample_kwargs_list to control how to generate samples and how to visualize them. Each element in sample_kwargs_list, called sample_kwargs, may contains the following keywords:
- Required key words:
- ‘type’: Value must be string. Denotes what kind of sampler is used to
generate image. Refers to
get_sampler()
.
- Optional key words (If not passed, will use the default value):
‘n_row’: Value must be int. The number of images in one row.
‘num_samples’: Value must be int. The number of samples to visualize.
- ‘vis_mode’: Value must be string. How to visualize the generated
samples (e.g. image, gif).
- ‘fixed_input’: Value must be bool. Whether use the fixed input
during the loop.
‘draw_gt’: Value must be bool. Whether save the real images.
- ‘target_keys’: Value must be string or list of string. The keys of
the target image to visualize.
- ‘name’: Value must be string. If not passed, will use
sample_kwargs[‘type’] as default.
For convenience, we also define a group of alias of samplers’ type for models supported in MMagic. Refers to :attr:self.SAMPLER_TYPE_MAPPING.
Example
>>> # for GAN models >>> custom_hooks = [ >>> dict( >>> type='VisualizationHook', >>> interval=1000, >>> fixed_input=True, >>> vis_kwargs_list=dict(type='GAN', name='fake_img'))] >>> # for Translation models >>> custom_hooks = [ >>> dict( >>> type='VisualizationHook', >>> interval=10, >>> fixed_input=False, >>> vis_kwargs_list=[dict(type='Translation', >>> name='translation_train', >>> n_samples=6, draw_gt=True, >>> n_row=3), >>> dict(type='TranslationVal', >>> name='translation_val', >>> n_samples=16, draw_gt=True, >>> n_row=4)])]
# NOTE: user-defined vis_kwargs > vis_kwargs_mapping > hook init args
- Parameters
interval (int) – Visualization interval. Default: 1000.
sampler_kwargs_list (Tuple[List[dict], dict]) – The list of sampling behavior to generate images.
fixed_input (bool) – The default action of whether use fixed input to generate samples during the loop. Defaults to True.
n_samples (Optional[int]) – The default value of number of samples to visualize. Defaults to 64.
n_row (Optional[int]) – The default value of number of images in each row in the visualization results. Defaults to None.
(Optional[Tuple[str (message_hub_vis_kwargs) – List[Dict]]]): Key arguments visualize images in message hub. Defaults to None.
dict – List[Dict]]]): Key arguments visualize images in message hub. Defaults to None.
List[str] – List[Dict]]]): Key arguments visualize images in message hub. Defaults to None.
- :paramList[Dict]]]): Key arguments visualize images in message hub.
Defaults to None.
- Parameters
save_at_test (bool) – Whether save images during test. Defaults to True.
max_save_at_test (int) – Maximum number of samples saved at test time. If None is passed, all samples will be saved. Defaults to 100.
show (bool) – Whether to display the drawn image. Default to False.
wait_time (float) – The interval of show (s). Defaults to 0.
- priority = NORMAL¶
- VIS_KWARGS_MAPPING¶
- after_val_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: dict, outputs) None [source]¶
VisualizationHook
do not support visualize during validation.- Parameters
runner (Runner) – The runner of the training process.
batch_idx (int) – The index of the current batch in the test loop.
data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.
outputs – outputs of the generation model
- after_test_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: dict, outputs)[source]¶
Visualize samples after test iteration.
- Parameters
runner (Runner) – The runner of the training process.
batch_idx (int) – The index of the current batch in the test loop.
data_batch (dict, optional) – Data from dataloader. Defaults to None.
outputs – outputs of the generation model Defaults to None.
- after_train_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: dict = None, outputs: Optional[dict] = None) None [source]¶
Visualize samples after train iteration.
- Parameters
runner (Runner) – The runner of the training process.
batch_idx (int) – The index of the current batch in the train loop.
data_batch (dict) – Data from dataloader. Defaults to None.
outputs (dict, optional) – Outputs from model. Defaults to None.
- vis_sample(runner: mmengine.runner.Runner, batch_idx: int, data_batch: dict, outputs: Optional[dict] = None) None [source]¶
Visualize samples.
- Parameters
runner (Runner) – The runner contains model to visualize.
batch_idx (int) – The index of the current batch in loop.
data_batch (dict) – Data from dataloader. Defaults to None.
outputs (dict, optional) – Outputs from model. Defaults to None.
- vis_from_message_hub(batch_idx: int)[source]¶
Visualize samples from message hub.
- Parameters
batch_idx (int) – The index of the current batch in the test loop.
color_order (str) – The color order of generated images.
target_mean (Sequence[Union[float, int]]) – The original mean of the image tensor before preprocessing. Image will be re-shifted to
target_mean
before visualizing.target_std (Sequence[Union[float, int]]) – The original std of the image tensor before preprocessing. Image will be re-scaled to
target_std
before visualizing.