"""
Flow visualization functions for perturbation analysis.
This module contains pure visualization functions for flow analysis.
Computation functions are in:
- scHopfield.tools.flow: calculate_flow, calculate_grid_flow, calculate_inner_product
- scHopfield.tools.velocity: compute_velocity, compute_velocity_delta
- scHopfield.tools.embedding: project_to_embedding
- scHopfield.dynamics.simulation: calculate_trajectory_flow
References
----------
Logic for the transition vector field is inspired by the perturbation
simulation workflow in CellOracle:
Kamimoto et al. (2023). Nature. https://doi.org/10.1038/s41586-022-05688-9
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
import seaborn as sns
from typing import Optional, List, Dict, Tuple
from anndata import AnnData
from ..tools.flow import calculate_flow, calculate_grid_flow, calculate_inner_product
from .._utils.io import get_genes_used
# =============================================================================
# Main Plotting Functions
# =============================================================================
def plot_flow(
adata: AnnData,
flow_key: Optional[str] = None,
basis: str = 'umap',
ax: Optional[plt.Axes] = None,
on_grid: bool = False,
scale: float = 1.0,
color: str = 'black',
alpha: float = 0.8,
show_background: bool = True,
cluster_key: Optional[str] = None,
colors: Optional[Dict[str, str]] = None,
s: float = 10,
figsize: Tuple[float, float] = (8, 8),
title: Optional[str] = None,
# Grid options
n_grid: int = 40,
n_neighbors: int = 200,
min_mass: float = 1.0,
recalculate: bool = False,
n_jobs: int = 4,
**quiver_kwargs
) -> plt.Axes:
"""
Unified flow plotting function.
Can plot flow vectors directly on cells or interpolated onto a grid.
Parameters
----------
adata : AnnData
Annotated data with flow vectors
flow_key : str, optional
Key in adata.obsm for flow vectors.
If None, uses 'perturbation_flow_{basis}'.
basis : str, optional (default: 'umap')
Embedding basis
ax : plt.Axes, optional
Axes to plot on. If None, creates new figure.
on_grid : bool, optional (default: False)
If True, interpolate flow to grid before plotting.
scale : float, optional (default: 1.0)
Scale factor for arrows
color : str, optional (default: 'black')
Arrow color
alpha : float, optional (default: 0.8)
Arrow transparency
show_background : bool, optional (default: True)
Show background scatter of cells
cluster_key : str, optional
Key for cluster labels (for coloring background)
colors : dict, optional
Dictionary mapping cluster names to colors
s : float, optional (default: 10)
Scatter point size
figsize : tuple, optional
Figure size if creating new figure
title : str, optional
Plot title. If None, auto-generates based on flow_key.
n_grid : int, optional (default: 40)
Number of grid points per dimension (when on_grid=True)
n_neighbors : int, optional (default: 200)
Number of neighbors for grid interpolation
min_mass : float, optional (default: 1.0)
Minimum probability mass to show arrows
recalculate : bool, optional (default: False)
If True, recalculate grid flow even if cached
n_jobs : int, optional (default: 4)
Number of parallel jobs
**quiver_kwargs
Additional arguments for matplotlib quiver
Returns
-------
plt.Axes
"""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
embedding_key = f'X_{basis}'
embedding = adata.obsm[embedding_key]
if flow_key is None:
flow_key = f'perturbation_flow_{basis}'
if flow_key not in adata.obsm:
raise ValueError(f"Flow '{flow_key}' not found. Run calculate_flow first.")
# Background scatter
if show_background:
if cluster_key is not None and colors is not None:
c = [colors.get(cl, 'lightgray') for cl in adata.obs[cluster_key]]
else:
c = 'lightgray'
ax.scatter(embedding[:, 0], embedding[:, 1], c=c,
s=s, alpha=0.5, rasterized=True)
# Default quiver settings
default_quiver = dict(
headaxislength=4, headlength=5, headwidth=4,
linewidths=0.5, width=0.003
)
default_quiver.update(quiver_kwargs)
if on_grid:
# Interpolate to grid
grid_key = f'grid_flow_{flow_key}'
if grid_key in adata.uns and not recalculate:
grid_data = adata.uns[grid_key]
else:
grid_data = calculate_grid_flow(
adata, flow_key=flow_key, basis=basis, n_grid=n_grid,
n_neighbors=n_neighbors, min_mass=min_mass, n_jobs=n_jobs
)
adata.uns[grid_key] = grid_data
grid_coords = grid_data['grid_coords']
grid_flow = grid_data['grid_flow']
mass_filter = grid_data['mass_filter']
valid = ~mass_filter
default_quiver['width'] = 0.004 # Slightly wider for grid
ax.quiver(
grid_coords[valid, 0], grid_coords[valid, 1],
grid_flow[valid, 0], grid_flow[valid, 1],
color=color, alpha=alpha, scale=scale,
**default_quiver
)
else:
# Plot directly on cells
flow = adata.obsm[flow_key]
ax.quiver(
embedding[:, 0], embedding[:, 1],
flow[:, 0], flow[:, 1],
color=color, alpha=alpha, scale=scale,
**default_quiver
)
# Title
if title is None:
if 'perturbation' in flow_key:
if 'scHopfield' in adata.uns and 'perturb_condition' in adata.uns['scHopfield']:
perturb = adata.uns['scHopfield']['perturb_condition']
perturb_str = ', '.join([f"{k}={'KO' if v==0 else v}" for k, v in perturb.items()])
title = f'Perturbation Flow: {perturb_str}'
else:
title = 'Perturbation Flow'
else:
title = flow_key.replace('_', ' ').title()
ax.set_title(title, fontsize=12, fontweight='bold')
ax.axis('off')
ax.set_aspect('equal')
return ax
def plot_inner_product(
adata: AnnData,
basis: str = 'umap',
by_cluster: bool = False,
cluster_key: str = 'cell_type',
ax: Optional[plt.Axes] = None,
inner_product_key: str = 'perturbation_inner_product',
vmin: float = -1,
vmax: float = 1,
cmap: str = 'RdBu_r',
s: float = 15,
figsize: Tuple[float, float] = (8, 8),
title: Optional[str] = None,
show_colorbar: bool = True,
order: Optional[List[str]] = None,
colors: Optional[Dict[str, str]] = None,
on_grid: bool = False,
n_grid: int = 40,
min_mass: float = 1.0,
) -> plt.Axes:
"""
Plot inner product values on embedding or by cluster.
Parameters
----------
adata : AnnData
Annotated data with inner product calculated
basis : str, optional (default: 'umap')
Embedding basis
by_cluster : bool, optional (default: False)
If True, show boxplot by cluster. If False, show on embedding.
cluster_key : str, optional (default: 'cell_type')
Key for cluster labels
ax : plt.Axes, optional
Axes to plot on
inner_product_key : str, optional (default: 'perturbation_inner_product')
Key in adata.obs for inner product values
vmin, vmax : float, optional
Color scale limits
cmap : str, optional (default: 'RdBu_r')
Colormap
s : float, optional (default: 15)
Point size
figsize : tuple, optional
Figure size
title : str, optional
Plot title
show_colorbar : bool, optional (default: True)
Whether to show colorbar (embedding mode)
order : list, optional
Order of clusters (cluster mode)
colors : dict, optional
Colors for clusters (cluster mode)
Returns
-------
plt.Axes
"""
if inner_product_key not in adata.obs:
raise ValueError(f"Inner product '{inner_product_key}' not found. "
"Run calculate_inner_product first.")
if by_cluster:
return _plot_inner_product_by_cluster(
adata, cluster_key=cluster_key, ax=ax,
inner_product_key=inner_product_key,
figsize=figsize, title=title, order=order, colors=colors
)
else:
return _plot_inner_product_on_embedding(
adata, basis=basis, ax=ax, inner_product_key=inner_product_key,
vmin=vmin, vmax=vmax, cmap=cmap, s=s, figsize=figsize,
title=title, show_colorbar=show_colorbar,
on_grid=on_grid, n_grid=n_grid, min_mass=min_mass
)
def _plot_inner_product_on_embedding(
adata: AnnData,
basis: str = 'umap',
ax: Optional[plt.Axes] = None,
inner_product_key: str = 'perturbation_inner_product',
vmin: float = -1,
vmax: float = 1,
cmap: str = 'RdBu_r',
s: float = 15,
figsize: Tuple[float, float] = (8, 8),
title: Optional[str] = None,
show_colorbar: bool = True,
on_grid: bool = False,
n_grid: int = 40,
min_mass: float = 1.0,
) -> plt.Axes:
"""Plot inner product on embedding."""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
embedding = adata.obsm[f'X_{basis}']
inner_product = adata.obs[inner_product_key].values
try:
norm = mpl_colors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
except Exception:
norm = mpl_colors.Normalize(vmin=vmin, vmax=vmax)
if on_grid:
from ..tools.flow import calculate_grid_scalar
grid_data = calculate_grid_scalar(
adata, scalar_key=inner_product_key, basis=basis, n_grid=n_grid,
min_mass=min_mass
)
grid_coords = grid_data['grid_coords']
grid_scalar = grid_data['grid_scalar']
mass_filter = grid_data['mass_filter']
valid = ~mass_filter
# Use a slightly larger point size for grid, or fallback to s
grid_s = s * 4 if on_grid else s
sc = ax.scatter(grid_coords[valid, 0], grid_coords[valid, 1], c=grid_scalar[valid],
cmap=cmap, norm=norm, s=grid_s, rasterized=True)
else:
sc = ax.scatter(embedding[:, 0], embedding[:, 1], c=inner_product,
cmap=cmap, norm=norm, s=s, rasterized=True)
if show_colorbar:
cbar = plt.colorbar(sc, ax=ax, shrink=0.6)
cbar.set_label('Inner Product', fontsize=10)
if title is None:
title = 'Inner Product\n(Perturbation vs Reference)'
ax.set_title(title, fontsize=12, fontweight='bold')
ax.axis('off')
ax.set_aspect('equal')
return ax
def _plot_inner_product_by_cluster(
adata: AnnData,
cluster_key: str = 'cell_type',
ax: Optional[plt.Axes] = None,
inner_product_key: str = 'perturbation_inner_product',
figsize: Tuple[float, float] = (10, 5),
title: Optional[str] = None,
order: Optional[List[str]] = None,
colors: Optional[Dict[str, str]] = None,
) -> plt.Axes:
"""Plot inner product by cluster."""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
df = pd.DataFrame({
'Cluster': adata.obs[cluster_key].values,
'Inner Product': adata.obs[inner_product_key].values
})
if order is None:
order = df.groupby('Cluster')['Inner Product'].median().sort_values().index.tolist()
palette = None
if colors is not None:
palette = [colors.get(c, '#cccccc') for c in order]
sns.boxplot(data=df, x='Cluster', y='Inner Product', order=order,
palette=palette, ax=ax)
ax.axhline(0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('Cluster', fontsize=11)
ax.set_ylabel('Inner Product Score', fontsize=11)
if title is None:
title = 'Inner Product by Cluster'
ax.set_title(title, fontsize=12, fontweight='bold')
if len(order) > 5:
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
ax.grid(True, alpha=0.3, axis='y')
sns.despine()
return ax
def visualize_flow_comparison(
adata: AnnData,
flows: Optional[List[str]] = None,
basis: str = 'umap',
cluster_key: str = 'cell_type',
colors: Optional[Dict[str, str]] = None,
scale: float = 1.0,
figsize: Tuple[float, float] = (20, 6),
n_neighbors: int = 30,
n_jobs: int = 4,
use_cluster_specific: bool = True,
**kwargs
) -> plt.Figure:
"""
Multi-panel comparison of different flow fields.
Parameters
----------
adata : AnnData
Annotated data with perturbation results
flows : list of str, optional
Flow types to compare. If None, uses ['original', 'perturbed', 'delta'].
Valid options: 'original', 'perturbed', 'delta', or any flow_key in obsm.
basis : str, optional (default: 'umap')
Embedding basis
cluster_key : str, optional (default: 'cell_type')
Cluster key
colors : dict, optional
Cluster colors
scale : float, optional (default: 1.0)
Arrow scale
figsize : tuple, optional
Figure size
n_neighbors : int, optional (default: 30)
Neighbors for flow calculation
n_jobs : int, optional (default: 4)
Parallel jobs
use_cluster_specific : bool, optional (default: True)
Use cluster-specific GRNs
Returns
-------
plt.Figure
"""
if flows is None:
flows = ['clusters', 'original', 'perturbed', 'delta']
n_panels = len(flows)
fig, axes = plt.subplots(1, n_panels, figsize=figsize)
if n_panels == 1:
axes = [axes]
embedding = adata.obsm[f'X_{basis}']
# Get perturbation info
perturb_str = "Perturbation"
if 'scHopfield' in adata.uns and 'perturb_condition' in adata.uns['scHopfield']:
perturb = adata.uns['scHopfield']['perturb_condition']
perturb_str = ', '.join([f"{k}={'KO' if v==0 else v}" for k, v in perturb.items()])
flow_colors = {
'original': '#3498DB',
'perturbed': '#27AE60',
'delta': '#E74C3C'
}
for i, flow_type in enumerate(flows):
ax = axes[i]
if flow_type == 'clusters':
# Show clusters
if colors is not None:
c = [colors.get(cl, 'gray') for cl in adata.obs[cluster_key]]
else:
c = adata.obs[cluster_key].astype('category').cat.codes
ax.scatter(embedding[:, 0], embedding[:, 1], c=c, s=10, alpha=0.7)
ax.set_title('Clusters', fontsize=12, fontweight='bold')
elif flow_type in ['original', 'perturbed', 'delta']:
# Calculate and plot flow
flow_key = f'{flow_type}_velocity_flow_{basis}' if flow_type != 'delta' else f'perturbation_flow_{basis}'
if flow_key not in adata.obsm:
# Calculate flow
source = flow_type if flow_type != 'delta' else 'delta'
calculate_flow(
adata, source=source, basis=basis,
cluster_key=cluster_key, use_cluster_specific=use_cluster_specific,
n_neighbors=n_neighbors, n_jobs=n_jobs, verbose=False
)
# Background
if colors is not None:
c = [colors.get(cl, 'lightgray') for cl in adata.obs[cluster_key]]
else:
c = 'lightgray'
ax.scatter(embedding[:, 0], embedding[:, 1], c=c, s=10, alpha=0.5)
# Flow arrows
flow = adata.obsm[flow_key]
ax.quiver(
embedding[:, 0], embedding[:, 1],
flow[:, 0], flow[:, 1],
color=flow_colors.get(flow_type, 'black'),
alpha=0.8, scale=scale,
headaxislength=4, headlength=5, headwidth=4
)
title_map = {
'original': 'Original Hopfield Velocity',
'perturbed': f'Perturbed Velocity\n({perturb_str})',
'delta': 'Delta Velocity\n(Perturbed - Original)'
}
ax.set_title(title_map.get(flow_type, flow_type), fontsize=12, fontweight='bold')
else:
# Custom flow key
if flow_type in adata.obsm:
# Background
if colors is not None:
c = [colors.get(cl, 'lightgray') for cl in adata.obs[cluster_key]]
else:
c = 'lightgray'
ax.scatter(embedding[:, 0], embedding[:, 1], c=c, s=10, alpha=0.5)
flow = adata.obsm[flow_type]
ax.quiver(
embedding[:, 0], embedding[:, 1],
flow[:, 0], flow[:, 1],
color='black', alpha=0.8, scale=scale,
headaxislength=4, headlength=5, headwidth=4
)
ax.set_title(flow_type.replace('_', ' ').title(), fontsize=12, fontweight='bold')
else:
ax.text(0.5, 0.5, f'Flow not found:\n{flow_type}',
ha='center', va='center', transform=ax.transAxes)
ax.axis('off')
ax.set_aspect('equal')
fig.suptitle(f'Hopfield Velocity Analysis: {perturb_str}', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
return fig
[docs]
def visualize_perturbation_flow(
adata: AnnData,
basis: str = 'umap',
velocity_key: Optional[str] = None,
cluster_key: str = 'cell_type',
colors: Optional[Dict[str, str]] = None,
scale_reference: float = 1.0,
scale_perturbation: float = 1.0,
figsize: Tuple[float, float] = (20, 10),
vm: float = 1.0
) -> plt.Figure:
"""
Create comprehensive visualization of perturbation flow analysis.
Creates a 2x3 figure with:
- Row 0: Clusters, Reference velocity, Perturbation flow
- Row 1: Inner product on embedding, Inner product + flow, Inner product by cluster
Parameters
----------
adata : AnnData
Annotated data with perturbation simulation results
basis : str, optional (default: 'umap')
Embedding basis
velocity_key : str, optional
Key for reference velocity
cluster_key : str, optional (default: 'cell_type')
Key for cluster labels
colors : dict, optional
Colors for clusters
scale_reference : float, optional (default: 1.0)
Scale for reference flow arrows
scale_perturbation : float, optional (default: 1.0)
Scale for perturbation flow arrows
figsize : tuple, optional
Figure size
vm : float, optional (default: 1.0)
Max value for inner product colorscale
Returns
-------
plt.Figure
"""
# Get perturbation info
perturb_str = "Perturbation"
if 'scHopfield' in adata.uns and 'perturb_condition' in adata.uns['scHopfield']:
perturb = adata.uns['scHopfield']['perturb_condition']
perturb_str = ', '.join([f"{k}={'KO' if v==0 else v}" for k, v in perturb.items()])
fig, axes = plt.subplots(2, 3, figsize=figsize)
embedding = adata.obsm[f'X_{basis}']
# Row 0, Col 0: Clusters
ax = axes[0, 0]
if colors is not None:
c = [colors.get(cl, 'gray') for cl in adata.obs[cluster_key]]
else:
c = adata.obs[cluster_key].astype('category').cat.codes
ax.scatter(embedding[:, 0], embedding[:, 1], c=c, s=10, alpha=0.7, rasterized=True)
ax.set_title('Clusters', fontsize=12, fontweight='bold')
ax.axis('off')
ax.set_aspect('equal')
# Row 0, Col 1: Reference velocity
ax = axes[0, 1]
if velocity_key is None:
velocity_key = f'velocity_{basis}'
try:
plot_reference_flow(
adata, basis=basis, velocity_key=velocity_key, ax=ax,
scale=scale_reference, title='Reference Velocity'
)
except ValueError as e:
ax.text(0.5, 0.5, f'No velocity data\n({e})',
ha='center', va='center', transform=ax.transAxes)
ax.axis('off')
# Row 0, Col 2: Perturbation flow
ax = axes[0, 2]
try:
plot_flow(
adata, basis=basis, ax=ax, scale=scale_perturbation,
cluster_key=cluster_key, colors=colors,
title=f'Perturbation Flow\n({perturb_str})', color='#EC7063'
)
except ValueError as e:
ax.text(0.5, 0.5, f'No perturbation flow\n({e})',
ha='center', va='center', transform=ax.transAxes)
ax.axis('off')
# Row 1, Col 0: Inner product on embedding
ax = axes[1, 0]
try:
plot_inner_product(
adata, basis=basis, ax=ax, vmin=-vm, vmax=vm,
title='Inner Product\n(Perturbation \u00d7 Reference)'
)
except ValueError as e:
ax.text(0.5, 0.5, f'No inner product\n({e})',
ha='center', va='center', transform=ax.transAxes)
ax.axis('off')
# Row 1, Col 1: Inner product + flow overlay
ax = axes[1, 1]
try:
plot_inner_product(
adata, basis=basis, ax=ax, vmin=-vm, vmax=vm,
show_colorbar=False, s=10, title=''
)
plot_flow(
adata, basis=basis, ax=ax,
show_background=False, color='black', alpha=0.6
)
ax.set_title('Inner Product + Flow', fontsize=12, fontweight='bold')
except ValueError:
ax.axis('off')
# Row 1, Col 2: Inner product by cluster
ax = axes[1, 2]
try:
plot_inner_product(
adata, by_cluster=True, cluster_key=cluster_key, ax=ax,
colors=colors, title='Inner Product by Cluster'
)
except ValueError as e:
ax.text(0.5, 0.5, f'No inner product\n({e})',
ha='center', va='center', transform=ax.transAxes)
ax.axis('off')
fig.suptitle(f'Perturbation Analysis: {perturb_str}', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
return fig
# =============================================================================
# Additional Plotting Functions
# =============================================================================
[docs]
def plot_reference_flow(
adata: AnnData,
basis: str = 'umap',
velocity_key: Optional[str] = None,
ax: Optional[plt.Axes] = None,
scale: float = 1.0,
color: str = 'black',
alpha: float = 0.8,
show_background: bool = True,
background_color: str = 'lightgray',
s: float = 10,
figsize: Tuple[float, float] = (8, 8),
title: str = 'Reference Velocity',
**quiver_kwargs
) -> plt.Axes:
"""
Plot reference velocity flow (e.g., from scVelo).
Parameters
----------
adata : AnnData
Annotated data with velocity
basis : str, optional (default: 'umap')
Embedding basis
velocity_key : str, optional
Key for velocity in obsm
ax : plt.Axes, optional
Axes to plot on
scale : float, optional (default: 1.0)
Scale factor for arrows
color : str, optional (default: 'black')
Arrow color
alpha : float, optional (default: 0.8)
Arrow transparency
show_background : bool, optional (default: True)
Show background scatter
background_color : str, optional (default: 'lightgray')
Background scatter color
s : float, optional (default: 10)
Scatter point size
figsize : tuple, optional
Figure size
title : str, optional
Plot title
**quiver_kwargs
Additional arguments for quiver plot
Returns
-------
plt.Axes
"""
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
embedding = adata.obsm[f'X_{basis}']
if velocity_key is None:
velocity_key = f'velocity_{basis}'
if velocity_key not in adata.obsm:
raise ValueError(f"Velocity '{velocity_key}' not found in adata.obsm")
velocity = adata.obsm[velocity_key]
if show_background:
ax.scatter(embedding[:, 0], embedding[:, 1], c=background_color,
s=s, alpha=0.5, rasterized=True)
default_quiver = dict(headaxislength=4, headlength=5, headwidth=4,
linewidths=0.5, width=0.003)
default_quiver.update(quiver_kwargs)
ax.quiver(embedding[:, 0], embedding[:, 1],
velocity[:, 0], velocity[:, 1],
color=color, alpha=alpha, scale=scale,
**default_quiver)
ax.set_title(title, fontsize=12, fontweight='bold')
ax.axis('off')
ax.set_aspect('equal')
return ax
def plot_ode_perturbation_flow(
adata: AnnData,
basis: str = 'umap',
ax: Optional[plt.Axes] = None,
scale: float = 1.0,
color: str = '#9B59B6',
alpha: float = 0.8,
show_background: bool = True,
cluster_key: Optional[str] = None,
colors: Optional[Dict[str, str]] = None,
s: float = 10,
figsize: Tuple[float, float] = (8, 8),
title: str = 'ODE Perturbation Flow',
**quiver_kwargs
) -> plt.Axes:
"""
Plot ODE trajectory perturbation flow.
Parameters
----------
adata : AnnData
Annotated data with ODE perturbation flow
basis : str, optional (default: 'umap')
Embedding basis
ax : plt.Axes, optional
Axes to plot on
scale : float, optional (default: 1.0)
Arrow scale
color : str, optional (default: '#9B59B6')
Arrow color (purple)
alpha : float, optional (default: 0.8)
Arrow transparency
show_background : bool, optional (default: True)
Show background scatter
cluster_key : str, optional
Cluster key for coloring
colors : dict, optional
Cluster colors
s : float, optional (default: 10)
Point size
figsize : tuple, optional
Figure size
title : str, optional
Plot title
**quiver_kwargs
Additional quiver arguments
Returns
-------
plt.Axes
"""
flow_key = f'ode_perturbation_flow_{basis}'
if flow_key not in adata.obsm:
raise ValueError(f"ODE flow '{flow_key}' not found. "
"Run calculate_trajectory_flow first.")
return plot_flow(
adata, flow_key=flow_key, basis=basis, ax=ax,
scale=scale, color=color, alpha=alpha,
show_background=show_background, cluster_key=cluster_key,
colors=colors, s=s, figsize=figsize, title=title,
**quiver_kwargs
)
def visualize_ode_perturbation(
adata: AnnData,
wt_trajectories: Dict[str, np.ndarray],
perturbed_trajectories: Dict[str, np.ndarray],
gene_perturbations: Dict[str, float],
t_span: np.ndarray,
cluster_key: str = 'cell_type',
basis: str = 'umap',
velocity_key: Optional[str] = None,
colors: Optional[Dict[str, str]] = None,
method: str = 'hopfield',
figsize: Tuple[float, float] = (20, 10),
scale_flow: float = 1.0,
vm: float = 1.0
) -> plt.Figure:
"""
Create comprehensive visualization of ODE perturbation analysis.
Parameters
----------
adata : AnnData
Annotated data
wt_trajectories : dict
WT trajectories per cluster
perturbed_trajectories : dict
Perturbed trajectories per cluster
gene_perturbations : dict
Perturbation conditions
t_span : np.ndarray
Time points
cluster_key : str, optional
Cluster key
basis : str, optional
Embedding basis
velocity_key : str, optional
Reference velocity key
colors : dict, optional
Cluster colors
method : str, optional (default: 'hopfield')
Flow calculation method
figsize : tuple, optional
Figure size
scale_flow : float, optional
Arrow scale
vm : float, optional
Inner product colorscale max
Returns
-------
plt.Figure
"""
# Import and calculate flow
from ..dynamics.simulation import calculate_trajectory_flow
calculate_trajectory_flow(
adata, wt_trajectories, perturbed_trajectories,
cluster_key=cluster_key, basis=basis, method=method
)
# Calculate inner product
if velocity_key is None:
velocity_key = f'velocity_{basis}'
if velocity_key in adata.obsm:
flow_key = f'ode_perturbation_flow_{basis}'
calculate_inner_product(
adata, velocity_key, flow_key,
store_key='ode_perturbation_inner_product'
)
# Create figure
perturb_str = ', '.join([f"{k}={'KO' if v==0 else 'OE' if v>0 else v}"
for k, v in gene_perturbations.items()])
fig, axes = plt.subplots(2, 3, figsize=figsize)
embedding = adata.obsm[f'X_{basis}']
# Row 0, Col 0: Clusters
ax = axes[0, 0]
if colors is not None:
c = [colors.get(cl, 'gray') for cl in adata.obs[cluster_key]]
else:
c = adata.obs[cluster_key].astype('category').cat.codes
ax.scatter(embedding[:, 0], embedding[:, 1], c=c, s=10, alpha=0.7)
ax.set_title('Clusters', fontsize=12, fontweight='bold')
ax.axis('off')
# Row 0, Col 1: Reference velocity
ax = axes[0, 1]
try:
plot_reference_flow(adata, basis=basis, velocity_key=velocity_key,
ax=ax, scale=scale_flow*5, title='Reference Velocity')
except Exception:
ax.text(0.5, 0.5, 'No velocity data', ha='center', va='center',
transform=ax.transAxes)
ax.axis('off')
# Row 0, Col 2: ODE perturbation flow
ax = axes[0, 2]
plot_ode_perturbation_flow(
adata, basis=basis, ax=ax, scale=scale_flow,
cluster_key=cluster_key, colors=colors,
title=f'ODE Perturbation Flow\n({perturb_str})'
)
# Row 1, Col 0: Inner product
ax = axes[1, 0]
if 'ode_perturbation_inner_product' in adata.obs:
try:
norm = mpl_colors.TwoSlopeNorm(vmin=-vm, vcenter=0, vmax=vm)
except Exception:
norm = mpl_colors.Normalize(vmin=-vm, vmax=vm)
sc = ax.scatter(embedding[:, 0], embedding[:, 1],
c=adata.obs['ode_perturbation_inner_product'],
cmap='RdBu_r', norm=norm, s=15)
plt.colorbar(sc, ax=ax, shrink=0.6)
ax.set_title('Inner Product (ODE)', fontsize=12, fontweight='bold')
ax.axis('off')
else:
ax.text(0.5, 0.5, 'No inner product', ha='center', va='center',
transform=ax.transAxes)
ax.axis('off')
# Row 1, Col 1: Trajectory examples
ax = axes[1, 1]
genes = get_genes_used(adata)
gene_names = adata.var.index[genes]
cluster = list(wt_trajectories.keys())[0]
wt = wt_trajectories[cluster]
pert = perturbed_trajectories[cluster]
delta_final = np.abs(pert[-1] - wt[-1])
top_gene_idx = np.argsort(delta_final)[-3:]
for idx in top_gene_idx:
ax.plot(t_span, wt[:, idx], '-', label=f'{gene_names[idx]} (WT)')
ax.plot(t_span, pert[:, idx], '--', label=f'{gene_names[idx]} (Pert)')
ax.set_xlabel('Time')
ax.set_ylabel('Expression')
ax.set_title(f'Trajectory ({cluster})', fontsize=12, fontweight='bold')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
# Row 1, Col 2: Inner product by cluster
ax = axes[1, 2]
if 'ode_perturbation_inner_product' in adata.obs:
df = pd.DataFrame({
'Cluster': adata.obs[cluster_key],
'Inner Product': adata.obs['ode_perturbation_inner_product']
})
cluster_order = df.groupby('Cluster')['Inner Product'].median().sort_values().index
palette = [colors.get(c, 'gray') for c in cluster_order] if colors else None
sns.boxplot(data=df, x='Cluster', y='Inner Product',
order=cluster_order, palette=palette, ax=ax)
ax.axhline(0, color='gray', linestyle='--', alpha=0.5)
ax.set_title('Inner Product by Cluster', fontsize=12, fontweight='bold')
if len(cluster_order) > 5:
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
else:
ax.axis('off')
fig.suptitle(f'ODE Perturbation Analysis: {perturb_str}', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
return fig