Skip to content

Instantly share code, notes, and snippets.

@myers
Created February 9, 2012 05:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save myers/1777621 to your computer and use it in GitHub Desktop.
Save myers/1777621 to your computer and use it in GitHub Desktop.
=== modified file 'OpenSSL/ssl/context.c'
--- OpenSSL/ssl/context.c 2011-09-11 13:35:32 +0000
+++ OpenSSL/ssl/context.c 2012-02-09 05:35:09 +0000
@@ -238,6 +238,66 @@
}
/*
+ * Globally defined next proto callback. This is called from OpenSSL internally.
+ * The GIL will not be held when this function is invoked. It must not be held
+ * when the function returns.
+ *
+ * Arguments: ssl - The Connection
+ * **out - handle to where we can put a new string to be returned
+ * *outlen - pointer to where we can put the len of the string we want to return
+ * Returns: SSL_TLSEXT_ERR_OK
+ */
+static int
+global_next_proto_callback(SSL *ssl, const unsigned char **out, unsigned int *outlen, void *arg)
+{
+ ssl_ConnectionObj *conn = (ssl_ConnectionObj *)SSL_get_app_data(ssl);
+ PyObject *argv, *ret, *item;
+ Py_ssize_t length, i, strlength;
+ unsigned char *outptr;
+
+ /*
+ * GIL isn't held yet. First things first - acquire it, or any Python API
+ * we invoke might segfault or blow up the sun. The reverse will be done
+ * before returning.
+ */
+ MY_END_ALLOW_THREADS(conn->tstate);
+
+ argv = Py_BuildValue("(O)", (PyObject *)conn);
+ ret = PyEval_CallObject(conn->context->next_protos_advertised_callback, argv);
+ Py_DECREF(argv);
+
+ length = PyList_Size(ret);
+ for (i = 0; i < length; i++) {
+ *outlen += PyBytes_Size(PyList_GetItem(ret, i)) + 1;
+ /* FIXME: need to test that all strings are less than 255 */
+ }
+ *out = outptr = OPENSSL_malloc(*outlen);
+ for (i = 0; i < length; i++) {
+ item = PyList_GetItem(ret, i);
+ strlength = PyBytes_Size(item);
+ *outptr = Py_SAFE_DOWNCAST(strlength, Py_ssize_t, char);
+ outptr++;
+ strncpy((char *)outptr, PyString_AsString(item), strlength);
+ outptr += strlength;
+ }
+
+ if (ret == NULL) {
+ /*
+ * XXX - This should be reported somehow. -exarkun
+ */
+ PyErr_Clear();
+ } else {
+ Py_DECREF(ret);
+ }
+
+ /*
+ * This function is returning into OpenSSL. Release the GIL again.
+ */
+ MY_BEGIN_ALLOW_THREADS(conn->tstate);
+ return SSL_TLSEXT_ERR_OK;
+}
+
+/*
* Globally defined TLS extension server name callback. This is called from
* OpenSSL internally. The GIL will not be held when this function is invoked.
* It must not be held when the function returns.
@@ -1030,6 +1090,35 @@
return Py_None;
}
+static char ssl_Context_set_next_protos_advertised_callback_doc[] = "\n\
+Set the next protos advertised callback\n\
+\n\
+:param callback: The Python callback to use\n\
+:return: None\n\
+";
+static PyObject *
+ssl_Context_set_next_protos_advertised_callback(ssl_ContextObj *self, PyObject *args)
+{
+ PyObject *callback;
+
+ if (!PyArg_ParseTuple(args, "O:set_next_protos_advertised_callback", &callback))
+ return NULL;
+
+ if (!PyCallable_Check(callback))
+ {
+ PyErr_SetString(PyExc_TypeError, "expected PyCallable");
+ return NULL;
+ }
+
+ Py_DECREF(self->next_protos_advertised_callback);
+ Py_INCREF(callback);
+ self->next_protos_advertised_callback = callback;
+ SSL_CTX_set_next_protos_advertised_cb(self->ctx, global_next_proto_callback, NULL);
+
+ Py_INCREF(Py_None);
+ return Py_None;
+}
+
static char ssl_Context_get_app_data_doc[] = "\n\
Get the application data (supplied via set_app_data())\n\
\n\
@@ -1187,6 +1276,7 @@
ADD_METHOD(set_timeout),
ADD_METHOD(get_timeout),
ADD_METHOD(set_info_callback),
+ ADD_METHOD(set_next_protos_advertised_callback),
ADD_METHOD(get_app_data),
ADD_METHOD(set_app_data),
ADD_METHOD(get_cert_store),
@@ -1241,6 +1331,9 @@
self->info_callback = Py_None;
Py_INCREF(Py_None);
+ self->next_protos_advertised_callback = Py_None;
+
+ Py_INCREF(Py_None);
self->tlsext_servername_callback = Py_None;
Py_INCREF(Py_None);
@@ -1320,6 +1413,8 @@
ret = visit((PyObject *)self->verify_callback, arg);
if (ret == 0 && self->info_callback != NULL)
ret = visit((PyObject *)self->info_callback, arg);
+ if (ret == 0 && self->next_protos_advertised_callback != NULL)
+ ret = visit((PyObject *)self->next_protos_advertised_callback, arg);
if (ret == 0 && self->app_data != NULL)
ret = visit(self->app_data, arg);
return ret;
@@ -1342,6 +1437,8 @@
self->verify_callback = NULL;
Py_XDECREF(self->info_callback);
self->info_callback = NULL;
+ Py_XDECREF(self->next_protos_advertised_callback);
+ self->next_protos_advertised_callback = NULL;
Py_XDECREF(self->app_data);
self->app_data = NULL;
return 0;
=== modified file 'OpenSSL/ssl/context.h'
--- OpenSSL/ssl/context.h 2011-05-26 22:47:00 +0000
+++ OpenSSL/ssl/context.h 2012-02-09 05:09:32 +0000
@@ -29,6 +29,7 @@
*passphrase_userdata,
*verify_callback,
*info_callback,
+ *next_protos_advertised_callback,
*tlsext_servername_callback,
*app_data;
PyThreadState *tstate;
=== modified file 'OpenSSL/test/test_ssl.py'
--- OpenSSL/test/test_ssl.py 2011-09-11 14:01:31 +0000
+++ OpenSSL/test/test_ssl.py 2012-02-09 05:13:06 +0000
@@ -587,6 +587,41 @@
# Kind of lame. Just make sure it got called somehow.
self.assertTrue(called)
+ def FAIL_test_set_next_protos_advertised_callback_callback(self):
+ """
+ :py:obj:`Context.set_info_callback` accepts a callable which will be invoked
+ when certain information about an SSL connection is available.
+ """
+ (server, client) = socket_pair()
+
+ clientSSL = Connection(Context(SSLv3_METHOD), client)
+ clientSSL.set_connect_state()
+
+ called = []
+ def next_protos_advertised(conn):
+ called.append(conn)
+ return ['spdy/2', 'http/1.1']
+ context = Context(SSLv3_METHOD)
+ context.set_next_protos_advertised_callback(next_protos_advertised)
+ context.use_certificate(
+ load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
+ context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
+
+ serverSSL = Connection(context, server)
+ serverSSL.set_accept_state()
+
+ while not called:
+ for ssl in clientSSL, serverSSL:
+ try:
+ ssl.do_handshake()
+ except WantReadError:
+ print 'r'
+ pass
+
+ # Kind of lame. Just make sure it got called somehow.
+ self.assertTrue(called)
+
def _load_verify_locations_test(self, *args):
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment