Skip to content

Instantly share code, notes, and snippets.

@jschueller
Last active January 25, 2024 17:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jschueller/bb59a3dc019cb9a3159e82db412e3712 to your computer and use it in GitHub Desktop.
Save jschueller/bb59a3dc019cb9a3159e82db412e3712 to your computer and use it in GitHub Desktop.
censored_dist.py
import openturns as ot
class CensoredDistribution(ot.PythonDistribution):
def __init__(self, distribution, bounds):
dim = distribution.getDimension()
super(CensoredDistribution, self).__init__(dim)
if dim != bounds.getDimension():
raise ValueError("distribution/bounds dimension do not match")
if dim > 1:
raise ValueError("TODO")
w0 = distribution.computeProbability(bounds)
if w0 <= 0.0:
raise ValueError("bounds contain a null probability")
self.distribution = distribution
self.bounds = bounds
coll = [ot.TruncatedDistribution(distribution, bounds)]
weights = [w0]
dist_lb = distribution.getRange().getLowerBound()[0]
dist_ub = distribution.getRange().getUpperBound()[0]
intersection = distribution.getRange().intersect(bounds)
lbi = intersection.getLowerBound()[0]
ubi = intersection.getUpperBound()[0]
if lbi > dist_lb:
li = ot.Interval([dist_lb], [lbi])
w = distribution.computeProbability(li)
coll += [ot.Dirac(lbi)]
weights += [w]
if ubi < dist_ub:
ui = ot.Interval([ubi], [dist_ub])
w = distribution.computeProbability(ui)
coll += [ot.Dirac(ubi)]
weights += [w]
self.mixture = ot.Mixture(coll, weights)
def getRange(self):
return self.distribution.getRange()
def getRealization(self):
return self.mixture.getRealization()
def computeCDF(self, X):
return self.mixture.computeCDF(X)
def computePDF(self, X):
return self.mixture.computePDF(X)
def isDiscrete(self):
return self.mixture.isDiscrete()
def getMarginal(self, indices):
py_dist = CensoredDistribution(self.distribution.getMarginal(indices), self.bounds.getMarginal(indices))
return ot.Distribution(py_dist)
def censor(dist, bounds):
py_dist = CensoredDistribution(dist, bounds)
dist = ot.Distribution(py_dist)
sample = dist.getSample(10000)
print(f"range={dist.getRange()}")
print(f"min={sample.getMin()} max={sample.getMax()}")
# print(sample)
censor(ot.Normal(0.0, 1.0), ot.Interval(-1.2, 1.1))
censor(ot.Geometric(0.5), ot.Interval(-1.0, 4.0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment