Skip to content

Instantly share code, notes, and snippets.

@cjw85
Created June 21, 2017 16:06
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cjw85/56f6aae3d3ce8995af6f8173b6f4eb07 to your computer and use it in GitHub Desktop.
Save cjw85/56f6aae3d3ce8995af6f8173b6f4eb07 to your computer and use it in GitHub Desktop.
Implementation of nanonet using keras.
"""
Reimplementation of nanonet using keras.
Follow the instructions at
https://www.tensorflow.org/install/install_linux
to setup an NVIDIA GPU with CUDA8.0 and cuDNN v5.1.
virtualenv venv --python=python3
. venv/bin/activate
pip install numpy
pip install git+https://github.com/nanoporetech/nanonet@e8ff1edf
pip install --upgrade tensorflow-gpu keras numpy
python keras_call.py --help
Reuses bits of nanonet for peripheral calculations and decoding. Overall speed
is limited by decoding step. Reads are chunked into a maximum of 1000 feature
vectors for processing on GPU. Batch sizes are set for GPU with 11GB. Stitching
together of read chunks is not performed. On a AWS K80 GPU, the network performs
at around 140 feature vectors per second.
For training files should have data labelled as nanonet requires, or can be hacked
in a similar fashion to nanonet.
Results should be roughly equivalent to nanonet, only the LSTM implementation
used does not contain peepholes.
This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0.
If a copy of the MPL was not distributed with this file, You can obtain one at
http://mozilla.org/MPL/2.0/.
(c) 2017 Oxford Nanopore Technologies Ltd.
"""
import argparse
import errno
import os
import sys
from glob import glob
import multiprocessing
import itertools
from collections import Counter
import numpy as np
from keras.utils import to_categorical
from keras.models import Sequential, model_from_json
from keras.layers import LSTM, Dense
from keras.layers.wrappers import Bidirectional
from keras.callbacks import ModelCheckpoint, CSVLogger, TensorBoard, EarlyStopping
from nanonet.features import *
from nanonet.nanonetcall import *
from nanonet.fast5 import Fast5
from nanonet.util import all_nmers
from nanonet.eventdetection.filters import minknow_event_detect
from nanonet.util import tang_imap
import logging
def mkdir_p(path, info=None):
"""Make a directory if it doesn't exist."""
try:
os.makedirs(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
if info is not None:
info = " {}".format(info)
logging.warn("The path {} exists.{}".format(path, info))
pass
else:
raise
def grouper(iterable, n):
"""Yield fixed size chunks of an iterable. Remainder is not padded."""
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
def fast5_to_features(fast5_files, section='template', window=[-1, 0, 1], event_detect=True,
ed_params={'window_lengths':[3, 6], 'thresholds':[1.4, 1.1], 'peak_height':0.2}, sloika_model=False):
"""Generate features from scratch (not using mapping event data as in training)."""
skipped = 0
for f in fast5_files:
try:
with Fast5(f) as fh:
if event_detect:
raw = fh.get_read(raw=True)
events = minknow_event_detect(
raw, fh.sample_rate, **ed_params
)
else:
events = fh.get_read()
except Exception as e:
skipped += 1
continue
try:
X = events_to_features(events, window=window, sloika_model=sloika_model)
except TypeError:
skipped += 1
continue
yield f, X
logging.info("Skipped generating features for {} reads.".format(skipped))
def create_labels(kmer_len, alphabet):
kmers = all_nmers(kmer_len, alpha=alphabet)
bad_kmer = 'X'*kmer_len
kmers.append(bad_kmer)
all_kmers = {k:i for i,k in enumerate(kmers)}
return kmers, all_kmers
def make_training_input(fast5_files, window=[-1, 0, 1], kmer_len=5, alphabet='ACGT', chunk_size=1000, min_chunk=1000, trim=10, get_events=get_events_ont_mapping, get_labels=get_labels_ont_mapping, callback_kwargs={'section':'template', 'kmer_len':5}):
"""Generating training input, adapted from nanonet's equivalent.
:param fast5_list: list of .fast5 files to process.
:param window: event window to derive features.
:param kmer_len: length of kmers to learn.
:param alphabet: alphabet of kmers.
:param chunk_size: chunk size to break reads into for SGE batching.
:param min_chunk: minimum chunk size (used to discard remainder of reads.
:param trim: no. of feature vectors to trim (from either end).
:param get_events: callback to return event data, will be passed .fast5 filename.
:param get_labels: callback to return event kmer labels, will be passed .fast5 filename.
:param callback_kwargs: kwargs for both `get_events` and `get_labels`.
:returns: dictionary of structure: {filename:(features, labels)}. labels
are state indices which will need transforming accoring to keras loss
function that is used.
"""
# Our state labels are kmers plus a junk kmer
kmers, all_kmers = create_labels(kmer_len, alphabet)
data = dict()
for i, f in enumerate(fast5_files):
try:
# Run callbacks to get features and labels
X = events_to_features(get_events(f, **callback_kwargs), window=window)
labels = get_labels(f, **callback_kwargs)
except:
logging.debug("Couldn't get features/labels: {}".format(f))
continue
try:
X = X[trim:-trim]
labels = labels[trim:-trim]
if len(X) != len(labels):
raise RuntimeError('Length of features and labels not equal.')
except:
logging.debug("Feature/labels bad: {}".format(f))
try:
# convert kmers to ints
y = np.fromiter(
(all_kmers[k] for k in labels),
dtype=np.int16, count=len(labels)
)
except Exception as e:
# Checks for erroneous alphabet or kmer length
raise RuntimeError(
'Could not convert kmer labels to ints in file {}. '
'Check labels are no longer than {} and contain only {}'.format(f, kmer_len, alphabet)
)
else:
for chunk, (X_chunk, y_chunk) in enumerate(zip(chunker(X, chunk_size), chunker(y, chunk_size))):
if len(X_chunk) < min_chunk:
break
ident = '{}_{}'.format(f, chunk)
data[ident] = (X_chunk, y_chunk)
return data
def generate_features(fast5_files, jobs=multiprocessing.cpu_count()):
"""Generate training features and labels using multi-processing."""
all_data = dict()
logging.info("Processing {} files.".format(len(fast5_files)))
n_processed = 0
files_per_worker = 100
file_gen = (list(x) for x in grouper(fast5_files, files_per_worker))
for i, data in enumerate(tang_imap(make_training_input, file_gen, unordered=True, threads=jobs)):
all_data.update(data)
n_processed += len(data)
logging.info("Processed {} read chunks ({} files).".format(n_processed, i*files_per_worker))
logging.info("Finished generating features.")
return all_data
def build_model(timesteps, data_dim, num_classes):
"""Builds a nanonet-style graph.
The keras LSTM implementation follows Graves 2013 (with forget gates
with bias equal 1). Usually we add-in peepholes.
"""
model = Sequential()
layer_size = 96
# Bidirectional wrapper takes a copy of the first argument and reverses
# the direction. Weights are independent between components.
model.add(Bidirectional(
LSTM(96, return_sequences=True, name='lstm1', implementation=2),
input_shape=(timesteps, data_dim)
))
model.add(Dense(128, activation='tanh', name='ff1'))
model.add(Bidirectional(
LSTM(96, return_sequences=True, name='lstm2', implementation=2)
))
model.add(Dense(128, activation='tanh', name='ff2'))
model.add(Dense(num_classes, activation='softmax', name='classify'))
return model
def save_model(fname, model):
"""Save model definition."""
with open(fname, 'w') as json_file:
json_file.write(model.to_json())
def load_model(structure, weights):
"""Load a model from .json file with weights initilized from .hdf."""
with open(structure) as json_file:
model_json = json_file.read()
model = model_from_json(model_json)
model.load_weights(weights)
return model
def save_feature_file(fname, data):
"""Save feature dictionary."""
np.save(fname, data)
def load_feature_file(fname):
"""Load the result of `save_feature_file` back to the original
representation.
"""
data = dict()
src = np.load(fname)
fnames = src[()].keys()
for fname in src[()].keys():
data[fname] = src[()][fname]
return data
def run_training(train_name, x_train, y_train, num_classes, model_data=None):
"""Run training."""
data_dim = x_train.shape[2]
timesteps = x_train.shape[1]
if model_data is None:
model = build_model(timesteps, data_dim, num_classes)
else:
model = load_model(*model_data)
#TODO: should check model data dimensions match data dimensions
logging.info("data_dim:", data_dim, 'time_steps:', timesteps, "num_classes:", num_classes)
logging.indo("\n{}".format(model.summary()))
save_model(os.path.join(train_name, 'model_structure.json'), model)
callbacks = [
# Best model according to training set accuracy
ModelCheckpoint(os.path.join(train_name, 'weights.best.hdf5'),
monitor='acc', verbose=1, save_best_only=True, mode='max'),
# Best model according to validation set accuracy
ModelCheckpoint(os.path.join(train_name, 'weights.best.val.hdf5'),
monitor='val_acc', verbose=1, save_best_only=True, mode='max'),
# Checkpoints when training set accuracy improves
ModelCheckpoint(os.path.join(train_name, 'weights-improvement-{epoch:02d}-{acc:.2f}.hdf5'),
monitor='acc', verbose=1, save_best_only=True, mode='max'),
# Stop when no improvement, patience is number of epochs to allow no improvement
EarlyStopping(monitor='val_loss', patience=20),
# Log of epoch stats
CSVLogger(os.path.join(train_name, 'training.log')),
# Allow us to run tensorboard to see how things are going. Some
# features require validation data, not clear why.
TensorBoard(log_dir=os.path.join(train_name, 'logs'),
histogram_freq=5, batch_size=100, write_graph=True, write_grads=True, write_images=True)
]
model.compile(
loss='sparse_categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'],
)
# maybe possible to increase batch_size for faster processing
model.fit(
x_train, y_train,
batch_size=100, epochs=5000,
validation_split=0.2,
callbacks=callbacks,
)
def train(args):
"""Training program."""
train_name = args.train_name
mkdir_p(train_name, info='Results will be overwritten and may use pregenerated features.')
dataset_name = os.path.join(train_name, '{}_squiggles.npy'.format(train_name))
logging.info("Using {} for feature storage/reading.".format(dataset_name))
fast5s = glob(os.path.join(args.fast5_path, "*.fast5"))[:2000]
logging.info("Found {} input files.".format(len(fast5s)))
if not os.path.isfile(dataset_name):
logging.info("Creating dataset. This may take a while.")
data = generate_features(fast5s)
save_feature_file(dataset_name, data)
else:
logging.info("Loading dataset from file.")
data = load_feature_file(dataset_name)
logging.info("Got {} squiggle chunks for training.".format(len(data)))
x_data = []
y_labels = []
for fname in data.keys():
x_data.append(data[fname][0])
# this is the form required by keras' sparse_categorical_crossentropy
y_labels.append([[yi,] for yi in data[fname][1]])
# stack the individual samples into one big tensor
x_data = np.stack(x_data)
y_labels = np.stack(y_labels)
num_classes = 1025 #TODO: obtain this from somewhere
run_training(train_name, x_data, y_labels, num_classes, model_data=args.model_data)
def post_to_call(post, min_prob=1e-5):
kmers, _ = create_labels(5, 'ACGT')
post, good_events = clean_post(post, kmers, min_prob)
if post is None:
return None
# Decode kmers
score, states = decoding.decode_homogenous(post, log=False)
# Form basecall
kmers = [x for x in kmers if 'X' not in x]
qdata = get_qdata(post, kmers)
seq, qual, kmer_path = form_basecall(qdata, kmers, states)
return seq, qual
def run_prediction(data, model_structure, model_weights, output_file='basecalls.fasta'):
"""Run inference, doesn't do basecalling for now, just exercises the network."""
from timeit import default_timer as now
model = load_model(model_structure, model_weights)
logging.info('\n{}'.format(model.summary()))
t0 = now()
class_probs = model.predict(data, batch_size=1500, verbose=1)
t1 = now()
logging.info('Running network took {}s for data of shape {}'.format(t1 - t0, data.shape))
t0 = now()
count = 0
with open(output_file, 'w') as fasta:
for i, seq in enumerate(tang_imap(post_to_call, class_probs, unordered=True, threads=multiprocessing.cpu_count())):
if seq is not None:
count += 1
fasta.write((">block_{}\n{}\n".format(i, seq)))
t1 = now()
logging.info('Decoding took {}s for {} blocks.'.format(t1 - t0, count))
def predict(args):
"""Inference program."""
if args.fast5_path:
fast5s = glob(os.path.join(args.fast5_path, "*.fast5"))[:500]
logging.info("Found {} input files.".format(len(fast5s)))
logging.info("Creating dataset. This may take a while.")
data = generate_features(fast5s)
else:
logging.info("Loading dataset from file.")
data = load_feature_file(args.feature_file)
logging.info("Got {} squiggle chunks for training.".format(len(data)))
for x in data:
break
x_data = np.stack((x[0] for x in data.values()))
run_prediction(x_data, args.model, args.weights)
def main():
logging.basicConfig(format='[%(asctime)s - %(name)s] %(message)s', datefmt='%H:%M:%S', level=logging.INFO)
parser = argparse.ArgumentParser('Squiggle Demultiplexer')
subparsers = parser.add_subparsers(title='subcommands', description='valid commands', help='additional help', dest='command')
subparsers.required = True
tparser = subparsers.add_parser('train', help='Train a model from labelled squiggles.')
tparser.set_defaults(func=train)
tparser.add_argument('fast5_path', help='Path for training fast5.')
tparser.add_argument('--train_name', type=str, default='keras_train', help='Name for training run.')
tparser.add_argument('--model_data', nargs=2, metavar=('def.json', 'weights.hdf'), help='Model definition and initial weights.')
pparser = subparsers.add_parser('predict', help='Create a ZMQ router (client).')
pparser.set_defaults(func=predict)
pparser.add_argument('model', help='Model structure json file from training.')
pparser.add_argument('weights', help='Model weights HDF5 file from training.')
ingroup = pparser.add_mutually_exclusive_group(required=True)
ingroup.add_argument('--fast5_path', help='Path fast5 files.')
ingroup.add_argument('--feature_file', help='Pregenerated features as stored during training.')
args = parser.parse_args()
args.func(args)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment