diff --git a/schema_salad/codegen.py b/schema_salad/codegen.py index 6ef6e443..f81dc5ca 100644 --- a/schema_salad/codegen.py +++ b/schema_salad/codegen.py @@ -19,7 +19,7 @@ from .typescript_codegen import TypeScriptCodeGen from .utils import aslist -FIELD_SORT_ORDER = ["id", "class", "name"] +FIELD_SORT_ORDER = ["class", "id", "name"] def codegen( diff --git a/schema_salad/metaschema.py b/schema_salad/metaschema.py index 74bbc9f8..a7610080 100644 --- a/schema_salad/metaschema.py +++ b/schema_salad/metaschema.py @@ -1162,6 +1162,8 @@ class RecordField(Documented): A field of a record. """ + class_uri = "https://w3id.org/cwl/salad#RecordField" + def __init__( self, name: Any, @@ -1428,6 +1430,8 @@ def save( class RecordSchema(Saveable): + class_uri = "https://w3id.org/cwl/salad#RecordSchema" + def __init__( self, type_: Any, @@ -1632,6 +1636,8 @@ class EnumSchema(Saveable): """ + class_uri = "https://w3id.org/cwl/salad#EnumSchema" + def __init__( self, symbols: Any, @@ -1898,6 +1904,8 @@ def save( class ArraySchema(Saveable): + class_uri = "https://w3id.org/cwl/salad#ArraySchema" + def __init__( self, items: Any, @@ -2097,6 +2105,8 @@ def save( class MapSchema(Saveable): + class_uri = "https://w3id.org/cwl/salad#MapSchema" + def __init__( self, type_: Any, @@ -2296,6 +2306,8 @@ def save( class UnionSchema(Saveable): + class_uri = "https://w3id.org/cwl/salad#UnionSchema" + def __init__( self, names: Any, @@ -2501,6 +2513,8 @@ class JsonldPredicate(Saveable): """ + class_uri = "https://w3id.org/cwl/salad#JsonldPredicate" + def __init__( self, _id: Optional[Any] = None, @@ -3239,6 +3253,8 @@ def save( class SpecializeDef(Saveable): + class_uri = "https://w3id.org/cwl/salad#SpecializeDef" + def __init__( self, specializeFrom: Any, @@ -3463,6 +3479,8 @@ class SaladRecordField(RecordField): A field of a record. """ + class_uri = "https://w3id.org/cwl/salad#SaladRecordField" + def __init__( self, name: Any, @@ -3844,6 +3862,8 @@ def save( class SaladRecordSchema(NamedType, RecordSchema, SchemaDefinedType): + class_uri = "https://w3id.org/cwl/salad#SaladRecordSchema" + def __init__( self, name: Any, @@ -4705,6 +4725,8 @@ class SaladEnumSchema(NamedType, EnumSchema, SchemaDefinedType): """ + class_uri = "https://w3id.org/cwl/salad#SaladEnumSchema" + def __init__( self, symbols: Any, @@ -5446,6 +5468,8 @@ class SaladMapSchema(NamedType, MapSchema, SchemaDefinedType): """ + class_uri = "https://w3id.org/cwl/salad#SaladMapSchema" + def __init__( self, name: Any, @@ -6131,6 +6155,8 @@ class SaladUnionSchema(NamedType, UnionSchema, DocType): """ + class_uri = "https://w3id.org/cwl/salad#SaladUnionSchema" + def __init__( self, name: Any, @@ -6757,6 +6783,8 @@ class Documentation(NamedType, DocType): """ + class_uri = "https://w3id.org/cwl/salad#Documentation" + def __init__( self, name: Any, @@ -7612,4 +7640,4 @@ def load_document_by_yaml( uri, loadingOptions, ) - return result + return result \ No newline at end of file diff --git a/schema_salad/python_codegen.py b/schema_salad/python_codegen.py index 619c6c4c..104a1af4 100644 --- a/schema_salad/python_codegen.py +++ b/schema_salad/python_codegen.py @@ -143,6 +143,7 @@ def begin_class( idfield: str, optional_fields: set[str], ) -> None: + class_uri = classname classname = self.safe_name(classname) if extends: @@ -163,6 +164,8 @@ def begin_class( self.out.write(" pass\n\n\n") return + self.out.write(f' class_uri = "{class_uri}"\n\n') + required_field_names = [f for f in field_names if f not in optional_fields] optional_field_names = [f for f in field_names if f in optional_fields] @@ -276,27 +279,6 @@ def save( """ ) - if "class" in field_names: - self.out.write( - """ - if "class" not in _doc: - raise ValidationException("Missing 'class' field") - if _doc.get("class") != "{class_}": - raise ValidationException("tried `{class_}` but") - -""".format( - class_=classname - ) - ) - - self.serializer.write( - """ - r["class"] = "{class_}" -""".format( - class_=classname - ) - ) - def end_class(self, classname: str, field_names: list[str]) -> None: """Signal that we are done with this class.""" if self.current_class_is_abstract: @@ -554,9 +536,6 @@ def declare_field( if self.current_class_is_abstract: return - if shortname(name) == "class": - return - if optional: self.out.write(f""" {self.safe_name(name)} = None\n""") self.out.write(f""" if "{shortname(name)}" in _doc:\n""") # noqa: B907 @@ -608,8 +587,22 @@ def declare_field( spc=spc, ) ) - self.out.write( - """ + + if shortname(name) == "class": + self.out.write( + """{spc} if {safename} != cls.__name__ and {safename} != cls.class_uri: +{spc} raise ValidationException(f"tried `{{cls.__name__}}` but") +{spc} except ValidationException as e: +{spc} raise e +""".format( + safename=self.safe_name(name), + spc=spc, + ) + ) + + else: + self.out.write( + """ {spc} except ValidationException as e: {spc} error_message, to_print, verb_tensage = parse_errors(str(e)) diff --git a/schema_salad/tests/test_codegen_errors.py b/schema_salad/tests/test_codegen_errors.py index 2a05702e..3e2d1158 100644 --- a/schema_salad/tests/test_codegen_errors.py +++ b/schema_salad/tests/test_codegen_errors.py @@ -67,11 +67,11 @@ def test_error_message5(tmp_path: Path) -> None: def test_error_message6(tmp_path: Path) -> None: t = "test_schema/test6.cwl" match = r"""\*\s+tried\s+`CommandLineTool`\s+but -\s+Missing\s+'class'\s+field +\s+missing\s+required\s+field\s+`class` +\*\s+tried\s+`ExpressionTool`\s+but -\s+Missing\s+'class'\s+field +\s+missing\s+required\s+field\s+`class` +\*\s+tried\s+`Workflow`\s+but -\s+Missing\s+'class'\s+field""" +\s+missing\s+required\s+field\s+`class`""" path = get_data("tests/" + t) assert path with pytest.raises(ValidationException, match=match):