Skip to content

Instantly share code, notes, and snippets.

@sergei-mironov
Created November 10, 2023 10:14
Show Gist options
  • Save sergei-mironov/aeebdf7cdeaf7c600ee22b1e0229e621 to your computer and use it in GitHub Desktop.
Save sergei-mironov/aeebdf7cdeaf7c600ee22b1e0229e621 to your computer and use it in GitHub Desktop.
jax-0.4.19-dynshape.patch
diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py
index d168d22ab..798d07a6d 100644
--- a/jax/_src/interpreters/mlir.py
+++ b/jax/_src/interpreters/mlir.py
@@ -1461,7 +1461,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
ans, "lowering function returned a bad output", eqn)
assert len(ans) == len(eqn.outvars), (ans, eqn)
map(write, eqn.outvars, out_nodes)
- core.clean_up_dead_vars(eqn, env, last_used)
+ # core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars), tokens
# See docstring for lower_multi_platform.
diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py
index 2e3ac7005..82e86080f 100644
--- a/jax/_src/lax/slicing.py
+++ b/jax/_src/lax/slicing.py
@@ -1534,15 +1534,15 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
f"output_slice_sizes={offset_dims}, collapsed_slice_dims="
f"{collapsed_slice_dims}.")
- for i in range(len(slice_sizes)):
- slice_size = slice_sizes[i]
- corresponding_input_size = operand.shape[i]
-
- if not (slice_size >= 0 and
- corresponding_input_size >= slice_size):
- raise TypeError(f"Slice size at index {i} in gather op is out of range, "
- f"must be within [0, {corresponding_input_size} + 1), "
- f"got {slice_size}.")
+ # for i in range(len(slice_sizes)):
+ # slice_size = slice_sizes[i]
+ # corresponding_input_size = operand.shape[i]
+
+ # if not (slice_size >= 0 and
+ # corresponding_input_size >= slice_size):
+ # raise TypeError(f"Slice size at index {i} in gather op is out of range, "
+ # f"must be within [0, {corresponding_input_size} + 1), "
+ # f"got {slice_size}.")
for i in range(len(collapsed_slice_dims)):
bound = slice_sizes[collapsed_slice_dims[i]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment