Created
January 4, 2019 09:12
-
-
Save ayende/9e36c8f04e4830b154a59de9ca169c8f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate custom_error; | |
extern crate memmem; | |
extern crate openssl; | |
extern crate openssl_sys; | |
extern crate hex; | |
extern crate foreign_types_shared; | |
#[macro_use] extern crate lazy_static; | |
use std::io; | |
use std::io::Write; | |
use std::sync::Arc; | |
use std::collections::HashMap; | |
use std::collections::HashSet; | |
use custom_error::custom_error; | |
use std::net::{TcpListener, TcpStream}; | |
use memmem::{Searcher, TwoWaySearcher}; | |
use foreign_types_shared::ForeignTypeRef; | |
custom_error! { | |
ConnectionError | |
Io{source: io::Error} = "unable to read from the network", | |
Utf8{source: std::str::Utf8Error} = "Invalid UTF8 character sequence", | |
Parse{origin: String} = "Unable to parse command: {origin}", | |
MessageTooBig = "Message length was over 8KB", | |
SslIssue{source : openssl::error::ErrorStack} = "OpenSSL error {source}", | |
Handshake{source: openssl::ssl::HandshakeError<std::net::TcpStream>} = "Handshake error {source}", | |
InvalidTimeFormat = "Unable to understand certificate time", | |
ClientCertExpired{date: String} = "The client certificate has already expired: {date}", | |
ClientCertNotYetValid{date: String} = "The client certificate is not yet valid: {date}" | |
} | |
impl ConnectionError { | |
fn parsing(origin: &str) -> ConnectionError { | |
ConnectionError::Parse{ origin: origin.to_string() } | |
} | |
} | |
struct Cmd<'a> { | |
args: Vec<&'a str>, | |
headers: HashMap<&'a str, &'a str>, | |
} | |
lazy_static! { | |
static ref msg_break : TwoWaySearcher<'static> = { | |
TwoWaySearcher::new(b"\r\n\r\n") | |
}; | |
} | |
struct Server { | |
tls_config: Arc<openssl::ssl::SslAcceptor>, | |
tcp_listener: TcpListener, | |
allowed_certs_thumbprints: HashSet<String> | |
} | |
impl Server { | |
fn new(cert_path: &str, key_path: &str, listen_uri: &str, allowed_certs_thumbprints: &[&str]) -> Result<Server, ConnectionError> { | |
let mut allowed = HashSet::new(); | |
for thumbprint in allowed_certs_thumbprints { | |
allowed.insert(thumbprint.to_lowercase()); | |
} | |
let mut sslb = openssl::ssl::SslAcceptor::mozilla_modern(openssl::ssl::SslMethod::tls())?; | |
sslb.set_private_key_file(key_path, openssl::ssl::SslFiletype::PEM)?; | |
sslb.set_certificate_chain_file(cert_path)?; | |
sslb.check_private_key()?; | |
// accept all certificates, we'll do our own validation on them | |
sslb.set_verify_callback(openssl::ssl::SslVerifyMode::PEER, |_, _| true); | |
let listener = TcpListener::bind(listen_uri)?; | |
Ok(Server { tls_config: Arc::new(sslb.build()), tcp_listener: listener, allowed_certs_thumbprints: allowed }) | |
} | |
} | |
fn parse_cmd<'a>(cmd_str: &'a str) -> Result<Cmd, ConnectionError> { | |
let mut lines = cmd_str.lines(); | |
let cmd_line = match lines.next() { | |
None => { | |
return Err(ConnectionError::parsing(cmd_str)); | |
} | |
Some(v) => v, | |
}; | |
let mut cmd = Cmd { | |
args: cmd_line.split(' ').collect(), | |
headers: HashMap::new(), | |
}; | |
for line in lines { | |
let parts: Vec<&str> = line.splitn(2, ':').collect(); | |
if parts.len() != 2 { | |
return Err(ConnectionError::parsing(line)); | |
} | |
cmd.headers.insert(parts[0].trim(), parts[1].trim()); | |
} | |
Ok(cmd) | |
} | |
fn read_full_message<'a>(stream: &mut io::Read, buffer: &'a mut Vec<u8>) -> Result<&'a [u8], ConnectionError> { | |
let mut to_scan = 0; | |
let mut tmp_buf = [0; 256]; | |
loop { | |
match msg_break.search_in(&buffer[to_scan..]) { | |
None => to_scan = if buffer.len() > 3 { buffer.len() - 3} else { 0 }, | |
Some(msg_end) => return Ok(&buffer[0..(to_scan + msg_end + 4)]) | |
} | |
let read = stream.read(&mut tmp_buf)?; | |
if read + buffer.len() > 8192 { | |
return Err(ConnectionError::MessageTooBig) | |
} | |
buffer.extend_from_slice(&tmp_buf[0..read]); | |
} | |
} | |
fn dispatch_cmd<S>(stream: &mut S, cmd : Cmd) -> io::Result<()> | |
where S : io::Write + io::Read { | |
stream.write(&cmd.args[0].as_bytes())?; | |
stream.flush()?; | |
Ok(()) | |
} | |
fn authenticate_certificate(stream: &mut openssl::ssl::SslStream<TcpStream>, server: &Server) -> Result<bool, ConnectionError> { | |
fn get_friendly_name(peer: &openssl::x509::X509) -> String { | |
peer.subject_name() // can't figure out how to get the real friendly name | |
.entries() | |
.last() | |
.map( |it| it.data() | |
.as_utf8() | |
.and_then(|s| Ok(s.to_string())) | |
.unwrap_or("".to_string()) | |
) | |
.unwrap_or("<Unknown>".to_string()) | |
} | |
extern "C" { | |
fn ASN1_TIME_diff( | |
pday: *mut std::os::raw::c_int, | |
psec: *mut std::os::raw::c_int, | |
from: *const openssl_sys::ASN1_TIME, | |
to: *const openssl_sys::ASN1_TIME) -> std::os::raw::c_int; | |
} | |
fn is_before(x: &openssl::asn1::Asn1TimeRef, y: &openssl::asn1::Asn1TimeRef) -> Result<bool, ConnectionError> { | |
unsafe { | |
let mut day : std::os::raw::c_int = 0; | |
let mut sec : std::os::raw::c_int = 0; | |
match ASN1_TIME_diff(&mut day, &mut sec, x.as_ptr(), y.as_ptr() ) { | |
0 => Err(ConnectionError::InvalidTimeFormat), | |
_ => Ok(day > 0 || sec > 0) | |
} | |
} | |
} | |
fn is_valid_time(peer: &openssl::x509::X509) -> Result<(), ConnectionError> { | |
let now = openssl::asn1::Asn1Time::days_from_now(0)?; | |
if is_before(&now, peer.not_before())? { | |
return Err(ConnectionError::ClientCertNotYetValid { date: peer.not_before().to_string() }); | |
} | |
if is_before(peer.not_after(), &now)? { | |
return Err(ConnectionError::ClientCertExpired { date: peer.not_after().to_string() } ); | |
} | |
Ok(()) | |
} | |
match stream.ssl().peer_certificate() { | |
None => { | |
stream.write(b"ERR No certificate was provided\r\n")?; | |
return Ok(false); | |
} | |
Some(peer) => { | |
let thumbprint = hex::encode(peer.digest(openssl::hash::MessageDigest::sha1())?); | |
if server.allowed_certs_thumbprints.contains(&thumbprint) == false { | |
let msg = format!("ERR certificate ({}) thumbprint '{}' is unknown\r\n", | |
get_friendly_name(&peer), | |
thumbprint); | |
stream.write(msg.as_bytes())?; | |
return Ok(false); | |
} | |
if let Err(e) = is_valid_time(&peer) { | |
let msg = format!("ERR certificate ({}) thumbprint '{}' cannot be used: {}\r\n", | |
get_friendly_name(&peer), | |
thumbprint, | |
e); | |
stream.write(msg.as_bytes())?; | |
return Ok(false); | |
} | |
} | |
}; | |
return Ok(true); | |
} | |
fn handle_connection(socket: TcpStream, server: &Server) -> Result<(), ConnectionError> { | |
let acceptor = server.tls_config.clone(); | |
let mut stream = acceptor.accept(socket)?; | |
if authenticate_certificate(&mut stream, server)? == false{ | |
return Ok(());// error already sent to client | |
} | |
stream.write(b"OK\r\n")?; | |
let mut cmd_buffer = Vec::new(); | |
loop { | |
let consumed_bytes = { | |
let msg = read_full_message(&mut stream, &mut cmd_buffer)?; | |
let cmd_str = std::str::from_utf8(&msg[0..msg.len()-4])?; | |
let cmd = parse_cmd(&cmd_str)?; | |
dispatch_cmd(&mut stream, cmd)?; | |
msg.len() | |
}; | |
cmd_buffer.drain(0 .. consumed_bytes); | |
} | |
} | |
fn main() -> Result<(), ConnectionError> { | |
let server = Server::new( | |
"C:\\Work\\temp\\example-com.cert.pem", | |
"C:\\Work\\temp\\example-com.key.pem", | |
"127.0.0.1:4888", | |
// allowed thumprints | |
&["1776821db1002b0e2a9b4ee3d5ee14133d367009"] | |
)?; | |
println!("Started"); | |
for stream in server.tcp_listener.incoming() { | |
let x = handle_connection(stream?, &server); | |
println!("{:?}", x); | |
} | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment