Created
November 2, 2021 13:20
-
-
Save jot240/fdd1c5eaf1300795b0884c0c2017c3c4 to your computer and use it in GitHub Desktop.
Basic TF_IDF search pyspark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "ca1ac778", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<SparkContext master=local appName=appName>\n" | |
] | |
} | |
], | |
"source": [ | |
"import findspark\n", | |
"findspark.init()\n", | |
"findspark.find()\n", | |
"import pyspark\n", | |
"from pyspark import SparkContext, SparkConf\n", | |
"from pyspark.sql import SparkSession\n", | |
"try:\n", | |
" sc\n", | |
"except NameError:\n", | |
" conf = pyspark.SparkConf().setAppName('appName').setMaster('local')\n", | |
" sc = pyspark.SparkContext(conf=conf)\n", | |
"else:\n", | |
" print(\"getting rid of old spark context\")\n", | |
" sc.stop()\n", | |
" conf = pyspark.SparkConf().setAppName('appName').setMaster('local')\n", | |
" sc = pyspark.SparkContext(conf=conf)\n", | |
"\n", | |
" \n", | |
"\n", | |
"print(sc)\n", | |
"spark = SparkSession(sc)\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "a1257c3c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pyspark.sql.types import StringType\n", | |
"from pyspark.sql import Row\n", | |
"from pyspark.sql.functions import lit, col, log\n", | |
"import glob, os\n", | |
"import re\n", | |
"#takes text and tokenizes it, removing non-words\n", | |
"def tokenize_text(txt):\n", | |
" regex = re.compile('[^a-zA-Z]')\n", | |
" txt = txt.flatMap( lambda l: l.split(' '))\n", | |
" txt = txt.map(lambda l : regex.sub('',l).lower())\n", | |
" return txt\n", | |
"\n", | |
"def make_entry(file_path):\n", | |
" txt = sc.textFile(file_path)\n", | |
" article_title = file_path.rsplit('\\\\', 1)[-1].rsplit('.',1)[0]\n", | |
" tok_txt = tokenize_text(txt)\n", | |
" row = Row(\"word\") \n", | |
" tok_df = tok_txt.map(row).toDF().groupBy(\"word\").count()\n", | |
" tok_df = tok_df.withColumn(\"title\", lit(article_title))\n", | |
" tok_df =tok_df.join(stop_df, [tok_df.word == stop_df.word], how = 'left_anti')\n", | |
" tok_df = tok_df.withColumn(\"total_words\", lit(tok_df.count()))\n", | |
" words_df = tok_df.withColumn(\"norm_count\", col(\"count\")/col(\"total_words\"))\n", | |
" return words_df\n", | |
"\n", | |
"\n", | |
"def create_tokens_df(folder_path, stop_df, file_type = \"*.txt\"):\n", | |
" article_list = []\n", | |
" try:\n", | |
" first_article = glob.glob(folder_path + file_type)[0] #take first text file to make the df\n", | |
" except IndexError:\n", | |
" print(\"IndexError: No files of given type in directory\")\n", | |
" return \n", | |
" words_df = make_entry(first_article)\n", | |
" #Add other text files to dataframe\n", | |
" for file in glob.glob(folder_path + file_type)[1:]: \n", | |
" join_df = make_entry(file)\n", | |
" words_df = words_df.union(join_df)\n", | |
" return words_df\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bf62d799", | |
"metadata": {}, | |
"source": [ | |
"Create dataframe with unique words in each article and their count" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "cf34117f", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"C:\\Users\\Letshopethisworks\\Documents\\Coding projects\\Big Data\n" | |
] | |
} | |
], | |
"source": [ | |
"print(os.getcwd())\n", | |
"folder_string = 'books/test/'\n", | |
"#create stop words dictionary\n", | |
"stop_words = 'books/stopwords.txt'\n", | |
"stop_txt = tokenize_text(sc.textFile(stop_words))\n", | |
"row = Row(\"word\") \n", | |
"stop_df = stop_txt.map(row).toDF()\n", | |
"newRow = spark.createDataFrame([''],StringType()) #also filter out empty strings\n", | |
"stop_df = stop_df.union(newRow)\n", | |
"#get dictionary of tokens\n", | |
"tokens_df = create_tokens_df(folder_string, stop_df)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "580b82c6", | |
"metadata": {}, | |
"source": [ | |
"Calulate the TF-IDF score for the applicable articles" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "417e5d7c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Get N\n", | |
"from pyspark.sql.functions import countDistinct\n", | |
"df2 = tokens_df.select(countDistinct(\"title\"))\n", | |
"N = df2.collect()[0][0]\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "2dc61f47", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4\n" | |
] | |
} | |
], | |
"source": [ | |
"import numpy as np\n", | |
"#get df_t\n", | |
"search_term = \"stocks\"\n", | |
"search_df= tokens_df.filter(tokens_df.word == search_term)\n", | |
"df_t = search_df.select(countDistinct(\"title\")).collect()[0][0]\n", | |
"print(df_t)\n", | |
"#Calculate w_t_d\n", | |
"const = np.log(N/df_t)\n", | |
"search_df = search_df.withColumn(\"w_td\", search_df.norm_count*const)\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "9e108058", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"+------+-----+-----+-----------+--------------------+--------------------+\n", | |
"| word|count|title|total_words| norm_count| w_td|\n", | |
"+------+-----+-----+-----------+--------------------+--------------------+\n", | |
"|stocks| 2| 144| 57| 0.03508771929824561| 0.10682534869204992|\n", | |
"|stocks| 2| 152| 108|0.018518518518518517| 0.05638004514302635|\n", | |
"|stocks| 2| 031| 115|0.017391304347826087|0.052948216308233445|\n", | |
"|stocks| 1| 134| 135|0.007407407407407408|0.022552018057210542|\n", | |
"+------+-----+-----+-----------+--------------------+--------------------+\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"search_df.orderBy(col(\"w_td\").desc()).show()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment