"""
Visualization functions for perturbation simulation results.
References
----------
Logic for the transition vector field is inspired by the perturbation
simulation workflow in CellOracle:
Kamimoto et al. (2023). Nature. https://doi.org/10.1038/s41586-022-05688-9
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, List, Dict, Union, Tuple
from anndata import AnnData
from .._utils.io import get_genes_used
def _get_perturbed_genes(adata):
"""Get list of perturbed gene names from adata.uns."""
if 'scHopfield' in adata.uns and 'perturb_condition' in adata.uns['scHopfield']:
return list(adata.uns['scHopfield']['perturb_condition'].keys())
return []
def _filter_perturbed_genes(gene_names, perturbed_genes):
"""Return mask for genes that are NOT perturbed."""
return ~np.isin(gene_names, perturbed_genes)
[docs]
def plot_perturbation_effect_heatmap(
adata: AnnData,
cluster_key: str = 'cell_type',
n_genes: int = 30,
figsize: Tuple[float, float] = (12, 8),
cmap: str = 'RdBu_r',
center: float = 0,
cluster_cols: bool = True,
cluster_rows: bool = False,
order: Optional[List[str]] = None,
colors: Optional[Dict[str, str]] = None
) -> sns.matrix.ClusterGrid:
"""
Plot heatmap of perturbation effects across clusters and genes.
Parameters
----------
adata : AnnData
Annotated data object with simulation results (delta_X layer)
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
n_genes : int, optional (default: 30)
Number of top affected genes to show
figsize : tuple, optional
Figure size
cmap : str, optional (default: 'RdBu_r')
Colormap
center : float, optional (default: 0)
Center value for colormap
cluster_cols : bool, optional (default: True)
If True, cluster columns (clusters) with dendrogram
cluster_rows : bool, optional (default: False)
If True, cluster rows (genes) with dendrogram
order : list, optional
Order of clusters to plot. Ignored if cluster_cols=True.
colors : dict, optional
Dictionary mapping cluster names to colors for column color bar
Returns
-------
sns.matrix.ClusterGrid
ClusterGrid object with the heatmap
"""
if 'delta_X' not in adata.layers:
raise ValueError("No simulation results found. Run simulate_shift first.")
genes = get_genes_used(adata)
gene_names = adata.var_names[genes].values
delta_X = adata.layers['delta_X'][:, genes]
clusters = adata.obs[cluster_key].unique()
# Calculate mean delta per cluster and gene
cluster_effects = {}
for cluster in clusters:
mask = (adata.obs[cluster_key] == cluster).values
cluster_effects[cluster] = delta_X[mask, :].mean(axis=0)
df = pd.DataFrame(cluster_effects, index=gene_names)
# Exclude perturbed genes
perturbed_genes = _get_perturbed_genes(adata)
df = df.loc[~df.index.isin(perturbed_genes)]
# Select top genes by variance across clusters
gene_variance = df.var(axis=1)
top_genes = gene_variance.nlargest(n_genes).index
df = df.loc[top_genes]
# Apply order if specified and not clustering
if order is not None and not cluster_cols:
order = [c for c in order if c in df.columns]
df = df[order]
# Create column colors if colors dict provided
col_colors = None
if colors is not None:
col_colors = pd.Series([colors.get(c, '#cccccc') for c in df.columns], index=df.columns)
# Plot with clustermap
g = sns.clustermap(
df, cmap=cmap, center=center, figsize=figsize,
col_cluster=cluster_cols, row_cluster=cluster_rows,
xticklabels=True, yticklabels=True,
cbar_kws={'label': 'Mean Δ Expression'},
col_colors=col_colors,
dendrogram_ratio=(0.1, 0.15)
)
g.ax_heatmap.set_xlabel('Cluster', fontsize=11)
g.ax_heatmap.set_ylabel('Gene', fontsize=11)
# Get perturbation info for title
title = 'Perturbation Effects by Cluster'
if 'scHopfield' in adata.uns and 'perturb_condition' in adata.uns['scHopfield']:
perturb = adata.uns['scHopfield']['perturb_condition']
perturb_str = ', '.join([f"{k}={v}" for k, v in perturb.items()])
title = f'Perturbation Effects: {perturb_str}'
g.fig.suptitle(title, fontsize=12, fontweight='bold', y=1.02)
return g
[docs]
def plot_perturbation_magnitude(
adata: AnnData,
cluster_key: str = 'cell_type',
basis: str = 'umap',
figsize: Tuple[float, float] = (12, 5),
cmap: str = 'viridis',
vmax: Optional[float] = None,
order: Optional[List[str]] = None,
colors: Optional[Dict[str, str]] = None
) -> plt.Figure:
"""
Plot perturbation magnitude on embedding and as boxplot.
Parameters
----------
adata : AnnData
Annotated data object with simulation results
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
basis : str, optional (default: 'umap')
Embedding basis
figsize : tuple, optional
Figure size
cmap : str, optional (default: 'viridis')
Colormap for scatter plot
vmax : float, optional
Maximum value for colormap
order : list, optional
Order of clusters in boxplot. If None, sorts by median magnitude.
colors : dict, optional
Dictionary mapping cluster names to colors for boxplot
Returns
-------
plt.Figure
Figure with plots
"""
if 'delta_X' not in adata.layers:
raise ValueError("No simulation results found. Run simulate_shift first.")
# Calculate magnitude (excluding perturbed genes)
genes = get_genes_used(adata)
gene_names = adata.var_names[genes].values
perturbed_genes = _get_perturbed_genes(adata)
gene_mask = _filter_perturbed_genes(gene_names, perturbed_genes)
delta_X = adata.layers['delta_X'][:, genes][:, gene_mask]
magnitude = np.linalg.norm(delta_X, axis=1)
adata.obs['perturbation_magnitude'] = magnitude
fig, axes = plt.subplots(1, 2, figsize=figsize)
# Scatter plot on embedding
embedding_key = f'X_{basis}'
if embedding_key in adata.obsm:
coords = adata.obsm[embedding_key]
sc = axes[0].scatter(
coords[:, 0], coords[:, 1],
c=magnitude, cmap=cmap, s=10, alpha=0.7,
vmax=vmax, rasterized=True
)
axes[0].set_xlabel(f'{basis.upper()} 1', fontsize=10)
axes[0].set_ylabel(f'{basis.upper()} 2', fontsize=10)
axes[0].set_title('Perturbation Magnitude', fontsize=12, fontweight='bold')
plt.colorbar(sc, ax=axes[0], label='||Δx||')
axes[0].axis('equal')
else:
axes[0].text(0.5, 0.5, f'Embedding {embedding_key} not found',
ha='center', va='center', transform=axes[0].transAxes)
# Boxplot by cluster
df = pd.DataFrame({
'Cluster': adata.obs[cluster_key].values,
'Magnitude': magnitude
})
# Determine order
if order is None:
order = df.groupby('Cluster')['Magnitude'].median().sort_values(ascending=False).index.tolist()
# Create palette from colors dict
palette = None
if colors is not None:
palette = [colors.get(c, '#cccccc') for c in order]
sns.boxplot(data=df, x='Cluster', y='Magnitude', order=order, palette=palette, ax=axes[1])
axes[1].set_xlabel('Cluster', fontsize=10)
axes[1].set_ylabel('Perturbation Magnitude', fontsize=10)
axes[1].set_title('Effect by Cluster', fontsize=12, fontweight='bold')
if len(order) > 5:
axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=45, ha='right')
axes[1].grid(True, alpha=0.3, axis='y')
plt.tight_layout()
return fig
[docs]
def plot_gene_response(
adata: AnnData,
genes: Union[str, List[str]],
cluster_key: str = 'cell_type',
figsize: Optional[Tuple[float, float]] = None,
order: Optional[List[str]] = None,
colors: Optional[Dict[str, str]] = None
) -> plt.Figure:
"""
Plot expression change for specific genes across clusters.
Parameters
----------
adata : AnnData
Annotated data object with simulation results
genes : str or list
Gene(s) to plot
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
figsize : tuple, optional
Figure size
order : list, optional
Order of clusters to plot. If None, sorts by median effect.
colors : dict, optional
Dictionary mapping cluster names to colors
Returns
-------
plt.Figure
Figure with plots
"""
if 'delta_X' not in adata.layers:
raise ValueError("No simulation results found. Run simulate_shift first.")
if isinstance(genes, str):
genes = [genes]
n_genes = len(genes)
if figsize is None:
figsize = (5 * n_genes, 5)
fig, axes = plt.subplots(1, n_genes, figsize=figsize, squeeze=False)
axes = axes.flatten()
sch_genes = get_genes_used(adata)
gene_names = adata.var_names[sch_genes].values
for i, gene in enumerate(genes):
ax = axes[i]
if gene not in gene_names:
ax.text(0.5, 0.5, f'{gene} not in analysis',
ha='center', va='center', transform=ax.transAxes)
continue
gene_idx = np.where(gene_names == gene)[0][0]
delta = adata.layers['delta_X'][:, sch_genes[gene_idx]]
df = pd.DataFrame({
'Cluster': adata.obs[cluster_key].values,
'Δ Expression': delta
})
# Determine order
if order is None:
plot_order = df.groupby('Cluster')['Δ Expression'].median().sort_values().index.tolist()
else:
plot_order = order
# Create palette from colors dict
palette = None
if colors is not None:
palette = [colors.get(c, '#cccccc') for c in plot_order]
sns.violinplot(data=df, x='Cluster', y='Δ Expression',
order=plot_order, ax=ax, palette=palette, inner='box')
ax.axhline(0, color='red', linestyle='--', alpha=0.5)
ax.set_title(f'{gene}', fontsize=12, fontweight='bold')
ax.set_xlabel('Cluster', fontsize=10)
ax.set_ylabel('Δ Expression', fontsize=10)
if len(plot_order) > 5:
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
return fig
[docs]
def plot_top_affected_genes_bar(
adata: AnnData,
n_genes: int = 20,
cluster: Optional[str] = None,
cluster_key: str = 'cell_type',
figsize: Tuple[float, float] = (10, 8),
ax: Optional[plt.Axes] = None
) -> plt.Axes:
"""
Bar plot of top affected genes showing direction and magnitude.
Parameters
----------
adata : AnnData
Annotated data object with simulation results
n_genes : int, optional (default: 20)
Number of genes to show
cluster : str, optional
Specific cluster to analyze. If None, uses all cells.
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
figsize : tuple, optional
Figure size
ax : plt.Axes, optional
Axes to plot on
Returns
-------
plt.Axes
Axes with plot
"""
if 'delta_X' not in adata.layers:
raise ValueError("No simulation results found. Run simulate_shift first.")
genes = get_genes_used(adata)
gene_names = adata.var_names[genes].values
# Exclude perturbed genes
perturbed_genes = _get_perturbed_genes(adata)
gene_mask = _filter_perturbed_genes(gene_names, perturbed_genes)
gene_names = gene_names[gene_mask]
if cluster is not None:
mask = (adata.obs[cluster_key] == cluster).values
delta_X = adata.layers['delta_X'][mask, :][:, genes][:, gene_mask]
title_suffix = f' ({cluster})'
else:
delta_X = adata.layers['delta_X'][:, genes][:, gene_mask]
title_suffix = ' (All cells)'
mean_delta = delta_X.mean(axis=0)
abs_delta = np.abs(mean_delta)
# Get top genes
top_idx = np.argsort(abs_delta)[-n_genes:]
top_genes = gene_names[top_idx]
top_values = mean_delta[top_idx]
# Sort by actual value (not absolute)
sort_idx = np.argsort(top_values)
top_genes = top_genes[sort_idx]
top_values = top_values[sort_idx]
# Plot
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
colors = ['#1f77b4' if v > 0 else '#d62728' for v in top_values]
ax.barh(range(len(top_genes)), top_values, color=colors, edgecolor='black', linewidth=0.5)
ax.set_yticks(range(len(top_genes)))
ax.set_yticklabels(top_genes)
ax.axvline(0, color='black', linewidth=0.5)
ax.set_xlabel('Mean Δ Expression', fontsize=11)
ax.set_ylabel('Gene', fontsize=11)
ax.set_title(f'Top Affected Genes{title_suffix}', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')
# Add legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='#1f77b4', edgecolor='black', label='Upregulated'),
Patch(facecolor='#d62728', edgecolor='black', label='Downregulated')]
ax.legend(handles=legend_elements, loc='lower right')
plt.tight_layout()
return ax
[docs]
def plot_simulation_comparison(
adata: AnnData,
gene: str,
cluster_key: str = 'cell_type',
figsize: Tuple[float, float] = (12, 5),
order: Optional[List[str]] = None,
colors: Optional[Dict[str, str]] = None
) -> plt.Figure:
"""
Compare original and simulated expression for a gene.
Parameters
----------
adata : AnnData
Annotated data object with simulation results
gene : str
Gene to compare
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
figsize : tuple, optional
Figure size
order : list, optional
Order of clusters in boxplot. If None, sorts by median delta.
colors : dict, optional
Dictionary mapping cluster names to colors for boxplot
Returns
-------
plt.Figure
Figure with comparison plots
"""
genes = get_genes_used(adata)
gene_names = adata.var_names[genes].values
if gene not in gene_names:
raise ValueError(f"Gene '{gene}' not found in analysis")
gene_idx = np.where(gene_names == gene)[0][0]
sch_gene_idx = genes[gene_idx]
# Get expression data from spliced layer
spliced_key = adata.uns.get('scHopfield', {}).get('spliced_key', 'Ms')
if spliced_key in adata.layers:
original = adata.layers[spliced_key][:, sch_gene_idx]
else:
original = adata.X[:, sch_gene_idx]
if hasattr(original, 'toarray'):
original = original.toarray().flatten()
simulated = adata.layers['simulated_count'][:, sch_gene_idx]
fig, axes = plt.subplots(1, 3, figsize=figsize)
# Histogram comparison
axes[0].hist(original, bins=50, alpha=0.5, label='Original', density=True)
axes[0].hist(simulated, bins=50, alpha=0.5, label='Simulated', density=True)
axes[0].set_xlabel('Expression', fontsize=10)
axes[0].set_ylabel('Density', fontsize=10)
axes[0].set_title(f'{gene} Distribution', fontsize=12, fontweight='bold')
axes[0].legend()
# Scatter: original vs simulated
axes[1].scatter(original, simulated, alpha=0.3, s=5, rasterized=True)
lims = [min(original.min(), simulated.min()), max(original.max(), simulated.max())]
axes[1].plot(lims, lims, 'r--', alpha=0.5, label='y=x')
axes[1].set_xlabel('Original', fontsize=10)
axes[1].set_ylabel('Simulated', fontsize=10)
axes[1].set_title('Original vs Simulated', fontsize=12, fontweight='bold')
axes[1].legend()
# Delta by cluster
delta = simulated - original
df = pd.DataFrame({
'Cluster': adata.obs[cluster_key].values,
'Δ Expression': delta
})
# Determine order
if order is None:
plot_order = df.groupby('Cluster')['Δ Expression'].median().sort_values().index.tolist()
else:
plot_order = order
# Create palette from colors dict
palette = None
if colors is not None:
palette = [colors.get(c, '#cccccc') for c in plot_order]
sns.boxplot(data=df, x='Cluster', y='Δ Expression', order=plot_order, palette=palette, ax=axes[2])
axes[2].axhline(0, color='red', linestyle='--', alpha=0.5)
axes[2].set_xlabel('Cluster', fontsize=10)
axes[2].set_ylabel('Δ Expression', fontsize=10)
axes[2].set_title('Change by Cluster', fontsize=12, fontweight='bold')
if len(plot_order) > 5:
axes[2].set_xticklabels(axes[2].get_xticklabels(), rotation=45, ha='right')
plt.tight_layout()
return fig