Skip to content

Instantly share code, notes, and snippets.

View yaoyaoding's full-sized avatar

Yaoyao Ding yaoyaoding

View GitHub Profile
from typing import List
import os
import torch
import hidet
from hidet.apps.llm import create_llm
from hidet.apps.llm.sampler import SamplingParams
from hidet.apps.llm.nn.attention import DefaultAttnState
from hidet.apps.llm.tokenizer import Tokenizer
from hidet.apps.llm.modeling.llama import LlamaForCausalLM
@yaoyaoding
yaoyaoding / batch_block_update.py
Created October 16, 2023 17:39
Batch update of blocks
from typing import List
from functools import lru_cache
import torch
import hidet
hidet.option.cache_dir('./outs/cache')
@lru_cache(maxsize=None)
def build_kernel():
@yaoyaoding
yaoyaoding / trition_utils.py
Created August 18, 2023 18:54
Dump Triton IR
import atexit
# clean the cache dir in envorioment variable TRITON_CACHE_DIR
import os
import shutil
if 'TRITON_CACHE_DIR' not in os.environ:
os.environ['TRITON_CACHE_DIR'] = './triton_cache'
cache_dir = os.environ['TRITON_CACHE_DIR']
if os.path.exists(cache_dir):

(Beta) Hidet: a dynamo backend focuses on inference acceleration

With torch dynamo, we can dispatch a pytorch model to other awesome deep learning framework/compilers for acceleration. Hidet is one of such deep learning compilers that accelerates your model with a bunch of optimizations (e.g., subgraph fusion, rewriting and kernel tuning). To use hidet, please first install it via

$ pip install hidet

Then you can enable it via torch.compile(model, backend='hidet') as shown in the code snippet below:

import torch
import hidet 
from typing import List, Union
import json
from collections import defaultdict
class Model:
def __init__(self, graph, description="", author="", company="", license="", domain="", source=""):
self.graphs: List[Graph] = [graph]
self.description: str = description
self.author: str = author
#include <cassert>
#include <cstdio>
extern "C" {
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_c_init_warp(float out[64]) {
int32_t lane_id = (threadIdx.x % 32);
out[0] = 0.0;
out[1] = 0.0;
out[2] = 0.0;
out[3] = 0.0;
#include <cassert>
#include <cstdio>
extern "C" {
__device__ __forceinline__ void matmul_bt128x128_bsz256_s128x128_block_c_init_warp(float out[64]) {
int32_t lane_id = (threadIdx.x % 32);
out[0] = 0.0;
out[1] = 0.0;
out[2] = 0.0;
out[3] = 0.0;
@yaoyaoding
yaoyaoding / extract_launch_config.py
Last active November 28, 2021 04:49
Extract the launch configuration from PrimFunc.
import tvm
from tvm import te, tir
def extract_launch_config(prim_func: tir.PrimFunc):
"""
Extract the launch configuration of given prim_func.
Parameters
----------
@yaoyaoding
yaoyaoding / taso_inception_v3.py
Last active February 28, 2021 21:05
Incpetion V3 defined in TASO API.
import taso
def get_pads(kernel, padding):
if sum(padding) == 0 and sum(kernel) > 2:
pads = "VALID"
else:
pads = "SAME"
return pads
def conv2d(graph, v, out_channels, kernel=(1, 1), stride=(1, 1), padding=(0, 0)):
Names: GFlops: MParams:
resnet50 4.113562624 25.557032
resnext50 4.261431296 25.028904
isetnet_0 2.473464832 15.339816
isetnet_1 2.778018304 17.054104
isetnet_2 6.111159296 34.855528
isetnet_3 7.098017792 39.636352
isetnet_4 2.746732544 16.71396
isetnet_5 5.106156032 29.349224
isetnet_6 3.028927488 18.108968