1 Star 2 Fork 1

左令君 / seq2seq_keras_translate

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
train.py 7.58 KB
一键复制 编辑 原始数据 按行查看 历史
涛涛 提交于 2019-02-12 18:40 . 完成项目
from keras.layers import Input,LSTM,Dense
from keras.models import Model,load_model
from keras.utils import plot_model
import pandas as pd
import numpy as np
import re
import os
# os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin'
#我的graphviz环境没配好,为了后面的Plot_model
def create_model(n_input, n_output, n_units):
# 训练阶段
# encoder
encoder_input = Input(shape=(None, n_input))
# encoder输入维度n_input为每个时间步的输入xt的维度,这里是用来one-hot的英文字符数
encoder = LSTM(n_units, return_state=True)
# n_units为LSTM单元中每个门的神经元的个数,return_state设为True时才会返回最后时刻的状态h,c
_, encoder_h, encoder_c = encoder(encoder_input)
encoder_state = [encoder_h, encoder_c]
# 保留下来encoder的末状态作为decoder的初始状态
# decoder
decoder_input = Input(shape=(None, n_output))
# decoder的输入维度为中文字符数
decoder = LSTM(n_units, return_sequences=True, return_state=True)
# 训练模型时需要decoder的输出序列来与结果对比优化,故return_sequences也要设为True
decoder_output, _, _ = decoder(decoder_input, initial_state=encoder_state)
# 在训练阶段只需要用到decoder的输出序列,不需要用最终状态h.c
decoder_dense = Dense(n_output, activation='softmax')
decoder_output = decoder_dense(decoder_output)
# 输出序列经过全连接层得到结果
# 生成的训练模型
model = Model([encoder_input, decoder_input], decoder_output)
# 第一个参数为训练模型的输入,包含了encoder和decoder的输入,第二个参数为模型的输出,包含了decoder的输出
# 推理阶段,用于预测过程
# 推断模型—encoder
encoder_infer = Model(encoder_input, encoder_state)
# 推断模型-decoder
decoder_state_input_h = Input(shape=(n_units,))
decoder_state_input_c = Input(shape=(n_units,))
decoder_state_input = [decoder_state_input_h, decoder_state_input_c] # 上个时刻的状态h,c
decoder_infer_output, decoder_infer_state_h, decoder_infer_state_c = decoder(decoder_input,
initial_state=decoder_state_input)
decoder_infer_state = [decoder_infer_state_h, decoder_infer_state_c] # 当前时刻得到的状态
decoder_infer_output = decoder_dense(decoder_infer_output) # 当前时刻的输出
decoder_infer = Model([decoder_input] + decoder_state_input, [decoder_infer_output] + decoder_infer_state)
return model, encoder_infer, decoder_infer
N_UNITS = 256
BATCH_SIZE = 128
# EPOCH = 200
EPOCH = 200
# 样本数
# NUM_SAMPLES = 10000
NUM_SAMPLES = 2048
# True:英文翻译为中文,False:中文翻译为英文
english_to_chinese = False
# True:继续训练,False:从头训练
continue_train = True
data_path = 'data/translate2048.txt'
df = pd.read_table(data_path,header=None).iloc[:NUM_SAMPLES,:,]
# 去掉标点符号
df.replace('[,.!?,。!?]','', regex=True, inplace=True)
# print(df)
# exit()
# 全部转小写
df[0] = df[0].apply(lambda x: x.lower())
if english_to_chinese:
df.columns=['inputs','targets']
else:
df.columns=['targets','inputs']
# https://www.cnblogs.com/ppqchina/p/6270242.html
df['targets'] = df['targets'].apply(lambda x: '\t'+x+'\n')
input_texts = df.inputs.values.tolist()
target_texts = df.targets.values.tolist()
input_characters = sorted(list(set(df.inputs.unique().sum())))
target_characters = sorted(list(set(df.targets.unique().sum())))
INUPT_LENGTH = max([len(i) for i in input_texts])
OUTPUT_LENGTH = max([len(i) for i in target_texts])
INPUT_FEATURE_LENGTH = len(input_characters)
OUTPUT_FEATURE_LENGTH = len(target_characters)
encoder_input = np.zeros((NUM_SAMPLES,INUPT_LENGTH,INPUT_FEATURE_LENGTH))
decoder_input = np.zeros((NUM_SAMPLES,OUTPUT_LENGTH,OUTPUT_FEATURE_LENGTH))
decoder_output = np.zeros((NUM_SAMPLES,OUTPUT_LENGTH,OUTPUT_FEATURE_LENGTH))
input_dict = {char:index for index,char in enumerate(input_characters)}
input_dict_reverse = {index:char for index,char in enumerate(input_characters)}
target_dict = {char:index for index,char in enumerate(target_characters)}
target_dict_reverse = {index:char for index,char in enumerate(target_characters)}
for seq_index,seq in enumerate(input_texts):
for char_index, char in enumerate(seq):
encoder_input[seq_index,char_index,input_dict[char]] = 1
for seq_index,seq in enumerate(target_texts):
for char_index,char in enumerate(seq):
decoder_input[seq_index,char_index,target_dict[char]] = 1.0
if char_index > 0:
decoder_output[seq_index,char_index-1,target_dict[char]] = 1.0
print(''.join([input_dict_reverse[np.argmax(i)] for i in encoder_input[0] if max(i) !=0]))
print(''.join([target_dict_reverse[np.argmax(i)] for i in decoder_output[0] if max(i) !=0]))
print(''.join([target_dict_reverse[np.argmax(i)] for i in decoder_input[0] if max(i) !=0]))
if english_to_chinese:
model_name = 'model/en2zh_model.h5'
encoder_name = 'model/en2zh_encoder.h5'
decoder_name = 'model/en2zh_decoder.h5'
else:
model_name = 'model/zh2en_model.h5'
encoder_name = 'model/zh2en_encoder.h5'
decoder_name = 'model/zh2en_decoder.h5'
if continue_train: #继续训练
model_train = load_model(model_name)
encoder_infer = load_model(encoder_name)
decoder_infer = load_model(decoder_name)
else: #从头训练
model_train, encoder_infer, decoder_infer = create_model(INPUT_FEATURE_LENGTH, OUTPUT_FEATURE_LENGTH, N_UNITS)
#查看模型结构
# plot_model(to_file='model.png',model=model_train,show_shapes=True)
# plot_model(to_file='encoder.png',model=encoder_infer,show_shapes=True)
# plot_model(to_file='decoder.png',model=decoder_infer,show_shapes=True)
model_train.compile(optimizer='RMSprop', loss='categorical_crossentropy')
model_train.summary()
encoder_infer.summary()
decoder_infer.summary()
def predict_chinese(source,encoder_inference, decoder_inference, n_steps, features):
#先通过推理encoder获得预测输入序列的隐状态
state = encoder_inference.predict(source)
#第一个字符'\t',为起始标志
predict_seq = np.zeros((1,1,features))
predict_seq[0,0,target_dict['\t']] = 1
output = ''
#开始对encoder获得的隐状态进行推理
#每次循环用上次预测的字符作为输入来预测下一次的字符,直到预测出了终止符
for i in range(n_steps):#n_steps为句子最大长度
#给decoder输入上一个时刻的h,c隐状态,以及上一次的预测字符predict_seq
yhat,h,c = decoder_inference.predict([predict_seq]+state)
#注意,这里的yhat为Dense之后输出的结果,因此与h不同
char_index = np.argmax(yhat[0,-1,:])
char = target_dict_reverse[char_index]
output += char
state = [h,c]#本次状态做为下一次的初始状态继续传递
predict_seq = np.zeros((1,1,features))
predict_seq[0,0,char_index] = 1
if char == '\n':#预测到了终止符则停下来
break
return output
# for i in range(100):
model_train.fit([encoder_input,decoder_input],decoder_output,batch_size=BATCH_SIZE,epochs=EPOCH,validation_split=0.2)
print('saving model.........')
model_train.save(model_name)
encoder_infer.save(encoder_name)
decoder_infer.save(decoder_name)
for i in range(1,10):
test = encoder_input[i:i+1,:,:]#i:i+1保持数组是三维
out = predict_chinese(test,encoder_infer,decoder_infer,OUTPUT_LENGTH,OUTPUT_FEATURE_LENGTH)
#print(input_texts[i],'\n---\n',target_texts[i],'\n---\n',out)
print(input_texts[i])
print(out)
1
https://gitee.com/ling_jun_zuo/seq2seq_keras_translate.git
git@gitee.com:ling_jun_zuo/seq2seq_keras_translate.git
ling_jun_zuo
seq2seq_keras_translate
seq2seq_keras_translate
master

搜索帮助