Skip to content

Instantly share code, notes, and snippets.

@psychon
Last active June 14, 2020 15:54
Show Gist options
  • Save psychon/4f5916f91bde7b1a3a3a8fdc9e9ba5b9 to your computer and use it in GitHub Desktop.
Save psychon/4f5916f91bde7b1a3a3a8fdc9e9ba5b9 to your computer and use it in GitHub Desktop.
use std::io::{Error as IOError, ErrorKind, Read, Write};
use std::net::{TcpListener, TcpStream};
use std::os::unix::net::UnixStream;
use std::thread;
use std::env::args_os;
use std::process::Command;
use std::convert::{TryFrom, TryInto};
use std::sync::{Arc, Mutex};
use std::collections::VecDeque;
use x11rb::protocol::xproto::{Setup, SetupFailed, SetupAuthenticate, SetupRequest, GE_GENERIC_EVENT};
use x11rb::protocol::{Error, Event, Request, Reply};
use x11rb::x11_utils::{BigRequests, ExtInfoProvider, ExtensionInformation, TryParse, ReplyParsingFunction};
fn forward_impl(mut read: impl Read, mut write: impl Write, mut parser: impl Parser) -> Result<(), IOError> {
let mut buffer = [0; 4096];
let mut buffer2 = Vec::new();
loop {
let amount = match read.read(&mut buffer) {
Ok(amount) => amount,
Err(ref e) if e.kind() == ErrorKind::WouldBlock => continue, // TODO: Use poll()
Err(e) => return Err(e),
};
if amount == 0 {
return Ok(());
}
buffer2.extend(&buffer[..amount]);
let mut to_skip = 0;
loop {
let length = parser.parse(&buffer2[to_skip..]);
to_skip += length;
if length == 0 {
break
}
}
match write.write_all(&buffer[..amount]) {
Err(e) if e.kind() == ErrorKind::BrokenPipe => return Ok(()),
Err(e) => return Err(e),
Ok(()) => {}
}
buffer2.drain(..to_skip);
}
}
fn forward(read: impl Read, write: impl Write, parser: impl Parser) {
forward_impl(read, write, parser).unwrap();
}
fn handle_client(client: TcpStream) -> Result<(), IOError> {
let client2 = client.try_clone()?;
let server = UnixStream::connect("/tmp/.X11-unix/X0")?;
let server2 = server.try_clone()?;
client.set_nonblocking(true)?;
server.set_nonblocking(true)?;
let ext_info = Arc::new(Mutex::new(ConnectionState::default()));
let ext_info2 = ext_info.clone();
thread::spawn(move || forward(client, server2, ParseClient(false, ext_info)));
thread::spawn(|| forward(server, client2, ParseServer(false, ext_info2)));
Ok(())
}
fn main() -> Result<(), IOError> {
let listener = TcpListener::bind("127.0.0.1:6004")?;
std::env::set_var("DISPLAY", ":4");
let mut args = args_os().skip(1);
if let Some(command) = args.next() {
Command::new(command)
.args(args)
.spawn()
.unwrap();
}
loop {
let (socket, _addr) = listener.accept()?;
handle_client(socket)?;
}
}
trait Parser {
fn parse(&mut self, data: &[u8]) -> usize;
}
struct ParseClient(bool, Arc<Mutex<ConnectionState>>);
impl Parser for ParseClient {
fn parse(&mut self, data: &[u8]) -> usize {
let mut state = self.1.lock().unwrap();
if !self.0 {
// Read a SetupRequest
let (setup, remaining) = SetupRequest::try_parse(data).unwrap();
println!("client: {:?}", setup);
if !remaining.is_empty() {
println!(" trailing: {:?}", remaining);
}
self.0 = true;
assert_eq!(state.next_client_request, 0);
state.next_client_request = 1;
data.len() - remaining.len()
} else {
if data.len() < 4 {
// Not enough data yet
return 0;
}
// Read a request
let length = 4 * usize::from(u16::from_ne_bytes(data[2..4].try_into().unwrap()));
let length = if length == 0 {
4 * usize::try_from(u32::from_ne_bytes(data[4..8].try_into().unwrap())).unwrap()
} else {
length
};
if data.len() < length {
// Not enough data yet
return 0;
}
let (header, remaining) = x11rb::x11_utils::parse_request_header(&data[..length], BigRequests::Enabled).unwrap();
let seqno = state.next_client_request;
state.next_client_request = seqno.wrapping_add(1);
let request = Request::parse(header, remaining, &mut Vec::new(), &state.ext_info).unwrap();
let remaining: &[u8; 0] = &[]; // FIXME?
if let Some(request) = ClientRequest::new(seqno, &request) {
state.requests.push_back(request);
}
println!("client ({}): {:?}", seqno, request);
if !remaining.is_empty() {
println!(" trailing: {:?}", remaining);
}
length
}
}
}
struct ParseServer(bool, Arc<Mutex<ConnectionState>>);
impl Parser for ParseServer {
fn parse(&mut self, data: &[u8]) -> usize {
if !self.0 {
let remaining = match data[0] {
0 => {
let (s, r) = SetupFailed::try_parse(data).unwrap();
println!("server: {:?}", s);
r
},
1 => {
let (s, r) = match Setup::try_parse(data) {
Ok((s, r)) => (s, r),
Err(_) => return 0,
};
println!("server: {:?}", s);
r
},
2 => {
let (s, r) = SetupAuthenticate::try_parse(data).unwrap();
println!("server: {:?}", s);
r
},
_ => panic!("No idea what to do with {:?}", data),
};
self.0 = true;
data.len() - remaining.len()
} else {
// Read a reply/error/event
if data.len() < 32 {
// All packets have at least 32 bytes
return 0;
}
let response_type = data[0];
const ERROR: u8 = 0;
const REPLY: u8 = 1;
let length = 32 + if response_type == REPLY || response_type & 0x7f == GE_GENERIC_EVENT {
let length_field = data[4..8].try_into().unwrap();
let length_field = u32::from_ne_bytes(length_field) as usize;
assert!(length_field <= usize::max_value() / 4);
4 * length_field
} else {
0
};
if data.len() < length {
// There is still some data missing
return 0;
}
let this_buffer = &data[..length];
let mut state = self.1.lock().unwrap();
let seqno = u16::from_ne_bytes(this_buffer[2..4].try_into().unwrap());
match response_type {
ERROR => {
// Remove the request that failed from the pending requests
if state.requests.front().map(|r| r.seqno) == Some(seqno) {
state.requests.pop_front();
}
let err = Error::parse(this_buffer, &state.ext_info).unwrap();
println!("server: {:?}", err);
},
REPLY => {
let request = state.requests.pop_front().unwrap();
assert_eq!(seqno, request.seqno);
let (reply, remaining) = (request.reply_parser)(this_buffer, &mut Vec::new()).unwrap();
println!("server: {:?}", reply);
if !remaining.is_empty() {
println!(" trailing: {:?}", remaining);
}
// If this is a QueryExtension reply, we have to update our state
if let Reply::QueryExtension(reply) = reply {
let name = request.queried_extension.unwrap();
if reply.present {
let info = ExtensionInformation {
major_opcode: reply.major_opcode,
first_event: reply.first_event,
first_error: reply.first_error,
};
state.ext_info.add_extension(name, info);
}
}
},
_ => {
let ev = Event::parse(this_buffer, &state.ext_info).unwrap();
println!("server: {:?}", ev);
}
}
this_buffer.len()
}
}
}
#[derive(Clone)]
struct ClientRequest {
seqno: u16,
reply_parser: ReplyParsingFunction,
queried_extension: Option<String>,
}
impl ClientRequest {
fn new(seqno: u16, request: &Request) -> Option<Self> {
let queried_extension = if let Request::QueryExtension(query) = &request {
Some(std::str::from_utf8(query.name).unwrap().to_string())
} else {
None
};
request.reply_parser()
.map(|reply_parser| Self {
seqno,
queried_extension,
reply_parser
})
}
}
#[derive(Clone, Default)]
struct ConnectionState {
ext_info: SnoopingExtInfo,
next_client_request: u16,
requests: VecDeque<ClientRequest>,
}
#[derive(Clone, Default, Debug)]
struct SnoopingExtInfo {
exts: Vec<(String, ExtensionInformation)>,
}
impl SnoopingExtInfo {
fn add_extension(&mut self, name: String, info: ExtensionInformation) {
self.exts.push((name, info))
}
}
impl ExtInfoProvider for SnoopingExtInfo {
fn get_from_major_opcode(&self, major_opcode: u8) -> Option<(&str, ExtensionInformation)> {
self.exts
.iter()
.find(|(_, ext)| ext.major_opcode == major_opcode)
.map(|(s, ext)| (s.as_ref(), *ext))
}
fn get_from_event_code(&self, event_code: u8) -> Option<(&str, ExtensionInformation)> {
self.exts
.iter()
.filter(|(_, ext)| ext.first_event <= event_code)
.max_by_key(|(_, ext)| ext.first_event)
.map(|(s, ext)| (s.as_ref(), *ext))
}
fn get_from_error_code(&self, error_code: u8) -> Option<(&str, ExtensionInformation)> {
self.exts
.iter()
.filter(|(_, ext)| ext.first_error <= error_code)
.max_by_key(|(_, ext)| ext.first_event)
.map(|(s, ext)| (s.as_ref(), *ext))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment