Skip to content

Instantly share code, notes, and snippets.

View Cjkkkk's full-sized avatar
🏠
coding...

Shanbin Ke Cjkkkk

🏠
coding...
View GitHub Profile
@Cjkkkk
Cjkkkk / remat.py
Created July 30, 2025 21:47
jax remat example
import os
os.environ['XLA_FLAGS'] = "--xla_dump_to=./hlo"
import jax
from jax._src import custom_derivatives
import functools
def f_fwd_rule(a, b):
o = a + b
# works
# o = jax._src.ad_checkpoint.checkpoint_name(o, "context")
import os
os.environ["XLA_FLAGS"] = "--xla_dump_to=./hlo --xla_dump_hlo_as_text"
from functools import partial
import jax
import jax.numpy as jnp
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention, MaskType
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import shard_map
from jax._src.sharding_impls import (
@Cjkkkk
Cjkkkk / flex_attn_gym.py
Created May 1, 2025 20:43
flex attn example
from functools import partial
from absl.testing import absltest
import os
# --xla_dump_hlo_pass_re=.*
os.environ["XLA_FLAGS"] = \
"--xla_dump_to=./hlo --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=.* --xla_dump_disable_metadata --xla_disable_hlo_passes=float-normalization-bf16"
import sys
import numpy as np
@Cjkkkk
Cjkkkk / ring_attn_fwd.py
Last active April 25, 2025 21:41
ring attention fwd using JAX SDPA
import os
os.environ["XLA_FLAGS"] = "--xla_dump_to=./hlo --xla_dump_hlo_as_text"
from functools import partial
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention, MaskType
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import shard_map
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention, MaskType
from jax.experimental.pallas.ops.gpu import decode_attention as gpu_pallas_decode_attention
def transpose(tensor):
return jnp.transpose(tensor, (0, 2, 1, 3))
def get_encoded_padding_mask(encoded):
@Cjkkkk
Cjkkkk / cudnn_normalization.py
Created March 17, 2025 21:23
cudnn_normalization
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
def transpose(tensor):
return jnp.transpose(tensor, (0, 2, 1, 3))
def normalized_cudnn_attention(q, k1, v1, k2, v2):
(encoded1, stat1), _ = jax.vjp(dot_product_attention, q, k1, v1)
diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc
index 9c55820c5..fa8492d62 100644
--- a/xla/service/gpu/gpu_compiler.cc
+++ b/xla/service/gpu/gpu_compiler.cc
@@ -55,6 +55,7 @@ limitations under the License.
#include "llvm/Support/Error.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/SplitModule.h"
+#include "llvm/Transforms/Utils/Cloning.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
@Cjkkkk
Cjkkkk / quiz3.ts
Last active April 18, 2022 06:50
quiz3.ts
import wabt from 'wabt';
async function run(watSource : string, config: any) : Promise<number> {
const wabtApi = await wabt();
const parsed = wabtApi.parseWat("example", watSource);
const binary = parsed.toBinary({});
const wasmModule = await WebAssembly.instantiate(binary.buffer, config);
return (wasmModule.instance.exports as any)._start();
}
@Cjkkkk
Cjkkkk / 1.png
Last active March 12, 2022 22:46
1.png
# tvm/python/tvm/topi/nn/dense.py
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#