Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix msgspec root import #1611

Merged
merged 1 commit into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_data_model_types(
elif data_model_type == DataModelType.MsgspecStruct:
return DataModelSet(
data_model=msgspec.Struct,
root_model=rootmodel.RootModel,
root_model=msgspec.RootModel,
field_model=msgspec.DataModelField,
data_type_manager=DataTypeManager,
dump_resolve_reference_action=None,
Expand Down
54 changes: 41 additions & 13 deletions datamodel_code_generator/model/msgspec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
from functools import wraps
from pathlib import Path
from typing import Any, ClassVar, DefaultDict, Dict, List, Optional, Set, Tuple
from typing import (
Any,
ClassVar,
DefaultDict,
Dict,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
)

from pydantic import Field

Expand All @@ -15,7 +27,7 @@
from datamodel_code_generator.model.pydantic.base_model import (
Constraints as _Constraints,
)
from datamodel_code_generator.model.rootmodel import RootModel
from datamodel_code_generator.model.rootmodel import RootModel as _RootModel
from datamodel_code_generator.reference import Reference
from datamodel_code_generator.types import chain_as_tuple, get_optional_type

Expand All @@ -27,6 +39,33 @@ def _has_field_assignment(field: DataModelFieldBase) -> bool:
)


DataModelT = TypeVar('DataModelT', bound=DataModel)


def import_extender(cls: Type[DataModelT]) -> Type[DataModelT]:
original_imports: property = getattr(cls, 'imports', None) # type: ignore

@wraps(original_imports.fget) # type: ignore
def new_imports(self: DataModelT) -> Tuple[Import, ...]:
extra_imports = []
if any(f for f in self.fields if f.field):
extra_imports.append(IMPORT_MSGSPEC_FIELD)
if any(f for f in self.fields if f.field and 'lambda: convert' in f.field):
extra_imports.append(IMPORT_MSGSPEC_CONVERT)
if any(f for f in self.fields if f.annotated):
extra_imports.append(IMPORT_MSGSPEC_META)
return chain_as_tuple(original_imports.fget(self), extra_imports) # type: ignore

setattr(cls, 'imports', property(new_imports))
return cls


@import_extender
class RootModel(_RootModel):
pass


@import_extender
class Struct(DataModel):
TEMPLATE_FILE_PATH: ClassVar[str] = 'msgspec.jinja2'
BASE_CLASS: ClassVar[str] = 'msgspec.Struct'
Expand Down Expand Up @@ -63,17 +102,6 @@ def __init__(
nullable=nullable,
)

@property
def imports(self) -> Tuple[Import, ...]:
extra_imports = []
if any(f for f in self.fields if f.field):
extra_imports.append(IMPORT_MSGSPEC_FIELD)
if any(f for f in self.fields if f.field and 'lambda: convert' in f.field):
extra_imports.append(IMPORT_MSGSPEC_CONVERT)
if any(f for f in self.fields if f.annotated):
extra_imports.append(IMPORT_MSGSPEC_META)
return chain_as_tuple(super().imports, extra_imports)


class Constraints(_Constraints):
# To override existing pattern alias
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# generated by datamodel-codegen:
# filename: common.yml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import Annotated, Any

from msgspec import Meta

Model = Any


Ulid = Annotated[str, Meta(pattern='[0-9ABCDEFGHJKMNPQRSTVWXYZ]{26,26}')]
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# generated by datamodel-codegen:
# filename: test.yml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import Annotated

from msgspec import Meta, Struct

from . import common


class Test(Struct):
uid: Annotated[common.Ulid, Meta(description='ulid of this object')]
49 changes: 34 additions & 15 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5832,29 +5832,48 @@ def test_main_duplicate_field_constraints():
assert result == path.read_text()


@pytest.mark.parametrize(
'collapse_root_models,python_version,expected_output',
[
(
'--collapse-root-models',
'3.8',
'duplicate_field_constraints_msgspec_py38_collapse_root_models',
),
(
None,
'3.9',
'duplicate_field_constraints_msgspec',
),
],
)
@freeze_time('2019-07-26')
def test_main_duplicate_field_constraints_py38():
def test_main_duplicate_field_constraints_msgspec(
collapse_root_models, python_version, expected_output
):
with TemporaryDirectory() as output_dir:
output_path: Path = Path(output_dir)
return_code: Exit = main(
[
'--input',
str(JSON_SCHEMA_DATA_PATH / 'duplicate_field_constraints'),
'--output',
str(output_path),
'--input-file-type',
'jsonschema',
'--collapse-root-models',
'--output-model-type',
'msgspec.Struct',
'--target-python-version',
'3.8',
a
for a in [
'--input',
str(JSON_SCHEMA_DATA_PATH / 'duplicate_field_constraints'),
'--output',
str(output_path),
'--input-file-type',
'jsonschema',
'--output-model-type',
'msgspec.Struct',
'--target-python-version',
python_version,
collapse_root_models,
]
if a
]
)
assert return_code == Exit.OK
main_modular_dir = (
EXPECTED_MAIN_PATH / 'duplicate_field_constraints_msgspec_py38'
)
main_modular_dir = EXPECTED_MAIN_PATH / expected_output
for path in main_modular_dir.rglob('*.py'):
result = output_path.joinpath(
path.relative_to(main_modular_dir)
Expand Down
Loading