Skip to content

Commit bef8f62

Browse files
feat(_transcription_dynamics): Added function for multiome dynamics.
Signed-off-by: Alexander Aivazidis <alexander.aivazidis@sanger.ac.uk>
1 parent 6809aec commit bef8f62

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

src/pyrovelocity/models/_transcription_dynamics.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,73 @@ def mrna_dynamics(
6161
return ut, st
6262

6363

64+
@beartype
65+
def atac_mrna_dynamics(
66+
tau_c: Tensor,
67+
tau: Tensor,
68+
c0: Tensor,
69+
u0: Tensor,
70+
s0: Tensor,
71+
alpha_c: Tensor,
72+
alpha: Tensor,
73+
beta: Tensor,
74+
gamma: Tensor,
75+
) -> Tuple[Tensor, Tensor]:
76+
"""
77+
Computes the ATAC and mRNA dynamics given temporal coordinate, parameter values, and
78+
initial conditions.
79+
80+
`st_gamma_equals_beta` for the case where the gamma parameter is equal
81+
to the beta parameter is taken from Equation 2.12 of
82+
83+
Args:
84+
tau (Tensor): Time points starting at last change in RNA transcription rate.
85+
tau_c (Tensor): Time points starting at last change in chromatin opening/closing rate.
86+
c0 (Tensor): Initial value of c.
87+
u0 (Tensor): Initial value of u.
88+
s0 (Tensor): Initial value of s.
89+
alpha_c (Tensor): Rate of chromatin opening/closing.
90+
alpha (Tensor): Alpha parameter.
91+
beta (Tensor): Beta parameter.
92+
gamma (Tensor): Gamma parameter.
93+
94+
Returns:
95+
Tuple[Tensor, Tensor]: Tuple containing the final values of c, u and s.
96+
97+
Examples:
98+
>>> import torch
99+
>>> tau = torch.tensor(2.0)
100+
>>> tau_c = torch.tensor(2.0)
101+
>>> c0 = torch.tensor(1.0)
102+
>>> u0 = torch.tensor(1.0)
103+
>>> s0 = torch.tensor(0.5)
104+
>>> alpha_c = torch.tensor(0.45)
105+
>>> alpha = torch.tensor(0.5)
106+
>>> beta = torch.tensor(0.4)
107+
>>> gamma = torch.tensor(0.3)
108+
>>> mrna_dynamics(tau_c, tau, c0, u0, s0, alpha_c, alpha, beta, gamma)
109+
(tensor(1.1377), tensor(0.9269))
110+
"""
111+
112+
A = torch.exp(-alpha_c * tau_c)
113+
B = torch.exp(-beta * tau)
114+
C = torch.exp(-gamma * tau)
115+
116+
ct = c0 * A + k_c * (1 - A)
117+
ut = (
118+
u0 * B
119+
+ alpha * k_c / beta * (1 - B)
120+
+ (k_c - c0) * alpha / (beta - alpha_c) * (B - A)
121+
)
122+
st = s0 * C + alpha * k_c / gamma * (1 - C)
123+
+beta / (gamma - beta) * (
124+
(alpha * k_c) / beta - u0 - (k_c - c0) * alpha / (beta - alpha_c)
125+
) * (C - B)
126+
+beta / (gamma - alpha_c) * (k_c - c0) * alpha / (beta - alpha_c) * (C - A)
127+
128+
return ct, ut, st
129+
130+
64131
@beartype
65132
def inv(x: Tensor) -> Tensor:
66133
"""

0 commit comments

Comments
 (0)