Step 1 — Simulate Gene Expression¶
In STpuppeteer, we assume each gene expression across cells follows a negative binomial distribution in the form of gamma-poisson conjugation. And the first step of simulation is to generate the gene-wise gamma parameters: μ (cell-wise mean expression) and θ (cell-wise expression variation).
Given the nature of simulation, genes are generated arbitrarily to serve as either celltype-marker or general housekeeping genes. Due to the different modes of expression (activated/muted) of marker genes, the μ parameters in the gamma distribution are generated per gene per cell type.
How do we sample those gamma parameters? As the number of parameters can easily explode (300 gene x 4 cell types = 1200 pairs of parameters), the package doesn't encourage you to define them all by hand, but rather define the parameters for the prior distribution (which is also gamma) where the gene-by-celltype-wise parameters are drawing from. Therefore, the parameter generation pipeline has two sub-steps:
- Gamma priors — define the shape of the μ and θ distributions
- Gene parameter table — sample one (μ, θ) per gene per cell type
Notice that in the current model, the dispersion parameter θ are sampled from the same prior distribution cross gene types and cell types.
So the overall pipeline looks like:
Gamma priors (marker_mu, marker_cv, theta_alpha, …)
│
▼
generate_gene_parameters()
│
▼
gpar_df — shape (n_genes × 2·n_celltypes)
columns: ct_0_mu, ct_0_theta, ct_1_mu, ct_1_theta, …
The plate diagram below illustrates the generative structure — hyperparameters (blue) feed into per-gene or per-gene-per-celltype latent variables, which together determine the observed transcript counts Y_gc.
# Plate diagram: generative model for gene expression
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch
def _node(ax, xy, label, observed=False, is_param=False, fontsize=9.5):
"""Draw a labeled node (circle or rounded-rect for hyperparams)."""
x, y = xy
if is_param:
kw = dict(boxstyle="round,pad=0.4", fc="#dde8ff", ec="#3355aa", lw=1.6)
elif observed:
kw = dict(boxstyle="circle,pad=0.4", fc="#cccccc", ec="black", lw=1.6)
else:
kw = dict(boxstyle="circle,pad=0.4", fc="white", ec="black", lw=1.6)
ax.text(x, y, label, ha="center", va="center", fontsize=fontsize, bbox=kw, zorder=5)
def _arrow(ax, src, dst, rad=0.0):
"""Draw a directed arrow between two coordinates."""
ax.annotate(
"", xy=dst, xytext=src,
arrowprops=dict(arrowstyle="-|>", color="#333333", lw=1.3,
connectionstyle=f"arc3,rad={rad}"),
zorder=4,
)
fig, ax = plt.subplots(figsize=(10, 5.2))
ax.set_xlim(0, 10)
ax.set_ylim(0, 6)
ax.set_aspect("equal")
ax.axis("off")
# ── Plates ──────────────────────────────────────────────
# Outer plate: G genes
ax.add_patch(FancyBboxPatch((1.4, 0.5), 7.8, 5.1,
boxstyle="square,pad=0", fc="none", ec="#888888", lw=1.5))
ax.text(1.58, 0.72, "g = 1 … G (genes)", fontsize=8.5,
color="#666666", style="italic")
# Inner plate: C cell types (dashed)
ax.add_patch(FancyBboxPatch((4.6, 1.1), 4.0, 4.0,
boxstyle="square,pad=0", fc="none", ec="#888888", lw=1.5, ls="--"))
ax.text(4.78, 1.32, "c = 1 … C (cell types)", fontsize=8.5,
color="#666666", style="italic")
# ── Hyperparameter nodes (outside all plates) ──────────
_node(ax, (0.62, 5.0), "marker\nμ-prior", is_param=True, fontsize=8.2)
_node(ax, (0.62, 3.2), "silence\nμ-prior", is_param=True, fontsize=8.2)
_node(ax, (0.62, 1.6), "hk\nμ-prior", is_param=True, fontsize=8.2)
_node(ax, (9.38, 3.0), "θ-prior", is_param=True, fontsize=8.2)
# ── Random variable nodes ───────────────────────────────
# θ_g — inside outer plate only (one per gene)
_node(ax, (3.3, 3.0), "θ_g")
# μ_gc — inside both plates (one per gene × cell type)
_node(ax, (6.1, 4.2), "μ_gc")
# Y_gc — observed; inside both plates
_node(ax, (7.8, 3.0), "Y_gc", observed=True)
# ── Arrows ──────────────────────────────────────────────
_arrow(ax, (1.02, 5.0), (5.73, 4.44)) # marker prior → μ_gc
_arrow(ax, (1.02, 3.2), (5.73, 4.12)) # silence prior → μ_gc
_arrow(ax, (1.02, 1.6), (5.73, 3.96)) # hk prior → μ_gc
_arrow(ax, (8.98, 3.0), (3.70, 3.0)) # θ prior → θ_g
_arrow(ax, (6.1, 3.87), (7.47, 3.2)) # μ_gc → Y_gc
_arrow(ax, (3.70, 3.0), (7.47, 2.88)) # θ_g → Y_gc
# ── Legend ──────────────────────────────────────────────
legend_items = [
mpatches.Patch(fc="white", ec="black", label="Latent variable"),
mpatches.Patch(fc="#cccccc", ec="black", label="Observed variable (Y_gc = counts)"),
mpatches.Patch(fc="#dde8ff", ec="#3355aa", label="Hyperparameter (prior)"),
mpatches.Patch(fc="none", ec="#888888", label="Plate — repeated structure"),
]
ax.legend(handles=legend_items, loc="lower right", fontsize=8,
frameon=True, framealpha=0.92, edgecolor="#cccccc")
ax.set_title("Generative Model — Plate Diagram", fontsize=12, pad=8)
plt.tight_layout()
plt.show()
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from matplotlib.colors import to_rgba
from STpuppeteer.simulation import (
SimulationConfig, SpotlessSimulator,
get_marker_gamma_prior, get_silence_gamma_prior, get_theta_gamma_prior, get_hk_gamma_prior,
)
CT_PALETTE = {"ct_0": "#49997c", "ct_1": "#1ebecd", "ct_2": "#ae3918"}
plt.rcParams.update({"figure.dpi": 110, "axes.spines.top": False, "axes.spines.right": False})
print("Imports OK")
Imports OK
1.1 Gamma priors for μ and θ¶
All gene parameters are drawn from Gamma distributions. There are two kinds of priors, with intentionally different parameterisations:
μ priors — parameterised by mean and CV¶
The three μ prior functions (get_marker_gamma_prior, get_silence_gamma_prior, get_hk_gamma_prior) all use the same intuitive interface:
| Argument | Meaning |
|---|---|
mu |
Mean of the prior — controls the average expression level drawn |
cv |
Coefficient of variation (std / mean) — controls gene-to-gene spread |
Internally the shape (α) and scale are derived as α = 1/cv², scale = mu/α, so that the Gamma distribution has the requested mean and relative spread.
θ prior — currently parameterised by shape and rate¶
get_theta_gamma_prior uses the raw Gamma parameterisation (alpha, rate), where mean(θ) = alpha / rate. This is different from the μ prior interface above.
Note (future work): Unify
get_theta_gamma_priorto also acceptmu/cvfor consistency with the μ priors.
Gene classes and their priors¶
In the current version, there are three gene classes, each drawing μ from a different prior:
| Class | Role | μ prior in own cell type | μ prior in other cell types |
|---|---|---|---|
| Marker | cell-type identity genes | marker prior (high μ) | silence prior (low μ) |
| Housekeeping | constitutive expression | hk prior (medium μ) | hk prior (medium μ) |
| Silence | background / unexpressed | silence prior (low μ) | silence prior (low μ) |
θ is shared: all genes draw from the same θ prior regardless of gene class or cell type.
With the default configuration the prior distributions look like this:
# TODO: Put in into an enhancement for the future: Can we define celltype-specific mu and theta value? Can we allow celltypes to have shared markers?
# Default prior parameters
marker_mu, marker_cv = 0.75, 0.6
silence_mu, silence_cv = 0.05, 1.0
hk_mu, hk_cv = 0.2, 0.3
theta_alpha, theta_rate = 2.0, 1.0
mu_mkr_prior = get_marker_gamma_prior(mu=marker_mu, cv=marker_cv)
mu_sil_prior = get_silence_gamma_prior(mu=silence_mu, cv=silence_cv)
mu_hk_prior = get_hk_gamma_prior(hk_mu, hk_cv)
theta_prior = get_theta_gamma_prior(alpha=theta_alpha, rate=theta_rate)
# Hide visualization code blocks
# TODO: Reduce the figure into two (mu and theta priors), add HK histograms too)
# TODO: define theta gamma prior in the same manner (mu and cv, instead of rate)
N = 50_000
fig, axs = plt.subplots(1, 2, figsize=(8, 3.5))
# μ comparison
axs[0].hist(mu_mkr_prior.rvs(N), bins=60, color=CT_PALETTE["ct_0"],
alpha=0.7, density=True, label=f"marker μ (mean={marker_mu}, CV={marker_cv})")
axs[0].hist(mu_hk_prior.rvs(N), bins=60, color="#FF9F30",
alpha=0.7, density=True, label=f"housekeeping μ (mean={hk_mu}, CV={hk_cv})")
axs[0].hist(mu_sil_prior.rvs(N), bins=60, color="#999999",
alpha=0.7, density=True, label=f"silence μ (mean={silence_mu}, CV={silence_cv})")
axs[0].set_title("Prior on mean expression μ")
axs[0].set_xlabel("μ")
axs[0].legend(fontsize=8)
# NB dispersion θ
axs[1].hist(theta_prior.rvs(N), bins=60, color=CT_PALETTE["ct_2"], alpha=0.8, density=True)
axs[1].set_title(f"Prior on NB dispersion θ (α={theta_alpha}, rate={theta_rate})")
axs[1].set_xlabel("θ")
fig.suptitle("Examplery Gamma priors, mu and theta", fontsize=12)
fig.tight_layout()
plt.show()
From the histogram, we can tell that there is a nice separation between the two mode (marker on/off) in terms of μ parameters.
1.2 Generate the gene parameter table¶
generate_gene_parameters() samples (μ, θ) for every gene-celltype combination:
- Marker gene of
ct_k→ high μ inct_k, silence μ in all other types - Housekeeping genes → medium μ in all types
config = SimulationConfig(
seed=42, n_cells=200, n_celltype=3, n_genes=300,
n_markers=50, marker_mu=0.75, marker_cv=0.6,
silence_mu=0.05, silence_cv=1.0, theta_alpha=2.0, theta_rate=1.0,
)
sim = SpotlessSimulator(config)
sim.generate_gene_parameters()
gpar_df = sim.gpar_df
print(f"Shape: {gpar_df.shape} (n_genes × 2·n_celltypes + metadata)")
gpar_df.head(4)
Shape: (300, 9) (n_genes × 2·n_celltypes + metadata)
| feature_name | gene_type | ct_0_theta | ct_0_mu | ct_1_theta | ct_1_mu | ct_2_theta | ct_2_mu | gene_leakage | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | ct_0_mkr_0 | ct_0_mkr | 2.091817 | 0.306475 | 0.705229 | 0.039907 | 1.209854 | 0.056850 | 0.0 |
| 1 | ct_0_mkr_1 | ct_0_mkr | 2.835346 | 0.475299 | 0.164149 | 0.009448 | 0.934799 | 0.084951 | 0.0 |
| 2 | ct_0_mkr_2 | ct_0_mkr | 1.837216 | 0.617711 | 2.841613 | 0.236585 | 3.179418 | 0.025210 | 0.0 |
| 3 | ct_0_mkr_3 | ct_0_mkr | 1.645070 | 0.414340 | 0.077426 | 0.032049 | 4.852572 | 0.023355 | 0.0 |
Heatmap of μ across genes × cell types¶
Each row is a gene; each column is a cell type.
The annotation strip at the bottom shows gene class — marker genes form bright vertical stripes in their own cell type's column.
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
mu_cols = [c for c in gpar_df.columns if c.endswith("_mu")]
mu_mtx = gpar_df[mu_cols].to_numpy()
# Gene class labels
def gene_class(gt):
for k in ["ct_0", "ct_1", "ct_2"]:
if f"{k}_mkr" in gt:
return f"{k} marker"
return "housekeeping"
gpar_df["gene_class"] = gpar_df["gene_type"].apply(gene_class)
class_color_map = {
"ct_0 marker": CT_PALETTE["ct_0"], "ct_1 marker": CT_PALETTE["ct_1"],
"ct_2 marker": CT_PALETTE["ct_2"], "housekeeping": "#aaaaaa",
}
fig = plt.figure(figsize=(15, 5.5))
gs = GridSpec(1, 3, figure=fig, wspace=0.38)
# Left panel: heatmap (row 0), gene-class strip (row 1), colorbar+legend (row 2)
gs_left = GridSpecFromSubplotSpec(3, 1, subplot_spec=gs[0],
height_ratios=[13, 1, 2.5], hspace=0.04)
ax_hmap = fig.add_subplot(gs_left[0])
ax_strip = fig.add_subplot(gs_left[1])
ax_cbl = fig.add_subplot(gs_left[2]) # colorbar + legend axis
ax_bp0 = fig.add_subplot(gs[1])
ax_bp1 = fig.add_subplot(gs[2])
# ── Heatmap (no built-in colorbar) ──────────────────────────────────────────
sns.heatmap(mu_mtx.T, ax=ax_hmap, cmap="YlOrRd",
xticklabels=False, yticklabels=["ct_0", "ct_1", "ct_2"],
cbar=False)
ax_hmap.set_title("Mean expression μ (cell types × genes)")
ax_hmap.tick_params(bottom=False)
ax_hmap.set_xlabel("")
# ── Gene-class colour strip ──────────────────────────────────────────────────
strip_rgba = np.array([to_rgba(class_color_map[g])
for g in gpar_df["gene_class"]])[np.newaxis, :, :]
ax_strip.imshow(strip_rgba, aspect="auto", interpolation="none")
ax_strip.set_xticks([]); ax_strip.set_yticks([])
for sp in ax_strip.spines.values():
sp.set_visible(False)
ax_strip.set_xlabel("← genes →", fontsize=8, labelpad=3)
# ── Combined colorbar + legend in ax_cbl ─────────────────────────────────────
ax_cbl.set_axis_off()
# Colorbar: small inset axes on the left of ax_cbl
cbar_inset = inset_axes(ax_cbl, width="45%", height="35%",
loc="center left", borderpad=0)
norm = Normalize(vmin=mu_mtx.min(), vmax=mu_mtx.max())
sm = ScalarMappable(cmap="YlOrRd", norm=norm)
sm.set_array([])
fig.colorbar(sm, cax=cbar_inset, orientation="horizontal", label="μ")
cbar_inset.tick_params(labelsize=7)
# Gene class legend: right half of ax_cbl
strip_handles = [mpatches.Patch(color=v, label=k) for k, v in class_color_map.items()]
ax_cbl.legend(handles=strip_handles, ncol=2, fontsize=7, frameon=False,
loc="center right", bbox_to_anchor=(1.0, 0.5))
# ── Boxplots ─────────────────────────────────────────────────────────────────
order = list(class_color_map.keys())
fp = {"marker": ".", "markersize": 3}
for ax, ct, title in [(ax_bp0, "ct_0_mu", "μ in ct_0"), (ax_bp1, "ct_1_mu", "μ in ct_1")]:
sns.boxplot(data=gpar_df, x="gene_class", y=ct, order=order,
palette=class_color_map, width=0.5, ax=ax, flierprops=fp)
ax.set_title(title + " — own markers express")
ax.set_xlabel(""); ax.tick_params(axis="x", rotation=18)
fig.suptitle("Gene parameter table — marker genes express only in their own cell type", fontsize=12)
plt.show()
1.3 Effect of marker_mu — signal strength¶
marker_mu sets the mean of the Gamma prior used to sample marker gene means.
Higher → stronger cell-type signal (larger separation between marker and silence distributions).
mu_values = [0.3, 0.5, 0.75, 1.5]
fig, axs = plt.subplots(1, len(mu_values), figsize=(14, 3.5), sharey=False, sharex=True)
for ax, mu in zip(axs, mu_values):
cfg = SimulationConfig(seed=0, n_cells=50, n_celltype=3, n_genes=500, n_markers=125,
marker_mu=mu, silence_mu=0.05)
s = SpotlessSimulator(cfg)
s.generate_gene_parameters()
df = s.gpar_df.copy()
df["gene_class"] = df["gene_type"].apply(gene_class)
for cls, color in class_color_map.items():
vals = df.loc[df["gene_class"] == cls, "ct_0_mu"]
if len(vals):
ax.hist(vals, bins=10, color=color, alpha=0.65, density=True, label=cls)
ax.set_title(f"marker_mu = {mu}")
ax.set_xlabel("μ in ct_0")
axs[0].set_ylabel("Density")
axs[-1].legend(fontsize=8, title="Gene class")
fig.suptitle("marker_mu controls signal strength (separation from silence)", fontsize=12)
fig.tight_layout()
plt.show()
1.4 Effect of theta_alpha — overdispersion¶
θ is the NB dispersion parameter: lower θ = more overdispersion (more variance relative to mean).
theta_alpha is the shape of the Gamma prior on θ. Lower theta_alpha → prior mass at small θ → more overdispersed genes.
# TODO: Change this into a mu-cv based expression
theta_alphas = [0.5, 1.0, 2.0, 5.0]
fig, axs = plt.subplots(1, len(theta_alphas), figsize=(14, 3.5), sharey=True, sharex=True)
rng = np.random.default_rng(0)
MU = 1.0 # fixed mean — vary only dispersion
for ax, ta in zip(axs, theta_alphas):
theta_pr = get_theta_gamma_prior(alpha=ta, rate=1.0)
thetas = theta_pr.rvs(500)
# Simulate NB counts: Poisson(Gamma(θ, μ/θ))
counts = []
for th in thetas:
lam = rng.gamma(shape=th, scale=MU / th)
counts.append(rng.poisson(lam))
ax.hist(counts, bins=range(0, 15), color="#555599", alpha=0.8, edgecolor="white")
ax.set_title(f"theta_alpha = {ta}\nVar/Mean = {np.var(counts)/max(np.mean(counts),1e-9):.2f}")
ax.set_xlabel("Simulated count")
ax.set_xlim(-0.5, 14)
axs[0].set_ylabel("# genes")
fig.suptitle("theta_alpha — lower = more overdispersion (wider count distribution)", fontsize=12)
fig.tight_layout()
plt.show()
Summary¶
| Parameter | Effect | Typical range |
|---|---|---|
marker_mu |
Mean expression of marker genes — controls signal strength | 0.3 – 2.0 |
marker_cv |
Variability of marker μ across genes of the same class | 0.3 – 1.0 |
silence_mu |
Background expression in non-expressing cell types | 0.01 – 0.1 |
theta_alpha |
Shape of θ prior — lower = more overdispersed counts | 0.5 – 5.0 |
Next: 02_cell_generation.ipynb — how spatial positions and nucleus shapes are generated.