Script to generate movie genres Venn diagrams and upset plots from IMDB.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import gzip | |
import io | |
import pprint | |
import upsetplot | |
import pandas as pd | |
from collections import defaultdict | |
from matplotlib import pyplot as plt | |
from matplotlib_venn import venn2, venn3 | |
from urllib.request import Request, urlopen | |
def load_movie_data(sample=True): | |
""" | |
Directly download and format data into pandas dataframe | |
/!\ File is about 130Mb depending on speed connexion, | |
it might take some time. | |
""" | |
req = Request('https://datasets.imdbws.com/title.basics.tsv.gz') | |
req.add_header('Accept-Encoding', 'gzip') | |
response = urlopen(req) | |
content = gzip.decompress(response.read()) | |
data = pd.read_csv(io.BytesIO(content), encoding='utf8', sep="\t") | |
data = data[data.isAdult == 0] | |
if sample: | |
# Two percent might seem low but there is approx. 7 million | |
# titles without Adult category. | |
return data.sample(frac=0.02) | |
else: | |
return data | |
def plot_upset(genres_movies_set, movie_categories, filename): | |
upset_data_sub = upsetplot.from_contents({k: v for k, v in genres_movies_set.items() if k.startswith(movie_categories)}) | |
upsetplot.plot(upset_data_sub) | |
plt.savefig(filename) | |
return | |
if __name__ == "__main__": | |
data = load_movie_data() | |
# Get a general sense of the data | |
pp = pprint.PrettyPrinter(indent=4) | |
print("Data column names: ") | |
pp.pprint(list(data.columns)) | |
print("Data shape: " + str(data.shape)) | |
print("Genres column examples: ") | |
pp.pprint(data.genres.sample(5).head()) | |
# Reshape data to have for every category, | |
# a list of movies. | |
genres_movies = defaultdict(list) | |
for index, row in data.iterrows(): | |
try: | |
for genre in row["genres"].split(','): | |
genres_movies[genre].append(row['primaryTitle']) | |
except: | |
pass | |
pp = pprint.PrettyPrinter(indent=4, depth=1) | |
print("Data structure: ") | |
pp.pprint(genres_movies) | |
# Plot a simple Venn diagram and save it to file | |
venn2([set(genres_movies['Action']), set(genres_movies['Romance'])], set_labels = ('Action', 'Romance')) | |
plt.tight_layout() | |
plt.savefig("./simple_venn.png") | |
plt.clf() | |
venn3([set(genres_movies['Action']), set(genres_movies['Romance']), set(genres_movies['Drama'])], set_labels = ('Action', 'Romance', 'Drama')) | |
plt.tight_layout() | |
plt.savefig("./large_venn.png") | |
plt.clf() | |
genres_movies_set = dict() | |
for k, v in genres_movies.items(): | |
genres_movies_set[k] = set(v) | |
plot_upset(genres_movies_set, ('Action', 'Romance'), "./simple_upset.png") | |
plot_upset(genres_movies_set, ('Action', 'Romance', 'Drama', 'Sci-Fi'), "./large_upset.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment