Skip to content

Instantly share code, notes, and snippets.

@steven-peralta
Last active July 29, 2021 15:55
Show Gist options
  • Save steven-peralta/82de425fa08ec32fe2e2a65919eb8638 to your computer and use it in GitHub Desktop.
Save steven-peralta/82de425fa08ec32fe2e2a65919eb8638 to your computer and use it in GitHub Desktop.
import threading
import time
import gym
from gym import spaces
import numpy as np
import requests
from flatten_json import flatten_json
import subprocess
import os
import psutil
timeout = 60
commands = ['potion use 0',
'potion use 1',
'potion use 2',
'potion use 3',
'potion use 4',
'potion discard 0',
'potion discard 1',
'potion discard 2',
'potion discard 3',
'potion discard 4',
'potion use 0 0',
'potion use 1 0',
'potion use 2 0',
'potion use 3 0',
'potion use 4 0',
'potion use 0 1',
'potion use 1 1',
'potion use 2 1',
'potion use 3 1',
'potion use 4 1',
'potion use 0 2',
'potion use 1 2',
'potion use 2 2',
'potion use 3 2',
'potion use 4 2',
'potion use 0 3',
'potion use 1 3',
'potion use 2 3',
'potion use 3 3',
'potion use 4 3',
'potion use 0 4',
'potion use 1 4',
'potion use 2 4',
'potion use 3 4',
'potion use 4 4',
'potion use 0 5',
'potion use 1 5',
'potion use 2 5',
'potion use 3 5',
'potion use 4 5',
'play 1',
'play 2',
'play 3',
'play 4',
'play 5',
'play 6',
'play 7',
'play 8',
'play 9',
'play 10',
'play 1 0',
'play 2 0',
'play 3 0',
'play 4 0',
'play 5 0',
'play 6 0',
'play 7 0',
'play 8 0',
'play 9 0',
'play 10 5',
'play 1 1',
'play 2 1',
'play 3 1',
'play 4 1',
'play 5 1',
'play 6 1',
'play 7 1',
'play 8 1',
'play 9 1',
'play 10 5',
'play 1 2',
'play 2 2',
'play 3 2',
'play 4 2',
'play 5 2',
'play 6 2',
'play 7 2',
'play 8 2',
'play 9 2',
'play 10 5',
'play 1 3',
'play 2 3',
'play 3 3',
'play 4 3',
'play 5 3',
'play 6 3',
'play 7 3',
'play 8 3',
'play 9 3',
'play 10 5',
'play 1 4',
'play 2 4',
'play 3 4',
'play 4 4',
'play 5 4',
'play 6 4',
'play 7 4',
'play 8 4',
'play 9 4',
'play 10 5',
'play 1 5',
'play 2 5',
'play 3 5',
'play 4 5',
'play 5 5',
'play 6 5',
'play 7 5',
'play 8 5',
'play 9 5',
'play 10 5',
'choose 0',
'choose 1',
'choose 2',
'choose 3',
'choose 4',
'choose 5',
'choose 6',
'choose 7',
'choose 8',
'choose 9',
'choose 10',
'choose 11',
'choose 12',
'choose 13',
'choose 14',
'choose 15',
'choose 16',
'choose 17',
'choose 18',
'choose 19',
'choose 20',
'choose 21',
'choose 22',
'choose 23',
'choose 24',
'choose 25',
'choose 26',
'choose 27',
'choose 28',
'choose 29',
'choose 30',
'choose 31',
'choose 32',
'choose 33',
'choose 34',
'choose 35',
'choose 36',
'choose 37',
'choose 38',
'choose 39',
'choose 40',
'choose 41',
'choose 42',
'choose 43',
'choose 44',
'choose 45',
'choose 46',
'choose 47',
'choose 48',
'choose 49',
'end',
'proceed',
'return', ]
def get_java_args(java_path):
return [java_path + 'java', '-Xmx512M', '-Xms128M',
'-jar', 'ModTheSpire.jar', '--skip-launcher', '--skip-intro', '--profile', 'Default']
def launch_process(port, sts_path, java_bin_path, virtualize):
log_path = open('sts_'+str(port)+'.log', 'w')
err_path = open('sts_'+str(port)+'.err', 'w')
if virtualize:
var = subprocess.Popen(['/opt/VirtualGL/bin/vglrun', *get_java_args(java_bin_path)],
stdout=log_path,
stderr=err_path,
cwd=os.getcwd() + sts_path,
env={
**os.environ,
'SURVEYOR_HOSTNAME': '0.0.0.0',
'SURVEYOR_PORT': str(port),
'DISPLAY': ':0',
'VGL_LOGO': '1'
}).pid
else:
var = subprocess.Popen(get_java_args(java_bin_path), cwd=os.getcwd() + sts_path,
stdout=log_path,
stderr=err_path,
env={
**os.environ,
'SURVEYOR_HOSTNAME': '0.0.0.0',
'SURVEYOR_PORT': str(port)
}).pid
return var
def flatten_state_json(state):
return np.fromiter(flatten_json(state).values(), dtype=np.float32)
def is_listening(pid, port):
connections = psutil.Process(pid).connections()
return len(list(
filter(lambda conn: conn[3][0] == '0.0.0.0' and conn[3][1] == port and conn[5] == 'LISTEN', connections))) > 0
class SlayTheSpireGym(gym.Env):
def __init__(self, sts_path='/sts/run', java_bin_path='', virtualize=True, port=8008):
self.action_space = spaces.Discrete(163)
self.observation_space = spaces.Box(-2, 2, (13164,), dtype=np.float32)
self.sts_pid = launch_process(port, sts_path, java_bin_path, virtualize)
self.port = port
self.hostname = 'localhost'
self.score = 0
while not is_listening(self.sts_pid, port):
pass
self._reset_game()
def step(self, action=None):
if action is not None:
command = commands[action]
state = self._send_action(command)
reward = state['score'] - self.score
self.score = state['score']
is_done = state['gameState']['isGameOver'] == 1
state = flatten_state_json(state['gameState'])
return state, reward, is_done, {}
raise Exception('either action was none or the response failed')
def reset(self):
state = flatten_state_json(self._reset_game()['gameState'])
self.score = 0
return state
def render(self, mode='human'):
pass
def _send_action(self, command):
return requests.post(self._get_url('/game'), json={'command': command}, timeout=timeout).json()
def _reset_game(self):
return requests.post(self._get_url('/game/reset'), timeout=timeout).json()
def _get_url(self, route):
return 'http://{hostname}:{port}{route}'.format(hostname=self.hostname, port=self.port, route=route)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment