From 431207bccdb315c717ff4ab8e65b09053998ae50 Mon Sep 17 00:00:00 2001 From: GlassOfWhiskey Date: Wed, 4 Dec 2024 16:28:47 +0100 Subject: [PATCH] Fix class detection for namespaced classes (Py) This commit asjusts the Python generated parser to correctly deal with namespaced classes (e.g., those coming from cwltool extensions). --- schema_salad/codegen.py | 2 +- schema_salad/python_codegen.py | 53 ++++++++++------------- schema_salad/tests/test_codegen_errors.py | 6 +-- 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/schema_salad/codegen.py b/schema_salad/codegen.py index 6ef6e443..f81dc5ca 100644 --- a/schema_salad/codegen.py +++ b/schema_salad/codegen.py @@ -19,7 +19,7 @@ from .typescript_codegen import TypeScriptCodeGen from .utils import aslist -FIELD_SORT_ORDER = ["id", "class", "name"] +FIELD_SORT_ORDER = ["class", "id", "name"] def codegen( diff --git a/schema_salad/python_codegen.py b/schema_salad/python_codegen.py index 619c6c4c..56ac250f 100644 --- a/schema_salad/python_codegen.py +++ b/schema_salad/python_codegen.py @@ -143,6 +143,7 @@ def begin_class( idfield: str, optional_fields: set[str], ) -> None: + class_uri = classname classname = self.safe_name(classname) if extends: @@ -163,6 +164,8 @@ def begin_class( self.out.write(" pass\n\n\n") return + self.out.write(f' class_uri = "{class_uri}"\n\n') + required_field_names = [f for f in field_names if f not in optional_fields] optional_field_names = [f for f in field_names if f in optional_fields] @@ -276,27 +279,6 @@ def save( """ ) - if "class" in field_names: - self.out.write( - """ - if "class" not in _doc: - raise ValidationException("Missing 'class' field") - if _doc.get("class") != "{class_}": - raise ValidationException("tried `{class_}` but") - -""".format( - class_=classname - ) - ) - - self.serializer.write( - """ - r["class"] = "{class_}" -""".format( - class_=classname - ) - ) - def end_class(self, classname: str, field_names: list[str]) -> None: """Signal that we are done with this class.""" if self.current_class_is_abstract: @@ -554,9 +536,6 @@ def declare_field( if self.current_class_is_abstract: return - if shortname(name) == "class": - return - if optional: self.out.write(f""" {self.safe_name(name)} = None\n""") self.out.write(f""" if "{shortname(name)}" in _doc:\n""") # noqa: B907 @@ -608,18 +587,34 @@ def declare_field( spc=spc, ) ) + + if shortname(name) == "class": + self.out.write( + """{spc} if {safename} != cls.__name__ and {safename} != cls.class_uri: +{spc} raise ValidationException(f"tried `{{cls.__name__}}` but") +""".format( + safename=self.safe_name(name), + spc=spc, + ) + ) + self.out.write( """ {spc} except ValidationException as e: {spc} error_message, to_print, verb_tensage = parse_errors(str(e)) {spc} if str(e) == "missing required field `{fieldname}`": -{spc} _errors__.append( -{spc} ValidationException( -{spc} str(e), -{spc} None +{spc} if "{fieldname}" == "class": +{spc} raise e +{spc} else: +{spc} _errors__.append( +{spc} ValidationException( +{spc} str(e), +{spc} None +{spc} ) {spc} ) -{spc} ) +{spc} elif str(e) == f"tried `{{cls.__name__}}` but": +{spc} raise e {spc} else: {spc} val = _doc.get("{fieldname}") {spc} if error_message != str(e): diff --git a/schema_salad/tests/test_codegen_errors.py b/schema_salad/tests/test_codegen_errors.py index 2a05702e..3e2d1158 100644 --- a/schema_salad/tests/test_codegen_errors.py +++ b/schema_salad/tests/test_codegen_errors.py @@ -67,11 +67,11 @@ def test_error_message5(tmp_path: Path) -> None: def test_error_message6(tmp_path: Path) -> None: t = "test_schema/test6.cwl" match = r"""\*\s+tried\s+`CommandLineTool`\s+but -\s+Missing\s+'class'\s+field +\s+missing\s+required\s+field\s+`class` +\*\s+tried\s+`ExpressionTool`\s+but -\s+Missing\s+'class'\s+field +\s+missing\s+required\s+field\s+`class` +\*\s+tried\s+`Workflow`\s+but -\s+Missing\s+'class'\s+field""" +\s+missing\s+required\s+field\s+`class`""" path = get_data("tests/" + t) assert path with pytest.raises(ValidationException, match=match):