mmagic.models.losses.clip_loss
¶
Module Contents¶
Classes¶
Wrapped clip model to calculate clip loss. |
|
Clip loss. In styleclip, this loss is used to optimize the latent code |
Attributes¶
- class mmagic.models.losses.clip_loss.CLIPLossModel(in_size: int = 1024, scale_factor: int = 7, pool_size: int = 224, clip_type: str = 'ViT-B/32')[source]¶
Bases:
torch.nn.Module
Wrapped clip model to calculate clip loss.
Ref: https://github.com/orpatashnik/StyleCLIP/blob/main/criteria/clip_loss.py # noqa
- Parameters
in_size (int, optional) – Input image size. Defaults to 1024.
scale_factor (int, optional) – Unsampling factor. Defaults to 7.
pool_size (int, optional) – Pooling output size. Defaults to 224.
clip_type (str, optional) – A model name listed by clip.available_models(), or the path to a model checkpoint containing the state_dict. For more details, you can refer to https://github.com/openai/CLIP/blob/573315e83f07b53a61ff5098757e8fc885f1703e/clip/clip.py#L91 # noqa Defaults to ‘ViT-B/32’.
- class mmagic.models.losses.clip_loss.CLIPLoss(loss_weight: float = 1.0, data_info: Optional[dict] = None, clip_model: dict = dict(), loss_name: str = 'loss_clip')[source]¶
Bases:
torch.nn.Module
Clip loss. In styleclip, this loss is used to optimize the latent code to generate image that match the text.
In this loss, we may need to provide
image
,text
. Thus, an example of thedata_info
is:1data_info = dict( 2 image='fake_imgs', 3 text='descriptions')
Then, the module will automatically construct this mapping from the input data dictionary.
- Parameters
loss_weight (float, optional) – Weight of this loss item. Defaults to
1.
.data_info (dict, optional) – Dictionary contains the mapping between loss input args and data dictionary. If
None
, this module will directly pass the input data to the loss function. Defaults to None.clip_model (dict, optional) – Kwargs for clip loss model. Defaults to dict().
loss_name (str, optional) – Name of the loss item. If you want this loss item to be included into the backward graph, loss_ must be the prefix of the name. Defaults to ‘loss_clip’.
- forward(image: torch.Tensor, text: torch.Tensor) torch.Tensor [source]¶
Forward function.
If
self.data_info
is notNone
, a dictionary containing all of the data and necessary modules should be passed into this function. If this dictionary is given as a non-keyword argument, it should be offered as the first argument. If you are using keyword argument, please name it as outputs_dict.If
self.data_info
isNone
, the input argument or key-word argument will be directly passed to loss function,third_party_net_loss
.