Skip to content

Commit

Permalink
Invert experimental import logic to opt-in (#651)
Browse files Browse the repository at this point in the history
* Invert experimental import logic to opt-in

* remove default user import
  • Loading branch information
levineds authored Apr 12, 2024
1 parent be20a86 commit daf72a5
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions ocpmodels/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,22 +264,31 @@ def _import_local_file(path: Path, *, project_root: Path) -> None:


def setup_experimental_imports(project_root: Path) -> None:
experimental_folder = (project_root / "experimental").resolve()
if not experimental_folder.exists() or not experimental_folder.is_dir():
"""
Import selected directories of modules from the "experimental" subdirectory.
If a file named ".include" is present in the "experimental" subdirectory,
this will be read as a list of experimental subdirectories whose module
(including in any subsubdirectories) should be imported.
:param project_root: The root directory of the project (i.e., the "ocp" folder)
"""
experimental_dir = (project_root / "experimental").resolve()
if not experimental_dir.exists() or not experimental_dir.is_dir():
return

experimental_files = [
f.resolve().absolute() for f in experimental_folder.rglob("*.py")
]
# Ignore certain directories within experimental
ignore_file = experimental_folder / ".ignore"
if ignore_file.exists():
with open(ignore_file, "r") as f:
for line in f.read().splitlines():
for ignored_file in (experimental_folder / line).rglob("*.py"):
experimental_files.remove(
ignored_file.resolve().absolute()
)
experimental_files = []
include_file = experimental_dir / ".include"

if include_file.exists():
with open(include_file, "r") as f:
include_dirs = f.read().splitlines()

for inc_dir in include_dirs:
experimental_files.extend(
f.resolve().absolute()
for f in (experimental_dir / inc_dir).rglob("*.py")
)

for f in experimental_files:
_import_local_file(f, project_root=project_root)
Expand Down Expand Up @@ -313,7 +322,7 @@ def setup_imports(config: Optional[dict] = None) -> None:
from ocpmodels.common.registry import registry

skip_experimental_imports = (config or {}).get(
"skip_experimental_imports", None
"skip_experimental_imports", False
)

# First, check if imports are already setup
Expand Down Expand Up @@ -1159,11 +1168,11 @@ def get_commit_hash():
return commit_hash


def cg_change_mat(l, device="cpu"):
if l not in [2]:
def cg_change_mat(ang_mom: int, device: str = "cpu") -> torch.tensor:
if ang_mom not in [2]:
raise NotImplementedError

if l == 2:
if ang_mom == 2:
change_mat = torch.tensor(
[
[3 ** (-0.5), 0, 0, 0, 3 ** (-0.5), 0, 0, 0, 3 ** (-0.5)],
Expand Down Expand Up @@ -1192,12 +1201,14 @@ def cg_change_mat(l, device="cpu"):
return change_mat


def irreps_sum(l):
def irreps_sum(ang_mom: int) -> int:
"""
Returns the sum of the dimensions of the irreps up to the specified l.
Returns the sum of the dimensions of the irreps up to the specified angular momentum.
:param ang_mom: max angular momenttum to sum up dimensions of irreps
"""
total = 0
for i in range(l + 1):
for i in range(ang_mom + 1):
total += 2 * i + 1

return total
Expand Down

0 comments on commit daf72a5

Please sign in to comment.