"""
Perturbation simulation using GRN signal propagation.
This module implements CellOracle-style perturbation simulations using
the scHopfield GRN framework. It simulates how gene expression changes
propagate through the inferred gene regulatory network.
References
----------
Logic for the transition vector field is inspired by the perturbation
simulation workflow in CellOracle:
Kamimoto et al. (2023). Nature. https://doi.org/10.1038/s41586-022-05688-9
"""
import numpy as np
import pandas as pd
from typing import Dict, Optional, Union, List, Tuple
from anndata import AnnData
from tqdm.auto import tqdm
from .._utils.math import sigmoid
from .._utils.io import get_matrix, to_numpy, get_genes_used
from ._utils import _parse_perturb_genes, _get_W_matrix, _compute_x_bounds, _update_scHopfield_uns
from ..tools.perturbation_analysis import compute_lineage_bias, compute_cluster_effects
def _propagate_signal(
X_current: np.ndarray,
X_original: np.ndarray,
W: np.ndarray,
source_indices: np.ndarray,
threshold: np.ndarray,
exponent: np.ndarray,
dt: float = 1.0,
x_min: float = 0.0,
x_max: Optional[np.ndarray] = None
) -> np.ndarray:
"""
Propagate signal through the GRN for one step.
Computes the effect of source genes on all other genes using:
x_i^new = x_i^current + dt * sum_k W_ik * (sigmoid_k(x_k^current) - sigmoid_k(x_k^original))
Where k iterates over the source genes (source_indices).
Parameters
----------
X_current : np.ndarray
Current expression matrix (n_cells, n_genes)
X_original : np.ndarray
Original expression matrix (n_cells, n_genes)
W : np.ndarray
Interaction matrix (n_genes, n_genes), W[i,k] = effect of gene k on gene i
source_indices : np.ndarray
Indices of source genes to propagate from
threshold : np.ndarray
Sigmoid threshold parameters for all genes
exponent : np.ndarray
Sigmoid exponent parameters for all genes
dt : float, optional (default: 1.0)
Scaling factor for the propagation step
x_min : float, optional (default: 0.0)
Minimum expression value (non-negative constraint)
x_max : np.ndarray, optional
Maximum expression values per gene. If None, no upper bound.
Returns
-------
np.ndarray
Updated expression matrix after one propagation step
"""
# Compute sigmoid of current expression for source genes
sig_current = sigmoid(
X_current[:, source_indices],
threshold[source_indices],
exponent[source_indices]
)
# Compute sigmoid of original expression for source genes
sig_original = sigmoid(
X_original[:, source_indices],
threshold[source_indices],
exponent[source_indices]
)
# Compute delta sigmoid: sigmoid(x^current) - sigmoid(x^original)
delta_sig = sig_current - sig_original # (n_cells, n_source)
# Get W columns for source genes: W[:, source_indices]
W_source = W[:, source_indices] # (n_genes, n_source)
# Compute delta_X for this step:
# delta_X_i = dt * sum_k W[i,k] * delta_sig[k]
delta_X_step = dt * (delta_sig @ W_source.T) # (n_cells, n_genes)
# Update expression
X_new = X_current + delta_X_step
# Clip to valid range (prevents divergence)
X_new = np.maximum(X_new, x_min)
if x_max is not None:
X_new = np.minimum(X_new, x_max)
return X_new
def _get_tf_indices(W: np.ndarray) -> np.ndarray:
"""
Get indices of transcription factors (genes with outgoing edges in GRN).
Parameters
----------
W : np.ndarray
Interaction matrix (n_genes, n_genes), W[i,k] = effect of gene k on gene i
Returns
-------
np.ndarray
Indices of genes that have at least one non-zero outgoing edge
"""
# TFs are genes that regulate at least one other gene (non-zero column sum)
has_targets = np.abs(W).sum(axis=0) > 0
return np.where(has_targets)[0]
[docs]
def simulate_perturbation(
adata: AnnData,
perturb_condition: Dict[str, float],
cluster_key: str = 'cell_type',
target_clusters: Optional[List[str]] = None,
n_propagation: int = 3,
dt: float = 1.0,
use_cluster_specific_GRN: bool = True,
clip_delta_X: bool = True,
x_max_percentile: float = 99.0,
residual_gene_dynamics: bool = False,
verbose: bool = True
) -> AnnData:
"""
Simulate gene expression changes after perturbation using direct GRN effects.
Computes the effect of perturbed genes on all other genes using iterative
signal propagation:
x_i^new = x_i^current + dt * sum_k W_ik * (sigmoid_k(x_k^current) - sigmoid_k(x_k^original))
The propagation works as follows:
- Step 1: Only the manually perturbed genes propagate their effects
- Steps 2+: All TFs (genes with outgoing edges in the GRN) that have changed
from their original state propagate their effects
This captures the cascade where perturbed genes affect other TFs, which
then also contribute to further propagation through the network.
Parameters
----------
adata : AnnData
Annotated data object with fitted interactions (W matrices)
perturb_condition : dict
Perturbation conditions as {gene_name: value}.
Examples:
- Knockout: {"Gata1": 0.0}
- Overexpression: {"Gata1": 5.0}
- Multiple: {"Gata1": 0.0, "Tal1": 2.0}
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
target_clusters : list of str, optional
List of cluster names to simulate perturbation in.
If None, simulates in all clusters.
Cells not in target clusters will have delta_X = 0.
n_propagation : int, optional (default: 3)
Number of signal propagation steps through the GRN.
Higher values capture more indirect effects.
dt : float, optional (default: 1.0)
Scaling factor for each propagation step.
use_cluster_specific_GRN : bool, optional (default: True)
If True, uses cluster-specific W matrices.
If False, uses the 'all' W matrix for all cells.
clip_delta_X : bool, optional (default: True)
If True, clips final simulated values to the observed expression range
to avoid out-of-distribution predictions.
x_max_percentile : float, optional (default: 99.0)
Percentile of expression to use as upper bound during propagation.
This prevents divergence by clipping values at each step.
Set to None to disable step-wise upper bound clipping.
residual_gene_dynamics : bool, optional (default: False)
If False, perturbed genes are held fixed at their perturbed values
throughout all propagation steps.
If True, perturbed genes can change according to the GRN dynamics
after the initial perturbation is applied.
verbose : bool, optional (default: True)
Whether to show progress information.
Returns
-------
AnnData
Modified adata with added layers:
- 'simulated_count': Simulated gene expression after perturbation
- 'delta_X': Difference between simulated and original expression
And added to adata.uns['scHopfield']:
- 'perturb_condition': The perturbation conditions used
- 'n_propagation': Number of propagation steps
- 'dt': Scaling factor used
References
----------
Logic for the transition vector field is inspired by the perturbation
simulation workflow in CellOracle:
Kamimoto et al. (2023). Nature. https://doi.org/10.1038/s41586-022-05688-9
Examples
--------
>>> import scHopfield as sch
>>> # Knockout simulation
>>> sch.dyn.simulate_perturbation(adata, {"Gata1": 0.0})
>>> # Overexpression
>>> sch.dyn.simulate_perturbation(adata, {"Gata1": 5.0})
>>> # Check results
>>> delta = adata.layers['delta_X']
"""
# Validate perturbation conditions
_validate_perturb_condition(adata, perturb_condition, verbose=verbose)
# Get gene indices used in scHopfield
genes = get_genes_used(adata)
gene_names = adata.var_names[genes].values
# Get base expression from spliced layer
spliced_key = adata.uns.get('scHopfield', {}).get('spliced_key', 'Ms')
base_expression = to_numpy(get_matrix(adata, spliced_key, genes=genes))
# Get sigmoid parameters
threshold = adata.var['sigmoid_threshold'].values[genes]
exponent = adata.var['sigmoid_exponent'].values[genes]
# Compute expression bounds for stability
x_min, x_max = _compute_x_bounds(base_expression, x_max_percentile, multiplier=2.0)
# Get indices and values of perturbed genes
perturb_indices, perturb_values = _parse_perturb_genes(gene_names, perturb_condition)
# Get clusters to simulate
all_clusters = adata.obs[cluster_key].unique()
if target_clusters is not None:
# Validate target clusters
invalid_clusters = set(target_clusters) - set(all_clusters)
if invalid_clusters:
raise ValueError(f"Target clusters not found in data: {invalid_clusters}")
clusters = [c for c in target_clusters if c in all_clusters]
if verbose:
print(f"Simulating perturbation in {len(clusters)} target clusters: {clusters}")
else:
clusters = all_clusters
# Initialize simulated array with base expression
simulated = base_expression.copy()
# Run simulation for each cluster
if verbose:
cluster_iter = tqdm(clusters, desc="Simulating perturbation")
else:
cluster_iter = clusters
for cluster in cluster_iter:
# Get cells in this cluster
cluster_mask = (adata.obs[cluster_key] == cluster).values
n_cells_cluster = cluster_mask.sum()
if n_cells_cluster == 0:
continue
# Get cluster-specific or global W matrix
W = _get_W_matrix(adata, cluster, use_cluster_specific=use_cluster_specific_GRN)
# Get TF indices for this cluster (genes with outgoing edges)
tf_indices = _get_tf_indices(W)
# Get expression for this cluster
X_current = simulated[cluster_mask, :].copy()
X_original = base_expression[cluster_mask, :].copy()
# Apply initial perturbation: set perturbed genes to their target values
X_current[:, perturb_indices] = perturb_values[None, :]
# Iterative propagation
for step in range(n_propagation):
if step == 0:
# First step: only propagate from manually perturbed genes
source_indices = perturb_indices
else:
# Subsequent steps: propagate from all TFs
source_indices = tf_indices
# Propagate signal
X_current = _propagate_signal(
X_current=X_current,
X_original=X_original,
W=W,
source_indices=source_indices,
threshold=threshold,
exponent=exponent,
dt=dt,
x_min=x_min,
x_max=x_max
)
# Keep perturbed genes fixed at their perturbed values (unless residual dynamics allowed)
if not residual_gene_dynamics:
X_current[:, perturb_indices] = perturb_values[None, :]
# Store results
simulated[cluster_mask, :] = X_current
# Clip to observed range if requested
if clip_delta_X:
min_vals = base_expression.min(axis=0)
max_vals = base_expression.max(axis=0)
simulated = np.clip(simulated, min_vals, max_vals)
# Compute delta_X
delta_X = simulated - base_expression
# Store results
_store_layer(adata, simulated, 'simulated_count', genes)
_store_layer(adata, delta_X, 'delta_X', genes)
# Store metadata
_update_scHopfield_uns(adata, perturb_condition=perturb_condition,
n_propagation=n_propagation, dt=dt)
if verbose:
print("Perturbation simulation complete")
print(f" Genes perturbed: {list(perturb_condition.keys())}")
print(f" Propagation steps: {n_propagation}")
print(f" dt (scaling): {dt}")
if residual_gene_dynamics:
print(" Perturbed genes: can evolve (residual_gene_dynamics=True)")
else:
print(" Perturbed genes: held constant")
print(" Results stored in adata.layers['simulated_count'] and adata.layers['delta_X']")
return adata
def _validate_perturb_condition(
adata: AnnData,
perturb_condition: Dict[str, float],
verbose: bool = True
) -> None:
"""Validate perturbation conditions."""
genes = get_genes_used(adata)
gene_names = adata.var_names[genes].values
gene_to_idx = {name: i for i, name in enumerate(gene_names)}
for gene, value in perturb_condition.items():
# Check gene exists
if gene not in adata.var_names:
raise ValueError(f"Gene '{gene}' not found in adata.var_names")
# Check gene is in scHopfield analysis
if gene not in gene_names:
raise ValueError(f"Gene '{gene}' was not included in scHopfield analysis. "
f"Check adata.var['scHopfield_used']")
# Check value is non-negative
if value < 0:
raise ValueError(f"Perturbation value must be non-negative. Got {value} for '{gene}'")
# Warn if value is far from observed range
gene_idx = gene_to_idx[gene]
spliced_key = adata.uns.get('scHopfield', {}).get('spliced_key', 'Ms')
expr = to_numpy(get_matrix(adata, spliced_key, genes=[genes[gene_idx]])).flatten()
min_val, max_val = expr.min(), expr.max()
if value < min_val * 0.5 or value > max_val * 2:
if verbose:
print(f" Warning: Perturbation value {value} for '{gene}' is outside "
f"typical range [{min_val:.2f}, {max_val:.2f}]")
def _store_layer(
adata: AnnData,
data: np.ndarray,
layer_name: str,
gene_indices: np.ndarray
) -> None:
"""Store data as a layer, expanding to full gene space."""
full_data = np.zeros((adata.n_obs, adata.n_vars), dtype=data.dtype)
full_data[:, gene_indices] = data
adata.layers[layer_name] = full_data
[docs]
def calculate_perturbation_effect_scores(
adata: AnnData,
cluster_key: str = 'cell_type',
method: str = 'mean'
) -> pd.DataFrame:
"""
Calculate perturbation effect scores per cluster.
Summarizes the delta_X values by cluster to quantify the overall
effect of the perturbation on each cell population.
Parameters
----------
adata : AnnData
Annotated data object with simulation results (delta_X layer)
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
method : str, optional (default: 'mean')
How to summarize effects: 'mean', 'median', 'max', or 'norm'
- 'mean': Mean absolute delta_X
- 'median': Median absolute delta_X
- 'max': Maximum absolute delta_X
- 'norm': L2 norm of delta_X vector
Returns
-------
pd.DataFrame
DataFrame with clusters as index and genes as columns,
containing the summarized perturbation effects.
"""
if 'delta_X' not in adata.layers:
raise ValueError("No simulation results found. Run simulate_perturbation first.")
genes = get_genes_used(adata)
gene_names = adata.var_names[genes].values
delta_X = adata.layers['delta_X'][:, genes]
clusters = adata.obs[cluster_key].unique()
results = {}
for cluster in clusters:
cluster_mask = (adata.obs[cluster_key] == cluster).values
delta_cluster = delta_X[cluster_mask, :]
if method == 'mean':
score = np.abs(delta_cluster).mean(axis=0)
elif method == 'median':
score = np.median(np.abs(delta_cluster), axis=0)
elif method == 'max':
score = np.abs(delta_cluster).max(axis=0)
elif method == 'norm':
score = np.linalg.norm(delta_cluster, axis=0) / delta_cluster.shape[0]
else:
raise ValueError(f"Unknown method: {method}")
results[cluster] = score
return pd.DataFrame(results, index=gene_names).T
[docs]
def calculate_cell_transition_scores(
adata: AnnData,
cluster_key: str = 'cell_type',
basis: str = 'umap'
) -> pd.DataFrame:
"""
Calculate cell transition scores based on delta_X magnitude.
This measures how much each cell's state changes due to perturbation,
which can indicate cells most affected by the perturbation.
Parameters
----------
adata : AnnData
Annotated data object with simulation results
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
basis : str, optional (default: 'umap')
Embedding basis for potential vector field visualization
Returns
-------
pd.DataFrame
DataFrame with cell-level transition scores
"""
if 'delta_X' not in adata.layers:
raise ValueError("No simulation results found. Run simulate_perturbation first.")
genes = get_genes_used(adata)
delta_X = adata.layers['delta_X'][:, genes]
# Compute magnitude of change for each cell
magnitude = np.linalg.norm(delta_X, axis=1)
# Store in obs
adata.obs['perturbation_magnitude'] = magnitude
# Summarize by cluster
summary = pd.DataFrame({
'cluster': adata.obs[cluster_key].values,
'magnitude': magnitude
})
cluster_summary = summary.groupby('cluster').agg(['mean', 'std', 'max']).round(4)
cluster_summary.columns = ['mean_magnitude', 'std_magnitude', 'max_magnitude']
return cluster_summary
[docs]
def get_top_affected_genes(
adata: AnnData,
n_genes: int = 20,
cluster: Optional[str] = None,
cluster_key: str = 'cell_type',
exclude_perturbed: bool = True
) -> pd.DataFrame:
"""
Get the top genes most affected by the perturbation.
Parameters
----------
adata : AnnData
Annotated data object with simulation results
n_genes : int, optional (default: 20)
Number of top genes to return
cluster : str, optional
If specified, analyze only cells in this cluster
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
exclude_perturbed : bool, optional (default: True)
If True, exclude the perturbed genes from the results
Returns
-------
pd.DataFrame
DataFrame with top affected genes and their mean delta_X values
"""
if 'delta_X' not in adata.layers:
raise ValueError("No simulation results found. Run simulate_perturbation first.")
genes = get_genes_used(adata)
gene_names = adata.var_names[genes].values
# Exclude perturbed genes if requested
if exclude_perturbed and 'scHopfield' in adata.uns and 'perturb_condition' in adata.uns['scHopfield']:
perturbed_genes = list(adata.uns['scHopfield']['perturb_condition'].keys())
gene_mask = ~np.isin(gene_names, perturbed_genes)
gene_names = gene_names[gene_mask]
genes_filtered = genes[gene_mask]
else:
genes_filtered = genes
if cluster is not None:
mask = (adata.obs[cluster_key] == cluster).values
delta_X = adata.layers['delta_X'][mask, :][:, genes_filtered]
else:
delta_X = adata.layers['delta_X'][:, genes_filtered]
# Mean change per gene
mean_delta = delta_X.mean(axis=0)
abs_mean_delta = np.abs(mean_delta)
# Get top genes
top_idx = np.argsort(abs_mean_delta)[-n_genes:][::-1]
df = pd.DataFrame({
'gene': gene_names[top_idx],
'mean_delta_X': mean_delta[top_idx],
'abs_mean_delta_X': abs_mean_delta[top_idx],
'direction': ['up' if d > 0 else 'down' for d in mean_delta[top_idx]]
})
return df
[docs]
def compare_perturbations(
adata: AnnData,
perturbations: Union[Dict[str, Dict[str, float]], List[Dict[str, float]]],
labels: Optional[List[str]] = None,
cluster_key: str = 'cell_type',
target_clusters: Optional[List[str]] = None,
n_propagation: int = 3,
dt: float = 1.0,
verbose: bool = True
) -> pd.DataFrame:
"""
Compare multiple perturbation conditions.
Parameters
----------
adata : AnnData
Annotated data object with fitted interactions
perturbations : dict or list
Either:
- Dict mapping labels to perturbation conditions: {"KO": {"Gata1": 0.0}, "OE": {"Gata1": 1.0}}
- List of perturbation conditions (requires labels parameter)
labels : list of str, optional
Labels for each perturbation. Required if perturbations is a list.
Ignored if perturbations is a dict (keys are used as labels).
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
target_clusters : list of str, optional
List of cluster names to simulate perturbation in.
If None, simulates in all clusters.
n_propagation : int, optional (default: 3)
Number of propagation steps
dt : float, optional (default: 1.0)
Scaling factor for each propagation step
verbose : bool, optional (default: True)
Whether to show progress
Returns
-------
pd.DataFrame
DataFrame with genes as index and mean |delta_X| for each perturbation condition
References
----------
Logic for the transition vector field is inspired by the perturbation
simulation workflow in CellOracle:
Kamimoto et al. (2023). Nature. https://doi.org/10.1038/s41586-022-05688-9
"""
# Handle dict input: {label: perturbation_condition}
if isinstance(perturbations, dict):
labels = list(perturbations.keys())
perturbations_list = list(perturbations.values())
else:
perturbations_list = perturbations
if labels is None:
labels = [f"perturb_{i+1}" for i in range(len(perturbations_list))]
assert len(labels) == len(perturbations_list), "Number of labels must match perturbations"
genes = get_genes_used(adata)
gene_names = adata.var_names[genes].values
all_deltas = {}
for label, perturb in zip(labels, perturbations_list):
if verbose:
print(f"\nRunning simulation for: {label}")
print(f" Condition: {perturb}")
# Run simulation
simulate_perturbation(
adata, perturb,
cluster_key=cluster_key,
target_clusters=target_clusters,
n_propagation=n_propagation,
dt=dt,
verbose=False
)
# Get mean |delta_X| per gene
delta_X = adata.layers['delta_X'][:, genes]
mean_abs_delta = np.abs(delta_X).mean(axis=0)
all_deltas[label] = mean_abs_delta
# Combine into DataFrame with genes as index
result = pd.DataFrame(all_deltas, index=gene_names)
# Sort by total effect across conditions
result['_total'] = result.sum(axis=1)
result = result.sort_values('_total', ascending=False)
result = result.drop('_total', axis=1)
return result
# ---------------------------------------------------------------------------
# High-level KO screen helpers
# ---------------------------------------------------------------------------
def run_ko_screen(
adata: AnnData,
genes: List[str],
lineage_A_clusters: List[str],
lineage_B_clusters: List[str],
basis: str,
wt_flow_key: str,
cluster_key: str = 'cell_type',
cluster_order: Optional[List[str]] = None,
simulate_kwargs: Optional[Dict] = None,
verbose: bool = True,
) -> Tuple[Dict[str, Dict[str, float]], Dict[str, pd.Series]]:
"""
Run a single-gene KO screen and compute lineage bias + cluster effects.
For each gene in ``genes``, performs an ODE-based KO simulation
(``simulate_shift_ode``), then computes:
- **lineage bias** via :func:`~scHopfield.tools.compute_lineage_bias`
- **cluster effects** via :func:`~scHopfield.tools.compute_cluster_effects`
Parameters
----------
adata : AnnData
Base (WT) AnnData with fitted model. Each gene is simulated on a
copy so the original object is not modified.
genes : list of str
Gene names to screen. Genes absent from ``adata.var_names`` are skipped.
lineage_A_clusters : list of str
Cluster names for lineage A (e.g. erythroid).
lineage_B_clusters : list of str
Cluster names for lineage B (e.g. myeloid).
basis : str
Embedding basis for flow projection (e.g. ``'draw_graph_fa'``).
wt_flow_key : str
Key in ``adata.obsm`` for the pre-computed WT Hopfield velocity.
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels.
cluster_order : list of str, optional
Ordered cluster names for ``compute_cluster_effects``.
If None, uses ``adata.obs[cluster_key].unique()``.
simulate_kwargs : dict, optional
Extra keyword arguments forwarded to ``simulate_shift_ode``.
Defaults: ``dt=5.0, n_steps=100, use_cluster_specific_GRN=True, n_jobs=-1``.
verbose : bool, optional (default: True)
Print progress for each gene.
Returns
-------
bias_dict : dict[str, dict]
``{gene: {'score_A', 'score_B', 'lineage_bias'}}``
effects_dict : dict[str, pd.Series]
``{gene: pd.Series(mean |delta_X| per cluster)}``
Examples
--------
>>> bias, effects = sch.dyn.run_ko_screen(
... adata, CANDIDATES, ERYTHROID, MYELOID,
... basis='draw_graph_fa', wt_flow_key='original_velocity_flow_draw_graph_fa',
... cluster_key='paul15_clusters', cluster_order=CLUSTER_ORDER,
... )
>>> bias_df = pd.DataFrame(bias).T.sort_values('lineage_bias', ascending=False)
"""
from .simulation import simulate_shift_ode
if simulate_kwargs is None:
simulate_kwargs = {}
sim_kw = dict(
cluster_key=cluster_key,
dt=5.0,
n_steps=100,
use_cluster_specific_GRN=True,
n_jobs=-1,
verbose=False,
)
sim_kw.update(simulate_kwargs)
if cluster_order is None:
cluster_order = list(adata.obs[cluster_key].unique())
bias_dict = {}
effects_dict = {}
for gene in genes:
if gene not in adata.var_names:
if verbose:
print(f" Skip {gene}: not in adata")
continue
if verbose:
print(f" KO: {gene}...")
adata_ko = simulate_shift_ode(
adata.copy(),
perturb_condition={gene: 0.0},
**sim_kw,
)
bias_dict[gene] = compute_lineage_bias(
adata_ko, adata,
lineage_A_clusters, lineage_B_clusters,
basis, wt_flow_key,
cluster_key=cluster_key,
)
effects_dict[gene] = compute_cluster_effects(
adata_ko, cluster_order, cluster_key=cluster_key
)
if verbose:
print(f"\nCompleted {len(bias_dict)} single KOs.")
return bias_dict, effects_dict
def run_pairwise_ko_screen(
adata: AnnData,
pairs: List[Tuple[str, str]],
lineage_A_clusters: List[str],
lineage_B_clusters: List[str],
basis: str,
wt_flow_key: str,
cluster_key: str = 'cell_type',
cluster_order: Optional[List[str]] = None,
simulate_kwargs: Optional[Dict] = None,
verbose: bool = True,
) -> Tuple[Dict[Tuple[str, str], Dict[str, float]], Dict[Tuple[str, str], pd.Series]]:
"""
Run a pairwise KO screen and compute lineage bias + cluster effects.
Parameters
----------
adata : AnnData
Base (WT) AnnData with fitted model.
pairs : list of (str, str)
Gene-name tuples to screen. Pairs where either gene is absent are skipped.
lineage_A_clusters : list of str
Cluster names for lineage A.
lineage_B_clusters : list of str
Cluster names for lineage B.
basis : str
Embedding basis for flow projection.
wt_flow_key : str
Key in ``adata.obsm`` for the pre-computed WT Hopfield velocity.
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels.
cluster_order : list of str, optional
Ordered cluster names for ``compute_cluster_effects``.
If None, uses ``adata.obs[cluster_key].unique()``.
simulate_kwargs : dict, optional
Extra keyword arguments forwarded to ``simulate_shift_ode``.
Defaults: ``dt=5.0, n_steps=100, use_cluster_specific_GRN=True, n_jobs=-1``.
verbose : bool, optional (default: True)
Print progress for each pair.
Returns
-------
bias_dict : dict[(str, str), dict]
``{(geneA, geneB): {'score_A', 'score_B', 'lineage_bias'}}``
effects_dict : dict[(str, str), pd.Series]
``{(geneA, geneB): pd.Series(mean |delta_X| per cluster)}``
Examples
--------
>>> import itertools
>>> cross_pairs = list(itertools.product(top5_ery, top5_mye))
>>> bias, effects = sch.dyn.run_pairwise_ko_screen(
... adata, cross_pairs, ERYTHROID, MYELOID,
... basis='draw_graph_fa', wt_flow_key='original_velocity_flow_draw_graph_fa',
... cluster_key='paul15_clusters', cluster_order=CLUSTER_ORDER,
... )
"""
from .simulation import simulate_shift_ode
if simulate_kwargs is None:
simulate_kwargs = {}
sim_kw = dict(
cluster_key=cluster_key,
dt=5.0,
n_steps=100,
use_cluster_specific_GRN=True,
n_jobs=-1,
verbose=False,
)
sim_kw.update(simulate_kwargs)
if cluster_order is None:
cluster_order = list(adata.obs[cluster_key].unique())
bias_dict = {}
effects_dict = {}
for geneA, geneB in pairs:
if geneA == geneB:
continue
if geneA not in adata.var_names or geneB not in adata.var_names:
if verbose:
print(f" Skip ({geneA}, {geneB}): gene not in adata")
continue
if verbose:
print(f" KO pair: ({geneA}, {geneB})...")
adata_pair = simulate_shift_ode(
adata.copy(),
perturb_condition={geneA: 0.0, geneB: 0.0},
**sim_kw,
)
bias_dict[(geneA, geneB)] = compute_lineage_bias(
adata_pair, adata,
lineage_A_clusters, lineage_B_clusters,
basis, wt_flow_key,
cluster_key=cluster_key,
)
effects_dict[(geneA, geneB)] = compute_cluster_effects(
adata_pair, cluster_order, cluster_key=cluster_key
)
if verbose:
print(f"\nCompleted {len(bias_dict)} pairwise KOs.")
return bias_dict, effects_dict
def compute_synergy(
pair_bias: Dict[str, float],
single_bias_A: Dict[str, float],
single_bias_B: Dict[str, float],
) -> float:
"""
Compute synergy between a gene pair KO and individual KOs.
Synergy measures whether the pair produces a stronger directional
lineage bias than either gene alone:
``synergy = |pair_lineage_bias| - max(|single_A_lineage_bias|, |single_B_lineage_bias|)``
Positive synergy → pair amplifies directional bias beyond either single KO.
Negative synergy → pair is redundant with one of the single KOs.
Parameters
----------
pair_bias : dict
Bias dict for the double KO, e.g. from ``run_pairwise_ko_screen``.
Must contain key ``'lineage_bias'``.
single_bias_A : dict
Bias dict for the single KO of gene A.
Must contain key ``'lineage_bias'``.
single_bias_B : dict
Bias dict for the single KO of gene B.
Must contain key ``'lineage_bias'``.
Returns
-------
float
Synergy score. Positive = synergistic, negative = redundant.
Examples
--------
>>> syn = sch.dyn.compute_synergy(
... pair_bias=pair_ko_bias[('Gata1', 'Spi1')],
... single_bias_A=single_ko_bias['Gata1'],
... single_bias_B=single_ko_bias['Spi1'],
... )
"""
bias_pair = abs(pair_bias.get('lineage_bias', np.nan))
bias_singleA = abs(single_bias_A.get('lineage_bias', np.nan))
bias_singleB = abs(single_bias_B.get('lineage_bias', np.nan))
return float(bias_pair - max(bias_singleA, bias_singleB))
def compute_epistasis(
pair_ko_bias: Dict,
single_ko_bias,
lineage_A_genes: Optional[List[str]] = None,
lineage_B_genes: Optional[List[str]] = None,
) -> pd.DataFrame:
"""
Compute epistasis metrics for all pairwise KO results.
For each gene pair (A, B) computes:
- **cancellation_error**: ``actual_bias - (bias_A + bias_B)`` — deviation
from the additive expectation (Bliss independence on lineage bias).
- **synergy_ery** / **synergy_mye**: per-lineage score above the best
single agent (HSA-style).
- **dominant_epistasis**: the synergy value with the larger absolute
magnitude, preserving sign.
Parameters
----------
pair_ko_bias : dict
``{(geneA, geneB): {'score_A', 'score_B', 'lineage_bias'}}``
from ``run_pairwise_ko_screen``.
single_ko_bias : dict or pd.DataFrame
``{gene: {'score_A', 'score_B', 'lineage_bias'}}`` for all single KOs.
A DataFrame indexed by gene name with these columns is also accepted.
lineage_A_genes : list of str, optional
Genes in lineage A; used only to classify pair type (``'ery-ery'``,
``'cross'``, etc.). If None, all pairs are labelled ``'unknown'``.
lineage_B_genes : list of str, optional
Genes in lineage B; used only for pair type classification.
Returns
-------
pd.DataFrame
Indexed by ``'geneA+geneB'`` pair string, sorted by ``lineage_bias``
descending. Columns: ``geneA``, ``geneB``, ``score_A``, ``score_B``,
``lineage_bias``, ``expected_bias``, ``cancellation_error``,
``synergy_ery``, ``synergy_mye``, ``dominant_epistasis``,
``pair_type``.
Examples
--------
>>> pair_df = sch.dyn.compute_epistasis(
... pair_ko_bias, single_ko_bias,
... lineage_A_genes=top5_ery, lineage_B_genes=top5_mye,
... )
"""
def _max_magnitude(a, b):
return a if abs(a) >= abs(b) else b
def _get(sko, gene, key):
if isinstance(sko, pd.DataFrame):
return float(sko.loc[gene, key]) if gene in sko.index else 0.0
return float(sko.get(gene, {}).get(key, 0.0))
ery_genes = list(lineage_A_genes) if lineage_A_genes is not None else []
mye_genes = list(lineage_B_genes) if lineage_B_genes is not None else []
records = []
for (gA, gB), bias in pair_ko_bias.items():
score_A_A = _get(single_ko_bias, gA, 'score_A')
score_B_A = _get(single_ko_bias, gA, 'score_B')
bias_A = _get(single_ko_bias, gA, 'lineage_bias')
score_A_B = _get(single_ko_bias, gB, 'score_A')
score_B_B = _get(single_ko_bias, gB, 'score_B')
bias_B = _get(single_ko_bias, gB, 'lineage_bias')
score_A_pair = bias.get('score_A', 0.0)
score_B_pair = bias.get('score_B', 0.0)
actual_bias = bias.get('lineage_bias', np.nan)
synergy_ery = score_A_pair - max(score_A_A, score_A_B)
synergy_mye = score_B_pair - max(score_B_A, score_B_B)
dominant_epi = _max_magnitude(synergy_ery, synergy_mye)
expected_bias = bias_A + bias_B
cancellation_error = actual_bias - expected_bias
if ery_genes and mye_genes:
in_ery_A = gA in ery_genes
in_ery_B = gB in ery_genes
in_mye_A = gA in mye_genes
in_mye_B = gB in mye_genes
if (in_ery_A and in_mye_B) or (in_mye_A and in_ery_B):
pair_type = 'cross'
elif in_ery_A and in_ery_B:
pair_type = 'ery-ery'
elif in_mye_A and in_mye_B:
pair_type = 'mye-mye'
else:
pair_type = 'other'
else:
pair_type = 'unknown'
records.append({
'geneA': gA,
'geneB': gB,
'pair': f'{gA}+{gB}',
'score_A': score_A_pair,
'score_B': score_B_pair,
'lineage_bias': actual_bias,
'expected_bias': expected_bias,
'cancellation_error': cancellation_error,
'synergy_ery': synergy_ery,
'synergy_mye': synergy_mye,
'dominant_epistasis': dominant_epi,
'pair_type': pair_type,
})
return (
pd.DataFrame(records)
.set_index('pair')
.sort_values('lineage_bias', ascending=False)
)
def run_dose_response(
adata: AnnData,
gene: str,
levels,
lineage_A_clusters: List[str],
lineage_B_clusters: List[str],
basis: str,
wt_flow_key: str,
natural_max: Optional[float] = None,
cluster_key: str = 'cell_type',
simulate_kwargs: Optional[Dict] = None,
verbose: bool = True,
) -> pd.DataFrame:
"""
Run ODE perturbation at multiple expression levels for dose-response analysis.
Sweeps from 0 (complete KO) through natural expression to 2x natural max
(strong OE). Returns lineage bias at each level, revealing whether the
erythroid/myeloid switch is graded or threshold-like.
Parameters
----------
adata : AnnData
Base (WT) AnnData with fitted model.
gene : str
Gene name to perturb.
levels : array-like
Absolute expression levels to test (e.g. ``np.linspace(0, max*2, 10)``).
lineage_A_clusters : list of str
Cluster names for lineage A (e.g. erythroid).
lineage_B_clusters : list of str
Cluster names for lineage B (e.g. myeloid).
basis : str
Embedding basis for flow projection.
wt_flow_key : str
Key in ``adata.obsm`` for the pre-computed WT Hopfield velocity.
natural_max : float, optional
Natural expression maximum for the gene (e.g. 99th percentile).
When provided, adds a ``level_frac`` column (``level / natural_max``).
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels.
simulate_kwargs : dict, optional
Extra keyword arguments forwarded to ``simulate_shift_ode``.
verbose : bool, optional (default: True)
Show tqdm progress bar.
Returns
-------
pd.DataFrame
Columns: ``['gene', 'level', 'level_frac', 'score_A', 'score_B', 'lineage_bias']``.
One row per level.
Examples
--------
>>> gata1_max = float(np.percentile(adata.layers['spliced'][:, idx], 99))
>>> levels = np.linspace(0, gata1_max * 2, 10)
>>> dr = sch.dyn.run_dose_response(
... adata, 'Gata1', levels, ERYTHROID, MYELOID,
... basis='draw_graph_fa', wt_flow_key='original_velocity_flow_draw_graph_fa',
... natural_max=gata1_max, cluster_key='paul15_clusters',
... )
"""
from .simulation import simulate_shift_ode
from ..tools.flow import calculate_flow
if gene not in adata.var_names:
raise ValueError(f"Gene '{gene}' not found in adata.var_names")
if simulate_kwargs is None:
simulate_kwargs = {}
sim_kw = dict(
cluster_key=cluster_key,
dt=5.0,
n_steps=100,
use_cluster_specific_GRN=True,
n_jobs=-1,
verbose=False,
)
sim_kw.update(simulate_kwargs)
levels = np.asarray(levels, dtype=float)
records = []
iter_levels = (
tqdm(levels, desc=f'{gene} dose-response') if verbose else levels
)
for level in iter_levels:
adata_t = simulate_shift_ode(
adata.copy(),
perturb_condition={gene: float(level)},
**sim_kw,
)
calculate_flow(
adata_t, source='delta', basis=basis, method='celloracle',
cluster_key=cluster_key,
store_key=f'perturbation_flow_{basis}',
verbose=False,
)
bias = compute_lineage_bias(
adata_t, adata,
lineage_A_clusters, lineage_B_clusters,
basis, wt_flow_key,
cluster_key=cluster_key,
)
level_frac = float(level) / natural_max if natural_max is not None else np.nan
records.append({
'gene': gene,
'level': float(level),
'level_frac': level_frac,
'score_A': bias.get('score_A', np.nan),
'score_B': bias.get('score_B', np.nan),
'lineage_bias': bias.get('lineage_bias', np.nan),
})
return pd.DataFrame(records)