Last active
June 2, 2023 20:33
-
-
Save davipatti/e56b99f44da3487eb48d55119f814f57 to your computer and use it in GitHub Desktop.
[aesara scan IndexError] Using aesara.scan to make values 0 if any three preceding values in a column are non-zero. Output is as expected, but errors are logged, and IndexError gets raised. #aesara
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import aesara as ae | |
import numpy as np | |
print("aesara version", ae.__version__) | |
np.random.seed(42) | |
m, n = 10, 12 | |
arr = np.random.choice([0, 1], p=[0.75, 0.25], size=m * n).reshape(m, n) | |
print("\ninput:") | |
print(arr) | |
def mask_prev_three(arr: ae.tensor.TensorLike, n: int) -> ae.tensor.TensorVariable: | |
taps = -3, -2, -1 | |
initial = ae.tensor.zeros((3, n), "int8") | |
masked, _ = ae.scan( | |
lambda i0, im3, im2, im1: ae.tensor.switch(im3 | im2 | im1, 0, i0), | |
sequences=arr, | |
outputs_info=dict(taps=taps, initial=initial), | |
) | |
return masked | |
masked = mask_prev_three(ae.tensor.as_tensor(arr, dtype="int8"), n).eval() | |
print("\noutput:") | |
print(masked) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
aesara version 2.9.0 | |
input: | |
[[0 1 0 0 0 0 0 1 0 0 0 1] | |
[1 0 0 0 0 0 0 0 0 0 0 0] | |
[0 1 0 0 0 0 0 0 0 1 1 1] | |
[0 0 0 0 0 0 0 1 0 0 0 0] | |
[0 0 1 1 1 1 0 1 0 0 0 0] | |
[0 0 1 0 0 0 0 1 0 1 1 0] | |
[0 1 0 0 1 0 0 0 1 0 0 0] | |
[0 0 0 0 1 0 0 0 1 0 1 0] | |
[0 0 0 0 0 0 0 0 1 0 0 1] | |
[0 0 0 0 1 1 0 1 1 0 1 0]] | |
ERROR (aesara.graph.rewriting.basic): Rewrite failure due to: save_mem_new_scan | |
ERROR (aesara.graph.rewriting.basic): node: for{cpu,scan_fn}(TensorConstant{10}, TensorConstant{[[0 1 0 0 .. 1 0 1 0]]}, IncSubtensor{Set;:int64:}.0) | |
ERROR (aesara.graph.rewriting.basic): TRACEBACK: | |
ERROR (aesara.graph.rewriting.basic): Traceback (most recent call last): | |
File "/Users/pattinson/.virtualenvs/aesara-env/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1926, in process_node | |
replacements = node_rewriter.transform(fgraph, node) | |
File "/Users/pattinson/.virtualenvs/aesara-env/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1086, in transform | |
return self.fn(fgraph, node) | |
File "/Users/pattinson/.virtualenvs/aesara-env/lib/python3.10/site-packages/aesara/scan/rewriting.py", line 1431, in save_mem_new_scan | |
nw_input = expand_empty(_nw_input, tmp_idx) | |
File "/Users/pattinson/.virtualenvs/aesara-env/lib/python3.10/site-packages/aesara/scan/utils.py", line 239, in expand_empty | |
new_shape = [size + shapes[0]] + shapes[1:] | |
IndexError: list index out of range | |
output: | |
[[0 1 0 0 0 0 0 1 0 0 0 1] | |
[1 0 0 0 0 0 0 0 0 0 0 0] | |
[0 0 0 0 0 0 0 0 0 1 1 0] | |
[0 0 0 0 0 0 0 0 0 0 0 0] | |
[0 0 1 1 1 1 0 1 0 0 0 0] | |
[0 0 0 0 0 0 0 0 0 0 0 0] | |
[0 1 0 0 0 0 0 0 1 0 0 0] | |
[0 0 0 0 0 0 0 0 0 0 1 0] | |
[0 0 0 0 0 0 0 0 0 0 0 1] | |
[0 0 0 0 1 1 0 1 0 0 0 0]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment