Skip to content

Instantly share code, notes, and snippets.

@nishidy
Last active October 10, 2022 04:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nishidy/ca20b2746d97c43713485be3fb699434 to your computer and use it in GitHub Desktop.
Save nishidy/ca20b2746d97c43713485be3fb699434 to your computer and use it in GitHub Desktop.
Segment Tree in Python
class SegmentTree:
def __init__(s_):
s_.segtree=[]
s_.lazytree=[]
s_.treelayer=0
s_.capacity=0
s_.datalength=0
s_.compare=max
s_.firstleaf=0
s_.numleafs=0 # alias to firstleaf
s_.INIT=-1
s_.INF=10**9
def setdata(s_,data):
if s_.segtree!=[]:
print("### Your data has been already set in the segment tree. Exit.")
return
print("### Make sure that the function to compare is {0}.".format(s_.compare))
s_.datalength = len(data)
while 2**s_.treelayer<s_.datalength:
s_.treelayer+=1
s_.capacity=2**(s_.treelayer+1)
s_.segtree=[s_.INIT]*s_.capacity
s_.lazytree=[s_.INIT]*s_.capacity
s_.firstleaf=s_.capacity//2
s_.numleafs=s_.firstleaf
for i in range(s_.datalength):
s_.segtree[s_.firstleaf+i]=data[i]
s_.segtree[0]=0
for i in range(s_.capacity-1,1,-2):
s_.segtree[i//2]=s_.compare(s_.segtree[i],s_.segtree[i-1])
def setcompare(s_,func):
s_.compare=func
if -1==func(-1,s_.INF):
s_.INIT=s_.INF
else:
s_.INIT=-1
if s_.segtree!=[]:
print("### Make sure that the function to compare is {0}.".format(s_.compare))
def eval(s_,i):
if s_.lazytree[i]==s_.INIT:
return
if i<s_.firstleaf:
s_.lazytree[i*2] = s_.lazytree[i]
s_.lazytree[i*2+1] = s_.lazytree[i]
s_.segtree[i]=s_.lazytree[i]
s_.lazytree[i]=s_.INIT
def __inner_query(s_,ql,qr,l,r,i):
s_.eval(i)
if qr<l or r<ql:
return s_.INIT
if ql<=l and r<=qr:
return s_.segtree[i]
q1=s_.__inner_query(ql,qr,l,(l+r)//2,i*2)
q2=s_.__inner_query(ql,qr,(l+r)//2+1,r,i*2+1)
return s_.compare(q1,q2)
def query(s_,ql,qr):
return s_.__inner_query(ql,qr,0,s_.numleafs-1,1)
def __inner_update(s_,ql,qr,l,r,i,v):
s_.eval(i)
if qr<l or r<ql:
return
if ql<=l and r<=qr:
s_.lazytree[i]=v
s_.eval(i)
return
s_.__inner_update(ql,qr,l,(l+r)//2,i*2,v)
s_.__inner_update(ql,qr,(l+r)//2+1,r,i*2+1,v)
s_.segtree[i]=s_.compare(s_.segtree[i*2],s_.segtree[i*2+1])
def lazyupdate(s_,ql,qr,v):
s_.__inner_update(ql,qr,0,s_.numleafs-1,1,v)
def showtree(s_):
print("=== Segment Tree ===")
for i in range(1,s_.treelayer+2):
print(2**(i-1),[ "INF" if x==s_.INIT else x for x in s_.segtree[2**(i-1):2**i]])
print("=== END ===")
print("=== Lazy Segment Tree ===")
for i in range(1,s_.treelayer+2):
print(2**(i-1),[ "INF" if x==s_.INIT else x for x in s_.lazytree[2**(i-1):2**i]])
print("=== END ===")
import random
def main():
N=30
data=[]
for n in range(1,N+1):
data.append(n)
#data.append(10)
random.shuffle(data)
sg=SegmentTree()
sg.setdata(data)
sg.showtree()
print(s_g.query(10,20))
print(s_g.query(4,7))
sg.lazyupdate(5,6,5)
print(s_g.query(4,7))
if __name__ == "__main__":
main()
@nishidy
Copy link
Author

nishidy commented Oct 10, 2022

Test code.

import segment
import sys,random

def create_segtree(N):
    data=[]
    for n in range(1,N+1):
        data.append(n)
        #data.append(10)
    random.shuffle(data)

    sg=segment.SegmentTree()
    sg.setcompare(min)
    sg.setdata(data)
    #sg.showtree()
    return sg

def test1(sg,N):
    base=10000
    vi=-1
    for n in range(N):
        renew=base+n
        if vi==-1: vi=renew # remember the first one which will be the smallest
        v0=sg.query(n,n)
        sg.lazyupdate(n,n,renew)
        v1=sg.query(n,n)
        assert v1==renew,"{0} NG: {1} {2}".format(n,v0,v1)

    v=sg.query(0,N)
    assert v==vi,"NG: {0} {1}".format(v,vi)

    v=sg.query(1,N)
    assert v!=vi,"NG: {0} {1}".format(v,vi)

    print("### test1 successfully completed.")

def test2(sg,N):
    if 500<N: return
    for m in range(1,N):
        for n in range(N-m):
            v0=sg.query(n,n+m)
            sg.lazyupdate(n,n+m,max(0,v0-1))
            v1=sg.query(n,n+m)
            assert max(0,v0-1)==v1, "{0} NG: {1} {2}".format(n,v0,v1)

    print("### test2 successfully completed.")

def debug1(sg,N):
    base=10000
    n=2
    #renew=base+n
    renew=0
    v0=sg.query(n,n)
    print("lazyupdate---")
    sg.lazyupdate(n,n,renew)
    print("---lazyupdate")
    sg.showtree()
    v1=sg.query(n,n)
    if v0!=v1 and v1==renew:
        print("{0} OK: {1} {2}".format(n,v0,v1))
    else:
        print("{0} NG: {1} {2}".format(n,v0,v1))
    sg.showtree()

if len(sys.argv)==2:
    N=int(sys.argv[1])
    random.seed(N)
    sg=create_segtree(N)
    test1(sg,N)
    #debug1(sg,N)
    test2(sg,N)

@nishidy
Copy link
Author

nishidy commented Oct 10, 2022

The test result.

python3 segment_test.py 499
### Make sure that the function to compare is <built-in function min>.
### test1 successfully completed.
### test2 successfully completed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment