Shortcuts

mmagic.evaluation.metrics.base_gen_metric

Module Contents

Classes

GenMetric

Metric for MMagic.

GenerativeMetric

Metric for generative metrics. Except for the preparation phase

class mmagic.evaluation.metrics.base_gen_metric.GenMetric(fake_nums: int, real_nums: int = 0, fake_key: Optional[str] = None, real_key: Optional[str] = 'gt_img', sample_model: str = 'ema', collect_device: str = 'cpu', prefix: Optional[str] = None)[源代码]

Bases: mmengine.evaluator.BaseMetric

Metric for MMagic.

参数
  • fake_nums (int) – Numbers of the generated image need for the metric.

  • real_nums (int) – Numbers of the real image need for the metric. If -1 is passed means all images from the dataset is need. Defaults to 0.

  • fake_key (Optional[str]) – Key for get fake images of the output dict. Defaults to None.

  • real_key (Optional[str]) – Key for get real images from the input dict. Defaults to ‘img’.

  • sample_model (str) – Sampling model for the generative model. Support ‘orig’ and ‘ema’. Defaults to ‘ema’.

  • collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.

  • prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.

property real_nums_per_device[源代码]

Number of real images need for current device.

property fake_nums_per_device[源代码]

Number of fake images need for current device.

SAMPLER_MODE = 'normal'[源代码]
_collect_target_results(target: str) Optional[list][源代码]

Collected results in distributed environments.

参数

target (str) – Target results to collect.

返回

The collected results.

返回类型

Optional[list]

evaluate() dict[源代码]

Evaluate the model performance of the whole dataset after processing all batches. Different like BaseMetric, this function evaluate the metric with paired results (results_fake and results_real).

返回

Evaluation metrics dict on the val dataset. The keys are the

names of the metrics, and the values are corresponding results.

返回类型

dict

get_metric_sampler(model: torch.nn.Module, dataloader: torch.utils.data.dataloader.DataLoader, metrics: List[GenMetric]) torch.utils.data.dataloader.DataLoader[源代码]

Get sampler for normal metrics. Directly returns the dataloader.

参数
  • model (nn.Module) – Model to evaluate.

  • dataloader (DataLoader) – Dataloader for real images.

  • metrics (List['GenMetric']) – Metrics with the same sample mode.

返回

Default sampler for normal metrics.

返回类型

DataLoader

compute_metrics(results_fake, results_real) dict[源代码]

Compute the metrics from processed results.

参数

results (list) – The processed results of each batch.

返回

The computed metrics. The keys are the names of the metrics, and the values are corresponding results.

返回类型

dict

prepare(module: torch.nn.Module, dataloader: torch.utils.data.dataloader.DataLoader) None[源代码]

Prepare for the pre-calculating items of the metric. Defaults to do nothing.

参数
  • module (nn.Module) – Model to evaluate.

  • dataloader (DataLoader) – Dataloader for the real images.

class mmagic.evaluation.metrics.base_gen_metric.GenerativeMetric(fake_nums: int, real_nums: int = 0, fake_key: Optional[str] = None, real_key: Optional[str] = 'img', need_cond_input: bool = False, sample_model: str = 'ema', collect_device: str = 'cpu', prefix: Optional[str] = None, sample_kwargs: dict = dict())[源代码]

Bases: GenMetric

Metric for generative metrics. Except for the preparation phase (prepare()), generative metrics do not need extra real images.

参数
  • fake_nums (int) – Numbers of the generated image need for the metric.

  • real_nums (int) – Numbers of the real image need for the metric. If -1 is passed means all images from the dataset is need. Defaults to 0.

  • fake_key (Optional[str]) – Key for get fake images of the output dict. Defaults to None.

  • real_key (Optional[str]) – Key for get real images from the input dict. Defaults to ‘img’.

  • need_cond_input (bool) – If true, the sampler will return the conditional input randomly sampled from the original dataset. This require the dataset implement get_data_info and field gt_label must be contained in the return value of get_data_info. Noted that, for unconditional models, set need_cond_input as True may influence the result of evaluation results since the conditional inputs are sampled from the dataset distribution; otherwise will be sampled from the uniform distribution. Defaults to False.

  • sample_model (str) – Sampling mode for the generative model. Support ‘orig’ and ‘ema’. Defaults to ‘ema’.

  • collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.

  • prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.

  • sample_kwargs (dict) – Sampling arguments for model test.

SAMPLER_MODE = 'Generative'[源代码]
get_metric_sampler(model: torch.nn.Module, dataloader: torch.utils.data.dataloader.DataLoader, metrics: GenMetric)[源代码]

Get sampler for generative metrics. Returns a dummy iterator, whose return value of each iteration is a dict containing batch size and sample mode to generate images.

参数
  • model (nn.Module) – Model to evaluate.

  • dataloader (DataLoader) – Dataloader for real images. Used to get batch size during generate fake images.

  • metrics (List['GenMetric']) – Metrics with the same sampler mode.

返回

Sampler for generative metrics.

返回类型

dummy_iterator

evaluate()[源代码]

Evaluate generative metric. In this function we only collect fake_results because generative metrics do not need real images.

返回

Evaluation metrics dict on the val dataset. The keys are the

names of the metrics, and the values are corresponding results.

返回类型

dict

compute_metrics(results) dict[源代码]

Compute the metrics from processed results.

参数

results (list) – The processed results of each batch.

返回

The computed metrics. The keys are the names of the metrics, and the values are corresponding results.

返回类型

dict

Read the Docs v: latest
Versions
latest
stable
0.x
Downloads
pdf
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.