Skip to content

Instantly share code, notes, and snippets.

@naotokui
Created January 25, 2017 12:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save naotokui/425928251c163528b3716779319fc95e to your computer and use it in GitHub Desktop.
Save naotokui/425928251c163528b3716779319fc95e to your computer and use it in GitHub Desktop.
DQN Turntable Control for AI-DJ
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Q-Learning in Keras\n",
"\n",
"based on [Keras plays catch](https://gist.github.com/EderSantana/c7222daa328f0e885093)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import json\n",
"import numpy as np\n",
"from keras.models import Sequential\n",
"from keras.layers.core import Dense\n",
"from keras.optimizers import sgd\n",
"\n",
"\n",
"# Store and Retrieve Experiences for Replay\n",
"class ExperienceReplay(object):\n",
" def __init__(self, max_memory=10000, discount=.9):\n",
" self.max_memory = max_memory # max number of experiences\n",
" self.memory = list() \n",
" self.discount = discount # q-learning discout value\n",
"\n",
" def remember(self, states):\n",
" # memory[i] = [[state_t, action_t, reward_t, state_t+1]]\n",
" self.memory.append([states])\n",
" if len(self.memory) > self.max_memory: # remove the oldest one if the memory list is full\n",
" del self.memory[0]\n",
"\n",
" def get_batch(self, model, batch_size=10):\n",
" len_memory = len(self.memory)\n",
" num_actions = model.output_shape[-1]\n",
" env_dim = self.memory[0][0][0].shape[1]\n",
" inputs = np.zeros((min(len_memory, batch_size), env_dim))\n",
" targets = np.zeros((inputs.shape[0], num_actions))\n",
" for i, idx in enumerate(np.random.randint(0, len_memory,\n",
" size=inputs.shape[0])):\n",
" state_t, action_t, reward_t, state_tp1 = self.memory[idx][0]\n",
"\n",
" inputs[i:i+1] = state_t\n",
" # There should be no target values for actions not taken.\n",
" # Thou shalt not correct actions not taken #deep\n",
" targets[i] = model.predict(state_t)[0]\n",
" Q_sa = np.max(model.predict(state_tp1)[0])\n",
" # reward_t + gamma * max_a' Q(s', a')\n",
" targets[i, action_t] = reward_t + self.discount * Q_sa\n",
" return inputs, targets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# parameters\n",
"epsilon = .1 # exploration\n",
"epoch = 1000\n",
"max_memory = 500\n",
"\n",
"# If you want to continue training from a previous model, just uncomment the line bellow\n",
"# model.load_weights(\"model.h5\")\n",
"\n",
"# Define environment/game\n",
"# env = Catch(grid_size)\n",
"\n",
"# Initialize experience replay object\n",
"exp_replay = ExperienceReplay(max_memory=max_memory)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import OSC\n",
"import threading, time\n",
"import zmq\n",
"import sys\n",
"import numpy as np\n",
"from IPython.display import clear_output\n",
"\n",
"# Turntable Class\n",
"class TurntableSimulation(object):\n",
" \n",
" def __init__(self, observation_size, observation_shape, action_num):\n",
"\n",
" self.reward = 0\n",
" self.collected_rewards = []\n",
" self.tempo_match_interval = 0\n",
" \n",
" self.observation_size = observation_size # size of state array\n",
" self.observation_shape = observation_shape\n",
" self.num_actions = action_num # number of possible actions\n",
" self.observation = np.array([])\n",
"\n",
" # OSC Setup - python to oF\n",
" self.oscclient = OSC.OSCClient()\n",
" self.oscclient.connect(('localhost', 7710)) # client\n",
" self.oscserver = OSC.ThreadingOSCServer((\"localhost\", 7711)) # server\n",
" self.oscserver.addMsgHandler( \"/experience\", self.experience_osc_callback )\n",
" self.oscserver.addMsgHandler( \"/status\", self.status_osc_callback )\n",
" self.oscserver.addMsgHandler( \"/beatmap\", self.beatmap_osc_callback )\n",
" self.oscserver.addMsgHandler( \"/beatdist\", self.beatdist_osc_callback )\n",
" self.oscserver.addMsgHandler( \"/reward\", self.reward_osc_callback )\n",
" self.oscserver.addMsgHandler( \"/tempo_match_interval\", self.tempomatchinterval_osc_callback )\n",
" self.st = threading.Thread( target = self.oscserver.serve_forever )\n",
" self.st.start()\n",
" \n",
" # send data to visualization app via ZMQ\n",
" context = zmq.Context()\n",
" self.zmqsock = context.socket(zmq.PUB)\n",
" self.zmqsock.bind(\"tcp://*:8506\")\n",
"\n",
" def experience_osc_callback(self, path, tags, args, source):\n",
" assert len(args) == self.observation_size * 2 + 2\n",
" \n",
" pstates = np.array(args[:self.observation_size]).reshape((1, self.observation_size,))\n",
" paction = int(args[self.observation_size])\n",
" reward = args[self.observation_size + 1] \n",
" states = np.array(args[self.observation_size + 2 :]).reshape((1, self.observation_size,))\n",
" \n",
" exp_replay.remember([pstates, paction, reward, states])\n",
" \n",
" # send back a ACK\n",
" msg = OSC.OSCMessage(\"/ack\")\n",
" self.oscclient.send(msg)\n",
"\n",
"\n",
" \n",
" def status_osc_callback(self, path, tags, args, source):\n",
" assert len(args) == self.observation_size\n",
" self.observation = np.array(args).reshape(self.observation_shape)\n",
"# print self.observation \n",
"\n",
" \n",
" # to oF app for visualization \n",
" def beatmap_osc_callback(self, path, tags, args, source):\n",
" self.zmqsock.send_json({\"beatmap\": args}) \n",
" def beatdist_osc_callback(self, path, tags, args, source):\n",
" self.zmqsock.send_json({\"beatdist\": args})\n",
"\n",
" # to visualization / tempo matched flag\n",
" def tempomatchinterval_osc_callback(self, path, tags, args, source):\n",
" self.tempo_match_interval = args[0]\n",
" print \"reward %.3f\" % self.tempo_match_interval\n",
" \n",
" def reward_osc_callback(self, path, tags, args, source):\n",
" assert len(args) == 1\n",
" print \"reward %.3f\" % args[0] \n",
" self.reward = args[0]\n",
" \n",
" # python to oF\n",
" def send_osc(self, path, values):\n",
" oscmsg = OSC.OSCMessage()\n",
" oscmsg.setAddress(path)\n",
" if type(values) == list:\n",
" for v in values:\n",
" oscmsg.append(v)\n",
" else:\n",
" oscmsg.append(values)\n",
" try:\n",
" self.oscclient.send(oscmsg)\n",
" except:\n",
" print \"osc exception\"\n",
" \n",
" def send_action(self, action_id):\n",
" self.last_action = action_id\n",
" self.send_osc(\"/action\", action_id)\n",
"\n",
" def send_action_scores(self, scores):\n",
" self.send_osc(\"/action_scores\", scores)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# try:\n",
"# if turntable.oscserver is not None:\n",
"# turntable.oscserver.close()\n",
"# turntable.oscserver = None\n",
"# except:\n",
"# None\n",
"\n",
"num_actions = 3\n",
"turntable = TurntableSimulation(10+1, (10 + 1, ), num_actions)\n",
"\n",
"hidden_size = 100\n",
"batch_size = 50\n",
"\n",
"num_actions = turntable.num_actions # [move_left, stay, move_right]\n",
"\n",
"model = Sequential()\n",
"model.add(Dense(hidden_size, input_dim=turntable.observation_size, activation='relu'))\n",
"model.add(Dense(hidden_size, activation='relu'))\n",
"model.add(Dense(num_actions, activation='softmax'))\n",
"model.compile(sgd(lr=.2), \"mse\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 000/999 | Loss 104.3554\n",
"Epoch 001/999 | Loss 103.6451\n",
"Epoch 002/999 | Loss 116.3734\n",
"Epoch 003/999 | Loss 106.6050\n",
"Epoch 004/999 | Loss 115.5968\n",
"Epoch 005/999 | Loss 125.6049\n",
"Epoch 006/999 | Loss 126.7171\n",
"Epoch 007/999 | Loss 129.7448"
]
}
],
"source": [
"\n",
"# Train\n",
"for e in range(epoch):\n",
" loss = 0.\n",
" input_tm1 = None\n",
"\n",
" for i in range(1000):\n",
" \n",
" # get initial input\n",
" input_t = turntable.observation.reshape((1, turntable.observation_size,)) # env.observe()\n",
"\n",
" if not np.array_equal(input_t, input_tm1):\n",
"\n",
" input_tm1 = input_t\n",
" # get next action\n",
" if np.random.rand() <= epsilon:\n",
" action = np.random.randint(0, num_actions, size=1)[0]\n",
" else:\n",
" q = model.predict(input_tm1)\n",
" action = np.argmax(q[0])\n",
" turntable.send_action_scores(q[0])\n",
" turntable.send_action(action) # send action back to turntable\n",
" inputs, targets = exp_replay.get_batch(model, batch_size=batch_size) \n",
" loss += model.train_on_batch(inputs, targets)\n",
" print(\"Epoch {:03d}/999 | Loss {:.4f}\".format(e, loss))\n",
"\n",
"# Save trained model weights and architecture, this will be used by the visualization code\n",
"model.save_weights(\"model.h5\", overwrite=True)\n",
"with open(\"model.json\", \"w\") as outfile:\n",
" json.dump(model.to_json(), outfile)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment