"""Velocity computation and validation."""
import numpy as np
from typing import Optional, Union
from anndata import AnnData
from .._utils.io import get_matrix, to_numpy, get_genes_used, ensure_sigmoid_layer
from .._utils.math import sigmoid
[docs]
def compute_reconstructed_velocity(
adata: AnnData,
cluster: Optional[str] = None,
spliced_key: str = 'Ms',
degradation_key: str = 'gamma',
cluster_key: str = 'cell_type',
layer_key: Optional[str] = None,
copy: bool = False
) -> Union[AnnData, np.ndarray]:
"""
Compute reconstructed velocity from Hopfield model.
The velocity is computed as: v = W @ sigmoid(X) - gamma * X + I
Parameters
----------
adata : AnnData
Annotated data object with fitted interactions
cluster : str, optional
Cluster to compute velocity for. If None, computes for all cells
using their respective cluster parameters
spliced_key : str, optional (default: 'Ms')
Key in adata.layers for spliced counts
degradation_key : str, optional (default: 'gamma')
Key in adata.var for degradation rates
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
layer_key : str, optional
If provided, stores reconstructed velocity in adata.layers[layer_key]
If cluster is specified, uses f'{layer_key}_{cluster}'
copy : bool, optional (default: False)
If True, return a copy instead of modifying in-place
Returns
-------
AnnData or np.ndarray
If layer_key is provided and copy=False: None (modifies adata in-place)
If layer_key is provided and copy=True: modified copy of adata
If layer_key is None: np.ndarray with reconstructed velocities
"""
if layer_key is not None and copy:
adata = adata.copy()
ensure_sigmoid_layer(adata, spliced_key)
genes = get_genes_used(adata)
n_genes = len(genes)
if cluster is not None:
# Compute for specific cluster
cluster_mask = (adata.obs[cluster_key] == cluster).values
cluster_indices = np.where(cluster_mask)[0]
# Get cluster-specific parameters
W = adata.varp[f'W_{cluster}']
I_vec = adata.var[f'I_{cluster}'].values[genes] if f'I_{cluster}' in adata.var.columns else np.zeros(n_genes, dtype=np.float32)
gamma_col = f'gamma_{cluster}'
gamma_vec = adata.var[gamma_col].values[genes] if gamma_col in adata.var.columns else adata.var[degradation_key].values[genes]
# Get expression data
X = to_numpy(get_matrix(adata, spliced_key, genes=genes)[cluster_mask])
sigmoid_vals = to_numpy(get_matrix(adata, 'sigmoid', genes=genes)[cluster_mask])
# Compute velocity: W @ sigmoid(X) - gamma * X + I
reconstructed_v = (W @ sigmoid_vals.T).T - gamma_vec * X + I_vec
if layer_key is not None:
# Store in layer
key = f'{layer_key}_{cluster}'
if key not in adata.layers:
adata.layers[key] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.float32)
adata.layers[key][cluster_indices[:, None], genes[None, :]] = reconstructed_v
return adata if copy else None
else:
return reconstructed_v
else:
# Compute for all cells using their respective cluster parameters
clusters = adata.obs[cluster_key].unique()
reconstructed_v = np.zeros((adata.n_obs, n_genes), dtype=np.float32)
for clust in clusters:
cluster_mask = (adata.obs[cluster_key] == clust).values
cluster_indices = np.where(cluster_mask)[0]
# Get cluster-specific parameters
W = adata.varp[f'W_{clust}']
I_vec = adata.var[f'I_{clust}'].values[genes] if f'I_{clust}' in adata.var.columns else np.zeros(n_genes, dtype=np.float32)
gamma_col = f'gamma_{clust}'
gamma_vec = adata.var[gamma_col].values[genes] if gamma_col in adata.var.columns else adata.var[degradation_key].values[genes]
# Get expression data
X = to_numpy(get_matrix(adata, spliced_key, genes=genes)[cluster_mask])
sigmoid_vals = to_numpy(get_matrix(adata, 'sigmoid', genes=genes)[cluster_mask])
# Compute velocity
reconstructed_v[cluster_indices] = (W @ sigmoid_vals.T).T - gamma_vec * X + I_vec
if layer_key is not None:
# Store in layer
if layer_key not in adata.layers:
adata.layers[layer_key] = np.zeros((adata.n_obs, adata.n_vars), dtype=np.float32)
adata.layers[layer_key][:, genes] = reconstructed_v
return adata if copy else None
else:
return reconstructed_v
[docs]
def validate_velocity(
adata: AnnData,
velocity_key: str = 'velocity',
spliced_key: str = 'Ms',
degradation_key: str = 'gamma',
cluster_key: str = 'cell_type',
return_mse: bool = True
) -> Union[float, dict]:
"""
Validate reconstructed velocity against original velocity.
Computes mean squared error between Hopfield model predictions
and original RNA velocity estimates.
Parameters
----------
adata : AnnData
Annotated data object with fitted interactions and velocity
velocity_key : str, optional (default: 'velocity')
Key in adata.layers for original velocity
spliced_key : str, optional (default: 'Ms')
Key in adata.layers for spliced counts
degradation_key : str, optional (default: 'gamma')
Key in adata.var for degradation rates
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
return_mse : bool, optional (default: True)
If True, returns overall MSE. If False, returns dict with
per-cluster MSE values
Returns
-------
float or dict
Overall MSE (if return_mse=True) or dict mapping cluster names
to their MSE values (if return_mse=False)
"""
genes = get_genes_used(adata)
clusters = adata.obs[cluster_key].unique()
if not return_mse:
mse_dict = {}
total_squared_error = 0
total_elements = 0
for cluster in clusters:
# Compute reconstructed velocity for this cluster
reconstructed_v = compute_reconstructed_velocity(
adata,
cluster=cluster,
spliced_key=spliced_key,
degradation_key=degradation_key,
cluster_key=cluster_key
)
# Get original velocity
cluster_mask = (adata.obs[cluster_key] == cluster).values
original_v = to_numpy(get_matrix(adata, velocity_key, genes=genes)[cluster_mask])
# Compute squared error
squared_error = (reconstructed_v - original_v) ** 2
cluster_mse = np.mean(squared_error)
if not return_mse:
mse_dict[cluster] = cluster_mse
total_squared_error += np.sum(squared_error)
total_elements += squared_error.size
if return_mse:
return total_squared_error / total_elements
else:
return mse_dict
def compute_velocity(
adata: AnnData,
X: Optional[np.ndarray] = None,
cluster: Optional[str] = None,
cluster_key: str = 'cell_type',
use_cluster_specific: bool = True,
spliced_key: str = 'Ms',
) -> np.ndarray:
"""
Compute Hopfield velocity at given expression state.
v = W @ sigmoid(X) - gamma * X + I
This is a unified function that replaces the various velocity computation
functions that were previously in plotting/flow.py.
Parameters
----------
adata : AnnData
Annotated data object with fitted interactions
X : np.ndarray, optional
Expression matrix (n_cells, n_genes) to compute velocity at.
If None, uses expression from adata.layers[spliced_key].
cluster : str, optional
Specific cluster to use parameters from. If None and use_cluster_specific=True,
iterates over all clusters using their respective parameters.
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
use_cluster_specific : bool, optional (default: True)
If True, use cluster-specific W/I/gamma. If False, use 'all' parameters.
spliced_key : str, optional (default: 'Ms')
Key in adata.layers for spliced counts (used if X is None)
Returns
-------
np.ndarray
Velocity matrix (n_cells, n_genes)
"""
genes_mask = get_genes_used(adata)
gene_names = adata.var.index[genes_mask]
n_genes = len(gene_names)
# Get sigmoid parameters
threshold = adata.var.loc[gene_names, 'sigmoid_threshold'].values
exponent = adata.var.loc[gene_names, 'sigmoid_exponent'].values
# Handle X input
# When X is provided, it's used directly (caller is responsible for matching cells)
# When X is None and cluster is specified, slice to only that cluster's cells
# When X is None and cluster is None, use all cells
if X is not None:
# X provided by caller - use as-is
n_cells = X.shape[0]
elif cluster is not None and cluster != 'all':
# Specific cluster requested, slice to only those cells
cluster_mask = (adata.obs[cluster_key] == cluster).values
X_full = get_matrix(adata, spliced_key, genes=genes_mask)
X = to_numpy(X_full[cluster_mask])
else:
# All cells
X_full = get_matrix(adata, spliced_key, genes=genes_mask)
X = np.nan_to_num(X)
# Determine clusters to iterate over
if cluster is not None:
# Single cluster specified - use its parameters for all cells in X
clusters = [cluster]
elif use_cluster_specific:
# All clusters - iterate and apply each cluster's parameters to its cells
clusters = adata.obs[cluster_key].unique().tolist()
else:
# Use 'all' parameters for everything
clusters = ['all']
velocity = np.zeros((n_cells, n_genes), dtype=np.float32)
for clust in clusters:
# Determine which cells to process
if cluster is not None:
# Specific cluster: X already contains only those cells, process all
clust_mask = np.ones(n_cells, dtype=bool)
elif clust == 'all':
# 'all' parameters: process all cells
clust_mask = np.ones(n_cells, dtype=bool)
else:
# Iterating over clusters: get mask for this cluster
clust_mask = (adata.obs[cluster_key] == clust).values
if not np.any(clust_mask):
continue
# Get parameters for this cluster
W_key = f'W_{clust}'
I_key = f'I_{clust}'
gamma_key = f'gamma_{clust}'
if W_key not in adata.varp:
raise ValueError(f"W matrix '{W_key}' not found in adata.varp")
W = adata.varp[W_key]
# Slice W if it's full size
if W.shape[0] == adata.n_vars:
W = W[np.ix_(genes_mask, genes_mask)]
# Get I vector
if I_key in adata.var.columns:
I_vec = adata.var.loc[gene_names, I_key].values
else:
I_vec = np.zeros(n_genes, dtype=np.float32)
# Get gamma
if gamma_key in adata.var.columns:
gamma = adata.var.loc[gene_names, gamma_key].values
elif 'gamma' in adata.var.columns:
gamma = adata.var.loc[gene_names, 'gamma'].values
else:
gamma = np.ones(n_genes, dtype=np.float32)
# Compute velocity: v = W @ sigmoid(X) - gamma * X + I
X_clust = X[clust_mask]
sig_X = sigmoid(X_clust, threshold, exponent)
v_clust = (sig_X @ W.T) - (gamma * X_clust) + I_vec
# Store results
velocity[clust_mask] = v_clust
return velocity
def compute_velocity_delta(
adata: AnnData,
perturbed_key: str = 'simulated_count',
original_key: str = 'Ms',
cluster_key: str = 'cell_type',
use_cluster_specific: bool = True,
) -> np.ndarray:
"""
Compute velocity difference between perturbed and original states.
Returns v(X_perturbed) - v(X_original) for each cell.
Parameters
----------
adata : AnnData
Annotated data object with perturbation results
perturbed_key : str, optional (default: 'simulated_count')
Key in adata.layers for perturbed expression
original_key : str, optional (default: 'Ms')
Key in adata.layers for original expression
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
use_cluster_specific : bool, optional (default: True)
If True, use cluster-specific W/I/gamma. If False, use 'all' parameters.
Returns
-------
np.ndarray
Velocity delta matrix (n_cells, n_genes)
"""
if perturbed_key not in adata.layers:
raise ValueError(f"'{perturbed_key}' not found in adata.layers. Run simulation first.")
genes_mask = get_genes_used(adata)
# Get expression matrices
X_orig = to_numpy(get_matrix(adata, original_key, genes=genes_mask))
X_pert = to_numpy(get_matrix(adata, perturbed_key, genes=genes_mask))
# Determine clusters
if use_cluster_specific:
clusters = adata.obs[cluster_key].unique().tolist()
else:
clusters = ['all']
delta_velocity = np.zeros_like(X_orig)
for cluster in clusters:
if cluster == 'all':
mask = np.ones(adata.n_obs, dtype=bool)
else:
mask = (adata.obs[cluster_key] == cluster).values
if not np.any(mask):
continue
# Compute velocity at original and perturbed states
v_orig = compute_velocity(adata, X=X_orig[mask], cluster=cluster)
v_pert = compute_velocity(adata, X=X_pert[mask], cluster=cluster)
delta_velocity[mask] = v_pert - v_orig
return delta_velocity