Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions pf2barcode/figures/figure4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
Figure 4 -- Recreation of Boxplot of the correlation distance for related cells
(blue), and randomly sampled cells from the GEMLI paper
"""

import numpy as np
import pandas as pd
import seaborn as sns
from scipy.spatial.distance import pdist

from pf2barcode.imports import import_CCLE

from .common import getSetup, subplotLabel


def makeFigure():
"""Boxplot of correlation distance for related and random cells per lineage."""
ax, f = getSetup((8, 4), (1, 1))
subplotLabel(ax)

X = import_CCLE(min_cell_count=10000, pca_option="none")

# Filter out unknown or rare barcodes
X = X[X.obs["SW"] != "unknown"]
good_SW = X.obs["SW"].value_counts().index[X.obs["SW"].value_counts() > 10]
X = X[X.obs["SW"].isin(good_SW)]

# Convert matrix to dense for correlation computation
mat = X.X.toarray()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now converted to dense on import, so this can go.

df = X.obs.copy()
df["index"] = np.arange(len(df))

results = []

for sw, cells in df.groupby("SW"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move the calculation of the results to a separate function that only takes in the data it needs for the work. Please add comments to the loops and key operations explaining what is being done.

idx = cells["index"].values
if len(idx) < 2:
continue

# Related (within-lineage) distances
related = pdist(mat[idx], metric="correlation")

# Random (same number of pairs)
n_pairs = len(related)
n_cells = mat.shape[0]
random_corrs = []
for _ in range(100):
pairs = np.random.choice(n_cells, (n_pairs, 2), replace=True)
random_corrs.extend(
[1 - np.corrcoef(mat[i], mat[j])[0, 1] for i, j in pairs]
)

results.append(
pd.DataFrame(
{
"Correlation distance": np.concatenate([related, random_corrs]),
"Group": ["Cell lineage"] * len(related)
+ ["Random cells"] * len(random_corrs),
"Lineage": [sw] * (len(related) + len(random_corrs)),
}
)
)

df_plot = pd.concat(results, ignore_index=True)

sns.boxplot(
data=df_plot,
x="Lineage",
y="Correlation distance",
hue="Group",
showfliers=False,
palette={"Cell lineage": "#377eb8", "Random cells": "#bbbbbb"},
ax=ax[0],
)

ax[0].set_title("Correlation distance for related and random cells per lineage")
ax[0].set_xlabel("Lineage barcode (SW)")
ax[0].set_ylabel("Correlation distance")
ax[0].legend(title=None)

return f
107 changes: 18 additions & 89 deletions pf2barcode/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,91 +5,11 @@
import scanpy as sc
from anndata import AnnData, concat
from anndata.io import read_text
from scipy.sparse import csr_array, csr_matrix
from scipy.special import xlogy
from scipy.sparse import csr_array
from sklearn.preprocessing import scale
from sklearn.utils.sparsefuncs import (
inplace_column_scale,
mean_variance_axis,
)


def prepare_dataset(X: AnnData, geneThreshold: float) -> AnnData:
assert isinstance(X.X, csr_matrix)
assert np.amin(X.X.data) >= 0.0

# Filter out genes with too few reads
readmean, _ = mean_variance_axis(X.X, axis=0) # type: ignore
X = X[:, readmean > geneThreshold]

# Normalize read depth
sc.pp.normalize_total(
X, exclude_highly_expressed=False, inplace=True, key_added="n_counts"
)

# Scale genes by sum
readmean, _ = mean_variance_axis(X.X, axis=0) # type: ignore
readsum = X.shape[0] * readmean
inplace_column_scale(X.X, 1.0 / readsum) # type: ignore

# Transform values
X.X.data = np.log10((1000.0 * X.X.data) + 1.0) # type: ignore

return X


def prepare_dataset_dev(X: AnnData) -> AnnData:
X.X = csr_array(X.X) # type: ignore
assert np.amin(X.X.data) >= 0.0

# Remove cells and genes with fewer than 30 reads
X = X[X.X.sum(axis=1) > 5, X.X.sum(axis=0) > 5]

# Copy so that the subsetting is preserved
X._init_as_actual(X.copy())

# deviance transform
y_ij = X.X.toarray() # type: ignore

# counts per cell
n_i_col = y_ij.sum(axis=1).reshape(-1, 1)

# MLE of gene expression
pi_j = y_ij.sum(axis=0) / np.sum(n_i_col)
mu_ij = n_i_col * pi_j[None, :]

# --- Calculate Deviance Terms using numerically stable xlogy ---
# D = 2 * [ y*log(y/mu) + (n-y)*log((n-y)/(n-mu)) ]
# D = 2 * [ (xlogy(y, y) - xlogy(y, mu)) + (xlogy(n-y, n-y) - xlogy(n-y, n-mu)) ]

n_minus_y = n_i_col - y_ij
n_minus_mu = n_i_col - mu_ij

# Term 1: y * log(y / mu) = xlogy(y, y) - xlogy(y, mu)
# xlogy handles y=0 case correctly returning 0.
term1 = xlogy(y_ij, y_ij) - xlogy(y_ij, mu_ij)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the import code to make it easier to fix issues.

# Term 2: (n-y) * log((n-y) / (n-mu)) = xlogy(n-y, n-y) - xlogy(n-y, n-mu)
# xlogy handles n-y=0 case correctly returning 0.
term2 = xlogy(n_minus_y, n_minus_y) - xlogy(n_minus_y, n_minus_mu)

# Calculate full deviance: D = 2 * (term1 + term2)
# Handle potential floating point inaccuracies leading to small negatives
deviance = 2 * (term1 + term2)
deviance = np.maximum(deviance, 0.0) # Ensure non-negative before sqrt

# Calculate signed square root residuals: sign(y - mu) * sqrt(D)
signs = np.sign(y_ij - mu_ij)

# Reset sign to 0 if deviance is exactly 0 (e.g. y=0, mu=0 or y=n, mu=n)
# Avoids sign(-0.0) sometimes being -1
signs[deviance == 0] = 0

X.X = signs * np.sqrt(deviance)
return X


def import_CCLE(pca_option="dev_pca", n_comp=10) -> AnnData:
def import_CCLE(pca_option="pca", n_comp=10, min_cell_count: int = 10) -> AnnData:
# pca option should be passed as either pca or glm_pca
"""Imports barcoded cell data."""
adatas = {}
Expand Down Expand Up @@ -117,7 +37,7 @@ def import_CCLE(pca_option="dev_pca", n_comp=10) -> AnnData:
adatas[name] = data

X = concat(adatas, label="sample", index_unique="-")
X.X = csr_matrix(X.X)
X.X = csr_array(X.X).todense()

counts = X.obs["SW"].value_counts()
counts = counts[counts > 5]
Expand All @@ -129,15 +49,24 @@ def import_CCLE(pca_option="dev_pca", n_comp=10) -> AnnData:

# Counts per cell
X.obs["n_counts"] = X.X.sum(axis=1)
X = X[X.obs["n_counts"] >= min_cell_count, :]

# conditional statement for either dev_pca or pca
if pca_option == "dev_pca":
X = prepare_dataset_dev(X)
X.X = scale(X.X)
sc.pp.pca(X, n_comps=n_comp, svd_solver="arpack")
if pca_option == "none":
sc.pp.normalize_total(X, target_sum=40000)
else:
X = prepare_dataset(X, geneThreshold=0.001)
geneThreshold = 0.1

# Filter out genes with too few reads
X = X[:, X.X.mean(axis=0) > geneThreshold] # type: ignore

# Normalize read depth
sc.pp.normalize_total(X, exclude_highly_expressed=False, inplace=True)

# Transform values
X.X = np.log10((1000.0 * X.X) + 1.0) # type: ignore

X.X = scale(X.X)
sc.pp.pca(X, n_comps=n_comp, svd_solver="arpack")
sc.pp.pca(X, n_comps=n_comp, svd_solver="randomized")

return X