Shortcuts

mmagic.evaluation.metrics.base_sample_wise_metric

Evaluation metrics based on each sample.

Module Contents

Classes

BaseSampleWiseMetric

Base sample wise metric of edit.

class mmagic.evaluation.metrics.base_sample_wise_metric.BaseSampleWiseMetric(gt_key: str = 'gt_img', pred_key: str = 'pred_img', mask_key: Optional[str] = None, scaling=1, device='cpu', collect_device: str = 'cpu', prefix: Optional[str] = None)[源代码]

Bases: mmengine.evaluator.BaseMetric

Base sample wise metric of edit.

Subclass must provide process function.

参数
  • gt_key (str) – Key of ground-truth. Default: ‘gt_img’

  • pred_key (str) – Key of prediction. Default: ‘pred_img’

  • mask_key (str, optional) – Key of mask, if mask_key is None, calculate all regions. Default: None

  • collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.

  • device (str) – Device used to place torch tensors to compute metrics. Defaults to ‘cpu’.

  • prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Default: None

  • scaling (float, optional) – Scaling factor for final metric. E.g. scaling=100 means the final metric will be amplified by 100 for output. Default: 1

SAMPLER_MODE = 'normal'[源代码]
sample_model = 'orig'[源代码]
metric[源代码]
compute_metrics(results: List)[源代码]

Compute the metrics from processed results.

参数

results (List) – The processed results of each batch.

返回

The computed metrics. The keys are the names of the metrics, and the values are corresponding results.

返回类型

Dict

process(data_batch: Sequence[dict], data_samples: Sequence[dict]) None[源代码]

Process one batch of data and predictions.

参数
  • data_batch (Sequence[dict]) – A batch of data from the dataloader.

  • predictions (Sequence[dict]) – A batch of outputs from the model.

abstract process_image(gt, pred, mask)[源代码]
evaluate() dict[源代码]

Evaluate the model performance of the whole dataset after processing all batches.

参数

size (int) – Length of the entire validation dataset. When batch size > 1, the dataloader may pad some data samples to make sure all ranks have the same length of dataset slice. The collect_results function will drop the padded data based on this size.

返回

Evaluation metrics dict on the val dataset. The keys are the names of the metrics, and the values are corresponding results.

返回类型

dict

prepare(module: torch.nn.Module, dataloader: torch.utils.data.dataloader.DataLoader)[源代码]
get_metric_sampler(model: torch.nn.Module, dataloader: torch.utils.data.dataloader.DataLoader, metrics) torch.utils.data.dataloader.DataLoader[源代码]

Get sampler for normal metrics. Directly returns the dataloader.

参数
  • model (nn.Module) – Model to evaluate.

  • dataloader (DataLoader) – Dataloader for real images.

  • metrics (List['GenMetric']) – Metrics with the same sample mode.

返回

Default sampler for normal metrics.

返回类型

DataLoader

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.