1 Star 1 Fork 0

左令君 / Machine-Translation

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
data_gen.py 3.13 KB
一键复制 编辑 原始数据 按行查看 历史
foamliu 提交于 2018-10-11 16:15 . update
# encoding=utf-8
import itertools
import numpy as np
from torch.utils.data import Dataset
from config import *
samples_path = 'data/samples_train.json'
samples = json.load(open(samples_path, 'r'))
# np.random.shuffle(samples)
def zeroPadding(l, fillvalue=PAD_token):
return list(itertools.zip_longest(*l, fillvalue=fillvalue))
def binaryMatrix(l):
m = []
for i, seq in enumerate(l):
m.append([])
for token in seq:
if token == PAD_token:
m[i].append(0)
else:
m[i].append(1)
return m
# Returns padded input sequence tensor and lengths
def inputVar(indexes_batch):
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
padList = zeroPadding(indexes_batch)
padVar = torch.LongTensor(padList)
return padVar, lengths
# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(indexes_batch):
max_target_len = max([len(indexes) for indexes in indexes_batch])
padList = zeroPadding(indexes_batch)
mask = binaryMatrix(padList)
mask = torch.ByteTensor(mask)
padVar = torch.LongTensor(padList)
return padVar, mask, max_target_len
# Returns all items for a given batch of pairs
def batch2TrainData(pair_batch):
pair_batch.sort(key=lambda x: len(x[0]), reverse=True)
input_batch, output_batch = [], []
for pair in pair_batch:
input_batch.append(pair[0])
output_batch.append(pair[1])
inp, lengths = inputVar(input_batch)
output, mask, max_target_len = outputVar(output_batch)
return inp, lengths, output, mask, max_target_len
class TranslationDataset(Dataset):
def __init__(self, split):
self.split = split
assert self.split in {'train', 'valid'}
print('loading {} samples'.format(split))
train_count = int(len(samples) * train_split)
if split == 'train':
self.samples = samples[:train_count]
else:
self.samples = samples[train_count:]
self.num_chunks = len(self.samples) // chunk_size
np.random.shuffle(self.samples)
print('count: ' + str(len(self.samples)))
def __getitem__(self, i):
start_idx = i * chunk_size
pair_batch = []
for i_batch in range(chunk_size):
sample = self.samples[start_idx + i_batch]
pair_batch.append((sample['input'], sample['output']))
return batch2TrainData(pair_batch)
def __len__(self):
return self.num_chunks
if __name__ == '__main__':
print('loading {} samples'.format('train'))
samples_path = 'data/samples_train.json'
samples = json.load(open(samples_path, 'r'))
pair_batch = []
for i in range(5):
sample = samples[i]
pair_batch.append((sample['input'], sample['output']))
# Example for validation
small_batch_size = 5
batches = batch2TrainData(pair_batch)
input_variable, lengths, target_variable, mask, max_target_len = batches
print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)
1
https://gitee.com/ling_jun_zuo/Machine-Translation.git
git@gitee.com:ling_jun_zuo/Machine-Translation.git
ling_jun_zuo
Machine-Translation
Machine-Translation
master

搜索帮助