代码拉取完成,页面将自动刷新
# import the necessary packages
from models import EncoderRNN, LuongAttnDecoderRNN
from utils import *
if __name__ == '__main__':
input_lang = Lang('data/WORDMAP_zh.json')
output_lang = Lang('data/WORDMAP_en.json')
print("input_lang.n_words: " + str(input_lang.n_words))
print("output_lang.n_words: " + str(output_lang.n_words))
checkpoint = '{}/BEST_checkpoint.tar'.format(save_dir) # model checkpoint
print('checkpoint: ' + str(checkpoint))
# Load model
checkpoint = torch.load(checkpoint)
encoder_sd = checkpoint['en']
decoder_sd = checkpoint['de']
print('Building encoder and decoder ...')
# Initialize encoder & decoder models
encoder = EncoderRNN(input_lang.n_words, hidden_size, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, hidden_size, output_lang.n_words, decoder_n_layers, dropout)
encoder.load_state_dict(encoder_sd)
decoder.load_state_dict(decoder_sd)
# Use appropriate device
encoder = encoder.to(device)
decoder = decoder.to(device)
print('Models built and ready to go!')
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()
# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)
for input_sentence, target_sentence in pick_n_valid_sentences(input_lang, output_lang, 10):
decoded_words = evaluate(searcher, input_sentence, input_lang, output_lang)
print('> {}'.format(input_sentence))
print('= {}'.format(target_sentence))
print('< {}'.format(' '.join(decoded_words)))
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。