代码拉取完成,页面将自动刷新
import json
import os
import xml.etree.ElementTree
from collections import Counter
import jieba
import nltk
from tqdm import tqdm
from config import output_lang_vocab_size, input_lang_vocab_size, max_len, UNK_token
from config import train_translation_folder, train_translation_zh_filename, train_translation_en_filename
from config import valid_translation_folder, valid_translation_zh_filename, valid_translation_en_filename
from utils import normalizeString, encode_text
def build_wordmap_zh():
translation_path = os.path.join(train_translation_folder, train_translation_zh_filename)
with open(translation_path, 'r') as f:
sentences = f.readlines()
word_freq = Counter()
for sentence in tqdm(sentences):
seg_list = jieba.cut(sentence.strip())
# Update word frequency
word_freq.update(list(seg_list))
# Create word map
# words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
words = word_freq.most_common(output_lang_vocab_size - 4)
word_map = {k[0]: v + 4 for v, k in enumerate(words)}
word_map['<pad>'] = 0
word_map['<start>'] = 1
word_map['<end>'] = 2
word_map['<unk>'] = 3
print(len(word_map))
print(words[:10])
with open('data/WORDMAP_zh.json', 'w') as file:
json.dump(word_map, file, indent=4)
def build_wordmap_en():
translation_path = os.path.join(train_translation_folder, train_translation_en_filename)
with open(translation_path, 'r') as f:
sentences = f.readlines()
word_freq = Counter()
for sentence in tqdm(sentences):
sentence_en = sentence.strip().lower()
tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en) if len(normalizeString(s)) > 0]
# Update word frequency
word_freq.update(tokens)
# Create word map
# words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
words = word_freq.most_common(input_lang_vocab_size - 4)
word_map = {k[0]: v + 4 for v, k in enumerate(words)}
word_map['<pad>'] = 0
word_map['<start>'] = 1
word_map['<end>'] = 2
word_map['<unk>'] = 3
print(len(word_map))
print(words[:10])
with open('data/WORDMAP_en.json', 'w') as file:
json.dump(word_map, file, indent=4)
def extract_valid_data():
valid_translation_path = os.path.join(valid_translation_folder, 'valid.en-zh.en.sgm')
with open(valid_translation_path, 'r') as f:
data_en = f.readlines()
data_en = [line.replace(' & ', ' & ') for line in data_en]
with open(valid_translation_path, 'w') as f:
f.writelines(data_en)
root = xml.etree.ElementTree.parse(valid_translation_path).getroot()
data_en = [elem.text.strip() for elem in root.iter() if elem.tag == 'seg']
with open(os.path.join(valid_translation_folder, 'valid.en'), 'w') as out_file:
out_file.write('\n'.join(data_en) + '\n')
root = xml.etree.ElementTree.parse(os.path.join(valid_translation_folder, 'valid.en-zh.zh.sgm')).getroot()
data_zh = [elem.text.strip() for elem in root.iter() if elem.tag == 'seg']
with open(os.path.join(valid_translation_folder, 'valid.zh'), 'w') as out_file:
out_file.write('\n'.join(data_zh) + '\n')
def build_samples():
word_map_zh = json.load(open('data/WORDMAP_zh.json', 'r'))
word_map_en = json.load(open('data/WORDMAP_en.json', 'r'))
for usage in ['train', 'valid']:
if usage == 'train':
translation_path_en = os.path.join(train_translation_folder, train_translation_en_filename)
translation_path_zh = os.path.join(train_translation_folder, train_translation_zh_filename)
filename = 'data/samples_train.json'
else:
translation_path_en = os.path.join(valid_translation_folder, valid_translation_en_filename)
translation_path_zh = os.path.join(valid_translation_folder, valid_translation_zh_filename)
filename = 'data/samples_valid.json'
print('loading {} texts and vocab'.format(usage))
with open(translation_path_en, 'r') as f:
data_en = f.readlines()
with open(translation_path_zh, 'r') as f:
data_zh = f.readlines()
print('building {} samples'.format(usage))
samples = []
for idx in tqdm(range(len(data_en))):
sentence_zh = data_zh[idx].strip()
seg_list = jieba.cut(sentence_zh)
input_zh = encode_text(word_map_zh, list(seg_list))
sentence_en = data_en[idx].strip().lower()
tokens = [normalizeString(s) for s in nltk.word_tokenize(sentence_en) if len(normalizeString(s)) > 0]
output_en = encode_text(word_map_en, tokens)
if len(input_zh) <= max_len and len(
output_en) <= max_len and UNK_token not in input_zh and UNK_token not in output_en:
samples.append({'input': list(input_zh), 'output': list(output_en)})
with open(filename, 'w') as f:
json.dump(samples, f, indent=4)
print('{} {} samples created at: {}.'.format(len(samples), usage, filename))
if __name__ == '__main__':
build_wordmap_zh()
build_wordmap_en()
extract_valid_data()
build_samples()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。