-
Notifications
You must be signed in to change notification settings - Fork 74
Add keep_conditions argument to continuous_approximator.sample #523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
|
Ok, after your review and eventual changes this would be ready to merge. The convenience this PR provides is that instead of the following tricky commands
users can directly get
Both result in shapes like this:
Please see docstrings for a precise description of |
When writing tests for this I came across a few skip instructions of multivariate normal approximators that can be re-enabled, so I added that to this PR. |
@@ -465,6 +475,15 @@ def sample( | |||
|
|||
if split: | |||
samples = split_arrays(samples, axis=-1) | |||
|
|||
if keep_conditions: | |||
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conditions that are repeated and kept are not the original ones, but the prepared (self._prepare_data
) and adapted ones. Is this intentional? What are the downstream functions that this should be the input for, and which input would they expect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the same code is used below, do you think it would make sense to move this into a helper function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking a look!
Yes, exactly. My thinking was, that this way we can avoid the difficulty of conditions that are not (batch_size, ...) tensors but rather non-batchable conditions. One example for this is in the linear regression notebook, where the sample size is just an integer and it goes into the condition.
When we repeat the output of self._prepare_data
we don't have to deal with this distinction what needs to be repeated and what does not.
Repeating the prepared conditions is also what happens in self._sample
to generate the conditions passed along to the self.inference_network
. This might help being robust against special adapter transformations.
However, the current implementation relies on the adapter inverse being faithful in all special cases and I am not sure of that at the moment.
Do you think this is unsafe in some way? I need to think about it some more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The downstream functions would be for example approximator.log_prob, but also any other post processing users might want to do with samples from the joint posterior predictive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation, I missed the inverse=True
in the adapter call and misinterpreted what happens. I'm still a bit confused, though.
What would be your current assumption regarding the adapter calls? That conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
gives the same output as the conditions
initially passed by the user? Or that there is a difference (i.e., the adapter is not perfectly invertible) that we try to exploit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused as well. I think users should have the choice on whether they want to output the conditions before or after the adapter call. Likely, before the adapter call will have more use cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I admit its a bit confusing, sorry! I definitely agree, we always want the output to be repeated pre-adapter conditions. The question is only how to get there.
Two options:
- repeat all array values of condition dictionaries while leaving non-batched conditions intact.
- apply adapter to conditions (which will lead to broadcasting of the non-batched conditions to batch_shape), then apply an inverse adapter. This way only dict values that are repeated in
approximator._sample
survive which can then be repeated safely.
The second sounds more complicated, but parallels the journey of how conditions are actually repeated and passed to the inference network while sampling and avoids to have to guess what needs to be repeated and what does not.
The implementation I have proposed represents option 2.
My uncertainties in choosing which option is better are
- how hard is it to reliably guess what needs to be repeated? Are there special adapters where the naive guess (just not repeating anything that isn't an array) would fail?
- is the inverse adapter sufficiently reliable?
EDIT: I am now also confused. I am misrepresenting the order of what the code actually does.
EDIT2: I fixed the order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After talking directly, Valentin and I think the feature to keep conditions in a repeated way cannot be provided to users in general. Ignoring the flexibility of real adapters is bound to produce hard-to-debug problems for users.
Both of the uncertainties I mentioned have a negative answers, so both options are therefore inadequate.
Uncertainty 1 (how hard is it to guess what conditions needs to be repeated): input to adapters can be too diverse to reliably guess how to repeat them. This takes option 1 off the table.
Uncertainty 2 (is the inverse adapter cycle consistent in general): We can think of adapters that fail the requirement by essentially losing information in the forward pass. The chain
then produces incorrect repeated conditions.
As discussed in the comments (#523 (comment)) the proposed feature cannot be delivered for arbitrary adapters and I would thus close the PR tomorrow. How to proceed: To get a similar level of convenience without the trouble with reliable inverse adapters, we can solve this inside of log_prob. |
Thank you for thinking about it more throughly. I agree with your assessment. I believe such a feature would be (down the line) better suited for packages that build on bayesflow and support a restricted class of models with known variable structure, for which such a feature could be reliably implemented. |
One comment about log_prob: I would not solve this inside log_prob for now until we know that this is actually a sensible diagnostic. I.e. we would need to repeat some of the experiments in modrak with log_prob to see if that behaves similarily than the (analytic) joint log_lik. Only if the results are encouraging, we should implement it somewhat natively (because implementation from us implies that we encourage users to use it). |
To compute quantities that depend both on conditions and inference variables it is useful to keep the repeated conditions that correspond to samples generated by an approximator.
For convenience I'd propose to add a keep_conditions flag to the approximators that repeats the conditions and returns them along with the newly generated samples.
To-Do:
keep_conditions
forContinuousApproximator.sample
keep_conditions
forPointApproximator.sample
[ ] addkeep_conditions
forModelComparisonApproximator.sample
EDIT: ModelComparisonApproximator doesn't have
sample
, so I removed the todo item.