mmagic.models.editors.inst_colorization.inst_colorization
¶
Module Contents¶
Classes¶
Colorization InstColorization method. |
- class mmagic.models.editors.inst_colorization.inst_colorization.InstColorization(data_preprocessor: Union[dict, mmengine.config.Config], image_model, instance_model, fusion_model, color_data_opt, which_direction='AtoB', loss=None, init_cfg=None, train_cfg=None, test_cfg=None)[source]¶
Bases:
mmengine.model.BaseModel
Colorization InstColorization method.
- This Colorization is implemented according to the paper:
Instance-aware Image Colorization, CVPR 2020
Adapted from ‘https://github.com/ericsujw/InstColorization.git’ ‘InstColorization/models/train_model’ Copyright (c) 2020, Su, under MIT License.
- Parameters
data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.image_model (dict) – Config for single image model
instance_model (dict) – Config for instance model
fusion_model (dict) – Config for fusion model
color_data_opt (dict) – Option for colorspace conversion
which_direction (str) – AtoB or BtoA
loss (dict) – Config for loss.
init_cfg (str) – Initialization config dict. Default: None.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
- forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs)[source]¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forward
method of BaseModel is an abstract method, its subclasses must implement this method.Accepts
inputs
anddata_samples
processed bydata_preprocessor
, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forward
will be called byBaseModel.train_step
,BaseModel.val_step
andBaseModel.val_step
directly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_step
will first callDistributedDataParallel.forward
to enable automatic gradient synchronization, and then callforward
to get training loss.- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- Returns
If
mode == loss
, return adict
of loss tensor used for backward and logging.If
mode == predict
, return alist
ofBaseDataElement
for computing metric and getting inference result.If
mode == tensor
, return a tensor ortuple
of tensor ordict
or tensor for custom use.
- Return type
ForwardResults
- convert_to_datasample(inputs, data_samples)[source]¶
Add predictions and destructed inputs (if passed) to data samples.
- Parameters
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
data_samples (List[DataSample]) – The data samples loaded from dataloader.
- Returns
Modified data samples.
- Return type
List[DataSample]
- abstract train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [source]¶
Train step function.
- Parameters
data (List[dict]) – Batch of data as input.
optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).
- Returns
- Dict with loss, information for logger, the number of
samples and results for visualization.
- Return type
dict
- forward_inference(inputs, data_samples=None, **kwargs)[source]¶
Forward inference. Returns predictions of validation, testing.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
predictions.
- Return type
List[DataSample]