425 Star 4.3K Fork 423

GVPPaddlePaddle / Paddle

 / 详情

paddle.nn.TransformerDecoder组网开启amp无法运行

待办的
创建于  
2023-07-25 16:18

请提出你的问题 Please ask your question

cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss() # 定义损失计算函数
model = SimpleNet(emb_dim=512, 
                            n_head=2, 
                            num_encoder_layers=2, 
                            num_decoder_layers=2, 
                            dim_feedforward=1024,
                            max_len=1024)  # 定义 SimpleNet 模型
optimizer = paddle.optimizer.SGD(learning_rate=0.0001, parameters=model.parameters())  # 定义 SGD 优化器

# 逻辑 2:可选,定义 GradScaler,用于缩放 loss 比例,避免浮点数溢出,默认开启动态更新 loss_scaling 机制
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

train_time = 0 # 记录总训练时长
for epoch in range(epochs):
    for i, (scr, tgt, lbl) in enumerate(loader):
        start_time = time.time() # 记录开始训练时刻
        lbl._to(place) # 将 label 数据拷贝到 gpu
        # 逻辑 1:创建 AMP-O1 auto_cast 环境,开启自动混合精度训练,将 add 算子添加到自定义白名单中(custom_white_list),
        # 因此前向计算过程中该算子将采用 float16 数据类型计算
        with paddle.amp.auto_cast(level='O1'):
            pre = model(scr, tgt) # 前向计算(9 层 Linear 网络,每层由 matmul、add 算子组成)
            loss = cross_entropy_loss(pre, lbl) # loss 计算
        # 逻辑 2:使用 GradScaler 完成 loss 的缩放,用缩放后的 loss 进行反向传播
        scaled = scaler.scale(loss) # loss 缩放,乘以系数 loss_scaling
        scaled.backward()           # 反向传播
        scaler.step(optimizer)      # 更新参数(参数梯度先除系数 loss_scaling 再更新参数)
        scaler.update()             # 基于动态 loss_scaling 策略更新 loss_scaling 系数
        optimizer.clear_grad(set_to_zero=False)
        # 记录训练 loss 及训练时长
        train_loss = loss.numpy()
        train_time += time.time() - start_time

print("loss:", train_loss)
print("使用 AMP-O1 模式耗时:{:.3f} sec".format(train_time/(epochs*nums_batch)))

评论 (0)

Zero 创建了任务

登录 后才可以发表评论

状态
负责人
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
参与者(1)
Python
1
https://gitee.com/paddlepaddle/Paddle.git
git@gitee.com:paddlepaddle/Paddle.git
paddlepaddle
Paddle
Paddle

搜索帮助