Skip to content

Instantly share code, notes, and snippets.

@Mythra
Created January 28, 2021 19:55
Show Gist options
  • Save Mythra/71a88e40b5a9bd0a5964a96b59fc19ec to your computer and use it in GitHub Desktop.
Save Mythra/71a88e40b5a9bd0a5964a96b59fc19ec to your computer and use it in GitHub Desktop.
//! Connectors, and URIs for Hyper and Unix Domain Sockets. Ported from
//! `hyperlocal`, licenses in LICENSE
use color_eyre::Result;
use futures_util::future::BoxFuture;
use hex::{encode as HexEncode, FromHex};
use hyper::{
client::connect::{Connected, Connection},
service::Service,
Uri as HyperUri,
};
use pin_project::pin_project;
use std::{
io::{Error as IoError, Result as IoResult},
path::{Path, PathBuf},
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::UnixStream,
};
/// A convenience type that can be used to construct Unix Domain Socket URIs
///
/// This type implements `Into<hyper::Uri>`.
#[derive(Clone, Debug)]
pub struct Uri {
inner: HyperUri,
}
impl Uri {
/// Create a new `[Uri]` from a socket address and a path
pub fn new(socket: impl AsRef<Path>, path: &str) -> Result<Self> {
let host = HexEncode(socket.as_ref().to_string_lossy().as_bytes());
let host_str = format!("unix://{}:0{}", host, path);
let inner: HyperUri = host_str.parse()?;
Ok(Self { inner })
}
}
impl From<Uri> for HyperUri {
fn from(uri: Uri) -> Self {
uri.inner
}
}
/// A stream wrapper around a `std::UnixStream`.
#[pin_project]
#[derive(Debug)]
pub struct UdsStream {
#[pin]
unix_stream: UnixStream,
}
impl AsyncWrite for UdsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
self.project().unix_stream.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
self.project().unix_stream.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
self.project().unix_stream.poll_shutdown(cx)
}
}
impl AsyncRead for UdsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
self.project().unix_stream.poll_read(cx, buf)
}
}
impl Connection for UdsStream {
fn connected(&self) -> Connected {
Connected::new()
}
}
pub fn parse_socket_path(uri: HyperUri) -> Result<std::path::PathBuf, IoError> {
if uri.scheme_str() != Some("unix") {
return Err(IoError::new(
std::io::ErrorKind::InvalidInput,
"invalid URL, scheme must be unix",
));
}
if let Some(host) = uri.host() {
let bytes = Vec::from_hex(host).map_err(|_| {
IoError::new(
std::io::ErrorKind::InvalidInput,
"invalid URL, host must be a hex-encoded path",
)
})?;
Ok(PathBuf::from(String::from_utf8_lossy(&bytes).into_owned()))
} else {
Err(IoError::new(
std::io::ErrorKind::InvalidInput,
"invalid URL, host must be present",
))
}
}
/// the `[UnixConnector]` can be used to construct a `[hyper::Client]` which can
/// speak to a unix domain socket.
#[derive(Clone, Copy, Debug, Default)]
pub struct UnixConnector;
impl Unpin for UnixConnector {}
impl Service<HyperUri> for UnixConnector {
type Response = UdsStream;
type Error = IoError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn call(&mut self, req: HyperUri) -> Self::Future {
let fut = async move {
let path = parse_socket_path(req)?;
let unix_stream = UnixStream::connect(path).await?;
Ok(UdsStream { unix_stream })
};
Box::pin(fut)
}
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod unit_tests {
use super::Uri;
use crate::assert_eq;
use hyper::Uri as HyperUri;
#[test]
fn test_unix_uri_into_hyper_uri() {
let _ = *crate::test_helper::REGISTER_HIJACK;
let unix: HyperUri = Uri::new("foo.sock", "/")
.expect("Failed to parse uri!")
.into();
let expected: HyperUri = "unix://666f6f2e736f636b:0/".parse().unwrap();
assert_eq!(unix, expected);
}
}
//! Connectors, and URIs for Hyper and Windows Named Pipes.
use color_eyre::Result;
use futures_util::future::BoxFuture;
use hex::{encode as HexEncode, FromHex};
use hyper::{
client::connect::{Connected, Connection},
service::Service,
Uri as HyperUri,
};
use pin_project::pin_project;
use std::{
io::{Error as IoError, Result as IoResult},
path::{Path, PathBuf},
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::windows::named_pipe::NamedPipe,
};
/// A convenience type that can be used to construct Unix Domain Socket URIs
///
/// This type implements `Into<hyper::Uri>`.
#[derive(Clone, Debug)]
pub struct Uri {
inner: HyperUri,
}
impl Uri {
/// Create a new `[Uri]` from a named pipe address and path.
pub fn new(socket: impl AsRef<Path>, path: &str) -> Result<Self> {
let host = HexEncode(socket.as_ref().to_string_lossy().as_bytes());
let host_str = format!("npipe://{}:0{}", host, path);
let inner: HyperUri = host_str.parse()?;
Ok(Self { inner })
}
}
impl From<Uri> for HyperUri {
fn from(uri: Uri) -> Self {
uri.inner
}
}
pub fn parse_socket_path(uri: HyperUri) -> Result<std::path::PathBuf, IoError> {
if uri.scheme_str() != Some("npipe") {
return Err(IoError::new(
std::io::ErrorKind::InvalidInput,
"invalid URL, scheme must be npipe",
));
}
if let Some(host) = uri.host() {
let bytes = Vec::from_hex(host).map_err(|_| {
IoError::new(
std::io::ErrorKind::InvalidInput,
"invalid URL, host must be a hex-encoded path",
)
})?;
Ok(PathBuf::from(String::from_utf8_lossy(&bytes).into_owned()))
} else {
Err(IoError::new(
std::io::ErrorKind::InvalidInput,
"invalid URL, host must be present",
))
}
}
/// A stream connecting to a named pipe.
#[pin_project]
#[derive(Debug)]
pub struct NamedPipeStream {
#[pin]
io_stream: NamedPipe,
}
impl AsyncWrite for NamedPipeStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, IoError>> {
self.project().io_stream.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
self.project().io_stream.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), IoError>> {
self.project().io_stream.poll_shutdown(cx)
}
}
impl AsyncRead for NamedPipeStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
self.project().io_stream.poll_read(cx, buf)
}
}
impl Connection for NamedPipeStream {
fn connected(&self) -> Connected {
Connected::new()
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct NamedPipeConnector;
impl Unpin for NamedPipeConnector {}
impl Service<HyperUri> for NamedPipeConnector {
type Response = NamedPipeStream;
type Error = IoError;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn call(&mut self, req: HyperUri) -> Self::Future {
let fut = async move {
let path = parse_socket_path(req)?;
let io_stream = NamedPipe::connect(path).await?;
Ok(NamedPipeStream { io_stream })
};
Box::pin(fut)
}
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod unit_tests {
use super::Uri;
use crate::assert_eq;
use hyper::Uri as HyperUri;
#[test]
fn test_unix_uri_into_hyper_uri() {
let _ = *crate::test_helper::REGISTER_HIJACK;
let unix: HyperUri = Uri::new(r#"\\.\pipe\docker_engine"#, "/")
.expect("Failed to parse uri!")
.into();
let expected: HyperUri = "npipe://5c5c2e5c706970655c646f636b65725f656e67696e65:0/"
.parse()
.unwrap();
assert_eq!(unix, expected);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment