Skip to content

Instantly share code, notes, and snippets.

@KexinFeng
KexinFeng / trace_model_with_past_key_values.py
Last active December 7, 2023 05:03
GPT2 model tracing
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model_name = 'gpt2-large'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# add the EOS token as PAD token to avoid warnings
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, torchscript=True)
# %% model_inputs
@KexinFeng
KexinFeng / TransferFreshFruitTrainExperiment.java
Last active December 16, 2022 17:24
Transfer learning experiment with FreshFruit dataset: the reduction of the training data size.
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.cv.classification.FruitsFreshAndRotten;
import ai.djl.engine.Engine;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.CenterCrop;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.OneHot;
import ai.djl.modality.cv.transform.RandomFlipLeftRight;
import ai.djl.modality.cv.transform.RandomFlipTopBottom;
@KexinFeng
KexinFeng / README.md
Last active January 26, 2023 03:12 — forked from Carkham/README.md

This Gist contains some python scripts for the DJL timeseries package. Including drawing prediction graphs and coarse-grained data, etc. For more information, check out the gluonTS example.

plot

If you want to visualize your forecasts please run:

python plot.py -p YOUR_PRED_LENGTH -f YOUR_FREQUENCY -t YOUR_STARGT_TIME -target-path YOUR_TARGET_PATH -samples-path YOUR_SAMPLES_PATH

plot result

forecast

@KexinFeng
KexinFeng / BertTranslator.java
Last active January 4, 2024 08:41
Deploying HuggingFace QA model in Java
import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.bert.BertToken;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;