Skip to content

Instantly share code, notes, and snippets.

@gbaptista
Created September 19, 2018 01:51
Show Gist options
  • Save gbaptista/840746bab54a2fa6cb7f603b1fc4292a to your computer and use it in GitHub Desktop.
Save gbaptista/840746bab54a2fa6cb7f603b1fc4292a to your computer and use it in GitHub Desktop.
from Orange.widgets import gui
from Orange.widgets.widget import OWWidget, Input
from Orange.widgets.settings import Setting
from AnyQt import QtWidgets
from AnyQt.QtGui import QColor, QPen, QPalette, QFont
from AnyQt.QtCore import Qt
from Orange.widgets.utils import colorpalette, colorbrewer
import pyqtgraph as pg
from .agents.agent import Agent
class OWBenchmark(OWWidget):
id = "orange.widgets.reinforcement.benchmark"
name = "Benchmark"
description = """Compare Agents performance."""
icon = "icons/benchmark.png"
priority = 80
category = "Reinforcement"
keywords = ["OpenAI Gym", "Enviroment", "Info", "Details"]
want_main_area = True
resizing_enabled = True
agent = None
enviroment_id = None
selected_agents = Setting([])
graph_name = "plot"
class Inputs:
agent = Input("Agent", Agent, multiple=True)
def __init__(self):
super().__init__()
self.agents = {}
self.plot_areas = {}
self.plot_items = {}
self.agent_names = []
self.render_layout()
def render_layout(self):
cbox = gui.vBox(self.controlArea, "Agents:")
cbox.setFlat(True)
self.agents_list_box = gui.listBox(
cbox, self, "selected_agents", "agent_names",
selectionMode=QtWidgets.QListView.MultiSelection,
callback=self._on_agents_changed
)
self.render_plot_area(0)
def _on_agents_changed(self):
self.render_agents_lines()
def render_agents_lines(self):
self.plot_items[0].clear()
self.generate_colors(len(self.agents))
for i in range(len(self.agents)):
item = self.agents_list_box.item(i)
if item:
item.setIcon(colorpalette.ColorPixmap(self.colors[i]))
# TODO line
if i in self.selected_agents:
self.add_line(0, i, [0, 0.5, 1], [0, 0.5, 1])
@Inputs.agent
def set_agent(self, agent, channel):
if not agent is None:
channel_id = channel[0]
self.agents[channel_id] = agent
self.agent_names = [self.agents[agent].name for agent in self.agents]
self.selected_agents = list(range(len(self.agents)))
self.render_agents_lines()
def add_line(self, plot_area_i, line_i, x_values, y_values):
color = self.colors[line_i]
pen = QPen(color, 1)
pen.setCosmetic(True)
shadow_pen = QPen(pen.color().lighter(160), 2.5)
shadow_pen.setCosmetic(True)
line = pg.PlotDataItem(
x_values, y_values,
pen=pen, shadowPen=shadow_pen,
symbol="+", symbolSize=3,
symbolPen=shadow_pen, antialias=True
)
self.plot_items[plot_area_i].addItem(line)
def generate_colors(self, n):
scheme = colorbrewer.colorSchemes["qualitative"]["Dark2"]
if n > len(scheme):
scheme = colorpalette.DefaultRGBColors
self.colors = colorpalette.ColorPaletteGenerator(n, scheme)
def render_plot_area(self, i):
self.plot_areas[i] = pg.GraphicsView(background="w")
self.plot_areas[i].setFrameStyle(QtWidgets.QFrame.StyledPanel)
self.plot_items[i] = pg.PlotItem(enableMenu=True)
self.plot_items[i].setMouseEnabled(True, True)
self.plot_items[i].hideButtons()
pen = QPen(self.palette().color(QPalette.Text))
tickfont = QFont(self.font())
tickfont.setPixelSize(max(int(tickfont.pixelSize() * 2 // 3), 11))
axis = self.plot_items[i].getAxis("bottom")
axis.setTickFont(tickfont)
axis.setPen(pen)
axis.setLabel("Episode")
axis = self.plot_items[i].getAxis("left")
axis.setTickFont(tickfont)
axis.setPen(pen)
axis.setLabel("Score")
self.plot_items[i].showGrid(True, True, alpha=0.1)
# self.plot_items[i].setRange(xRange=(0.0, 2.0), yRange=(0.0, 2.0), padding=0.05)
self.plot_areas[i].setCentralItem(self.plot_items[i])
self.mainArea.layout().addWidget(self.plot_areas[i])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment