Skip to content

Instantly share code, notes, and snippets.

View seanie12's full-sized avatar
🎲

Seanie Lee seanie12

🎲
View GitHub Profile
@malkin1729
malkin1729 / grid.py
Created January 30, 2022 19:28
GFlowNet trajectory balance: grid environment example
import torch as T
import numpy as np
import tqdm
import pickle
device = T.device('cpu')
horizon = 8
ndim = 2
def sample_gumbel(shape, eps=1e-20):
"""Sample from Gumbel(0, 1)"""
U = tf.random_uniform(shape,minval=0,maxval=1)
return -tf.log(-tf.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
y = logits + sample_gumbel(tf.shape(logits))
return tf.nn.softmax( y / temperature)