data_generation module
GARAGE — Graph-Attentive Rare-cell-Aware single-cell data GEneration.
A two-stage framework for generating high-fidelity synthetic scRNA-seq data:
- Stage 1 (GAT Subsampling):
A Graph Attention Network (GAT) classifier is trained on a KNN cell-cell graph. Rare cell types receive a priority weight boost. After training, per-cell attention scores from the second GAT layer are extracted and the top-k cells (k = leakage_fraction * n_cells) are selected as “seeds”.
- Stage 2 (GAN Generation with Attention-Guided Seeding):
A Generator/Discriminator GAN is trained. Instead of pure noise, the generator receives a hybrid input batch: a mix of random noise vectors and the GAT-selected seed cells. This seeding anchors the generator to biologically realistic states, stabilises training, and ensures rare cell types are represented in the output.
- Datasets supported: Yan (124 cells, 6 types), Pollen (301 cells, 11 types),
CBMC (7,895 cells, 13 types), Muraro (2,126 cells, 10 types).
Usage
python -m data_generation.garage –dataset muraro
Citation
Ganguly, R., et al. “GARAGE: A Graph-Attentive GAN for Rare-Cell-Aware Single-Cell RNA-seq Data Generation.” bioRxiv, 2025. DOI: 10.1101/2025.09.28.679012
- class data_generation.garage.Discriminator(*args, **kwargs)
Bases:
ModuleDiscriminator: 512 → 256 → n_genes → 1, LeakyReLU(0.2). Returns a (logits, intermediate_representation) tuple.
- forward(x_in)
- class data_generation.garage.GATClassifier(*args, **kwargs)
Bases:
Module2-layer GAT classifier for node-level cell-type prediction. The priority_weight boosts attention toward rare-type cells after the first GAT layer, encouraging the model to attend to them.
- forward(data)
- class data_generation.garage.Generator(*args, **kwargs)
Bases:
ModuleGenerator: 1024 → 1024 → n_genes, LeakyReLU(0.2), no output activation.
- forward(z)
- data_generation.garage.gat_subsample(X, y_str, dataset_name, seed=42)
Train the GAT classifier and return the indices of the top-k cells ranked by attention weight from the second GAT layer.
- data_generation.garage.generate_data(generator, n_sample, n_features, Xnew, out_dir, dataset_name, k)
Generate synthetic data at multiple volume multipliers (0.25x – 1.5x n_sample). Save each batch as a CSV file.
- data_generation.garage.load_dataset(dataset_name)
Load expression matrix and cell-type labels for dataset_name.
- Returns:
X (np.ndarray (n_cells, n_genes) float32)
y_str (np.ndarray (n_cells,) string labels)
n_sample (int)
n_features (int)
- data_generation.garage.main()
- data_generation.garage.run_garage(dataset_name, out_dir=None, seed=42)
- Full GARAGE pipeline for a single dataset:
GAT subsampling → (2) GAN training → (3) data generation.
- data_generation.garage.sample_Z(m, n)
Uniform noise U(-1, 1) of shape (m, n).
- data_generation.garage.train_gan(x_plot, Xnew, n_features, seed=42)
Train the GAN with the GAT-seeded hybrid input.
The generator’s input batch Z_batch is a vertical stack of random noise plus GAT-selected seed cells, giving the generator a biological “anchor” for rare cell types.
Wasserstein distance between real and generated scRNA-seq distributions.
- For each (dataset, leakage_level), this script:
Loads the real expression matrix.
Loads the GARAGE‑generated data for the requested iteration.
Computes the 1‑Wasserstein (Earth Mover’s) distance between the two histograms using the Python Optimal Transport library (POT).
Datasets
Yan — header=None, transpose, gen data header=0+index_col=0 Pollen — header=None, no transpose, gen data header=0+index_col=0 CBMC — header=0+index_col=0, transpose, gen data header=0+index_col=0 Muraro — header=0, no transpose, gen data header=0+index_col=0
Usage
python -m data_generation.wasserstein_distance –dataset muraro –leakage 0.2 –gen_iter 3 –gen_csv data/gen_data/muraro_data_mixdata_iter3_top_426.csv
- data_generation.wasserstein_distance.load_generated(gen_csv, transpose=False)
Load generated CSV. Default: header=None (no index_col).
- data_generation.wasserstein_distance.load_real(dataset_name)
- data_generation.wasserstein_distance.main()
- data_generation.wasserstein_distance.wasserstein_distance(data, generated)
Compute 1-Wasserstein (Earth Mover’s) distance using POT.emd2.
Each row is treated as a sample; cost matrix is Euclidean between rows.
Core Functions
The garage module contains:
load_dataset()— loads expression matrix and labels viaconfig.DATASET_CONFIG.gat_main()— trains the GAT classifier, extracts attention weights, and returns seed cell indices.Generator— MLP that receives the hybrid (noise + seeds) input batch.Discriminator— MLP binary classifier.sample_Z()— constructs the hybrid input batch $(1-lambda) cdot z oplus lambda cdot x_{text{seed}}$.
The wasserstein_distance module contains:
compute_wasserstein()— returns the Earth Mover’s Distance between two expression matrices.