Skip to content

Instantly share code, notes, and snippets.

View merrymercy's full-sized avatar
:octocat:

Lianmin Zheng merrymercy

:octocat:
View GitHub Profile
@merrymercy
merrymercy / redirect.py
Created March 21, 2023 00:57
Permanently redirect to another url.
from fastapi import FastAPI
from starlette.responses import RedirectResponse
app = FastAPI()
@app.get("/")
async def redirect():
response = RedirectResponse(url="https://alpa-projects.github.io/opt", status_code=301)
return response
@merrymercy
merrymercy / test_pad.py
Last active August 21, 2022 18:24
test padding in the middle
"""Use huggingface/transformers interface and Alpa backend for distributed inference."""
from transformers import AutoTokenizer
from opt_serving.model.wrapper import get_model
import numpy as np
import torch
# Load the tokenizer. We have to use the 30B version because
# other versions have some issues. The 30B version works for all OPT models.
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False)
tokenizer.add_bos_token = False
from flax import linen as nn
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
import numpy as np
import optax
def create_train_state_and_batch(batch_size, hidden_size, use_remat):
class Layer(nn.Module):
@merrymercy
merrymercy / jax_pjit_embedding.py
Created June 17, 2022 00:40
Use jax.pjit to partition embedding table
"""Test embedding table partition in XLA.
References:
- Introduction to pjit. https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html
"""
from functools import partial
import jax
import jax.numpy as jnp
// Original : https://github.com/alpa-projects/tensorflow-alpa/blob/d298f84474a04ecce02085332793e6115c0c8e0e/tensorflow/compiler/xla/service/spmd/auto_sharding_strategy.h#L854-L876
if (adj_list.size() > 1) {
// Merge src to dst.
//
// Before:
//
// src ---- adj ---- dst
// | |
// -------------------
HloModule train_step_shard_parallel.3684, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36
import time
import cupy as cp
def benchmark(n, k, m, dtype):
warmup = 2
number = 100
a = cp.ones((n, k), dtype)
b = cp.ones((k, m), dtype)
# Style 1
@auto_parallel
def step(batch, weight):
grads = grad(loss_func)(batch, weight)
# do not know where to insert pmean
new_weight = optimier_step(grads)
return new_weight # REQUIREMENT: new_weight and weight maps
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax.nn import relu
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import mesh
from jax.experimental.pjit import pjit, with_sharding_constraint
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
def split(a, axis, factor):
assert a.shape[axis] % factor == 0
new_shape = a.shape[:axis] + (factor, a.shape[axis] // factor) + a.shape[axis+1:]
a = a.reshape(new_shape)