Skip to content

Instantly share code, notes, and snippets.

@xinyuanwylb19
Last active June 28, 2018 20:33
Show Gist options
  • Save xinyuanwylb19/3a21db8ac1f9a569bef643805c8f9872 to your computer and use it in GitHub Desktop.
Save xinyuanwylb19/3a21db8ac1f9a569bef643805c8f9872 to your computer and use it in GitHub Desktop.
Number of Required Sample Points
#Number of Sample Points
#Version 1.0
#Xinyuan Wei 01/01/2016
'''
This program can calculate the minimum number of sample points (mean and median)
required to reliably estimate the fire cycle.
'''
import sys
import os
import gdal
import ogr
from gdalconst import *
import math
from math import sqrt
import random
import copy
import pylab
import numpy as np
import matplotlib.pyplot as plt
#get the data name
dataname='Clipped_TSLF_3.img'
#how many times fire cycle are calculated
nsimulation=1000
#the tolerance error
ci=0.1
#the total number of subgrids
stratify=100
#the width of edge area(Kilometer)
edge_effect_width=0
#onetime*stratify=1000
onetime=int(1000/stratify)
#the width of edge area (cells)
edge=edge_effect_width*10 #cells
#get the work directory
print(os.getcwd())
#read the data
rasterdata=gdal.Open(dataname,GA_ReadOnly)
if rasterdata is None:
print ('Could not open ' + dataname)
sys.exit(1)
#get the raster data information
cols=rasterdata.RasterXSize
rows=rasterdata.RasterYSize
bands=rasterdata.RasterCount
driver=rasterdata.GetDriver().LongName
print('')
print('[Description of the study area map:]')
print('Size of the map: '+str(cols)+'*'+str(rows)+' ha'+'.')
print('Number of bands: '+str(bands)+'.')
print('Format of this image: '+str(driver)+'.')
geotransform=rasterdata.GetGeoTransform()
print('')
print('xOrigin,pixelWidth,xrotation,yOrigin,xrotation,pixelHeight')
print(geotransform)
#get the band information
band=rasterdata.GetRasterBand(1)
bandtype=gdal.GetDataTypeName(band.DataType)
print('')
print('[Band type is]'+bandtype+'.')
if edge_effect_width==0:
print('Width of edge is '+str(edge_effect_width)+' kilometer.')
if edge_effect_width==1:
print('Width of edge is '+str(edge_effect_width)+' kilometer.')
if edge_effect_width>1:
print('Width of edge is '+str(edge_effect_width)+' kilometers.')
print('')
#calculate the number of points
points=[]
values=[]
averages=[]
average=0
number_point=0
number_points=[]
#fc_each=[]
new_cols=cols-2*edge
#the number of grids in each line
lr=int(sqrt(stratify))
#the number of pixels one line in a subgrid
lrpixel=int(new_cols/lr)
print('Study area map is divided to '+str(stratify)+' subgrids')
print('')
for m in range (0,nsimulation):
for n in range(0,onetime):
for i in range(0,lr):
for j in range(0,lr):
a=random.randint(edge+i*lrpixel,edge+(i+1)*lrpixel-1)
b=random.randint(edge+j*lrpixel,edge+(j+1)*lrpixel-1)
point=(a,b)
pixelvalue=band.ReadRaster(a,b,1,1)
value=pixelvalue[0]
points.append(point)
values.append(value)
length=len(points)
summary=sum(values)
average=summary/length
averages.append(average)
#print(averages)
#print(points)
#print(length)
#print(values)
#print(summary)
#fc_each.append(average)
for nop in range (0,len(averages)):
variance=abs(averages[nop]-average)
if variance/average>ci:
number_point=nop
number_point=nop+1
number_points.append(number_point)
points=[]
values=[]
averages=[]
#mean number of points required
a=sum(number_points)/len(number_points)
#median number of points required
b=len(number_points)
nsort=copy.deepcopy(number_points)
nsort.sort()
c=nsort[int(len(number_points)/2)]
print('The mean number of points is '+str(int(a)))
print('The median number of points is '+str(int(c)))
#print(number_points)
#print('')
#print('The points are '+str(points))
#print('')
#print('The value of each point '+str(values))
np.savetxt('number_points.csv',number_points,delimiter=',')
#np.savetxt('fc_each.csv',fc_each,delimiter=',')
#show the distribution of points
plt.figure(1)
plt.subplot(211)
plt.plot(number_points,'.b-')
plt.xlabel('Simulation Time Step',fontsize=12)
plt.ylabel('Number of Points',fontsize=14)
plt.subplot(212)
plt.hist(number_points,100,normed=1,facecolor='b',alpha=0.5)
plt.xlabel('Number of Points',fontsize=12)
plt.ylabel('Frequency',fontsize=12)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment