mmagic.models.editors.cyclegan
¶
Package Contents¶
Classes¶
CycleGAN model for unpaired image-to-image translation. |
|
Construct a Resnet-based generator that consists of residual blocks |
- class mmagic.models.editors.cyclegan.CycleGAN(*args, buffer_size=50, loss_config=dict(cycle_loss_weight=10.0, id_loss_weight=0.5), **kwargs)[源代码]¶
Bases:
mmagic.models.base_models.BaseTranslationModel
CycleGAN model for unpaired image-to-image translation.
Ref: Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
- forward_test(img, target_domain, **kwargs)¶
Forward function for testing.
- 参数
img (tensor) – Input image tensor.
target_domain (str) – Target domain of output image.
kwargs (dict) – Other arguments.
- 返回
Forward results.
- 返回类型
dict
- _get_disc_loss(outputs)¶
Backward function for the discriminators.
- 参数
outputs (dict) – Dict of forward results.
- 返回
Discriminators’ loss and loss dict.
- 返回类型
dict
- _get_gen_loss(outputs)¶
Backward function for the generators.
- 参数
outputs (dict) – Dict of forward results.
- 返回
Generators’ loss and loss dict.
- 返回类型
dict
- _get_opposite_domain(domain)¶
Get the opposite domain respect to the input domain.
- 参数
domain (str) – The input domain.
- 返回
The opposite domain.
- 返回类型
str
- train_step(data: dict, optim_wrapper: mmengine.optim.OptimWrapperDict)¶
Training step function.
- 参数
data_batch (dict) – Dict of the input data batch.
optimizer (dict[torch.optim.Optimizer]) – Dict of optimizers for the generators and discriminators.
ddp_reducer (
Reducer
| None, optional) – Reducer from ddp. It is used to prepare forbackward()
in ddp. Defaults to None.running_status (dict | None, optional) – Contains necessary basic information for training, e.g., iteration number. Defaults to None.
- 返回
Dict of loss, information for logger, the number of samples and results for visualization.
- 返回类型
dict
- test_step(data: dict) mmagic.utils.typing.SampleList ¶
Gets the generated image of given data. Same as
val_step()
.- 参数
data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.
- 返回
A list of
DataSample
contain generated results.- 返回类型
SampleList
- val_step(data: dict) mmagic.utils.typing.SampleList ¶
Gets the generated image of given data. Same as
val_step()
.- 参数
data (dict) – Data sampled from metric specific sampler. More details in Metrics and Evaluator.
- 返回
A list of
DataSample
contain generated results.- 返回类型
SampleList
- class mmagic.models.editors.cyclegan.ResnetGenerator(in_channels, out_channels, base_channels=64, norm_cfg=dict(type='IN'), use_dropout=False, num_blocks=9, padding_mode='reflect', init_cfg=dict(type='normal', gain=0.02))[源代码]¶
Bases:
mmengine.model.BaseModule
Construct a Resnet-based generator that consists of residual blocks between a few downsampling/upsampling operations.
- 参数
in_channels (int) – Number of channels in input images.
out_channels (int) – Number of channels in output images.
base_channels (int) – Number of filters at the last conv layer. Default: 64.
norm_cfg (dict) – Config dict to build norm layer. Default: dict(type=’IN’).
use_dropout (bool) – Whether to use dropout layers. Default: False.
num_blocks (int) – Number of residual blocks. Default: 9.
padding_mode (str) – The name of padding layer in conv layers: ‘reflect’ | ‘replicate’ | ‘zeros’. Default: ‘reflect’.
init_cfg (dict) – Config dict for initialization. type: The name of our initialization method. Default: ‘normal’. gain: Scaling factor for normal, xavier and orthogonal. Default: 0.02.
- forward(x)¶
Forward function.
- 参数
x (Tensor) – Input tensor with shape (n, c, h, w).
- 返回
Forward results.
- 返回类型
Tensor
- init_weights()¶
Initialize weights for the model.