Created
January 25, 2017 12:36
-
-
Save naotokui/425928251c163528b3716779319fc95e to your computer and use it in GitHub Desktop.
DQN Turntable Control for AI-DJ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Q-Learning in Keras\n", | |
"\n", | |
"based on [Keras plays catch](https://gist.github.com/EderSantana/c7222daa328f0e885093)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using TensorFlow backend.\n" | |
] | |
} | |
], | |
"source": [ | |
"import json\n", | |
"import numpy as np\n", | |
"from keras.models import Sequential\n", | |
"from keras.layers.core import Dense\n", | |
"from keras.optimizers import sgd\n", | |
"\n", | |
"\n", | |
"# Store and Retrieve Experiences for Replay\n", | |
"class ExperienceReplay(object):\n", | |
" def __init__(self, max_memory=10000, discount=.9):\n", | |
" self.max_memory = max_memory # max number of experiences\n", | |
" self.memory = list() \n", | |
" self.discount = discount # q-learning discout value\n", | |
"\n", | |
" def remember(self, states):\n", | |
" # memory[i] = [[state_t, action_t, reward_t, state_t+1]]\n", | |
" self.memory.append([states])\n", | |
" if len(self.memory) > self.max_memory: # remove the oldest one if the memory list is full\n", | |
" del self.memory[0]\n", | |
"\n", | |
" def get_batch(self, model, batch_size=10):\n", | |
" len_memory = len(self.memory)\n", | |
" num_actions = model.output_shape[-1]\n", | |
" env_dim = self.memory[0][0][0].shape[1]\n", | |
" inputs = np.zeros((min(len_memory, batch_size), env_dim))\n", | |
" targets = np.zeros((inputs.shape[0], num_actions))\n", | |
" for i, idx in enumerate(np.random.randint(0, len_memory,\n", | |
" size=inputs.shape[0])):\n", | |
" state_t, action_t, reward_t, state_tp1 = self.memory[idx][0]\n", | |
"\n", | |
" inputs[i:i+1] = state_t\n", | |
" # There should be no target values for actions not taken.\n", | |
" # Thou shalt not correct actions not taken #deep\n", | |
" targets[i] = model.predict(state_t)[0]\n", | |
" Q_sa = np.max(model.predict(state_tp1)[0])\n", | |
" # reward_t + gamma * max_a' Q(s', a')\n", | |
" targets[i, action_t] = reward_t + self.discount * Q_sa\n", | |
" return inputs, targets" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# parameters\n", | |
"epsilon = .1 # exploration\n", | |
"epoch = 1000\n", | |
"max_memory = 500\n", | |
"\n", | |
"# If you want to continue training from a previous model, just uncomment the line bellow\n", | |
"# model.load_weights(\"model.h5\")\n", | |
"\n", | |
"# Define environment/game\n", | |
"# env = Catch(grid_size)\n", | |
"\n", | |
"# Initialize experience replay object\n", | |
"exp_replay = ExperienceReplay(max_memory=max_memory)\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import OSC\n", | |
"import threading, time\n", | |
"import zmq\n", | |
"import sys\n", | |
"import numpy as np\n", | |
"from IPython.display import clear_output\n", | |
"\n", | |
"# Turntable Class\n", | |
"class TurntableSimulation(object):\n", | |
" \n", | |
" def __init__(self, observation_size, observation_shape, action_num):\n", | |
"\n", | |
" self.reward = 0\n", | |
" self.collected_rewards = []\n", | |
" self.tempo_match_interval = 0\n", | |
" \n", | |
" self.observation_size = observation_size # size of state array\n", | |
" self.observation_shape = observation_shape\n", | |
" self.num_actions = action_num # number of possible actions\n", | |
" self.observation = np.array([])\n", | |
"\n", | |
" # OSC Setup - python to oF\n", | |
" self.oscclient = OSC.OSCClient()\n", | |
" self.oscclient.connect(('localhost', 7710)) # client\n", | |
" self.oscserver = OSC.ThreadingOSCServer((\"localhost\", 7711)) # server\n", | |
" self.oscserver.addMsgHandler( \"/experience\", self.experience_osc_callback )\n", | |
" self.oscserver.addMsgHandler( \"/status\", self.status_osc_callback )\n", | |
" self.oscserver.addMsgHandler( \"/beatmap\", self.beatmap_osc_callback )\n", | |
" self.oscserver.addMsgHandler( \"/beatdist\", self.beatdist_osc_callback )\n", | |
" self.oscserver.addMsgHandler( \"/reward\", self.reward_osc_callback )\n", | |
" self.oscserver.addMsgHandler( \"/tempo_match_interval\", self.tempomatchinterval_osc_callback )\n", | |
" self.st = threading.Thread( target = self.oscserver.serve_forever )\n", | |
" self.st.start()\n", | |
" \n", | |
" # send data to visualization app via ZMQ\n", | |
" context = zmq.Context()\n", | |
" self.zmqsock = context.socket(zmq.PUB)\n", | |
" self.zmqsock.bind(\"tcp://*:8506\")\n", | |
"\n", | |
" def experience_osc_callback(self, path, tags, args, source):\n", | |
" assert len(args) == self.observation_size * 2 + 2\n", | |
" \n", | |
" pstates = np.array(args[:self.observation_size]).reshape((1, self.observation_size,))\n", | |
" paction = int(args[self.observation_size])\n", | |
" reward = args[self.observation_size + 1] \n", | |
" states = np.array(args[self.observation_size + 2 :]).reshape((1, self.observation_size,))\n", | |
" \n", | |
" exp_replay.remember([pstates, paction, reward, states])\n", | |
" \n", | |
" # send back a ACK\n", | |
" msg = OSC.OSCMessage(\"/ack\")\n", | |
" self.oscclient.send(msg)\n", | |
"\n", | |
"\n", | |
" \n", | |
" def status_osc_callback(self, path, tags, args, source):\n", | |
" assert len(args) == self.observation_size\n", | |
" self.observation = np.array(args).reshape(self.observation_shape)\n", | |
"# print self.observation \n", | |
"\n", | |
" \n", | |
" # to oF app for visualization \n", | |
" def beatmap_osc_callback(self, path, tags, args, source):\n", | |
" self.zmqsock.send_json({\"beatmap\": args}) \n", | |
" def beatdist_osc_callback(self, path, tags, args, source):\n", | |
" self.zmqsock.send_json({\"beatdist\": args})\n", | |
"\n", | |
" # to visualization / tempo matched flag\n", | |
" def tempomatchinterval_osc_callback(self, path, tags, args, source):\n", | |
" self.tempo_match_interval = args[0]\n", | |
" print \"reward %.3f\" % self.tempo_match_interval\n", | |
" \n", | |
" def reward_osc_callback(self, path, tags, args, source):\n", | |
" assert len(args) == 1\n", | |
" print \"reward %.3f\" % args[0] \n", | |
" self.reward = args[0]\n", | |
" \n", | |
" # python to oF\n", | |
" def send_osc(self, path, values):\n", | |
" oscmsg = OSC.OSCMessage()\n", | |
" oscmsg.setAddress(path)\n", | |
" if type(values) == list:\n", | |
" for v in values:\n", | |
" oscmsg.append(v)\n", | |
" else:\n", | |
" oscmsg.append(values)\n", | |
" try:\n", | |
" self.oscclient.send(oscmsg)\n", | |
" except:\n", | |
" print \"osc exception\"\n", | |
" \n", | |
" def send_action(self, action_id):\n", | |
" self.last_action = action_id\n", | |
" self.send_osc(\"/action\", action_id)\n", | |
"\n", | |
" def send_action_scores(self, scores):\n", | |
" self.send_osc(\"/action_scores\", scores)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# try:\n", | |
"# if turntable.oscserver is not None:\n", | |
"# turntable.oscserver.close()\n", | |
"# turntable.oscserver = None\n", | |
"# except:\n", | |
"# None\n", | |
"\n", | |
"num_actions = 3\n", | |
"turntable = TurntableSimulation(10+1, (10 + 1, ), num_actions)\n", | |
"\n", | |
"hidden_size = 100\n", | |
"batch_size = 50\n", | |
"\n", | |
"num_actions = turntable.num_actions # [move_left, stay, move_right]\n", | |
"\n", | |
"model = Sequential()\n", | |
"model.add(Dense(hidden_size, input_dim=turntable.observation_size, activation='relu'))\n", | |
"model.add(Dense(hidden_size, activation='relu'))\n", | |
"model.add(Dense(num_actions, activation='softmax'))\n", | |
"model.compile(sgd(lr=.2), \"mse\")\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 000/999 | Loss 104.3554\n", | |
"Epoch 001/999 | Loss 103.6451\n", | |
"Epoch 002/999 | Loss 116.3734\n", | |
"Epoch 003/999 | Loss 106.6050\n", | |
"Epoch 004/999 | Loss 115.5968\n", | |
"Epoch 005/999 | Loss 125.6049\n", | |
"Epoch 006/999 | Loss 126.7171\n", | |
"Epoch 007/999 | Loss 129.7448" | |
] | |
} | |
], | |
"source": [ | |
"\n", | |
"# Train\n", | |
"for e in range(epoch):\n", | |
" loss = 0.\n", | |
" input_tm1 = None\n", | |
"\n", | |
" for i in range(1000):\n", | |
" \n", | |
" # get initial input\n", | |
" input_t = turntable.observation.reshape((1, turntable.observation_size,)) # env.observe()\n", | |
"\n", | |
" if not np.array_equal(input_t, input_tm1):\n", | |
"\n", | |
" input_tm1 = input_t\n", | |
" # get next action\n", | |
" if np.random.rand() <= epsilon:\n", | |
" action = np.random.randint(0, num_actions, size=1)[0]\n", | |
" else:\n", | |
" q = model.predict(input_tm1)\n", | |
" action = np.argmax(q[0])\n", | |
" turntable.send_action_scores(q[0])\n", | |
" turntable.send_action(action) # send action back to turntable\n", | |
" inputs, targets = exp_replay.get_batch(model, batch_size=batch_size) \n", | |
" loss += model.train_on_batch(inputs, targets)\n", | |
" print(\"Epoch {:03d}/999 | Loss {:.4f}\".format(e, loss))\n", | |
"\n", | |
"# Save trained model weights and architecture, this will be used by the visualization code\n", | |
"model.save_weights(\"model.h5\", overwrite=True)\n", | |
"with open(\"model.json\", \"w\") as outfile:\n", | |
" json.dump(model.to_json(), outfile)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment