Source code for tabeval.metrics.plots

# stdlib
from typing import Any

# third party
import numpy as np
import pandas as pd
from pydantic import validate_arguments
from sklearn.manifold import TSNE

# tabeval absolute
from tabeval.metrics.eval_statistical import JensenShannonDistance
from tabeval.plugins.core.dataloader import DataLoader

COLOR_PALETTE = ["#2b2d42", "#d90429"]
LABELS = ["real", "syn"]


[docs] @validate_arguments(config=dict(arbitrary_types_allowed=True)) def plot_marginal_comparison( plt: Any, X_gt: DataLoader, X_syn: DataLoader, normalize: bool = True ) -> None: evaluator = JensenShannonDistance(n_histogram_bins=10) stats_, stats_gt, stats_syn = evaluator._evaluate_stats(X_gt, X_syn) column_names = stats_gt.keys() plots_cnt = len(column_names) row_len = 2 fig, ax = plt.subplots( int(np.ceil(plots_cnt / row_len)), row_len, figsize=(14, plots_cnt * 3) ) fig.subplots_adjust(hspace=1) if plots_cnt % row_len != 0: fig.delaxes(ax[-1][-1]) for idx, col in enumerate(column_names): row_idx = int(idx / row_len) col_idx = idx % row_len local_ax = ax[row_idx][col_idx] column_value_counts_original = stats_gt[col] column_value_counts_synthetic = stats_syn[col] bar_position = np.arange(len(column_value_counts_original.values)) bar_width = 0.4 # real distribution local_ax.bar( x=bar_position, height=column_value_counts_original.values, color=COLOR_PALETTE[0], label=LABELS[0], width=bar_width, ) # synthetic distribution local_ax.bar( x=bar_position + bar_width, height=column_value_counts_synthetic.values, color=COLOR_PALETTE[1], label=LABELS[1], width=bar_width, ) local_ax.set_xticks(bar_position + bar_width / 2) local_ax.set_xticklabels(column_value_counts_original.keys(), rotation=90) title = ( r"$\bf{" + col.replace("_", "\\_") + "}$" + "\n jensen-shannon distance: {:.2f}".format(stats_[col]) ) local_ax.set_title(title) if normalize: local_ax.set_ylabel("Probability") else: local_ax.set_ylabel("Count") local_ax.legend()
[docs] @validate_arguments(config=dict(arbitrary_types_allowed=True)) def plot_tsne( plt: Any, X_gt: DataLoader, X_syn: DataLoader, ) -> None: fig, ax = plt.subplots(1, 1, figsize=(12, 10)) tsne_gt = TSNE(n_components=2, random_state=0, learning_rate="auto", init="pca") proj_gt = pd.DataFrame(tsne_gt.fit_transform(X_gt.dataframe())) tsne_syn = TSNE(n_components=2, random_state=0, learning_rate="auto", init="pca") proj_syn = pd.DataFrame(tsne_syn.fit_transform(X_syn.dataframe())) ax.scatter(x=proj_gt[0], y=proj_gt[1], s=10, label="Real data") ax.scatter(x=proj_syn[0], y=proj_syn[1], s=10, label="Synthetic data") ax.legend(loc="upper left") ax.set_ylabel("t-SNE plot")