Skip to content

Instantly share code, notes, and snippets.

@lamhoangtung
Created November 29, 2018 08:30
Show Gist options
  • Save lamhoangtung/787f4bd182ebe45969a1754c10d653c1 to your computer and use it in GitHub Desktop.
Save lamhoangtung/787f4bd182ebe45969a1754c10d653c1 to your computer and use it in GitHub Desktop.
Basic TTA
class TTA_ModelWrapper():
"""A simple TTA wrapper for keras computer vision models.
Args:
model (keras model): A fitted keras model with a predict method.
"""
def __init__(self, model):
self.model = model
def predict(self, X):
"""Wraps the predict method of the provided model.
Augments the testdata with horizontal and vertical flips and
averages the results.
Args:
X (numpy array of dim 4): The data to get predictions for.
"""
p0 = self.model.predict(X, batch_size=128, verbose=1)
p1 = self.model.predict(np.flipud(X), batch_size=128, verbose=1)
p = (p0 + p1) / 2
return np.array(p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment