mmagic.evaluation.metrics.precision_and_recall
¶
Module Contents¶
Classes¶
Improved Precision and recall metric. |
Functions¶
|
Compute distances between real images and fake images. |
- mmagic.evaluation.metrics.precision_and_recall.compute_pr_distances(row_features, col_features, num_gpus=1, rank=0, col_batch_size=10000)[source]¶
Compute distances between real images and fake images.
This function is used for calculate Precision and Recall metric. Refer to:https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/metrics/precision_recall.py # noqa
- class mmagic.evaluation.metrics.precision_and_recall.PrecisionAndRecall(fake_nums, real_nums=- 1, k=3, fake_key: Optional[str] = None, real_key: Optional[str] = 'gt_img', need_cond_input: bool = False, sample_model: str = 'ema', collect_device: str = 'cpu', prefix: Optional[str] = None, vgg16_script='work_dirs/cache/vgg16.pt', vgg16_pkl=None, row_batch_size=10000, col_batch_size=10000, auto_save=True)[source]¶
Bases:
mmagic.evaluation.metrics.base_gen_metric.GenerativeMetric
Improved Precision and recall metric.
In this metric, we draw real and generated samples respectively, and embed them into a high-dimensional feature space using a pre-trained classifier network. We use these features to estimate the corresponding manifold. We obtain the estimation by calculating pairwise Euclidean distances between all feature vectors in the set and, for each feature vector, construct a hypersphere with radius equal to the distance to its kth nearest neighbor. Together, these hyperspheres define a volume in the feature space that serves as an estimate of the true manifold. Precision is quantified by querying for each generated image whether the image is within the estimated manifold of real images. Symmetrically, recall is calculated by querying for each real image whether the image is within estimated manifold of generated image.
Ref: https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/metrics/precision_recall.py # noqa
Note that we highly recommend that users should download the vgg16 script module from the following address. Then, the vgg16_script can be set with user’s local path. If not given, we will use the vgg16 from pytorch model zoo. However, this may bring significant different in the final results.
Tero’s vgg16: https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt
- Parameters
num_images (int) – The number of evaluated generated samples.
image_shape (tuple) – Image shape in order “CHW”. Defaults to None.
num_real_need (int | None, optional) – The number of real images. Defaults to None.
full_dataset (bool, optional) – Whether to use full dataset for evaluation. Defaults to False.
k (int, optional) – Kth nearest parameter. Defaults to 3.
bgr2rgb (bool, optional) – Whether to change the order of image channel. Defaults to True.
vgg16_script (str, optional) – Path for the Tero’s vgg16 module. Defaults to ‘work_dirs/cache/vgg16.pt’.
row_batch_size (int, optional) – The batch size of row data. Defaults to 10000.
col_batch_size (int, optional) – The batch size of col data. Defaults to 10000.
auto_save (bool, optional) – Whether save vgg feature automatically.
need_cond_input (bool) – If true, the sampler will return the conditional input randomly sampled from the original dataset. This require the dataset implement get_data_info and field gt_label must be contained in the return value of get_data_info. Noted that, for unconditional models, set need_cond_input as True may influence the result of evaluation results since the conditional inputs are sampled from the dataset distribution; otherwise will be sampled from the uniform distribution. Defaults to False.
- _load_vgg(vgg16_script: Optional[str]) Tuple[torch.nn.Module, bool] [source]¶
Load VGG network from the given path.
- Parameters
vgg16_script – The path of script model of VGG network. If None, will load the pytorch version.
- Returns
- The actually loaded VGG network and
corresponding style.
- Return type
Tuple[nn.Module, str]
- extract_features(images: torch.Tensor) torch.Tensor [source]¶
Extracting image features.
- Parameters
images (torch.Tensor) – Images tensor.
- Returns
Vgg16 features of input images.
- Return type
torch.Tensor
- compute_metrics(results_fake) dict [source]¶
compute_metrics.
- Returns
Summarized results.
- Return type
dict
- process(data_batch: dict, data_samples: Sequence[dict]) None [source]¶
Process one batch of data samples and predictions. The processed results should be stored in
self.fake_results
, which will be used to compute the metrics when all batches have been processed.- Parameters
data_batch (dict) – A batch of data from the dataloader.
data_samples (Sequence[dict]) – A batch of outputs from the model.