Last active
August 31, 2015 13:37
-
-
Save mthh/bc71aa67ac0521297a97 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# The MIT License (MIT) | |
# | |
# « Copyright (c) 2015, mthh | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import platform | |
import sys | |
import os.path | |
from sqlite3 import dbapi2 as db | |
import csv | |
from math import cos, sin, radians, asin, sqrt | |
from time import time | |
try: | |
from osgeo import ogr | |
except: | |
import ogr | |
def read_row(input_path, input_type): # Read the point layer ... | |
if input_type == 1: # ... if the datasaource is a .csv : | |
liste_coord = [] | |
f = open(input_path, encoding='utf-8') | |
reader = csv.reader(f) | |
next(reader) | |
for row in reader: | |
liste_coord.append((str.strip(row[0]), | |
str.strip(row[1]), str.strip(row[2]))) | |
return liste_coord | |
elif input_type == 2: # .. if the datasaource is a shapefile : | |
datasource = ogr.Open(input_path) | |
layer = datasource.GetLayer(0) | |
if layer.GetGeomType() != 1: | |
sys.exit('Input must be a POINTS layer\n') | |
liste_coord = [] | |
for feature in layer: # ... assuming the 1st attribute field is an id | |
liste_coord.append((feature.GetField(0), | |
str(feature.geometry().GetX()), | |
str(feature.geometry().GetY()))) | |
return liste_coord | |
def write_row(output_path, liste): # Write the result in a csv file: | |
with open(output_path, 'w') as f: | |
writer = csv.writer(f) | |
if len(liste[0]) == 3: | |
writer.writerow(['ID', 'x', 'y']) | |
else: | |
writer.writerow(['ID', 'x', 'y', 'distance']) | |
writer.writerows(liste) | |
print('Results saved in {}'.format(output_path)) | |
def haversine_dist(pt1, pt2): | |
""" | |
Calculate the great circle distance between two points | |
on the earth (specified in decimal degrees). | |
""" | |
lon1, lat1 = float(pt1[0]), float(pt1[1]) | |
lon2, lat2 = float(pt2[0]), float(pt2[1]) | |
lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2]) | |
dlon = lon2 - lon1 | |
dlat = lat2 - lat1 | |
a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2 | |
c = 2 * asin(sqrt(a)) | |
return c * 6367000 | |
class DbLineHandler(): | |
###################### | |
# Config options / can be untouched (or not) | |
INFO = True # To display basic info about SQLite/SpatiaLite shared library | |
TMP_TABLE0 = 'tmp_pts' # Table for inputs points | |
TMP_TABLE1 = 'tmp_dist_min' # Intermediate table | |
TMP_TABLE2 = 'tmp_closest_points' # Result table | |
BBOX_TABLE = 'tmp_bbox' # The search_frame for the nearest neighbour query | |
###################### | |
def __init__(self, options): # Initialize the connection to the database | |
self.options = options | |
self.TABLE = options.table_name | |
try: | |
self.conn = db.connect(options.db_path) | |
except: | |
sys.exit('Failed to connect to the database') | |
try: | |
if 'windo' in options.syst: | |
current_folder = os.getcwd() | |
fold = os.path.dirname(options.libspatialite) | |
lib = os.path.basename(options.libspatialite) | |
os.chdir(fold) | |
self.conn.enable_load_extension(True) | |
self.conn.execute("SELECT load_extension(?, \ | |
'sqlite3_modspatialite_init');", (lib,)) | |
os.chdir(current_folder) | |
else: | |
self.conn.enable_load_extension(True) | |
self.conn.execute("SELECT load_extension(?, \ | |
'sqlite3_modspatialite_init');", (options.libspatialite,)) | |
except db.OperationalError as err: | |
self.conn.close() | |
sys.exit('Unable to load spatialite extension\n{}'.format(err)) | |
if options.db_path == ':memory:': | |
self.load_shapefile(options.table_name) | |
else: | |
self.find_geomcol_spatialindex() | |
c = self.conn.cursor() | |
c.execute('''PRAGMA cache_size={};'''.format( | |
int(self.options.cachesize))) | |
if self.INFO: | |
print("SQLite cache_size : {} page_size : {}".format( | |
[row[0] for row in c.execute('''PRAGMA cache_size;''')], | |
[row[0] for row in c.execute('''PRAGMA page_size;''')])) | |
vs = str([i for i in c.execute("""SELECT spatialite_version()""")]) | |
vgeos = str([i for i in c.execute("""SELECT geos_version()""")]) | |
print('Spatialite {} (running against GEOS {})'.format( | |
vs.strip("()[]',"), vgeos.strip("()',[]"))) | |
c.close() | |
self.conn.commit() | |
def find_geomcol_spatialindex(self): | |
conn = self.conn | |
r = conn.execute('''SELECT f_geometry_column, geometry_type, | |
spatial_index_enabled | |
FROM geometry_columns | |
WHERE f_table_name LIKE ?;''', (self.TABLE,)) | |
self.col_geom, self.type_geom, self.sp_index = r.fetchone() | |
if (self.type_geom == 2 or self.type_geom == 5) and self.sp_index == 1: | |
pass | |
else: # only indexed table of (multi)linestring are supported here: | |
conn.close() | |
sys.exit('No table found, wrong geometry type or no spatial index') | |
def load_shapefile(self, input_shp): | |
conn = self.conn | |
c = conn.cursor() | |
self.TABLE = 'tmp_import' | |
print('Creating memory database and initializing Spatial MetaData...') | |
c.execute("""SELECT InitSpatialMetaData();""") | |
print('Loading {} in memory...'.format(input_shp)) | |
datasource = ogr.Open(input_shp) | |
layer = datasource.GetLayer(0) | |
if layer.GetGeomType() != 2: | |
conn.close() | |
sys.exit('Input must be a LINE layer\n') | |
c.executescript("""DROP TABLE IF EXISTS tmp_import; | |
CREATE TABLE tmp_import | |
(PK_UID INTEGER PRIMARY KEY AUTOINCREMENT, | |
geometry LINESTRING);""") | |
transac = ["""BEGIN TRANSACTION;"""] | |
for feature in layer: | |
transac.append("""INSERT INTO tmp_import (geometry) | |
VALUES (GeomFromText('{}', 4326));""".format( | |
feature.geometry().ExportToWkt())) | |
transac.append("""COMMIT TRANSACTION;""") | |
transac = ''.join(transac) | |
c.executescript(transac) | |
conn.commit() | |
print('Creating geometry column and spatial index...') | |
c.execute("""SELECT RecoverGeometryColumn('{}', | |
'geometry', 4326, 'Linestring', 'XY');""".format(self.TABLE)) | |
c.execute("""SELECT CreateSpatialIndex('{}', | |
'geometry');""".format(self.TABLE)) | |
print('Ready to work...') | |
conn.commit() | |
self.col_geom = 'geometry' | |
c.close() | |
def close_db(self, skip_vacuum=False): | |
# Drop temporary tables ... | |
conn = self.conn | |
try: | |
with conn: | |
conn.executescript("""DELETE FROM geometry_columns | |
WHERE f_table_name = '{0}'; | |
DELETE FROM geometry_columns | |
WHERE f_table_name = '{1}'; | |
DROP TABLE idx_{0}_GEOMETRY; | |
DROP TABLE idx_{1}_GEOMETRY;""".format( | |
self.TMP_TABLE0, self.BBOX_TABLE)) | |
conn.executescript("""DROP TABLE {0}; DROP TABLE {1}; | |
DROP TABLE {2}; DROP TABLE {3}""".format( | |
self.TMP_TABLE0, self.TMP_TABLE1, | |
self.TMP_TABLE2, self.BBOX_TABLE)) | |
if not skip_vacuum and 'memory' not in self.options.db_path: | |
# ... perform a vacuum ... | |
timer = time() | |
print('Now performing a VACUUM on the database...', end='') | |
conn.execute("VACUUM;") | |
print(' {:.2f}s'.format(time()-timer)) | |
except: | |
print('An error occured when deleting temporary tables...') | |
conn.commit() | |
conn.close() # ..and close the db | |
def import_row_db(self, input_rows, buff_d=0.075): | |
"""input_rows doit etre une liste de tuples | |
de la forme [(uid, x, y), (uid, x, y), ...] | |
La fonction retourne le nombre de lignes insérées""" | |
conn = self.conn | |
transac_l = ["""BEGIN TRANSACTION; DROP TABLE IF EXISTS {0}; | |
CREATE TABLE {0} (PK_UID INTEGER PRIMARY KEY AUTOINCREMENT, | |
GEOMETRY POINT, id_ref TEXT, _x TEXT, _y TEXT);""" | |
.format(self.TMP_TABLE0)] | |
transac_bbox_l = ["""BEGIN TRANSACTION; DROP TABLE IF EXISTS {0}; | |
CREATE TABLE {0} (PK_UID INTEGER PRIMARY KEY AUTOINCREMENT, | |
GEOMETRY POLYGON, id_ref TEXT);""".format(self.BBOX_TABLE)] | |
for i, coord in enumerate(input_rows): | |
transac_l.append("""INSERT INTO {} (GEOMETRY, id_ref, _x, _y) | |
VALUES (MakePoint({},{},4326), '{}', '{}', '{}');""" | |
.format(self.TMP_TABLE0, coord[1], coord[2], | |
coord[0], coord[1], coord[2])) | |
geom_pt = ogr.Geometry(ogr.wkbPoint) | |
geom_pt.AddPoint(float(coord[1]), float(coord[2])) | |
transac_bbox_l.append("""INSERT INTO {} (GEOMETRY, id_ref) | |
VALUES (GeomFromText('{}',4326), '{}');""" | |
.format(self.BBOX_TABLE, | |
geom_pt.Buffer(buff_d).ExportToWkt(), | |
coord[0])) | |
transac_l.append("""COMMIT TRANSACTION;""") | |
transac_bbox_l.append("""COMMIT TRANSACTION;""") | |
try: | |
with conn: | |
trnsac = ''.join(transac_l) # Inserting points | |
conn.executescript(trnsac) # ..into temporary table | |
trnsac = ''.join(transac_bbox_l) # Inserting rectangle.. | |
conn.executescript(trnsac) # ..into anoter temporary table | |
conn.commit() | |
# Making geometry column valid and creating spatials index : | |
trnsac = ("""SELECT RecoverGeometryColumn('{0}', 'GEOMETRY',""" | |
"""4326, 'POINT', 'XY');""" | |
"""SELECT CreateSpatialIndex('{0}', 'GEOMETRY');""" | |
"""SELECT RecoverGeometryColumn('{1}', 'GEOMETRY',""" | |
"""4326, 'POLYGON', 'XY');""" | |
"""SELECT CreateSpatialIndex('{1}', 'GEOMETRY');""" | |
).format(self.TMP_TABLE0, self.BBOX_TABLE) | |
conn.executescript(trnsac) | |
conn.commit() | |
except db.OperationalError as er: | |
print('Error while creating tables : {} \ | |
\n A rollback has been performed\n'.format(er)) | |
return -1 | |
return i+1 | |
def find_nearest(self): | |
"""Fonction qui renvoit la liste des points accrochés au réseau""" | |
conn = self.conn | |
liste_result = [] | |
# Compute the distance between each point and every highway present in | |
# ... the point search_frame then take the ID of the closest highway : | |
transac_l = ["""BEGIN TRANSACTION; | |
DROP TABLE IF EXISTS {0}; | |
CREATE TABLE {0} AS | |
SELECT a.ROWID as l_id, b.PK_UID as pt_id, | |
Min(ST_Distance(a.{1}, b.GEOMETRY)) as distance | |
FROM {2} b, {3} a, {4} c | |
WHERE b.PK_UID = c.PK_UID AND a.ROWID in ( | |
SELECT ROWID FROM SpatialIndex | |
WHERE f_table_name = '{3}' AND search_frame = c.GEOMETRY) | |
GROUP BY pt_id | |
ORDER BY pt_id;""".format(self.TMP_TABLE1, self.col_geom, | |
self.TMP_TABLE0, self.TABLE, | |
self.BBOX_TABLE)] | |
# For each input point, compute the coordinates of the closest point | |
# ... laying on the highway segment previously identified : | |
transac_l.append("""DROP TABLE IF EXISTS {0}; | |
CREATE TABLE {0} AS | |
SELECT t.pt_id as pt_id, | |
X(ST_ClosestPoint(a.{1}, b.GEOMETRY)) as x, | |
Y(ST_ClosestPoint(a.{1}, b.GEOMETRY)) as y | |
FROM {2} t, {3} a, {4} b | |
WHERE a.ROWID = t.l_id AND t.pt_id = b.PK_UID | |
GROUP BY t.ROWID;""".format(self.TMP_TABLE2, self.col_geom, | |
self.TMP_TABLE1, self.TABLE, | |
self.TMP_TABLE0)) | |
# Populate the table with freshly obtained and line-snapped coordinates | |
transac_l.append("""ALTER TABLE {0} ADD COLUMN x_new TEXT; | |
ALTER TABLE {0} ADD COLUMN y_new TEXT; | |
UPDATE {0} SET x_new = | |
(SELECT x FROM {1} WHERE {0}.PK_UID = {1}.pt_id); | |
UPDATE {0} SET y_new = | |
(SELECT y FROM {1} WHERE {0}.PK_UID = {1}.pt_id); | |
COMMIT TRANSACTION;""".format(self.TMP_TABLE0, | |
self.TMP_TABLE2)) | |
transac = ''.join(transac_l) | |
try: | |
with conn: | |
conn.executescript(transac) | |
except db.OperationalError as er: | |
print('Error while creating tables : {} \ | |
\n A rollback has been performed\n'.format(er)) | |
return [0], -1 | |
# Retrieve the result : | |
transac = ("SELECT id_ref, x_new, y_new" | |
" FROM {} WHERE x_new is Not NULL;").format(self.TMP_TABLE0) | |
liste_result = [row for row in conn.execute(transac)] | |
# And potential errors : | |
transac = ("SELECT id_ref, x_new, y_new" | |
" FROM {} WHERE x_new is NULL" | |
" OR y_new is NULL;").format(self.TMP_TABLE0) | |
liste_error = [row for row in conn.execute(transac)] | |
conn.commit() | |
return liste_result, liste_error | |
if __name__ == "__main__": | |
syst = platform.system().lower() | |
if 'windo' in syst: | |
SPATIALITE_LIBRARY = 'mod_spatialite.dll' | |
elif 'inux' in syst: | |
# SPATIALITE_LIBRARY = 'mod_spatialite.so' | |
SPATIALITE_LIBRARY = '/usr/local/lib/mod_spatialite.so' | |
elif 'darwin' in syst: | |
SPATIALITE_LIBRARY = 'mod_spatialite.dylib' | |
import argparse | |
desc = ("Script to snap a set of points (csv/shp)" | |
" on a network (shp/spatialite)") | |
argPr = argparse.ArgumentParser(description=desc) | |
argPr.add_argument("input_path", | |
help="Path to input .csv or .shp file of points") | |
argPr.add_argument("db_path", help=("Path to the SQLite/Shapefile " | |
"containing the reference network")) | |
argPr.add_argument("table_name", help=("Name of the network table (put 'me" | |
"mory' here if using a shapefile)")) | |
argPr.add_argument("-l", "--libspatialite", default=SPATIALITE_LIBRARY, | |
help=("Set manually the path to the spatialite library" | |
" (full path is preferred)")) | |
argPr.add_argument("-c", "--cachesize", default=2048000, | |
help="Set the SQLite cache size (default: 2048000)") | |
argPr.add_argument("-o", "--output", action='store', | |
help=("Chose the name of the ouputfile " | |
"(default : input_name-snapped.csv)")) | |
argPr.add_argument("-s", "--skipvacuum", | |
help="Skip the DB VACUUM after processing the points", | |
action='store_true') | |
argPr.add_argument("-d", "--distance", action='store_true', | |
help=("Fill a field with the distance (meter) between " | |
"the original point and the snapped one")) | |
argPr.add_argument("-f", "--framesearch", action='store', default=0.072, | |
help=("Change the search frame (" | |
"in km) (default : 6.5km)")) | |
options = argPr.parse_args() | |
options.syst = syst | |
if 'csv' in options.input_path and os.path.exists(options.input_path): | |
input_type = 1 | |
elif 'shp' in options.input_path and os.path.exists(options.input_path): | |
input_type = 2 | |
else: | |
sys.exit('Wrong input file (wrong extension or file don\'t exists)') | |
if '.shp' in options.db_path and 'mem' in options.table_name: | |
options.table_name = options.db_path | |
options.db_path = ':memory:' | |
if options.output: | |
output = options.output | |
else: | |
output = options.input_path[:len(options.input_path)-4] + '-snapped.csv' | |
options.framesearch = float(options.framesearch) | |
if options.framesearch != 0.072: | |
options.framesearch = options.framesearch/111.11 | |
# poor meter/degree conversion ! | |
options.libspatialite = os.path.realpath(options.libspatialite) | |
if 'wind' in syst and options.libspatialite != SPATIALITE_LIBRARY: | |
options.libspatialite = str(options.libspatialite).replace('\\', '/') | |
if not os.path.exists(options.libspatialite): | |
sys.exit('Spatialite shared library not found') | |
start_time = time() | |
liste_coord = read_row(options.input_path, input_type) | |
a = DbLineHandler(options) | |
a.import_row_db(liste_coord, buff_d=options.framesearch) | |
result, error = a.find_nearest() | |
print('{}/{} points snapped (performed in {:.2f}s)'.format(len(result), | |
len(liste_coord), time()-start_time)) | |
if options.distance: | |
result = [[pt2[0], pt2[1], pt2[2], | |
haversine_dist([pt1[1], pt1[2]], | |
[pt2[1], pt2[2]])] for pt1, pt2 | |
in zip(liste_coord, result)] | |
if len(result) > 0: | |
write_row(output, result) | |
else: | |
print('Something went wrong : no snapped point to write') | |
a.close_db(skip_vacuum=options.skipvacuum) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment