Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Apr 10, 2024
1 parent 1b49bb6 commit 585c3af
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions acegen/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,20 @@ def adapt_state_dict(source_state_dict: dict, target_state_dict: dict):
target_state_dict[key_target] = value_source

return target_state_dict


def get_primers_from_module(module):
"""Get all tensordict primers from all submodules of a module."""
primers = []

def make_primers(submodule):
if hasattr(submodule, "make_tensordict_primer"):
primers.append(submodule.make_tensordict_primer())

module.apply(make_primers)
if not primers:
import warnings

raise warnings.warn("No primers found in the module.")
else:
return primers

0 comments on commit 585c3af

Please sign in to comment.