Instantly share code, notes, and snippets.

@ayende ayende/openssl.rs
Created Jan 4, 2019

Embed
What would you like to do?
extern crate custom_error;
extern crate memmem;
extern crate openssl;
extern crate openssl_sys;
extern crate hex;
extern crate foreign_types_shared;
#[macro_use] extern crate lazy_static;
use std::io;
use std::io::Write;
use std::sync::Arc;
use std::collections::HashMap;
use std::collections::HashSet;
use custom_error::custom_error;
use std::net::{TcpListener, TcpStream};
use memmem::{Searcher, TwoWaySearcher};
use foreign_types_shared::ForeignTypeRef;
custom_error! {
ConnectionError
Io{source: io::Error} = "unable to read from the network",
Utf8{source: std::str::Utf8Error} = "Invalid UTF8 character sequence",
Parse{origin: String} = "Unable to parse command: {origin}",
MessageTooBig = "Message length was over 8KB",
SslIssue{source : openssl::error::ErrorStack} = "OpenSSL error {source}",
Handshake{source: openssl::ssl::HandshakeError<std::net::TcpStream>} = "Handshake error {source}",
InvalidTimeFormat = "Unable to understand certificate time",
ClientCertExpired{date: String} = "The client certificate has already expired: {date}",
ClientCertNotYetValid{date: String} = "The client certificate is not yet valid: {date}"
}
impl ConnectionError {
fn parsing(origin: &str) -> ConnectionError {
ConnectionError::Parse{ origin: origin.to_string() }
}
}
struct Cmd<'a> {
args: Vec<&'a str>,
headers: HashMap<&'a str, &'a str>,
}
lazy_static! {
static ref msg_break : TwoWaySearcher<'static> = {
TwoWaySearcher::new(b"\r\n\r\n")
};
}
struct Server {
tls_config: Arc<openssl::ssl::SslAcceptor>,
tcp_listener: TcpListener,
allowed_certs_thumbprints: HashSet<String>
}
impl Server {
fn new(cert_path: &str, key_path: &str, listen_uri: &str, allowed_certs_thumbprints: &[&str]) -> Result<Server, ConnectionError> {
let mut allowed = HashSet::new();
for thumbprint in allowed_certs_thumbprints {
allowed.insert(thumbprint.to_lowercase());
}
let mut sslb = openssl::ssl::SslAcceptor::mozilla_modern(openssl::ssl::SslMethod::tls())?;
sslb.set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)?;
sslb.set_certificate_chain_file(cert_path)?;
sslb.check_private_key()?;
// accept all certificates, we'll do our own validation on them
sslb.set_verify_callback(openssl::ssl::SslVerifyMode::PEER, |_, _| true);
let listener = TcpListener::bind(listen_uri)?;
Ok(Server { tls_config: Arc::new(sslb.build()), tcp_listener: listener, allowed_certs_thumbprints: allowed })
}
}
fn parse_cmd<'a>(cmd_str: &'a str) -> Result<Cmd, ConnectionError> {
let mut lines = cmd_str.lines();
let cmd_line = match lines.next() {
None => {
return Err(ConnectionError::parsing(cmd_str));
}
Some(v) => v,
};
let mut cmd = Cmd {
args: cmd_line.split(' ').collect(),
headers: HashMap::new(),
};
for line in lines {
let parts: Vec<&str> = line.splitn(2, ':').collect();
if parts.len() != 2 {
return Err(ConnectionError::parsing(line));
}
cmd.headers.insert(parts[0].trim(), parts[1].trim());
}
Ok(cmd)
}
fn read_full_message<'a>(stream: &mut io::Read, buffer: &'a mut Vec<u8>) -> Result<&'a [u8], ConnectionError> {
let mut to_scan = 0;
let mut tmp_buf = [0; 256];
loop {
match msg_break.search_in(&buffer[to_scan..]) {
None => to_scan = if buffer.len() > 3 { buffer.len() - 3} else { 0 },
Some(msg_end) => return Ok(&buffer[0..(to_scan + msg_end + 4)])
}
let read = stream.read(&mut tmp_buf)?;
if read + buffer.len() > 8192 {
return Err(ConnectionError::MessageTooBig)
}
buffer.extend_from_slice(&tmp_buf[0..read]);
}
}
fn dispatch_cmd<S>(stream: &mut S, cmd : Cmd) -> io::Result<()>
where S : io::Write + io::Read {
stream.write(&cmd.args[0].as_bytes())?;
stream.flush()?;
Ok(())
}
fn authenticate_certificate(stream: &mut openssl::ssl::SslStream<TcpStream>, server: &Server) -> Result<bool, ConnectionError> {
fn get_friendly_name(peer: &openssl::x509::X509) -> String {
peer.subject_name() // can't figure out how to get the real friendly name
.entries()
.last()
.map( |it| it.data()
.as_utf8()
.and_then(|s| Ok(s.to_string()))
.unwrap_or("".to_string())
)
.unwrap_or("<Unknown>".to_string())
}
extern "C" {
fn ASN1_TIME_diff(
pday: *mut std::os::raw::c_int,
psec: *mut std::os::raw::c_int,
from: *const openssl_sys::ASN1_TIME,
to: *const openssl_sys::ASN1_TIME) -> std::os::raw::c_int;
}
fn is_before(x: &openssl::asn1::Asn1TimeRef, y: &openssl::asn1::Asn1TimeRef) -> Result<bool, ConnectionError> {
unsafe {
let mut day : std::os::raw::c_int = 0;
let mut sec : std::os::raw::c_int = 0;
match ASN1_TIME_diff(&mut day, &mut sec, x.as_ptr(), y.as_ptr() ) {
0 => Err(ConnectionError::InvalidTimeFormat),
_ => Ok(day > 0 || sec > 0)
}
}
}
fn is_valid_time(peer: &openssl::x509::X509) -> Result<(), ConnectionError> {
let now = openssl::asn1::Asn1Time::days_from_now(0)?;
if is_before(&now, peer.not_before())? {
return Err(ConnectionError::ClientCertNotYetValid { date: peer.not_before().to_string() });
}
if is_before(peer.not_after(), &now)? {
return Err(ConnectionError::ClientCertExpired { date: peer.not_after().to_string() } );
}
Ok(())
}
match stream.ssl().peer_certificate() {
None => {
stream.write(b"ERR No certificate was provided\r\n")?;
return Ok(false);
}
Some(peer) => {
let thumbprint = hex::encode(peer.digest(openssl::hash::MessageDigest::sha1())?);
if server.allowed_certs_thumbprints.contains(&thumbprint) == false {
let msg = format!("ERR certificate ({}) thumbprint '{}' is unknown\r\n",
get_friendly_name(&peer),
thumbprint);
stream.write(msg.as_bytes())?;
return Ok(false);
}
if let Err(e) = is_valid_time(&peer) {
let msg = format!("ERR certificate ({}) thumbprint '{}' cannot be used: {}\r\n",
get_friendly_name(&peer),
thumbprint,
e);
stream.write(msg.as_bytes())?;
return Ok(false);
}
}
};
return Ok(true);
}
fn handle_connection(socket: TcpStream, server: &Server) -> Result<(), ConnectionError> {
let acceptor = server.tls_config.clone();
let mut stream = acceptor.accept(socket)?;
if authenticate_certificate(&mut stream, server)? == false{
return Ok(());// error already sent to client
}
stream.write(b"OK\r\n")?;
let mut cmd_buffer = Vec::new();
loop {
let consumed_bytes = {
let msg = read_full_message(&mut stream, &mut cmd_buffer)?;
let cmd_str = std::str::from_utf8(&msg[0..msg.len()-4])?;
let cmd = parse_cmd(&cmd_str)?;
dispatch_cmd(&mut stream, cmd)?;
msg.len()
};
cmd_buffer.drain(0 .. consumed_bytes);
}
}
fn main() -> Result<(), ConnectionError> {
let server = Server::new(
"C:\\Work\\temp\\example-com.cert.pem",
"C:\\Work\\temp\\example-com.key.pem",
"127.0.0.1:4888",
// allowed thumprints
&["1776821db1002b0e2a9b4ee3d5ee14133d367009"]
)?;
println!("Started");
for stream in server.tcp_listener.incoming() {
let x = handle_connection(stream?, &server);
println!("{:?}", x);
}
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment