2.3K Star 8K Fork 4.2K

GVPMindSpore / mindspore

 / 详情

图算expander支持复数计算及图层流程打通(基于GPU)

DONE
RFC
创建于  
2021-06-18 08:59

一、问题背景

当前的Mindspore框架基本都是实数域的相关应用,缺乏对复数域的支持,而引入复数计算,可以拓宽更多的业务场景,将框架的表达能力进一步提升。而对复数计算支持后,需要保证性能尽可能的优,因此采取图算融合的方式,寻求性能的更优解,所以需要对于一些经典的复数算子在图算侧进行实现并在图层中,流程完整打通(暂时以GPU为主)

二、场景分析

当前仅考虑最简单的场景打通的方式
若只考虑最简单的场景,那么实际上第一个版本仅需要将最基础的加减乘除单算子,进行图层的流程打通。下面我们以add为例进行更进一步的分析:
以调研过tensorflow对复数的支持为参考,它们并没有单独的对复数有专门的complex_add的接口来接收两个复数输入,而是将tf.add的功能做到完备,不仅支持实数域的计算,也同时支持复数域的计算,那么显然mindspore侧也无需要多添加若干个专门的复数接口进行定义,照常使用以后的ms.add即可,只不过需要我们图算侧进行感知适配。而且第一个版本当前只考虑最基础的输入输出均为复数类型的形式,一些虽然常见,但是不是最基础的场景,后面的未来改进可能会有提及。

三、基于最简单场景下的代码实现思路

I.Python侧
py侧的实现相对简单,只需要将一些基本复数类算子的逻辑梳理清楚,AKG侧支撑了这些基本的复数类算子原语实现(已经支持),就可以成功expander展开往下走。

II.C++侧
1.C++侧当前梳理,大概需要在graph_expander.h里创建一个complex_expander类,该类继承于基础的expander类(暂时不想再单独开一个文件来写,有需要也能拆,欢迎给建议)。
2.其次需要定义一些复数类可展开的列表,到时候对复数类进行doexpand的时候,查表,在表内的展开即可。
3.doexpand也得重写,基本的逻辑不发生具体改变,但是在符合原有doexpand条件的基础上,还需要特别判定一下,是否包含复数类输入,因此需要再加一个判断,也即拿到当前node的所有的input_type,若存在一个complex类型(当前没有complex类型,因此用float64代替),则进行正常的expander展开。
4.正如3所说,当前暂时没有complex类型支持,用float64类型代替的话,需要整体修改输入输出的float64类型为complex64让展开后生成的JSON能够正常进行复数编译。进一步分析应该还得在若干文件中加入float64, complex64的类型,以让后面的流程能够识别。
C++侧编写属实不熟,也可能当前最基础的版本依然设计的不是很好,麻烦各位多提意见,多指点一下我,谢谢!

四、未来的改进可能

1.正如前文提到,当前打通流程,都是按照最简单的形式来打通的,比较理想化的想象成输入输出均为complex类型的算子,且暂时只需要实现最基本的加减乘除算子,未来还得去实现FFT之类,只会更加复杂。
2.如果考虑更加泛化的场景,肯定会遇到两个输入,一个是complex类型,而另一个是正常float类型的情况,这种情况,当前的流程处理(仅仅是把所有输入输出改为complex,或者未来已经支持complex表达,单纯的判断输入是不是包含complex就做expander展开)下来肯定还是无法打通的,德仕是建议未来对于这种场景,识别到后,将float类型的输入的value取出,成为实部,再将虚部写成0,重新将该输入包一层,转化为complex类型,继续往下处理。
3.未完待续,大家也可以思考一下未来会遇到什么场景,需要进一步处理。

五、基于打通后的优化:CReal, CImag,Complex的消除

在最基本的复数单算子运算打通后,若在前端构造简易的算子网络---->多个简单算子构成。那么经过图算的cluster的pass处理后,这些算子将会聚合在一起。但是这种简单的聚合会造成非常多的冗余,如下简图所示:
输入图片说明
从截取的一小段IR上来看,也确实如此(重复取Para 55和56的实部显然是冗余的):

 %0([CNode]29) = CReal(%para3_[Parameter]55) primitive_attrs: {IsFeatureMapInputList: (0), IsFeatureMapOutput: true}
      : (<Tensor[Float64]x[const vector][4, 3]>) -> (<Tensor[Float32]x[const vector][4, 3]>)
      : (<Complex64xDefaultFormat[const vector][4, 3]>) -> (<Float32xDefaultFormat[const vector][4, 3]>)
      : (Default/CReal-op21)
  %1([CNode]31) = CReal(%para4_[Parameter]56) primitive_attrs: {IsFeatureMapInputList: (0), IsFeatureMapOutput: true}
      : (<Tensor[Float64]x[const vector][4, 3]>) -> (<Tensor[Float32]x[const vector][4, 3]>)
      : (<Complex64xDefaultFormat[const vector][4, 3]>) -> (<Float32xDefaultFormat[const vector][4, 3]>)
      : (Default/CReal-op22)
 %8([CNode]40) = CReal(%para3_[Parameter]55) primitive_attrs: {IsFeatureMapInputList: (0), IsFeatureMapOutput: true}
      : (<Tensor[Float64]x[const vector][4, 3]>) -> (<Tensor[Float32]x[const vector][4, 3]>)
      : (<Complex64xDefaultFormat[const vector][4, 3]>) -> (<Float32xDefaultFormat[const vector][4, 3]>)
      : (Default/CReal-op30)
  %9([CNode]42) = CReal(%para4_[Parameter]56) primitive_attrs: {IsFeatureMapInputList: (0), IsFeatureMapOutput: true}
      : (<Tensor[Float64]x[const vector][4, 3]>) -> (<Tensor[Float32]x[const vector][4, 3]>)
      : (<Complex64xDefaultFormat[const vector][4, 3]>) -> (<Float32xDefaultFormat[const vector][4, 3]>)
      : (Default/CReal-op31)

不过针对以上场景,CSE的PASS会将这类问题解决,问题可以规约为如雄哥评论所说的包含Complex + CReal, Complex + CImg的场景,如下所示:

%0([CNode]29) = CReal(%para3_[Parameter]55) 
%1([CNode]31) = CReal(%para4_[Parameter]56) 
%2([CNode]33) = Add(%0, %1) 
%5([CNode]36) = Add(%3, %4) 
%6([CNode]37) = Complex(%2, %5)
%7([CNode]20) = CReal(%6) 
%12([CNode]24) = Sub(%7, %11)

显然%6 %7和是可以消除的,因为%12里Sub用到的%7实际上就是%2,因此我们需要做的就是找到相关的Pattern,然后在图上重新连边即可做到消除的功能。化简后的场景如下图所示:
输入图片说明

不太清楚考虑的场景是否完善,不过对遍历图后具体相关操作,我还不是很熟练,需要参考代码,多讨论再来代码的编写。
简单讨论后分为以下四步来做消除:
1.遍历CNode节点,确认是否为graph_kernel node,如果是,获取到graph_kernel子图
2.基于获取到的graph_kernel子图,遍历节点找到所有Complex节点
3.遍历所有Complex对应的全部users,找到CReal, CImg的CNode
4.直接替换?Replace(CReal, complex->input[1]), Replace(CImag, complex->input[2])

评论 (12)

ZengZitao 创建了RFC
ZengZitao 关联仓库设置为MindSpore/mindspore
展开全部操作日志

Please add labels (comp or sig), also you can visit "https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md" to find more.
为了让问题更快得到响应,请您为该issue打上**组件(comp)或兴趣组(sig)**标签,打上标签的问题可以直接推送给责任人进行处理。更多的标签可以查看 https://gitee.com/mindspore/community/blob/master/sigs/dx/docs/labels.md"
以组件问题为例,如果你发现问题是data组件造成的,你可以这样评论:
//comp/data
当然你也可以向data SIG组求助,可以这样写:
//comp/data
//sig/data
如果是一个简单的问题,你可以留给刚进入社区的小伙伴来回答,这时候你可以这样写:
//good-first-issue
恭喜你,你已经学会了使用命令来打标签,接下来就在下面的评论里打上标签吧!

mindspore-dx-bot 负责人设置为ZengZitao
mindspore-dx-bot 添加了kind/feature(已删除)标签
mindspore-dx-bot 添加了
 
stat/wait-response
标签
ZengZitao 修改了描述

//kind/feature
//comp/akg
//sig/akg
//stat/discuss-welcome
//stat/wait-response

mindspore-dx-bot 添加了comp/akg(已删除)标签

//kind/discuss-welcome

mindspore-dx-bot 添加了kind/discuss-welcome(已删除)标签
mindspore-dx-bot 添加了
 
sig/akg
标签
ZengZitao 修改了描述
ZengZitao 修改了描述
ZengZitao 修改了描述

Good job, it's a good practice. :+1:

//good-example-for-all

mindspore-dx-bot 添加了
 
good-example-for-all
标签
zhunaipan 置顶等级设置为

I will set to the top for you.

ZengZitao 任务状态TODO 修改为ACCEPTED
ZengZitao 移除了kind/discuss-welcome(已删除)标签
ZengZitao 修改了描述
ZengZitao 修改了描述
ZengZitao 修改了描述
ZengZitao 修改了描述
ZengZitao 修改了描述
ZengZitao 修改了描述
ZengZitao 修改了描述

关于CReal, CImag,Complex的消除:
1) CReal CImg的消除: 这个应该不需要单独做,因为后面的CSE应该可以直接处理的;
2) Complex的消除: 本质是消除 Complex->Real/Imag对。所以按照这个pattern的进行搜索可能更简单?

ZengZitao 修改了描述
ZengZitao 修改了描述

Complex + Real/Imag的消除,类似make_tuple+getitem消除,参考 GetitemTuple 写一个?

ZengZitao 修改了描述

Complex + Real/Imag的消除,类似make_tuple+getitem消除,参考 GetitemTuple 写一个?

这个后面也可以放到代数化简里面做。写两条规则就行了。

Real(Complex(A,B))=A
Imag(Complex(A,B))=B

@jiaoy1224

ZengZitao 任务状态ACCEPTED 修改为DONE
ZengZitao 置顶等级 修改为不置顶
ZengZitao 移除了kind/feature(已删除)标签
ZengZitao 移除了comp/akg(已删除)标签

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(5)
8777557 test bot 1617846881 6560119 panza 1584156773
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore

搜索帮助