Created
August 29, 2023 17:01
-
-
Save dlozeve/5c5c000c46acf06507b6b3577cbeb70b to your computer and use it in GitHub Desktop.
Visualization of travel time by train from Paris to major French cities, 1950-2019
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import polars as pl | |
import altair as alt | |
# https://data.sncf.com/explore/dataset/meilleurs-temps-des-parcours-des-trains/information/ | |
DATA_FILE = "meilleurs-temps-des-parcours-des-trains.csv" | |
# https://en.wikipedia.org/wiki/File:France_TGV.png | |
LGVs = pl.DataFrame( | |
{ | |
"line": [ | |
"LGV Nord", | |
"LGV Est", | |
"LGV Sud-Est", | |
"LGV Rhônes-Alpes", | |
"LGV Méditerranée", | |
"LGV Atlantique", | |
"LGV Sud Europe Atlantique", | |
"LGV Bretagne-Pays de la Loire", | |
], | |
"year": [ | |
1993, | |
2007, | |
1983, | |
1994, | |
2001, | |
1990, | |
2017, | |
2017, | |
], | |
} | |
).select("line", pl.col("year").cast(str).str.strptime(pl.Date, format="%Y")) | |
def load_data() -> pl.DataFrame: | |
return ( | |
pl.read_csv(DATA_FILE, separator=";") | |
.select( | |
pl.col("Relations") | |
.str.split_exact(" - ", 1) | |
.struct.rename_fields(["start", "end"]), | |
pl.col("Année").cast(str).str.strptime(pl.Date, format="%Y").alias("year"), | |
pl.col("Temps estimé en minutes").alias("duration"), | |
) | |
.unnest("Relations") | |
) | |
def plot_durations(durations: pl.DataFrame, start: str, ends: list[str]): | |
start = start.upper() | |
ends = [end.upper() for end in ends] | |
df = ( | |
durations.filter((pl.col("start") == start) & (pl.col("end").is_in(ends))) | |
.select(pl.col("end").str.to_titlecase(), "year", pl.col("duration") / 60) | |
.sort("end", "year") | |
) | |
base = alt.Chart(df.to_pandas()).encode( | |
alt.Color("end", title="Destination").legend(None) | |
) | |
lines = base.mark_line(clip=True).encode( | |
alt.X("year", title="Year").scale(domain=("1948", "2019")), | |
alt.Y("duration", title="Duration (hours)") | |
.scale(domain=(0, 11)) | |
.axis(values=list(range(12))), | |
) | |
last_duration = ( | |
base.mark_circle() | |
.encode( | |
alt.X("last_year['year']:T"), | |
alt.Y("last_year['duration']:Q"), | |
) | |
.transform_aggregate(last_year="argmax(year)", groupby=["end"]) | |
) | |
names = last_duration.mark_text(align="left", dx=4, fontSize=14).encode(text="end") | |
# lgvs = ( | |
# alt.Chart(LGVs.to_pandas()) | |
# .encode(alt.X("year")) | |
# .mark_rule(strokeDash=(8, 4), opacity=0.5) | |
# ) | |
return ( | |
(lines + last_duration + names) | |
.properties( | |
title=alt.Title( | |
"Travel time by train from Paris to major French cities", fontSize=20 | |
), | |
width=800, | |
height=600, | |
) | |
.configure_axis( | |
titleFontSize=20, | |
labelFontSize=18, | |
) | |
.configure_legend( | |
titleFontSize=20, | |
labelFontSize=18, | |
) | |
) | |
def main(): | |
durations = load_data() | |
durations.write_csv("durations.csv") | |
ends = [ | |
"Lille", | |
"Strasbourg", | |
"Lyon", | |
"Marseille", | |
"Rennes", | |
"Bordeaux", | |
] | |
chart = plot_durations(durations, start="Paris", ends=ends) | |
chart.save("all.html") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment