Skip to content

Commit

Permalink
chore: reformat code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
aexvir committed Aug 14, 2020
1 parent dd26891 commit a3bee37
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 35 deletions.
17 changes: 15 additions & 2 deletions kw/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@

from ._compat import _load as load
from ._compat import _loads as loads
from .encode import KiwiJSONEncoder, MaskedJSONEncoder, default_encoder, dump, dumps, raw_encoder
from .encode import (
KiwiJSONEncoder,
MaskedJSONEncoder,
default_encoder,
dump,
dumps,
raw_encoder,
)
from .flask import JSONExtension
from .utils import DEFAULT_BLACKLIST, DEFAULT_PLACEHOLDER, DEFAULT_WHITELIST, mask_dict, mask_dict_factory
from .utils import (
DEFAULT_BLACKLIST,
DEFAULT_PLACEHOLDER,
DEFAULT_WHITELIST,
mask_dict,
mask_dict_factory,
)
9 changes: 7 additions & 2 deletions kw/json/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
enum = None

try:
from simplejson.encoder import JSONEncoder as BaseJSONEncoder # pylint: disable=W0611
from simplejson.encoder import ( # pylint: disable=W0611
JSONEncoder as BaseJSONEncoder,
)
from simplejson import dumps as json_dumps # pylint: disable=W0611
from simplejson import dump as json_dump # pylint: disable=W0611
from simplejson import loads as json_loads # pylint: disable=W0611
Expand Down Expand Up @@ -51,7 +53,10 @@ def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
except TypeError as err:
if str(err) == "__init__() got an unexpected keyword argument 'use_decimal'":
if (
str(err)
== "__init__() got an unexpected keyword argument 'use_decimal'"
):
raise KiwiJsonError(__use_decimal_error_message)
raise
return result
Expand Down
26 changes: 20 additions & 6 deletions kw/json/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


def _fail(obj, *args, **kwargs):
raise TypeError("Object of type {} is not JSON serializable".format(obj.__class__.__name__))
raise TypeError(
"Object of type {} is not JSON serializable".format(obj.__class__.__name__)
)


try:
Expand All @@ -23,7 +25,9 @@ def _fail(obj, *args, **kwargs):
dc_asdict = _fail


def default_encoder(obj, dict_factory=dict, date_as_unix_time=False): # Ignore RadonBear
def default_encoder(
obj, dict_factory=dict, date_as_unix_time=False
): # Ignore RadonBear
if hasattr(obj, "isoformat"): # date, datetime, arrow
if date_as_unix_time:
if obj.__class__.__name__ == "Arrow":
Expand All @@ -41,7 +45,10 @@ def default_encoder(obj, dict_factory=dict, date_as_unix_time=False): # Ignore
return obj.name

# Second option is for `iteritems()` on Python 2
if isinstance(obj, ItemsView) or obj.__class__.__name__ == "dictionary-itemiterator":
if (
isinstance(obj, ItemsView)
or obj.__class__.__name__ == "dictionary-itemiterator"
):
return dict_factory(obj)

if hasattr(obj, "asdict"): # dictablemodel
Expand All @@ -68,7 +75,9 @@ def default_encoder(obj, dict_factory=dict, date_as_unix_time=False): # Ignore
def raw_encoder(obj, date_as_unix_time=False):
"""Return representation of values that are not encodable instead of encoding them."""
try:
return default_encoder(obj, dict_factory=mask_dict, date_as_unix_time=date_as_unix_time)
return default_encoder(
obj, dict_factory=mask_dict, date_as_unix_time=date_as_unix_time
)
except TypeError:
return repr(obj)

Expand All @@ -95,7 +104,9 @@ def modify_kwargs(kwargs):
kwargs.setdefault("use_decimal", False)
if "default" not in kwargs:
date_as_unix_time = kwargs.pop("date_as_unix_time", False)
kwargs["default"] = partial(default_encoder, date_as_unix_time=date_as_unix_time)
kwargs["default"] = partial(
default_encoder, date_as_unix_time=date_as_unix_time
)


def format_value(value, precision):
Expand All @@ -104,7 +115,10 @@ def format_value(value, precision):
return round(value, precision)
if isinstance(value, (list, set)):
return traverse_iterable(value, precision)
if isinstance(value, ItemsView) or value.__class__.__name__ == "dictionary-itemiterator":
if (
isinstance(value, ItemsView)
or value.__class__.__name__ == "dictionary-itemiterator"
):
return traverse_dict(dict(value), precision)
if isinstance(value, dict):
return traverse_dict(value, precision)
Expand Down
7 changes: 5 additions & 2 deletions kw/json/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@


def mask_dict_factory(
placeholder=DEFAULT_PLACEHOLDER, blacklist=DEFAULT_BLACKLIST, whitelist=DEFAULT_WHITELIST,
placeholder=DEFAULT_PLACEHOLDER,
blacklist=DEFAULT_BLACKLIST,
whitelist=DEFAULT_WHITELIST,
):
def mask_dict(pairs):
"""Return a dict with dangerous looking key/value pairs masked."""
Expand All @@ -19,7 +21,8 @@ def mask_dict(pairs):
return {
key: (
placeholder
if key.lower() not in whitelist and any(word in key.lower() for word in blacklist)
if key.lower() not in whitelist
and any(word in key.lower() for word in blacklist)
else value
)
for key, value in items
Expand Down
4 changes: 3 additions & 1 deletion test/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
simplejson_loads = None


@pytest.mark.skipif(simplejson_loads is None, reason="Decimal encoding with simplejson only")
@pytest.mark.skipif(
simplejson_loads is None, reason="Decimal encoding with simplejson only"
)
@pytest.mark.parametrize(
"value, expected",
[
Expand Down
99 changes: 77 additions & 22 deletions test/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

from kw.json import KiwiJSONEncoder, MaskedJSONEncoder, default_encoder, dump, dumps, raw_encoder
from kw.json import (
KiwiJSONEncoder,
MaskedJSONEncoder,
default_encoder,
dump,
dumps,
raw_encoder,
)
from kw.json._compat import DataclassItem, enum
from kw.json.exceptions import KiwiJsonError

Expand Down Expand Up @@ -100,10 +107,16 @@ def test_default_encoder(value, expected, date_as_unix_time):
assert default_encoder(value, date_as_unix_time=date_as_unix_time) == expected


@pytest.mark.skipif(simplejson_dumps is None, reason="Decimal encoding with simplejson only")
@pytest.mark.skipif(
simplejson_dumps is None, reason="Decimal encoding with simplejson only"
)
@pytest.mark.parametrize(
"value, expected",
((Decimal("1"), "1"), (Decimal("-1"), "-1"), (Decimal("0.123456789123456789"), "0.123456789123456789"),),
(
(Decimal("1"), "1"),
(Decimal("-1"), "-1"),
(Decimal("0.123456789123456789"), "0.123456789123456789"),
),
)
def test_simplejson_encoder_with_decimal(value, expected):
assert dumps(value, use_decimal=True) == expected
Expand All @@ -127,7 +140,11 @@ def test_default_encoder_defaults():
(Decimal("1"), '"1"', False),
(UUID, '"{}"'.format(str(UUID)), False),
(datetime.datetime(2018, 1, 1), '"2018-01-01T00:00:00"', False),
(datetime.datetime(2018, 1, 1, tzinfo=UTC), '"2018-01-01T00:00:00+00:00"', False,),
(
datetime.datetime(2018, 1, 1, tzinfo=UTC),
'"2018-01-01T00:00:00+00:00"',
False,
),
(arrow.get("2018-01-01"), '"2018-01-01T00:00:00+00:00"', False),
(datetime.date(2018, 1, 1), '"2018-01-01"', False),
(datetime.datetime(2018, 1, 1), "1514764800", True),
Expand All @@ -149,15 +166,22 @@ def test_dumps(value, expected, date_as_unix_time):
(({1: 1.333}, {1: 1.333}), '{"1": 1.33}', 2),
(([1.333, 2.333], [1.333, 2.333]), "[1.33, 2.33]", 2),
(([1.333, {1: 1.333}], [1.333, {1: 1.333}]), '[1.33, {"1": 1.33}]', 2),
(([1.333, {1: 1.333}, {1.333}], [1.333, {1: 1.333}, {1.333}]), '[1.33, {"1": 1.33}, [1.33]]', 2),
(
([1.333, {1: 1.333}, {1.333}], [1.333, {1: 1.333}, {1.333}]),
'[1.33, {"1": 1.33}, [1.33]]',
2,
),
((items_view_float, items_view_float), '{"foo": 1.33}', 2),
((items_view_complex, items_view_complex), '{"1": 1.33, "2": {"2": 0.33}}', 2),
((HTML(), None), '"foo"', 2),
(([set()], [set()]), "[[]]", 2),
(((1.3333, 2.3333), (1.3333, 2.3333)), "[1.33, 2.33]", 2),
((({1: 1.33333}, 1.33333), ({1: 1.33333}, 1.33333)), '[{"1": 1.33}, 1.33]', 2),
(
([{1: 1.222, 2: [1.333, {1: 1.333}, {3: {3.333}}]}], [{1: 1.222, 2: [1.333, {1: 1.333}, {3: {3.333}}]}]),
(
[{1: 1.222, 2: [1.333, {1: 1.333}, {3: {3.333}}]}],
[{1: 1.222, 2: [1.333, {1: 1.333}, {3: {3.333}}]}],
),
'[{"1": 1.22, "2": [1.33, {"1": 1.33}, {"3": [3.33]}]}]',
2,
),
Expand All @@ -174,10 +198,17 @@ def test_rounding(values, expected, precision):
@pytest.mark.parametrize(
"values, expected, precision",
(
((test_namedtuple, test_namedtuple), {"as_object": '{"a": 1.33, "b": 2.33}', "as_list": "[1.33, 2.33]"}, 2),
(
(test_namedtuple, test_namedtuple),
{"as_object": '{"a": 1.33, "b": 2.33}', "as_list": "[1.33, 2.33]"},
2,
),
(
(test_namedtuple_complex, test_namedtuple_complex),
{"as_object": '{"a": 1.33, "b": {"a": 1.33, "b": {"1": 1.33}}}', "as_list": '[1.33, [1.33, {"1": 1.33}]]',},
{
"as_object": '{"a": 1.33, "b": {"a": 1.33, "b": {"1": 1.33}}}',
"as_list": '[1.33, [1.33, {"1": 1.33}]]',
},
2,
),
),
Expand All @@ -187,7 +218,10 @@ def test_rounding_tuples(values, expected, precision):
if simplejson_dumps:
# simplejson supports `namedtuple_as_object` param unlike json
assert dumps(before, precision=precision) == expected["as_object"]
assert dumps(before, precision=precision, namedtuple_as_object=False) == expected["as_list"]
assert (
dumps(before, precision=precision, namedtuple_as_object=False)
== expected["as_list"]
)
else:
assert dumps(before, precision=precision) == expected["as_list"]
assert before == after
Expand All @@ -204,9 +238,9 @@ def __repr__(self):
return "<Foo>"

# by default `raw_encoder` encodes dates as ISO
assert dumps({"foo": Foo(), "bar": datetime.date(2018, 1, 1)}, default=raw_encoder) == dumps(
{"foo": "<Foo>", "bar": "2018-01-01"}
)
assert dumps(
{"foo": Foo(), "bar": datetime.date(2018, 1, 1)}, default=raw_encoder
) == dumps({"foo": "<Foo>", "bar": "2018-01-01"})


def test_dump_with_default():
Expand Down Expand Up @@ -275,7 +309,9 @@ def test_unknown_raises():
class Foo(object):
bar = True # pylint: disable=C0102

with pytest.raises(TypeError, match="^Object of type Foo is not JSON serializable$"):
with pytest.raises(
TypeError, match="^Object of type Foo is not JSON serializable$"
):
default_encoder(Foo())


Expand Down Expand Up @@ -305,23 +341,32 @@ def test_masked_json_encoders(value, expected):
(partial(json_dumps, cls=MaskedJSONEncoder), '{"attrib": 1}'),
),
)
@pytest.mark.skipif(DataclassItem is None, reason="Dataclasses are available only on Python 3.7+")
@pytest.mark.skipif(
DataclassItem is None, reason="Dataclasses are available only on Python 3.7+"
)
def test_dataclasses(dumper, expected):
assert dumper(DataclassItem(attrib=1)) == expected # pylint: disable=not-callable


@pytest.mark.parametrize(
"dumper, expected",
((default_encoder, {"attrib": 1}), (partial(json_dumps, default=default_encoder), '{"attrib": 1}'),),
(
(default_encoder, {"attrib": 1}),
(partial(json_dumps, default=default_encoder), '{"attrib": 1}'),
),
)
def test_attrs(dumper, expected):
assert dumper(AttrsItem(attrib=1)) == expected


@pytest.mark.skipif(sys.version_info[:2] >= (3, 7), reason="Dataclasses should not be available")
@pytest.mark.skipif(
sys.version_info[:2] >= (3, 7), reason="Dataclasses should not be available"
)
def test_missing_dependency():
"""If we have a class that have the same attributes as attrs provide."""
with pytest.raises(TypeError, match="Object of type NotDataclassesItem is not JSON serializable"):
with pytest.raises(
TypeError, match="Object of type NotDataclassesItem is not JSON serializable"
):
default_encoder(NotDataclassesItem())


Expand Down Expand Up @@ -354,22 +399,32 @@ def test_sqlalchemy_cursor_row(alchemy_session):
assert_json(data, [{"id": 1, "name": "test"}])


@pytest.mark.skipif(sys.version_info[0] == 2, reason="That trick doesn't work on Python 2")
@pytest.mark.skipif(
sys.version_info[0] == 2, reason="That trick doesn't work on Python 2"
)
def test_no_attrs():
# Need to re-import
del sys.modules["kw.json"]
del sys.modules["kw.json.encode"]
sys.modules["attr"] = None
from kw.json import default_encoder # pylint: disable=reimported,import-outside-toplevel
from kw.json import ( # pylint: disable=reimported,import-outside-toplevel
default_encoder,
)

with pytest.raises(TypeError, match="Object of type NotAttrsItem is not JSON serializable"):
with pytest.raises(
TypeError, match="Object of type NotAttrsItem is not JSON serializable"
):
default_encoder(NotAttrsItem())


@pytest.mark.skipif(get_asyncpg_record is None, reason="Asyncpg is available only on Python 3.5+.")
@pytest.mark.skipif(
get_asyncpg_record is None, reason="Asyncpg is available only on Python 3.5+."
)
def test_asyncpg():
import asyncio # pylint: disable=import-outside-toplevel

loop = asyncio.get_event_loop()
result = loop.run_until_complete(get_asyncpg_record(os.getenv("DATABASE_URI"))) # pylint: disable=not-callable
result = loop.run_until_complete(
get_asyncpg_record(os.getenv("DATABASE_URI")) # pylint: disable=not-callable
)
assert json_dumps(result, default=default_encoder) == '[{"value": 1}]'

0 comments on commit a3bee37

Please sign in to comment.