Shortcuts

mmagic.engine.hooks

Package Contents

Classes

ExponentialMovingAverageHook

Exponential Moving Average Hook.

IterTimerHook

IterTimerHooks inherits from mmengine.hooks.IterTimerHook and

PGGANFetchDataHook

PGGAN Fetch Data Hook.

PickleDataHook

Pickle Useful Data Hook.

ReduceLRSchedulerHook

A hook to update learning rate.

BasicVisualizationHook

Basic hook that invoke visualizers during validation and test.

VisualizationHook

MMagic Visualization Hook. Used to visual output samples in training,

class mmagic.engine.hooks.ExponentialMovingAverageHook(module_keys, interp_mode='lerp', interp_cfg=None, interval=- 1, start_iter=0)[source]

Bases: mmengine.hooks.Hook

Exponential Moving Average Hook.

Exponential moving average is a trick that widely used in current GAN literature, e.g., PGGAN, StyleGAN, and BigGAN. This general idea of it is maintaining a model with the same architecture, but its parameters are updated as a moving average of the trained weights in the original model. In general, the model with moving averaged weights achieves better performance.

Parameters
  • module_keys (str | tuple[str]) – The name of the ema model. Note that we require these keys are followed by ‘_ema’ so that we can easily find the original model by discarding the last four characters.

  • interp_mode (str, optional) – Mode of the interpolation method. Defaults to ‘lerp’.

  • interp_cfg (dict | None, optional) – Set arguments of the interpolation function. Defaults to None.

  • interval (int, optional) – Evaluation interval (by iterations). Default: -1.

  • start_iter (int, optional) – Start iteration for ema. If the start iteration is not reached, the weights of ema model will maintain the same as the original one. Otherwise, its parameters are updated as a moving average of the trained weights in the original model. Default: 0.

static lerp(a, b, momentum=0.001, momentum_nontrainable=1.0, trainable=True)[source]

Does a linear interpolation of two parameters/ buffers.

Parameters
  • a (torch.Tensor) – Interpolation start point, refer to orig state.

  • b (torch.Tensor) – Interpolation end point, refer to ema state.

  • momentum (float, optional) – The weight for the interpolation formula. Defaults to 0.001.

  • momentum_nontrainable (float, optional) – The weight for the interpolation formula used for nontrainable parameters. Defaults to 1..

  • trainable (bool, optional) – Whether input parameters is trainable. If set to False, momentum_nontrainable will be used. Defaults to True.

Returns

Interpolation result.

Return type

torch.Tensor

every_n_iters(runner: mmengine.runner.Runner, n: int)[source]

This is the function to perform every n iterations.

Parameters
  • runner (Runner) – runner used to drive the whole pipeline

  • n (int) – the number of iterations

Returns

the latest iterations

Return type

int

after_train_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) None[source]

This is the function to perform after each training iteration.

Parameters
  • runner (Runner) – runner to drive the pipeline

  • batch_idx (int) – the id of batch

  • data_batch (DATA_BATCH, optional) – data batch. Defaults to None.

  • outputs (Optional[dict], optional) – output. Defaults to None.

before_run(runner: mmengine.runner.Runner)[source]

This is the function perform before each run.

Parameters

runner (Runner) – runner used to drive the whole pipeline

Raises

RuntimeError – error message

class mmagic.engine.hooks.IterTimerHook[source]

Bases: mmengine.hooks.IterTimerHook

IterTimerHooks inherits from mmengine.hooks.IterTimerHook and overwrites self._after_iter().

This hooks should be used along with mmagic.engine.runner.MultiValLoop and mmagic.engine.runner.MultiTestLoop.

_after_iter(runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[Union[dict, Sequence[mmengine.structures.BaseDataElement]]] = None, mode: str = 'train') None[source]

Calculating time for an iteration and updating “time” HistoryBuffer of runner.message_hub. If mode is ‘train’, we take runner.max_iters as the total iterations and calculate the rest time. If mode in val or test, we use runner.val_loop.total_length or runner.test_loop.total_length as total number of iterations. If you want to know how total_length is calculated, please refers to mmagic.engine.runner.MultiValLoop.run() and mmagic.engine.runner.MultiTestLoop.run().

Parameters
  • runner (Runner) – The runner of the training validation and testing process.

  • batch_idx (int) – The index of the current batch in the loop.

  • data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.

  • outputs (dict or sequence, optional) – Outputs from model. Defaults to None.

  • mode (str) – Current mode of runner. Defaults to ‘train’.

class mmagic.engine.hooks.PGGANFetchDataHook[source]

Bases: mmengine.hooks.Hook

PGGAN Fetch Data Hook.

Parameters

interval (int, optional) – The interval of calling this hook. If set to -1, the visualization hook will not be called. Defaults to 1.

before_train_iter(runner, batch_idx: int, data_batch: DATA_BATCH = None) None[source]

All subclasses should override this method, if they need any operations before each training iteration.

Parameters
  • runner (Runner) – The runner of the training process.

  • batch_idx (int) – The index of the current batch in the train loop.

  • data_batch (dict or tuple or list, optional) – Data from dataloader.

update_dataloader(dataloader: torch.utils.data.dataloader.DataLoader, curr_scale: int) Optional[torch.utils.data.dataloader.DataLoader][source]

Update the data loader.

Parameters
  • dataloader (DataLoader) – The dataloader to be updated.

  • curr_scale (int) – The current scale of the generated image.

Returns

The updated dataloader. If the dataloader do

not need to update, return None.

Return type

Optional[DataLoader]

class mmagic.engine.hooks.PickleDataHook(output_dir, data_name_list, interval=- 1, before_run=False, after_run=False, filename_tmpl='iter_{}.pkl')[source]

Bases: mmengine.hooks.Hook

Pickle Useful Data Hook.

This hook will be used in SinGAN training for saving some important data that will be used in testing or inference.

Parameters
  • output_dir (str) – The output path for saving pickled data.

  • data_name_list (list[str]) – The list contains the name of results in outputs dict.

  • interval (int) – The interval of calling this hook. If set to -1, the PickleDataHook will not be called during training. Default: -1.

  • before_run (bool, optional) – Whether to save before running. Defaults to False.

  • after_run (bool, optional) – Whether to save after running. Defaults to False.

  • filename_tmpl (str, optional) – Format string used to save images. The output file name will be formatted as this args. Defaults to ‘iter_{}.pkl’.

after_run(runner)[source]

The behavior after each train iteration.

Parameters

runner (object) – The runner.

before_run(runner)[source]

The behavior after each train iteration.

Parameters

runner (object) – The runner.

after_train_iter(runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None)[source]

The behavior after each train iteration.

Parameters
  • runner (Runner) – The runner of the training process.

  • batch_idx (int) – The index of the current batch in the train loop.

  • data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.

  • outputs (dict, optional) – Outputs from model. Defaults to None.

_pickle_data(runner: mmengine.runner.Runner)[source]

Save target data to pickle file.

Parameters

runner (Runner) – The runner of the training process.

_get_numpy_data(data: Tuple[List[torch.Tensor], torch.Tensor, int]) Tuple[List[numpy.ndarray], numpy.ndarray, int][source]

Convert tensor or list of tensor to numpy or list of numpy.

Parameters

data (Tuple[List[Tensor], Tensor, int]) – Data to be converted.

Returns

Converted data.

Return type

Tuple[List[np.ndarray], np.ndarray, int]

class mmagic.engine.hooks.ReduceLRSchedulerHook(val_metric: str = None, by_epoch=True, interval=1)[source]

Bases: mmengine.hooks.ParamSchedulerHook

A hook to update learning rate.

Parameters
  • val_metric (str) – The metric of validation. If val_metric is not None, we check val_metric to reduce learning. Default: None.

  • by_epoch (bool) – Whether to update by epoch. Default: True.

  • interval (int) – The interval of iterations to update. Default: 1.

_calculate_average_value()[source]
after_train_epoch(runner: mmengine.runner.Runner)[source]

Call step function for each scheduler after each train epoch.

Parameters

runner (Runner) – The runner of the training process.

after_train_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) None[source]

Call step function for each scheduler after each iteration.

Parameters
  • runner (Runner) – The runner of the training process.

  • batch_idx (int) – The index of the current batch in the train loop.

  • data_batch (Sequence[dict], optional) – Data from dataloader. In order to keep this interface consistent with other hooks, we keep data_batch here. Defaults to None.

  • outputs (dict, optional) – Outputs from model. In order to keep this interface consistent with other hooks, we keep data_batch here. Defaults to None.

after_val_epoch(runner, metrics: Optional[Dict[str, float]] = None)[source]

Call step function for each scheduler after each validation epoch.

Parameters
  • runner (Runner) – The runner of the training process.

  • metrics (dict, optional) – The metrics of validation. Default: None.

class mmagic.engine.hooks.BasicVisualizationHook(interval: dict = {}, on_train=False, on_val=True, on_test=True)[source]

Bases: mmengine.hooks.Hook

Basic hook that invoke visualizers during validation and test.

Parameters
  • interval (int | dict) – Visualization interval. Default: {}.

  • on_train (bool) – Whether to call hook during train. Default to False.

  • on_val (bool) – Whether to call hook during validation. Default to True.

  • on_test (bool) – Whether to call hook during test. Default to True.

priority = NORMAL
_after_iter(runner, batch_idx: int, data_batch: Optional[Sequence[dict]], outputs: Optional[Sequence[mmengine.structures.BaseDataElement]], mode=None) None[source]

Show or Write the predicted results.

Parameters
  • runner (Runner) – The runner of the training process.

  • batch_idx (int) – The index of the current batch in the test loop.

  • data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.

  • outputs (Sequence[BaseDataElement], optional) – Outputs from model. Defaults to None.

class mmagic.engine.hooks.VisualizationHook(interval: int = 1000, vis_kwargs_list: Tuple[List[dict], dict] = None, fixed_input: bool = True, n_samples: Optional[int] = 64, n_row: Optional[int] = None, message_hub_vis_kwargs: Optional[Tuple[str, dict, List[str], List[Dict]]] = None, save_at_test: bool = True, max_save_at_test: int = 100, test_vis_keys: Optional[Union[str, List[str]]] = None, show: bool = False, wait_time: float = 0)[source]

Bases: mmengine.hooks.Hook

MMagic Visualization Hook. Used to visual output samples in training, validation and testing. In this hook, we use a list called sample_kwargs_list to control how to generate samples and how to visualize them. Each element in sample_kwargs_list, called sample_kwargs, may contains the following keywords:

  • Required key words:
    • ‘type’: Value must be string. Denotes what kind of sampler is used to

      generate image. Refers to get_sampler().

  • Optional key words (If not passed, will use the default value):
    • ‘n_row’: Value must be int. The number of images in one row.

    • ‘num_samples’: Value must be int. The number of samples to visualize.

    • ‘vis_mode’: Value must be string. How to visualize the generated

      samples (e.g. image, gif).

    • ‘fixed_input’: Value must be bool. Whether use the fixed input

      during the loop.

    • ‘draw_gt’: Value must be bool. Whether save the real images.

    • ‘target_keys’: Value must be string or list of string. The keys of

      the target image to visualize.

    • ‘name’: Value must be string. If not passed, will use

      sample_kwargs[‘type’] as default.

For convenience, we also define a group of alias of samplers’ type for models supported in MMagic. Refers to :attr:self.SAMPLER_TYPE_MAPPING.

Example

>>> # for GAN models
>>> custom_hooks = [
>>>     dict(
>>>         type='VisualizationHook',
>>>         interval=1000,
>>>         fixed_input=True,
>>>         vis_kwargs_list=dict(type='GAN', name='fake_img'))]
>>> # for Translation models
>>> custom_hooks = [
>>>     dict(
>>>         type='VisualizationHook',
>>>         interval=10,
>>>         fixed_input=False,
>>>         vis_kwargs_list=[dict(type='Translation',
>>>                                  name='translation_train',
>>>                                  n_samples=6, draw_gt=True,
>>>                                  n_row=3),
>>>                             dict(type='TranslationVal',
>>>                                  name='translation_val',
>>>                                  n_samples=16, draw_gt=True,
>>>                                  n_row=4)])]

# NOTE: user-defined vis_kwargs > vis_kwargs_mapping > hook init args

Parameters
  • interval (int) – Visualization interval. Default: 1000.

  • sampler_kwargs_list (Tuple[List[dict], dict]) – The list of sampling behavior to generate images.

  • fixed_input (bool) – The default action of whether use fixed input to generate samples during the loop. Defaults to True.

  • n_samples (Optional[int]) – The default value of number of samples to visualize. Defaults to 64.

  • n_row (Optional[int]) – The default value of number of images in each row in the visualization results. Defaults to None.

  • (Optional[Tuple[str (message_hub_vis_kwargs) – List[Dict]]]): Key arguments visualize images in message hub. Defaults to None.

  • dict – List[Dict]]]): Key arguments visualize images in message hub. Defaults to None.

  • List[str] – List[Dict]]]): Key arguments visualize images in message hub. Defaults to None.

:paramList[Dict]]]): Key arguments visualize images in message hub.

Defaults to None.

Parameters
  • save_at_test (bool) – Whether save images during test. Defaults to True.

  • max_save_at_test (int) – Maximum number of samples saved at test time. If None is passed, all samples will be saved. Defaults to 100.

  • show (bool) – Whether to display the drawn image. Default to False.

  • wait_time (float) – The interval of show (s). Defaults to 0.

priority = NORMAL
VIS_KWARGS_MAPPING
after_val_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: dict, outputs) None[source]

VisualizationHook do not support visualize during validation.

Parameters
  • runner (Runner) – The runner of the training process.

  • batch_idx (int) – The index of the current batch in the test loop.

  • data_batch (Sequence[dict], optional) – Data from dataloader. Defaults to None.

  • outputs – outputs of the generation model

after_test_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: dict, outputs)[source]

Visualize samples after test iteration.

Parameters
  • runner (Runner) – The runner of the training process.

  • batch_idx (int) – The index of the current batch in the test loop.

  • data_batch (dict, optional) – Data from dataloader. Defaults to None.

  • outputs – outputs of the generation model Defaults to None.

after_train_iter(runner: mmengine.runner.Runner, batch_idx: int, data_batch: dict = None, outputs: Optional[dict] = None) None[source]

Visualize samples after train iteration.

Parameters
  • runner (Runner) – The runner of the training process.

  • batch_idx (int) – The index of the current batch in the train loop.

  • data_batch (dict) – Data from dataloader. Defaults to None.

  • outputs (dict, optional) – Outputs from model. Defaults to None.

vis_sample(runner: mmengine.runner.Runner, batch_idx: int, data_batch: dict, outputs: Optional[dict] = None) None[source]

Visualize samples.

Parameters
  • runner (Runner) – The runner contains model to visualize.

  • batch_idx (int) – The index of the current batch in loop.

  • data_batch (dict) – Data from dataloader. Defaults to None.

  • outputs (dict, optional) – Outputs from model. Defaults to None.

vis_from_message_hub(batch_idx: int)[source]

Visualize samples from message hub.

Parameters
  • batch_idx (int) – The index of the current batch in the test loop.

  • color_order (str) – The color order of generated images.

  • target_mean (Sequence[Union[float, int]]) – The original mean of the image tensor before preprocessing. Image will be re-shifted to target_mean before visualizing.

  • target_std (Sequence[Union[float, int]]) – The original std of the image tensor before preprocessing. Image will be re-scaled to target_std before visualizing.

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.