mmagic.evaluation.metrics.equivariance
¶
Module Contents¶
Classes¶
Metric for generative metrics. Except for the preparation phase |
|
- class mmagic.evaluation.metrics.equivariance.Equivariance(fake_nums: int, real_nums: int = 0, fake_key: Optional[str] = None, real_key: Optional[str] = 'gt_img', need_cond_input: bool = False, sample_mode: str = 'ema', sample_kwargs: dict = dict(), collect_device: str = 'cpu', prefix: Optional[str] = None, eq_cfg=dict())[source]¶
Bases:
mmagic.evaluation.metrics.base_gen_metric.GenerativeMetric
Metric for generative metrics. Except for the preparation phase (
prepare()
), generative metrics do not need extra real images.- Parameters
fake_nums (int) – Numbers of the generated image need for the metric.
real_nums (int) – Numbers of the real image need for the metric. If -1 is passed means all images from the dataset is need. Defaults to 0.
fake_key (Optional[str]) – Key for get fake images of the output dict. Defaults to None.
real_key (Optional[str]) – Key for get real images from the input dict. Defaults to ‘img’.
need_cond_input (bool) – If true, the sampler will return the conditional input randomly sampled from the original dataset. This require the dataset implement get_data_info and field gt_label must be contained in the return value of get_data_info. Noted that, for unconditional models, set need_cond_input as True may influence the result of evaluation results since the conditional inputs are sampled from the dataset distribution; otherwise will be sampled from the uniform distribution. Defaults to False.
sample_model (str) – Sampling mode for the generative model. Support ‘orig’ and ‘ema’. Defaults to ‘ema’.
collect_device (str) – Device name used for collecting results from different ranks during distributed training. Must be ‘cpu’ or ‘gpu’. Defaults to ‘cpu’.
prefix (str, optional) – The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Defaults to None.
sample_kwargs (dict) – Sampling arguments for model test.
- process(data_batch: dict, data_samples: Sequence[dict]) None [source]¶
Process one batch of data samples and predictions. The processed results should be stored in
self.fake_results
, which will be used to compute the metrics when all batches have been processed.- Parameters
data_batch (dict) – A batch of data from the dataloader.
data_samples (Sequence[dict]) – A batch of outputs from the model.
- get_metric_sampler(model: torch.nn.Module, dataloader: torch.utils.data.dataloader.DataLoader, metrics: List[mmagic.evaluation.metrics.base_gen_metric.GenerativeMetric])[source]¶
Get sampler for generative metrics. Returns a dummy iterator, whose return value of each iteration is a dict containing batch size and sample mode to generate images.
- Parameters
model (nn.Module) – Model to evaluate.
dataloader (DataLoader) – Dataloader for real images. Used to get batch size during generate fake images.
metrics (List['GenerativeMetric']) – Metrics with the same sampler mode.
- Returns
Sampler for generative metrics.
- Return type
dummy_iterator