Source code for scHopfield.plotting.genes

"""Plotting functions for gene-level analysis."""

import numpy as np
import matplotlib.pyplot as plt
from typing import Optional
from anndata import AnnData
from .._utils.math import sigmoid
from .._utils.io import get_matrix, to_numpy


[docs] def plot_sigmoid_fit( adata: AnnData, gene: str, spliced_key: str = 'Ms', color_clusters: bool = False, cluster_key: str = 'cell_type', show_zeros: bool = True, ax: Optional[plt.Axes] = None, **kwargs ) -> plt.Axes: """ Plot sigmoid fit for a gene showing expression CDF and fitted curve. Parameters ---------- adata : AnnData Annotated data object with fitted sigmoid parameters gene : str Gene name to plot spliced_key : str, optional (default: 'Ms') Layer key for spliced expression color_clusters : bool, optional (default: False) If True, color points by cluster cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels (used if color_clusters=True) show_zeros : bool, optional (default: True) If True, show all expression values including zeros. If False, filter out zero values and plot sigmoid without offset. ax : plt.Axes, optional Axes to plot on. If None, creates new figure. **kwargs Additional keyword arguments: - c1: color for expression data (default: 'gray') - c2: color for fitted curve (default: 'red') - alpha: transparency for scatter points (default: 0.5) - s: size for scatter points (default: 10) Returns ------- plt.Axes Axes with plot Examples -------- >>> import scHopfield as sch >>> sch.pl.plot_sigmoid_fit(adata, 'Gata1') >>> sch.pl.plot_sigmoid_fit(adata, 'Gata1', color_clusters=True) >>> sch.pl.plot_sigmoid_fit(adata, 'Gata1', show_zeros=False) # Hide zeros """ if ax is None: fig, ax = plt.subplots(figsize=(6, 5)) # Check if sigmoid parameters exist if 'sigmoid_threshold' not in adata.var.columns: raise ValueError("Sigmoid parameters not found. Run sch.pp.fit_all_sigmoids() first.") # Get gene index if gene not in adata.var_names: raise ValueError(f"Gene '{gene}' not found in adata.var_names") gene_idx = adata.var_names.get_loc(gene) # Check if this gene was used in fitting if 'scHopfield_used' in adata.var.columns and not adata.var['scHopfield_used'].iloc[gene_idx]: ax.text(0.5, 0.5, f'{gene}\nNot included in analysis', ha='center', va='center', transform=ax.transAxes) ax.set_xlabel('Expression') ax.set_ylabel('CDF') return ax # Get expression data for this gene gexp = to_numpy(get_matrix(adata, spliced_key, genes=[gene_idx])).flatten() # Filter zeros if requested if not show_zeros: gexp = gexp[gexp > 0] if len(gexp) == 0: ax.text(0.5, 0.5, f'{gene}\nNo non-zero expression values', ha='center', va='center', transform=ax.transAxes) ax.set_xlabel('Expression') ax.set_ylabel('CDF') return ax # Sort expression and create empirical CDF sorted_expr = np.sort(gexp) empirical_cdf = np.linspace(0, 1, len(sorted_expr)) # Get sigmoid parameters threshold = adata.var['sigmoid_threshold'].iloc[gene_idx] exponent = adata.var['sigmoid_exponent'].iloc[gene_idx] offset = adata.var['sigmoid_offset'].iloc[gene_idx] mse = adata.var['sigmoid_mse'].iloc[gene_idx] # Plot expression vs CDF c1 = kwargs.get('c1', 'gray') c2 = kwargs.get('c2', 'red') alpha = kwargs.get('alpha', 0.5) size = kwargs.get('s', 10) if color_clusters and cluster_key in adata.obs.columns: # Color by cluster # Get original full expression array for cluster filtering gexp_full = to_numpy(get_matrix(adata, spliced_key, genes=[gene_idx])).flatten() for cluster in adata.obs[cluster_key].unique(): cluster_mask = (adata.obs[cluster_key] == cluster).values cluster_expr = gexp_full[cluster_mask] if not show_zeros: cluster_expr = cluster_expr[cluster_expr > 0] if len(cluster_expr) > 0: cluster_expr_sorted = np.sort(cluster_expr) cluster_cdf = np.linspace(0, 1, len(cluster_expr_sorted)) ax.scatter(cluster_expr_sorted, cluster_cdf, s=size, alpha=alpha, label=f'{cluster}', rasterized=True) else: # Single color ax.scatter(sorted_expr, empirical_cdf, s=size, alpha=alpha, color=c1, label='Expression', rasterized=True) # Compute fitted sigmoid curve if show_zeros: # The formula includes offset: sigmoid(x) * (1 - offset) + offset fitted_curve = sigmoid(sorted_expr, threshold, exponent) * (1 - offset) + offset else: # Without zeros, plot pure sigmoid (no offset) fitted_curve = sigmoid(sorted_expr, threshold, exponent) # Plot fitted curve ax.plot(sorted_expr, fitted_curve, '-', linewidth=2.5, color=c2, label='Sigmoid fit', zorder=10) # Add sigmoid formula as text sigmoid_formula = r"$\frac{{x^{{{:.2f}}}}}{{x^{{{:.2f}}} + {:.2f}^{{{:.2f}}}}}$".format( exponent, exponent, threshold, exponent ) # Position text in upper left if show_zeros: textstr = f'{sigmoid_formula}\nOffset = {offset:.3f}\nMSE = {mse:.4f}' else: textstr = f'{sigmoid_formula}\n(no offset, zeros excluded)\nMSE = {mse:.4f}' ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3)) # Styling ax.set_xlabel('Expression', fontsize=11) ax.set_ylabel('Cumulative Distribution', fontsize=11) ax.set_title(f'{gene}', fontsize=12, fontweight='bold') ax.grid(True, alpha=0.3, linestyle='--') ax.legend(loc='lower right', framealpha=0.9) ax.set_xlim(left=0) ax.set_ylim([0, 1]) return ax