diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..2fbeced --- /dev/null +++ b/tests/README.md @@ -0,0 +1,3 @@ +# Arduino Client for MQTT Test Suite + + diff --git a/tests/testcases/mqtt_basic.py b/tests/testcases/mqtt_basic.py index e3b0022..1b0cc65 100644 --- a/tests/testcases/mqtt_basic.py +++ b/tests/testcases/mqtt_basic.py @@ -1,12 +1,43 @@ import unittest +import settings + +import time +import mosquitto + +import serial + +def on_message(mosq, obj, msg): + obj.message_queue.append(msg) class mqtt_basic(unittest.TestCase): - def setUp(self): - pass - - - def test_one(self): - self.assertEqual(3,3) + + message_queue = [] + + @classmethod + def setUpClass(self): + self.client = mosquitto.Mosquitto("pubsubclient_ut", clean_session=True,obj=self) + self.client.connect(settings.server_ip) + self.client.on_message = on_message + self.client.subscribe("outTopic",0) + + @classmethod + def tearDownClass(self): + self.client.disconnect() + + def test_one(self): + i=30 + while len(self.message_queue) == 0 and i > 0: + self.client.loop() + time.sleep(0.5) + i -= 1 + self.assertTrue(i>0, "message receive timed-out") + self.assertEqual(len(self.message_queue), 1, "unexpected number of messages received") + msg = self.message_queue[0] + self.assertEqual(msg.mid,0,"message id not 0") + self.assertEqual(msg.topic,"outTopic","message topic incorrect") + self.assertEqual(msg.payload,"hello world") + self.assertEqual(msg.qos,0,"message qos not 0") + self.assertEqual(msg.retain,False,"message retain flag incorrect") + + - def test_two(self): - self.assertEqual(4,4) diff --git a/tests/testcases/mqtt_publish_in_callback.py b/tests/testcases/mqtt_publish_in_callback.py new file mode 100644 index 0000000..7989f7f --- /dev/null +++ b/tests/testcases/mqtt_publish_in_callback.py @@ -0,0 +1,64 @@ +import unittest +import settings + +import time +import mosquitto + +import serial + +def on_message(mosq, obj, msg): + obj.message_queue.append(msg) + +class mqtt_publish_in_callback(unittest.TestCase): + + message_queue = [] + + @classmethod + def setUpClass(self): + self.client = mosquitto.Mosquitto("pubsubclient_ut", clean_session=True,obj=self) + self.client.connect(settings.server_ip) + self.client.on_message = on_message + self.client.subscribe("outTopic",0) + + @classmethod + def tearDownClass(self): + self.client.disconnect() + + def test_connect(self): + i=30 + while len(self.message_queue) == 0 and i > 0: + self.client.loop() + time.sleep(0.5) + i -= 1 + self.assertTrue(i>0, "message receive timed-out") + self.assertEqual(len(self.message_queue), 1, "unexpected number of messages received") + msg = self.message_queue.pop(0) + self.assertEqual(msg.mid,0,"message id not 0") + self.assertEqual(msg.topic,"outTopic","message topic incorrect") + self.assertEqual(msg.payload,"hello world") + self.assertEqual(msg.qos,0,"message qos not 0") + self.assertEqual(msg.retain,False,"message retain flag incorrect") + + + def test_publish(self): + self.assertEqual(len(self.message_queue), 0, "message queue not empty") + payload = "abcdefghij" + self.client.publish("inTopic",payload) + + i=30 + while len(self.message_queue) == 0 and i > 0: + self.client.loop() + time.sleep(0.5) + i -= 1 + + self.assertTrue(i>0, "message receive timed-out") + self.assertEqual(len(self.message_queue), 1, "unexpected number of messages received") + msg = self.message_queue.pop(0) + self.assertEqual(msg.mid,0,"message id not 0") + self.assertEqual(msg.topic,"outTopic","message topic incorrect") + self.assertEqual(msg.payload,payload) + self.assertEqual(msg.qos,0,"message qos not 0") + self.assertEqual(msg.retain,False,"message retain flag incorrect") + + + diff --git a/tests/testcases/settings.py b/tests/testcases/settings.py new file mode 100644 index 0000000..4ad8719 --- /dev/null +++ b/tests/testcases/settings.py @@ -0,0 +1,2 @@ +server_ip = "172.16.0.2" +arduino_ip = "172.16.0.100" diff --git a/tests/testsuite.py b/tests/testsuite.py index 59b726b..c86f707 100644 --- a/tests/testsuite.py +++ b/tests/testsuite.py @@ -6,6 +6,9 @@ import shutil from subprocess import call import importlib import unittest +import re + +from testcases import settings class Workspace(object): @@ -66,7 +69,23 @@ class Sketch(object): def build(self): sys.stdout.write(" Build: ") sys.stdout.flush() - shutil.copy(self.filename,os.path.join(self.w.build_dir,"src","sketch.ino")) + + # Copy sketch over, replacing IP addresses as necessary + fin = open(self.filename,"r") + lines = fin.readlines() + fin.close() + fout = open(os.path.join(self.w.build_dir,"src","sketch.ino"),"w") + for l in lines: + if re.match(r"^byte server\[\]",l): + fout.write("byte server[] = { %s };\n"%(settings.server_ip.replace(".",", "),)) + elif re.match(r"^byte ip\[\]",l): + fout.write("byte ip[] = { %s };\n"%(settings.arduino_ip.replace(".",", "),)) + else: + fout.write(l) + fout.flush() + fout.close() + + # Run build fout = open(self.build_log, "w") ferr = open(self.build_err_log, "w") rc = call(["ino","build"],stdout=fout,stderr=ferr) @@ -104,6 +123,7 @@ class Sketch(object): def test(self): + # import the matching test case, if it exists try: basename = os.path.basename(self.filename)[:-4] i = importlib.import_module("testcases."+basename) @@ -111,25 +131,33 @@ class Sketch(object): sys.stdout.write(" Test: no tests found") sys.stdout.write("\n") return + c = getattr(i,basename) + + testmethods = [m for m in dir(c) if m.startswith("test_")] + testmethods.sort() + tests = [] + for m in testmethods: + tests.append(c(m)) + + result = unittest.TestResult() + c.setUpClass() if self.upload(): sys.stdout.write(" Test: ") sys.stdout.flush() - c = getattr(i,basename) - suite = unittest.makeSuite(c,'test') - result = unittest.TestResult() - suite.run(result) - print "%d/%d"%(result.testsRun-len(result.failures)-len(result.errors),result.testsRun) - if not result.wasSuccessful(): - if len(result.failures) > 0: - for f in result.failures: - print "-- %s"%(str(f[0]),) - print f[1] - if len(result.errors) > 0: - print " Errors:" - for f in result.errors: - print "-- %s"%(str(f[0]),) - print f[1] - + for t in tests: + t.run(result) + print "%d/%d"%(result.testsRun-len(result.failures)-len(result.errors),result.testsRun) + if not result.wasSuccessful(): + if len(result.failures) > 0: + for f in result.failures: + print "-- %s"%(str(f[0]),) + print f[1] + if len(result.errors) > 0: + print " Errors:" + for f in result.errors: + print "-- %s"%(str(f[0]),) + print f[1] + c.tearDownClass() if __name__ == '__main__': run_tests = True