Last active
July 29, 2021 15:55
-
-
Save steven-peralta/82de425fa08ec32fe2e2a65919eb8638 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
import time | |
import gym | |
from gym import spaces | |
import numpy as np | |
import requests | |
from flatten_json import flatten_json | |
import subprocess | |
import os | |
import psutil | |
timeout = 60 | |
commands = ['potion use 0', | |
'potion use 1', | |
'potion use 2', | |
'potion use 3', | |
'potion use 4', | |
'potion discard 0', | |
'potion discard 1', | |
'potion discard 2', | |
'potion discard 3', | |
'potion discard 4', | |
'potion use 0 0', | |
'potion use 1 0', | |
'potion use 2 0', | |
'potion use 3 0', | |
'potion use 4 0', | |
'potion use 0 1', | |
'potion use 1 1', | |
'potion use 2 1', | |
'potion use 3 1', | |
'potion use 4 1', | |
'potion use 0 2', | |
'potion use 1 2', | |
'potion use 2 2', | |
'potion use 3 2', | |
'potion use 4 2', | |
'potion use 0 3', | |
'potion use 1 3', | |
'potion use 2 3', | |
'potion use 3 3', | |
'potion use 4 3', | |
'potion use 0 4', | |
'potion use 1 4', | |
'potion use 2 4', | |
'potion use 3 4', | |
'potion use 4 4', | |
'potion use 0 5', | |
'potion use 1 5', | |
'potion use 2 5', | |
'potion use 3 5', | |
'potion use 4 5', | |
'play 1', | |
'play 2', | |
'play 3', | |
'play 4', | |
'play 5', | |
'play 6', | |
'play 7', | |
'play 8', | |
'play 9', | |
'play 10', | |
'play 1 0', | |
'play 2 0', | |
'play 3 0', | |
'play 4 0', | |
'play 5 0', | |
'play 6 0', | |
'play 7 0', | |
'play 8 0', | |
'play 9 0', | |
'play 10 5', | |
'play 1 1', | |
'play 2 1', | |
'play 3 1', | |
'play 4 1', | |
'play 5 1', | |
'play 6 1', | |
'play 7 1', | |
'play 8 1', | |
'play 9 1', | |
'play 10 5', | |
'play 1 2', | |
'play 2 2', | |
'play 3 2', | |
'play 4 2', | |
'play 5 2', | |
'play 6 2', | |
'play 7 2', | |
'play 8 2', | |
'play 9 2', | |
'play 10 5', | |
'play 1 3', | |
'play 2 3', | |
'play 3 3', | |
'play 4 3', | |
'play 5 3', | |
'play 6 3', | |
'play 7 3', | |
'play 8 3', | |
'play 9 3', | |
'play 10 5', | |
'play 1 4', | |
'play 2 4', | |
'play 3 4', | |
'play 4 4', | |
'play 5 4', | |
'play 6 4', | |
'play 7 4', | |
'play 8 4', | |
'play 9 4', | |
'play 10 5', | |
'play 1 5', | |
'play 2 5', | |
'play 3 5', | |
'play 4 5', | |
'play 5 5', | |
'play 6 5', | |
'play 7 5', | |
'play 8 5', | |
'play 9 5', | |
'play 10 5', | |
'choose 0', | |
'choose 1', | |
'choose 2', | |
'choose 3', | |
'choose 4', | |
'choose 5', | |
'choose 6', | |
'choose 7', | |
'choose 8', | |
'choose 9', | |
'choose 10', | |
'choose 11', | |
'choose 12', | |
'choose 13', | |
'choose 14', | |
'choose 15', | |
'choose 16', | |
'choose 17', | |
'choose 18', | |
'choose 19', | |
'choose 20', | |
'choose 21', | |
'choose 22', | |
'choose 23', | |
'choose 24', | |
'choose 25', | |
'choose 26', | |
'choose 27', | |
'choose 28', | |
'choose 29', | |
'choose 30', | |
'choose 31', | |
'choose 32', | |
'choose 33', | |
'choose 34', | |
'choose 35', | |
'choose 36', | |
'choose 37', | |
'choose 38', | |
'choose 39', | |
'choose 40', | |
'choose 41', | |
'choose 42', | |
'choose 43', | |
'choose 44', | |
'choose 45', | |
'choose 46', | |
'choose 47', | |
'choose 48', | |
'choose 49', | |
'end', | |
'proceed', | |
'return', ] | |
def get_java_args(java_path): | |
return [java_path + 'java', '-Xmx512M', '-Xms128M', | |
'-jar', 'ModTheSpire.jar', '--skip-launcher', '--skip-intro', '--profile', 'Default'] | |
def launch_process(port, sts_path, java_bin_path, virtualize): | |
log_path = open('sts_'+str(port)+'.log', 'w') | |
err_path = open('sts_'+str(port)+'.err', 'w') | |
if virtualize: | |
var = subprocess.Popen(['/opt/VirtualGL/bin/vglrun', *get_java_args(java_bin_path)], | |
stdout=log_path, | |
stderr=err_path, | |
cwd=os.getcwd() + sts_path, | |
env={ | |
**os.environ, | |
'SURVEYOR_HOSTNAME': '0.0.0.0', | |
'SURVEYOR_PORT': str(port), | |
'DISPLAY': ':0', | |
'VGL_LOGO': '1' | |
}).pid | |
else: | |
var = subprocess.Popen(get_java_args(java_bin_path), cwd=os.getcwd() + sts_path, | |
stdout=log_path, | |
stderr=err_path, | |
env={ | |
**os.environ, | |
'SURVEYOR_HOSTNAME': '0.0.0.0', | |
'SURVEYOR_PORT': str(port) | |
}).pid | |
return var | |
def flatten_state_json(state): | |
return np.fromiter(flatten_json(state).values(), dtype=np.float32) | |
def is_listening(pid, port): | |
connections = psutil.Process(pid).connections() | |
return len(list( | |
filter(lambda conn: conn[3][0] == '0.0.0.0' and conn[3][1] == port and conn[5] == 'LISTEN', connections))) > 0 | |
class SlayTheSpireGym(gym.Env): | |
def __init__(self, sts_path='/sts/run', java_bin_path='', virtualize=True, port=8008): | |
self.action_space = spaces.Discrete(163) | |
self.observation_space = spaces.Box(-2, 2, (13164,), dtype=np.float32) | |
self.sts_pid = launch_process(port, sts_path, java_bin_path, virtualize) | |
self.port = port | |
self.hostname = 'localhost' | |
self.score = 0 | |
while not is_listening(self.sts_pid, port): | |
pass | |
self._reset_game() | |
def step(self, action=None): | |
if action is not None: | |
command = commands[action] | |
state = self._send_action(command) | |
reward = state['score'] - self.score | |
self.score = state['score'] | |
is_done = state['gameState']['isGameOver'] == 1 | |
state = flatten_state_json(state['gameState']) | |
return state, reward, is_done, {} | |
raise Exception('either action was none or the response failed') | |
def reset(self): | |
state = flatten_state_json(self._reset_game()['gameState']) | |
self.score = 0 | |
return state | |
def render(self, mode='human'): | |
pass | |
def _send_action(self, command): | |
return requests.post(self._get_url('/game'), json={'command': command}, timeout=timeout).json() | |
def _reset_game(self): | |
return requests.post(self._get_url('/game/reset'), timeout=timeout).json() | |
def _get_url(self, route): | |
return 'http://{hostname}:{port}{route}'.format(hostname=self.hostname, port=self.port, route=route) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment