Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched inference CEBRA & padding at the Solver level #168

Open
wants to merge 49 commits into
base: main
Choose a base branch
from

Conversation

@cla-bot cla-bot bot added the CLA signed label Aug 23, 2024
@CeliaBenquet CeliaBenquet self-assigned this Aug 23, 2024
@CeliaBenquet CeliaBenquet requested a review from stes August 23, 2024 12:02
@MMathisLab MMathisLab added the enhancement New feature or request label Aug 23, 2024
@CeliaBenquet
Copy link
Member Author

@stes @MMathisLab, if you have time to review this that would be great :)

Copy link
Member

@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks fine to me, and is already used internally in production right @stes ?

@MMathisLab
Copy link
Member

@CeliaBenquet can you solve the conflicts, then I think fine to merge!

@CeliaBenquet
Copy link
Member Author

CeliaBenquet commented Sep 18, 2024

@MMathisLab there's been big code changes / refactoring since @stes's last review, so I would be more confident about merging after an "in-depth" reviewing, but your call :)

@stes
Copy link
Member

stes commented Sep 18, 2024

reviewing now

Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments.

The biggest issue I see, is that a lot of features are changed that do not seem to be directly related to the batched implementation (but I might be wrong). So one iteration addressing some of these Qs in my review would help me understand the logic a bit better.

Comment on lines 209 to 210
raise NotImplementedError
self.offset = model.get_offset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo? / missing cleanup?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. should the line below be removed here? why is that relevant for batched inference?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

configure_for was done in the cebra.CEBRA class (in configure_for_all) and now it is moved to the solvers directly, and the configure_for in the multisession solver was wrongly implemented and not used.

So now not implemented in the base class and defined in multi and single solvers.

cebra/data/single_session.py Show resolved Hide resolved
cebra/integrations/sklearn/cebra.py Outdated Show resolved Hide resolved
cebra/integrations/sklearn/cebra.py Show resolved Hide resolved
cebra/integrations/sklearn/cebra.py Show resolved Hide resolved
Comment on lines 591 to 594
if not hasattr(self, "n_features"):
raise ValueError(
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with "
"appropriate arguments before using this estimator.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not ideal to use the n_features for this. can you implement a @property that gives you that info directly (is_fitted()) for example

Copy link
Member Author

@CeliaBenquet CeliaBenquet Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's what was done initially with the sklearn function.

I'm not sure to understand how the is_fitted changes that, it's just an implementation thing right? I keep the n_features?

@@ -336,7 +647,7 @@ def load(self, logdir, filename="checkpoint.pth"):
checkpoint = torch.load(savepath, map_location=self.device)
self.load_state_dict(checkpoint, strict=True)

def save(self, logdir, filename="checkpoint_last.pth"):
def save(self, logdir, filename="checkpoint.pth"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep the old naming here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason it's different from the default in the load function?

Comment on lines 45 to 52
def parameters(self, session_id: Optional[int] = None):
"""Iterate over all parameters."""
self._check_is_session_id_valid(session_id=session_id)
for parameter in self.model[session_id].parameters():
yield parameter

for parameter in self.criterion.parameters():
yield parameter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if None is given, we should return all parameters from the super() class

Copy link
Member Author

@CeliaBenquet CeliaBenquet Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I get it, you mean we should return the parameters from the criterion at least? or the parameters for all models?

the super class has an abstract method for that method

cebra/solver/multi_session.py Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
@CeliaBenquet
Copy link
Member Author

re your comment on changes not related to the batched inference, it is because the PR was started with 2 (related) goals at once if I'm correct (not me who started it):

  • batched inference,
  • but to do that in the solver, all padding etc operations in the transform needed to be moved from the cebra.CEBRA() class to the solver.

--> see other linked issues for better understanding.

@stes
Copy link
Member

stes commented Sep 18, 2024

Ok, makes sense!

@CeliaBenquet CeliaBenquet requested a review from stes September 19, 2024 11:57
@stes
Copy link
Member

stes commented Oct 20, 2024

also upgraded this, and checking again once tests passed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA signed enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants