diff --git a/Makefile b/Makefile index 13b5f0a..2605c39 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,7 @@ check: coverage: coverage erase coverage run --source marshmallow_jsonschema -m py.test -v + coverage report -m pypitest: python setup.py sdist upload -r pypitest diff --git a/marshmallow_jsonschema/base.py b/marshmallow_jsonschema/base.py index ffae465..fbda2aa 100644 --- a/marshmallow_jsonschema/base.py +++ b/marshmallow_jsonschema/base.py @@ -3,7 +3,8 @@ import decimal from marshmallow import fields, missing, Schema, validate -from marshmallow.compat import text_type, binary_type +from marshmallow.class_registry import get_class +from marshmallow.compat import text_type, binary_type, basestring from .validation import handle_length, handle_one_of, handle_range @@ -144,7 +145,11 @@ def _from_python_type(cls, field, pytype): @classmethod def _from_nested_schema(cls, field): - schema = cls().dump(field.nested()).data + if isinstance(field.nested, basestring): + nested = get_class(field.nested) + else: + nested = field.nested + schema = cls().dump(nested()).data if field.metadata.get('metadata', {}).get('description'): schema['description'] = ( diff --git a/tests/test_dump.py b/tests/test_dump.py index 4476fe9..039c72f 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -57,6 +57,22 @@ class TestNestedSchema(Schema): assert nested_dmp['title'] == 'Title1' +def test_nested_string_to_cls(): + class TestSchema(Schema): + foo = fields.Integer(required=True) + + class TestNestedSchema(Schema): + foo2 = fields.Integer(required=True) + nested = fields.Nested('TestSchema') + schema = TestNestedSchema() + json_schema = JSONSchema() + dumped = json_schema.dump(schema).data + _validate_schema(dumped) + nested_json = dumped['properties']['nested'] + assert nested_json['properties']['foo']['format'] == 'integer' + assert nested_json['type'] == 'object' + + def test_one_of_validator(): schema = UserSchema() json_schema = JSONSchema() @@ -64,6 +80,7 @@ def test_one_of_validator(): _validate_schema(dumped) assert dumped['properties']['sex']['enum'] == ['male', 'female'] + def test_range_validator(): schema = Address() json_schema = JSONSchema()