mmagic.models.editors.inst_colorization
¶
Package Contents¶
Classes¶
Real-Time User-Guided Image Colorization with Learned Deep Priors. The |
|
Instance-aware Image Colorization. |
|
Colorization InstColorization method. |
- class mmagic.models.editors.inst_colorization.ColorizationNet(input_nc, output_nc, norm_type, use_tanh=True, classification=True)[source]¶
Bases:
mmengine.model.BaseModule
Real-Time User-Guided Image Colorization with Learned Deep Priors. The backbone used for.
https://arxiv.org/abs/1705.02999
Codes adapted from ‘https://github.com/ericsujw/InstColorization.git’ ‘InstColorization/blob/master/models/networks.py#L108’
- Parameters
input_nc (int) – input image channels
output_nc (int) – output image channels
norm_type (str) – instance normalization or batch normalization
use_tanh (bool) – Whether to use nn.Tanh() Default: True.
classification (bool) – backprop trunk using classification, otherwise use regression. Default: True
- forward(input_A, input_B, mask_B)[source]¶
Forward function.
- Parameters
input_A (tensor) – Channel of the image in lab color space
input_B (tensor) – Color patch
mask_B (tensor) – Color patch mask
- Returns
Classification output out_reg (tensor): Regression output feature_map (dict): The full-image feature
- Return type
out_class (tensor)
- class mmagic.models.editors.inst_colorization.FusionNet(input_nc, output_nc, norm_type, use_tanh=True, classification=True)[source]¶
Bases:
mmengine.model.BaseModule
Instance-aware Image Colorization.
https://arxiv.org/abs/2005.10825
Codes adapted from ‘https://github.com/ericsujw/InstColorization.git’ ‘InstColorization/blob/master/models/networks.py#L314’ FusionNet: the full image model with weight layer for fusion.
- Parameters
input_nc (int) – input image channels
output_nc (int) – output image channels
norm_type (str) – instance normalization or batch normalization
use_tanh (bool) – Whether to use nn.Tanh() Default: True.
classification (bool) – backprop trunk using classification, otherwise use regression. Default: True
- forward(input_A, input_B, mask_B, instance_feature, box_info_list)[source]¶
Forward function.
- Parameters
input_A (tensor) – Channel of the image in lab color space
input_B (tensor) – Color patch
mask_B (tensor) – Color patch mask
instance_feature (dict) – A bunch of instance features
box_info_list (list) – Bounding box information corresponding to the instance
- Returns
Regression output
- Return type
out_reg (tensor)
- class mmagic.models.editors.inst_colorization.InstColorization(data_preprocessor: Union[dict, mmengine.config.Config], image_model, instance_model, fusion_model, color_data_opt, which_direction='AtoB', loss=None, init_cfg=None, train_cfg=None, test_cfg=None)[source]¶
Bases:
mmengine.model.BaseModel
Colorization InstColorization method.
- This Colorization is implemented according to the paper:
Instance-aware Image Colorization, CVPR 2020
Adapted from ‘https://github.com/ericsujw/InstColorization.git’ ‘InstColorization/models/train_model’ Copyright (c) 2020, Su, under MIT License.
- Parameters
data_preprocessor (dict, optional) – The pre-process config of
BaseDataPreprocessor
.image_model (dict) – Config for single image model
instance_model (dict) – Config for instance model
fusion_model (dict) – Config for fusion model
color_data_opt (dict) – Option for colorspace conversion
which_direction (str) – AtoB or BtoA
loss (dict) – Config for loss.
init_cfg (str) – Initialization config dict. Default: None.
train_cfg (dict) – Config for training. Default: None.
test_cfg (dict) – Config for testing. Default: None.
- forward(inputs: torch.Tensor, data_samples: Optional[List[mmagic.structures.DataSample]] = None, mode: str = 'tensor', **kwargs)[source]¶
Returns losses or predictions of training, validation, testing, and simple inference process.
forward
method of BaseModel is an abstract method, its subclasses must implement this method.Accepts
inputs
anddata_samples
processed bydata_preprocessor
, and returns results according to mode arguments.During non-distributed training, validation, and testing process,
forward
will be called byBaseModel.train_step
,BaseModel.val_step
andBaseModel.val_step
directly.During distributed data parallel training process,
MMSeparateDistributedDataParallel.train_step
will first callDistributedDataParallel.forward
to enable automatic gradient synchronization, and then callforward
to get training loss.- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.mode (str) –
mode should be one of
loss
,predict
andtensor
. Default: ‘tensor’.loss
: Called bytrain_step
and return lossdict
used for loggingpredict
: Called byval_step
andtest_step
and return list ofBaseDataElement
results used for computing metric.tensor
: Called by custom use to getTensor
type results.
- Returns
If
mode == loss
, return adict
of loss tensor used for backward and logging.If
mode == predict
, return alist
ofBaseDataElement
for computing metric and getting inference result.If
mode == tensor
, return a tensor ortuple
of tensor ordict
or tensor for custom use.
- Return type
ForwardResults
- convert_to_datasample(inputs, data_samples)[source]¶
Add predictions and destructed inputs (if passed) to data samples.
- Parameters
inputs (Optional[torch.Tensor]) – The input of model. Defaults to None.
data_samples (List[DataSample]) – The data samples loaded from dataloader.
- Returns
Modified data samples.
- Return type
List[DataSample]
- abstract train_step(data: List[dict], optim_wrapper: mmengine.optim.OptimWrapperDict) Dict[str, torch.Tensor] [source]¶
Train step function.
- Parameters
data (List[dict]) – Batch of data as input.
optim_wrapper (dict[torch.optim.Optimizer]) – Dict with optimizers for generator and discriminator (if have).
- Returns
- Dict with loss, information for logger, the number of
samples and results for visualization.
- Return type
dict
- forward_inference(inputs, data_samples=None, **kwargs)[source]¶
Forward inference. Returns predictions of validation, testing.
- Parameters
inputs (torch.Tensor) – batch input tensor collated by
data_preprocessor
.data_samples (List[BaseDataElement], optional) – data samples collated by
data_preprocessor
.
- Returns
predictions.
- Return type
List[DataSample]