Last active
November 11, 2016 17:56
-
-
Save taotao54321/a72311839e6bd5835250b73b5c7ad2b3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import gym | |
class Ram: | |
def __init__(self, state): | |
self.state = state | |
def read_u8(self, addr): | |
# state の要素は numpy.uint8 型なので、そのまま返すと int と挙 | |
# 動が違うケースがあって紛らわしい。例えば -ram.read_u8(addr) | |
# と書いたとき、返り値が uint8 型だと期待した結果にならない。 | |
# ということで int に変換して返す。 | |
return int(self.state[addr-0x80]) | |
def read_s8(self, addr): | |
value = self.read_u8(addr) | |
return value if value < 0x80 else value-0x100 | |
def get_action(): | |
while True: | |
try: | |
s = input("--- (0-5) [0] > ") | |
if not s: return 0 | |
act = int(s) | |
if not 0 <= act <= 5: raise ValueError("act must be 0-5") | |
return act | |
except ValueError as e: | |
print(e) | |
def dump_ram(state): | |
ram = Ram(state) | |
print(" | 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F") | |
print("-----+------------------------------------------------") | |
for addr in range(0x80, 0x100, 0x10): | |
row = " ".join("{:02X}".format(ram.read_u8(addr+i)) for i in range(0x10)) | |
print("0x{:02X} | {}".format(addr, row)) | |
print() | |
print("hero y: {}".format(ram.read_u8(0xBC))) | |
print("ball pos: ({}, {})".format(ram.read_u8(0xB1), ram.read_u8(0xB6))) | |
print("ball vel: ({}, {})".format(-ram.read_s8(0xBA), ram.read_s8(0xB8))) | |
print() | |
def main(): | |
env = gym.make("Pong-ram-v0") | |
print("env.action_space: {}".format(env.action_space)) | |
print("env.observation_space: {}".format(env.observation_space)) | |
print("env.reward_range: {}".format(env.reward_range)) | |
print("env.spec.timestep_limit: {}".format(env.spec.timestep_limit)) | |
print() | |
while True: | |
state = env.reset() | |
env.render() | |
for t in range(env.spec.timestep_limit): | |
print("step {}".format(t)) | |
dump_ram(state) | |
state, reward, done, info = env.step(get_action()) | |
env.render() | |
print("Reward: {}".format(reward)) | |
print("Info: {}".format(info)) | |
print() | |
if done: break | |
print("===== D O N E =====") | |
print() | |
if __name__ == "__main__": main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment