Skip to content

Instantly share code, notes, and snippets.

@rcoup
Created March 26, 2012 20:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rcoup/2209392 to your computer and use it in GitHub Desktop.
Save rcoup/2209392 to your computer and use it in GitHub Desktop.
gdal_retile parallel experiment
--- swig/python/scripts/gdal_retile.py 2011-11-30 14:28:34.528230721 +1300
+++ swig/python/scripts/gdal_retile.py 2011-12-02 14:25:54.106882681 +1300
@@ -44,6 +44,32 @@
import os
import math
+from Queue import Queue
+
+class DummyPool(object):
+ def __init__(self, *args, **kwargs):
+ pass
+ def close(self):
+ pass
+ def join(self):
+ pass
+ def apply_async(self, f, args=[], kwargs={}):
+ f(*args, **kwargs)
+
+class DummyLock(object):
+ def __enter__(self):
+ pass
+ def __exit__(self, typ, value, traceback):
+ pass
+
+try:
+ from multiprocessing.dummy import Pool as _Pool
+ from multiprocessing.dummy import Lock as _Lock
+except:
+ # drop-in replacements for systems without Multithreading.dummy support
+ _Pool = DummyPool
+ _Lock = DummyLock
+
class AffineTransformDecorator:
""" A class providing some usefull methods for affine Transformations """
def __init__(self, transform ):
@@ -74,9 +100,9 @@
class DataSetCache:
""" A class for caching source tiles """
- def __init__(self ):
- self.cacheSize=8
- self.queue=[]
+ def __init__(self, cacheSize=8):
+ self.cacheSize=cacheSize
+ self.queue=Queue()
self.dict={}
def get(self,name ):
@@ -86,19 +112,12 @@
if result is None:
print("Error openenig:%s" % NameError)
sys.exit(1)
- if len(self.queue)==self.cacheSize:
- toRemove = self.queue.pop(0)
+ if self.queue.qsize()>=self.cacheSize:
+ toRemove = self.queue.get_nowait()
del self.dict[toRemove]
- self.queue.append(name)
+ self.queue.put_nowait(name)
self.dict[name]=result
return result
- def __del__(self):
- for name, dataset in self.dict.items():
- del dataset
- del self.queue
- del self.dict
-
-
class tile_info:
""" A class holding info how to tile """
@@ -144,8 +163,8 @@
"""
self.TempDriver=gdal.GetDriverByName("MEM")
self.filename = filename
- self.cache = DataSetCache()
self.ogrTileIndexDS = inputDS
+ self.cache = DataSetCache(8)
self.ogrTileIndexDS.GetLayer().ResetReading()
feature = self.ogrTileIndexDS.GetLayer().GetNextFeature()
@@ -179,10 +198,7 @@
self.xsize = int(round((self.lrx-self.ulx) / self.scaleX))
self.ysize = abs(int(round((self.uly-self.lry) / self.scaleY)))
-
- def __del__(self):
- del self.cache
- del self.ogrTileIndexDS
+ self.cache.cacheSize = int(self.xsize / float(TileWidth) * 6) + 8
def getDataSet(self,minx,miny,maxx,maxy):
@@ -259,7 +275,7 @@
return resultDS
def closeDataSet(self, memDS):
- del memDS
+ pass
#self.TempDriver.Delete("TEMP")
@@ -277,10 +293,8 @@
if Verbose:
from sys import version_info
- if version_info >= (3,0,0):
- exec('print("Building internal Index for %d tile(s) ..." % len(inputTiles), end=" ")')
- else:
- exec('print "Building internal Index for %d tile(s) ..." % len(inputTiles), ')
+ sys.stdout.write("Building internal Index for %d tile(s) ..." % len(inputTiles))
+ sys.stdout.flush()
ogrTileIndexDS = createTileIndex("TileIndex",TileIndexFieldName,None,driverTyp);
for inputTile in inputTiles:
@@ -293,7 +307,6 @@
points = dec.pointsFor(fhInputTile.RasterXSize, fhInputTile.RasterYSize)
addFeature(ogrTileIndexDS,inputTile,points[0],points[1])
- del fhInputTile
if Verbose:
print("finished")
@@ -312,6 +325,24 @@
+def tileImage2(minfo, ti, xIndex, yIndex, mutex, OGRDS):
+ offsetY=(yIndex-1)* ti.tileHeight
+ offsetX=(xIndex-1)* ti.tileWidth
+ if yIndex==ti.countTilesY:
+ height=ti.lastTileHeight
+ else:
+ height=ti.tileHeight
+
+ if xIndex==ti.countTilesX:
+ width=ti.lastTileWidth
+ else:
+ width=ti.tileWidth
+ if UseDirForEachRow :
+ tilename=getTileName(minfo,ti, xIndex, yIndex,0)
+ else:
+ tilename=getTileName(minfo,ti, xIndex, yIndex)
+ createTile(minfo, offsetX, offsetY, width, height,tilename, mutex, OGRDS)
+
def tileImage(minfo, ti ):
"""
@@ -320,7 +351,6 @@
returns list of created tiles
"""
-
global LastRowIndx
LastRowIndx=-1
OGRDS=createTileIndex("TileResult_0", TileIndexFieldName, Source_SRS,TileIndexDriverTyp)
@@ -329,24 +359,13 @@
yRange = list(range(1,ti.countTilesY+1))
xRange = list(range(1,ti.countTilesX+1))
+ pool = Pool()
+ mutex = Lock()
for yIndex in yRange:
for xIndex in xRange:
- offsetY=(yIndex-1)* ti.tileHeight
- offsetX=(xIndex-1)* ti.tileWidth
- if yIndex==ti.countTilesY:
- height=ti.lastTileHeight
- else:
- height=ti.tileHeight
-
- if xIndex==ti.countTilesX:
- width=ti.lastTileWidth
- else:
- width=ti.tileWidth
- if UseDirForEachRow :
- tilename=getTileName(minfo,ti, xIndex, yIndex,0)
- else:
- tilename=getTileName(minfo,ti, xIndex, yIndex)
- createTile(minfo, offsetX, offsetY, width, height,tilename,OGRDS)
+ pool.apply_async(tileImage2, [minfo, ti, xIndex, yIndex, mutex, OGRDS])
+ pool.close()
+ pool.join()
if TileIndexName is not None:
@@ -406,7 +425,7 @@
-def createPyramidTile(levelMosaicInfo, offsetX, offsetY, width, height,tileName,OGRDS):
+def createPyramidTile(levelMosaicInfo, offsetX, offsetY, width, height,tileName,mutex,OGRDS):
sx= levelMosaicInfo.scaleX*2
sy= levelMosaicInfo.scaleY*2
@@ -424,7 +443,8 @@
if OGRDS is not None:
points = dec.pointsFor(width, height)
- addFeature(OGRDS, tileName, points[0], points[1])
+ with mutex:
+ addFeature(OGRDS, tileName, points[0], points[1])
if BandType is None:
@@ -472,7 +492,7 @@
-def createTile( minfo, offsetX,offsetY,width,height, tilename,OGRDS):
+def createTile( minfo, offsetX,offsetY,width,height, tilename,mutex,OGRDS):
"""
Create tile
@@ -488,7 +508,6 @@
dec = AffineTransformDecorator([minfo.ulx,minfo.scaleX,0,minfo.uly,0,minfo.scaleY])
-
s_fh = minfo.getDataSet(dec.ulx+offsetX*dec.scaleX,dec.uly+offsetY*dec.scaleY+height*dec.scaleY,
dec.ulx+offsetX*dec.scaleX+width*dec.scaleX,
dec.uly+offsetY*dec.scaleY)
@@ -502,7 +521,8 @@
if OGRDS is not None:
dec2 = AffineTransformDecorator(geotransform)
points = dec2.pointsFor(width, height)
- addFeature(OGRDS, tilename, points[0], points[1])
+ with mutex:
+ addFeature(OGRDS, tilename, points[0], points[1])
@@ -615,27 +635,34 @@
inputDS=buildPyramidLevel(levelMosaicInfo,levelOutputTileInfo,level)
+def buildPyramidLevel2(levelMosaicInfo,levelOutputTileInfo, xIndex, yIndex, level, mutex, OGRDS):
+ offsetY=(yIndex-1)* levelOutputTileInfo.tileHeight
+ offsetX=(xIndex-1)* levelOutputTileInfo.tileWidth
+ if yIndex==levelOutputTileInfo.countTilesY:
+ height=levelOutputTileInfo.lastTileHeight
+ else:
+ height=levelOutputTileInfo.tileHeight
+
+ if xIndex==levelOutputTileInfo.countTilesX:
+ width=levelOutputTileInfo.lastTileWidth
+ else:
+ width=levelOutputTileInfo.tileWidth
+ tilename=getTileName(levelMosaicInfo,levelOutputTileInfo, xIndex, yIndex,level)
+ createPyramidTile(levelMosaicInfo, offsetX, offsetY, width, height,tilename,mutex,OGRDS)
+
def buildPyramidLevel(levelMosaicInfo,levelOutputTileInfo, level):
yRange = list(range(1,levelOutputTileInfo.countTilesY+1))
xRange = list(range(1,levelOutputTileInfo.countTilesX+1))
OGRDS=createTileIndex("TileResult_"+str(level), TileIndexFieldName, Source_SRS,TileIndexDriverTyp)
+ pool = Pool()
+ mutex = Lock()
for yIndex in yRange:
for xIndex in xRange:
- offsetY=(yIndex-1)* levelOutputTileInfo.tileHeight
- offsetX=(xIndex-1)* levelOutputTileInfo.tileWidth
- if yIndex==levelOutputTileInfo.countTilesY:
- height=levelOutputTileInfo.lastTileHeight
- else:
- height=levelOutputTileInfo.tileHeight
-
- if xIndex==levelOutputTileInfo.countTilesX:
- width=levelOutputTileInfo.lastTileWidth
- else:
- width=levelOutputTileInfo.tileWidth
- tilename=getTileName(levelMosaicInfo,levelOutputTileInfo, xIndex, yIndex,level)
- createPyramidTile(levelMosaicInfo, offsetX, offsetY, width, height,tilename,OGRDS)
+ pool.apply_async(buildPyramidLevel2, [levelMosaicInfo,levelOutputTileInfo, xIndex, yIndex, level, mutex, OGRDS])
+ pool.close()
+ pool.join()
if TileIndexName is not None:
@@ -698,6 +725,7 @@
print(' [-s_srs srs_def] [-pyramidOnly] -levels numberoflevels')
print(' [-r {near/bilinear/cubic/cubicspline/lanczos}]')
print(' [-useDirForEachRow]')
+ print(' [-multi]')
print(' -targetDir TileDirectory input_files')
# =============================================================================
@@ -732,6 +760,7 @@
global Levels
global PyramidOnly
global UseDirForEachRow
+ global Multithreading
gdal.AllRegister()
@@ -831,6 +860,8 @@
CsvDelimiter=argv[i]
elif arg == '-useDirForEachRow':
UseDirForEachRow=True
+ elif arg == '-multi':
+ Multithreading=True
elif arg[:1] == '-':
print('Unrecognised command option: %s' % arg)
Usage()
@@ -854,6 +885,16 @@
Usage()
return 1
+ global Pool
+ global Lock
+
+ if Multithreading is True:
+ Pool = _Pool
+ Lock = _Lock
+ else:
+ Pool = DummyPool
+ Lock = DummyLock
+
# create level 0 directory if needed
if(UseDirForEachRow and PyramidOnly==False) :
leveldir=TargetDir+str(0)+os.sep
@@ -944,6 +985,7 @@
global PyramidOnly
global LastRowIndx
global UseDirForEachRow
+ global Multithreading
Verbose=False
@@ -969,6 +1011,7 @@
PyramidOnly=False
LastRowIndx=-1
UseDirForEachRow=False
+ Multithreading=False
@@ -995,7 +1038,7 @@
PyramidOnly=False
LastRowIndx=-1
UseDirForEachRow=False
-
+Multithreading=False
if __name__ == '__main__':
sys.exit(main(sys.argv))
@rcoup
Copy link
Author

rcoup commented Mar 26, 2012

Not production code, use at your own risk, etc. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment