-
-
Save roee88/4aa7dfeceb2d8c3d8868ed8465ebf561 to your computer and use it in GitHub Desktop.
Arrow Java->Python bridge
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
""" | |
Functions to interact with Arrow memory allocated by Arrow Java. | |
These functions convert the objects holding the metadata, the actual | |
data is not copied at all. | |
This will only work with a JVM running in the same process such as provided | |
through jpype. Modules that talk to a remote JVM like py4j will not work as the | |
memory addresses reported by them are not reachable in the python process. | |
""" | |
import pyarrow as pa | |
from pyarrow.cffi import ffi | |
class JvmToPython: | |
def __init__(self, jvm_allocator, package): | |
""" | |
A bridge between Arrow Java and pyarrow. | |
Parameters | |
---------- | |
jvm_allocator: org.apache.arrow.memory.BufferAllocator | |
package: A Python object with the following attributes: | |
- ArrowSchema: org.apache.arrow.c.ArrowSchema class | |
- ArrowArray: org.apache.arrow.c.ArrowArray class | |
- Data: org.apache.arrow.c.Data class | |
Returns | |
------- | |
pyarrow.Field | |
""" | |
self.allocator = jvm_allocator | |
self.c_package = package | |
def field(self, jvm_field, jvm_dictionary_provider=None): | |
""" | |
Construct a Field from a org.apache.arrow.vector.types.pojo.Field | |
instance. | |
Parameters | |
---------- | |
jvm_field: org.apache.arrow.vector.types.pojo.Field | |
jvm_dictionary_provider: org.apache.arrow.vector.dictionary.DictionaryProvider | |
Returns | |
------- | |
pyarrow.Field | |
""" | |
c_schema = ffi.new("struct ArrowSchema*") | |
ptr_schema = int(ffi.cast("uintptr_t", c_schema)) | |
self.c_package.Data.exportField(self.allocator, | |
jvm_field, | |
jvm_dictionary_provider, | |
self.c_package.ArrowSchema.wrap( | |
ptr_schema) | |
) | |
return pa.Field._import_from_c(ptr_schema) | |
def schema(self, jvm_schema, jvm_dictionary_provider=None): | |
""" | |
Construct a Schema from a org.apache.arrow.vector.types.pojo.Schema | |
instance. | |
Parameters | |
---------- | |
jvm_schema: org.apache.arrow.vector.types.pojo.Schema | |
jvm_dictionary_provider: org.apache.arrow.vector.dictionary.DictionaryProvider | |
Returns | |
------- | |
pyarrow.Schema | |
""" | |
c_schema = ffi.new("struct ArrowSchema*") | |
ptr_schema = int(ffi.cast("uintptr_t", c_schema)) | |
self.c_package.Data.exportSchema(self.allocator, | |
jvm_schema, | |
jvm_dictionary_provider, | |
self.c_package.ArrowSchema.wrap( | |
ptr_schema) | |
) | |
return pa.Schema._import_from_c(ptr_schema) | |
def array(self, jvm_vector, jvm_dictionary_provider=None): | |
""" | |
Construct an (Python) Array from its JVM equivalent. | |
Parameters | |
---------- | |
jvm_vector: org.apache.arrow.vector.FieldVector | |
jvm_dictionary_provider: org.apache.arrow.vector.dictionary.DictionaryProvider | |
Returns | |
------- | |
array : Array | |
""" | |
c_schema = ffi.new("struct ArrowSchema*") | |
ptr_schema = int(ffi.cast("uintptr_t", c_schema)) | |
c_array = ffi.new("struct ArrowArray*") | |
ptr_array = int(ffi.cast("uintptr_t", c_array)) | |
self.c_package.Data.exportVector(self.allocator, | |
jvm_vector, | |
jvm_dictionary_provider, | |
self.c_package.ArrowArray.wrap( | |
ptr_array), | |
self.c_package.ArrowSchema.wrap( | |
ptr_schema) | |
) | |
return pa.Array._import_from_c(ptr_array, ptr_schema) | |
def record_batch(self, jvm_vector_schema_root, jvm_dictionary_provider=None): | |
""" | |
Construct a (Python) RecordBatch from a JVM VectorSchemaRoot | |
Parameters | |
---------- | |
jvm_vector_schema_root : org.apache.arrow.vector.VectorSchemaRoot | |
jvm_dictionary_provider: org.apache.arrow.vector.dictionary.DictionaryProvider | |
Returns | |
------- | |
record_batch: pyarrow.RecordBatch | |
""" | |
c_schema = ffi.new("struct ArrowSchema*") | |
ptr_schema = int(ffi.cast("uintptr_t", c_schema)) | |
c_array = ffi.new("struct ArrowArray*") | |
ptr_array = int(ffi.cast("uintptr_t", c_array)) | |
self.c_package.Data.exportVectorSchemaRoot(self.allocator, | |
jvm_vector_schema_root, | |
jvm_dictionary_provider, | |
self.c_package.ArrowArray.wrap( | |
ptr_array), | |
self.c_package.ArrowSchema.wrap( | |
ptr_schema) | |
) | |
return pa.RecordBatch._import_from_c(ptr_array, ptr_schema) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
import json | |
import os | |
from numpy import int32, string_ | |
import pyarrow as pa | |
import jvm as pa_jvm | |
import pytest | |
import sys | |
import xml.etree.ElementTree as ET | |
jpype = pytest.importorskip("jpype") | |
@pytest.fixture(scope="session") | |
def root_allocator(): | |
# This test requires Arrow Java to be built in the same source tree | |
try: | |
arrow_dir = os.environ["ARROW_SOURCE_DIR"] | |
except KeyError: | |
arrow_dir = os.path.join(os.path.dirname( | |
__file__), '..', '..', '..', '..', '..') | |
pom_path = os.path.join(arrow_dir, 'java', 'pom.xml') | |
tree = ET.parse(pom_path) | |
version = tree.getroot().find( | |
'POM:version', | |
namespaces={ | |
'POM': 'http://maven.apache.org/POM/4.0.0' | |
}).text | |
jar_path = os.path.join( | |
arrow_dir, 'java', 'tools', 'target', | |
'arrow-tools-{}-jar-with-dependencies.jar'.format(version)) | |
jar_path = os.getenv("ARROW_TOOLS_JAR", jar_path) | |
jar_path += ":{}".format(os.path.join(arrow_dir, | |
"java", "c/target/arrow-c-data-{}.jar".format(version))) | |
kwargs = {} | |
# This will be the default behaviour in jpype 0.8+ | |
kwargs['convertStrings'] = False | |
jpype.startJVM(jpype.getDefaultJVMPath(), "-Djava.class.path=" + jar_path, | |
**kwargs) | |
return jpype.JPackage("org").apache.arrow.memory.RootAllocator(sys.maxsize) | |
@pytest.fixture(scope="session") | |
def bridge(root_allocator): | |
return pa_jvm.JvmToPython(root_allocator, jpype.JPackage("org").apache.arrow.c) | |
# def test_jvm_buffer(root_allocator): | |
# # Create a Java buffer | |
# jvm_buffer = root_allocator.buffer(8) | |
# for i in range(8): | |
# jvm_buffer.setByte(i, 8 - i) | |
# orig_refcnt = jvm_buffer.refCnt() | |
# # Convert to Python | |
# buf = pa_jvm.jvm_buffer(jvm_buffer) | |
# # Check its content | |
# assert buf.to_pybytes() == b'\x08\x07\x06\x05\x04\x03\x02\x01' | |
# # Check Java buffer lifetime is tied to PyArrow buffer lifetime | |
# assert jvm_buffer.refCnt() == orig_refcnt + 1 | |
# del buf | |
# assert jvm_buffer.refCnt() == orig_refcnt | |
# def test_jvm_buffer_released(root_allocator): | |
# import jpype.imports # noqa | |
# from java.lang import IllegalArgumentException | |
# jvm_buffer = root_allocator.buffer(8) | |
# jvm_buffer.release() | |
# with pytest.raises(IllegalArgumentException): | |
# pa_jvm.jvm_buffer(jvm_buffer) | |
def _jvm_field(jvm_spec): | |
om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() | |
pojo_Field = jpype.JClass('org.apache.arrow.vector.types.pojo.Field') | |
return om.readValue(jvm_spec, pojo_Field) | |
def _jvm_schema(jvm_spec, metadata=None): | |
field = _jvm_field(jvm_spec) | |
schema_cls = jpype.JClass('org.apache.arrow.vector.types.pojo.Schema') | |
fields = jpype.JClass('java.util.ArrayList')() | |
fields.add(field) | |
if metadata: | |
dct = jpype.JClass('java.util.HashMap')() | |
for k, v in metadata.items(): | |
dct.put(k, v) | |
return schema_cls(fields, dct) | |
else: | |
return schema_cls(fields) | |
# In the following, we use the JSON serialization of the Field objects in Java. | |
# This ensures that we neither rely on the exact mechanics on how to construct | |
# them using Java code as well as enables us to define them as parameters | |
# without to invoke the JVM. | |
# | |
# The specifications were created using: | |
# | |
# om = jpype.JClass('com.fasterxml.jackson.databind.ObjectMapper')() | |
# field = … # Code to instantiate the field | |
# jvm_spec = om.writeValueAsString(field) | |
@pytest.mark.parametrize('pa_type,jvm_spec', [ | |
(pa.null(), {"type": {"name": "null"}}), | |
(pa.bool_(), {"type": {"name":"bool"}}), | |
(pa.int8(), {"type": {"name":"int","bitWidth":8,"isSigned":True}}), | |
(pa.int16(), {"type": {"name":"int","bitWidth":16,"isSigned":True}}), | |
(pa.int32(), {"type": {"name":"int","bitWidth":32,"isSigned":True}}), | |
(pa.int64(), {"type": {"name":"int","bitWidth":64,"isSigned":True}}), | |
(pa.uint8(), {"type": {"name":"int","bitWidth":8,"isSigned":False}}), | |
(pa.uint16(), {"type": {"name":"int","bitWidth":16,"isSigned":False}}), | |
(pa.uint32(), {"type": {"name":"int","bitWidth":32,"isSigned":False}}), | |
(pa.uint64(), {"type": {"name":"int","bitWidth":64,"isSigned":False}}), | |
(pa.float16(), {"type": {"name":"floatingpoint","precision":"HALF"}}), | |
(pa.float32(), {"type": {"name":"floatingpoint","precision":"SINGLE"}}), | |
(pa.float64(), {"type": {"name":"floatingpoint","precision":"DOUBLE"}}), | |
(pa.time32('s'), {"type": {"name":"time","unit":"SECOND","bitWidth":32}}), | |
(pa.time32('ms'), {"type": {"name":"time","unit":"MILLISECOND","bitWidth":32}}), | |
(pa.time64('us'), {"type": {"name":"time","unit":"MICROSECOND","bitWidth":64}}), | |
(pa.time64('ns'), {"type": {"name":"time","unit":"NANOSECOND","bitWidth":64}}), | |
(pa.timestamp('s'), {"type": {"name":"timestamp","unit":"SECOND", "timezone": None}}), | |
(pa.timestamp('ms'), {"type": {"name":"timestamp","unit":"MILLISECOND","timezone":None}}), | |
(pa.timestamp('us'), {"type": {"name":"timestamp","unit":"MICROSECOND","timezone":None}}), | |
(pa.timestamp('ns'), {"type": {"name":"timestamp","unit":"NANOSECOND","timezone":None}}), | |
(pa.timestamp('ns', tz='UTC'), {"type": {"name":"timestamp","unit":"NANOSECOND","timezone":"UTC"}}), | |
(pa.timestamp('ns', tz='Europe/Paris'), {"type": {"name":"timestamp","unit":"NANOSECOND","timezone":"Europe/Paris"}}), | |
(pa.date32(), {"type": {"name":"date","unit":"DAY"}}), | |
(pa.date64(), {"type": {"name":"date","unit":"MILLISECOND"}}), | |
(pa.decimal128(19, 4), {"type": {"name":"decimal","precision":19,"scale":4}}), | |
(pa.string(), {"type": {"name":"utf8"}}), | |
(pa.binary(), {"type": {"name":"binary"}}), | |
(pa.binary(10), {"type": {"name":"fixedsizebinary","byteWidth":10}}), | |
( | |
pa.list_(pa.int32()), | |
{ | |
"type": {"name": "list"}, | |
"children": [ | |
{"name": "item", "nullable": True, "type": {"name": "int", "isSigned": True, "bitWidth": 32}}, | |
] | |
} | |
), | |
( | |
pa.struct([pa.field('a', pa.int32()), pa.field('b', pa.int8()), pa.field('c', pa.string())]), | |
{ | |
"type": {"name": "struct"}, | |
"children": [ | |
{"name": "a", "nullable": True, "type": {"name": "int", "isSigned": True, "bitWidth": 32}}, | |
{"name": "b", "nullable": True, "type": {"name": "int", "isSigned": True, "bitWidth": 8}}, | |
{"name": "c", "nullable": True, "type": {"name": "utf8"}}, | |
] | |
} | |
), | |
( | |
pa.union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())], mode=pa.lib.UnionMode_DENSE), | |
{ | |
"type": {"name": "union", "mode" : "Dense", "typeIds" : [0,1]}, | |
"children": [ | |
{"name": "a", "nullable": True, "type": {"name": "fixedsizebinary", "byteWidth": 10}}, | |
{"name": "b", "nullable": True, "type": {"name": "utf8"}}, | |
] | |
} | |
), | |
( | |
pa.union([pa.field('a', pa.binary(10)), pa.field('b', pa.string())], mode=pa.lib.UnionMode_SPARSE), | |
{ | |
"type": {"name": "union", "mode" : "Sparse", "typeIds" : [0,1]}, | |
"children": [ | |
{"name": "a", "nullable": True, "type": {"name": "fixedsizebinary", "byteWidth": 10}}, | |
{"name": "b", "nullable": True, "type": {"name": "utf8"}}, | |
] | |
} | |
), | |
# TODO: DictionaryType requires a populated Java dictionary provider | |
# ( | |
# pa.dictionary(pa.int32(), pa.utf8(), False), | |
# { | |
# "type" : {"name" : "utf8"}, | |
# "dictionary" : { | |
# "id" : 1, | |
# "isOrdered" : False, | |
# "indexType" : { | |
# "name" : "int", | |
# "bitWidth" : 32, | |
# "isSigned" : True | |
# } | |
# }, | |
# } | |
# ), | |
]) | |
@pytest.mark.parametrize('nullable', [True, False]) | |
def test_jvm_types(root_allocator, bridge, pa_type, jvm_spec, nullable): | |
if pa_type == pa.null() and not nullable: | |
return | |
spec = jvm_spec | |
spec['name'] = 'field_name' | |
spec['nullable'] = nullable | |
# TODO: DictionaryType requires a populated Java dictionary provider | |
provider = None | |
jvm_field = _jvm_field(json.dumps(spec)) | |
result = bridge.field(jvm_field, provider) | |
expected_field = pa.field('field_name', pa_type, nullable=nullable) | |
assert result == expected_field | |
jvm_schema = _jvm_schema(json.dumps(spec)) | |
result = bridge.schema(jvm_schema) | |
assert result == pa.schema([expected_field]) | |
# Schema with custom metadata | |
jvm_schema = _jvm_schema(json.dumps(spec), {'meta': 'data'}) | |
result = bridge.schema(jvm_schema) | |
assert result == pa.schema([expected_field], {'meta': 'data'}) | |
# Schema with custom field metadata | |
spec['metadata'] = [{'key': 'field meta', 'value': 'field data'}] | |
jvm_schema = _jvm_schema(json.dumps(spec)) | |
result = bridge.schema(jvm_schema, provider) | |
expected_field = expected_field.with_metadata( | |
{'field meta': 'field data'}) | |
assert result == pa.schema([expected_field]) | |
# These test parameters mostly use an integer range as an input as this is | |
# often the only type that is understood by both Python and Java | |
# implementations of Arrow. | |
@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_type_args', [ | |
(pa.bool_(), [True, False, True, True], 'BitVector', None), | |
(pa.uint8(), list(range(128)), 'UInt1Vector', None), | |
(pa.uint16(), list(range(128)), 'UInt2Vector', None), | |
(pa.int32(), list(range(128)), 'IntVector', None), | |
(pa.int64(), list(range(128)), 'BigIntVector', None), | |
(pa.float32(), list(range(128)), 'Float4Vector', None), | |
(pa.float64(), list(range(128)), 'Float8Vector', None), | |
(pa.timestamp('s'), list(range(128)), 'TimeStampSecVector', None), | |
(pa.timestamp('ms'), list(range(128)), 'TimeStampMilliVector', None), | |
(pa.timestamp('us'), list(range(128)), 'TimeStampMicroVector', None), | |
(pa.timestamp('ns'), list(range(128)), 'TimeStampNanoVector', None), | |
(pa.time32('s'), list(range(128)), 'TimeSecVector', None), | |
(pa.time32('ms'), list(range(128)), 'TimeMilliVector', None), | |
(pa.time64('us'), list(range(128)), 'TimeMicroVector', None), | |
(pa.time64('ns'), list(range(128)), 'TimeNanoVector', None), | |
(pa.date32(), list(range(128)), 'DateDayVector', None), | |
(pa.date64(), list(range(128)), 'DateMilliVector', None), | |
(pa.decimal128(19, 4), list(range(128)), 'DecimalVector', [19, 4]), | |
]) | |
def test_jvm_array(root_allocator, bridge, pa_type, py_data, jvm_type, jvm_type_args): | |
# Create vector | |
cls = "org.apache.arrow.vector.{}".format(jvm_type) | |
jvm_type_args = jvm_type_args or [] | |
jvm_vector = jpype.JClass(cls)("vector", root_allocator, *jvm_type_args) | |
jvm_vector.allocateNew(len(py_data)) | |
for i, val in enumerate(py_data): | |
# char and int are ambiguous overloads for these two setSafe calls | |
if jvm_type in {'UInt1Vector', 'UInt2Vector'}: | |
val = jpype.JInt(val) | |
# values for decimal should account for scale | |
if jvm_type in {'DecimalVector'}: | |
_, scale = jvm_type_args | |
val = val * 10**scale | |
jvm_vector.setSafe(i, val) | |
jvm_vector.setValueCount(len(py_data)) | |
py_array = pa.array(py_data, type=pa_type) | |
jvm_array = bridge.array(jvm_vector) | |
assert py_array.equals(jvm_array) | |
def test_jvm_array_empty(root_allocator, bridge): | |
cls = "org.apache.arrow.vector.{}".format('IntVector') | |
jvm_vector = jpype.JClass(cls)("vector", root_allocator) | |
jvm_vector.allocateNew() | |
jvm_array = bridge.array(jvm_vector) | |
assert len(jvm_array) == 0 | |
assert jvm_array.type == pa.int32() | |
# These test parameters mostly use an integer range as an input as this is | |
# often the only type that is understood by both Python and Java | |
# implementations of Arrow. | |
@pytest.mark.parametrize('pa_type,py_data,jvm_type,jvm_type_args,jvm_spec', [ | |
# TODO: null | |
# ( | |
# pa.null(), | |
# [], | |
# 'NullVector', | |
# None, | |
# '{"name":"null"}' | |
# ), | |
( | |
pa.bool_(), | |
[True, False, True, True], | |
'BitVector', | |
None, | |
'{"name":"bool"}' | |
), | |
( | |
pa.uint8(), | |
list(range(128)), | |
'UInt1Vector', | |
None, | |
'{"name":"int","bitWidth":8,"isSigned":false}' | |
), | |
( | |
pa.uint16(), | |
list(range(128)), | |
'UInt2Vector', | |
None, | |
'{"name":"int","bitWidth":16,"isSigned":false}' | |
), | |
( | |
pa.uint32(), | |
list(range(128)), | |
'UInt4Vector', | |
None, | |
'{"name":"int","bitWidth":32,"isSigned":false}' | |
), | |
( | |
pa.uint64(), | |
list(range(128)), | |
'UInt8Vector', | |
None, | |
'{"name":"int","bitWidth":64,"isSigned":false}' | |
), | |
( | |
pa.int8(), | |
list(range(128)), | |
'TinyIntVector', | |
None, | |
'{"name":"int","bitWidth":8,"isSigned":true}' | |
), | |
( | |
pa.int16(), | |
list(range(128)), | |
'SmallIntVector', | |
None, | |
'{"name":"int","bitWidth":16,"isSigned":true}' | |
), | |
( | |
pa.int32(), | |
list(range(128)), | |
'IntVector', | |
None, | |
'{"name":"int","bitWidth":32,"isSigned":true}' | |
), | |
( | |
pa.int64(), | |
list(range(128)), | |
'BigIntVector', | |
None, | |
'{"name":"int","bitWidth":64,"isSigned":true}' | |
), | |
# TODO: float16 | |
( | |
pa.float32(), | |
list(range(128)), | |
'Float4Vector', | |
None, | |
'{"name":"floatingpoint","precision":"SINGLE"}' | |
), | |
( | |
pa.float64(), | |
list(range(128)), | |
'Float8Vector', | |
None, | |
'{"name":"floatingpoint","precision":"DOUBLE"}' | |
), | |
( | |
pa.timestamp('s'), | |
list(range(128)), | |
'TimeStampSecVector', | |
None, | |
'{"name":"timestamp","unit":"SECOND","timezone":null}' | |
), | |
( | |
pa.timestamp('ms'), | |
list(range(128)), | |
'TimeStampMilliVector', | |
None, | |
'{"name":"timestamp","unit":"MILLISECOND","timezone":null}' | |
), | |
( | |
pa.timestamp('us'), | |
list(range(128)), | |
'TimeStampMicroVector', | |
None, | |
'{"name":"timestamp","unit":"MICROSECOND","timezone":null}' | |
), | |
( | |
pa.timestamp('ns'), | |
list(range(128)), | |
'TimeStampNanoVector', | |
None, | |
'{"name":"timestamp","unit":"NANOSECOND","timezone":null}' | |
), | |
( | |
pa.time32('s'), | |
list(range(128)), | |
'TimeSecVector', | |
None, | |
'{"name":"time","unit":"SECOND", "bitWidth":32}' | |
), | |
( | |
pa.time32('ms'), | |
list(range(128)), | |
'TimeMilliVector', | |
None, | |
'{"name":"time","unit":"MILLISECOND", "bitWidth":32}' | |
), | |
( | |
pa.time64('us'), | |
list(range(128)), | |
'TimeMicroVector', | |
None, | |
'{"name":"time","unit":"MICROSECOND", "bitWidth":64}' | |
), | |
( | |
pa.time64('ns'), | |
list(range(128)), | |
'TimeNanoVector', | |
None, | |
'{"name":"time","unit":"NANOSECOND", "bitWidth":64}' | |
), | |
( | |
pa.date32(), | |
list(range(128)), | |
'DateDayVector', | |
None, | |
'{"name":"date","unit":"DAY"}' | |
), | |
( | |
pa.date64(), | |
list(range(128)), | |
'DateMilliVector', | |
None, | |
'{"name":"date","unit":"MILLISECOND"}' | |
), | |
( | |
pa.decimal128(19, 4), | |
list(range(128)), | |
'DecimalVector', | |
(19, 4), | |
'{"name":"decimal","bitWidth":128,"precision":19,"scale":4}' | |
), | |
]) | |
def test_jvm_record_batch(root_allocator, bridge, pa_type, py_data, jvm_type, | |
jvm_type_args, jvm_spec): | |
# Create vector | |
cls = "org.apache.arrow.vector.{}".format(jvm_type) | |
jvm_type_args = jvm_type_args or [] | |
args = ["vector"] | |
if jvm_type not in {'NullVector'}: | |
args.append(root_allocator) | |
if jvm_type_args: | |
args += jvm_type_args | |
jvm_vector = jpype.JClass(cls)(*args) | |
if jvm_type not in {'NullVector'}: | |
jvm_vector.allocateNew(len(py_data)) | |
else: | |
jvm_vector.allocateNew() | |
for i, val in enumerate(py_data): | |
if jvm_type in {'UInt1Vector', 'UInt2Vector'}: | |
val = jpype.JInt(val) | |
if jvm_type in {'DecimalVector'}: | |
_, scale = jvm_type_args | |
val = val * 10**scale | |
jvm_vector.setSafe(i, val) | |
jvm_vector.setValueCount(len(py_data)) | |
# Create field | |
spec = { | |
'name': 'field_name', | |
'nullable': False, | |
'type': json.loads(jvm_spec), | |
# TODO: This needs to be set for complex types | |
'children': [] | |
} | |
jvm_field = _jvm_field(json.dumps(spec)) | |
# Create VectorSchemaRoot | |
jvm_fields = jpype.JClass('java.util.ArrayList')() | |
jvm_fields.add(jvm_field) | |
jvm_vectors = jpype.JClass('java.util.ArrayList')() | |
jvm_vectors.add(jvm_vector) | |
jvm_vsr = jpype.JClass('org.apache.arrow.vector.VectorSchemaRoot') | |
jvm_vsr = jvm_vsr(jvm_fields, jvm_vectors, len(py_data)) | |
py_record_batch = pa.RecordBatch.from_arrays( | |
[pa.array(py_data, type=pa_type)], | |
['col'] | |
) | |
jvm_record_batch = bridge.record_batch(jvm_vsr) | |
assert py_record_batch.equals(jvm_record_batch) | |
def _string_to_varchar_holder(ra, string): | |
nvch_cls = "org.apache.arrow.vector.holders.NullableVarCharHolder" | |
holder = jpype.JClass(nvch_cls)() | |
if string is None: | |
holder.isSet = 0 | |
else: | |
holder.isSet = 1 | |
value = jpype.JClass("java.lang.String")(string) | |
std_charsets = jpype.JClass("java.nio.charset.StandardCharsets") | |
bytes_ = value.getBytes(std_charsets.UTF_8) | |
holder.buffer = ra.buffer(len(bytes_)) | |
holder.buffer.setBytes(0, bytes_, 0, len(bytes_)) | |
holder.start = 0 | |
holder.end = len(bytes_) | |
return holder | |
def test_jvm_string_array(root_allocator, bridge): | |
data = ["string", None, "töst"] | |
cls = "org.apache.arrow.vector.VarCharVector" | |
jvm_vector = jpype.JClass(cls)("vector", root_allocator) | |
jvm_vector.allocateNew() | |
for i, string in enumerate(data): | |
holder = _string_to_varchar_holder(root_allocator, string) | |
jvm_vector.setSafe(i, holder) | |
jvm_vector.setValueCount(i + 1) | |
py_array = pa.array(data, type=pa.string()) | |
jvm_array = bridge.array(jvm_vector) | |
assert py_array.equals(jvm_array) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment