Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PubSub: Allow custom encoder #850

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/socketio/asyncio_aiopika_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class AsyncAioPikaManager(AsyncPubSubManager): # pragma: no cover
name = 'asyncaiopika'

def __init__(self, url='amqp://guest:guest@localhost:5672//',
channel='socketio', write_only=False, logger=None):
channel='socketio', write_only=False, logger=None,
encoder=pickle):
if aio_pika is None:
raise RuntimeError('aio_pika package is not installed '
'(Run "pip install aio_pika" in your '
Expand Down Expand Up @@ -70,7 +71,7 @@ async def _publish(self, data):
channel = await self._channel(connection)
exchange = await self._exchange(channel)
await exchange.publish(
aio_pika.Message(body=pickle.dumps(data),
aio_pika.Message(body=self.encoder.dumps(data),
delivery_mode=aio_pika.DeliveryMode.PERSISTENT),
routing_key='*'
)
Expand All @@ -94,7 +95,7 @@ async def _listen(self):
async with self.listener_queue.iterator() as queue_iter:
async for message in queue_iter:
with message.process():
yield pickle.loads(message.body)
yield message.body
except Exception:
self._get_logger().error('Cannot receive from rabbitmq... '
'retrying in '
Expand Down
12 changes: 10 additions & 2 deletions src/socketio/asyncio_pubsub_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class AsyncPubSubManager(AsyncManager):
"""
name = 'asyncpubsub'

def __init__(self, channel='socketio', write_only=False, logger=None):
def __init__(self, channel='socketio', write_only=False, logger=None,
encoder=pickle):
super().__init__()
self.channel = channel
self.write_only = write_only
self.host_id = uuid.uuid4().hex
self.logger = logger
self.encoder = encoder

def initialize(self):
super().initialize()
Expand Down Expand Up @@ -153,7 +155,13 @@ async def _thread(self):
if isinstance(message, dict):
data = message
else:
if isinstance(message, bytes): # pragma: no cover
if self.encoder:
try:
data = self.encoder.loads(message)
except:
pass
if data is None and \
isinstance(message, bytes): # pragma: no cover
try:
data = pickle.loads(message)
except:
Expand Down
10 changes: 7 additions & 3 deletions src/socketio/asyncio_redis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover
and receiving.
:param redis_options: additional keyword arguments to be passed to
``aioredis.from_url()``.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.
"""
name = 'aioredis'

def __init__(self, url='redis://localhost:6379/0', channel='socketio',
write_only=False, logger=None, redis_options=None):
write_only=False, logger=None, redis_options=None,
encoder=pickle):
if aioredis is None:
raise RuntimeError('Redis package is not installed '
'(Run "pip install aioredis" in your '
Expand All @@ -46,7 +49,8 @@ def __init__(self, url='redis://localhost:6379/0', channel='socketio',
self.redis_url = url
self.redis_options = redis_options or {}
self._redis_connect()
super().__init__(channel=channel, write_only=write_only, logger=logger)
super().__init__(channel=channel, write_only=write_only, logger=logger,
encoder=encoder)

def _redis_connect(self):
self.redis = aioredis.Redis.from_url(self.redis_url,
Expand All @@ -60,7 +64,7 @@ async def _publish(self, data):
if not retry:
self._redis_connect()
return await self.redis.publish(
self.channel, pickle.dumps(data))
self.channel, self.encoder.dumps(data))
except aioredis.exceptions.RedisError:
if retry:
self._get_logger().error('Cannot publish to redis... '
Expand Down
11 changes: 7 additions & 4 deletions src/socketio/kafka_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,21 @@ class KafkaManager(PubSubManager): # pragma: no cover
:param write_only: If set to ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.
"""
name = 'kafka'

def __init__(self, url='kafka://localhost:9092', channel='socketio',
write_only=False):
write_only=False, encoder=pickle):
if kafka is None:
raise RuntimeError('kafka-python package is not installed '
'(Run "pip install kafka-python" in your '
'virtualenv).')

super(KafkaManager, self).__init__(channel=channel,
write_only=write_only)
write_only=write_only,
encoder=encoder)

urls = [url] if isinstance(url, str) else url
self.kafka_urls = [url[8:] if url != 'kafka://' else 'localhost:9092'
Expand All @@ -54,7 +57,7 @@ def __init__(self, url='kafka://localhost:9092', channel='socketio',
bootstrap_servers=self.kafka_urls)

def _publish(self, data):
self.producer.send(self.channel, value=pickle.dumps(data))
self.producer.send(self.channel, value=self.encoder.dumps(data))
self.producer.flush()

def _kafka_listen(self):
Expand All @@ -64,4 +67,4 @@ def _kafka_listen(self):
def _listen(self):
for message in self._kafka_listen():
if message.topic == self.channel:
yield pickle.loads(message.value)
yield message.value
10 changes: 7 additions & 3 deletions src/socketio/kombu_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,24 @@ class KombuManager(PubSubManager): # pragma: no cover
``kombu.Queue()``.
:param producer_options: additional keyword arguments to be passed to
``kombu.Producer()``.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.
"""
name = 'kombu'

def __init__(self, url='amqp://guest:guest@localhost:5672//',
channel='socketio', write_only=False, logger=None,
connection_options=None, exchange_options=None,
queue_options=None, producer_options=None):
queue_options=None, producer_options=None,
encoder=pickle):
if kombu is None:
raise RuntimeError('Kombu package is not installed '
'(Run "pip install kombu" in your '
'virtualenv).')
super(KombuManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
logger=logger,
encoder=encoder)
self.url = url
self.connection_options = connection_options or {}
self.exchange_options = exchange_options or {}
Expand Down Expand Up @@ -103,7 +107,7 @@ def _publish(self, data):
connection = self._connection()
publish = connection.ensure(self.producer, self.producer.publish,
errback=self.__error_callback)
publish(pickle.dumps(data))
publish(self.encoder.dumps(data))

def _listen(self):
reader_queue = self._queue()
Expand Down
12 changes: 10 additions & 2 deletions src/socketio/pubsub_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ class PubSubManager(BaseManager):
"""
name = 'pubsub'

def __init__(self, channel='socketio', write_only=False, logger=None):
def __init__(self, channel='socketio', write_only=False, logger=None,
encoder=None):
super(PubSubManager, self).__init__()
self.channel = channel
self.write_only = write_only
self.host_id = uuid.uuid4().hex
self.logger = logger
self.encoder = encoder

def initialize(self):
super(PubSubManager, self).initialize()
Expand Down Expand Up @@ -151,7 +153,13 @@ def _thread(self):
if isinstance(message, dict):
data = message
else:
if isinstance(message, bytes): # pragma: no cover
if self.encoder:
try:
data = self.encoder.loads(message)
except:
pass
if data is None and \
isinstance(message, bytes): # pragma: no cover
try:
data = pickle.loads(message)
except:
Expand Down
11 changes: 8 additions & 3 deletions src/socketio/redis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ class RedisManager(PubSubManager): # pragma: no cover
and receiving.
:param redis_options: additional keyword arguments to be passed to
``Redis.from_url()``.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.
"""
name = 'redis'

def __init__(self, url='redis://localhost:6379/0', channel='socketio',
write_only=False, logger=None, redis_options=None):
write_only=False, logger=None, redis_options=None,
encoder=pickle):
if redis is None:
raise RuntimeError('Redis package is not installed '
'(Run "pip install redis" in your '
Expand All @@ -50,7 +53,8 @@ def __init__(self, url='redis://localhost:6379/0', channel='socketio',
self._redis_connect()
super(RedisManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
logger=logger,
encoder=encoder)

def initialize(self):
super(RedisManager, self).initialize()
Expand Down Expand Up @@ -78,7 +82,8 @@ def _publish(self, data):
try:
if not retry:
self._redis_connect()
return self.redis.publish(self.channel, pickle.dumps(data))
return self.redis.publish(self.channel,
self.encoder.dumps(data))
except redis.exceptions.RedisError:
if retry:
logger.error('Cannot publish to redis... retrying')
Expand Down
17 changes: 9 additions & 8 deletions src/socketio/zmq_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class ZmqManager(PubSubManager): # pragma: no cover
:param write_only: If set to ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.

A zmq message broker must be running for the zmq_manager to work.
you can write your own or adapt one from the following simple broker
Expand All @@ -50,7 +52,8 @@ class ZmqManager(PubSubManager): # pragma: no cover
def __init__(self, url='zmq+tcp://localhost:5555+5556',
channel='socketio',
write_only=False,
logger=None):
logger=None,
encoder=pickle):
if zmq is None:
raise RuntimeError('zmq package is not installed '
'(Run "pip install pyzmq" in your '
Expand All @@ -77,17 +80,18 @@ def __init__(self, url='zmq+tcp://localhost:5555+5556',
self.channel = channel
super(ZmqManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
logger=logger,
encoder=encoder)

def _publish(self, data):
pickled_data = pickle.dumps(
encoded_data = self.encoder.dumps(
{
'type': 'message',
'channel': self.channel,
'data': data
}
)
return self.sink.send(pickled_data)
return self.sink.send(encoded_data)

def zmq_listen(self):
while True:
Expand All @@ -98,10 +102,7 @@ def zmq_listen(self):
def _listen(self):
for message in self.zmq_listen():
if isinstance(message, bytes):
try:
message = pickle.loads(message)
except Exception:
pass
yield message
if isinstance(message, dict) and \
message['type'] == 'message' and \
message['channel'] == self.channel and \
Expand Down
48 changes: 46 additions & 2 deletions tests/common/test_pubsub_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import functools
import logging
import pickle
import json
import marshal
import unittest
from unittest import mock

Expand Down Expand Up @@ -365,8 +368,6 @@ def test_background_thread(self):
self.pm._handle_close_room = mock.MagicMock()

def messages():
import pickle

yield {'method': 'emit', 'value': 'foo'}
yield {'missing': 'method'}
yield '{"method": "callback", "value": "bar"}'
Expand Down Expand Up @@ -394,3 +395,46 @@ def messages():
self.pm._handle_close_room.assert_called_once_with(
{'method': 'close_room', 'value': 'baz'}
)

def test_background_thread_with_encoder(self):
mock_server = mock.MagicMock()
pm = pubsub_manager.PubSubManager(encoder=marshal)
pm.set_server(mock_server)
pm._publish = mock.MagicMock()
pm._handle_emit = mock.MagicMock()
pm._handle_callback = mock.MagicMock()
pm._handle_disconnect = mock.MagicMock()
pm._handle_close_room = mock.MagicMock()

pm.initialize()

def messages():
yield {'method': 'emit', 'value': 'foo'}
yield marshal.dumps({'method': 'callback', 'value': 'bar'})
yield json.dumps(
{'method': 'disconnect', 'sid': '123', 'namespace': '/foo'}
)
yield pickle.dumps({'method': 'close_room', 'value': 'baz'})
yield {'method': 'bogus'}
yield 'bad json'
yield b'bad encoding'

pm._listen = mock.MagicMock(side_effect=messages)

try:
pm._thread()
except StopIteration:
pass

pm._handle_emit.assert_called_once_with(
{'method': 'emit', 'value': 'foo'}
)
pm._handle_callback.assert_called_once_with(
{'method': 'callback', 'value': 'bar'}
)
pm._handle_disconnect.assert_called_once_with(
{'method': 'disconnect', 'sid': '123', 'namespace': '/foo'}
)
pm._handle_close_room.assert_called_once_with(
{'method': 'close_room', 'value': 'baz'}
)