-
Notifications
You must be signed in to change notification settings - Fork 85
Batched inference CEBRA & padding at the Solver
level
#168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Batched inference CEBRA & padding at the Solver
level
#168
Conversation
…ional models in _transform
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some early comments; apologies if i have asked some of these before
tests/test_solver.py
Outdated
@pytest.mark.parametrize( | ||
"data_name, loader_initfunc, model_architecture, solver_initfunc", | ||
multi_session_tests) | ||
def test_multi_session(data_name, loader_initfunc, model_architecture, | ||
solver_initfunc): | ||
data = cebra.datasets.init(data_name) | ||
loader = _get_loader(data, loader_initfunc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the changes here? i.e. did anything change that would cause the "old" multi session tests to break?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I restablished _get_loader as it was but added a return value as I need the dataset to configure it with the model.
Else,
- I added the model_architecture as offset1-model is a special case for padding at transform.
- I added the configure_for(model) as now this is handled in the solver.
- I added some tests on the transform (was not done at all before), similar to the sklearn tests but at the pytorch level.
tests/test_solver_batched.py
Outdated
single_session_tests_select_model = [] | ||
single_session_hybrid_tests_select_model = [] | ||
for model_name in ["offset1-model", "offset10-model"]: | ||
for session_id in [None, 0, 5]: | ||
for args in [ | ||
("demo-discrete", model_name, session_id, | ||
cebra.data.DiscreteDataLoader), | ||
("demo-continuous", model_name, session_id, | ||
cebra.data.ContinuousDataLoader), | ||
("demo-mixed", model_name, session_id, cebra.data.MixedDataLoader), | ||
]: | ||
single_session_tests_select_model.append( | ||
(*args, cebra.solver.SingleSessionSolver)) | ||
single_session_hybrid_tests_select_model.append( | ||
(*args, cebra.solver.SingleSessionHybridSolver)) | ||
|
||
multi_session_tests_select_model = [] | ||
for model_name in ["offset10-model"]: | ||
for session_id in [None, 0, 1, 5, 2, 6, 4]: | ||
for args in [("demo-continuous-multisession", model_name, session_id, | ||
cebra.data.ContinuousMultiSessionDataLoader)]: | ||
multi_session_tests_select_model.append( | ||
(*args, cebra.solver.MultiSessionSolver)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you wrap the for loops here (quite complex) in functions, and only do the assingment on the global level?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I proposed something lmk if that's what you meant :)
doc error is: |
@CeliaBenquet not sure I see your edits post review; did you push them? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some initial comments; broader discussion is a bit on the api design in the solver/base class --- lets discuss offline.
Co-authored-by: Steffen Schneider <steffen@bethgelab.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, review got a bit longer again; I realized I missed a few things on the last review. High level comments:
- I made some comments in solver which could be fine; I think some arguments were moved from the sklearn class to the solver class, but the motivation for that is not entirely clear. Mostly needs one round of discussion so we can settle on a good API design for these. Specifically, what is the usecase for storing these variables now in the solver, where are they called?
- the new
transform
function adds a lot of duplicated code that should be unified; again, could be first discussed
if hasattr(self, "n_features"): | ||
state_dict["n_features"] = self.n_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this an attribute of the solver, vs. being returned directly from the model? For sklearn it makes sense to fix this, but for the solver this could also simply be a property to be returned from the model? Where is this used?
E.g. what would happen for an xCEBRA solver, where you have not a single feature dim, but multiple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for the multisession case that's already the case and that's a list.
num_features
cannot be a property I think, because that can be defined only based on the inputs provided to the fit(), and later if we adapt the solver, it needs to be reset. This is used to be saved with the solver as it's needed when reloading it + to be called in the sklearn + to see if the solver is fitted when calling transform().
for xcebra that's just similar to the original sklearn one but at a lower level, so yes we need to think about it but we would have had to in any case.
@@ -127,12 +317,27 @@ def _inference(self, batch): | |||
|
|||
@register("single-session-hybrid") | |||
@dataclasses.dataclass | |||
class SingleSessionHybridSolver(abc_.MultiobjectiveSolver): | |||
class SingleSessionHybridSolver(abc_.MultiobjectiveSolver, SingleSessionSolver): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not work, I think. Both inherit from Solver
base, this might have some weird effects; what was the motivation though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's the same transform() method as well as all the checks methods etc. so that's to avoid a lot of duplicate code.
I thought so as well but all tests pass, and they don't have redefined methods in common. Else happy to hear your suggestion to avoid duplication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the methods that would be duplicated, could you list here, @CeliaBenquet ? If this is the issue, the proper way is that we write a Mixin
functionality that puts the otherwise duplicated functions with same functionality in a new class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate (because single session mode)
- parameters
- _set_fitted_params
- _check_is_inputs_valid
- _check_is_session_id_valid
differences:
- _get_model
- _inference
note that this is similar for the multi session mode and multiobjective and auxiliary variable option in single session mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
notes, but not final
class BaseSolverMixin(abc.ABC):
# all abstract
parameters
_set_fitted_params
_check_is_inputs_valid
_check_is_session_id_valid
class SingleSessionMixin(BaseSolverMixin)
...
class MultiSessionMixin(BaseSolverMixin)
...
@@ -1,5 +1,7 @@ | |||
import pickle | |||
|
|||
import _utils_deprecated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multiobjective is tested, single objective is not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
single objective solver transform was not tested before that PR + not used in the CEBRA() class, so we test the CEBRA transform (single objective) and the multiobjective one, as the solver transform was more similar to the new structure (padding, etc).
@@ -209,6 +210,13 @@ def __post_init__(self): | |||
renormalize=self.renormalize, | |||
) | |||
|
|||
def parameters(self, session_id: Optional[int] = None): | |||
"""Iterate over all parameters.""" | |||
super().parameters(session_id=session_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this an error? this does not do anything besides checking, right? shouldnt the params be also returend?
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/746
fix #199
This PR adds the following features:
CEBRA.transform()
orsolver.transform()
) can be performed in batch, allowing inference on larger datasets or with larger models in a memory-efficient way (fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/issues/624).Example Usage of the new PyTorch API:
all is similar to previous implementation but the inference part, which doesn't require to handle the padding of the input before passing it to the model.