Last active
April 4, 2024 22:11
-
-
Save orellabac/82061f3b449ae7861bbe4930d6728fb8 to your computer and use it in GitHub Desktop.
Function to extract xml fragments from an XML document. The input file can be compressed (ZIP format, GZIP format or TAR.GZ format are supported)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CREATE OR REPLACE FUNCTION EXTRACT_XML(FILE_PATH STRING, QUERY STRING, ITER BOOLEAN DEFAULT FALSE, OPTIONS OBJECT DEFAULT {}) | |
RETURNS TABLE(DATA VARIANT) | |
LANGUAGE PYTHON | |
RUNTIME_VERSION='3.8' | |
PACKAGES = ('snowflake-snowpark-python','lxml') | |
HANDLER = 'XmlProcess' | |
AS | |
$$ | |
import gzip | |
import tarfile | |
import zipfile | |
import re | |
from snowflake.snowpark.files import SnowflakeFile | |
from lxml import etree | |
class XmlProcess: | |
def __init__(self): | |
pass | |
def opts(self,options): | |
self.redacted = options.get("redacted",[]) | |
self.array_elements = options.get("array_elements",[]) | |
self.namespaces = options.get("namespaces") | |
self.omit_namespaces = options.get("omit_namespaces",[]) | |
self.text_attr = options.get("text") or "#text" | |
self.get_attr = lambda attr_name : "@" + attr_name | |
self.get_tag = lambda tag_name : tag_name | |
if options.get("tag","").lower() == "with_namespace": | |
nm_map = self.namespaces | |
nm_separator = options.get("namespace_sep","") | |
if nm_map: | |
inv_nm_map = {v: k for k, v in nm_map.items()} | |
def map_tag(tag_name): | |
qn=etree.QName(tag_name) | |
if qn.namespace in self.omit_namespaces: | |
return qn.localname | |
else: | |
return inv_nm_map.get(qn.namespace,qn.namespace) + nm_separator + qn.localname | |
self.get_tag = map_tag | |
if options.get("tag","").lower() == "remove_namespace": | |
def remove_ns(tag_name): | |
qn=etree.QName(tag_name) | |
return qn.localname | |
self.get_tag = remove_ns | |
if options.get("tag_case","").lower() == "upper": | |
old = self.get_tag | |
self.get_tag = lambda x : old(x).upper() | |
if options.get("attribute_case","").lower() == "upper": | |
self.get_attr = lambda attr_name : attr_name.upper() | |
def element_to_dict(self, element): | |
element_tag = self.get_tag(element.tag) | |
if element_tag in self.array_elements: | |
return [self.element_to_dict(child) for child in element] | |
result = {self.get_attr(k): v for k, v in element.attrib.items()} | |
for child in element: | |
child_new_tag = self.get_tag(child.tag) | |
result[child_new_tag] = self.element_to_dict(child) | |
inner_text = (element.text or "").strip() | |
element.clear() | |
if len(result)==0: # Nothing added so far, well then everything we have is the inner_text | |
result = inner_text | |
elif inner_text: | |
result[self.text_attr] = inner_text | |
if isinstance(result,dict): | |
for entry in self.redacted: | |
result.pop(entry,None) | |
return result | |
def extract_from_xml_xpath(self,xml_data, query): | |
xml_text = str(xml_data,'utf-8') | |
xml_data = None # Free memory | |
RE_XML_ENCODING = re.compile(r'^(<\?xml[^>]+)\s+encoding\s*=\s*["\'][^"\']*["\'](\s*\?>|)', re.U) | |
xml_text=RE_XML_ENCODING.sub("", xml_text, count=1) | |
# Parse XML | |
root = etree.fromstring(xml_text) | |
# Apply Query expression | |
results = root.xpath(query,namespaces=self.namespaces) | |
for node in results: | |
yield (self.element_to_dict(node),) | |
def extract_from_xml_iter(self, xml_data, element_name): | |
from io import BytesIO | |
context = etree.iterparse(BytesIO(xml_data), tag=element_name) | |
for event, element in context: | |
yield (self.element_to_dict(element), ) | |
def extract_from_xml(self,xml_data, query, iter): | |
if iter: | |
yield from self.extract_from_xml_iter(xml_data,query) | |
else: | |
yield from self.extract_from_xml_xpath(xml_data,query) | |
def process(self,file_path, query, iter, options): | |
self.opts(options) | |
with SnowflakeFile.open(file_path,'rb',require_scoped_url=False) as fhandle: | |
if file_path.endswith(".zip"): | |
with zipfile.ZipFile(fhandle,'r') as zip_file: | |
for file_info in zip_file.infolist(): | |
if file_info.filename.endswith(".xml"): | |
with zip_file.open(file_info) as open_zip: | |
yield from self.extract_from_xml(xml_data,query,iter) | |
elif file_path.endswith(".tar.gz") or file_path.endswith(".tgz"): | |
with gzip.open(fhandle, 'rb') as gz_file: | |
with tarfile.open(fileobj=gz_file, mode='r:*') as tar: | |
for member in tar.getmembers(): | |
if member.isfile() and member.name.endswith(".xml"): | |
xml_data = tar.extractfile(member).read() | |
# is it binary | |
if xml_data[0] != 0: | |
yield from self.extract_from_xml(xml_data, query,iter) | |
elif file_path.endswith(".gz"): | |
with gzip.open(fhandle, 'rb') as gz_file: | |
xml_data = gz_file.read() | |
yield from self.extract_from_xml_iter(xml_data, query,iter) | |
else: | |
xml_data = fhandle.read() | |
yield from self.extract_from_xml(xml_data, query,iter) | |
$$; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment