Source code for tabeval.metrics.eval_statistical

# stdlib
import platform
from abc import abstractmethod
from typing import Any, Dict, Optional, Tuple

# third party
import numpy as np
import pandas as pd
import torch
from geomloss import SamplesLoss
from pydantic import validate_arguments
from scipy import linalg
from scipy.spatial.distance import jensenshannon
from scipy.special import kl_div
from scipy.stats import chisquare, ks_2samp
from sklearn import metrics
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import MinMaxScaler

# tabeval absolute
import tabeval.logger as log
from tabeval.metrics._utils import get_frequency
from tabeval.metrics.core import MetricEvaluator
from tabeval.plugins.core.dataloader import DataLoader
from tabeval.plugins.core.models.survival_analysis.metrics import nonparametric_distance
from tabeval.utils.reproducibility import clear_cache
from tabeval.utils.serialization import load_from_file, save_to_file


[docs] class StatisticalEvaluator(MetricEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.StatisticalEvaluator :parts: 1 """ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs)
[docs] @staticmethod def type() -> str: return "stats"
@abstractmethod def _evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict: ...
[docs] @validate_arguments(config=dict(arbitrary_types_allowed=True)) def evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict: cache_file = ( self._workspace / f"sc_metric_cache_{self.type()}_{self.name()}_{X_gt.hash()}_{X_syn.hash()}_{self._reduction}_{platform.python_version()}.bkp" ) if self.use_cache(cache_file): return load_from_file(cache_file) clear_cache() results = self._evaluate(X_gt, X_syn) save_to_file(cache_file, results) return results
[docs] @validate_arguments(config=dict(arbitrary_types_allowed=True)) def evaluate_default( self, X_gt: DataLoader, X_syn: DataLoader, ) -> float: return self.evaluate(X_gt, X_syn)[self._default_metric]
[docs] class InverseKLDivergence(StatisticalEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.InverseKLDivergence :parts: 1 Returns the average inverse of the Kullback–Leibler Divergence metric. Score: 0: the datasets are from different distributions. 1: the datasets are from the same distribution. """ def __init__(self, **kwargs: Any) -> None: super().__init__(default_metric="marginal", **kwargs)
[docs] @staticmethod def name() -> str: return "inv_kl_divergence"
[docs] @staticmethod def direction() -> str: return "maximize"
@validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict: freqs = get_frequency(X_gt.dataframe(), X_syn.dataframe(), n_histogram_bins=self._n_histogram_bins) res = [] for col in X_gt.columns: gt_freq, synth_freq = freqs[col] res.append(1 / (1 + np.sum(kl_div(gt_freq, synth_freq)))) return {"marginal": float(self.reduction()(res))}
[docs] class KolmogorovSmirnovTest(StatisticalEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.KolmogorovSmirnovTest :parts: 1 Performs the Kolmogorov-Smirnov test for goodness of fit. Score: 0: the distributions are totally different. 1: the distributions are identical. """ def __init__(self, **kwargs: Any) -> None: super().__init__(default_metric="marginal", **kwargs)
[docs] @staticmethod def name() -> str: return "ks_test"
[docs] @staticmethod def direction() -> str: return "maximize"
@validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict: res = [] for col in X_gt.columns: statistic, _ = ks_2samp(X_gt[col], X_syn[col]) res.append(1 - statistic) return {"marginal": float(self.reduction()(res))}
[docs] class ChiSquaredTest(StatisticalEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.ChiSquaredTest :parts: 1 Performs the one-way chi-square test. Returns: The p-value. A small value indicates that we can reject the null hypothesis and that the distributions are different. Score: 0: the distributions are different 1: the distributions are identical. """ def __init__(self, **kwargs: Any) -> None: super().__init__(default_metric="marginal", **kwargs)
[docs] @staticmethod def name() -> str: return "chi_squared_test"
[docs] @staticmethod def direction() -> str: return "maximize"
@validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict: res = [] freqs = get_frequency(X_gt.dataframe(), X_syn.dataframe(), n_histogram_bins=self._n_histogram_bins) for col in X_gt.columns: gt_freq, synth_freq = freqs[col] try: _, pvalue = chisquare(gt_freq, synth_freq) if np.isnan(pvalue): pvalue = 0 except BaseException: log.error("chisquare failed") pvalue = 0 res.append(pvalue) return {"marginal": float(self.reduction()(res))}
[docs] class MaximumMeanDiscrepancy(StatisticalEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.MaximumMeanDiscrepancy :parts: 1 Empirical maximum mean discrepancy. The lower the result the more evidence that distributions are the same. Args: kernel: "rbf", "linear" or "polynomial" Score: 0: The distributions are the same. 1: The distributions are totally different. """ @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__(self, kernel: str = "rbf", **kwargs: Any) -> None: super().__init__(default_metric="joint", **kwargs) self.kernel = kernel
[docs] @staticmethod def name() -> str: return "max_mean_discrepancy"
[docs] @staticmethod def direction() -> str: return "minimize"
@validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate( self, X_gt: DataLoader, X_syn: DataLoader, ) -> Dict: if self.kernel == "linear": """ MMD using linear kernel (i.e., k(x,y) = <x,y>) """ delta_df = X_gt.dataframe().mean(axis=0) - X_syn.dataframe().mean(axis=0) delta = delta_df.values score = delta.dot(delta.T) elif self.kernel == "rbf": """ MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2)) """ gamma = 1.0 XX = metrics.pairwise.rbf_kernel( X_gt.numpy().reshape(len(X_gt), -1), X_gt.numpy().reshape(len(X_gt), -1), gamma, ) YY = metrics.pairwise.rbf_kernel( X_syn.numpy().reshape(len(X_syn), -1), X_syn.numpy().reshape(len(X_syn), -1), gamma, ) XY = metrics.pairwise.rbf_kernel( X_gt.numpy().reshape(len(X_gt), -1), X_syn.numpy().reshape(len(X_syn), -1), gamma, ) score = XX.mean() + YY.mean() - 2 * XY.mean() elif self.kernel == "polynomial": """ MMD using polynomial kernel (i.e., k(x,y) = (gamma <X, Y> + coef0)^degree) """ degree = 2 gamma = 1 coef0 = 0 XX = metrics.pairwise.polynomial_kernel( X_gt.numpy().reshape(len(X_gt), -1), X_gt.numpy().reshape(len(X_gt), -1), degree, gamma, coef0, ) YY = metrics.pairwise.polynomial_kernel( X_syn.numpy().reshape(len(X_syn), -1), X_syn.numpy().reshape(len(X_syn), -1), degree, gamma, coef0, ) XY = metrics.pairwise.polynomial_kernel( X_gt.numpy().reshape(len(X_gt), -1), X_syn.numpy().reshape(len(X_syn), -1), degree, gamma, coef0, ) score = XX.mean() + YY.mean() - 2 * XY.mean() else: raise ValueError(f"Unsupported kernel {self.kernel}") return {"joint": float(score)}
[docs] class JensenShannonDistance(StatisticalEvaluator): """Evaluate the average Jensen-Shannon distance (metric) between two probability arrays.""" @validate_arguments(config=dict(arbitrary_types_allowed=True)) def __init__(self, normalize: bool = True, **kwargs: Any) -> None: super().__init__(default_metric="marginal", **kwargs) self.normalize = normalize
[docs] @staticmethod def name() -> str: return "jensenshannon_dist"
[docs] @staticmethod def direction() -> str: return "minimize"
@validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate_stats( self, X_gt: DataLoader, X_syn: DataLoader, ) -> Tuple[Dict, Dict, Dict]: stats_gt = {} stats_syn = {} stats_ = {} for col in X_gt.columns: local_bins = min(self._n_histogram_bins, len(X_gt[col].unique())) X_gt_bin, gt_bins = pd.cut(X_gt[col], bins=local_bins, retbins=True) X_syn_bin = pd.cut(X_syn[col], bins=gt_bins) stats_gt[col], stats_syn[col] = X_gt_bin.value_counts(dropna=False, normalize=self.normalize).align( X_syn_bin.value_counts(dropna=False, normalize=self.normalize), join="outer", axis=0, fill_value=0, ) stats_gt[col] += 1 stats_syn[col] += 1 stats_[col] = jensenshannon(stats_gt[col], stats_syn[col]) if np.isnan(stats_[col]): raise RuntimeError("NaNs in prediction") return stats_, stats_gt, stats_syn @validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate( self, X_gt: DataLoader, X_syn: DataLoader, ) -> Dict: stats_, _, _ = self._evaluate_stats(X_gt, X_syn) return {"marginal": sum(stats_.values()) / len(stats_.keys()) if len(stats_.keys()) > 0 else np.nan}
[docs] class WassersteinDistance(StatisticalEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.WassersteinDistance :parts: 1 Compare Wasserstein distance between original data and synthetic data. Args: X: original data X_syn: synthetically generated data Returns: WD_value: Wasserstein distance """ def __init__(self, **kwargs: Any) -> None: super().__init__(default_metric="joint", **kwargs)
[docs] @staticmethod def name() -> str: return "wasserstein_dist"
[docs] @staticmethod def direction() -> str: return "minimize"
@validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate( self, X: DataLoader, X_syn: DataLoader, ) -> Dict: X_ = np.ascontiguousarray(X.numpy().reshape(len(X), -1)) X_syn_ = np.ascontiguousarray(X_syn.numpy().reshape(len(X_syn), -1)) if len(X_) > len(X_syn_): X_syn_ = np.concatenate([X_syn_, np.zeros((len(X_) - len(X_syn_), X_.shape[1]))]) scaler = MinMaxScaler().fit(X_) X_ = scaler.transform(X_) X_syn_ = scaler.transform(X_syn_) X_ten = torch.from_numpy(X_).double() Xsyn_ten = torch.from_numpy(X_syn_).double() OT_solver = SamplesLoss(loss="sinkhorn") return {"joint": OT_solver(X_ten, Xsyn_ten).cpu().numpy().item()}
[docs] class PRDCScore(StatisticalEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.PRDCScore :parts: 1 Computes precision, recall, density, and coverage given two manifolds. Args: nearest_k: int. """ def __init__(self, nearest_k: int = 5, **kwargs: Any) -> None: super().__init__(default_metric="precision", **kwargs) self.nearest_k = nearest_k
[docs] @staticmethod def name() -> str: return "prdc"
[docs] @staticmethod def direction() -> str: return "maximize"
@validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate( self, X: DataLoader, X_syn: DataLoader, ) -> Dict: X_ = np.ascontiguousarray(X.numpy().reshape(len(X), -1)) X_syn_ = np.ascontiguousarray(X_syn.numpy().reshape(len(X_syn), -1)) # Default representation results = self._compute_prdc(X_, X_syn_) return results def _compute_pairwise_distance(self, data_x: np.ndarray, data_y: Optional[np.ndarray] = None) -> np.ndarray: """ Args: data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) Returns: numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. """ if data_y is None: data_y = data_x dists = metrics.pairwise_distances(data_x, data_y) return dists def _get_kth_value(self, unsorted: np.ndarray, k: int, axis: int = -1) -> np.ndarray: """ Args: unsorted: numpy.ndarray of any dimensionality. k: int Returns: kth values along the designated axis. """ indices = np.argpartition(unsorted, k, axis=axis)[..., :k] k_smallests = np.take_along_axis(unsorted, indices, axis=axis) kth_values = k_smallests.max(axis=axis) return kth_values def _compute_nearest_neighbour_distances(self, input_features: np.ndarray, nearest_k: int) -> np.ndarray: """ Args: input_features: numpy.ndarray nearest_k: int Returns: Distances to kth nearest neighbours. """ distances = self._compute_pairwise_distance(input_features) radii = self._get_kth_value(distances, k=nearest_k + 1, axis=-1) return radii def _compute_prdc(self, real_features: np.ndarray, fake_features: np.ndarray) -> Dict: """ Computes precision, recall, density, and coverage given two manifolds. Args: real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) Returns: dict of precision, recall, density, and coverage. """ real_nearest_neighbour_distances = self._compute_nearest_neighbour_distances(real_features, self.nearest_k) fake_nearest_neighbour_distances = self._compute_nearest_neighbour_distances(fake_features, self.nearest_k) distance_real_fake = self._compute_pairwise_distance(real_features, fake_features) precision = (distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1)).any(axis=0).mean() recall = (distance_real_fake < np.expand_dims(fake_nearest_neighbour_distances, axis=0)).any(axis=1).mean() density = (1.0 / float(self.nearest_k)) * ( distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1) ).sum(axis=0).mean() coverage = (distance_real_fake.min(axis=1) < real_nearest_neighbour_distances).mean() return dict(precision=precision, recall=recall, density=density, coverage=coverage)
[docs] class AlphaPrecision(StatisticalEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.AlphaPrecision :parts: 1 Evaluates the alpha-precision, beta-recall, and authenticity scores. The class evaluates the synthetic data using a tuple of three metrics: alpha-precision, beta-recall, and authenticity. Note that these metrics can be evaluated for each synthetic data point (which are useful for auditing and post-processing). Here we average the scores to reflect the overall quality of the data. The formal definitions can be found in the reference below: Alaa, Ahmed, Boris Van Breugel, Evgeny S. Saveliev, and Mihaela van der Schaar. "How faithful is your synthetic data? sample-level metrics for evaluating and auditing generative models." In International Conference on Machine Learning, pp. 290-306. PMLR, 2022. """ def __init__(self, **kwargs: Any) -> None: super().__init__(default_metric="authenticity_OC", **kwargs)
[docs] @staticmethod def name() -> str: return "alpha_precision"
[docs] @staticmethod def direction() -> str: return "maximize"
[docs] @validate_arguments(config=dict(arbitrary_types_allowed=True)) def metrics( self, X: np.ndarray, X_syn: np.ndarray, emb_center: Optional[np.ndarray] = None, ) -> Tuple: if emb_center is None: emb_center = np.mean(X, axis=0) n_steps = 30 alphas = np.linspace(0, 1, n_steps) Radii = np.quantile(np.sqrt(np.sum((X - emb_center) ** 2, axis=1)), alphas) synth_center = np.mean(X_syn, axis=0) alpha_precision_curve = [] beta_coverage_curve = [] synth_to_center = np.sqrt(np.sum((X_syn - emb_center) ** 2, axis=1)) nbrs_real = NearestNeighbors(n_neighbors=2, n_jobs=-1, p=2).fit(X) real_to_real, real_to_real_args = nbrs_real.kneighbors(X) nbrs_synth = NearestNeighbors(n_neighbors=1, n_jobs=-1, p=2).fit(X_syn) real_to_synth, real_to_synth_args = nbrs_synth.kneighbors(X) # Let us find closest real point to any real point, excluding itself (therefore 1 instead of 0) real_to_real = real_to_real[:, 1].squeeze() real_to_real_args = real_to_real_args[:, 1].squeeze() real_to_synth = real_to_synth.squeeze() real_to_synth_args = real_to_synth_args.squeeze() real_synth_closest = X_syn[real_to_synth_args] real_synth_closest_d = np.sqrt(np.sum((real_synth_closest - synth_center) ** 2, axis=1)) closest_synth_Radii = np.quantile(real_synth_closest_d, alphas) for k in range(len(Radii)): precision_audit_mask = synth_to_center <= Radii[k] alpha_precision = np.mean(precision_audit_mask) beta_coverage = np.mean( ((real_to_synth <= real_to_real) * (real_synth_closest_d <= closest_synth_Radii[k])) ) alpha_precision_curve.append(alpha_precision) beta_coverage_curve.append(beta_coverage) # See which one is bigger # The original implementation uses the following, where the index of `real_to_real` seems wrong. # According to the paper ()https://arxiv.org/abs/2102.08921), the first distance should be real samples to the nearest real sample. # authen = real_to_real[real_to_synth_args] < real_to_synth authen = real_to_real[real_to_real_args] < real_to_synth authenticity = np.mean(authen) Delta_precision_alpha = 1 - np.sum(np.abs(np.array(alphas) - np.array(alpha_precision_curve))) / np.sum(alphas) if Delta_precision_alpha < 0: raise RuntimeError("negative value detected for Delta_precision_alpha") Delta_coverage_beta = 1 - np.sum(np.abs(np.array(alphas) - np.array(beta_coverage_curve))) / np.sum(alphas) if Delta_coverage_beta < 0: raise RuntimeError("negative value detected for Delta_coverage_beta") return ( alphas, alpha_precision_curve, beta_coverage_curve, Delta_precision_alpha, Delta_coverage_beta, authenticity, )
def _normalize_covariates( self, X: DataLoader, X_syn: DataLoader, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """_normalize_covariates This is an internal method to replicate the old, naive method for evaluating AlphaPrecision. Args: X (DataLoader): The ground truth dataset. X_syn (DataLoader): The synthetic dataset. Returns: Tuple[pd.DataFrame, pd.DataFrame]: normalised version of the datasets """ X_gt_norm = X.dataframe().copy() X_syn_norm = X_syn.dataframe().copy() if self._task_type != "survival_analysis": if hasattr(X, "target_column") and hasattr(X_gt_norm, X.target_column): X_gt_norm = X_gt_norm.drop(columns=[X.target_column]) if hasattr(X_syn, "target_column") and hasattr(X_syn_norm, X_syn.target_column): X_syn_norm = X_syn_norm.drop(columns=[X_syn.target_column]) scaler = MinMaxScaler().fit(X_gt_norm) if hasattr(X, "target_column"): X_gt_norm_df = pd.DataFrame( scaler.transform(X_gt_norm), columns=[col for col in X.train().dataframe().columns if col != X.target_column], ) else: X_gt_norm_df = pd.DataFrame(scaler.transform(X_gt_norm), columns=X.train().dataframe().columns) if hasattr(X_syn, "target_column"): X_syn_norm_df = pd.DataFrame( scaler.transform(X_syn_norm), columns=[col for col in X_syn.dataframe().columns if col != X_syn.target_column], ) else: X_syn_norm_df = pd.DataFrame(scaler.transform(X_syn_norm), columns=X_syn.dataframe().columns) return (X_gt_norm_df, X_syn_norm_df) @validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate( self, X: DataLoader, X_syn: DataLoader, ) -> Dict: results = {} X_ = np.ascontiguousarray(X.numpy().reshape(len(X), -1)) X_syn_ = np.ascontiguousarray(X_syn.numpy().reshape(len(X_syn), -1)) # OneClass representation emb = "_OC" oneclass_model = self._get_oneclass_model(X_) X_ = self._oneclass_predict(oneclass_model, X_) X_syn_ = self._oneclass_predict(oneclass_model, X_syn_) emb_center = oneclass_model.c.detach().cpu().numpy() ( alphas, alpha_precision_curve, beta_coverage_curve, Delta_precision_alpha, Delta_coverage_beta, authenticity, ) = self.metrics(X_, X_syn_, emb_center=emb_center) results[f"delta_precision_alpha{emb}"] = Delta_precision_alpha results[f"delta_coverage_beta{emb}"] = Delta_coverage_beta results[f"authenticity{emb}"] = authenticity X_df, X_syn_df = self._normalize_covariates(X, X_syn) ( alphas_naive, alpha_precision_curve_naive, beta_coverage_curve_naive, Delta_precision_alpha_naive, Delta_coverage_beta_naive, authenticity_naive, ) = self.metrics(X_df.to_numpy(), X_syn_df.to_numpy(), emb_center=None) results["delta_precision_alpha_naive"] = Delta_precision_alpha_naive results["delta_coverage_beta_naive"] = Delta_coverage_beta_naive results["authenticity_naive"] = authenticity_naive return results
[docs] class SurvivalKMDistance(StatisticalEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.SurvivalKMDistance :parts: 1 The distance between two Kaplan-Meier plots. Used for survival analysis""" def __init__(self, **kwargs: Any) -> None: super().__init__(default_metric="optimism", **kwargs)
[docs] @staticmethod def name() -> str: return "survival_km_distance"
[docs] @staticmethod def direction() -> str: return "minimize"
@validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate( self, X: DataLoader, X_syn: DataLoader, ) -> Dict: if self._task_type != "survival_analysis": raise RuntimeError(f"The metric is valid only for survival analysis tasks, but got {self._task_type}") if X.type() != "survival_analysis" or X_syn.type() != "survival_analysis": raise RuntimeError( f"The metric is valid only for survival analysis tasks, but got datasets {X.type()} and {X_syn.type()}" ) _, real_T, real_E = X.unpack() _, syn_T, syn_E = X_syn.unpack() optimism, abs_optimism, sightedness = nonparametric_distance((real_T, real_E), (syn_T, syn_E)) return { "optimism": optimism, "abs_optimism": abs_optimism, "sightedness": sightedness, }
[docs] class FrechetInceptionDistance(StatisticalEvaluator): """ .. inheritance-diagram:: tabeval.metrics.eval_statistical.FrechetInceptionDistance :parts: 1 Calculates the Frechet Inception Distance (FID) to evalulate GANs. Paper: GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium. The FID metric calculates the distance between two distributions of images. Typically, we have summary statistics (mean & covariance matrix) of one of these distributions, while the 2nd distribution is given by a GAN. Adapted by Boris van Breugel(bv292@cam.ac.uk) """ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs)
[docs] @staticmethod def name() -> str: return "fid"
[docs] @staticmethod def direction() -> str: return "minimize"
def _fit_gaussian(self, act: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Calculation of the statistics used by the FID. Params: -- act : activations Returns: -- mu : The mean over samples of the activations -- sigma : The covariance matrix of the activations """ mu = np.mean(act, axis=0) sigma = np.cov(act.T) return mu, sigma def _calculate_frechet_distance( self, mu1: np.ndarray, sigma1: np.ndarray, mu2: np.ndarray, sigma2: np.ndarray, eps: float = 1e-6, ) -> float: """Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). Stable version by Dougal J. Sutherland. Params: -- mu1 : Numpy array containing the activations of the pool_3 layer of the inception net ( like returned by the function 'get_predictions') for generated samples. -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted on an representive data set. -- sigma1: The covariance matrix over activations of the pool_3 layer for generated samples. -- sigma2: The covariance matrix over activations of the pool_3 layer, precalcualted on an representive data set. Returns: -- : The Frechet Distance. """ mu1 = np.atleast_1d(mu1) mu2 = np.atleast_1d(mu2) sigma1 = np.atleast_2d(sigma1) sigma2 = np.atleast_2d(sigma2) if mu1.shape != mu2.shape: raise RuntimeError("Training and test mean vectors have different lengths") if sigma1.shape != sigma2.shape: raise RuntimeError("Training and test covariances have different dimensions") diff = mu1 - mu2 # product might be almost singular covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) # numerical error might give slight imaginary component if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=2e-3): m = np.max(np.abs(covmean.imag)) raise ValueError("Imaginary component {}".format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean @validate_arguments(config=dict(arbitrary_types_allowed=True)) def _evaluate( self, X: DataLoader, X_syn: DataLoader, ) -> Dict: if X.type() != "images": raise RuntimeError( f"The metric is valid only for image tasks, but got datasets {X.type()} and {X_syn.type()}" ) X1 = X.numpy().reshape(len(X), -1) X2 = X_syn.numpy().reshape(len(X_syn), -1) mu1, cov1 = self._fit_gaussian(X1) mu2, cov2 = self._fit_gaussian(X2) score = self._calculate_frechet_distance(mu1, cov1, mu2, cov2) return { "score": score, }