|
#Copyright 2019 Google LLC |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# https://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
"""Dataset Pipeline for mesh_segmentation_demo.ipynb. |
|
|
|
The shorthands used in parameter descriptions below are |
|
'B': Batch size. |
|
'E': Number of unique directed edges in a mesh. |
|
'V': Number of vertices in a mesh. |
|
'T': Number of triangles in a mesh. |
|
""" |
|
from typing import Tuple |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
AUTOTUNE = tf.data.experimental.AUTOTUNE |
|
|
|
|
|
def block_diagonalize(vertices: tf.RaggedTensor, |
|
neighbor_indices: tf.RaggedTensor |
|
) -> Tuple[tf.Tensor, tf.Tensor]: |
|
""" |
|
Block diagonalize the ragged batched features. |
|
|
|
Args: |
|
vertices: [B, V?, D] ragged tensor of vertex features. |
|
neigbor_indices: [B, E?, 2] ragged tensor of neighbor indices. |
|
|
|
Returns: |
|
block_vertices: [BV, D] flattened tensor of vertex values |
|
block_indices: [BE, 2] block-diagonalized indices |
|
""" |
|
assert isinstance(vertices, tf.RaggedTensor) |
|
assert isinstance(neighbor_indices, tf.RaggedTensor) |
|
b = tf.expand_dims(neighbor_indices.value_rowids(), axis=-1) |
|
offset = vertices.row_splits |
|
vertices = vertices.values |
|
neighbor_indices = neighbor_indices.values + tf.gather(offset, b) |
|
return vertices, neighbor_indices |
|
|
|
|
|
def extract_unique_edges_from_triangular_mesh(faces, |
|
num_vertices, |
|
directed_edges=False): |
|
"""Extracts all the unique edges using the faces of a mesh. |
|
|
|
Args: |
|
faces: A tensor of shape [T, 3], where T is the number of triangular |
|
faces in the mesh. Each entry in this array describes the index of a |
|
vertex in the mesh. |
|
num_vertices: int scalar giving the number of vertices in the |
|
corresponding mesh. |
|
directed_edges: A boolean flag, whether to treat an edge as directed or |
|
undirected. If (i, j) is an edge in the mesh and directed_edges is True, |
|
then both (i, j) and (j, i) are returned in the list of edges. If (i, j) |
|
is an edge in the mesh and directed_edges is False, then one of (i, j) or |
|
(j, i) is returned. |
|
|
|
Returns: |
|
A tensor of shape [E, 2], where E is the number of edges in the mesh. |
|
|
|
|
|
For eg: given faces = [[0, 1, 2], [0, 1, 3]], then |
|
for directed_edges = False, one valid output is |
|
[[0, 1], [0, 2], [0, 3], [1, 2], [3, 1]] |
|
for directed_edges = True, one valid output is |
|
[[0, 1], [0, 2], [0, 3], [1, 0], [1, 2], [1, 3], |
|
[2, 0], [2, 1], [3, 0], [3, 1]] |
|
|
|
|
|
Raises: |
|
ValueError: If `faces` is not a tensor or if its shape is not |
|
supported. |
|
""" |
|
if not isinstance(faces, tf.Tensor): |
|
raise ValueError("'faces' must be a tf.Tenspr.") |
|
if not faces.dtype.is_integer: |
|
raise ValueError("'faces' must be of integer type") |
|
faces.shape.assert_has_rank(2) |
|
if faces.shape[1] != 3: |
|
raise ValueError("'faces' must have shape [T, 3], got {}".format( |
|
faces.shape)) |
|
rolled_faces = tf.roll(faces, shift=1, axis=1) |
|
# we could make indices by stacking faces and rolled faces |
|
# but unique requires our tensor to be 1D, so we'll ravel the index |
|
# that means there's no need to stack in the first place |
|
i = tf.reshape(faces, (-1, )) |
|
j = tf.reshape(rolled_faces, (-1, )) |
|
ravelled = i * num_vertices + j |
|
unique, _ = tf.unique(ravelled) |
|
indices = tf.unravel_index(unique, (num_vertices, num_vertices)) |
|
indices = tf.transpose(indices, (1, 0)) |
|
if directed_edges: |
|
indices = tf.concat((indices, tf.reverse(indices, axis=[1])), axis=0) |
|
return indices |
|
|
|
|
|
def get_weighted_edges(faces, num_vertices, self_edges=True): |
|
"""Gets unique edges and degree weights from a triangular mesh. |
|
|
|
The shorthands used below are: |
|
`T`: The number of triangles in the mesh. |
|
`E`: The number of unique directed edges in the mesh. |
|
|
|
Args: |
|
faces: A [T, 3] `int32` numpy.ndarray of triangle vertex indices. |
|
self_edges: A `bool` flag. If true, then for every vertex 'i' an edge |
|
[i, i] is added to edge list. |
|
Returns: |
|
edges: A [E, 2] `int32` numpy.ndarray of directed edges. |
|
weights: A [E] `float32` numpy.ndarray denoting edge weights. |
|
|
|
The degree of a vertex is the number of edges incident on the vertex, |
|
including any self-edges. The weight for an edge $w_{ij}$ connecting vertex |
|
$v_i$ and vertex $v_j$ is defined as, |
|
$$ |
|
w_{ij} = 1.0 / degree(v_i) |
|
\sum_{j} w_{ij} = 1 |
|
$$ |
|
""" |
|
edges = extract_unique_edges_from_triangular_mesh(faces, |
|
num_vertices, |
|
directed_edges=True) |
|
if self_edges: |
|
identity = tf.tile( |
|
tf.expand_dims(tf.range(num_vertices, dtype=edges.dtype), axis=-1), |
|
(1, 2)) |
|
edges = tf.concat((edges, identity), axis=0) |
|
_, index, counts = tf.unique_with_counts(edges[:, 0]) |
|
weights = 1. / tf.cast(tf.gather(counts, index), tf.float32) |
|
return edges, weights |
|
|
|
|
|
def _parse_tfex_proto(example_proto): |
|
"""Parses the tfexample proto to a raw mesh_data dictionary. |
|
|
|
Args: |
|
example_proto: A tf.Example proto storing the encoded mesh data. |
|
|
|
Returns: |
|
A mesh data dictionary with the following fields: |
|
'num_vertices': The `int64` number of vertices in mesh. |
|
'num_triangles': The `int64` number of triangles in mesh. |
|
'vertices': A serialized tensor of vertex positions. |
|
'triangles': A serialized tensor of triangle vertex indices. |
|
'labels': A serialized tensor of per vertex class labels. |
|
""" |
|
feature_description = { |
|
'num_vertices': tf.io.FixedLenFeature([], tf.int64, default_value=0), |
|
'num_triangles': tf.io.FixedLenFeature([], tf.int64, default_value=0), |
|
'vertices': tf.io.FixedLenFeature([], tf.string, default_value=''), |
|
'triangles': tf.io.FixedLenFeature([], tf.string, default_value=''), |
|
'labels': tf.io.FixedLenFeature([], tf.string, default_value=''), |
|
} |
|
return tf.io.parse_single_example(serialized=example_proto, |
|
features=feature_description) |
|
|
|
|
|
def _parse_mesh_data(mesh_data, mean_center=True): |
|
"""Parses a raw mesh_data dictionary read from tf examples. |
|
|
|
Args: |
|
mesh_data: A mesh data dictionary with serialized data tensors, |
|
as output from _parse_tfex_proto() |
|
mean_center: If true, centers the mesh vertices to mean(vertices). |
|
Returns: |
|
A mesh data dictionary with following fields: |
|
'vertices': A [V, 3] `float32` of vertex positions. |
|
'labels': A [V] `int32` tensor of per vertex class labels. |
|
'edges': A [E, 2] `int32` tensor of unique directed edges in mesh. |
|
'edge_weights': A [E] `float32` tensor of vertex degree based edge |
|
weights. |
|
""" |
|
labels = tf.io.parse_tensor(mesh_data['labels'], tf.int32) |
|
vertices = tf.io.parse_tensor(mesh_data['vertices'], tf.float32) |
|
triangles = tf.io.parse_tensor(mesh_data['triangles'], tf.int32) |
|
labels.set_shape((None, )) |
|
vertices.set_shape((None, 3)) |
|
triangles.set_shape((None, 3)) |
|
if mean_center: |
|
vertices = vertices - tf.reduce_mean( |
|
input_tensor=vertices, axis=0, keepdims=True) |
|
|
|
edges, weights = get_weighted_edges(triangles, tf.shape(vertices)[0]) |
|
|
|
mesh_data = dict(vertices=vertices, |
|
labels=labels, |
|
edges=edges, |
|
edge_weights=weights) |
|
return mesh_data |
|
|
|
|
|
def get_base_dataset(tfrecords, num_parallel_reads: int = 16): |
|
if not isinstance(tfrecords, list): |
|
tfrecords = [tfrecords] |
|
num_parallel_reads = max(len(tfrecords), num_parallel_reads) |
|
return tf.data.TFRecordDataset(tfrecords, |
|
num_parallel_reads=num_parallel_reads) |
|
|
|
|
|
def preprocess(dataset: tf.data.Dataset, |
|
mean_center: bool = True, |
|
batch_size: int = 8, |
|
shuffle_buffer: int = 100): |
|
def pre_batch_map(example_proto): |
|
mesh_data = _parse_tfex_proto(example_proto) |
|
mesh_data = _parse_mesh_data(mesh_data, mean_center=mean_center) |
|
return mesh_data |
|
|
|
def post_batch_map(vertices, edges, edge_weights, labels): |
|
|
|
vertices, edges = block_diagonalize(vertices, edges) |
|
v = tf.shape(vertices)[0] |
|
neighbors = tf.SparseTensor(tf.cast(edges, tf.int64), |
|
edge_weights.values, (v, v)) |
|
neighbors = tf.sparse.reorder(neighbors) |
|
labels = labels.values |
|
return (vertices, neighbors), labels |
|
|
|
if shuffle_buffer is not None: |
|
dataset = dataset.shuffle(shuffle_buffer) |
|
dataset = dataset.map(pre_batch_map, AUTOTUNE) |
|
dataset = dataset.apply( |
|
tf.data.experimental.dense_to_ragged_batch(batch_size, |
|
row_splits_dtype=tf.int32)) |
|
dataset = dataset.map(lambda kwargs: post_batch_map(**kwargs), AUTOTUNE) |
|
dataset = dataset.prefetch(AUTOTUNE) |
|
return dataset |
|
|
|
|
|
def get_batched_dataset(tfrecords, |
|
batch_size=8, |
|
shuffle_buffer=100, |
|
mean_center=True, |
|
num_parallel_reads=16, |
|
shuffle_files=True): |
|
return preprocess( |
|
get_base_dataset(tfrecords, num_parallel_reads=num_parallel_reads), |
|
batch_size=batch_size, |
|
shuffle_buffer=shuffle_buffer, |
|
mean_center=mean_center, |
|
) |