Last active
June 14, 2020 15:54
-
-
Save psychon/4f5916f91bde7b1a3a3a8fdc9e9ba5b9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::io::{Error as IOError, ErrorKind, Read, Write}; | |
use std::net::{TcpListener, TcpStream}; | |
use std::os::unix::net::UnixStream; | |
use std::thread; | |
use std::env::args_os; | |
use std::process::Command; | |
use std::convert::{TryFrom, TryInto}; | |
use std::sync::{Arc, Mutex}; | |
use std::collections::VecDeque; | |
use x11rb::protocol::xproto::{Setup, SetupFailed, SetupAuthenticate, SetupRequest, GE_GENERIC_EVENT}; | |
use x11rb::protocol::{Error, Event, Request, Reply}; | |
use x11rb::x11_utils::{BigRequests, ExtInfoProvider, ExtensionInformation, TryParse, ReplyParsingFunction}; | |
fn forward_impl(mut read: impl Read, mut write: impl Write, mut parser: impl Parser) -> Result<(), IOError> { | |
let mut buffer = [0; 4096]; | |
let mut buffer2 = Vec::new(); | |
loop { | |
let amount = match read.read(&mut buffer) { | |
Ok(amount) => amount, | |
Err(ref e) if e.kind() == ErrorKind::WouldBlock => continue, // TODO: Use poll() | |
Err(e) => return Err(e), | |
}; | |
if amount == 0 { | |
return Ok(()); | |
} | |
buffer2.extend(&buffer[..amount]); | |
let mut to_skip = 0; | |
loop { | |
let length = parser.parse(&buffer2[to_skip..]); | |
to_skip += length; | |
if length == 0 { | |
break | |
} | |
} | |
match write.write_all(&buffer[..amount]) { | |
Err(e) if e.kind() == ErrorKind::BrokenPipe => return Ok(()), | |
Err(e) => return Err(e), | |
Ok(()) => {} | |
} | |
buffer2.drain(..to_skip); | |
} | |
} | |
fn forward(read: impl Read, write: impl Write, parser: impl Parser) { | |
forward_impl(read, write, parser).unwrap(); | |
} | |
fn handle_client(client: TcpStream) -> Result<(), IOError> { | |
let client2 = client.try_clone()?; | |
let server = UnixStream::connect("/tmp/.X11-unix/X0")?; | |
let server2 = server.try_clone()?; | |
client.set_nonblocking(true)?; | |
server.set_nonblocking(true)?; | |
let ext_info = Arc::new(Mutex::new(ConnectionState::default())); | |
let ext_info2 = ext_info.clone(); | |
thread::spawn(move || forward(client, server2, ParseClient(false, ext_info))); | |
thread::spawn(|| forward(server, client2, ParseServer(false, ext_info2))); | |
Ok(()) | |
} | |
fn main() -> Result<(), IOError> { | |
let listener = TcpListener::bind("127.0.0.1:6004")?; | |
std::env::set_var("DISPLAY", ":4"); | |
let mut args = args_os().skip(1); | |
if let Some(command) = args.next() { | |
Command::new(command) | |
.args(args) | |
.spawn() | |
.unwrap(); | |
} | |
loop { | |
let (socket, _addr) = listener.accept()?; | |
handle_client(socket)?; | |
} | |
} | |
trait Parser { | |
fn parse(&mut self, data: &[u8]) -> usize; | |
} | |
struct ParseClient(bool, Arc<Mutex<ConnectionState>>); | |
impl Parser for ParseClient { | |
fn parse(&mut self, data: &[u8]) -> usize { | |
let mut state = self.1.lock().unwrap(); | |
if !self.0 { | |
// Read a SetupRequest | |
let (setup, remaining) = SetupRequest::try_parse(data).unwrap(); | |
println!("client: {:?}", setup); | |
if !remaining.is_empty() { | |
println!(" trailing: {:?}", remaining); | |
} | |
self.0 = true; | |
assert_eq!(state.next_client_request, 0); | |
state.next_client_request = 1; | |
data.len() - remaining.len() | |
} else { | |
if data.len() < 4 { | |
// Not enough data yet | |
return 0; | |
} | |
// Read a request | |
let length = 4 * usize::from(u16::from_ne_bytes(data[2..4].try_into().unwrap())); | |
let length = if length == 0 { | |
4 * usize::try_from(u32::from_ne_bytes(data[4..8].try_into().unwrap())).unwrap() | |
} else { | |
length | |
}; | |
if data.len() < length { | |
// Not enough data yet | |
return 0; | |
} | |
let (header, remaining) = x11rb::x11_utils::parse_request_header(&data[..length], BigRequests::Enabled).unwrap(); | |
let seqno = state.next_client_request; | |
state.next_client_request = seqno.wrapping_add(1); | |
let request = Request::parse(header, remaining, &mut Vec::new(), &state.ext_info).unwrap(); | |
let remaining: &[u8; 0] = &[]; // FIXME? | |
if let Some(request) = ClientRequest::new(seqno, &request) { | |
state.requests.push_back(request); | |
} | |
println!("client ({}): {:?}", seqno, request); | |
if !remaining.is_empty() { | |
println!(" trailing: {:?}", remaining); | |
} | |
length | |
} | |
} | |
} | |
struct ParseServer(bool, Arc<Mutex<ConnectionState>>); | |
impl Parser for ParseServer { | |
fn parse(&mut self, data: &[u8]) -> usize { | |
if !self.0 { | |
let remaining = match data[0] { | |
0 => { | |
let (s, r) = SetupFailed::try_parse(data).unwrap(); | |
println!("server: {:?}", s); | |
r | |
}, | |
1 => { | |
let (s, r) = match Setup::try_parse(data) { | |
Ok((s, r)) => (s, r), | |
Err(_) => return 0, | |
}; | |
println!("server: {:?}", s); | |
r | |
}, | |
2 => { | |
let (s, r) = SetupAuthenticate::try_parse(data).unwrap(); | |
println!("server: {:?}", s); | |
r | |
}, | |
_ => panic!("No idea what to do with {:?}", data), | |
}; | |
self.0 = true; | |
data.len() - remaining.len() | |
} else { | |
// Read a reply/error/event | |
if data.len() < 32 { | |
// All packets have at least 32 bytes | |
return 0; | |
} | |
let response_type = data[0]; | |
const ERROR: u8 = 0; | |
const REPLY: u8 = 1; | |
let length = 32 + if response_type == REPLY || response_type & 0x7f == GE_GENERIC_EVENT { | |
let length_field = data[4..8].try_into().unwrap(); | |
let length_field = u32::from_ne_bytes(length_field) as usize; | |
assert!(length_field <= usize::max_value() / 4); | |
4 * length_field | |
} else { | |
0 | |
}; | |
if data.len() < length { | |
// There is still some data missing | |
return 0; | |
} | |
let this_buffer = &data[..length]; | |
let mut state = self.1.lock().unwrap(); | |
let seqno = u16::from_ne_bytes(this_buffer[2..4].try_into().unwrap()); | |
match response_type { | |
ERROR => { | |
// Remove the request that failed from the pending requests | |
if state.requests.front().map(|r| r.seqno) == Some(seqno) { | |
state.requests.pop_front(); | |
} | |
let err = Error::parse(this_buffer, &state.ext_info).unwrap(); | |
println!("server: {:?}", err); | |
}, | |
REPLY => { | |
let request = state.requests.pop_front().unwrap(); | |
assert_eq!(seqno, request.seqno); | |
let (reply, remaining) = (request.reply_parser)(this_buffer, &mut Vec::new()).unwrap(); | |
println!("server: {:?}", reply); | |
if !remaining.is_empty() { | |
println!(" trailing: {:?}", remaining); | |
} | |
// If this is a QueryExtension reply, we have to update our state | |
if let Reply::QueryExtension(reply) = reply { | |
let name = request.queried_extension.unwrap(); | |
if reply.present { | |
let info = ExtensionInformation { | |
major_opcode: reply.major_opcode, | |
first_event: reply.first_event, | |
first_error: reply.first_error, | |
}; | |
state.ext_info.add_extension(name, info); | |
} | |
} | |
}, | |
_ => { | |
let ev = Event::parse(this_buffer, &state.ext_info).unwrap(); | |
println!("server: {:?}", ev); | |
} | |
} | |
this_buffer.len() | |
} | |
} | |
} | |
#[derive(Clone)] | |
struct ClientRequest { | |
seqno: u16, | |
reply_parser: ReplyParsingFunction, | |
queried_extension: Option<String>, | |
} | |
impl ClientRequest { | |
fn new(seqno: u16, request: &Request) -> Option<Self> { | |
let queried_extension = if let Request::QueryExtension(query) = &request { | |
Some(std::str::from_utf8(query.name).unwrap().to_string()) | |
} else { | |
None | |
}; | |
request.reply_parser() | |
.map(|reply_parser| Self { | |
seqno, | |
queried_extension, | |
reply_parser | |
}) | |
} | |
} | |
#[derive(Clone, Default)] | |
struct ConnectionState { | |
ext_info: SnoopingExtInfo, | |
next_client_request: u16, | |
requests: VecDeque<ClientRequest>, | |
} | |
#[derive(Clone, Default, Debug)] | |
struct SnoopingExtInfo { | |
exts: Vec<(String, ExtensionInformation)>, | |
} | |
impl SnoopingExtInfo { | |
fn add_extension(&mut self, name: String, info: ExtensionInformation) { | |
self.exts.push((name, info)) | |
} | |
} | |
impl ExtInfoProvider for SnoopingExtInfo { | |
fn get_from_major_opcode(&self, major_opcode: u8) -> Option<(&str, ExtensionInformation)> { | |
self.exts | |
.iter() | |
.find(|(_, ext)| ext.major_opcode == major_opcode) | |
.map(|(s, ext)| (s.as_ref(), *ext)) | |
} | |
fn get_from_event_code(&self, event_code: u8) -> Option<(&str, ExtensionInformation)> { | |
self.exts | |
.iter() | |
.filter(|(_, ext)| ext.first_event <= event_code) | |
.max_by_key(|(_, ext)| ext.first_event) | |
.map(|(s, ext)| (s.as_ref(), *ext)) | |
} | |
fn get_from_error_code(&self, error_code: u8) -> Option<(&str, ExtensionInformation)> { | |
self.exts | |
.iter() | |
.filter(|(_, ext)| ext.first_error <= error_code) | |
.max_by_key(|(_, ext)| ext.first_event) | |
.map(|(s, ext)| (s.as_ref(), *ext)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment