Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tall-josh/63411cec48eb9efe7afae83e452be307 to your computer and use it in GitHub Desktop.
Save tall-josh/63411cec48eb9efe7afae83e452be307 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reinforcement Learning For Self Driving Cars"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At the moment this code does not work because I'm very new to jupyter notebooks and I haven't figured out how to import classes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"leave_program = False\n",
"total_frames = 0\n",
"epochs = 50000\n",
"epoch_cnt = 0\n",
"gamma = 0.9\n",
"epsilon = 1\n",
"batch_size = 30\n",
"buffer = 30000\n",
"replay = []\n",
"h = 0\n",
"reward = 0\n",
"\n",
"for i in range(epochs):\n",
" \n",
" initSimulation(agent, state, filling_buffer = True if len(replay) < buffer else False)\n",
" collision_detected = False\n",
" frames_this_epoch = 0\n",
"\n",
" while not collision_detected:\n",
" # concatonating a string that I print at the bottom of the loop\n",
" __console_string = \"\"\n",
" __console_string += \"FRAME: {0} -- \".format(frames_this_epoch)\n",
" \n",
" ##### PYGAME HOUSE KEEPING #####\n",
" # Keep loop time constant\n",
" clock.tick(CONST.SCREEN_FPS)\n",
" screen.fill(CONST.COLOR_BLACK)\n",
"\n",
" # Returns quality estimates for all posiable actions, copies to state_0\n",
" qMatrix = dqnn.getQMat(state.state)\n",
" state_0 = copy.deepcopy(state.state)\n",
" \n",
" ##### SELECT ACTION #####\n",
" # Select random action or use best action from qMatrix\n",
" action_idx = 0\n",
" if (random.random() < epsilon):\n",
" action_idx = random.randint(0,len(CONST.ACTION_AND_COSTS)-1)\n",
" __console_string += \"random action: {0} -- \".format(CONST.ACTION_NAMES[action_idx])\n",
" else:\n",
" action_idx = np.argmax(qMatrix)\n",
" __console_string += \"selected action: {0} -- \".format(CONST.ACTION_NAMES[action_idx])\n",
"\n",
" ##### Take action #####\n",
" agent.updateAction(action_idx) # Apply action selected above\n",
" all_sprites.update() \n",
" __console_string += \"speed: {0} -- \".format(agent.speed)\n",
" \n",
" # Check for agent obstacle collisions\n",
" collisions = pygame.sprite.spritecollide(agent, obstacles, False) \n",
" \n",
"\n",
" ##### Observe new state (s') #####\n",
" agent.updateSensors(obstacles) # Sensor update\n",
" state.update(agent.sensor_data) # Update state with new data\n",
"\n",
" ##### GET maxQ' from DQN #####\n",
" next_qMatrix = dqnn.getQMat(state.state)\n",
"\n",
" # Get reward from agent\n",
" reward = agent.reward\n",
"\n",
" if (collisions or agent.out_of_bounds):\n",
" collision_detected = True\n",
" reward = CONST.REWARDS['terminal_crash']\n",
" print(\"terminal_crash*********************************************************************************\")\n",
" \n",
" if agent.isAtGoal():\n",
" collision_detected = True\n",
" reward = CONST.REWARDS['terminal_goal']\n",
" agent.terminal = True\n",
" print(\"terminal_goal!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\")\n",
" \n",
" if frames_this_epoch > CONST.TAKING_TOO_LONG:\n",
" collision_detected = True\n",
" print(\"TAKING TOO LONG :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( :-( \")\n",
" __console_string += \"Reward: {0} -- Epsilon: {1} -- Epoch: {2} -- Total_Frames: {3}\".format(reward, epsilon, epoch_cnt, total_frames)\n",
" \n",
"\t\t# if the buffer is not full, keep filling. Else, overwrite oldest element begin learning\n",
" if len(replay) < buffer:\n",
" replay.append((copy.deepcopy(state_0), copy.deepcopy(action_idx), copy.deepcopy(reward), copy.deepcopy(state.state)))\n",
" \n",
" if log_data:\n",
" # Data Logging \n",
" states0.append(copy.deepcopy(state_0))\n",
" actions.append(copy.deepcopy(action_idx))\n",
" rewards.append(copy.deepcopy(reward))\n",
" states1.append(copy.deepcopy(state.state))\n",
"\n",
" epoch_cnt = 0# keep epochs at zero until buffer if full\n",
" else:\n",
" h = (h+1)%buffer\n",
" replay.append((copy.deepcopy(state_0), copy.deepcopy(action_idx), copy.deepcopy(reward), copy.deepcopy(state.state)))\n",
" \n",
" batch = random.sample(replay, batch_size)\n",
" target_batch = []\n",
" for element in batch:\n",
" # Breakup tuple into more readable elements\n",
"\t\t\t\treplay_state, replay_action_idx, replay_reward, replay_new_state = element\n",
"\t\t\t\t\n",
"\t\t\t\t# First feed forward pass to get qualities of each action in initial state\n",
" q_mat = dqnn.getQMat(replay_state)\n",
"\t\t\t\t\n",
"\t\t\t\t# this will be used as the 'target' for back-prop \n",
" y = copy.deepcopy(q_mat[0])\n",
"\n",
"\t\t\t\t# Second feed forward pass to get qualities of each action in new state\n",
" q_mat_new = dqnn.getQMat(replay_new_state)[0]\n",
" q_val_new = max(q_mat_new)\n",
"\n",
"\t\t\t\t# Negative 1 rewards indicate a terminal state (collision)\n",
" if (replay_reward == -1):\n",
" q_update = replay_reward\n",
" else:\n",
" q_update = replay_reward + (gamma*q_val_new)\n",
" \n",
"\t\t\t\t# Overwrite index of action taken with q_update\n",
"\t\t\t\ty[replay_action_idx] = q_update\n",
" target_batch.append(y)\n",
"\n",
"\t\t\t# writing various things to the screen for debugging \n",
" if total_frames % 100 == 0:\n",
" dqnn.fitBatch([row[0] for row in batch], target_batch, save=False, verbose=True, iteration_count=total_frames-buffer)\n",
" elif total_frames % 10001 == 0:\n",
" dqnn.fitBatch([row[0] for row in batch], target_batch, save=True, verbose=False, iteration_count=total_frames-buffer)\n",
" else:\n",
" dqnn.fitBatch([row[0] for row in batch], target_batch)\n",
"\n",
"\t\t\t# Decreasing epsilon\n",
" if epsilon > 0.1:\n",
" epsilon -= 1/100000\n",
"\n",
"\n",
"##### MORE PYGAME HOUSE KEEPING #####\n",
"# Respawn obstacles if they are - CONST. out of range\n",
" if moving_obstacles:\n",
" for obs in obstacles:\n",
" if obs.out_of_range:\n",
" obs.reInitObs(0, CONST.LANES[random.rand(CONST.CAR_LANE_MIN,CONST.CAR_LANE_MAX)], obstacles)\n",
" #print(\"Reload\")\n",
" \n",
"# Check if agent is out of bounds\n",
" if agent.rect.x > CONST.SCREEN_WIDTH + CONST.SCREEN_PADDING:\n",
" collision_detected = True\n",
"\n",
"# Draw / render\n",
" all_sprites.draw(screen)\n",
"### Drawing lane markers\n",
" center_guard = CONST.LANES[3] + CONST.LANE_WIDTH//2\n",
" color = CONST.COLOR_ORANGE\n",
" for lane in CONST.LANES:\n",
" pygame.draw.line(screen, color, (0, lane-CONST.LANE_WIDTH//2), (CONST.SCREEN_WIDTH, lane-CONST.LANE_WIDTH//2))\n",
" color = CONST.COLOR_WHITE\n",
" pygame.draw.line(screen, CONST.COLOR_ORANGE, (0, CONST.LANES[len(CONST.LANES)-1] + CONST.LANE_WIDTH//2), (CONST.SCREEN_WIDTH, CONST.LANES[len(CONST.LANES)-1] + CONST.LANE_WIDTH//2))\n",
"\n",
"## Draw carrot (what the PID follows track lanes)\n",
" pygame.draw.circle(screen, CONST.COLOR_ORANGE, (agent.carrot), 5)\n",
" pygame.draw.circle(screen, CONST.COLOR_ORANGE, (300, int(CONST.LANES[3] + CONST.LANE_WIDTH//2)), 4)\n",
"#\n",
"## Draw most recent LiD\n",
" for beam in agent.lidar.beams:\n",
" pygame.draw.line(screen, beam.color, (beam.x1, beam.y1), (agent.rect.centerx, agent.rect.centery))\n",
"\n",
"# Plot lidar data in console if state.setLivePlot\n",
" if state.setLivePlot:\n",
" print(\"PLOTTING\")\n",
" state.plotState(True)\n",
" \n",
" \n",
"##### For if I decide to put some user input fuctionality #####\n",
"# Process input (events)\n",
" for event in pygame.event.get():\n",
"# Check for closing window\n",
" if event.type == pygame.QUIT:\n",
" collision_detected = True\n",
" leave_program = True\n",
" dqnn.session.close()\n",
" if event.type == pygame.KEYDOWN:\n",
" if event.key == pygame.K_p:\n",
" state.setLivePlot = not state.setLivePlot #toggle live plotting\n",
" action_idx = 1\n",
" if event.key == pygame.K_UP:\n",
" __console_data_print_frequency += 1\n",
" print(\"Print Frequ every {0} frames\".format(__console_data_print_frequency))\n",
" if event.key == pygame.K_DOWN:\n",
" __console_data_print_frequency -= 1\n",
" print(\"Print Frequ every {0} frames\".format(__console_data_print_frequency))\n",
" if event.key == pygame.K_RIGHT:\n",
" __console_data_print_frequency += 10\n",
" print(\"Print Frequ every {0} frames\".format(__console_data_print_frequency))\n",
" if event.key == pygame.K_LEFT:\n",
" __console_data_print_frequency -= 10\n",
" print(\"Print Frequ every {0} frames\".format(__console_data_print_frequency))\n",
" if event.key == pygame.K_q:\n",
" epsilon += 0.05\n",
" if epsilon > 1: epsilon = 1\n",
" print(\"Epsilon now: {0}\".format(epsilon))\n",
" if event.key == pygame.K_a:\n",
" epsilon -= 0.05\n",
" if epsilon < 0.1: epsilon = 0.1\n",
" print(\"Epsilon now: {0}\".format(epsilon))\n",
" \n",
" if __console_data_print_frequency <= 0: __console_data_print_frequency = 1\n",
" \n",
" if total_frames % __console_data_print_frequency == 0:\n",
" print(__console_string, os.linesep)\n",
" print(\"q_matrix: {0} -- \".format(qMatrix))\n",
" print(\"_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _\", os.linesep)\n",
" frames_this_epoch += 1\n",
" total_frames += 1\n",
"# After everything, flip display\n",
" pygame.display.flip()\n",
"\n",
" epoch_cnt += 1\n",
" if epoch_cnt == epochs-1: epoch_cnt = epochs - 2 \n",
" if leave_program: break\n",
" if log_data:\n",
" log.logData(fileNames, toLog)\n",
" # Data Logging \n",
" states0.clear()\n",
" actions.clear()\n",
" rewards.clear()\n",
" states1.clear()\n",
"\n",
"dqnn.session.close()\n",
"pygame.quit();"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment