Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Created July 13, 2023 09:40
Show Gist options
  • Save ariG23498/b8b4c0912a0a19dfe2ef8b29b3160943 to your computer and use it in GitHub Desktop.
Save ariG23498/b8b4c0912a0a19dfe2ef8b29b3160943 to your computer and use it in GitHub Desktop.
s2s.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"toc_visible": true,
"machine_shape": "hm",
"gpuType": "V100",
"authorship_tag": "ABX9TyPw9oAVuXYy8M35S5dMFK9e",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/b8b4c0912a0a19dfe2ef8b29b3160943/s2s.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"## Download the dataset"
],
"metadata": {
"id": "zHMPQhKiAqR5"
}
},
{
"cell_type": "code",
"source": [
"!wget http://www.manythings.org/anki/fra-eng.zip\n",
"!unzip fra-eng.zip"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "74ioFGuBlF3N",
"outputId": "8dc4a93a-3d2a-4337-da10-1b147c9a8bbd"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"--2023-07-13 09:39:37-- http://www.manythings.org/anki/fra-eng.zip\n",
"Resolving www.manythings.org (www.manythings.org)... 173.254.30.110\n",
"Connecting to www.manythings.org (www.manythings.org)|173.254.30.110|:80... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 7420323 (7.1M) [application/zip]\n",
"Saving to: ‘fra-eng.zip’\n",
"\n",
"fra-eng.zip 100%[===================>] 7.08M 6.00MB/s in 1.2s \n",
"\n",
"2023-07-13 09:39:39 (6.00 MB/s) - ‘fra-eng.zip’ saved [7420323/7420323]\n",
"\n",
"Archive: fra-eng.zip\n",
" inflating: _about.txt \n",
" inflating: fra.txt \n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Imports and Setups"
],
"metadata": {
"id": "EN5FBnnGSv_D"
}
},
{
"cell_type": "code",
"source": [
"# !pip install -q wandb\n",
"!pip install -q keras-core"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gCNwvMzMSnvy",
"outputId": "64c15765-be25-4eb9-a599-2d69213740b5"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/728.0 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m728.0/728.0 kB\u001b[0m \u001b[31m27.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"source": [
"# !wandb login"
],
"metadata": {
"id": "VmQxlrAgsGBb"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import os\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"\n",
"import keras_core as keras\n",
"\n",
"import numpy as np\n",
"# import wandb"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "CbqUfljTRfP7",
"outputId": "889e42f6-eca5-4737-9c6c-eeb4cafb9ebb"
},
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Using TensorFlow backend\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(f\"{keras.backend.backend()=}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZHIL5PloXFeL",
"outputId": "d23fba24-ca37-4d3c-c071-c515ac2c6678"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"keras.backend.backend()='tensorflow'\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Configurations"
],
"metadata": {
"id": "dPMOTsz5BuT1"
}
},
{
"cell_type": "code",
"source": [
"batch_size = 64 # Batch size for training.\n",
"epochs = 100 # Number of epochs to train for.\n",
"latent_dim = 256 # Latent dimensionality of the encoding space.\n",
"num_samples = 10000 # Number of samples to train on.\n",
"# Path to the data txt file on disk.\n",
"data_path = \"fra.txt\"\n",
"\n",
"print(f\"{data_path=}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "J0S8tb7WRkDM",
"outputId": "5e6bc2e5-ed9f-47d0-ac90-b57c7dd37a8a"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"data_path='fra.txt'\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Dataset"
],
"metadata": {
"id": "BCZdVuReBw0W"
}
},
{
"cell_type": "code",
"source": [
"# Vectorize the data.\n",
"input_texts = []\n",
"target_texts = []\n",
"\n",
"input_characters = set()\n",
"target_characters = set()\n",
"\n",
"with open(data_path, \"r\", encoding=\"utf-8\") as f:\n",
" lines = f.read().split(\"\\n\")"
],
"metadata": {
"id": "ObqzdjgyRqzf"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the first 5 samples\n",
"for line in lines[:5]:\n",
" print(line)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "66JFApr1lsA2",
"outputId": "e53285f5-b009-4ad0-d910-07dfc7efe22f"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Go.\tVa !\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #1158250 (Wittydev)\n",
"Go.\tMarche.\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #8090732 (Micsmithel)\n",
"Go.\tEn route !\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #8267435 (felix63)\n",
"Go.\tBouge !\tCC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #9022935 (Micsmithel)\n",
"Hi.\tSalut !\tCC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #509819 (Aiji)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import random\n",
"\n",
"random.shuffle(lines)"
],
"metadata": {
"id": "eEs3AxNXB--t"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Print the first 5 samples\n",
"for line in lines[:5]:\n",
" print(line)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "f0VUpQ-KCNbF",
"outputId": "bf4b4255-aaf8-48ec-f01e-6d55b18176fb"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"I found Tom's diary.\tJ'ai trouvé le journal intime de Tom.\tCC-BY 2.0 (France) Attribution: tatoeba.org #2327230 (CK) & #5396516 (pititnatole)\n",
"The people in the room all know one another.\tLes personnes dans la salle se connaissent toutes.\tCC-BY 2.0 (France) Attribution: tatoeba.org #44226 (CK) & #11274105 (lbdx)\n",
"How old is this zoo?\tQuel âge a ce zoo ?\tCC-BY 2.0 (France) Attribution: tatoeba.org #436249 (lukaszpp) & #590081 (qdii)\n",
"Do you give lessons?\tEnseignes-tu ?\tCC-BY 2.0 (France) Attribution: tatoeba.org #3151577 (CK) & #7581847 (Micsmithel)\n",
"We haven't lost much.\tNous n'avons pas beaucoup perdu.\tCC-BY 2.0 (France) Attribution: tatoeba.org #5789467 (CK) & #5800044 (Toynop)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"for line in lines[: min(num_samples, len(lines) - 1)]:\n",
" input_text, target_text, _ = line.split(\"\\t\")\n",
" # We use \"tab\" as the \"start sequence\" character\n",
" # for the targets, and \"\\n\" as \"end sequence\" character.\n",
" target_text = \"\\t\" + target_text + \"\\n\"\n",
" input_texts.append(input_text)\n",
" target_texts.append(target_text)\n",
" for char in input_text:\n",
" if char not in input_characters:\n",
" input_characters.add(char)\n",
" for char in target_text:\n",
" if char not in target_characters:\n",
" target_characters.add(char)"
],
"metadata": {
"id": "zElLeLxoVrlV"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"source": [
"input_characters = sorted(list(input_characters))\n",
"target_characters = sorted(list(target_characters))\n",
"num_encoder_tokens = len(input_characters)\n",
"num_decoder_tokens = len(target_characters)\n",
"max_encoder_seq_length = max([len(txt) for txt in input_texts])\n",
"max_decoder_seq_length = max([len(txt) for txt in target_texts])\n",
"\n",
"print(\"Number of samples:\", len(input_texts))\n",
"print(\"Number of unique input tokens:\", num_encoder_tokens)\n",
"print(\"Number of unique output tokens:\", num_decoder_tokens)\n",
"print(\"Max sequence length for inputs:\", max_encoder_seq_length)\n",
"print(\"Max sequence length for outputs:\", max_decoder_seq_length)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kjlu6-q2mGLh",
"outputId": "c9515194-fb46-4dc5-9071-fd54563ead77"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of samples: 10000\n",
"Number of unique input tokens: 73\n",
"Number of unique output tokens: 96\n",
"Max sequence length for inputs: 128\n",
"Max sequence length for outputs: 166\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Build the input to token mapping\n",
"input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])\n",
"target_token_index = dict(\n",
" [(char, i) for i, char in enumerate(target_characters)]\n",
")"
],
"metadata": {
"id": "3Vp935oWm3lS"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"source": [
"encoder_input_data = np.zeros(\n",
" (len(input_texts), max_encoder_seq_length, num_encoder_tokens),\n",
" dtype=\"float32\",\n",
")\n",
"\n",
"# Changes (@ariG23498): Target texts length used\n",
"decoder_input_data = np.zeros(\n",
" (len(target_texts), max_decoder_seq_length, num_decoder_tokens),\n",
" dtype=\"float32\",\n",
")\n",
"decoder_target_data = np.zeros(\n",
" (len(target_texts), max_decoder_seq_length, num_decoder_tokens),\n",
" dtype=\"float32\",\n",
")"
],
"metadata": {
"id": "BLxf-XcIm9H_"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(f\"{encoder_input_data.shape=}\")\n",
"print(f\"{decoder_input_data.shape=}\")\n",
"print(f\"{decoder_target_data.shape=}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PMG4c8kpn0Fe",
"outputId": "e8c10e7c-69a6-4391-adcf-b0b323997027"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"encoder_input_data.shape=(10000, 128, 73)\n",
"decoder_input_data.shape=(10000, 166, 96)\n",
"decoder_target_data.shape=(10000, 166, 96)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):\n",
" # One Hot encode the encoder input data\n",
" for t, char in enumerate(input_text):\n",
" encoder_input_data[i, t, input_token_index[char]] = 1.0\n",
" # Pad the rest of the places\n",
" encoder_input_data[i, t + 1 :, input_token_index[\" \"]] = 1.0\n",
"\n",
" # One Hot encode the decoder input data\n",
" for t, char in enumerate(target_text):\n",
" # decoder_target_data is ahead of decoder_input_data by one timestep\n",
" decoder_input_data[i, t, target_token_index[char]] = 1.0\n",
" if t > 0:\n",
" # decoder_target_data will be ahead by one timestep\n",
" # and will not include the start character.\n",
" decoder_target_data[i, t - 1, target_token_index[char]] = 1.0\n",
" # Pad the rest of the places\n",
" decoder_input_data[i, t + 1 :, target_token_index[\" \"]] = 1.0\n",
" decoder_target_data[i, t:, target_token_index[\" \"]] = 1.0"
],
"metadata": {
"id": "ceNRvi8VmquY"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Define an input sequence and process it.\n",
"encoder_inputs = keras.Input(shape=(None, num_encoder_tokens))\n",
"encoder = keras.layers.LSTM(latent_dim, return_state=True)\n",
"encoder_outputs, state_h, state_c = encoder(encoder_inputs)\n",
"\n",
"print(f\"{encoder_outputs.shape=}\")\n",
"print(f\"{state_h.shape=}\")\n",
"print(f\"{state_c.shape=}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BEf_ZhvpRud3",
"outputId": "4d0fba81-6b12-4cb1-a89f-c128575e687d"
},
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"encoder_outputs.shape=(None, 256)\n",
"state_h.shape=(None, 256)\n",
"state_c.shape=(None, 256)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# We discard `encoder_outputs` and only keep the states.\n",
"encoder_states = [state_h, state_c]"
],
"metadata": {
"id": "BpWdTvEYpYK_"
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Set up the decoder, using `encoder_states` as initial state.\n",
"decoder_inputs = keras.Input(shape=(None, num_decoder_tokens))\n",
"\n",
"# We set up our decoder to return full output sequences,\n",
"# and to return internal states as well. We don't use the\n",
"# return states in the training model, but we will use them in inference.\n",
"decoder_lstm = keras.layers.LSTM(\n",
" latent_dim, return_sequences=True, return_state=True\n",
")\n",
"decoder_outputs, _, _ = decoder_lstm(\n",
" decoder_inputs, initial_state=encoder_states\n",
")\n",
"decoder_dense = keras.layers.Dense(num_decoder_tokens)\n",
"decoder_outputs = decoder_dense(decoder_outputs)"
],
"metadata": {
"id": "LpZqFrjRpnzn"
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print(f\"{decoder_outputs.shape=}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PsQJKEbfpxpv",
"outputId": "8f6749ac-80fa-492c-cc5e-29f544c40d27"
},
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"decoder_outputs.shape=(None, None, 96)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Define the model that will turn\n",
"# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`\n",
"model = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs)\n",
"\n",
"model.summary()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 417
},
"id": "UMmlicLgpruh",
"outputId": "049656e4-7aec-422e-bca2-3136ef21a6b9"
},
"execution_count": 21,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1mModel: \"functional_1\"\u001b[0m\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"functional_1\"</span>\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mParam #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m73\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n",
"│ input_layer_1 │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m96\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n",
"│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ \u001b[38;5;34m337,920\u001b[0m │ input_layer[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m), │ │ │\n",
"│ │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)] │ │ │\n",
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n",
"│ lstm_1 (\u001b[38;5;33mLSTM\u001b[0m) │ [(\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, │ \u001b[38;5;34m361,472\u001b[0m │ input_layer_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ │ \u001b[38;5;34m256\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, │ │ lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m1\u001b[0m], │\n",
"│ │ \u001b[38;5;34m256\u001b[0m), (\u001b[38;5;45mNone\u001b[0m, │ │ lstm[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m2\u001b[0m] │\n",
"│ │ \u001b[38;5;34m256\u001b[0m)] │ │ │\n",
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m96\u001b[0m) │ \u001b[38;5;34m24,672\u001b[0m │ lstm_1[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"└─────────────────────┴───────────────────┴─────────┴──────────────────────┘\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ input_layer │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">73</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n",
"│ input_layer_1 │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
"│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ │ │ │\n",
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n",
"│ lstm (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ <span style=\"color: #00af00; text-decoration-color: #00af00\">337,920</span> │ input_layer[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), │ │ │\n",
"│ │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ │\n",
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n",
"│ lstm_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">LSTM</span>) │ [(<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ <span style=\"color: #00af00; text-decoration-color: #00af00\">361,472</span> │ input_layer_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
"│ │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ │ lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">1</span>], │\n",
"│ │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>), (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, │ │ lstm[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">2</span>] │\n",
"│ │ <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)] │ │ │\n",
"├─────────────────────┼───────────────────┼─────────┼──────────────────────┤\n",
"│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">96</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">24,672</span> │ lstm_1[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
"└─────────────────────┴───────────────────┴─────────┴──────────────────────┘\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Total params: \u001b[0m\u001b[38;5;34m724,064\u001b[0m (22.10 MB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">724,064</span> (22.10 MB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m724,064\u001b[0m (22.10 MB)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">724,064</span> (22.10 MB)\n",
"</pre>\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
"</pre>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"# wandb.init(\n",
"# entity=\"ariG23498\",\n",
"# project=\"s2s\",\n",
"# config={\n",
"# \"backend\": keras.backend.backend(),\n",
"# \"batch_size\": batch_size,\n",
"# \"epochs\": epochs,\n",
"# \"latent_dim\": latent_dim,\n",
"# \"num_samples\": num_samples,\n",
"# }\n",
"# )"
],
"metadata": {
"id": "4zbWNn46tBHQ"
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "GIEBcYexRZ8k",
"outputId": "98a98cae-0adb-49e0-9492-5ab846bc671f"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/100\n"
]
},
{
"output_type": "error",
"ename": "InvalidArgumentError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mInvalidArgumentError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-23-aedeb93827cb>\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m )\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m model.fit(\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mencoder_input_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecoder_input_data\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mdecoder_target_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0;31m# To get the full stack trace, call:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;31m# `keras_core.config.disable_traceback_filtering()`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/execute.py\u001b[0m in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mensure_initialized\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,\n\u001b[0m\u001b[1;32m 53\u001b[0m inputs, attrs, num_outputs)\n\u001b[1;32m 54\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mInvalidArgumentError\u001b[0m: Graph execution error:\n\nDetected unsupported operations when trying to compile graph __inference_one_step_on_data_2665[] on XLA_GPU_JIT: CudnnRNN (No registered 'CudnnRNN' OpKernel for XLA_GPU_JIT devices compatible with node {{node functional_1/lstm/CudnnRNN}}){{node functional_1/lstm/CudnnRNN}}\nThe op is created at: \nFile \"/usr/lib/python3.10/runpy.py\", line 196, in _run_module_as_main\n return _run_code(code, main_globals, None,\nFile \"/usr/lib/python3.10/runpy.py\", line 86, in _run_code\n exec(code, run_globals)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel_launcher.py\", line 16, in <module>\n app.launch_new_instance()\nFile \"/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py\", line 992, in launch_instance\n app.start()\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py\", line 619, in start\n self.io_loop.start()\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py\", line 195, in start\n self.asyncio_loop.run_forever()\nFile \"/usr/lib/python3.10/asyncio/base_events.py\", line 603, in run_forever\n self._run_once()\nFile \"/usr/lib/python3.10/asyncio/base_events.py\", line 1909, in _run_once\n handle._run()\nFile \"/usr/lib/python3.10/asyncio/events.py\", line 80, in _run\n self._context.run(self._callback, *self._args)\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py\", line 685, in <lambda>\n lambda f: self._run_callback(functools.partial(callback, future))\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py\", line 738, in _run_callback\n ret = callback()\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 825, in inner\n self.ctx_run(self.run)\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 786, in run\n yielded = self.gen.send(value)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\", line 377, in dispatch_queue\n yield self.process_one()\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 250, in wrapper\n runner = Runner(ctx_run, result, future, yielded)\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 748, in __init__\n self.ctx_run(self.run)\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 786, in run\n yielded = self.gen.send(value)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\", line 361, in process_one\n yield gen.maybe_future(dispatch(*args))\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 234, in wrapper\n yielded = ctx_run(next, result)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\", line 261, in dispatch_shell\n yield gen.maybe_future(handler(stream, idents, msg))\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 234, in wrapper\n yielded = ctx_run(next, result)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\", line 539, in execute_request\n self.do_execute(\nFile \"/usr/local/lib/python3.10/dist-packages/tornado/gen.py\", line 234, in wrapper\n yielded = ctx_run(next, result)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py\", line 302, in do_execute\n res = shell.run_cell(code, store_history=store_history, silent=silent)\nFile \"/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py\", line 539, in run_cell\n return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 2975, in run_cell\n result = self._run_cell(\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 3030, in _run_cell\n return runner(coro)\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py\", line 78, in _pseudo_sync_runner\n coro.send(None)\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 3257, in run_cell_async\n has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 3473, in run_ast_nodes\n if (await self.run_code(code, result, async_=asy)):\nFile \"/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\", line 3553, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\nFile \"<ipython-input-23-aedeb93827cb>\", line 7, in <cell line: 7>\n model.fit(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/trainer.py\", line 306, in fit\n logs = self.train_function(iterator)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/trainer.py\", line 111, in one_step_on_iterator\n outputs = self.distribute_strategy.run(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/trainer.py\", line 98, in one_step_on_data\n return self.train_step(data)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/trainer.py\", line 51, in train_step\n y_pred = self(x, training=True)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/layer.py\", line 703, in __call__\n outputs = super().__call__(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/ops/operation.py\", line 41, in __call__\n return call_fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 154, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/models/functional.py\", line 181, in call\n outputs = self._run_through_graph(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/ops/function.py\", line 127, in _run_through_graph\n outputs = operation_fn(node.operation)(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/models/functional.py\", line 549, in call\n return operation(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/layer.py\", line 703, in __call__\n outputs = super().__call__(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 119, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/ops/operation.py\", line 41, in __call__\n return call_fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/utils/traceback_utils.py\", line 154, in error_handler\n return fn(*args, **kwargs)\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/rnn/lstm.py\", line 526, in call\n return super().call(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/rnn/rnn.py\", line 390, in call\n last_output, outputs, states = self.inner_loop(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/layers/rnn/lstm.py\", line 505, in inner_loop\n return backend.lstm(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/rnn.py\", line 815, in lstm\n return _cudnn_lstm(\nFile \"/usr/local/lib/python3.10/dist-packages/keras_core/src/backend/tensorflow/rnn.py\", line 926, in _cudnn_lstm\n outputs, h, c, _ = tf.raw_ops.CudnnRNN(\n\t [[StatefulPartitionedCall]] [Op:__inference_one_step_on_iterator_2730]"
]
}
],
"source": [
"model.compile(\n",
" optimizer=keras.optimizers.AdamW(3e-3),\n",
" loss=keras.losses.CategoricalCrossentropy(from_logits=True),\n",
" metrics=[\"accuracy\"],\n",
")\n",
"\n",
"model.fit(\n",
" [encoder_input_data, decoder_input_data],\n",
" decoder_target_data,\n",
" batch_size=batch_size,\n",
" epochs=epochs,\n",
" validation_split=0.2,\n",
" # callbacks=[wandb.keras.WandbMetricsLogger()],\n",
")"
]
},
{
"cell_type": "code",
"source": [
"# wandb.finish()"
],
"metadata": {
"id": "J2XmgVpPuIc5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Save model\n",
"model.save(\"s2s_model.keras\")"
],
"metadata": {
"id": "p1BFMAFjR2Yo"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Define sampling models\n",
"# Restore the model and construct the encoder and decoder.\n",
"model = keras.models.load_model(\"s2s_model.keras\")"
],
"metadata": {
"id": "jVMW4G97R6H-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"encoder_inputs = model.input[0] # input_1\n",
"encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1\n",
"encoder_states = [state_h_enc, state_c_enc]\n",
"encoder_model = keras.Model(encoder_inputs, encoder_states)\n",
"\n",
"decoder_inputs = model.input[1] # input_2\n",
"decoder_state_input_h = keras.Input(shape=(latent_dim,))\n",
"decoder_state_input_c = keras.Input(shape=(latent_dim,))\n",
"decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]\n",
"decoder_lstm = model.layers[3]\n",
"decoder_outputs, state_h_dec, state_c_dec = decoder_lstm(\n",
" decoder_inputs, initial_state=decoder_states_inputs\n",
")\n",
"decoder_states = [state_h_dec, state_c_dec]\n",
"decoder_dense = model.layers[4]\n",
"decoder_outputs = decoder_dense(decoder_outputs)\n",
"decoder_model = keras.Model(\n",
" [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states\n",
")"
],
"metadata": {
"id": "z5p0sER-R8rx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Reverse-lookup token index to decode sequences back to\n",
"# something readable.\n",
"reverse_input_char_index = dict(\n",
" (i, char) for char, i in input_token_index.items()\n",
")\n",
"reverse_target_char_index = dict(\n",
" (i, char) for char, i in target_token_index.items()\n",
")"
],
"metadata": {
"id": "ZAPRidq5R-rW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def decode_sequence(input_seq):\n",
" # Encode the input as state vectors.\n",
" states_value = encoder_model.predict(input_seq, verbose=0)\n",
"\n",
" # Generate empty target sequence of length 1.\n",
" target_seq = np.zeros((1, 1, num_decoder_tokens))\n",
" # Populate the first character of target sequence with the start character.\n",
" target_seq[0, 0, target_token_index[\"\\t\"]] = 1.0\n",
"\n",
" # Sampling loop for a batch of sequences\n",
" # (to simplify, here we assume a batch of size 1).\n",
" stop_condition = False\n",
" decoded_sentence = \"\"\n",
" while not stop_condition:\n",
" output_tokens, h, c = decoder_model.predict(\n",
" [target_seq] + states_value, verbose=0\n",
" )\n",
"\n",
" # Sample a token\n",
" sampled_token_index = np.argmax(output_tokens[0, -1, :])\n",
" sampled_char = reverse_target_char_index[sampled_token_index]\n",
" decoded_sentence += sampled_char\n",
"\n",
" # Exit condition: either hit max length\n",
" # or find stop character.\n",
" if (\n",
" sampled_char == \"\\n\"\n",
" or len(decoded_sentence) > max_decoder_seq_length\n",
" ):\n",
" stop_condition = True\n",
"\n",
" # Update the target sequence (of length 1).\n",
" target_seq = np.zeros((1, 1, num_decoder_tokens))\n",
" target_seq[0, 0, sampled_token_index] = 1.0\n",
"\n",
" # Update states\n",
" states_value = [h, c]\n",
" return decoded_sentence"
],
"metadata": {
"id": "9S4PEz68SCHI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for seq_index in range(20):\n",
" # Take one sequence (part of the training set)\n",
" # for trying out decoding.\n",
" input_seq = encoder_input_data[seq_index : seq_index + 1]\n",
" decoded_sentence = decode_sequence(input_seq)\n",
" print(\"-\")\n",
" print(\"Input sentence:\", input_texts[seq_index])\n",
" print(\"Decoded sentence:\", decoded_sentence)"
],
"metadata": {
"id": "Vb_sRnYrRzqK"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment