Skip to content

Instantly share code, notes, and snippets.

@tanguycdls
Last active March 29, 2020 15:36
Show Gist options
  • Save tanguycdls/0a1f7b928a27f9a4c5659f17d315999c to your computer and use it in GitHub Desktop.
Save tanguycdls/0a1f7b928a27f9a4c5659f17d315999c to your computer and use it in GitHub Desktop.
import numpy as np
import pyarrow as pa
import numpy as np
import pyarrow as pa
from tensorflow_data_validation.statistics import stats_options as options
from apache_beam.io.filesystem import CompressionTypes
import apache_beam as beam
from typing import Any, List, Optional, Text
from tensorflow_metadata.proto.v0 import statistics_pb2
import os
import tempfile
import tensorflow as tf
from tensorflow_data_validation import constants
from tensorflow_data_validation.api import stats_api
from tensorflow_data_validation import load_statistics
from apache_beam.options.pipeline_options import PipelineOptions
import pandas as pd
def convert_table(table):
arrays = []
col_names = []
for col in table.column_names:
try:
arrays.append(convert_to_list_array(table[col]))
col_names.append(col)
except Exception as e:
print(f'column {col} could not be converted {e}')
return pa.table(arrays, names=col_names)
def generate_statistics_from_arrow(
data_location: Text,
output_path: Optional[bytes] = None,
stats_options: options.StatsOptions = options.StatsOptions(),
pipeline_options: Optional[PipelineOptions] = None,
compression_type: Text = CompressionTypes.AUTO,
) -> statistics_pb2.DatasetFeatureStatisticsList:
if output_path is None:
output_path = os.path.join(tempfile.mkdtemp(), 'data_stats.tfrecord')
output_dir_path = os.path.dirname(output_path)
if not tf.io.gfile.exists(output_dir_path):
tf.io.gfile.makedirs(output_dir_path)
with beam.Pipeline(options=pipeline_options) as p:
_ = (
p
| 'ReadData' >> beam.io.parquetio.ReadFromParquetBatched(
file_pattern=data_location)
| 'ConvertToValidArrow' >> beam.Map(convert_table)
| 'GenerateStatistics' >> stats_api.GenerateStatistics(stats_options)
| 'WriteStatsOutput' >> beam.io.WriteToTFRecord(
output_path,
shard_name_template='',
coder=beam.coders.ProtoCoder(
statistics_pb2.DatasetFeatureStatisticsList)))
return load_statistics(output_path)
def get_null_mask(arr):
return arr.to_pandas().isna().values # thats costly ...
def create_offset(null_mask):
# to have null in a pyarrow List of List you need:
# offset[j] = None -> arr[j] = None
offset = (null_mask == False).cumsum() - 1 # if first value is None it will be replaced so Ok otherwise it will be 0
offset = np.concatenate([offset, [offset[-1] + 1]]).astype(object)
offset[np.where(null_mask==True)[0]] = None
return offset
def get_values(arr, null_mask):
return arr.take(pa.array(np.where(null_mask==False)[0]))
def transform_null_to_list_list(arr):
null_mask = get_null_mask(arr)
offset = create_offset(null_mask)
values = get_values(arr, null_mask)
return pa.ListArray.from_arrays(offset, values)
def convert_to_list_array(array):
if isinstance(array.type, pa.lib.ListType):
return array
type_ = array.type.to_pandas_dtype()
if hasattr(type_, 'kind'):
if type_.kind == 'M':
# tfdv does not support dates and pyarrow cannot cast from datetime to string natively
return pa.lib.ListArray.from_arrays(np.arange(len(array)+1),
pd.to_datetime(array.to_pandas()).dt.strftime('%Y-%m-%d'))
assert array.num_chunks == 1, 'Function is not compatible with array with more than one chunk'
if array.null_count == 0:
return pa.lib.ListArray.from_arrays(np.arange(len(array)+1), array.chunks[0])
else:
return transform_null_to_list_list(array.chunks[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment