Created
July 18, 2023 13:53
-
-
Save hr0nix/195f1ece2e6cde792cd0ae0e2fbf6357 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
============================= test session starts ============================== | |
platform linux -- Python 3.10.12, pytest-7.3.2, pluggy-1.0.0 | |
rootdir: /papyrax/tests | |
configfile: pytest.ini | |
collected 1 item | |
tests/test_attention.py F [100%] | |
=================================== FAILURES =================================== | |
_____________________ test_triton_flash_attention[384-384] _____________________ | |
ctx = TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, b...9660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>]) | |
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]} | |
d:Ref{int32[384]} e:Ref{int32[3...] bl | |
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p | |
h[m:m+128,:] <- dz | |
in () } | |
block_infos = [BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), start_indices=(<triton.language.c...e.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), ...] | |
args = (<triton.language.core.tensor object at 0x7fbb28150c70>, <triton.language.core.tensor object at 0x7fbb281510f0>, <trit...>, <triton.language.core.tensor object at 0x7fbb28198460>, <triton.language.core.tensor object at 0x7fbb28198b20>, ...) | |
read_env = <function lower_jaxpr_to_triton_ir.<locals>.read_env at 0x7fbc21d9ab00> | |
read_block_info_env = <function lower_jaxpr_to_triton_ir.<locals>.read_block_info_env at 0x7fbc21d9ae60> | |
write_env = <function lower_jaxpr_to_triton_ir.<locals>.write_env at 0x7fbc21d9aef0> | |
def lower_jaxpr_to_triton_ir( | |
ctx: TritonModuleContext, | |
jaxpr: jax_core.Jaxpr, | |
block_infos: Optional[Sequence[Optional[BlockInfo]]], | |
*args | |
) -> Sequence[Any]: | |
env = {} | |
block_info_env = {} | |
def read_env(var: jax_core.Atom): | |
if type(var) is jax_core.Literal: | |
t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder) | |
dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty | |
if t.type.scalar != dst_ty: | |
# _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64 | |
# comes out of .tolist() as list[float], which then comes out of | |
# _to_tensor as a block of f32. | |
t = tl.semantic.cast(t, dst_ty, ctx.builder) | |
return t | |
return env[var] | |
def read_block_info_env(var: jax_core.Atom): | |
if type(var) is jax_core.Literal: | |
return None | |
return block_info_env.get(var, None) | |
def write_env(var: jax_core.Var, val): | |
env[var] = val | |
if block_infos is None: | |
block_infos = [None] * len(jaxpr.invars) | |
for invar, block_info in zip(jaxpr.invars, block_infos): | |
block_info_env[invar] = block_info | |
map(write_env, jaxpr.invars, args) | |
for eqn in jaxpr.eqns: | |
invals = map(read_env, eqn.invars) | |
if eqn.primitive not in triton_lowering_rules: | |
raise NotImplementedError(eqn.primitive) | |
rule = triton_lowering_rules[eqn.primitive] | |
avals_in = [v.aval for v in eqn.invars] | |
avals_out = [v.aval for v in eqn.outvars] | |
eqn_block_infos = map(read_block_info_env, eqn.invars) | |
rule_ctx = TritonLoweringRuleContext( | |
ctx, avals_in, avals_out, eqn_block_infos | |
) | |
try: | |
> outvals = rule(rule_ctx, *invals, **eqn.params) | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:274: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
ctx = TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.co...tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), None]) | |
jaxpr = { lambda ; a:Ref{int32[384]} b:i32[]. let | |
c:i32[] = mul b 64 | |
d:i32[64] = iota[dimension=0 dtype=int32 shape=(6...32[64] = add d c | |
f:bool[64] = lt e 384 | |
g:i32[64] <- a[c:c+64] | |
h:i32[] = reduce_min[axes=(0,)] g | |
in (h,) } | |
linear = (False, False), length = 6, reverse = False | |
def _scan_lowering_rule( | |
ctx: TritonLoweringRuleContext, | |
*args, | |
jaxpr, | |
linear, | |
length, | |
reverse, | |
unroll, | |
num_consts, | |
num_carry, | |
): | |
# Only implements fori_loop-like scans | |
num_extensive = len(args) - num_consts - num_carry | |
> if num_extensive: raise NotImplementedError | |
E NotImplementedError | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1314: NotImplementedError | |
The above exception was the direct cause of the following exception: | |
> return _run_code(code, main_globals, None, | |
"__main__", mod_spec) | |
/usr/lib/python3.10/runpy.py:196: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> exec(code, run_globals) | |
/usr/lib/python3.10/runpy.py:86: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> raise SystemExit(pytest.console_main()) | |
/usr/local/lib/python3.10/dist-packages/pytest/__main__.py:5: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> code = main() | |
/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:189: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main( | |
config=config | |
) | |
/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:166: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return wrap_session(config, _main) | |
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:316: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> session.exitstatus = doit(config, session) or 0 | |
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:269: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> config.hook.pytest_runtestloop(session=session) | |
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:323: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem) | |
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:348: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> runtestprotocol(item, nextitem=nextitem) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:114: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> reports.append(call_and_report(item, "call", log)) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:133: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> call = call_runtest_hook(item, when, **kwds) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:222: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return CallInfo.from_call( | |
lambda: ihook(item=item, **kwds), when=when, reraise=reraise | |
) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:261: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> result: Optional[TResult] = func() | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:341: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise | |
) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:262: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> item.runtest() | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:169: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> self.ihook.pytest_pyfunc_call(pyfuncitem=self) | |
/usr/local/lib/python3.10/dist-packages/_pytest/python.py:1799: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> result = testfunction(**testargs) | |
/usr/local/lib/python3.10/dist-packages/_pytest/python.py:194: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> result_triton = triton_multi_head_attention( | |
q, | |
k, | |
v, | |
q_pos, | |
q_sid, | |
kv_pos, | |
kv_sid, | |
) | |
tests/test_attention.py:92: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> result = pl.pallas_call( | |
kernel, | |
grid=grid, | |
in_specs=[ | |
pl.BlockSpec( | |
lambda c, b, h: (b, 0, h, 0), (None, query_seq_len, None, head_dim) | |
), | |
pl.BlockSpec( | |
lambda c, b, h: (b, 0, h, 0), (None, kv_seq_len, None, head_dim) | |
), | |
pl.BlockSpec( | |
lambda c, b, h: (b, 0, h, 0), (None, kv_seq_len, None, head_dim) | |
), | |
pl.BlockSpec(lambda c, b, h: (b, 0), (None, query_seq_len)), | |
pl.BlockSpec(lambda c, b, h: (b, 0), (None, query_seq_len)), | |
pl.BlockSpec(lambda c, b, h: (b, 0), (None, kv_seq_len)), | |
pl.BlockSpec(lambda c, b, h: (b, 0), (None, kv_seq_len)), | |
], | |
out_specs=pl.BlockSpec( | |
lambda c, b, h: (b, 0, h, 0), (None, query_seq_len, None, head_dim) | |
), | |
num_warps=num_warps, | |
num_stages=num_stages, | |
out_shape=out_shape, | |
debug=debug, | |
interpret=interpret, | |
name="mha_forward", | |
)( | |
query, | |
key, | |
value, | |
query_positions, | |
query_segment_ids, | |
kv_positions, | |
kv_segment_ids, | |
) | |
papyrax/nn/triton_ops.py:218: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> out_flat = pallas_call_p.bind( | |
*consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear, | |
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype) | |
for a in flat_args), | |
out_shapes=tuple(flat_out_shapes), debug=debug, | |
interpret=interpret, | |
grid_mapping=grid_mapping, | |
input_output_aliases=tuple(input_output_aliases.items()), | |
**compiler_params) | |
E jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn: | |
E a:i32[6] = scan[ | |
E jaxpr={ lambda ; b:Ref{int32[384]} c:i32[]. let | |
E d:i32[] = mul c 64 | |
E e:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)] | |
E f:i32[64] = add e d | |
E g:bool[64] = lt f 384 | |
E h:i32[64] <- b[d:d+64] | |
E i:i32[] = reduce_min[axes=(0,)] h | |
E in (i,) } | |
E length=6 | |
E linear=(False, False) | |
E num_carry=0 | |
E num_consts=1 | |
E reverse=False | |
E unroll=1 | |
E ] j k | |
E With context: | |
E TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, builder=<triton._C.libtriton.triton.ir.builder object at 0x7fbb28171e90>, module=<triton._C.libtriton.triton.ir.module object at 0x7fbb28172110>, grid_mapping=GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0), program_ids=[<triton.language.core.tensor object at 0x7fbb28199660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>]), avals_in=[Ref{int32[384]}, ShapedArray(int32[6])], avals_out=[ShapedArray(int32[6])], block_infos=[BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384), dtype=int32), start_indices=(<triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), None]) | |
E With inval shapes=[[constexpr[1]], [constexpr[6]]] | |
E With inval types=[pointer<int32>, <[6], int32>] | |
E In jaxpr: | |
E { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]} | |
E d:Ref{int32[384]} e:Ref{int32[384]} f:Ref{int32[384]} g:Ref{int32[384]} h:Ref{bfloat16[384,64]}. let | |
E i:i32[] = program_id[axis=0] | |
E j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0 | |
E k:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] -1.0000000200408773e+20 | |
E l:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] 0.0 | |
E m:i32[] = mul i 128 | |
E n:i32[128] = iota[dimension=0 dtype=int32 shape=(128,)] | |
E o:i32[128] = add n m | |
E p:bool[128] = lt o 384 | |
E q:bf16[128,64] <- a[m:m+128,:] | |
E r:i32[128] <- d[m:m+128] | |
E s:i32[128] <- e[m:m+128] | |
E t:bf16[128,64] = mul q 0.125 | |
E u:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)] | |
E v:i32[6] = scan[ | |
E jaxpr={ lambda ; w:Ref{int32[384]} x:i32[]. let | |
E y:i32[] = mul x 64 | |
E z:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)] | |
E ba:i32[64] = add z y | |
E bb:bool[64] = lt ba 384 | |
E bc:i32[64] <- w[y:y+64] | |
E bd:i32[] = reduce_min[axes=(0,)] bc | |
E in (bd,) } | |
E length=6 | |
E linear=(False, False) | |
E num_carry=0 | |
E num_consts=1 | |
E reverse=False | |
E unroll=1 | |
E ] f u | |
E be:i32[] = reduce_max[axes=(0,)] r | |
E bf:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)] | |
E bg:bool[6] = ge be v | |
E bh:i32[6] = convert_element_type[new_dtype=int32 weak_type=False] bg | |
E bi:i32[6] = mul bf bh | |
E bj:i32[] = argmax[axes=(0,) index_dtype=int32] bi | |
E bk:i32[] = add bj 1 | |
E _:i32[] _:i32[] bl:f32[128,64] _:f32[128] _:f32[128] = while[ | |
E body_jaxpr={ lambda ; bm:Ref{bfloat16[384,64]} bn:Ref{int32[384]} bo:Ref{int32[384]} | |
E bp:bf16[128,64] bq:i32[128] br:i32[128] bs:Ref{bfloat16[384,64]} bt:i32[] | |
E bu:i32[] bv:f32[128,64] bw:f32[128] bx:f32[128]. let | |
E by:i32[] = add bt 1 | |
E bz:i32[] = mul bt 64 | |
E ca:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)] | |
E cb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bz | |
E cc:i32[64] = add ca cb | |
E cd:bool[64] = lt cc 384 | |
E ce:bf16[64,64] <- bm[bz:bz+64,:] | |
E cf:i32[64] <- bn[bz:bz+64] | |
E cg:i32[64] <- bo[bz:bz+64] | |
E ch:bf16[64,64] = transpose[permutation=(1, 0)] ce | |
E ci:bf16[128,64] = dot_general[ | |
E dimension_numbers=(([1], [0]), ([], [])) | |
E ] bp ch | |
E cj:f32[128,64] = convert_element_type[ | |
E new_dtype=float32 | |
E weak_type=False | |
E ] ci | |
E ck:i32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] bq | |
E cl:i32[1,64] = broadcast_in_dim[ | |
E broadcast_dimensions=(1,) | |
E shape=(1, 64) | |
E ] cg | |
E cm:bool[128,64] = eq ck cl | |
E cn:i32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] br | |
E co:i32[1,64] = broadcast_in_dim[ | |
E broadcast_dimensions=(1,) | |
E shape=(1, 64) | |
E ] cf | |
E cp:bool[128,64] = ge cn co | |
E cq:bool[128,64] = and cm cp | |
E cr:f32[128,64] = pjit[ | |
E jaxpr={ lambda ; cs:bool[128,64] ct:f32[128,64] cu:f32[]. let | |
E cv:f32[] = convert_element_type[ | |
E new_dtype=float32 | |
E weak_type=False | |
E ] cu | |
E cw:f32[128,64] = broadcast_in_dim[ | |
E broadcast_dimensions=() | |
E shape=(128, 64) | |
E ] cv | |
E cx:f32[128,64] = select_n cs cw ct | |
E in (cx,) } | |
E name=_where | |
E ] cq cj -1e+20 | |
E cy:f32[128] = reduce_max[axes=(1,)] cr | |
E cz:f32[128] = max cy bw | |
E da:f32[128] = sub bw cz | |
E db:f32[128] = exp da | |
E dc:f32[128] = mul bx db | |
E dd:f32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] cz | |
E de:f32[128,64] = sub cr dd | |
E df:f32[128,64] = exp de | |
E dg:f32[128] = reduce_sum[axes=(1,)] df | |
E dh:f32[128] = add dg dc | |
E di:f32[128] = div 1.0 dh | |
E dj:f32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] di | |
E dk:f32[128,64] = mul df dj | |
E dl:f32[128] = mul dc di | |
E dm:f32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] dl | |
E dn:f32[128,64] = mul bv dm | |
E do:bf16[64,64] <- bs[bz:bz+64,:] | |
E dp:bf16[128,64] = convert_element_type[ | |
E new_dtype=bfloat16 | |
E weak_type=False | |
E ] dk | |
E dq:bf16[128,64] = dot_general[ | |
E dimension_numbers=(([1], [0]), ([], [])) | |
E ] dp do | |
E dr:f32[128,64] = convert_element_type[ | |
E new_dtype=float32 | |
E weak_type=False | |
E ] dq | |
E ds:f32[128,64] = add dn dr | |
E in (by, bu, ds, cz, dh) } | |
E body_nconsts=7 | |
E cond_jaxpr={ lambda ; dt:i32[] du:i32[] dv:f32[128,64] dw:f32[128] dx:f32[128]. let | |
E dy:bool[] = lt dt du | |
E in (dy,) } | |
E cond_nconsts=0 | |
E ] b f g t s r c 0 bk j k l | |
E dz:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] bl | |
E ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p | |
E h[m:m+128,:] <- dz | |
E in () } | |
E | |
E The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception. | |
E | |
E -------------------- | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/pallas_call.py:352: JaxStackTraceBeforeTransformation | |
The above exception was the direct cause of the following exception: | |
"""The pytest entry point.""" | |
import pytest | |
if __name__ == "__main__": | |
> raise SystemExit(pytest.console_main()) | |
/usr/local/lib/python3.10/dist-packages/pytest/__main__.py:5: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
def console_main() -> int: | |
"""The CLI entry point of pytest. | |
This function is not meant for programmable use; use `main()` instead. | |
""" | |
# https://docs.python.org/3/library/signal.html#note-on-sigpipe | |
try: | |
> code = main() | |
/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:189: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = None, plugins = None | |
def main( | |
args: Optional[Union[List[str], "os.PathLike[str]"]] = None, | |
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None, | |
) -> Union[int, ExitCode]: | |
"""Perform an in-process test run. | |
:param args: List of command line arguments. | |
:param plugins: List of plugin objects to be auto-registered during initialization. | |
:returns: An exit code. | |
""" | |
try: | |
try: | |
config = _prepareconfig(args, plugins) | |
except ConftestImportFailure as e: | |
exc_info = ExceptionInfo.from_exc_info(e.excinfo) | |
tw = TerminalWriter(sys.stderr) | |
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True) | |
exc_info.traceback = exc_info.traceback.filter( | |
filter_traceback_for_conftest_import_failure | |
) | |
exc_repr = ( | |
exc_info.getrepr(style="short", chain=False) | |
if exc_info.traceback | |
else exc_info.exconly() | |
) | |
formatted_tb = str(exc_repr) | |
for line in formatted_tb.splitlines(): | |
tw.line(line.rstrip(), red=True) | |
return ExitCode.USAGE_ERROR | |
else: | |
try: | |
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main( | |
config=config | |
) | |
/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:166: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_HookCaller 'pytest_cmdline_main'>, args = () | |
kwargs = {'config': <_pytest.config.Config object at 0x7fbcddaa94e0>} | |
argname = 'config', firstresult = True | |
def __call__(self, *args, **kwargs): | |
if args: | |
raise TypeError("hook calling supports only keyword arguments") | |
assert not self.is_historic() | |
# This is written to avoid expensive operations when not needed. | |
if self.spec: | |
for argname in self.spec.argnames: | |
if argname not in kwargs: | |
notincall = tuple(set(self.spec.argnames) - kwargs.keys()) | |
warnings.warn( | |
"Argument(s) {} which are declared in the hookspec " | |
"can not be found in this hook call".format(notincall), | |
stacklevel=2, | |
) | |
break | |
firstresult = self.spec.opts.get("firstresult") | |
else: | |
firstresult = False | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10> | |
hook_name = 'pytest_cmdline_main' | |
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/usr/local/lib/python3.10/dist-packages/_pytest/mai...uponly', plugin=<module '_pytest.setuponly' from '/usr/local/lib/python3.10/dist-packages/_pytest/setuponly.py'>>, ...] | |
kwargs = {'config': <_pytest.config.Config object at 0x7fbcddaa94e0>} | |
firstresult = True | |
def _hookexec(self, hook_name, methods, kwargs, firstresult): | |
# called from all hookcaller instances. | |
# enable_tracing will set its own wrapping function at self._inner_hookexec | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
hook_name = 'pytest_cmdline_main' | |
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/usr/local/lib/python3.10/dist-packages/_pytest/mai...uponly', plugin=<module '_pytest.setuponly' from '/usr/local/lib/python3.10/dist-packages/_pytest/setuponly.py'>>, ...] | |
caller_kwargs = {'config': <_pytest.config.Config object at 0x7fbcddaa94e0>} | |
firstresult = True | |
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult): | |
"""Execute a call into multiple python functions/methods and return the | |
result(s). | |
``caller_kwargs`` comes from _HookCaller.__call__(). | |
""" | |
__tracebackhide__ = True | |
results = [] | |
excinfo = None | |
try: # run impl and wrapper setup functions in a loop | |
teardowns = [] | |
try: | |
for hook_impl in reversed(hook_impls): | |
try: | |
args = [caller_kwargs[argname] for argname in hook_impl.argnames] | |
except KeyError: | |
for argname in hook_impl.argnames: | |
if argname not in caller_kwargs: | |
raise HookCallError( | |
f"hook call must provide argument {argname!r}" | |
) | |
if hook_impl.hookwrapper: | |
try: | |
gen = hook_impl.function(*args) | |
next(gen) # first yield | |
teardowns.append(gen) | |
except StopIteration: | |
_raise_wrapfail(gen, "did not yield") | |
else: | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
config = <_pytest.config.Config object at 0x7fbcddaa94e0> | |
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]: | |
> return wrap_session(config, _main) | |
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:316: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
config = <_pytest.config.Config object at 0x7fbcddaa94e0> | |
doit = <function _main at 0x7fbcddd18280> | |
def wrap_session( | |
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]] | |
) -> Union[int, ExitCode]: | |
"""Skeleton command line program.""" | |
session = Session.from_config(config) | |
session.exitstatus = ExitCode.OK | |
initstate = 0 | |
try: | |
try: | |
config._do_configure() | |
initstate = 1 | |
config.hook.pytest_sessionstart(session=session) | |
initstate = 2 | |
> session.exitstatus = doit(config, session) or 0 | |
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:269: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
config = <_pytest.config.Config object at 0x7fbcddaa94e0> | |
session = <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1> | |
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]: | |
"""Default command line protocol for initialization, session, | |
running tests and reporting.""" | |
config.hook.pytest_collection(session=session) | |
> config.hook.pytest_runtestloop(session=session) | |
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:323: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_HookCaller 'pytest_runtestloop'>, args = () | |
kwargs = {'session': <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>} | |
argname = 'session', firstresult = True | |
def __call__(self, *args, **kwargs): | |
if args: | |
raise TypeError("hook calling supports only keyword arguments") | |
assert not self.is_historic() | |
# This is written to avoid expensive operations when not needed. | |
if self.spec: | |
for argname in self.spec.argnames: | |
if argname not in kwargs: | |
notincall = tuple(set(self.spec.argnames) - kwargs.keys()) | |
warnings.warn( | |
"Argument(s) {} which are declared in the hookspec " | |
"can not be found in this hook call".format(notincall), | |
stacklevel=2, | |
) | |
break | |
firstresult = self.spec.opts.get("firstresult") | |
else: | |
firstresult = False | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10> | |
hook_name = 'pytest_runtestloop' | |
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/usr/local/lib/python3.10/dist-packages/_pytest/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7fbcdd963250>>] | |
kwargs = {'session': <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>} | |
firstresult = True | |
def _hookexec(self, hook_name, methods, kwargs, firstresult): | |
# called from all hookcaller instances. | |
# enable_tracing will set its own wrapping function at self._inner_hookexec | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
hook_name = 'pytest_runtestloop' | |
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/usr/local/lib/python3.10/dist-packages/_pytest/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7fbcdd963250>>] | |
caller_kwargs = {'session': <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>} | |
firstresult = True | |
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult): | |
"""Execute a call into multiple python functions/methods and return the | |
result(s). | |
``caller_kwargs`` comes from _HookCaller.__call__(). | |
""" | |
__tracebackhide__ = True | |
results = [] | |
excinfo = None | |
try: # run impl and wrapper setup functions in a loop | |
teardowns = [] | |
try: | |
for hook_impl in reversed(hook_impls): | |
try: | |
args = [caller_kwargs[argname] for argname in hook_impl.argnames] | |
except KeyError: | |
for argname in hook_impl.argnames: | |
if argname not in caller_kwargs: | |
raise HookCallError( | |
f"hook call must provide argument {argname!r}" | |
) | |
if hook_impl.hookwrapper: | |
try: | |
gen = hook_impl.function(*args) | |
next(gen) # first yield | |
teardowns.append(gen) | |
except StopIteration: | |
_raise_wrapfail(gen, "did not yield") | |
else: | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
session = <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1> | |
def pytest_runtestloop(session: "Session") -> bool: | |
if session.testsfailed and not session.config.option.continue_on_collection_errors: | |
raise session.Interrupted( | |
"%d error%s during collection" | |
% (session.testsfailed, "s" if session.testsfailed != 1 else "") | |
) | |
if session.config.option.collectonly: | |
return True | |
for i, item in enumerate(session.items): | |
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None | |
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem) | |
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:348: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_HookCaller 'pytest_runtest_protocol'>, args = () | |
kwargs = {'item': <Function test_triton_flash_attention[384-384]>, 'nextitem': None} | |
argname = 'nextitem', firstresult = True | |
def __call__(self, *args, **kwargs): | |
if args: | |
raise TypeError("hook calling supports only keyword arguments") | |
assert not self.is_historic() | |
# This is written to avoid expensive operations when not needed. | |
if self.spec: | |
for argname in self.spec.argnames: | |
if argname not in kwargs: | |
notincall = tuple(set(self.spec.argnames) - kwargs.keys()) | |
warnings.warn( | |
"Argument(s) {} which are declared in the hookspec " | |
"can not be found in this hook call".format(notincall), | |
stacklevel=2, | |
) | |
break | |
firstresult = self.spec.opts.get("firstresult") | |
else: | |
firstresult = False | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10> | |
hook_name = 'pytest_runtest_protocol' | |
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/usr/local/lib/python3.10/dist-packages/_pytest...ame='warnings', plugin=<module '_pytest.warnings' from '/usr/local/lib/python3.10/dist-packages/_pytest/warnings.py'>>] | |
kwargs = {'item': <Function test_triton_flash_attention[384-384]>, 'nextitem': None} | |
firstresult = True | |
def _hookexec(self, hook_name, methods, kwargs, firstresult): | |
# called from all hookcaller instances. | |
# enable_tracing will set its own wrapping function at self._inner_hookexec | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
hook_name = 'pytest_runtest_protocol' | |
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/usr/local/lib/python3.10/dist-packages/_pytest...ame='warnings', plugin=<module '_pytest.warnings' from '/usr/local/lib/python3.10/dist-packages/_pytest/warnings.py'>>] | |
caller_kwargs = {'item': <Function test_triton_flash_attention[384-384]>, 'nextitem': None} | |
firstresult = True | |
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult): | |
"""Execute a call into multiple python functions/methods and return the | |
result(s). | |
``caller_kwargs`` comes from _HookCaller.__call__(). | |
""" | |
__tracebackhide__ = True | |
results = [] | |
excinfo = None | |
try: # run impl and wrapper setup functions in a loop | |
teardowns = [] | |
try: | |
for hook_impl in reversed(hook_impls): | |
try: | |
args = [caller_kwargs[argname] for argname in hook_impl.argnames] | |
except KeyError: | |
for argname in hook_impl.argnames: | |
if argname not in caller_kwargs: | |
raise HookCallError( | |
f"hook call must provide argument {argname!r}" | |
) | |
if hook_impl.hookwrapper: | |
try: | |
gen = hook_impl.function(*args) | |
next(gen) # first yield | |
teardowns.append(gen) | |
except StopIteration: | |
_raise_wrapfail(gen, "did not yield") | |
else: | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
item = <Function test_triton_flash_attention[384-384]>, nextitem = None | |
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool: | |
ihook = item.ihook | |
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location) | |
> runtestprotocol(item, nextitem=nextitem) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:114: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
item = <Function test_triton_flash_attention[384-384]>, log = True | |
nextitem = None | |
def runtestprotocol( | |
item: Item, log: bool = True, nextitem: Optional[Item] = None | |
) -> List[TestReport]: | |
hasrequest = hasattr(item, "_request") | |
if hasrequest and not item._request: # type: ignore[attr-defined] | |
# This only happens if the item is re-run, as is done by | |
# pytest-rerunfailures. | |
item._initrequest() # type: ignore[attr-defined] | |
rep = call_and_report(item, "setup", log) | |
reports = [rep] | |
if rep.passed: | |
if item.config.getoption("setupshow", False): | |
show_test_item(item) | |
if not item.config.getoption("setuponly", False): | |
> reports.append(call_and_report(item, "call", log)) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:133: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
item = <Function test_triton_flash_attention[384-384]>, when = 'call' | |
log = True, kwds = {} | |
call = <CallInfo when='call' excinfo=<ExceptionInfo TritonLoweringException("Exception while lowering eqn:\n a:i32[6] = scan...ol[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p\n h[m:m+128,:] <- dz\n in () }") tblen=5>> | |
hook = <_pytest.config.compat.PathAwareHookProxy object at 0x7fbcddaa9660> | |
def call_and_report( | |
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds | |
) -> TestReport: | |
> call = call_runtest_hook(item, when, **kwds) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:222: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
item = <Function test_triton_flash_attention[384-384]>, when = 'call', kwds = {} | |
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>) | |
def call_runtest_hook( | |
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds | |
) -> "CallInfo[None]": | |
if when == "setup": | |
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup | |
elif when == "call": | |
ihook = item.ihook.pytest_runtest_call | |
elif when == "teardown": | |
ihook = item.ihook.pytest_runtest_teardown | |
else: | |
assert False, f"Unhandled runtest hook case: {when}" | |
reraise: Tuple[Type[BaseException], ...] = (Exit,) | |
if not item.config.getoption("usepdb", False): | |
reraise += (KeyboardInterrupt,) | |
> return CallInfo.from_call( | |
lambda: ihook(item=item, **kwds), when=when, reraise=reraise | |
) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:261: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
cls = <class '_pytest.runner.CallInfo'> | |
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fbcdd9d0310> | |
when = 'call' | |
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>) | |
@classmethod | |
def from_call( | |
cls, | |
func: "Callable[[], TResult]", | |
when: "Literal['collect', 'setup', 'call', 'teardown']", | |
reraise: Optional[ | |
Union[Type[BaseException], Tuple[Type[BaseException], ...]] | |
] = None, | |
) -> "CallInfo[TResult]": | |
"""Call func, wrapping the result in a CallInfo. | |
:param func: | |
The function to call. Called without arguments. | |
:param when: | |
The phase in which the function is called. | |
:param reraise: | |
Exception or exceptions that shall propagate if raised by the | |
function, instead of being wrapped in the CallInfo. | |
""" | |
excinfo = None | |
start = timing.time() | |
precise_start = timing.perf_counter() | |
try: | |
> result: Optional[TResult] = func() | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:341: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise | |
) | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:262: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_HookCaller 'pytest_runtest_call'>, args = () | |
kwargs = {'item': <Function test_triton_flash_attention[384-384]>} | |
argname = 'item', firstresult = False | |
def __call__(self, *args, **kwargs): | |
if args: | |
raise TypeError("hook calling supports only keyword arguments") | |
assert not self.is_historic() | |
# This is written to avoid expensive operations when not needed. | |
if self.spec: | |
for argname in self.spec.argnames: | |
if argname not in kwargs: | |
notincall = tuple(set(self.spec.argnames) - kwargs.keys()) | |
warnings.warn( | |
"Argument(s) {} which are declared in the hookspec " | |
"can not be found in this hook call".format(notincall), | |
stacklevel=2, | |
) | |
break | |
firstresult = self.spec.opts.get("firstresult") | |
else: | |
firstresult = False | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10> | |
hook_name = 'pytest_runtest_call' | |
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/usr/local/lib/python3.10/dist-packages/_pytest..., plugin=<module '_pytest.threadexception' from '/usr/local/lib/python3.10/dist-packages/_pytest/threadexception.py'>>] | |
kwargs = {'item': <Function test_triton_flash_attention[384-384]>} | |
firstresult = False | |
def _hookexec(self, hook_name, methods, kwargs, firstresult): | |
# called from all hookcaller instances. | |
# enable_tracing will set its own wrapping function at self._inner_hookexec | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
hook_name = 'pytest_runtest_call' | |
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/usr/local/lib/python3.10/dist-packages/_pytest..., plugin=<module '_pytest.threadexception' from '/usr/local/lib/python3.10/dist-packages/_pytest/threadexception.py'>>] | |
caller_kwargs = {'item': <Function test_triton_flash_attention[384-384]>} | |
firstresult = False | |
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult): | |
"""Execute a call into multiple python functions/methods and return the | |
result(s). | |
``caller_kwargs`` comes from _HookCaller.__call__(). | |
""" | |
__tracebackhide__ = True | |
results = [] | |
excinfo = None | |
try: # run impl and wrapper setup functions in a loop | |
teardowns = [] | |
try: | |
for hook_impl in reversed(hook_impls): | |
try: | |
args = [caller_kwargs[argname] for argname in hook_impl.argnames] | |
except KeyError: | |
for argname in hook_impl.argnames: | |
if argname not in caller_kwargs: | |
raise HookCallError( | |
f"hook call must provide argument {argname!r}" | |
) | |
if hook_impl.hookwrapper: | |
try: | |
gen = hook_impl.function(*args) | |
next(gen) # first yield | |
teardowns.append(gen) | |
except StopIteration: | |
_raise_wrapfail(gen, "did not yield") | |
else: | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
item = <Function test_triton_flash_attention[384-384]> | |
def pytest_runtest_call(item: Item) -> None: | |
_update_current_test_var(item, "call") | |
try: | |
del sys.last_type | |
del sys.last_value | |
del sys.last_traceback | |
except AttributeError: | |
pass | |
try: | |
> item.runtest() | |
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:169: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <Function test_triton_flash_attention[384-384]> | |
def runtest(self) -> None: | |
"""Execute the underlying test function.""" | |
> self.ihook.pytest_pyfunc_call(pyfuncitem=self) | |
/usr/local/lib/python3.10/dist-packages/_pytest/python.py:1799: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_HookCaller 'pytest_pyfunc_call'>, args = () | |
kwargs = {'pyfuncitem': <Function test_triton_flash_attention[384-384]>} | |
argname = 'pyfuncitem', firstresult = True | |
def __call__(self, *args, **kwargs): | |
if args: | |
raise TypeError("hook calling supports only keyword arguments") | |
assert not self.is_historic() | |
# This is written to avoid expensive operations when not needed. | |
if self.spec: | |
for argname in self.spec.argnames: | |
if argname not in kwargs: | |
notincall = tuple(set(self.spec.argnames) - kwargs.keys()) | |
warnings.warn( | |
"Argument(s) {} which are declared in the hookspec " | |
"can not be found in this hook call".format(notincall), | |
stacklevel=2, | |
) | |
break | |
firstresult = self.spec.opts.get("firstresult") | |
else: | |
firstresult = False | |
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10> | |
hook_name = 'pytest_pyfunc_call' | |
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/usr/local/lib/python3.10/dist-packages/_pytest/python.py'>>] | |
kwargs = {'pyfuncitem': <Function test_triton_flash_attention[384-384]>} | |
firstresult = True | |
def _hookexec(self, hook_name, methods, kwargs, firstresult): | |
# called from all hookcaller instances. | |
# enable_tracing will set its own wrapping function at self._inner_hookexec | |
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
hook_name = 'pytest_pyfunc_call' | |
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/usr/local/lib/python3.10/dist-packages/_pytest/python.py'>>] | |
caller_kwargs = {'pyfuncitem': <Function test_triton_flash_attention[384-384]>} | |
firstresult = True | |
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult): | |
"""Execute a call into multiple python functions/methods and return the | |
result(s). | |
``caller_kwargs`` comes from _HookCaller.__call__(). | |
""" | |
__tracebackhide__ = True | |
results = [] | |
excinfo = None | |
try: # run impl and wrapper setup functions in a loop | |
teardowns = [] | |
try: | |
for hook_impl in reversed(hook_impls): | |
try: | |
args = [caller_kwargs[argname] for argname in hook_impl.argnames] | |
except KeyError: | |
for argname in hook_impl.argnames: | |
if argname not in caller_kwargs: | |
raise HookCallError( | |
f"hook call must provide argument {argname!r}" | |
) | |
if hook_impl.hookwrapper: | |
try: | |
gen = hook_impl.function(*args) | |
next(gen) # first yield | |
teardowns.append(gen) | |
except StopIteration: | |
_raise_wrapfail(gen, "did not yield") | |
else: | |
> res = hook_impl.function(*args) | |
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
pyfuncitem = <Function test_triton_flash_attention[384-384]> | |
@hookimpl(trylast=True) | |
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]: | |
testfunction = pyfuncitem.obj | |
if is_async_function(testfunction): | |
async_warn_and_skip(pyfuncitem.nodeid) | |
funcargs = pyfuncitem.funcargs | |
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames} | |
> result = testfunction(**testargs) | |
/usr/local/lib/python3.10/dist-packages/_pytest/python.py:194: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
query_seq_len = 384, kv_seq_len = 384 | |
@pytest.mark.gpu | |
@pytest.mark.parametrize( | |
"query_seq_len,kv_seq_len", | |
[ | |
(384, 384), | |
(400, 400), | |
(384, 256), | |
(384, 300), | |
(400, 256), | |
(400, 300), | |
(1, 256), | |
(1, 300), | |
(256, 1), | |
(300, 1), | |
(1, 3), | |
], | |
) | |
def test_triton_flash_attention(query_seq_len, kv_seq_len): | |
q, k, v, q_pos, q_sid, kv_pos, kv_sid = _make_attention_inputs( | |
query_seq_len=query_seq_len, | |
kv_seq_len=kv_seq_len, | |
) | |
> result_triton = triton_multi_head_attention( | |
q, | |
k, | |
v, | |
q_pos, | |
q_sid, | |
kv_pos, | |
kv_sid, | |
) | |
tests/test_attention.py:92: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383, | |
1.375], | |
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, | |
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...) | |
kwargs = {}, __tracebackhide__ = True | |
msg = 'jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn:\n a:i32[6] = scan[\n jaxpr...ludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------' | |
@functools.wraps(fun) | |
def reraise_with_filtered_traceback(*args, **kwargs): | |
__tracebackhide__ = True | |
try: | |
> return fun(*args, **kwargs) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py:166: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383, | |
1.375], | |
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, | |
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...) | |
kwargs = {} | |
@api_boundary | |
def cache_miss(*args, **kwargs): | |
> outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( | |
fun, infer_params_fn, *args, **kwargs) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:250: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
fun = <function multi_head_attention at 0x7fbc2a3d9120> | |
infer_params_fn = <function jit.<locals>.infer_params at 0x7fbc2a416950> | |
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383, | |
1.375], | |
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, | |
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...) | |
kwargs = {} | |
args_flat = [Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383, | |
1.375], | |
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, | |
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...] | |
_ = () | |
params = {'donated_invars': (False, False, False, False, False, False, ...), 'in_shardings': (UnspecifiedValue, UnspecifiedValu...lse, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) }, ...} | |
in_tree = PyTreeDef((*, *, *, *, *, *, *)), out_tree = PyTreeDef(*) | |
arg = Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32, weak_type=True) | |
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs): | |
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn( | |
*args, **kwargs) | |
for arg in args_flat: | |
dispatch.check_arg(arg) | |
try: | |
> out_flat = pjit_p.bind(*args_flat, **params) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:163: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = pjit | |
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383, | |
1.375], | |
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, | |
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...) | |
params = {'donated_invars': (False, False, False, False, False, False, ...), 'in_shardings': (UnspecifiedValue, UnspecifiedValu...lse, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) }, ...} | |
top_trace = EvalTrace(level=0/0), axis_main = None | |
def bind(self, *args, **params): | |
top_trace = find_top_trace(args) | |
axis_main = max((axis_frame(a).main_trace for a in used_axis_names(self, params)), | |
default=None, key=lambda t: getattr(t, 'level', -1)) | |
top_trace = (top_trace if not axis_main or axis_main.level < top_trace.level | |
else axis_main.with_cur_sublevel()) | |
> return self.bind_with_trace(top_trace, args, params) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:2578: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = pjit, trace = EvalTrace(level=0/0) | |
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383, | |
1.375], | |
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, | |
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...) | |
params = {'donated_invars': (False, False, False, False, False, False, ...), 'in_shardings': (UnspecifiedValue, UnspecifiedValu...lse, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) }, ...} | |
def bind_with_trace(self, trace, args, params): | |
> out = trace.process_primitive(self, map(trace.full_raise, args), params) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:382: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = EvalTrace(level=0/0), primitive = pjit | |
tracers = [Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383, | |
1.375], | |
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, | |
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...] | |
params = {'donated_invars': (False, False, False, False, False, False, ...), 'in_shardings': (UnspecifiedValue, UnspecifiedValu...lse, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) }, ...} | |
def process_primitive(self, primitive, tracers, params): | |
> return primitive.impl(*tracers, **params) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:814: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) } | |
in_shardings = (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, ...) | |
out_shardings = (UnspecifiedValue,), resource_env = None | |
def _pjit_call_impl(*args, jaxpr, | |
in_shardings, out_shardings, resource_env, | |
donated_invars, name, keep_unused, inline): | |
def call_impl_cache_miss(*args_, **kwargs_): | |
out_flat, compiled = _pjit_call_impl_python( | |
*args, jaxpr=jaxpr, in_shardings=in_shardings, | |
out_shardings=out_shardings, resource_env=resource_env, | |
donated_invars=donated_invars, name=name, keep_unused=keep_unused, | |
inline=inline) | |
fastpath_data = _get_fastpath_data( | |
compiled, tree_structure(out_flat), args, out_flat) | |
return out_flat, fastpath_data | |
f = _get_jaxpr_as_fun( | |
jaxpr, tuple(getattr(i, '_original_sharding', i) for i in in_shardings), | |
tuple(getattr(o, '_original_sharding', o) for o in out_shardings), | |
resource_env, donated_invars, name, keep_unused, inline) | |
donated_argnums = [i for i, d in enumerate(donated_invars) if d] | |
has_explicit_sharding = _pjit_explicit_sharding( | |
in_shardings, out_shardings, None, None) | |
> return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, | |
_get_cpp_global_cache(has_explicit_sharding))(*args) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1223: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args_ = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383, | |
1.375], | |
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, | |
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...) | |
kwargs_ = {} | |
def call_impl_cache_miss(*args_, **kwargs_): | |
> out_flat, compiled = _pjit_call_impl_python( | |
*args, jaxpr=jaxpr, in_shardings=in_shardings, | |
out_shardings=out_shardings, resource_env=resource_env, | |
donated_invars=donated_invars, name=name, keep_unused=keep_unused, | |
inline=inline) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1207: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) } | |
in_shardings = (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, ...) | |
out_shardings = (UnspecifiedValue,), resource_env = None | |
def _pjit_call_impl_python( | |
*args, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, | |
name, keep_unused, inline): | |
global _most_recent_pjit_call_executable | |
in_shardings = _resolve_in_shardings( | |
args, in_shardings, out_shardings, | |
resource_env.physical_mesh if resource_env is not None else None) | |
> compiled = _pjit_lower( | |
jaxpr, in_shardings, out_shardings, resource_env, | |
donated_invars, name, keep_unused, inline, | |
always_lower=False, lowering_platform=None).compile() | |
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1140: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) } | |
in_shardings = SameDeviceAssignmentTuple(shardings=(UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue), device_assignment=None) | |
out_shardings = SameDeviceAssignmentTuple(shardings=(UnspecifiedValue,), device_assignment=None) | |
args = (None, (False, False, False, False, False, False, ...), 'multi_head_attention', False, False) | |
kwargs = {'always_lower': False, 'lowering_platform': None}, da = None | |
def _pjit_lower( | |
jaxpr: core.ClosedJaxpr, | |
in_shardings, | |
out_shardings, | |
*args, **kwargs): | |
da = _fast_path_get_device_assignment(it.chain(in_shardings, out_shardings)) | |
in_shardings = SameDeviceAssignmentTuple(tuple(in_shardings), da) | |
out_shardings = SameDeviceAssignmentTuple(tuple(out_shardings), da) | |
> return _pjit_lower_cached(jaxpr, in_shardings, out_shardings, *args, **kwargs) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1269: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) } | |
sdat_in_shardings = SameDeviceAssignmentTuple(shardings=(UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue), device_assignment=None) | |
sdat_out_shardings = SameDeviceAssignmentTuple(shardings=(UnspecifiedValue,), device_assignment=None) | |
resource_env = None | |
donated_invars = (False, False, False, False, False, False, ...) | |
name = 'multi_head_attention', keep_unused = False, inline = False | |
always_lower = False | |
@weakref_lru_cache | |
def _pjit_lower_cached( | |
jaxpr: core.ClosedJaxpr, | |
sdat_in_shardings: SameDeviceAssignmentTuple, | |
sdat_out_shardings: SameDeviceAssignmentTuple, | |
resource_env, | |
donated_invars, | |
name: str, | |
keep_unused: bool, | |
inline: bool, | |
always_lower: bool, | |
*, | |
lowering_platform: Optional[str], | |
override_lowering_rules: Optional[ | |
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None): | |
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast( | |
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings) | |
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings | |
if resource_env is not None: | |
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") | |
if resource_env is not None: | |
mesh = resource_env.physical_mesh | |
api_name = 'pjit' | |
else: | |
# resource_env is `None` in the jit wrapper around pjit. | |
mesh = None | |
api_name = 'jit' | |
# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path | |
# because `xmap` only supports SPMDAxisContext right now. | |
if dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'): | |
return pxla.lower_mesh_computation( | |
jaxpr, api_name, name, mesh, | |
in_shardings, out_shardings, donated_invars, | |
True, jaxpr.in_avals, tiling_method=None, | |
lowering_platform=lowering_platform) | |
else: | |
> return pxla.lower_sharding_computation( | |
jaxpr, api_name, name, in_shardings, out_shardings, | |
tuple(donated_invars), tuple(jaxpr.in_avals), | |
keep_unused=keep_unused, inline=inline, always_lower=always_lower, | |
devices_from_context=( | |
None if mesh is None or mesh.empty else list(mesh.devices.flat)), | |
lowering_platform=lowering_platform, | |
override_lowering_rules=override_lowering_rules, | |
) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1311: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
args = ({ lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[...e, UnspecifiedValue, UnspecifiedValue, ...), (UnspecifiedValue,), (False, False, False, False, False, False, ...), ...) | |
kwargs = {'always_lower': False, 'devices_from_context': None, 'inline': False, 'keep_unused': False, ...} | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
with TraceAnnotation(name, **decorator_kwargs): | |
> return func(*args, **kwargs) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py:314: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
fun_or_jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) } | |
api_name = 'jit', fun_name = 'multi_head_attention' | |
in_shardings = (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), ...) | |
out_shardings = (UnspecifiedValue,) | |
donated_invars = (False, False, False, False, False, False, ...) | |
global_in_avals = (ShapedArray(bfloat16[2,384,2,64]), ShapedArray(bfloat16[2,384,2,64]), ShapedArray(bfloat16[2,384,2,64]), ShapedArray(int32[2,384]), ShapedArray(int32[2,384], weak_type=True), ShapedArray(int32[2,384]), ...) | |
@profiler.annotate_function | |
def lower_sharding_computation( | |
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr], | |
api_name: str, | |
fun_name: str, | |
in_shardings: Sequence[MaybeSharding], | |
out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue], | |
donated_invars: Sequence[bool], | |
global_in_avals: Sequence[core.ShapedArray], | |
*, | |
keep_unused: bool, | |
inline: bool, | |
always_lower: bool, | |
devices_from_context: Optional[Sequence[xc.Device]] = None, | |
lowering_platform: Optional[str], | |
override_lowering_rules: Optional[ | |
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None, | |
) -> MeshComputation: | |
"""Lowers a computation to XLA. It can take arbitrary shardings as input. | |
The caller of this code can pass in a singleton UNSPECIFIED because the | |
number of out_avals might not be known at that time and | |
lower_sharding_computation calculates the number of out_avals so it can apply | |
the singleton UNSPECIFIED to all out_avals. | |
""" | |
# 1. Trace to jaxpr and preprocess/verify it | |
auto_spmd_lowering = ( | |
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else | |
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore | |
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars, | |
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce( | |
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused, | |
donated_invars, auto_spmd_lowering) | |
jaxpr = closed_jaxpr.jaxpr | |
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) | |
if is_unspecified(out_shardings): | |
out_shardings = (UNSPECIFIED,) * len(global_out_avals) | |
assert isinstance(out_shardings, tuple) | |
assert len(out_shardings) == len(global_out_avals), ( | |
len(out_shardings), len(global_out_avals)) | |
# Device assignment across all inputs, outputs and shardings inside jaxpr | |
# should be the same. | |
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr)) | |
backend, device_assignment = _get_and_check_device_assignment( | |
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings], | |
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings], | |
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info) | |
for js, source_info in jaxpr_sharding]), | |
devices_from_context) | |
committed = bool( | |
devices_from_context or | |
len(device_assignment) > 1 or | |
any(not is_unspecified(i) for i in in_shardings) or | |
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or | |
any(not is_unspecified(o) for o in out_shardings)) | |
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment) | |
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) | |
da_object = _create_da_object(tuple(device_assignment)) | |
if not da_object.is_fully_addressable: | |
check_multihost_collective_allowlist(jaxpr) | |
if inline and config.jax_spmd_mode != 'allow_all': | |
raise RuntimeError( | |
"Running operations on `Array`s that are not fully addressable by this " | |
"process (i.e. `Array`s with data sharded across multiple devices and " | |
"processes.) is dangerous. It’s very important that all processes run " | |
"the same cross-process computations in the same order otherwise it " | |
"can lead to hangs. " | |
"If you’re not already familiar with JAX’s multi-process " | |
"programming model, please read " | |
"https://jax.readthedocs.io/en/latest/multi_process.html. " | |
"To fix this error, run your `jitted` computation inside " | |
"`with jax.spmd_mode('allow_all'):` context manager.") | |
has_outfeed = core.jaxpr_uses_outfeed(jaxpr) | |
kept_outputs = [True] * len(global_out_avals) | |
# Computations that only produce constants and/or only rearrange their inputs, | |
# which are often produced from partial evaluation, don't need compilation, | |
# and don't need to evaluate their arguments. | |
if (not always_lower and not (jaxpr.effects or has_outfeed) and | |
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and | |
all(is_unspecified(o) for o in out_shardings)): | |
return MeshComputation( | |
str(name_stack), None, True, donated_invars, jaxpr=jaxpr, | |
consts=closed_jaxpr.consts, global_in_avals=global_in_avals, | |
global_out_avals=global_out_avals, in_shardings=in_shardings, | |
backend=backend, da_object=da_object, | |
committed=committed, kept_var_idx=kept_var_idx, keepalive=None) | |
# 2. Build up the HLO | |
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore | |
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) | |
(module, keepalive, host_callbacks, unordered_effects, ordered_effects, | |
> nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( | |
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, | |
semantic_out_shardings, da_object, lowering_platform, | |
donated_invars, name_stack, override_lowering_rules) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py:2083: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
closed_jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) } | |
api_name = 'jit', fun_name = 'multi_head_attention' | |
backend = <jaxlib.xla_extension.Client object at 0x7fbc21ccc070> | |
semantic_in_shardings = SemanticallyEqualShardings(shardings=(GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replica...), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}))) | |
semantic_out_shardings = SemanticallyEqualShardings(shardings=(UnspecifiedValue,)) | |
da_object = _DeviceAssignment(device_assignment=(gpu(id=0),)) | |
lowering_platform = None | |
donated_invars = (False, False, False, False, False, False, ...) | |
name_stack = NameStack(stack=(Scope(name='jit(multi_head_attention)'),)) | |
override_lowering_rules = None | |
@weakref_lru_cache | |
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, | |
semantic_in_shardings, semantic_out_shardings, | |
da_object, lowering_platform, | |
donated_invars, name_stack, override_lowering_rules): | |
jaxpr = closed_jaxpr.jaxpr | |
in_shardings = semantic_in_shardings.shardings | |
out_shardings = semantic_out_shardings.shardings | |
global_in_avals = closed_jaxpr.in_avals | |
global_out_avals = closed_jaxpr.out_avals | |
device_assignment = da_object.device_assignment | |
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG | |
if logger.isEnabledFor(log_priority): | |
logger.log(log_priority, | |
"Compiling %s for with global shapes and types %s. " | |
"Argument mapping: %s.", | |
fun_name, global_in_avals, in_shardings) | |
# Look at the number of replcas present in the jaxpr. In | |
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is | |
# handled here so as to deprecate the lower_xla_callable codepath when | |
# `jax.Array` is turned on by default. | |
# TODO(yashkatariya): Remove this when `jit(pmap)` is removed. | |
nreps = dispatch.jaxpr_replicas(jaxpr) | |
dispatch.raise_warnings_or_errors_for_jit_of_pmap( | |
nreps, backend, fun_name, jaxpr) | |
in_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]] | |
out_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]] | |
axis_ctx: mlir.AxisContext | |
if nreps == 1: | |
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings) | |
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings) | |
replicated_args = [False] * len(global_in_avals) | |
axis_ctx = sharding_impls.ShardingContext(device_assignment) | |
num_partitions = len(device_assignment) | |
else: | |
# This path is triggered for `jit(pmap)` cases. | |
replicated_args = None | |
in_mlir_shardings = None | |
out_mlir_shardings = None | |
axis_env = sharding_impls.AxisEnv(nreps, (), ()) | |
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env) | |
num_partitions = 1 | |
module_name = f"{api_name}_{fun_name}" | |
if len(device_assignment) > 1: | |
if any(effects.ordered_effects.contains(eff) for eff | |
in closed_jaxpr.effects): | |
raise ValueError("Ordered effects are not supported for more than 1 device.") | |
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects)) | |
with dispatch.log_elapsed_time( | |
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec", | |
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT): | |
> lowering_result = mlir.lower_jaxpr_to_module( | |
module_name, | |
closed_jaxpr, | |
ordered_effects, | |
backend, | |
# Optionally, override the lowering platform | |
lowering_platform or backend.platform, | |
axis_ctx, | |
name_stack, | |
donated_invars, | |
replicated_args=replicated_args, | |
arg_shardings=in_mlir_shardings, | |
result_shardings=out_mlir_shardings, | |
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names, | |
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths, | |
num_replicas=nreps, | |
num_partitions=num_partitions, | |
override_lowering_rules=override_lowering_rules) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py:1923: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
module_name = 'jit_multi_head_attention' | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) } | |
ordered_effects = [] | |
backend_or_name = <jaxlib.xla_extension.Client object at 0x7fbc21ccc070> | |
platform = 'cuda' | |
axis_context = ShardingContext(device_assignment=(gpu(id=0),)) | |
name_stack = NameStack(stack=(Scope(name='jit(multi_head_attention)'),)) | |
donated_args = [False, False, False, False, False, False, ...] | |
replicated_args = [False, False, False, False, False, False, ...] | |
arg_shardings = [GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), ...] | |
result_shardings = [None] | |
arg_names = ('query', 'key', 'value', 'query_positions', 'query_segment_ids', 'kv_positions', ...) | |
result_names = ('',), num_replicas = 1, num_partitions = 1 | |
override_lowering_rules = None | |
def lower_jaxpr_to_module( | |
module_name: str, | |
jaxpr: core.ClosedJaxpr, | |
ordered_effects: list[core.Effect], | |
backend_or_name: Optional[Union[str, xb.XlaBackend]], | |
platform: str, | |
axis_context: AxisContext, | |
name_stack: source_info_util.NameStack, | |
donated_args: Sequence[bool], | |
replicated_args: Optional[Sequence[bool]] = None, | |
arg_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None, | |
result_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None, | |
arg_names: Optional[Sequence[Optional[str]]] = None, | |
result_names: Optional[Sequence[Optional[str]]] = None, | |
num_replicas: int = 1, | |
num_partitions: int = 1, | |
override_lowering_rules: Optional[ | |
tuple[tuple[core.Primitive, LoweringRule]]] = None, | |
) -> LoweringResult: | |
"""Lowers a top-level jaxpr to an MLIR module. | |
Handles the quirks of the argument/return value passing conventions of the | |
runtime. | |
""" | |
platform = xb.canonicalize_platform(platform) | |
if not xb.is_known_platform(platform): | |
raise ValueError(f"Unknown platform {platform}") | |
input_output_aliases = None | |
in_avals = (jaxpr.in_avals if arg_shardings is None else | |
map(sharded_aval, jaxpr.in_avals, arg_shardings)) | |
out_avals = (jaxpr.out_avals if result_shardings is None else | |
map(sharded_aval, jaxpr.out_avals, result_shardings)) | |
if platform in _platforms_with_donation: | |
input_output_aliases, donated_args = _set_up_aliases( | |
in_avals, out_avals, donated_args) | |
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) | |
if unlowerable_effects: | |
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') | |
if any(donated_args): | |
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d] | |
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation." | |
if platform not in _platforms_with_donation: | |
msg = f"Donation is not implemented for {platform}.\n{msg}" | |
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}") | |
# HLO channels need to start at 1 | |
channel_iter = itertools.count(1) | |
# Create a keepalives list that will be mutated during the lowering. | |
keepalives: list[Any] = [] | |
host_callbacks: list[Any] = [] | |
dim_vars: Sequence[str] | |
if not config.jax_dynamic_shapes: | |
# Find the dimension variables | |
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape") | |
for d in aval.shape if not core.is_constant_dim(d)] | |
dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new.get_vars()), | |
all_dim_poly, set()))) | |
else: | |
dim_vars = () | |
arg_op_shardings = ( | |
map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings) | |
if arg_shardings is not None else arg_shardings) | |
result_op_shardings = ( | |
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings) | |
if result_shardings is not None else result_shardings) | |
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack, | |
keepalives, channel_iter, host_callbacks, | |
override_lowering_rules=override_lowering_rules, | |
shape_poly_state=ShapePolyLoweringState(dim_vars)) | |
with ctx.context, ir.Location.unknown(ctx.context): | |
# Remove module name characters that XLA would alter. This ensures that | |
# XLA computation preserves the module name. | |
attrs = ctx.module.operation.attributes | |
module_name = _module_name_regex.sub("_", module_name) | |
attrs["sym_name"] = ir.StringAttr.get(module_name) | |
attrs["mhlo.num_replicas"] = i32_attr(num_replicas) | |
attrs["mhlo.num_partitions"] = i32_attr(num_partitions) | |
> lower_jaxpr_to_fun( | |
ctx, "main", jaxpr, ordered_effects, public=True, create_tokens=True, | |
replace_tokens_with_dummy=True, | |
num_output_tokens=0, | |
replicated_args=replicated_args, | |
arg_shardings=arg_op_shardings, | |
result_shardings=result_op_shardings, | |
input_output_aliases=input_output_aliases, | |
arg_names=arg_names, | |
result_names=result_names) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:728: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
ctx = ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbb2812df80>, module=<jax...object at 0x7fbb281097e0>, cached_primitive_lowerings={}, cached_call_jaxpr_lowerings={}, override_lowering_rules=None) | |
name = 'main' | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) } | |
effects = [] | |
def lower_jaxpr_to_fun( | |
ctx: ModuleContext, | |
name: str, | |
jaxpr: core.ClosedJaxpr, | |
effects: Sequence[core.Effect], | |
*, | |
create_tokens: bool = False, | |
public: bool = False, | |
replace_tokens_with_dummy: bool = False, | |
replicated_args: Optional[Sequence[bool]] = None, | |
arg_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None, | |
result_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None, | |
use_sharding_annotations: bool = True, | |
input_output_aliases: Optional[Sequence[Optional[int]]] = None, | |
num_output_tokens: int = 0, | |
api_name: str = "jit", | |
arg_names: Optional[Sequence[Optional[str]]] = None, | |
result_names: Optional[Sequence[Optional[str]]] = None, | |
) -> func_dialect.FuncOp: | |
"""Lowers jaxpr and its callees to an IR function. | |
Assumes that an MLIR context, location, and insertion point are set. | |
Args: | |
ctx: the lowering context. | |
name: the function name. The name will be uniquified by the symbol table, | |
so it is ok to use the same name multiple times. | |
jaxpr: the jaxpr to lower. | |
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens | |
that will be created in or used by the lowered function. | |
create_tokens: if true, the HLO will create tokens and ignore dummy input tokens. | |
public: if true, the function's visibility is set to "public". | |
replace_tokens_with_dummy: if true, token arguments/return values are | |
replaced with bool arrays of size [0]. | |
replicated_args: if present, annotates arguments as replicated. | |
arg_shardings: sharding annotations for each argument (optional). | |
result_shardings: sharding annotations for each result (optional). | |
use_sharding_annotations: if True, use "mhlo.sharding" annotations on | |
parameters and return values to express sharding. If False, use | |
hlo.custom_call operators with sharding annotations. | |
TODO(b/228598865): remove this option when "mhlo.sharding" annotations are | |
propagated on non-entry functions during MLIR->HLO conversion. | |
input_output_aliases: optional sequence that maps argument numbers to the | |
corresponding output that should alias them. | |
api_name: The name of the higher level primitive which should show up in the | |
name stack. | |
Returns: | |
MLIR func op | |
""" | |
def aval_to_types(aval): | |
if replace_tokens_with_dummy and aval is core.abstract_token: | |
aval = core.ShapedArray((), np.dtype(np.bool_)) | |
return aval_to_ir_types(aval) | |
num_dim_vars = len(ctx.shape_poly_state.dim_vars) | |
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars | |
dim_var_types = map(aval_to_types, dim_var_avals) | |
# Function inputs: *dim_var_values, *tokens, *actual_inputs | |
input_types = map(aval_to_types, jaxpr.in_avals) | |
output_types = map(aval_to_types, jaxpr.out_avals) | |
num_tokens = len(effects) | |
if create_tokens: | |
# If we create the tokens they won't be inputs to the MLIR function. | |
token_types = [dummy_token_type() for _ in effects] | |
output_token_types = [dummy_token_type() for _ in range(num_output_tokens)] | |
else: | |
# If we aren't creating tokens they will be the initial inputs to the | |
# MLIR function. | |
output_token_types = [] | |
token_types = [token_type() for _ in effects] | |
token_avals = [core.AbstractToken] * num_tokens | |
# Order of arguments: dim vars, tokens, array inputs | |
input_avals = dim_var_avals + token_avals + jaxpr.in_avals | |
input_types = [*dim_var_types, *token_types, *input_types] | |
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals | |
output_types = [*output_token_types, *token_types, *output_types] | |
if input_output_aliases is not None: | |
token_input_output_aliases = [None] * (num_dim_vars + num_tokens) | |
input_output_aliases = [*token_input_output_aliases, *input_output_aliases] | |
# Update the existing aliases to account for the new output values | |
input_output_aliases = [None if a is None | |
else a + num_output_tokens + num_tokens | |
for a in input_output_aliases] | |
if arg_shardings is not None: | |
token_shardings = [None] * (num_dim_vars + num_tokens) | |
arg_shardings = [*token_shardings, *arg_shardings] | |
if result_shardings is not None: | |
token_shardings = [None] * (num_tokens + num_output_tokens) | |
result_shardings = [*token_shardings, *result_shardings] | |
if replicated_args is not None: | |
token_replicated_args = [False] * (num_dim_vars + num_tokens) | |
replicated_args = [*token_replicated_args, *replicated_args] | |
flat_input_types = util.flatten(input_types) | |
flat_output_types = util.flatten(output_types) | |
ftype = ir.FunctionType.get(flat_input_types, flat_output_types) | |
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip) | |
func_op.attributes["sym_visibility"] = ir.StringAttr.get( | |
"public" if public else "private") | |
ctx.symbol_table.insert(func_op) | |
ir_arg_shardings = None | |
if arg_shardings is not None: | |
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals) | |
ir_arg_shardings = util.flatten( | |
[[_to_physical_op_sharding(a, s)] * len(types) | |
for a, s, types in zip(in_avals, arg_shardings, input_types)]) | |
del in_avals | |
ir_result_shardings = None | |
if result_shardings is not None: | |
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals) | |
ir_result_shardings = util.flatten( | |
[[_to_physical_op_sharding(a, s)] * len(types) | |
for a, s, types in zip(out_avals, result_shardings, output_types)]) | |
del out_avals | |
if ( | |
replicated_args is not None | |
or ir_arg_shardings is not None | |
or input_output_aliases is not None | |
or arg_names is not None | |
or num_tokens > 0 | |
): | |
arg_attrs: list[dict[str, ir.Attribute]] = [ | |
{} for _ in range(len(flat_input_types))] | |
if replicated_args is not None: | |
replicated_ir_args = [[replicated] * len(types) for replicated, types | |
in zip(replicated_args, input_types)] | |
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)): | |
if replicated: | |
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get() | |
if use_sharding_annotations and ir_arg_shardings is not None: | |
for attrs, sharding in zip(arg_attrs, ir_arg_shardings): | |
if sharding is not None: | |
attrs["mhlo.sharding"] = get_sharding_attr(sharding) | |
if input_output_aliases is not None: | |
output_ids = util.unflatten(list(range(len(flat_output_types))), | |
map(len, output_types)) | |
aliases: list[Optional[int]] = [] | |
for types, alias in zip(input_types, input_output_aliases): | |
if alias is None: | |
aliases.extend([None] * len(types)) | |
else: | |
aliases.extend(output_ids[alias]) | |
for attrs, alias in zip(arg_attrs, aliases): | |
if alias is not None: | |
attrs["tf.aliasing_output"] = i32_attr(alias) | |
if num_tokens > 0: | |
token_arg_attrs = arg_attrs[num_dim_vars:num_dim_vars + num_tokens] | |
for attrs in token_arg_attrs: | |
attrs["jax.token"] = ir.BoolAttr.get(True) | |
if arg_names: | |
named_arg_attrs = arg_attrs[num_dim_vars + num_tokens:] | |
for attrs, name_ in zip(named_arg_attrs, arg_names): | |
if name_: | |
attrs['jax.arg_info'] = ir.StringAttr.get(name_) | |
func_op.arg_attrs = ir.ArrayAttr.get( | |
[ir.DictAttr.get(attrs) for attrs in arg_attrs]) | |
result_attrs: list[dict[str, ir.Attribute]] = [ | |
{} for _ in range(len(flat_output_types))] | |
if num_tokens > 0: | |
token_result_attrs = result_attrs[:num_tokens] | |
for attrs in token_result_attrs: | |
attrs["jax.token"] = ir.BoolAttr.get(True) | |
if result_names: | |
named_result_attrs = result_attrs[num_tokens:] | |
if len(named_result_attrs) == len(result_names): | |
for attrs, name_ in zip(named_result_attrs, result_names): | |
attrs['jax.result_info'] = ir.StringAttr.get(name_) | |
if use_sharding_annotations and ir_result_shardings is not None: | |
for attrs, sharding in zip(result_attrs, ir_result_shardings): | |
if sharding is not None: | |
attrs['mhlo.sharding'] = get_sharding_attr(sharding) | |
func_op.result_attrs = ir.ArrayAttr.get( | |
[ir.DictAttr.get(attrs) for attrs in result_attrs]) | |
entry_block = func_op.add_entry_block() | |
with ir.InsertionPoint(entry_block): | |
flat_args = entry_block.arguments | |
# We separate out the dimension variable inputs, the token inputs and | |
# the regular inputs. The dimension variables and token inputs | |
# will be passed to `jaxpr_subcomp` separately from the `args`. | |
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens]) | |
# A lowering context just for function body entry/exit code. | |
entry_lowering_ctx = LoweringRuleContext( | |
ctx, None, [], None, TokenSet.create([]), None, None, dim_var_values) | |
if not use_sharding_annotations and ir_arg_shardings is not None: | |
flat_args = [ | |
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s) | |
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)] | |
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)), | |
[num_dim_vars, num_tokens]) | |
if create_tokens: | |
tokens_in = TokenSet.create(effects) | |
else: | |
tokens_in = TokenSet(zip(effects, token_args)) | |
args: list[list[ir.Value]] = [] | |
for aval, arg in zip(jaxpr.in_avals, unflattened_args): | |
if replace_tokens_with_dummy and aval is core.abstract_token: | |
args.append(hlo.CreateTokenOp().results) | |
else: | |
args.append(arg) | |
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name)) | |
> out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), | |
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts), | |
*args, dim_var_values=dim_var_values) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:1060: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
ctx = ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbb2812df80>, module=<jax...object at 0x7fbb281097e0>, cached_primitive_lowerings={}, cached_call_jaxpr_lowerings={}, override_lowering_rules=None) | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False) | |
] i j k l m n o | |
in (p,) } | |
name=wrapped | |
] a b c d e f g | |
in (h,) } | |
tokens = <jax._src.interpreters.mlir.TokenSet object at 0x7fbb28109e70> | |
consts = [], dim_var_values = [] | |
args = ([<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb28116430>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockAr...ckArgument object at 0x7fbb281178f0>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb28115fb0>], ...) | |
aval = <function jaxpr_subcomp.<locals>.aval at 0x7fbc21d9a710> | |
write = <function jaxpr_subcomp.<locals>.write at 0x7fbc21d9a440> | |
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, | |
tokens: TokenSet, | |
consts: Sequence[Sequence[ir.Value]], | |
*args: Sequence[ir.Value], | |
dim_var_values: Sequence[ir.Value] | |
) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]: | |
"""Lowers a jaxpr into MLIR, inlined into an existing function. | |
Assumes that an MLIR context, location, and insertion point are set. | |
dim_var_values: the list of dimension variables values in the current | |
IR function, in the order of ctx.shape_poly_state.dim_vars. | |
""" | |
assert ctx.platform != "gpu" | |
def read(v: core.Atom) -> Sequence[ir.Value]: | |
if type(v) is core.Literal: | |
return ir_constants(v.val, canonicalize_types=True) | |
else: | |
assert isinstance(v, core.Var) | |
return env[v] | |
def aval(v: core.Atom) -> core.AbstractValue: | |
if type(v) is core.Literal: | |
return xla.abstractify(v.val) | |
else: | |
return v.aval | |
def write(v: core.Var, node: Sequence[ir.Value]): | |
assert node is not None | |
env[v] = tuple(node) | |
def get_lowering(primitive: core.Primitive) -> Optional[LoweringRule]: | |
if ctx.override_lowering_rules is None: | |
return None | |
for p, rule in ctx.override_lowering_rules: | |
if primitive is p: | |
return rule | |
return None | |
env: dict[core.Var, tuple[ir.Value, ...]] = {} | |
assert len(args) == len(jaxpr.invars), (jaxpr, args) | |
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts) | |
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts | |
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values) | |
map(write, jaxpr.constvars, consts) | |
map(write, jaxpr.invars, args) | |
last_used = core.last_used(jaxpr) | |
for eqn in jaxpr.eqns: | |
in_nodes = map(read, eqn.invars) | |
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack) | |
source_info = eqn.source_info.replace( | |
name_stack=ctx.name_stack + eqn.source_info.name_stack) | |
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info, | |
ctx.name_stack) | |
with source_info_util.user_context(eqn.source_info.traceback), loc: | |
override_rule = get_lowering(eqn.primitive) | |
if override_rule is not None: | |
rule = override_rule | |
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]: | |
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive] | |
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]: | |
rule = xla_fallback_lowering(eqn.primitive) | |
elif eqn.primitive in _lowerings: | |
rule = _lowerings[eqn.primitive] | |
elif eqn.primitive in xla._translations: | |
rule = xla_fallback_lowering(eqn.primitive) | |
else: | |
raise NotImplementedError( | |
f"MLIR translation rule for primitive '{eqn.primitive.name}' not " | |
f"found for platform {ctx.platform}") | |
eqn_ctx = ctx.replace(name_stack=source_info.name_stack) | |
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects)) | |
tokens_in = tokens.subset(effects) | |
avals_in = map(aval, eqn.invars) | |
rule_ctx = LoweringRuleContext( | |
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in, | |
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in, | |
tokens_out=None, dim_var_values=dim_var_values) | |
if config.jax_dynamic_shapes: | |
axis_size_env = {d: read(d)[0] | |
for a in avals_in if type(a) is core.DShapedArray | |
for d in a.shape if type(d) is core.Var} | |
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) | |
> ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes), | |
**eqn.params) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:1217: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
ctx = LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context obj...<jax._src.interpreters.mlir.TokenSet object at 0x7fbb28109c90>, tokens_out=None, axis_size_env=None, dim_var_values=[]) | |
name = 'wrapped' | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2... dtype=bfloat16),) | |
which_linear=(False, False, False, False, False, False, False) | |
] a b c d e f g | |
in (h,) } | |
in_shardings = (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, ...) | |
out_shardings = (UnspecifiedValue,) | |
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, | |
out_shardings, resource_env, donated_invars, | |
keep_unused, inline): | |
effects = list(ctx.tokens_in.effects()) | |
output_types = map(mlir.aval_to_ir_types, ctx.avals_out) | |
output_types = [mlir.token_type()] * len(effects) + output_types | |
flat_output_types = flatten(output_types) | |
arg_shardings = [None if is_unspecified(i) else | |
i._to_xla_hlo_sharding(aval.ndim) | |
for aval, i in zip(ctx.avals_in, in_shardings)] | |
result_shardings = [None if is_unspecified(o) else | |
o._to_xla_hlo_sharding(aval.ndim) | |
for aval, o in zip(ctx.avals_out, out_shardings)] | |
# TODO(b/228598865): inlined calls cannot have shardings set directly on the | |
# inputs or outputs because they are lost during MLIR->HLO conversion. | |
# using_sharding_annotation=False means we add an identity operation instead. | |
> func = mlir.lower_jaxpr_to_fun( | |
ctx.module_context, name, jaxpr, effects, arg_shardings=arg_shardings, | |
result_shardings=result_shardings, use_sharding_annotations=False, | |
api_name=('jit' if resource_env is None else 'pjit')) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1395: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
ctx = ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbb2812df80>, module=<jax...object at 0x7fbb281097e0>, cached_primitive_lowerings={}, cached_call_jaxpr_lowerings={}, override_lowering_rules=None) | |
name = 'wrapped' | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2... dtype=bfloat16),) | |
which_linear=(False, False, False, False, False, False, False) | |
] a b c d e f g | |
in (h,) } | |
effects = [] | |
def lower_jaxpr_to_fun( | |
ctx: ModuleContext, | |
name: str, | |
jaxpr: core.ClosedJaxpr, | |
effects: Sequence[core.Effect], | |
*, | |
create_tokens: bool = False, | |
public: bool = False, | |
replace_tokens_with_dummy: bool = False, | |
replicated_args: Optional[Sequence[bool]] = None, | |
arg_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None, | |
result_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None, | |
use_sharding_annotations: bool = True, | |
input_output_aliases: Optional[Sequence[Optional[int]]] = None, | |
num_output_tokens: int = 0, | |
api_name: str = "jit", | |
arg_names: Optional[Sequence[Optional[str]]] = None, | |
result_names: Optional[Sequence[Optional[str]]] = None, | |
) -> func_dialect.FuncOp: | |
"""Lowers jaxpr and its callees to an IR function. | |
Assumes that an MLIR context, location, and insertion point are set. | |
Args: | |
ctx: the lowering context. | |
name: the function name. The name will be uniquified by the symbol table, | |
so it is ok to use the same name multiple times. | |
jaxpr: the jaxpr to lower. | |
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens | |
that will be created in or used by the lowered function. | |
create_tokens: if true, the HLO will create tokens and ignore dummy input tokens. | |
public: if true, the function's visibility is set to "public". | |
replace_tokens_with_dummy: if true, token arguments/return values are | |
replaced with bool arrays of size [0]. | |
replicated_args: if present, annotates arguments as replicated. | |
arg_shardings: sharding annotations for each argument (optional). | |
result_shardings: sharding annotations for each result (optional). | |
use_sharding_annotations: if True, use "mhlo.sharding" annotations on | |
parameters and return values to express sharding. If False, use | |
hlo.custom_call operators with sharding annotations. | |
TODO(b/228598865): remove this option when "mhlo.sharding" annotations are | |
propagated on non-entry functions during MLIR->HLO conversion. | |
input_output_aliases: optional sequence that maps argument numbers to the | |
corresponding output that should alias them. | |
api_name: The name of the higher level primitive which should show up in the | |
name stack. | |
Returns: | |
MLIR func op | |
""" | |
def aval_to_types(aval): | |
if replace_tokens_with_dummy and aval is core.abstract_token: | |
aval = core.ShapedArray((), np.dtype(np.bool_)) | |
return aval_to_ir_types(aval) | |
num_dim_vars = len(ctx.shape_poly_state.dim_vars) | |
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars | |
dim_var_types = map(aval_to_types, dim_var_avals) | |
# Function inputs: *dim_var_values, *tokens, *actual_inputs | |
input_types = map(aval_to_types, jaxpr.in_avals) | |
output_types = map(aval_to_types, jaxpr.out_avals) | |
num_tokens = len(effects) | |
if create_tokens: | |
# If we create the tokens they won't be inputs to the MLIR function. | |
token_types = [dummy_token_type() for _ in effects] | |
output_token_types = [dummy_token_type() for _ in range(num_output_tokens)] | |
else: | |
# If we aren't creating tokens they will be the initial inputs to the | |
# MLIR function. | |
output_token_types = [] | |
token_types = [token_type() for _ in effects] | |
token_avals = [core.AbstractToken] * num_tokens | |
# Order of arguments: dim vars, tokens, array inputs | |
input_avals = dim_var_avals + token_avals + jaxpr.in_avals | |
input_types = [*dim_var_types, *token_types, *input_types] | |
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals | |
output_types = [*output_token_types, *token_types, *output_types] | |
if input_output_aliases is not None: | |
token_input_output_aliases = [None] * (num_dim_vars + num_tokens) | |
input_output_aliases = [*token_input_output_aliases, *input_output_aliases] | |
# Update the existing aliases to account for the new output values | |
input_output_aliases = [None if a is None | |
else a + num_output_tokens + num_tokens | |
for a in input_output_aliases] | |
if arg_shardings is not None: | |
token_shardings = [None] * (num_dim_vars + num_tokens) | |
arg_shardings = [*token_shardings, *arg_shardings] | |
if result_shardings is not None: | |
token_shardings = [None] * (num_tokens + num_output_tokens) | |
result_shardings = [*token_shardings, *result_shardings] | |
if replicated_args is not None: | |
token_replicated_args = [False] * (num_dim_vars + num_tokens) | |
replicated_args = [*token_replicated_args, *replicated_args] | |
flat_input_types = util.flatten(input_types) | |
flat_output_types = util.flatten(output_types) | |
ftype = ir.FunctionType.get(flat_input_types, flat_output_types) | |
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip) | |
func_op.attributes["sym_visibility"] = ir.StringAttr.get( | |
"public" if public else "private") | |
ctx.symbol_table.insert(func_op) | |
ir_arg_shardings = None | |
if arg_shardings is not None: | |
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals) | |
ir_arg_shardings = util.flatten( | |
[[_to_physical_op_sharding(a, s)] * len(types) | |
for a, s, types in zip(in_avals, arg_shardings, input_types)]) | |
del in_avals | |
ir_result_shardings = None | |
if result_shardings is not None: | |
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals) | |
ir_result_shardings = util.flatten( | |
[[_to_physical_op_sharding(a, s)] * len(types) | |
for a, s, types in zip(out_avals, result_shardings, output_types)]) | |
del out_avals | |
if ( | |
replicated_args is not None | |
or ir_arg_shardings is not None | |
or input_output_aliases is not None | |
or arg_names is not None | |
or num_tokens > 0 | |
): | |
arg_attrs: list[dict[str, ir.Attribute]] = [ | |
{} for _ in range(len(flat_input_types))] | |
if replicated_args is not None: | |
replicated_ir_args = [[replicated] * len(types) for replicated, types | |
in zip(replicated_args, input_types)] | |
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)): | |
if replicated: | |
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get() | |
if use_sharding_annotations and ir_arg_shardings is not None: | |
for attrs, sharding in zip(arg_attrs, ir_arg_shardings): | |
if sharding is not None: | |
attrs["mhlo.sharding"] = get_sharding_attr(sharding) | |
if input_output_aliases is not None: | |
output_ids = util.unflatten(list(range(len(flat_output_types))), | |
map(len, output_types)) | |
aliases: list[Optional[int]] = [] | |
for types, alias in zip(input_types, input_output_aliases): | |
if alias is None: | |
aliases.extend([None] * len(types)) | |
else: | |
aliases.extend(output_ids[alias]) | |
for attrs, alias in zip(arg_attrs, aliases): | |
if alias is not None: | |
attrs["tf.aliasing_output"] = i32_attr(alias) | |
if num_tokens > 0: | |
token_arg_attrs = arg_attrs[num_dim_vars:num_dim_vars + num_tokens] | |
for attrs in token_arg_attrs: | |
attrs["jax.token"] = ir.BoolAttr.get(True) | |
if arg_names: | |
named_arg_attrs = arg_attrs[num_dim_vars + num_tokens:] | |
for attrs, name_ in zip(named_arg_attrs, arg_names): | |
if name_: | |
attrs['jax.arg_info'] = ir.StringAttr.get(name_) | |
func_op.arg_attrs = ir.ArrayAttr.get( | |
[ir.DictAttr.get(attrs) for attrs in arg_attrs]) | |
result_attrs: list[dict[str, ir.Attribute]] = [ | |
{} for _ in range(len(flat_output_types))] | |
if num_tokens > 0: | |
token_result_attrs = result_attrs[:num_tokens] | |
for attrs in token_result_attrs: | |
attrs["jax.token"] = ir.BoolAttr.get(True) | |
if result_names: | |
named_result_attrs = result_attrs[num_tokens:] | |
if len(named_result_attrs) == len(result_names): | |
for attrs, name_ in zip(named_result_attrs, result_names): | |
attrs['jax.result_info'] = ir.StringAttr.get(name_) | |
if use_sharding_annotations and ir_result_shardings is not None: | |
for attrs, sharding in zip(result_attrs, ir_result_shardings): | |
if sharding is not None: | |
attrs['mhlo.sharding'] = get_sharding_attr(sharding) | |
func_op.result_attrs = ir.ArrayAttr.get( | |
[ir.DictAttr.get(attrs) for attrs in result_attrs]) | |
entry_block = func_op.add_entry_block() | |
with ir.InsertionPoint(entry_block): | |
flat_args = entry_block.arguments | |
# We separate out the dimension variable inputs, the token inputs and | |
# the regular inputs. The dimension variables and token inputs | |
# will be passed to `jaxpr_subcomp` separately from the `args`. | |
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens]) | |
# A lowering context just for function body entry/exit code. | |
entry_lowering_ctx = LoweringRuleContext( | |
ctx, None, [], None, TokenSet.create([]), None, None, dim_var_values) | |
if not use_sharding_annotations and ir_arg_shardings is not None: | |
flat_args = [ | |
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s) | |
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)] | |
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)), | |
[num_dim_vars, num_tokens]) | |
if create_tokens: | |
tokens_in = TokenSet.create(effects) | |
else: | |
tokens_in = TokenSet(zip(effects, token_args)) | |
args: list[list[ir.Value]] = [] | |
for aval, arg in zip(jaxpr.in_avals, unflattened_args): | |
if replace_tokens_with_dummy and aval is core.abstract_token: | |
args.append(hlo.CreateTokenOp().results) | |
else: | |
args.append(arg) | |
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name)) | |
> out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), | |
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts), | |
*args, dim_var_values=dim_var_values) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:1060: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
ctx = ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbb2812df80>, module=<jax...object at 0x7fbb281097e0>, cached_primitive_lowerings={}, cached_call_jaxpr_lowerings={}, override_lowering_rules=None) | |
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384] | |
e:i32[2,384] f:i32[2,384] g:i32[2... dtype=bfloat16),) | |
which_linear=(False, False, False, False, False, False, False) | |
] a b c d e f g | |
in (h,) } | |
tokens = <jax._src.interpreters.mlir.TokenSet object at 0x7fbb28109ed0> | |
consts = [], dim_var_values = [] | |
args = ([<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb2811f770>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockAr...ckArgument object at 0x7fbb2811d330>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb2815b430>], ...) | |
aval = <function jaxpr_subcomp.<locals>.aval at 0x7fbc21d9a8c0> | |
write = <function jaxpr_subcomp.<locals>.write at 0x7fbc21d9a830> | |
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, | |
tokens: TokenSet, | |
consts: Sequence[Sequence[ir.Value]], | |
*args: Sequence[ir.Value], | |
dim_var_values: Sequence[ir.Value] | |
) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]: | |
"""Lowers a jaxpr into MLIR, inlined into an existing function. | |
Assumes that an MLIR context, location, and insertion point are set. | |
dim_var_values: the list of dimension variables values in the current | |
IR function, in the order of ctx.shape_poly_state.dim_vars. | |
""" | |
assert ctx.platform != "gpu" | |
def read(v: core.Atom) -> Sequence[ir.Value]: | |
if type(v) is core.Literal: | |
return ir_constants(v.val, canonicalize_types=True) | |
else: | |
assert isinstance(v, core.Var) | |
return env[v] | |
def aval(v: core.Atom) -> core.AbstractValue: | |
if type(v) is core.Literal: | |
return xla.abstractify(v.val) | |
else: | |
return v.aval | |
def write(v: core.Var, node: Sequence[ir.Value]): | |
assert node is not None | |
env[v] = tuple(node) | |
def get_lowering(primitive: core.Primitive) -> Optional[LoweringRule]: | |
if ctx.override_lowering_rules is None: | |
return None | |
for p, rule in ctx.override_lowering_rules: | |
if primitive is p: | |
return rule | |
return None | |
env: dict[core.Var, tuple[ir.Value, ...]] = {} | |
assert len(args) == len(jaxpr.invars), (jaxpr, args) | |
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts) | |
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts | |
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values) | |
map(write, jaxpr.constvars, consts) | |
map(write, jaxpr.invars, args) | |
last_used = core.last_used(jaxpr) | |
for eqn in jaxpr.eqns: | |
in_nodes = map(read, eqn.invars) | |
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack) | |
source_info = eqn.source_info.replace( | |
name_stack=ctx.name_stack + eqn.source_info.name_stack) | |
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info, | |
ctx.name_stack) | |
with source_info_util.user_context(eqn.source_info.traceback), loc: | |
override_rule = get_lowering(eqn.primitive) | |
if override_rule is not None: | |
rule = override_rule | |
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]: | |
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive] | |
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]: | |
rule = xla_fallback_lowering(eqn.primitive) | |
elif eqn.primitive in _lowerings: | |
rule = _lowerings[eqn.primitive] | |
elif eqn.primitive in xla._translations: | |
rule = xla_fallback_lowering(eqn.primitive) | |
else: | |
raise NotImplementedError( | |
f"MLIR translation rule for primitive '{eqn.primitive.name}' not " | |
f"found for platform {ctx.platform}") | |
eqn_ctx = ctx.replace(name_stack=source_info.name_stack) | |
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects)) | |
tokens_in = tokens.subset(effects) | |
avals_in = map(aval, eqn.invars) | |
rule_ctx = LoweringRuleContext( | |
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in, | |
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in, | |
tokens_out=None, dim_var_values=dim_var_values) | |
if config.jax_dynamic_shapes: | |
axis_size_env = {d: read(d)[0] | |
for a in avals_in if type(a) is core.DShapedArray | |
for d in a.shape if type(d) is core.Var} | |
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env) | |
> ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes), | |
**eqn.params) | |
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:1217: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
ctx = LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context obj...<jax._src.interpreters.mlir.TokenSet object at 0x7fbb2810ad70>, tokens_out=None, axis_size_env=None, dim_var_values=[]) | |
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]} | |
d:Ref{int32[384]} e:Ref{int32[3...] bl | |
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p | |
h[m:m+128,:] <- dz | |
in () } | |
name = 'mha_forward' | |
in_shapes = (ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), Sha... 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ...) | |
out_shapes = (ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16),) | |
which_linear = (False, False, False, False, False, False, ...) | |
interpret = False, debug = False, input_output_aliases = () | |
grid_mapping = GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2...4), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0) | |
in_nodes = (<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb2811f770>, <jaxlib.mlir._mlir_libs._mlir.ir.BlockArgum...BlockArgument object at 0x7fbb2811d330>, <jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb2815b430>, ...) | |
compiler_params = {'num_stages': 2, 'num_warps': 4}, num_warps = 4 | |
def pallas_call_lowering( | |
ctx: mlir.LoweringRuleContext, | |
*in_nodes, | |
jaxpr: jax_core.Jaxpr, | |
name: str, | |
in_shapes: Tuple[jax.ShapeDtypeStruct, ...], | |
out_shapes: Tuple[jax.ShapeDtypeStruct, ...], | |
which_linear: Tuple[bool, ...], | |
interpret: bool, | |
debug: bool, | |
input_output_aliases: Tuple[Tuple[int, int], ...], | |
grid_mapping: GridMapping, | |
**compiler_params: Any | |
): | |
if interpret: | |
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)( | |
ctx, | |
*in_nodes, | |
jaxpr=jaxpr, | |
name=name, | |
out_shapes=out_shapes, | |
in_shapes=in_shapes, | |
which_linear=which_linear, | |
interpret=interpret, | |
debug=debug, | |
input_output_aliases=input_output_aliases, | |
grid_mapping=grid_mapping, | |
**compiler_params | |
) | |
num_warps = compiler_params.get("num_warps", 4) | |
num_stages = compiler_params.get("num_stages", 3) | |
if debug: | |
print(jaxpr) | |
print(grid_mapping) | |
> compilation_result = compile_jaxpr( | |
jaxpr, | |
tuple((*in_shapes, *out_shapes)), | |
grid_mapping, | |
name, | |
num_warps, | |
num_stages, | |
debug=debug, | |
) | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1661: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]} | |
d:Ref{int32[384]} e:Ref{int32[3...] bl | |
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p | |
h[m:m+128,:] <- dz | |
in () } | |
in_shapes = (ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), Sha... 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ...) | |
grid_mapping = GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2...4), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0) | |
name = 'mha_forward', num_warps = 4, num_stages = 2, debug = False | |
@weakref_lru_cache | |
def compile_jaxpr( | |
jaxpr: jax_core.Jaxpr, | |
in_shapes, | |
grid_mapping: GridMapping, | |
name: str, | |
num_warps: int, | |
num_stages: int, | |
debug: bool, | |
) -> TritonCompilationResult: | |
> lowering_result = lower_jaxpr_to_triton_module( | |
jaxpr, in_shapes, grid_mapping, name | |
) | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1610: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]} | |
d:Ref{int32[384]} e:Ref{int32[3...] bl | |
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p | |
h[m:m+128,:] <- dz | |
in () } | |
in_shapes = (ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), Sha... 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ...) | |
grid_mapping = GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2...4), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0) | |
name = 'mha_forward' | |
def lower_jaxpr_to_triton_module( | |
jaxpr: jax_core.Jaxpr, in_shapes, grid_mapping: GridMapping, name: str | |
) -> tl_ir.module: | |
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True) | |
ir_context = tl_ir.context() | |
ir_context.load_triton() | |
builder = tl_ir.builder(ir_context) | |
# TODO(sharadmv): handle multiple devices, right now we assume device 0 | |
# which is fine when we have multiple of the same GPU but this won't work in | |
# general. | |
device = 0 | |
builder.arch = triton_kernel_call_lib.get_compute_capability(device) | |
module = builder.create_module() | |
in_avals = [var.aval for var in jaxpr.invars] | |
triton_types = [get_triton_type(x) for x in in_avals] | |
arg_types = [code_gen.str_to_ty(arg) for arg in triton_types] | |
assert len(jaxpr.outvars) == 0 | |
prototype = tl.function_type([], arg_types) | |
out = prototype.to_ir(builder) | |
fn = builder.get_or_insert_function(module, name, out, "public", False) | |
module.push_back(fn) | |
entry = fn.add_entry_block() | |
args = [] | |
for i in range(len(in_avals)): | |
fn.set_arg_attr(i, "tt.divisibility", 16) | |
ptr = tl.tensor(fn.args(i), prototype.param_types[i]) | |
args.append(ptr) | |
builder.set_insertion_point_to_start(entry) | |
new_grid, program_ids = _process_grid_to_3d_grid(builder, grid_mapping) | |
local_program_ids = [ | |
pid for i, pid in enumerate(program_ids) if i not in grid_mapping.mapped_dims | |
] | |
ctx = TritonModuleContext( | |
name, ir_context, builder, module, grid_mapping, local_program_ids | |
) | |
if grid_mapping.num_index_operands: | |
raise NotImplementedError( | |
"Scalar prefetch not supported in Triton lowering.") | |
start_indices = map( | |
partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings | |
) | |
block_infos = [ | |
BlockInfo( | |
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype), | |
start_idx, | |
block_mapping.block_shape, | |
) | |
if block_mapping is not None | |
else None | |
for shape_dtype, block_mapping, start_idx in zip( | |
in_shapes, grid_mapping.block_mappings, start_indices | |
) | |
] | |
> () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args) | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:222: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
ctx = TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, b...9660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>]) | |
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]} | |
d:Ref{int32[384]} e:Ref{int32[3...] bl | |
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p | |
h[m:m+128,:] <- dz | |
in () } | |
block_infos = [BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), start_indices=(<triton.language.c...e.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), ...] | |
args = (<triton.language.core.tensor object at 0x7fbb28150c70>, <triton.language.core.tensor object at 0x7fbb281510f0>, <trit...>, <triton.language.core.tensor object at 0x7fbb28198460>, <triton.language.core.tensor object at 0x7fbb28198b20>, ...) | |
read_env = <function lower_jaxpr_to_triton_ir.<locals>.read_env at 0x7fbc21d9ab00> | |
read_block_info_env = <function lower_jaxpr_to_triton_ir.<locals>.read_block_info_env at 0x7fbc21d9ae60> | |
write_env = <function lower_jaxpr_to_triton_ir.<locals>.write_env at 0x7fbc21d9aef0> | |
def lower_jaxpr_to_triton_ir( | |
ctx: TritonModuleContext, | |
jaxpr: jax_core.Jaxpr, | |
block_infos: Optional[Sequence[Optional[BlockInfo]]], | |
*args | |
) -> Sequence[Any]: | |
env = {} | |
block_info_env = {} | |
def read_env(var: jax_core.Atom): | |
if type(var) is jax_core.Literal: | |
t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder) | |
dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty | |
if t.type.scalar != dst_ty: | |
# _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64 | |
# comes out of .tolist() as list[float], which then comes out of | |
# _to_tensor as a block of f32. | |
t = tl.semantic.cast(t, dst_ty, ctx.builder) | |
return t | |
return env[var] | |
def read_block_info_env(var: jax_core.Atom): | |
if type(var) is jax_core.Literal: | |
return None | |
return block_info_env.get(var, None) | |
def write_env(var: jax_core.Var, val): | |
env[var] = val | |
if block_infos is None: | |
block_infos = [None] * len(jaxpr.invars) | |
for invar, block_info in zip(jaxpr.invars, block_infos): | |
block_info_env[invar] = block_info | |
map(write_env, jaxpr.invars, args) | |
for eqn in jaxpr.eqns: | |
invals = map(read_env, eqn.invars) | |
if eqn.primitive not in triton_lowering_rules: | |
raise NotImplementedError(eqn.primitive) | |
rule = triton_lowering_rules[eqn.primitive] | |
avals_in = [v.aval for v in eqn.invars] | |
avals_out = [v.aval for v in eqn.outvars] | |
eqn_block_infos = map(read_block_info_env, eqn.invars) | |
rule_ctx = TritonLoweringRuleContext( | |
ctx, avals_in, avals_out, eqn_block_infos | |
) | |
try: | |
outvals = rule(rule_ctx, *invals, **eqn.params) | |
except TritonLoweringException: | |
raise # We only add the extra info to the innermost exception. | |
except Exception as e: | |
> raise TritonLoweringException( | |
f"Exception while lowering eqn:\n {eqn}\n" | |
f"With context:\n {rule_ctx}\n" | |
f"With inval shapes={map(lambda t: t.shape, invals)}\n" | |
f"With inval types={map(lambda t: t.type, invals)}\n" | |
f"In jaxpr:\n{jaxpr}") from e | |
E jax._src.traceback_util.UnfilteredStackTrace: jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn: | |
E a:i32[6] = scan[ | |
E jaxpr={ lambda ; b:Ref{int32[384]} c:i32[]. let | |
E d:i32[] = mul c 64 | |
E e:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)] | |
E f:i32[64] = add e d | |
E g:bool[64] = lt f 384 | |
E h:i32[64] <- b[d:d+64] | |
E i:i32[] = reduce_min[axes=(0,)] h | |
E in (i,) } | |
E length=6 | |
E linear=(False, False) | |
E num_carry=0 | |
E num_consts=1 | |
E reverse=False | |
E unroll=1 | |
E ] j k | |
E With context: | |
E TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, builder=<triton._C.libtriton.triton.ir.builder object at 0x7fbb28171e90>, module=<triton._C.libtriton.triton.ir.module object at 0x7fbb28172110>, grid_mapping=GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0), program_ids=[<triton.language.core.tensor object at 0x7fbb28199660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>]), avals_in=[Ref{int32[384]}, ShapedArray(int32[6])], avals_out=[ShapedArray(int32[6])], block_infos=[BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384), dtype=int32), start_indices=(<triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), None]) | |
E With inval shapes=[[constexpr[1]], [constexpr[6]]] | |
E With inval types=[pointer<int32>, <[6], int32>] | |
E In jaxpr: | |
E { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]} | |
E d:Ref{int32[384]} e:Ref{int32[384]} f:Ref{int32[384]} g:Ref{int32[384]} h:Ref{bfloat16[384,64]}. let | |
E i:i32[] = program_id[axis=0] | |
E j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0 | |
E k:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] -1.0000000200408773e+20 | |
E l:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] 0.0 | |
E m:i32[] = mul i 128 | |
E n:i32[128] = iota[dimension=0 dtype=int32 shape=(128,)] | |
E o:i32[128] = add n m | |
E p:bool[128] = lt o 384 | |
E q:bf16[128,64] <- a[m:m+128,:] | |
E r:i32[128] <- d[m:m+128] | |
E s:i32[128] <- e[m:m+128] | |
E t:bf16[128,64] = mul q 0.125 | |
E u:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)] | |
E v:i32[6] = scan[ | |
E jaxpr={ lambda ; w:Ref{int32[384]} x:i32[]. let | |
E y:i32[] = mul x 64 | |
E z:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)] | |
E ba:i32[64] = add z y | |
E bb:bool[64] = lt ba 384 | |
E bc:i32[64] <- w[y:y+64] | |
E bd:i32[] = reduce_min[axes=(0,)] bc | |
E in (bd,) } | |
E length=6 | |
E linear=(False, False) | |
E num_carry=0 | |
E num_consts=1 | |
E reverse=False | |
E unroll=1 | |
E ] f u | |
E be:i32[] = reduce_max[axes=(0,)] r | |
E bf:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)] | |
E bg:bool[6] = ge be v | |
E bh:i32[6] = convert_element_type[new_dtype=int32 weak_type=False] bg | |
E bi:i32[6] = mul bf bh | |
E bj:i32[] = argmax[axes=(0,) index_dtype=int32] bi | |
E bk:i32[] = add bj 1 | |
E _:i32[] _:i32[] bl:f32[128,64] _:f32[128] _:f32[128] = while[ | |
E body_jaxpr={ lambda ; bm:Ref{bfloat16[384,64]} bn:Ref{int32[384]} bo:Ref{int32[384]} | |
E bp:bf16[128,64] bq:i32[128] br:i32[128] bs:Ref{bfloat16[384,64]} bt:i32[] | |
E bu:i32[] bv:f32[128,64] bw:f32[128] bx:f32[128]. let | |
E by:i32[] = add bt 1 | |
E bz:i32[] = mul bt 64 | |
E ca:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)] | |
E cb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bz | |
E cc:i32[64] = add ca cb | |
E cd:bool[64] = lt cc 384 | |
E ce:bf16[64,64] <- bm[bz:bz+64,:] | |
E cf:i32[64] <- bn[bz:bz+64] | |
E cg:i32[64] <- bo[bz:bz+64] | |
E ch:bf16[64,64] = transpose[permutation=(1, 0)] ce | |
E ci:bf16[128,64] = dot_general[ | |
E dimension_numbers=(([1], [0]), ([], [])) | |
E ] bp ch | |
E cj:f32[128,64] = convert_element_type[ | |
E new_dtype=float32 | |
E weak_type=False | |
E ] ci | |
E ck:i32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] bq | |
E cl:i32[1,64] = broadcast_in_dim[ | |
E broadcast_dimensions=(1,) | |
E shape=(1, 64) | |
E ] cg | |
E cm:bool[128,64] = eq ck cl | |
E cn:i32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] br | |
E co:i32[1,64] = broadcast_in_dim[ | |
E broadcast_dimensions=(1,) | |
E shape=(1, 64) | |
E ] cf | |
E cp:bool[128,64] = ge cn co | |
E cq:bool[128,64] = and cm cp | |
E cr:f32[128,64] = pjit[ | |
E jaxpr={ lambda ; cs:bool[128,64] ct:f32[128,64] cu:f32[]. let | |
E cv:f32[] = convert_element_type[ | |
E new_dtype=float32 | |
E weak_type=False | |
E ] cu | |
E cw:f32[128,64] = broadcast_in_dim[ | |
E broadcast_dimensions=() | |
E shape=(128, 64) | |
E ] cv | |
E cx:f32[128,64] = select_n cs cw ct | |
E in (cx,) } | |
E name=_where | |
E ] cq cj -1e+20 | |
E cy:f32[128] = reduce_max[axes=(1,)] cr | |
E cz:f32[128] = max cy bw | |
E da:f32[128] = sub bw cz | |
E db:f32[128] = exp da | |
E dc:f32[128] = mul bx db | |
E dd:f32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] cz | |
E de:f32[128,64] = sub cr dd | |
E df:f32[128,64] = exp de | |
E dg:f32[128] = reduce_sum[axes=(1,)] df | |
E dh:f32[128] = add dg dc | |
E di:f32[128] = div 1.0 dh | |
E dj:f32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] di | |
E dk:f32[128,64] = mul df dj | |
E dl:f32[128] = mul dc di | |
E dm:f32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] dl | |
E dn:f32[128,64] = mul bv dm | |
E do:bf16[64,64] <- bs[bz:bz+64,:] | |
E dp:bf16[128,64] = convert_element_type[ | |
E new_dtype=bfloat16 | |
E weak_type=False | |
E ] dk | |
E dq:bf16[128,64] = dot_general[ | |
E dimension_numbers=(([1], [0]), ([], [])) | |
E ] dp do | |
E dr:f32[128,64] = convert_element_type[ | |
E new_dtype=float32 | |
E weak_type=False | |
E ] dq | |
E ds:f32[128,64] = add dn dr | |
E in (by, bu, ds, cz, dh) } | |
E body_nconsts=7 | |
E cond_jaxpr={ lambda ; dt:i32[] du:i32[] dv:f32[128,64] dw:f32[128] dx:f32[128]. let | |
E dy:bool[] = lt dt du | |
E in (dy,) } | |
E cond_nconsts=0 | |
E ] b f g t s r c 0 bk j k l | |
E dz:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] bl | |
E ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p | |
E h[m:m+128,:] <- dz | |
E in () } | |
E | |
E The stack trace below excludes JAX-internal frames. | |
E The preceding is the original exception that occurred, unmodified. | |
E | |
E -------------------- | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:278: UnfilteredStackTrace | |
The above exception was the direct cause of the following exception: | |
query_seq_len = 384, kv_seq_len = 384 | |
@pytest.mark.gpu | |
@pytest.mark.parametrize( | |
"query_seq_len,kv_seq_len", | |
[ | |
(384, 384), | |
(400, 400), | |
(384, 256), | |
(384, 300), | |
(400, 256), | |
(400, 300), | |
(1, 256), | |
(1, 300), | |
(256, 1), | |
(300, 1), | |
(1, 3), | |
], | |
) | |
def test_triton_flash_attention(query_seq_len, kv_seq_len): | |
q, k, v, q_pos, q_sid, kv_pos, kv_sid = _make_attention_inputs( | |
query_seq_len=query_seq_len, | |
kv_seq_len=kv_seq_len, | |
) | |
> result_triton = triton_multi_head_attention( | |
q, | |
k, | |
v, | |
q_pos, | |
q_sid, | |
kv_pos, | |
kv_sid, | |
) | |
tests/test_attention.py:92: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1661: in pallas_call_lowering | |
compilation_result = compile_jaxpr( | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1610: in compile_jaxpr | |
lowering_result = lower_jaxpr_to_triton_module( | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:222: in lower_jaxpr_to_triton_module | |
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
ctx = TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, b...9660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>]) | |
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]} | |
d:Ref{int32[384]} e:Ref{int32[3...] bl | |
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p | |
h[m:m+128,:] <- dz | |
in () } | |
block_infos = [BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), start_indices=(<triton.language.c...e.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), ...] | |
args = (<triton.language.core.tensor object at 0x7fbb28150c70>, <triton.language.core.tensor object at 0x7fbb281510f0>, <trit...>, <triton.language.core.tensor object at 0x7fbb28198460>, <triton.language.core.tensor object at 0x7fbb28198b20>, ...) | |
read_env = <function lower_jaxpr_to_triton_ir.<locals>.read_env at 0x7fbc21d9ab00> | |
read_block_info_env = <function lower_jaxpr_to_triton_ir.<locals>.read_block_info_env at 0x7fbc21d9ae60> | |
write_env = <function lower_jaxpr_to_triton_ir.<locals>.write_env at 0x7fbc21d9aef0> | |
def lower_jaxpr_to_triton_ir( | |
ctx: TritonModuleContext, | |
jaxpr: jax_core.Jaxpr, | |
block_infos: Optional[Sequence[Optional[BlockInfo]]], | |
*args | |
) -> Sequence[Any]: | |
env = {} | |
block_info_env = {} | |
def read_env(var: jax_core.Atom): | |
if type(var) is jax_core.Literal: | |
t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder) | |
dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty | |
if t.type.scalar != dst_ty: | |
# _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64 | |
# comes out of .tolist() as list[float], which then comes out of | |
# _to_tensor as a block of f32. | |
t = tl.semantic.cast(t, dst_ty, ctx.builder) | |
return t | |
return env[var] | |
def read_block_info_env(var: jax_core.Atom): | |
if type(var) is jax_core.Literal: | |
return None | |
return block_info_env.get(var, None) | |
def write_env(var: jax_core.Var, val): | |
env[var] = val | |
if block_infos is None: | |
block_infos = [None] * len(jaxpr.invars) | |
for invar, block_info in zip(jaxpr.invars, block_infos): | |
block_info_env[invar] = block_info | |
map(write_env, jaxpr.invars, args) | |
for eqn in jaxpr.eqns: | |
invals = map(read_env, eqn.invars) | |
if eqn.primitive not in triton_lowering_rules: | |
raise NotImplementedError(eqn.primitive) | |
rule = triton_lowering_rules[eqn.primitive] | |
avals_in = [v.aval for v in eqn.invars] | |
avals_out = [v.aval for v in eqn.outvars] | |
eqn_block_infos = map(read_block_info_env, eqn.invars) | |
rule_ctx = TritonLoweringRuleContext( | |
ctx, avals_in, avals_out, eqn_block_infos | |
) | |
try: | |
outvals = rule(rule_ctx, *invals, **eqn.params) | |
except TritonLoweringException: | |
raise # We only add the extra info to the innermost exception. | |
except Exception as e: | |
> raise TritonLoweringException( | |
f"Exception while lowering eqn:\n {eqn}\n" | |
f"With context:\n {rule_ctx}\n" | |
f"With inval shapes={map(lambda t: t.shape, invals)}\n" | |
f"With inval types={map(lambda t: t.type, invals)}\n" | |
f"In jaxpr:\n{jaxpr}") from e | |
E jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn: | |
E a:i32[6] = scan[ | |
E jaxpr={ lambda ; b:Ref{int32[384]} c:i32[]. let | |
E d:i32[] = mul c 64 | |
E e:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)] | |
E f:i32[64] = add e d | |
E g:bool[64] = lt f 384 | |
E h:i32[64] <- b[d:d+64] | |
E i:i32[] = reduce_min[axes=(0,)] h | |
E in (i,) } | |
E length=6 | |
E linear=(False, False) | |
E num_carry=0 | |
E num_consts=1 | |
E reverse=False | |
E unroll=1 | |
E ] j k | |
E With context: | |
E TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, builder=<triton._C.libtriton.triton.ir.builder object at 0x7fbb28171e90>, module=<triton._C.libtriton.triton.ir.module object at 0x7fbb28172110>, grid_mapping=GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0), program_ids=[<triton.language.core.tensor object at 0x7fbb28199660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>]), avals_in=[Ref{int32[384]}, ShapedArray(int32[6])], avals_out=[ShapedArray(int32[6])], block_infos=[BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384), dtype=int32), start_indices=(<triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), None]) | |
E With inval shapes=[[constexpr[1]], [constexpr[6]]] | |
E With inval types=[pointer<int32>, <[6], int32>] | |
E In jaxpr: | |
E { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]} | |
E d:Ref{int32[384]} e:Ref{int32[384]} f:Ref{int32[384]} g:Ref{int32[384]} h:Ref{bfloat16[384,64]}. let | |
E i:i32[] = program_id[axis=0] | |
E j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0 | |
E k:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] -1.0000000200408773e+20 | |
E l:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] 0.0 | |
E m:i32[] = mul i 128 | |
E n:i32[128] = iota[dimension=0 dtype=int32 shape=(128,)] | |
E o:i32[128] = add n m | |
E p:bool[128] = lt o 384 | |
E q:bf16[128,64] <- a[m:m+128,:] | |
E r:i32[128] <- d[m:m+128] | |
E s:i32[128] <- e[m:m+128] | |
E t:bf16[128,64] = mul q 0.125 | |
E u:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)] | |
E v:i32[6] = scan[ | |
E jaxpr={ lambda ; w:Ref{int32[384]} x:i32[]. let | |
E y:i32[] = mul x 64 | |
E z:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)] | |
E ba:i32[64] = add z y | |
E bb:bool[64] = lt ba 384 | |
E bc:i32[64] <- w[y:y+64] | |
E bd:i32[] = reduce_min[axes=(0,)] bc | |
E in (bd,) } | |
E length=6 | |
E linear=(False, False) | |
E num_carry=0 | |
E num_consts=1 | |
E reverse=False | |
E unroll=1 | |
E ] f u | |
E be:i32[] = reduce_max[axes=(0,)] r | |
E bf:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)] | |
E bg:bool[6] = ge be v | |
E bh:i32[6] = convert_element_type[new_dtype=int32 weak_type=False] bg | |
E bi:i32[6] = mul bf bh | |
E bj:i32[] = argmax[axes=(0,) index_dtype=int32] bi | |
E bk:i32[] = add bj 1 | |
E _:i32[] _:i32[] bl:f32[128,64] _:f32[128] _:f32[128] = while[ | |
E body_jaxpr={ lambda ; bm:Ref{bfloat16[384,64]} bn:Ref{int32[384]} bo:Ref{int32[384]} | |
E bp:bf16[128,64] bq:i32[128] br:i32[128] bs:Ref{bfloat16[384,64]} bt:i32[] | |
E bu:i32[] bv:f32[128,64] bw:f32[128] bx:f32[128]. let | |
E by:i32[] = add bt 1 | |
E bz:i32[] = mul bt 64 | |
E ca:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)] | |
E cb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bz | |
E cc:i32[64] = add ca cb | |
E cd:bool[64] = lt cc 384 | |
E ce:bf16[64,64] <- bm[bz:bz+64,:] | |
E cf:i32[64] <- bn[bz:bz+64] | |
E cg:i32[64] <- bo[bz:bz+64] | |
E ch:bf16[64,64] = transpose[permutation=(1, 0)] ce | |
E ci:bf16[128,64] = dot_general[ | |
E dimension_numbers=(([1], [0]), ([], [])) | |
E ] bp ch | |
E cj:f32[128,64] = convert_element_type[ | |
E new_dtype=float32 | |
E weak_type=False | |
E ] ci | |
E ck:i32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] bq | |
E cl:i32[1,64] = broadcast_in_dim[ | |
E broadcast_dimensions=(1,) | |
E shape=(1, 64) | |
E ] cg | |
E cm:bool[128,64] = eq ck cl | |
E cn:i32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] br | |
E co:i32[1,64] = broadcast_in_dim[ | |
E broadcast_dimensions=(1,) | |
E shape=(1, 64) | |
E ] cf | |
E cp:bool[128,64] = ge cn co | |
E cq:bool[128,64] = and cm cp | |
E cr:f32[128,64] = pjit[ | |
E jaxpr={ lambda ; cs:bool[128,64] ct:f32[128,64] cu:f32[]. let | |
E cv:f32[] = convert_element_type[ | |
E new_dtype=float32 | |
E weak_type=False | |
E ] cu | |
E cw:f32[128,64] = broadcast_in_dim[ | |
E broadcast_dimensions=() | |
E shape=(128, 64) | |
E ] cv | |
E cx:f32[128,64] = select_n cs cw ct | |
E in (cx,) } | |
E name=_where | |
E ] cq cj -1e+20 | |
E cy:f32[128] = reduce_max[axes=(1,)] cr | |
E cz:f32[128] = max cy bw | |
E da:f32[128] = sub bw cz | |
E db:f32[128] = exp da | |
E dc:f32[128] = mul bx db | |
E dd:f32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] cz | |
E de:f32[128,64] = sub cr dd | |
E df:f32[128,64] = exp de | |
E dg:f32[128] = reduce_sum[axes=(1,)] df | |
E dh:f32[128] = add dg dc | |
E di:f32[128] = div 1.0 dh | |
E dj:f32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] di | |
E dk:f32[128,64] = mul df dj | |
E dl:f32[128] = mul dc di | |
E dm:f32[128,1] = broadcast_in_dim[ | |
E broadcast_dimensions=(0,) | |
E shape=(128, 1) | |
E ] dl | |
E dn:f32[128,64] = mul bv dm | |
E do:bf16[64,64] <- bs[bz:bz+64,:] | |
E dp:bf16[128,64] = convert_element_type[ | |
E new_dtype=bfloat16 | |
E weak_type=False | |
E ] dk | |
E dq:bf16[128,64] = dot_general[ | |
E dimension_numbers=(([1], [0]), ([], [])) | |
E ] dp do | |
E dr:f32[128,64] = convert_element_type[ | |
E new_dtype=float32 | |
E weak_type=False | |
E ] dq | |
E ds:f32[128,64] = add dn dr | |
E in (by, bu, ds, cz, dh) } | |
E body_nconsts=7 | |
E cond_jaxpr={ lambda ; dt:i32[] du:i32[] dv:f32[128,64] dw:f32[128] dx:f32[128]. let | |
E dy:bool[] = lt dt du | |
E in (dy,) } | |
E cond_nconsts=0 | |
E ] b f g t s r c 0 bk j k l | |
E dz:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] bl | |
E ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p | |
E h[m:m+128,:] <- dz | |
E in () } | |
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:278: TritonLoweringException | |
=========================== short test summary info ============================ | |
FAILED tests/test_attention.py::test_triton_flash_attention[384-384] - jax_tr... | |
============================== 1 failed in 13.12s ============================== |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment