Skip to content

Commit

Permalink
Fix class detection for namespaced classes (Py)
Browse files Browse the repository at this point in the history
This commit asjusts the Python generated parser to correctly deal with
namespaced classes (e.g., those coming from cwltool extensions).
  • Loading branch information
GlassOfWhiskey committed Dec 4, 2024
1 parent f3518e2 commit d584673
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 28 deletions.
2 changes: 1 addition & 1 deletion schema_salad/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .typescript_codegen import TypeScriptCodeGen
from .utils import aslist

FIELD_SORT_ORDER = ["id", "class", "name"]
FIELD_SORT_ORDER = ["class", "id", "name"]


def codegen(
Expand Down
38 changes: 14 additions & 24 deletions schema_salad/python_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def begin_class(
idfield: str,
optional_fields: set[str],
) -> None:
class_uri = classname
classname = self.safe_name(classname)

if extends:
Expand All @@ -163,6 +164,8 @@ def begin_class(
self.out.write(" pass\n\n\n")
return

self.out.write(f' class_uri = "{class_uri}"\n\n')

required_field_names = [f for f in field_names if f not in optional_fields]
optional_field_names = [f for f in field_names if f in optional_fields]

Expand Down Expand Up @@ -276,27 +279,6 @@ def save(
"""
)

if "class" in field_names:
self.out.write(
"""
if "class" not in _doc:
raise ValidationException("Missing 'class' field")
if _doc.get("class") != "{class_}":
raise ValidationException("tried `{class_}` but")
""".format(
class_=classname
)
)

self.serializer.write(
"""
r["class"] = "{class_}"
""".format(
class_=classname
)
)

def end_class(self, classname: str, field_names: list[str]) -> None:
"""Signal that we are done with this class."""
if self.current_class_is_abstract:
Expand Down Expand Up @@ -554,9 +536,6 @@ def declare_field(
if self.current_class_is_abstract:
return

if shortname(name) == "class":
return

if optional:
self.out.write(f""" {self.safe_name(name)} = None\n""")
self.out.write(f""" if "{shortname(name)}" in _doc:\n""") # noqa: B907
Expand Down Expand Up @@ -608,6 +587,17 @@ def declare_field(
spc=spc,
)
)

if shortname(name) == "class":
self.out.write(
"""{spc} if {safename} != cls.__name__ and {safename} != cls.class_uri:
{spc} raise ValidationException(f"tried `{{cls.__name__}}` but")
""".format(
safename=self.safe_name(name),
spc=spc,
)
)

self.out.write(
"""
{spc} except ValidationException as e:
Expand Down
6 changes: 3 additions & 3 deletions schema_salad/tests/test_codegen_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def test_error_message5(tmp_path: Path) -> None:
def test_error_message6(tmp_path: Path) -> None:
t = "test_schema/test6.cwl"
match = r"""\*\s+tried\s+`CommandLineTool`\s+but
\s+Missing\s+'class'\s+field
\s+missing\s+required\s+field\s+`class`
+\*\s+tried\s+`ExpressionTool`\s+but
\s+Missing\s+'class'\s+field
\s+missing\s+required\s+field\s+`class`
+\*\s+tried\s+`Workflow`\s+but
\s+Missing\s+'class'\s+field"""
\s+missing\s+required\s+field\s+`class`"""
path = get_data("tests/" + t)
assert path
with pytest.raises(ValidationException, match=match):
Expand Down

0 comments on commit d584673

Please sign in to comment.