Skip to content

Instantly share code, notes, and snippets.

@tezansahu
Created March 7, 2022 21:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tezansahu/a22dd2e4d880d4fc2e308e36795219a2 to your computer and use it in GitHub Desktop.
Save tezansahu/a22dd2e4d880d4fc2e308e36795219a2 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "90d028fc",
"metadata": {},
"source": [
"# Financial Sentiment Analysis with Simple Transformers and MLFoundry"
]
},
{
"cell_type": "markdown",
"id": "76cee7f1",
"metadata": {},
"source": [
"## Preliminaries\n",
"\n",
"First, we need to import the required libraries."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8536b639",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:48:17.131 INFO streamlit_gradio.networking: Hashes generated for all static assets.\n"
]
}
],
"source": [
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import f1_score, accuracy_score\n",
"\n",
"import torch\n",
"from simpletransformers.classification import ClassificationModel\n",
"\n",
"import mlfoundry as mlf"
]
},
{
"cell_type": "markdown",
"id": "640f21c0",
"metadata": {},
"source": [
"Optionally, we can also clear the cache in CUDA"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0437c17e",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "markdown",
"id": "f69d03e3",
"metadata": {},
"source": [
"## Exploring & Processing the Dataset"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "098d8ae9",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv(\"all-data.csv\", header=None, names=[\"sentiment\", \"headline\"], encoding = 'ISO-8859-1')"
]
},
{
"cell_type": "markdown",
"id": "90a7f746",
"metadata": {},
"source": [
"We inspect the distribution of the number of words in the headlines to figure out the max number of tokens to be used by the tokenizer."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9233f571",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([2.200e+02, 1.402e+03, 1.504e+03, 9.470e+02, 4.890e+02, 2.310e+02,\n",
" 5.000e+01, 2.000e+00, 0.000e+00, 1.000e+00]),\n",
" array([ 2. , 9.9, 17.8, 25.7, 33.6, 41.5, 49.4, 57.3, 65.2, 73.1, 81. ]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAASnUlEQVR4nO3df6zd9X3f8edruCGFdtjAHaO20+s1ViIWNQm7IkSpqgw6YiCK+SONQNHiZpasaXRNmkiJaaWitqpEtKo0kTokL7ghU0SS0XRYQEs9QhVtGiTXhBCDQ7kjDrYF+Cb8yFbUJm7f++N8vJw6Nva95/qc432eD+nofL/v7+ec7/vec/w6X3/O95ybqkKS1Id/NOkGJEnjY+hLUkcMfUnqiKEvSR0x9CWpI6sm3cCrufDCC2t2dnbSbUjSGWXPnj3fraqZ422b6tCfnZ1lfn5+0m1I0hklyXdOtM3pHUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6shUfyJXSze7/d6J7Hf/LddOZL+SlsYjfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdOWnoJ9mZ5HCSvcfZ9tEkleTCtp4kn0qykOSxJJcOjd2S5Kl22bKyP4Yk6VScypH+Z4BNxxaTrAeuAp4ZKl8NbGyXbcBtbez5wM3A24DLgJuTrBmlcUnS0p009KvqK8ALx9l0K/AxoIZqm4HP1sBDwOokFwPvAnZX1QtV9SKwm+O8kEiSTq9lfctmks3Aoar6RpLhTWuBA0PrB1vtRPXj3fc2Bv9L4HWve91y2pu4SX3TpSSdzJLfyE1yDvAbwG+tfDtQVTuqaq6q5mZmZk7HLiSpW8s5e+fngA3AN5LsB9YBjyT5p8AhYP3Q2HWtdqK6JGmMlhz6VfXNqvonVTVbVbMMpmourarngF3AB9pZPJcDL1fVs8D9wFVJ1rQ3cK9qNUnSGJ3KKZt3Av8TeEOSg0m2vsrw+4CngQXgPwH/DqCqXgB+F/hau/xOq0mSxuikb+RW1Q0n2T47tFzAjScYtxPYucT+JEkryE/kSlJHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqyKn8jdydSQ4n2TtU+w9JvpXksSR/mmT10LabkiwkeTLJu4bqm1ptIcn2Ff9JJEkndSpH+p8BNh1T2w28qap+Hvgr4CaAJJcA1wP/vN3mPyY5K8lZwB8BVwOXADe0sZKkMTpp6FfVV4AXjqn9RVUdaasPAeva8mbg81X1t1X1bWABuKxdFqrq6ar6AfD5NlaSNEYrMaf/b4A/a8trgQND2w622onqkqQxGin0k/wmcAT43Mq0A0m2JZlPMr+4uLhSdytJYoTQT/IrwLuB91dVtfIhYP3QsHWtdqL6j6mqHVU1V1VzMzMzy21PknQcywr9JJuAjwHvqapXhjbtAq5PcnaSDcBG4KvA14CNSTYkeQ2DN3t3jda6JGmpVp1sQJI7gXcCFyY5CNzM4Gyds4HdSQAeqqp/W1WPJ/ki8ASDaZ8bq+rv2v38KnA/cBaws6oePw0/jyTpVZw09KvqhuOUb3+V8b8H/N5x6vcB9y2pO0nSivITuZLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHThr6SXYmOZxk71Dt/CS7kzzVrte0epJ8KslCkseSXDp0my1t/FNJtpyeH0eS9GpO5Uj/M8CmY2rbgQeqaiPwQFsHuBrY2C7bgNtg8CIB3Ay8DbgMuPnoC4UkaXxOGvpV9RXghWPKm4E72vIdwHVD9c/WwEPA6iQXA+8CdlfVC1X1IrCbH38hkSSdZsud07+oqp5ty88BF7XltcCBoXEHW+1E9R+TZFuS+STzi4uLy2xPknQ8I7+RW1UF1Ar0cvT+dlTVXFXNzczMrNTdSpJYfug/36ZtaNeHW/0QsH5o3LpWO1FdkjRGyw39XcDRM3C2AHcP1T/QzuK5HHi5TQPdD1yVZE17A/eqVpMkjdGqkw1IcifwTuDCJAcZnIVzC/DFJFuB7wDva8PvA64BFoBXgA8CVNULSX4X+Fob9ztVdeybw5Kk0+ykoV9VN5xg05XHGVvAjSe4n53AziV1J0laUX4iV5I6ctIjfelUzG6/dyL73X/LtRPZr3Sm8khfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktSRkUI/ya8neTzJ3iR3Jnltkg1JHk6ykOQLSV7Txp7d1hfa9tkV+QkkSads2aGfZC3wa8BcVb0JOAu4HvgEcGtVvR54EdjabrIVeLHVb23jJEljNOr0zirgJ5OsAs4BngWuAO5q2+8ArmvLm9s6bfuVSTLi/iVJS7Ds0K+qQ8DvA88wCPuXgT3AS1V1pA07CKxty2uBA+22R9r4C5a7f0nS0o0yvbOGwdH7BuBngHOBTaM2lGRbkvkk84uLi6PenSRpyCjTO78EfLuqFqvqh8CXgHcAq9t0D8A64FBbPgSsB2jbzwO+d+ydVtWOqpqrqrmZmZkR2pMkHWuU0H8GuDzJOW1u/krgCeBB4L1tzBbg7ra8q63Ttn+5qmqE/UuSlmiUOf2HGbwh+wjwzXZfO4CPAx9JssBgzv72dpPbgQta/SPA9hH6liQtw6qTDzmxqroZuPmY8tPAZccZ+zfAL4+yP0nSaPxEriR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktSRkUI/yeokdyX5VpJ9Sd6e5Pwku5M81a7XtLFJ8qkkC0keS3LpyvwIkqRTNeqR/ieBP6+qNwJvBvYB24EHqmoj8EBbB7ga2Ngu24DbRty3JGmJlh36Sc4DfhG4HaCqflBVLwGbgTvasDuA69ryZuCzNfAQsDrJxcvdvyRp6UY50t8ALAJ/nOTrST6d5Fzgoqp6to15DrioLa8FDgzd/mCr/QNJtiWZTzK/uLg4QnuSpGONEvqrgEuB26rqrcBf86OpHACqqoBayp1W1Y6qmququZmZmRHakyQda5TQPwgcrKqH2/pdDF4Enj86bdOuD7fth4D1Q7df12qSpDFZduhX1XPAgSRvaKUrgSeAXcCWVtsC3N2WdwEfaGfxXA68PDQNJEkag1Uj3v7fA59L8hrgaeCDDF5IvphkK/Ad4H1t7H3ANcAC8EobK0kao5FCv6oeBeaOs+nK44wt4MZR9idJGo2fyJWkjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI6M+jUM0kTNbr93Yvvef8u1E9u3tFwe6UtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkdGDv0kZyX5epJ72vqGJA8nWUjyhfZH00lydltfaNtnR923JGlpVuJI/0PAvqH1TwC3VtXrgReBra2+FXix1W9t4yRJYzRS6CdZB1wLfLqtB7gCuKsNuQO4ri1vbuu07Ve28ZKkMRn1SP8PgY8Bf9/WLwBeqqojbf0gsLYtrwUOALTtL7fx/0CSbUnmk8wvLi6O2J4kadiyQz/Ju4HDVbVnBfuhqnZU1VxVzc3MzKzkXUtS90b5auV3AO9Jcg3wWuAfA58EVidZ1Y7m1wGH2vhDwHrgYJJVwHnA90bYvyRpiZZ9pF9VN1XVuqqaBa4HvlxV7wceBN7bhm0B7m7Lu9o6bfuXq6qWu39J0tKdjvP0Pw58JMkCgzn721v9duCCVv8IsP007FuS9CpW5C9nVdVfAn/Zlp8GLjvOmL8Bfnkl9idJWh4/kStJHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdWZG/nCX1aHb7vRPZ7/5brp3IfvX/h2Uf6SdZn+TBJE8keTzJh1r9/CS7kzzVrte0epJ8KslCkseSXLpSP4Qk6dSMcqR/BPhoVT2S5KeBPUl2A78CPFBVtyTZzuAPoH8cuBrY2C5vA25r16fNpI7EJGlaLftIv6qerapH2vL/BvYBa4HNwB1t2B3AdW15M/DZGngIWJ3k4uXuX5K0dCvyRm6SWeCtwMPARVX1bNv0HHBRW14LHBi62cFWO/a+tiWZTzK/uLi4Eu1JkpqRQz/JTwF/Any4qr4/vK2qCqil3F9V7aiquaqam5mZGbU9SdKQkUI/yU8wCPzPVdWXWvn5o9M27fpwqx8C1g/dfF2rSZLGZJSzdwLcDuyrqj8Y2rQL2NKWtwB3D9U/0M7iuRx4eWgaSJI0BqOcvfMO4F8D30zyaKv9BnAL8MUkW4HvAO9r2+4DrgEWgFeAD46wb0nSMiw79KvqvwM5weYrjzO+gBuXuz9J0uj8GgZJ6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6MsofUZE0AbPb753Yvvffcu3E9q2V4ZG+JHXE0Jekjhj6ktSRsYd+kk1JnkyykGT7uPcvST0ba+gnOQv4I+Bq4BLghiSXjLMHSerZuI/0LwMWqurpqvoB8Hlg85h7kKRujfuUzbXAgaH1g8Dbhgck2QZsa6v/J8mTr3J/FwLfXdEOV8a09gX2thzT2heMubd84pSH+jtbupXs62dPtGHqztOvqh3AjlMZm2S+quZOc0tLNq19gb0tx7T2BdPb27T2BdPb27j6Gvf0ziFg/dD6ulaTJI3BuEP/a8DGJBuSvAa4Htg15h4kqVtjnd6pqiNJfhW4HzgL2FlVj49wl6c0DTQB09oX2NtyTGtfML29TWtfML29jaWvVNU49iNJmgJ+IleSOmLoS1JHzsjQn6avckiyM8nhJHuHaucn2Z3kqXa9ZgJ9rU/yYJInkjye5ENT1Ntrk3w1yTdab7/d6huSPNwe1y+0N/vHLslZSb6e5J4p62t/km8meTTJfKtN/PFsfaxOcleSbyXZl+Ttk+4tyRva7+ro5ftJPjzpvob6+/X2/N+b5M727+K0P9fOuNCfwq9y+Ayw6ZjaduCBqtoIPNDWx+0I8NGqugS4HLix/Z6mobe/Ba6oqjcDbwE2Jbkc+ARwa1W9HngR2DqB3gA+BOwbWp+WvgD+ZVW9Zeh87ml4PAE+Cfx5Vb0ReDOD399Ee6uqJ9vv6i3AvwBeAf500n0BJFkL/BowV1VvYnBiy/WM47lWVWfUBXg7cP/Q+k3ATRPuaRbYO7T+JHBxW74YeHIKfm93A/9q2noDzgEeYfDJ7O8Cq473OI+xn3UMguAK4B4g09BX2/d+4MJjahN/PIHzgG/TTgyZpt6GerkK+B/T0hc/+naC8xmcRXkP8K5xPNfOuCN9jv9VDmsn1MuJXFRVz7bl54CLJtlMklngrcDDTElvbQrlUeAwsBv4X8BLVXWkDZnU4/qHwMeAv2/rF0xJXwAF/EWSPe3rSmA6Hs8NwCLwx21a7NNJzp2S3o66HrizLU+8r6o6BPw+8AzwLPAysIcxPNfOxNA/o9TgJXti58Um+SngT4APV9X3h7dNsreq+rsa/Ld7HYMv4nvjJPoYluTdwOGq2jPpXk7gF6rqUgZTmzcm+cXhjRN8PFcBlwK3VdVbgb/mmCmTST7X2rz4e4D/cuy2SfXV3kfYzOAF82eAc/nxaeLT4kwM/TPhqxyeT3IxQLs+PIkmkvwEg8D/XFV9aZp6O6qqXgIeZPBf2dVJjn5gcBKP6zuA9yTZz+AbYK9gMFc96b6A/3d0SFUdZjA3fRnT8XgeBA5W1cNt/S4GLwLT0BsMXiQfqarn2/o09PVLwLerarGqfgh8icHz77Q/187E0D8TvsphF7ClLW9hMJ8+VkkC3A7sq6o/mLLeZpKsbss/yeC9hn0Mwv+9k+qtqm6qqnVVNcvgefXlqnr/pPsCSHJukp8+usxgjnovU/B4VtVzwIEkb2ilK4EnpqG35gZ+NLUD09HXM8DlSc5p/1aP/s5O/3NtUm+sjPgmyDXAXzGYB/7NCfdyJ4M5uR8yOOLZymAe+AHgKeC/AedPoK9fYPDf1seAR9vlminp7eeBr7fe9gK/1er/DPgqsMDgv+JnT/BxfSdwz7T01Xr4Rrs8fvR5Pw2PZ+vjLcB8e0z/K7BmGnpjMG3yPeC8odrE+2p9/DbwrfZv4D8DZ4/juebXMEhSR87E6R1J0jIZ+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakj/xccRKwY0LFbzAAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"num_words = df[\"headline\"].apply(lambda x: len(x.split()))\n",
"plt.hist(num_words)"
]
},
{
"cell_type": "markdown",
"id": "67448b20",
"metadata": {},
"source": [
"Since the Simple Transformers library requires data to be in Pandas DataFrames with at least two columns, named text (of type `str`) and labels (of type `int`), we do the required processing as well before splitting the data into training & evaluation sets."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "103b2d6b",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>labels</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>according to gran , the company has no plans t...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>technopolis plans to develop in stages an area...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>the international electronic industry company ...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>with the new production plant the company woul...</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>according to the company 's updated strategy f...</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" text labels\n",
"0 according to gran , the company has no plans t... 1\n",
"1 technopolis plans to develop in stages an area... 1\n",
"2 the international electronic industry company ... 0\n",
"3 with the new production plant the company woul... 2\n",
"4 according to the company 's updated strategy f... 2"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"labels = {\n",
" \"negative\": 0,\n",
" \"neutral\": 1,\n",
" \"positive\": 2\n",
"}\n",
"\n",
"df_new = pd.DataFrame({\n",
" \"text\": df[\"headline\"].apply(lambda x: x.lower()),\n",
" \"labels\": df[\"sentiment\"].apply(lambda x: labels[x])\n",
"})\n",
"\n",
"df_new.head()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fbcb85fa",
"metadata": {},
"outputs": [],
"source": [
"train_df, eval_df = train_test_split(df_new, test_size=0.2)"
]
},
{
"cell_type": "markdown",
"id": "438eaea4",
"metadata": {},
"source": [
"We also look at the distribution of labels in the training set."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "abf5a6d8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 491., 2306., 1079.]),\n",
" array([0. , 0.66666667, 1.33333333, 2. ]),\n",
" <BarContainer object of 3 artists>)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAPtklEQVR4nO3df6xkZX3H8fenoBh/RJfuSglQF5pNzJJUpBuklrRYGn4Zu5gmBtLWldKsttBo2jRZS1KMxpT+0dqQWhqqGyGxIPVH3SoWt0hjWrPIxSA/VGRFKGyQXVmKEhJazLd/zHPt4Xrv3rl7Z2Z3fd6vZDJnnvOcc77z3LOfOXPOzGyqCklSH37mUBcgSZodQ1+SOmLoS1JHDH1J6oihL0kdOfpQF3Aga9eurfXr1x/qMiTpiHLXXXd9v6rWLTbvsA799evXMzc3d6jLkKQjSpJHlprn6R1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SerIYf2NXP30Wb/t84e6BC3i4avffKhL0Ix4pC9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSR5YN/SQnJbk9yTeS3J/k3a392CQ7kzzY7te09iS5JsnuJPckOX2wri2t/4NJtkzvaUmSFjPOkf7zwJ9U1UbgTODyJBuBbcBtVbUBuK09BrgA2NBuW4FrYfQiAVwFvAE4A7hq/oVCkjQby4Z+VT1eVV9r0z8EvgmcAGwGrm/drgcuatObgRtqZBfwqiTHA+cBO6tqf1U9BewEzp/kk5EkHdiKzuknWQ+8HrgDOK6qHm+zvgcc16ZPAB4dLPZYa1uqfeE2tiaZSzK3b9++lZQnSVrG2KGf5OXAp4D3VNUPhvOqqoCaREFVdV1VbaqqTevWrZvEKiVJzVihn+RFjAL/41X16db8RDttQ7vf29r3ACcNFj+xtS3VLkmakXE+vRPgo8A3q+qvB7N2APOfwNkCfHbQ/vb2KZ4zgafbaaBbgXOTrGkXcM9tbZKkGTl6jD6/AvwucG+Su1vbnwFXAzcnuQx4BHhbm3cLcCGwG3gWuBSgqvYn+QBwZ+v3/qraP4knIUkaz7KhX1X/AWSJ2ecs0r+Ay5dY13Zg+0oKlCRNjt/IlaSOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHlg39JNuT7E1y36DtfUn2JLm73S4czHtvkt1JHkhy3qD9/Na2O8m2yT8VSdJyxjnS/xhw/iLtH6qq09rtFoAkG4GLgVPbMn+X5KgkRwEfBi4ANgKXtL6SpBk6erkOVfXlJOvHXN9m4Kaqeg74bpLdwBlt3u6qegggyU2t7zdWXrIk6WCt5pz+FUnuaad/1rS2E4BHB30ea21Ltf+EJFuTzCWZ27dv3yrKkyQtdLChfy3wC8BpwOPAX02qoKq6rqo2VdWmdevWTWq1kiTGOL2zmKp6Yn46yT8An2sP9wAnDbqe2No4QLskaUYO6kg/yfGDh28F5j/ZswO4OMkxSU4GNgBfBe4ENiQ5OcmLGV3s3XHwZUuSDsayR/pJbgTOBtYmeQy4Cjg7yWlAAQ8D7wSoqvuT3MzoAu3zwOVV9aO2niuAW4GjgO1Vdf+kn4wk6cDG+fTOJYs0f/QA/T8IfHCR9luAW1ZUnSRpovxGriR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0JakjRx/qAiQdeuu3ff5Ql6AFHr76zVNZr0f6ktQRQ1+SOmLoS1JHlg39JNuT7E1y36Dt2CQ7kzzY7te09iS5JsnuJPckOX2wzJbW/8EkW6bzdCRJBzLOkf7HgPMXtG0DbquqDcBt7THABcCGdtsKXAujFwngKuANwBnAVfMvFJKk2Vk29Kvqy8D+Bc2bgevb9PXARYP2G2pkF/CqJMcD5wE7q2p/VT0F7OQnX0gkSVN2sOf0j6uqx9v094Dj2vQJwKODfo+1tqXaf0KSrUnmkszt27fvIMuTJC1m1Rdyq6qAmkAt8+u7rqo2VdWmdevWTWq1kiQOPvSfaKdtaPd7W/se4KRBvxNb21LtkqQZOtjQ3wHMfwJnC/DZQfvb26d4zgSebqeBbgXOTbKmXcA9t7VJkmZo2Z9hSHIjcDawNsljjD6FczVwc5LLgEeAt7XutwAXAruBZ4FLAapqf5IPAHe2fu+vqoUXhyVJU7Zs6FfVJUvMOmeRvgVcvsR6tgPbV1SdJGmi/EauJHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjqyqtBP8nCSe5PcnWSutR2bZGeSB9v9mtaeJNck2Z3kniSnT+IJSJLGN4kj/TdV1WlVtak93gbcVlUbgNvaY4ALgA3tthW4dgLbliStwDRO72wGrm/T1wMXDdpvqJFdwKuSHD+F7UuSlrDa0C/gi0nuSrK1tR1XVY+36e8Bx7XpE4BHB8s+1tpeIMnWJHNJ5vbt27fK8iRJQ0evcvmzqmpPklcDO5N8azizqipJrWSFVXUdcB3Apk2bVrTsQuu3fX41i0vST51VHelX1Z52vxf4DHAG8MT8aZt2v7d13wOcNFj8xNYmSZqRgw79JC9L8or5aeBc4D5gB7ClddsCfLZN7wDe3j7Fcybw9OA0kCRpBlZzeuc44DNJ5tfzj1X1r0nuBG5OchnwCPC21v8W4EJgN/AscOkqti1JOggHHfpV9RDwukXanwTOWaS9gMsPdnuSpNXzG7mS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6sjMQz/J+UkeSLI7ybZZb1+SejbT0E9yFPBh4AJgI3BJko2zrEGSejbrI/0zgN1V9VBV/Q9wE7B5xjVIUreOnvH2TgAeHTx+DHjDsEOSrcDW9vCZJA+sYntrge+vYvlpsa6Vsa6Vsa6VOSzryl+uqq7XLDVj1qG/rKq6DrhuEutKMldVmyaxrkmyrpWxrpWxrpXpra5Zn97ZA5w0eHxia5MkzcCsQ/9OYEOSk5O8GLgY2DHjGiSpWzM9vVNVzye5ArgVOArYXlX3T3GTEzlNNAXWtTLWtTLWtTJd1ZWqmsZ6JUmHIb+RK0kdMfQlqSNHZOgv91MOSY5J8ok2/44k6wfz3tvaH0hy3ozr+uMk30hyT5LbkrxmMO9HSe5ut4le3B6jrnck2TfY/u8P5m1J8mC7bZlxXR8a1PTtJP89mDfN8dqeZG+S+5aYnyTXtLrvSXL6YN40x2u5un671XNvkq8ked1g3sOt/e4kczOu6+wkTw/+Xn8+mDe1n2UZo64/HdR0X9unjm3zpjleJyW5vWXB/UnevUif6e1jVXVE3RhdAP4OcArwYuDrwMYFff4Q+Ps2fTHwiTa9sfU/Bji5reeoGdb1JuClbfoP5utqj585hOP1DuBvF1n2WOChdr+mTa+ZVV0L+v8Rowv/Ux2vtu5fBU4H7lti/oXAF4AAZwJ3THu8xqzrjfPbY/RTJ3cM5j0MrD1E43U28LnV7gOTrmtB37cAX5rReB0PnN6mXwF8e5F/k1Pbx47EI/1xfsphM3B9m/4kcE6StPabquq5qvousLutbyZ1VdXtVfVse7iL0fcUpm01P31xHrCzqvZX1VPATuD8Q1TXJcCNE9r2AVXVl4H9B+iyGbihRnYBr0pyPNMdr2XrqqqvtO3C7PavccZrKVP9WZYV1jXL/evxqvpam/4h8E1Gv1YwNLV97EgM/cV+ymHhgP24T1U9DzwN/OyYy06zrqHLGL2Sz3tJkrkku5JcNKGaVlLXb7W3kZ9MMv8FusNivNppsJOBLw2apzVe41iq9mmO10ot3L8K+GKSuzL6qZNZ++UkX0/yhSSntrbDYrySvJRRcH5q0DyT8cro1PPrgTsWzJraPnbY/QxDD5L8DrAJ+LVB82uqak+SU4AvJbm3qr4zo5L+Bbixqp5L8k5G75J+fUbbHsfFwCer6keDtkM5Xoe1JG9iFPpnDZrPauP1amBnkm+1I+FZ+Bqjv9czSS4E/hnYMKNtj+MtwH9W1fBdwdTHK8nLGb3QvKeqfjDJdR/IkXikP85POfy4T5KjgVcCT4657DTrIslvAFcCv1lVz823V9Wedv8Q8O+MXv1nUldVPTmo5SPAL4277DTrGriYBW+9pzhe41iq9kP+MyNJfpHR33BzVT053z4Yr73AZ5jcac1lVdUPquqZNn0L8KIkazkMxqs50P41lfFK8iJGgf/xqvr0Il2mt49N40LFNG+M3p08xOjt/vzFn1MX9LmcF17IvblNn8oLL+Q+xOQu5I5T1+sZXbjasKB9DXBMm14LPMiELmiNWdfxg+m3Arvq/y8afbfVt6ZNHzurulq/1zK6qJZZjNdgG+tZ+sLkm3nhRbavTnu8xqzr5xldp3rjgvaXAa8YTH8FOH+Gdf3c/N+PUXj+Vxu7sfaBadXV5r+S0Xn/l81qvNpzvwH4mwP0mdo+NrHBneWN0ZXtbzMK0Ctb2/sZHT0DvAT4p/YP4KvAKYNlr2zLPQBcMOO6/g14Ari73Xa09jcC97ad/l7gshnX9RfA/W37twOvHSz7e20cdwOXzrKu9vh9wNULlpv2eN0IPA78L6NzppcB7wLe1eaH0X8G9J22/U0zGq/l6voI8NRg/5pr7ae0sfp6+ztfOeO6rhjsX7sYvCgttg/Mqq7W5x2MPtwxXG7a43UWo2sG9wz+VhfOah/zZxgkqSNH4jl9SdJBMvQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSR/4Puv/ft8kFpX0AAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.hist(train_df[\"labels\"], bins=3)"
]
},
{
"cell_type": "markdown",
"id": "1c86a57c",
"metadata": {},
"source": [
"## Logging Experiment Details with MLFoundry"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d0abe5f1",
"metadata": {},
"outputs": [],
"source": [
"# Initialize the API\n",
"mlf_api = mlf.get_client()"
]
},
{
"cell_type": "markdown",
"id": "ebde12d3",
"metadata": {},
"source": [
"### Training a Simple Transformers Model and Track Relevant Information\n",
"\n",
"We create a function to create, train and evaluate out Simple Transformers model. It also accept an MLFoundry run as input and log all the required information (parameters, dataset, metrics and dataset stats)."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "71006ddf",
"metadata": {},
"outputs": [],
"source": [
"def trainModel(model_params, training_args, run, run_name):\n",
" run.log_dataset(train_df, data_slice=mlf.DataSlice.TRAIN, fileformat=mlf.FileFormat.CSV)\n",
" run.log_dataset(eval_df, data_slice=mlf.DataSlice.TEST, fileformat=mlf.FileFormat.CSV)\n",
" run.log_params({**model_params, **training_args})\n",
" \n",
" training_args['output_dir'] = os.path.join('outputs', run_name)\n",
" training_args['overwrite_output_dir'] = True\n",
"\n",
" model = ClassificationModel(\n",
" model_params['model_type'], \n",
" model_params['model_name'], \n",
" num_labels=3, \n",
" args=training_args\n",
" )\n",
" \n",
" print(\"Training the model...\")\n",
" model.train_model(train_df)\n",
" \n",
" def f1_multiclass(labels, preds):\n",
" return f1_score(labels, preds, average='micro')\n",
" \n",
" print(\"Evaluating the model...\")\n",
" result, model_outputs, wrong_predictions = model.eval_model(eval_df, f1=f1_multiclass, acc=accuracy_score)\n",
" \n",
" run.log_metrics(result)\n",
" \n",
" def labelToSentiment(label):\n",
" if label == 0: \n",
" return \"negative\"\n",
" elif label == 1: \n",
" return \"neutral\"\n",
" else:\n",
" return \"positive\"\n",
" \n",
" eval_df_toLog = pd.DataFrame({\n",
" \"headline\": eval_df.text,\n",
" \"sentiment\": [labelToSentiment(label) for label in eval_df.labels.to_list()],\n",
" \"prediction\": [labelToSentiment(np.argmax(x)) for x in model_outputs]\n",
" })\n",
" run.log_dataset_stats(\n",
" eval_df_toLog,\n",
" data_slice=mlf.DataSlice.TEST,\n",
" data_schema=mlf.Schema(\n",
" feature_column_names=['headline'],\n",
" prediction_column_name='prediction',\n",
" actual_column_name='sentiment'\n",
" ),\n",
" model_type=mlf.ModelType.MULTICLASS_CLASSIFICATION\n",
" )\n",
" \n",
" return model, result, model_outputs, wrong_predictions\n",
" "
]
},
{
"cell_type": "markdown",
"id": "dedfb435",
"metadata": {},
"source": [
"### Training a BERT Model for FSA (for 3 epochs)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7a027a92",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:48:19.006 WARNING mlfoundry.mlfoundry_run: failed to log git info due to \n",
"2022-02-25 14:48:19.006 INFO mlfoundry.mlfoundry_api: Run is created with id 4db8eebdc8dc421181c40dfe92682c75 and name bert_3epochs\n",
"2022-02-25 14:48:19.285 INFO mlfoundry.mlfoundry_run: Parameters logged successfully\n",
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training the model...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:48:27.253 INFO simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "332a4e5dd8464a8cb3b34749a4582d5d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3876 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:48:32.895 INFO simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_train_bert_60_3_2\n",
"/data/tezansahu/venvs/simpletransformers/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b759f26907334e9699fed047793f3b10",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ca7cb81ca85d4b9ebc578c367cd0b74f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 0 of 3: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2d1ae6de661548ffa954bbd13bab69db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 1 of 3: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2e17c574a94641afb75490ea75e76da4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 2 of 3: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:49:55.687 INFO simpletransformers.classification.classification_model: Training of bert model complete. Saved to outputs/bert_3epochs.\n",
"2022-02-25 14:49:55.693 INFO simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating the model...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e64442c8766b46349c9b213c8ce63aea",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/970 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:50:01.778 INFO simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_60_3_2\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e2cd6c36723b4277bdfef3b502f6d41d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Evaluation: 0%| | 0/122 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:50:03.382 INFO simpletransformers.classification.classification_model: {'mcc': 0.6659641718576687, 'f1': 0.8164948453608247, 'acc': 0.8164948453608247, 'eval_loss': 0.4939497337966669}\n",
"2022-02-25 14:50:03.450 INFO mlfoundry.mlfoundry_run: Metrics logged successfully\n",
"/data/tezansahu/venvs/simpletransformers/lib/python3.8/site-packages/mlfoundry/mlfoundry_run.py:314: FutureWarning: Passing a set as an indexer is deprecated and will raise in a future version. Use a list instead.\n",
" self.__compute_whylogs_stats(df[set(data_schema.feature_column_names)])\n",
"2022-02-25 14:50:03.553 INFO whylogs.app.config: No config file loaded\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARN: Missing config\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/data/tezansahu/venvs/simpletransformers/lib/python3.8/site-packages/mlfoundry/mlfoundry_run.py:404: FutureWarning: Passing a set as an indexer is deprecated and will raise in a future version. Use a list instead.\n",
" df[set(data_schema.feature_column_names)],\n",
"2022-02-25 14:50:05.259 INFO mlfoundry.mlfoundry_run: Metrics logged successfully\n",
"2022-02-25 14:50:05.386 INFO mlfoundry.mlfoundry_run: Dataset stats have been successfully computed and logged\n"
]
}
],
"source": [
"training_args = {\n",
" 'train_batch_size':8, \n",
" 'gradient_accumulation_steps':16, \n",
" 'learning_rate': 2e-5, \n",
" 'num_train_epochs': 3, \n",
" 'max_seq_length': 60 # based on the histogram of number of words present in the headline\n",
"}\n",
"\n",
"model_params = {\n",
" 'model_type': 'bert',\n",
" 'model_name': 'bert-base-uncased'\n",
"}\n",
"\n",
"mlf_run = mlf_api.create_run(project_name='financial-sentiment-analysis', run_name='bert_3epochs')\n",
"model, result, model_outputs, wrong_predictions = trainModel(model_params, training_args, mlf_run, 'bert_3epochs')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a2134f52-0602-41c4-887e-e8dae7f46ccf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mcc': 0.6659641718576687,\n",
" 'f1': 0.8164948453608247,\n",
" 'acc': 0.8164948453608247,\n",
" 'eval_loss': 0.4939497337966669}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "markdown",
"id": "6c7aeb3a",
"metadata": {},
"source": [
"## Efficient Tracking and Comparison of Multiple Experiments with MLFoundry"
]
},
{
"cell_type": "markdown",
"id": "c0ff4f1d",
"metadata": {},
"source": [
"To demonstrate experimentation by changing hyperparameters, we create a new run called `bert_5epochs` and pass it to the `trainModel()` function along with the `training_args` dictionary (this time, the value of `num_train_epochs` in this dictionary is set to `5` instead of `3`) and `model_params` (no change)."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "fa70edb5-a0f3-4abf-b72e-a7d81c9cc9ca",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:50:05.562 WARNING mlfoundry.mlfoundry_run: failed to log git info due to \n",
"2022-02-25 14:50:05.563 INFO mlfoundry.mlfoundry_api: Run is created with id a3403c310d6741a9af20160c667c2d5a and name bert_5epochs\n",
"2022-02-25 14:50:05.804 INFO mlfoundry.mlfoundry_run: Parameters logged successfully\n",
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n",
"- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2022-02-25 14:50:10.773 INFO simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training the model...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "338310ed832649f7baaf7bbe960fb8cd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3876 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:50:17.694 INFO simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_train_bert_60_3_2\n",
"/data/tezansahu/venvs/simpletransformers/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c1fce46951b846f4bf7ea16bc2d6ced2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "73d8f73516c7414eb71c8f43bc8e7580",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 0 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "005e1562ad3547e09193859882f86dda",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 1 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e7eec921e32e47d687f7408971f000a5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 2 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "df192596e84642d5a0fe5b3657e48a31",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 3 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "67cf8679e44342c98a3df60e4ab39e48",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 4 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:52:43.013 INFO simpletransformers.classification.classification_model: Training of bert model complete. Saved to outputs/bert_5epochs.\n",
"2022-02-25 14:52:43.021 INFO simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating the model...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "03450e9b85a144aaa3e22edfdbf31680",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/970 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:52:49.967 INFO simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_bert_60_3_2\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "26e857edca1940a695c46df8ad70c9f3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Evaluation: 0%| | 0/122 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:52:51.601 INFO simpletransformers.classification.classification_model: {'mcc': 0.6977812725208506, 'f1': 0.831958762886598, 'acc': 0.831958762886598, 'eval_loss': 0.41605544481121126}\n",
"2022-02-25 14:52:51.667 INFO mlfoundry.mlfoundry_run: Metrics logged successfully\n",
"/data/tezansahu/venvs/simpletransformers/lib/python3.8/site-packages/mlfoundry/mlfoundry_run.py:314: FutureWarning: Passing a set as an indexer is deprecated and will raise in a future version. Use a list instead.\n",
" self.__compute_whylogs_stats(df[set(data_schema.feature_column_names)])\n",
"/data/tezansahu/venvs/simpletransformers/lib/python3.8/site-packages/mlfoundry/mlfoundry_run.py:404: FutureWarning: Passing a set as an indexer is deprecated and will raise in a future version. Use a list instead.\n",
" df[set(data_schema.feature_column_names)],\n",
"2022-02-25 14:52:53.352 INFO mlfoundry.mlfoundry_run: Metrics logged successfully\n",
"2022-02-25 14:52:53.472 INFO mlfoundry.mlfoundry_run: Dataset stats have been successfully computed and logged\n"
]
}
],
"source": [
"training_args = {\n",
" 'train_batch_size':8, \n",
" 'gradient_accumulation_steps':16, \n",
" 'learning_rate': 2e-5, \n",
" 'num_train_epochs': 5, \n",
" 'max_seq_length': 60 # based on the histogram of number of words present in the headline\n",
"}\n",
"\n",
"model_params = {\n",
" 'model_type': 'bert',\n",
" 'model_name': 'bert-base-uncased'\n",
"}\n",
"\n",
"mlf_run_2 = mlf_api.create_run(project_name='financial-sentiment-analysis', run_name='bert_5epochs')\n",
"model, result, model_outputs, wrong_predictions = trainModel(model_params, training_args, mlf_run_2, 'bert_5epochs')"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "ca5f0c62-49cb-4076-82a3-324c3ba22b68",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mcc': 0.6977812725208506,\n",
" 'f1': 0.831958762886598,\n",
" 'acc': 0.831958762886598,\n",
" 'eval_loss': 0.41605544481121126}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "markdown",
"id": "f7b6db8c",
"metadata": {},
"source": [
"To illustrate the use of different transformer architectures for experimentation, we create a new run called `roberta`. This time, we keep the `training_args` unchanged from the previous run (with 5 epochs) and change the `model_params` to use the weights of pre-trained RoBERTa instead of BERT."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "36c2037b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:52:53.717 WARNING mlfoundry.mlfoundry_run: failed to log git info due to \n",
"2022-02-25 14:52:53.718 INFO mlfoundry.mlfoundry_api: Run is created with id ac89ad4f2a5b4f06883658937e4b1067 and name roberta\n",
"2022-02-25 14:52:53.966 INFO mlfoundry.mlfoundry_run: Parameters logged successfully\n",
"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.bias', 'lm_head.dense.bias', 'lm_head.dense.weight']\n",
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"2022-02-25 14:53:22.528 INFO simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training the model...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cbf8070c7584495ba6a6c19e68f8215d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3876 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:53:30.085 INFO simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_train_roberta_60_3_2\n",
"/data/tezansahu/venvs/simpletransformers/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3daa09d59d034c7182a01164b4857e7b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f1d966d3c2ee4bd09b60871ac0d5249a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 0 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0257333b7c5f42b8bc99468ae96e1d4e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 1 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9fe228a903b44447b9f800ed02a6a6cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 2 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e9983426df834598b1ad04a8bdfdded4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 3 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4ae7ea3c40494caa8e8ba4bd45b5013b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Epoch 4 of 5: 0%| | 0/485 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:56:05.698 INFO simpletransformers.classification.classification_model: Training of roberta model complete. Saved to outputs/roberta.\n",
"2022-02-25 14:56:05.706 INFO simpletransformers.classification.classification_utils: Converting to features started. Cache is not used.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluating the model...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0a17f8f63f3545adadb705700953f6cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/970 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:56:13.060 INFO simpletransformers.classification.classification_utils: Saving features into cached file cache_dir/cached_dev_roberta_60_3_2\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d0afb9b2c04847b2b5d31cefee8a8465",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Running Evaluation: 0%| | 0/122 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-25 14:56:14.731 INFO simpletransformers.classification.classification_model: {'mcc': 0.7625885046479652, 'f1': 0.8659793814432989, 'acc': 0.865979381443299, 'eval_loss': 0.35005076205144164}\n",
"2022-02-25 14:56:14.800 INFO mlfoundry.mlfoundry_run: Metrics logged successfully\n",
"/data/tezansahu/venvs/simpletransformers/lib/python3.8/site-packages/mlfoundry/mlfoundry_run.py:314: FutureWarning: Passing a set as an indexer is deprecated and will raise in a future version. Use a list instead.\n",
" self.__compute_whylogs_stats(df[set(data_schema.feature_column_names)])\n",
"/data/tezansahu/venvs/simpletransformers/lib/python3.8/site-packages/mlfoundry/mlfoundry_run.py:404: FutureWarning: Passing a set as an indexer is deprecated and will raise in a future version. Use a list instead.\n",
" df[set(data_schema.feature_column_names)],\n",
"2022-02-25 14:56:16.480 INFO mlfoundry.mlfoundry_run: Metrics logged successfully\n",
"2022-02-25 14:56:16.594 INFO mlfoundry.mlfoundry_run: Dataset stats have been successfully computed and logged\n"
]
}
],
"source": [
"training_args = {\n",
" 'train_batch_size':8, \n",
" 'gradient_accumulation_steps':16, \n",
" 'learning_rate': 2e-5, \n",
" 'num_train_epochs': 5, \n",
" 'max_seq_length': 60 # based on the histogram of number of words present in the headline\n",
"}\n",
"\n",
"model_params = {\n",
" 'model_type': 'roberta',\n",
" 'model_name': 'roberta-base'\n",
"}\n",
"\n",
"mlf_run_3 = mlf_api.create_run(project_name='financial-sentiment-analysis', run_name='roberta')\n",
"model, result, model_outputs, wrong_predictions = trainModel(model_params, training_args, mlf_run_3, 'roberta')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8a4bb3ba-60e7-47ba-9d0f-cd0d1a801cf0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mcc': 0.7625885046479652,\n",
" 'f1': 0.8659793814432989,\n",
" 'acc': 0.865979381443299,\n",
" 'eval_loss': 0.35005076205144164}"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "f02d1468-2829-435a-971f-ce0f846c30b7",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>run_id</th>\n",
" <th>run_name</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ac89ad4f2a5b4f06883658937e4b1067</td>\n",
" <td>roberta</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>a3403c310d6741a9af20160c667c2d5a</td>\n",
" <td>bert_5epochs</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4db8eebdc8dc421181c40dfe92682c75</td>\n",
" <td>bert_3epochs</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" run_id run_name\n",
"0 ac89ad4f2a5b4f06883658937e4b1067 roberta\n",
"1 a3403c310d6741a9af20160c667c2d5a bert_5epochs\n",
"2 4db8eebdc8dc421181c40dfe92682c75 bert_3epochs"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mlf_api.get_all_runs(\"financial-sentiment-analysis\")"
]
},
{
"cell_type": "markdown",
"id": "978e6ff3",
"metadata": {},
"source": [
"## Model Demo using MLFoundry Web App"
]
},
{
"cell_type": "markdown",
"id": "d70ec8f3",
"metadata": {},
"source": [
"Create a standalone web app file that will be registered with the appropriate run. In this file, we first need to write a function that can load a saved Simple Transformer model and predict the sentiment using it for an input financial news headline. Then, we need to initialize the MLFoundry client, create a run and call `webapp()` from the run by supplying it with the prediction function, type(s) of inputs and outputs (here, we have just one input and one output, each of type `text`). This defines a model demo interface on the dashboard."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "deb251dd-c226-4f52-83f3-97e597debb81",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overwriting streamlit_roberta.py\n"
]
}
],
"source": [
"%%writefile streamlit_roberta.py\n",
"\n",
"import mlfoundry as mlf\n",
"import random\n",
"import os\n",
"from simpletransformers.classification import ClassificationModel\n",
"\n",
"def predict_model(model_params, input_headline):\n",
" try:\n",
" class_name_map = {\n",
" 0: \"negative\", \n",
" 1: \"neutral\",\n",
" 2: \"positive\"\n",
" }\n",
" \n",
" model_loaded = ClassificationModel(model_params[\"model_type\"], os.path.join(\"outputs\", model_params[\"model_name\"]))\n",
" \n",
" model_output = model_loaded.predict([input_headline.lower()])[0][0]\n",
" return class_name_map[model_output]\n",
" except Exception:\n",
" return random.choice([\"negative\", \"neutral\", \"positive\"])\n",
"\n",
"def predict_roberta(input_headline: str) -> str: \n",
" model_params = {\n",
" \"model_type\": \"roberta\",\n",
" \"model_name\": \"roberta\"\n",
" }\n",
" return predict_model(model_params, input_headline)\n",
"\n",
"mlf_api = mlf.get_client()\n",
"mlf_run = mlf_api.create_run(project_name=\"financial_sentiment_analysis_webapp\")\n",
"raw_in, raw_out = mlf_run.webapp(\n",
" fn=predict_roberta, inputs=\"text\", outputs=\"text\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f27fa125",
"metadata": {},
"source": [
"Register this file with mlf_run_3, which tracks all the information for our RoBERTa-based experiment."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "d66f2781-c16a-440e-a1f5-2a38c431419f",
"metadata": {},
"outputs": [],
"source": [
"mlf_run_3.log_webapp_file('streamlit_roberta.py')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e523443c-28ee-49f8-9ce3-98e61ddda1c1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment