Last active
September 23, 2017 19:01
-
-
Save edubart/00c09fe14390be4acefa610a01f51e11 to your computer and use it in GitHub Desktop.
Arraymancer Utilities
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Seq of tensors to tensor | |
proc toTensor*[T](s: openarray[Tensor[T]]): Tensor[T] = | |
s.map(proc(t: Tensor[T]): Tensor[T] = t.unsafeUnsqueeze(0)).concat(0) | |
# Make universal does not work with abs | |
proc abs*[T](t: Tensor[T]): Tensor[T] = | |
t.map(proc(x: T):T = abs(x)) | |
# Seq to reshaped tensor, no copy | |
proc unsafeToTensorReshape[T](data: seq[T], shape: openarray[int]): Tensor[T] {.noSideEffect.} = | |
result.shape = @shape | |
result.strides = shape_to_strides(result.shape) | |
result.offset = 0 | |
shallowCopy(result.data, data) | |
# This is not the full implementation | |
template unsafeAt[T](t: Tensor[T], x: int): Tensor[T] = | |
t.unsafeView(x, _, _).unsafeReshape([t.shape[1], t.shape[2]]) | |
# unsafeSqueeze on axis | |
proc unsafeSqueeze*[T](t: Tensor[T], axis: int): Tensor[T] {.noSideEffect,inline.} = | |
var shape = t.shape | |
assert shape[axis] == 1 | |
shape.delete(axis) | |
t.unsafeReshape(shape) | |
# unsqueeze on axis | |
proc unsafeUnsqueeze*(t: Tensor, axis: int): Tensor = | |
var shape = t.shape | |
shape.insert(1, axis) | |
t.reshape(shape) | |
proc unsafeTranspose*(t: Tensor): Tensor {.noSideEffect, inline.}= | |
## Transpose a Tensor. | |
## | |
## For N-d Tensor with shape (0, 1, 2 ... n-1) the resulting tensor will have shape (n-1, ... 2, 1, 0) | |
## | |
## Data is copied as-is and not modified. | |
result.shape = t.shape.reversed | |
result.strides = t.strides.reversed | |
result.offset = t.offset | |
shallowCopy(result.data, t.data) | |
proc fold2*[T](t1: Tensor[T], | |
start_val: T, | |
f: (T, T, T)-> T, | |
t2: Tensor[T] | |
): T {.noSideEffect.}= | |
result = start_val | |
for ai, bi in zip(t1.values, t2.values): | |
result = f(result, ai, bi) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment