"""Dimensionality reduction and energy landscape embedding."""
from typing import Tuple, Optional
import numpy as np
import pickle
from anndata import AnnData
from .._utils.math import soften, sigmoid, int_sig_act_inv
from .._utils.io import get_matrix, to_numpy, get_genes_used
[docs]
def compute_umap(
adata: AnnData,
spliced_key: str = 'Ms',
n_neighbors: int = 30,
min_dist: float = 0.1,
basis: str = 'umap',
copy: bool = False
) -> Optional[AnnData]:
"""
Compute UMAP embedding from gene expression data.
Performs dimensionality reduction using UMAP on the selected gene expression
layer. The UMAP model is stored in adata.uns['scHopfield']['embedding'] and
the 2D coordinates are stored in adata.obsm[f'X_{basis}'].
Parameters
----------
adata : AnnData
Annotated data object
spliced_key : str, optional (default: 'Ms')
Key in adata.layers for expression data to use
n_neighbors : int, optional (default: 30)
Number of neighbors for UMAP
min_dist : float, optional (default: 0.1)
Minimum distance parameter for UMAP
basis : str, optional (default: 'umap')
Name for the embedding basis (stored as 'X_{basis}' in obsm)
copy : bool, optional (default: False)
Whether to return a copy or modify in place
Returns
-------
Optional[AnnData]
Returns AnnData if copy=True, otherwise modifies in place and returns None
"""
adata = adata.copy() if copy else adata
genes = get_genes_used(adata)
X = to_numpy(get_matrix(adata, spliced_key, genes=genes))
import umap
emb = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, n_components=2)
cells2d = emb.fit_transform(X)
adata.uns['scHopfield']['embedding'] = emb
adata.obsm[f'X_{basis}'] = cells2d
return adata if copy else None
[docs]
def energy_embedding(
adata: AnnData,
basis: str = 'umap',
resolution: int = 50,
cluster_key: str = 'cell_type',
degradation_key: str = 'gamma',
copy: bool = False
) -> Optional[AnnData]:
"""
Compute energy landscape on 2D embedding space.
For each cluster, creates a grid in the embedding space and computes the
Hopfield energy at each grid point. The grid is transformed to the original
high-dimensional gene expression space using the inverse UMAP transform,
and energies are computed using the cluster-specific interaction matrices.
Stores grid coordinates and energy values in adata.uns['scHopfield'] with
keys: 'grid_X_{cluster}', 'grid_Y_{cluster}', 'grid_energy_{cluster}',
'grid_energy_interaction_{cluster}', 'grid_energy_degradation_{cluster}',
'grid_energy_bias_{cluster}'.
Parameters
----------
adata : AnnData
Annotated data object with computed UMAP embedding
basis : str, optional (default: 'umap')
Name of the embedding basis to use (from obsm['X_{basis}'])
resolution : int, optional (default: 50)
Number of grid points per dimension (creates resolution x resolution grid)
cluster_key : str, optional (default: 'cell_type')
Key in adata.obs for cluster labels
degradation_key : str, optional (default: 'gamma')
Key in adata.var for degradation rates (fallback if cluster-specific not found)
copy : bool, optional (default: False)
Whether to return a copy or modify in place
Returns
-------
Optional[AnnData]
Returns AnnData if copy=True, otherwise modifies in place and returns None
"""
adata = adata.copy() if copy else adata
genes = get_genes_used(adata)
cells2d = adata.obsm[f'X_{basis}']
embedding = adata.uns['scHopfield']['embedding']
clusters = adata.obs[cluster_key].unique()
# Generate grids
grid_X, grid_Y = {}, {}
for cluster in clusters:
cidx = (adata.obs[cluster_key] == cluster).values
minx, miny = np.min(cells2d[cidx], axis=0)
maxx, maxy = np.max(cells2d[cidx], axis=0)
grid_X[cluster], grid_Y[cluster] = np.mgrid[minx:maxx:resolution*1j, miny:maxy:resolution*1j]
# Transform to high-D space
all_grid_points = np.vstack([np.c_[grid_X[k].ravel(), grid_Y[k].ravel()] for k in clusters])
highD_grid = embedding.inverse_transform(all_grid_points)
highD_grid = np.maximum(highD_grid, 0)
adata.varm['highD_grid'] = highD_grid
# Compute energies on grid
threshold = adata.var['sigmoid_threshold'].values[genes]
exponent = adata.var['sigmoid_exponent'].values[genes]
for i, cluster in enumerate(clusters):
start_idx = i * resolution**2
end_idx = (i + 1) * resolution**2
x_grid = highD_grid[start_idx:end_idx]
sig_grid = sigmoid(x_grid, threshold[None, :], exponent[None, :])
W = adata.varp[f'W_{cluster}']
bias_vector = adata.var[f'I_{cluster}'].values[genes]
gamma_key = f'gamma_{cluster}'
g = adata.var[gamma_key].values[genes] if gamma_key in adata.var else adata.var[degradation_key].values[genes]
e_int = -0.5 * np.sum((sig_grid @ W.T) * sig_grid, axis=1)
integral = int_sig_act_inv(sig_grid, threshold, exponent)
e_deg = np.sum(g[None, :] * integral, axis=1)
e_bias = -np.sum(bias_vector[None, :] * sig_grid, axis=1)
e_total = e_int + e_deg + e_bias
shape = grid_X[cluster].shape
adata.uns['scHopfield'][f'grid_X_{cluster}'] = grid_X[cluster]
adata.uns['scHopfield'][f'grid_Y_{cluster}'] = grid_Y[cluster]
adata.uns['scHopfield'][f'grid_energy_{cluster}'] = soften(e_total.reshape(shape))
adata.uns['scHopfield'][f'grid_energy_interaction_{cluster}'] = soften(e_int.reshape(shape))
adata.uns['scHopfield'][f'grid_energy_degradation_{cluster}'] = soften(e_deg.reshape(shape))
adata.uns['scHopfield'][f'grid_energy_bias_{cluster}'] = soften(e_bias.reshape(shape))
return adata if copy else None
[docs]
def save_embedding(adata: AnnData, filename: str):
"""
Save UMAP embedding and energy grid data to file.
Saves the UMAP model, high-dimensional grid points, and grid coordinates
to a pickle file for later loading.
Parameters
----------
adata : AnnData
Annotated data object with computed embedding
filename : str
Path to save the embedding data (will be saved as pickle file)
"""
emb_data = {
'embedding': adata.uns['scHopfield']['embedding'],
'highD_grid': adata.varm['highD_grid']
}
for key in adata.uns['scHopfield'].keys():
if key.startswith('grid_X_') or key.startswith('grid_Y_'):
emb_data[key] = adata.uns['scHopfield'][key]
with open(filename, 'wb') as f:
pickle.dump(emb_data, f, pickle.HIGHEST_PROTOCOL)
[docs]
def load_embedding(
adata: AnnData,
filename: str,
basis: str = 'umap',
copy: bool = False
) -> Optional[AnnData]:
"""
Load UMAP embedding and energy grid data from file.
Loads the UMAP model and grid data from a pickle file saved with
save_embedding(). Also transforms the current expression data using
the loaded UMAP model to get 2D coordinates.
Parameters
----------
adata : AnnData
Annotated data object to load embedding into
filename : str
Path to the saved embedding pickle file
basis : str, optional (default: 'umap')
Name for the embedding basis (stored as 'X_{basis}' in obsm)
copy : bool, optional (default: False)
Whether to return a copy or modify in place
Returns
-------
Optional[AnnData]
Returns AnnData if copy=True, otherwise modifies in place and returns None
"""
adata = adata.copy() if copy else adata
with open(filename, 'rb') as f:
emb_data = pickle.load(f)
adata.uns['scHopfield']['embedding'] = emb_data['embedding']
adata.varm['highD_grid'] = emb_data['highD_grid']
for key in emb_data.keys():
if key.startswith('grid_X_') or key.startswith('grid_Y_'):
adata.uns['scHopfield'][key] = emb_data[key]
genes = get_genes_used(adata)
X = to_numpy(get_matrix(adata, 'Ms', genes=genes))
cells2d = emb_data['embedding'].transform(X)
adata.obsm[f'X_{basis}'] = cells2d
return adata if copy else None
def project_to_embedding(
adata: 'AnnData',
vectors: np.ndarray,
basis: str = 'umap',
method: str = 'dot_product',
n_neighbors: int = 30,
n_jobs: int = 4,
spliced_key: Optional[str] = None,
# CellOracle correlation specific parameters
sigma_corr: float = 0.05,
correlation_mode: str = 'sampled',
sampled_fraction: float = 0.3,
sampling_probs: Tuple[float, float] = (0.5, 0.1),
random_seed: int = 42,
verbose: bool = False
) -> np.ndarray:
"""
Project gene-space vectors to embedding space.
Parameters
----------
adata : AnnData
Annotated data with embedding
vectors : np.ndarray
Gene-space vectors (n_cells, n_genes) - can be velocities, delta_X, etc.
basis : str, optional (default: 'umap')
Embedding to project onto (key in adata.obsm as X_{basis})
method : str, optional (default: 'dot_product')
Projection method:
- 'dot_product': Gene-space KNN, weights neighbors by vector alignment.
- 'correlation': Embedding-space KNN, CellOracle-style correlation.
n_neighbors : int, optional (default: 30)
Number of neighbors for projection (Note: correlation method often uses ~200)
spliced_key : str, optional
Key for expression data. Defaults to scHopfield's spliced_key or 'Ms'.
"""
from sklearn.neighbors import NearestNeighbors
from scipy import sparse
from scipy.sparse import issparse
embedding_key = f'X_{basis}'
if embedding_key not in adata.obsm:
raise ValueError(f"Embedding {embedding_key} not found in adata.obsm")
embedding = adata.obsm[embedding_key]
n_cells = embedding.shape[0]
genes = get_genes_used(adata)
if spliced_key is None:
spliced_key = adata.uns.get('scHopfield', {}).get('spliced_key', 'Ms')
if spliced_key in adata.layers:
X = adata.layers[spliced_key][:, genes]
else:
X = adata.X[:, genes]
if issparse(X):
X = X.toarray()
if issparse(vectors):
vectors = vectors.toarray()
# -------------------------------------------------------------------
# Method 1: Gene-Space Dot Product Alignment
# -------------------------------------------------------------------
if method == 'dot_product':
if verbose:
print("Projecting using gene-space dot product alignment...")
nn = NearestNeighbors(n_neighbors=n_neighbors + 1, n_jobs=n_jobs)
nn.fit(X)
distances, indices = nn.kneighbors(X)
embedding_vectors = np.zeros((n_cells, 2), dtype=np.float32)
for i in range(n_cells):
neighbors = indices[i, 1:] # Exclude self
dX = X[neighbors] - X[i]
alignment = (vectors[i] * dX).sum(axis=1)
dists = distances[i, 1:]
weights = np.exp(-dists / (np.median(dists) + 1e-10))
weights = weights * np.maximum(alignment, 0)
weights = weights / (weights.sum() + 1e-10)
dE = embedding[neighbors] - embedding[i]
embedding_vectors[i] = (weights[:, None] * dE).sum(axis=0)
return embedding_vectors
# -------------------------------------------------------------------
# Method 2: Embedding-Space Correlation (CellOracle)
# -------------------------------------------------------------------
elif method == 'correlation':
if verbose:
print("Projecting using embedding-space correlation (CellOracle style)...")
np.random.seed(random_seed)
nn = NearestNeighbors(n_neighbors=n_neighbors + 1, n_jobs=n_jobs)
nn.fit(embedding)
embedding_knn = nn.kneighbors_graph(mode="connectivity")
if correlation_mode == 'sampled':
neigh_ixs = embedding_knn.indices.reshape((-1, n_neighbors + 1))
p = np.linspace(sampling_probs[0], sampling_probs[1], neigh_ixs.shape[1])
p = p / p.sum()
n_sampled = int(sampled_fraction * (n_neighbors + 1))
sampling_ixs = np.stack([
np.random.choice(neigh_ixs.shape[1], size=n_sampled, replace=False, p=p)
for _ in range(n_cells)
], axis=0)
neigh_ixs = neigh_ixs[np.arange(n_cells)[:, None], sampling_ixs]
# Computes correlation for sampled neighbors
corrcoef = _calculate_correlation_sampled(X, vectors, neigh_ixs, verbose=verbose)
nonzero = n_cells * n_sampled
embedding_knn_used = sparse.csr_matrix(
(np.ones(nonzero), neigh_ixs.ravel(), np.arange(0, nonzero + 1, n_sampled)),
shape=(n_cells, n_cells)
)
elif correlation_mode == 'full':
corrcoef = _calculate_correlation_full(X, vectors, verbose=verbose)
np.fill_diagonal(corrcoef, 0)
embedding_knn_used = embedding_knn
else:
raise ValueError(f"Unknown correlation_mode: {correlation_mode}")
if np.any(np.isnan(corrcoef)):
corrcoef[np.isnan(corrcoef)] = 1
if verbose:
print("Warning: NaNs in correlation matrix corrected to 1s.")
knn_array = embedding_knn_used.toarray()
transition_prob = np.exp(corrcoef / sigma_corr) * knn_array
transition_prob /= transition_prob.sum(axis=1, keepdims=True) + 1e-10
return _calculate_embedding_shift(embedding, transition_prob, knn_array)
else:
raise ValueError(f"Unknown method: '{method}'. Use 'dot_product' or 'correlation'.")
# =============================================================================
# Private helper functions
# =============================================================================
def _calculate_embedding_shift(
embedding: np.ndarray,
transition_prob: np.ndarray,
knn_array: np.ndarray
) -> np.ndarray:
"""
Calculate embedding shift from transition probabilities.
Follows CellOracle's calculate_embedding_shift logic.
"""
# Unitary vectors from each cell to all other cells
unitary_vectors = embedding.T[:, None, :] - embedding.T[:, :, None]
# Normalize to unit vectors
with np.errstate(divide='ignore', invalid='ignore'):
norms = np.linalg.norm(unitary_vectors, ord=2, axis=0)
unitary_vectors = unitary_vectors / (norms + 1e-10)
np.fill_diagonal(unitary_vectors[0], 0)
np.fill_diagonal(unitary_vectors[1], 0)
# Weighted sum of directions
delta_embedding = (transition_prob * unitary_vectors).sum(axis=2)
# Subtract baseline
knn_sum = knn_array.sum(axis=1, keepdims=True)
baseline = (knn_array * unitary_vectors).sum(axis=2) / (knn_sum.T + 1e-10)
delta_embedding = delta_embedding - baseline
return delta_embedding.T
def _calculate_correlation_sampled(
X: np.ndarray,
delta_X: np.ndarray,
neigh_ixs: np.ndarray,
verbose: bool = True
) -> np.ndarray:
"""
Calculate correlation between delta_X and neighbor expression differences.
For each cell i and its sampled neighbors j, computes:
corr(delta_X[i], X[j] - X[i])
"""
from tqdm.auto import tqdm
n_cells, n_neighbors = neigh_ixs.shape
corrcoef = np.zeros((n_cells, n_cells), dtype=np.float32)
iterator = range(n_cells)
if verbose:
iterator = tqdm(iterator, desc="Calculating correlations (sampled)")
for i in iterator:
neighbors = neigh_ixs[i]
diffs = X[neighbors] - X[i]
corrs = _pearson_correlation_rows(delta_X[i:i+1], diffs)
for j_idx, j in enumerate(neighbors):
if j != i:
corrcoef[i, j] = corrs[j_idx]
return corrcoef
def _calculate_correlation_full(
X: np.ndarray,
delta_X: np.ndarray,
verbose: bool = True
) -> np.ndarray:
"""
Calculate full correlation matrix between delta_X and expression differences.
"""
from tqdm.auto import tqdm
n_cells = X.shape[0]
corrcoef = np.zeros((n_cells, n_cells), dtype=np.float32)
iterator = range(n_cells)
if verbose:
iterator = tqdm(iterator, desc="Calculating correlations (full)")
for i in iterator:
diffs = X - X[i]
corrs = _pearson_correlation_rows(delta_X[i:i+1], diffs)
corrcoef[i, :] = corrs
return corrcoef
def _pearson_correlation_rows(a: np.ndarray, B: np.ndarray) -> np.ndarray:
"""
Compute Pearson correlation between vector a and each row of matrix B.
"""
a_centered = a - a.mean()
B_centered = B - B.mean(axis=1, keepdims=True)
ss_a = np.sum(a_centered ** 2)
ss_B = np.sum(B_centered ** 2, axis=1)
if ss_a < 1e-10:
return np.zeros(B.shape[0], dtype=np.float32)
numerator = (a_centered @ B_centered.T).flatten()
denominator = np.sqrt(ss_a) * np.sqrt(ss_B)
with np.errstate(divide='ignore', invalid='ignore'):
corrs = numerator / (denominator + 1e-10)
corrs[ss_B < 1e-10] = 0
return corrs