代码拉取完成,页面将自动刷新
同步操作将从 Ascend/MindSpeed 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import torch
import torch_npu
from .ascend_turbo_cfg import ascend_turbo_cfg
class ColumnSeqParallelLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias, group):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
rank = torch.distributed.get_rank(group)
hcomm_info = None
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
global_rank
)
else:
hcomm_info = group.get_hccl_comm_name(rank)
x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
world_size = ascend_turbo_cfg.get_world_size()
output, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
x,
weight.t(),
hcomm_info,
world_size,
bias=bias,
gather_index=0,
gather_output=(not ascend_turbo_cfg.all_gather_recomputation),
)
output = output.view(
int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
)
ctx.all_gather_output = all_gather_grad_output
ctx.world_size = world_size
ctx.group = group
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
grad_output_ = grad_output.reshape(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
if ascend_turbo_cfg.all_gather_recomputation:
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * ctx.world_size
all_gather_output = torch.empty(
dim_size,
dtype=input_.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
all_gather_work = torch.distributed._all_gather_base(
all_gather_output, input_.contiguous(), group=ctx.group, async_op=True
)
else:
all_gather_output = ctx.all_gather_output
grad_input = grad_output_.matmul(weight)
grad_input = grad_input.reshape(
grad_output.shape[0], grad_output.shape[1], weight.shape[1]
)
sub_grad_input = torch.empty(
list(input_.size()), dtype=input_.dtype, device=torch.cuda.current_device()
)
reduce_scatter_work = torch.distributed._reduce_scatter_base(
sub_grad_input, grad_input, group=ctx.group, async_op=True
)
if ascend_turbo_cfg.all_gather_recomputation:
all_gather_work.wait()
all_gather_output = all_gather_output.reshape(
all_gather_output.shape[0] * all_gather_output.shape[1],
all_gather_output.shape[2],
)
grad_weight = grad_output_.t().matmul(all_gather_output)
is_grad_bias_needed = ctx.needs_input_grad[2]
if is_grad_bias_needed and ctx.use_bias:
grad_bias = (
grad_output_.sum(dim=0)
if grad_output_.is_contiguous()
else grad_output_.t().sum(dim=1)
)
else:
grad_bias = None
reduce_scatter_work.wait()
return sub_grad_input, grad_weight, grad_bias, None
class RowSeqParallelLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias, group):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
rank = torch.distributed.get_rank(group)
world_size = ascend_turbo_cfg.get_world_size()
hcomm_info = None
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
hcomm_info = group._get_backend(torch.device("npu")).get_hccl_comm_name(
global_rank
)
else:
hcomm_info = group.get_hccl_comm_name(rank)
x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
output = torch_npu.npu_mm_reduce_scatter_base(
x, weight.t(), hcomm_info, world_size, reduce_op="sum", bias=bias
)
ctx.hcomm_info = hcomm_info
ctx.world_size = world_size
output = output.view(
int(output.shape[0] / input_.shape[1]), input_.shape[1], output.shape[1]
)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
hcomm_info = ctx.hcomm_info
world_size = ctx.world_size
grad_output_ = grad_output.reshape(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
grad_input, all_gather_grad_output = torch_npu.npu_all_gather_base_mm(
grad_output_, weight, hcomm_info, world_size, bias=None, gather_index=0
)
grad_input = grad_input.view_as(input_)
x = input_.reshape(input_.shape[0] * input_.shape[1], input_.shape[2])
grad_weight = all_gather_grad_output.t().matmul(x)
is_grad_bias_needed = ctx.needs_input_grad[2]
if is_grad_bias_needed and ctx.use_bias:
grad_bias = (
grad_output.sum(dim=0)
if grad_output.is_contiguous()
else grad_output.t().sum(dim=1)
)
else:
grad_bias = None
return grad_input, grad_weight, grad_bias, None
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。