Skip to content

Add GrassiaIIGeometric Distribution #528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
269dd75
dist and rv init commit
ColtAllen Mar 29, 2025
b264161
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 11, 2025
d734c68
docstrings
ColtAllen Apr 15, 2025
71bd632
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Apr 15, 2025
48e93f3
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Apr 15, 2025
93c4a60
unit tests
ColtAllen Apr 20, 2025
d2e72b5
alpha min value
ColtAllen Apr 20, 2025
8685005
revert alpha lim
ColtAllen Apr 21, 2025
026f182
small lam value tests
ColtAllen Apr 22, 2025
d12dd0b
ruff formatting
ColtAllen Apr 22, 2025
bcd9cac
TODOs
ColtAllen Apr 22, 2025
78be107
WIP add covar support to RV
ColtAllen Apr 22, 2025
f3ae359
Merge branch 'main' into grassia2geo-dist
ColtAllen Jun 20, 2025
8a30459
WIP time indexing
ColtAllen Jun 20, 2025
7c7afc8
WIP time indexing
ColtAllen Jun 20, 2025
fa9c1ec
Merge branch 'grassia2geo-dist' of https://github.com/ColtAllen/pymc-…
ColtAllen Jun 20, 2025
b957333
WIP symbolic indexing
ColtAllen Jun 20, 2025
d0c1d98
delete test_simple.py
ColtAllen Jun 20, 2025
264c55e
fix symbolic indexing errors
ColtAllen Jul 11, 2025
05e7c55
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Jul 11, 2025
0fa3390
clean up cursor code
ColtAllen Jul 11, 2025
5baa6f7
warn for ndims deprecation
ColtAllen Jul 11, 2025
a715ec7
clean up comments and final TODO
ColtAllen Jul 11, 2025
f3c0f29
remove ndims deprecation and extraneous code
ColtAllen Jul 11, 2025
a232e4c
revert changes to irrelevant test
ColtAllen Jul 12, 2025
ffc059f
remove time_covariate_vector default args
ColtAllen Jul 12, 2025
1d41eb7
revert remaining changes in irrelevant tests
ColtAllen Jul 12, 2025
47ad523
remove test_sampling_consistency
ColtAllen Jul 12, 2025
5b77263
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
eb7222f
checkpoint commit for log_cdf and test frameworks
ColtAllen Jul 12, 2025
b34e3d8
make C_t external function, code cleanup
ColtAllen Jul 12, 2025
9803321
rng_fn cleanup
ColtAllen Jul 13, 2025
5ff6853
WIP test frameworks
ColtAllen Jul 13, 2025
63a0b10
inverse cdf
ColtAllen Jul 15, 2025
932a046
covariate pos constraint and WIP RV
ColtAllen Jul 15, 2025
b78a5c4
Merge branch 'pymc-devs:main' into grassia2geo-dist
ColtAllen Jul 28, 2025
ab45a9c
WIP rng_fn testing
ColtAllen Jul 28, 2025
0d1dcea
WIP time covars required param
ColtAllen Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pymc_extras/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BetaNegativeBinomial,
GeneralizedPoisson,
Skellam,
GrassiaIIGeometric,
)
from pymc_extras.distributions.histogram_utils import histogram_approximation
from pymc_extras.distributions.multivariate import R2D2M2CP
Expand All @@ -38,5 +39,6 @@
"R2D2M2CP",
"Skellam",
"histogram_approximation",
"GrassiaIIGeometric",
"PartialOrder",
]
223 changes: 223 additions & 0 deletions pymc_extras/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import numpy as np
import pymc as pm

from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
from pymc.distributions.distribution import Discrete
from pymc.distributions.shape_utils import rv_size_is_none
from pytensor import tensor as pt
from pytensor.tensor.random.op import RandomVariable

warnings.filterwarnings("ignore", category=FutureWarning, message="ndims_params is deprecated")


def log1mexp(x):
cond = x < np.log(0.5)
Expand Down Expand Up @@ -399,3 +404,221 @@ def dist(cls, mu1, mu2, **kwargs):
class_name="Skellam",
**kwargs,
)


class GrassiaIIGeometricRV(RandomVariable):
name = "g2g"
signature = "(),(),()->()"
ndims_params = [0, 0, 0] # deprecated in PyTensor 2.31.7, but still required for RandomVariable

dtype = "int64"
_print_name = ("GrassiaIIGeometric", "\\operatorname{GrassiaIIGeometric}")

@classmethod
def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
# Cast inputs as numpy arrays
r = np.asarray(r, dtype=np.float64)
alpha = np.asarray(alpha, dtype=np.float64)
time_covariate_vector = np.asarray(time_covariate_vector, dtype=np.float64)

# Determine output size
if size is None:
size = np.broadcast_shapes(r.shape, alpha.shape, time_covariate_vector.shape)

# Broadcast parameters to output size
r = np.broadcast_to(r, size)
alpha = np.broadcast_to(alpha, size)
time_covariate_vector = np.broadcast_to(time_covariate_vector, size)

lam = rng.gamma(shape=r, scale=1 / alpha, size=size)

# Calculate exp(time_covariate_vector) for all samples
exp_time_covar = np.exp(time_covariate_vector)
lam_covar = lam * exp_time_covar

# TODO: This is not aggregated over time
p = 1 - np.exp(-lam_covar)

# Ensure p is in valid range for geometric distribution
min_p = max(1e-6, np.finfo(float).tiny) # Minimum probability to prevent infinite values
p = np.clip(p, min_p, 1.0)
Copy link
Member

@ricardoV94 ricardoV94 Jul 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both clips are suspicious. Can't we compute draws in a more stable way. You could use log1mexp to get p on log scale, is there a way to sample from a geometric with log_p instead of p?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to sample from a custom PMF rather than a Geometric?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need an inverse CDF or a custom sampling algorithm. There's no generic way of sampling from a PMF (other than running some MCMC algorithm)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can derive the inverse CDF (another task for the trusty whiteboard). Would adding this to rng_fn look something like rng.inverse_cdf?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a stable inverse cdf (or even better inverse log_cdf), you could then take a uniform (or log uniform) draw and pass it through the icdf to get a draw. I'm not saying that's the best route, just that those clips are very indicative of a poor random implementation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found a stable inverse CDF, which is probably the best path forward because I no longer think taking draws from a geometric is viable with these time-indexed covariates.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what you explained to me the geometric draws are what actually makes most sense. You should be able to do pm.logp(rv, pm.draw(rv)), so you should have the same size when you eval the logp and what the rv returns. I suppose your inverse cdf only returns one value?


samples = rng.geometric(p)

# Clip samples to reasonable bounds to prevent infinite values
max_sample = 10000 # Reasonable upper bound for discrete time-to-event data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this an argument?

samples = np.clip(samples, 1, max_sample)

return samples


g2g = GrassiaIIGeometricRV()


class GrassiaIIGeometric(Discrete):
r"""Grassia(II)-Geometric distribution.

This distribution is a flexible alternative to the Geometric distribution for the number of trials until a
discrete event, and can be extended to support both static and time-varying covariates.

Hardie and Fader describe this distribution with the following PMF and survival functions in [1]_:

.. math::
\mathbb{P}T=t|r,\alpha,\beta;Z(t)) = (\frac{\alpha}{\alpha+C(t-1)})^{r} - (\frac{\alpha}{\alpha+C(t)})^{r} \\
\begin{align}
\mathbb{S}(t|r,\alpha,\beta;Z(t)) = (\frac{\alpha}{\alpha+C(t)})^{r} \\
\end{align}

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import arviz as az
plt.style.use('arviz-darkgrid')
t = np.arange(1, 11)
alpha_vals = [1., 1., 2., 2.]
r_vals = [.1, .25, .5, 1.]
for alpha, r in zip(alpha_vals, r_vals):
pmf = (alpha/(alpha + t - 1))**r - (alpha/(alpha+t))**r
plt.plot(t, pmf, '-o', label=r'$\alpha$ = {}, $r$ = {}'.format(alpha, r))
plt.xlabel('t', fontsize=12)
plt.ylabel('p(t)', fontsize=12)
plt.legend(loc=1)
plt.show()

======== ===============================================
Support :math:`t \in \mathbb{N}_{>0}`
======== ===============================================

Parameters
----------
r : tensor_like of float
Shape parameter (r > 0).
alpha : tensor_like of float
Scale parameter (alpha > 0).
time_covariate_vector : tensor_like of float, optional
Optional vector containing dot products of time-varying covariates and coefficients.

References
----------
.. [1] Fader, Peter & G. S. Hardie, Bruce (2020).
"Incorporating Time-Varying Covariates in a Simple Mixture Model for Discrete-Time Duration Data."
https://www.brucehardie.com/notes/037/time-varying_covariates_in_BG.pdf
"""

rv_op = g2g

@classmethod
def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
r = pt.as_tensor_variable(r)
alpha = pt.as_tensor_variable(alpha)
if time_covariate_vector is None:
time_covariate_vector = pt.constant(0.0)
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
return super().dist([r, alpha, time_covariate_vector], *args, **kwargs)

def logp(value, r, alpha, time_covariate_vector=None):
if time_covariate_vector is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like your logp doesn't handle ndim > 1 right? In that case raise NotImplementedError if value.ndim > 1 ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would hierarchical models still be supported if this were the case?

time_covariate_vector = pt.constant(0.0)
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)

def C_t(t):
if t == 0:
return pt.constant(0.0)
if time_covariate_vector.ndim == 0:
return t
else:
# Ensure t is a valid index
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
# If t_idx exceeds length of time_covariate_vector, use last value
max_idx = pt.shape(time_covariate_vector)[0] - 1
safe_idx = pt.minimum(t_idx, max_idx)
covariate_value = time_covariate_vector[safe_idx]
return t * pt.exp(covariate_value)

logp = pt.log(
pt.pow(alpha / (alpha + C_t(value - 1)), r) - pt.pow(alpha / (alpha + C_t(value)), r)
)

# Handle invalid values
logp = pt.switch(
pt.or_(
value < 1, # Value must be >= 1
pt.isnan(logp), # Handle NaN cases
),
-np.inf,
logp,
)

return check_parameters(
logp,
r > 0,
alpha > 0,
msg="r > 0, alpha > 0",
)

def logcdf(value, r, alpha, time_covariate_vector=None):
if time_covariate_vector is None:
time_covariate_vector = pt.constant(0.0)
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)

def C_t(t):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be moved outside the function so that it can be reused by logp?

Copy link
Author

@ColtAllen ColtAllen Jul 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how kosher this is, but due to how instantiation is handled in logp and logcdf, I had to move C_t outside the distribution class altogether to get it to work.

if t == 0:
return pt.constant(0.0)
if time_covariate_vector.ndim == 0:
return t
else:
# Ensure t is a valid index
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
# If t_idx exceeds length of time_covariate_vector, use last value
max_idx = pt.shape(time_covariate_vector)[0] - 1
safe_idx = pt.minimum(t_idx, max_idx)
covariate_value = time_covariate_vector[safe_idx]
return t * pt.exp(covariate_value)

survival = pt.pow(alpha / (alpha + C_t(value)), r)
logcdf = pt.log(1 - survival)

return check_parameters(
logcdf,
r > 0,
alpha > 0,
msg="r > 0, alpha > 0",
)

def support_point(rv, size, r, alpha, time_covariate_vector=None):
"""Calculate a reasonable starting point for sampling.

For the GrassiaIIGeometric distribution, we use a point estimate based on
the expected value of the mixing distribution. Since the mixing distribution
is Gamma(r, 1/alpha), its mean is r/alpha. We then transform this through
the geometric link function and round to ensure an integer value.

When time_covariate_vector is provided, it affects the expected value through
the exponential link function: exp(time_covariate_vector).
"""
if time_covariate_vector is None:
time_covariate_vector = pt.constant(0.0)

base_lambda = r / alpha

# Approximate expected value of geometric distribution
mean = pt.switch(
base_lambda < 0.1,
1.0 / base_lambda, # Approximation for small lambda
1.0 / (1.0 - pt.exp(-base_lambda)), # Full expression for larger lambda
)

# Apply time covariates if provided
mean = mean * pt.exp(time_covariate_vector)

# Round up to nearest integer and ensure >= 1
mean = pt.maximum(pt.ceil(mean), 1.0)

# Handle size parameter
if not rv_size_is_none(size):
mean = pt.full(size, mean)

return mean
Loading