Skip to content

Commit

Permalink
Merge pull request #227 from aisk/strict-decode
Browse files Browse the repository at this point in the history
add 'strict_decode' option for protocols
  • Loading branch information
ethe authored Nov 28, 2023
2 parents 9a2553a + 8d340f7 commit fc26372
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 45 deletions.
33 changes: 33 additions & 0 deletions tests/test_aio_protocol_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-

from io import BytesIO

import pytest

from thriftpy2.thrift import TType, TPayload
from thriftpy2.contrib.aio.protocol import binary as proto


class TItem(TPayload):
thrift_spec = {
1: (TType.I32, "id", False),
2: (TType.LIST, "phones", (TType.STRING), False),
}
default_spec = [("id", None), ("phones", None)]


class AsyncBytesIO:
def __init__(self, b):
self.b = b

async def read(self, *args, **kwargs):
return self.b.read(*args, **kwargs)


@pytest.mark.asyncio
async def test_strict_decode():
bs = AsyncBytesIO(BytesIO(b"\x00\x00\x00\x0c\x00" # there is a redundant '\x00'
b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"))
with pytest.raises(UnicodeDecodeError):
await proto.read_val(bs, TType.STRING, decode_response=True,
strict_decode=True)
33 changes: 33 additions & 0 deletions tests/test_aio_protocol_compact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-

from io import BytesIO

import pytest

from thriftpy2.thrift import TType, TPayload
from thriftpy2.contrib.aio.protocol import compact
from test_aio_protocol_binary import AsyncBytesIO


class TItem(TPayload):
thrift_spec = {
1: (TType.I32, "id", False),
2: (TType.LIST, "phones", (TType.STRING), False),
}
default_spec = [("id", None), ("phones", None)]


def gen_proto(bytearray=b''):
b = AsyncBytesIO(BytesIO(bytearray))
proto = compact.TAsyncCompactProtocol(b)
return (b, proto)


@pytest.mark.asyncio
async def test_strict_decode():
b, proto = gen_proto(b'\x0c\xe4\xbd\xa0\xe5\xa5\x00'
b'\xbd\xe4\xb8\x96\xe7\x95\x8c')
proto.strict_decode = True

with pytest.raises(UnicodeDecodeError):
await proto._read_val(TType.STRING)
10 changes: 10 additions & 0 deletions tests/test_protocol_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from io import BytesIO

import pytest

from thriftpy2._compat import u
from thriftpy2.thrift import TType, TPayload
from thriftpy2.utils import hexlify
Expand Down Expand Up @@ -98,6 +100,14 @@ def test_unpack_binary():
bs, TType.STRING, decode_response=False)


def test_strict_decode():
bs = BytesIO(b"\x00\x00\x00\x0c\x00" # there is a redundant '\x00'
b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c")
with pytest.raises(UnicodeDecodeError):
proto.read_val(bs, TType.STRING, decode_response=True,
strict_decode=True)


def test_write_message_begin():
b = BytesIO()
proto.TBinaryProtocol(b).write_message_begin("test", TType.STRING, 1)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_protocol_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from io import BytesIO

import pytest

from thriftpy2._compat import u
from thriftpy2.thrift import TType, TPayload
from thriftpy2.utils import hexlify
Expand Down Expand Up @@ -115,6 +117,15 @@ def test_unpack_binary():
assert u('你好世界').encode("utf-8") == proto._read_val(TType.STRING)


def test_strict_decode():
b, proto = gen_proto(b'\x0c\xe4\xbd\xa0\xe5\xa5\x00'
b'\xbd\xe4\xb8\x96\xe7\x95\x8c')
proto.strict_decode = True

with pytest.raises(UnicodeDecodeError):
proto._read_val(TType.STRING)


def test_pack_bool():
b, proto = gen_proto()
proto._write_bool(True)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_protocol_cybinary.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ def test_read_binary():
b, TType.STRING, decode_response=False)


def test_strict_decode():
bs = TCyMemoryBuffer(b"\x00\x00\x00\x0c\x00" # there is a redundant '\x00'
b"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c")
with pytest.raises(UnicodeDecodeError):
proto.read_val(bs, TType.STRING, decode_response=True,
strict_decode=True)


def test_write_message_begin():
trans = TCyMemoryBuffer()
b = proto.TCyBinaryProtocol(trans)
Expand Down
33 changes: 21 additions & 12 deletions thriftpy2/contrib/aio/protocol/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ async def read_map_begin(inbuf):
return k_type, v_type, sz


async def read_val(inbuf, ttype, spec=None, decode_response=True):
async def read_val(inbuf, ttype, spec=None, decode_response=True,
strict_decode=False):
if ttype == TType.BOOL:
return bool(unpack_i8(await inbuf.read(1)))

Expand Down Expand Up @@ -103,7 +104,8 @@ async def read_val(inbuf, ttype, spec=None, decode_response=True):
try:
return byte_payload.decode('utf-8')
except UnicodeDecodeError:
pass
if strict_decode:
raise
return byte_payload

elif ttype == TType.SET or ttype == TType.LIST:
Expand All @@ -123,7 +125,8 @@ async def read_val(inbuf, ttype, spec=None, decode_response=True):

for i in range(sz):
result.append(
await read_val(inbuf, v_type, v_spec, decode_response)
await read_val(inbuf, v_type, v_spec, decode_response,
strict_decode)
)
return result

Expand Down Expand Up @@ -153,19 +156,21 @@ async def read_val(inbuf, ttype, spec=None, decode_response=True):
return {}

for i in range(sz):
k_val = await read_val(inbuf, k_type, k_spec, decode_response)
v_val = await read_val(inbuf, v_type, v_spec, decode_response)
k_val = await read_val(inbuf, k_type, k_spec, decode_response,
strict_decode)
v_val = await read_val(inbuf, v_type, v_spec, decode_response,
strict_decode)
result[k_val] = v_val

return result

elif ttype == TType.STRUCT:
obj = spec()
await read_struct(inbuf, obj, decode_response)
await read_struct(inbuf, obj, decode_response, strict_decode)
return obj


async def read_struct(inbuf, obj, decode_response=True):
async def read_struct(inbuf, obj, decode_response=True, strict_decode=False):
while True:
f_type, fid = await read_field_begin(inbuf)
if f_type == TType.STOP:
Expand All @@ -191,7 +196,7 @@ async def read_struct(inbuf, obj, decode_response=True):
continue

_buf = await read_val(
inbuf, f_type, f_container_spec, decode_response)
inbuf, f_type, f_container_spec, decode_response, strict_decode)
setattr(obj, f_name, _buf)


