Skip to content

Instantly share code, notes, and snippets.

@edlanglois
Last active November 23, 2021 06:31
Show Gist options
  • Save edlanglois/5444813b409a9b44f4067591bf0f40b0 to your computer and use it in GitHub Desktop.
Save edlanglois/5444813b409a9b44f4067591bf0f40b0 to your computer and use it in GitHub Desktop.
ExclusiveTensor providing hopefully-safe views of tch::Tensor data
// This file is dual-licensed under the terms of the MIT or Apache 2.0 licenses.
//
// == MIT ==
// Copyright © 2021 Eric Langlois <eric@langlois.xyz>
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included
// in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
// OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
// IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
// TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
// OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
//
// == Apache 2.0 ==
// Copyright © 2021 Eric Langlois <eric@langlois.xyz>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use ndarray::{ArrayView, ArrayViewMut, Dim, Dimension, IntoDimension, Ix, IxDyn};
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::{mem, slice};
use tch::{kind::Element, Device, Kind, Tensor};
/// An exclusive owner of a [`Tensor`] and its data.
///
/// Given an ordinary `Tensor`, it is impossible to reason about the lifetime of the data at
/// [`Tensor::data_ptr`]. Copies created by [`Tensor::shallow_clone`] share the same underlying
/// tensor object and can cause the data memory to be moved or reallocated at any time (for
/// example, by calling `Tensor::resize_`].
///
/// To avoid this issue `ExclusiveTensor` manages the creation of the tensor such that it has
/// exclusive access to the underlying data.
///
/// The managed tensor always lives on the CPU.
#[derive(Debug)]
pub struct ExclusiveTensor<E, D>
where
D: Dimension,
{
tensor: Tensor,
/// Track shape to avoid runtime checks
shape: D,
/// Number of elements in the tensor
num_elements: usize,
/// Track element type
element_type: PhantomData<E>,
}
impl<E, D> ExclusiveTensor<E, D>
where
E: Element,
D: Dimension + IntoTorchShape,
{
/// Create a zero-initialized tensor.
pub fn zeros<Sh: IntoDimension<Dim = D>>(shape: Sh) -> Self {
unsafe {
Self::from_tensor_fn(shape, |shape, kind| {
Tensor::zeros(shape, (kind, Device::Cpu))
})
}
}
/// Create a one-initialized tensor.
pub fn ones<Sh: IntoDimension<Dim = D>>(shape: Sh) -> Self {
unsafe {
Self::from_tensor_fn(shape, |shape, kind| {
Tensor::ones(shape, (kind, Device::Cpu))
})
}
}
/// Initialize given a tensor construction function.
///
/// # Safety
/// The constructed tensor must
/// * have number of elements corresponding to `shape`,
/// * have elements of type `E`,
/// * use `Device::Cpu`, and
/// * exclusively manage its own memory (e.g. no `shallow_clone`).
///
/// # Panics
/// If the total size of all elements exceeds `isize::MAX`.
unsafe fn from_tensor_fn<Sh, F>(shape: Sh, f: F) -> Self
where
Sh: IntoDimension<Dim = D>,
F: FnOnce(&[i64], Kind) -> Tensor,
{
let shape = shape.into_dimension();
let num_elements = match shape.size_checked() {
Some(size) if size < isize::MAX as usize => size,
_ => panic!("number of elements must not exceed isize::MAX"),
};
match num_elements.checked_mul(mem::size_of::<E>()) {
Some(size) if size < isize::MAX as usize => {}
_ => panic!("size of allocated memory must not exceed isize::MAX"),
}
let tensor = f(shape.clone().into_torch_shape().as_ref(), E::KIND);
Self {
tensor,
shape,
num_elements,
element_type: PhantomData,
}
}
}
impl<E, D: Dimension> ExclusiveTensor<E, D> {
/// Convert into the inner tensor.
pub fn into_tensor(self) -> Tensor {
self.tensor
}
}
impl<E, D> ExclusiveTensor<E, D>
where
E: Element,
D: Dimension,
{
/// View the tensor data as a slice.
pub fn as_slice(&self) -> &[E] {
// # Safety
// ✓ **data must be valid for reads for `len * mem::size_of::<T>()` many bytes,
// and it must be properly aligned.**
// The tensor is storing that amount of data at the pointer, so long as the size is
// non-empty. The pointer is NonNull::dangling for empty tensors.
//
// ✓ **data must point to len consecutive properly initialized values of type T.**
// The tensor has been fully initialized with valid data.
//
// ✓ **The memory referenced by the returned slice must not be mutated for the duration of
// lifetime 'a, except inside an UnsafeCell.**
// Managed by the lifetime of self, which has exclusive access to the tensor memory.
//
// ✓ **The total size len * mem::size_of::<T>() must be no larger than isize::MAX.**
// Asserted in construction and probably must hold for Tensor anyways.
unsafe { slice::from_raw_parts(self.data_ptr().as_ptr(), self.num_elements) }
}
/// View the tensor data as a mutable slice.
pub fn as_slice_mut(&mut self) -> &mut [E] {
// # Safety
// See `Self::as_slice` implementation
unsafe { slice::from_raw_parts_mut(self.data_ptr().as_ptr(), self.num_elements) }
}
/// View as an n-dimensional array.
pub fn array_view(&self) -> ArrayView<E, D> {
// # Safety
//
// ✓ **Elements must live as long as 'a (in ArrayView<'a, E, D>).**
// Managed by the lifetime of self, which has exclusive access to the tensor memory.
//
// ✓ **ptr must be non-null and aligned, and it must be safe to .offset() ptr by zero.**
// This is up to torch but it should be true for non-empty tensors since data is being
// stored at this pointer value.
// In the case of empty tensors, the data pointer is NonNull::dangling.
//
// ? **It must be safe to .offset() the pointer repeatedly along all axes and calculate the
// counts for the .offset() calls without overflow, even if the array is empty or the
// elements are zero-sized.**
// Up to pytorch but again it should be true since the full tensor's worth of data is
// being stored at this pointer value.
//
// ✓ **The product of non-zero axis lengths must not exceed isize::MAX.**
// Asserted in constructors; but probably a similar constraint applies to the tensor
// creation by pytorch.
//
// ✓ **Strides must be non-negative.**
// Dimension as IntoDimension as Into<StrideShape> always uses C-style strides
// which have a value of 0 or 1 depending on the array shape.
unsafe { ArrayView::from_shape_ptr(self.shape.clone(), self.data_ptr().as_ptr()) }
}
/// View as a mutable n-dimensional array.
pub fn array_view_mut(&mut self) -> ArrayViewMut<E, D> {
// # Safety
// See `Self::array_view` implementation
unsafe { ArrayViewMut::from_shape_ptr(self.shape.clone(), self.data_ptr().as_ptr()) }
}
/// The current tensor data pointer; may be dangling if the tensor is empty.
///
/// This is not cached in case additional methods are added that can cause the tensor to
/// re-allocate.
fn data_ptr(&self) -> NonNull<E> {
if self.num_elements == 0 {
NonNull::dangling()
} else {
NonNull::new(self.tensor.data_ptr() as _).expect("unexpected null data_ptr")
}
}
}
impl<E, D: Dimension> From<ExclusiveTensor<E, D>> for Tensor {
fn from(exclusive: ExclusiveTensor<E, D>) -> Self {
exclusive.into_tensor()
}
}
impl<'a, E, D> From<&'a ExclusiveTensor<E, D>> for ArrayView<'a, E, D>
where
E: Element,
D: Dimension,
{
fn from(exclusive: &'a ExclusiveTensor<E, D>) -> Self {
exclusive.array_view()
}
}
fn to_i64(x: Ix) -> i64 {
x.try_into().expect("dimension too large")
}
/// Convert an ndarray-style dimension into the shape type used by [`tch`].
pub trait IntoTorchShape {
type TorchDim: AsRef<[i64]>;
fn into_torch_shape(self) -> Self::TorchDim;
}
impl IntoTorchShape for IxDyn {
type TorchDim = Vec<i64>;
fn into_torch_shape(self) -> Self::TorchDim {
self.as_array_view()
.into_iter()
.map(|&x| to_i64(x))
.collect()
}
}
impl IntoTorchShape for Dim<[Ix; 0]> {
type TorchDim = [i64; 0];
fn into_torch_shape(self) -> Self::TorchDim {
[]
}
}
impl IntoTorchShape for Dim<[Ix; 1]> {
type TorchDim = [i64; 1];
fn into_torch_shape(self) -> Self::TorchDim {
[self.into_pattern() as _]
}
}
impl IntoTorchShape for Dim<[Ix; 2]> {
type TorchDim = [i64; 2];
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b) = self.into_pattern();
[to_i64(a), to_i64(b)]
}
}
impl IntoTorchShape for Dim<[Ix; 3]> {
type TorchDim = [i64; 3];
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b, c) = self.into_pattern();
[to_i64(a), to_i64(b), to_i64(c)]
}
}
impl IntoTorchShape for Dim<[Ix; 4]> {
type TorchDim = [i64; 4];
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b, c, d) = self.into_pattern();
[to_i64(a), to_i64(b), to_i64(c), to_i64(d)]
}
}
impl IntoTorchShape for Dim<[Ix; 5]> {
type TorchDim = [i64; 5];
#[allow(clippy::many_single_char_names)]
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b, c, d, e) = self.into_pattern();
[to_i64(a), to_i64(b), to_i64(c), to_i64(d), to_i64(e)]
}
}
impl IntoTorchShape for Dim<[Ix; 6]> {
type TorchDim = [i64; 6];
#[allow(clippy::many_single_char_names)]
fn into_torch_shape(self) -> Self::TorchDim {
let (a, b, c, d, e, f) = self.into_pattern();
[
to_i64(a),
to_i64(b),
to_i64(c),
to_i64(d),
to_i64(e),
to_i64(f),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr2, Array};
#[test]
fn zeros() {
let u = ExclusiveTensor::<f32, _>::zeros([2, 4, 3]);
let tensor: Tensor = u.into();
assert_eq!(tensor.size(), vec![2, 4, 3]);
assert_eq!(tensor.kind(), Kind::Float);
assert_eq!(tensor.device(), Device::Cpu);
assert_eq!(
tensor,
Tensor::zeros(&[2, 4, 3], (Kind::Float, Device::Cpu))
);
}
#[test]
fn ones() {
let u = ExclusiveTensor::<f32, _>::ones([2, 4, 3]);
let tensor: Tensor = u.into();
assert_eq!(tensor.size(), vec![2, 4, 3]);
assert_eq!(tensor.kind(), Kind::Float);
assert_eq!(tensor.device(), Device::Cpu);
assert_eq!(tensor, Tensor::ones(&[2, 4, 3], (Kind::Float, Device::Cpu)));
}
#[test]
#[allow(clippy::float_cmp)]
fn slice_f64() {
let u = ExclusiveTensor::<f64, _>::ones([3, 1, 2]);
assert_eq!(u.as_slice().len(), 6);
assert_eq!(u.as_slice(), &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn slice_mut_i16() {
let mut u = ExclusiveTensor::<i16, _>::ones([3, 1, 2]);
assert_eq!(u.as_slice_mut().len(), 6);
for (i, x) in u.as_slice_mut().iter_mut().enumerate() {
*x = i.try_into().unwrap()
}
assert_eq!(u.as_slice(), &[0, 1, 2, 3, 4, 5]);
let tensor: Tensor = u.into();
assert_eq!(
tensor,
Tensor::of_slice(&[0, 1, 2, 3, 4, 5]).reshape(&[3, 1, 2])
);
}
#[test]
fn array_view_f32() {
let u = ExclusiveTensor::<f32, _>::ones([2, 4, 3]);
let view = u.array_view();
assert_eq!(view.dim(), (2, 4, 3));
assert_eq!(view, Array::ones((2, 4, 3)));
}
#[test]
#[allow(clippy::unit_cmp)]
fn array_view_i64_scalar() {
let u = ExclusiveTensor::<i64, _>::ones([]);
let view = u.array_view();
assert_eq!(view.dim(), ());
assert_eq!(view.into_scalar(), &1);
}
#[test]
fn array_view_f32_empty() {
let u = ExclusiveTensor::<f32, _>::ones([0]);
let view = u.array_view();
assert_eq!(view.dim(), 0);
assert!(view.as_slice().unwrap().is_empty());
}
#[test]
fn array_view_mut() {
let mut u = ExclusiveTensor::<i32, _>::ones([3, 4]);
let mut view = u.array_view_mut();
for (i, mut row) in view.rows_mut().into_iter().enumerate() {
for (j, cell) in row.iter_mut().enumerate() {
*cell = (i * 10 + j).try_into().unwrap();
}
}
let expected = arr2(&[[0, 1, 2, 3], [10, 11, 12, 13], [20, 21, 22, 23]]);
assert_eq!(view, expected); // Compare as arrays
let t: Tensor = u.into();
let expected: Tensor = expected.try_into().unwrap();
assert_eq!(t, expected); // Compare as tensors
}
#[test]
fn array_view_mut_empty() {
let mut u = ExclusiveTensor::<f32, _>::ones([2, 0, 3]);
let mut view = u.array_view_mut();
assert!(view.as_slice_mut().unwrap().is_empty());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment