Skip to content

Instantly share code, notes, and snippets.

@jackd
Last active August 2, 2022 12:41
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jackd/f62cfc5e548876fc05af5610587f0704 to your computer and use it in GitHub Desktop.
Save jackd/f62cfc5e548876fc05af5610587f0704 to your computer and use it in GitHub Desktop.
tensorflow graphics keras port for PR #155

Get my forked tensorflow graphics repo and switch to appropriate branch

git clone https://github.com/jackd/graphics.git
cd graphics
git checkout sparse-feastnet
pip install -e .
cd ..

Get this gist:

git clone https://gist.github.com/jackd/f62cfc5e548876fc05af5610587f0704.git
cd f62cfc5e548876fc05af5610587f0704

Train the variants

python fit.py
python fit.py --sparse

Look at the summaries

tensorboard --logdir=/tmp/graphics
#Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset Pipeline for mesh_segmentation_demo.ipynb.
The shorthands used in parameter descriptions below are
'B': Batch size.
'E': Number of unique directed edges in a mesh.
'V': Number of vertices in a mesh.
'T': Number of triangles in a mesh.
"""
from typing import Tuple
import numpy as np
import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE
def block_diagonalize(vertices: tf.RaggedTensor,
neighbor_indices: tf.RaggedTensor
) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Block diagonalize the ragged batched features.
Args:
vertices: [B, V?, D] ragged tensor of vertex features.
neigbor_indices: [B, E?, 2] ragged tensor of neighbor indices.
Returns:
block_vertices: [BV, D] flattened tensor of vertex values
block_indices: [BE, 2] block-diagonalized indices
"""
assert isinstance(vertices, tf.RaggedTensor)
assert isinstance(neighbor_indices, tf.RaggedTensor)
b = tf.expand_dims(neighbor_indices.value_rowids(), axis=-1)
offset = vertices.row_splits
vertices = vertices.values
neighbor_indices = neighbor_indices.values + tf.gather(offset, b)
return vertices, neighbor_indices
def extract_unique_edges_from_triangular_mesh(faces,
num_vertices,
directed_edges=False):
"""Extracts all the unique edges using the faces of a mesh.
Args:
faces: A tensor of shape [T, 3], where T is the number of triangular
faces in the mesh. Each entry in this array describes the index of a
vertex in the mesh.
num_vertices: int scalar giving the number of vertices in the
corresponding mesh.
directed_edges: A boolean flag, whether to treat an edge as directed or
undirected. If (i, j) is an edge in the mesh and directed_edges is True,
then both (i, j) and (j, i) are returned in the list of edges. If (i, j)
is an edge in the mesh and directed_edges is False, then one of (i, j) or
(j, i) is returned.
Returns:
A tensor of shape [E, 2], where E is the number of edges in the mesh.
For eg: given faces = [[0, 1, 2], [0, 1, 3]], then
for directed_edges = False, one valid output is
[[0, 1], [0, 2], [0, 3], [1, 2], [3, 1]]
for directed_edges = True, one valid output is
[[0, 1], [0, 2], [0, 3], [1, 0], [1, 2], [1, 3],
[2, 0], [2, 1], [3, 0], [3, 1]]
Raises:
ValueError: If `faces` is not a tensor or if its shape is not
supported.
"""
if not isinstance(faces, tf.Tensor):
raise ValueError("'faces' must be a tf.Tenspr.")
if not faces.dtype.is_integer:
raise ValueError("'faces' must be of integer type")
faces.shape.assert_has_rank(2)
if faces.shape[1] != 3:
raise ValueError("'faces' must have shape [T, 3], got {}".format(
faces.shape))
rolled_faces = tf.roll(faces, shift=1, axis=1)
# we could make indices by stacking faces and rolled faces
# but unique requires our tensor to be 1D, so we'll ravel the index
# that means there's no need to stack in the first place
i = tf.reshape(faces, (-1, ))
j = tf.reshape(rolled_faces, (-1, ))
ravelled = i * num_vertices + j
unique, _ = tf.unique(ravelled)
indices = tf.unravel_index(unique, (num_vertices, num_vertices))
indices = tf.transpose(indices, (1, 0))
if directed_edges:
indices = tf.concat((indices, tf.reverse(indices, axis=[1])), axis=0)
return indices
def get_weighted_edges(faces, num_vertices, self_edges=True):
"""Gets unique edges and degree weights from a triangular mesh.
The shorthands used below are:
`T`: The number of triangles in the mesh.
`E`: The number of unique directed edges in the mesh.
Args:
faces: A [T, 3] `int32` numpy.ndarray of triangle vertex indices.
self_edges: A `bool` flag. If true, then for every vertex 'i' an edge
[i, i] is added to edge list.
Returns:
edges: A [E, 2] `int32` numpy.ndarray of directed edges.
weights: A [E] `float32` numpy.ndarray denoting edge weights.
The degree of a vertex is the number of edges incident on the vertex,
including any self-edges. The weight for an edge $w_{ij}$ connecting vertex
$v_i$ and vertex $v_j$ is defined as,
$$
w_{ij} = 1.0 / degree(v_i)
\sum_{j} w_{ij} = 1
$$
"""
edges = extract_unique_edges_from_triangular_mesh(faces,
num_vertices,
directed_edges=True)
if self_edges:
identity = tf.tile(
tf.expand_dims(tf.range(num_vertices, dtype=edges.dtype), axis=-1),
(1, 2))
edges = tf.concat((edges, identity), axis=0)
_, index, counts = tf.unique_with_counts(edges[:, 0])
weights = 1. / tf.cast(tf.gather(counts, index), tf.float32)
return edges, weights
def _parse_tfex_proto(example_proto):
"""Parses the tfexample proto to a raw mesh_data dictionary.
Args:
example_proto: A tf.Example proto storing the encoded mesh data.
Returns:
A mesh data dictionary with the following fields:
'num_vertices': The `int64` number of vertices in mesh.
'num_triangles': The `int64` number of triangles in mesh.
'vertices': A serialized tensor of vertex positions.
'triangles': A serialized tensor of triangle vertex indices.
'labels': A serialized tensor of per vertex class labels.
"""
feature_description = {
'num_vertices': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'num_triangles': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'vertices': tf.io.FixedLenFeature([], tf.string, default_value=''),
'triangles': tf.io.FixedLenFeature([], tf.string, default_value=''),
'labels': tf.io.FixedLenFeature([], tf.string, default_value=''),
}
return tf.io.parse_single_example(serialized=example_proto,
features=feature_description)
def _parse_mesh_data(mesh_data, mean_center=True):
"""Parses a raw mesh_data dictionary read from tf examples.
Args:
mesh_data: A mesh data dictionary with serialized data tensors,
as output from _parse_tfex_proto()
mean_center: If true, centers the mesh vertices to mean(vertices).
Returns:
A mesh data dictionary with following fields:
'vertices': A [V, 3] `float32` of vertex positions.
'labels': A [V] `int32` tensor of per vertex class labels.
'edges': A [E, 2] `int32` tensor of unique directed edges in mesh.
'edge_weights': A [E] `float32` tensor of vertex degree based edge
weights.
"""
labels = tf.io.parse_tensor(mesh_data['labels'], tf.int32)
vertices = tf.io.parse_tensor(mesh_data['vertices'], tf.float32)
triangles = tf.io.parse_tensor(mesh_data['triangles'], tf.int32)
labels.set_shape((None, ))
vertices.set_shape((None, 3))
triangles.set_shape((None, 3))
if mean_center:
vertices = vertices - tf.reduce_mean(
input_tensor=vertices, axis=0, keepdims=True)
edges, weights = get_weighted_edges(triangles, tf.shape(vertices)[0])
mesh_data = dict(vertices=vertices,
labels=labels,
edges=edges,
edge_weights=weights)
return mesh_data
def get_base_dataset(tfrecords, num_parallel_reads: int = 16):
if not isinstance(tfrecords, list):
tfrecords = [tfrecords]
num_parallel_reads = max(len(tfrecords), num_parallel_reads)
return tf.data.TFRecordDataset(tfrecords,
num_parallel_reads=num_parallel_reads)
def preprocess(dataset: tf.data.Dataset,
mean_center: bool = True,
batch_size: int = 8,
shuffle_buffer: int = 100):
def pre_batch_map(example_proto):
mesh_data = _parse_tfex_proto(example_proto)
mesh_data = _parse_mesh_data(mesh_data, mean_center=mean_center)
return mesh_data
def post_batch_map(vertices, edges, edge_weights, labels):
vertices, edges = block_diagonalize(vertices, edges)
v = tf.shape(vertices)[0]
neighbors = tf.SparseTensor(tf.cast(edges, tf.int64),
edge_weights.values, (v, v))
neighbors = tf.sparse.reorder(neighbors)
labels = labels.values
return (vertices, neighbors), labels
if shuffle_buffer is not None:
dataset = dataset.shuffle(shuffle_buffer)
dataset = dataset.map(pre_batch_map, AUTOTUNE)
dataset = dataset.apply(
tf.data.experimental.dense_to_ragged_batch(batch_size,
row_splits_dtype=tf.int32))
dataset = dataset.map(lambda kwargs: post_batch_map(**kwargs), AUTOTUNE)
dataset = dataset.prefetch(AUTOTUNE)
return dataset
def get_batched_dataset(tfrecords,
batch_size=8,
shuffle_buffer=100,
mean_center=True,
num_parallel_reads=16,
shuffle_files=True):
return preprocess(
get_base_dataset(tfrecords, num_parallel_reads=num_parallel_reads),
batch_size=batch_size,
shuffle_buffer=shuffle_buffer,
mean_center=mean_center,
)
from absl import app, flags
import glob
import os
import tensorflow as tf
from model import mesh_model, compile_classifier
from dataio import get_batched_dataset
flags.DEFINE_boolean('sparse',
default=False,
help='use proposed sparse implementation')
flags.DEFINE_string('tb_dir',
default='/tmp/graphics/',
help='root directory to store tensorboard data')
flags.DEFINE_integer('epochs', default=10, help='number of epochs to train')
NUM_CLASSES = 16
def get_datasets():
url = ('https://storage.googleapis.com/tensorflow-graphics/notebooks/'
'mesh_segmentation/{}')
path_to_data_zip = tf.keras.utils.get_file('data.zip',
origin=url.format('data.zip'),
extract=True)
test_data_files = [
os.path.join(os.path.dirname(path_to_data_zip),
'data/Dancer_test_sequence.tfrecords')
]
test_dataset = get_batched_dataset(test_data_files)
path_to_train_data_zip = tf.keras.utils.get_file(
'train_data.zip', origin=url.format('train_data.zip'), extract=True)
train_data_files = glob.glob(
os.path.join(os.path.dirname(path_to_train_data_zip),
'*train*.tfrecords'))
train_dataset = get_batched_dataset(train_data_files)
return train_dataset, test_dataset
def main(_):
FLAGS = flags.FLAGS
sparse = FLAGS.sparse
sparse_impl = 'sparse_matmul' if sparse else 'gather_sum'
train_dataset, test_dataset = get_datasets()
initial_vertex_feature_dim = train_dataset.element_spec[0][0].shape[-1]
model = mesh_model(num_classes=NUM_CLASSES,
initial_vertex_feature_dim=initial_vertex_feature_dim,
sparse_impl=sparse_impl)
compile_classifier(model)
tb_dir = os.path.join(FLAGS.tb_dir, sparse_impl)
model.fit(train_dataset,
validation_data=test_dataset,
epochs=FLAGS.epochs,
callbacks=[tf.keras.callbacks.TensorBoard(tb_dir)])
if __name__ == '__main__':
app.run(main)
from typing import Sequence
import tensorflow as tf
from tensorflow_graphics.geometry.convolution.graph_convolution import SparseImplementation
from tensorflow_graphics.nn.layer import graph_convolution as gc
def mesh_model(num_classes: int = 16,
initial_vertex_feature_dim: int = 3,
num_weight_matrices: int = 8,
encoder_filter_dims: Sequence[int] = (32, 64, 128),
sparse_impl=SparseImplementation.GATHER_SUM) -> tf.keras.Model:
"""
Get a mesh encoder keras model.
This model operates on one graph at a time. To achieve the effect of
operating on multiple graphs, concatenate vertex features and block
diagonalize weights matrices. See `dataio.block_diagonalize`.
The shorthands used below are
`V`: The maximum number of vertices over all meshes in the batch.
`D`: The number of dimensions of input vertex features, D=3 if vertex
positions are used as features.
Args:
initial_vertex_feature_dim: number of features per input feature, D.
num_classes: number of classes to infer logit values for.
num_weight_matrices: The number of weight matrices to be used in feature
steered graph conv.
output_dim: A dimension of output per vertex features.
conv_layer_dims: A list of dimensions used in graph convolution layers.
Returns:
A keras model that maps (vertices, neighbors) -> logits:
'vertices': A [V, D] `float32` tensor of vertex features.
'neighbors': A [V, V] `float32` sparse tensor of edge weights.
'logits': A [V, num_classes] `float32` tensor of per-vertex logits.
"""
vertices = tf.keras.Input(shape=(initial_vertex_feature_dim, ),
dtype=tf.float32)
neighbors = tf.keras.Input(shape=(None, ), dtype=tf.float32, sparse=True)
vertex_features = tf.keras.layers.Dense(16, name='lin16')(vertices)
for i, dim in enumerate(encoder_filter_dims):
vertex_features = gc.FeatureSteeredConvolutionKerasLayer(
num_weight_matrices=num_weight_matrices,
num_output_channels=dim,
sparse_impl=sparse_impl,
name=f'conv_{i}-{dim}')((vertex_features, neighbors))
vertex_features = tf.nn.relu(vertex_features)
vertex_features = tf.keras.layers.Dense(256,
activation='relu',
name='lin256')(vertex_features)
logits = tf.keras.layers.Dense(num_classes, name='logits')(vertex_features)
return tf.keras.Model((vertices, neighbors), logits)
def compile_classifier(model,
init_learning_rate=1e-3,
lr_decay_steps=10000,
lr_decay_rate=0.95,
beta=0.9,
adam_epsilon=1e-8):
learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
init_learning_rate,
decay_steps=lr_decay_steps,
decay_rate=lr_decay_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
beta_1=beta,
epsilon=adam_epsilon)
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
@ddofer
Copy link

ddofer commented Dec 1, 2020

Any update/plans to porting that to TF/keras example notebooks? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment