Created
May 5, 2022 07:26
-
-
Save yknishidate/32cd63ef648217ca7fc787b5b98a8696 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import random | |
import matplotlib.pyplot as plt | |
class UniformDistribution: | |
def __init__(self, a: float, b: float) -> None: | |
self.a = a | |
self.b = b | |
def sample(self) -> float: | |
return random.uniform(self.a, self.b) | |
def pdf(self, x: float) -> float: | |
return 1.0 / (self.b - self.a) | |
class SinDistribution: | |
def sample(self) -> None: | |
assert False, 'cannot sample from this distribution!' | |
def pdf(self, x: float) -> float: | |
return math.sin(x) | |
class RISReservoir: | |
def __init__(self) -> None: | |
self.y = None | |
self.weight_sum = 0.0 | |
self.target_pdf = 0 | |
self.num_candidates = 0 | |
def update(self, x, w, target_pdf) -> bool: | |
self.weight_sum += w | |
self.num_candidates += 1 | |
accept = random.random() < w / self.weight_sum | |
if accept: | |
self.y = x | |
self.target_pdf = target_pdf | |
return accept | |
def calc_reservoir_weight(self) -> float: | |
return (1.0 / self.target_pdf) * (self.weight_sum / self.num_candidates) | |
def actual_f(x): | |
""" example function """ | |
return math.sin(x) | |
def perform_streaming_RIS(num_candidates) -> float: | |
""" calc a result of integration """ | |
reservoir = RISReservoir() | |
source = UniformDistribution(0.0, math.pi) | |
target = SinDistribution() | |
for _ in range(num_candidates): | |
x = source.sample() | |
target_pdf = target.pdf(x) | |
w = target_pdf / source.pdf(x) | |
reservoir.update(x, w, target_pdf) | |
actual_value = actual_f(reservoir.y) | |
w = reservoir.calc_reservoir_weight() | |
return actual_value * w | |
if __name__ == "__main__": | |
num_candidates = 32 | |
samples = [perform_streaming_RIS(num_candidates) for _ in range(50000)] | |
plt.hist(samples, bins=20) | |
plt.title("Streaming RIS, M=" + str(num_candidates)) | |
plt.show() | |
plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment