Skip to content

Instantly share code, notes, and snippets.

@surya501
Created March 14, 2017 18:31
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save surya501/1ad1235bcb75131e25f8ab6da99ddfc7 to your computer and use it in GitHub Desktop.
Save surya501/1ad1235bcb75131e25f8ab6da99ddfc7 to your computer and use it in GitHub Desktop.
Airport embedding in keras
# %matplotlib inline
import pandas as pd
import numpy as np
import keras
import tensorflow as tf
import os
from keras.models import Model
from keras.layers import Input, Embedding, merge
from keras.layers.core import Flatten, Dense, Dropout, Lambda
from keras.optimizers import Adam
from keras.regularizers import l2
# https://www.transtats.bts.gov/DL_SelectFields.asp?Table_ID=236
# download the flight data from the above url and unzip/rename.
# select the following fields: Origin, Dest, Actual Elapsed time, distance
flights = pd.read_csv('flights_1.csv')
flights = flights[np.isfinite(flights['ACTUAL_ELAPSED_TIME'])] #remove empty and nan values for elapsed time.
# flights = flights.drop('ACTUAL_ELAPSED_TIME', 1)
flights = flights.drop('Unnamed: 4', 1)
# flights = flights.drop_duplicates()
# flights = flights[np.isfinite(flights['DISTANCE'])] #remove empty and nan values for elapsed time.
flights = flights.drop('DISTANCE', 1)
flights.head()
n_airports = flights.ORIGIN.nunique()
n_airports
airports = flights.ORIGIN.unique()
airports.sort()
airports_d = airports
# airports
airports2idx = {o:i for i,o in enumerate(airports)}
# airports2idx
#convert all airports to index values in the flights data.
flights.ORIGIN = flights.ORIGIN.apply(lambda x: airports2idx[x])
flights.DEST = flights.DEST.apply(lambda x: airports2idx[x])
# flights
# seperate data into train and test values.
msk = np.random.rand(len(flights)) < 0.8
trn = flights[msk]
val = flights[~msk]
def embedding_input(name, n_in, n_out, reg):
inp = Input(shape=(1,), dtype='int64', name=name)
return inp, Embedding(n_in, n_out, input_length=1, W_regularizer=l2(reg))(inp)
n_factors = 5
origin_airports, o = embedding_input('airports', n_airports, n_factors, 1e-4)
dest_airports, d = embedding_input('airports_d', n_airports, n_factors, 1e-4)
x = merge([o, d], mode='concat')
x = Flatten()(x)
# x = Dropout(0.3)(x)
x = Dense(70, activation='relu')(x)
# x = Dropout(0.7)(x)
x = Dense(1)(x)
nn = Model([origin_airports, dest_airports], x)
nn.compile(Adam(0.001), loss='mse')
nn.fit([trn.ORIGIN, trn.DEST], trn.ACTUAL_ELAPSED_TIME, batch_size=64, nb_epoch=8,
validation_data=([val.ORIGIN, val.DEST], val.ACTUAL_ELAPSED_TIME))
nn.optimizer.lr=0.1
nn.fit([trn.ORIGIN, trn.DEST], trn.ACTUAL_ELAPSED_TIME, batch_size=64, nb_epoch=20,
validation_data=([val.ORIGIN, val.DEST], val.ACTUAL_ELAPSED_TIME))
saver = tf.train.Saver()
LOG_DIR = '/home/surya/tensorboard_log'
os.makedirs(LOG_DIR, exist_ok=True)
saver.save(keras.backend.get_session(), os.path.join(LOG_DIR, "model.ckpt"))
#start the tensorboard with `tensorboard --logdir=~/tensorboard_log` and explore the embedding tab (look for 294x2)
def estimate_time(src_airport, dest_airport):
return nn.predict([np.array([airports2idx[src_airport]]),np.array([airports2idx[dest_airport]])])
estimate_time('ABE', 'ATL')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment