Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Add include_data=True option for including all relationships #264

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion marshmallow_jsonapi/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ class Meta:
def __init__(self, *args, **kwargs):
self.include_data = kwargs.pop("include_data", ())
super().__init__(*args, **kwargs)
if self.include_data:

if self.include_data is True:
self.include_all_data()
elif self.include_data:
self.check_relations(self.include_data)

if not self.opts.type_:
Expand All @@ -93,6 +96,15 @@ def __init__(self, *args, **kwargs):

OPTIONS_CLASS = SchemaOpts

def include_all_data(self):
"""
Recursively set include_data for all relationships to this schema
"""
for field in self.fields.values():
if isinstance(field, BaseRelationship):
field.include_data = True
field.schema.include_all_data()

def check_relations(self, relations):
"""Recursive function which checks if a relation is valid."""
for rel in relations:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ def test_include_data_with_all_relations(self, post):
}
assert included_comments_author_ids == expected_comments_author_ids

def test_include_data_auto_all(self, post):
"""
Test that we can use include_data=True to include all relations recursively
"""
data = unpack(PostSchema(include_data=True).dump(post))
assert "included" in data
assert len(data["included"]) == 8
for included in data["included"]:
assert included["id"]
assert included["type"] in ("people", "comments", "keywords")

def test_include_no_data(self, post):
data = unpack(PostSchema(include_data=()).dump(post))
assert "included" not in data
Expand Down