Skip to content

Commit

Permalink
Fix some edge-cases with module references (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshbode authored and koxudaxi committed Oct 17, 2019
1 parent 5697b69 commit b54c3d0
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 11 deletions.
18 changes: 16 additions & 2 deletions datamodel_code_generator/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,21 @@ def __init__(
if auto_import:
if base_class_full_path:
self.imports.append(Import.from_full_path(base_class_full_path))
self.base_class = base_class_full_path.split('.')[-1]
self.base_class = base_class_full_path.rsplit('.', 1)[-1]

if '.' in name:
module, class_name = name.rsplit('.', 1)
prefix = f'{module}.'
if self.base_class.startswith(prefix):
self.base_class = self.base_class.replace(prefix, '', 1)
for field in self.fields:
type_hint = field.type_hint
if type_hint is not None and prefix in type_hint:
field.type_hint = type_hint.replace(prefix, '', 1)
else:
class_name = name

self.class_name: str = class_name

self.extra_template_data = (
extra_template_data.get(self.name, {})
Expand All @@ -177,7 +191,7 @@ def __init__(

def render(self) -> str:
response = self._render(
class_name=self.name,
class_name=self.class_name,
fields=self.fields,
decorators=self.decorators,
base_class=self.base_class,
Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def is_array(self) -> bool:

@property
def ref_object_name(self) -> str:
return self.ref.split('/')[-1] # type: ignore
return self.ref.rsplit('/', 1)[-1] # type: ignore


JsonSchemaObject.update_forward_refs()
5 changes: 1 addition & 4 deletions datamodel_code_generator/parser/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def parse(
for ref_name in model.reference_classes:
if '.' not in ref_name:
continue
ref_path = ref_name.split('.', 1)[0]
ref_path = ref_name.rsplit('.', 1)[0]
if ref_path == module_path:
continue
imports.append(Import(from_='.', import_=ref_path))
Expand All @@ -345,9 +345,6 @@ def parse(
result += [imports.dump(), self.imports.dump(), '\n']

code = dump_templates(models)
if module_path:
# make references relative to current module
code = code.replace(f'{module_path}.', '')
result += [code]

if self.dump_resolve_reference_action is not None:
Expand Down
11 changes: 7 additions & 4 deletions tests/data/modular.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ components:
name:
type: string
Result:
type: object
properties:
event:
$ref: '#/components/schemas/models.Event'
type: object
properties:
event:
$ref: '#/components/schemas/models.Event'
foo.bar.Thing:
properties:
attributes:
Expand All @@ -194,3 +194,6 @@ components:
type: array
items:
type: object
foo.bar.Clone:
allOf:
- $ref: '#/components/schemas/foo.bar.Thing'
4 changes: 4 additions & 0 deletions tests/parser/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,10 @@ class Thing(BaseModel):
class Thang(BaseModel):
attributes: Optional[List[Dict[str, Any]]] = None
class Clone(Thing):
pass
''',
},
)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ class Thing(BaseModel):
class Thang(BaseModel):
attributes: Optional[List[Dict[str, Any]]] = None
class Clone(Thing):
pass
''',
}
],
Expand Down

0 comments on commit b54c3d0

Please sign in to comment.