Source code for scHopfield.plotting.energy

"""Plotting functions for energy landscapes."""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Union, List, Dict
from anndata import AnnData


[docs] def plot_energy_landscape( adata: AnnData, cluster: str, basis: str = 'umap', ax: Optional[plt.Axes] = None, **kwargs ) -> plt.Axes: """ Plot energy landscape on embedding space. Parameters ---------- adata : AnnData Annotated data object with energy embedding cluster : str Cluster name basis : str, optional Embedding basis ax : plt.Axes, optional Axes to plot on Returns ------- plt.Axes Axes with plot """ if ax is None: fig, ax = plt.subplots(figsize=(8, 6)) grid_X = adata.uns['scHopfield'][f'grid_X_{cluster}'] grid_Y = adata.uns['scHopfield'][f'grid_Y_{cluster}'] grid_energy = adata.uns['scHopfield'][f'grid_energy_{cluster}'] im = ax.contourf(grid_X, grid_Y, grid_energy, levels=20, cmap='viridis', **kwargs) ax.set_xlabel(f'{basis.upper()} 1') ax.set_ylabel(f'{basis.upper()} 2') ax.set_title(f'Energy Landscape: {cluster}') plt.colorbar(im, ax=ax, label='Energy') return ax
[docs] def plot_energy_components( adata: AnnData, cluster: str, basis: str = 'umap' ) -> plt.Figure: """ Plot all energy components (total, interaction, degradation, bias) for a cluster. Parameters ---------- adata : AnnData Annotated data object cluster : str Cluster name basis : str, optional Embedding basis Returns ------- plt.Figure Figure with subplots """ fig, axes = plt.subplots(2, 2, figsize=(12, 10)) energy_types = ['total', 'interaction', 'degradation', 'bias'] for ax, etype in zip(axes.flat, energy_types): grid_X = adata.uns['scHopfield'][f'grid_X_{cluster}'] grid_Y = adata.uns['scHopfield'][f'grid_Y_{cluster}'] grid_energy = adata.uns['scHopfield'][f'grid_energy_{etype}_{cluster}'] im = ax.contourf(grid_X, grid_Y, grid_energy, levels=20, cmap='viridis') ax.set_title(f'{etype.capitalize()} Energy') plt.colorbar(im, ax=ax) plt.tight_layout() return fig
[docs] def plot_energy_boxplots( adata: AnnData, cluster_key: str = 'cell_type', order: Optional[List[str]] = None, plot_energy: str = 'all', colors: Optional[Union[List, Dict]] = None, palette: Optional[str] = None, show_points: bool = False, **fig_kws ) -> Union[np.ndarray, plt.Axes]: """ Plot energy distributions for different clusters using boxplots. Parameters ---------- adata : AnnData Annotated data object with computed energies cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels order : list, optional Order of clusters to display in the boxplots plot_energy : str, optional (default: 'all') Which energy to plot: 'all', 'total', 'interaction', 'degradation', or 'bias' colors : list or dict, optional Colors for each cluster. Overrides palette. palette : str, optional Seaborn palette name (e.g., 'Set2', 'husl', 'tab10') show_points : bool, optional (default: False) If True, overlay individual points as strip plot **fig_kws Additional keyword arguments for plt.subplots() Returns ------- np.ndarray or plt.Axes Array of axes (if plot_energy='all') or single axes Examples -------- >>> import scHopfield as sch >>> sch.pl.plot_energy_boxplots(adata, cluster_key='cell_type') >>> sch.pl.plot_energy_boxplots(adata, plot_energy='interaction', palette='Set2') """ if order is None: order = adata.obs[cluster_key].unique().tolist() # Set up figure if plot_energy == 'all': fig_kws.setdefault('figsize', (14, 10)) fig, axs = plt.subplots(2, 2, **fig_kws) titles = ['Total Energy', 'Interaction Energy', 'Degradation Energy', 'Bias Energy'] for ax, title in zip(axs.flatten(), titles): ax.set_title(title, fontsize=12, fontweight='bold', pad=10) axs = axs.flatten() energy_cols = ['energy_total', 'energy_interaction', 'energy_degradation', 'energy_bias'] else: fig_kws.setdefault('figsize', (10, 6)) fig, axs = plt.subplots(1, 1, **fig_kws) axs = np.array([axs]) energy_cols = [f'energy_{plot_energy.lower()}'] axs[0].set_title(f'{plot_energy.capitalize()} Energy', fontsize=12, fontweight='bold', pad=10) # Handle colors plot_palette = None if colors is not None: if isinstance(colors, dict): plot_palette = [colors.get(k, 'gray') for k in order] elif isinstance(colors, list): assert len(colors) >= len(order), \ "Colors list should have at least as many colors as clusters." plot_palette = colors[:len(order)] elif palette is not None: plot_palette = palette # Create boxplots for energy_col, ax in zip(energy_cols, axs): # Check if energy column exists if energy_col not in adata.obs.columns: ax.text(0.5, 0.5, f'{energy_col} not found\nRun sch.tl.compute_energies() first', ha='center', va='center', transform=ax.transAxes, fontsize=10) continue df = pd.DataFrame({ 'Cluster': adata.obs[cluster_key], 'Energy': adata.obs[energy_col] }) # Create boxplot with better styling sns.boxplot( data=df, x='Cluster', y='Energy', order=order, ax=ax, palette=plot_palette, linewidth=1.5, fliersize=3, width=0.6 ) # Optionally add strip plot for individual points if show_points: sns.stripplot( data=df, x='Cluster', y='Energy', order=order, ax=ax, color='black', alpha=0.3, size=2, jitter=True ) # Styling ax.set_xlabel('Cell Type', fontsize=10, fontweight='bold') ax.set_ylabel('Energy', fontsize=10, fontweight='bold') ax.grid(True, alpha=0.3, linestyle='--', axis='y') ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) # Rotate x-axis labels if many clusters if len(order) > 5: ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right') plt.tight_layout() return fig if plot_energy == 'all' else axs[0]
[docs] def plot_energy_scatters( adata: AnnData, cluster_key: str = 'cell_type', basis: str = 'umap', order: Optional[List[str]] = None, plot_energy: str = 'all', show_legend: bool = True, colors: Optional[Union[List, Dict]] = None, palette: Optional[str] = None, alpha: float = 0.6, s: float = 20, elev: float = 30, azim: float = -60, **fig_kws ) -> Union[plt.Figure, plt.Axes]: """ Plot energy landscapes for different clusters using 3D scatter plots. Parameters ---------- adata : AnnData Annotated data object with computed energies cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels basis : str, optional (default: 'umap') The basis used for embedding order : list, optional Order of clusters to display plot_energy : str, optional (default: 'all') Which energy to plot: 'all', 'total', 'interaction', 'degradation', or 'bias' show_legend : bool, optional (default: True) Whether to show legend colors : list or dict, optional Colors for each cluster. Overrides palette. palette : str, optional Seaborn or matplotlib colormap name (e.g., 'tab10', 'Set2') alpha : float, optional (default: 0.6) Transparency of points s : float, optional (default: 20) Size of points elev : float, optional (default: 30) Elevation viewing angle azim : float, optional (default: -60) Azimuthal viewing angle **fig_kws Additional keyword arguments for plt.subplots() Returns ------- plt.Figure or plt.Axes Figure (if plot_energy='all') or single axes Examples -------- >>> import scHopfield as sch >>> sch.pl.plot_energy_scatters(adata, cluster_key='cell_type') >>> sch.pl.plot_energy_scatters(adata, plot_energy='interaction', palette='tab10') """ if order is None: order = adata.obs[cluster_key].unique().tolist() # Set up figure if plot_energy == 'all': fig_kws.setdefault('figsize', (16, 12)) fig, axs = plt.subplots(2, 2, subplot_kw={'projection': '3d'}, **fig_kws) titles = ['Total Energy', 'Interaction Energy', 'Degradation Energy', 'Bias Energy'] for ax, title in zip(axs.flatten(), titles): ax.set_title(title, fontsize=12, fontweight='bold', pad=15) axs = axs.flatten() energy_cols = ['energy_total', 'energy_interaction', 'energy_degradation', 'energy_bias'] else: fig_kws.setdefault('figsize', (10, 8)) fig, axs = plt.subplots(1, 1, subplot_kw={'projection': '3d'}, **fig_kws) axs = np.array([axs]) energy_cols = [f'energy_{plot_energy.lower()}'] axs[0].set_title(f'{plot_energy.capitalize()} Energy', fontsize=12, fontweight='bold', pad=15) # Handle colors import matplotlib.cm as cm if colors is not None: if isinstance(colors, dict): color_map = colors elif isinstance(colors, list): assert len(colors) >= len(order), \ "Colors list should have at least as many colors as clusters." color_map = {k: colors[i] for i, k in enumerate(order)} elif palette is not None: # Use colormap cmap = cm.get_cmap(palette) color_map = {k: cmap(i / len(order)) for i, k in enumerate(order)} else: # Default to tab10 cmap = cm.get_cmap('tab10') color_map = {k: cmap(i % 10) for i, k in enumerate(order)} # Check if embedding exists embedding_key = f'X_{basis}' if embedding_key not in adata.obsm: raise ValueError(f"Embedding '{embedding_key}' not found in adata.obsm. " f"Available: {list(adata.obsm.keys())}") # Plot each cluster for ax, energy_col in zip(axs, energy_cols): # Check if energy column exists if energy_col not in adata.obs.columns: ax.text2D(0.5, 0.5, f'{energy_col} not found\nRun sch.tl.compute_energies() first', ha='center', va='center', transform=ax.transAxes, fontsize=10) continue for k in order: cluster_mask = (adata.obs[cluster_key] == k).values cells = adata.obsm[embedding_key][cluster_mask, :2] energies = adata.obs[energy_col].values[cluster_mask] # Plot with cluster-specific color ax.scatter(cells[:, 0], cells[:, 1], energies, c=[color_map[k]], label=k, alpha=alpha, s=s, edgecolors='none') # Styling ax.set_xlabel(f'{basis.upper()} 1', fontsize=10, labelpad=8) ax.set_ylabel(f'{basis.upper()} 2', fontsize=10, labelpad=8) ax.set_zlabel('Energy', fontsize=10, labelpad=8) ax.view_init(elev=elev, azim=azim) # Grid ax.grid(True, alpha=0.3, linestyle='--') # Legend if show_legend: # Place legend outside plot ax.legend(loc='upper left', bbox_to_anchor=(1.05, 1), frameon=True, framealpha=0.9, fontsize=8) plt.tight_layout() return fig if plot_energy == 'all' else axs[0]