Skip to content

Derived scan logprob fails when observed data provides more broadcastable information than generative graph #7892

@ricardoV94

Description

@ricardoV94

Description

Reported by @jessegrabowski

import numpy as np
import pytensor
import pymc as pm
from pymc.pytensorf import collect_default_updates
from pymc.distributions.shape_utils import rv_size_is_none

def GRW(y_init, size=None):
    def grw_step(y_tm1):
        y = pm.Normal.dist(mu=y_tm1)
        return y, pm.pytensorf.collect_default_updates([y])
    
    if rv_size_is_none(size):
        n_steps = 10
    else:
        n_steps = size[0]

    y_hat, updates = pytensor.scan(fn=grw_step, outputs_info=[y_init], n_steps=n_steps)
    return y_hat

coords = {
    'date': range(10),
    'item': [1],
}
with pm.Model(coords=coords) as m:
    y0 = pm.Normal('y0', 0, 0.1, dims=['item'])
    y_hat = pm.CustomDist(
        'y_hat',
          y0,
          dist=GRW,
          dims=['date', 'item'],
          observed=np.ones((10, 1)),
    )

m.logp()  # TypeError: The broadcast pattern of the output of scan (Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info` (Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis 1 in `output_info`. This can happen if one of the dimension is fixed to 1 in the input, while it is still variable in the output, or vice-verca. You have to make them consistent, e.g. using pytensor.tensor.specify_broadcastable.

We need to somehow introduce the correct broadcastable information during the logprob inference...

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions