Skip to content

Instantly share code, notes, and snippets.

@smothiki
Last active April 19, 2024 03:39
Show Gist options
  • Save smothiki/e265b5b1173fc261cfe5a99f85c0026f to your computer and use it in GitHub Desktop.
Save smothiki/e265b5b1173fc261cfe5a99f85c0026f to your computer and use it in GitHub Desktop.
import logging
import logging as log
from pathlib import Path
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType
from pyspark.sql.functions import udf
from pyspark.sql.types import StructField
from pyspark.sql.functions import lit
from pyspark.sql.types import StringType
import cml.data_v1 as cmldata
import os
import uuid
import os
# Tika is a library that allows you to extract text from a file in one of the many formats it supports
from tika import parser, detector, language
os.environ['TIKA_CLIENT_ONLY']="True"
os.environ['TIKA_SERVER_ENDPOINT']='http://localhost:9998'
def pdfcontent(file):
# os.environ["TIKA_SERVER_JAR"]='https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/2.6.0/tika-server-standard-2.6.0.jar'
# tika.initVM()
return parser.from_file(file)["content"]
#
#def pdfcontent(file):
# from pypdf import PdfReader
# reader = PdfReader(file)
# number_of_pages = len(reader.pages)
# text=''
# for i in range(0,number_of_pages):
# page = reader.pages[i]
# text += page.extract_text()
# return text
class SVError(Exception):
def __init__(self, message=None):
"""
Constructor
:param message: the error message.
"""
self.message = message
class TextExtractionError(SVError):
"""
Exception raised when a text could not be extracted from a file
:param path: name of the file from which text could not be extracted
:param message: explanation of the error
:rtype: object
"""
def __init__(self, path, message=None):
super().__init__(message=message)
self.path = path
if not self.message:
self.message = (
f"Could not extract text from the file. "
f"Check if it a Tika-supported document format: {self.path}"
)
class TextExtraction:
def __init__(self, path: str):
self.path_ = path
self.subject = os.path.basename(os.path.dirname(path))
self.text_ = self.to_text(self.path_)
self.doctype_ = self.document_type(self.path_)
self.language_ = language.from_buffer(self.text_)
self.id_ = str(uuid.uuid4())
@staticmethod
def to_text(path: str) -> str:
"""
Extracts plain-text from a file, in one of the Tika-supported formats
:param path: path to the document file
:return: text from document file
"""
# Preconditions check for an existing, readable, non-empty file
# check_valid_file(path)
log.info(f"Parsing file: {path}")
try:
text_content: str = pdfcontent(path)
if text_content is None:
raise TextExtractionError(
path=path, message=f"No content found in file: {path}"
)
return text_content.strip()
except Exception as e:
raise TextExtractionError(path, str(e))
@staticmethod
def document_type(path: str) -> str:
"""
Determines the MIME type of the file
:param path: the filesystem path to the document.
:return: the MIME-type, such as "application/pdf"
"""
# Preconditions check for an existing, readable, non-empty file
# check_valid_file(path)
return detector.from_file(path)
def __repr__(self):
limit: int = min(100, len(self.text_))
return f" Document type: {self.doctype_}\n Language: {self.language_}\n Text: {self.text_[:limit]}..."
#class TextExtractionJob(BootcampComputeJob):
class TextExtractionJob():
"""
This class is the entry point for the text extraction job.
Given a directory of documents, it read all the files in the directory,
and all the subdirectories recursively, and extracts plain text from each file.
It then stores the extracted text in a database table.
"""
def __init__(self):
self.job_name = "TextExtractionJob"
logging.info(f'Initializing {self.job_name} job')
CONNECTION_NAME = "eng-ml-dev-env-aws-dl"
conn = cmldata.get_connection(CONNECTION_NAME)
self.spark = conn.get_spark_session()
self.text_struc = StructType([
StructField("path", StringType(), True),
StructField("subject", StringType(), True),
StructField("text", StringType(), True),
StructField("doctype", StringType(), True),
StructField("language", StringType(), True),
StructField("uuid", StringType(), True)
])
@staticmethod
def _udf_text_extraction(path):
"""
A function that extracts text, its document-type and language
from a file, given its path.
"""
extraction = TextExtraction(path)
return {"path": path,
"subject": extraction.subject,
"text": extraction.text_,
"doctype": extraction.doctype_,
"language": extraction.language_,
"uuid": extraction.id_
}
def run(self) -> None:
"""
This method is the entry point for the compute job where
the text is extracted from the documents, and stored in a database table.
:return: None
"""
logging.info(f'Running {self.job_name} job')
files_df = self._list_documents()
logging.info(f'Extracting text from {files_df.count()} files')
df = self._extract_text(files_df)
# self._persist(df=df, table='DOCUMENT')
def _extract_text(self, files_df: DataFrame) -> DataFrame:
"""
Extracts plain-text from each file in the DataFrame
:param files_df: DataFrame containing the list of files
:return: DataFrame containing the extracted text
"""
# Step 1: Extract text from each file
files_df = files_df.withColumn('extract',
udf(self._udf_text_extraction,
self.text_struc)(files_df.value))
# Step 2: Extract the columns from the nested structure
df = files_df.select('extract.language',
'extract.path',
'extract.subject',
'extract.doctype',
'extract.text',
'extract.uuid')
# Step 3: Rename the columns
df = df.withColumnRenamed("language", "LANGUAGE") \
.withColumnRenamed("uuid", "UUID") \
.withColumnRenamed("path", "PATH") \
.withColumnRenamed("subject", "SUBJECT") \
.withColumnRenamed("doctype", "DOCTYPE") \
.withColumnRenamed("text", "TEXT")
# Step 4: Add boolean columns that help in later processing
df = df.withColumn('CHUNKED', lit(False))
print(df.count())
# Step 5: Show the DataFrame
df.show()
return df
def _list_documents(self) -> DataFrame:
"""
Lists all the files in the directory, and returns as a DataFrame
:return: DataFrame containing the list of files
"""
# Step 1: List all files in the directory using pathlib
all_files = ["/home/cdsw/docs/test.pdf"]
print("all docs", all_files)
# Step 2: Read all file-names into a Spark DataFrame
files = [str(file) for file in all_files]
files_df = self.spark.createDataFrame(files, StringType())
files_df.show(truncate=False)
return files_df
def describe(self):
return 'Extracts text from documents in a directory, and stores it in a database table'
if __name__ == '__main__':
job = TextExtractionJob()
job.run()
job.spark.stop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment