Skip to content

Instantly share code, notes, and snippets.

View merrymercy's full-sized avatar
:octocat:

Lianmin Zheng merrymercy

:octocat:
View GitHub Profile
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
import jax.numpy as jnp
nbytes = 1024*1024*32
data_type = cp.float32
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
import jax.numpy as jnp
nbytes = 1024*1024*32
data_type = cp.float32
# mpirun -np 2 python p2p-nonblocking.py
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
import jax.numpy as jnp
# mpirun -np 2 python p2p-nonblocking.py
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
import jax.numpy as jnp
# mpirun -np 2 python p2p-nonblocking.py
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
nbytes = 1024*1024*32
data_type = cp.float32
# mpirun -np 2 python p2p-nonblocking.py
import cupy as cp
import cupy.cuda.nccl as nccl
from mpi4py import MPI
import time
import os
nbytes = 1024*1024*32
data_type = cp.float32
@merrymercy
merrymercy / redirect.py
Created March 21, 2023 00:57
Permanently redirect to another url.
from fastapi import FastAPI
from starlette.responses import RedirectResponse
app = FastAPI()
@app.get("/")
async def redirect():
response = RedirectResponse(url="https://alpa-projects.github.io/opt", status_code=301)
return response
@merrymercy
merrymercy / test_pad.py
Last active August 21, 2022 18:24
test padding in the middle
"""Use huggingface/transformers interface and Alpa backend for distributed inference."""
from transformers import AutoTokenizer
from opt_serving.model.wrapper import get_model
import numpy as np
import torch
# Load the tokenizer. We have to use the 30B version because
# other versions have some issues. The 30B version works for all OPT models.
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False)
tokenizer.add_bos_token = False
from flax import linen as nn
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
import numpy as np
import optax
def create_train_state_and_batch(batch_size, hidden_size, use_remat):
class Layer(nn.Module):
@merrymercy
merrymercy / jax_pjit_embedding.py
Created June 17, 2022 00:40
Use jax.pjit to partition embedding table
"""Test embedding table partition in XLA.
References:
- Introduction to pjit. https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html
"""
from functools import partial
import jax
import jax.numpy as jnp