mmagic.models.base_models.base_conditional_gan
¶
Module Contents¶
Classes¶
Base class for Conditional GAM models. |
Attributes¶
- class mmagic.models.base_models.base_conditional_gan.BaseConditionalGAN(generator: ModelType, discriminator: Optional[ModelType] = None, data_preprocessor: Optional[Union[dict, mmengine.Config]] = None, generator_steps: int = 1, discriminator_steps: int = 1, noise_size: Optional[int] = None, num_classes: Optional[int] = None, ema_config: Optional[Dict] = None, loss_config: Optional[Dict] = None)[source]¶
Bases:
mmagic.models.base_models.base_gan.BaseGAN
Base class for Conditional GAM models.
- Parameters
generator (ModelType) – The config or model of the generator.
discriminator (Optional[ModelType]) – The config or model of the discriminator. Defaults to None.
data_preprocessor (Optional[Union[dict, Config]]) – The pre-process config or
DataPreprocessor
.generator_steps (int) – The number of times the generator is completely updated before the discriminator is updated. Defaults to 1.
discriminator_steps (int) – The number of times the discriminator is completely updated before the generator is updated. Defaults to 1.
noise_size (Optional[int]) – Size of the input noise vector. Default to None.
num_classes (Optional[int]) – The number classes you would like to generate. Defaults to None.
ema_config (Optional[Dict]) – The config for generator’s exponential moving average setting. Defaults to None.
- label_fn(label: mmagic.utils.typing.LabelVar = None, num_batches: int = 1) torch.Tensor [source]¶
Sampling function for label. There are three scenarios in this function:
If label is a callable function, sample num_batches of labels with passed label.
If label is None, sample num_batches of labels in range of [0, self.num_classes-1] uniformly.
If label is a torch.Tensor, check the range of the tensor is in [0, self.num_classes-1]. If all values are in valid range, directly return label.
- Parameters
label (Union[Tensor, Callable, List[int], None]) – You can directly give a batch of label through a
torch.Tensor
or offer a callable function to sample a batch of label data. Otherwise, theNone
indicates to use the default label sampler. Defaults to None.num_batches (int, optional) – The number of batches label want to sample. If label is a Tensor, this will be ignored. Defaults to 1.
- Returns
Sampled label tensor.
- Return type
Tensor
- data_sample_to_label(data_sample: mmagic.structures.DataSample) Optional[torch.Tensor] [source]¶
Get labels from input data_sample and pack to torch.Tensor. If no label is found in the passed data_sample, None would be returned.
- Parameters
data_sample (DataSample) – Input data samples.
- Returns
Packed label tensor.
- Return type
Optional[torch.Tensor]
- static _get_valid_num_classes(num_classes: Optional[int], generator: ModelType, discriminator: Optional[ModelType]) int [source]¶
Try to get the value of num_classes from input, generator and discriminator and check the consistency of these values. If no conflict is found, return the num_classes.
- Parameters
num_classes (Optional[int]) – num_classes passed to BaseConditionalGAN_refactor’s initialize function.
generator (ModelType) – The config or the model of generator.
discriminator (Optional[ModelType]) – The config or model of discriminator.
- Returns
The number of classes to be generated.
- Return type
int
- forward(inputs: mmagic.utils.typing.ForwardInputs, data_samples: Optional[list] = None, mode: Optional[str] = None) List[mmagic.structures.DataSample] [source]¶
Sample images with the given inputs. If forward mode is ‘ema’ or ‘orig’, the image generated by corresponding generator will be returned. If forward mode is ‘ema/orig’, images generated by original generator and EMA generator will both be returned in a dict.
- Parameters
inputs (ForwardInputs) – Dict containing the necessary information (e.g. noise, num_batches, mode) to generate image.
data_samples (Optional[list]) – Data samples collated by
data_preprocessor
. Defaults to None.mode (Optional[str]) – mode is not used in
BaseConditionalGAN
. Defaults to None.
- Returns
Generated images or image dict.
- Return type
List[DataSample]
- train_generator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Training function for discriminator. All GANs should implement this function by themselves.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]
- train_discriminator(inputs: dict, data_samples: List[mmagic.structures.DataSample], optimizer_wrapper: mmengine.optim.OptimWrapper) Dict[str, torch.Tensor] [source]¶
Training function for discriminator. All GANs should implement this function by themselves.
- Parameters
inputs (dict) – Inputs from dataloader.
data_samples (List[DataSample]) – Data samples from dataloader.
optim_wrapper (OptimWrapper) – OptimWrapper instance used to update model parameters.
- Returns
A
dict
of tensor for logging.- Return type
Dict[str, Tensor]