使用卷积神经网络进行心电图心律失常分类

sallyyoungsh 2018-07-05

在本文中,我们将实现这篇论文https://arxiv.org/pdf/1804.06812.pdf,其中我们将心电图分为7类,一种是正常的,另外六种是不同类型的心律失常。将一维心电图信号转化为二维心电图图像,不再需要噪声滤波和特征提取。这是很重要的,因为一些ECG节拍在噪声滤波和特征提取中被忽略。此外,可以通过扩大心电图图像来扩大训练数据,提高分类精度。由于一维心电信号的畸变会降低分类器的性能,因此在一维信号中很难进行数据增强。然而,用不同的裁剪方法增加二维心电图图像有助于CNN模型对单个心电图图像进行不同视点的训练。使用ECG图像作为ECG心律失常分类的输入数据,也有利于鲁棒性。

获取数据

我使用过MIT-BIH心律失常数据库(https://www.physionet.org/physiobank/database/mitdb/)进行CNN模型训练和测试,这在论文中已经提到过。MIT-BIH心律失常数据库包含48个半小时的双通道动态心电图记录摘录,这些记录来自1975年至1979年期间BIH心律失常实验室研究的47名受试者。每条记录有三个文件:1.注释文件 2.信号文件 3.头文件。

在这个实现中,我使用了lead2信号。下面Python代码片段中的get_records()函数创建数据集中所有记录的列表。beat_annotations()函数查找属于特定类别的beats指标(在下面的代码中,我找到了正常beats的指标)。segmentation()函数用于分割每个类别的beat。

def get_records():

""" Get paths for data in data/mit/ directory """

# Download if doesn't exist

# There are 3 files for each record

# *.atr is one of them

paths = glob('/path/to/MITDB/dataset/*.atr')

# Get rid of the extension

paths = [path[:-4] for path in paths]

paths.sort()

return paths

def beat_annotations(annotation):

""" Get rid of non-beat markers """

"""'N' for normal beats. Similarly we can give the input 'L' for left bundle branch block beats. 'R' for right bundle branch block

beats. 'A' for Atrial premature contraction. 'V' for ventricular premature contraction. '/' for paced beat. 'E' for Ventricular

escape beat."""

good = ['N']

ids = np.in1d(annotation.symbol, good)

# We want to know only the positions

beats = annotation.sample[ids]

return beats

def segmentation(records):

Normal = []

for e in records:

signals, fields = wfdb.rdsamp(e, channels = [0])

ann = wfdb.rdann(e, 'atr')

good = ['N']

ids = np.in1d(ann.symbol, good)

imp_beats = ann.sample[ids]

beats = (ann.sample)

for i in imp_beats:

beats = list(beats)

j = beats.index(i)

if(j!=0 and j!=(len(beats)-1)):

x = beats[j-1]

y = beats[j+1]

diff1 = abs(x - beats[j])//2

diff2 = abs(y - beats[j])//2

Normal.append(signals[beats[j] - diff1: beats[j] + diff2, 0])

return Normal

方法

使用卷积神经网络进行心电图心律失常分类

由于CNN模型将二维图像作为输入数据处理,所以在心电图数据预处理步骤中将心电图信号转换为心电图图像。利用这些获得的心电图图像,在CNN分类器步骤中对7种心电图类型进行分类。七类:房性早搏,正常,左束支传导阻滞,起搏心搏,室性早搏,右束支传导阻滞和心室逃逸搏动。

我已经通过绘制每个ECG beat将ECG信号转换成ECG图像。我首先使用Python的Biosppy模块检测ECG信号中的R峰。一旦找到R-peak,为了分割beat,我采用当前的R-peak值和最后的R-peak值,占据了两者之间的距离的一半,并将这些信号包括在当前beat中。同样地,我为下一个beat做了这个。

data = np.array(csv_data)

signals = []

count = 1

peaks = biosppy.signals.ecg.christov_segmenter(signal=data, sampling_rate = 200)[0]

for i in (peaks[1:-1]):

diff1 = abs(peaks[count - 1] - i)

diff2 = abs(peaks[count + 1]- i)

x = peaks[count - 1] + diff1//2

y = peaks[count + 1] - diff2//2

signal = data[x:y]

signals.append(signal)

count += 1

return signals

为了将这些分段信号转换为图像,我使用了Matplotlib和OpenCV。因为在ECG信号中,颜色并不重要,所以我将它们转换为灰度图像。

for count, i in enumerate(array):

fig = plt.figure(frameon=False)

plt.plot(i)

plt.xticks([]), plt.yticks([])

for spine in plt.gca().spines.values():

spine.set_visible(False)

filename = directory + '/' + str(count)+'.png'

fig.savefig(filename)

im_gray = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)

