Skip to content

Instantly share code, notes, and snippets.

@Tostino
Last active January 5, 2017 22:02
Show Gist options
  • Save Tostino/4ba74e5366c95ae403c165eddb24f5ed to your computer and use it in GitHub Desktop.
Save Tostino/4ba74e5366c95ae403c165eddb24f5ed to your computer and use it in GitHub Desktop.
rl4j doom
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import oshi.SystemInfo;
import oshi.hardware.GlobalMemory;
import oshi.util.FormatUtil;
import vizdoom.*;
import vizdoom.Button;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.color.ColorSpace;
import java.awt.image.*;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
/**
* @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/28/16.
* <p>
* Mother abstract class for all VizDoom scenarios
* <p>
* is mostly configured by
* <p>
* String scenario; name of the scenario
* double livingReward; additional reward at each step for living
* double deathPenalty; negative reward when ded
* int doomSkill; skill of the ennemy
* int timeout; number of step after which simulation time out
* int startTime; number of internal tics before the simulation starts (useful to draw weapon by example)
* List<Button> buttons; the list of inputs one can press for a given scenario (noop is automatically added)
*/
abstract public class VizDoom implements MDP<VizDoom.MdpGameScreen, Integer, DiscreteSpace>
{
final public static String DOOM_ROOT = "./vizdoom";
final protected Logger log = LoggerFactory.getLogger("Vizdoom");
final protected GlobalMemory memory = new SystemInfo().getHardware().getMemory();
final protected List<int[]> actions;
final protected DiscreteSpace discreteSpace;
final protected ObservationSpace<MdpGameScreen> observationSpace;
final protected boolean render;
protected DoomGame game;
protected double scaleFactor = 1;
public VizDoom()
{
this(true);
}
public VizDoom(boolean render)
{
this.render = render;
actions = new ArrayList<int[]>();
game = new DoomGame();
setupGame();
discreteSpace = new DiscreteSpace(getConfiguration().getButtons().size() + 1);
observationSpace = new ArrayObservationSpace<>(new int[]{game.getScreenHeight(), game.getScreenWidth(), 3});
}
public boolean isRender()
{
return render;
}
public void setScaleFactor(final double scaleFactor)
{
this.scaleFactor = scaleFactor;
}
public void setupGame()
{
Configuration conf = getConfiguration();
game.setViZDoomPath(DOOM_ROOT + "/vizdoom");
game.setDoomGamePath(DOOM_ROOT + "/scenarios/freedoom2.wad");
game.setDoomScenarioPath(DOOM_ROOT + "/scenarios/" + conf.getScenario() + ".wad");
game.setDoomMap("map01");
game.setScreenFormat(ScreenFormat.RGB24);
game.setScreenResolution(ScreenResolution.RES_800X600);
// Sets other rendering options
game.setRenderHud(false);
game.setRenderCrosshair(false);
game.setRenderWeapon(true);
game.setRenderDecals(false);
game.setRenderParticles(false);
GameVariable[] gameVar = new GameVariable[]{
GameVariable.KILLCOUNT,
GameVariable.ITEMCOUNT,
GameVariable.SECRETCOUNT,
GameVariable.FRAGCOUNT,
GameVariable.HEALTH,
GameVariable.ARMOR,
GameVariable.DEAD,
GameVariable.ON_GROUND,
GameVariable.ATTACK_READY,
GameVariable.ALTATTACK_READY,
GameVariable.SELECTED_WEAPON,
GameVariable.SELECTED_WEAPON_AMMO,
GameVariable.AMMO1,
GameVariable.AMMO2,
GameVariable.AMMO3,
GameVariable.AMMO4,
GameVariable.AMMO5,
GameVariable.AMMO6,
GameVariable.AMMO7,
GameVariable.AMMO8,
GameVariable.AMMO9,
GameVariable.AMMO0
};
// Adds game variables that will be included in state.
for (int i = 0; i < gameVar.length; i++)
{
game.addAvailableGameVariable(gameVar[i]);
}
// Causes episodes to finish after timeout tics
game.setEpisodeTimeout(conf.getTimeout());
game.setEpisodeStartTime(conf.getStartTime());
game.setWindowVisible(render);
game.setSoundEnabled(false);
game.setMode(Mode.PLAYER);
game.setLivingReward(conf.getLivingReward());
// Adds buttons that will be allowed.
List<Button> buttons = conf.getButtons();
int size = buttons.size();
actions.add(new int[size + 1]);
for (int i = 0; i < size; i++)
{
game.addAvailableButton(buttons.get(i));
int[] action = new int[size + 1];
action[i] = 1;
actions.add(action);
}
game.setDeathPenalty(conf.getDeathPenalty());
game.setDoomSkill(conf.getDoomSkill());
game.init();
}
public boolean isDone()
{
return game.isEpisodeFinished();
}
public MdpGameScreen reset()
{
log.info("free Memory: " + FormatUtil.formatBytes(memory.getAvailable()) + "/"
+ FormatUtil.formatBytes(memory.getTotal()));
game.newEpisode();
int[] screen_buffer = convertScreenBuffer(game.getState().screenBuffer);
return new MdpGameScreen(screen_buffer);
}
public void close()
{
game.close();
}
public int[] convertScreenBuffer(byte[] buffer)
{
IntBuffer intBuf =
ByteBuffer.wrap(game.getState().screenBuffer)
.order(ByteOrder.BIG_ENDIAN)
.asIntBuffer();
int[] initial_array = new int[intBuf.remaining()];
intBuf.get(initial_array);
int height_ratio = game.getScreenHeight() / game.getScreenWidth();
int width_ratio = game.getScreenWidth() / game.getScreenHeight();
if (height_ratio / width_ratio == 1)
{
return initial_array;
}
else
{
// FIXME: here I need help
// do some scaling somehow
try
{
DataBufferByte buf = new DataBufferByte(buffer, buffer.length);
ColorModel cm = new ComponentColorModel(ColorSpace.getInstance(ColorSpace.CS_sRGB), new int[]{8, 8, 8}, false, false, Transparency.OPAQUE, DataBuffer.TYPE_BYTE);
BufferedImage i = new BufferedImage(cm, Raster.createInterleavedRaster(buf, 800, 600, 800 * 3, 3, new int[]{0, 1, 2}, null), false, null);
File f = new File("E:\\test.png");
ImageIO.write(i, "png", f);
}
catch (IOException e)
{
e.printStackTrace();
}
// return scaled value
return initial_array;
}
}
public StepReply<MdpGameScreen> step(Integer action)
{
double r = game.makeAction(actions.get(action)) * scaleFactor;
log.info(game.getEpisodeTime() + " " + r + " " + action + " ");
int[] screen_buffer = convertScreenBuffer(game.getState().screenBuffer);
return new StepReply(new MdpGameScreen(screen_buffer), r, game.isEpisodeFinished(), null);
}
public ObservationSpace<MdpGameScreen> getObservationSpace()
{
return observationSpace;
}
public DiscreteSpace getActionSpace()
{
return discreteSpace;
}
public abstract Configuration getConfiguration();
public abstract VizDoom newInstance();
public static class Configuration
{
String scenario;
double livingReward;
double deathPenalty;
int doomSkill;
int timeout;
int startTime;
List<Button> buttons;
public Configuration(final String scenario, final double livingReward, final double deathPenalty, final int doomSkill, final int timeout, final int startTime, final List<Button> buttons)
{
this.scenario = scenario;
this.livingReward = livingReward;
this.deathPenalty = deathPenalty;
this.doomSkill = doomSkill;
this.timeout = timeout;
this.startTime = startTime;
this.buttons = buttons;
}
public String getScenario()
{
return scenario;
}
public double getLivingReward()
{
return livingReward;
}
public double getDeathPenalty()
{
return deathPenalty;
}
public int getDoomSkill()
{
return doomSkill;
}
public int getTimeout()
{
return timeout;
}
public int getStartTime()
{
return startTime;
}
public List<Button> getButtons()
{
return buttons;
}
}
public static class MdpGameScreen implements Encodable
{
double[] array;
public MdpGameScreen(int[] screen)
{
array = new double[screen.length];
for (int i = 0; i < screen.length; i++)
{
array[i] = screen[i];
}
}
public double[] toArray()
{
return array;
}
}
}
@Tostino
Copy link
Author

Tostino commented Jan 5, 2017

test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment