1 Star 4 Fork 1

Tim / 基于循环神经网络(RNN)的智能聊天机器人系统

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
test.py 3.75 KB
一键复制 编辑 原始数据 按行查看 历史
hhhvvvddd 提交于 2019-04-28 23:48 . Add files via upload
import tensorflow as tf
import numpy as np
import os
from six.moves import xrange
_buckets = []
convo_hist_limit = 1
max_source_length = 0
max_target_length = 0
flags = tf.app.flags
FLAGS = flags.FLAGS
import datautil
import seq2seq_model
tf.reset_default_graph()
max_train_data_size= 0#(0: no limit)
dropout = 1.0
grad_clip = 5.0
batch_size = 60
hidden_size = 14
num_layers =2
learning_rate =0.5
lr_decay_factor =0.99
hidden_size = 100
checkpoint_dir= "datacn/checkpoints/"
_buckets =[(5, 5), (10, 10), (20, 20)]
def getdialogInfo():
vocabch, rev_vocabch=datautil.initialize_vocabulary(os.path.join(datautil.data_dir, datautil.vocabulary_filech))
vocab_sizech= len(vocabch)
print("vocab_sizech",vocab_sizech)
filesfrom,_=datautil.getRawFileList(datautil.data_dir+"fromids/")
filesto,_=datautil.getRawFileList(datautil.data_dir+"toids/")
source_train_file_path = filesfrom[0]
target_train_file_path= filesto[0]
return vocab_sizech,vocab_sizech,vocabch,rev_vocabch
def main():
vocab_sizeen,vocab_sizech,vocaben,rev_vocabch= getdialogInfo()
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
print ("checkpoint_dir is {0}".format(checkpoint_dir))
with tf.Session() as sess:
model = createModel(sess,True,vocab_sizeen,vocab_sizech)
print (_buckets)
model.batch_size = 1
conversation_history =[]
while True:
prompt = "请输入: "
sentence = input(prompt)
conversation_history.append(sentence.strip())
conversation_history = conversation_history[-convo_hist_limit:]
token_ids = list(reversed( datautil.sentence_to_ids(" ".join(conversation_history) ,vocaben,normalize_digits=True,Isch=True) ) )
#token_ids = list(reversed(vocab.tokens2Indices(" ".join(conversation_history))))
print(token_ids)
#token_ids = list(reversed(vocab.tokens2Indices(sentence)))
bucket_id = min([b for b in xrange(len(_buckets))if _buckets[b][0] > len(token_ids)])
encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id)
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, True)
#TODO implement beam search
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
print("outputs",outputs,datautil.EOS_ID)
if datautil.EOS_ID in outputs:
outputs = outputs[:outputs.index(datautil.EOS_ID)]
#print(vocab.indices2Tokens(outputs))
#print("结果",datautil.ids2texts(outputs,rev_vocabch))
convo_output = " ".join(datautil.ids2texts(outputs,rev_vocabch))
conversation_history.append(convo_output)
print (convo_output)
else:
print("can not translation!")
def createModel(session, forward_only,from_vocab_size,to_vocab_size):
"""Create translation model and initialize or load parameters in session."""
model = seq2seq_model.Seq2SeqModel(
from_vocab_size,#from
to_vocab_size,#to
_buckets,
hidden_size,
num_layers,
dropout,
grad_clip,
batch_size,
learning_rate,
lr_decay_factor,
forward_only=forward_only,
dtype=tf.float32)
print("model is ok")
ckpt = tf.train.latest_checkpoint(checkpoint_dir)
if ckpt!=None:
model.saver.restore(session, ckpt)
print ("Reading model parameters from {0}".format(ckpt))
else:
print ("Created model with fresh parameters.")
session.run(tf.global_variables_initializer())
return model
if __name__=="__main__":
main()
Python
1
https://gitee.com/tuxg/rnn.git
git@gitee.com:tuxg/rnn.git
tuxg
rnn
基于循环神经网络(RNN)的智能聊天机器人系统
master

搜索帮助