im_gray = cv2.resize(im_gray, (128, 128), interpolation = cv2.INTER_LANCZOS4)

cv2.imwrite(filename, im_gray)

使用卷积神经网络进行心电图心律失常分类

数据增强

数据增强意味着增加数据点的数量。就图像而言,这可能意味着增加数据集中的图像数量。通过增加和平衡输入数据,我们可以实现高特异性和灵敏度。我增加了六种心电图心律失常beats(PVC,PAB,RBB,LBB,APC,VEB),采用九种不同的裁剪方法:左上,中上,右上,左中,中,右中,左下,中底,右下角。每种裁剪方法产生三种尺寸的ECG图像中的两种,即96×96。然后,将这些增强图像调整为原始尺寸,即128×128。

def cropping(image, filename):

#Left Top Crop

crop = image[:96, :96]

crop = cv2.resize(crop, (128, 128))

cv2.imwrite(filename[:-4] + 'leftTop' + '.png', crop)

#Center Top Crop

crop = image[:96, 16:112]

crop = cv2.resize(crop, (128, 128))

cv2.imwrite(filename[:-4] + 'centerTop' + '.png', crop)

#Right Top Crop

crop = image[:96, 32:]

crop = cv2.resize(crop, (128, 128))

cv2.imwrite(filename[:-4] + 'rightTop' + '.png', crop)

#Left Center Crop

crop = image[16:112, :96]

crop = cv2.resize(crop, (128, 128))

cv2.imwrite(filename[:-4] + 'leftCenter' + '.png', crop)

#Center Center Crop

crop = image[16:112, 16:112]

crop = cv2.resize(crop, (128, 128))

cv2.imwrite(filename[:-4] + 'centerCenter' + '.png', crop)

#Right Center Crop

crop = image[16:112, 32:]

crop = cv2.resize(crop, (128, 128))

cv2.imwrite(filename[:-4] + 'rightCenter' + '.png', crop)

#Left Bottom Crop

crop = image[32:, :96]

crop = cv2.resize(crop, (128, 128))

cv2.imwrite(filename[:-4] + 'leftBottom' + '.png', crop)

#Center Bottom Crop

crop = image[32:, 16:112]

crop = cv2.resize(crop, (128, 128))

cv2.imwrite(filename[:-4] + 'centerBottom' + '.png', crop)

#Right Bottom Crop

crop = image[32:, 32:]

crop = cv2.resize(crop, (128, 128))

cv2.imwrite(filename[:-4] + 'rightBottom' + '.png', crop)

CNN模型的架构

CNN模型的体系结构与本文中使用的体系结构相同:https://arxiv.org/pdf/1804.06812.pdf。已经使用了11层模型。该模型的主要结构与VGGNet非常相似。CNN模型对所有层使用Xavier初始化。对于激活函数,他们使用了指数线性单位(ELU),与ReLU不同,它已在VGGNet中使用。

使用卷积神经网络进行心电图心律失常分类

CNN模型的架构

使用卷积神经网络进行心电图心律失常分类

model = Sequential()

model.add(Conv2D(64, (3,3),strides = (1,1), input_shape = IMAGE_SIZE + [3],kernel_initializer='glorot_uniform'))

model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(Conv2D(64, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(MaxPool2D(pool_size=(2, 2), strides= (2,2)))

model.add(Conv2D(128, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(Conv2D(128, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(MaxPool2D(pool_size=(2, 2), strides= (2,2)))

model.add(Conv2D(256, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(Conv2D(256, (3,3),strides = (1,1),kernel_initializer='glorot_uniform'))

model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(MaxPool2D(pool_size=(2, 2), strides= (2,2)))

model.add(Flatten())

model.add(Dense(2048))

model.add(keras.layers.ELU())

model.add(BatchNormalization())

model.add(Dropout(0.5))

model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

验证

为了验证,我使用了skLearn的train_test_split分割训练和测试集,训练集中有352295张图像,测试集中有44041张图像,该模型在给定的测试集中的准确率为99.31。

相关推荐