1 Star 0 Fork 149

zhouhao / AscendSpeed

forked from Ascend / MindSpeed 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
swiglu.md 1.07 KB
一键复制 编辑 原始数据 按行查看 历史
dytc 提交于 2024-03-25 06:21 . !102 rms_norm增加适配

swiglu对外接口

前向接口:

输入:

  • x:必选输入,数据类型float16, bfloat16, float32

输出:

  • y:必选输出,数据类型float16, bfloat16, float32

属性:

  • dim:可选属性,数据类型int32_t,缺省-1。

反向接口:

输入:

  • dy:必选输入,数据类型float16, bfloat16, float32
  • x:必选输入,数据类型float16, bfloat16, float32

输出

  • dx:必选输出,数据类型float16, bfloat16, float32

属性:

  • dim:可选属性,数据类型int32_t,缺省-1。

案例

    import os
    import torch
    import torch_npu
    import numpy as np
    import math
    from ascendspeed.op_builder import SwigluOpBuilder

    x = np.random.uniform(-2, 2, (8192,1,3904))
    x = torch.from_numpy(x).float().npu()
    y_grad = np.random.uniform(-2, 2, (8192,1,1952))
    y_grad = torch.from_numpy(y_grad).float().npu()

    x.requires_grad = True
    // 正向接口案例
    ascendspeed_ops = SwigluOpBuilder.load()
    result = ascendspeed_ops.swiglu(x, dim=-1)
    // 反向接口案例
    result.backward(y_grad)
1
https://gitee.com/tus-aliez/AscendSpeed.git
git@gitee.com:tus-aliez/AscendSpeed.git
tus-aliez
AscendSpeed
AscendSpeed
master

搜索帮助