@@ -188,7 +188,7 @@ def train_dataset(
188
188
"batch_size" : batch_size ,
189
189
}
190
190
)
191
-
191
+
192
192
adata , trained_model , posterior_samples = train_model (
193
193
adata = adata ,
194
194
guide_type = guide_type ,
@@ -279,6 +279,7 @@ def check_shared_time(posterior_samples, adata):
279
279
@beartype
280
280
def train_model (
281
281
adata : str | AnnData ,
282
+ atac_layer : Optional [str ] = None ,
282
283
guide_type : str = "auto" ,
283
284
model_type : str = "auto" ,
284
285
batch_size : int = - 1 ,
@@ -305,6 +306,7 @@ def train_model(
305
306
306
307
Args:
307
308
adata (str | AnnData): Path to a file that can be read to an AnnData object or an AnnData object.
309
+ atac_layer (Optional[str], optional): Name of AnnData layer that contains atac data, if present.
308
310
guide_type (str, optional): The type of guide function for the Pyro model. Default is "auto".
309
311
model_type (str, optional): The type of Pyro model. Default is "auto".
310
312
batch_size (int, optional): Batch size for training. Default is -1, which indicates using the full dataset.
@@ -347,6 +349,9 @@ def train_model(
347
349
>>> copy_raw_counts(adata)
348
350
>>> _, model, posterior_samples = train_model(adata, use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path)
349
351
"""
352
+
353
+ if atac_layer :
354
+ logger .info ("Multiome model not yet implemented. Proceeding without atac data." )
350
355
if isinstance (adata , str ):
351
356
adata = load_anndata_from_path (adata )
352
357
0 commit comments