Skip to content

Commit

Permalink
Utils: YAML goes brrrt (ArchipelagoMW#2868)
Browse files Browse the repository at this point in the history
Also tests to validate we dont break the API.
  • Loading branch information
black-sliver authored Feb 27, 2024
1 parent 738a9eb commit c126418
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
8 changes: 3 additions & 5 deletions Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
from argparse import Namespace
from settings import Settings, get_settings
from typing import BinaryIO, Coroutine, Optional, Set, Dict, Any, Union
from yaml import load, load_all, dump, SafeLoader
from yaml import load, load_all, dump

try:
from yaml import CLoader as UnsafeLoader
from yaml import CDumper as Dumper
from yaml import CLoader as UnsafeLoader, CSafeLoader as SafeLoader, CDumper as Dumper
except ImportError:
from yaml import Loader as UnsafeLoader
from yaml import Dumper
from yaml import Loader as UnsafeLoader, SafeLoader, Dumper

if typing.TYPE_CHECKING:
import tkinter
Expand Down
68 changes: 68 additions & 0 deletions test/utils/test_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Tests that yaml wrappers in Utils.py do what they should

import unittest
from typing import cast, Any, ClassVar, Dict

from Utils import dump, Dumper # type: ignore[attr-defined]
from Utils import parse_yaml, parse_yamls, unsafe_parse_yaml


class AClass:
def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__)


class TestYaml(unittest.TestCase):
safe_data: ClassVar[Dict[str, Any]] = {
"a": [1, 2, 3],
"b": None,
"c": True,
}
unsafe_data: ClassVar[Dict[str, Any]] = {
"a": AClass()
}

@property
def safe_str(self) -> str:
return cast(str, dump(self.safe_data, Dumper=Dumper))

@property
def unsafe_str(self) -> str:
return cast(str, dump(self.unsafe_data, Dumper=Dumper))

def assertIsNonEmptyString(self, string: str) -> None:
self.assertTrue(string)
self.assertIsInstance(string, str)

def test_dump(self) -> None:
self.assertIsNonEmptyString(self.safe_str)
self.assertIsNonEmptyString(self.unsafe_str)

def test_safe_parse(self) -> None:
self.assertEqual(self.safe_data, parse_yaml(self.safe_str))
with self.assertRaises(Exception):
parse_yaml(self.unsafe_str)
with self.assertRaises(Exception):
parse_yaml("1\n---\n2\n")

def test_unsafe_parse(self) -> None:
self.assertEqual(self.safe_data, unsafe_parse_yaml(self.safe_str))
self.assertEqual(self.unsafe_data, unsafe_parse_yaml(self.unsafe_str))
with self.assertRaises(Exception):
unsafe_parse_yaml("1\n---\n2\n")

def test_multi_parse(self) -> None:
self.assertEqual(self.safe_data, next(parse_yamls(self.safe_str)))
with self.assertRaises(Exception):
next(parse_yamls(self.unsafe_str))
self.assertEqual(2, len(list(parse_yamls("1\n---\n2\n"))))

def test_unique_key(self) -> None:
s = """
a: 1
a: 2
"""
with self.assertRaises(Exception):
parse_yaml(s)
with self.assertRaises(Exception):
next(parse_yamls(s))

0 comments on commit c126418

Please sign in to comment.