ablation_study module

GAN training stability ablation study.

Varies leakage fraction (lambda) — the proportion of GAT-selected real data mixed into the generator’s noise input Z_batch — across 4 scRNA-seq datasets.

Datasets: Muraro, CBMC, Yan, Pollen Leakage: lambda in {0.0, 0.1, 0.2, 0.3} Seed: single seed = 42 Logging: every 1000 iterations over 40001 total iterations

Output: results/rev6_losses.csv

Usage:

conda activate ritwik_base python ablation_study/leakage_ablation.py

class ablation_study.leakage_ablation.Discriminator(*args, **kwargs)

Bases: Module

forward(x_in)
class ablation_study.leakage_ablation.Generator(*args, **kwargs)

Bases: Module

forward(z)
ablation_study.leakage_ablation.gat_subsample(X_np, y_str, gat_seed=42)

Run GAT to rank cells by attention importance. Returns Xnew (top-k rows as float32 np array) and k.

ablation_study.leakage_ablation.load_cbmc()
ablation_study.leakage_ablation.load_muraro()
ablation_study.leakage_ablation.load_pollen()
ablation_study.leakage_ablation.load_yan()
ablation_study.leakage_ablation.main()
ablation_study.leakage_ablation.sample_Z(rows, cols)

Uniform noise U(-1, 1) of shape (rows, cols).

ablation_study.leakage_ablation.train_gan_with_leakage(X_np, Xnew, leakage, seed, device)

Train GAN for TOTAL_ITERS iterations, logging losses every LOG_EVERY steps.

leakagefloat

Fraction of Z_batch rows that are real GAT-selected data. 0.0 = all noise, 0.2 = 20 % real (approx. current behaviour).

Returns list of dicts: [{iteration, d_loss, g_loss}, …]

Multi-seed synthetic data generation for 5 methods × 5 seeds × 4 datasets.

Methods: WGAN, f-GAN, LSH-GAN, Vanilla GAN, GAT-GAN. Seeds: 42, 123, 456, 789, 1024 (same seed used by all 5 methods per run).

Iteration mapping:

Yan/CBMC/Muraro → iter3 (1.0 × n_real) Pollen → iter5 (1.5 × n_real)

Output:

gen_data/seed_{s}/{method}/{dataset_lower}_{prefix}_generated_mixdata_iter{iter}.csv

Usage: conda run -n ritwik_base python generate_synthetic_data.py

class ablation_study.multi_seed_synthesis.Critic(*args, **kwargs)

Bases: Module

forward(x)
class ablation_study.multi_seed_synthesis.FDiscriminator(*args, **kwargs)

Bases: Module

forward(x)
class ablation_study.multi_seed_synthesis.FGenerator(*args, **kwargs)

Bases: Module

forward(z)
class ablation_study.multi_seed_synthesis.GATG_Discriminator(*args, **kwargs)

Bases: Module

forward(x_in)
class ablation_study.multi_seed_synthesis.GATG_Generator(*args, **kwargs)

Bases: Module

forward(z)
class ablation_study.multi_seed_synthesis.LSHDiscriminator(*args, **kwargs)

Bases: Module

forward(x)
class ablation_study.multi_seed_synthesis.LSHGenerator(*args, **kwargs)

Bases: Module

forward(z)
class ablation_study.multi_seed_synthesis.VanillaDiscriminator(*args, **kwargs)

Bases: Module

forward(x)
class ablation_study.multi_seed_synthesis.VanillaGenerator(*args, **kwargs)

Bases: Module

forward(z)
class ablation_study.multi_seed_synthesis.WGenerator(*args, **kwargs)

Bases: Module

forward(z)
ablation_study.multi_seed_synthesis.fisher_ratio(T_real, T_fake)
ablation_study.multi_seed_synthesis.gat_subsample(X, y, index_list, k, device)
ablation_study.multi_seed_synthesis.knn_subsample(X, k=5)
ablation_study.multi_seed_synthesis.load_labels(ds_cfg)
ablation_study.multi_seed_synthesis.load_real(ds_cfg)
ablation_study.multi_seed_synthesis.main()
ablation_study.multi_seed_synthesis.sample_Z(m, n)
ablation_study.multi_seed_synthesis.train_and_generate_fgan(real, n_features, device, seed, ds_name, wanted_iter)
ablation_study.multi_seed_synthesis.train_and_generate_gan(real, n_features, device, seed, ds_name, wanted_iter)
ablation_study.multi_seed_synthesis.train_and_generate_gatgan(real, labels, n_features, device, seed, ds_name, wanted_iter)
ablation_study.multi_seed_synthesis.train_and_generate_lshgan(real, n_features, device, seed, ds_name, wanted_iter)
ablation_study.multi_seed_synthesis.train_and_generate_wgan(real, n_features, device, seed, ds_name, wanted_iter)

Modules Overview

  • leakage_ablation — runs GARAGE on all 4 datasets with varying leakage fractions $lambda in {0.0, 0.1, 0.2, 0.3}$ and logs GAN losses to results/rev6_losses.csv.

  • multi_seed_synthesis — runs GARAGE with 5 different random seeds on all 4 datasets to assess reproducibility of the generated data.