mmagic.evaluation.metrics.base_gen_metric
¶
Module Contents¶
Classes¶
Metric for MMagic. |
|
Metric for generative metrics. Except for the preparation phase |
- class mmagic.evaluation.metrics.base_gen_metric.GenMetric(fake_nums: int, real_nums: int = 0, fake_key: Optional[str] = None, real_key: Optional[str] = 'gt_img', sample_model: str = 'ema', collect_device: str = 'cpu', prefix: Optional[str] = None)[源代码]¶
Bases:
mmengine.evaluator.BaseMetric
Metric for MMagic.
- 参数
fake_nums (int) – Numbers of the generated image need for the metric.
real_nums (int) – Numbers of the real image need for the metric. If -1 is passed means all images from the dataset is need. Defaults to 0.
fake_key (Optional[str]) – Key for get fake images of the output dict. Defaults to None.
real_key (Optional[str]) – Key for get real images from the input dict. Defaults to ‘img’.
sample_model (str) – Sampling model for the generative model. Support ‘orig’ and ‘ema’. Defaults to ‘ema’.
collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.
prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.
- _collect_target_results(target: str) Optional[list] [源代码]¶
Collected results in distributed environments.
- 参数
target (str) – Target results to collect.
- 返回
The collected results.
- 返回类型
Optional[list]
- evaluate() dict [源代码]¶
Evaluate the model performance of the whole dataset after processing all batches. Different like
BaseMetric
, this function evaluate the metric with paired results (results_fake and results_real).- 返回
- Evaluation metrics dict on the val dataset. The keys are the
names of the metrics, and the values are corresponding results.
- 返回类型
dict
- get_metric_sampler(model: torch.nn.Module, dataloader: torch.utils.data.dataloader.DataLoader, metrics: List[GenMetric]) torch.utils.data.dataloader.DataLoader [源代码]¶
Get sampler for normal metrics. Directly returns the dataloader.
- 参数
model (nn.Module) – Model to evaluate.
dataloader (DataLoader) – Dataloader for real images.
metrics (List['GenMetric']) – Metrics with the same sample mode.
- 返回
Default sampler for normal metrics.
- 返回类型
DataLoader
- class mmagic.evaluation.metrics.base_gen_metric.GenerativeMetric(fake_nums: int, real_nums: int = 0, fake_key: Optional[str] = None, real_key: Optional[str] = 'img', need_cond_input: bool = False, sample_model: str = 'ema', collect_device: str = 'cpu', prefix: Optional[str] = None, sample_kwargs: dict = dict())[源代码]¶
Bases:
GenMetric
Metric for generative metrics. Except for the preparation phase (
prepare()
), generative metrics do not need extra real images.- 参数
fake_nums (int) – Numbers of the generated image need for the metric.
real_nums (int) – Numbers of the real image need for the metric. If -1 is passed means all images from the dataset is need. Defaults to 0.
fake_key (Optional[str]) – Key for get fake images of the output dict. Defaults to None.
real_key (Optional[str]) – Key for get real images from the input dict. Defaults to ‘img’.
need_cond_input (bool) – If true, the sampler will return the conditional input randomly sampled from the original dataset. This require the dataset implement get_data_info and field gt_label must be contained in the return value of get_data_info. Noted that, for unconditional models, set need_cond_input as True may influence the result of evaluation results since the conditional inputs are sampled from the dataset distribution; otherwise will be sampled from the uniform distribution. Defaults to False.
sample_model (str) – Sampling mode for the generative model. Support ‘orig’ and ‘ema’. Defaults to ‘ema’.
collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.
prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.
sample_kwargs (dict) – Sampling arguments for model test.
- get_metric_sampler(model: torch.nn.Module, dataloader: torch.utils.data.dataloader.DataLoader, metrics: GenMetric)[源代码]¶
Get sampler for generative metrics. Returns a dummy iterator, whose return value of each iteration is a dict containing batch size and sample mode to generate images.
- 参数
model (nn.Module) – Model to evaluate.
dataloader (DataLoader) – Dataloader for real images. Used to get batch size during generate fake images.
metrics (List['GenMetric']) – Metrics with the same sampler mode.
- 返回
Sampler for generative metrics.
- 返回类型
dummy_iterator