Expand Down Expand Up @@ -239,11 +244,12 @@ class TAsyncBinaryProtocol(TAsyncProtocolBase):

def __init__(self, trans,
strict_read=True, strict_write=True,
decode_response=True):
decode_response=True, strict_decode=False):
TAsyncProtocolBase.__init__(self, trans)
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response
self.strict_decode = strict_decode

async def skip(self, ttype):
await skip(self.trans, ttype)
Expand All @@ -266,23 +272,26 @@ def write_message_end(self):
pass

async def read_struct(self, obj):
return await read_struct(self.trans, obj, self.decode_response)
return await read_struct(self.trans, obj, self.decode_response,
self.strict_decode)

def write_struct(self, obj):
write_val(self.trans, TType.STRUCT, obj)


class TAsyncBinaryProtocolFactory(object):
def __init__(self, strict_read=True, strict_write=True,
decode_response=True):
decode_response=True, strict_decode=False):
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response
self.strict_decode = strict_decode

def get_protocol(self, trans):
return TAsyncBinaryProtocol(
trans,
self.strict_read,
self.strict_write,
self.decode_response
self.decode_response,
self.strict_decode,
)
7 changes: 5 additions & 2 deletions thriftpy2/contrib/aio/protocol/compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ async def _read_string(self):
try:
byte_payload = byte_payload.decode('utf-8')
except UnicodeDecodeError:
pass
if self.strict_decode:
raise
return byte_payload

async def _read_bool(self):
Expand Down Expand Up @@ -305,11 +306,13 @@ async def skip(self, ttype):


class TAsyncCompactProtocolFactory(object):
def __init__(self, decode_response=True):
def __init__(self, decode_response=True, strict_decode=False):
self.decode_response = decode_response
self.strict_decode = strict_decode

def get_protocol(self, trans):
return TAsyncCompactProtocol(
trans,
decode_response=self.decode_response,
strict_decode=self.strict_decode,
)
33 changes: 21 additions & 12 deletions thriftpy2/protocol/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def read_map_begin(inbuf):
return k_type, v_type, sz


def read_val(inbuf, ttype, spec=None, decode_response=True):
def read_val(inbuf, ttype, spec=None, decode_response=True,
strict_decode=False):
if ttype == TType.BOOL:
return bool(unpack_i8(inbuf.read(1)))

Expand Down Expand Up @@ -248,7 +249,8 @@ def read_val(inbuf, ttype, spec=None, decode_response=True):
try:
return byte_payload.decode('utf-8')
except UnicodeDecodeError:
pass
if strict_decode:
raise
return byte_payload

elif ttype == TType.SET or ttype == TType.LIST:
Expand All @@ -267,7 +269,8 @@ def read_val(inbuf, ttype, spec=None, decode_response=True):
return []

for i in range(sz):
result.append(read_val(inbuf, v_type, v_spec, decode_response))
result.append(read_val(inbuf, v_type, v_spec, decode_response,
strict_decode))
return result

elif ttype == TType.MAP:
Expand Down Expand Up @@ -296,19 +299,21 @@ def read_val(inbuf, ttype, spec=None, decode_response=True):
return {}

for i in range(sz):
k_val = read_val(inbuf, k_type, k_spec, decode_response)
v_val = read_val(inbuf, v_type, v_spec, decode_response)
k_val = read_val(inbuf, k_type, k_spec, decode_response,
strict_decode)
v_val = read_val(inbuf, v_type, v_spec, decode_response,
strict_decode)
result[k_val] = v_val

return result

elif ttype == TType.STRUCT:
obj = spec()
read_struct(inbuf, obj, decode_response)
read_struct(inbuf, obj, decode_response, strict_decode)
return obj


def read_struct(inbuf, obj, decode_response=True):
def read_struct(inbuf, obj, decode_response=True, strict_decode=False):
while True:
f_type, fid = read_field_begin(inbuf)
if f_type == TType.STOP:
Expand All @@ -334,7 +339,8 @@ def read_struct(inbuf, obj, decode_response=True):
continue

setattr(obj, f_name,
read_val(inbuf, f_type, f_container_spec, decode_response))
read_val(inbuf, f_type, f_container_spec, decode_response,
strict_decode))


def skip(inbuf, ftype):
Expand Down Expand Up @@ -380,11 +386,12 @@ class TBinaryProtocol(TProtocolBase):

def __init__(self, trans,
strict_read=True, strict_write=True,
decode_response=True):
decode_response=True, strict_decode=False):
TProtocolBase.__init__(self, trans)
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response
self.strict_decode = strict_decode

def skip(self, ttype):
skip(self.trans, ttype)
Expand All @@ -405,20 +412,22 @@ def write_message_end(self):
pass

def read_struct(self, obj):
return read_struct(self.trans, obj, self.decode_response)
return read_struct(self.trans, obj, self.decode_response,
self.strict_decode)

def write_struct(self, obj):
write_val(self.trans, TType.STRUCT, obj)


class TBinaryProtocolFactory(object):
def __init__(self, strict_read=True, strict_write=True,
decode_response=True):
decode_response=True, strict_decode=False):
self.strict_read = strict_read
self.strict_write = strict_write
self.decode_response = decode_response
self.strict_decode = strict_decode

def get_protocol(self, trans):
return TBinaryProtocol(trans,
self.strict_read, self.strict_write,
self.decode_response)
self.decode_response, self.strict_decode)
Loading

0 comments on commit fc26372

Please sign in to comment.