Created
February 12, 2022 14:27
-
-
Save tezansahu/69fa8d1aa2a877cc38d3e014b18cd7ef to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from copy import deepcopy | |
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Tuple | |
from datasets import load_dataset, set_caching_enabled | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
from transformers import ( | |
# Preprocessing / Common | |
AutoTokenizer, AutoFeatureExtractor, | |
# Text & Image Models (Now, image transformers like ViTModel, DeiTModel, BEiT can also be loaded using AutoModel) | |
AutoModel, | |
# Training / Evaluation | |
TrainingArguments, Trainer, | |
# Misc | |
logging | |
) | |
import nltk | |
nltk.download('wordnet') | |
from nltk.corpus import wordnet | |
from sklearn.metrics import accuracy_score, f1_score | |
# SET CACHE FOR HUGGINGFACE TRANSFORMERS + DATASETS | |
os.environ['HF_HOME'] = os.path.join(".", "cache") | |
# SET ONLY 1 GPU DEVICE | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
set_caching_enabled(True) | |
logging.set_verbosity_error() | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment