Skip to content

Instantly share code, notes, and snippets.

@hannes-brt
Last active April 26, 2018 11:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save hannes-brt/137a3ad081fbd891eafb24e1ae9a6aad to your computer and use it in GitHub Desktop.
Save hannes-brt/137a3ad081fbd891eafb24e1ae9a6aad to your computer and use it in GitHub Desktop.
Function to decode a COSSMO training example in tfrecord format
def read_single_cossmo_example(serialized_example, n_tissues=1, coord_sys='rna1'):
"""Decode a single COSSMO example
coord_sys must be one of 'rna1' or 'dna0', if 'dna0' then an extra 'strand' field
must exist in the tfrecord and is extracted.
"""
assert coord_sys in ['dna0', 'rna1']
context_features = {
'n_alt_ss': tf.FixedLenFeature([], tf.int64),
'event_type': tf.FixedLenFeature([], tf.string),
'const_seq': tf.FixedLenFeature([2], tf.string),
'const_site_id': tf.FixedLenFeature([], tf.string),
'const_site_position': tf.FixedLenFeature([], tf.int64),
}
if coord_sys == 'dna0':
context_features['strand'] = tf.FixedLenFeature([], tf.string)
sequence_features = {
'alt_seq': tf.FixedLenSequenceFeature([2], tf.string),
'psi': tf.FixedLenSequenceFeature([n_tissues], tf.float32),
'psi_std': tf.FixedLenSequenceFeature([n_tissues], tf.float32),
'alt_ss_position': tf.FixedLenSequenceFeature([], tf.int64),
'alt_ss_type': tf.FixedLenSequenceFeature([], tf.string)
}
decoded_features = tf.parse_single_sequence_example(
serialized_example,
context_features=context_features,
sequence_features=sequence_features
)
return decoded_features
def read_data_files(alt_ss_type, input_files, n_tissues=1,
num_epochs=None, shuffle=False, sort=True):
"""Read and decode a list of COSSMO tfrecord files.
Parameters
----------
alt_ss_type : One of 'acceptor' or 'donor'.
input_files : A list of paths to the tfrecord files.
n_tissues : Number of tissues (is always one in our dataset.
num_epochs : Number of epochs for which to repeat, or None to repeat forever.
shuffle : Whether to shuffle training examples.
sort : Whether to sort alternative splice sites from 5' to 3'. Should be true
when training recurrent models or whenever the order of splice sites is
important.
Returns
-------
decoded_example : dict
Fields
------
tfrecord_key : (tf.Tensor, shape=(), dtype=string) A unique key for every
training example.
event_type : (tf.Tensor, shape=(), dtype=string) The event type
('acceptor' or 'donor').
const_site_id : (tf.Tensor, shape=(), dtype=string) The constitutive site as
'chromosome:strand:position'.
const_site_position : (tf.Tensor, shape=(), dtype=int64) The position of the
constitutive site. Positions are in "RNA1" format, i.e. forward strand
positions are positive and one based and reverse strand positions are
negative numbers.
const_seq : (tf.Tensor, shape=(2,), dtype=object) The first element is 40nt of
the intronic sequence of the constitutive site and the second one is the
the exonic sequence. I.e. when event type is 'acceptor', the first element
is 40nt upstream of the donor, and the second element is 40nt downstream.
When event type is 'donor', the first element is 40nt downstream of the
constitutive acceptor and the second element is 40nt upstream of it. Sequence
is according to the coding strand.
const_dna_seq : (tf.Tensor, shape=(80,), dtype=uint8) This is the sequence from
a symmetric window of 80nt around the constitutive splice site, encoded as
the ASCII code of the nucleotide (uppercase only).
n_alt_ss : (tf.Tensor, shape=(), dtype=int64) The number of alternative splice
sites (K).
alt_ss_position : (tf.Tensor, shape=(K,), dtype=int64) Position, in RNA1
coordinatesFor each of K alternative splice sites.
alt_ss_type : (tf.Tensor, shape=(K,), dtype=string) The type of each of K
alternative splice sites. Value is one of:
- 'annotated': Splice site is from Gencode v19 annotations.
- 'gtex': A de-novo splice site found in GTEx RNA-Seq data.
- 'maxent': A "decoy" splice site that has MaxEntScore score >= 3.0,
but without RNA-Seq evidence to be used as a splice site.
- 'hard_negative': A random genomic location.
alt_seq : (tf.Tensor, shape=(K, 2), dtype=string) The first dimension corresponds to
K alternative splice site. In each row, the first element is the intronic sequence
of the splice site and the second one is the exonic one. I.e. when event type is
'acceptor', the first element is 40nt upstream of the acceptor and the second element
is 40nt downstream. When event type is 'donor', the first element is 40nt downstream
of the donor site and the second one is 40nt upstream.
alt_dna_seq : (tf.Tensor, shape=(K, 80), dtype=uint8) For each of K alternative sites, the
sequence from a 80nt window around it, encoded as the ASCII code of the nucleotide
(uppercase only).
rna_seq : (tf.Tensor, shape=(K, 80), dtype=uint8) For each of K alternative sites, this
is the "post-splicing" (mRNA) sequence, encoded as the ASCII code of the nucleotide
(uppercase only). When event type is 'acceptor' this is the exonic sequence of the
constitutive site/donor with the exonic sequence of the alternative site/acceptor,
or the reverse when the event type is 'donor'.
psi : (tf.Tensor, shape=(K, 1), dtype=float32) The PSI estimated by the positional bootstrap
procedure for each alternative splice site .
psi_std : (tf.Tensor, shape=(K, 2), dtype=float32) The standard deviation of the PSI estimated
by the positional bootstrap procedure for each alternative splice site.
"""
with tf.name_scope('data_pipeline'):
assert (alt_ss_type in ('acceptor', 'donor'))
filename_queue = tf.train.string_input_producer(
input_files, num_epochs=num_epochs, shuffle=shuffle)
file_reader = tf.TFRecordReader()
tf_record_key, serialized_example = file_reader.read(filename_queue)
_decoded_example = read_single_cossmo_example(serialized_example,
n_tissues)
decoded_example = _decoded_example[0]
decoded_example.update(_decoded_example[1])
if sort:
sorted_distance_indices = tf.nn.top_k(-tf.abs(decoded_example['alt_ss_position'] -
decoded_example['const_site_position']),
k=tf.cast(decoded_example['n_alt_ss'], tf.int32),
sorted=True).indices
decoded_example['alt_seq'] = tf.gather(decoded_example['alt_seq'], sorted_distance_indices)
decoded_example['psi'] = tf.gather(decoded_example['psi'], sorted_distance_indices)
decoded_example['psi_std'] = tf.gather(decoded_example['psi_std'], sorted_distance_indices)
decoded_example['alt_ss_position'] = tf.gather(decoded_example['alt_ss_position'], sorted_distance_indices)
decoded_example['alt_ss_type'] = tf.gather(decoded_example['alt_ss_type'], sorted_distance_indices)
decoded_example['tfrecord_key'] = tf_record_key
const_exonic_seq, const_intronic_seq = \
tf.split(axis=0, num_or_size_splits=2, value=decoded_example['const_seq'])
alt_exonic_seq, alt_intronic_seq = \
tf.split(axis=1, num_or_size_splits=2, value=decoded_example['alt_seq'])
alt_exonic_seq = tf.squeeze(alt_exonic_seq, [1])
alt_intronic_seq = tf.squeeze(alt_intronic_seq, [1])
const_exonic_seq = tf.decode_raw(const_exonic_seq, tf.uint8)
const_intronic_seq = tf.decode_raw(const_intronic_seq, tf.uint8)
alt_exonic_seq = tf.decode_raw(alt_exonic_seq, tf.uint8)
alt_intronic_seq = tf.decode_raw(alt_intronic_seq, tf.uint8)
tile_multiples = tf.stack(
[tf.to_int32(decoded_example['n_alt_ss']), 1])
const_exonic_seq_tiled = tf.tile(
const_exonic_seq, tile_multiples
)
if alt_ss_type == 'acceptor':
rna_seq = tf.concat(axis=1, values=[const_exonic_seq_tiled, alt_exonic_seq])
const_dna = tf.squeeze(
tf.concat(axis=1, values=[const_exonic_seq, const_intronic_seq]),
[0])
alt_dna = tf.concat(axis=1, values=[alt_intronic_seq, alt_exonic_seq])
elif alt_ss_type == 'donor':
rna_seq = tf.concat(axis=1, values=[alt_exonic_seq, const_exonic_seq_tiled])
const_dna = tf.squeeze(
tf.concat(axis=1, values=[const_intronic_seq, const_exonic_seq]),
[0])
alt_dna = tf.concat(axis=1, values=[alt_exonic_seq, alt_intronic_seq])
decoded_example['rna_seq'] = rna_seq
decoded_example['const_dna_seq'] = const_dna
decoded_example['alt_dna_seq'] = alt_dna
return decoded_example
if __name__ == '__main__':
import tensorflow as tf
import os
# Get a list of all input files
tfrecord_dir = 'local/path/to/tfrecords'
files = [os.path.join(tfrecord_dir, f) for f in os.listdir(tfrecord_dir) if f.endswith('tfrecord')]
# Read and decode the tfrecords
decoded_examples_tensor = cossmo.data_pipeline.read_data_files('acceptor', files)
# Get a session and start reading from the queues
session = tf.Session()
coordinator = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
sess.run(tf.local_variables_initializer())
# Training examples can now be read
decoded_examples_values = sess.run(decoded_examples_tensor)
# ...continue with batching etc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment