-
Notifications
You must be signed in to change notification settings - Fork 70
Add GrassiaIIGeometric Distribution #528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ColtAllen
wants to merge
40
commits into
pymc-devs:main
Choose a base branch
from
ColtAllen:grassia2geo-dist
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 33 commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
269dd75
dist and rv init commit
ColtAllen b264161
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen d734c68
docstrings
ColtAllen 71bd632
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen 48e93f3
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen 93c4a60
unit tests
ColtAllen d2e72b5
alpha min value
ColtAllen 8685005
revert alpha lim
ColtAllen 026f182
small lam value tests
ColtAllen d12dd0b
ruff formatting
ColtAllen bcd9cac
TODOs
ColtAllen 78be107
WIP add covar support to RV
ColtAllen f3ae359
Merge branch 'main' into grassia2geo-dist
ColtAllen 8a30459
WIP time indexing
ColtAllen 7c7afc8
WIP time indexing
ColtAllen fa9c1ec
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen b957333
WIP symbolic indexing
ColtAllen d0c1d98
delete test_simple.py
ColtAllen 264c55e
fix symbolic indexing errors
ColtAllen 05e7c55
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen 0fa3390
clean up cursor code
ColtAllen 5baa6f7
warn for ndims deprecation
ColtAllen a715ec7
clean up comments and final TODO
ColtAllen f3c0f29
remove ndims deprecation and extraneous code
ColtAllen a232e4c
revert changes to irrelevant test
ColtAllen ffc059f
remove time_covariate_vector default args
ColtAllen 1d41eb7
revert remaining changes in irrelevant tests
ColtAllen 47ad523
remove test_sampling_consistency
ColtAllen 5b77263
checkpoint commit for log_cdf and test frameworks
ColtAllen eb7222f
checkpoint commit for log_cdf and test frameworks
ColtAllen b34e3d8
make C_t external function, code cleanup
ColtAllen 9803321
rng_fn cleanup
ColtAllen 5ff6853
WIP test frameworks
ColtAllen 63a0b10
inverse cdf
ColtAllen 932a046
covariate pos constraint and WIP RV
ColtAllen b78a5c4
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen ab45a9c
WIP rng_fn testing
ColtAllen 0d1dcea
WIP time covars required param
ColtAllen 434e5a5
C_t for RV time covar support
ColtAllen 86085ce
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import pymc as pm | ||
import pytensor | ||
|
@@ -26,13 +27,15 @@ | |
Rplus, | ||
assert_support_point_is_expected, | ||
check_logp, | ||
check_selfconsistency_discrete_logcdf, | ||
discrete_random_tester, | ||
) | ||
from pytensor import config | ||
|
||
from pymc_extras.distributions import ( | ||
BetaNegativeBinomial, | ||
GeneralizedPoisson, | ||
GrassiaIIGeometric, | ||
Skellam, | ||
) | ||
|
||
|
@@ -208,3 +211,161 @@ def test_logp(self): | |
{"mu1": Rplus_small, "mu2": Rplus_small}, | ||
lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2), | ||
) | ||
|
||
|
||
class TestGrassiaIIGeometric: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can a test for |
||
class TestRandomVariable(BaseTestDistributionRandom): | ||
pymc_dist = GrassiaIIGeometric | ||
pymc_dist_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": 0.0} | ||
expected_rv_op_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": 0.0} | ||
tests_to_run = [ | ||
"check_pymc_params_match_rv_op", | ||
"check_rv_size", | ||
] | ||
|
||
def test_random_basic_properties(self): | ||
"""Test basic random sampling properties""" | ||
# Test with standard parameter values | ||
r_vals = [0.5, 1.0, 2.0] | ||
alpha_vals = [0.5, 1.0, 2.0] | ||
time_cov_vals = [-1.0, 1.0, 2.0] | ||
|
||
for r in r_vals: | ||
for alpha in alpha_vals: | ||
for time_cov in time_cov_vals: | ||
dist = self.pymc_dist.dist( | ||
r=r, alpha=alpha, time_covariate_vector=time_cov, size=1000 | ||
) | ||
draws = dist.eval() | ||
|
||
# Check basic properties | ||
assert np.all(draws > 0) | ||
assert np.all(draws.astype(int) == draws) | ||
assert np.mean(draws) > 0 | ||
assert np.var(draws) > 0 | ||
|
||
def test_random_edge_cases(self): | ||
"""Test edge cases with more reasonable parameter values""" | ||
# Test with small r and large alpha values | ||
r_vals = [0.1, 0.5] | ||
alpha_vals = [5.0, 10.0] | ||
time_cov_vals = [0.0, 1.0] | ||
|
||
for r in r_vals: | ||
for alpha in alpha_vals: | ||
for time_cov in time_cov_vals: | ||
dist = self.pymc_dist.dist( | ||
r=r, alpha=alpha, time_covariate_vector=time_cov, size=1000 | ||
) | ||
draws = dist.eval() | ||
|
||
# Check basic properties | ||
assert np.all(draws > 0) | ||
assert np.all(draws.astype(int) == draws) | ||
assert np.mean(draws) > 0 | ||
assert np.var(draws) > 0 | ||
|
||
def test_random_none_covariates(self): | ||
"""Test random sampling with None time_covariate_vector""" | ||
r_vals = [0.5, 1.0, 2.0] | ||
alpha_vals = [0.5, 1.0, 2.0] | ||
|
||
for r in r_vals: | ||
for alpha in alpha_vals: | ||
dist = self.pymc_dist.dist( | ||
r=r, | ||
alpha=alpha, | ||
time_covariate_vector=0.0, # Changed from None to avoid zip issues | ||
size=1000, | ||
) | ||
draws = dist.eval() | ||
|
||
# Check basic properties | ||
assert np.all(draws > 0) | ||
assert np.all(draws.astype(int) == draws) | ||
assert np.mean(draws) > 0 | ||
assert np.var(draws) > 0 | ||
|
||
@pytest.mark.parametrize( | ||
"r,alpha,time_covariate_vector", | ||
[ | ||
(0.5, 1.0, 0.0), | ||
(1.0, 2.0, 1.0), | ||
(2.0, 0.5, -1.0), | ||
(5.0, 1.0, 0.0), # Changed from None to avoid zip issues | ||
], | ||
) | ||
def test_random_moments(self, r, alpha, time_covariate_vector): | ||
dist = self.pymc_dist.dist( | ||
r=r, alpha=alpha, time_covariate_vector=time_covariate_vector, size=10_000 | ||
) | ||
draws = dist.eval() | ||
|
||
assert np.all(draws > 0) | ||
assert np.all(draws.astype(int) == draws) | ||
assert np.mean(draws) > 0 | ||
assert np.var(draws) > 0 | ||
|
||
def test_logp_basic(self): | ||
ColtAllen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Create PyTensor variables with explicit values to ensure proper initialization | ||
r = pt.as_tensor_variable(1.0) | ||
alpha = pt.as_tensor_variable(2.0) | ||
time_covariate_vector = pt.as_tensor_variable(0.5) | ||
value = pt.vector("value", dtype="int64") | ||
|
||
# Create the distribution with the PyTensor variables | ||
dist = GrassiaIIGeometric.dist(r, alpha, time_covariate_vector) | ||
logp = pm.logp(dist, value) | ||
logp_fn = pytensor.function([value], logp) | ||
|
||
# Test basic properties of logp | ||
test_value = np.array([1, 2, 3, 4, 5]) | ||
|
||
logp_vals = logp_fn(test_value) | ||
assert not np.any(np.isnan(logp_vals)) | ||
assert np.all(np.isfinite(logp_vals)) | ||
|
||
# Test invalid values | ||
assert logp_fn(np.array([0])) == -np.inf # Value must be > 0 | ||
|
||
with pytest.raises(TypeError): | ||
logp_fn(np.array([1.5])) # Value must be integer | ||
|
||
def test_logcdf(self): | ||
# test logcdf matches log sums across parameter values | ||
check_selfconsistency_discrete_logcdf( | ||
GrassiaIIGeometric, I, {"r": Rplus, "alpha": Rplus, "time_covariate_vector": I} | ||
) | ||
|
||
@pytest.mark.parametrize( | ||
"r, alpha, time_covariate_vector, size, expected_shape", | ||
[ | ||
(1.0, 1.0, 0.0, None, ()), # Scalar output with no covariates (0.0 instead of None) | ||
([1.0, 2.0], 1.0, 0.0, None, (2,)), # Vector output from r | ||
(1.0, [1.0, 2.0], 0.0, None, (2,)), # Vector output from alpha | ||
(1.0, 1.0, [1.0, 2.0], None, (2,)), # Vector output from time covariates | ||
(1.0, 1.0, 1.0, (3, 2), (3, 2)), # Explicit size with scalar time covariates | ||
], | ||
) | ||
def test_support_point(self, r, alpha, time_covariate_vector, size, expected_shape): | ||
"""Test that support_point returns reasonable values with correct shapes""" | ||
with pm.Model() as model: | ||
GrassiaIIGeometric( | ||
"x", r=r, alpha=alpha, time_covariate_vector=time_covariate_vector, size=size | ||
) | ||
|
||
init_point = model.initial_point()["x"] | ||
|
||
# Check shape | ||
assert init_point.shape == expected_shape | ||
|
||
# Check values are positive integers | ||
assert np.all(init_point > 0) | ||
assert np.all(init_point.astype(int) == init_point) | ||
|
||
# Check values are finite and reasonable | ||
assert np.all(np.isfinite(init_point)) | ||
assert np.all(init_point < 1e6) # Should not be extremely large | ||
|
||
# TODO: expected values must be provided | ||
# assert_support_point_is_expected(model, init_point) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this. Why do you take the mean? What is this distribution supposed to do with a time_covariate_vector in the theoretical sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also if it's a vector your signature is wrong, it should be
(),(),(a)->()
or perhaps(),(),(a)->(a)
, it's no longer a univariate distributionsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also you should try to make your RV/logp work with batch dims, so you should be doing stuff like
mean(axis=-1)
, and indexing with[..., safe_idx]
, or explicitly raise NotImplementedError if you don't want to supprot batch dims. By default we assume the implementations work with batch parameters.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the PMF and survival function, the covariate vector is exponentiated and summed over all active time periods. However, the research note implies geometric samples are drawn for each time period. If batch parameters are supported by default perhaps there's nothing to worry about.
I'll have to investigate this more; changing the signature breaks most of the tests. Also, doesn't
pymc.Censored
still only support univariate distributions?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the RV actually representing in plain english? What would be the meaning of each geometric per covariate? Would their sum (or some other aggregation) make sense?
If it's
(),(),(a)->()
it's still univariate, it just has a vector parameter in the core case (like Categorical which takes a vector of p, and returns an integer).Still we can relax the constraint of Censored, we just never did because there was no multivariate distribution with a cdf, so it didn't make sense to bother back then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding tests breaking, yes the signature changes the meaning of some things, so we need to restructure it a bit. First thing we should clarify what is the RV supposed to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Time period at which an event occurs.
p
is typically fixed for Geometric draws, but the covariate vector allows it to vary over time, similar to an Accelerated Failure Time model. Summing would make the most sense, but will have to think about the formulation to ensure 0 < p <= 1.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if you have covariate vector you also observe multiple times, and in your logp you have a value vector that's as long as the covariate for the values (for a single subject)?
The thing that makes this a univariate distribution is that a subject has a constant lambda over all these events?
In that case it sounds like you have a "(),(),(a)->(a)" indeed? Just have to adjust because in that case
a
should always be atleast a vector (even if there's only one constant 0) andsize
doesn't includea
, it's whats batched on top of that, like in the mvnormal size doesn't include the shape implied by the last dimensions of mu or covThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a value scalar:
value == len(covariate_vector)
. Covariates are actually optional for this distribution.Yes
This might be true in the case of RV draws. Is
a == len(convariate_vector)
?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can't have a scalar and multivariate distribution with the same signature, so you could consider len(covar)==1 on the scalar case, and set it to zero like you already do.
Otherwise you can have two instances of the rv with two signatures, but not sure it's worth the trouble
yes
a
in the signature iscovar
. You can give it whatever name you want in the signature. Usually we try to keep short or one letter but doesn't really matter.a
isn't great I guess