Skip to content

Commit

Permalink
move embedding files
Browse files Browse the repository at this point in the history
  • Loading branch information
mshuaibii committed Dec 9, 2020
1 parent d32a2a4 commit 328ad29
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion ocpmodels/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def setup_imports():
trainer_folder = os.path.join(root_folder, "trainers")
trainer_pattern = os.path.join(trainer_folder, "**", "*.py")
datasets_folder = os.path.join(root_folder, "datasets")
datasets_pattern = os.path.join(datasets_folder, "**", "*.py")
datasets_pattern = os.path.join(datasets_folder, "*.py")
model_folder = os.path.join(root_folder, "models")
model_pattern = os.path.join(model_folder, "*.py")

Expand Down
9 changes: 9 additions & 0 deletions ocpmodels/datasets/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
__all__ = [
"ATOMIC_RADII",
"KHOT_EMBEDDINGS",
"CONTINUOUS_EMBEDDINGS",
]

from .atomic_radii import ATOMIC_RADII
from .continuous_embeddings import CONTINUOUS_EMBEDDINGS
from .khot_embeddings import KHOT_EMBEDDINGS
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
NaN stored for unavaialable parameters.
"""
FORCENET_EMBEDDINGS = {
CONTINUOUS_EMBEDDINGS = {
0: [
float("NaN"),
float("NaN"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Original CGCNN k-hot elemental embeddings.
"""

EMBEDDINGS = {
KHOT_EMBEDDINGS = {
1: [
0,
1,
Expand Down
4 changes: 2 additions & 2 deletions ocpmodels/models/cgcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from ocpmodels.common.registry import registry
from ocpmodels.common.utils import get_pbc_distances, radius_graph_pbc
from ocpmodels.datasets.embeddings import KHOT_EMBEDDINGS
from ocpmodels.models.base import BaseModel
from ocpmodels.models.utils.embeddings import EMBEDDINGS


@registry.register_model("cgcnn")
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
# Get CGCNN atom embeddings
self.embedding = torch.zeros(100, 92)
for i in range(100):
self.embedding[i] = torch.tensor(EMBEDDINGS[i + 1])
self.embedding[i] = torch.tensor(KHOT_EMBEDDINGS[i + 1])
self.embedding_fc = nn.Linear(92, atom_embedding_size)

self.convs = nn.ModuleList(
Expand Down
9 changes: 0 additions & 9 deletions ocpmodels/models/utils/embeddings/__init__.py

This file was deleted.

0 comments on commit 328ad29

Please sign in to comment.