From b1187ba97312c08d8e3ee1301861a3ff2cf44951 Mon Sep 17 00:00:00 2001 From: AN Long Date: Mon, 4 Mar 2024 18:31:26 +0800 Subject: [PATCH 1/2] Support build cython codes with `-Werror=strict-aliasing` --- thriftpy2/protocol/cybin/cybin.pyx | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/thriftpy2/protocol/cybin/cybin.pyx b/thriftpy2/protocol/cybin/cybin.pyx index 8ecf722..97d6ec9 100644 --- a/thriftpy2/protocol/cybin/cybin.pyx +++ b/thriftpy2/protocol/cybin/cybin.pyx @@ -2,6 +2,7 @@ import sys from libc.stdlib cimport free, malloc from libc.stdint cimport int16_t, int32_t, int64_t +from libc.string cimport memcpy from cpython cimport bool import six @@ -99,7 +100,9 @@ cdef inline int write_i64(CyTransportBase buf, int64_t val) except -1: cdef inline int write_double(CyTransportBase buf, double val) except -1: - cdef int64_t v = htobe64(((&val))[0]) + cdef int64_t v + memcpy(&v, &val, 8) + v = htobe64(v) buf.c_write((&v), 8) return 0 @@ -269,6 +272,7 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, cdef int size cdef int64_t n cdef TType v_type, k_type, orig_type, orig_key_type + cdef double double_value if ttype == T_BOOL: return read_i08(buf) @@ -287,7 +291,8 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, elif ttype == T_DOUBLE: n = read_i64(buf) - return ((&n))[0] + memcpy(&double_value, &n, 8) + return double_value elif ttype == T_BINARY: size = read_i32(buf) From 91b28ea3de4bc3d734364d9e902b8a2ddee5fc80 Mon Sep 17 00:00:00 2001 From: Erik Cederstrand Date: Thu, 7 Mar 2024 10:36:44 +0100 Subject: [PATCH 2/2] Add sasl transport support (#222) * Allow building package without pre-instaling Cython * Add support for SASL transport * Remove unused code * Fix buffer fetching * Remove redundant condition * Add link to original code * Revert build_ext changes * Create the buffer with a size that matches the data --------- Co-authored-by: Erik Cederstrand --- setup.py | 2 + thriftpy2/transport/__init__.py | 7 +- thriftpy2/transport/sasl/__init__.py | 203 +++++++++++++++++++++++++ thriftpy2/transport/sasl/cysasl.pyx | 214 +++++++++++++++++++++++++++ 4 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 thriftpy2/transport/sasl/__init__.py create mode 100644 thriftpy2/transport/sasl/cysasl.pyx diff --git a/setup.py b/setup.py index dde1a0b..113017a 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,8 @@ ["thriftpy2/transport/memory/cymemory.c"])) ext_modules.append(Extension("thriftpy2.transport.framed.cyframed", ["thriftpy2/transport/framed/cyframed.c"])) + ext_modules.append(Extension("thriftpy2.transport.sasl.cysasl", + ["thriftpy2/transport/sasl/cysasl.c"])) ext_modules.append(Extension("thriftpy2.protocol.cybin", ["thriftpy2/protocol/cybin/cybin.c"])) diff --git a/thriftpy2/transport/__init__.py b/thriftpy2/transport/__init__.py index c36753a..8c6c59a 100644 --- a/thriftpy2/transport/__init__.py +++ b/thriftpy2/transport/__init__.py @@ -11,11 +11,13 @@ from .buffered import TBufferedTransport, TBufferedTransportFactory from .framed import TFramedTransport, TFramedTransportFactory from .memory import TMemoryBuffer +from .sasl import TSaslClientTransport if CYTHON: from .buffered import TCyBufferedTransport, TCyBufferedTransportFactory from .framed import TCyFramedTransport, TCyFramedTransportFactory from .memory import TCyMemoryBuffer + from .sasl import TCySaslClientTransport # enable cython binary by default for CPython. TMemoryBuffer = TCyMemoryBuffer # noqa @@ -23,6 +25,7 @@ TBufferedTransportFactory = TCyBufferedTransportFactory # noqa TFramedTransport = TCyFramedTransport # noqa TFramedTransportFactory = TCyFramedTransportFactory # noqa + TSaslClientTransport = TCySaslClientTransport # noqa else: # disable cython binary protocol for PYPY since it's slower. TCyMemoryBuffer = TMemoryBuffer @@ -30,6 +33,7 @@ TCyBufferedTransportFactory = TBufferedTransportFactory TCyFramedTransport = TFramedTransport TCyFramedTransportFactory = TFramedTransportFactory + TCySaslClientTransport = TSaslClientTransport __all__ = [ "TSocket", "TServerSocket", @@ -38,5 +42,6 @@ "TMemoryBuffer", "TFramedTransport", "TFramedTransportFactory", "TBufferedTransport", "TBufferedTransportFactory", "TCyMemoryBuffer", "TCyBufferedTransport", "TCyBufferedTransportFactory", - "TCyFramedTransport", "TCyFramedTransportFactory" + "TCyFramedTransport", "TCyFramedTransportFactory", + "TSaslClientTransport", "TCySaslClientTransport", ] diff --git a/thriftpy2/transport/sasl/__init__.py b/thriftpy2/transport/sasl/__init__.py new file mode 100644 index 0000000..3ec1c5a --- /dev/null +++ b/thriftpy2/transport/sasl/__init__.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" SASL transports for Thrift. """ + +# Initially copied from +# https://github.com/cloudera/thrift_sasl/blob/master/thrift_sasl/__init__.py + +from __future__ import absolute_import + +import struct +from io import BytesIO + +from ..._compat import CYTHON +from ..base import TTransportBase, readall +from .. import TTransportException + + +class TSaslClientTransport(TTransportBase): + START = 1 + OK = 2 + BAD = 3 + ERROR = 4 + COMPLETE = 5 + + def __init__(self, sasl_client_factory, mechanism, trans): + """ + @param sasl_client_factory: a callable that returns a new sasl.Client object + @param mechanism: the SASL mechanism (e.g. "GSSAPI") + @param trans: the underlying transport over which to communicate. + """ + self._trans = trans + self.sasl_client_factory = sasl_client_factory + self.sasl = None + self.mechanism = mechanism + self.__wbuf = BytesIO() + self.__rbuf = BytesIO(b'') + self.encode = None + + def is_open(self): + return self._trans.is_open() + + def open(self): + if not self.is_open(): + self._trans.open() + + if self.sasl is not None: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message="Already open!") + self.sasl = self.sasl_client_factory() + + ret, chosen_mech, initial_response = self.sasl.start(self.mechanism) + if not ret: + raise TTransportException(type=TTransportException.NOT_OPEN, + message=("Could not start SASL: %s" % self.sasl.getError())) + + # Send initial response + self._send_message(self.START, chosen_mech) + self._send_message(self.OK, initial_response) + + # SASL negotiation loop + while True: + status, payload = self._recv_sasl_message() + if status not in (self.OK, self.COMPLETE): + raise TTransportException(type=TTransportException.NOT_OPEN, + message=("Bad status: %d (%s)" % (status, payload))) + if status == self.COMPLETE: + break + ret, response = self.sasl.step(payload) + if not ret: + raise TTransportException(type=TTransportException.NOT_OPEN, + message=("Bad SASL result: %s" % (self.sasl.getError()))) + self._send_message(self.OK, response) + + def _send_message(self, status, body): + header = struct.pack(">BI", status, len(body)) + self._trans.write(header + body) + self._trans.flush() + + def _recv_sasl_message(self): + header = readall(self._trans.read, 5) + status, length = struct.unpack(">BI", header) + if length > 0: + payload = readall(self._trans.read, length) + else: + payload = "" + return status, payload + + def write(self, data): + self.__wbuf.write(data) + + def flush(self): + buffer = self.__wbuf.getvalue() + # The first time we flush data, we send it to sasl.encode() + # If the length doesn't change, then we must be using a QOP + # of auth and we should no longer call sasl.encode(), otherwise + # we encode every time. + if self.encode is None: + success, encoded = self.sasl.encode(buffer) + if not success: + raise TTransportException(type=TTransportException.UNKNOWN, + message=self.sasl.getError()) + if (len(encoded) == len(buffer)): + self.encode = False + self._flushPlain(buffer) + else: + self.encode = True + self._trans.write(encoded) + elif self.encode: + self._flushEncoded(buffer) + else: + self._flushPlain(buffer) + + self._trans.flush() + self.__wbuf = BytesIO() + + def _flushEncoded(self, buffer): + # sasl.ecnode() does the encoding and adds the length header, so nothing + # to do but call it and write the result. + success, encoded = self.sasl.encode(buffer) + if not success: + raise TTransportException(type=TTransportException.UNKNOWN, + message=self.sasl.getError()) + self._trans.write(encoded) + + def _flushPlain(self, buffer): + # When we have QOP of auth, sasl.encode() will pass the input to the output + # but won't put a length header, so we have to do that. + + # Note stolen from TFramedTransport: + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + self._trans.write(struct.pack(">I", len(buffer)) + buffer) + + def c_flush(self): + return self.flush() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) == sz: + return ret + + self._read_frame() + return ret + self.__rbuf.read(sz - len(ret)) + + def _read_frame(self): + header = readall(self._trans.read, 4) + (length,) = struct.unpack(">I", header) + if self.encode: + # If the frames are encoded (i.e. you're using a QOP of auth-int or + # auth-conf), then make sure to include the header in the bytes you send to + # sasl.decode() + encoded = header + readall(self._trans.read, length) + success, decoded = self.sasl.decode(encoded) + if not success: + raise TTransportException(type=TTransportException.UNKNOWN, + message=self.sasl.getError()) + else: + # If the frames are not encoded, just pass it through + decoded = readall(self._trans.read, length) + self.__rbuf = BytesIO(decoded) + + def close(self): + self._trans.close() + self.sasl = None + + # XXX: Is this actually needed? + # Implement the CReadableTransport interface. + # Stolen shamelessly from TFramedTransport + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self._read_frame() + prefix += self.__rbuf.getvalue() + self.__rbuf = BytesIO(prefix) + return self.__rbuf + + +if CYTHON: + from .cysasl import TCySaslClientTransport # noqa diff --git a/thriftpy2/transport/sasl/cysasl.pyx b/thriftpy2/transport/sasl/cysasl.pyx new file mode 100644 index 0000000..29bfdb4 --- /dev/null +++ b/thriftpy2/transport/sasl/cysasl.pyx @@ -0,0 +1,214 @@ +import struct + +from thriftpy2.transport.cybase cimport ( + TCyBuffer, + CyTransportBase, + DEFAULT_BUFFER +) + +from ..base import readall +from .. import TTransportException + +from libc.string cimport memcpy + +DEF MIN_BUFFER_SIZE = 1024 + +cdef class TCySaslClientTransport(CyTransportBase): + """sasl wrapper""" + + START = 1 + OK = 2 + BAD = 3 + ERROR = 4 + COMPLETE = 5 + + cdef object sasl, sasl_client_factory + cdef TCyBuffer __wbuf, __rbuf + cdef bint opened, encode, encode_decided + cdef str mechanism + + def __init__(self, sasl_client_factory, mechanism, trans): + """ + @param sasl_client_factory: a callable that returns a new sasl.Client object + @param mechanism: the SASL mechanism (e.g. "GSSAPI") + @param trans: the underlying transport over which to communicate. + """ + self.trans = trans + self.sasl_client_factory = sasl_client_factory + self.sasl = None + self.mechanism = mechanism + self.__wbuf = TCyBuffer(DEFAULT_BUFFER) + self.__rbuf = TCyBuffer(DEFAULT_BUFFER) + self.encode_decided = False + self.encode = False + + def is_open(self): + return self.trans.is_open() + + def open(self): + if not self.is_open(): + self.trans.open() + + if self.sasl is not None: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message="Already open!") + self.sasl = self.sasl_client_factory() + + ret, chosen_mech, initial_response = self.sasl.start(self.mechanism) + if not ret: + raise TTransportException(type=TTransportException.NOT_OPEN, + message=("Could not start SASL: %s" % self.sasl.getError())) + + # Send initial response + self._send_message(self.START, chosen_mech) + self._send_message(self.OK, initial_response) + + # SASL negotiation loop + while True: + status, payload = self._recv_sasl_message() + if status not in (self.OK, self.COMPLETE): + raise TTransportException(type=TTransportException.NOT_OPEN, + message=("Bad status: %d (%s)" % (status, payload))) + if status == self.COMPLETE: + break + ret, response = self.sasl.step(payload) + if not ret: + raise TTransportException(type=TTransportException.NOT_OPEN, + message=("Bad SASL result: %s" % (self.sasl.getError()))) + self._send_message(self.OK, response) + + def _send_message(self, status, body): + header = struct.pack(">BI", status, len(body)) + self.trans.write(header + body) + self.trans.flush() + + def _recv_sasl_message(self): + header = readall(self.trans.read, 5) + status, length = struct.unpack(">BI", header) + if length > 0: + payload = readall(self.trans.read, length) + else: + payload = "" + return status, payload + + def write(self, bytes data): + cdef int sz = len(data) + return self.c_write(data, sz) + + cdef c_write(self, const char *data, int sz): + cdef: + int cap = self.__wbuf.buf_size - self.__wbuf.data_size + int r + + if cap < sz: + self.c_flush() + + r = self.__wbuf.write(sz, data) + if r == -1: + raise MemoryError("Write to buffer error") + + def flush(self): + return self.c_flush() + + cdef c_flush(self): + cdef bytes data + if self.__wbuf.data_size > 0: + data = self.__wbuf.buf[:self.__wbuf.data_size] + # The first time we flush data, we send it to sasl.encode() + # If the length doesn't change, then we must be using a QOP + # of auth and we should no longer call sasl.encode(), otherwise + # we encode every time. + if not self.encode_decided: + success, encoded = self.sasl.encode(data) + if not success: + raise TTransportException(type=TTransportException.UNKNOWN, + message=self.sasl.getError()) + if (len(encoded)==len(data)): + self.encode = False + self._flushPlain(data) + else: + self.encode = True + self.trans.write(encoded) + self.encode_decided = True + elif self.encode: + self._flushEncoded(data) + else: + self._flushPlain(data) + + self.trans.flush() + self.__wbuf.clean() + return("DUN FLUSHING IN SASL") + + def _flushEncoded(self, buffer): + # sasl.ecnode() does the encoding and adds the length header, so nothing + # to do but call it and write the result. + success, encoded = self.sasl.encode(buffer) + if not success: + raise TTransportException(type=TTransportException.UNKNOWN, + message=self.sasl.getError()) + self.trans.write(encoded) + + def _flushPlain(self, buffer): + # When we have QOP of auth, sasl.encode() will pass the input to the output + # but won't put a length header, so we have to do that. + + # Note stolen from TFramedTransport: + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + self.trans.write(struct.pack(">I", len(buffer)) + buffer) + + def read(self, sz): + return self.get_string(sz) + + cdef c_read(self, int sz, char* out): + cdef bytes ret + + ret = b"" + + if sz <= 0: + return 0 + + orig_sz = sz + if self.__rbuf.data_size < sz: + # Read what remains, then get more data plz + ret += self.__rbuf.buf[:self.__rbuf.data_size] + sz -= self.__rbuf.data_size + self._read_frame() + + ret += self.__rbuf.buf[self.__rbuf.cur:self.__rbuf.cur + sz] + self.__rbuf.cur += sz + self.__rbuf.data_size -= sz + + memcpy(out, ret, orig_sz) + + def _read_frame(self): + header = readall(self.trans.read, 4) + (length,) = struct.unpack(">I", header) + if self.encode_decided and self.encode: + # If the frames are encoded (i.e. you're using a QOP of auth-int or + # auth-conf), then make sure to include the header in the bytes you send to + # sasl.decode() + encoded = header + readall(self.trans.read, length) + success, decoded = self.sasl.decode(encoded) + if not success: + raise TTransportException(type=TTransportException.UNKNOWN, + message=self.sasl.getError()) + else: + # If the frames are not encoded, just pass it through + decoded = readall(self.trans.read, length) + self.__rbuf = TCyBuffer(len(decoded)+1) # just to be sure make room for an extra byte + memcpy(self.__rbuf.buf, decoded, len(decoded)) + self.__rbuf.data_size = len(decoded) + self.__rbuf.cur = 0 + + def clean(self): + self.__rbuf.clean() + self.__wbuf.clean() + + def close(self): + self.trans.close() + self.sasl = None +