mmagic.models.editors.singan.singan
¶
Module Contents¶
Attributes¶
- class mmagic.models.editors.singan.singan.SinGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, num_scales: Optional[int] = None, iters_per_scale: int = 2000, noise_weight_init: int = 0.1, lr_scheduler_args: Optional[dict] = None, test_pkl_data: Optional[str] = None, ema_confg: Optional[dict] = None)[source]¶
Bases:
mmagic.models.base_models.BaseGAN
SinGAN.
This model implement the single image generative adversarial model proposed in: Singan: Learning a Generative Model from a Single Natural Image, ICCV’19.
Notes for training:
This model should be trained with our dataset
SinGANDataset
.In training, the
total_iters
arguments is related to the number of scales in the image pyramid anditers_per_scale
in thetrain_cfg
. You should set it carefully in the training config file.
Notes for model architectures:
The generator and discriminator need
num_scales
in initialization. However, this arguments is generated bycreate_real_pyramid
function from thesingan_dataset.py
. The last element in the returned list (stop_scale
) is the value fornum_scales
. Pay attention that this scale is counted from zero. Please see our tutorial for SinGAN to obtain more details or our standard config for reference.
- Parameters
generator (ModelType) – The config or model of the generator.
discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.
data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or
DataPreprocessor
.generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
num_scales (int) – The number of scales/stages in generator/ discriminator. Note that this number is counted from zero, which is the same as the original paper. Defaults to None.
iters_per_scale (int) – The training iteration for each resolution scale. Defaults to 2000.
noise_weight_init (float) – The initialize weight of fixed noise. Defaults to 0.1
lr_scheduler_args (Optional[dict]) – Arguments for learning schedulers. Note that in SinGAN, we use MultiStepLR, which is the same as the original paper. If not passed, no learning schedule will be used. Defaults to None.
test_pkl_data (Optional[str]) – The path of pickle file which contains fixed noise and noise weight. This is must for test. Defaults to None.
ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.
- _from_numpy(data: Tuple[list, numpy.ndarray]) Tuple[torch.Tensor, List[torch.Tensor]] [source]¶
Convert input numpy array or list of numpy array to Tensor or list of Tensor.
- Parameters
data (Tuple[list, np.ndarray]) – Input data to convert.
- Returns
Converted Tensor or list of tensor.
- Return type
Tuple[Tensor, List[Tensor]]
- get_module(model: torch.nn.Module, module_name: str) torch.nn.Module [source]¶
Get an inner module from model.
Since we will wrapper DDP for some model, we have to judge whether the module can be indexed directly.
- Parameters
model (nn.Module) – This model may wrapped with DDP or not.
module_name (str) – The name of specific module.
- Returns
Returned sub module.
- Return type
nn.Module
- forward(inputs: mmagic.utils.ForwardInputs, data_samples: Optional[list] = None, mode=None) List[mmagic.structures.DataSample] [source]¶
Forward function for SinGAN. For SinGAN, inputs should be a dict contains ‘num_batches’, ‘mode’ and other input arguments for the generator.
- Parameters
inputs (dict) – Dict containing the necessary information (e.g., noise, num_batches, mode) to generate image.
data_samples (Optional[list]) – Data samples collated by
data_preprocessor
. Defaults to None.mode (Optional[str]) – mode is not used in
BaseConditionalGAN
. Defaults to None.
- gen_loss(disc_pred_fake: torch.Tensor, recon_imgs: torch.Tensor) Tuple[torch.Tensor, dict] [source]¶
Generator loss for SinGAN. SinGAN use WGAN’s loss and MSE loss to train the generator.
- Parameters
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
recon_imgs (Tensor) – Reconstructive images.
- Returns
Loss value and a dict of log variables.
- Return type
Tuple[Tensor, dict]
- disc_loss(disc_pred_fake: torch.Tensor, disc_pred_real: torch.Tensor, fake_data: torch.Tensor, real_data: torch.Tensor) Tuple[torch.Tensor, dict] [source]¶
Get disc loss. SAGAN, SNGAN and Proj-GAN use hinge loss to train the generator.
- Parameters
disc_pred_fake (Tensor) – Discriminator’s prediction of the fake images.
disc_pred_real (Tensor) – Discriminator’s prediction of the real images.
fake_data (Tensor) – Generated images, used to calculate gradient penalty.
real_data (Tensor) – Real images, used to calculate gradient penalty.
- Returns
Loss value and a dict of log variables.
- Return type
Tuple[Tensor, dict]
- train_generator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Train generator.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader. Do not used in generator’s training.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- train_discriminator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Train discriminator.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- train_gan(inputs_dict: dict, data_sample: List[mmagic.structures.DataSample], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [source]¶
Train GAN model. In the training of GAN models, generator and discriminator are updated alternatively. In MMagic’s design, self.train_step is called with data input. Therefore we always update discriminator, whose updating is relay on real data, and then determine if the generator needs to be updated based on the current number of iterations. More details about whether to update generator can be found in
should_gen_update()
.- Parameters
data (dict) – Data sampled from dataloader.
data_sample (List[DataSample]) – List of data sample contains GT and meta information.
optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, torch.Tensor]
- train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [source]¶
Train step for SinGAN model. SinGAN is trained with multi-resolution images, and each resolution is trained for :attr:self.iters_per_scale times.
We initialize the weight and learning rate scheduler of the corresponding module at the start of each resolution’s training. At the end of each resolution’s training, we update the weight of the noise of current resolution by mse loss between reconstructed image and real image.
- Parameters
data (dict) – Data sampled from dataloader.
optim_wrapper (OptimWrapperDict) – OptimWrapperDict instance contains OptimWrapper of generator and discriminator.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, torch.Tensor]
- test_step(data: dict) mmagic.utils.SampleList [source]¶
Gets the generated image of given data in test progress. Before generate images, we call :meth:self.load_test_pkl to load the fixed noise and current stage of the model from the pickle file.
- Parameters
data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.
- Returns
A list of
DataSample
contain generated results.- Return type
SampleList