Source code for scHopfield.plotting.correlation

"""Plotting functions for energy-gene correlations."""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from typing import Optional, Union, List, Dict
from anndata import AnnData

from .._utils.io import get_genes_used


[docs] def plot_gene_correlation_scatter( adata: AnnData, clus1: str, clus2: str, energy: str = 'total', cluster_key: str = 'cell_type', ax: Optional[plt.Axes] = None, annotate: Optional[int] = None, clus1_low: float = -0.5, clus1_high: float = 0.5, clus2_low: float = -0.5, clus2_high: float = 0.5 ) -> plt.Axes: """ Plot scatter of gene correlations between two clusters. Creates a scatter plot comparing the gene correlations with energy landscapes between two clusters, highlighting genes with divergent behavior (strongly positive in one cluster, strongly negative in the other). Parameters ---------- adata : AnnData Annotated data object with computed correlations clus1 : str First cluster name (x-axis) clus2 : str Second cluster name (y-axis) energy : str, optional (default: 'total') Energy type: 'total', 'interaction', 'degradation', or 'bias' cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels ax : plt.Axes, optional Axes to plot on. If None, creates new figure annotate : int, optional If provided, annotates top N divergent genes clus1_low : float, optional (default: -0.5) Lower threshold for clus1 to identify divergent genes clus1_high : float, optional (default: 0.5) Upper threshold for clus1 to identify divergent genes clus2_low : float, optional (default: -0.5) Lower threshold for clus2 to identify divergent genes clus2_high : float, optional (default: 0.5) Upper threshold for clus2 to identify divergent genes Returns ------- plt.Axes Axes with plot """ genes = get_genes_used(adata) gene_names = adata.var.index[genes] # Get correlations corr_col1 = f'correlation_{energy}_{clus1}' corr_col2 = f'correlation_{energy}_{clus2}' if corr_col1 not in adata.var.columns or corr_col2 not in adata.var.columns: raise ValueError( "Correlation data not found. Please run sch.tl.energy_gene_correlation() first." ) corr1 = adata.var[corr_col1].values[genes] corr2 = adata.var[corr_col2].values[genes] # Create a new figure and axes if none are provided if ax is None: fig, ax = plt.subplots(1, 1, figsize=(8, 8), tight_layout=True) # Set the limits for the axes ax.set_xlim((-1, 1)) ax.set_ylim((-1, 1)) # Identify correlations that are in opposite corners of the plot positions_corners = np.logical_or( np.logical_and(corr1 >= clus1_high, corr2 <= clus2_low), np.logical_and(corr1 <= clus1_low, corr2 >= clus2_high) ) corr_corners = np.where(positions_corners)[0] corr_center = np.where(~positions_corners)[0] # Plot the correlations using different colors for clarity ax.scatter(corr1[corr_corners], corr2[corr_corners], c='k', s=0.6, label='Divergent Correlations') ax.scatter(corr1[corr_center], corr2[corr_center], c='lightgray', s=0.5, label='Other Correlations') # Annotate top N divergent genes if requested if annotate is not None: nn = annotate # Get top N genes with the highest absolute correlation values cor_indices = np.argsort((corr1[corr_corners])**2 + (corr2[corr_corners])**2)[-nn:] # Get the names of the top N genes gois = gene_names[corr_corners][cor_indices] # Adding labels for the top N genes arrow_dict = {"width": 0.5, "headwidth": 0.5, "headlength": 1, "color": "gray"} for gg, xx, yy in zip(gois, corr1[corr_corners][cor_indices], corr2[corr_corners][cor_indices]): rand_shift_1 = np.random.uniform(-0.08, 0.08) rand_shift_2 = np.random.uniform(-0.08, 0.08) ax.annotate(gg, xy=(xx, yy), xytext=(xx+rand_shift_1, yy+rand_shift_2), arrowprops=arrow_dict) # Add reference lines ax.vlines([clus1_low, clus1_high], ymin=-1, ymax=1, linestyles='dashed', color='r') ax.hlines([clus2_low, clus2_high], xmin=-1, xmax=1, linestyles='dashed', color='r') ax.set_xlabel(clus1) ax.set_ylabel(clus2) return ax
[docs] def plot_correlations_grid( adata: AnnData, cluster_key: str = 'cell_type', energy: str = 'total', order: Optional[List[str]] = None, colors: Optional[Union[List, Dict]] = None, x_low: float = -0.5, x_high: float = 0.5, y_low: float = -0.5, y_high: float = 0.5, **kwargs ) -> plt.Figure: """ Plot grid of correlation scatter plots between all pairs of clusters. Creates a matrix where the diagonal shows cluster names and the off-diagonal plots show gene correlation scatter plots between clusters. Parameters ---------- adata : AnnData Annotated data object with computed correlations cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels energy : str, optional (default: 'total') Energy type: 'total', 'interaction', 'degradation', or 'bias' order : list, optional Order of clusters to display. If None, uses all unique clusters colors : list or dict, optional Colors for each cluster. If dict, maps cluster names to colors. If list, colors in order matching `order` parameter. Colors should be RGBA tuples or RGB tuples. x_low : float, optional (default: -0.5) Lower x threshold for highlighting divergent genes x_high : float, optional (default: 0.5) Upper x threshold for highlighting divergent genes y_low : float, optional (default: -0.5) Lower y threshold for highlighting divergent genes y_high : float, optional (default: 0.5) Upper y threshold for highlighting divergent genes **kwargs Additional arguments: - figsize : tuple (default: (15, 15)) - tight_layout : bool (default: True) Returns ------- plt.Figure Figure with correlation grid """ if order is None: cell_types = adata.obs[cluster_key].unique().tolist() else: cell_types = order n = len(cell_types) figsize = kwargs.get('figsize', (15, 15)) tight_layout = kwargs.get('tight_layout', True) # Convert colors to dict if it's a list if colors is not None and not isinstance(colors, dict): colors = {cell_types[i]: colors[i] for i in range(len(cell_types))} fig, axs = plt.subplots(n, n, figsize=figsize, tight_layout=tight_layout) # Handle case where n=1 (axs is not an array) if n == 1: axs = np.array([[axs]]) elif n > 1 and axs.ndim == 1: axs = axs.reshape(n, n) for i in range(n): for j in range(i, n): if i == j: # Diagonal: show cluster name for spine in axs[i, j].spines.values(): spine.set_visible(True) spine.set_linewidth(2) if colors is not None: spine.set_color(colors[cell_types[i]]) # Remove ticks axs[i, j].set_xticks([]) axs[i, j].set_yticks([]) # Add text in the middle text = cell_types[i] text = text.replace(' ', '\n', 1) text = text.replace('-', '-\n') axs[i, j].text( 0.5, 0.5, text, ha='center', va='center', fontsize=18, fontweight='bold', fontname='serif', transform=axs[i, j].transAxes ) # Set background color if colors is not None: # Get the color (could be hex '#ff0000', name 'red', or list [1, 0, 0]) raw_color = colors[cell_types[i]] # Normalize everything to an RGBA tuple (values 0.0 to 1.0) rgba = list(to_rgba(raw_color)) # Set the alpha channel to 0.2 rgba[3] = 0.2 axs[i, j].set_facecolor(rgba) else: # Upper triangle: turn off axs[i, j].axis('off') # Lower triangle: plot correlation scatter plot_gene_correlation_scatter( adata, clus1=cell_types[i], clus2=cell_types[j], energy=energy, cluster_key=cluster_key, ax=axs[j, i], clus1_low=x_low, clus1_high=x_high, clus2_low=y_low, clus2_high=y_high ) # Clean up axes axs[j, i].set_xticks([]) axs[j, i].set_yticks([]) axs[j, i].set_xlabel('') axs[j, i].set_ylabel('') # Add ticks for edges if i == 0: # First column axs[j, i].set_yticks([-1, -0.5, 0, 0.5, 1]) if j == n - 1: # Last row axs[j, i].set_xticks([-1, -0.5, 0, 0.5, 1]) return fig