Skip to content

Instantly share code, notes, and snippets.

@tenderlove
Created July 6, 2015 21:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tenderlove/b19bdea0d98fd1c0b655 to your computer and use it in GitHub Desktop.
Save tenderlove/b19bdea0d98fd1c0b655 to your computer and use it in GitHub Desktop.
diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c
index 4496d46..c6b51af 100644
--- a/ext/openssl/ossl_ssl.c
+++ b/ext/openssl/ossl_ssl.c
@@ -645,6 +645,32 @@ ssl_npn_select_cb(SSL *s, unsigned char **out, unsigned char *outlen, const unsi
return SSL_TLSEXT_ERR_OK;
}
+
+static int
+ssl_alpn_select_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg)
+{
+ int i = 0;
+ VALUE sslctx_obj, cb, protocols, selected;
+
+ sslctx_obj = (VALUE) arg;
+ cb = rb_iv_get(sslctx_obj, "@alpn_select_cb");
+ protocols = rb_ary_new();
+
+ /* The format is len_1|proto_1|...|len_n|proto_n\0 */
+ while (in[i]) {
+ VALUE protocol = rb_str_new((const char *) &in[i + 1], in[i]);
+ rb_ary_push(protocols, protocol);
+ i += in[i] + 1;
+ }
+
+ selected = rb_funcall(cb, rb_intern("call"), 1, protocols);
+ StringValue(selected);
+ *out = (unsigned char *) StringValuePtr(selected);
+ *outlen = RSTRING_LENINT(selected);
+
+ return SSL_TLSEXT_ERR_OK;
+}
+
#endif
/* This function may serve as the entry point to support further
@@ -654,6 +680,10 @@ ssl_info_cb(const SSL *ssl, int where, int val)
{
int state = SSL_state(ssl);
+ if (where & SSL_CB_ALERT) {
+ printf("wtf: %s\n", SSL_alert_desc_string(val));
+ }
+
if ((where & SSL_CB_HANDSHAKE_START) &&
(state & SSL_ST_ACCEPT)) {
ssl_renegotiation_cb(ssl);
@@ -789,6 +819,10 @@ ossl_sslctx_setup(VALUE self)
SSL_CTX_set_next_proto_select_cb(ctx, ssl_npn_select_cb, (void *) self);
OSSL_Debug("SSL NPN select callback added");
}
+ if (RTEST(rb_iv_get(self, "@alpn_select_cb"))) {
+ SSL_CTX_set_alpn_select_cb(ctx, ssl_alpn_select_cb, (void *) self);
+ OSSL_Debug("SSL ALPN select callback added");
+ }
#endif
rb_obj_freeze(self);
@@ -1425,6 +1459,43 @@ ossl_ssl_accept_nonblock(int argc, VALUE *argv, VALUE self)
return ossl_start_ssl(self, SSL_accept, "SSL_accept", 1, no_exception);
}
+static VALUE ossl_ssl_accept_state(VALUE self)
+{
+ SSL *ssl;
+ GetSSL(self, ssl);
+
+ SSL_set_accept_state(ssl);
+
+ return Qnil;
+}
+
+static VALUE
+ossl_ssl_handshake(VALUE self)
+{
+ SSL *ssl;
+ int retval;
+
+ GetSSL(self, ssl);
+ retval = SSL_do_handshake(ssl);
+
+ switch(ssl_get_error(ssl, retval)){
+ case SSL_ERROR_NONE:
+ return Qtrue;
+ case SSL_ERROR_ZERO_RETURN:
+ return Qnil;
+ case SSL_ERROR_WANT_WRITE:
+ return sym_wait_writable;
+ case SSL_ERROR_WANT_READ:
+ return sym_wait_readable;
+ case SSL_ERROR_SYSCALL:
+ if(ERR_peek_error() == 0 && retval == 0) {
+ return Qnil;
+ }
+ default:
+ ossl_raise(eSSLError, "SSL_read");
+ }
+}
+
static VALUE
ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
{
@@ -1462,6 +1533,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
case SSL_ERROR_NONE:
goto end;
case SSL_ERROR_ZERO_RETURN:
+ printf("ZERO RETURN\n");
if (no_exception) { return Qnil; }
rb_eof_error();
case SSL_ERROR_WANT_WRITE:
@@ -1909,6 +1981,29 @@ ossl_ssl_npn_protocol(VALUE self)
else
return rb_str_new((const char *) out, outlen);
}
+
+/*
+ * call-seq:
+ * ssl.alpn_protocol => String
+ *
+ * Returns the ALPN protocol string that was finally selected by the client
+ * during the handshake.
+ */
+static VALUE
+ossl_ssl_alpn_protocol(VALUE self)
+{
+ SSL *ssl;
+ const unsigned char *out;
+ unsigned int outlen;
+
+ ossl_ssl_data_get_struct(self, ssl);
+
+ SSL_get0_alpn_selected(ssl, &out, &outlen);
+ if (!outlen)
+ return Qnil;
+ else
+ return rb_str_new((const char *) out, outlen);
+}
# endif
#endif /* !defined(OPENSSL_NO_SOCK) */
@@ -2159,6 +2254,7 @@ Init_ossl_ssl(void)
*/
rb_attr(cSSLContext, rb_intern("npn_select_cb"), 1, 1, Qfalse);
#endif
+ rb_attr(cSSLContext, rb_intern("alpn_select_cb"), 1, 1, Qfalse);
rb_define_alias(cSSLContext, "ssl_timeout", "timeout");
rb_define_alias(cSSLContext, "ssl_timeout=", "timeout=");
@@ -2254,6 +2350,8 @@ Init_ossl_ssl(void)
rb_define_method(cSSLSocket, "connect", ossl_ssl_connect, 0);
rb_define_method(cSSLSocket, "connect_nonblock", ossl_ssl_connect_nonblock, -1);
rb_define_method(cSSLSocket, "accept", ossl_ssl_accept, 0);
+ rb_define_method(cSSLSocket, "accept_state", ossl_ssl_accept_state, 0);
+ rb_define_method(cSSLSocket, "handshake", ossl_ssl_handshake, 0);
rb_define_method(cSSLSocket, "accept_nonblock", ossl_ssl_accept_nonblock, -1);
rb_define_method(cSSLSocket, "sysread", ossl_ssl_read, -1);
rb_define_private_method(cSSLSocket, "sysread_nonblock", ossl_ssl_read_nonblock, -1);
@@ -2274,6 +2372,7 @@ Init_ossl_ssl(void)
rb_define_method(cSSLSocket, "client_ca", ossl_ssl_get_client_ca_list, 0);
# ifdef HAVE_OPENSSL_NPN_NEGOTIATED
rb_define_method(cSSLSocket, "npn_protocol", ossl_ssl_npn_protocol, 0);
+ rb_define_method(cSSLSocket, "alpn_protocol", ossl_ssl_alpn_protocol, 0);
# endif
#endif
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment