checkpoint
模型或模型的一部分
checkpoint
通过交换计算内存来工作。而不是存储整个计算图的所有中间激活用于向后计算,checkpoint
不会不保存中间激活部分,而是在反向传递中重新计算它们。它可以应用于模型的任何部分。
具体来说,在正向传递中,function
将以torch.no_grad()
方式运行 ,即不存储中间激活。相反,正向传递保存输入元组和 function
参数。在向后计算中,保存的输入变量以及 function
会被回收,并且正向计算被function
再次计算 ,现在跟踪中间激活,然后使用这些激活值来计算梯度。
警告
checkpoint
在torch.autograd.grad()
中不起作用,但仅适用于torch.autograd.backward()
。警告 如果
function
在向后执行和前向执行都不同,例如由于某个全局变量,checkpoint
版本将不等同,并且不幸的是无法检测到。
参数:
返回: attrfunction
开*args
返回类型: 运行输出
用于checkpoint
sequential
模型的辅助函数。
sequential 模型按顺序执行一系列模块/函数(按顺序)。因此,我们可以将这种模型分为不同的部分和
checkpoint。除最后一个段以外的所有段都将以某种
torch.no_grad()方式运行 ,即不存储中间活动。将保存每个
checkpoint`段的输入,以便在向后传递中重新运行段。
关于checkpoint
如何工作可以参考checkpoint()。
警告
checkpoint
在torch.autograd.grad()
中不起作用,但仅适用于torch.autograd.backward()
。
参数:
返回:
functions
按顺序运行的输出*inputs
例:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
部分地方存在翻译错误,即将修复
用户名 | 头像 | 职能 | 签名 |
---|---|---|---|
Song | 翻译 | 人生总要追求点什么 |
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。