analysis module

Compute MMD and Sliced Wasserstein Distance for all methods and datasets. Produces label-agnostic distributional similarity metrics.

Usage: conda run -n ritwik_base python run_distribution_metrics.py

analysis.distribution_metrics.load_real(dataset)
analysis.distribution_metrics.load_synthetic(dataset, method)
analysis.distribution_metrics.main()
analysis.distribution_metrics.mmd_rbf(real, fake)
analysis.distribution_metrics.sliced_wasserstein(real, fake, seed=42)

Data validation: feature selection + clustering + metrics across seeds.

Feature selection technique (implemented in Python, matching R logic):

CV2 — largest normalized CV² dispersion → top 100 features

Clustering:
  • Feature select from generated data (CV2) → filter real data to those features

  • Cluster filtered REAL data with Leiden

  • GAT-GAN (GARAGE): grid search over dataset-specific resolution ranges, maximise ARI

  • All other methods: fixed resolution = 1.0

  • npcs = 20, n_neighbors = 30 for all

Metrics: ARI, NMI (cluster labels vs ground truth labels on filtered real data)

Output: results/rev8_raw_results.csv

Usage: conda run -n scrna python validate_and_evaluate.py

analysis.clustering_evaluation.cluster_and_evaluate(real_filt, true_labels, resolution, n_pcs, n_neighbors)
analysis.clustering_evaluation.cv2(data, n_genes=100)

Select top n_genes by normalized CV² dispersion. Exact reimplementation of the R CV2() function from Data_vaidation_ARI_new.ipynb.

analysis.clustering_evaluation.evaluate_baseline(real_filt, true_labels, n_pcs=20, n_neighbors=30)
analysis.clustering_evaluation.evaluate_sweep(real_filt, true_labels, res_range, n_pcs=20, n_neighbors=30)
analysis.clustering_evaluation.load_gen(ds_name, method_dir, method_prefix, seed, wanted_iter)
analysis.clustering_evaluation.load_labels(ds_cfg)
analysis.clustering_evaluation.load_real(ds_cfg)
analysis.clustering_evaluation.main()

Create publication-ready summary tables from raw clustering results.

Extended table: per-seed, per-method, per-dataset, per-feature-selection Main table: per-method, per-dataset, per-feature-selection → mean±std (3 seeds)

Outputs:

results/extended_table.csv results/main_table.csv

Usage: python analysis/build_summary_tables.py

analysis.build_summary_tables.fmt_mean_std(series)

Format as ‘mean±std’ to 4 decimal places, or ‘–’ if all NaN.

analysis.build_summary_tables.main()

ARI/NMI/F1 benchmarking for scRNA-seq-specific baseline methods. Fixed params: resolution=1.0, cosine distance, n_neighbors=30, n_pcs=20. Feature selection: Fano factor (lowest variance-to-mean) → top 100 genes. F1 is macro-averaged (f1_score with average=’macro’).

Usage: conda run -n scrna python run_rev4_validation.py

analysis.sc_specific_benchmark.evaluate_ari(gen_data, real_data, labels, n_pcs=20, n_neighbors=30, resolution=1.0)

Feature selection on generated data → filter real data → Leiden on real data → ARI.

analysis.sc_specific_benchmark.fano_selection(data, n_genes=100)

Select n_genes with LOWEST Fano factor (variance/mean).

analysis.sc_specific_benchmark.load_real(dataset)
analysis.sc_specific_benchmark.main()

Fixed marker-gene clustering evaluation (grid sweep).

Methods: GAN, VAE, LSH-GAN, GARAGE. Metrics: ARI, NMI.

All methods sweep Leiden resolution 0.1–3.0 (step 0.01) across all datasets. Pseudo-labels: NearestCentroid; fallback 3-NN majority vote.

Outputs:
  • results/marker_genes.csv

  • results/clustering_performance.csv

Usage: conda run -n scrna python run_rev5_marker_clustering_grid.py

analysis.marker_clustering_grid.compute_ari_nmi(gen_filt, real_filt, real_labels_enc, resolution, n_pcs=20, n_neighbors=30)
analysis.marker_clustering_grid.evaluate_clustering(gen_data, real_data, real_labels, marker_idx, method_name, n_pcs=20, n_neighbors=30)
analysis.marker_clustering_grid.evaluate_real_reference(real_filt, real_labels_enc, n_pcs=20, n_neighbors=30)
analysis.marker_clustering_grid.evaluate_sweep(gen_filt, real_filt, real_labels_enc, n_pcs=20, n_neighbors=30)
analysis.marker_clustering_grid.get_pseudo_labels(gen_filt, real_filt, real_labels_enc)

Primary: NearestCentroid on real -> predict on gen. Fallback (NC < 2 classes): 3-NN majority vote on real data (cosine).

analysis.marker_clustering_grid.load_garage_data(dataset)
analysis.marker_clustering_grid.load_gen_data(dataset, dir_name, prefix, suffix, iter_idx)
analysis.marker_clustering_grid.load_real_data(dataset)
analysis.marker_clustering_grid.main()
analysis.marker_clustering_grid.select_markers_cbmc(real, labels)
analysis.marker_clustering_grid.select_markers_muraro(real, labels)
analysis.marker_clustering_grid.select_markers_pollen(real, labels)
analysis.marker_clustering_grid.select_markers_yan(real, labels)

MMD Analysis: Maximum Mean Discrepancy with RBF kernel (median heuristic) Compares real vs synthetic data on the full preprocessed gene-expression matrix before feature selection. Lower MMD = better distributional agreement.

Usage:

python mmd_analysis.py

analysis.mmd_analysis.load_gan_data(dataset, iter_idx)
analysis.mmd_analysis.load_garage_data(dataset)
analysis.mmd_analysis.load_lsh_gan_data(dataset, iter_idx)
analysis.mmd_analysis.load_real_data(dataset)
analysis.mmd_analysis.main()
analysis.mmd_analysis.mmd_rbf(real, fake, sigma=None)
analysis.mmd_analysis.rbf_kernel(X, Y, sigma)

SWD Analysis: Sliced Wasserstein Distance on the full gene-expression space. Lower distance = better distributional agreement.

Usage:

python swd_analysis.py

analysis.swd_analysis.load_gan_data(dataset, iter_idx)
analysis.swd_analysis.load_garage_data(dataset)
analysis.swd_analysis.load_lsh_gan_data(dataset, iter_idx)
analysis.swd_analysis.load_real_data(dataset)
analysis.swd_analysis.main()
analysis.swd_analysis.sliced_wasserstein(real, fake, n_proj=200, seed=42)

Modules Overview

  • distribution_metrics — batch computation of MMD and SWD for all methods × datasets.

  • clustering_evaluation — feature selection + clustering across multiple random seeds.

  • aggregate_losses — aggregates per-run GAN loss CSV files into a single record.

  • build_summary_tables — builds mean ± std summary tables for WD, ARI, NMI, F1, MMD, SWD.

  • sc_specific_benchmark — ARI/NMI/F1 for scRNA-seq-specific baselines only.

  • marker_clustering_grid — grid search over clustering parameters.

  • plot_wasserstein_vs_leakage — generates WD vs leakage fraction figure.

  • mmd_analysis — standalone MMD computation and analysis.

  • swd_analysis — standalone Sliced Wasserstein Distance computation and analysis.