Skip to content

Instantly share code, notes, and snippets.

@reu
Last active December 16, 2023 13:44
Show Gist options
  • Save reu/8a703c0e57927050abcc3387b61c3c79 to your computer and use it in GitHub Desktop.
Save reu/8a703c0e57927050abcc3387b61c3c79 to your computer and use it in GitHub Desktop.
Simple TCP proxy using thread per connection and non blocking IO
use std::{
io::{self, Read, Write},
net::{TcpListener, TcpStream},
os::fd::{AsRawFd, RawFd},
thread,
};
const READ: libc::c_short = libc::POLLRDNORM;
const WRITE: libc::c_short = libc::POLLWRNORM;
#[repr(C)]
#[derive(Debug)]
pub struct PollFd {
fd: RawFd,
events: libc::c_short,
revents: libc::c_short,
}
impl PollFd {
fn new(fd: RawFd) -> Self {
Self {
fd,
events: READ,
revents: 0,
}
}
pub fn is_readable(&self) -> bool {
self.revents & READ != 0
}
pub fn is_writable(&self) -> bool {
self.revents & WRITE != 0
}
pub fn set(&mut self, events: libc::c_short) {
self.events |= events;
}
pub fn unset(&mut self, events: libc::c_short) {
self.events &= !events;
}
}
fn main() -> io::Result<()> {
let listener = TcpListener::bind(("0.0.0.0", 10000))?;
while let Ok((downstream, _addr)) = listener.accept() {
thread::spawn(move || {
if let Ok(upstream) = TcpStream::connect(("0.0.0.0", 4444)) {
if let Err(err) = proxy(downstream, upstream) {
eprintln!("{err}");
}
}
});
}
Ok(())
}
fn proxy(mut downstream: TcpStream, mut upstream: TcpStream) -> io::Result<()> {
downstream.set_nonblocking(true)?;
upstream.set_nonblocking(true)?;
let mut downstream_buf = vec![0; 1024];
let mut upstream_buf = vec![0; 1024];
let mut write_to_upstream: &[u8] = &[];
let mut write_to_downstream: &[u8] = &[];
let mut fds = [
PollFd::new(downstream.as_raw_fd()),
PollFd::new(upstream.as_raw_fd()),
];
loop {
unsafe {
libc::poll(
fds.as_mut_ptr() as *mut libc::pollfd,
2 as libc::nfds_t,
-1 as libc::c_int,
);
};
let (fds1, fds2) = fds.split_at_mut(1);
let poll_downstream = &mut fds1[0];
let poll_upstream = &mut fds2[0];
if poll_downstream.is_readable() {
let start = write_to_upstream.len();
let end = downstream_buf.len();
if !(start..end).is_empty() {
write_to_upstream = match downstream.read(&mut downstream_buf[start..end])? {
0 => break,
bytes => &downstream_buf[0..start + bytes],
};
poll_upstream.set(WRITE);
}
}
if poll_downstream.is_writable() && !write_to_downstream.is_empty() {
let written = downstream.write(write_to_downstream)?;
write_to_downstream = &write_to_downstream[written..write_to_downstream.len()];
if write_to_downstream.is_empty() {
poll_downstream.unset(WRITE);
}
}
if poll_upstream.is_readable() {
let start = write_to_downstream.len();
let end = upstream_buf.len();
if !(start..end).is_empty() {
write_to_downstream = match upstream.read(&mut upstream_buf[start..end])? {
0 => break,
bytes => &upstream_buf[0..start + bytes],
};
poll_downstream.set(WRITE);
}
}
if poll_upstream.is_writable() && !write_to_upstream.is_empty() {
let written = upstream.write(write_to_upstream)?;
write_to_upstream = &write_to_upstream[written..write_to_upstream.len()];
if write_to_upstream.is_empty() {
poll_upstream.unset(WRITE);
}
}
}
io::Result::Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment