This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
categories: List[str] | |
category_ids: List[int] | |
""" | |
label_details = list(map(lambda x, y: x+ ':' +str(y), categories, category_ids)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
updated full list: | |
https://simpletransformers.ai/docs/usage/#configuring-a-simple-transformers-model | |
""" | |
self.args = { | |
"output_dir": "outputs/", | |
"cache_dir": "cache_dir/", |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# see the total number of nan values in the dataset | |
df.isnull().sum().sum() | |
# see the rows which contains NaN values | |
nan_rows = df[df.isnull().T.any().T] | |
print(nan_rows) | |
# remove NaN values | |
df = df[df['column_name'].notnull()] | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
# comments = [......] | |
# true_label = [......] | |
# predictions = [......] | |
df = pd.DataFrame( | |
{'text': comments, | |
'true labels': true_label, | |
'predicted labels': predictions | |
}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# data is out dataframe | |
# train 80% and test 20% | |
# data is our dataframe | |
# data['class_id'] is our target column | |
from sklearn.model_selection import train_test_split | |
train_df, test_df = train_test_split(data, | |
stratify=data['class_id'], | |
test_size=0.20) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This method get randomly k elements of each class. | |
def sampling_k_elements(group, k=3): | |
if len(group) < k: | |
return group | |
return group.sample(k) | |
balanced = df.groupby('class').apply(sampling_k_elements).reset_index(drop=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# column_name's feature value will be converted into 0 and 1 | |
# value_1 class will be converted to 0 | |
# value_2 class will be converted to 1 | |
# Binary Classification | |
data.loc[data['column_name'] == 'value_1', 'class_id'] = 1 | |
data.loc[data['column_name'] != 'value_2', 'class_id'] = 0 | |
# Multiclass Classification | |
for i in range(len(data['column_name'].unique())): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# anaconda version | |
# create the venv with python | |
conda create -n envname python=3.6.9 | |
# then install the following packages with the specified version | |
!pip install torch===1.2.0 torchvision===0.4.0 -f https://download.pytorch.org/whl/torch_stable.html | |
!pip install transformers==2.11.0 | |
!pip install simpletransformers==0.41.1 | |
!git clone --recursive https://github.com/NVIDIA/apex.git | |
!cd apex && pip install . |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data = #pandas dataframe | |
data = data[data['ColumnName'].notnull()] | |
data.drop_duplicates(keep=False, inplace=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Assuming we have only two class | |
data = # DataFrame | |
print(data.class_id.value_counts()) | |
# assuming 1 is the minority class and 0 is the majority class | |
# minority class length | |
minority_class_length = len(data[data['target_column'] == 1]) | |
print(minority_class_length) |