1 Star 1 Fork 0

Tim / mandarin-tts

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
loss.py 1.80 KB
一键复制 编辑 原始数据 按行查看 历史
ranch 提交于 2021-03-18 01:43 . bug fixed for 儿话音
import torch
import torch.nn as nn
import hparams as hp
from ipdb import set_trace
class FastSpeech2Loss(nn.Module):
""" FastSpeech2 Loss """
def __init__(self):
super(FastSpeech2Loss, self).__init__()
self.mse_loss = nn.MSELoss()
self.mae_loss = nn.L1Loss()
def forward(self, log_d_predicted, log_d_target, mel, mel_postnet, mel_target, src_mask, mel_mask):
log_d_target.requires_grad = False
# p_target.requires_grad = False
# e_target.requires_grad = False
mel_target.requires_grad = False
# p_smooth_loss = self.mae_loss(p_predicted[:,1:],p_predicted[:,:-1])
# e_smooth_loss = self.mae_loss(e_predicted[:,1:],e_predicted[:,:-1])
try:
log_d_predicted = log_d_predicted.masked_select(src_mask)
log_d_target = log_d_target.masked_select(src_mask)
except:
set_trace()
# p_predicted = p_predicted.masked_select(mel_mask)
# p_target = p_target.masked_select(mel_mask)
# e_predicted = e_predicted.masked_select(mel_mask)
# e_target = e_target.masked_select(mel_mask)
try:
mel = mel.masked_select(mel_mask.unsqueeze(-1))
mel_postnet = mel_postnet.masked_select(mel_mask.unsqueeze(-1))
mel_target = mel_target.masked_select(mel_mask.unsqueeze(-1))
mel_loss = self.mse_loss(mel, mel_target)*0.1
mel_postnet_loss = self.mse_loss(mel_postnet, mel_target)
d_loss = self.mae_loss(log_d_predicted, log_d_target)*0.01
except:
set_trace()
# p_loss = self.mse_loss(p_predicted, p_target)
# e_loss = self.mse_loss(e_predicted, e_target)
return mel_loss, mel_postnet_loss, d_loss#, p_loss+p_smooth_loss, e_loss+e_smooth_loss
Python
1
https://gitee.com/tuxg/mandarin-tts.git
git@gitee.com:tuxg/mandarin-tts.git
tuxg
mandarin-tts
mandarin-tts
master

搜索帮助