Skip to content

Instantly share code, notes, and snippets.

@vtslab
Last active December 30, 2020 14:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vtslab/81ded1a7af006100e00bf2a4a70a8147 to your computer and use it in GitHub Desktop.
Save vtslab/81ded1a7af006100e00bf2a4a70a8147 to your computer and use it in GitHub Desktop.
Converts spark-sql dtypes to a python-friendly format
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import re
import string
import pyparsing
def pysql_dtypes(dtypes):
"""Represents the spark-sql dtypes in terms of python [], {} and Row()
constructs.
:param dtypes: [(string, string)] result from pyspark.sql.DataFrame.dtypes
:return: [(string, string)]
"""
def assemble(nested):
cur = 0
assembled = ''
while cur < len(nested):
parts = re.findall(r'[^:,]+', nested[cur])
if not parts:
parts = [nested[cur]]
tail = parts[-1]
if tail == 'array':
assembled += nested[cur][:-5] + '['
assembled += assemble(nested[cur+1])
assembled += ']'
cur += 2
elif tail == 'map':
assembled += nested[cur][:-3] + '{'
assembled += assemble(nested[cur+1])
assembled += '}'
cur += 2
elif tail == 'struct':
assembled += nested[cur][:-6] + 'Row('
assembled += assemble(nested[cur+1])
assembled += ')'
cur += 2
else:
assembled += nested[cur]
cur += 1
return assembled
chars = ''.join([x for x in string.printable if x not in ['<', '>']])
word = pyparsing.Word(chars)
parens = pyparsing.nestedExpr('<', '>', content=word)
dtype = word + pyparsing.Optional(parens)
result = []
for name, schema in dtypes:
tree = dtype.parseString(schema).asList()
pyschema = assemble(tree).replace(',', ', ').replace(', ', ', ')
result.append((name, pyschema))
return result
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from pyspark.sql import types as t
from pysql_dtypes import pysql_dtypes
class TestPySqlDTypes:
"""Run with:
export SPARK_HOME=/opt/spark3-client
export PYTHONPATH=`echo $SPARK_HOME/python/lib/py4j-*-src.zip`
export PYTHONPATH=.:$PYTHONPATH:$SPARK_HOME/python/lib/pyspark.zip
pytest -vv test_pysql_dtypes.py
"""
def test_atomic(self):
data = []
for atomic in t._atomic_types:
type_name = atomic.typeName()
data.append(('field_' + type_name, type_name))
assert pysql_dtypes(data) == data
def test_array(self):
data = [('field', 'array<bigint>')]
assert pysql_dtypes(data) == [('field', '[bigint]')]
def test_map(self):
data = [('field', 'map<string,bigint>')]
assert pysql_dtypes(data) == [('field', '{string, bigint}')]
def test_struct_with_atom_atom(self):
data = [('field', 'struct<x:bigint,y:string>')]
assert pysql_dtypes(data) == [('field', 'Row(x:bigint, y:string)')]
def test_struct_with_atom_map(self):
data = [(
'field', 'struct<x:bigint,y:map<string,bigint>>')]
assert pysql_dtypes(data) == [(
'field', 'Row(x:bigint, y:{string, bigint})')]
def test_struct_with_atom_atom_map(self):
data = [(
'field', 'struct<x:bigint,y:bigint,z:map<string,bigint>>')]
assert pysql_dtypes(data) == [(
'field', 'Row(x:bigint, y:bigint, z:{string, bigint})')]
def test_struct_with_atom_array_map(self):
data = [(
'field', 'struct<x:bigint,y:array<bigint>,z:map<string,bigint>>')]
assert pysql_dtypes(data) == [(
'field', 'Row(x:bigint, y:[bigint], z:{string, bigint})')]
def test_array_struct_with_atom_atom(self):
data = [(
'field', 'array<struct<x:string,y.z:array<string>>>')]
assert pysql_dtypes(data) == [(
'field', '[Row(x:string, y.z:[string])]')]
def test_array_struct_with_atom_map(self):
data = [(
'field', 'array<struct<x.y:string,z:map<string,array<string>>>>')]
assert pysql_dtypes(data) == [(
'field', '[Row(x.y:string, z:{string, [string]})]')]
def test_array_struct_with_arraystruct_atom_atom(self):
data = [(
'field', 'array<struct<x:array<struct<x1:string,x2:string>>,' +
'y:string, z:string>>'
)]
assert pysql_dtypes(data) == [(
'field', '[Row(x:[Row(x1:string, x2:string)], y:string, z:string)]'
)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment