@@ -61,6 +61,73 @@ def mrna_dynamics(
61
61
return ut , st
62
62
63
63
64
+ @beartype
65
+ def atac_mrna_dynamics (
66
+ tau_c : Tensor ,
67
+ tau : Tensor ,
68
+ c0 : Tensor ,
69
+ u0 : Tensor ,
70
+ s0 : Tensor ,
71
+ alpha_c : Tensor ,
72
+ alpha : Tensor ,
73
+ beta : Tensor ,
74
+ gamma : Tensor ,
75
+ ) -> Tuple [Tensor , Tensor ]:
76
+ """
77
+ Computes the ATAC and mRNA dynamics given temporal coordinate, parameter values, and
78
+ initial conditions.
79
+
80
+ `st_gamma_equals_beta` for the case where the gamma parameter is equal
81
+ to the beta parameter is taken from Equation 2.12 of
82
+
83
+ Args:
84
+ tau (Tensor): Time points starting at last change in RNA transcription rate.
85
+ tau_c (Tensor): Time points starting at last change in chromatin opening/closing rate.
86
+ c0 (Tensor): Initial value of c.
87
+ u0 (Tensor): Initial value of u.
88
+ s0 (Tensor): Initial value of s.
89
+ alpha_c (Tensor): Rate of chromatin opening/closing.
90
+ alpha (Tensor): Alpha parameter.
91
+ beta (Tensor): Beta parameter.
92
+ gamma (Tensor): Gamma parameter.
93
+
94
+ Returns:
95
+ Tuple[Tensor, Tensor]: Tuple containing the final values of c, u and s.
96
+
97
+ Examples:
98
+ >>> import torch
99
+ >>> tau = torch.tensor(2.0)
100
+ >>> tau_c = torch.tensor(2.0)
101
+ >>> c0 = torch.tensor(1.0)
102
+ >>> u0 = torch.tensor(1.0)
103
+ >>> s0 = torch.tensor(0.5)
104
+ >>> alpha_c = torch.tensor(0.45)
105
+ >>> alpha = torch.tensor(0.5)
106
+ >>> beta = torch.tensor(0.4)
107
+ >>> gamma = torch.tensor(0.3)
108
+ >>> mrna_dynamics(tau_c, tau, c0, u0, s0, alpha_c, alpha, beta, gamma)
109
+ (tensor(1.1377), tensor(0.9269))
110
+ """
111
+
112
+ A = torch .exp (- alpha_c * tau_c )
113
+ B = torch .exp (- beta * tau )
114
+ C = torch .exp (- gamma * tau )
115
+
116
+ ct = c0 * A + k_c * (1 - A )
117
+ ut = (
118
+ u0 * B
119
+ + alpha * k_c / beta * (1 - B )
120
+ + (k_c - c0 ) * alpha / (beta - alpha_c ) * (B - A )
121
+ )
122
+ st = s0 * C + alpha * k_c / gamma * (1 - C )
123
+ + beta / (gamma - beta ) * (
124
+ (alpha * k_c ) / beta - u0 - (k_c - c0 ) * alpha / (beta - alpha_c )
125
+ ) * (C - B )
126
+ + beta / (gamma - alpha_c ) * (k_c - c0 ) * alpha / (beta - alpha_c ) * (C - A )
127
+
128
+ return ct , ut , st
129
+
130
+
64
131
@beartype
65
132
def inv (x : Tensor ) -> Tensor :
66
133
"""
0 commit comments