Skip to content

Instantly share code, notes, and snippets.

@rafbarr
Created May 28, 2020 21:07
Show Gist options
  • Save rafbarr/1670426783b046ffff17c45dd063f45e to your computer and use it in GitHub Desktop.
Save rafbarr/1670426783b046ffff17c45dd063f45e to your computer and use it in GitHub Desktop.
Custom TFT scale_by_min_max_per_key
from functools import reduce
import apache_beam as beam
import numpy as np
import tensorflow as tf
import tensorflow_transform as tft
def transform_sparse_values(sp_tensor, trans_fun):
return tf.sparse.SparseTensor(
indices=sp_tensor.indices,
values=trans_fun(sp_tensor.values),
dense_shape=sp_tensor.dense_shape
)
class CollectAll(beam.PTransform):
def expand(self, pcoll):
pcoll = pcoll | "AddDummyKey" >> beam.Map(lambda v: (None, v))
pcoll = pcoll | "GroupByDummyKey" >> beam.GroupByKey()
pcoll = pcoll | "RemoveDummyKey" >> beam.Map(lambda v: v[1])
return pcoll
class BasicStatsCombiner(beam.CombineFn):
def create_accumulator(self):
return (np.inf, -np.inf, 0.0, 0.0, 0)
def add_input(self, acc, v):
(curr_min, curr_max, curr_sum, curr_sum_sqr, curr_count) = acc
new_min = min(curr_min, v)
new_max = max(curr_max, v)
new_sum = curr_sum + v
new_sum_sqr = curr_sum_sqr + v ** 2
new_count = curr_count + 1
return (new_min, new_max, new_sum, new_sum_sqr, new_count)
def merge_accumulators(self, accs):
min_accs, max_accs, sum_accs, sum_sqr_accs, count_accs = zip(*accs)
return reduce(min, min_accs), reduce(max, max_accs), sum(sum_accs), sum(sum_sqr_accs), sum(count_accs)
def extract_output(self, acc):
(final_min, final_max, final_sum, final_sum_sqr, final_count) = acc
if final_count:
mean = final_sum / final_count
std = np.maximum(
np.sqrt(final_sum_sqr / final_count - mean ** 2),
np.finfo(np.float64).eps
)
return {
'min': final_min,
'max': final_max,
'mean': mean,
'std': std,
'count': final_count
}
else:
return {
'min': np.nan,
'max': np.nan,
'mean': np.nan,
'std': np.nan,
'count': 0
}
class ComputeBasicStatsPerKey(beam.PTransform):
_DTYPE_BY_STAT_NAME = {
'min': np.float32,
'max': np.float32,
'mean': np.float32,
'std': np.float32,
'count': np.int64
}
def expand(self, keys_and_values):
flattened_keys_and_values = (
keys_and_values |
"ZipKeysAndValues" >> beam.FlatMap(lambda v: list(zip(v[0], v[1])))
)
stats_per_key = (
flattened_keys_and_values |
"ComputeStatsPerKey" >> beam.CombinePerKey(BasicStatsCombiner()) |
"CollectStatsPerKey" >> CollectAll()
)
rets = []
rets.append(
stats_per_key |
"ExtractKeys" >> beam.Map(lambda s: np.array([v[0] for v in s]))
)
def extract_stat(stats_per_key, stat_name, stat_dtype):
return (
stats_per_key |
"Extract{}Stat".format(stat_name.capitalize()) >> beam.Map(
lambda s: np.array([v[1][stat_name] for v in s], dtype=stat_dtype)
)
)
for stat_name, stat_dtype in self._DTYPE_BY_STAT_NAME.items():
rets.append(extract_stat(stats_per_key, stat_name, stat_dtype))
return tuple(rets)
def get_basic_stats_per_key(dense_keys, dense_values):
stats = tft.ptransform_analyzer(
[dense_keys, dense_values],
# keys, min, max, mean, std, count
[dense_keys.dtype, tf.float32, tf.float32, tf.float32, tf.float32, tf.int64],
[[None], [None], [None], [None], [None], [None]],
ptransform=ComputeBasicStatsPerKey(),
name='ComputeBasicStatsPerKey'
)
return {
'keys': stats[0],
'min': stats[1],
'max': stats[2],
'mean': stats[3],
'std': stats[4],
'count': stats[5]
}
def scale_by_min_max_per_key(
sp_keys,
sp_values,
output_min=0.0,
output_max=1.0,
stats_per_key=None,
name=None
):
with tf.compat.v1.name_scope(name, 'scale_by_min_max_per_key'):
stats_per_key = stats_per_key or get_basic_stats_per_key(
sp_keys.values, sp_values.values
)
min_lookup = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
stats_per_key['keys'],
stats_per_key['min']
),
default_value=np.nan,
name='min_lookup'
)
max_lookup = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
stats_per_key['keys'],
stats_per_key['max']
),
default_value=np.nan,
name='max_lookup'
)
def scale(v):
key_min = min_lookup.lookup(sp_keys.values)
key_max = max_lookup.lookup(sp_keys.values)
return tf.where(
key_min < key_max,
(v - key_min) / (key_max - key_min) * (output_max - output_min) + output_min,
tf.fill(tf.shape(v), (output_min + output_max) / 2.0)
)
return transform_sparse_values(sp_values, scale)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment