1 Star 1 Fork 0

Tim / mandarin-tts

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
optimizer.py 1.01 KB
一键复制 编辑 原始数据 按行查看 历史
ranchlai_insta360 提交于 2021-02-19 13:41 . init
import numpy as np
class ScheduledOptim():
''' A simple wrapper class for learning rate scheduling '''
def __init__(self, optimizer, d_model, n_warmup_steps, current_steps):
self._optimizer = optimizer
self.n_warmup_steps = n_warmup_steps
self.n_current_steps = current_steps
self.init_lr = np.power(d_model, -0.5)
print('init lr',self.init_lr)
def step_and_update_lr(self):
self._update_learning_rate()
self._optimizer.step()
def zero_grad(self):
# print(self.init_lr)
self._optimizer.zero_grad()
def _get_lr_scale(self):
return np.min([
np.power(self.n_current_steps, -0.5),
np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
def _update_learning_rate(self):
''' Learning rate scheduling per step '''
self.n_current_steps += 1
lr = self.init_lr * self._get_lr_scale()
for param_group in self._optimizer.param_groups:
param_group['lr'] = lr
Python
1
https://gitee.com/tuxg/mandarin-tts.git
git@gitee.com:tuxg/mandarin-tts.git
tuxg
mandarin-tts
mandarin-tts
master

搜索帮助