-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched inference CEBRA & padding at the Solver
level
#168
base: main
Are you sure you want to change the base?
Batched inference CEBRA & padding at the Solver
level
#168
Conversation
…ional models in _transform
@stes @MMathisLab, if you have time to review this that would be great :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks fine to me, and is already used internally in production right @stes ?
@CeliaBenquet can you solve the conflicts, then I think fine to merge! |
@MMathisLab there's been big code changes / refactoring since @stes's last review, so I would be more confident about merging after an "in-depth" reviewing, but your call :) |
reviewing now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments.
The biggest issue I see, is that a lot of features are changed that do not seem to be directly related to the batched implementation (but I might be wrong). So one iteration addressing some of these Qs in my review would help me understand the logic a bit better.
cebra/data/base.py
Outdated
raise NotImplementedError | ||
self.offset = model.get_offset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo? / missing cleanup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. should the line below be removed here? why is that relevant for batched inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
configure_for was done in the cebra.CEBRA class (in configure_for_all) and now it is moved to the solvers directly, and the configure_for in the multisession solver was wrongly implemented and not used.
So now not implemented in the base class and defined in multi and single solvers.
cebra/solver/base.py
Outdated
if not hasattr(self, "n_features"): | ||
raise ValueError( | ||
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " | ||
"appropriate arguments before using this estimator.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is not ideal to use the n_features
for this. can you implement a @property
that gives you that info directly (is_fitted()
) for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's what was done initially with the sklearn function.
I'm not sure to understand how the is_fitted changes that, it's just an implementation thing right? I keep the n_features?
cebra/solver/base.py
Outdated
@@ -336,7 +647,7 @@ def load(self, logdir, filename="checkpoint.pth"): | |||
checkpoint = torch.load(savepath, map_location=self.device) | |||
self.load_state_dict(checkpoint, strict=True) | |||
|
|||
def save(self, logdir, filename="checkpoint_last.pth"): | |||
def save(self, logdir, filename="checkpoint.pth"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep the old naming here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason it's different from the default in the load function?
cebra/solver/multi_session.py
Outdated
def parameters(self, session_id: Optional[int] = None): | ||
"""Iterate over all parameters.""" | ||
self._check_is_session_id_valid(session_id=session_id) | ||
for parameter in self.model[session_id].parameters(): | ||
yield parameter | ||
|
||
for parameter in self.criterion.parameters(): | ||
yield parameter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if None
is given, we should return all parameters from the super() class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I get it, you mean we should return the parameters from the criterion at least? or the parameters for all models?
the super class has an abstract method for that method
re your comment on changes not related to the batched inference, it is because the PR was started with 2 (related) goals at once if I'm correct (not me who started it):
--> see other linked issues for better understanding. |
Ok, makes sense! |
also upgraded this, and checking again once tests passed. |
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/746
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/issues/624
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/issues/637
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/594