Skip to content

Instantly share code, notes, and snippets.

@pavanky
Created April 28, 2020 23:56
Show Gist options
  • Save pavanky/180582df8c29eeb7786853f1cfd8617a to your computer and use it in GitHub Desktop.
Save pavanky/180582df8c29eeb7786853f1cfd8617a to your computer and use it in GitHub Desktop.
dummy_keras_estimator.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "dummy_keras_estimator.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyONXy48FZKXIq07rNBST2+W",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/pavanky/180582df8c29eeb7786853f1cfd8617a/dummy_keras_estimator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Eulxi8liGvIN",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "73f275b8-84bb-44a3-91b8-db94c01ef202"
},
"source": [
"import tensorflow as tf\n",
"tf.compat.v1.disable_eager_execution()\n",
"print(tf.__version__)"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"2.2.0-rc3\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "znlX9xswG3ti",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tKqx_hglG7JS",
"colab_type": "code",
"colab": {}
},
"source": [
"TOTAL_BATCHES = 10240\n",
"MAX_CATEGORIES = 2**16\n",
"NUMERICAL_DIMENSION = 64\n",
"EMBEDDING_DIMENSION = 64\n",
"\n",
"# Create dummy data\n",
"CATEGORICAL = np.array([\"word\" + str(np.random.random(MAX_CATEGORIES)) for _ in range(TOTAL_BATCHES)])\n",
"NUMERICAL = np.random.random((TOTAL_BATCHES, NUMERICAL_DIMENSION))\n",
"LABELS = np.random.randint(0, 2, size=(TOTAL_BATCHES))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Xa-QcSXbG9DH",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_feature_columns():\n",
" cat_col = tf.feature_column.categorical_column_with_hash_bucket(\n",
" key=\"categorical\", hash_bucket_size=MAX_CATEGORIES, dtype=tf.dtypes.string,\n",
" )\n",
"\n",
" num_col = tf.feature_column.numeric_column(\n",
" key=\"numerical\", shape=(NUMERICAL_DIMENSION,), dtype=tf.dtypes.float32,\n",
" )\n",
"\n",
" emb_col = tf.feature_column.embedding_column(\n",
" categorical_column=cat_col, dimension=EMBEDDING_DIMENSION, combiner=\"sum\", \n",
" )\n",
" return [emb_col, num_col]\n",
"\n",
"def get_dataset(batch_size, prefetch=4):\n",
" return (tf.data.Dataset.from_tensor_slices((\n",
" {\"categorical\": CATEGORICAL, \"numerical\": NUMERICAL},\n",
" LABELS),\n",
" )\n",
" .repeat().shuffle(buffer_size=batch_size*10)\n",
" .batch(batch_size).prefetch(4))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pvNrG1ykHDMM",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_model(feature_columns):\n",
" return tf.keras.Sequential([\n",
" tf.keras.layers.DenseFeatures(feature_columns),\n",
" tf.keras.layers.Dense(64, activation=\"relu\"),\n",
" tf.keras.layers.Dense(64, activation=\"relu\"),\n",
" tf.keras.layers.Dense(1)\n",
" ])\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "7G4VmGyxHFbx",
"colab_type": "code",
"colab": {}
},
"source": [
"def model_fn(features, labels, mode, params, config=None):\n",
" model = get_model(params['feature_columns'])\n",
" logits = model(features)\n",
" \n",
" loss = None\n",
" train_op =None\n",
" \n",
" if mode != \"infer\":\n",
" logits = tf.squeeze(logits)\n",
" loss = tf.keras.losses.binary_crossentropy(labels, logits, from_logits=True)\n",
" update_ops = model.get_updates_for(None) + model.get_updates_for(features)\n",
" accuracy = tf.compat.v1.metrics.accuracy(\n",
" labels=labels, predictions=logits)\n",
" optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=1E-3)\n",
" minimize_op = optimizer.minimize(\n",
" loss,\n",
" var_list=model.trainable_variables,\n",
" global_step=tf.compat.v1.train.get_or_create_global_step())\n",
" train_op = tf.group(minimize_op, update_ops)\n",
"\n",
" return tf.estimator.EstimatorSpec(\n",
" mode=mode,\n",
" predictions=logits,\n",
" loss=loss,\n",
" train_op=train_op,\n",
" eval_metric_ops={\"accuracy\": accuracy},\n",
" )"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KzoEsJC_HXPY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "bccc2d93-ed93-4ca5-a8e5-069fc8645f4d"
},
"source": [
"estimator = tf.compat.v1.estimator.Estimator(\n",
" model_fn, params={\"feature_columns\": get_feature_columns()},\n",
")\n",
"\n",
"num_epochs = 2\n",
"num_steps = 1024\n",
"estimator.train(input_fn=lambda: get_dataset(256), steps=num_steps * num_epochs)"
],
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpz5hjlwg9\n",
"INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpz5hjlwg9', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
"graph_options {\n",
" rewrite_options {\n",
" meta_optimizer_iterations: ONE\n",
" }\n",
"}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
"INFO:tensorflow:Calling model_fn.\n",
"WARNING:tensorflow:Layer dense_features is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.\n",
"\n",
"If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
"\n",
"To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
"\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...\n",
"INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpz5hjlwg9/model.ckpt.\n",
"INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...\n",
"INFO:tensorflow:loss = 0.71594024, step = 0\n",
"INFO:tensorflow:global_step/sec: 385.416\n",
"INFO:tensorflow:loss = 0.68732035, step = 100 (0.260 sec)\n",
"INFO:tensorflow:global_step/sec: 482.751\n",
"INFO:tensorflow:loss = 0.7127993, step = 200 (0.207 sec)\n",
"INFO:tensorflow:global_step/sec: 494.241\n",
"INFO:tensorflow:loss = 0.6979073, step = 300 (0.205 sec)\n",
"INFO:tensorflow:global_step/sec: 491.27\n",
"INFO:tensorflow:loss = 0.6936704, step = 400 (0.203 sec)\n",
"INFO:tensorflow:global_step/sec: 477.82\n",
"INFO:tensorflow:loss = 0.69902116, step = 500 (0.209 sec)\n",
"INFO:tensorflow:global_step/sec: 503.881\n",
"INFO:tensorflow:loss = 0.69956505, step = 600 (0.199 sec)\n",
"INFO:tensorflow:global_step/sec: 483.219\n",
"INFO:tensorflow:loss = 0.69008255, step = 700 (0.206 sec)\n",
"INFO:tensorflow:global_step/sec: 498.528\n",
"INFO:tensorflow:loss = 0.6868509, step = 800 (0.202 sec)\n",
"INFO:tensorflow:global_step/sec: 497.114\n",
"INFO:tensorflow:loss = 0.6931174, step = 900 (0.201 sec)\n",
"INFO:tensorflow:global_step/sec: 494.511\n",
"INFO:tensorflow:loss = 0.69813424, step = 1000 (0.201 sec)\n",
"INFO:tensorflow:global_step/sec: 500.12\n",
"INFO:tensorflow:loss = 0.69822085, step = 1100 (0.199 sec)\n",
"INFO:tensorflow:global_step/sec: 491.584\n",
"INFO:tensorflow:loss = 0.69584596, step = 1200 (0.203 sec)\n",
"INFO:tensorflow:global_step/sec: 452.504\n",
"INFO:tensorflow:loss = 0.6982038, step = 1300 (0.221 sec)\n",
"INFO:tensorflow:global_step/sec: 479.404\n",
"INFO:tensorflow:loss = 0.69895667, step = 1400 (0.209 sec)\n",
"INFO:tensorflow:global_step/sec: 436.022\n",
"INFO:tensorflow:loss = 0.7051493, step = 1500 (0.236 sec)\n",
"INFO:tensorflow:global_step/sec: 443.791\n",
"INFO:tensorflow:loss = 0.70734096, step = 1600 (0.219 sec)\n",
"INFO:tensorflow:global_step/sec: 473.623\n",
"INFO:tensorflow:loss = 0.6926912, step = 1700 (0.212 sec)\n",
"INFO:tensorflow:global_step/sec: 507.997\n",
"INFO:tensorflow:loss = 0.6907356, step = 1800 (0.195 sec)\n",
"INFO:tensorflow:global_step/sec: 509.095\n",
"INFO:tensorflow:loss = 0.69566274, step = 1900 (0.196 sec)\n",
"INFO:tensorflow:global_step/sec: 497.72\n",
"INFO:tensorflow:loss = 0.70572984, step = 2000 (0.201 sec)\n",
"INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2048...\n",
"INFO:tensorflow:Saving checkpoints for 2048 into /tmp/tmpz5hjlwg9/model.ckpt.\n",
"INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2048...\n",
"INFO:tensorflow:Loss for final step: 0.6977092.\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f62025d08d0>"
]
},
"metadata": {
"tags": []
},
"execution_count": 33
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment