diff --git a/schema_salad/codegen.py b/schema_salad/codegen.py index 6ef6e443..f81dc5ca 100644 --- a/schema_salad/codegen.py +++ b/schema_salad/codegen.py @@ -19,7 +19,7 @@ from .typescript_codegen import TypeScriptCodeGen from .utils import aslist -FIELD_SORT_ORDER = ["id", "class", "name"] +FIELD_SORT_ORDER = ["class", "id", "name"] def codegen( diff --git a/schema_salad/python_codegen.py b/schema_salad/python_codegen.py index 619c6c4c..d39ee57e 100644 --- a/schema_salad/python_codegen.py +++ b/schema_salad/python_codegen.py @@ -143,6 +143,7 @@ def begin_class( idfield: str, optional_fields: set[str], ) -> None: + class_uri = classname classname = self.safe_name(classname) if extends: @@ -163,6 +164,8 @@ def begin_class( self.out.write(" pass\n\n\n") return + self.out.write(f' class_uri = "{class_uri}"\n\n') + required_field_names = [f for f in field_names if f not in optional_fields] optional_field_names = [f for f in field_names if f in optional_fields] @@ -276,27 +279,6 @@ def save( """ ) - if "class" in field_names: - self.out.write( - """ - if "class" not in _doc: - raise ValidationException("Missing 'class' field") - if _doc.get("class") != "{class_}": - raise ValidationException("tried `{class_}` but") - -""".format( - class_=classname - ) - ) - - self.serializer.write( - """ - r["class"] = "{class_}" -""".format( - class_=classname - ) - ) - def end_class(self, classname: str, field_names: list[str]) -> None: """Signal that we are done with this class.""" if self.current_class_is_abstract: @@ -554,9 +536,6 @@ def declare_field( if self.current_class_is_abstract: return - if shortname(name) == "class": - return - if optional: self.out.write(f""" {self.safe_name(name)} = None\n""") self.out.write(f""" if "{shortname(name)}" in _doc:\n""") # noqa: B907 @@ -608,6 +587,17 @@ def declare_field( spc=spc, ) ) + + if shortname(name) == "class": + self.out.write( + """{spc} if {safename} != cls.__name__ and {safename} != cls.class_uri: +{spc} raise ValidationException(f"tried `{{cls.__name__}}` but") +""".format( + safename=self.safe_name(name), + spc=spc, + ) + ) + self.out.write( """ {spc} except ValidationException as e: diff --git a/schema_salad/tests/test_codegen_errors.py b/schema_salad/tests/test_codegen_errors.py index 2a05702e..3e2d1158 100644 --- a/schema_salad/tests/test_codegen_errors.py +++ b/schema_salad/tests/test_codegen_errors.py @@ -67,11 +67,11 @@ def test_error_message5(tmp_path: Path) -> None: def test_error_message6(tmp_path: Path) -> None: t = "test_schema/test6.cwl" match = r"""\*\s+tried\s+`CommandLineTool`\s+but -\s+Missing\s+'class'\s+field +\s+missing\s+required\s+field\s+`class` +\*\s+tried\s+`ExpressionTool`\s+but -\s+Missing\s+'class'\s+field +\s+missing\s+required\s+field\s+`class` +\*\s+tried\s+`Workflow`\s+but -\s+Missing\s+'class'\s+field""" +\s+missing\s+required\s+field\s+`class`""" path = get_data("tests/" + t) assert path with pytest.raises(ValidationException, match=match):