Skip to content

Commit

Permalink
add unique key loader check to YAML to avoid problematic configs (#658)
Browse files Browse the repository at this point in the history
  • Loading branch information
misko authored Apr 25, 2024
1 parent 069442d commit 9409555
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
20 changes: 18 additions & 2 deletions ocpmodels/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,22 @@
from torch.nn.modules.module import _IncompatibleKeys


# copied from https://stackoverflow.com/questions/33490870/parsing-yaml-in-python-detect-duplicated-keys
# prevents loading YAMLS where keys have been overwritten
class UniqueKeyLoader(yaml.SafeLoader):
def construct_mapping(self, node, deep=False):
mapping = set()
for key_node, value_node in node.value:
each_key = self.construct_object(key_node, deep=deep)
if each_key in mapping:
raise ValueError(
f"Duplicate Key: {each_key!r} is found in YAML File.\n"
f"Error File location: {key_node.end_mark}"
)
mapping.add(each_key)
return super().construct_mapping(node, deep)


def pyg2_data_transform(data: Data):
"""
if we're on the new pyg (2.0 or later) and if the Data stored is in older format
Expand Down Expand Up @@ -392,7 +408,7 @@ def load_config(path: str, previous_includes: list = []):
)
previous_includes = previous_includes + [path]

direct_config = yaml.safe_load(open(path, "r"))
direct_config = yaml.load(open(path, "r"), Loader=UniqueKeyLoader)

# Load config from included files.
if "includes" in direct_config:
Expand Down Expand Up @@ -489,7 +505,7 @@ def _update_config(config, keys, override_vals, sep: str = "."):
child_config[key_path[-1]] = value
return config

sweeps = yaml.safe_load(open(sweep_file, "r"))
sweeps = yaml.load(open(sweep_file, "r"), Loader=UniqueKeyLoader)
flat_sweeps = _flatten_sweeps(sweeps)
keys = list(flat_sweeps.keys())
values = list(itertools.product(*flat_sweeps.values()))
Expand Down
45 changes: 45 additions & 0 deletions tests/common/test_yaml_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import tempfile

import pytest
import yaml

from ocpmodels.common.utils import UniqueKeyLoader


@pytest.fixture(scope="class")
def invalid_yaml_config():
return """
key1:
- a
- b
key1:
- c
- d
"""


@pytest.fixture(scope="class")
def valid_yaml_config():
return """
key1:
- a
- b
key2:
- c
- d
"""


def test_invalid_config(invalid_yaml_config):
with tempfile.NamedTemporaryFile(delete=False) as fp:
fp.write(invalid_yaml_config.encode())
fp.close()
with pytest.raises(ValueError):
yaml.load(open(fp.name, "r"), Loader=UniqueKeyLoader)


def test_valid_config(valid_yaml_config):
with tempfile.NamedTemporaryFile(delete=False) as fp:
fp.write(valid_yaml_config.encode())
fp.close()
yaml.load(open(fp.name, "r"), Loader=UniqueKeyLoader)

0 comments on commit 9409555

Please sign in to comment.