使用TensorFlow进行高效DNA嵌入

Kindle君 2019-04-05

介绍

利用DNA或其他生物序列(如RNA或蛋白质序列)进行深度学习在过去几年里取得了很大进展。大多数从DNA序列中学习的模型使用one-hot编码方案,该方案使用四个通道来表示四个可能的核苷酸A、C、G和T。

使用TensorFlow进行高效DNA嵌入

长度为N的DNA序列的one-hot编码是(N×4)矩阵,其中列对应于字母A,C,G和T,并且每行恰好具有一个等于一的项,其他条目为零。

任何此类模型中的第一步是以one-hot方案对输入DNA序列进行编码,该方案可以由以下神经网络层处理,无论它们是全连接的,卷积的还是循环的。

通常,one-hot编码被视为在将数据馈送到模型之前发生的预处理步骤。例如,可以对整个训练数据集进行预编码,然后将其存储在磁盘上以用于训练。在其他情况下,模型包含在一些Python预处理层中,该层在Python-land中执行编码,而不是在TensorFlow图中执行。例如,Selene包是一个在PyTorch中处理生物序列的框架,它将编码作为预处理步骤实现,并在Cython中实现。

然而,纯粹从界面设计的角度来看,从DNA序列进行预测的模型应该接受序列作为字符串,而one-hot编码表示应该被视为一个实现细节,不需要对用户可见。这也使得将模型打包为tf.SavedModel并共享或部署更容易,因为该图本身接受DNA序列而无需用户执行他们自己的编码。

编码器只需要执行一个简单的工作:取一个序列,对于每个核苷酸,根据以下映射输出一个向量。

使用TensorFlow进行高效DNA嵌入

正常情况下,与运行模型本身的成本相比,DNA嵌入将是一项非常廉价的操作,但一个糟糕的嵌入实现可能最终成为一个严重的瓶颈。

事实证明,有很多方法可以使用本机TensorFlow操作来实现DNA one-hot编码,这里我将介绍三种方法。

我将首先介绍我的三个实现,然后我将对它们进行基准测试,看看是否存在相当大的差异。

使用lookup table

使用DNA序列学习在某些方面类似于自然语言处理,通常使用lookup table将字符串键从词汇表映射到整数id。

TensorFlow在tf.contrib.lookup模块中有一个lookup table类。该contrib模块将在即将推出的TensorFlow 2.0中弃用,但可能会有非常接近的替代品。TensorFlow有tf.one_hot函数,它可以将这些整数ID转换为单热嵌入的功能。

以下Python函数将字符串格式的DNA输入映射到整数ID。我现在将省略one-hot 编码步骤,我们可以稍后单独对这两个步骤进行基准测试。

使用TensorFlow进行高效DNA嵌入

由于我们需要首先将字符串拆分为单个字符,因此我们使用该tf.string_split函数,因此该函数稍微复杂一些。tf.string_split但是,由于返回稀疏张量,我们需要将其转换回密集向量(lookup table只接受密集向量)。

这个函数稍微有点复杂,因为我们需要先将字符串分割成单独的字符,然后使用tfstring_split函数处理这些字符。因为tfstring_split返回一个稀疏张量,所以我们需要将它转换回一个dense向量(lookup table只接受dense向量)。

最后,table.lookup(seq)将结果作为整数id的张量返回。

使用位操作来计算整数索引

使用只有四个键的lookup table似乎有点过火。但是,有什么更简单的方法可以将DNA字母表映射到tf.one_hot函数的索引呢?一种方法是使用基本的位函数直接计算索引。我们需要执行的映射如下:

使用TensorFlow进行高效DNA嵌入

我们所需要做的就是使用基本的位操作符&、|、^、~、<<和>>找到一个操作序列,将左边的值转换为右边的值。

为了简单起见,我首先在纯Python中实现了以下一个这样的转换:

使用TensorFlow进行高效DNA嵌入

对于序列中的每个字母,这个Python代码片段首先清除第5位和第7位最低有效位(从右边开始),然后向右移1位,除了G和T的值交换之外,几乎得到了正确的位模式。因此,剩下的唯一步骤是用2替换3,用3替换2。第三行使用表达式(nt & 1 << 1)作为掩码,只影响右边第二位的值,然后使用xor翻转最右边的位。下表显示了每个输入的逐步转换。

使用TensorFlow进行高效DNA嵌入

总而言之,此操作将DNA alphabete映射ACGT到索引0,1,2,3。

TensorFlow中此函数的Python实现如下:

def dna_encode_bit_manipulation(seq, name='dna_encode'):
 with tf.name_scope(name):
 bytes = tf.decode_raw(seq, tf.uint8)
 bytes = tf.bitwise.bitwise_and(bytes, ~((1 << 6) | (1 << 4))
 bytes = tf.bitwise.right_shift(bytes, 1)
 mask = tf.bitwise.bitwise_and(bytes, 2)
 mask = tf.bitwise.right_shift(mask, 1)
 bytes = tf.bitwise.bitwise_xor(bytes, mask)
 return bytes

此函数可以替换上面的lookup table,然后可以使用tf.one_hot函数来获取最终编码。由于此

使用TensorFlow进行高效DNA嵌入

函数仅使用元素运算,因此它也可以在GPU上非常有效地运行。

使用嵌入表

自然语言处理的另一个概念是嵌入表,它将整数id作为输入并输出包含该id的嵌入的向量。在自然语言处理中,这些嵌入是随机初始化和训练的,但我们可以利用相同的工具将我们的序列映射到固定的one-hot编码。

在上面的两种方法中,我们将DNA字母映射到整数0,…,3,以便我们可以使用tf.one_hot函数。但是如果我们直接使用A、C、G和T的整数ASCII码作为索引呢?我们只需要把表做大一些,这样我们就可以使用tf.nn.embedding_lookup函数。

使用TensorFlow进行高效DNA嵌入

需要分配一个包含84行的嵌入表,因为这是ASCII码的T。

这种方法还有一些其他优点:它可以很容易地适应编码其他字母,如氨基酸序列,或者我们可以使用这种方法来解释IUPAC核苷酸通配符,例如定义R为嘌呤(A或G)或B定义为“除A之外的任何项”(C、G或T)。如果这是我们想要的,我们可以定义以下嵌入表:

使用TensorFlow进行高效DNA嵌入

基准

为了验证这些方法,我反复使用它们对人类基因组DMD (dystrophine)中最长的基因(约2.24 Mbp)进行one_hot编码。我使用twobitreader模块从hg38 2bit文件(http://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/)中提取序列,计算反补码,因为DMD在负链上:

使用TensorFlow进行高效DNA嵌入

基准测试使用的软件是Python 3.6.8(Anaconda),TensorFlow 1.12.0(conda包,支持GPU和MKL扩展)和CUDA 9.2。

我首先运行了dna_encode_lookup_table和dna_encode_bit_manipulation,没有使用tf.one_hot步骤,以便分别对这两个步骤进行基准测试。

我用timeit.repeat(number=10和repeats=10)。对于位操作方法,执行编码10次的10次重复中的最小值为14ms,对于lookup table方法,为2.17s。

dna_encode_embedding_table函数花了38毫秒直接计算最终的one-hot编码。

如果对前两个函数进行基准测试,然后对tf.one_hot应用程序进行基准测试,就会发现总时间主要由从索引中计算one-hot编码决定。位操作方法和one-hot编码的总时间为41 ms,lookup table方法和one-hot编码的总时间为2.18 s。

最后,下图显示了所有结果:

使用TensorFlow进行高效DNA嵌入

上图为运行时间用于DNA序列的one-hot编码的不同实现。任务是将DMD基因的完整序列(2.24 Mbp)嵌入 10次​​。蓝条仅显示计算整数索引的时间,橙条显示计算整数索引的总时间,然后是one-hot编码。

lookup table显然是一个糟糕的选择,非常低效。在计算整数索引时,位操作非常快,但如果我们包括计算one-hot编码所需的时间,则其大致与嵌入方法一样快。

总的来说,位操作方法具有一些简洁的魅力,但是当需要嵌入IUPAC通配符值时,或者当您使用不同的字母表(如蛋白质序列)时,它不能被使用。在大多数情况下,嵌入表可能是最实用的,因为它可以与任何序列字母表一起使用,并且运行速度与位操作方法一样快。

完整Python示例代码

import numpy as np
import tensorflow as tf
import twobitreader
import timeit
def tf_dna_encode_lookup_table(seq, name="dna_encode"):
 """Map DNA string inputs to integer ids using a lookup table."""
 
 with tf.name_scope(name):
 # Defining the lookup table
 mapping_strings = tf.constant(["A", "C", "G", "T"])
 table = tf.contrib.lookup.index_table_from_tensor(
 mapping=mapping_strings, num_oov_buckets=0, default_value=-1)
 
 # Splitting the string into single characters
 seq = tf.squeeze(
 tf.sparse.to_dense(
 tf.string_split([seq], delimiter=""),
 default_value=""), 0)
 return table.lookup(seq)
def tf_dna_encode_bit_manipulation(seq, name='dna_encode'):
 with tf.name_scope(name):
 bytes = tf.decode_raw(seq, tf.uint8)
 bytes = tf.bitwise.bitwise_and(bytes, ~(1 << 6))
 bytes = tf.bitwise.bitwise_and(bytes, ~(1 << 4))
 bytes = tf.bitwise.right_shift(bytes, 1)
 mask = tf.bitwise.bitwise_and(bytes, 2)
 mask = tf.bitwise.right_shift(mask, 1)
 bytes = tf.bitwise.bitwise_xor(bytes, mask)
 return bytes
#%%
def tf_dna_encode_embedding_table(dna_input, name="dna_encode"):
 """Map DNA sequence to one-hot encoding using an embedding table."""
 
 # Define the embedding table
 _embedding_values = np.zeros([89, 4], np.float32)
 _embedding_values[ord('A')] = np.array([1, 0, 0, 0])
 _embedding_values[ord('C')] = np.array([0, 1, 0, 0])
 _embedding_values[ord('G')] = np.array([0, 0, 1, 0])
 _embedding_values[ord('T')] = np.array([0, 0, 0, 1])
 _embedding_values[ord('W')] = np.array([.5, 0, 0, .5])
 _embedding_values[ord('S')] = np.array([0, .5, .5, 0])
 _embedding_values[ord('M')] = np.array([.5, .5, 0, 0])
 _embedding_values[ord('K')] = np.array([0, 0, .5, .5])
 _embedding_values[ord('R')] = np.array([.5, 0, .5, 0])
 _embedding_values[ord('Y')] = np.array([0, .5, 0, .5])
 _embedding_values[ord('B')] = np.array([0, 1. / 3, 1. / 3, 1. / 3])
 _embedding_values[ord('D')] = np.array([1. / 3, 0, 1. / 3, 1. / 3])
 _embedding_values[ord('H')] = np.array([1. / 3, 1. / 3, 0, 1. / 3])
 _embedding_values[ord('V')] = np.array([1. / 3, 1. / 3, 1. / 3, 0])
 _embedding_values[ord('N')] = np.array([.25, .25, .25, .25])
 embedding_table = tf.get_variable(
 'dna_lookup_table', _embedding_values.shape,
 initializer=tf.constant_initializer(_embedding_values),
 trainable=False) # Ensure that embedding table is not trained
 with tf.name_scope(name):
 dna_input = tf.decode_raw(dna_input, tf.uint8) # Interpret string as bytes
 dna_32 = tf.cast(dna_input, tf.int32)
 encoded_dna = tf.nn.embedding_lookup(embedding_table, dna_32)
 return encoded_dna
#%%
if __name__ == "__main__":
 import argparse
 parser = argparse.ArgumentParser()
 parser.add_argument(
 "genome_file", help="Location to genome 2bit file (hg38)")
 parser.add_argument(
 "-N", type=int, help="Number of iterations for each method")
 parser.add_argument("-r", type=int, help="Number of repeats")
 args = parser.parse_args()
 # Extract DMD sequence and compute reverse complement
 genome = twobitreader.TwoBitFile(args.genome_file)
 dmd_sequence = genome['chrX'][31097676:33339441].upper()
 def reverse_complement(seq):
 return "".join("TGCA"["ACGT".index(s)] for s in seq[::-1])
 dmd_sequence_r = reverse_complement(dmd_sequence)
 # Set up TensorFlow graph
 seq_t = tf.constant(dmd_sequence_r, tf.string)
 seq_encoded_bit_manip_t = tf.one_hot(tf_dna_encode_bit_manipulation(seq_t), 4)
 seq_encoded_lookup_t = tf.one_hot(tf_dna_encode_lookup_table(seq_t), 4)
 seq_encoded_embedding_table_t = tf_dna_encode_embedding_table(seq_t)
 # TensorFlow boilerplate
 session = tf.Session()
 with session.as_default():
 tf.tables_initializer().run()
 tf.global_variables_initializer().run()
 # Now benchmark each method
 print("### Benchmarking bit manipulation method ###")
 results = timeit.repeat(lambda: session.run(seq_encoded_bit_manip_t),
 number=args.N, repeat=args.r)
 print("""Bit manipulation method ({} iterations, {} repeats):
 Total time: {}
 Best time: {}
 """.format(args.N, args.r, sum(results), min(results)))
 print("### Benchmarking embedding table method ###")
 results = timeit.repeat(lambda: session.run(seq_encoded_embedding_table_t),
 number=args.N, repeat=args.r)
 print("""Embedding table method ({} iterations, {} repeats):
 Total time: {}
 Best time: {}
 """.format(args.N, args.r, sum(results), min(results)))
 print("### Benchmarking lookup table method ###")
 results = timeit.repeat(lambda: session.run(seq_encoded_lookup_t),
 number=args.N, repeat=args.r)
 print("""Lookup table method ({} iterations, {} repeats):
 Total time: {}
 Best time: {}
 """.format(args.N, args.r, sum(results), min(results)))

使用TensorFlow进行高效DNA嵌入

使用TensorFlow进行高效DNA嵌入

使用TensorFlow进行高效DNA嵌入

相关推荐