Skip to content

Instantly share code, notes, and snippets.

@mthh
Last active August 31, 2015 13:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mthh/bc71aa67ac0521297a97 to your computer and use it in GitHub Desktop.
Save mthh/bc71aa67ac0521297a97 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
# The MIT License (MIT)
#
# « Copyright (c) 2015, mthh
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import platform
import sys
import os.path
from sqlite3 import dbapi2 as db
import csv
from math import cos, sin, radians, asin, sqrt
from time import time
try:
from osgeo import ogr
except:
import ogr
def read_row(input_path, input_type): # Read the point layer ...
if input_type == 1: # ... if the datasaource is a .csv :
liste_coord = []
f = open(input_path, encoding='utf-8')
reader = csv.reader(f)
next(reader)
for row in reader:
liste_coord.append((str.strip(row[0]),
str.strip(row[1]), str.strip(row[2])))
return liste_coord
elif input_type == 2: # .. if the datasaource is a shapefile :
datasource = ogr.Open(input_path)
layer = datasource.GetLayer(0)
if layer.GetGeomType() != 1:
sys.exit('Input must be a POINTS layer\n')
liste_coord = []
for feature in layer: # ... assuming the 1st attribute field is an id
liste_coord.append((feature.GetField(0),
str(feature.geometry().GetX()),
str(feature.geometry().GetY())))
return liste_coord
def write_row(output_path, liste): # Write the result in a csv file:
with open(output_path, 'w') as f:
writer = csv.writer(f)
if len(liste[0]) == 3:
writer.writerow(['ID', 'x', 'y'])
else:
writer.writerow(['ID', 'x', 'y', 'distance'])
writer.writerows(liste)
print('Results saved in {}'.format(output_path))
def haversine_dist(pt1, pt2):
"""
Calculate the great circle distance between two points
on the earth (specified in decimal degrees).
"""
lon1, lat1 = float(pt1[0]), float(pt1[1])
lon2, lat2 = float(pt2[0]), float(pt2[1])
lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])
dlon = lon2 - lon1
dlat = lat2 - lat1
a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
c = 2 * asin(sqrt(a))
return c * 6367000
class DbLineHandler():
######################
# Config options / can be untouched (or not)
INFO = True # To display basic info about SQLite/SpatiaLite shared library
TMP_TABLE0 = 'tmp_pts' # Table for inputs points
TMP_TABLE1 = 'tmp_dist_min' # Intermediate table
TMP_TABLE2 = 'tmp_closest_points' # Result table
BBOX_TABLE = 'tmp_bbox' # The search_frame for the nearest neighbour query
######################
def __init__(self, options): # Initialize the connection to the database
self.options = options
self.TABLE = options.table_name
try:
self.conn = db.connect(options.db_path)
except:
sys.exit('Failed to connect to the database')
try:
if 'windo' in options.syst:
current_folder = os.getcwd()
fold = os.path.dirname(options.libspatialite)
lib = os.path.basename(options.libspatialite)
os.chdir(fold)
self.conn.enable_load_extension(True)
self.conn.execute("SELECT load_extension(?, \
'sqlite3_modspatialite_init');", (lib,))
os.chdir(current_folder)
else:
self.conn.enable_load_extension(True)
self.conn.execute("SELECT load_extension(?, \
'sqlite3_modspatialite_init');", (options.libspatialite,))
except db.OperationalError as err:
self.conn.close()
sys.exit('Unable to load spatialite extension\n{}'.format(err))
if options.db_path == ':memory:':
self.load_shapefile(options.table_name)
else:
self.find_geomcol_spatialindex()
c = self.conn.cursor()
c.execute('''PRAGMA cache_size={};'''.format(
int(self.options.cachesize)))
if self.INFO:
print("SQLite cache_size : {} page_size : {}".format(
[row[0] for row in c.execute('''PRAGMA cache_size;''')],
[row[0] for row in c.execute('''PRAGMA page_size;''')]))
vs = str([i for i in c.execute("""SELECT spatialite_version()""")])
vgeos = str([i for i in c.execute("""SELECT geos_version()""")])
print('Spatialite {} (running against GEOS {})'.format(
vs.strip("()[]',"), vgeos.strip("()',[]")))
c.close()
self.conn.commit()
def find_geomcol_spatialindex(self):
conn = self.conn
r = conn.execute('''SELECT f_geometry_column, geometry_type,
spatial_index_enabled
FROM geometry_columns
WHERE f_table_name LIKE ?;''', (self.TABLE,))
self.col_geom, self.type_geom, self.sp_index = r.fetchone()
if (self.type_geom == 2 or self.type_geom == 5) and self.sp_index == 1:
pass
else: # only indexed table of (multi)linestring are supported here:
conn.close()
sys.exit('No table found, wrong geometry type or no spatial index')
def load_shapefile(self, input_shp):
conn = self.conn
c = conn.cursor()
self.TABLE = 'tmp_import'
print('Creating memory database and initializing Spatial MetaData...')
c.execute("""SELECT InitSpatialMetaData();""")
print('Loading {} in memory...'.format(input_shp))
datasource = ogr.Open(input_shp)
layer = datasource.GetLayer(0)
if layer.GetGeomType() != 2:
conn.close()
sys.exit('Input must be a LINE layer\n')
c.executescript("""DROP TABLE IF EXISTS tmp_import;
CREATE TABLE tmp_import
(PK_UID INTEGER PRIMARY KEY AUTOINCREMENT,
geometry LINESTRING);""")
transac = ["""BEGIN TRANSACTION;"""]
for feature in layer:
transac.append("""INSERT INTO tmp_import (geometry)
VALUES (GeomFromText('{}', 4326));""".format(
feature.geometry().ExportToWkt()))
transac.append("""COMMIT TRANSACTION;""")
transac = ''.join(transac)
c.executescript(transac)
conn.commit()
print('Creating geometry column and spatial index...')
c.execute("""SELECT RecoverGeometryColumn('{}',
'geometry', 4326, 'Linestring', 'XY');""".format(self.TABLE))
c.execute("""SELECT CreateSpatialIndex('{}',
'geometry');""".format(self.TABLE))
print('Ready to work...')
conn.commit()
self.col_geom = 'geometry'
c.close()
def close_db(self, skip_vacuum=False):
# Drop temporary tables ...
conn = self.conn
try:
with conn:
conn.executescript("""DELETE FROM geometry_columns
WHERE f_table_name = '{0}';
DELETE FROM geometry_columns
WHERE f_table_name = '{1}';
DROP TABLE idx_{0}_GEOMETRY;
DROP TABLE idx_{1}_GEOMETRY;""".format(
self.TMP_TABLE0, self.BBOX_TABLE))
conn.executescript("""DROP TABLE {0}; DROP TABLE {1};
DROP TABLE {2}; DROP TABLE {3}""".format(
self.TMP_TABLE0, self.TMP_TABLE1,
self.TMP_TABLE2, self.BBOX_TABLE))
if not skip_vacuum and 'memory' not in self.options.db_path:
# ... perform a vacuum ...
timer = time()
print('Now performing a VACUUM on the database...', end='')
conn.execute("VACUUM;")
print(' {:.2f}s'.format(time()-timer))
except:
print('An error occured when deleting temporary tables...')
conn.commit()
conn.close() # ..and close the db
def import_row_db(self, input_rows, buff_d=0.075):
"""input_rows doit etre une liste de tuples
de la forme [(uid, x, y), (uid, x, y), ...]
La fonction retourne le nombre de lignes insérées"""
conn = self.conn
transac_l = ["""BEGIN TRANSACTION; DROP TABLE IF EXISTS {0};
CREATE TABLE {0} (PK_UID INTEGER PRIMARY KEY AUTOINCREMENT,
GEOMETRY POINT, id_ref TEXT, _x TEXT, _y TEXT);"""
.format(self.TMP_TABLE0)]
transac_bbox_l = ["""BEGIN TRANSACTION; DROP TABLE IF EXISTS {0};
CREATE TABLE {0} (PK_UID INTEGER PRIMARY KEY AUTOINCREMENT,
GEOMETRY POLYGON, id_ref TEXT);""".format(self.BBOX_TABLE)]
for i, coord in enumerate(input_rows):
transac_l.append("""INSERT INTO {} (GEOMETRY, id_ref, _x, _y)
VALUES (MakePoint({},{},4326), '{}', '{}', '{}');"""
.format(self.TMP_TABLE0, coord[1], coord[2],
coord[0], coord[1], coord[2]))
geom_pt = ogr.Geometry(ogr.wkbPoint)
geom_pt.AddPoint(float(coord[1]), float(coord[2]))
transac_bbox_l.append("""INSERT INTO {} (GEOMETRY, id_ref)
VALUES (GeomFromText('{}',4326), '{}');"""
.format(self.BBOX_TABLE,
geom_pt.Buffer(buff_d).ExportToWkt(),
coord[0]))
transac_l.append("""COMMIT TRANSACTION;""")
transac_bbox_l.append("""COMMIT TRANSACTION;""")
try:
with conn:
trnsac = ''.join(transac_l) # Inserting points
conn.executescript(trnsac) # ..into temporary table
trnsac = ''.join(transac_bbox_l) # Inserting rectangle..
conn.executescript(trnsac) # ..into anoter temporary table
conn.commit()
# Making geometry column valid and creating spatials index :
trnsac = ("""SELECT RecoverGeometryColumn('{0}', 'GEOMETRY',"""
"""4326, 'POINT', 'XY');"""
"""SELECT CreateSpatialIndex('{0}', 'GEOMETRY');"""
"""SELECT RecoverGeometryColumn('{1}', 'GEOMETRY',"""
"""4326, 'POLYGON', 'XY');"""
"""SELECT CreateSpatialIndex('{1}', 'GEOMETRY');"""
).format(self.TMP_TABLE0, self.BBOX_TABLE)
conn.executescript(trnsac)
conn.commit()
except db.OperationalError as er:
print('Error while creating tables : {} \
\n A rollback has been performed\n'.format(er))
return -1
return i+1
def find_nearest(self):
"""Fonction qui renvoit la liste des points accrochés au réseau"""
conn = self.conn
liste_result = []
# Compute the distance between each point and every highway present in
# ... the point search_frame then take the ID of the closest highway :
transac_l = ["""BEGIN TRANSACTION;
DROP TABLE IF EXISTS {0};
CREATE TABLE {0} AS
SELECT a.ROWID as l_id, b.PK_UID as pt_id,
Min(ST_Distance(a.{1}, b.GEOMETRY)) as distance
FROM {2} b, {3} a, {4} c
WHERE b.PK_UID = c.PK_UID AND a.ROWID in (
SELECT ROWID FROM SpatialIndex
WHERE f_table_name = '{3}' AND search_frame = c.GEOMETRY)
GROUP BY pt_id
ORDER BY pt_id;""".format(self.TMP_TABLE1, self.col_geom,
self.TMP_TABLE0, self.TABLE,
self.BBOX_TABLE)]
# For each input point, compute the coordinates of the closest point
# ... laying on the highway segment previously identified :
transac_l.append("""DROP TABLE IF EXISTS {0};
CREATE TABLE {0} AS
SELECT t.pt_id as pt_id,
X(ST_ClosestPoint(a.{1}, b.GEOMETRY)) as x,
Y(ST_ClosestPoint(a.{1}, b.GEOMETRY)) as y
FROM {2} t, {3} a, {4} b
WHERE a.ROWID = t.l_id AND t.pt_id = b.PK_UID
GROUP BY t.ROWID;""".format(self.TMP_TABLE2, self.col_geom,
self.TMP_TABLE1, self.TABLE,
self.TMP_TABLE0))
# Populate the table with freshly obtained and line-snapped coordinates
transac_l.append("""ALTER TABLE {0} ADD COLUMN x_new TEXT;
ALTER TABLE {0} ADD COLUMN y_new TEXT;
UPDATE {0} SET x_new =
(SELECT x FROM {1} WHERE {0}.PK_UID = {1}.pt_id);
UPDATE {0} SET y_new =
(SELECT y FROM {1} WHERE {0}.PK_UID = {1}.pt_id);
COMMIT TRANSACTION;""".format(self.TMP_TABLE0,
self.TMP_TABLE2))
transac = ''.join(transac_l)
try:
with conn:
conn.executescript(transac)
except db.OperationalError as er:
print('Error while creating tables : {} \
\n A rollback has been performed\n'.format(er))
return [0], -1
# Retrieve the result :
transac = ("SELECT id_ref, x_new, y_new"
" FROM {} WHERE x_new is Not NULL;").format(self.TMP_TABLE0)
liste_result = [row for row in conn.execute(transac)]
# And potential errors :
transac = ("SELECT id_ref, x_new, y_new"
" FROM {} WHERE x_new is NULL"
" OR y_new is NULL;").format(self.TMP_TABLE0)
liste_error = [row for row in conn.execute(transac)]
conn.commit()
return liste_result, liste_error
if __name__ == "__main__":
syst = platform.system().lower()
if 'windo' in syst:
SPATIALITE_LIBRARY = 'mod_spatialite.dll'
elif 'inux' in syst:
# SPATIALITE_LIBRARY = 'mod_spatialite.so'
SPATIALITE_LIBRARY = '/usr/local/lib/mod_spatialite.so'
elif 'darwin' in syst:
SPATIALITE_LIBRARY = 'mod_spatialite.dylib'
import argparse
desc = ("Script to snap a set of points (csv/shp)"
" on a network (shp/spatialite)")
argPr = argparse.ArgumentParser(description=desc)
argPr.add_argument("input_path",
help="Path to input .csv or .shp file of points")
argPr.add_argument("db_path", help=("Path to the SQLite/Shapefile "
"containing the reference network"))
argPr.add_argument("table_name", help=("Name of the network table (put 'me"
"mory' here if using a shapefile)"))
argPr.add_argument("-l", "--libspatialite", default=SPATIALITE_LIBRARY,
help=("Set manually the path to the spatialite library"
" (full path is preferred)"))
argPr.add_argument("-c", "--cachesize", default=2048000,
help="Set the SQLite cache size (default: 2048000)")
argPr.add_argument("-o", "--output", action='store',
help=("Chose the name of the ouputfile "
"(default : input_name-snapped.csv)"))
argPr.add_argument("-s", "--skipvacuum",
help="Skip the DB VACUUM after processing the points",
action='store_true')
argPr.add_argument("-d", "--distance", action='store_true',
help=("Fill a field with the distance (meter) between "
"the original point and the snapped one"))
argPr.add_argument("-f", "--framesearch", action='store', default=0.072,
help=("Change the search frame ("
"in km) (default : 6.5km)"))
options = argPr.parse_args()
options.syst = syst
if 'csv' in options.input_path and os.path.exists(options.input_path):
input_type = 1
elif 'shp' in options.input_path and os.path.exists(options.input_path):
input_type = 2
else:
sys.exit('Wrong input file (wrong extension or file don\'t exists)')
if '.shp' in options.db_path and 'mem' in options.table_name:
options.table_name = options.db_path
options.db_path = ':memory:'
if options.output:
output = options.output
else:
output = options.input_path[:len(options.input_path)-4] + '-snapped.csv'
options.framesearch = float(options.framesearch)
if options.framesearch != 0.072:
options.framesearch = options.framesearch/111.11
# poor meter/degree conversion !
options.libspatialite = os.path.realpath(options.libspatialite)
if 'wind' in syst and options.libspatialite != SPATIALITE_LIBRARY:
options.libspatialite = str(options.libspatialite).replace('\\', '/')
if not os.path.exists(options.libspatialite):
sys.exit('Spatialite shared library not found')
start_time = time()
liste_coord = read_row(options.input_path, input_type)
a = DbLineHandler(options)
a.import_row_db(liste_coord, buff_d=options.framesearch)
result, error = a.find_nearest()
print('{}/{} points snapped (performed in {:.2f}s)'.format(len(result),
len(liste_coord), time()-start_time))
if options.distance:
result = [[pt2[0], pt2[1], pt2[2],
haversine_dist([pt1[1], pt1[2]],
[pt2[1], pt2[2]])] for pt1, pt2
in zip(liste_coord, result)]
if len(result) > 0:
write_row(output, result)
else:
print('Something went wrong : no snapped point to write')
a.close_db(skip_vacuum=options.skipvacuum)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment