xiaoxiaokeke 2020-01-14
最近在工作中有一个需求:用训练好的模型将数据库中所有数据得出预测结果,并保存到另一张表上。数据库中的数据是一篇篇文章,我训练好的模型是对其中的四个段落分别分类,即我有四个模型,拿到文本后需要提取出这四个段落,并用对应模型分别预测这四个段落的类别,然后存入数据库中。我是用keras训练的模型,backend为tensorflow,因为数据量比较大,自然想到用多进程,但是使用时发现每次都在model.predict上停在那不动了。在Windows上运行一点问题没有,但是在Linux服务器上就有这个问题
模型使用时大致如下:
# -*- coding: utf-8 -*-
import jieba
import numpy as np
import keras
import tensorflow as tf
from keras.preprocessing import sequence
from keras.models import load_model
from config import Config
import json
config_file = ‘data/config.ini‘
model_path = Config(config_file).get_value_str(‘cnn‘, ‘model_path‘)
graph = tf.Graph()
with graph.as_default():
session = tf.Session()
with session.as_default():
model = load_model(model_path)
graph_var = graph
session_var = session
def sentence_process(sentence):
with open(‘data/words.json‘, encoding=‘utf-8‘) as f:
words_json = json.load(f)
words = words_json[‘words‘]
word_to_id = words_json[‘word_to_id‘]
max_length = words_json[‘max_length‘]
segs = jieba.lcut(sentence)
segs = filter(lambda x: len(x) >= 1, segs)
segs = [x for x in segs if x]
vector = []
for seg in segs:
if seg in words:
vector.append(word_to_id[seg])
else:
vector.append(4999)
return vector, max_length
def predict(sentence):
vector, max_length = sentence_process(sentence)
vector_np = np.array([vector])
x_vector = sequence.pad_sequences(vector_np, max_length)
with graph_var.as_default():
with session_var.as_default():
y = model.predict_proba(x_vector)
if y[0][1] > 0.5:
predict = 1
else:
predict = 0
return predict多进程使用大致如下:
from multiprocessing import Pool
from classifaction.classify1 import predict1
from classifaction.classify2 import predict2
from classifaction.classify3 import predict3
from classifaction.classify4 import predict4
def main():
‘‘‘
get texts
‘‘‘
pool = Pool(processes=4, maxtasksperchild=1)
pool.map(save_to_database, texts)
pool.close()
pool.join()
def save_to_database(texts):
text1, text2, text3, text4 = texts[0], texts[1], texts[2], texts[3]
label1 = predict1(text1)
label2 = predict2(text2)
label3 = predict3(text3)
label4 = predict4(text4)
if __name__ == ‘__main__‘:
main()运行时发现所有进程都停在model.predict上不动了。
Google后发现很多遇到这个问题,也终于找到一个方法。可以看一下链接:
https://github.com/keras-team/keras/issues/9964
有一个方法是
As of TF 1.10, the library seems to be somewhat forkable. So you will have to test what you can do. Also, something you can try is: multiprocessing.set_start_method(‘spawn‘, force=True) if you‘re on UNIX and using Python3.
即在使用multiprocessing之前先设置一下。更改后代码如下:
import multiprocessing
from multiprocessing import Pool
from classifaction.classify1 import predict1
from classifaction.classify2 import predict2
from classifaction.classify3 import predict3
from classifaction.classify4 import predict4
def main():
‘‘‘
get texts
‘‘‘
pool = Pool(processes=4, maxtasksperchild=1)
multiprocessing.set_start_method(‘spawn‘, force=True)
pool.map(save_to_database, texts)
pool.close()
pool.join()
def save_to_database(texts):
text1, text2, text3, text4 = texts[0], texts[1], texts[2], texts[3]
label1 = predict1(text1)
label2 = predict2(text2)
label3 = predict3(text3)
label4 = predict4(text4)
if __name__ == ‘__main__‘:
main()这样就可以用多进程了