Skip to content

Instantly share code, notes, and snippets.

@spaanse
Created November 6, 2023 08:46
Show Gist options
  • Save spaanse/9089b4a177449974b969f4027b566911 to your computer and use it in GitHub Desktop.
Save spaanse/9089b4a177449974b969f4027b566911 to your computer and use it in GitHub Desktop.
Takes an extracted GTFS feed and for each route produces a stop order that matches all trips
import pandas as pd
import sys
class union_find:
def __init__(self, n):
self.parent = [-1 for i in range(n)]
def find(self, x):
if self.parent[x] < 0:
return x
else:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def size(self, x):
return -self.parent[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def join(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
self.parent[y] += self.parent[x]
self.parent[x] = y
def tarjan(adj, components):
n = len(adj)
index = n
low = [0] * n
high = [0] * n
order = []
def scc(cur):
nonlocal index, low, high, order, components
if low[cur] != 0:
return
low[cur] = index
high[cur] = index
index -= 1
for child in adj[cur]:
scc(child)
if high[components.find(child)] > high[components.find(cur)]:
components.join(cur, child)
if components.find(cur) == cur:
order.append(cur)
high[cur] = 0
for cur in range(n):
scc(cur)
return order
def scs_iter(routes, dp, cur):
if cur == [len(route) for route in routes]:
return []
if tuple(cur) in dp:
return dp[tuple(cur)]
best = None
for i in range(len(routes)):
if cur[i] == len(routes[i]):
continue
new = cur.copy()
stop = routes[i][cur[i]]
for j in range(len(routes)):
if cur[j] == len(routes[j]):
continue
if routes[j][cur[j]] == stop:
new[j] += 1
new_best = [stop] + scs_iter(routes, dp, new)
if best == None or len(new_best) < len(best):
best = new_best
dp[tuple(cur)] = best
return best
def scs(trips):
index = 0
ids = {}
names = []
for trip in trips:
for stop in trip:
if stop not in ids:
ids[stop] = index
names.append(stop)
index += 1
trips = [[ids[stop] for stop in trip] for trip in trips]
graph = [set() for i in range(index)]
for trip in trips:
for u,v in zip(trip,trip[1:]):
if u == v:
continue
graph[u].add(v)
order = []
uf = union_find(index)
repeat = True
while repeat:
uf = union_find(index)
order = tarjan(graph, uf)
order = list(reversed(order))
repeat = False
for stop in order:
if uf.size(stop) == 1:
continue
group = [i for i in range(index) if uf.find(i) == stop]
middles = set()
for trip in trips:
for u,v,w in zip(trip, trip[1:], trip[2:]):
if uf.find(u) != stop:
continue
if uf.find(v) != stop:
continue
if uf.find(w) != stop:
continue
middles.add(v)
for u in group:
if u not in middles:
names.append(names[u])
# print('split {} ({}) with {}'.format(u,names[u],index), file=sys.stderr)
for v in group:
if u in graph[v]:
graph[v].remove(u)
graph[v].add(index)
graph.append(set([v for v in graph[u] if v not in group]))
graph[u] = set([v for v in graph[u] if v in group])
uf.parent.append(-1)
for trip in trips:
for i,v in enumerate(trip):
if i == 0:
continue
if v != u:
continue
if uf.find(trip[i-1]) != v:
continue
trip[i] = index
index += 1
repeat = True
break
result = []
for stop in order:
if uf.size(stop) == 1:
result.append(names[stop])
else:
group = [i for i in range(index) if uf.find(i) == stop]
group = [names[stop] for stop in group]
result.append("({:d}: {})".format(uf.size(stop), '; '.join(group)))
# dp = {}
# subroutes = [[stop for stop in trip if stop in group] for trip in trips]
# suborder = scs_iter(subroutes, dp, [0] * len(trips))
# for substop in suborder:
# result.append(names[substop])
return result
# dp = {}
# return scs_iter(trips, dp, [0] * len(trips))
stop_types = {
'stop_id': str,
'stop_code': str,
'stop_name': str,
'stop_lat': float,
'stop_lon': float,
'location_type': int,
'parent_station': str,
'stop_timezone': str,
'wheelchair_boarding': "boolean",
'platform_code': str,
'zone_id': str
}
stops = pd.read_csv('../gtfs-nl/stops.txt', dtype=stop_types, index_col='stop_id')
stops.sort_index(inplace=True)
stop_times_types = {
'trip_id': str,
'stop_sequence': int,
'stop_id': str,
'stop_headsign': str,
'arrival_time': str,
'departure_time': str,
'pickup_type': int,
'drop_off_type': int,
'timepoint': "boolean",
'shape_dist_traveled': float,
'fare_units_traveled': float,
}
stop_times = pd.read_csv('../gtfs-nl/stop_times.txt', dtype=stop_times_types, index_col='trip_id')
stop_times.rename_axis('trip_id')
stop_times.sort_values(['trip_id', 'stop_sequence'], inplace=True)
stop_times.sort_index(inplace=True, kind='stable')
trip_types = {
'route_id': str,
'service_id': str,
'trip_id': str,
'realtime_trip_id': str,
'trip_headsign': str,
'trip_short_name': str,
'trip_long_name': str,
'direction_id': int,
'block_id': str,
'shape_id': str,
'wheelchair_accesible': "Int8",
'bikes_allowed': "Int8",
}
trips = pd.read_csv('../gtfs-nl/trips.txt', dtype=trip_types, index_col='route_id')
trips.sort_index(inplace=True)
route_types = {
'route_id': str,
'agency_id': str,
'route_short_name': str,
'route_long_name': str,
'route_desc': str,
'route_type': int,
'route_color': str,
'route_text_color': str,
'route_url': str,
}
routes = pd.read_csv('../gtfs-nl/routes.txt', dtype=route_types, index_col='route_id')
routes.sort_index(inplace=True)
print("load done", file=sys.stderr)
remaining = []
for route_id, route in routes.iterrows():
print(route_id, route['route_short_name'], route['route_long_name'])
try:
route_trips = trips[trips.index.get_loc(route_id)]
except Exception as e:
print(e)
remaining.append(route_id + 'E')
# forward_trips = route_trips[route_trips['direction_id'] == 0]
# backward_trips = route_trips[route_trips['direction_id'] == 1]
route_stops = []
for _, trip in route_trips.iterrows():
trip_stops = []
trip_stop_times = stop_times[stop_times.index.get_loc(trip['trip_id'])]
for _, stop_time in trip_stop_times.iterrows():
stop_id = stop_time['stop_id']
stop = stops.loc[stop_id]
if not pd.isna(stop['parent_station']):
stop = stops.loc[stop['parent_station']]
name = stop['stop_name']
trip_stops.append(name)
if trip["direction_id"] == 1:
trip_stops = reversed(trip_stops)
route_stops.append(trip_stops)
route = scs(route_stops)
print(route_id, ' -> '.join(route))
print(len(remaining), file=sys.stderr)
print('; '.join(remaining), file=sys.stderr)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment