Skip to content

Instantly share code, notes, and snippets.

@hr0nix
Created July 18, 2023 13:53
Show Gist options
  • Save hr0nix/195f1ece2e6cde792cd0ae0e2fbf6357 to your computer and use it in GitHub Desktop.
Save hr0nix/195f1ece2e6cde792cd0ae0e2fbf6357 to your computer and use it in GitHub Desktop.
============================= test session starts ==============================
platform linux -- Python 3.10.12, pytest-7.3.2, pluggy-1.0.0
rootdir: /papyrax/tests
configfile: pytest.ini
collected 1 item
tests/test_attention.py F [100%]
=================================== FAILURES ===================================
_____________________ test_triton_flash_attention[384-384] _____________________
ctx = TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, b...9660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>])
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]}
d:Ref{int32[384]} e:Ref{int32[3...] bl
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p
h[m:m+128,:] <- dz
in () }
block_infos = [BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), start_indices=(<triton.language.c...e.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), ...]
args = (<triton.language.core.tensor object at 0x7fbb28150c70>, <triton.language.core.tensor object at 0x7fbb281510f0>, <trit...>, <triton.language.core.tensor object at 0x7fbb28198460>, <triton.language.core.tensor object at 0x7fbb28198b20>, ...)
read_env = <function lower_jaxpr_to_triton_ir.<locals>.read_env at 0x7fbc21d9ab00>
read_block_info_env = <function lower_jaxpr_to_triton_ir.<locals>.read_block_info_env at 0x7fbc21d9ae60>
write_env = <function lower_jaxpr_to_triton_ir.<locals>.write_env at 0x7fbc21d9aef0>
def lower_jaxpr_to_triton_ir(
ctx: TritonModuleContext,
jaxpr: jax_core.Jaxpr,
block_infos: Optional[Sequence[Optional[BlockInfo]]],
*args
) -> Sequence[Any]:
env = {}
block_info_env = {}
def read_env(var: jax_core.Atom):
if type(var) is jax_core.Literal:
t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder)
dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty
if t.type.scalar != dst_ty:
# _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64
# comes out of .tolist() as list[float], which then comes out of
# _to_tensor as a block of f32.
t = tl.semantic.cast(t, dst_ty, ctx.builder)
return t
return env[var]
def read_block_info_env(var: jax_core.Atom):
if type(var) is jax_core.Literal:
return None
return block_info_env.get(var, None)
def write_env(var: jax_core.Var, val):
env[var] = val
if block_infos is None:
block_infos = [None] * len(jaxpr.invars)
for invar, block_info in zip(jaxpr.invars, block_infos):
block_info_env[invar] = block_info
map(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = map(read_env, eqn.invars)
if eqn.primitive not in triton_lowering_rules:
raise NotImplementedError(eqn.primitive)
rule = triton_lowering_rules[eqn.primitive]
avals_in = [v.aval for v in eqn.invars]
avals_out = [v.aval for v in eqn.outvars]
eqn_block_infos = map(read_block_info_env, eqn.invars)
rule_ctx = TritonLoweringRuleContext(
ctx, avals_in, avals_out, eqn_block_infos
)
try:
> outvals = rule(rule_ctx, *invals, **eqn.params)
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:274:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.co...tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), None])
jaxpr = { lambda ; a:Ref{int32[384]} b:i32[]. let
c:i32[] = mul b 64
d:i32[64] = iota[dimension=0 dtype=int32 shape=(6...32[64] = add d c
f:bool[64] = lt e 384
g:i32[64] <- a[c:c+64]
h:i32[] = reduce_min[axes=(0,)] g
in (h,) }
linear = (False, False), length = 6, reverse = False
def _scan_lowering_rule(
ctx: TritonLoweringRuleContext,
*args,
jaxpr,
linear,
length,
reverse,
unroll,
num_consts,
num_carry,
):
# Only implements fori_loop-like scans
num_extensive = len(args) - num_consts - num_carry
> if num_extensive: raise NotImplementedError
E NotImplementedError
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1314: NotImplementedError
The above exception was the direct cause of the following exception:
> return _run_code(code, main_globals, None,
"__main__", mod_spec)
/usr/lib/python3.10/runpy.py:196:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> exec(code, run_globals)
/usr/lib/python3.10/runpy.py:86:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> raise SystemExit(pytest.console_main())
/usr/local/lib/python3.10/dist-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> code = main()
/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:189:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return wrap_session(config, _main)
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> session.exitstatus = doit(config, session) or 0
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> config.hook.pytest_runtestloop(session=session)
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> runtestprotocol(item, nextitem=nextitem)
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:114:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> reports.append(call_and_report(item, "call", log))
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:133:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> call = call_runtest_hook(item, when, **kwds)
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:222:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:261:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> result: Optional[TResult] = func()
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:341:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:262:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> item.runtest()
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/usr/local/lib/python3.10/dist-packages/_pytest/python.py:1799:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> result = testfunction(**testargs)
/usr/local/lib/python3.10/dist-packages/_pytest/python.py:194:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> result_triton = triton_multi_head_attention(
q,
k,
v,
q_pos,
q_sid,
kv_pos,
kv_sid,
)
tests/test_attention.py:92:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> result = pl.pallas_call(
kernel,
grid=grid,
in_specs=[
pl.BlockSpec(
lambda c, b, h: (b, 0, h, 0), (None, query_seq_len, None, head_dim)
),
pl.BlockSpec(
lambda c, b, h: (b, 0, h, 0), (None, kv_seq_len, None, head_dim)
),
pl.BlockSpec(
lambda c, b, h: (b, 0, h, 0), (None, kv_seq_len, None, head_dim)
),
pl.BlockSpec(lambda c, b, h: (b, 0), (None, query_seq_len)),
pl.BlockSpec(lambda c, b, h: (b, 0), (None, query_seq_len)),
pl.BlockSpec(lambda c, b, h: (b, 0), (None, kv_seq_len)),
pl.BlockSpec(lambda c, b, h: (b, 0), (None, kv_seq_len)),
],
out_specs=pl.BlockSpec(
lambda c, b, h: (b, 0, h, 0), (None, query_seq_len, None, head_dim)
),
num_warps=num_warps,
num_stages=num_stages,
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="mha_forward",
)(
query,
key,
value,
query_positions,
query_segment_ids,
kv_positions,
kv_segment_ids,
)
papyrax/nn/triton_ops.py:218:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> out_flat = pallas_call_p.bind(
*consts, *flat_args, jaxpr=jaxpr, name=name, which_linear=which_linear,
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
for a in flat_args),
out_shapes=tuple(flat_out_shapes), debug=debug,
interpret=interpret,
grid_mapping=grid_mapping,
input_output_aliases=tuple(input_output_aliases.items()),
**compiler_params)
E jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn:
E a:i32[6] = scan[
E jaxpr={ lambda ; b:Ref{int32[384]} c:i32[]. let
E d:i32[] = mul c 64
E e:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E f:i32[64] = add e d
E g:bool[64] = lt f 384
E h:i32[64] <- b[d:d+64]
E i:i32[] = reduce_min[axes=(0,)] h
E in (i,) }
E length=6
E linear=(False, False)
E num_carry=0
E num_consts=1
E reverse=False
E unroll=1
E ] j k
E With context:
E TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, builder=<triton._C.libtriton.triton.ir.builder object at 0x7fbb28171e90>, module=<triton._C.libtriton.triton.ir.module object at 0x7fbb28172110>, grid_mapping=GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0), program_ids=[<triton.language.core.tensor object at 0x7fbb28199660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>]), avals_in=[Ref{int32[384]}, ShapedArray(int32[6])], avals_out=[ShapedArray(int32[6])], block_infos=[BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384), dtype=int32), start_indices=(<triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), None])
E With inval shapes=[[constexpr[1]], [constexpr[6]]]
E With inval types=[pointer<int32>, <[6], int32>]
E In jaxpr:
E { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]}
E d:Ref{int32[384]} e:Ref{int32[384]} f:Ref{int32[384]} g:Ref{int32[384]} h:Ref{bfloat16[384,64]}. let
E i:i32[] = program_id[axis=0]
E j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0
E k:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] -1.0000000200408773e+20
E l:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] 0.0
E m:i32[] = mul i 128
E n:i32[128] = iota[dimension=0 dtype=int32 shape=(128,)]
E o:i32[128] = add n m
E p:bool[128] = lt o 384
E q:bf16[128,64] <- a[m:m+128,:]
E r:i32[128] <- d[m:m+128]
E s:i32[128] <- e[m:m+128]
E t:bf16[128,64] = mul q 0.125
E u:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)]
E v:i32[6] = scan[
E jaxpr={ lambda ; w:Ref{int32[384]} x:i32[]. let
E y:i32[] = mul x 64
E z:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E ba:i32[64] = add z y
E bb:bool[64] = lt ba 384
E bc:i32[64] <- w[y:y+64]
E bd:i32[] = reduce_min[axes=(0,)] bc
E in (bd,) }
E length=6
E linear=(False, False)
E num_carry=0
E num_consts=1
E reverse=False
E unroll=1
E ] f u
E be:i32[] = reduce_max[axes=(0,)] r
E bf:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)]
E bg:bool[6] = ge be v
E bh:i32[6] = convert_element_type[new_dtype=int32 weak_type=False] bg
E bi:i32[6] = mul bf bh
E bj:i32[] = argmax[axes=(0,) index_dtype=int32] bi
E bk:i32[] = add bj 1
E _:i32[] _:i32[] bl:f32[128,64] _:f32[128] _:f32[128] = while[
E body_jaxpr={ lambda ; bm:Ref{bfloat16[384,64]} bn:Ref{int32[384]} bo:Ref{int32[384]}
E bp:bf16[128,64] bq:i32[128] br:i32[128] bs:Ref{bfloat16[384,64]} bt:i32[]
E bu:i32[] bv:f32[128,64] bw:f32[128] bx:f32[128]. let
E by:i32[] = add bt 1
E bz:i32[] = mul bt 64
E ca:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E cb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bz
E cc:i32[64] = add ca cb
E cd:bool[64] = lt cc 384
E ce:bf16[64,64] <- bm[bz:bz+64,:]
E cf:i32[64] <- bn[bz:bz+64]
E cg:i32[64] <- bo[bz:bz+64]
E ch:bf16[64,64] = transpose[permutation=(1, 0)] ce
E ci:bf16[128,64] = dot_general[
E dimension_numbers=(([1], [0]), ([], []))
E ] bp ch
E cj:f32[128,64] = convert_element_type[
E new_dtype=float32
E weak_type=False
E ] ci
E ck:i32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] bq
E cl:i32[1,64] = broadcast_in_dim[
E broadcast_dimensions=(1,)
E shape=(1, 64)
E ] cg
E cm:bool[128,64] = eq ck cl
E cn:i32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] br
E co:i32[1,64] = broadcast_in_dim[
E broadcast_dimensions=(1,)
E shape=(1, 64)
E ] cf
E cp:bool[128,64] = ge cn co
E cq:bool[128,64] = and cm cp
E cr:f32[128,64] = pjit[
E jaxpr={ lambda ; cs:bool[128,64] ct:f32[128,64] cu:f32[]. let
E cv:f32[] = convert_element_type[
E new_dtype=float32
E weak_type=False
E ] cu
E cw:f32[128,64] = broadcast_in_dim[
E broadcast_dimensions=()
E shape=(128, 64)
E ] cv
E cx:f32[128,64] = select_n cs cw ct
E in (cx,) }
E name=_where
E ] cq cj -1e+20
E cy:f32[128] = reduce_max[axes=(1,)] cr
E cz:f32[128] = max cy bw
E da:f32[128] = sub bw cz
E db:f32[128] = exp da
E dc:f32[128] = mul bx db
E dd:f32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] cz
E de:f32[128,64] = sub cr dd
E df:f32[128,64] = exp de
E dg:f32[128] = reduce_sum[axes=(1,)] df
E dh:f32[128] = add dg dc
E di:f32[128] = div 1.0 dh
E dj:f32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] di
E dk:f32[128,64] = mul df dj
E dl:f32[128] = mul dc di
E dm:f32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] dl
E dn:f32[128,64] = mul bv dm
E do:bf16[64,64] <- bs[bz:bz+64,:]
E dp:bf16[128,64] = convert_element_type[
E new_dtype=bfloat16
E weak_type=False
E ] dk
E dq:bf16[128,64] = dot_general[
E dimension_numbers=(([1], [0]), ([], []))
E ] dp do
E dr:f32[128,64] = convert_element_type[
E new_dtype=float32
E weak_type=False
E ] dq
E ds:f32[128,64] = add dn dr
E in (by, bu, ds, cz, dh) }
E body_nconsts=7
E cond_jaxpr={ lambda ; dt:i32[] du:i32[] dv:f32[128,64] dw:f32[128] dx:f32[128]. let
E dy:bool[] = lt dt du
E in (dy,) }
E cond_nconsts=0
E ] b f g t s r c 0 bk j k l
E dz:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] bl
E ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p
E h[m:m+128,:] <- dz
E in () }
E
E The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
E
E --------------------
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/pallas_call.py:352: JaxStackTraceBeforeTransformation
The above exception was the direct cause of the following exception:
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/usr/local/lib/python3.10/dist-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:189:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], "os.PathLike[str]"]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo.from_exc_info(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/usr/local/lib/python3.10/dist-packages/_pytest/config/__init__.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7fbcddaa94e0>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/usr/local/lib/python3.10/dist-packages/_pytest/mai...uponly', plugin=<module '_pytest.setuponly' from '/usr/local/lib/python3.10/dist-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7fbcddaa94e0>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/usr/local/lib/python3.10/dist-packages/_pytest/mai...uponly', plugin=<module '_pytest.setuponly' from '/usr/local/lib/python3.10/dist-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7fbcddaa94e0>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7fbcddaa94e0>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7fbcddaa94e0>
doit = <function _main at 0x7fbcddd18280>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7fbcddaa94e0>
session = <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/usr/local/lib/python3.10/dist-packages/_pytest/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7fbcdd963250>>]
kwargs = {'session': <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/usr/local/lib/python3.10/dist-packages/_pytest/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7fbcdd963250>>]
caller_kwargs = {'session': <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session tests exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=1>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/usr/local/lib/python3.10/dist-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_triton_flash_attention[384-384]>, 'nextitem': None}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/usr/local/lib/python3.10/dist-packages/_pytest...ame='warnings', plugin=<module '_pytest.warnings' from '/usr/local/lib/python3.10/dist-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_triton_flash_attention[384-384]>, 'nextitem': None}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/usr/local/lib/python3.10/dist-packages/_pytest...ame='warnings', plugin=<module '_pytest.warnings' from '/usr/local/lib/python3.10/dist-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_triton_flash_attention[384-384]>, 'nextitem': None}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_triton_flash_attention[384-384]>, nextitem = None
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:114:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_triton_flash_attention[384-384]>, log = True
nextitem = None
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
# This only happens if the item is re-run, as is done by
# pytest-rerunfailures.
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:133:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_triton_flash_attention[384-384]>, when = 'call'
log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo TritonLoweringException("Exception while lowering eqn:\n a:i32[6] = scan...ol[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p\n h[m:m+128,:] <- dz\n in () }") tblen=5>>
hook = <_pytest.config.compat.PathAwareHookProxy object at 0x7fbcddaa9660>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:222:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_triton_flash_attention[384-384]>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:261:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fbcdd9d0310>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
"""Call func, wrapping the result in a CallInfo.
:param func:
The function to call. Called without arguments.
:param when:
The phase in which the function is called.
:param reraise:
Exception or exceptions that shall propagate if raised by the
function, instead of being wrapped in the CallInfo.
"""
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:341:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:262:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_triton_flash_attention[384-384]>}
argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/usr/local/lib/python3.10/dist-packages/_pytest..., plugin=<module '_pytest.threadexception' from '/usr/local/lib/python3.10/dist-packages/_pytest/threadexception.py'>>]
kwargs = {'item': <Function test_triton_flash_attention[384-384]>}
firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/usr/local/lib/python3.10/dist-packages/_pytest..., plugin=<module '_pytest.threadexception' from '/usr/local/lib/python3.10/dist-packages/_pytest/threadexception.py'>>]
caller_kwargs = {'item': <Function test_triton_flash_attention[384-384]>}
firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_triton_flash_attention[384-384]>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/usr/local/lib/python3.10/dist-packages/_pytest/runner.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_triton_flash_attention[384-384]>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/usr/local/lib/python3.10/dist-packages/_pytest/python.py:1799:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_triton_flash_attention[384-384]>}
argname = 'pyfuncitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7fbcde41aa10>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/usr/local/lib/python3.10/dist-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_triton_flash_attention[384-384]>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/usr/local/lib/python3.10/dist-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/usr/local/lib/python3.10/dist-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_triton_flash_attention[384-384]>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/usr/local/lib/python3.10/dist-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_triton_flash_attention[384-384]>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/usr/local/lib/python3.10/dist-packages/_pytest/python.py:194:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
query_seq_len = 384, kv_seq_len = 384
@pytest.mark.gpu
@pytest.mark.parametrize(
"query_seq_len,kv_seq_len",
[
(384, 384),
(400, 400),
(384, 256),
(384, 300),
(400, 256),
(400, 300),
(1, 256),
(1, 300),
(256, 1),
(300, 1),
(1, 3),
],
)
def test_triton_flash_attention(query_seq_len, kv_seq_len):
q, k, v, q_pos, q_sid, kv_pos, kv_sid = _make_attention_inputs(
query_seq_len=query_seq_len,
kv_seq_len=kv_seq_len,
)
> result_triton = triton_multi_head_attention(
q,
k,
v,
q_pos,
q_sid,
kv_pos,
kv_sid,
)
tests/test_attention.py:92:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383,
1.375],
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376,
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...)
kwargs = {}, __tracebackhide__ = True
msg = 'jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn:\n a:i32[6] = scan[\n jaxpr...ludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------'
@functools.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/jax/_src/traceback_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383,
1.375],
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376,
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...)
kwargs = {}
@api_boundary
def cache_miss(*args, **kwargs):
> outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
fun, infer_params_fn, *args, **kwargs)
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:250:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = <function multi_head_attention at 0x7fbc2a3d9120>
infer_params_fn = <function jit.<locals>.infer_params at 0x7fbc2a416950>
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383,
1.375],
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376,
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...)
kwargs = {}
args_flat = [Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383,
1.375],
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376,
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...]
_ = ()
params = {'donated_invars': (False, False, False, False, False, False, ...), 'in_shardings': (UnspecifiedValue, UnspecifiedValu...lse, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }, ...}
in_tree = PyTreeDef((*, *, *, *, *, *, *)), out_tree = PyTreeDef(*)
arg = Array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32, weak_type=True)
def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
*args, **kwargs)
for arg in args_flat:
dispatch.check_arg(arg)
try:
> out_flat = pjit_p.bind(*args_flat, **params)
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:163:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = pjit
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383,
1.375],
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376,
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...)
params = {'donated_invars': (False, False, False, False, False, False, ...), 'in_shardings': (UnspecifiedValue, UnspecifiedValu...lse, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }, ...}
top_trace = EvalTrace(level=0/0), axis_main = None
def bind(self, *args, **params):
top_trace = find_top_trace(args)
axis_main = max((axis_frame(a).main_trace for a in used_axis_names(self, params)),
default=None, key=lambda t: getattr(t, 'level', -1))
top_trace = (top_trace if not axis_main or axis_main.level < top_trace.level
else axis_main.with_cur_sublevel())
> return self.bind_with_trace(top_trace, args, params)
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:2578:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = pjit, trace = EvalTrace(level=0/0)
args = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383,
1.375],
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376,
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...)
params = {'donated_invars': (False, False, False, False, False, False, ...), 'in_shardings': (UnspecifiedValue, UnspecifiedValu...lse, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }, ...}
def bind_with_trace(self, trace, args, params):
> out = trace.process_primitive(self, map(trace.full_raise, args), params)
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:382:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = pjit
tracers = [Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383,
1.375],
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376,
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...]
params = {'donated_invars': (False, False, False, False, False, False, ...), 'in_shardings': (UnspecifiedValue, UnspecifiedValu...lse, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }, ...}
def process_primitive(self, primitive, tracers, params):
> return primitive.impl(*tracers, **params)
/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:814:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }
in_shardings = (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, ...)
out_shardings = (UnspecifiedValue,), resource_env = None
def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name, keep_unused, inline):
def call_impl_cache_miss(*args_, **kwargs_):
out_flat, compiled = _pjit_call_impl_python(
*args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat)
return out_flat, fastpath_data
f = _get_jaxpr_as_fun(
jaxpr, tuple(getattr(i, '_original_sharding', i) for i in in_shardings),
tuple(getattr(o, '_original_sharding', o) for o in out_shardings),
resource_env, donated_invars, name, keep_unused, inline)
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, None, None)
> return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
_get_cpp_global_cache(has_explicit_sharding))(*args)
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1223:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args_ = (Array([[[[0.162109, 0.632812, -0.333984, ..., -0.667969, -0.192383,
1.375],
[1.42969, 0.953125, -0... 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376,
377, 378, 379, 380, 381, 382, 383]], dtype=int32), ...)
kwargs_ = {}
def call_impl_cache_miss(*args_, **kwargs_):
> out_flat, compiled = _pjit_call_impl_python(
*args, jaxpr=jaxpr, in_shardings=in_shardings,
out_shardings=out_shardings, resource_env=resource_env,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline)
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1207:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }
in_shardings = (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, ...)
out_shardings = (UnspecifiedValue,), resource_env = None
def _pjit_call_impl_python(
*args, jaxpr, in_shardings, out_shardings, resource_env, donated_invars,
name, keep_unused, inline):
global _most_recent_pjit_call_executable
in_shardings = _resolve_in_shardings(
args, in_shardings, out_shardings,
resource_env.physical_mesh if resource_env is not None else None)
> compiled = _pjit_lower(
jaxpr, in_shardings, out_shardings, resource_env,
donated_invars, name, keep_unused, inline,
always_lower=False, lowering_platform=None).compile()
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1140:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }
in_shardings = SameDeviceAssignmentTuple(shardings=(UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue), device_assignment=None)
out_shardings = SameDeviceAssignmentTuple(shardings=(UnspecifiedValue,), device_assignment=None)
args = (None, (False, False, False, False, False, False, ...), 'multi_head_attention', False, False)
kwargs = {'always_lower': False, 'lowering_platform': None}, da = None
def _pjit_lower(
jaxpr: core.ClosedJaxpr,
in_shardings,
out_shardings,
*args, **kwargs):
da = _fast_path_get_device_assignment(it.chain(in_shardings, out_shardings))
in_shardings = SameDeviceAssignmentTuple(tuple(in_shardings), da)
out_shardings = SameDeviceAssignmentTuple(tuple(out_shardings), da)
> return _pjit_lower_cached(jaxpr, in_shardings, out_shardings, *args, **kwargs)
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }
sdat_in_shardings = SameDeviceAssignmentTuple(shardings=(UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue), device_assignment=None)
sdat_out_shardings = SameDeviceAssignmentTuple(shardings=(UnspecifiedValue,), device_assignment=None)
resource_env = None
donated_invars = (False, False, False, False, False, False, ...)
name = 'multi_head_attention', keep_unused = False, inline = False
always_lower = False
@weakref_lru_cache
def _pjit_lower_cached(
jaxpr: core.ClosedJaxpr,
sdat_in_shardings: SameDeviceAssignmentTuple,
sdat_out_shardings: SameDeviceAssignmentTuple,
resource_env,
donated_invars,
name: str,
keep_unused: bool,
inline: bool,
always_lower: bool,
*,
lowering_platform: Optional[str],
override_lowering_rules: Optional[
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None):
in_shardings: tuple[PjitShardingMinusUnspecified, ...] = cast(
tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
out_shardings: tuple[PjitSharding, ...] = sdat_out_shardings.shardings
if resource_env is not None:
pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
if resource_env is not None:
mesh = resource_env.physical_mesh
api_name = 'pjit'
else:
# resource_env is `None` in the jit wrapper around pjit.
mesh = None
api_name = 'jit'
# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
# because `xmap` only supports SPMDAxisContext right now.
if dispatch.jaxpr_has_primitive(jaxpr.jaxpr, 'xmap'):
return pxla.lower_mesh_computation(
jaxpr, api_name, name, mesh,
in_shardings, out_shardings, donated_invars,
True, jaxpr.in_avals, tiling_method=None,
lowering_platform=lowering_platform)
else:
> return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
tuple(donated_invars), tuple(jaxpr.in_avals),
keep_unused=keep_unused, inline=inline, always_lower=always_lower,
devices_from_context=(
None if mesh is None or mesh.empty else list(mesh.devices.flat)),
lowering_platform=lowering_platform,
override_lowering_rules=override_lowering_rules,
)
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = ({ lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[...e, UnspecifiedValue, UnspecifiedValue, ...), (UnspecifiedValue,), (False, False, False, False, False, False, ...), ...)
kwargs = {'always_lower': False, 'devices_from_context': None, 'inline': False, 'keep_unused': False, ...}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py:314:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun_or_jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }
api_name = 'jit', fun_name = 'multi_head_attention'
in_shardings = (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), ...)
out_shardings = (UnspecifiedValue,)
donated_invars = (False, False, False, False, False, False, ...)
global_in_avals = (ShapedArray(bfloat16[2,384,2,64]), ShapedArray(bfloat16[2,384,2,64]), ShapedArray(bfloat16[2,384,2,64]), ShapedArray(int32[2,384]), ShapedArray(int32[2,384], weak_type=True), ShapedArray(int32[2,384]), ...)
@profiler.annotate_function
def lower_sharding_computation(
fun_or_jaxpr: Union[lu.WrappedFun, core.ClosedJaxpr],
api_name: str,
fun_name: str,
in_shardings: Sequence[MaybeSharding],
out_shardings: Union[Sequence[MaybeSharding], UnspecifiedValue],
donated_invars: Sequence[bool],
global_in_avals: Sequence[core.ShapedArray],
*,
keep_unused: bool,
inline: bool,
always_lower: bool,
devices_from_context: Optional[Sequence[xc.Device]] = None,
lowering_platform: Optional[str],
override_lowering_rules: Optional[
tuple[tuple[core.Primitive, mlir.LoweringRule]]] = None,
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
The caller of this code can pass in a singleton UNSPECIFIED because the
number of out_avals might not be known at that time and
lower_sharding_computation calculates the number of out_avals so it can apply
the singleton UNSPECIFIED to all out_avals.
"""
# 1. Trace to jaxpr and preprocess/verify it
auto_spmd_lowering = (
check_if_any_auto(in_shardings) if is_unspecified(out_shardings) else
check_if_any_auto(it.chain.from_iterable([in_shardings, out_shardings]))) # type: ignore
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
kept_var_idx, name_stack) = _trace_to_jaxpr_and_dce(
fun_or_jaxpr, global_in_avals, api_name, fun_name, keep_unused,
donated_invars, auto_spmd_lowering)
jaxpr = closed_jaxpr.jaxpr
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
if is_unspecified(out_shardings):
out_shardings = (UNSPECIFIED,) * len(global_out_avals)
assert isinstance(out_shardings, tuple)
assert len(out_shardings) == len(global_out_avals), (
len(out_shardings), len(global_out_avals))
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
jaxpr_sharding = list(dispatch.jaxpr_shardings(jaxpr))
backend, device_assignment = _get_and_check_device_assignment(
it.chain([(i, MismatchType.ARG_SHARDING, None) for i in in_shardings],
[(o, MismatchType.OUT_SHARDING, None) for o in out_shardings],
[(js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in jaxpr_sharding]),
devices_from_context)
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in in_shardings) or
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings))
gs = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
da_object = _create_da_object(tuple(device_assignment))
if not da_object.is_fully_addressable:
check_multihost_collective_allowlist(jaxpr)
if inline and config.jax_spmd_mode != 'allow_all':
raise RuntimeError(
"Running operations on `Array`s that are not fully addressable by this "
"process (i.e. `Array`s with data sharded across multiple devices and "
"processes.) is dangerous. It’s very important that all processes run "
"the same cross-process computations in the same order otherwise it "
"can lead to hangs. "
"If you’re not already familiar with JAX’s multi-process "
"programming model, please read "
"https://jax.readthedocs.io/en/latest/multi_process.html. "
"To fix this error, run your `jitted` computation inside "
"`with jax.spmd_mode('allow_all'):` context manager.")
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
kept_outputs = [True] * len(global_out_avals)
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(is_unspecified(o) for o in out_shardings)):
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr,
consts=closed_jaxpr.consts, global_in_avals=global_in_avals,
global_out_avals=global_out_avals, in_shardings=in_shardings,
backend=backend, da_object=da_object,
committed=committed, kept_var_idx=kept_var_idx, keepalive=None)
# 2. Build up the HLO
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
semantic_out_shardings = SemanticallyEqualShardings(out_shardings)
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
> nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, da_object, lowering_platform,
donated_invars, name_stack, override_lowering_rules)
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py:2083:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
closed_jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }
api_name = 'jit', fun_name = 'multi_head_attention'
backend = <jaxlib.xla_extension.Client object at 0x7fbc21ccc070>
semantic_in_shardings = SemanticallyEqualShardings(shardings=(GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replica...), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})))
semantic_out_shardings = SemanticallyEqualShardings(shardings=(UnspecifiedValue,))
da_object = _DeviceAssignment(device_assignment=(gpu(id=0),))
lowering_platform = None
donated_invars = (False, False, False, False, False, False, ...)
name_stack = NameStack(stack=(Scope(name='jit(multi_head_attention)'),))
override_lowering_rules = None
@weakref_lru_cache
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
da_object, lowering_platform,
donated_invars, name_stack, override_lowering_rules):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals
device_assignment = da_object.device_assignment
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
if logger.isEnabledFor(log_priority):
logger.log(log_priority,
"Compiling %s for with global shapes and types %s. "
"Argument mapping: %s.",
fun_name, global_in_avals, in_shardings)
# Look at the number of replcas present in the jaxpr. In
# lower_sharding_computation, nreps > 1 during `jit(pmap)` cases. This is
# handled here so as to deprecate the lower_xla_callable codepath when
# `jax.Array` is turned on by default.
# TODO(yashkatariya): Remove this when `jit(pmap)` is removed.
nreps = dispatch.jaxpr_replicas(jaxpr)
dispatch.raise_warnings_or_errors_for_jit_of_pmap(
nreps, backend, fun_name, jaxpr)
in_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
out_mlir_shardings: Optional[list[Optional[sharding_impls.XLACompatibleSharding]]]
axis_ctx: mlir.AxisContext
if nreps == 1:
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = sharding_impls.ShardingContext(device_assignment)
num_partitions = len(device_assignment)
else:
# This path is triggered for `jit(pmap)` cases.
replicated_args = None
in_mlir_shardings = None
out_mlir_shardings = None
axis_env = sharding_impls.AxisEnv(nreps, (), ())
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
num_partitions = 1
module_name = f"{api_name}_{fun_name}"
if len(device_assignment) > 1:
if any(effects.ordered_effects.contains(eff) for eff
in closed_jaxpr.effects):
raise ValueError("Ordered effects are not supported for more than 1 device.")
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
with dispatch.log_elapsed_time(
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec",
fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
> lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
ordered_effects,
backend,
# Optionally, override the lowering platform
lowering_platform or backend.platform,
axis_ctx,
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=in_mlir_shardings,
result_shardings=out_mlir_shardings,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=nreps,
num_partitions=num_partitions,
override_lowering_rules=override_lowering_rules)
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py:1923:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
module_name = 'jit_multi_head_attention'
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }
ordered_effects = []
backend_or_name = <jaxlib.xla_extension.Client object at 0x7fbc21ccc070>
platform = 'cuda'
axis_context = ShardingContext(device_assignment=(gpu(id=0),))
name_stack = NameStack(stack=(Scope(name='jit(multi_head_attention)'),))
donated_args = [False, False, False, False, False, False, ...]
replicated_args = [False, False, False, False, False, False, ...]
arg_shardings = [GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), ...]
result_shardings = [None]
arg_names = ('query', 'key', 'value', 'query_positions', 'query_segment_ids', 'kv_positions', ...)
result_names = ('',), num_replicas = 1, num_partitions = 1
override_lowering_rules = None
def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
ordered_effects: list[core.Effect],
backend_or_name: Optional[Union[str, xb.XlaBackend]],
platform: str,
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
donated_args: Sequence[bool],
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None,
result_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None,
arg_names: Optional[Sequence[Optional[str]]] = None,
result_names: Optional[Sequence[Optional[str]]] = None,
num_replicas: int = 1,
num_partitions: int = 1,
override_lowering_rules: Optional[
tuple[tuple[core.Primitive, LoweringRule]]] = None,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
platform = xb.canonicalize_platform(platform)
if not xb.is_known_platform(platform):
raise ValueError(f"Unknown platform {platform}")
input_output_aliases = None
in_avals = (jaxpr.in_avals if arg_shardings is None else
map(sharded_aval, jaxpr.in_avals, arg_shardings))
out_avals = (jaxpr.out_avals if result_shardings is None else
map(sharded_aval, jaxpr.out_avals, result_shardings))
if platform in _platforms_with_donation:
input_output_aliases, donated_args = _set_up_aliases(
in_avals, out_avals, donated_args)
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
if any(donated_args):
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
if platform not in _platforms_with_donation:
msg = f"Donation is not implemented for {platform}.\n{msg}"
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
# HLO channels need to start at 1
channel_iter = itertools.count(1)
# Create a keepalives list that will be mutated during the lowering.
keepalives: list[Any] = []
host_callbacks: list[Any] = []
dim_vars: Sequence[str]
if not config.jax_dynamic_shapes:
# Find the dimension variables
all_dim_poly = [d for aval in jaxpr.in_avals if hasattr(aval, "shape")
for d in aval.shape if not core.is_constant_dim(d)]
dim_vars = tuple(sorted(functools.reduce(lambda acc, new: acc.union(new.get_vars()),
all_dim_poly, set())))
else:
dim_vars = ()
arg_op_shardings = (
map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings)
if arg_shardings is not None else arg_shardings)
result_op_shardings = (
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks,
override_lowering_rules=override_lowering_rules,
shape_poly_state=ShapePolyLoweringState(dim_vars))
with ctx.context, ir.Location.unknown(ctx.context):
# Remove module name characters that XLA would alter. This ensures that
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
module_name = _module_name_regex.sub("_", module_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
> lower_jaxpr_to_fun(
ctx, "main", jaxpr, ordered_effects, public=True, create_tokens=True,
replace_tokens_with_dummy=True,
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_op_shardings,
result_shardings=result_op_shardings,
input_output_aliases=input_output_aliases,
arg_names=arg_names,
result_names=result_names)
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:728:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbb2812df80>, module=<jax...object at 0x7fbb281097e0>, cached_primitive_lowerings={}, cached_call_jaxpr_lowerings={}, override_lowering_rules=None)
name = 'main'
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }
effects = []
def lower_jaxpr_to_fun(
ctx: ModuleContext,
name: str,
jaxpr: core.ClosedJaxpr,
effects: Sequence[core.Effect],
*,
create_tokens: bool = False,
public: bool = False,
replace_tokens_with_dummy: bool = False,
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None,
use_sharding_annotations: bool = True,
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
num_output_tokens: int = 0,
api_name: str = "jit",
arg_names: Optional[Sequence[Optional[str]]] = None,
result_names: Optional[Sequence[Optional[str]]] = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
Assumes that an MLIR context, location, and insertion point are set.
Args:
ctx: the lowering context.
name: the function name. The name will be uniquified by the symbol table,
so it is ok to use the same name multiple times.
jaxpr: the jaxpr to lower.
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
that will be created in or used by the lowered function.
create_tokens: if true, the HLO will create tokens and ignore dummy input tokens.
public: if true, the function's visibility is set to "public".
replace_tokens_with_dummy: if true, token arguments/return values are
replaced with bool arrays of size [0].
replicated_args: if present, annotates arguments as replicated.
arg_shardings: sharding annotations for each argument (optional).
result_shardings: sharding annotations for each result (optional).
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
parameters and return values to express sharding. If False, use
hlo.custom_call operators with sharding annotations.
TODO(b/228598865): remove this option when "mhlo.sharding" annotations are
propagated on non-entry functions during MLIR->HLO conversion.
input_output_aliases: optional sequence that maps argument numbers to the
corresponding output that should alias them.
api_name: The name of the higher level primitive which should show up in the
name stack.
Returns:
MLIR func op
"""
def aval_to_types(aval):
if replace_tokens_with_dummy and aval is core.abstract_token:
aval = core.ShapedArray((), np.dtype(np.bool_))
return aval_to_ir_types(aval)
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
dim_var_types = map(aval_to_types, dim_var_avals)
# Function inputs: *dim_var_values, *tokens, *actual_inputs
input_types = map(aval_to_types, jaxpr.in_avals)
output_types = map(aval_to_types, jaxpr.out_avals)
num_tokens = len(effects)
if create_tokens:
# If we create the tokens they won't be inputs to the MLIR function.
token_types = [dummy_token_type() for _ in effects]
output_token_types = [dummy_token_type() for _ in range(num_output_tokens)]
else:
# If we aren't creating tokens they will be the initial inputs to the
# MLIR function.
output_token_types = []
token_types = [token_type() for _ in effects]
token_avals = [core.AbstractToken] * num_tokens
# Order of arguments: dim vars, tokens, array inputs
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types]
if input_output_aliases is not None:
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
# Update the existing aliases to account for the new output values
input_output_aliases = [None if a is None
else a + num_output_tokens + num_tokens
for a in input_output_aliases]
if arg_shardings is not None:
token_shardings = [None] * (num_dim_vars + num_tokens)
arg_shardings = [*token_shardings, *arg_shardings]
if result_shardings is not None:
token_shardings = [None] * (num_tokens + num_output_tokens)
result_shardings = [*token_shardings, *result_shardings]
if replicated_args is not None:
token_replicated_args = [False] * (num_dim_vars + num_tokens)
replicated_args = [*token_replicated_args, *replicated_args]
flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types)
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip)
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
"public" if public else "private")
ctx.symbol_table.insert(func_op)
ir_arg_shardings = None
if arg_shardings is not None:
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
ir_arg_shardings = util.flatten(
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(in_avals, arg_shardings, input_types)])
del in_avals
ir_result_shardings = None
if result_shardings is not None:
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
ir_result_shardings = util.flatten(
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(out_avals, result_shardings, output_types)])
del out_avals
if (
replicated_args is not None
or ir_arg_shardings is not None
or input_output_aliases is not None
or arg_names is not None
or num_tokens > 0
):
arg_attrs: list[dict[str, ir.Attribute]] = [
{} for _ in range(len(flat_input_types))]
if replicated_args is not None:
replicated_ir_args = [[replicated] * len(types) for replicated, types
in zip(replicated_args, input_types)]
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
if replicated:
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
if use_sharding_annotations and ir_arg_shardings is not None:
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
if sharding is not None:
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
if input_output_aliases is not None:
output_ids = util.unflatten(list(range(len(flat_output_types))),
map(len, output_types))
aliases: list[Optional[int]] = []
for types, alias in zip(input_types, input_output_aliases):
if alias is None:
aliases.extend([None] * len(types))
else:
aliases.extend(output_ids[alias])
for attrs, alias in zip(arg_attrs, aliases):
if alias is not None:
attrs["tf.aliasing_output"] = i32_attr(alias)
if num_tokens > 0:
token_arg_attrs = arg_attrs[num_dim_vars:num_dim_vars + num_tokens]
for attrs in token_arg_attrs:
attrs["jax.token"] = ir.BoolAttr.get(True)
if arg_names:
named_arg_attrs = arg_attrs[num_dim_vars + num_tokens:]
for attrs, name_ in zip(named_arg_attrs, arg_names):
if name_:
attrs['jax.arg_info'] = ir.StringAttr.get(name_)
func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
result_attrs: list[dict[str, ir.Attribute]] = [
{} for _ in range(len(flat_output_types))]
if num_tokens > 0:
token_result_attrs = result_attrs[:num_tokens]
for attrs in token_result_attrs:
attrs["jax.token"] = ir.BoolAttr.get(True)
if result_names:
named_result_attrs = result_attrs[num_tokens:]
if len(named_result_attrs) == len(result_names):
for attrs, name_ in zip(named_result_attrs, result_names):
attrs['jax.result_info'] = ir.StringAttr.get(name_)
if use_sharding_annotations and ir_result_shardings is not None:
for attrs, sharding in zip(result_attrs, ir_result_shardings):
if sharding is not None:
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
func_op.result_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in result_attrs])
entry_block = func_op.add_entry_block()
with ir.InsertionPoint(entry_block):
flat_args = entry_block.arguments
# We separate out the dimension variable inputs, the token inputs and
# the regular inputs. The dimension variables and token inputs
# will be passed to `jaxpr_subcomp` separately from the `args`.
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
# A lowering context just for function body entry/exit code.
entry_lowering_ctx = LoweringRuleContext(
ctx, None, [], None, TokenSet.create([]), None, None, dim_var_values)
if not use_sharding_annotations and ir_arg_shardings is not None:
flat_args = [
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)),
[num_dim_vars, num_tokens])
if create_tokens:
tokens_in = TokenSet.create(effects)
else:
tokens_in = TokenSet(zip(effects, token_args))
args: list[list[ir.Value]] = []
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
if replace_tokens_with_dummy and aval is core.abstract_token:
args.append(hlo.CreateTokenOp().results)
else:
args.append(arg)
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
> out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts),
*args, dim_var_values=dim_var_values)
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:1060:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbb2812df80>, module=<jax...object at 0x7fbb281097e0>, cached_primitive_lowerings={}, cached_call_jaxpr_lowerings={}, override_lowering_rules=None)
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2...se, False, False, False)
] i j k l m n o
in (p,) }
name=wrapped
] a b c d e f g
in (h,) }
tokens = <jax._src.interpreters.mlir.TokenSet object at 0x7fbb28109e70>
consts = [], dim_var_values = []
args = ([<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb28116430>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockAr...ckArgument object at 0x7fbb281178f0>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb28115fb0>], ...)
aval = <function jaxpr_subcomp.<locals>.aval at 0x7fbc21d9a710>
write = <function jaxpr_subcomp.<locals>.write at 0x7fbc21d9a440>
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
tokens: TokenSet,
consts: Sequence[Sequence[ir.Value]],
*args: Sequence[ir.Value],
dim_var_values: Sequence[ir.Value]
) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]:
"""Lowers a jaxpr into MLIR, inlined into an existing function.
Assumes that an MLIR context, location, and insertion point are set.
dim_var_values: the list of dimension variables values in the current
IR function, in the order of ctx.shape_poly_state.dim_vars.
"""
assert ctx.platform != "gpu"
def read(v: core.Atom) -> Sequence[ir.Value]:
if type(v) is core.Literal:
return ir_constants(v.val, canonicalize_types=True)
else:
assert isinstance(v, core.Var)
return env[v]
def aval(v: core.Atom) -> core.AbstractValue:
if type(v) is core.Literal:
return xla.abstractify(v.val)
else:
return v.aval
def write(v: core.Var, node: Sequence[ir.Value]):
assert node is not None
env[v] = tuple(node)
def get_lowering(primitive: core.Primitive) -> Optional[LoweringRule]:
if ctx.override_lowering_rules is None:
return None
for p, rule in ctx.override_lowering_rules:
if primitive is p:
return rule
return None
env: dict[core.Var, tuple[ir.Value, ...]] = {}
assert len(args) == len(jaxpr.invars), (jaxpr, args)
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
source_info = eqn.source_info.replace(
name_stack=ctx.name_stack + eqn.source_info.name_stack)
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
ctx.name_stack)
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_lowering(eqn.primitive)
if override_rule is not None:
rule = override_rule
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
rule = xla_fallback_lowering(eqn.primitive)
elif eqn.primitive in _lowerings:
rule = _lowerings[eqn.primitive]
elif eqn.primitive in xla._translations:
rule = xla_fallback_lowering(eqn.primitive)
else:
raise NotImplementedError(
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
f"found for platform {ctx.platform}")
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
tokens_in = tokens.subset(effects)
avals_in = map(aval, eqn.invars)
rule_ctx = LoweringRuleContext(
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in,
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
tokens_out=None, dim_var_values=dim_var_values)
if config.jax_dynamic_shapes:
axis_size_env = {d: read(d)[0]
for a in avals_in if type(a) is core.DShapedArray
for d in a.shape if type(d) is core.Var}
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
> ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
**eqn.params)
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:1217:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context obj...<jax._src.interpreters.mlir.TokenSet object at 0x7fbb28109c90>, tokens_out=None, axis_size_env=None, dim_var_values=[])
name = 'wrapped'
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2... dtype=bfloat16),)
which_linear=(False, False, False, False, False, False, False)
] a b c d e f g
in (h,) }
in_shardings = (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, UnspecifiedValue, ...)
out_shardings = (UnspecifiedValue,)
def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
out_shardings, resource_env, donated_invars,
keep_unused, inline):
effects = list(ctx.tokens_in.effects())
output_types = map(mlir.aval_to_ir_types, ctx.avals_out)
output_types = [mlir.token_type()] * len(effects) + output_types
flat_output_types = flatten(output_types)
arg_shardings = [None if is_unspecified(i) else
i._to_xla_hlo_sharding(aval.ndim)
for aval, i in zip(ctx.avals_in, in_shardings)]
result_shardings = [None if is_unspecified(o) else
o._to_xla_hlo_sharding(aval.ndim)
for aval, o in zip(ctx.avals_out, out_shardings)]
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
# inputs or outputs because they are lost during MLIR->HLO conversion.
# using_sharding_annotation=False means we add an identity operation instead.
> func = mlir.lower_jaxpr_to_fun(
ctx.module_context, name, jaxpr, effects, arg_shardings=arg_shardings,
result_shardings=result_shardings, use_sharding_annotations=False,
api_name=('jit' if resource_env is None else 'pjit'))
/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py:1395:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbb2812df80>, module=<jax...object at 0x7fbb281097e0>, cached_primitive_lowerings={}, cached_call_jaxpr_lowerings={}, override_lowering_rules=None)
name = 'wrapped'
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2... dtype=bfloat16),)
which_linear=(False, False, False, False, False, False, False)
] a b c d e f g
in (h,) }
effects = []
def lower_jaxpr_to_fun(
ctx: ModuleContext,
name: str,
jaxpr: core.ClosedJaxpr,
effects: Sequence[core.Effect],
*,
create_tokens: bool = False,
public: bool = False,
replace_tokens_with_dummy: bool = False,
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.HloSharding]]] = None,
use_sharding_annotations: bool = True,
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
num_output_tokens: int = 0,
api_name: str = "jit",
arg_names: Optional[Sequence[Optional[str]]] = None,
result_names: Optional[Sequence[Optional[str]]] = None,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
Assumes that an MLIR context, location, and insertion point are set.
Args:
ctx: the lowering context.
name: the function name. The name will be uniquified by the symbol table,
so it is ok to use the same name multiple times.
jaxpr: the jaxpr to lower.
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
that will be created in or used by the lowered function.
create_tokens: if true, the HLO will create tokens and ignore dummy input tokens.
public: if true, the function's visibility is set to "public".
replace_tokens_with_dummy: if true, token arguments/return values are
replaced with bool arrays of size [0].
replicated_args: if present, annotates arguments as replicated.
arg_shardings: sharding annotations for each argument (optional).
result_shardings: sharding annotations for each result (optional).
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
parameters and return values to express sharding. If False, use
hlo.custom_call operators with sharding annotations.
TODO(b/228598865): remove this option when "mhlo.sharding" annotations are
propagated on non-entry functions during MLIR->HLO conversion.
input_output_aliases: optional sequence that maps argument numbers to the
corresponding output that should alias them.
api_name: The name of the higher level primitive which should show up in the
name stack.
Returns:
MLIR func op
"""
def aval_to_types(aval):
if replace_tokens_with_dummy and aval is core.abstract_token:
aval = core.ShapedArray((), np.dtype(np.bool_))
return aval_to_ir_types(aval)
num_dim_vars = len(ctx.shape_poly_state.dim_vars)
dim_var_avals = [core.ShapedArray((), dtypes.canonicalize_dtype(np.int64))] * num_dim_vars
dim_var_types = map(aval_to_types, dim_var_avals)
# Function inputs: *dim_var_values, *tokens, *actual_inputs
input_types = map(aval_to_types, jaxpr.in_avals)
output_types = map(aval_to_types, jaxpr.out_avals)
num_tokens = len(effects)
if create_tokens:
# If we create the tokens they won't be inputs to the MLIR function.
token_types = [dummy_token_type() for _ in effects]
output_token_types = [dummy_token_type() for _ in range(num_output_tokens)]
else:
# If we aren't creating tokens they will be the initial inputs to the
# MLIR function.
output_token_types = []
token_types = [token_type() for _ in effects]
token_avals = [core.AbstractToken] * num_tokens
# Order of arguments: dim vars, tokens, array inputs
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types]
if input_output_aliases is not None:
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
input_output_aliases = [*token_input_output_aliases, *input_output_aliases]
# Update the existing aliases to account for the new output values
input_output_aliases = [None if a is None
else a + num_output_tokens + num_tokens
for a in input_output_aliases]
if arg_shardings is not None:
token_shardings = [None] * (num_dim_vars + num_tokens)
arg_shardings = [*token_shardings, *arg_shardings]
if result_shardings is not None:
token_shardings = [None] * (num_tokens + num_output_tokens)
result_shardings = [*token_shardings, *result_shardings]
if replicated_args is not None:
token_replicated_args = [False] * (num_dim_vars + num_tokens)
replicated_args = [*token_replicated_args, *replicated_args]
flat_input_types = util.flatten(input_types)
flat_output_types = util.flatten(output_types)
ftype = ir.FunctionType.get(flat_input_types, flat_output_types)
func_op = func_dialect.FuncOp(name, ftype, ip=ctx.ip)
func_op.attributes["sym_visibility"] = ir.StringAttr.get(
"public" if public else "private")
ctx.symbol_table.insert(func_op)
ir_arg_shardings = None
if arg_shardings is not None:
in_avals = [None] * (num_dim_vars + num_tokens) + list(jaxpr.in_avals)
ir_arg_shardings = util.flatten(
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(in_avals, arg_shardings, input_types)])
del in_avals
ir_result_shardings = None
if result_shardings is not None:
out_avals = [None] * (num_tokens + num_output_tokens) + list(jaxpr.out_avals)
ir_result_shardings = util.flatten(
[[_to_physical_op_sharding(a, s)] * len(types)
for a, s, types in zip(out_avals, result_shardings, output_types)])
del out_avals
if (
replicated_args is not None
or ir_arg_shardings is not None
or input_output_aliases is not None
or arg_names is not None
or num_tokens > 0
):
arg_attrs: list[dict[str, ir.Attribute]] = [
{} for _ in range(len(flat_input_types))]
if replicated_args is not None:
replicated_ir_args = [[replicated] * len(types) for replicated, types
in zip(replicated_args, input_types)]
for attrs, replicated in zip(arg_attrs, util.flatten(replicated_ir_args)):
if replicated:
attrs["mhlo.is_same_data_across_replicas"] = ir.UnitAttr.get()
if use_sharding_annotations and ir_arg_shardings is not None:
for attrs, sharding in zip(arg_attrs, ir_arg_shardings):
if sharding is not None:
attrs["mhlo.sharding"] = get_sharding_attr(sharding)
if input_output_aliases is not None:
output_ids = util.unflatten(list(range(len(flat_output_types))),
map(len, output_types))
aliases: list[Optional[int]] = []
for types, alias in zip(input_types, input_output_aliases):
if alias is None:
aliases.extend([None] * len(types))
else:
aliases.extend(output_ids[alias])
for attrs, alias in zip(arg_attrs, aliases):
if alias is not None:
attrs["tf.aliasing_output"] = i32_attr(alias)
if num_tokens > 0:
token_arg_attrs = arg_attrs[num_dim_vars:num_dim_vars + num_tokens]
for attrs in token_arg_attrs:
attrs["jax.token"] = ir.BoolAttr.get(True)
if arg_names:
named_arg_attrs = arg_attrs[num_dim_vars + num_tokens:]
for attrs, name_ in zip(named_arg_attrs, arg_names):
if name_:
attrs['jax.arg_info'] = ir.StringAttr.get(name_)
func_op.arg_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in arg_attrs])
result_attrs: list[dict[str, ir.Attribute]] = [
{} for _ in range(len(flat_output_types))]
if num_tokens > 0:
token_result_attrs = result_attrs[:num_tokens]
for attrs in token_result_attrs:
attrs["jax.token"] = ir.BoolAttr.get(True)
if result_names:
named_result_attrs = result_attrs[num_tokens:]
if len(named_result_attrs) == len(result_names):
for attrs, name_ in zip(named_result_attrs, result_names):
attrs['jax.result_info'] = ir.StringAttr.get(name_)
if use_sharding_annotations and ir_result_shardings is not None:
for attrs, sharding in zip(result_attrs, ir_result_shardings):
if sharding is not None:
attrs['mhlo.sharding'] = get_sharding_attr(sharding)
func_op.result_attrs = ir.ArrayAttr.get(
[ir.DictAttr.get(attrs) for attrs in result_attrs])
entry_block = func_op.add_entry_block()
with ir.InsertionPoint(entry_block):
flat_args = entry_block.arguments
# We separate out the dimension variable inputs, the token inputs and
# the regular inputs. The dimension variables and token inputs
# will be passed to `jaxpr_subcomp` separately from the `args`.
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
# A lowering context just for function body entry/exit code.
entry_lowering_ctx = LoweringRuleContext(
ctx, None, [], None, TokenSet.create([]), None, None, dim_var_values)
if not use_sharding_annotations and ir_arg_shardings is not None:
flat_args = [
a if s is None else wrap_with_sharding_op(entry_lowering_ctx, a, a_aval, s)
for a, s, a_aval in zip(flat_args, ir_arg_shardings, input_avals)]
_, token_args, unflattened_args = util.split_list(util.unflatten(flat_args, map(len, input_types)),
[num_dim_vars, num_tokens])
if create_tokens:
tokens_in = TokenSet.create(effects)
else:
tokens_in = TokenSet(zip(effects, token_args))
args: list[list[ir.Value]] = []
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
if replace_tokens_with_dummy and aval is core.abstract_token:
args.append(hlo.CreateTokenOp().results)
else:
args.append(arg)
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
> out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts),
*args, dim_var_values=dim_var_values)
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:1060:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7fbb2812df80>, module=<jax...object at 0x7fbb281097e0>, cached_primitive_lowerings={}, cached_call_jaxpr_lowerings={}, override_lowering_rules=None)
jaxpr = { lambda ; a:bf16[2,384,2,64] b:bf16[2,384,2,64] c:bf16[2,384,2,64] d:i32[2,384]
e:i32[2,384] f:i32[2,384] g:i32[2... dtype=bfloat16),)
which_linear=(False, False, False, False, False, False, False)
] a b c d e f g
in (h,) }
tokens = <jax._src.interpreters.mlir.TokenSet object at 0x7fbb28109ed0>
consts = [], dim_var_values = []
args = ([<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb2811f770>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockAr...ckArgument object at 0x7fbb2811d330>], [<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb2815b430>], ...)
aval = <function jaxpr_subcomp.<locals>.aval at 0x7fbc21d9a8c0>
write = <function jaxpr_subcomp.<locals>.write at 0x7fbc21d9a830>
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
tokens: TokenSet,
consts: Sequence[Sequence[ir.Value]],
*args: Sequence[ir.Value],
dim_var_values: Sequence[ir.Value]
) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]:
"""Lowers a jaxpr into MLIR, inlined into an existing function.
Assumes that an MLIR context, location, and insertion point are set.
dim_var_values: the list of dimension variables values in the current
IR function, in the order of ctx.shape_poly_state.dim_vars.
"""
assert ctx.platform != "gpu"
def read(v: core.Atom) -> Sequence[ir.Value]:
if type(v) is core.Literal:
return ir_constants(v.val, canonicalize_types=True)
else:
assert isinstance(v, core.Var)
return env[v]
def aval(v: core.Atom) -> core.AbstractValue:
if type(v) is core.Literal:
return xla.abstractify(v.val)
else:
return v.aval
def write(v: core.Var, node: Sequence[ir.Value]):
assert node is not None
env[v] = tuple(node)
def get_lowering(primitive: core.Primitive) -> Optional[LoweringRule]:
if ctx.override_lowering_rules is None:
return None
for p, rule in ctx.override_lowering_rules:
if primitive is p:
return rule
return None
env: dict[core.Var, tuple[ir.Value, ...]] = {}
assert len(args) == len(jaxpr.invars), (jaxpr, args)
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
source_info = eqn.source_info.replace(
name_stack=ctx.name_stack + eqn.source_info.name_stack)
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
ctx.name_stack)
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_lowering(eqn.primitive)
if override_rule is not None:
rule = override_rule
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
rule = xla_fallback_lowering(eqn.primitive)
elif eqn.primitive in _lowerings:
rule = _lowerings[eqn.primitive]
elif eqn.primitive in xla._translations:
rule = xla_fallback_lowering(eqn.primitive)
else:
raise NotImplementedError(
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
f"found for platform {ctx.platform}")
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
tokens_in = tokens.subset(effects)
avals_in = map(aval, eqn.invars)
rule_ctx = LoweringRuleContext(
module_context=eqn_ctx, primitive=eqn.primitive, avals_in=avals_in,
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
tokens_out=None, dim_var_values=dim_var_values)
if config.jax_dynamic_shapes:
axis_size_env = {d: read(d)[0]
for a in avals_in if type(a) is core.DShapedArray
for d in a.shape if type(d) is core.Var}
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
> ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
**eqn.params)
/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py:1217:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context obj...<jax._src.interpreters.mlir.TokenSet object at 0x7fbb2810ad70>, tokens_out=None, axis_size_env=None, dim_var_values=[])
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]}
d:Ref{int32[384]} e:Ref{int32[3...] bl
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p
h[m:m+128,:] <- dz
in () }
name = 'mha_forward'
in_shapes = (ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), Sha... 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ...)
out_shapes = (ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16),)
which_linear = (False, False, False, False, False, False, ...)
interpret = False, debug = False, input_output_aliases = ()
grid_mapping = GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2...4), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0)
in_nodes = (<jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb2811f770>, <jaxlib.mlir._mlir_libs._mlir.ir.BlockArgum...BlockArgument object at 0x7fbb2811d330>, <jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb2815b430>, ...)
compiler_params = {'num_stages': 2, 'num_warps': 4}, num_warps = 4
def pallas_call_lowering(
ctx: mlir.LoweringRuleContext,
*in_nodes,
jaxpr: jax_core.Jaxpr,
name: str,
in_shapes: Tuple[jax.ShapeDtypeStruct, ...],
out_shapes: Tuple[jax.ShapeDtypeStruct, ...],
which_linear: Tuple[bool, ...],
interpret: bool,
debug: bool,
input_output_aliases: Tuple[Tuple[int, int], ...],
grid_mapping: GridMapping,
**compiler_params: Any
):
if interpret:
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx,
*in_nodes,
jaxpr=jaxpr,
name=name,
out_shapes=out_shapes,
in_shapes=in_shapes,
which_linear=which_linear,
interpret=interpret,
debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping,
**compiler_params
)
num_warps = compiler_params.get("num_warps", 4)
num_stages = compiler_params.get("num_stages", 3)
if debug:
print(jaxpr)
print(grid_mapping)
> compilation_result = compile_jaxpr(
jaxpr,
tuple((*in_shapes, *out_shapes)),
grid_mapping,
name,
num_warps,
num_stages,
debug=debug,
)
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1661:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]}
d:Ref{int32[384]} e:Ref{int32[3...] bl
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p
h[m:m+128,:] <- dz
in () }
in_shapes = (ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), Sha... 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ...)
grid_mapping = GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2...4), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0)
name = 'mha_forward', num_warps = 4, num_stages = 2, debug = False
@weakref_lru_cache
def compile_jaxpr(
jaxpr: jax_core.Jaxpr,
in_shapes,
grid_mapping: GridMapping,
name: str,
num_warps: int,
num_stages: int,
debug: bool,
) -> TritonCompilationResult:
> lowering_result = lower_jaxpr_to_triton_module(
jaxpr, in_shapes, grid_mapping, name
)
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1610:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]}
d:Ref{int32[384]} e:Ref{int32[3...] bl
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p
h[m:m+128,:] <- dz
in () }
in_shapes = (ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), Sha... 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ShapeDtypeStruct(shape=(2, 384), dtype=int32), ...)
grid_mapping = GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2...4), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0)
name = 'mha_forward'
def lower_jaxpr_to_triton_module(
jaxpr: jax_core.Jaxpr, in_shapes, grid_mapping: GridMapping, name: str
) -> tl_ir.module:
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), instantiate=True)
ir_context = tl_ir.context()
ir_context.load_triton()
builder = tl_ir.builder(ir_context)
# TODO(sharadmv): handle multiple devices, right now we assume device 0
# which is fine when we have multiple of the same GPU but this won't work in
# general.
device = 0
builder.arch = triton_kernel_call_lib.get_compute_capability(device)
module = builder.create_module()
in_avals = [var.aval for var in jaxpr.invars]
triton_types = [get_triton_type(x) for x in in_avals]
arg_types = [code_gen.str_to_ty(arg) for arg in triton_types]
assert len(jaxpr.outvars) == 0
prototype = tl.function_type([], arg_types)
out = prototype.to_ir(builder)
fn = builder.get_or_insert_function(module, name, out, "public", False)
module.push_back(fn)
entry = fn.add_entry_block()
args = []
for i in range(len(in_avals)):
fn.set_arg_attr(i, "tt.divisibility", 16)
ptr = tl.tensor(fn.args(i), prototype.param_types[i])
args.append(ptr)
builder.set_insertion_point_to_start(entry)
new_grid, program_ids = _process_grid_to_3d_grid(builder, grid_mapping)
local_program_ids = [
pid for i, pid in enumerate(program_ids) if i not in grid_mapping.mapped_dims
]
ctx = TritonModuleContext(
name, ir_context, builder, module, grid_mapping, local_program_ids
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"Scalar prefetch not supported in Triton lowering.")
start_indices = map(
partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings
)
block_infos = [
BlockInfo(
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
start_idx,
block_mapping.block_shape,
)
if block_mapping is not None
else None
for shape_dtype, block_mapping, start_idx in zip(
in_shapes, grid_mapping.block_mappings, start_indices
)
]
> () = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:222:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, b...9660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>])
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]}
d:Ref{int32[384]} e:Ref{int32[3...] bl
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p
h[m:m+128,:] <- dz
in () }
block_infos = [BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), start_indices=(<triton.language.c...e.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), ...]
args = (<triton.language.core.tensor object at 0x7fbb28150c70>, <triton.language.core.tensor object at 0x7fbb281510f0>, <trit...>, <triton.language.core.tensor object at 0x7fbb28198460>, <triton.language.core.tensor object at 0x7fbb28198b20>, ...)
read_env = <function lower_jaxpr_to_triton_ir.<locals>.read_env at 0x7fbc21d9ab00>
read_block_info_env = <function lower_jaxpr_to_triton_ir.<locals>.read_block_info_env at 0x7fbc21d9ae60>
write_env = <function lower_jaxpr_to_triton_ir.<locals>.write_env at 0x7fbc21d9aef0>
def lower_jaxpr_to_triton_ir(
ctx: TritonModuleContext,
jaxpr: jax_core.Jaxpr,
block_infos: Optional[Sequence[Optional[BlockInfo]]],
*args
) -> Sequence[Any]:
env = {}
block_info_env = {}
def read_env(var: jax_core.Atom):
if type(var) is jax_core.Literal:
t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder)
dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty
if t.type.scalar != dst_ty:
# _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64
# comes out of .tolist() as list[float], which then comes out of
# _to_tensor as a block of f32.
t = tl.semantic.cast(t, dst_ty, ctx.builder)
return t
return env[var]
def read_block_info_env(var: jax_core.Atom):
if type(var) is jax_core.Literal:
return None
return block_info_env.get(var, None)
def write_env(var: jax_core.Var, val):
env[var] = val
if block_infos is None:
block_infos = [None] * len(jaxpr.invars)
for invar, block_info in zip(jaxpr.invars, block_infos):
block_info_env[invar] = block_info
map(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = map(read_env, eqn.invars)
if eqn.primitive not in triton_lowering_rules:
raise NotImplementedError(eqn.primitive)
rule = triton_lowering_rules[eqn.primitive]
avals_in = [v.aval for v in eqn.invars]
avals_out = [v.aval for v in eqn.outvars]
eqn_block_infos = map(read_block_info_env, eqn.invars)
rule_ctx = TritonLoweringRuleContext(
ctx, avals_in, avals_out, eqn_block_infos
)
try:
outvals = rule(rule_ctx, *invals, **eqn.params)
except TritonLoweringException:
raise # We only add the extra info to the innermost exception.
except Exception as e:
> raise TritonLoweringException(
f"Exception while lowering eqn:\n {eqn}\n"
f"With context:\n {rule_ctx}\n"
f"With inval shapes={map(lambda t: t.shape, invals)}\n"
f"With inval types={map(lambda t: t.type, invals)}\n"
f"In jaxpr:\n{jaxpr}") from e
E jax._src.traceback_util.UnfilteredStackTrace: jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn:
E a:i32[6] = scan[
E jaxpr={ lambda ; b:Ref{int32[384]} c:i32[]. let
E d:i32[] = mul c 64
E e:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E f:i32[64] = add e d
E g:bool[64] = lt f 384
E h:i32[64] <- b[d:d+64]
E i:i32[] = reduce_min[axes=(0,)] h
E in (i,) }
E length=6
E linear=(False, False)
E num_carry=0
E num_consts=1
E reverse=False
E unroll=1
E ] j k
E With context:
E TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, builder=<triton._C.libtriton.triton.ir.builder object at 0x7fbb28171e90>, module=<triton._C.libtriton.triton.ir.module object at 0x7fbb28172110>, grid_mapping=GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0), program_ids=[<triton.language.core.tensor object at 0x7fbb28199660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>]), avals_in=[Ref{int32[384]}, ShapedArray(int32[6])], avals_out=[ShapedArray(int32[6])], block_infos=[BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384), dtype=int32), start_indices=(<triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), None])
E With inval shapes=[[constexpr[1]], [constexpr[6]]]
E With inval types=[pointer<int32>, <[6], int32>]
E In jaxpr:
E { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]}
E d:Ref{int32[384]} e:Ref{int32[384]} f:Ref{int32[384]} g:Ref{int32[384]} h:Ref{bfloat16[384,64]}. let
E i:i32[] = program_id[axis=0]
E j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0
E k:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] -1.0000000200408773e+20
E l:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] 0.0
E m:i32[] = mul i 128
E n:i32[128] = iota[dimension=0 dtype=int32 shape=(128,)]
E o:i32[128] = add n m
E p:bool[128] = lt o 384
E q:bf16[128,64] <- a[m:m+128,:]
E r:i32[128] <- d[m:m+128]
E s:i32[128] <- e[m:m+128]
E t:bf16[128,64] = mul q 0.125
E u:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)]
E v:i32[6] = scan[
E jaxpr={ lambda ; w:Ref{int32[384]} x:i32[]. let
E y:i32[] = mul x 64
E z:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E ba:i32[64] = add z y
E bb:bool[64] = lt ba 384
E bc:i32[64] <- w[y:y+64]
E bd:i32[] = reduce_min[axes=(0,)] bc
E in (bd,) }
E length=6
E linear=(False, False)
E num_carry=0
E num_consts=1
E reverse=False
E unroll=1
E ] f u
E be:i32[] = reduce_max[axes=(0,)] r
E bf:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)]
E bg:bool[6] = ge be v
E bh:i32[6] = convert_element_type[new_dtype=int32 weak_type=False] bg
E bi:i32[6] = mul bf bh
E bj:i32[] = argmax[axes=(0,) index_dtype=int32] bi
E bk:i32[] = add bj 1
E _:i32[] _:i32[] bl:f32[128,64] _:f32[128] _:f32[128] = while[
E body_jaxpr={ lambda ; bm:Ref{bfloat16[384,64]} bn:Ref{int32[384]} bo:Ref{int32[384]}
E bp:bf16[128,64] bq:i32[128] br:i32[128] bs:Ref{bfloat16[384,64]} bt:i32[]
E bu:i32[] bv:f32[128,64] bw:f32[128] bx:f32[128]. let
E by:i32[] = add bt 1
E bz:i32[] = mul bt 64
E ca:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E cb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bz
E cc:i32[64] = add ca cb
E cd:bool[64] = lt cc 384
E ce:bf16[64,64] <- bm[bz:bz+64,:]
E cf:i32[64] <- bn[bz:bz+64]
E cg:i32[64] <- bo[bz:bz+64]
E ch:bf16[64,64] = transpose[permutation=(1, 0)] ce
E ci:bf16[128,64] = dot_general[
E dimension_numbers=(([1], [0]), ([], []))
E ] bp ch
E cj:f32[128,64] = convert_element_type[
E new_dtype=float32
E weak_type=False
E ] ci
E ck:i32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] bq
E cl:i32[1,64] = broadcast_in_dim[
E broadcast_dimensions=(1,)
E shape=(1, 64)
E ] cg
E cm:bool[128,64] = eq ck cl
E cn:i32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] br
E co:i32[1,64] = broadcast_in_dim[
E broadcast_dimensions=(1,)
E shape=(1, 64)
E ] cf
E cp:bool[128,64] = ge cn co
E cq:bool[128,64] = and cm cp
E cr:f32[128,64] = pjit[
E jaxpr={ lambda ; cs:bool[128,64] ct:f32[128,64] cu:f32[]. let
E cv:f32[] = convert_element_type[
E new_dtype=float32
E weak_type=False
E ] cu
E cw:f32[128,64] = broadcast_in_dim[
E broadcast_dimensions=()
E shape=(128, 64)
E ] cv
E cx:f32[128,64] = select_n cs cw ct
E in (cx,) }
E name=_where
E ] cq cj -1e+20
E cy:f32[128] = reduce_max[axes=(1,)] cr
E cz:f32[128] = max cy bw
E da:f32[128] = sub bw cz
E db:f32[128] = exp da
E dc:f32[128] = mul bx db
E dd:f32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] cz
E de:f32[128,64] = sub cr dd
E df:f32[128,64] = exp de
E dg:f32[128] = reduce_sum[axes=(1,)] df
E dh:f32[128] = add dg dc
E di:f32[128] = div 1.0 dh
E dj:f32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] di
E dk:f32[128,64] = mul df dj
E dl:f32[128] = mul dc di
E dm:f32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] dl
E dn:f32[128,64] = mul bv dm
E do:bf16[64,64] <- bs[bz:bz+64,:]
E dp:bf16[128,64] = convert_element_type[
E new_dtype=bfloat16
E weak_type=False
E ] dk
E dq:bf16[128,64] = dot_general[
E dimension_numbers=(([1], [0]), ([], []))
E ] dp do
E dr:f32[128,64] = convert_element_type[
E new_dtype=float32
E weak_type=False
E ] dq
E ds:f32[128,64] = add dn dr
E in (by, bu, ds, cz, dh) }
E body_nconsts=7
E cond_jaxpr={ lambda ; dt:i32[] du:i32[] dv:f32[128,64] dw:f32[128] dx:f32[128]. let
E dy:bool[] = lt dt du
E in (dy,) }
E cond_nconsts=0
E ] b f g t s r c 0 bk j k l
E dz:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] bl
E ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p
E h[m:m+128,:] <- dz
E in () }
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:278: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
query_seq_len = 384, kv_seq_len = 384
@pytest.mark.gpu
@pytest.mark.parametrize(
"query_seq_len,kv_seq_len",
[
(384, 384),
(400, 400),
(384, 256),
(384, 300),
(400, 256),
(400, 300),
(1, 256),
(1, 300),
(256, 1),
(300, 1),
(1, 3),
],
)
def test_triton_flash_attention(query_seq_len, kv_seq_len):
q, k, v, q_pos, q_sid, kv_pos, kv_sid = _make_attention_inputs(
query_seq_len=query_seq_len,
kv_seq_len=kv_seq_len,
)
> result_triton = triton_multi_head_attention(
q,
k,
v,
q_pos,
q_sid,
kv_pos,
kv_sid,
)
tests/test_attention.py:92:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1661: in pallas_call_lowering
compilation_result = compile_jaxpr(
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:1610: in compile_jaxpr
lowering_result = lower_jaxpr_to_triton_module(
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:222: in lower_jaxpr_to_triton_module
() = lower_jaxpr_to_triton_ir(ctx, jaxpr, block_infos, *args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
ctx = TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, b...9660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>])
jaxpr = { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]}
d:Ref{int32[384]} e:Ref{int32[3...] bl
ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p
h[m:m+128,:] <- dz
in () }
block_infos = [BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384, 2, 64), dtype=bfloat16), start_indices=(<triton.language.c...e.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), ...]
args = (<triton.language.core.tensor object at 0x7fbb28150c70>, <triton.language.core.tensor object at 0x7fbb281510f0>, <trit...>, <triton.language.core.tensor object at 0x7fbb28198460>, <triton.language.core.tensor object at 0x7fbb28198b20>, ...)
read_env = <function lower_jaxpr_to_triton_ir.<locals>.read_env at 0x7fbc21d9ab00>
read_block_info_env = <function lower_jaxpr_to_triton_ir.<locals>.read_block_info_env at 0x7fbc21d9ae60>
write_env = <function lower_jaxpr_to_triton_ir.<locals>.write_env at 0x7fbc21d9aef0>
def lower_jaxpr_to_triton_ir(
ctx: TritonModuleContext,
jaxpr: jax_core.Jaxpr,
block_infos: Optional[Sequence[Optional[BlockInfo]]],
*args
) -> Sequence[Any]:
env = {}
block_info_env = {}
def read_env(var: jax_core.Atom):
if type(var) is jax_core.Literal:
t = tl.core._to_tensor(np.array(var.val).tolist(), builder=ctx.builder)
dst_ty = code_gen.str_to_ty(get_triton_type(var.aval)).element_ty
if t.type.scalar != dst_ty:
# _to_tensor(np.array(var.val).tolist()) can be lossy e.g. np.float64
# comes out of .tolist() as list[float], which then comes out of
# _to_tensor as a block of f32.
t = tl.semantic.cast(t, dst_ty, ctx.builder)
return t
return env[var]
def read_block_info_env(var: jax_core.Atom):
if type(var) is jax_core.Literal:
return None
return block_info_env.get(var, None)
def write_env(var: jax_core.Var, val):
env[var] = val
if block_infos is None:
block_infos = [None] * len(jaxpr.invars)
for invar, block_info in zip(jaxpr.invars, block_infos):
block_info_env[invar] = block_info
map(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = map(read_env, eqn.invars)
if eqn.primitive not in triton_lowering_rules:
raise NotImplementedError(eqn.primitive)
rule = triton_lowering_rules[eqn.primitive]
avals_in = [v.aval for v in eqn.invars]
avals_out = [v.aval for v in eqn.outvars]
eqn_block_infos = map(read_block_info_env, eqn.invars)
rule_ctx = TritonLoweringRuleContext(
ctx, avals_in, avals_out, eqn_block_infos
)
try:
outvals = rule(rule_ctx, *invals, **eqn.params)
except TritonLoweringException:
raise # We only add the extra info to the innermost exception.
except Exception as e:
> raise TritonLoweringException(
f"Exception while lowering eqn:\n {eqn}\n"
f"With context:\n {rule_ctx}\n"
f"With inval shapes={map(lambda t: t.shape, invals)}\n"
f"With inval types={map(lambda t: t.type, invals)}\n"
f"In jaxpr:\n{jaxpr}") from e
E jax_triton.pallas.triton_lowering.TritonLoweringException: Exception while lowering eqn:
E a:i32[6] = scan[
E jaxpr={ lambda ; b:Ref{int32[384]} c:i32[]. let
E d:i32[] = mul c 64
E e:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E f:i32[64] = add e d
E g:bool[64] = lt f 384
E h:i32[64] <- b[d:d+64]
E i:i32[] = reduce_min[axes=(0,)] h
E in (i,) }
E length=6
E linear=(False, False)
E num_carry=0
E num_consts=1
E reverse=False
E unroll=1
E ] j k
E With context:
E TritonLoweringRuleContext(context=TritonModuleContext(name='mha_forward', ir_context=<triton._C.libtriton.triton.ir.context object at 0x7fbc21d4bef0>, builder=<triton._C.libtriton.triton.ir.builder object at 0x7fbb28171e90>, module=<triton._C.libtriton.triton.ir.module object at 0x7fbb28172110>, grid_mapping=GridMapping(grid=(3, 2, 2), block_mappings=(BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0) }), BlockMapping(block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384, <jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 64), index_map_jaxpr={ lambda ; a:i32[] b:i32[] c:i32[]. let in (b, 0, c, 0) })), mapped_dims=(), num_index_operands=0), program_ids=[<triton.language.core.tensor object at 0x7fbb28199660>, <triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199b10>]), avals_in=[Ref{int32[384]}, ShapedArray(int32[6])], avals_out=[ShapedArray(int32[6])], block_infos=[BlockInfo(full_shape_dtype=ShapeDtypeStruct(shape=(2, 384), dtype=int32), start_indices=(<triton.language.core.tensor object at 0x7fbb28199990>, <triton.language.core.tensor object at 0x7fbb28199db0>), block_shape=(<jax_triton.pallas.core.Mapped object at 0x7fbc2a3d4f10>, 384)), None])
E With inval shapes=[[constexpr[1]], [constexpr[6]]]
E With inval types=[pointer<int32>, <[6], int32>]
E In jaxpr:
E { lambda ; a:Ref{bfloat16[384,64]} b:Ref{bfloat16[384,64]} c:Ref{bfloat16[384,64]}
E d:Ref{int32[384]} e:Ref{int32[384]} f:Ref{int32[384]} g:Ref{int32[384]} h:Ref{bfloat16[384,64]}. let
E i:i32[] = program_id[axis=0]
E j:f32[128,64] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 64)] 0.0
E k:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] -1.0000000200408773e+20
E l:f32[128] = broadcast_in_dim[broadcast_dimensions=() shape=(128,)] 0.0
E m:i32[] = mul i 128
E n:i32[128] = iota[dimension=0 dtype=int32 shape=(128,)]
E o:i32[128] = add n m
E p:bool[128] = lt o 384
E q:bf16[128,64] <- a[m:m+128,:]
E r:i32[128] <- d[m:m+128]
E s:i32[128] <- e[m:m+128]
E t:bf16[128,64] = mul q 0.125
E u:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)]
E v:i32[6] = scan[
E jaxpr={ lambda ; w:Ref{int32[384]} x:i32[]. let
E y:i32[] = mul x 64
E z:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E ba:i32[64] = add z y
E bb:bool[64] = lt ba 384
E bc:i32[64] <- w[y:y+64]
E bd:i32[] = reduce_min[axes=(0,)] bc
E in (bd,) }
E length=6
E linear=(False, False)
E num_carry=0
E num_consts=1
E reverse=False
E unroll=1
E ] f u
E be:i32[] = reduce_max[axes=(0,)] r
E bf:i32[6] = iota[dimension=0 dtype=int32 shape=(6,)]
E bg:bool[6] = ge be v
E bh:i32[6] = convert_element_type[new_dtype=int32 weak_type=False] bg
E bi:i32[6] = mul bf bh
E bj:i32[] = argmax[axes=(0,) index_dtype=int32] bi
E bk:i32[] = add bj 1
E _:i32[] _:i32[] bl:f32[128,64] _:f32[128] _:f32[128] = while[
E body_jaxpr={ lambda ; bm:Ref{bfloat16[384,64]} bn:Ref{int32[384]} bo:Ref{int32[384]}
E bp:bf16[128,64] bq:i32[128] br:i32[128] bs:Ref{bfloat16[384,64]} bt:i32[]
E bu:i32[] bv:f32[128,64] bw:f32[128] bx:f32[128]. let
E by:i32[] = add bt 1
E bz:i32[] = mul bt 64
E ca:i32[64] = iota[dimension=0 dtype=int32 shape=(64,)]
E cb:i32[] = convert_element_type[new_dtype=int32 weak_type=False] bz
E cc:i32[64] = add ca cb
E cd:bool[64] = lt cc 384
E ce:bf16[64,64] <- bm[bz:bz+64,:]
E cf:i32[64] <- bn[bz:bz+64]
E cg:i32[64] <- bo[bz:bz+64]
E ch:bf16[64,64] = transpose[permutation=(1, 0)] ce
E ci:bf16[128,64] = dot_general[
E dimension_numbers=(([1], [0]), ([], []))
E ] bp ch
E cj:f32[128,64] = convert_element_type[
E new_dtype=float32
E weak_type=False
E ] ci
E ck:i32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] bq
E cl:i32[1,64] = broadcast_in_dim[
E broadcast_dimensions=(1,)
E shape=(1, 64)
E ] cg
E cm:bool[128,64] = eq ck cl
E cn:i32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] br
E co:i32[1,64] = broadcast_in_dim[
E broadcast_dimensions=(1,)
E shape=(1, 64)
E ] cf
E cp:bool[128,64] = ge cn co
E cq:bool[128,64] = and cm cp
E cr:f32[128,64] = pjit[
E jaxpr={ lambda ; cs:bool[128,64] ct:f32[128,64] cu:f32[]. let
E cv:f32[] = convert_element_type[
E new_dtype=float32
E weak_type=False
E ] cu
E cw:f32[128,64] = broadcast_in_dim[
E broadcast_dimensions=()
E shape=(128, 64)
E ] cv
E cx:f32[128,64] = select_n cs cw ct
E in (cx,) }
E name=_where
E ] cq cj -1e+20
E cy:f32[128] = reduce_max[axes=(1,)] cr
E cz:f32[128] = max cy bw
E da:f32[128] = sub bw cz
E db:f32[128] = exp da
E dc:f32[128] = mul bx db
E dd:f32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] cz
E de:f32[128,64] = sub cr dd
E df:f32[128,64] = exp de
E dg:f32[128] = reduce_sum[axes=(1,)] df
E dh:f32[128] = add dg dc
E di:f32[128] = div 1.0 dh
E dj:f32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] di
E dk:f32[128,64] = mul df dj
E dl:f32[128] = mul dc di
E dm:f32[128,1] = broadcast_in_dim[
E broadcast_dimensions=(0,)
E shape=(128, 1)
E ] dl
E dn:f32[128,64] = mul bv dm
E do:bf16[64,64] <- bs[bz:bz+64,:]
E dp:bf16[128,64] = convert_element_type[
E new_dtype=bfloat16
E weak_type=False
E ] dk
E dq:bf16[128,64] = dot_general[
E dimension_numbers=(([1], [0]), ([], []))
E ] dp do
E dr:f32[128,64] = convert_element_type[
E new_dtype=float32
E weak_type=False
E ] dq
E ds:f32[128,64] = add dn dr
E in (by, bu, ds, cz, dh) }
E body_nconsts=7
E cond_jaxpr={ lambda ; dt:i32[] du:i32[] dv:f32[128,64] dw:f32[128] dx:f32[128]. let
E dy:bool[] = lt dt du
E in (dy,) }
E cond_nconsts=0
E ] b f g t s r c 0 bk j k l
E dz:bf16[128,64] = convert_element_type[new_dtype=bfloat16 weak_type=False] bl
E ea:bool[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] p
E h[m:m+128,:] <- dz
E in () }
/usr/local/lib/python3.10/dist-packages/jax_triton/pallas/triton_lowering.py:278: TritonLoweringException
=========================== short test summary info ============================
FAILED tests/test_attention.py::test_triton_flash_attention[384-384] - jax_tr...
============================== 1 failed in 13.12s ==============================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment