diff --git a/PubSubClient/PubSubClient.cpp b/PubSubClient/PubSubClient.cpp index f26ac4c..b2764b9 100755 --- a/PubSubClient/PubSubClient.cpp +++ b/PubSubClient/PubSubClient.cpp @@ -29,21 +29,21 @@ PubSubClient::PubSubClient(char* domain, uint16_t port, void (*callback)(char*,u this->stream = NULL; } -PubSubClient::PubSubClient(uint8_t *ip, uint16_t port, void (*callback)(char*,uint8_t*,unsigned int), Client& client, Stream *stream) { +PubSubClient::PubSubClient(uint8_t *ip, uint16_t port, void (*callback)(char*,uint8_t*,unsigned int), Client& client, Stream& stream) { this->_client = &client; this->callback = callback; this->ip = ip; this->port = port; this->domain = NULL; - this->stream = stream; + this->stream = &stream; } -PubSubClient::PubSubClient(char* domain, uint16_t port, void (*callback)(char*,uint8_t*,unsigned int), Client& client, Stream *stream) { +PubSubClient::PubSubClient(char* domain, uint16_t port, void (*callback)(char*,uint8_t*,unsigned int), Client& client, Stream& stream) { this->_client = &client; this->callback = callback; this->domain = domain; this->port = port; - this->stream = stream; + this->stream = &stream; } boolean PubSubClient::connect(char *id) { @@ -144,6 +144,7 @@ uint8_t PubSubClient::readByte() { uint16_t PubSubClient::readPacket(uint8_t* lengthLength) { uint16_t len = 0; buffer[len++] = readByte(); + bool isPublish = (buffer[0]&0xF0) == MQTTPUBLISH; uint32_t multiplier = 1; uint16_t length = 0; uint8_t digit = 0; @@ -158,7 +159,7 @@ uint16_t PubSubClient::readPacket(uint8_t* lengthLength) { } while ((digit & 128) != 0); *lengthLength = len-1; - if ((buffer[0]&0xF0) == MQTTPUBLISH) { + if (isPublish) { // Read in topic length to calculate bytes to skip over for Stream writing buffer[len++] = readByte(); buffer[len++] = readByte(); @@ -172,18 +173,21 @@ uint16_t PubSubClient::readPacket(uint8_t* lengthLength) { for (uint16_t i = start;istream && ((buffer[0]&0xF0) == MQTTPUBLISH) && len-*lengthLength-2>skip) { - this->stream->write(digit); + if (this->stream) { + if (isPublish && len-*lengthLength-2>skip) { + this->stream->write(digit); + } } if (len < MQTT_MAX_PACKET_SIZE) { - buffer[len++] = digit; - } else if (!this->stream) { - len = 0; // This will cause the packet to be ignored. + buffer[len] = digit; } + len++; + } + + if (!this->stream && len > MQTT_MAX_PACKET_SIZE) { + len = 0; // This will cause the packet to be ignored. } - // If a stream has been provided, indicate that we wrote the whole length, - // else return 0 if the length exceed the max packet size return len; } @@ -220,11 +224,18 @@ boolean PubSubClient::loop() { } topic[tl] = 0; // msgId only present for QOS>0 - if (buffer[0]&MQTTQOS1) { + if ((buffer[0]&0x06) == MQTTQOS1) { msgId = (buffer[llen+3+tl]<<8)+buffer[llen+3+tl+1]; payload = buffer+llen+3+tl+2; callback(topic,payload,len-llen-3-tl-2); - puback(msgId); + + buffer[0] = MQTTPUBACK; + buffer[1] = 2; + buffer[2] = (msgId >> 8); + buffer[3] = (msgId & 0xFF); + _client->write(buffer,4); + lastOutActivity = t; + } else { payload = buffer+llen+3+tl; callback(topic,payload,len-llen-3-tl); @@ -366,17 +377,6 @@ boolean PubSubClient::subscribe(char* topic, uint8_t qos) { return false; } -boolean PubSubClient::puback(uint16_t msgId) { - if(connected()) { - // Leave room in the buffer for header and variable length field - uint16_t length = 5; - buffer[length++] = (msgId >> 8); - buffer[length++] = (msgId & 0xFF); - return write(MQTTPUBACK,buffer,length-5); - } - return false; -} - boolean PubSubClient::unsubscribe(char* topic) { if (connected()) { uint16_t length = 5; diff --git a/PubSubClient/PubSubClient.h b/PubSubClient/PubSubClient.h index 28f4f67..215be95 100755 --- a/PubSubClient/PubSubClient.h +++ b/PubSubClient/PubSubClient.h @@ -51,7 +51,6 @@ private: uint8_t readByte(); boolean write(uint8_t header, uint8_t* buf, uint16_t length); uint16_t writeString(char* string, uint8_t* buf, uint16_t pos); - boolean puback(uint16_t msgId); uint8_t *ip; char* domain; uint16_t port; @@ -59,9 +58,9 @@ private: public: PubSubClient(); PubSubClient(uint8_t *, uint16_t, void(*)(char*,uint8_t*,unsigned int),Client& client); - PubSubClient(uint8_t *, uint16_t, void(*)(char*,uint8_t*,unsigned int),Client& client, Stream*); + PubSubClient(uint8_t *, uint16_t, void(*)(char*,uint8_t*,unsigned int),Client& client, Stream&); PubSubClient(char*, uint16_t, void(*)(char*,uint8_t*,unsigned int),Client& client); - PubSubClient(char*, uint16_t, void(*)(char*,uint8_t*,unsigned int),Client& client, Stream*); + PubSubClient(char*, uint16_t, void(*)(char*,uint8_t*,unsigned int),Client& client, Stream&); boolean connect(char *); boolean connect(char *, char *, char *); boolean connect(char *, char *, uint8_t, uint8_t, char *); diff --git a/PubSubClient/examples/mqtt_stream/mqtt_stream.ino b/PubSubClient/examples/mqtt_stream/mqtt_stream.ino index 0bc4279..a416ee2 100644 --- a/PubSubClient/examples/mqtt_stream/mqtt_stream.ino +++ b/PubSubClient/examples/mqtt_stream/mqtt_stream.ino @@ -35,7 +35,7 @@ void callback(char* topic, byte* payload, unsigned int length) { } EthernetClient ethClient; -PubSubClient client(server, 1883, callback, ethClient, &sram); +PubSubClient client(server, 1883, callback, ethClient, sram); void setup() { diff --git a/tests/src/keepalive_spec.cpp b/tests/src/keepalive_spec.cpp new file mode 100644 index 0000000..ec51868 --- /dev/null +++ b/tests/src/keepalive_spec.cpp @@ -0,0 +1,177 @@ +#include "PubSubClient.h" +#include "ShimClient.h" +#include "Buffer.h" +#include "BDDTest.h" +#include "trace.h" + + +byte server[] = { 172, 16, 0, 2 }; + +void callback(char* topic, byte* payload, unsigned int length) { + // handle message arrived +} + + +int test_keepalive_pings_idle() { + IT("keeps an idle connection alive"); + + ShimClient shimClient; + shimClient.setAllowConnect(true); + + byte connack[] = { 0x20, 0x02, 0x00, 0x00 }; + shimClient.respond(connack,4); + + PubSubClient client(server, 1883, callback, shimClient); + int rc = client.connect((char*)"client_test1"); + IS_TRUE(rc); + + byte pingreq[] = { 0xC0,0x0 }; + shimClient.expect(pingreq,2); + byte pingresp[] = { 0xD0,0x0 }; + shimClient.respond(pingresp,2); + + for (int i = 0; i < 50; i++) { + sleep(1); + rc = client.loop(); + IS_TRUE(rc); + } + + IS_FALSE(shimClient.error()); + + END_IT +} + +int test_keepalive_pings_with_outbound_qos0() { + IT("keeps a connection alive that only sends qos0"); + + ShimClient shimClient; + shimClient.setAllowConnect(true); + + byte connack[] = { 0x20, 0x02, 0x00, 0x00 }; + shimClient.respond(connack,4); + + PubSubClient client(server, 1883, callback, shimClient); + int rc = client.connect((char*)"client_test1"); + IS_TRUE(rc); + + byte publish[] = {0x30,0xe,0x0,0x5,0x74,0x6f,0x70,0x69,0x63,0x70,0x61,0x79,0x6c,0x6f,0x61,0x64}; + + for (int i = 0; i < 50; i++) { + TRACE(i<<":"); + shimClient.expect(publish,16); + rc = client.publish((char*)"topic",(char*)"payload"); + IS_TRUE(rc); + IS_FALSE(shimClient.error()); + sleep(1); + if ( i == 15 || i == 31 || i == 47) { + byte pingreq[] = { 0xC0,0x0 }; + shimClient.expect(pingreq,2); + byte pingresp[] = { 0xD0,0x0 }; + shimClient.respond(pingresp,2); + } + rc = client.loop(); + IS_TRUE(rc); + IS_FALSE(shimClient.error()); + } + + END_IT +} + +int test_keepalive_pings_with_inbound_qos0() { + IT("keeps a connection alive that only receives qos0"); + + ShimClient shimClient; + shimClient.setAllowConnect(true); + + byte connack[] = { 0x20, 0x02, 0x00, 0x00 }; + shimClient.respond(connack,4); + + PubSubClient client(server, 1883, callback, shimClient); + int rc = client.connect((char*)"client_test1"); + IS_TRUE(rc); + + byte publish[] = {0x30,0xe,0x0,0x5,0x74,0x6f,0x70,0x69,0x63,0x70,0x61,0x79,0x6c,0x6f,0x61,0x64}; + + for (int i = 0; i < 50; i++) { + TRACE(i<<":"); + sleep(1); + if ( i == 15 || i == 31 || i == 47) { + byte pingreq[] = { 0xC0,0x0 }; + shimClient.expect(pingreq,2); + byte pingresp[] = { 0xD0,0x0 }; + shimClient.respond(pingresp,2); + } + shimClient.respond(publish,16); + rc = client.loop(); + IS_TRUE(rc); + IS_FALSE(shimClient.error()); + } + + END_IT +} + +int test_keepalive_no_pings_inbound_qos1() { + IT("does not send pings for connections with inbound qos1"); + + ShimClient shimClient; + shimClient.setAllowConnect(true); + + byte connack[] = { 0x20, 0x02, 0x00, 0x00 }; + shimClient.respond(connack,4); + + PubSubClient client(server, 1883, callback, shimClient); + int rc = client.connect((char*)"client_test1"); + IS_TRUE(rc); + + byte publish[] = {0x32,0x10,0x0,0x5,0x74,0x6f,0x70,0x69,0x63,0x12,0x34,0x70,0x61,0x79,0x6c,0x6f,0x61,0x64}; + byte puback[] = {0x40,0x2,0x12,0x34}; + + for (int i = 0; i < 50; i++) { + shimClient.respond(publish,18); + shimClient.expect(puback,4); + sleep(1); + rc = client.loop(); + IS_TRUE(rc); + IS_FALSE(shimClient.error()); + } + + END_IT +} + +int test_keepalive_disconnects_hung() { + IT("disconnects a hung connection"); + + ShimClient shimClient; + shimClient.setAllowConnect(true); + + byte connack[] = { 0x20, 0x02, 0x00, 0x00 }; + shimClient.respond(connack,4); + + PubSubClient client(server, 1883, callback, shimClient); + int rc = client.connect((char*)"client_test1"); + IS_TRUE(rc); + + byte pingreq[] = { 0xC0,0x0 }; + shimClient.expect(pingreq,2); + + for (int i = 0; i < 32; i++) { + sleep(1); + rc = client.loop(); + } + IS_FALSE(rc); + + IS_FALSE(shimClient.error()); + + END_IT +} + +int main() +{ + test_keepalive_pings_idle(); + test_keepalive_pings_with_outbound_qos0(); + test_keepalive_pings_with_inbound_qos0(); + test_keepalive_no_pings_inbound_qos1(); + test_keepalive_disconnects_hung(); + + FINISH +} diff --git a/tests/src/lib/Stream.cpp b/tests/src/lib/Stream.cpp new file mode 100644 index 0000000..b0ecbb4 --- /dev/null +++ b/tests/src/lib/Stream.cpp @@ -0,0 +1,39 @@ +#include "Stream.h" +#include "trace.h" +#include +#include + +Stream::Stream() { + this->expectBuffer = new Buffer(); + this->_error = false; + this->_written = 0; +} + +size_t Stream::write(uint8_t b) { + this->_written++; + TRACE(std::hex << (unsigned int)b); + if (this->expectBuffer->available()) { + uint8_t expected = this->expectBuffer->next(); + if (expected != b) { + this->_error = true; + TRACE("!=" << (unsigned int)expected); + } + } else { + this->_error = true; + } + TRACE("\n"<< std::dec); + return 1; +} + + +bool Stream::error() { + return this->_error; +} + +void Stream::expect(uint8_t *buf, size_t size) { + this->expectBuffer->add(buf,size); +} + +uint16_t Stream::length() { + return this->_written; +} diff --git a/tests/src/lib/Stream.h b/tests/src/lib/Stream.h index 4447d16..4e41f86 100644 --- a/tests/src/lib/Stream.h +++ b/tests/src/lib/Stream.h @@ -1,93 +1,22 @@ -/* - Stream.h - base class for character-based streams. - Copyright (c) 2010 David A. Mellis. All right reserved. - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. - - You should have received a copy of the GNU Lesser General Public - License along with this library; if not, write to the Free Software - Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA - - parsing functions based on TextFinder library by Michael Margolis -*/ - #ifndef Stream_h #define Stream_h -#include +#include "Arduino.h" +#include "Buffer.h" -// compatability macros for testing -/* -#define getInt() parseInt() -#define getInt(skipChar) parseInt(skipchar) -#define getFloat() parseFloat() -#define getFloat(skipChar) parseFloat(skipChar) -#define getString( pre_string, post_string, buffer, length) -readBytesBetween( pre_string, terminator, buffer, length) -*/ +class Stream { +private: + Buffer* expectBuffer; + bool _error; + uint16_t _written; -class Stream -{ - protected: - unsigned long _timeout; // number of milliseconds to wait for the next char before aborting timed read - unsigned long _startMillis; // used for timeout measurement - int timedRead(); // private method to read stream with timeout - int timedPeek(); // private method to peek stream with timeout - int peekNextDigit(); // returns the next numeric digit in the stream or -1 if timeout - - public: - virtual int available() = 0; - virtual int read() = 0; - virtual int peek() = 0; - virtual void flush() = 0; - virtual size_t write(uint8_t) = 0; +public: + Stream(); + virtual size_t write(uint8_t); - Stream() {_timeout=1000;} - -// parsing methods - - void setTimeout(unsigned long timeout); // sets maximum milliseconds to wait for stream data, default is 1 second - - bool find(char *target); // reads data from the stream until the target string is found - // returns true if target string is found, false if timed out (see setTimeout) - - bool find(char *target, size_t length); // reads data from the stream until the target string of given length is found - // returns true if target string is found, false if timed out - - bool findUntil(char *target, char *terminator); // as find but search ends if the terminator string is found - - bool findUntil(char *target, size_t targetLen, char *terminate, size_t termLen); // as above but search ends if the terminate string is found - - - long parseInt(); // returns the first valid (long) integer value from the current position. - // initial characters that are not digits (or the minus sign) are skipped - // integer is terminated by the first character that is not a digit. - - float parseFloat(); // float version of parseInt - - size_t readBytes( char *buffer, size_t length); // read chars from stream into buffer - // terminates if length characters have been read or timeout (see setTimeout) - // returns the number of characters placed in the buffer (0 means no valid data found) - - size_t readBytesUntil( char terminator, char *buffer, size_t length); // as readBytes with terminator character - // terminates if length characters have been read, timeout, or if the terminator character detected - // returns the number of characters placed in the buffer (0 means no valid data found) - - - protected: - long parseInt(char skipChar); // as above but the given skipChar is ignored - // as above but the given skipChar is ignored - // this allows format characters (typically commas) in values to be ignored - - float parseFloat(char skipChar); // as above but the given skipChar is ignored + virtual bool error(); + virtual void expect(uint8_t *buf, size_t size); + virtual uint16_t length(); }; #endif diff --git a/tests/src/receive_spec.cpp b/tests/src/receive_spec.cpp index a92852e..42ace54 100644 --- a/tests/src/receive_spec.cpp +++ b/tests/src/receive_spec.cpp @@ -58,9 +58,36 @@ int test_receive_callback() { } int test_receive_stream() { - IT("receives a stream payload"); - // TODO - IS_FALSE(true); + IT("receives a streamed callback message"); + reset_callback(); + + Stream stream; + stream.expect((uint8_t*)"payload",7); + + ShimClient shimClient; + shimClient.setAllowConnect(true); + + byte connack[] = { 0x20, 0x02, 0x00, 0x00 }; + shimClient.respond(connack,4); + + PubSubClient client(server, 1883, callback, shimClient, stream); + int rc = client.connect((char*)"client_test1"); + IS_TRUE(rc); + + byte publish[] = {0x30,0xe,0x0,0x5,0x74,0x6f,0x70,0x69,0x63,0x70,0x61,0x79,0x6c,0x6f,0x61,0x64}; + shimClient.respond(publish,16); + + rc = client.loop(); + + IS_TRUE(rc); + + IS_TRUE(callback_called); + IS_TRUE(strcmp(lastTopic,"topic")==0); + IS_TRUE(lastLength == 7); + + IS_FALSE(stream.error()); + IS_FALSE(shimClient.error()); + END_IT } @@ -82,6 +109,7 @@ int test_receive_max_sized_message() { byte publish[] = {0x30,length-2,0x0,0x5,0x74,0x6f,0x70,0x69,0x63,0x70,0x61,0x79,0x6c,0x6f,0x61,0x64}; byte bigPublish[length]; memset(bigPublish,'A',length); + bigPublish[length] = 'B'; memcpy(bigPublish,publish,16); shimClient.respond(bigPublish,length); @@ -117,8 +145,9 @@ int test_receive_oversized_message() { byte publish[] = {0x30,length-2,0x0,0x5,0x74,0x6f,0x70,0x69,0x63,0x70,0x61,0x79,0x6c,0x6f,0x61,0x64}; byte bigPublish[length]; memset(bigPublish,'A',length); + bigPublish[length] = 'B'; memcpy(bigPublish,publish,16); - shimClient.respond(publish,length); + shimClient.respond(bigPublish,length); rc = client.loop(); @@ -131,6 +160,46 @@ int test_receive_oversized_message() { END_IT } +int test_receive_oversized_stream_message() { + IT("drops an oversized message"); + reset_callback(); + + Stream stream; + + ShimClient shimClient; + shimClient.setAllowConnect(true); + + byte connack[] = { 0x20, 0x02, 0x00, 0x00 }; + shimClient.respond(connack,4); + + PubSubClient client(server, 1883, callback, shimClient, stream); + int rc = client.connect((char*)"client_test1"); + IS_TRUE(rc); + + byte length = MQTT_MAX_PACKET_SIZE+1; + byte publish[] = {0x30,length-2,0x0,0x5,0x74,0x6f,0x70,0x69,0x63,0x70,0x61,0x79,0x6c,0x6f,0x61,0x64}; + + byte bigPublish[length]; + memset(bigPublish,'A',length); + bigPublish[length] = 'B'; + memcpy(bigPublish,publish,16); + + shimClient.respond(bigPublish,length); + stream.expect(bigPublish+9,length-9); + + rc = client.loop(); + + IS_TRUE(rc); + + IS_TRUE(callback_called); + IS_TRUE(strcmp(lastTopic,"topic")==0); + IS_TRUE(lastLength == length-9); + + IS_FALSE(stream.error()); + IS_FALSE(shimClient.error()); + + END_IT +} int test_receive_qos1() { IT("receives a qos1 message"); @@ -172,6 +241,7 @@ int main() test_receive_stream(); test_receive_max_sized_message(); test_receive_oversized_message(); + test_receive_oversized_stream_message(); test_receive_qos1(); FINISH