Skip to content

Instantly share code, notes, and snippets.

View Raibows's full-sized avatar
👀
Exploring

Chi Zuo Raibows

👀
Exploring
View GitHub Profile
@Raibows
Raibows / split_qkv4vllm.py
Created May 13, 2024 02:05
split merged qkv_proj of lora for phi-3, useful for enabling lora in VLLM.
import torch
import os
import json
from safetensors.torch import load_file, save_file
def replicate_lora_a(name: str, weight: "torch.Tensor") -> dict[str, "torch.Tensor"]:
prefix, suffix = name.split('qkv_proj')
res = {}
for t in ['q_proj', 'k_proj', 'v_proj']:
name = f"{prefix}{t}{suffix}"
@Raibows
Raibows / hf_download.py
Last active June 11, 2024 02:03
download model from huggingface
import os
import time
import requests
from requests.adapters import HTTPAdapter, Retry
from huggingface_hub import configure_http_backend
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='name of huggingface model to download', default='sshleifer/tiny-gpt2')
@Raibows
Raibows / dispatch_openai_requests.py
Created May 8, 2023 09:10 — forked from neubig/dispatch_openai_requests.py
A simple script to get results from the OpenAI Asynchronous API
import openai
import asyncio
from typing import Any
async def dispatch_openai_requests(
messages_list: list[list[dict[str,Any]]],
model: str,
temperature: float,
max_tokens: int,
top_p: float,
@Raibows
Raibows / conda.yaml
Last active April 21, 2023 06:04
conda
channels:
- anaconda
- pytorch
- nvidia
- conda-forge
dependencies:
- python=3.8
- pip
- pytorch=1.13.1
- torchvision
@Raibows
Raibows / simcse.py
Last active November 30, 2021 02:50
SimCSE loss function pytorch implement
# note this is a copy from https://paste.ubuntu.com/p/Nx5CcSmhHn/ for convenience
import torch
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def SimCSE_loss(pred, tau=0.05):
ids = torch.arange(0, pred.shape[0], device=device)
y_true = ids + 1 - ids % 2 * 2
similarities = F.cosine_similarity(pred.unsqueeze(1), pred.unsqueeze(0), dim=2)