Skip to content

Instantly share code, notes, and snippets.

@void-elf
Last active June 20, 2017 07:08
Show Gist options
  • Select an option

  • Save void-elf/25b23942615b8e8d45dfc426ac199cb3 to your computer and use it in GitHub Desktop.

Select an option

Save void-elf/25b23942615b8e8d45dfc426ac199cb3 to your computer and use it in GitHub Desktop.
Multipanel Atari graphs!
def produce_plots(axes_array, games, plotable_data):
count = 0
scale = 1.0
for axis, game in zip(axes_array, games):
axis.set_title(game.name)
axis.set_ylabel(game.axis_label)
# scatter plot of results in the literature
for n, m in enumerate(game.measures):
kwargs = {"c": "r"}
if game.target and game.scale.improvement(game.target, m.value) >= 0:
kwargs["c"] = "b"
if m.withdrawn:
if "X" in markers.MarkerStyle().markers:
kwargs["marker"] = "X"
kwargs["c"] = "#aaaaaa"
axis.plot_date([m.date], [m.value], **kwargs)
label = m.name
if m.withdrawn and not "withdrawn" in label.lower():
label = "WITHDRAWN " + label
if len(label) >= 28:
label = label[:25] + "..."
axis.annotate('%s' % label, xy=(m.date, m.value), xytext=m.metric.scale.offset, fontsize=scale * 6, textcoords='offset points')
# cases where either results or dates of publication are uncertain
kwargs = {"c": "#80cf80", "linewidth": scale*1.0, "capsize": scale*1.5, "capthick": scale*0.5, "dash_capstyle": 'projecting'}
if m.min_date or m.max_date:
before = (m.date - m.min_date) if m.min_date else datetime.timedelta(0)
after = (m.max_date - m.date) if m.max_date else datetime.timedelta(0)
kwargs["xerr"] = numpy.array([[before], [after]])
if game.measures[n].value != game.measures[n].minval:
kwargs["yerr"] = numpy.array([[m.value - game.measures[n].minval], [game.measures[n].maxval - m.value]])
if "xerr" in kwargs or "yerr" in kwargs:
axis.errorbar(m.date, m.value, **kwargs)
# line graph of the frontier of best results
if not game.changeable:
best = game.measures[0].value
frontier_x, frontier_y = [], []
for m in game.measures:
if game.scale.improvement(best, m.value) >= 0 and not m.withdrawn:
frontier_x.append(m.date)
frontier_y.append(m.value)
xy = (m.date, m.value)
best = m.value
axis.plot_date(frontier_x, frontier_y, "g-")
# dashed line for "solved" / strong human performance
if game.target:
target_label = (game.target_label if game.target_label
else "Human performance" if game.parent and "agi" in game.parent.attributes
else "Target")
start = min([game.measures[0].date] + [m.min_date for m in game.measures if m.min_date])
end = max([game.measures[-1].date] + [m.max_date for m in game.measures if m.max_date])
axis.plot_date([start, end], 2 * [game.target], "r--", label=target_label)
count += 1
def graph_atari_games(scale=1.0):
well_populated_games = filter(lambda x: len(x.measures) > 2, simple_games.metrics)
num_game_metrics = len(well_populated_games)
for m in range(0, num_game_metrics, 2): # go through games 2 at a time
game1 = well_populated_games[m] # game/game1/game2 are metrics
games = [game1]
if m != num_game_metrics - 1:
game2 = well_populated_games[m+1]
games.append(game2)
fig, axes_array = plt.subplots(1, 2, sharey=True)
fig.set_size_inches((17*scale, 5*scale))
else:
fig, axes_array = plt.subplots(1, 1)
fig.set_size_inches((17*scale, 5*scale))
plotable_data = []
for game in games:
measures_x = []
measures_y = []
game.measures.sort(key=lambda x: (x.date, x.metric.scale.pseudolinear(x.value)))
for n, m in enumerate(game.measures): # copy scatterplot code
measures_x.append(m.date)
measures_y.append(m.value)
plotable_data.append([measures_x, measures_y])
if len(games) == 1:
produce_plots([axes_array], games, plotable_data)
else:
produce_plots(axes_array, games, plotable_data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment