This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ['XLA_FLAGS'] = "--xla_dump_to=./hlo" | |
import jax | |
from jax._src import custom_derivatives | |
import functools | |
def f_fwd_rule(a, b): | |
o = a + b | |
# works | |
# o = jax._src.ad_checkpoint.checkpoint_name(o, "context") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["XLA_FLAGS"] = "--xla_dump_to=./hlo --xla_dump_hlo_as_text" | |
from functools import partial | |
import jax | |
import jax.numpy as jnp | |
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention, MaskType | |
from jax.sharding import Mesh, PartitionSpec as P | |
from jax.experimental import shard_map | |
from jax._src.sharding_impls import ( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from absl.testing import absltest | |
import os | |
# --xla_dump_hlo_pass_re=.* | |
os.environ["XLA_FLAGS"] = \ | |
"--xla_dump_to=./hlo --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=.* --xla_dump_disable_metadata --xla_disable_hlo_passes=float-normalization-bf16" | |
import sys | |
import numpy as np |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
os.environ["XLA_FLAGS"] = "--xla_dump_to=./hlo --xla_dump_hlo_as_text" | |
from functools import partial | |
import jax | |
import jax.numpy as jnp | |
from jax._src import test_util as jtu | |
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention, MaskType | |
from jax.sharding import Mesh, PartitionSpec as P | |
from jax.experimental import shard_map |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
from jax._src import test_util as jtu | |
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention, MaskType | |
from jax.experimental.pallas.ops.gpu import decode_attention as gpu_pallas_decode_attention | |
def transpose(tensor): | |
return jnp.transpose(tensor, (0, 2, 1, 3)) | |
def get_encoded_padding_mask(encoded): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax | |
import jax.numpy as jnp | |
from jax._src import test_util as jtu | |
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention | |
def transpose(tensor): | |
return jnp.transpose(tensor, (0, 2, 1, 3)) | |
def normalized_cudnn_attention(q, k1, v1, k2, v2): | |
(encoded1, stat1), _ = jax.vjp(dot_product_attention, q, k1, v1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc | |
index 9c55820c5..fa8492d62 100644 | |
--- a/xla/service/gpu/gpu_compiler.cc | |
+++ b/xla/service/gpu/gpu_compiler.cc | |
@@ -55,6 +55,7 @@ limitations under the License. | |
#include "llvm/Support/Error.h" | |
#include "llvm/Support/raw_ostream.h" | |
#include "llvm/Transforms/Utils/SplitModule.h" | |
+#include "llvm/Transforms/Utils/Cloning.h" | |
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wabt from 'wabt'; | |
async function run(watSource : string, config: any) : Promise<number> { | |
const wabtApi = await wabt(); | |
const parsed = wabtApi.parseWat("example", watSource); | |
const binary = parsed.toBinary({}); | |
const wasmModule = await WebAssembly.instantiate(binary.buffer, config); | |
return (wasmModule.instance.exports as any)._start(); | |
} |

This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# tvm/python/tvm/topi/nn/dense.py | |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# |
NewerOlder