Skip to content

Instantly share code, notes, and snippets.

@s3rius
Last active June 25, 2024 14:17
Show Gist options
  • Save s3rius/3bf4a0bd6b28ca1ae94376aa290f8f1c to your computer and use it in GitHub Desktop.
Save s3rius/3bf4a0bd6b28ca1ae94376aa290f8f1c to your computer and use it in GitHub Desktop.
PyO3-asyncio async streams
[package]
name = "itertest"
version = "0.1.0"
edition = "2021"
[dependencies]
futures = "0.3.28"
pyo3 = "0.19.2"
pyo3-asyncio = { version = "0.19.0", features = ["tokio-runtime"] }
tokio = { version = "1.32.0", features = ["sync"] }
use std::sync::Arc;
use futures::{Stream, StreamExt};
use pyo3::{
exceptions::PyStopAsyncIteration, pymethods, pymodule, types::PyModule, PyObject, PyRef,
PyResult, Python,
};
/// Here we define our Rust type,
/// that implements the Stream trait.
///
/// It iterates from 1 to i.
pub struct RustStreamer {
i: u32,
current: u32,
}
impl RustStreamer {
pub fn new(i: u32) -> Self {
RustStreamer { i, current: 0 }
}
}
/// Here goes stream implementation.
///
///
/// It's a simple stream. On each poll_next call it returns next value.
/// If current value is equal to i, it returns None, which means,
/// that stream is finished.
impl Stream for RustStreamer {
type Item = u32;
fn poll_next(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.current < this.i {
this.current += 1;
std::task::Poll::Ready(Some(this.current))
} else {
std::task::Poll::Ready(None)
}
}
}
/// Here I defined a class that can be used,
/// as an async iterator.
///
/// It's a simple class, that has an inner field,
/// which is an object that implements the Stream trait.
///
/// But I wrap it in Mutex<...> to make it thread safe
/// and shareable between tokio-threads. Arc here, because
/// it's cheap to clone.
///
/// Also, without mutex, it's not possible to mutate
/// the data inside the Arc.
#[pyo3::pyclass]
struct TestIterator {
pub inner: Arc<tokio::sync::Mutex<RustStreamer>>,
}
#[pymethods]
impl TestIterator {
#[new]
fn new(i: u32) -> Self {
TestIterator {
inner: Arc::new(tokio::sync::Mutex::new(RustStreamer::new(i))),
}
}
/// We don't want to create another classes, we want this
/// class to be iterable. Since we implemented __anext__ method,
/// we can return self here.
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
/// This is an anext implementation.
///
/// Notable thing here is that we return PyResult<Option<PyObject>>.
/// We cannot return &PyAny directly here, because of pyo3 limitations.
/// Here's the issue about it: https://github.com/PyO3/pyo3/issues/3190
fn __anext__<'a>(&self, py: Python<'a>) -> PyResult<Option<PyObject>> {
// Here we clone the inner field, so we can use it
// in our future.
let streamer = self.inner.clone();
let future = pyo3_asyncio::tokio::future_into_py(py, async move {
// Here we lock the mutex to access the data inside
// and call next() method to get the next value.
let val = streamer.lock().await.next().await;
match val {
Some(val) => Ok(val),
// Here we return PyStopAsyncIteration error,
// because python needs exceptions to tell that iterator
// has ended.
None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")),
}
});
Ok(Some(future?.into()))
}
}
#[pymodule]
fn _internal(_py: Python<'_>, pymod: &PyModule) -> PyResult<()> {
pymod.add_class::<TestIterator>()?;
Ok(())
}
[project]
name = "itertest"
[tool.maturin]
python-source = "python"
module-name = "itertest._internal"
features = ["pyo3/extension-module"]
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
import asyncio
from itertest._internal import TestIterator
async def main():
ti = TestIterator(i=5)
async for i in ti:
print(i)
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment