Skip to content

Instantly share code, notes, and snippets.

@GleasonK
Created April 4, 2024 21:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GleasonK/ad7bc21bd0c923763cb0423cc6424f83 to your computer and use it in GitHub Desktop.
Save GleasonK/ad7bc21bd0c923763cb0423cc6424f83 to your computer and use it in GitHub Desktop.
// import jax
// import jax.numpy as jnp
// from jax.experimental.export import export
//
// def f(a,b):
// return jnp.add(a,b)
//
// lhs = jax.ShapeDtypeStruct(export.symbolic_shape("a,10"), np.float32)
// rhs = jax.ShapeDtypeStruct(export.symbolic_shape("1"), np.float32)
// export.export(f)(polyA, poly2A)
//
module @jit_f attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<?xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<2x?xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<2x?xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
%1 = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor<2x?xf32>) -> tensor<i32>
%2 = stablehlo.constant dense<1> : tensor<i32>
%3 = stablehlo.compare GE, %0, %2, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.custom_call @shape_assertion(%3, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a,),args[1].shape = (2, a). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>) -> ()
%4 = stablehlo.compare EQ, %1, %0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.custom_call @shape_assertion(%4, %1, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Found inconsistency between dimension size args[1].shape[1] (= {0}) and the specification 'a' (= {1}). Using the following polymorphic shapes specifications: args[0].shape = (a,),args[1].shape = (2, a). Obtained dimension variables: 'a' = {1} from specification 'a' for dimension args[0].shape[0] (= {1}), . Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details.", has_side_effect = true} : (tensor<i1>, tensor<i32>, tensor<i32>) -> ()
%5 = stablehlo.constant dense<> : tensor<0xi1>
%6 = call @_wrapped_jax_export_main(%0, %arg0, %arg1) : (tensor<i32>, tensor<?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
return %6 : tensor<2x?xf32>
}
func.func private @_wrapped_jax_export_main(%arg0: tensor<i32> {jax.global_constant = "a"}, %arg1: tensor<?xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg2: tensor<2x?xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<2x?xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.constant dense<1> : tensor<i32>
%1 = stablehlo.constant dense<1> : tensor<1xi32>
%2 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%3 = stablehlo.concatenate %1, %2, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%4 = stablehlo.dynamic_broadcast_in_dim %arg1, %3, dims = [1] : (tensor<?xf32>, tensor<2xi32>) -> tensor<1x?xf32>
%5 = stablehlo.constant dense<2> : tensor<i32>
%6 = stablehlo.constant dense<2> : tensor<1xi32>
%7 = stablehlo.reshape %arg0 : (tensor<i32>) -> tensor<1xi32>
%8 = stablehlo.concatenate %6, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%9 = stablehlo.dynamic_broadcast_in_dim %4, %8, dims = [0, 1] : (tensor<1x?xf32>, tensor<2xi32>) -> tensor<2x?xf32>
%10 = stablehlo.add %9, %arg2 : tensor<2x?xf32>
return %10 : tensor<2x?xf32>
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment