mmagic.models.base_models.base_gan
¶
Module Contents¶
Attributes¶
- class mmagic.models.base_models.base_gan.BaseGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[source]¶
Bases:
mmengine.model.BaseModel
Base class for GAN models.
- Parameters
generator (ModelType) – The config or model of the generator.
discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.
data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or
DataPreprocessor
.generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.
- property generator_steps: int[source]¶
The number of times the generator is completely updated before the discriminator is updated.
- Type
int
- property discriminator_steps: int[source]¶
The number of times the discriminator is completely updated before the generator is updated.
- Type
int
- property device: torch.device[source]¶
Get current device of the model.
- Returns
The current device of the model.
- Return type
torch.device
- property with_ema_gen: bool[source]¶
Whether the GAN adopts exponential moving average.
- Returns
- If True, means this GAN model is adopted to exponential
moving average and vice versa.
- Return type
bool
- static gather_log_vars(log_vars_list: List[Dict[str, torch.Tensor]]) Dict[str, torch.Tensor] [source]¶
Gather a list of log_vars. :param log_vars_list: List[Dict[str, Tensor]]
- Returns
Dict[str, Tensor]
- _init_loss(loss_config: Optional[Dict] = None) None [source]¶
Initialize customized loss modules.
If loss_config is a dict, we allow kinds of value for each field.
- loss_config is None: Users will implement all loss calculations
in their own function. Weights for each loss terms are hard coded.
- loss_config is dict of scalar or string: Users will implement all
loss calculations and use passed loss_config to control the weight or behavior of the loss calculation. Users will unpack and use each field in this dict by themselves.
loss_config = dict(gp_norm_mode=’HWC’, gp_loss_weight=10)
- loss_config is dict of dict: Each field in loss_config will
used to build a corresponding loss module. And use loss calculation function predefined by
BaseGAN
to calculate the loss.loss_config = dict()
Example
- loss_config = dict(
# BaseGAN pre-defined fields gan_loss=dict(type=’GANLoss’, gan_type=’wgan-logistic-ns’), disc_auxiliary_loss=dict(
type=’R1GradientPenalty’, loss_weight=10. / 2., interval=2, norm_mode=’HWC’, data_info=dict(
real_data=’real_imgs’, discriminator=’disc’)),
- gen_auxiliary_loss=dict(
type=’GeneratorPathRegularizer’, loss_weight=2, pl_batch_shrink=2, interval=g_reg_interval, data_info=dict(
generator=’gen’, num_batches=’batch_size’)),
# user-defined field for loss weights or loss calculation my_loss_2=dict(weight=2, norm_mode=’L1’), my_loss_3=2, my_loss_4_norm_type=’L2’)
- Parameters
loss_config (Optional[Dict], optional) – Loss config used to build loss modules or define the loss weights. Defaults to None.
- noise_fn(noise: mmagic.utils.typing.NoiseVar = None, num_batches: int = 1)[source]¶
Sampling function for noise. There are three scenarios in this function:
If noise is a callable function, sample num_batches of noise with passed noise.
If noise is None, sample num_batches of noise from gaussian distribution.
If noise is a torch.Tensor, directly return noise.
- Parameters
noise (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a
torch.Tensor
or offer a callable function to sample a batch of label data. Otherwise, theNone
indicates to use the default noise sampler. Defaults to None.num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.
- Returns
Sampled noise tensor.
- Return type
Tensor
- _init_ema_model(ema_config: dict)[source]¶
Initialize a EMA model corresponding to the given ema_config. If ema_config is an empty dict or None, EMA model will not be initialized.
- Parameters
ema_config (dict) – Config to initialize the EMA model.
- _get_valid_model(batch_inputs: mmagic.utils.typing.ForwardInputs) str [source]¶
Try to get the valid forward model from inputs.
If forward model is defined in batch_inputs, it will be used as forward model.
If forward model is not defined in batch_inputs, ‘ema’ will returned if :property:`with_ema_gen` is true. Otherwise, ‘orig’ will be returned.
- Parameters
batch_inputs (ForwardInputs) – Inputs passed to
forward()
.- Returns
- Forward model to generate image. (‘orig’, ‘ema’ or
’ema/orig’).
- Return type
str
- forward(inputs: mmagic.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) mmagic.utils.typing.SampleList [source]¶
Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.
- Parameters
batch_inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.
data_samples (Optional[list]) – Data samples collated by
data_preprocessor
. Defaults to None.mode (Optional[str]) – mode is not used in
BaseGAN
. Defaults to None.
- Returns
A list of
DataSample
contain generated results.- Return type
SampleList
- val_step(data: dict) mmagic.utils.typing.SampleList [source]¶
Gets the generated image of given data.
Calls
self.data_preprocessor(data)
andself(inputs, data_sample, mode=None)
in order. Return the generated results which will be passed to evaluator.- Parameters
data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.
- Returns
Generated image or image dict.
- Return type
SampleList
- test_step(data: dict) mmagic.utils.typing.SampleList [source]¶
Gets the generated image of given data. Same as
val_step()
.- Parameters
data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.
- Returns
Generated image or image dict.
- Return type
List[DataSample]
- train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [source]¶
Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMagic’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in
should_gen_update()
.- Parameters
data (dict) – Data sampled from dataloader.
optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, torch.Tensor]
- train_generator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Training function for discriminator. All GANs should implement this function by themselves.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- train_discriminator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Training function for discriminator. All GANs should implement this function by themselves.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]