This page was generated from docs/notebooks/05_perturbation_analysis.ipynb.

Perturbation Analysis

Dynamics Simulation

Simulate gene expression trajectories using the inferred Hopfield network dynamics.

Theory

Gene expression evolves according to:

\[\frac{dx}{dt} = W \cdot \sigma(x) - \gamma \cdot x + I\]

scHopfield integrates this ODE system to simulate cellular dynamics.

Basic Trajectory Simulation

Simulate from a cell’s initial state:

scHopfield supports two complementary strategies for in-silico perturbations:

ODE-based — integrates the full dynamical system

\[\frac{dx}{dt} = W \cdot \sigma(x) - \gamma \cdot x + I\]

with a fixed perturbation (e.g. Gata1 = 0 for knockout).

GRN propagation — propagates expression shifts through the learned network

without full time integration (analogous to the CellOracle approach).

This notebook uses the Paul et al. 2015 hematopoiesis dataset loaded through

the CellOracle tutorial object, which provides a graph-based embedding and

pseudotime ordering.

Topics covered:

  1. Data loading, velocity estimation, sigmoid fitting, GRN inference

  2. ODE-based single-cell trajectories (WT, KO, OE)

  3. Dataset-wide ODE perturbation shift

  4. GRN propagation-based perturbation

  5. Perturbation flow on embedding (CellOracle-style and Hopfield-style)

5.1 Setup

[ ]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scvelo as scv
import torch

import celloracle as co
import scHopfield as sch

warnings.filterwarnings('default')

# Analysis parameters
CLUSTER_KEY    = 'paul15_clusters'
SPLICED_KEY    = 'Ms'
DEGRADATION_KEY = 'gamma'
BASIS          = 'draw_graph_fa'
VELOCITY_KEY   = 'velocity_S'
VELOCITY_SCALE = 500.0
SCAFFOLD_REG   = 1e-1
N_EPOCHS       = 1000
BATCH_SIZE     = 128
DEVICE         = 'cuda' if torch.cuda.is_available() else 'cpu'

# Gene of interest for perturbation
GOI = 'Gata1'

# Paul15 cluster display order
CLUSTER_ORDER = [
    '1Ery', '2Ery', '3Ery', '4Ery', '5Ery', '6Ery', '7MEP',
    '8Mk', '9Mk', '9GMP', '10GMP', '11DC', '12Baso', '13Baso',
    '14Mo', '15Mo', '16Neu', '17Neu', '18Eos', '19Lymph',
]

print(f"Device: {DEVICE}")

/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
which: no R in (/opt/slurm/puppet/bin:/opt/slurm/cluster/ibex/install-v2/RedHat-9/bin:/opt/slurm/scripts/bin:/usr/lpp/mmfs/bin:/home/bernaljp/miniconda3/envs/SCH/bin:/opt/slurm/puppet/bin:/opt/slurm/cluster/ibex/install-v2/RedHat-9/bin:/opt/slurm/scripts/bin:/usr/lpp/mmfs/bin:/home/bernaljp/miniconda3/condabin:/opt/slurm/puppet/bin:/usr/share/Modules/bin:/opt/slurm/cluster/ibex/install-v2/RedHat-9/bin:/opt/slurm/scripts/bin:/usr/lpp/mmfs/bin:/usr/local/bin:/usr/bin:/usr/local/sbin:/usr/sbin:/opt/slurm/scripts/bin:/opt/puppetlabs/bin:/home/bernaljp/.local/bin:/home/bernaljp/bin:/opt/slurm/scripts/bin:/home/bernaljp/.local/bin:/home/bernaljp/bin:/opt/slurm/scripts/bin:/home/bernaljp/.local/bin:/home/bernaljp/bin)
Device: cuda
[27]:
%matplotlib inline

Load CellOracle tutorial data (Paul et al. 2015)

[2]:
oracle_demo = co.data.load_tutorial_oracle_object()
adata = oracle_demo.adata.copy()
print(f"Loaded: {adata.n_obs} cells × {adata.n_vars} genes")
print(f"Cluster key: {CLUSTER_KEY}")
print(f"Cell types: {sorted(adata.obs[CLUSTER_KEY].unique())}")

Loaded: 2671 cells × 1999 genes
Cluster key: paul15_clusters
Cell types: ['10GMP', '11DC', '12Baso', '13Baso', '14Mo', '15Mo', '16Neu', '17Neu', '18Eos', '19Lymph', '1Ery', '2Ery', '3Ery', '4Ery', '5Ery', '6Ery', '7MEP', '8Mk', '9GMP']

Estimate velocities from pseudotime

[3]:
# Prepare moments (scVelo)
adata.layers['spliced'] = adata.layers['normalized_count']
adata.layers['unspliced'] = adata.layers['normalized_count']
scv.pp.moments(adata, n_pcs=30, n_neighbors=30)
_ = adata.layers.pop('unspliced')

# Estimate velocities from pseudotime
sch.pp.estimate_velocity_from_pseudotime(
    adata,
    pseudotime_key='Pseudotime',
    spliced_key=SPLICED_KEY,
    connectivity_key='connectivities',
    scale=VELOCITY_SCALE,
    store_key=VELOCITY_KEY,
)

# Compute velocity graph and embedding for later plotting
scv.tl.velocity_graph(adata, vkey=VELOCITY_KEY, xkey=SPLICED_KEY, n_jobs=-1)
scv.tl.velocity_embedding(adata, basis=BASIS, vkey=VELOCITY_KEY)
adata.obsm[f'velocity_{BASIS}'] = adata.obsm[f'{VELOCITY_KEY}_{BASIS}']

print(f"Velocity estimated with scale={VELOCITY_SCALE}.")

computing neighbors
    finished (0:00:06) --> added
    'distances' and 'connectivities', weighted adjacency matrices (adata.obsp)
computing moments based on connectivities
    finished (0:00:00) --> added
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocity graph (using 32/32 cores)
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/bernaljp/miniconda3/envs/SCH/lib/python3.11/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
    finished (0:00:24) --> added
    'velocity_S_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:00) --> added
    'velocity_S_draw_graph_fa', embedded velocity vectors (adata.obsm)
Velocity estimated with scale=500.0.
[4]:
# Estimate degradation rates (gamma) from velocity / expression ratio
expression = adata.layers[SPLICED_KEY].copy()
velocities  = adata.layers[VELOCITY_KEY]

mean_expr = np.abs(expression).mean(axis=0) + 1e-6
mean_vel  = np.abs(velocities).mean(axis=0)
gamma     = np.clip(mean_vel / mean_expr, 0.1, 10.0)
adata.var[DEGRADATION_KEY] = gamma

print(f"Gamma range: [{gamma.min():.3f}, {gamma.max():.3f}]  |  median={np.median(gamma):.3f}")

Gamma range: [0.100, 0.100]  |  median=0.100

Sigmoid fitting

[5]:
adata.var['scHopfield_used'] = True  # use all genes

sch.pp.fit_all_sigmoids(
    adata,
    genes=adata.var['scHopfield_used'].values,
    spliced_key=SPLICED_KEY,
)
sch.pp.compute_sigmoid(adata, spliced_key=SPLICED_KEY)

mse = adata.var.loc[adata.var['scHopfield_used'], 'sigmoid_mse']
print(f"Sigmoid MSE: mean={mse.mean():.4f}, median={mse.median():.4f}")

/home/bernaljp/packages/scHopfield/scHopfield/_utils/math.py:93: RuntimeWarning: divide by zero encountered in divide
  ty = np.log(y / (1 - y))
/home/bernaljp/packages/scHopfield/scHopfield/_utils/math.py:93: RuntimeWarning: divide by zero encountered in log
  ty = np.log(y / (1 - y))
Sigmoid MSE: mean=0.0025, median=0.0016
[28]:
# Plot sigmoid fits for key regulatory genes
genes_to_plot = [g for g in ['Gata1', 'Klf1', 'Gata2', 'Spi1', 'Cebpa', 'Mpo']
                 if g in adata.var_names]

n_cols = min(3, len(genes_to_plot))
n_rows = (len(genes_to_plot) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows), tight_layout=True)
for i, gene in enumerate(genes_to_plot):
    sch.pl.plot_sigmoid_fit(adata, gene=gene, ax=np.atleast_2d(axes).flatten()[i],
                            spliced_key=SPLICED_KEY, show_zeros=False)
for j in range(i + 1, n_rows * n_cols):
    np.atleast_2d(axes).flatten()[j].axis('off')
plt.suptitle('Sigmoid fits for key regulatory genes', y=1.02)
plt.show()

../_images/notebooks_05_perturbation_analysis_13_0.png

GRN scaffold from CellOracle base GRN

[7]:
base_GRN = co.data.load_mouse_scATAC_atlas_base_GRN()
base_GRN.drop(['peak_id'], axis=1, inplace=True)

gene_names = adata.var.index[adata.var['scHopfield_used'].values]
scaffold = pd.DataFrame(0, index=gene_names, columns=gene_names)

tfs = list(set(base_GRN.columns.str.lower()) & set(scaffold.index.str.lower()))
targets = list(set(base_GRN['gene_short_name'].str.lower().values) & set(scaffold.columns.str.lower()))
index_map = {g.lower(): g for g in scaffold.index}
col_map   = {g.lower(): g for g in scaffold.columns}

for tf in tfs:
    tf_col = [c for c in base_GRN.columns if c.lower() == tf][0]
    for tgt in base_GRN[base_GRN[tf_col] == 1]['gene_short_name']:
        if tgt.lower() in col_map:
            scaffold.loc[index_map[tf], col_map[tgt.lower()]] = 1

print(f"Scaffold: {len(tfs)} TFs, {len(targets)} targets, {int(scaffold.sum().sum())} potential edges")

Scaffold: 90 TFs, 1857 targets, 75325 potential edges

GRN inference

[8]:
sch.inf.fit_interactions(
    adata,
    cluster_key=CLUSTER_KEY,
    spliced_key=SPLICED_KEY,
    velocity_key=VELOCITY_KEY,
    degradation_key=DEGRADATION_KEY,
    n_epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    device=DEVICE,
    refit_gamma=True,
    w_scaffold=scaffold.values.T,
    scaffold_regularization=SCAFFOLD_REG,
    reconstruction_regularization=100,
    bias_regularization=1,
    only_TFs=True,
    w_threshold=1e-12,
    skip_all=True,
    learning_rate=0.1,
    use_plateau_scheduler=True,
    plateau_patience=100,
    plateau_factor=0.1,
    balanced_sampling=True,
    drop_last=True,
    include_neighbors=True,
    neighbor_fraction=0.2,
    get_plots=False,
)

clusters = list(adata.obs[CLUSTER_KEY].unique())
print(f"GRN inference complete for {len(clusters)} clusters.")

Inferring interaction matrix W and bias vector I for cluster 7MEP
  Including 402 neighboring cells (167 cluster + 402 neighbors = 569 total)
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
Training Epochs:   1%|          | 9/1000 [00:00<01:12, 13.62it/s]
[Epoch 1/1000] Total Loss: 609.155510, Reconstruction Loss: 91.877766, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█▏        | 113/1000 [00:02<00:11, 77.17it/s]
[Epoch 101/1000] Total Loss: 139.680561, Reconstruction Loss: 0.456839, LR: 1.00e-01, Batch size: 128
Training Epochs:  21%|██        | 209/1000 [00:03<00:10, 77.90it/s]
[Epoch 201/1000] Total Loss: 12.805837, Reconstruction Loss: 0.011106, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███▏      | 313/1000 [00:04<00:08, 77.77it/s]
[Epoch 301/1000] Total Loss: 1.276333, Reconstruction Loss: 0.001566, LR: 1.00e-03, Batch size: 128
Training Epochs:  41%|████      | 409/1000 [00:05<00:07, 77.83it/s]
[Epoch 401/1000] Total Loss: 0.126107, Reconstruction Loss: 0.001534, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████▏    | 513/1000 [00:07<00:06, 77.72it/s]
[Epoch 501/1000] Total Loss: 0.124727, Reconstruction Loss: 0.001458, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 609/1000 [00:08<00:04, 79.05it/s]
[Epoch 601/1000] Total Loss: 0.013114, Reconstruction Loss: 0.001465, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████▏  | 714/1000 [00:09<00:03, 79.60it/s]
[Epoch 701/1000] Total Loss: 0.002784, Reconstruction Loss: 0.001483, LR: 1.00e-06, Batch size: 128
Training Epochs:  81%|████████  | 811/1000 [00:11<00:02, 79.67it/s]
[Epoch 801/1000] Total Loss: 0.002675, Reconstruction Loss: 0.001483, LR: 1.00e-06, Batch size: 128
Training Epochs:  92%|█████████▏| 915/1000 [00:12<00:01, 78.02it/s]
[Epoch 901/1000] Total Loss: 0.002693, Reconstruction Loss: 0.001496, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:13<00:00, 74.18it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.002727, Reconstruction Loss: 0.001511, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 15Mo
  Including 581 neighboring cells (186 cluster + 581 neighbors = 767 total)
Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]
[Epoch 1/1000] Total Loss: 538.838196, Reconstruction Loss: 62.778238, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 111/1000 [00:01<00:14, 63.26it/s]
[Epoch 101/1000] Total Loss: 139.759558, Reconstruction Loss: 0.440999, LR: 1.00e-01, Batch size: 128
Training Epochs:  21%|██        | 209/1000 [00:03<00:12, 63.30it/s]
[Epoch 201/1000] Total Loss: 12.295334, Reconstruction Loss: 0.016251, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 307/1000 [00:04<00:11, 62.57it/s]
[Epoch 301/1000] Total Loss: 1.128002, Reconstruction Loss: 0.001945, LR: 1.00e-03, Batch size: 128
Training Epochs:  41%|████      | 412/1000 [00:06<00:09, 62.29it/s]
[Epoch 401/1000] Total Loss: 1.152868, Reconstruction Loss: 0.001676, LR: 1.00e-03, Batch size: 128
Training Epochs:  51%|█████     | 510/1000 [00:08<00:07, 62.27it/s]
[Epoch 501/1000] Total Loss: 0.130117, Reconstruction Loss: 0.001528, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 608/1000 [00:09<00:06, 62.10it/s]
[Epoch 601/1000] Total Loss: 0.014284, Reconstruction Loss: 0.001637, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████▏  | 713/1000 [00:11<00:04, 62.21it/s]
[Epoch 701/1000] Total Loss: 0.046023, Reconstruction Loss: 0.033146, LR: 1.00e-05, Batch size: 128
Training Epochs:  81%|████████  | 811/1000 [00:12<00:03, 62.18it/s]
[Epoch 801/1000] Total Loss: 0.002575, Reconstruction Loss: 0.001372, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 909/1000 [00:14<00:01, 62.19it/s]
[Epoch 901/1000] Total Loss: 0.002905, Reconstruction Loss: 0.001677, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:16<00:00, 62.46it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.002899, Reconstruction Loss: 0.001681, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 3Ery
  Including 701 neighboring cells (246 cluster + 701 neighbors = 947 total)
Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]
[Epoch 1/1000] Total Loss: 525.634530, Reconstruction Loss: 68.826119, LR: 1.00e-01, Batch size: 128
Training Epochs:  10%|█         | 105/1000 [00:02<00:20, 44.72it/s]
[Epoch 101/1000] Total Loss: 138.873123, Reconstruction Loss: 0.488235, LR: 1.00e-01, Batch size: 128
Training Epochs:  20%|██        | 205/1000 [00:04<00:18, 44.11it/s]
[Epoch 201/1000] Total Loss: 14.462436, Reconstruction Loss: 0.015452, LR: 1.00e-02, Batch size: 128
Training Epochs:  30%|███       | 305/1000 [00:06<00:15, 44.08it/s]
[Epoch 301/1000] Total Loss: 1.268832, Reconstruction Loss: 0.002057, LR: 1.00e-03, Batch size: 128
Training Epochs:  40%|████      | 405/1000 [00:09<00:13, 44.16it/s]
[Epoch 401/1000] Total Loss: 0.122728, Reconstruction Loss: 0.001337, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████     | 510/1000 [00:11<00:10, 45.02it/s]
[Epoch 501/1000] Total Loss: 0.014770, Reconstruction Loss: 0.001524, LR: 1.00e-05, Batch size: 128
Training Epochs:  61%|██████    | 610/1000 [00:13<00:08, 44.91it/s]
[Epoch 601/1000] Total Loss: 0.014905, Reconstruction Loss: 0.001558, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████   | 710/1000 [00:15<00:06, 45.10it/s]
[Epoch 701/1000] Total Loss: 0.002500, Reconstruction Loss: 0.001272, LR: 1.00e-06, Batch size: 128
Training Epochs:  81%|████████  | 810/1000 [00:18<00:04, 45.13it/s]
[Epoch 801/1000] Total Loss: 0.002746, Reconstruction Loss: 0.001492, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 910/1000 [00:20<00:01, 45.13it/s]
[Epoch 901/1000] Total Loss: 0.003025, Reconstruction Loss: 0.001752, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:22<00:00, 44.77it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.002575, Reconstruction Loss: 0.001333, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 4Ery
  Including 730 neighboring cells (124 cluster + 730 neighbors = 854 total)
Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]
[Epoch 1/1000] Total Loss: 536.244888, Reconstruction Loss: 69.043932, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 108/1000 [00:02<00:16, 52.84it/s]
[Epoch 101/1000] Total Loss: 136.876204, Reconstruction Loss: 0.452873, LR: 1.00e-01, Batch size: 128
Training Epochs:  21%|██        | 210/1000 [00:03<00:14, 52.95it/s]
[Epoch 201/1000] Total Loss: 13.560978, Reconstruction Loss: 0.012148, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 307/1000 [00:05<00:15, 45.74it/s]
[Epoch 301/1000] Total Loss: 1.391287, Reconstruction Loss: 0.001285, LR: 1.00e-03, Batch size: 128
Training Epochs:  41%|████      | 409/1000 [00:07<00:11, 52.96it/s]
[Epoch 401/1000] Total Loss: 0.118935, Reconstruction Loss: 0.001132, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████     | 511/1000 [00:09<00:09, 52.96it/s]
[Epoch 501/1000] Total Loss: 0.119260, Reconstruction Loss: 0.001161, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 607/1000 [00:11<00:07, 52.94it/s]
[Epoch 601/1000] Total Loss: 0.122002, Reconstruction Loss: 0.001189, LR: 1.00e-04, Batch size: 128
Training Epochs:  71%|███████   | 709/1000 [00:13<00:05, 52.92it/s]
[Epoch 701/1000] Total Loss: 0.012550, Reconstruction Loss: 0.001074, LR: 1.00e-05, Batch size: 128
Training Epochs:  81%|████████  | 811/1000 [00:15<00:03, 52.91it/s]
[Epoch 801/1000] Total Loss: 0.002449, Reconstruction Loss: 0.001267, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 907/1000 [00:17<00:01, 52.91it/s]
[Epoch 901/1000] Total Loss: 0.002390, Reconstruction Loss: 0.001174, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:19<00:00, 52.49it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.002350, Reconstruction Loss: 0.001144, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 2Ery
  Including 482 neighboring cells (329 cluster + 482 neighbors = 811 total)
Training Epochs:   1%|          | 6/1000 [00:00<00:19, 51.02it/s]
[Epoch 1/1000] Total Loss: 512.708354, Reconstruction Loss: 53.270665, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 108/1000 [00:02<00:16, 52.64it/s]
[Epoch 101/1000] Total Loss: 136.463707, Reconstruction Loss: 0.385647, LR: 1.00e-01, Batch size: 128
Training Epochs:  20%|██        | 205/1000 [00:04<00:18, 43.12it/s]
[Epoch 201/1000] Total Loss: 13.812432, Reconstruction Loss: 0.013796, LR: 1.00e-02, Batch size: 128
Training Epochs:  30%|███       | 305/1000 [00:06<00:16, 41.94it/s]
[Epoch 301/1000] Total Loss: 1.424286, Reconstruction Loss: 0.003600, LR: 1.00e-03, Batch size: 128
Training Epochs:  40%|████      | 405/1000 [00:08<00:14, 41.93it/s]
[Epoch 401/1000] Total Loss: 0.122108, Reconstruction Loss: 0.004466, LR: 1.00e-04, Batch size: 128
Training Epochs:  50%|█████     | 505/1000 [00:11<00:11, 41.95it/s]
[Epoch 501/1000] Total Loss: 0.013510, Reconstruction Loss: 0.002228, LR: 1.00e-05, Batch size: 128
Training Epochs:  60%|██████    | 605/1000 [00:13<00:09, 41.87it/s]
[Epoch 601/1000] Total Loss: 0.014668, Reconstruction Loss: 0.003431, LR: 1.00e-05, Batch size: 128
Training Epochs:  70%|███████   | 705/1000 [00:15<00:07, 41.98it/s]
[Epoch 701/1000] Total Loss: 0.003955, Reconstruction Loss: 0.002734, LR: 1.00e-06, Batch size: 128
Training Epochs:  80%|████████  | 805/1000 [00:18<00:04, 41.93it/s]
[Epoch 801/1000] Total Loss: 0.004509, Reconstruction Loss: 0.003286, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 908/1000 [00:20<00:01, 48.74it/s]
[Epoch 901/1000] Total Loss: 0.003324, Reconstruction Loss: 0.002104, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:22<00:00, 44.60it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.003930, Reconstruction Loss: 0.002725, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 17Neu
  Including 127 neighboring cells (22 cluster + 127 neighbors = 149 total)
Training Epochs:   3%|▎         | 30/1000 [00:00<00:03, 299.04it/s]
[Epoch 1/1000] Total Loss: 621.897522, Reconstruction Loss: 41.692261, LR: 1.00e-01, Batch size: 128
Training Epochs:  15%|█▌        | 154/1000 [00:00<00:02, 305.23it/s]
[Epoch 101/1000] Total Loss: 136.151749, Reconstruction Loss: 0.710504, LR: 1.00e-01, Batch size: 128
Training Epochs:  25%|██▍       | 247/1000 [00:00<00:02, 305.48it/s]
[Epoch 201/1000] Total Loss: 33.380272, Reconstruction Loss: 0.131969, LR: 1.00e-02, Batch size: 128
Training Epochs:  34%|███▍      | 340/1000 [00:01<00:02, 305.42it/s]
[Epoch 301/1000] Total Loss: 11.830083, Reconstruction Loss: 0.054620, LR: 1.00e-02, Batch size: 128
Training Epochs:  43%|████▎     | 434/1000 [00:01<00:01, 302.46it/s]
[Epoch 401/1000] Total Loss: 12.702967, Reconstruction Loss: 0.040102, LR: 1.00e-02, Batch size: 128
Training Epochs:  56%|█████▌    | 561/1000 [00:01<00:01, 310.21it/s]
[Epoch 501/1000] Total Loss: 1.765706, Reconstruction Loss: 0.029421, LR: 1.00e-03, Batch size: 128
Training Epochs:  66%|██████▌   | 657/1000 [00:02<00:01, 311.45it/s]
[Epoch 601/1000] Total Loss: 1.327378, Reconstruction Loss: 0.020119, LR: 1.00e-03, Batch size: 128
Training Epochs:  75%|███████▌  | 753/1000 [00:02<00:00, 311.80it/s]
[Epoch 701/1000] Total Loss: 1.350153, Reconstruction Loss: 0.043456, LR: 1.00e-03, Batch size: 128
Training Epochs:  85%|████████▍ | 849/1000 [00:02<00:00, 312.22it/s]
[Epoch 801/1000] Total Loss: 0.681535, Reconstruction Loss: 0.019119, LR: 1.00e-04, Batch size: 128
Training Epochs:  94%|█████████▍| 945/1000 [00:03<00:00, 312.29it/s]
[Epoch 901/1000] Total Loss: 0.164094, Reconstruction Loss: 0.033097, LR: 1.00e-04, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:03<00:00, 308.11it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.168568, Reconstruction Loss: 0.042523, LR: 1.00e-04, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 14Mo
  Including 774 neighboring cells (373 cluster + 774 neighbors = 1147 total)
Training Epochs:   0%|          | 4/1000 [00:00<00:26, 37.34it/s]
[Epoch 1/1000] Total Loss: 470.605244, Reconstruction Loss: 42.556326, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 108/1000 [00:02<00:22, 39.50it/s]
[Epoch 101/1000] Total Loss: 140.323347, Reconstruction Loss: 0.480498, LR: 1.00e-01, Batch size: 128
Training Epochs:  21%|██        | 208/1000 [00:05<00:20, 39.56it/s]
[Epoch 201/1000] Total Loss: 13.552452, Reconstruction Loss: 0.020117, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 308/1000 [00:07<00:17, 39.56it/s]
[Epoch 301/1000] Total Loss: 1.248959, Reconstruction Loss: 0.022513, LR: 1.00e-03, Batch size: 128
Training Epochs:  41%|████      | 408/1000 [00:10<00:14, 39.57it/s]
[Epoch 401/1000] Total Loss: 0.162997, Reconstruction Loss: 0.039692, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████     | 508/1000 [00:12<00:12, 39.56it/s]
[Epoch 501/1000] Total Loss: 0.191767, Reconstruction Loss: 0.073711, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 608/1000 [00:15<00:09, 39.48it/s]
[Epoch 601/1000] Total Loss: 0.145229, Reconstruction Loss: 0.021892, LR: 1.00e-04, Batch size: 128
Training Epochs:  71%|███████   | 708/1000 [00:17<00:07, 38.73it/s]
[Epoch 701/1000] Total Loss: 0.068800, Reconstruction Loss: 0.056120, LR: 1.00e-05, Batch size: 128
Training Epochs:  81%|████████  | 808/1000 [00:20<00:04, 38.73it/s]
[Epoch 801/1000] Total Loss: 0.040081, Reconstruction Loss: 0.038826, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 908/1000 [00:23<00:02, 38.76it/s]
[Epoch 901/1000] Total Loss: 0.006514, Reconstruction Loss: 0.005245, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:25<00:00, 39.31it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.040346, Reconstruction Loss: 0.039067, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 13Baso
  Including 732 neighboring cells (300 cluster + 732 neighbors = 1032 total)
Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]
[Epoch 1/1000] Total Loss: 487.833374, Reconstruction Loss: 51.935591, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 108/1000 [00:02<00:22, 38.89it/s]
[Epoch 101/1000] Total Loss: 140.142778, Reconstruction Loss: 0.478408, LR: 1.00e-01, Batch size: 128
Training Epochs:  21%|██        | 208/1000 [00:05<00:20, 38.69it/s]
[Epoch 201/1000] Total Loss: 13.275340, Reconstruction Loss: 0.016671, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 308/1000 [00:07<00:17, 39.69it/s]
[Epoch 301/1000] Total Loss: 1.217559, Reconstruction Loss: 0.001884, LR: 1.00e-03, Batch size: 128
Training Epochs:  40%|████      | 405/1000 [00:10<00:14, 39.73it/s]
[Epoch 401/1000] Total Loss: 0.123009, Reconstruction Loss: 0.001673, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████     | 507/1000 [00:12<00:12, 39.62it/s]
[Epoch 501/1000] Total Loss: 0.121865, Reconstruction Loss: 0.001754, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 608/1000 [00:15<00:09, 39.56it/s]
[Epoch 601/1000] Total Loss: 0.012941, Reconstruction Loss: 0.001801, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████   | 708/1000 [00:18<00:07, 39.05it/s]
[Epoch 701/1000] Total Loss: 0.003469, Reconstruction Loss: 0.002291, LR: 1.00e-06, Batch size: 128
Training Epochs:  81%|████████  | 808/1000 [00:20<00:04, 38.91it/s]
[Epoch 801/1000] Total Loss: 0.003229, Reconstruction Loss: 0.001975, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 908/1000 [00:23<00:02, 38.90it/s]
[Epoch 901/1000] Total Loss: 0.003367, Reconstruction Loss: 0.002149, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:25<00:00, 39.21it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.002648, Reconstruction Loss: 0.001397, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 8Mk
  Including 314 neighboring cells (68 cluster + 314 neighbors = 382 total)
Training Epochs:   2%|▏         | 16/1000 [00:00<00:06, 151.02it/s]
[Epoch 1/1000] Total Loss: 703.936890, Reconstruction Loss: 151.828520, LR: 1.00e-01, Batch size: 128
Training Epochs:  13%|█▎        | 128/1000 [00:00<00:05, 157.50it/s]
[Epoch 101/1000] Total Loss: 129.811188, Reconstruction Loss: 0.616909, LR: 1.00e-01, Batch size: 128
Training Epochs:  22%|██▏       | 224/1000 [00:01<00:04, 157.90it/s]
[Epoch 201/1000] Total Loss: 11.894033, Reconstruction Loss: 0.013081, LR: 1.00e-02, Batch size: 128
Training Epochs:  32%|███▏      | 320/1000 [00:02<00:04, 158.01it/s]
[Epoch 301/1000] Total Loss: 12.339972, Reconstruction Loss: 0.015128, LR: 1.00e-02, Batch size: 128
Training Epochs:  43%|████▎     | 432/1000 [00:02<00:03, 157.83it/s]
[Epoch 401/1000] Total Loss: 1.175145, Reconstruction Loss: 0.004216, LR: 1.00e-03, Batch size: 128
Training Epochs:  53%|█████▎    | 528/1000 [00:03<00:02, 157.84it/s]
[Epoch 501/1000] Total Loss: 0.278549, Reconstruction Loss: 0.003703, LR: 1.00e-04, Batch size: 128
Training Epochs:  62%|██████▏   | 624/1000 [00:03<00:02, 157.87it/s]
[Epoch 601/1000] Total Loss: 0.128252, Reconstruction Loss: 0.003602, LR: 1.00e-04, Batch size: 128
Training Epochs:  72%|███████▏  | 720/1000 [00:04<00:01, 157.93it/s]
[Epoch 701/1000] Total Loss: 0.130609, Reconstruction Loss: 0.003589, LR: 1.00e-04, Batch size: 128
Training Epochs:  83%|████████▎ | 832/1000 [00:05<00:01, 157.97it/s]
[Epoch 801/1000] Total Loss: 0.016030, Reconstruction Loss: 0.003621, LR: 1.00e-05, Batch size: 128
Training Epochs:  93%|█████████▎| 928/1000 [00:05<00:00, 157.88it/s]
[Epoch 901/1000] Total Loss: 0.016061, Reconstruction Loss: 0.003530, LR: 1.00e-05, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:06<00:00, 157.89it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.005079, Reconstruction Loss: 0.003791, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 9GMP
  Including 508 neighboring cells (63 cluster + 508 neighbors = 571 total)
Training Epochs:   1%|          | 8/1000 [00:00<00:13, 75.41it/s]
[Epoch 1/1000] Total Loss: 642.452698, Reconstruction Loss: 107.872519, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 112/1000 [00:01<00:11, 77.65it/s]
[Epoch 101/1000] Total Loss: 136.588480, Reconstruction Loss: 0.513962, LR: 1.00e-01, Batch size: 128
Training Epochs:  22%|██▏       | 216/1000 [00:02<00:10, 77.59it/s]
[Epoch 201/1000] Total Loss: 12.774814, Reconstruction Loss: 0.015243, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 311/1000 [00:04<00:10, 63.16it/s]
[Epoch 301/1000] Total Loss: 1.216847, Reconstruction Loss: 0.001584, LR: 1.00e-03, Batch size: 128
Training Epochs:  41%|████      | 409/1000 [00:05<00:09, 62.34it/s]
[Epoch 401/1000] Total Loss: 0.128342, Reconstruction Loss: 0.001505, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████     | 507/1000 [00:07<00:07, 62.82it/s]
[Epoch 501/1000] Total Loss: 0.126572, Reconstruction Loss: 0.001508, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 612/1000 [00:09<00:06, 62.93it/s]
[Epoch 601/1000] Total Loss: 0.013297, Reconstruction Loss: 0.001543, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████   | 710/1000 [00:10<00:04, 62.92it/s]
[Epoch 701/1000] Total Loss: 0.013331, Reconstruction Loss: 0.001476, LR: 1.00e-05, Batch size: 128
Training Epochs:  81%|████████  | 811/1000 [00:11<00:02, 78.82it/s]
[Epoch 801/1000] Total Loss: 0.002651, Reconstruction Loss: 0.001440, LR: 1.00e-06, Batch size: 128
Training Epochs:  92%|█████████▏| 915/1000 [00:13<00:01, 79.24it/s]
[Epoch 901/1000] Total Loss: 0.002769, Reconstruction Loss: 0.001563, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:14<00:00, 69.82it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.002786, Reconstruction Loss: 0.001589, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 10GMP
  Including 640 neighboring cells (153 cluster + 640 neighbors = 793 total)
Training Epochs:   0%|          | 5/1000 [00:00<00:20, 49.59it/s]
[Epoch 1/1000] Total Loss: 538.226852, Reconstruction Loss: 68.540875, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 107/1000 [00:02<00:16, 53.25it/s]
[Epoch 101/1000] Total Loss: 140.375003, Reconstruction Loss: 0.498863, LR: 1.00e-01, Batch size: 128
Training Epochs:  21%|██        | 209/1000 [00:03<00:14, 53.22it/s]
[Epoch 201/1000] Total Loss: 12.876518, Reconstruction Loss: 0.015265, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 311/1000 [00:05<00:12, 53.12it/s]
[Epoch 301/1000] Total Loss: 1.330200, Reconstruction Loss: 0.001173, LR: 1.00e-03, Batch size: 128
Training Epochs:  41%|████      | 407/1000 [00:07<00:11, 53.21it/s]
[Epoch 401/1000] Total Loss: 0.124834, Reconstruction Loss: 0.001121, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████     | 507/1000 [00:09<00:11, 42.47it/s]
[Epoch 501/1000] Total Loss: 0.123480, Reconstruction Loss: 0.001061, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 607/1000 [00:12<00:09, 42.15it/s]
[Epoch 601/1000] Total Loss: 0.013597, Reconstruction Loss: 0.001077, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████   | 707/1000 [00:14<00:06, 42.20it/s]
[Epoch 701/1000] Total Loss: 0.002369, Reconstruction Loss: 0.001120, LR: 1.00e-06, Batch size: 128
Training Epochs:  81%|████████  | 808/1000 [00:16<00:03, 52.74it/s]
[Epoch 801/1000] Total Loss: 0.002366, Reconstruction Loss: 0.001123, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 910/1000 [00:18<00:01, 52.89it/s]
[Epoch 901/1000] Total Loss: 0.002376, Reconstruction Loss: 0.001130, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:20<00:00, 49.56it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.002322, Reconstruction Loss: 0.001096, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 16Neu
  Including 352 neighboring cells (164 cluster + 352 neighbors = 516 total)
Training Epochs:   1%|          | 8/1000 [00:00<00:12, 77.12it/s]
[Epoch 1/1000] Total Loss: 555.262131, Reconstruction Loss: 52.856586, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 112/1000 [00:01<00:11, 77.41it/s]
[Epoch 101/1000] Total Loss: 141.538776, Reconstruction Loss: 0.440156, LR: 1.00e-01, Batch size: 128
Training Epochs:  22%|██▏       | 216/1000 [00:02<00:10, 77.39it/s]
[Epoch 201/1000] Total Loss: 13.089549, Reconstruction Loss: 0.011326, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 312/1000 [00:04<00:08, 77.30it/s]
[Epoch 301/1000] Total Loss: 1.303519, Reconstruction Loss: 0.003004, LR: 1.00e-03, Batch size: 128
Training Epochs:  42%|████▏     | 416/1000 [00:05<00:07, 77.36it/s]
[Epoch 401/1000] Total Loss: 0.369445, Reconstruction Loss: 0.003713, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████     | 512/1000 [00:06<00:06, 77.43it/s]
[Epoch 501/1000] Total Loss: 0.122316, Reconstruction Loss: 0.002713, LR: 1.00e-04, Batch size: 128
Training Epochs:  62%|██████▏   | 616/1000 [00:07<00:04, 78.10it/s]
[Epoch 601/1000] Total Loss: 0.046275, Reconstruction Loss: 0.003134, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████   | 712/1000 [00:09<00:03, 77.75it/s]
[Epoch 701/1000] Total Loss: 0.015851, Reconstruction Loss: 0.003053, LR: 1.00e-05, Batch size: 128
Training Epochs:  82%|████████▏ | 816/1000 [00:10<00:02, 77.66it/s]
[Epoch 801/1000] Total Loss: 0.003822, Reconstruction Loss: 0.002674, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 912/1000 [00:11<00:01, 77.73it/s]
[Epoch 901/1000] Total Loss: 0.003275, Reconstruction Loss: 0.002093, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:12<00:00, 77.65it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.042998, Reconstruction Loss: 0.041784, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 5Ery
  Including 636 neighboring cells (180 cluster + 636 neighbors = 816 total)
Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]
[Epoch 1/1000] Total Loss: 526.589417, Reconstruction Loss: 64.303920, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 108/1000 [00:02<00:16, 53.03it/s]
[Epoch 101/1000] Total Loss: 137.988588, Reconstruction Loss: 0.397441, LR: 1.00e-01, Batch size: 128
Training Epochs:  21%|██        | 210/1000 [00:03<00:14, 53.01it/s]
[Epoch 201/1000] Total Loss: 13.420788, Reconstruction Loss: 0.012431, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 306/1000 [00:05<00:13, 52.91it/s]
[Epoch 301/1000] Total Loss: 1.377526, Reconstruction Loss: 0.001665, LR: 1.00e-03, Batch size: 128
Training Epochs:  41%|████      | 408/1000 [00:07<00:11, 53.00it/s]
[Epoch 401/1000] Total Loss: 0.120787, Reconstruction Loss: 0.001807, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████     | 509/1000 [00:09<00:10, 47.01it/s]
[Epoch 501/1000] Total Loss: 0.120013, Reconstruction Loss: 0.001550, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 609/1000 [00:12<00:09, 41.97it/s]
[Epoch 601/1000] Total Loss: 0.013205, Reconstruction Loss: 0.001683, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████   | 709/1000 [00:14<00:06, 41.90it/s]
[Epoch 701/1000] Total Loss: 0.003225, Reconstruction Loss: 0.001858, LR: 1.00e-06, Batch size: 128
Training Epochs:  81%|████████  | 809/1000 [00:16<00:04, 41.89it/s]
[Epoch 801/1000] Total Loss: 0.002649, Reconstruction Loss: 0.001429, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 909/1000 [00:19<00:02, 41.84it/s]
[Epoch 901/1000] Total Loss: 0.003078, Reconstruction Loss: 0.001834, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:21<00:00, 46.82it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.003025, Reconstruction Loss: 0.001758, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 1Ery
  Including 154 neighboring cells (43 cluster + 154 neighbors = 197 total)
Training Epochs:   3%|▎         | 30/1000 [00:00<00:03, 296.95it/s]
[Epoch 1/1000] Total Loss: 616.029419, Reconstruction Loss: 37.509773, LR: 1.00e-01, Batch size: 128
Training Epochs:  16%|█▌        | 157/1000 [00:00<00:02, 302.88it/s]
[Epoch 101/1000] Total Loss: 137.103165, Reconstruction Loss: 0.618848, LR: 1.00e-01, Batch size: 128
Training Epochs:  24%|██▍       | 245/1000 [00:00<00:02, 253.85it/s]
[Epoch 201/1000] Total Loss: 33.040031, Reconstruction Loss: 0.098704, LR: 1.00e-02, Batch size: 128
Training Epochs:  35%|███▍      | 348/1000 [00:01<00:02, 248.96it/s]
[Epoch 301/1000] Total Loss: 10.811853, Reconstruction Loss: 0.016890, LR: 1.00e-02, Batch size: 128
Training Epochs:  44%|████▍     | 442/1000 [00:01<00:01, 288.69it/s]
[Epoch 401/1000] Total Loss: 12.890885, Reconstruction Loss: 0.028755, LR: 1.00e-02, Batch size: 128
Training Epochs:  54%|█████▎    | 537/1000 [00:01<00:01, 303.84it/s]
[Epoch 501/1000] Total Loss: 1.933045, Reconstruction Loss: 0.007760, LR: 1.00e-03, Batch size: 128
Training Epochs:  63%|██████▎   | 633/1000 [00:02<00:01, 309.52it/s]
[Epoch 601/1000] Total Loss: 1.360277, Reconstruction Loss: 0.012070, LR: 1.00e-03, Batch size: 128
Training Epochs:  76%|███████▌  | 761/1000 [00:02<00:00, 311.73it/s]
[Epoch 701/1000] Total Loss: 1.432901, Reconstruction Loss: 0.011833, LR: 1.00e-03, Batch size: 128
Training Epochs:  86%|████████▌ | 857/1000 [00:02<00:00, 311.59it/s]
[Epoch 801/1000] Total Loss: 0.159296, Reconstruction Loss: 0.026645, LR: 1.00e-04, Batch size: 128
Training Epochs:  95%|█████████▌| 953/1000 [00:03<00:00, 311.82it/s]
[Epoch 901/1000] Total Loss: 0.132341, Reconstruction Loss: 0.015429, LR: 1.00e-04, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:03<00:00, 294.71it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.142051, Reconstruction Loss: 0.019410, LR: 1.00e-04, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 6Ery
  Including 573 neighboring cells (173 cluster + 573 neighbors = 746 total)
Training Epochs:   1%|          | 6/1000 [00:00<00:16, 58.68it/s]
[Epoch 1/1000] Total Loss: 571.778064, Reconstruction Loss: 86.155782, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 111/1000 [00:01<00:14, 63.34it/s]
[Epoch 101/1000] Total Loss: 137.577292, Reconstruction Loss: 0.459664, LR: 1.00e-01, Batch size: 128
Training Epochs:  21%|██        | 209/1000 [00:03<00:12, 63.36it/s]
[Epoch 201/1000] Total Loss: 12.259250, Reconstruction Loss: 0.014818, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 307/1000 [00:04<00:11, 62.04it/s]
[Epoch 301/1000] Total Loss: 1.134596, Reconstruction Loss: 0.001273, LR: 1.00e-03, Batch size: 128
Training Epochs:  41%|████      | 412/1000 [00:06<00:09, 63.30it/s]
[Epoch 401/1000] Total Loss: 1.144775, Reconstruction Loss: 0.001195, LR: 1.00e-03, Batch size: 128
Training Epochs:  51%|█████     | 510/1000 [00:08<00:07, 63.27it/s]
[Epoch 501/1000] Total Loss: 0.130114, Reconstruction Loss: 0.001181, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 608/1000 [00:09<00:06, 63.30it/s]
[Epoch 601/1000] Total Loss: 0.013968, Reconstruction Loss: 0.001169, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████▏  | 713/1000 [00:11<00:04, 63.45it/s]
[Epoch 701/1000] Total Loss: 0.002475, Reconstruction Loss: 0.001206, LR: 1.00e-06, Batch size: 128
Training Epochs:  81%|████████  | 811/1000 [00:12<00:02, 63.52it/s]
[Epoch 801/1000] Total Loss: 0.002309, Reconstruction Loss: 0.001073, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 909/1000 [00:14<00:01, 63.53it/s]
[Epoch 901/1000] Total Loss: 0.002420, Reconstruction Loss: 0.001169, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:15<00:00, 63.33it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.002381, Reconstruction Loss: 0.001159, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 12Baso
  Including 482 neighboring cells (69 cluster + 482 neighbors = 551 total)
Training Epochs:   1%|          | 8/1000 [00:00<00:13, 72.38it/s]
[Epoch 1/1000] Total Loss: 599.778030, Reconstruction Loss: 83.738147, LR: 1.00e-01, Batch size: 128
Training Epochs:  11%|█         | 112/1000 [00:01<00:11, 79.08it/s]
[Epoch 101/1000] Total Loss: 140.298676, Reconstruction Loss: 0.488358, LR: 1.00e-01, Batch size: 128
Training Epochs:  22%|██▏       | 216/1000 [00:02<00:09, 78.92it/s]
[Epoch 201/1000] Total Loss: 12.743252, Reconstruction Loss: 0.050250, LR: 1.00e-02, Batch size: 128
Training Epochs:  31%|███       | 312/1000 [00:03<00:08, 79.13it/s]
[Epoch 301/1000] Total Loss: 1.270438, Reconstruction Loss: 0.001581, LR: 1.00e-03, Batch size: 128
Training Epochs:  42%|████▏     | 416/1000 [00:05<00:07, 79.04it/s]
[Epoch 401/1000] Total Loss: 0.127087, Reconstruction Loss: 0.003528, LR: 1.00e-04, Batch size: 128
Training Epochs:  51%|█████     | 512/1000 [00:06<00:06, 79.18it/s]
[Epoch 501/1000] Total Loss: 0.163729, Reconstruction Loss: 0.040517, LR: 1.00e-04, Batch size: 128
Training Epochs:  61%|██████    | 611/1000 [00:07<00:05, 68.87it/s]
[Epoch 601/1000] Total Loss: 0.013352, Reconstruction Loss: 0.001577, LR: 1.00e-05, Batch size: 128
Training Epochs:  71%|███████   | 711/1000 [00:09<00:04, 63.67it/s]
[Epoch 701/1000] Total Loss: 0.005730, Reconstruction Loss: 0.001591, LR: 1.00e-06, Batch size: 128
Training Epochs:  81%|████████  | 809/1000 [00:10<00:03, 62.04it/s]
[Epoch 801/1000] Total Loss: 0.041779, Reconstruction Loss: 0.040590, LR: 1.00e-06, Batch size: 128
Training Epochs:  91%|█████████ | 907/1000 [00:12<00:01, 62.75it/s]
[Epoch 901/1000] Total Loss: 0.004689, Reconstruction Loss: 0.003509, LR: 1.00e-06, Batch size: 128
Training Epochs: 100%|██████████| 1000/1000 [00:13<00:00, 71.53it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.002782, Reconstruction Loss: 0.001572, LR: 1.00e-06, Batch size: 128
Inferring interaction matrix W and bias vector I for cluster 18Eos
  Including 79 neighboring cells (9 cluster + 79 neighbors = 88 total)
Training Epochs:   0%|          | 0/1000 [00:00<?, ?it/s]
[Epoch 1/1000] Total Loss: 615.155457, Reconstruction Loss: 36.572170, LR: 1.00e-01, Batch size: 88
Training Epochs:  15%|█▌        | 154/1000 [00:00<00:03, 269.35it/s]
[Epoch 101/1000] Total Loss: 136.959961, Reconstruction Loss: 0.593444, LR: 1.00e-01, Batch size: 88
Training Epochs:  24%|██▍       | 240/1000 [00:00<00:02, 276.11it/s]
[Epoch 201/1000] Total Loss: 14.810553, Reconstruction Loss: 0.022653, LR: 1.00e-02, Batch size: 88
Training Epochs:  35%|███▌      | 353/1000 [00:01<00:02, 277.39it/s]
[Epoch 301/1000] Total Loss: 11.316996, Reconstruction Loss: 0.010906, LR: 1.00e-02, Batch size: 88
Training Epochs:  44%|████▍     | 438/1000 [00:01<00:02, 278.58it/s]
[Epoch 401/1000] Total Loss: 12.383344, Reconstruction Loss: 0.011407, LR: 1.00e-02, Batch size: 88
Training Epochs:  55%|█████▌    | 554/1000 [00:02<00:01, 280.45it/s]
[Epoch 501/1000] Total Loss: 1.325688, Reconstruction Loss: 0.004976, LR: 1.00e-03, Batch size: 88
Training Epochs:  64%|██████▍   | 641/1000 [00:02<00:01, 280.88it/s]
[Epoch 601/1000] Total Loss: 1.225986, Reconstruction Loss: 0.005140, LR: 1.00e-03, Batch size: 88
Training Epochs:  76%|███████▌  | 757/1000 [00:02<00:00, 280.97it/s]
[Epoch 701/1000] Total Loss: 0.306577, Reconstruction Loss: 0.005236, LR: 1.00e-04, Batch size: 88
Training Epochs:  85%|████████▍ | 846/1000 [00:03<00:00, 286.53it/s]
[Epoch 801/1000] Total Loss: 0.134355, Reconstruction Loss: 0.004800, LR: 1.00e-04, Batch size: 88
Training Epochs:  94%|█████████▍| 938/1000 [00:03<00:00, 288.19it/s]
[Epoch 901/1000] Total Loss: 0.127948, Reconstruction Loss: 0.005161, LR: 1.00e-04, Batch size: 88
Training Epochs: 100%|██████████| 1000/1000 [00:03<00:00, 276.39it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.022849, Reconstruction Loss: 0.004958, LR: 1.00e-05, Batch size: 88
Inferring interaction matrix W and bias vector I for cluster 11DC
  Including 29 neighboring cells (1 cluster + 29 neighbors = 30 total)
Training Epochs:   4%|▍         | 38/1000 [00:00<00:02, 375.10it/s]
[Epoch 1/1000] Total Loss: 617.309387, Reconstruction Loss: 39.388676, LR: 1.00e-01, Batch size: 30
Training Epochs:  16%|█▌        | 160/1000 [00:00<00:02, 398.49it/s]
[Epoch 101/1000] Total Loss: 134.232605, Reconstruction Loss: 0.760372, LR: 1.00e-01, Batch size: 30
Training Epochs:  24%|██▍       | 242/1000 [00:00<00:01, 400.80it/s]
[Epoch 201/1000] Total Loss: 32.072197, Reconstruction Loss: 0.108587, LR: 1.00e-02, Batch size: 30
Training Epochs:  36%|███▋      | 365/1000 [00:00<00:01, 402.23it/s]
[Epoch 301/1000] Total Loss: 12.060165, Reconstruction Loss: 0.013782, LR: 1.00e-02, Batch size: 30
Training Epochs:  45%|████▍     | 447/1000 [00:01<00:01, 401.81it/s]
[Epoch 401/1000] Total Loss: 12.669285, Reconstruction Loss: 0.013119, LR: 1.00e-02, Batch size: 30
Training Epochs:  57%|█████▋    | 570/1000 [00:01<00:01, 401.60it/s]
[Epoch 501/1000] Total Loss: 12.392544, Reconstruction Loss: 0.015478, LR: 1.00e-02, Batch size: 30
Training Epochs:  65%|██████▌   | 652/1000 [00:01<00:00, 401.82it/s]
[Epoch 601/1000] Total Loss: 1.319816, Reconstruction Loss: 0.004883, LR: 1.00e-03, Batch size: 30
Training Epochs:  78%|███████▊  | 775/1000 [00:01<00:00, 401.27it/s]
[Epoch 701/1000] Total Loss: 1.251818, Reconstruction Loss: 0.004567, LR: 1.00e-03, Batch size: 30
Training Epochs:  86%|████████▌ | 857/1000 [00:02<00:00, 400.06it/s]
[Epoch 801/1000] Total Loss: 1.280651, Reconstruction Loss: 0.004779, LR: 1.00e-03, Batch size: 30
Training Epochs:  98%|█████████▊| 980/1000 [00:02<00:00, 401.90it/s]
[Epoch 901/1000] Total Loss: 0.134520, Reconstruction Loss: 0.004028, LR: 1.00e-04, Batch size: 30
Training Epochs: 100%|██████████| 1000/1000 [00:02<00:00, 400.21it/s]
/home/bernaljp/packages/scHopfield/scHopfield/inference/optimizer.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer('mask', torch.tensor(mask, dtype=torch.float32, device=device))
[Epoch 1000/1000] Total Loss: 0.134135, Reconstruction Loss: 0.004956, LR: 1.00e-04, Batch size: 30
Inferring interaction matrix W and bias vector I for cluster 19Lymph
  Including 29 neighboring cells (1 cluster + 29 neighbors = 30 total)
Training Epochs:   4%|▍         | 38/1000 [00:00<00:02, 379.08it/s]
[Epoch 1/1000] Total Loss: 619.515259, Reconstruction Loss: 42.021152, LR: 1.00e-01, Batch size: 30
Training Epochs:  16%|█▌        | 159/1000 [00:00<00:02, 396.89it/s]
[Epoch 101/1000] Total Loss: 132.461731, Reconstruction Loss: 0.842844, LR: 1.00e-01, Batch size: 30
Training Epochs:  28%|██▊       | 281/1000 [00:00<00:01, 399.08it/s]
[Epoch 201/1000] Total Loss: 31.706991, Reconstruction Loss: 0.132093, LR: 1.00e-02, Batch size: 30
Training Epochs:  36%|███▌      | 362/1000 [00:00<00:01, 398.31it/s]
[Epoch 301/1000] Total Loss: 11.970432, Reconstruction Loss: 0.016650, LR: 1.00e-02, Batch size: 30
Training Epochs:  44%|████▍     | 443/1000 [00:01<00:01, 399.75it/s]
[Epoch 401/1000] Total Loss: 12.622685, Reconstruction Loss: 0.016970, LR: 1.00e-02, Batch size: 30
Training Epochs:  57%|█████▋    | 566/1000 [00:01<00:01, 399.97it/s]
[Epoch 501/1000] Total Loss: 12.394630, Reconstruction Loss: 0.018205, LR: 1.00e-02, Batch size: 30
Training Epochs:  65%|██████▍   | 647/1000 [00:01<00:00, 398.73it/s]
[Epoch 601/1000] Total Loss: 1.330454, Reconstruction Loss: 0.005943, LR: 1.00e-03, Batch size: 30
Training Epochs:  77%|███████▋  | 768/1000 [00:01<00:00, 398.89it/s]
[Epoch 701/1000] Total Loss: 1.259791, Reconstruction Loss: 0.006181, LR: 1.00e-03, Batch size: 30
Training Epochs:  85%|████████▍ | 849/1000 [00:02<00:00, 398.46it/s]
[Epoch 801/1000] Total Loss: 1.288557, Reconstruction Loss: 0.007436, LR: 1.00e-03, Batch size: 30
Training Epochs:  97%|█████████▋| 970/1000 [00:02<00:00, 398.86it/s]
[Epoch 901/1000] Total Loss: 0.145170, Reconstruction Loss: 0.005128, LR: 1.00e-04, Batch size: 30
Training Epochs: 100%|██████████| 1000/1000 [00:02<00:00, 397.79it/s]
[Epoch 1000/1000] Total Loss: 0.132250, Reconstruction Loss: 0.007338, LR: 1.00e-04, Batch size: 30
GRN inference complete for 19 clusters.

[29]:
# Colour map for consistent plotting (extracted from scVelo scatter)
fig_tmp, ax_tmp = plt.subplots()
scv.pl.scatter(adata, color=CLUSTER_KEY, basis=BASIS, ax=ax_tmp, show=False)
colors = {}
for k in adata.obs[CLUSTER_KEY].unique():
    idx = np.where(adata.obs[CLUSTER_KEY] == k)[0][0]
    c = ax_tmp.get_children()[0]._facecolors[idx].copy()
    c[3] = 1.0
    colors[k] = c
plt.close(fig_tmp)

5.2 ODE-based Single-Cell Trajectories

Integrate the Hopfield ODE forward in time from a representative cell’s initial

state under wild-type, knockout, and overexpression conditions.

[30]:
# Simulate WT, KO, and OE trajectories for the first three erythroid clusters
target_clusters = ['1Ery', '2Ery', '3Ery', '4Ery', '5Ery', '6Ery', '7MEP']
t_span = np.linspace(0, 10, 1000)

trajectory_results = {}

fig, axes = plt.subplots(min(3, len(target_clusters)), 2,
                         figsize=(14, 4 * min(3, len(target_clusters))),
                         tight_layout=True)
if axes.ndim == 1:
    axes = axes[np.newaxis, :]

for row, cluster in enumerate(target_clusters[:3]):
    print(f"  Simulating trajectories in {cluster}...")

    cluster_mask = adata.obs[CLUSTER_KEY] == cluster
    cluster_idx  = np.where(cluster_mask)[0]
    cell_idx     = cluster_idx[len(cluster_idx) // 2]  # representative cell

    wt = sch.dyn.simulate_trajectory(
        adata, cluster=cluster, spliced_key=SPLICED_KEY,
        cell_idx=cell_idx, t_span=t_span
    )
    ko = sch.dyn.simulate_perturbation_ode(
        adata, cluster=cluster, spliced_key=SPLICED_KEY,
        cell_idx=cell_idx, gene_perturbations={GOI: 0.0}, t_span=t_span
    )
    oe = sch.dyn.simulate_perturbation_ode(
        adata, cluster=cluster, spliced_key=SPLICED_KEY,
        cell_idx=cell_idx, gene_perturbations={GOI: 10.0}, t_span=t_span
    )
    trajectory_results[cluster] = {'wt': wt, 'ko': ko, 'oe': oe}

    # Top 5 most responsive genes (WT vs KO endpoint)
    top_idx    = np.argsort(np.abs(ko[-1] - wt[-1]))[-5:]
    clrs       = plt.cm.tab10(np.linspace(0, 1, len(top_idx)))

    for i, idx in enumerate(top_idx):
        axes[row, 0].plot(t_span, wt[:, idx], '-',  color=clrs[i], lw=2,
                          label=adata.var_names[idx])
        axes[row, 0].plot(t_span, ko[:, idx], '--', color=clrs[i], lw=2)
    axes[row, 0].set_xlabel('Time')
    axes[row, 0].set_ylabel('Expression')
    axes[row, 0].set_title(f'{cluster}: top responsive genes ({GOI} KO)\nSolid=WT  Dashed=KO')
    axes[row, 0].legend(fontsize=8)
    axes[row, 0].grid(True, alpha=0.3)

    # Phase-space divergence
    dist_wt = np.linalg.norm(wt - wt[0], axis=1)
    dist_ko = np.linalg.norm(ko - wt[0], axis=1)
    dist_oe = np.linalg.norm(oe - wt[0], axis=1)
    axes[row, 1].plot(t_span, dist_wt, color='gray',    lw=2, label='WT')
    axes[row, 1].plot(t_span, dist_ko, color='#E74C3C', lw=2, label=f'{GOI} KO')
    axes[row, 1].plot(t_span, dist_oe, color='#3498DB', lw=2, label=f'{GOI} OE')
    axes[row, 1].set_xlabel('Time')
    axes[row, 1].set_ylabel('Phase-space distance from t=0')
    axes[row, 1].set_title(f'{cluster}: trajectory divergence')
    axes[row, 1].legend()
    axes[row, 1].grid(True, alpha=0.3)

plt.show()

  Simulating trajectories in 1Ery...
  Simulating trajectories in 2Ery...
  Simulating trajectories in 3Ery...
../_images/notebooks_05_perturbation_analysis_20_1.png

5.3 Dataset-wide ODE Perturbation

Run the ODE simulation on every cell to obtain a perturbed expression landscape.

[11]:
print(f"Dataset-wide ODE simulation — {GOI} KO...")
adata_ode_ko = sch.dyn.simulate_shift_ode(
    adata.copy(),
    perturb_condition={GOI: 0.0},
    cluster_key=CLUSTER_KEY,
    dt=5.0,
    use_cluster_specific_GRN=True,
    verbose=True,
    n_jobs=-1,
)

print(f"Dataset-wide ODE simulation — {GOI} OE...")
adata_ode_oe = sch.dyn.simulate_shift_ode(
    adata.copy(),
    perturb_condition={GOI: 10.0},
    cluster_key=CLUSTER_KEY,
    dt=5.0,
    use_cluster_specific_GRN=True,
    verbose=True,
    n_jobs=-1,
)

top_ko = sch.dyn.get_top_affected_genes(adata_ode_ko, n_genes=20)
top_oe = sch.dyn.get_top_affected_genes(adata_ode_oe, n_genes=20)
print(f"Top ODE KO-affected genes: {top_ko[:5]}")
print(f"Top ODE OE-affected genes: {top_oe[:5]}")

Dataset-wide ODE simulation — Gata1 KO...
Processing cluster: 7MEP
Processing cluster: 15Mo
Processing cluster: 3Ery
Processing cluster: 4Ery
Processing cluster: 2Ery
Processing cluster: 17Neu
Processing cluster: 14Mo
Processing cluster: 13Baso
Processing cluster: 8Mk
Processing cluster: 9GMP
Processing cluster: 10GMP
Processing cluster: 16Neu
Processing cluster: 5Ery
Processing cluster: 1Ery
Processing cluster: 6Ery
Processing cluster: 12Baso
Processing cluster: 18Eos
Processing cluster: 11DC
Processing cluster: 19Lymph
Dataset-wide ODE simulation — Gata1 OE...
Processing cluster: 7MEP
Processing cluster: 15Mo
Processing cluster: 3Ery
Processing cluster: 4Ery
Processing cluster: 2Ery
Processing cluster: 17Neu
Processing cluster: 14Mo
Processing cluster: 13Baso
Processing cluster: 8Mk
Processing cluster: 9GMP
Processing cluster: 10GMP
Processing cluster: 16Neu
Processing cluster: 5Ery
Processing cluster: 1Ery
Processing cluster: 6Ery
Processing cluster: 12Baso
Processing cluster: 18Eos
Processing cluster: 11DC
Processing cluster: 19Lymph
Top ODE KO-affected genes:     gene  mean_delta_X  abs_mean_delta_X direction
0   Car2     -0.139609          0.139609      down
1   Rps3     -0.139453          0.139453      down
2    Ubb     -0.130804          0.130804      down
3  Snrpb     -0.103208          0.103208      down
4   Fth1     -0.093429          0.093429      down
Top ODE OE-affected genes:     gene  mean_delta_X  abs_mean_delta_X direction
0   Rps3      0.201916          0.201916        up
1   Ly6e      0.197750          0.197750        up
2  Rpl32      0.184011          0.184011        up
3   Rpl4      0.177697          0.177697        up
4   Calr      0.142530          0.142530        up
[31]:
# Perturbation effect heatmaps
fig_ko = sch.pl.plot_perturbation_effect_heatmap(adata_ode_ko, cluster_key=CLUSTER_KEY, n_genes=30)
plt.suptitle(f'{GOI} KO (ODE) — perturbation effect heatmap', y=1.01)
plt.show()

fig_oe = sch.pl.plot_perturbation_effect_heatmap(adata_ode_oe, cluster_key=CLUSTER_KEY, n_genes=30)
plt.suptitle(f'{GOI} OE (ODE) — perturbation effect heatmap', y=1.01)
plt.show()

../_images/notebooks_05_perturbation_analysis_23_0.png
../_images/notebooks_05_perturbation_analysis_23_1.png
[32]:
# Perturbation magnitude on embedding
sch.pl.plot_perturbation_magnitude(adata_ode_ko, cluster_key=CLUSTER_KEY, basis=BASIS)
plt.title(f'{GOI} KO (ODE) — perturbation magnitude')
plt.show()

sch.pl.plot_perturbation_magnitude(adata_ode_oe, cluster_key=CLUSTER_KEY, basis=BASIS)
plt.title(f'{GOI} OE (ODE) — perturbation magnitude')
plt.show()

../_images/notebooks_05_perturbation_analysis_24_0.png
../_images/notebooks_05_perturbation_analysis_24_1.png
[33]:
# Top affected genes (bar charts)
fig, axes = plt.subplots(1, 2, figsize=(16, 8), tight_layout=True)
sch.pl.plot_top_affected_genes_bar(adata_ode_ko, n_genes=20, ax=axes[0])
axes[0].set_title(f'{GOI} KO (ODE): top 20 affected genes')
sch.pl.plot_top_affected_genes_bar(adata_ode_oe, n_genes=20, ax=axes[1])
axes[1].set_title(f'{GOI} OE (ODE): top 20 affected genes')
plt.show()

../_images/notebooks_05_perturbation_analysis_25_0.png
[34]:
# Per-cluster KO vs OE symmetry and dual-column heatmaps
genes_used      = np.where(adata_ode_ko.var['scHopfield_used'])[0]
gene_names_used = adata_ode_ko.var_names[genes_used]

# Exclude the perturbed gene itself
perturb_keys = list(adata_ode_ko.uns['scHopfield'].get('perturb_condition', {}).keys())
keep = ~np.isin(gene_names_used, perturb_keys)
genes_used      = genes_used[keep]
gene_names_used = gene_names_used[keep]

delta_X_ko = np.asarray(adata_ode_ko.layers['delta_X'][:, genes_used])
delta_X_oe = np.asarray(adata_ode_oe.layers['delta_X'][:, genes_used])

cluster_changes_ko, cluster_changes_oe = {}, {}
for cl in clusters:
    mask = adata_ode_ko.obs[CLUSTER_KEY] == cl
    cluster_changes_ko[cl] = delta_X_ko[mask].mean(axis=0)
    cluster_changes_oe[cl] = delta_X_oe[mask].mean(axis=0)

[35]:
# Scatter: KO vs OE per-cluster symmetry
n_cl = len(clusters)
n_cols_plt = 4
n_rows_plt = (n_cl + n_cols_plt - 1) // n_cols_plt

fig, axes = plt.subplots(n_rows_plt, n_cols_plt,
                         figsize=(5 * n_cols_plt, 4 * n_rows_plt), tight_layout=True)
axes_flat = axes.flatten()

for i, cl in enumerate(clusters):
    ax  = axes_flat[i]
    m_ko = cluster_changes_ko[cl]
    m_oe = cluster_changes_oe[cl]
    t_ko = np.argsort(np.abs(m_ko))[-20:]
    t_oe = np.argsort(np.abs(m_oe))[-20:]
    ax.scatter(m_ko, m_oe, s=2, alpha=0.3, color='grey')
    ax.scatter(m_ko[t_ko], m_oe[t_ko], s=10, color='firebrick', label='Top KO')
    ax.scatter(m_ko[t_oe], m_oe[t_oe], s=10, color='steelblue', label='Top OE')
    lim = max(np.abs([ax.get_xlim(), ax.get_ylim()]).max(), 1e-9)
    ax.plot([-lim, lim], [lim, -lim], 'k--', alpha=0.2)
    ax.set_title(cl, fontsize=9)
    if i == 0:
        ax.legend(fontsize=7)

for j in range(i + 1, len(axes_flat)):
    axes_flat[j].axis('off')
plt.suptitle(f'ODE KO vs OE symmetry per cluster ({GOI})', y=1.02)
plt.show()

../_images/notebooks_05_perturbation_analysis_27_0.png
[36]:
# Dual-column heatmap: top KO genes per cluster
fig, axes = plt.subplots(n_rows_plt, n_cols_plt,
                         figsize=(5 * n_cols_plt, 6 * n_rows_plt), tight_layout=True)
axes_flat = axes.flatten()

for i, cl in enumerate(clusters):
    ax   = axes_flat[i]
    m_ko = cluster_changes_ko[cl]
    m_oe = cluster_changes_oe[cl]
    idx  = np.argsort(np.abs(m_ko))[-20:][::-1]
    data = np.stack([m_ko[idx], m_oe[idx]], axis=1)
    v    = np.abs(data).max() or 1e-9
    ax.imshow(data, cmap='coolwarm', aspect='auto', vmin=-v, vmax=v)
    ax.set_yticks(range(len(idx)))
    ax.set_yticklabels(gene_names_used[idx], fontsize=7)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(['KO', 'OE'], fontsize=8)
    ax.set_title(cl, fontsize=9)

for j in range(i + 1, len(axes_flat)):
    axes_flat[j].axis('off')
plt.suptitle(f'Per-cluster top-KO genes — KO vs OE ({GOI})', y=1.02)
plt.show()

../_images/notebooks_05_perturbation_analysis_28_0.png

5.4 GRN Propagation-based Perturbation

Propagate expression shifts through the inferred GRN iteratively

(analogous to the CellOracle approach).

[18]:
N_PROPAGATION  = 3
DT             = 5.0

print(f"GRN propagation — {GOI} KO...")
adata_ko = sch.dyn.simulate_shift(
    adata.copy(),
    perturb_condition={GOI: 0.0},
    cluster_key=CLUSTER_KEY,
    n_propagation=N_PROPAGATION,
    use_cluster_specific_GRN=True,
    clip_delta_X=True,
    x_max_percentile=90.0,
    verbose=True,
    dt=DT,
)

print(f"GRN propagation — {GOI} OE...")
adata_oe = sch.dyn.simulate_shift(
    adata.copy(),
    perturb_condition={GOI: 10.0},
    cluster_key=CLUSTER_KEY,
    n_propagation=N_PROPAGATION,
    use_cluster_specific_GRN=True,
    clip_delta_X=True,
    x_max_percentile=90.0,
    verbose=True,
    dt=DT,
)

top_ko_prop = sch.dyn.get_top_affected_genes(adata_ko, n_genes=20)
top_oe_prop = sch.dyn.get_top_affected_genes(adata_oe, n_genes=20)
print(f"Propagation KO top genes: {top_ko_prop[:5]}")
print(f"Propagation OE top genes: {top_oe_prop[:5]}")

GRN propagation — Gata1 KO...
Perturbation simulation complete
  Genes perturbed: ['Gata1']
  Propagation steps: 3
  dt (scaling): 5.0
  Perturbed genes: held constant
  Results stored in adata.layers['simulated_count'] and adata.layers['delta_X']
GRN propagation — Gata1 OE...
  Warning: Perturbation value 10.0 for 'Gata1' is outside typical range [0.00, 0.97]
Perturbation simulation complete
  Genes perturbed: ['Gata1']
  Propagation steps: 3
  dt (scaling): 5.0
  Perturbed genes: held constant
  Results stored in adata.layers['simulated_count'] and adata.layers['delta_X']
Propagation KO top genes:       gene  mean_delta_X  abs_mean_delta_X direction
0     Car2     -0.605450          0.605450      down
1     Fth1     -0.402631          0.402631      down
2    Vamp5     -0.369535          0.369535      down
3   Atpif1     -0.349242          0.349242      down
4  Fam132a     -0.322652          0.322652      down
Propagation OE top genes:      gene  mean_delta_X  abs_mean_delta_X direction
0   Rpl32      0.439467          0.439467        up
1    Car2      0.417303          0.417303        up
2    Rps3      0.385896          0.385896        up
3  Lgals9      0.376973          0.376973        up
4    Fth1      0.360218          0.360218        up
[37]:
# Heatmaps
fig_ko = sch.pl.plot_perturbation_effect_heatmap(adata_ko, cluster_key=CLUSTER_KEY, n_genes=30)
plt.suptitle(f'{GOI} KO (propagation) — perturbation effect heatmap', y=1.01)
plt.show()

fig_oe = sch.pl.plot_perturbation_effect_heatmap(adata_oe, cluster_key=CLUSTER_KEY, n_genes=30)
plt.suptitle(f'{GOI} OE (propagation) — perturbation effect heatmap', y=1.01)
plt.show()

../_images/notebooks_05_perturbation_analysis_31_0.png
../_images/notebooks_05_perturbation_analysis_31_1.png
[38]:
# Perturbation magnitude on embedding
sch.pl.plot_perturbation_magnitude(adata_ko, cluster_key=CLUSTER_KEY, basis=BASIS)
plt.title(f'{GOI} KO (propagation) — perturbation magnitude')
plt.show()

sch.pl.plot_perturbation_magnitude(adata_oe, cluster_key=CLUSTER_KEY, basis=BASIS)
plt.title(f'{GOI} OE (propagation) — perturbation magnitude')
plt.show()

../_images/notebooks_05_perturbation_analysis_32_0.png
../_images/notebooks_05_perturbation_analysis_32_1.png
[39]:
# Top affected genes (bar charts)
fig, axes = plt.subplots(1, 2, figsize=(16, 8), tight_layout=True)
sch.pl.plot_top_affected_genes_bar(adata_ko, n_genes=20, ax=axes[0])
axes[0].set_title(f'{GOI} KO (propagation): top 20 affected genes')
sch.pl.plot_top_affected_genes_bar(adata_oe, n_genes=20, ax=axes[1])
axes[1].set_title(f'{GOI} OE (propagation): top 20 affected genes')
plt.show()

../_images/notebooks_05_perturbation_analysis_33_0.png

5.5 Perturbation Flow on Embedding

Project the perturbation-induced expression shifts onto the 2-D embedding

using two approaches:

  • CellOracle-style (method='celloracle') — correlation-based neighbour voting

  • Hopfield-style (method='hopfield') — velocity derived directly from the

model: Δv = v(x_perturbed) − v(x_original)

CellOracle-style flow (propagation-based shift)

[22]:
n_neighbors_flow = 50

sch.tl.calculate_flow(
    adata_ko, basis=BASIS, method='celloracle',
    n_neighbors=n_neighbors_flow, correlation_mode='sampled', sigma_corr=0.05,
)
sch.tl.calculate_inner_product(
    adata_ko,
    flow_key_1=f'velocity_S_{BASIS}',
    flow_key_2=f'perturbation_flow_{BASIS}',
)

sch.tl.calculate_flow(
    adata_oe, basis=BASIS, method='celloracle',
    n_neighbors=n_neighbors_flow, correlation_mode='sampled', sigma_corr=0.05,
)
sch.tl.calculate_inner_product(
    adata_oe,
    flow_key_1=f'velocity_S_{BASIS}',
    flow_key_2=f'perturbation_flow_{BASIS}',
)

Projecting using embedding-space correlation (CellOracle style)...
Flow stored in adata.obsm['perturbation_flow_draw_graph_fa']
Projecting using embedding-space correlation (CellOracle style)...
Flow stored in adata.obsm['perturbation_flow_draw_graph_fa']
[22]:
array([-0.98475942, -0.89885457,  0.90522183, ...,  0.69439025,
       -0.66494506,  0.72592479])
[40]:
fig, axes = plt.subplots(3, 2, figsize=(10, 15), tight_layout=True)

# Cell-type scatter
c = [colors.get(cl, 'gray') for cl in adata_ko.obs[CLUSTER_KEY]]
axes[0, 0].scatter(adata_ko.obsm[f'X_{BASIS}'][:, 0],
                   adata_ko.obsm[f'X_{BASIS}'][:, 1],
                   c=c, s=10, alpha=0.7)
axes[0, 0].set_title('Cell types')
axes[0, 0].axis('off')

# Reference velocity
sch.pl.plot_flow(
    adata, basis=BASIS, ax=axes[0, 1], on_grid=True,
    flow_key=f'velocity_S_{BASIS}',
    n_grid=25, min_mass=45, scale=1000, color='black',
    cluster_key=CLUSTER_KEY, colors=colors, title='Reference velocity (grid)',
)

sch.pl.plot_flow(
    adata_ko, basis=BASIS, ax=axes[1, 0], on_grid=True,
    flow_key=f'perturbation_flow_{BASIS}',
    n_grid=25, min_mass=25, scale=5, color='#E74C3C',
    cluster_key=CLUSTER_KEY, colors=colors,
    title=f'{GOI} KO — perturbation flow (grid)',
)
sch.pl.plot_flow(
    adata_oe, basis=BASIS, ax=axes[1, 1], on_grid=True,
    flow_key=f'perturbation_flow_{BASIS}',
    n_grid=25, min_mass=25, scale=5, color='#3498DB',
    cluster_key=CLUSTER_KEY, colors=colors,
    title=f'{GOI} OE — perturbation flow (grid)',
)

sch.pl.plot_inner_product(adata_ko, basis=BASIS, ax=axes[2, 0],
                          title=f'{GOI} KO — inner product')
sch.pl.plot_inner_product(adata_oe, basis=BASIS, ax=axes[2, 1],
                          title=f'{GOI} OE — inner product')

plt.show()

../_images/notebooks_05_perturbation_analysis_37_0.png

Hopfield-style flow (ODE-based shift)

[24]:
n_neighbors_hopfield = 30

# KO: delta velocity (Δv = v' − v₀)
print(f"Computing Hopfield delta velocity — {GOI} KO...")
sch.tl.calculate_flow(
    adata_ode_ko, basis=BASIS, method='hopfield',
    cluster_key=CLUSTER_KEY, use_cluster_specific=True,
    source='delta', store_key='delta_velocity_hopfield',
    n_neighbors=n_neighbors_hopfield, n_jobs=-1, verbose=False,
)
sch.tl.calculate_inner_product(
    adata_ode_ko,
    flow_key_1=f'velocity_S_{BASIS}',
    flow_key_2='delta_velocity_hopfield',
    store_key='delta_inner_product_hopfield',
)

# KO: perturbed velocity (v')
print(f"Computing Hopfield perturbed velocity — {GOI} KO...")
sch.tl.calculate_flow(
    adata_ode_ko, basis=BASIS, method='hopfield',
    cluster_key=CLUSTER_KEY, use_cluster_specific=True,
    source='perturbed', store_key='perturbed_velocity_hopfield',
    n_neighbors=n_neighbors_hopfield, n_jobs=-1, verbose=False,
)
sch.tl.calculate_inner_product(
    adata_ode_ko,
    flow_key_1=f'velocity_S_{BASIS}',
    flow_key_2='perturbed_velocity_hopfield',
    store_key='perturbed_inner_product_hopfield',
)

# OE: delta velocity
print(f"Computing Hopfield delta velocity — {GOI} OE...")
sch.tl.calculate_flow(
    adata_ode_oe, basis=BASIS, method='hopfield',
    cluster_key=CLUSTER_KEY, use_cluster_specific=True,
    source='delta', store_key='delta_velocity_hopfield',
    n_neighbors=n_neighbors_hopfield, n_jobs=-1, verbose=False,
)
sch.tl.calculate_inner_product(
    adata_ode_oe,
    flow_key_1=f'velocity_S_{BASIS}',
    flow_key_2='delta_velocity_hopfield',
    store_key='delta_inner_product_hopfield',
)

# OE: perturbed velocity
print(f"Computing Hopfield perturbed velocity — {GOI} OE...")
sch.tl.calculate_flow(
    adata_ode_oe, basis=BASIS, method='hopfield',
    cluster_key=CLUSTER_KEY, use_cluster_specific=True,
    source='perturbed', store_key='perturbed_velocity_hopfield',
    n_neighbors=n_neighbors_hopfield, n_jobs=-1, verbose=False,
)
sch.tl.calculate_inner_product(
    adata_ode_oe,
    flow_key_1=f'velocity_S_{BASIS}',
    flow_key_2='perturbed_velocity_hopfield',
    store_key='perturbed_inner_product_hopfield',
)

print("Done computing all Hopfield flows.")

Computing Hopfield delta velocity — Gata1 KO...
Computing Hopfield perturbed velocity — Gata1 KO...
Computing Hopfield delta velocity — Gata1 OE...
Computing Hopfield perturbed velocity — Gata1 OE...
Done computing all Hopfield flows.
[41]:
# Visualise Hopfield delta and perturbed flow for KO
fig, axes = plt.subplots(2, 2, figsize=(16, 12), tight_layout=True)

sch.pl.plot_flow(
    adata_ode_ko, flow_key='delta_velocity_hopfield', basis=BASIS, on_grid=True,
    ax=axes[0, 0], n_grid=25, min_mass=25, scale=5000, color='#E74C3C',
    cluster_key=CLUSTER_KEY, colors=colors,
    title=f'{GOI} KO — Hopfield Δv (grid)',
)
sch.pl.plot_flow(
    adata_ode_ko, flow_key='perturbed_velocity_hopfield', basis=BASIS, on_grid=True,
    ax=axes[0, 1], n_grid=25, min_mass=25, scale=5000, color='#E74C3C',
    cluster_key=CLUSTER_KEY, colors=colors,
    title=f"{GOI} KO — Hopfield v' (grid)",
)
sch.pl.plot_inner_product(
    adata_ode_ko, basis=BASIS, ax=axes[1, 0],
    inner_product_key='delta_inner_product_hopfield',
    title=f'{GOI} KO — Hopfield Δv inner product',
)
sch.pl.plot_inner_product(
    adata_ode_ko, basis=BASIS, ax=axes[1, 1],
    inner_product_key='perturbed_inner_product_hopfield',
    title=f"{GOI} KO — Hopfield v' inner product",
)
plt.show()

../_images/notebooks_05_perturbation_analysis_40_0.png
[42]:
# Same for OE
fig, axes = plt.subplots(2, 2, figsize=(16, 12), tight_layout=True)

sch.pl.plot_flow(
    adata_ode_oe, flow_key='delta_velocity_hopfield', basis=BASIS, on_grid=True,
    ax=axes[0, 0], n_grid=25, min_mass=25, scale=5000, color='#3498DB',
    cluster_key=CLUSTER_KEY, colors=colors,
    title=f'{GOI} OE — Hopfield Δv (grid)',
)
sch.pl.plot_flow(
    adata_ode_oe, flow_key='perturbed_velocity_hopfield', basis=BASIS, on_grid=True,
    ax=axes[0, 1], n_grid=25, min_mass=25, scale=5000, color='#3498DB',
    cluster_key=CLUSTER_KEY, colors=colors,
    title=f"{GOI} OE — Hopfield v' (grid)",
)
sch.pl.plot_inner_product(
    adata_ode_oe, basis=BASIS, ax=axes[1, 0],
    inner_product_key='delta_inner_product_hopfield',
    title=f'{GOI} OE — Hopfield Δv inner product',
)
sch.pl.plot_inner_product(
    adata_ode_oe, basis=BASIS, ax=axes[1, 1],
    inner_product_key='perturbed_inner_product_hopfield',
    title=f"{GOI} OE — Hopfield v' inner product",
)
plt.show()

../_images/notebooks_05_perturbation_analysis_41_0.png