Skip to content

Instantly share code, notes, and snippets.

@davipatti
Last active June 2, 2023 20:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save davipatti/e56b99f44da3487eb48d55119f814f57 to your computer and use it in GitHub Desktop.
Save davipatti/e56b99f44da3487eb48d55119f814f57 to your computer and use it in GitHub Desktop.
[aesara scan IndexError] Using aesara.scan to make values 0 if any three preceding values in a column are non-zero. Output is as expected, but errors are logged, and IndexError gets raised. #aesara
import aesara as ae
import numpy as np
print("aesara version", ae.__version__)
np.random.seed(42)
m, n = 10, 12
arr = np.random.choice([0, 1], p=[0.75, 0.25], size=m * n).reshape(m, n)
print("\ninput:")
print(arr)
def mask_prev_three(arr: ae.tensor.TensorLike, n: int) -> ae.tensor.TensorVariable:
taps = -3, -2, -1
initial = ae.tensor.zeros((3, n), "int8")
masked, _ = ae.scan(
lambda i0, im3, im2, im1: ae.tensor.switch(im3 | im2 | im1, 0, i0),
sequences=arr,
outputs_info=dict(taps=taps, initial=initial),
)
return masked
masked = mask_prev_three(ae.tensor.as_tensor(arr, dtype="int8"), n).eval()
print("\noutput:")
print(masked)
aesara version 2.9.0
input:
[[0 1 0 0 0 0 0 1 0 0 0 1]
[1 0 0 0 0 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 1 1 1]
[0 0 0 0 0 0 0 1 0 0 0 0]
[0 0 1 1 1 1 0 1 0 0 0 0]
[0 0 1 0 0 0 0 1 0 1 1 0]
[0 1 0 0 1 0 0 0 1 0 0 0]
[0 0 0 0 1 0 0 0 1 0 1 0]
[0 0 0 0 0 0 0 0 1 0 0 1]
[0 0 0 0 1 1 0 1 1 0 1 0]]
ERROR (aesara.graph.rewriting.basic): Rewrite failure due to: save_mem_new_scan
ERROR (aesara.graph.rewriting.basic): node: for{cpu,scan_fn}(TensorConstant{10}, TensorConstant{[[0 1 0 0 .. 1 0 1 0]]}, IncSubtensor{Set;:int64:}.0)
ERROR (aesara.graph.rewriting.basic): TRACEBACK:
ERROR (aesara.graph.rewriting.basic): Traceback (most recent call last):
File "/Users/pattinson/.virtualenvs/aesara-env/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1926, in process_node
replacements = node_rewriter.transform(fgraph, node)
File "/Users/pattinson/.virtualenvs/aesara-env/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1086, in transform
return self.fn(fgraph, node)
File "/Users/pattinson/.virtualenvs/aesara-env/lib/python3.10/site-packages/aesara/scan/rewriting.py", line 1431, in save_mem_new_scan
nw_input = expand_empty(_nw_input, tmp_idx)
File "/Users/pattinson/.virtualenvs/aesara-env/lib/python3.10/site-packages/aesara/scan/utils.py", line 239, in expand_empty
new_shape = [size + shapes[0]] + shapes[1:]
IndexError: list index out of range
output:
[[0 1 0 0 0 0 0 1 0 0 0 1]
[1 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 1 1 0]
[0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 1 1 1 1 0 1 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 1 0 0 0]
[0 0 0 0 0 0 0 0 0 0 1 0]
[0 0 0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 1 1 0 1 0 0 0 0]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment