diff --git a/README.rst b/README.rst index 4a0c598a..2ad6a2ff 100644 --- a/README.rst +++ b/README.rst @@ -168,6 +168,10 @@ python lib): * json protocol + * Apache JSON protocol compatible with apache thrift distribution's JSON protocol. + Simply do ``from thriftpy2.protocol import TApacheJSONProtocolFactory`` and pass + this to the ``proto_factory`` argument where appropriate. + * buffered transport (python & cython) * framed transport diff --git a/setup.py b/setup.py index db5148f3..11f77522 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ install_requires = [ "ply>=3.4,<4.0", + "six~=1.15", ] tornado_requires = [ @@ -35,6 +36,7 @@ "pytest>=2.8", "sphinx-rtd-theme>=0.1.9", "sphinx>=1.3", + "pytest>=6.1.1", ] + tornado_requires diff --git a/tests/apache_json_test.thrift b/tests/apache_json_test.thrift new file mode 100644 index 00000000..83008e0c --- /dev/null +++ b/tests/apache_json_test.thrift @@ -0,0 +1,41 @@ +exception TestException { + 1: string message +} +struct Foo { + 1: string bar +} +struct Test { + 1: bool tbool, + 2: i8 tbyte, + 3: i16 tshort, + 4: i32 tint, + 5: i64 tlong, + 6: double tdouble, + 7: string tstr, + 8: list tlist_of_strings, + 9: map tmap_of_int2str, + 10: set tsetofints, + 11: map tmap_of_str2foo, + 12: map> tmap_of_str2stringlist, + 13: map> tmap_of_str2mapofstring2foo, + 14: list tlist_of_foo, + 15: Foo tfoo, + 16: list> tlist_of_maps2int, + 17: map> tmap_of_str2foolist, + 18: map tmap_of_int2foo, + 19: binary tbinary, + 20: optional map tmap_of_bool2str + 21: optional map tmap_of_bool2int, + 22: list tlist_of_binary, + 23: set tset_of_binary, + 24: map tbin2bin, + +} + +service TestService { + // Testing Service that just returns what you give it + Test test(1: Test test); + void do_error(1: string arg) throws ( + 1: TestException e + ) +} diff --git a/tests/bin_test.thrift b/tests/bin_test.thrift new file mode 100644 index 00000000..be5df50d --- /dev/null +++ b/tests/bin_test.thrift @@ -0,0 +1,15 @@ +struct BinTest { + 1: binary tbinary, + 2: map str2bin, + 3: map bin2bin, + 4: map bin2str, + 5: list binlist, + 6: set binset, + 7: map> map_of_str2binlist, + 8: map> map_of_bin2bin, + 9: optional list> list_of_bin2str +} +service BinService { + // Testing Service that just returns what you give it + BinTest test(1: BinTest test); +} diff --git a/tests/test_all_protocols_binary_field.py b/tests/test_all_protocols_binary_field.py new file mode 100644 index 00000000..03c604b0 --- /dev/null +++ b/tests/test_all_protocols_binary_field.py @@ -0,0 +1,346 @@ +from __future__ import absolute_import + +import time +import traceback +from multiprocessing import Process + +import pytest +import six + +from thriftpy2.thrift import TType, TPayloadMeta +try: + from thriftpy2.protocol import cybin +except ImportError: + cybin = None +import thriftpy2 +from thriftpy2.http import ( + make_server as make_http_server, + make_client as make_http_client, +) +from thriftpy2.protocol import ( + TApacheJSONProtocolFactory, + TJSONProtocolFactory, + TCompactProtocolFactory, +) +from thriftpy2.protocol import TBinaryProtocolFactory +from thriftpy2.rpc import make_server as make_rpc_server, \ + make_client as make_rpc_client +from thriftpy2.transport import TBufferedTransportFactory, TCyMemoryBuffer + +protocols = [TApacheJSONProtocolFactory, + TJSONProtocolFactory, + TBinaryProtocolFactory, + TCompactProtocolFactory] + + +def recursive_vars(obj): + if isinstance(obj, six.string_types): + return six.ensure_str(obj) + if isinstance(obj, six.binary_type): + return six.ensure_binary(obj) + if isinstance(obj, (int, float, bool)): + return obj + if isinstance(obj, dict): + return {k: recursive_vars(v) for k, v in obj.items()} + if isinstance(obj, (list, set)): + return [recursive_vars(v) for v in obj] + if hasattr(obj, '__dict__'): + return recursive_vars(vars(obj)) + + +@pytest.mark.parametrize('server_func', + [(make_rpc_server, make_rpc_client), + (make_http_server, make_http_client)]) +@pytest.mark.parametrize('tlist', [[], ['a', 'b', 'c']]) +@pytest.mark.parametrize('binary', [b'', b'\x01\x03test binary\x03\xff']) +@pytest.mark.parametrize('proto_factory', protocols) +def test_protocols(proto_factory, binary, tlist, server_func): + test_thrift = thriftpy2.load( + "apache_json_test.thrift", + module_name="test_thrift" + ) + Foo = test_thrift.Foo + + class Handler(object): + @staticmethod + def test(t): + return t + + trans_factory = TBufferedTransportFactory + + def run_server(): + server = server_func[0]( + test_thrift.TestService, + handler=Handler(), + host='localhost', + port=9090, + proto_factory=proto_factory(), + trans_factory=trans_factory(), + ) + server.serve() + + proc = Process(target=run_server) + proc.start() + time.sleep(0.2) + err = None + try: + test_object = test_thrift.Test( + tdouble=12.3456, + tint=567, + tstr='A test \'{["string', + tbinary=binary, + tlist_of_strings=tlist, + tbool=False, + tbyte=16, + tlong=123123123, + tshort=123, + tsetofints={1, 2, 3, 4, 5}, + tmap_of_int2str={ + 1: "one", + 2: "two", + 3: "three" + }, + tmap_of_str2foo={'first': Foo("first"), "2nd": Foo("baz")}, + tmap_of_str2foolist={ + 'test': [Foo("test list entry")] + }, + tmap_of_str2mapofstring2foo={ + "first": { + "second": Foo("testing") + } + }, + tmap_of_str2stringlist={ + "words": ["dog", "cat", "pie"], + "other": ["test", "foo", "bar", "baz", "quux"] + }, + tfoo=Foo("test food"), + tlist_of_foo=[Foo("1"), Foo("2"), Foo("3")], + tlist_of_maps2int=[ + {"one": 1, "two": 2, "three": 3} + ], + tmap_of_int2foo={ + 1: Foo("One"), + 2: Foo("Two"), + 5: Foo("Five") + }, + tbin2bin={b'Binary': b'data'}, + tset_of_binary={b'bin one', b'bin two'}, + tlist_of_binary=[b'foo roo', b'baz boo'], + ) + + client = server_func[1]( + test_thrift.TestService, + host='localhost', + port=9090, + proto_factory=proto_factory(), + trans_factory=trans_factory(), + ) + res = client.test(test_object) + assert recursive_vars(res) == recursive_vars(test_object) + except Exception as e: + traceback.print_exc() + err = e + finally: + proc.terminate() + if err: + raise err + time.sleep(0.1) + + +@pytest.mark.parametrize('server_func', + [(make_rpc_server, make_rpc_client), + (make_http_server, make_http_client)]) +@pytest.mark.parametrize('proto_factory', protocols) +def test_exceptions(server_func, proto_factory): + test_thrift = thriftpy2.load( + "apache_json_test.thrift", + module_name="test_thrift" + ) + TestException = test_thrift.TestException + + class Handler(object): + def do_error(self, arg): + raise TestException(message=arg) + + def do_server(): + server = server_func[0]( + service=test_thrift.TestService, + handler=Handler(), + host='localhost', + port=9090, + proto_factory=proto_factory() + ) + server.serve() + + proc = Process(target=do_server) + proc.start() + time.sleep(0.25) + msg = "exception raised!" + with pytest.raises(TestException)as e: + client = server_func[1]( + test_thrift.TestService, + host='localhost', + port=9090, + proto_factory=proto_factory() + ) + client.do_error(msg) + assert e.value.message == msg + + proc.terminate() + time.sleep(1) + + +@pytest.mark.parametrize('proto_factory', protocols) +def test_complex_binary(proto_factory): + + spec = thriftpy2.load("bin_test.thrift", module_name="bin_thrift") + bin_test_obj = spec.BinTest( + tbinary=b'\x01\x0f\xffa binary string\x0f\xee', + str2bin={ + 'key': 'value', + 'foo': 'bar' + }, + bin2bin={ + b'bin_key': b'bin_val', + 'str2bytes': b'bin bar' + }, + bin2str={ + b'bin key': 'str val', + }, + binlist=[b'bin one', b'bin two', 'str should become bin'], + binset={b'val 1', b'foo', b'bar', b'baz'}, + map_of_str2binlist={ + 'key1': [b'bin 1', b'pop 2'] + }, + map_of_bin2bin={ + b'abc': { + b'def': b'val', + b'\x1a\x04': b'\x45' + } + }, + list_of_bin2str=[ + { + b'bin key': 'str val', + b'other key\x04': 'bob' + } + ] + ) + + class Handler(object): + @staticmethod + def test(t): + return t + + trans_factory = TBufferedTransportFactory + + def run_server(): + server = make_rpc_server( + spec.BinService, + handler=Handler(), + host='localhost', + port=9090, + proto_factory=proto_factory(), + trans_factory=trans_factory(), + ) + server.serve() + + proc = Process(target=run_server) + proc.start() + time.sleep(0.2) + + try: + client = make_rpc_client( + spec.BinService, + host='localhost', + port=9090, + proto_factory=proto_factory(), + trans_factory=trans_factory(), + ) + res = client.test(bin_test_obj) + check_types(spec.BinTest.thrift_spec, res) + finally: + proc.terminate() + time.sleep(0.2) + + +@pytest.mark.skipif(cybin is None, reason="Must be run in cpython") +def test_complex_map(): + """ + Test from #156 + """ + proto = cybin + b1 = TCyMemoryBuffer() + proto.write_val(b1, TType.MAP, {"hello": "1"}, + spec=(TType.STRING, TType.STRING)) + b1.flush() + + b2 = TCyMemoryBuffer() + proto.write_val(b2, TType.MAP, {"hello": b"1"}, + spec=(TType.STRING, TType.BINARY)) + b2.flush() + + assert b1.getvalue() != b2.getvalue() + + +type_map = { + TType.BYTE: (int,), + TType.I16: (int,), + TType.I32: (int,), + TType.I64: (int,), + TType.DOUBLE: (float,), + TType.STRING: six.string_types, + TType.BOOL: (bool,), + TType.STRUCT: TPayloadMeta, + TType.SET: (set, list), + TType.LIST: (list,), + TType.MAP: (dict,), + TType.BINARY: six.binary_type +} + +type_names = { + TType.BYTE: "Byte", + TType.I16: "I16", + TType.I32: "I32", + TType.I64: "I64", + TType.DOUBLE: "Double", + TType.STRING: "String", + TType.BOOL: "Bool", + TType.STRUCT: "Struct", + TType.SET: "Set", + TType.LIST: "List", + TType.MAP: "Map", + TType.BINARY: "Binary" +} + + +def check_types(spec, val): + """ + This function should check if a given thrift object matches + a thrift spec + Nb. This function isn't complete + + """ + if isinstance(spec, int): + assert isinstance(val, type_map.get(spec)) + elif isinstance(spec, tuple): + if len(spec) >= 2: + if spec[0] in (TType.LIST, TType.SET): + for item in val: + check_types(spec[1], item) + else: + for i in spec.values(): + t, field_name, to_type = i[:3] + value = getattr(val, field_name) + assert isinstance(value, type_map.get(t)), \ + "Field {} expected {} got {}".format( + field_name, type_names.get(t), type(value)) + if to_type: + if t in (TType.SET, TType.LIST): + for _val in value: + check_types(to_type, _val) + elif t == TType.MAP: + for _key, _val in value.items(): + check_types(to_type[0], _key) + check_types(to_type[1], _val) + elif t == TType.STRUCT: + check_types(to_type, value) diff --git a/tests/test_apache_json.py b/tests/test_apache_json.py new file mode 100644 index 00000000..1a2b8d94 --- /dev/null +++ b/tests/test_apache_json.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +import json +import time +from multiprocessing import Process + +import pytest +import six + +import thriftpy2 +from thriftpy2.http import make_server as make_http_server, \ + make_client as make_http_client +from thriftpy2.protocol import TApacheJSONProtocolFactory +from thriftpy2.rpc import make_server as make_rpc_server, \ + make_client as make_rpc_client +from thriftpy2.thrift import TProcessor, TType +from thriftpy2.transport import TMemoryBuffer +from thriftpy2.transport.buffered import TBufferedTransportFactory + + +def recursive_vars(obj): + if isinstance(obj, six.string_types): + return six.ensure_str(obj) + if isinstance(obj, six.binary_type): + return six.ensure_binary(obj) + if isinstance(obj, (int, float, bool)): + return obj + if isinstance(obj, dict): + return {k: recursive_vars(v) for k, v in obj.items()} + if isinstance(obj, (list, set)): + return [recursive_vars(v) for v in obj] + if hasattr(obj, '__dict__'): + return recursive_vars(vars(obj)) + + +def test_thrift_transport(): + test_thrift = thriftpy2.load( + "apache_json_test.thrift", + module_name="test_thrift" + ) + Test = test_thrift.Test + Foo = test_thrift.Foo + test_object = Test( + tbool=False, + tbyte=16, + tdouble=1.234567, + tlong=123123123, + tshort=123, + tint=12345678, + tstr="Testing String", + tsetofints={1, 2, 3, 4, 5}, + tmap_of_int2str={ + 1: "one", + 2: "two", + 3: "three" + }, + tlist_of_strings=["how", "do", "i", "test", "this?"], + tmap_of_str2foo={'first': Foo("first"), "2nd": Foo("baz")}, + tmap_of_str2foolist={ + 'test': [Foo("test list entry")] + }, + tmap_of_str2mapofstring2foo={ + "first": { + "second": Foo("testing") + } + }, + tmap_of_str2stringlist={ + "words": ["dog", "cat", "pie"], + "other": ["test", "foo", "bar", "baz", "quux"] + }, + tfoo=Foo("test food"), + tlist_of_foo=[Foo("1"), Foo("2"), Foo("3")], + tlist_of_maps2int=[ + {"one": 1, "two": 2, "three": 3} + ], + tmap_of_int2foo={ + 1: Foo("One"), + 2: Foo("Two"), + 5: Foo("Five") + }, + tbinary=b"\x01\x0fabc123\x00\x02" + ) + # A request generated by apache thrift that matches the above object + request_data = b"""[1,"test",1,0,{"1":{"rec":{"1":{"tf":0},"2":{"i8":16}, + "3":{"i16":123},"4":{"i32":12345678},"5":{"i64":123123123},"6": + {"dbl":1.234567},"7":{"str":"Testing String"},"8":{"lst":["str",5, + "how","do","i","test","this?"]},"9":{"map":["i32","str",3,{"1":"one", + "2":"two","3":"three"}]},"10":{"set":["i32",5,1,2,3,4,5]}, + "11":{"map":["str","rec",2,{"first":{"1":{"str":"first"}},"2nd": + {"1":{"str":"baz"}}}]},"12":{"map":["str","lst", + 2,{"words":["str",3,"dog","cat","pie"],"other":["str",5,"test", + "foo","bar","baz","quux"]}]},"13":{"map":["str", + "map",1,{"first":["str","rec",1,{"second":{"1":{"str":"testing"}}}]}]}, + "14":{"lst":["rec",3,{"1":{"str":"1"}}, + {"1":{"str":"2"}},{"1":{"str":"3"}}]},"15":{"rec":{"1":{ + "str":"test food"}}},"16":{"lst":["map",1,["str","i32", + 3,{"one":1,"two":2,"three":3}]]},"17":{"map":["str","lst",1,{"test": + ["rec",1,{"1":{"str":"test list entry"}}]}]}, + "18":{"map":["i32","rec",3,{"1":{"1":{"str":"One"}},"2":{"1": + {"str":"Two"}},"5":{"1":{"str":"Five"}}}]}, + "19":{"str":"AQ9hYmMxMjMAAg=="}}}}]""" + + class Handler: + @staticmethod + def test(t): + # t should match the object above + expected_a = recursive_vars(t) + expected_b = recursive_vars(test_object) + + if TType.STRING != TType.BINARY: + assert expected_a == expected_b + return t + + tp2_thrift_processor = TProcessor(test_thrift.TestService, Handler()) + tp2_factory = TApacheJSONProtocolFactory() + iprot = tp2_factory.get_protocol(TMemoryBuffer(request_data)) + obuf = TMemoryBuffer() + oprot = tp2_factory.get_protocol(obuf) + + tp2_thrift_processor.process(iprot, oprot) + + # output buffers should be the same + final_data = obuf.getvalue() + assert json.loads(request_data.decode('utf8'))[4]['1'] == \ + json.loads(final_data.decode('utf8'))[4]['0'] + + +@pytest.mark.parametrize('server_func', [(make_rpc_server, make_rpc_client), + (make_http_server, make_http_client)]) +def test_client(server_func): + test_thrift = thriftpy2.load( + "apache_json_test.thrift", + module_name="test_thrift" + ) + + class Handler: + @staticmethod + def test(t): + return t + + def run_server(): + server = make_http_server( + test_thrift.TestService, + handler=Handler(), + host='localhost', + port=9090, + proto_factory=TApacheJSONProtocolFactory(), + trans_factory=TBufferedTransportFactory() + ) + server.serve() + + proc = Process(target=run_server, ) + proc.start() + time.sleep(0.25) + + try: + test_object = test_thrift.Test( + tdouble=12.3456, + tint=567, + tstr='A test \'{["string', + tmap_of_bool2str={True: "true string", False: "false string"}, + tmap_of_bool2int={True: 0, False: 1} + ) + + client = make_http_client( + test_thrift.TestService, + host='localhost', + port=9090, + proto_factory=TApacheJSONProtocolFactory(), + trans_factory=TBufferedTransportFactory() + ) + res = client.test(test_object) + assert recursive_vars(res) == recursive_vars(test_object) + finally: + proc.terminate() + time.sleep(1) diff --git a/thriftpy2/contrib/aio/protocol/binary.py b/thriftpy2/contrib/aio/protocol/binary.py index fa5e3a7e..dea51c53 100644 --- a/thriftpy2/contrib/aio/protocol/binary.py +++ b/thriftpy2/contrib/aio/protocol/binary.py @@ -22,6 +22,7 @@ from .base import TAsyncProtocolBase +BIN_TYPES = (TType.STRING, TType.BINARY) @asyncio.coroutine def read_message_begin(inbuf, strict=True): @@ -94,6 +95,10 @@ def read_val(inbuf, ttype, spec=None, decode_response=True): elif ttype == TType.DOUBLE: return unpack_double((yield from inbuf.read(8))) + elif ttype == TType.BINARY: + sz = unpack_i32((yield from inbuf.read(4))) + return inbuf.read(sz) + elif ttype == TType.STRING: sz = unpack_i32((yield from inbuf.read(4))) byte_payload = yield from inbuf.read(sz) @@ -116,7 +121,7 @@ def read_val(inbuf, ttype, spec=None, decode_response=True): result = [] r_type, sz = yield from read_list_begin(inbuf) # the v_type is useless here since we already get it from spec - if r_type != v_type: + if r_type != v_type and not (r_type in BIN_TYPES and v_type in BIN_TYPES): for _ in range(sz): yield from skip(inbuf, r_type) return [] @@ -144,6 +149,10 @@ def read_val(inbuf, ttype, spec=None, decode_response=True): result = {} sk_type, sv_type, sz = yield from read_map_begin(inbuf) + if sk_type in BIN_TYPES: + sk_type = k_type + if sv_type in BIN_TYPES: + sv_type = v_type if sk_type != k_type or sv_type != v_type: for _ in range(sz): yield from skip(inbuf, sk_type) @@ -183,8 +192,11 @@ def read_struct(inbuf, obj, decode_response=True): # it really should equal here. but since we already wasted # space storing the duplicate info, let's check it. if f_type != sf_type: - yield from skip(inbuf, f_type) - continue + if f_type in BIN_TYPES: + f_type = sf_type + else: + yield from skip(inbuf, f_type) + continue _buf = yield from read_val( inbuf, f_type, f_container_spec, decode_response) @@ -208,7 +220,7 @@ def skip(inbuf, ftype): elif ftype == TType.DOUBLE: yield from inbuf.read(8) - elif ftype == TType.STRING: + elif ftype in BIN_TYPES: _size = yield from inbuf.read(4) yield from inbuf.read(unpack_i32(_size)) diff --git a/thriftpy2/contrib/aio/protocol/compact.py b/thriftpy2/contrib/aio/protocol/compact.py index b0a5744e..c931f450 100644 --- a/thriftpy2/contrib/aio/protocol/compact.py +++ b/thriftpy2/contrib/aio/protocol/compact.py @@ -15,6 +15,7 @@ from .base import TAsyncProtocolBase +BIN_TYPES = (TType.STRING, TType.BINARY) @asyncio.coroutine def read_varint(trans): @@ -143,6 +144,11 @@ def _read_double(self): val, = unpack(' 2 and isinstance(val, str): + return val.encode() + return val + + +class TApacheJSONProtocolFactory(object): + @staticmethod + def get_protocol(trans): + return TApacheJSONProtocol(trans) + + +class TApacheJSONProtocol(TProtocolBase): + """ + Protocol that implements the Apache JSON Protocol + """ + + def __init__(self, trans): + TProtocolBase.__init__(self, trans) + self._req = None + + def _load_data(self): + data = b"" + l_braces = 0 + in_string = False + while True: + # read(sz) will wait until it has read exactly sz bytes, + # so we must read until we get a balanced json list in absence of knowing + # how long the json string will be + if hasattr(self.trans, 'getvalue'): + try: + data = self.trans.getvalue() + break + except Exception: + pass + new_data = self.trans.read(1) + data += new_data + if new_data == b'"' and not data.endswith(b'\\"'): + in_string = not in_string + if not in_string: + if new_data == b"[": + l_braces += 1 + elif new_data == b"]": + l_braces -= 1 + if l_braces == 0: + break + if data: + self._req = json.loads(data.decode('utf8')) + else: + self._req = None + + def read_message_begin(self): + if not self._req: + self._load_data() + return self._req[1:4] + + def read_message_end(self): + pass + + def skip(self, ttype): + pass + + def write_message_end(self): + pass + + def write_message_begin(self, name, ttype, seqid): + self.api = name + self.ttype = ttype + self.seqid = seqid + + def write_struct(self, obj): + """ + Write json to self.trans following apache style jsonification of `obj` + + :param obj: A thriftpy2 object + :return: + """ + doc = [VERSION, self.api, self.ttype, self.seqid, self._thrift_to_dict(obj)] + json_str = json.dumps(doc, separators=(',', ':')) + self.trans.write(json_str.encode("utf8")) + + def _thrift_to_dict(self, thrift_obj, item_type=None): + """ + Convert a thriftpy2 into an apache conformant dict, eg: + + >>> {0: {'rec': {1: {'str': "304"}, 14: {'rec': {1: {'lst': ["rec", 0]}}}}}} + + >>> {"0":{"rec":{"1":{"str":"284"},"14":{"rec":{"1":{"lst": + >>> ["rec",2,{"1":{"i32":12345.0},"2":{"i32":2.0},"3":{"str":"Testing notifications"},"4":{"tf":1}}, + {"1":{"i32":567809.0},"2":{"i32":2.0},"3":{"str":"Other test"},"4":{"tf":0}}]}}}}}} + + :param thrift_obj: the thing we want to make into a dict + :param item_type: the type of the item we are to convert + :return: + """ + if not hasattr(thrift_obj, 'thrift_spec'): + # use item_type to render it + if item_type is not None: + if isinstance(item_type, tuple) and len(item_type) > 1: + to_type = item_type[1] + flat_key_val = [TType.STRUCT if hasattr(t, 'thrift_spec') else t for t in flatten(to_type)] + if flat_key_val[0] == TType.LIST or isinstance(thrift_obj, list): + return [CTYPES[flat_key_val[1]], len(thrift_obj)] + [self._thrift_to_dict(v, to_type[1]) for v + in thrift_obj] + elif flat_key_val[0] == TType.MAP or isinstance(thrift_obj, dict): + if to_type[0] == TType.MAP: + key_type = flat_key_val[1] + val_type = flat_key_val[2] + else: + key_type = flat_key_val[0] + val_type = flat_key_val[1] + return [CTYPES[key_type], CTYPES[val_type], len(thrift_obj), { + self._thrift_to_dict(k, key_type): + self._thrift_to_dict(v, to_type[1]) for k, v in thrift_obj.items() + }] + if (to_type == TType.BINARY or item_type[-1] == TType.BINARY) and TType.BINARY != TType.STRING: + return base64.b64encode(_ensure_b64_encode(thrift_obj)).decode('ascii') + if isinstance(thrift_obj, bool): + return int(thrift_obj) + if ( + item_type == TType.BINARY + or (isinstance(item_type, tuple) and item_type[0] == TType.BINARY) + ) and TType.BINARY != TType.STRING: + return base64.b64encode(_ensure_b64_encode(thrift_obj)).decode("ascii") + return thrift_obj + result = {} + for field_idx, thrift_spec in thrift_obj.thrift_spec.items(): + ttype, field_name, spec = thrift_spec[:3] + if isinstance(spec, int): + spec = (spec,) + val = getattr(thrift_obj, field_name) + if val is not None: + if ttype == TType.STRUCT: + result[field_idx] = { + CTYPES[ttype]: self._thrift_to_dict(val) + } + elif ttype in [TType.LIST, TType.SET]: + # format is [list_item_type, length, items] + result[field_idx] = { + CTYPES[ttype]: [CTYPES[spec[0]], len(val)] + [self._thrift_to_dict(v, spec) for v in val] + } + elif ttype == TType.MAP: + key_type = CTYPES[spec[0]] + val_type = CTYPES[spec[1][0] if isinstance(spec[1], tuple) else spec[1]] + # format is [key_type, value_type, length, dict] + result[field_idx] = { + CTYPES[ttype]: [key_type, val_type, len(val), + {self._thrift_to_dict(k, spec[0]): + self._thrift_to_dict(v, spec) for k, v in val.items()}] + } + elif ttype == TType.BINARY and TType.BINARY != TType.STRING: + result[field_idx] = { + CTYPES[ttype]: base64.b64encode(_ensure_b64_encode(val)).decode('ascii') + } + elif ttype == TType.BOOL: + result[field_idx] = { + CTYPES[ttype]: int(val) + } + else: + result[field_idx] = { + CTYPES[ttype]: val + } + return result + + def _dict_to_thrift(self, data, base_type): + """ + Convert an apache thrift dict (where key is the type, value is the data) + + :param data: the dict data + :param base_type: the type we are going to convert data to + :return: + """ + # if the result is a python type, return it: + if isinstance(data, (str, int, float, bool, six.string_types, six.binary_type)) or data is None: + if base_type in (TType.I08, TType.I16, TType.I32, TType.I64): + return int(data) + if base_type == TType.BINARY and TType.BINARY != TType.STRING: + return base64.b64decode(data) + if base_type == TType.BOOL: + return { + 'true': True, + 'false': False, + '1': True, + '0': False + }[data.lower()] + if isinstance(data, bool): + return int(data) + return data + + if isinstance(base_type, tuple): + container_type = base_type[0] + item_type = base_type[1] + if container_type == TType.STRUCT: + return self._dict_to_thrift(data, item_type) + elif container_type in (TType.LIST, TType.SET): + return [self._dict_to_thrift(v, item_type) for v in data[2:]] + elif container_type == TType.MAP: + return { + self._dict_to_thrift(k, item_type[0]): + self._dict_to_thrift(v, item_type[1]) for k, v in data[3].items() + } + result = {} + base_spec = base_type.thrift_spec + for field_idx, val in data.items(): + thrift_spec = base_spec[int(field_idx)] + # spec has field type, field name, (sub spec), False + field_name = thrift_spec[1] + for ftype, value in val.items(): + ttype = JTYPES[ftype] + if thrift_spec[0] == TType.BINARY and TType.BINARY != TType.STRING: + bin_data = val.get('str', '') + m = len(bin_data) % 4 + if m != 0: + bin_data += '=' * (4-m) + result[field_name] = base64.b64decode(bin_data) + + elif ttype == TType.STRUCT: + result[field_name] = self._dict_to_thrift(value, thrift_spec[2]) + elif ttype in (TType.LIST, TType.SET): + result[field_name] = [self._dict_to_thrift(v, thrift_spec[2]) for v in value[2:]] + elif ttype == TType.MAP: + key_spec = thrift_spec[2][0] + val_spec = thrift_spec[2][1] + result[field_name] = { + self._dict_to_thrift(k, key_spec): self._dict_to_thrift(v, val_spec) + for k, v in value[3].items() + } + else: + result[field_name] = { + 'tf': bool, + 'i8': int, + 'i16': int, + 'i32': int, + 'i64': int, + 'dbl': float, + 'str': str, + }[ftype](value) + if hasattr(base_type, '__call__'): + return base_type(**result) + else: + for k, v in result.items(): + setattr(base_type, k, v) + return base_type + + def read_struct(self, obj): + """ + Read the next struct into obj, usually the argument from an incoming request + Only really used to read the arguments off a request into whatever we want + see thriftpy2.thrift.TProcessor.process_in for how this class will be used + + Will turn the contents of self.req[4] into the args of obj, + ie. self.req[4]["1"] must be rendered into obj.thrift_spec + + :param obj: + :return: + """ + return self._dict_to_thrift(self._req[4], obj) diff --git a/thriftpy2/protocol/binary.py b/thriftpy2/protocol/binary.py index 8bcc528c..2d399df8 100644 --- a/thriftpy2/protocol/binary.py +++ b/thriftpy2/protocol/binary.py @@ -14,6 +14,7 @@ # VERSION_1 = 0x80010000 VERSION_1 = -2147418112 TYPE_MASK = 0x000000ff +BIN_TYPES = (TType.STRING, TType.BINARY) def pack_i8(byte): @@ -72,6 +73,8 @@ def write_message_begin(outbuf, name, ttype, seqid, strict=True): def write_field_begin(outbuf, ttype, fid): + if ttype == TType.BINARY: + ttype = TType.STRING outbuf.write(pack_i8(ttype) + pack_i16(fid)) @@ -109,7 +112,7 @@ def write_val(outbuf, ttype, val, spec=None): elif ttype == TType.DOUBLE: outbuf.write(pack_double(val)) - elif ttype == TType.STRING: + elif ttype in BIN_TYPES: if not isinstance(val, bytes): val = val.encode('utf-8') outbuf.write(pack_string(val)) @@ -225,6 +228,10 @@ def read_val(inbuf, ttype, spec=None, decode_response=True): elif ttype == TType.DOUBLE: return unpack_double(inbuf.read(8)) + elif ttype == TType.BINARY: + sz = unpack_i32(inbuf.read(4)) + return inbuf.read(sz) + elif ttype == TType.STRING: sz = unpack_i32(inbuf.read(4)) byte_payload = inbuf.read(sz) @@ -247,7 +254,7 @@ def read_val(inbuf, ttype, spec=None, decode_response=True): result = [] r_type, sz = read_list_begin(inbuf) # the v_type is useless here since we already get it from spec - if r_type != v_type: + if r_type != v_type and not (r_type in BIN_TYPES and v_type in BIN_TYPES): for _ in range(sz): skip(inbuf, r_type) return [] @@ -271,6 +278,10 @@ def read_val(inbuf, ttype, spec=None, decode_response=True): result = {} sk_type, sv_type, sz = read_map_begin(inbuf) + if sk_type in BIN_TYPES: + sk_type = k_type + if sv_type in BIN_TYPES: + sv_type = v_type if sk_type != k_type or sv_type != v_type: for _ in range(sz): skip(inbuf, sk_type) @@ -309,8 +320,11 @@ def read_struct(inbuf, obj, decode_response=True): # it really should equal here. but since we already wasted # space storing the duplicate info, let's check it. if f_type != sf_type: - skip(inbuf, f_type) - continue + if f_type in BIN_TYPES: + f_type = sf_type + else: + skip(inbuf, f_type) + continue setattr(obj, f_name, read_val(inbuf, f_type, f_container_spec, decode_response)) @@ -332,7 +346,7 @@ def skip(inbuf, ftype): elif ftype == TType.DOUBLE: inbuf.read(8) - elif ftype == TType.STRING: + elif ftype in BIN_TYPES: inbuf.read(unpack_i32(inbuf.read(4))) elif ftype == TType.SET or ftype == TType.LIST: diff --git a/thriftpy2/protocol/compact.py b/thriftpy2/protocol/compact.py index 34a89f18..7c567b3f 100644 --- a/thriftpy2/protocol/compact.py +++ b/thriftpy2/protocol/compact.py @@ -3,8 +3,11 @@ from __future__ import absolute_import import array +import sys from struct import pack, unpack +import six + from .exc import TProtocolException from .base import TProtocolBase from ..thrift import TException @@ -22,6 +25,8 @@ VALUE_READ = 7 BOOL_READ = 8 +BIN_TYPES = (TType.STRING, TType.BINARY) + def check_integer_limits(i, bits): if bits == 8 and (i < -128 or i > 127): @@ -108,7 +113,8 @@ class CompactType(object): TType.STRUCT: CompactType.STRUCT, TType.LIST: CompactType.LIST, TType.SET: CompactType.SET, - TType.MAP: CompactType.MAP + TType.MAP: CompactType.MAP, + TType.BINARY: CompactType.BINARY, } TTYPES = dict((v, k) for k, v in CTYPES.items()) TTYPES[CompactType.FALSE] = TType.BOOL @@ -227,6 +233,10 @@ def _read_double(self): val, = unpack(' 2: + b = b.encode() + self.trans.write(b) + def _write_string(self, s): if not isinstance(s, bytes): s = s.encode('utf-8') @@ -467,6 +493,9 @@ def _write_val(self, ttype, val, spec=None): elif ttype == TType.DOUBLE: self._write_double(val) + elif ttype == TType.BINARY: + self._write_binary(val) + elif ttype == TType.STRING: self._write_string(val) @@ -520,6 +549,9 @@ def skip(self, ttype): elif ttype == TType.DOUBLE: self._read_double() + elif ttype == TType.BINARY: + self._read_binary() + elif ttype == TType.STRING: self._read_string() diff --git a/thriftpy2/protocol/cybin/cybin.pyx b/thriftpy2/protocol/cybin/cybin.pyx index d8f59351..6927cb68 100644 --- a/thriftpy2/protocol/cybin/cybin.pyx +++ b/thriftpy2/protocol/cybin/cybin.pyx @@ -1,7 +1,11 @@ +import sys + from libc.stdlib cimport free, malloc from libc.stdint cimport int16_t, int32_t, int64_t from cpython cimport bool +import six + from thriftpy2.transport.cybase cimport CyTransportBase, STACK_STRING_LEN from ..thrift import TDecodeException @@ -37,7 +41,10 @@ ctypedef enum TType: T_SET = 14, T_LIST = 15, T_UTF8 = 16, - T_UTF16 = 17 + T_UTF16 = 17, + T_BINARY = 18 + +BIN_TYPES = (T_BINARY, T_STRING) class ProtocolError(Exception): pass @@ -172,7 +179,7 @@ cdef inline read_struct(CyTransportBase buf, obj, decode_response=True): field_spec = field_specs[fid] ttype = field_spec[0] - if field_type != ttype: + if field_type != ttype and not (ttype in BIN_TYPES and field_type in BIN_TYPES): skip(buf, field_type) continue @@ -205,12 +212,14 @@ cdef inline write_struct(CyTransportBase buf, obj): v = getattr(obj, f_name, None) if v is None: continue - - write_i08(buf, f_type) + if f_type == T_BINARY: + write_i08(buf, T_STRING) + else: + write_i08(buf, f_type) write_i16(buf, fid) try: c_write_val(buf, f_type, v, container_spec) - except (TypeError, AttributeError, AssertionError, OverflowError): + except (TypeError, AttributeError, AssertionError, OverflowError) as e: raise TDecodeException(obj.__class__.__name__, fid, f_name, v, f_type, container_spec) @@ -265,6 +274,10 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, n = read_i64(buf) return ((&n))[0] + elif ttype == T_BINARY: + size = read_i32(buf) + return c_read_binary(buf, size) + elif ttype == T_STRING: size = read_i32(buf) if decode_response: @@ -283,7 +296,7 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, orig_type = read_i08(buf) size = read_i32(buf) - if orig_type != v_type: + if orig_type != v_type and not (orig_type in BIN_TYPES and v_type in BIN_TYPES): for _ in range(size): skip(buf, orig_type) return [] @@ -311,7 +324,10 @@ cdef c_read_val(CyTransportBase buf, TType ttype, spec=None, orig_key_type = read_i08(buf) orig_type = read_i08(buf) size = read_i32(buf) - + if orig_key_type in BIN_TYPES: + orig_key_type = k_type + if orig_type in BIN_TYPES: + orig_type = v_type if orig_key_type != k_type or orig_type != v_type: for _ in range(size): skip(buf, orig_key_type) @@ -344,8 +360,13 @@ cdef c_write_val(CyTransportBase buf, TType ttype, val, spec=None): elif ttype == T_DOUBLE: write_double(buf, val) + elif ttype == T_BINARY: + if isinstance(val, six.string_types) and sys.version_info[0] > 2: + val = val.encode() + write_string(buf, val) + elif ttype == T_STRING: - if not isinstance(val, bytes): + if not isinstance(val, six.binary_type): try: val = val.encode("utf-8") except Exception: @@ -374,7 +395,7 @@ cpdef skip(CyTransportBase buf, TType ttype): read_i32(buf) elif ttype == T_I64 or ttype == T_DOUBLE: read_i64(buf) - elif ttype == T_STRING: + elif ttype == T_STRING or ttype == T_BINARY: size = read_i32(buf) c_read_binary(buf, size) elif ttype == T_SET or ttype == T_LIST: diff --git a/thriftpy2/protocol/json.py b/thriftpy2/protocol/json.py index 9cf17828..89a57bd3 100644 --- a/thriftpy2/protocol/json.py +++ b/thriftpy2/protocol/json.py @@ -4,8 +4,12 @@ import json import struct +import base64 +import sys from warnings import warn +import six + from thriftpy2._compat import u from thriftpy2.thrift import TType @@ -15,6 +19,12 @@ VERSION = 1 +def encode_binary(data): + if isinstance(data, six.string_types) and sys.version_info[0] > 2: + data = data.encode() + return base64.b64encode(data).decode('ascii') + + def json_value(ttype, val, spec=None): TTYPE_TO_JSONFUNC_MAP = { TType.BYTE: (int, (val, )), @@ -28,6 +38,7 @@ def json_value(ttype, val, spec=None): TType.SET: (list_to_json, (val, spec)), TType.LIST: (list_to_json, (val, spec)), TType.MAP: (map_to_json, (val, spec)), + TType.BINARY: (encode_binary, (val, )), } func, args = TTYPE_TO_JSONFUNC_MAP.get(ttype) if func: @@ -53,6 +64,7 @@ def obj_value(ttype, val, spec=None): TType.SET: (list_to_obj, (val, spec)), TType.LIST: (list_to_obj, (val, spec)), TType.MAP: (map_to_obj, (val, spec)), + TType.BINARY: (base64.b64decode, (val, )), } func, args = TTYPE_TO_OBJFUNC_MAP.get(ttype) if func: diff --git a/thriftpy2/thrift.py b/thriftpy2/thrift.py index 99aa6a3c..f167719b 100644 --- a/thriftpy2/thrift.py +++ b/thriftpy2/thrift.py @@ -97,13 +97,13 @@ class TType(object): I64 = 10 STRING = 11 UTF7 = 11 - BINARY = 11 # This here just for parsing. For all purposes, it's a string STRUCT = 12 MAP = 13 SET = 14 LIST = 15 UTF8 = 16 UTF16 = 17 + BINARY = 18 _VALUES_TO_NAMES = { STOP: 'STOP', @@ -117,13 +117,13 @@ class TType(object): I64: 'I64', STRING: 'STRING', UTF7: 'STRING', - BINARY: 'STRING', STRUCT: 'STRUCT', MAP: 'MAP', SET: 'SET', LIST: 'LIST', UTF8: 'UTF8', - UTF16: 'UTF16' + UTF16: 'UTF16', + BINARY: 'BINARY' }