Skip to content

Instantly share code, notes, and snippets.

@zmjjmz
Last active June 24, 2020 17:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zmjjmz/dcfcb889c88115347de691b69067bff1 to your computer and use it in GitHub Desktop.
Save zmjjmz/dcfcb889c88115347de691b69067bff1 to your computer and use it in GitHub Desktop.
tabnet no variation
import os
import shutil
import tensorflow as tf
import tensorflow_datasets as tfds
import tabnet
import pandas as pd
import numpy as np
train_size = 125
BATCH_SIZE = 50
def analyze_decision_steps(tabnet, feature_slices):
# let's try to understand the aggregate feature selection
masks = tabnet.feature_selection_masks
num_steps = len(masks)
# mask is 1 x batch_size x feature_size x 1
attention_df = []
for decision_step in range(num_steps):
step_mask = masks[decision_step].numpy()
print(step_mask.shape)
for feature, index_slice in feature_slices.items():
# we'll do the sum here for the embeddings because we want to see how much time was spent
# attending to a specific feature column, even if that feature column is an embedding
total_attention = np.sum(step_mask[:,:,index_slice,:], axis=2)
attention_df.append({
'step':str(decision_step),
'feature':feature,
'mean_attention':np.squeeze(np.mean(total_attention, axis=1)),
'std_attention':np.squeeze(np.std(total_attention, axis=1))
})
attention_df = pd.DataFrame.from_records(attention_df).sort_values(by=['feature','step'])
return attention_df
def transform(ds):
features = tf.unstack(ds['features'])
labels = ds['label']
x = dict(zip(col_names, features))
y = tf.one_hot(labels, 3)
return x, y
col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
ds_full = tfds.load(name="iris", split=tfds.Split.TRAIN)
ds_full = ds_full.shuffle(150, seed=0)
ds_train = ds_full.take(train_size)
ds_train = ds_train.map(transform)
ds_train = ds_train.batch(BATCH_SIZE)
ds_test = ds_full.skip(train_size)
ds_test = ds_test.map(transform)
ds_test = ds_test.batch(BATCH_SIZE)
feature_columns = []
for col_name in col_names:
feature_columns.append(tf.feature_column.numeric_column(col_name))
# Group Norm does better for small datasets
model = tabnet.TabNetClassifier(feature_columns, num_classes=3,
feature_dim=4, output_dim=4,
num_decision_steps=4, relaxation_factor=1.0,
sparsity_coefficient=1e-5, batch_momentum=0.98,
virtual_batch_size=None, norm_type='group',
num_groups=1)
lr = tf.keras.optimizers.schedules.ExponentialDecay(0.01, decay_steps=100, decay_rate=0.9, staircase=False)
optimizer = tf.keras.optimizers.Adam(lr)
model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(ds_train, epochs=100, validation_data=ds_test, verbose=2)
model.summary()
print()
if os.path.exists('logs/'):
shutil.rmtree('logs/')
""" Save the images of the feature masks """
# Force eager execution mode to generate the masks
x, y = next(iter(ds_train))
_ = model(x)
iris_attention_df = analyze_decision_steps(model.tabnet,
{col_name:slice(i, i+1) for i, col_name in enumerate(col_names)})
print(iris_attention_df.sort_values(by=['feature','step']))
writer = tf.summary.create_file_writer("logs/")
with writer.as_default():
for i, mask in enumerate(model.tabnet.feature_selection_masks):
print("Saving mask {} of shape {}".format(i + 1, mask.shape))
tf.summary.image('mask_at_iter_{}'.format(i + 1), step=0, data=mask, max_outputs=1)
writer.flush()
agg_mask = model.tabnet.aggregate_feature_selection_mask
print("Saving aggregate mask of shape", agg_mask.shape)
tf.summary.image("Aggregate Mask", step=0, data=agg_mask, max_outputs=1)
writer.flush()
writer.close()
2020-06-24 13:48:18.277333: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer.so.6'; dlerror: libnvinfer.so.6: cannot open shared object file: No such file or directory
2020-06-24 13:48:18.277428: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libnvinfer_plugin.so.6'; dlerror: libnvinfer_plugin.so.6: cannot open shared object file: No such file or directory
2020-06-24 13:48:18.277438: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:30] Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
2020-06-24 13:48:21.230246: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2020-06-24 13:48:21.230278: E tensorflow/stream_executor/cuda/cuda_driver.cc:351] failed call to cuInit: UNKNOWN ERROR (303)
2020-06-24 13:48:21.230300: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (wa1okdba002): /proc/driver/nvidia/version does not exist
2020-06-24 13:48:21.230554: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-06-24 13:48:21.241574: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2297445000 Hz
2020-06-24 13:48:21.244607: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55ed8ea55730 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-06-24 13:48:21.244631: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
2020-06-24 13:48:29.981937: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
2020-06-24 13:48:31.095752: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 1/100
3/3 - 7s - loss: 1.2241 - accuracy: 0.3280 - val_loss: 0.9655 - val_accuracy: 0.3600
2020-06-24 13:48:31.218065: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 2/100
3/3 - 0s - loss: 1.0026 - accuracy: 0.3200 - val_loss: 0.8839 - val_accuracy: 0.4400
2020-06-24 13:48:31.320807: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 3/100
3/3 - 0s - loss: 0.8767 - accuracy: 0.4880 - val_loss: 0.9062 - val_accuracy: 0.4800
2020-06-24 13:48:31.421993: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 4/100
3/3 - 0s - loss: 0.8091 - accuracy: 0.5680 - val_loss: 0.6892 - val_accuracy: 0.6800
2020-06-24 13:48:31.522592: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 5/100
3/3 - 0s - loss: 0.7398 - accuracy: 0.6000 - val_loss: 0.7740 - val_accuracy: 0.4800
2020-06-24 13:48:31.621376: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 6/100
3/3 - 0s - loss: 0.6998 - accuracy: 0.5920 - val_loss: 0.5235 - val_accuracy: 0.8000
2020-06-24 13:48:31.718801: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 7/100
3/3 - 0s - loss: 0.6680 - accuracy: 0.6000 - val_loss: 0.6909 - val_accuracy: 0.5200
2020-06-24 13:48:31.816848: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 8/100
3/3 - 0s - loss: 0.6004 - accuracy: 0.6080 - val_loss: 0.6326 - val_accuracy: 0.5600
2020-06-24 13:48:31.915430: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 9/100
3/3 - 0s - loss: 0.6188 - accuracy: 0.5600 - val_loss: 0.5685 - val_accuracy: 0.6800
2020-06-24 13:48:32.011544: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 10/100
3/3 - 0s - loss: 0.5498 - accuracy: 0.5920 - val_loss: 0.5441 - val_accuracy: 0.5600
2020-06-24 13:48:32.109571: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 11/100
3/3 - 0s - loss: 0.4918 - accuracy: 0.6400 - val_loss: 0.4846 - val_accuracy: 0.6000
2020-06-24 13:48:32.204874: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 12/100
3/3 - 0s - loss: 0.5593 - accuracy: 0.5840 - val_loss: 0.3896 - val_accuracy: 0.6800
2020-06-24 13:48:32.305525: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 13/100
3/3 - 0s - loss: 0.5047 - accuracy: 0.6160 - val_loss: 0.4087 - val_accuracy: 0.6400
2020-06-24 13:48:32.402713: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 14/100
3/3 - 0s - loss: 0.4755 - accuracy: 0.6320 - val_loss: 0.4403 - val_accuracy: 0.6800
2020-06-24 13:48:32.502776: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 15/100
3/3 - 0s - loss: 0.4525 - accuracy: 0.6000 - val_loss: 0.5042 - val_accuracy: 0.5600
2020-06-24 13:48:32.597086: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 16/100
3/3 - 0s - loss: 0.4178 - accuracy: 0.6400 - val_loss: 0.2869 - val_accuracy: 0.6800
2020-06-24 13:48:32.694386: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 17/100
3/3 - 0s - loss: 0.3787 - accuracy: 0.7120 - val_loss: 0.2880 - val_accuracy: 1.0000
2020-06-24 13:48:32.789293: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 18/100
3/3 - 0s - loss: 0.3517 - accuracy: 0.9360 - val_loss: 0.3167 - val_accuracy: 0.9600
2020-06-24 13:48:32.890152: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 19/100
3/3 - 0s - loss: 0.2876 - accuracy: 0.9600 - val_loss: 0.2135 - val_accuracy: 1.0000
2020-06-24 13:48:32.977722: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 20/100
3/3 - 0s - loss: 0.2497 - accuracy: 0.9680 - val_loss: 0.3315 - val_accuracy: 0.9200
2020-06-24 13:48:33.073054: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 21/100
3/3 - 0s - loss: 0.1404 - accuracy: 0.9840 - val_loss: 0.1453 - val_accuracy: 0.9600
2020-06-24 13:48:33.172585: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 22/100
3/3 - 0s - loss: 0.1532 - accuracy: 0.9680 - val_loss: 0.0990 - val_accuracy: 0.9600
2020-06-24 13:48:33.271512: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 23/100
3/3 - 0s - loss: 0.1668 - accuracy: 0.9600 - val_loss: 0.2982 - val_accuracy: 0.9200
2020-06-24 13:48:33.380562: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 24/100
3/3 - 0s - loss: 0.2369 - accuracy: 0.9440 - val_loss: 0.0913 - val_accuracy: 0.9600
2020-06-24 13:48:33.477740: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 25/100
3/3 - 0s - loss: 0.1404 - accuracy: 0.9600 - val_loss: 0.0528 - val_accuracy: 1.0000
2020-06-24 13:48:33.580174: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 26/100
3/3 - 0s - loss: 0.1242 - accuracy: 0.9600 - val_loss: 0.1270 - val_accuracy: 0.9600
2020-06-24 13:48:33.682646: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 27/100
3/3 - 0s - loss: 0.1571 - accuracy: 0.9360 - val_loss: 0.1864 - val_accuracy: 0.9200
2020-06-24 13:48:33.782106: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 28/100
3/3 - 0s - loss: 0.1119 - accuracy: 0.9680 - val_loss: 0.3250 - val_accuracy: 0.8800
2020-06-24 13:48:33.876085: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 29/100
3/3 - 0s - loss: 0.1152 - accuracy: 0.9680 - val_loss: 0.0496 - val_accuracy: 1.0000
2020-06-24 13:48:33.974056: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 30/100
3/3 - 0s - loss: 0.1179 - accuracy: 0.9520 - val_loss: 0.0738 - val_accuracy: 0.9600
2020-06-24 13:48:34.068217: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 31/100
3/3 - 0s - loss: 0.1135 - accuracy: 0.9440 - val_loss: 0.0312 - val_accuracy: 1.0000
2020-06-24 13:48:34.164732: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 32/100
3/3 - 0s - loss: 0.1171 - accuracy: 0.9680 - val_loss: 0.0429 - val_accuracy: 1.0000
2020-06-24 13:48:34.268862: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 33/100
3/3 - 0s - loss: 0.1743 - accuracy: 0.9200 - val_loss: 0.0938 - val_accuracy: 0.9600
2020-06-24 13:48:34.363582: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 34/100
3/3 - 0s - loss: 0.1538 - accuracy: 0.9520 - val_loss: 0.0527 - val_accuracy: 1.0000
2020-06-24 13:48:34.456848: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 35/100
3/3 - 0s - loss: 0.0892 - accuracy: 0.9680 - val_loss: 0.2995 - val_accuracy: 0.8800
2020-06-24 13:48:34.555958: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 36/100
3/3 - 0s - loss: 0.1701 - accuracy: 0.9120 - val_loss: 0.0839 - val_accuracy: 1.0000
2020-06-24 13:48:34.649702: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 37/100
3/3 - 0s - loss: 0.1243 - accuracy: 0.9600 - val_loss: 0.2027 - val_accuracy: 0.9600
2020-06-24 13:48:34.743426: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 38/100
3/3 - 0s - loss: 0.1254 - accuracy: 0.9360 - val_loss: 0.1850 - val_accuracy: 0.9600
2020-06-24 13:48:34.840332: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 39/100
3/3 - 0s - loss: 0.1060 - accuracy: 0.9680 - val_loss: 0.0290 - val_accuracy: 1.0000
2020-06-24 13:48:34.940817: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 40/100
3/3 - 0s - loss: 0.1042 - accuracy: 0.9760 - val_loss: 0.1057 - val_accuracy: 0.9600
2020-06-24 13:48:35.040291: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 41/100
3/3 - 0s - loss: 0.1283 - accuracy: 0.9360 - val_loss: 0.2976 - val_accuracy: 0.8800
2020-06-24 13:48:35.135240: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 42/100
3/3 - 0s - loss: 0.1302 - accuracy: 0.9600 - val_loss: 0.1171 - val_accuracy: 0.9600
2020-06-24 13:48:35.224268: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 43/100
3/3 - 0s - loss: 0.1049 - accuracy: 0.9600 - val_loss: 0.0311 - val_accuracy: 1.0000
2020-06-24 13:48:35.323427: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 44/100
3/3 - 0s - loss: 0.1173 - accuracy: 0.9680 - val_loss: 0.0446 - val_accuracy: 1.0000
2020-06-24 13:48:35.424129: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 45/100
3/3 - 0s - loss: 0.0843 - accuracy: 0.9840 - val_loss: 0.1678 - val_accuracy: 0.9600
2020-06-24 13:48:35.522035: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 46/100
3/3 - 0s - loss: 0.0965 - accuracy: 0.9680 - val_loss: 0.1145 - val_accuracy: 0.9600
2020-06-24 13:48:35.625074: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 47/100
3/3 - 0s - loss: 0.0878 - accuracy: 0.9760 - val_loss: 0.0687 - val_accuracy: 1.0000
2020-06-24 13:48:35.721504: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 48/100
3/3 - 0s - loss: 0.0938 - accuracy: 0.9680 - val_loss: 0.1954 - val_accuracy: 0.9600
2020-06-24 13:48:35.813329: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 49/100
3/3 - 0s - loss: 0.1076 - accuracy: 0.9680 - val_loss: 0.0943 - val_accuracy: 0.9600
2020-06-24 13:48:35.911998: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 50/100
3/3 - 0s - loss: 0.1780 - accuracy: 0.9280 - val_loss: 0.2292 - val_accuracy: 0.8800
2020-06-24 13:48:36.004937: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 51/100
3/3 - 0s - loss: 0.1263 - accuracy: 0.9440 - val_loss: 0.0801 - val_accuracy: 1.0000
2020-06-24 13:48:36.097917: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 52/100
3/3 - 0s - loss: 0.1325 - accuracy: 0.9440 - val_loss: 0.1316 - val_accuracy: 0.9600
2020-06-24 13:48:36.196925: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 53/100
3/3 - 0s - loss: 0.1236 - accuracy: 0.9520 - val_loss: 0.2554 - val_accuracy: 0.8400
2020-06-24 13:48:36.292070: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 54/100
3/3 - 0s - loss: 0.1060 - accuracy: 0.9680 - val_loss: 0.1010 - val_accuracy: 0.9600
2020-06-24 13:48:36.391013: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 55/100
3/3 - 0s - loss: 0.0941 - accuracy: 0.9600 - val_loss: 0.0213 - val_accuracy: 1.0000
2020-06-24 13:48:36.491951: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 56/100
3/3 - 0s - loss: 0.0910 - accuracy: 0.9680 - val_loss: 0.1553 - val_accuracy: 0.9200
2020-06-24 13:48:36.595972: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 57/100
3/3 - 0s - loss: 0.0810 - accuracy: 0.9680 - val_loss: 0.1236 - val_accuracy: 0.9200
2020-06-24 13:48:36.692941: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 58/100
3/3 - 0s - loss: 0.0901 - accuracy: 0.9680 - val_loss: 0.0154 - val_accuracy: 1.0000
2020-06-24 13:48:36.788585: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 59/100
3/3 - 0s - loss: 0.1351 - accuracy: 0.9520 - val_loss: 0.2164 - val_accuracy: 0.9200
2020-06-24 13:48:36.879246: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 60/100
3/3 - 0s - loss: 0.0805 - accuracy: 0.9600 - val_loss: 0.0326 - val_accuracy: 1.0000
2020-06-24 13:48:36.964945: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 61/100
3/3 - 0s - loss: 0.0847 - accuracy: 0.9680 - val_loss: 0.0763 - val_accuracy: 0.9600
2020-06-24 13:48:37.060777: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 62/100
3/3 - 0s - loss: 0.0855 - accuracy: 0.9680 - val_loss: 0.1844 - val_accuracy: 0.9200
2020-06-24 13:48:37.153086: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 63/100
3/3 - 0s - loss: 0.0731 - accuracy: 0.9840 - val_loss: 0.0590 - val_accuracy: 1.0000
2020-06-24 13:48:37.249702: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 64/100
3/3 - 0s - loss: 0.0866 - accuracy: 0.9680 - val_loss: 0.2506 - val_accuracy: 0.9200
2020-06-24 13:48:37.346924: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 65/100
3/3 - 0s - loss: 0.1085 - accuracy: 0.9680 - val_loss: 0.1177 - val_accuracy: 0.9600
2020-06-24 13:48:37.450567: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 66/100
3/3 - 0s - loss: 0.0753 - accuracy: 0.9760 - val_loss: 0.1381 - val_accuracy: 0.9600
2020-06-24 13:48:37.545836: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 67/100
3/3 - 0s - loss: 0.0545 - accuracy: 0.9920 - val_loss: 0.1071 - val_accuracy: 0.9600
2020-06-24 13:48:37.644233: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 68/100
3/3 - 0s - loss: 0.1187 - accuracy: 0.9600 - val_loss: 0.1174 - val_accuracy: 0.9600
2020-06-24 13:48:37.740285: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 69/100
3/3 - 0s - loss: 0.1037 - accuracy: 0.9600 - val_loss: 0.0166 - val_accuracy: 1.0000
2020-06-24 13:48:37.834712: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 70/100
3/3 - 0s - loss: 0.0888 - accuracy: 0.9760 - val_loss: 0.0078 - val_accuracy: 1.0000
2020-06-24 13:48:37.934857: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 71/100
3/3 - 0s - loss: 0.1493 - accuracy: 0.9440 - val_loss: 0.0207 - val_accuracy: 1.0000
2020-06-24 13:48:38.023489: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 72/100
3/3 - 0s - loss: 0.1086 - accuracy: 0.9600 - val_loss: 0.2320 - val_accuracy: 0.9200
2020-06-24 13:48:38.119768: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 73/100
3/3 - 0s - loss: 0.0993 - accuracy: 0.9680 - val_loss: 0.0498 - val_accuracy: 1.0000
2020-06-24 13:48:38.219272: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 74/100
3/3 - 0s - loss: 0.0634 - accuracy: 0.9840 - val_loss: 0.1018 - val_accuracy: 0.9600
2020-06-24 13:48:38.313702: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 75/100
3/3 - 0s - loss: 0.0460 - accuracy: 0.9840 - val_loss: 0.1757 - val_accuracy: 0.9600
2020-06-24 13:48:38.411229: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 76/100
3/3 - 0s - loss: 0.0626 - accuracy: 0.9760 - val_loss: 0.0904 - val_accuracy: 0.9600
2020-06-24 13:48:38.506356: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 77/100
3/3 - 0s - loss: 0.1134 - accuracy: 0.9520 - val_loss: 0.0111 - val_accuracy: 1.0000
2020-06-24 13:48:38.598625: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 78/100
3/3 - 0s - loss: 0.1303 - accuracy: 0.9520 - val_loss: 0.0933 - val_accuracy: 0.9600
2020-06-24 13:48:38.691488: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 79/100
3/3 - 0s - loss: 0.1150 - accuracy: 0.9520 - val_loss: 0.1235 - val_accuracy: 0.9200
2020-06-24 13:48:38.783369: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 80/100
3/3 - 0s - loss: 0.1316 - accuracy: 0.9440 - val_loss: 0.1862 - val_accuracy: 0.9600
2020-06-24 13:48:38.875318: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 81/100
3/3 - 0s - loss: 0.1370 - accuracy: 0.9520 - val_loss: 0.1054 - val_accuracy: 0.9600
2020-06-24 13:48:38.971616: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 82/100
3/3 - 0s - loss: 0.1179 - accuracy: 0.9600 - val_loss: 0.1493 - val_accuracy: 0.9600
2020-06-24 13:48:39.058560: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 83/100
3/3 - 0s - loss: 0.1352 - accuracy: 0.9440 - val_loss: 0.1442 - val_accuracy: 0.9600
2020-06-24 13:48:39.151436: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 84/100
3/3 - 0s - loss: 0.0900 - accuracy: 0.9760 - val_loss: 0.0424 - val_accuracy: 1.0000
2020-06-24 13:48:39.244804: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 85/100
3/3 - 0s - loss: 0.0921 - accuracy: 0.9680 - val_loss: 0.1253 - val_accuracy: 0.9600
2020-06-24 13:48:39.341263: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 86/100
3/3 - 0s - loss: 0.0874 - accuracy: 0.9520 - val_loss: 0.1412 - val_accuracy: 0.9600
2020-06-24 13:48:39.432252: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 87/100
3/3 - 0s - loss: 0.0958 - accuracy: 0.9600 - val_loss: 0.1112 - val_accuracy: 0.9600
2020-06-24 13:48:39.527059: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 88/100
3/3 - 0s - loss: 0.0648 - accuracy: 0.9840 - val_loss: 0.1341 - val_accuracy: 0.9600
2020-06-24 13:48:39.619279: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 89/100
3/3 - 0s - loss: 0.1251 - accuracy: 0.9520 - val_loss: 0.0090 - val_accuracy: 1.0000
2020-06-24 13:48:39.730230: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 90/100
3/3 - 0s - loss: 0.1153 - accuracy: 0.9680 - val_loss: 0.3754 - val_accuracy: 0.8800
2020-06-24 13:48:39.826340: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 91/100
3/3 - 0s - loss: 0.0993 - accuracy: 0.9600 - val_loss: 0.1535 - val_accuracy: 0.9600
2020-06-24 13:48:39.920253: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 92/100
3/3 - 0s - loss: 0.0911 - accuracy: 0.9760 - val_loss: 0.0518 - val_accuracy: 0.9600
2020-06-24 13:48:40.011077: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 93/100
3/3 - 0s - loss: 0.0926 - accuracy: 0.9760 - val_loss: 0.0757 - val_accuracy: 0.9600
2020-06-24 13:48:40.100091: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 94/100
3/3 - 0s - loss: 0.1069 - accuracy: 0.9520 - val_loss: 0.0425 - val_accuracy: 1.0000
2020-06-24 13:48:40.195484: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 95/100
3/3 - 0s - loss: 0.1219 - accuracy: 0.9760 - val_loss: 0.0731 - val_accuracy: 0.9600
2020-06-24 13:48:40.293798: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 96/100
3/3 - 0s - loss: 0.0924 - accuracy: 0.9680 - val_loss: 0.0398 - val_accuracy: 1.0000
2020-06-24 13:48:40.386873: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 97/100
3/3 - 0s - loss: 0.0957 - accuracy: 0.9600 - val_loss: 0.0937 - val_accuracy: 0.9600
2020-06-24 13:48:40.478837: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 98/100
3/3 - 0s - loss: 0.0856 - accuracy: 0.9760 - val_loss: 0.1121 - val_accuracy: 0.9600
2020-06-24 13:48:40.576556: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 99/100
3/3 - 0s - loss: 0.0870 - accuracy: 0.9600 - val_loss: 0.0763 - val_accuracy: 0.9600
2020-06-24 13:48:40.670007: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
[[{{node IteratorGetNext}}]]
Epoch 100/100
3/3 - 0s - loss: 0.1152 - accuracy: 0.9520 - val_loss: 0.0284 - val_accuracy: 1.0000
Model: "tab_net_classifier"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
tab_net (TabNet) multiple 512
_________________________________________________________________
classifier (Dense) multiple 12
=================================================================
Total params: 524
Trainable params: 524
Non-trainable params: 0
_________________________________________________________________
Model: "tab_net"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_features (DenseFeature multiple 0
_________________________________________________________________
input_gn (GroupNormalization multiple 8
_________________________________________________________________
transform_block (TransformBl multiple 48
_________________________________________________________________
transform_block_1 (Transform multiple 48
_________________________________________________________________
transform_block_2 (Transform multiple 48
_________________________________________________________________
transform_block_3 (Transform multiple 48
_________________________________________________________________
transform_block_4 (Transform multiple 48
_________________________________________________________________
transform_block_5 (Transform multiple 48
_________________________________________________________________
transform_block_6 (Transform multiple 48
_________________________________________________________________
transform_block_7 (Transform multiple 48
_________________________________________________________________
transform_block_8 (Transform multiple 48
_________________________________________________________________
transform_block_9 (Transform multiple 48
_________________________________________________________________
transform_block_10 (Transfor multiple 8
_________________________________________________________________
transform_block_11 (Transfor multiple 8
_________________________________________________________________
transform_block_12 (Transfor multiple 8
=================================================================
Total params: 512
Trainable params: 512
Non-trainable params: 0
_________________________________________________________________
(1, 50, 4, 1)
(1, 50, 4, 1)
(1, 50, 4, 1)
step feature mean_attention std_attention
2 0 petal_length 0.2959992 0.0
6 1 petal_length 0.2555979 0.0
10 2 petal_length 0.23733087 1.4901161e-08
3 0 petal_width 0.21493916 0.0
7 1 petal_width 0.3252291 0.0
11 2 petal_width 0.29720604 2.9802322e-08
0 0 sepal_length 0.2985988 2.9802322e-08
4 1 sepal_length 0.21767889 0.0
8 2 sepal_length 0.24685532 0.0
1 0 sepal_width 0.19046284 1.4901161e-08
5 1 sepal_width 0.20149413 0.0
9 2 sepal_width 0.21860771 0.0
Saving mask 1 of shape (1, 50, 4, 1)
Saving mask 2 of shape (1, 50, 4, 1)
Saving mask 3 of shape (1, 50, 4, 1)
Saving aggregate mask of shape (1, 50, 4, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment