Skip to content

Instantly share code, notes, and snippets.

@enakai00
Created July 25, 2022 08:59
Show Gist options
  • Save enakai00/1407e48e0ce1607d48bb6a8906a8e2bc to your computer and use it in GitHub Desktop.
Save enakai00/1407e48e0ce1607d48bb6a8906a8e2bc to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"id": "2b31bfe9-3b3d-4313-9435-cead3071ee41",
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np\n",
"import copy, random, time, subprocess, os\n",
"from tensorflow.keras import layers, models"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6dcb0907-a790-47b0-8ac6-a6c2eb912d61",
"metadata": {},
"outputs": [],
"source": [
"class QValue:\n",
" def __init__(self):\n",
" self.model = None\n",
"\n",
" def get_action(self, state):\n",
" states = []\n",
" actions = []\n",
" for a in range(5):\n",
" states.append(np.array(state))\n",
" action_onehot = np.zeros(5)\n",
" action_onehot[a] = 1\n",
" actions.append(action_onehot)\n",
" \n",
" q_values = self.model.predict([np.array(states), np.array(actions)])\n",
" optimal_action = np.argmax(q_values)\n",
" return optimal_action, q_values[optimal_action][0]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cdc3cc3a-04dc-463b-87b0-5669d962e2fa",
"metadata": {},
"outputs": [],
"source": [
"def join_frames(o0, o1):\n",
" return np.r_[o0.transpose(), o1.transpose()].transpose() "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "14a92b6f-27a3-4bcf-b3e1-3e9c0dc85389",
"metadata": {},
"outputs": [],
"source": [
"q_value = QValue()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c9364323-e246-4abd-bfbc-74abe0f2d5e7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Copying gs://etsuji-car-racing-v2-model01/model01/car-racing-v2-model01-104.hd5...\n",
"- [1 files][292.5 MiB/292.5 MiB] \n",
"Operation completed over 1 objects/292.5 MiB. \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"load model car-racing-v2-model01-104.hd5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.7/site-packages/gym/core.py:330: DeprecationWarning: \u001b[33mWARN: Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
" \"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\"\n",
"/opt/conda/lib/python3.7/site-packages/gym/wrappers/step_api_compatibility.py:40: DeprecationWarning: \u001b[33mWARN: Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\u001b[0m\n",
" \"Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future.\"\n",
"/opt/conda/lib/python3.7/site-packages/gym/core.py:52: DeprecationWarning: \u001b[33mWARN: The argument mode in render method is deprecated; use render_mode during environment initialization instead.\n",
"See here for more information: https://www.gymlibrary.ml/content/api/\u001b[0m\n",
" \"The argument mode in render method is deprecated; \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"3 11.814814814814826\n",
"3 27.333333333333314\n",
"3 46.555555555555486\n",
"3 65.77777777777771\n",
"3 62.77777777777779\n",
"3 59.77777777777775\n",
"3 56.77777777777771\n",
"3 53.777777777777665\n",
"3 50.77777777777762\n",
"3 47.77777777777758\n",
"3 44.77777777777754\n"
]
}
],
"source": [
"import datetime \n",
"import imageio\n",
"\n",
"checkpoint = 104\n",
"model = 'model01'\n",
"\n",
"BUCKET = 'gs://etsuji-car-racing-v2-{}'.format(model)\n",
"filename = 'car-racing-v2-{}-{}.hd5'.format(model, checkpoint)\n",
"subprocess.run(['gsutil', 'cp', '{}/{}/{}'.format(BUCKET, model, filename), './'])\n",
"print('load model {}'.format(filename))\n",
"q_value.model = models.load_model(filename)\n",
"os.remove(filename)\n",
"\n",
"env = gym.make(\"CarRacing-v2\", continuous=False)\n",
"o0 = env.reset()\n",
"o1 = copy.deepcopy(o0)\n",
"done = 0\n",
"total_r = 0\n",
"c = 0\n",
"\n",
"frames = []\n",
"\n",
"while not done: \n",
" a, _ = q_value.get_action(join_frames(o0, o1))\n",
" o_new, r, done, i = env.step(a)\n",
" total_r += r\n",
" o0, o1 = o1, o_new \n",
" c += 1\n",
" frame = env.render('rgb_array')\n",
" frames.append(frame) \n",
" if c % 30 == 0:\n",
" print(a, total_r)\n",
"\n",
"now = datetime.datetime.now()\n",
"imageio.mimsave('car-racing-v2-{}-{}-{}-{}.gif'.format(model, int(total_r), checkpoint, now.strftime('%Y%m%d-%H%M%S')),\n",
" frames, 'GIF' , **{'duration': 1.0/30.0})"
]
}
],
"metadata": {
"environment": {
"kernel": "python3",
"name": "tf2-gpu.2-8.m94",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-8:m94"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment