Skip to content

Instantly share code, notes, and snippets.

@rayheberer
Created August 17, 2018 00:25
Show Gist options
  • Save rayheberer/a74d35873b01a5e2e7097c4c591f5370 to your computer and use it in GitHub Desktop.
Save rayheberer/a74d35873b01a5e2e7097c4c591f5370 to your computer and use it in GitHub Desktop.
def preprocess_spatial_features(features, screen=True):
"""Embed categorical spatial features, log transform numeric features."""
# ...
# ...
preprocess_ops = []
for index, (feature_type, scale) in enumerate(feature_specs):
layer = transposed[:, :, :, index]
if feature_type == sc2_features.FeatureType.CATEGORICAL:
# one-hot encode in channel dimension -> 1x1 convolution
one_hot = tf.one_hot(
layer,
depth=scale,
axis=-1,
name="one_hot")
embed = tf.layers.conv2d(
inputs=one_hot,
filters=1,
kernel_size=[1, 1],
strides=[1, 1],
padding="SAME")
preprocess_ops.append(embed)
else:
transform = tf.log(
tf.cast(layer, tf.float32) + 1.,
name="log")
preprocess_ops.append(tf.expand_dims(transform, -1))
preprocessed = tf.concat(preprocess_ops, -1)
return preprocessed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment