Skip to content

Instantly share code, notes, and snippets.

@cameronabrams
Created March 3, 2020 22:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cameronabrams/6e044b91a0768f7387fc3a99fcd5391b to your computer and use it in GitHub Desktop.
Save cameronabrams/6e044b91a0768f7387fc3a99fcd5391b to your computer and use it in GitHub Desktop.
An iterable stars and bars class in python
import math
import scipy.special as sci
class stars_bars:
def __init__(self,nbins=4,nstars=40):
self.nbins=nbins
self.nstars=nstars
self.bars=[0]*(nbins-1) # bar positions; there are nbins-1 bars
self.bins=[0]*nbins # number of stars in each bin
self.stars=[0]*nstars # bin index of each star
self.index=0 # arrangement counter
self.narrangements=int(sci.comb(nstars+nbins-1,nbins-1)) # number of unique arrangements
self.npermutations=0 # number of permutations of current arrangement (when iterating)
def __iter__(self):
self.index=-1
# initially, all stars are in the right-most bin, and all bars at position-0
for i in range(self.nbins-1):
self.bars[i]=0
self.bars[self.nbins-2]=-1
for i in range(self.nbins):
self.bins[i]=0
for i in range(self.nstars):
self.stars[i]=self.nbins-1
self.bins[self.nbins-1]=self.nstars
return self
def compute_hist(self):
stars = self.nstars
for bar in range(self.nbins-2,-1,-1):
stars2bin=stars-self.bars[bar]
self.bins[bar+1]=stars2bin
for star in range(stars-1,stars-stars2bin-1,-1):
self.stars[star]=bar+1
stars-=stars2bin
if stars == 0:
break
self.bins[0]=stars
# leftover stars
for star in range(stars-1,-1,-1):
self.stars[star]=0
self.npermutations=self.compute_npermut()
def compute_npermut(self):
den=1.0
for b in self.bins:
den*=math.factorial(b)
return math.factorial(self.nstars)/den
def slide_bar(self,i):
if i>=0:
self.bars[i]+=1
if self.bars[i]==self.nstars+1:
self.slide_bar(i-1)
self.bars[i]=self.bars[i-1]
def __next__(self):
if self.bars[0]==self.nstars:
raise StopIteration
else:
i=self.nbins-2
self.slide_bar(i)
self.compute_hist()
self.index+=1
return self
def __str__(self):
retval='{:7d}/{:<7d}: '.format(self.index,self.narrangements)
if self.bars[0]!=0:
retval+='*'*self.bins[0]
for i in range(self.nbins-1):
retval+='|'+'*'*self.bins[i+1]
retval+=' ['
for i in range(self.nbins):
retval+='{:>2d}{}'.format(self.bins[i],',' if i<self.nbins-1 else '')
retval+='] ('
for i in range(self.nstars):
retval+='{:d}{}'.format(self.stars[i],',' if i<self.nstars-1 else '')
return retval+') '+'{:<20d}'.format(int(self.npermutations))
def mean(self):
tally=0
denom=0
for i in range(self.nbins):
tally+=i*self.bins[i]
denom+=self.bins[i]
return round(tally/denom)
if __name__=='__main__':
print('Stars and bars: an example')
for s in stars_bars(nbins=4,nstars=10):
print(s)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment