For an IterVar (or an axis), it has three kinds of features
- axis attribute
- arithmetic feature
- touch feature
# mpirun -np 2 python p2p-nonblocking.py | |
import cupy as cp | |
import cupy.cuda.nccl as nccl | |
from mpi4py import MPI | |
import time | |
import os | |
nbytes = 1024*1024*32 | |
data_type = cp.float32 |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# |
from fastapi import FastAPI | |
from starlette.responses import RedirectResponse | |
app = FastAPI() | |
@app.get("/") | |
async def redirect(): | |
response = RedirectResponse(url="https://alpa-projects.github.io/opt", status_code=301) | |
return response |
"""Use huggingface/transformers interface and Alpa backend for distributed inference.""" | |
from transformers import AutoTokenizer | |
from opt_serving.model.wrapper import get_model | |
import numpy as np | |
import torch | |
# Load the tokenizer. We have to use the 30B version because | |
# other versions have some issues. The 30B version works for all OPT models. | |
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False) | |
tokenizer.add_bos_token = False |
from flax import linen as nn | |
from flax.training.train_state import TrainState | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import optax | |
def create_train_state_and_batch(batch_size, hidden_size, use_remat): | |
class Layer(nn.Module): |
"""Test embedding table partition in XLA. | |
References: | |
- Introduction to pjit. https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html | |
""" | |
from functools import partial | |
import jax | |
import jax.numpy as jnp |
// Original : https://github.com/alpa-projects/tensorflow-alpa/blob/d298f84474a04ecce02085332793e6115c0c8e0e/tensorflow/compiler/xla/service/spmd/auto_sharding_strategy.h#L854-L876 | |
if (adj_list.size() > 1) { | |
// Merge src to dst. | |
// | |
// Before: | |
// | |
// src ---- adj ---- dst | |
// | | | |
// ------------------- |
HloModule train_step_shard_parallel.3684, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36 |
import time | |
import cupy as cp | |
def benchmark(n, k, m, dtype): | |
warmup = 2 | |
number = 100 | |
a = cp.ones((n, k), dtype) | |
b = cp.ones((k, m), dtype) |