Skip to content

Instantly share code, notes, and snippets.

@brimston3
Created September 12, 2017 15:08
Show Gist options
  • Star 13 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save brimston3/aaf728d357ff4d6ac76a871e02adb8c6 to your computer and use it in GitHub Desktop.
Save brimston3/aaf728d357ff4d6ac76a871e02adb8c6 to your computer and use it in GitHub Desktop.
paho.mqtt.embedded-c (C++) with mbedTLS example
/*******************************************************************************
* Copyright 2017 Andrew Domaszek
*
* All rights reserved.
* This program made available under BSD-new.
*******************************************************************************/
/**
* This example is designed for Linux, using such calls as setsockopt and gettimeofday.
* On embedded, it's likely that all of the mbedtls_net_* functions would need to be
* handled by the embedded IP stack and the conn_ctx member would need replacing.
*/
#include "MQTTTransport_mbedTLS.h"
#include <assert.h>
#include <stdio.h>
#include <stdarg.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <mbedtls/error.h>
#include <sys/time.h>
#define TLS_CA_CERTIFICATE_PATH "/etc/mosquitto/certs/ca.crt"
//#define TLS_NONSTANDARD_SERVER_CN "Example Customer Svr 1 X1"
#define tlstrans_LOGERR(...) fprintf(stderr, ## __VA_ARGS__)
#define tlstrans_LOGDEBUG(...) _static_debug_print(stdout, 1, __FILE__, __LINE__, ## __VA_ARGS__)
#define tlstrans_LOG(...) printf(__VA_ARGS__)
#define UNUSED_VAR(x) ((void)x)
//#define ENABLE_TIMECHECKING
static void _static_debug( void *ctx, int level,
const char *file, int line, const char *str )
{
UNUSED_VAR(level);
fprintf( (FILE *) ctx, "%s:%04d: %s", file, line, str );
fflush( (FILE *) ctx );
}
//! @BUG: 1k buff is not small stack friendly.
static void _static_debug_print( void *ctx, int level, const char * file, int line, const char * format, ... )
{
char buff[1024] = "";
char * ostr = buff;
va_list args,cnt_args;
va_start(args, format);
va_copy(cnt_args, args);
int cnt = vsnprintf(NULL, 0, format, cnt_args);
if (cnt > sizeof(buff)-1) {
ostr = (char*)malloc(cnt+1);
if (!ostr)
{
fputs("heap allocation failure in print\n", (FILE *)ctx);
abort();
}
}
vsnprintf(ostr, cnt+1, format, args);
va_end(args);
_static_debug(ctx, level, file, line, ostr);
if (ostr != buff)
free(ostr);
}
MQTTTransport_mbedTLS::MQTTTransport_mbedTLS() :
read_timeout_ms(400)
{
getfn = MQTTTransport_mbedTLS::tlsFetchData;
sck = this;
state = 0;
mbedtls_net_init( &conn_ctx );
mbedtls_ssl_init( &ssl );
mbedtls_ssl_config_init( &conf );
mbedtls_x509_crt_init( &cacert );
mbedtls_ctr_drbg_init( &ctr_drbg );
mbedtls_entropy_init( &entropy );
mbedtls_x509_crt_parse_file( &cacert, TLS_CA_CERTIFICATE_PATH ); //! @BUG: Remove hardcoded path
}
MQTTTransport_mbedTLS::~MQTTTransport_mbedTLS()
{
mbedtls_x509_crt_free( &cacert );
mbedtls_ssl_free( &ssl );
mbedtls_ssl_config_free( &conf );
mbedtls_ctr_drbg_free( &ctr_drbg );
mbedtls_entropy_free( &entropy );
}
int MQTTTransport_mbedTLS::tlsFetchData(void * self, unsigned char * out, int bytesMax)
{
MQTTTransport_mbedTLS * ctx = static_cast<MQTTTransport_mbedTLS*>(self);
int rc = ctx->read(out, bytesMax, ctx->read_timeout_ms);
if (rc == -1)
{
// need to check something here to make sure it's not a terminal TLS error.
}
return rc;
}
int MQTTTransport_mbedTLS::connect(const char * hostname, int port)
{
assert(port > 0 && port < USHRT_MAX);
char port_str[6];
snprintf(port_str, sizeof(port_str), "%d", port); //! @TODO: itoa() instead?
const char * FUNC_NAME = "";
int rc = -1;
// Can provide personalization identifier in arg 4 & 5 for more entropy.
FUNC_NAME = "mbedtls_ctr_drbg_seed";
if( ( rc = mbedtls_ctr_drbg_seed( &ctr_drbg, mbedtls_entropy_func, &entropy,
(const unsigned char *) NULL,
0 ) ) != 0 )
goto error_out;
/*
* Start the connection
*/
tlstrans_LOG( "\n - Connecting to tcp/%s/%s...", hostname, port_str );
fflush( stdout );
FUNC_NAME = "mbedtls_net_connect";
if( ( rc = mbedtls_net_connect( &conn_ctx, hostname,
port_str, MBEDTLS_NET_PROTO_TCP ) ) != 0 )
goto error_out;
tlstrans_LOG( " ok\n" );
FUNC_NAME = "mbedtls_ssl_config_defaults";
if( ( rc = mbedtls_ssl_config_defaults( &conf,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT ) ) != 0 )
goto error_out;
mbedtls_ssl_conf_authmode( &conf, MBEDTLS_SSL_VERIFY_OPTIONAL ); //! @BUG: NO SECURITY.
mbedtls_ssl_conf_ca_chain( &conf, &cacert, NULL ); // should only be set if VERIFY OPTIONAL or REQUIRED.
mbedtls_ssl_conf_rng( &conf, mbedtls_ctr_drbg_random, &ctr_drbg );
mbedtls_ssl_conf_dbg( &conf, _static_debug, stdout ); // debug callback defined above.
FUNC_NAME = "mbedtls_ssl_setup";
if( ( rc = mbedtls_ssl_setup( &ssl, &conf ) ) != 0 )
goto error_out;
FUNC_NAME = "mbedtls_ssl_set_hostname";
#if defined(TLS_NONSTANDARD_SERVER_CN)
if( ( rc = mbedtls_ssl_set_hostname( &ssl, TLS_NONSTANDARD_SERVER_CN ) ) != 0 ) //! @BUG: hardcoded CN
goto error_out;
#else
if( ( rc = mbedtls_ssl_set_hostname( &ssl, hostname ) ) != 0 ) //! @TODO: Verify this is necessary and not default behavior.
goto error_out;
#endif
mbedtls_ssl_set_bio( &ssl, &conn_ctx, mbedtls_net_send, mbedtls_net_recv, NULL );
/*
* 4. TLS Handshake and verification
*/
tlstrans_LOG( " - Performing the SSL/TLS handshake..." );
fflush(stdout);
FUNC_NAME = "mbedtls_ssl_handshake";
while( ( rc = mbedtls_ssl_handshake( &ssl ) ) != 0 )
{
if( rc != MBEDTLS_ERR_SSL_WANT_READ && rc != MBEDTLS_ERR_SSL_WANT_WRITE )
goto error_out;
}
tlstrans_LOG(" success (%d).\n", rc);
rc = mbedtls_ssl_get_verify_result( &ssl ); //! @BUG: if this is non-zero, it should abort. Maybe instead use VERIFY_REQUIRED?
tlstrans_LOG("CN verify result: %d\n", rc);
return 0;
error_out:
tlstrans_LOGERR( " failed\n ! %s returned %d\n", FUNC_NAME, rc);
return rc;
}
int MQTTTransport_mbedTLS::disconnect()
{
mbedtls_net_free(&conn_ctx);
return 0;
}
int MQTTTransport_mbedTLS::read(tlsDataType * buffer, int len, int timeout_ms)
{
struct timeval tv;
tv.tv_sec = timeout_ms / 1000;
tv.tv_usec = (timeout_ms % 1000) * 1000;
assert(tv.tv_sec >= 0 && tv.tv_usec >= 0);
#if defined(ENABLE_TIMECHECKING)
struct timeval tv1,tv2;
gettimeofday(&tv1,NULL);
#endif
setsockopt(conn_ctx.fd, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv, sizeof(struct timeval));
//tlstrans_LOGDEBUG ("Reading %d bytes into %p in %d ms\n", len, buffer, timeout_ms);
int bytes = 0;
while (bytes < len)
{
//int rc = recv(n->my_socket, &buffer[bytes], (size_t)(len - bytes), 0);
int rc = mbedtls_ssl_read( &ssl, &buffer[bytes], (size_t)(len - bytes) );
if (rc <= -1)
{
char errbuf[256];
if (rc == -0x4c) { // seems to be what mbedtls_ssl_read provides when not expecting to timeout...
bytes = 0;
break;
}
// Otherwise, let's print an error.
mbedtls_strerror(rc, errbuf, sizeof(errbuf));
tlstrans_LOGERR("mbedtls_ssl_read returned error -0x%x: %.*s\n", -rc, sizeof(errbuf), errbuf);
// any error from mbedtls_ssl_read is terminal and the connection must be closed.
//if (errno != ENOTCONN && errno != ECONNRESET)
//{
bytes = -1;
break;
//}
}
else if (rc == 0)
{
bytes = 0;
break;
}
else
bytes += rc;
}
#if defined(ENABLE_TIMECHECKING)
gettimeofday(&tv2,NULL);
float sec = ((tv2.tv_sec - tv1.tv_sec) * 1000.0 + (tv2.tv_usec - tv1.tv_usec) / 1000) / 1000;
//tlstrans_LOG ("...got %d in %.3f sec\n", bytes, sec);
assert(!(bytes == 0 && sec < 0.001)); // This is a weird error with PINGRESP calling a 0-byte read.
#endif
return bytes;
}
int MQTTTransport_mbedTLS::write(tlsDataType * buffer, int len, int timeout_ms)
{
struct timeval tv;
tv.tv_sec = timeout_ms / 1000;
tv.tv_usec = (timeout_ms % 1000) * 1000;
assert(tv.tv_sec >= 0 && tv.tv_usec >= 0);
setsockopt(conn_ctx.fd, SOL_SOCKET, SO_SNDTIMEO, (char *)&tv,sizeof(struct timeval));
if (timeout_ms == 0) {
// if timeout_ms == 0, must handle partial writes on our own.
// ref: https://tls.mbed.org/api/ssl_8h.html#a5bbda87d484de82df730758b475f32e5
int rc = 0;
while ((rc = mbedtls_ssl_write( &ssl, buffer, len )) >= 0)
{
assert(rc <= len); // can this be greater than? what does that mean?
if (rc >= len) break;
buffer += rc;
len -= rc;
}
if (rc < 0)
return rc;
return len;
}
else
return mbedtls_ssl_write( &ssl, buffer, len );
}
/*******************************************************************************
* Copyright 2017 Andrew Domaszek
*
* All rights reserved.
* This program made available under BSD-new.
*******************************************************************************/
#pragma once
#ifndef __MQTTTRANSPORT_MBEDTLS_H_
#define __MQTTTRANSPORT_MBEDTLS_H_
#include "MQTTPacket.h"
#include <mbedtls/net_sockets.h>
#include <mbedtls/ssl.h>
#include <mbedtls/entropy.h>
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/debug.h>
#if defined(ENABLE_POLYMORPHIC)
#define VIRTUAL_FCN virtual
#else
#define VIRTUAL_FCN
#endif
class MQTTTransport_mbedTLS : public MQTTTransport
{
public:
explicit MQTTTransport_mbedTLS();
VIRTUAL_FCN ~MQTTTransport_mbedTLS();
/* aed.20170531:
* I disabled copy construction because I don't know if it is safe for
* mbedtls contexts. */
MQTTTransport_mbedTLS( const MQTTTransport_mbedTLS& other ) = delete; // non construction-copyable
MQTTTransport_mbedTLS& operator=( const MQTTTransport_mbedTLS& ) = delete; // non copyable
typedef unsigned char tlsDataType;
static int tlsFetchData(void *, tlsDataType *, int); /* must return -1 for error, 0 for call again, or the number of bytes read */
VIRTUAL_FCN int connect(const char * hostname, int port);
VIRTUAL_FCN int disconnect();
VIRTUAL_FCN int read(tlsDataType * buffer, int len, int timeout_ms = 0);
VIRTUAL_FCN int write(tlsDataType * buffer, int len, int timeout_ms = 0);
VIRTUAL_FCN int pollableFd() { return conn_ctx.fd; }
protected:
int read_timeout_ms;
/* mbedtls connection information */
mbedtls_net_context conn_ctx;
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ssl_context ssl;
mbedtls_ssl_config conf;
mbedtls_x509_crt cacert;
};
#endif
/*******************************************************************************
* Copyright (c) 2012, 2013 IBM Corp.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Eclipse Distribution License v1.0 which accompany this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
* and the Eclipse Distribution License is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* Contributors:
* Ian Craggs - initial contribution
* Ian Craggs - change delimiter option from char to string
* Andrew Domaszek - Modify to use mbedTLS in a naive way
*******************************************************************************/
/*
* build as:
* g++ -lpaho-embed-mqtt3c -lmbedtls -lmbedx509 -lmbedcrypto paho.mqtt-mbedtls.cpp MQTTTransport_mbedTLS.cpp -o paho.mqtt-mbedtls
*/
/*
stdout subscriber
compulsory parameters:
topic to subscribe to
defaulted parameters:
--host localhost
--port 1883
--qos 2
--delimiter \n
--clientid stdout_subscriber
--userid none
--password none
*/
#include <stdio.h>
#include <memory.h>
//#define MQTT_DEBUG 1
#include "MQTTClient.h"
#include "MQTTTransport_mbedTLS.h"
#define DEFAULT_STACK_SIZE -1
#include "linux.cpp"
#include <signal.h>
#include <sys/time.h>
#include <stdlib.h>
#include <poll.h>
volatile int toStop = 0;
void usage()
{
printf("MQTT stdout subscriber\n");
printf("Usage: stdoutsub topicname <options>, where options are:\n");
printf(" --host <hostname> (default is localhost)\n");
printf(" --port <port> (default is 1883)\n");
printf(" --qos <qos> (default is 2)\n");
printf(" --delimiter <delim> (default is \\n)\n");
printf(" --clientid <clientid> (default is hostname+timestamp)\n");
printf(" --username none\n");
printf(" --password none\n");
printf(" --showtopics <on or off> (default is on if the topic has a wildcard, else off)\n");
exit(-1);
}
void cfinish(int sig)
{
signal(SIGINT, NULL);
toStop = 1;
}
struct opts_struct
{
char* clientid;
int nodelimiter;
char* delimiter;
MQTT::QoS qos;
char* username;
char* password;
char* host;
int port;
int showtopics;
} opts =
{
(char*)"stdout-subscriber", 0, (char*)"\n", MQTT::QOS2, NULL, NULL, (char*)"localhost", 1883, 0
};
void getopts(int argc, char** argv)
{
int count = 2;
while (count < argc)
{
if (strcmp(argv[count], "--qos") == 0)
{
if (++count < argc)
{
if (strcmp(argv[count], "0") == 0)
opts.qos = MQTT::QOS0;
else if (strcmp(argv[count], "1") == 0)
opts.qos = MQTT::QOS1;
else if (strcmp(argv[count], "2") == 0)
opts.qos = MQTT::QOS2;
else
usage();
}
else
usage();
}
else if (strcmp(argv[count], "--host") == 0)
{
if (++count < argc)
opts.host = argv[count];
else
usage();
}
else if (strcmp(argv[count], "--port") == 0)
{
if (++count < argc)
opts.port = atoi(argv[count]);
else
usage();
}
else if (strcmp(argv[count], "--clientid") == 0)
{
if (++count < argc)
opts.clientid = argv[count];
else
usage();
}
else if (strcmp(argv[count], "--username") == 0)
{
if (++count < argc)
opts.username = argv[count];
else
usage();
}
else if (strcmp(argv[count], "--password") == 0)
{
if (++count < argc)
opts.password = argv[count];
else
usage();
}
else if (strcmp(argv[count], "--delimiter") == 0)
{
if (++count < argc)
opts.delimiter = argv[count];
else
opts.nodelimiter = 1;
}
else if (strcmp(argv[count], "--showtopics") == 0)
{
if (++count < argc)
{
if (strcmp(argv[count], "on") == 0)
opts.showtopics = 1;
else if (strcmp(argv[count], "off") == 0)
opts.showtopics = 0;
else
usage();
}
else
usage();
}
count++;
}
}
template <class Network>
void myconnect(Network& ipstack, MQTT::Client<Network, Countdown, 1000>& client, MQTTPacket_connectData& data)
{
printf("Connecting to %s:%d\n", opts.host, opts.port);
int rc = ipstack.connect(opts.host, opts.port);
if (rc != 0)
printf("rc from TCP connect is %d\n", rc);
rc = client.connect(data);
if (rc != 0)
{
printf("Failed to connect, return code %d\n", rc);
exit(-1);
}
printf("Connected\n");
}
void messageArrived(MQTT::MessageData& md)
{
MQTT::Message &message = md.message;
if (opts.showtopics)
printf("%.*s\t", md.topicName.lenstring.len, md.topicName.lenstring.data);
if (opts.nodelimiter)
printf("%.*s", (int)message.payloadlen, (char*)message.payload);
else
printf("%.*s%s", (int)message.payloadlen, (char*)message.payload, opts.delimiter);
fflush(stdout);
}
int main(int argc, char** argv)
{
int rc = 0;
if (argc < 2)
usage();
const char* topic = argv[1];
getopts(argc, argv);
if (strchr(topic, '#') || strchr(topic, '+'))
opts.showtopics = 1;
if (opts.showtopics)
printf("topic is %s\n", topic);
//IPStack ipstack = IPStack();
MQTTTransport_mbedTLS ipstack;
//MQTT::Client<IPStack, Countdown, 1000> client = MQTT::Client<IPStack, Countdown, 1000>(ipstack);
auto client = MQTT::Client<MQTTTransport_mbedTLS, Countdown, 1000>(ipstack);
signal(SIGINT, cfinish);
signal(SIGTERM, cfinish);
MQTTPacket_connectData data = MQTTPacket_connectData_initializer;
data.willFlag = 0;
data.MQTTVersion = 3;
data.clientID.cstring = opts.clientid;
data.username.cstring = opts.username;
data.password.cstring = opts.password;
data.keepAliveInterval = 20;
data.cleansession = 1;
printf("will flag %d\n", data.willFlag);
myconnect(ipstack, client, data);
rc = client.subscribe(topic, opts.qos, messageArrived);
printf("Subscribed[%d] to %s\n", rc, topic);
while (!toStop)
{
#if defined(USE_POLLING)
client.yield(10);
struct pollfd pfds[1] = { { ipstack.pollableFd(), POLLIN, 0 } };
int prc = poll(pfds, 1, 1000); // wait 1 sec for fd to have something to read.
#else
client.yield(1000);
#endif
//if (!client.isconnected)
// myconnect(ipstack, client, data);
}
printf("Stopping\n");
rc = client.disconnect();
ipstack.disconnect();
return 0;
}
@brimston3
Copy link
Author

There's a few bugs in this implementation:

  1. both tls::read() and tls::write() need to detect if timeout_ms == 0, and if so, set it to 1 before calculating tv. paho.mqtt.embedded-c uses a timeout_ms == 0 to indicate non-blocking, but SO_*TIMEO uses it to mean "wait forever." It could be used instead to toggle fcntl() O_NONBLOCK but it makes more sense to me to just use a very short timeout and save the extra call every time.
  2. tls::write() is overly complex in handling partial writes. paho.mqtt.embedded-c has similar logic in it to do the same thing. I've left it alone, but I noticed the redundancy.
  3. in tls::connect(), when reusing the tls object after disconnect, mbedtls_ssl_session_reset( &ssl ) must be called before mbedtls_ssl_handshake() or it will fail with error -0x50. It is (apparently) safe to call *session_reset on every connect attempt.

@brimston3
Copy link
Author

Also, with USE_POLLING, it's extremely easy to hit paho.mqtt.embedded-c issue #115. The shorter the timeout, the more likely it is to hit that bug.

@wuhaogs
Copy link

wuhaogs commented Apr 8, 2023

对我很有帮助,感谢您,请问可以提供一个C语言版本吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment