Skip to content

Instantly share code, notes, and snippets.

@Crispy13
Created April 21, 2024 07:18
Show Gist options
  • Save Crispy13/320129862c5fe80740a14dcea50519fd to your computer and use it in GitHub Desktop.
Save Crispy13/320129862c5fe80740a14dcea50519fd to your computer and use it in GitHub Desktop.
use rayon::{prelude::*, ThreadPoolBuilder};
use pyo3::{exceptions::PyRuntimeError, prelude::*, pybacked::PyBackedStr, types::PyString};
use anyhow::Error;
type PyStringRef = Py<PyString>;
#[derive(Clone)]
struct PyStringWrapper {
py_bytes: PyStringRef,
backed_str: PyBackedStr,
}
impl PyStringWrapper {
fn from_bound(bound: Bound<'_, PyString>) -> Result<Self, Error> {
let py_bytes = bound.clone().unbind();
let backed_str = PyBackedStr::try_from(bound)?;
Ok(Self {
py_bytes,
backed_str,
})
}
fn as_str(&self) -> &str {
&self.backed_str
}
fn into_py_bytes(self) -> PyStringRef {
self.py_bytes
}
// fn into_py_str(self) -> Py<PyString> {
// self.py_str
// }
}
#[pyfunction]
#[pyo3(signature = (input,input2))]
pub(crate) fn test<'py>(
input: Vec<[Bound<'py, PyString>; 2]>,
input2: Vec<(Bound<'py, PyString>, i32)>,
) -> Vec<(Py<PyString>, i32)> {
let input_unbound = input
.into_iter()
.map(|e| {
let mut e_iter = e.into_iter();
Ok([
PyStringWrapper::from_bound(e_iter.next().unwrap())?,
PyStringWrapper::from_bound(e_iter.next().unwrap())?,
])
})
.collect::<Result<Vec<_>, Error>>()
.unwrap();
let input2_unbound = input2
.into_iter()
.map(|e| {
// check_first_byte_is_digit(e.0.to_str()?);
Ok((PyStringWrapper::from_bound(e.0)?, e.1))
})
.collect::<Result<Vec<_>, Error>>()
.map_err(|err| PyRuntimeError::new_err(err.to_string()))
.unwrap();
let tp = ThreadPoolBuilder::new()
.num_threads(4 as usize)
.build()
.unwrap();
tp.install(|| {
input_unbound
.iter()
.flat_map(|a| {
input2_unbound
.iter()
.cloned()
.map(move |b| (a, b))
.collect::<Vec<_>>()
})
.map(|(e, r)| {
let res = rust_func(e[0].as_str(), e[1].as_str(), r.0.as_str(), r.1);
(r.0.into_py_bytes(), res)
})
.collect::<Vec<_>>()
})
}
fn rust_func(x: &str, y: &str, z: &str, n: i32) -> i32 {
println!("x={}, y={}, z={}, n={}", x, y, z, n);
n * 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment