Source code for scHopfield.plotting.dynamics

"""Plotting functions for dynamics and trajectories."""

import matplotlib.pyplot as plt
import numpy as np
from typing import Optional


[docs] def plot_trajectory( trajectory: np.ndarray, t_span: np.ndarray, gene_names: Optional[list] = None, ax: Optional[plt.Axes] = None ) -> plt.Axes: """ Plot gene expression trajectories over time. Parameters ---------- trajectory : np.ndarray Trajectory array (n_timepoints � n_genes) t_span : np.ndarray Time points gene_names : list, optional Gene names to label ax : plt.Axes, optional Axes to plot on Returns ------- plt.Axes Axes with plot """ if ax is None: fig, ax = plt.subplots(figsize=(10, 6)) for i in range(trajectory.shape[1]): label = gene_names[i] if gene_names and i < len(gene_names) else f'Gene {i}' ax.plot(t_span, trajectory[:, i], label=label, alpha=0.7) ax.set_xlabel('Time') ax.set_ylabel('Expression') ax.set_title('Gene Expression Trajectory') if trajectory.shape[1] <= 10: ax.legend() return ax