Skip to content

Commit

Permalink
Store subscriptions in client
Browse files Browse the repository at this point in the history
  • Loading branch information
Elena Shylko committed Jan 17, 2020
1 parent 90766b0 commit 6f5eb96
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def on_message(client, topic, payload, qos, properties):
def on_disconnect(client, packet, exc=None):
print('Disconnected')

def on_subscribe(client, mid, qos):
def on_subscribe(client, mid, qos, properties):
print('SUBSCRIBED')

def ask_exit(*args):
Expand Down
2 changes: 1 addition & 1 deletion examples/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def on_disconnect(client, packet, exc=None):
logging.info('[DISCONNECTED {}]'.format(client._client_id))


def on_subscribe(client, mid, qos):
def on_subscribe(client, mid, qos, properties):
logging.info('[SUBSCRIBED {}] QOS: {}'.format(client._client_id, qos))


Expand Down
2 changes: 1 addition & 1 deletion examples/shared_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def on_disconnect(client, packet, exc=None):
logging.info('[DISCONNECTED {}]'.format(client._client_id))


def on_subscribe(client, mid, qos):
def on_subscribe(client, mid, qos, properties):
logging.info('[SUBSCRIBED {}] QOS: {}'.format(client._client_id, qos))


Expand Down
2 changes: 1 addition & 1 deletion examples/will_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def on_disconnect(client, packet, exc=None):
logging.info('[DISCONNECTED {}]'.format(client._client_id))


def on_subscribe(client, mid, qos):
def on_subscribe(client, mid, qos, properties):
logging.info('[SUBSCRIBED {}] QOS: {}'.format(client._client_id, qos))


Expand Down
2 changes: 1 addition & 1 deletion gmqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"Mikhail Turchunovich",
"Elena Nikolaichik"
]
__version__ = "0.5.6"
__version__ = "0.6.0"


__all__ = [
Expand Down
54 changes: 45 additions & 9 deletions gmqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import uuid
from typing import Union, Sequence

from .mqtt.protocol import MQTTProtocol
from .mqtt.connection import MQTTConnection
Expand Down Expand Up @@ -41,13 +42,20 @@ def __init__(self, topic, payload, qos=0, retain=False, **kwargs):


class Subscription:
def __init__(self, topic, qos=0, no_local=False, retain_as_published=False, retain_handling_options=0):
def __init__(self, topic, qos=0, no_local=False, retain_as_published=False, retain_handling_options=0,
subscription_identifier=None):
self.topic = topic
self.qos = qos
self.no_local = no_local
self.retain_as_published = retain_as_published
self.retain_handling_options = retain_handling_options

self.mid = None
self.acknowledged = False

# this property can be used only in MQTT5.0
self.subscription_identifier = subscription_identifier


class Client(MqttPackageHandler):
def __init__(self, client_id, clean_session=True, optimistic_acknowledgement=True,
Expand Down Expand Up @@ -81,6 +89,14 @@ def __init__(self, client_id, clean_session=True, optimistic_acknowledgement=Tru

self._resend_task = asyncio.ensure_future(self._resend_qos_messages())

self.subscriptions = []

def get_subscription_by_identifier(self, subscription_identifier):
return next((sub for sub in self.subscriptions if sub.subscription_identifier == subscription_identifier), None)

def get_subscriptions_by_mid(self, mid):
return [sub for sub in self.subscriptions if sub.mid == mid]

def _remove_message_from_query(self, mid):
logger.debug('[REMOVE MESSAGE] %s', mid)
asyncio.ensure_future(
Expand Down Expand Up @@ -202,19 +218,39 @@ async def _disconnect(self, reason_code=0, **properties):
self._connection.send_disconnect(reason_code=reason_code, **properties)
await self._connection.close()

def subscribe(self, subscription_or_topic, qos=0, no_local=False, retain_as_published=False,
retain_handling_options=0, **kwargs):
def subscribe(self, subscription_or_topic: Union[str, Subscription, Sequence[Subscription]],
qos=0, no_local=False, retain_as_published=False, retain_handling_options=0, **kwargs):
subscriptions = self.update_subscriptions_with_subscription_or_topic(
subscription_or_topic, qos, no_local, retain_as_published, retain_handling_options, kwargs)
return self._connection.subscribe(subscriptions, **kwargs)

def update_subscriptions_with_subscription_or_topic(
self, subscription_or_topic, qos, no_local, retain_as_published, retain_handling_options, kwargs):
subscription_identifier = kwargs.get('subscription_identifier')
if isinstance(subscription_or_topic, Subscription):
subscription = subscription_or_topic
subscription_or_topic.subscription_identifier = subscription_identifier
subscriptions = [subscription_or_topic]
elif isinstance(subscription_or_topic, (tuple, list)):
subscription = subscription_or_topic
for sub in subscription_or_topic:
sub.subscription_identifier = subscription_identifier
subscriptions = subscription_or_topic
elif isinstance(subscription_or_topic, str):
subscription = Subscription(subscription_or_topic, qos=qos, no_local=no_local,
retain_as_published=retain_as_published,
retain_handling_options=retain_handling_options)
subscriptions = [Subscription(subscription_or_topic, qos=qos, no_local=no_local,
retain_as_published=retain_as_published,
retain_handling_options=retain_handling_options,
subscription_identifier=subscription_identifier)]
else:
raise ValueError('Bad subscription: must be string or Subscription or list of Subscriptions')
return self._connection.subscribe(subscription, **kwargs)
self.subscriptions.extend(subscriptions)
return subscriptions

def resubscribe(self, subscription: Subscription, **kwargs):
# send subscribe packet for subscription,that's already in client's subscription list
if 'subscription_identifier' in kwargs:
subscription.subscription_identifier = kwargs['subscription_identifier']
elif subscription.subscription_identifier is not None:
kwargs['subscription_identifier'] = subscription.subscription_identifier
return self._connection.subscribe([subscription], **kwargs)

def unsubscribe(self, topic, **kwargs):
return self._connection.unsubscribe(topic, **kwargs)
Expand Down
16 changes: 16 additions & 0 deletions gmqtt/mqtt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,22 @@ class ConnAckReasonCode(enum.IntEnum):
CONNECTION_RATE_EXCEEDED = 159


class SubAckReasonCode(enum.IntEnum):
QOS0 = 0
QOS1 = 1
QOS2 = 2

UNSPECIFIED_ERROR = 128
IMPLEMENTATION_SPECIFIC_ERROR = 131
NOT_AUTHORIZED = 135
TOPIC_FILTER_INVALID = 143
PACKET_IDENTIFIER_IN_USE = 145
QUOTA_EXCEEDED = 151
SHARED_SUBSCRIPTIONS_NOT_SUPPORTED = 158
SUBSCRIPTION_IDENTIFIERS_NOT_SUPPORTED = 161
WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED = 162


UNLIMITED_RECONNECTS = -1

DEFAULT_CONFIG = {
Expand Down
17 changes: 14 additions & 3 deletions gmqtt/mqtt/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,22 @@ def __call__(self, cmd, packet):
def _handle_suback_packet(self, cmd, raw_packet):
pack_format = "!H" + str(len(raw_packet) - 2) + 's'
(mid, packet) = struct.unpack(pack_format, raw_packet)
properties, packet = self._parse_properties(packet)

pack_format = "!" + "B" * len(packet)
granted_qos = struct.unpack(pack_format, packet)
granted_qoses = struct.unpack(pack_format, packet)

subs = self.get_subscriptions_by_mid(mid)
for granted_qos, sub in zip(granted_qoses, subs):
if granted_qos >= 128:
# subscription was not acknowledged
sub.acknowledged = False
else:
sub.acknowledged = True
sub.qos = granted_qos

logger.info('[SUBACK] %s %s', mid, granted_qos)
self.on_subscribe(self, mid, granted_qos)
logger.info('[SUBACK] %s %s', mid, granted_qoses)
self.on_subscribe(self, mid, granted_qoses, properties)

self._id_generator.free_id(mid)

Expand Down
6 changes: 1 addition & 5 deletions gmqtt/mqtt/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,8 @@ def build_package(cls, topic, protocol, **kwargs) -> Tuple[int, bytes]:

class SubscribePacket(PackageFactory):
@classmethod
def build_package(cls, subscription, protocol, **kwargs) -> Tuple[int, bytes]:
def build_package(cls, subscriptions, protocol, **kwargs) -> Tuple[int, bytes]:
remaining_length = 2
if not isinstance(subscription, (list, tuple)):
subscriptions = [subscription]
else:
subscriptions = subscription

topics = []
for s in subscriptions:
Expand Down
6 changes: 4 additions & 2 deletions gmqtt/mqtt/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ async def send_auth_package(self, client_id, username, password, clean_session,
keepalive, self, will_message=will_message, **kwargs)
self.write_data(pkg)

def send_subscribe_packet(self, subscription, **kwargs):
mid, pkg = package.SubscribePacket.build_package(subscription, self, **kwargs)
def send_subscribe_packet(self, subscriptions, **kwargs):
mid, pkg = package.SubscribePacket.build_package(subscriptions, self, **kwargs)
for sub in subscriptions:
sub.mid = mid
self.write_data(pkg)
return mid

Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def on_message(self, client, topic, payload, qos, properties):
.format(client._client_id, topic, payload, qos, properties))
self.messages.append((topic, payload, qos, properties))

def on_subscribe(self, client, mid, qos):
logging.info('[SUBSCRIBED {}] QOS: {}'.format(client._client_id, qos))
def on_subscribe(self, client, mid, qos, properties):
logging.info('[SUBSCRIBED {}] QOS: {}, properties: {}'.format(client._client_id, qos, properties))
self.subscribeds.append(mid)

def on_connect(self, client, flags, rc, properties):
Expand Down

0 comments on commit 6f5eb96

Please sign in to comment.