Skip to content

Instantly share code, notes, and snippets.

@robertknight
Last active January 24, 2024 07:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save robertknight/ad54cc02a79d0824e6e576401d3d433e to your computer and use it in GitHub Desktop.
Save robertknight/ad54cc02a79d0824e6e576401d3d433e to your computer and use it in GitHub Desktop.
rten-ndarray conversion
use ndarray::{Array, Array2, ArrayView, Dim, Dimension, Ix, StrideShape};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};
/// Convert an N-dimensional ndarray view to an [NdTensorView].
///
/// Returns `None` if the view is not in the standard layout (see
/// [ArrayView::is_standard_layout]).
fn as_ndtensor_view<'a, T, const N: usize>(
view: ArrayView<'a, T, Dim<[Ix; N]>>,
) -> Option<NdTensorView<'a, T, N>>
where
Dim<[Ix; N]>: Dimension,
{
view.to_slice().map(|slice| {
let shape: [usize; N] = view.shape().try_into().unwrap();
NdTensorView::from_data(shape, slice)
})
}
/// Convert an N-dimensional [NdTensorView] into an ndarray view.
///
/// Returns `None` if the view is not in "standard layout" (see
/// [ArrayView::is_standard_layout]).
fn as_array_view<'a, T, const N: usize>(
view: NdTensorView<'a, T, N>,
) -> Option<ArrayView<'a, T, Dim<[Ix; N]>>>
where
Dim<[Ix; N]>: Dimension,
[usize; N]: Into<StrideShape<Dim<[Ix; N]>>>,
{
view.data()
.map(|data| ArrayView::from_shape(view.shape(), data).unwrap())
}
/// Convert an N-dimensional [NdTensor] into an ndarray.
fn into_array<T, const N: usize>(tensor: NdTensor<T, N>) -> Array<T, Dim<[Ix; N]>>
where
Dim<[Ix; N]>: Dimension,
[usize; N]: Into<StrideShape<Dim<[Ix; N]>>>,
{
Array::from_shape_vec(tensor.shape(), tensor.into_data()).unwrap()
}
fn main() {
// Owned ndarray => NdTensorView
let mut array: Array2<f32> = Array2::zeros([2, 2]);
array[[0, 0]] = 1.;
array[[0, 1]] = 2.;
array[[1, 0]] = 3.;
array[[1, 1]] = 4.;
let view = as_ndtensor_view(array.view()).expect("non-contiguous view");
for (idx, el) in view.indices().zip(view.iter()) {
println!("index {:?} element {}", idx, el);
}
// NdTensor => ArrayView
let permuted_owned = view.permuted([1, 0]).to_tensor();
let ndarray_view = as_array_view(permuted_owned.view()).expect("non-contiguous view");
println!("ndarray_view {:?}", ndarray_view);
// Ndtensor => Array
let ndarray = into_array(permuted_owned);
println!("ndarray {:?}", ndarray);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment