Skip to content

Commit f34196b

Browse files
added atac_layer argument to train_model tasks and made tests for it
1 parent 4429882 commit f34196b

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

src/pyrovelocity/tasks/train.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def train_dataset(
188188
"batch_size": batch_size,
189189
}
190190
)
191-
191+
192192
adata, trained_model, posterior_samples = train_model(
193193
adata=adata,
194194
guide_type=guide_type,
@@ -279,6 +279,7 @@ def check_shared_time(posterior_samples, adata):
279279
@beartype
280280
def train_model(
281281
adata: str | AnnData,
282+
atac_layer: Optional[str] = None,
282283
guide_type: str = "auto",
283284
model_type: str = "auto",
284285
batch_size: int = -1,
@@ -305,6 +306,7 @@ def train_model(
305306
306307
Args:
307308
adata (str | AnnData): Path to a file that can be read to an AnnData object or an AnnData object.
309+
atac_layer (Optional[str], optional): Name of AnnData layer that contains atac data, if present.
308310
guide_type (str, optional): The type of guide function for the Pyro model. Default is "auto".
309311
model_type (str, optional): The type of Pyro model. Default is "auto".
310312
batch_size (int, optional): Batch size for training. Default is -1, which indicates using the full dataset.
@@ -347,6 +349,9 @@ def train_model(
347349
>>> copy_raw_counts(adata)
348350
>>> _, model, posterior_samples = train_model(adata, use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path)
349351
"""
352+
353+
if atac_layer:
354+
logger.info("Multiome model not yet implemented. Proceeding without atac data.")
350355
if isinstance(adata, str):
351356
adata = load_anndata_from_path(adata)
352357

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Unit test package for pyrovelocity."""
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Tests for `pyrovelocity._train_model` task."""
2+
3+
from pyrovelocity.tasks.train import train_model
4+
from pyrovelocity.utils import generate_sample_data
5+
from pyrovelocity.tasks.preprocess import copy_raw_counts
6+
def test_train_model(tmp_path):
7+
loss_plot_path = str(tmp_path) + "/loss_plot_docs.png"
8+
print(loss_plot_path)
9+
adata = generate_sample_data(random_seed=99)
10+
copy_raw_counts(adata)
11+
_, model, posterior_samples = train_model(adata, atac_layer = 'atac',
12+
use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path)

0 commit comments

Comments
 (0)