Last active
March 29, 2020 15:36
-
-
Save tanguycdls/0a1f7b928a27f9a4c5659f17d315999c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pyarrow as pa | |
import numpy as np | |
import pyarrow as pa | |
from tensorflow_data_validation.statistics import stats_options as options | |
from apache_beam.io.filesystem import CompressionTypes | |
import apache_beam as beam | |
from typing import Any, List, Optional, Text | |
from tensorflow_metadata.proto.v0 import statistics_pb2 | |
import os | |
import tempfile | |
import tensorflow as tf | |
from tensorflow_data_validation import constants | |
from tensorflow_data_validation.api import stats_api | |
from tensorflow_data_validation import load_statistics | |
from apache_beam.options.pipeline_options import PipelineOptions | |
import pandas as pd | |
def convert_table(table): | |
arrays = [] | |
col_names = [] | |
for col in table.column_names: | |
try: | |
arrays.append(convert_to_list_array(table[col])) | |
col_names.append(col) | |
except Exception as e: | |
print(f'column {col} could not be converted {e}') | |
return pa.table(arrays, names=col_names) | |
def generate_statistics_from_arrow( | |
data_location: Text, | |
output_path: Optional[bytes] = None, | |
stats_options: options.StatsOptions = options.StatsOptions(), | |
pipeline_options: Optional[PipelineOptions] = None, | |
compression_type: Text = CompressionTypes.AUTO, | |
) -> statistics_pb2.DatasetFeatureStatisticsList: | |
if output_path is None: | |
output_path = os.path.join(tempfile.mkdtemp(), 'data_stats.tfrecord') | |
output_dir_path = os.path.dirname(output_path) | |
if not tf.io.gfile.exists(output_dir_path): | |
tf.io.gfile.makedirs(output_dir_path) | |
with beam.Pipeline(options=pipeline_options) as p: | |
_ = ( | |
p | |
| 'ReadData' >> beam.io.parquetio.ReadFromParquetBatched( | |
file_pattern=data_location) | |
| 'ConvertToValidArrow' >> beam.Map(convert_table) | |
| 'GenerateStatistics' >> stats_api.GenerateStatistics(stats_options) | |
| 'WriteStatsOutput' >> beam.io.WriteToTFRecord( | |
output_path, | |
shard_name_template='', | |
coder=beam.coders.ProtoCoder( | |
statistics_pb2.DatasetFeatureStatisticsList))) | |
return load_statistics(output_path) | |
def get_null_mask(arr): | |
return arr.to_pandas().isna().values # thats costly ... | |
def create_offset(null_mask): | |
# to have null in a pyarrow List of List you need: | |
# offset[j] = None -> arr[j] = None | |
offset = (null_mask == False).cumsum() - 1 # if first value is None it will be replaced so Ok otherwise it will be 0 | |
offset = np.concatenate([offset, [offset[-1] + 1]]).astype(object) | |
offset[np.where(null_mask==True)[0]] = None | |
return offset | |
def get_values(arr, null_mask): | |
return arr.take(pa.array(np.where(null_mask==False)[0])) | |
def transform_null_to_list_list(arr): | |
null_mask = get_null_mask(arr) | |
offset = create_offset(null_mask) | |
values = get_values(arr, null_mask) | |
return pa.ListArray.from_arrays(offset, values) | |
def convert_to_list_array(array): | |
if isinstance(array.type, pa.lib.ListType): | |
return array | |
type_ = array.type.to_pandas_dtype() | |
if hasattr(type_, 'kind'): | |
if type_.kind == 'M': | |
# tfdv does not support dates and pyarrow cannot cast from datetime to string natively | |
return pa.lib.ListArray.from_arrays(np.arange(len(array)+1), | |
pd.to_datetime(array.to_pandas()).dt.strftime('%Y-%m-%d')) | |
assert array.num_chunks == 1, 'Function is not compatible with array with more than one chunk' | |
if array.null_count == 0: | |
return pa.lib.ListArray.from_arrays(np.arange(len(array)+1), array.chunks[0]) | |
else: | |
return transform_null_to_list_list(array.chunks[0]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment