Skip to content

Instantly share code, notes, and snippets.

View wkcn's full-sized avatar
🐳
Tell Your World 🎵

JackieWu wkcn

🐳
Tell Your World 🎵
  • China
View GitHub Profile
import math
import torch
from torch.nn import LayerNorm
from megatron.enums import AttnMaskType
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.utils import attention_mask_func
from megatron.global_vars import _parse_args
import torch
from torch import nn
import numpy as np
from flash_attn.flash_attention import FlashAttention
class Attention(nn.Module):
use_flash_attn: bool = False
def __init__(
@wkcn
wkcn / synset_words.txt
Created January 14, 2023 11:37
ImageNet-1k classification names
n01440764 tench, Tinca tinca
n01443537 goldfish, Carassius auratus
n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
n01491361 tiger shark, Galeocerdo cuvieri
n01494475 hammerhead, hammerhead shark
n01496331 electric ray, crampfish, numbfish, torpedo
n01498041 stingray
n01514668 cock
n01514859 hen
n01518878 ostrich, Struthio camelus
@wkcn
wkcn / fp8_gemm.py
Created October 17, 2022 09:50
FP8GEMM
import torch
import transformer_engine.pytorch.cpp_extensions as texcpp
from transformer_engine.pytorch.module import get_workspace
import transformer_engine_extensions as tex
scale = 1.0
meta = tex.FP8TensorMeta()
meta.scale = torch.ones(1,dtype=torch.float32, device="cuda") * scale
meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale
@wkcn
wkcn / measure_fp8_speed.py
Created October 16, 2022 15:42
measure FP8 speed
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
import transformer_engine_extensions as tex
import copy
import math
from typing import Callable, Optional
def speedometer(
module: torch.nn.Module,
@wkcn
wkcn / op.h
Last active July 11, 2019 08:35
MXNet CPP Op
This file has been truncated, but you can view the full file.
/*!
* Copyright (c) 2019 by Contributors
* \file op.h
* \brief definition of all the operators
* \author Chuntao Hong, Xin Li
*/
#ifndef MXNET_CPP_OP_H_
#define MXNET_CPP_OP_H_

type this in CMD in sequential order: bcdedit /create {0cb3b571-2f2e-4343-a879-d86a476d7215} /d "DebugTool" /application osloader

bcdedit /set {0cb3b571-2f2e-4343-a879-d86a476d7215} path "\EFI\Microsoft\Boot\SecConfig.efi"

bcdedit /set {bootmgr} bootsequence {0cb3b571-2f2e-4343-a879-d86a476d7215}

@wkcn
wkcn / custom_op_design.cpp
Created April 27, 2019 03:40
Custom Operator Design
#include <iostream>
#include <initializer_list>
using namespace std;
typedef union {
int64_t v_int64;
double v_float64;
void* v_handle;
} TValue;
@wkcn
wkcn / del_thread.py
Last active April 20, 2019 01:54
Deleter Thread
import mxnet as mx
from mxnet.base import check_call, _LIB
from multiprocessing.pool import ThreadPool
import time
num_workers = 16
old_deleter = mx.nd.NDArray.__del__
del_pool = ThreadPool(num_workers)
@wkcn
wkcn / dataloader.py
Created April 20, 2019 01:31
GluonDataloader
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#