Created
November 6, 2023 08:46
-
-
Save spaanse/9089b4a177449974b969f4027b566911 to your computer and use it in GitHub Desktop.
Takes an extracted GTFS feed and for each route produces a stop order that matches all trips
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import sys | |
class union_find: | |
def __init__(self, n): | |
self.parent = [-1 for i in range(n)] | |
def find(self, x): | |
if self.parent[x] < 0: | |
return x | |
else: | |
self.parent[x] = self.find(self.parent[x]) | |
return self.parent[x] | |
def size(self, x): | |
return -self.parent[self.find(x)] | |
def same(self, x, y): | |
return self.find(x) == self.find(y) | |
def join(self, x, y): | |
x = self.find(x) | |
y = self.find(y) | |
if x == y: | |
return | |
self.parent[y] += self.parent[x] | |
self.parent[x] = y | |
def tarjan(adj, components): | |
n = len(adj) | |
index = n | |
low = [0] * n | |
high = [0] * n | |
order = [] | |
def scc(cur): | |
nonlocal index, low, high, order, components | |
if low[cur] != 0: | |
return | |
low[cur] = index | |
high[cur] = index | |
index -= 1 | |
for child in adj[cur]: | |
scc(child) | |
if high[components.find(child)] > high[components.find(cur)]: | |
components.join(cur, child) | |
if components.find(cur) == cur: | |
order.append(cur) | |
high[cur] = 0 | |
for cur in range(n): | |
scc(cur) | |
return order | |
def scs_iter(routes, dp, cur): | |
if cur == [len(route) for route in routes]: | |
return [] | |
if tuple(cur) in dp: | |
return dp[tuple(cur)] | |
best = None | |
for i in range(len(routes)): | |
if cur[i] == len(routes[i]): | |
continue | |
new = cur.copy() | |
stop = routes[i][cur[i]] | |
for j in range(len(routes)): | |
if cur[j] == len(routes[j]): | |
continue | |
if routes[j][cur[j]] == stop: | |
new[j] += 1 | |
new_best = [stop] + scs_iter(routes, dp, new) | |
if best == None or len(new_best) < len(best): | |
best = new_best | |
dp[tuple(cur)] = best | |
return best | |
def scs(trips): | |
index = 0 | |
ids = {} | |
names = [] | |
for trip in trips: | |
for stop in trip: | |
if stop not in ids: | |
ids[stop] = index | |
names.append(stop) | |
index += 1 | |
trips = [[ids[stop] for stop in trip] for trip in trips] | |
graph = [set() for i in range(index)] | |
for trip in trips: | |
for u,v in zip(trip,trip[1:]): | |
if u == v: | |
continue | |
graph[u].add(v) | |
order = [] | |
uf = union_find(index) | |
repeat = True | |
while repeat: | |
uf = union_find(index) | |
order = tarjan(graph, uf) | |
order = list(reversed(order)) | |
repeat = False | |
for stop in order: | |
if uf.size(stop) == 1: | |
continue | |
group = [i for i in range(index) if uf.find(i) == stop] | |
middles = set() | |
for trip in trips: | |
for u,v,w in zip(trip, trip[1:], trip[2:]): | |
if uf.find(u) != stop: | |
continue | |
if uf.find(v) != stop: | |
continue | |
if uf.find(w) != stop: | |
continue | |
middles.add(v) | |
for u in group: | |
if u not in middles: | |
names.append(names[u]) | |
# print('split {} ({}) with {}'.format(u,names[u],index), file=sys.stderr) | |
for v in group: | |
if u in graph[v]: | |
graph[v].remove(u) | |
graph[v].add(index) | |
graph.append(set([v for v in graph[u] if v not in group])) | |
graph[u] = set([v for v in graph[u] if v in group]) | |
uf.parent.append(-1) | |
for trip in trips: | |
for i,v in enumerate(trip): | |
if i == 0: | |
continue | |
if v != u: | |
continue | |
if uf.find(trip[i-1]) != v: | |
continue | |
trip[i] = index | |
index += 1 | |
repeat = True | |
break | |
result = [] | |
for stop in order: | |
if uf.size(stop) == 1: | |
result.append(names[stop]) | |
else: | |
group = [i for i in range(index) if uf.find(i) == stop] | |
group = [names[stop] for stop in group] | |
result.append("({:d}: {})".format(uf.size(stop), '; '.join(group))) | |
# dp = {} | |
# subroutes = [[stop for stop in trip if stop in group] for trip in trips] | |
# suborder = scs_iter(subroutes, dp, [0] * len(trips)) | |
# for substop in suborder: | |
# result.append(names[substop]) | |
return result | |
# dp = {} | |
# return scs_iter(trips, dp, [0] * len(trips)) | |
stop_types = { | |
'stop_id': str, | |
'stop_code': str, | |
'stop_name': str, | |
'stop_lat': float, | |
'stop_lon': float, | |
'location_type': int, | |
'parent_station': str, | |
'stop_timezone': str, | |
'wheelchair_boarding': "boolean", | |
'platform_code': str, | |
'zone_id': str | |
} | |
stops = pd.read_csv('../gtfs-nl/stops.txt', dtype=stop_types, index_col='stop_id') | |
stops.sort_index(inplace=True) | |
stop_times_types = { | |
'trip_id': str, | |
'stop_sequence': int, | |
'stop_id': str, | |
'stop_headsign': str, | |
'arrival_time': str, | |
'departure_time': str, | |
'pickup_type': int, | |
'drop_off_type': int, | |
'timepoint': "boolean", | |
'shape_dist_traveled': float, | |
'fare_units_traveled': float, | |
} | |
stop_times = pd.read_csv('../gtfs-nl/stop_times.txt', dtype=stop_times_types, index_col='trip_id') | |
stop_times.rename_axis('trip_id') | |
stop_times.sort_values(['trip_id', 'stop_sequence'], inplace=True) | |
stop_times.sort_index(inplace=True, kind='stable') | |
trip_types = { | |
'route_id': str, | |
'service_id': str, | |
'trip_id': str, | |
'realtime_trip_id': str, | |
'trip_headsign': str, | |
'trip_short_name': str, | |
'trip_long_name': str, | |
'direction_id': int, | |
'block_id': str, | |
'shape_id': str, | |
'wheelchair_accesible': "Int8", | |
'bikes_allowed': "Int8", | |
} | |
trips = pd.read_csv('../gtfs-nl/trips.txt', dtype=trip_types, index_col='route_id') | |
trips.sort_index(inplace=True) | |
route_types = { | |
'route_id': str, | |
'agency_id': str, | |
'route_short_name': str, | |
'route_long_name': str, | |
'route_desc': str, | |
'route_type': int, | |
'route_color': str, | |
'route_text_color': str, | |
'route_url': str, | |
} | |
routes = pd.read_csv('../gtfs-nl/routes.txt', dtype=route_types, index_col='route_id') | |
routes.sort_index(inplace=True) | |
print("load done", file=sys.stderr) | |
remaining = [] | |
for route_id, route in routes.iterrows(): | |
print(route_id, route['route_short_name'], route['route_long_name']) | |
try: | |
route_trips = trips[trips.index.get_loc(route_id)] | |
except Exception as e: | |
print(e) | |
remaining.append(route_id + 'E') | |
# forward_trips = route_trips[route_trips['direction_id'] == 0] | |
# backward_trips = route_trips[route_trips['direction_id'] == 1] | |
route_stops = [] | |
for _, trip in route_trips.iterrows(): | |
trip_stops = [] | |
trip_stop_times = stop_times[stop_times.index.get_loc(trip['trip_id'])] | |
for _, stop_time in trip_stop_times.iterrows(): | |
stop_id = stop_time['stop_id'] | |
stop = stops.loc[stop_id] | |
if not pd.isna(stop['parent_station']): | |
stop = stops.loc[stop['parent_station']] | |
name = stop['stop_name'] | |
trip_stops.append(name) | |
if trip["direction_id"] == 1: | |
trip_stops = reversed(trip_stops) | |
route_stops.append(trip_stops) | |
route = scs(route_stops) | |
print(route_id, ' -> '.join(route)) | |
print(len(remaining), file=sys.stderr) | |
print('; '.join(remaining), file=sys.stderr) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment