mmagic.models.losses.loss_comps
¶
Package Contents¶
Classes¶
Clip loss. In styleclip, this loss is used to optimize the latent code |
|
Disc Shift Loss. |
|
Gradient Penalty for WGAN-GP. |
|
R1 gradient penalty for WGAN-GP. |
|
Face similarity loss. Generally this loss is used to keep the id |
|
Define GAN loss. |
|
Generator Path Regularizer. |
- class mmagic.models.losses.loss_comps.CLIPLossComps(loss_weight: float = 1.0, data_info: Optional[dict] = None, clip_model: dict = dict(), loss_name: str = 'loss_clip')[源代码]¶
Bases:
torch.nn.Module
Clip loss. In styleclip, this loss is used to optimize the latent code to generate image that match the text.
In this loss, we may need to provide
image
,text
. Thus, an example of thedata_info
is:1data_info = dict( 2 image='fake_imgs', 3 text='descriptions')
Then, the module will automatically construct this mapping from the input data dictionary.
- 参数
loss_weight (float, optional) – Weight of this loss item. Defaults to
1.
.data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If
None
, this module will directly pass the input data to the loss function. Defaults to None.clip_model (dict, optional) – Kwargs for clip loss model. Defaults to dict().
loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_clip’.
- forward(*args, **kwargs) torch.Tensor ¶
Forward function.
If
self.data_info
is notNone
, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.If
self.data_info
isNone
, the input argument or key-word argument will be directly passed to loss function,third_party_net_loss
.
- static loss_name() str ¶
Loss Name.
This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.
- 返回
The name of this loss item.
- 返回类型
str
- class mmagic.models.losses.loss_comps.DiscShiftLossComps(loss_weight: float = 1.0, data_info: Optional[dict] = None, loss_name: str = 'loss_disc_shift')[源代码]¶
Bases:
torch.nn.Module
Disc Shift Loss.
This loss is proposed in PGGAN as an auxiliary loss for discriminator.
Note for the design of ``data_info``: In
MMagic
, almost all of loss modules contain the argumentdata_info
, which can be used for constructing the link between the input items (needed in loss calculation) and the data from the generative model. For example, in the training of GAN model, we will collect all of important data/modules into a dictionary:1data_dict_ = dict( 2 gen=self.generator, 3 disc=self.discriminator, 4 disc_pred_fake=disc_pred_fake, 5 disc_pred_real=disc_pred_real, 6 fake_imgs=fake_imgs, 7 real_imgs=real_imgs, 8 iteration=curr_iter, 9 batch_size=batch_size)
But in this loss, we will need to provide
pred
as input. Thus, an example of thedata_info
is:1data_info = dict( 2 pred='disc_pred_fake')
Then, the module will automatically construct this mapping from the input data dictionary.
In addition, in general,
disc_shift_loss
will be applied over real and fake data. In this case, users just need to add this loss module twice, but with differentdata_info
. Our model will automatically add these two items.- 参数
loss_weight (float, optional) – Weight of this loss item. Defaults to
1.
.data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If
None
, this module will directly pass the input data to the loss function. Defaults to None.loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_disc_shift’.
- forward(*args, **kwargs) torch.Tensor ¶
Forward function.
If
self.data_info
is notNone
, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.If
self.data_info
isNone
, the input argument or key-word argument will be directly passed to loss function,disc_shift_loss
.
- loss_name() str ¶
Loss Name.
This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.
- 返回
The name of this loss item.
- 返回类型
str
- class mmagic.models.losses.loss_comps.GradientPenaltyLossComps(loss_weight: float = 1.0, norm_mode: str = 'pixel', data_info: Optional[dict] = None, loss_name: str = 'loss_gp')[源代码]¶
Bases:
torch.nn.Module
Gradient Penalty for WGAN-GP.
In the detailed implementation, there are two streams where one uses the pixel-wise gradient norm, but the other adopts normalization along instance (HWC) dimensions. Thus,
norm_mode
are offered to define which mode you want.Note for the design of ``data_info``: In
MMagic
, almost all of loss modules contain the argumentdata_info
, which can be used for constructing the link between the input items (needed in loss calculation) and the data from the generative model. For example, in the training of GAN model, we will collect all of important data/modules into a dictionary:1data_dict_ = dict( 2 gen=self.generator, 3 disc=self.discriminator, 4 disc_pred_fake=disc_pred_fake, 5 disc_pred_real=disc_pred_real, 6 fake_imgs=fake_imgs, 7 real_imgs=real_imgs, 8 iteration=curr_iter, 9 batch_size=batch_size)
But in this loss, we will need to provide
discriminator
,real_data
, andfake_data
as input. Thus, an example of thedata_info
is:1data_info = dict( 2 discriminator='disc', 3 real_data='real_imgs', 4 fake_data='fake_imgs')
Then, the module will automatically construct this mapping from the input data dictionary.
- 参数
loss_weight (float, optional) – Weight of this loss item. Defaults to
1.
.data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If
None
, this module will directly pass the input data to the loss function. Defaults to None.norm_mode (str) – This argument decides along which dimension the norm of the gradients will be calculated. Currently, we support [“pixel” , “HWC”]. Defaults to “pixel”.
loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_gp’.
- forward(*args, **kwargs) torch.Tensor ¶
Forward function.
If
self.data_info
is notNone
, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.If
self.data_info
isNone
, the input argument or key-word argument will be directly passed to loss function,gradient_penalty_loss
.
- loss_name() str ¶
Loss Name.
This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.
- 返回
The name of this loss item.
- 返回类型
str
- class mmagic.models.losses.loss_comps.R1GradientPenaltyComps(loss_weight: float = 1.0, norm_mode: str = 'pixel', interval: int = 1, data_info: Optional[dict] = None, use_apex_amp: bool = False, loss_name: str = 'loss_r1_gp')[源代码]¶
Bases:
torch.nn.Module
R1 gradient penalty for WGAN-GP.
R1 regularizer comes from: “Which Training Methods for GANs do actually Converge?” ICML’2018
Different from original gradient penalty, this regularizer only penalized gradient w.r.t. real data.
Note for the design of ``data_info``: In
MMagic
, almost all of loss modules contain the argumentdata_info
, which can be used for constructing the link between the input items (needed in loss calculation) and the data from the generative model. For example, in the training of GAN model, we will collect all of important data/modules into a dictionary:1data_dict_ = dict( 2 gen=self.generator, 3 disc=self.discriminator, 4 disc_pred_fake=disc_pred_fake, 5 disc_pred_real=disc_pred_real, 6 fake_imgs=fake_imgs, 7 real_imgs=real_imgs, 8 iteration=curr_iter, 9 batch_size=batch_size)
But in this loss, we will need to provide
discriminator
andreal_data
as input. Thus, an example of thedata_info
is:1data_info = dict( 2 discriminator='disc', 3 real_data='real_imgs')
Then, the module will automatically construct this mapping from the input data dictionary.
- 参数
loss_weight (float, optional) – Weight of this loss item. Defaults to
1.
.data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If
None
, this module will directly pass the input data to the loss function. Defaults to None.norm_mode (str) – This argument decides along which dimension the norm of the gradients will be calculated. Currently, we support [“pixel” , “HWC”]. Defaults to “pixel”.
interval (int, optional) – The interval of calculating this loss. Defaults to 1.
loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_r1_gp’.
- forward(*args, **kwargs) torch.Tensor ¶
Forward function.
If
self.data_info
is notNone
, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.If
self.data_info
isNone
, the input argument or key-word argument will be directly passed to loss function,r1_gradient_penalty_loss
.
- loss_name() str ¶
Loss Name.
This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.
- 返回
The name of this loss item.
- 返回类型
str
- class mmagic.models.losses.loss_comps.FaceIdLossComps(loss_weight: float = 1.0, data_info: Optional[dict] = None, facenet: dict = dict(type='ArcFace', ir_se50_weights=None), loss_name: str = 'loss_id')[源代码]¶
Bases:
torch.nn.Module
Face similarity loss. Generally this loss is used to keep the id consistency of the input face image and output face image.
In this loss, we may need to provide
gt
,pred
andx
. Thus, an example of thedata_info
is:1data_info = dict( 2 gt='real_imgs', 3 pred='fake_imgs')
Then, the module will automatically construct this mapping from the input data dictionary.
- 参数
loss_weight (float, optional) – Weight of this loss item. Defaults to
1.
.data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If
None
, this module will directly pass the input data to the loss function. Defaults to None.facenet (dict, optional) – Config dict for facenet. Defaults to dict(type=’ArcFace’, ir_se50_weights=None).
loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_id’.
- forward(*args, **kwargs) torch.Tensor ¶
Forward function.
If
self.data_info
is notNone
, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.If
self.data_info
isNone
, the input argument or key-word argument will be directly passed to loss function,third_party_net_loss
.
- loss_name() str ¶
Loss Name.
This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.
- 返回
The name of this loss item.
- 返回类型
str
- class mmagic.models.losses.loss_comps.GANLossComps(gan_type: str, real_label_val: float = 1.0, fake_label_val: float = 0.0, loss_weight: float = 1.0)[源代码]¶
Bases:
torch.nn.Module
Define GAN loss.
- 参数
gan_type (str) – Support ‘vanilla’, ‘lsgan’, ‘wgan’, ‘hinge’, ‘wgan-logistic-ns’.
real_label_val (float) – The value for real label. Default: 1.0.
fake_label_val (float) – The value for fake label. Default: 0.0.
loss_weight (float) – Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators.
- _wgan_loss(input: torch.Tensor, target: bool) torch.Tensor ¶
wgan loss.
- 参数
input (Tensor) – Input tensor.
target (bool) – Target label.
- 返回
wgan loss.
- 返回类型
Tensor
- _wgan_logistic_ns_loss(input: torch.Tensor, target: bool) torch.Tensor ¶
WGAN loss in logistically non-saturating mode.
This loss is widely used in StyleGANv2.
- 参数
input (Tensor) – Input tensor.
target (bool) – Target label.
- 返回
wgan loss.
- 返回类型
Tensor
- get_target_label(input: torch.Tensor, target_is_real: bool) Union[bool, torch.Tensor] ¶
Get target label.
- 参数
input (Tensor) – Input tensor.
target_is_real (bool) – Whether the target is real or fake.
- 返回
Target tensor. Return bool for wgan, otherwise, return Tensor.
- 返回类型
(bool | Tensor)
- forward(input: torch.Tensor, target_is_real: bool, is_disc: bool = False) torch.Tensor ¶
- 参数
input (Tensor) – The input for the loss module, i.e., the network prediction.
target_is_real (bool) – Whether the targe is real or fake.
is_disc (bool) – Whether the loss for discriminators or not. Default: False.
- 返回
GAN loss value.
- 返回类型
Tensor
- class mmagic.models.losses.loss_comps.GeneratorPathRegularizerComps(loss_weight: float = 1.0, pl_batch_shrink: int = 1, decay: float = 0.01, pl_batch_size: Optional[int] = None, sync_mean_buffer: bool = False, interval: int = 1, data_info: Optional[dict] = None, use_apex_amp: bool = False, loss_name: str = 'loss_path_regular')[源代码]¶
Bases:
torch.nn.Module
Generator Path Regularizer.
Path regularization is proposed in StyleGAN2, which can help the improve the continuity of the latent space. More details can be found in: Analyzing and Improving the Image Quality of StyleGAN, CVPR2020.
Users can achieve lazy regularization by setting
interval
arguments here.Note for the design of ``data_info``: In
MMagic
, almost all of loss modules contain the argumentdata_info
, which can be used for constructing the link between the input items (needed in loss calculation) and the data from the generative model. For example, in the training of GAN model, we will collect all of important data/modules into a dictionary:1data_dict_ = dict( 2 gen=self.generator, 3 disc=self.discriminator, 4 fake_imgs=fake_imgs, 5 disc_pred_fake_g=disc_pred_fake_g, 6 iteration=curr_iter, 7 batch_size=batch_size)
But in this loss, we will need to provide
generator
andnum_batches
as input. Thus an example of thedata_info
is:1data_info = dict( 2 generator='gen', 3 num_batches='batch_size')
Then, the module will automatically construct this mapping from the input data dictionary.
- 参数
loss_weight (float, optional) – Weight of this loss item. Defaults to
1.
.pl_batch_shrink (int, optional) – The factor of shrinking the batch size for saving GPU memory. Defaults to 1.
decay (float, optional) – Decay for moving average of mean path length. Defaults to 0.01.
pl_batch_size (int | None, optional) – The batch size in calculating generator path. Once this argument is set, the
num_batches
will be overridden with this argument and won’t be affected bypl_batch_shrink
. Defaults to None.sync_mean_buffer (bool, optional) – Whether to sync mean path length across all of GPUs. Defaults to False.
interval (int, optional) – The interval of calculating this loss. This argument is used to support lazy regularization. Defaults to 1.
data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If
None
, this module will directly pass the input data to the loss function. Defaults to None.loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_path_regular’.
- forward(*args, **kwargs) torch.Tensor ¶
Forward function.
If
self.data_info
is notNone
, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.If
self.data_info
isNone
, the input argument or key-word argument will be directly passed to loss function,gen_path_regularizer
.
- loss_name() str ¶
Loss Name.
This function must be implemented and will return the name of this loss function. This name will be used to combine different loss items by simple sum operation. In addition, if you want this loss item to be included into the backward graph, loss_ must be the prefix of the name.
- 返回
The name of this loss item.
- 返回类型
str