1 Star 1 Fork 137

walker / AscendSpeed

forked from Ascend / MindSpeed 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
rms_norm.md 1.17 KB
一键复制 编辑 原始数据 按行查看 历史
dytc 提交于 2024-03-25 06:21 . !102 rms_norm增加适配

rms_norm对外接口

前向接口:

输入:

  • x:必选输入,数据类型float16, bfloat16, float32
  • gamma:必选输入,数据类型float16, bfloat16, float32 输出
  • rstd:中间变量,数据类型float32
  • out:必选输出,数据类型float16, bfloat16, float32

属性:

  • epsilon:可选属性,数据类型float,缺省1e-6。

反向接口:

输入:

  • dy:必选输入,数据类型float16, bfloat16, float32
  • x:必选输入,数据类型float16, bfloat16, float32
  • rstd:必选输入,数据类型float32
  • gamma:必选输入,数据类型float16, bfloat16, float32 输出
  • dx:必选输出,数据类型float16, bfloat16, float32
  • dgammma:必选输出,数据类型float32

案例

    import os
    import torch
    import torch_npu
    import numpy as np
    import math
	from ascendspeed.op_builder import RmsNormOpBuilder
    
    x = torch.rand([2,4]).to(torch.float).npu()
    gamma = torch.rand([4]).to(torch.float).npu()
    // 正向接口案例
    ascendspeed_ops = RmsNormOpBuilder.load()
    out = ascendspeed_ops.rms_norm(x, gamma) 
    // 反向接口案例
    out.backward(torch.ones(out.shape).npu())
1
https://gitee.com/walker963/AscendSpeed.git
git@gitee.com:walker963/AscendSpeed.git
walker963
AscendSpeed
AscendSpeed
master

搜索帮助