Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions pf2rnaseq/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import scanpy as sc
import scipy.sparse as sps
from pacmap import PaCMAP
from parafac2.normalize import prepare_dataset
from parafac2.parafac2 import parafac2_nd, store_pf2
from scipy.stats import gmean
from sklearn.decomposition import PCA
Expand Down Expand Up @@ -190,3 +191,188 @@ def fms_diff_ranks(
)

return df


def downsample_counts_multinomial(
X: anndata.AnnData,
percent_drop: float,
random_state: int = 0,
) -> anndata.AnnData:
"""
Create a downsampled counts copy of AnnData using multinomial sampling.

Parameters:
-----------
X : anndata.AnnData
Input dataset
percent_drop : float
Percentage of counts to drop (0-100)
random_state : int
Random seed for reproducibility

Returns:
--------
anndata.AnnData
Downsampled copy of the input data
"""
import scipy.sparse as sp

# Handle 0% drop case
if percent_drop == 0:
return X.copy()

# Set random seed
np.random.seed(random_state)

# Convert to CSR and extract structure
original_csr = X.X.tocsr()
data = original_csr.data.copy()
indices = original_csr.indices
indptr = original_csr.indptr

# Process each cell
for cell_idx in range(X.n_obs):
start_idx = indptr[cell_idx]
end_idx = indptr[cell_idx + 1]

if start_idx == end_idx:
continue

cell_data = data[start_idx:end_idx]
total_counts = int(np.sum(cell_data))

if total_counts == 0:
continue

new_total = int(total_counts * (1 - percent_drop / 100))
if new_total == 0:
data[start_idx:end_idx] = 0
continue

# Convert to probabilities and normalize
probs = cell_data / total_counts
probs = probs / np.sum(probs) # Ensure sum = 1.0

# Multinomial sampling
new_counts = np.random.multinomial(new_total, probs)
data[start_idx:end_idx] = new_counts.astype(cell_data.dtype)

# Create new sparse matrix
sampled_csr = sp.csr_matrix((data, indices, indptr), shape=original_csr.shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh preserving the sparsity is clever.


# Create new AnnData object
sampled_data = X.copy()
sampled_data.X = sampled_csr

return sampled_data


def calculate_fms_downsample(
X: anndata.AnnData,
X_pf2: anndata.AnnData,
percent_drop: float,
rank: int = 30,
deviance: bool = False,
condition: str = "Condition",
random_state: int = 0,
) -> float:
"""
Calculate FMS for a single downsampling scenario.

Parameters:
-----------
X : anndata.AnnData
Original dataset for reference
X_pf2 : anndata.AnnData
Full factorized dataset
percent_drop : float
Percentage of counts to drop (0-100)
rank : int
Factorization rank
deviance : bool
Whether to use deviance normalization
condition : str
Condition column name
random_state : int
Random seed

Returns:
--------
float
FMS score
"""

# Handle 0% drop case
if percent_drop == 0:
return 1.0

# Create downsampled data
sampled_data = downsample_counts_multinomial(
X, percent_drop, random_state=random_state
)

# Apply same processing as reference
sampled_data = prepare_dataset(
sampled_data, condition, geneThreshold=0.0, deviance=deviance
)

# Factorization
sampledX = pf2(sampled_data, rank, random_state=random_state + 2, doEmbedding=False)

return calculateFMS(X_pf2, sampledX)


def fms_percent_drop_counts(
X: anndata.AnnData,
percentList: np.ndarray,
rank: int = 30,
deviance: bool = False,
condition: str = "Condition",
geneThreshold: float = 0.0,
random_state: int = 0,
) -> pd.DataFrame:
"""
Calculate FMS for multiple downsampling percentages (single run).

Parameters:
-----------
X : anndata.AnnData
Input dataset
percentList : np.ndarray
Array of dropout percentages to test
rank : int
Factorization rank
deviance : bool
Whether to use deviance normalization
condition : str
Condition column name
geneThreshold : float
Gene threshold for preparation
random_state : int
Random seed

Returns:
--------
pd.DataFrame
DataFrame with columns: Percentage of Counts Dropped, FMS
"""
results = []
X_prepared = prepare_dataset(
X, condition, geneThreshold=geneThreshold, deviance=deviance
)
X_pf2 = pf2(X_prepared, rank, doEmbedding=False)

for percent_drop in percentList:
fms_score = calculate_fms_downsample(
X=X,
X_pf2=X_pf2,
percent_drop=percent_drop,
rank=rank,
deviance=deviance,
condition=condition,
random_state=random_state,
)

results.append({"Percentage of Counts Dropped": percent_drop, "FMS": fms_score})

return pd.DataFrame(results)
22 changes: 21 additions & 1 deletion pf2rnaseq/figures/commonFuncs/plotGeneral.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import seaborn as sns
from matplotlib.axes import Axes

from ...factorization import fms_percent_drop, pf2_pca_r2x, fms_diff_ranks
from ...factorization import (
fms_diff_ranks,
fms_percent_drop,
fms_percent_drop_counts,
pf2_pca_r2x,
)


def plot_r2x(data, rank_vec, ax: Axes):
Expand Down Expand Up @@ -460,3 +465,18 @@ def plot_fms_percent_drop(
df = fms_percent_drop(X, percentList, runs, rank)
sns.lineplot(data=df, x="Percentage of Data Dropped", y="FMS", ax=ax)
ax.set_ylim(0, 1)


def plot_fms_percent_drop_counts(
X: anndata.AnnData,
ax: Axes,
percentList: np.ndarray,
rank: int = 30,
deviance: bool = False,
label: str = None,
):
"""Plots FMS when dropping different percentages of data"""
df = fms_percent_drop_counts(X, percentList, rank, deviance=deviance)
sns.lineplot(data=df, x="Percentage of Counts Dropped", y="FMS", ax=ax, label=label)
ax.set_ylim(0, 1)

31 changes: 31 additions & 0 deletions pf2rnaseq/figures/figureCountFMS.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
factorization score

"""

from anndata import read_h5ad

from .common import getSetup, subplotLabel
from .commonFuncs.plotGeneral import plot_fms_percent_drop_counts


def makeFigure():
ax, f = getSetup((6, 3), (1, 1))
subplotLabel(ax)
# Using our cytokine dataset
X = read_h5ad("/opt/extra-storage/Treg_h5ads/Treg_raw.h5ad")

# Remove multiplexing identifiers
X = X[:, ~X.var_names.str.match("^CMO3[0-9]{2}$")] # type: ignore
# Remove genes with too few reads now
X = X[X.X.sum(axis=1) > 10, X.X.mean(axis=0) > 0.1]
X = X.copy()
percentList = [0.0, 30.0, 50.0]
plot_fms_percent_drop_counts(
X, ax[0], percentList, rank=15, deviance=True, label="Deviance"
)
plot_fms_percent_drop_counts(
X, ax[0], percentList, rank=15, deviance=False, label="CPM"
)

return f
Loading