Skip to content

Instantly share code, notes, and snippets.

View merrymercy's full-sized avatar
:octocat:

Lianmin Zheng merrymercy

:octocat:
View GitHub Profile
@merrymercy
merrymercy / graph_format.md
Last active August 25, 2020 08:53
Graph Format
struct Edge {
  int src;
  int dst;
  float feature[100];
}

struct Node {
  int node_type;
  int id;
import numpy as np
import tvm
from tvm import te, auto_scheduler, topi
@auto_scheduler.register_workload
def dense_layer(in_dim, out_dim):
data = te.placeholder((1, in_dim), name="data")
weight = te.placeholder((out_dim, in_dim), name="weight")
bias = te.placeholder((out_dim,), name="bias")
out = topi.nn.dense(data, weight)
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
def split(a, axis, factor):
assert a.shape[axis] % factor == 0
new_shape = a.shape[:axis] + (factor, a.shape[axis] // factor) + a.shape[axis+1:]
a = a.reshape(new_shape)
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.nn import relu
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import mesh
from jax.experimental.pjit import pjit, with_sharding_constraint
# Style 1
@auto_parallel
def step(batch, weight):
grads = grad(loss_func)(batch, weight)
# do not know where to insert pmean
new_weight = optimier_step(grads)
return new_weight # REQUIREMENT: new_weight and weight maps
import time
import cupy as cp
def benchmark(n, k, m, dtype):
warmup = 2
number = 100
a = cp.ones((n, k), dtype)
b = cp.ones((k, m), dtype)
HloModule train_step_shard_parallel.3684, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36
// Original : https://github.com/alpa-projects/tensorflow-alpa/blob/d298f84474a04ecce02085332793e6115c0c8e0e/tensorflow/compiler/xla/service/spmd/auto_sharding_strategy.h#L854-L876
if (adj_list.size() > 1) {
// Merge src to dst.
//
// Before:
//
// src ---- adj ---- dst
// | |
// -------------------
@merrymercy
merrymercy / feature.md
Last active May 24, 2022 07:10
Feature description for autotvm

Features

Loop-based Feature

For an IterVar (or an axis), it has three kinds of features

  • axis attribute
  • arithmetic feature
  • touch feature

Axis Attribute

@merrymercy
merrymercy / jax_pjit_embedding.py
Created June 17, 2022 00:40
Use jax.pjit to partition embedding table
"""Test embedding table partition in XLA.
References:
- Introduction to pjit. https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html
"""
from functools import partial
import jax
import jax.numpy as jnp