From 934816d2b319b4e362760e801dbca8b3aec291bc Mon Sep 17 00:00:00 2001 From: GlassOfWhiskey Date: Wed, 4 Dec 2024 16:28:47 +0100 Subject: [PATCH] Fix class detection for namespaced classes (Py) This commit asjusts the Python generated parser to correctly deal with namespaced classes (e.g., those coming from cwltool extensions). --- schema_salad/codegen.py | 2 +- schema_salad/python_codegen.py | 38 +++++++++++++--------------------- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/schema_salad/codegen.py b/schema_salad/codegen.py index 6ef6e443..f81dc5ca 100644 --- a/schema_salad/codegen.py +++ b/schema_salad/codegen.py @@ -19,7 +19,7 @@ from .typescript_codegen import TypeScriptCodeGen from .utils import aslist -FIELD_SORT_ORDER = ["id", "class", "name"] +FIELD_SORT_ORDER = ["class", "id", "name"] def codegen( diff --git a/schema_salad/python_codegen.py b/schema_salad/python_codegen.py index 619c6c4c..d39ee57e 100644 --- a/schema_salad/python_codegen.py +++ b/schema_salad/python_codegen.py @@ -143,6 +143,7 @@ def begin_class( idfield: str, optional_fields: set[str], ) -> None: + class_uri = classname classname = self.safe_name(classname) if extends: @@ -163,6 +164,8 @@ def begin_class( self.out.write(" pass\n\n\n") return + self.out.write(f' class_uri = "{class_uri}"\n\n') + required_field_names = [f for f in field_names if f not in optional_fields] optional_field_names = [f for f in field_names if f in optional_fields] @@ -276,27 +279,6 @@ def save( """ ) - if "class" in field_names: - self.out.write( - """ - if "class" not in _doc: - raise ValidationException("Missing 'class' field") - if _doc.get("class") != "{class_}": - raise ValidationException("tried `{class_}` but") - -""".format( - class_=classname - ) - ) - - self.serializer.write( - """ - r["class"] = "{class_}" -""".format( - class_=classname - ) - ) - def end_class(self, classname: str, field_names: list[str]) -> None: """Signal that we are done with this class.""" if self.current_class_is_abstract: @@ -554,9 +536,6 @@ def declare_field( if self.current_class_is_abstract: return - if shortname(name) == "class": - return - if optional: self.out.write(f""" {self.safe_name(name)} = None\n""") self.out.write(f""" if "{shortname(name)}" in _doc:\n""") # noqa: B907 @@ -608,6 +587,17 @@ def declare_field( spc=spc, ) ) + + if shortname(name) == "class": + self.out.write( + """{spc} if {safename} != cls.__name__ and {safename} != cls.class_uri: +{spc} raise ValidationException(f"tried `{{cls.__name__}}` but") +""".format( + safename=self.safe_name(name), + spc=spc, + ) + ) + self.out.write( """ {spc} except ValidationException as e: