Skip to content

Commit

Permalink
Unsubscribe functionality + pytest environment (#21)
Browse files Browse the repository at this point in the history
* Unsubscribe functionality + flow control minor fix

* Pytest environment

* Update README.md

* Codecoverage on travis

* Codecov badge in README
  • Loading branch information
Lenka42 authored and wialon committed Sep 22, 2018
1 parent 3e2d514 commit e30baec
Show file tree
Hide file tree
Showing 14 changed files with 350 additions and 4 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,12 @@
.idea
gmqtt.egg-info
dist/

# virtualenvs
env/
pyenv/

# pytest
.coverage
.pytest_cache/
htmlcov/
10 changes: 10 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
language: python
python:
- "3.6"
install:
- pip install -r requirements.txt
# command to run tests
script:
- pytest --cov=gmqtt
after_success:
- codecov
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[![Build Status](https://travis-ci.com/Lenka42/gmqtt.svg?branch=master)](https://travis-ci.com/Lenka42/gmqtt) [![codecov](https://codecov.io/gh/Lenka42/gmqtt/branch/master/graph/badge.svg)](https://codecov.io/gh/Lenka42/gmqtt)

### Python MQTT client implementation.


Expand Down
7 changes: 6 additions & 1 deletion gmqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(self, client_id, clean_session=True, transport='tcp', optimistic_ac
self._retry_deliver_timeout = kwargs.pop('retry_deliver_timeout', 5)
self._persistent_storage = kwargs.pop('persistent_storage', HeapPersistentStorage(self._retry_deliver_timeout))

self._topic_alias_maximum = kwargs.get('topic_alias_maximum', 0)

asyncio.ensure_future(self._resend_qos_messages())

def _remove_message_from_query(self, mid):
Expand Down Expand Up @@ -168,7 +170,10 @@ async def disconnect(self, reason_code=0, **properties):
await self._connection.close()

def subscribe(self, topic, qos=0, **kwargs):
self._connection.subsribe(topic, qos, **kwargs)
self._connection.subscribe(topic, qos, **kwargs)

def unsubscribe(self, topic, **kwargs):
self._connection.unsubscribe(topic, **kwargs)

def publish(self, message_or_topic, payload=None, qos=0, retain=False, **kwargs):
loop = asyncio.get_event_loop()
Expand Down
5 changes: 4 additions & 1 deletion gmqtt/mqtt/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ def publish(self, message):
def send_disconnect(self, reason_code=0, **properties):
self._protocol.send_disconnect(reason_code=reason_code, **properties)

def subsribe(self, topic, qos, **kwargs):
def subscribe(self, topic, qos, **kwargs):
self._protocol.send_subscribe_packet(topic, qos, **kwargs)

def unsubscribe(self, topic, **kwargs):
self._protocol.send_unsubscribe_packet(topic, **kwargs)

def send_simple_command(self, cmd):
self._protocol.send_simple_command_packet(cmd)

Expand Down
2 changes: 1 addition & 1 deletion gmqtt/mqtt/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self, *args, **kwargs):
self._error = None
self._connection = None

self._id_generator = IdGenerator()
self._id_generator = IdGenerator(max=kwargs.get('receive_maximum', 65535))

if self.protocol_version == MQTTv50:
self._optimistic_acknowledgement = kwargs.get('optimistic_acknowledgement', True)
Expand Down
30 changes: 30 additions & 0 deletions gmqtt/mqtt/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,36 @@ def build_package(cls, client_id, username, password, clean_session, keepalive,
return packet


class UnsubscribePacket(PackageFactory):
@classmethod
def build_package(cls, topic, protocol, **kwargs) -> bytes:
remaining_length = 2
if not isinstance(topic, (list, tuple)):
topics = [topic]
else:
topics = topic

for t in topics:
remaining_length += 2 + len(t)

properties = cls._build_properties_data(kwargs, protocol.proto_ver)
remaining_length += len(properties)

command = MQTTCommands.UNSUBSCRIBE | 0x2
packet = bytearray()
packet.append(command)
packet.extend(pack_variable_byte_integer(remaining_length))
local_mid = cls.id_generator.next_id()
packet.extend(struct.pack("!H", local_mid))
packet.extend(properties)
for t in topics:
cls._pack_str16(packet, t)

logger.info('[SEND UNSUB] %s', topics)

return packet


class SubscribePacket(PackageFactory):
@classmethod
def build_package(cls, topic, qos, protocol, **kwargs) -> bytes:
Expand Down
2 changes: 1 addition & 1 deletion gmqtt/mqtt/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def factory(cls, id_=None, name=None):
Property(9, 'b', 'correlation_data', ['PUBLISH']),
Property(11, 'vbi', 'subscription_identifier', ['PUBLISH', 'SUBSCRIBE']),
Property(17, '!L', 'session_expiry_interval', ['CONNECT', ]),
Property(18, 'u8', 'client_id', ['CONNACK', ]),
Property(18, 'u8', 'assigned_client_identifier', ['CONNACK', ]),
Property(19, '!H', 'server_keep_alive', ['CONNACK']),
Property(21, 'u8', 'auth_method', ['CONNECT', 'CONNACK', 'AUTH']),
Property(23, '!B', 'request_problem_info', ['CONNECT']),
Expand Down
4 changes: 4 additions & 0 deletions gmqtt/mqtt/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def send_subscribe_packet(self, topic, qos, **kwargs):
pkg = package.SubscribePacket.build_package(topic, qos, self, **kwargs)
self.write_data(pkg)

def send_unsubscribe_packet(self, topic, **kwargs):
pkg = package.UnsubscribePacket.build_package(topic, self, **kwargs)
self.write_data(pkg)

def send_simple_command_packet(self, cmd):
pkg = package.SimpleCommandPacket.build_package(cmd)
self.write_data(pkg)
Expand Down
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
norecursedirs = env pyenv

12 changes: 12 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
atomicwrites==1.2.1
attrs==18.2.0
coverage==4.5.1
more-itertools==4.3.0
pluggy==0.7.1
py==1.6.0
pytest==3.8.0
pytest-asyncio==0.9.0
pytest-cov==2.6.0
six==1.11.0
uvloop==0.11.2
codecov==2.0.15
Empty file added tests/__init__.py
Empty file.
183 changes: 183 additions & 0 deletions tests/test_mqtt5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import asyncio

import os
import pytest

import gmqtt
from tests.utils import Callbacks, cleanup, clean_retained

host = 'mqtt.flespi.io'
port = 1883
username = os.getenv('USERNAME', 'fake_token')

TOPICS = ("TopicA", "TopicA/B", "TopicA/C", "TopicA/D", "/TopicA")
WILDTOPICS = ("TopicA/+", "+/C", "#", "/#", "/+", "+/+", "TopicA/#")
NOSUBSCRIBE_TOPICS = ("test/nosubscribe",)


@pytest.fixture()
async def init_clients():
await cleanup(host, port, username)

aclient = gmqtt.Client("myclientid", clean_session=True)
aclient.set_auth_credentials(username)
callback = Callbacks()
callback.register_for_client(aclient)

bclient = gmqtt.Client("myclientid2", clean_session=True)
bclient.set_auth_credentials(username)
callback2 = Callbacks()
callback2.register_for_client(bclient)

yield aclient, callback, bclient, callback2

await aclient.disconnect()
await bclient.disconnect()


@pytest.mark.asyncio
async def test_basic(init_clients):
aclient, callback, bclient, callback2 = init_clients

await aclient.connect(host=host, port=port, version=4)
await bclient.connect(host=host, port=port, version=4)
bclient.subscribe(TOPICS[0])
await asyncio.sleep(1)

aclient.publish(TOPICS[0], b"qos 0")
aclient.publish(TOPICS[0], b"qos 1", 1)
aclient.publish(TOPICS[0], b"qos 2", 2)
await asyncio.sleep(1)
assert len(callback2.messages) == 3


@pytest.mark.asyncio
async def test_retained_message(init_clients):
aclient, callback, bclient, callback2 = init_clients

await aclient.connect(host=host, port=port)
aclient.publish(TOPICS[1], b"ret qos 0", 0, retain=True, user_property=("a", "2"))
aclient.publish(TOPICS[2], b"ret qos 1", 1, retain=True, user_property=("c", "3"))
aclient.publish(TOPICS[3], b"ret qos 2", 2, retain=True, user_property=("a", "2"))

await asyncio.sleep(1)
await aclient.disconnect()
await asyncio.sleep(1)

await bclient.connect(host=host, port=port)
bclient.subscribe(WILDTOPICS[0], qos=2)
await asyncio.sleep(1)

assert len(callback2.messages) == 3

await clean_retained(host, port, username)


@pytest.mark.asyncio
async def test_will_message(init_clients):
aclient, callback, bclient, callback2 = init_clients

# re-initialize aclient with will message
will_message = gmqtt.Message(TOPICS[2], "I'm dead finally")
aclient = gmqtt.Client("myclientid3", clean_session=True, will_message=will_message)
aclient.set_auth_credentials(username)

await aclient.connect(host, port=port)

await bclient.connect(host=host, port=port)
bclient.subscribe(TOPICS[2])

await asyncio.sleep(1)
await aclient.disconnect(reason_code=4)
await asyncio.sleep(1)
assert len(callback2.messages) == 1


@pytest.mark.asyncio
async def test_shared_subscriptions(init_clients):
aclient, callback, bclient, callback2 = init_clients

shared_sub_topic = '$share/sharename/x'
shared_pub_topic = 'x'

await aclient.connect(host=host, port=port)
aclient.subscribe(shared_sub_topic)
aclient.subscribe(TOPICS[0])

await bclient.connect(host=host, port=port)
bclient.subscribe(shared_sub_topic)
bclient.subscribe(TOPICS[0])

pubclient = gmqtt.Client("myclient3", clean_session=True)
pubclient.set_auth_credentials(username)
await pubclient.connect(host, port)

count = 10
for i in range(count):
pubclient.publish(TOPICS[0], "message " + str(i), 0)
j = 0
while len(callback.messages) + len(callback2.messages) < 2 * count and j < 20:
await asyncio.sleep(1)
j += 1
await asyncio.sleep(1)
assert len(callback.messages) == count
assert len(callback2.messages) == count

callback.clear()
callback2.clear()

count = 10
for i in range(count):
pubclient.publish(shared_pub_topic, "message " + str(i), 0)
j = 0
while len(callback.messages) + len(callback2.messages) < count and j < 20:
await asyncio.sleep(1)
j += 1
await asyncio.sleep(1)
# Each message should only be received once
assert len(callback.messages) + len(callback2.messages) == count
assert len(callback.messages) > 0
assert len(callback2.messages) > 0


@pytest.mark.asyncio
async def test_assigned_clientid():
noidclient = gmqtt.Client("", clean_session=True)
noidclient.set_auth_credentials(username)
callback = Callbacks()
callback.register_for_client(noidclient)
await noidclient.connect(host=host, port=port)
await noidclient.disconnect()
assert callback.connack[2]['assigned_client_identifier'][0] != ""


@pytest.mark.asyncio
async def test_unsubscribe(init_clients):
aclient, callback, bclient, callback2 = init_clients
await bclient.connect(host=host, port=port)
await aclient.connect(host=host, port=port)

bclient.subscribe(TOPICS[1])
bclient.subscribe(TOPICS[2])
bclient.subscribe(TOPICS[3])
await asyncio.sleep(1)
print(callback2.messages)

aclient.publish(TOPICS[1], b"topic 0 - subscribed", 1, retain=False)
aclient.publish(TOPICS[2], b"topic 1", 1, retain=False)
aclient.publish(TOPICS[3], b"topic 2", 1, retain=False)
await asyncio.sleep(1)
print(callback2.messages)
assert len(callback2.messages) == 3
callback2.clear()
# Unsubscribe from one topic
bclient.unsubscribe(TOPICS[1])
await asyncio.sleep(3)

aclient.publish(TOPICS[1], b"topic 0 - unsubscribed", 1, retain=False)
aclient.publish(TOPICS[2], b"topic 1", 1, retain=False)
aclient.publish(TOPICS[3], b"topic 2", 1, retain=False)
await asyncio.sleep(1)

assert len(callback2.messages) == 2

Loading

0 comments on commit e30baec

Please sign in to comment.