enable_input_require_grads

model.enable_input_require_grads()

问题

在进行大模型微调时,如果同时满足以下两个条件:

  • 使用 PEFT(如 LoRA):仅训练少量增量参数,冻结了底座模型(Embedding 层及大部分层)。
  • 开启 Gradient Checkpointing:为了节省显存,不保存中间激活值。

会发生报错:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

结论:报错的直接原因是 Loss 变成了一个孤立的标量,它丢失了通往参数的“导航图”(grad_fn 为 None)。定位(拷打Gemini)后确认问题与LoRA无关,是源于

在开启 gradient checkpoint 的前提下,仅训练少量参数(非全参)时,反向传播计算图未能正确建立,导致梯度在 checkpoint 边界被切断。

解决方法

显式开启输入张量的梯度追踪,即在训练前加上

1
model.enable_input_require_grads()

其效果等价于:

强制模型 forward 的输入张量 requires_grad=True

原理

Gradient checkpoint会把“是否构建反向传播计算图”的决定,从forward阶段推迟到backward阶段。

在局部参数训练(PEFT)场景下,如果checkpoint的re-forward输入不具备 requires_grad=True,autograd会将整个checkpoint包裹的计算视为“从常量到常量的映射”,从而:

  • 不创建 grad_fn
  • 不保存 SavedTensor
  • 可训练参数(如 LoRA)无法被挂载进 GraphTask
  • 最终表现为参数grad_fn=None, 直接报错

checkpoint机制

普通 forward

1
2
3
4
autograd 立刻:
- 检查 inputs / params
- 决定是否创建 grad_fn
- 保存 SavedTensor

在这种模式下,即使输入不需要梯度,只要算子中使用了 requires_grad=True 的参数,参数梯度也可以被正常计算。


checkpoint forward(第一次)

1
2
3
4
5
❌ 不建 backward graph
❌ 不保存 SavedTensor
✅ 只 stash:
- Python function block
- Tensor inputs

此阶段几乎不涉及 autograd 图的构建。


checkpoint backward

1
2
3
4
5
6
CheckpointBackward.apply(...)

re-forward: block(*inputs)

这一次 forward,才由 autograd 判断:
“要不要构建 backward graph?”

checkpoint本质上是一个“延迟建图”的机制。而在checkpoint的re-forward中,autograd仅依据“输入Tensor是否requires_grad”来决定是否建图

在PEFT场景下:假设我们冻结了Embedding和前30层。这意味着传给第31层(被 Checkpoint 包裹的黑盒块)的输入 Tensor 是没有梯度的。因此Autograd会认为这个黑盒内部的所有计算都不需要回溯,即使黑盒中有需要梯度的LoRA参数,但autograd不会为它们构建grad_fn了。

为什么autograd不能通过检查params来决定是否建图?

  • 动态图的“不透明性”
    Checkpoint 接收的是一个普通的 Python 函数。由于 Python 是动态语言,Autograd 无法在不执行函数的情况下,预知函数内部会访问哪些参数,或者走哪条 if-else 分支。
  • 算子驱动 vs 结构驱动
    Autograd 是由 Tensor 运算 驱动的。它只关心当前的算子输入。如果它每走一步都要去扫描整个 nn.Module 的状态,Checkpoint 的性能开销将变得不可接受。
  • 隔离性
    Checkpoint 的设计初衷是将一段复杂的计算逻辑抽象为一个独立的函数。为了保证逻辑的解耦,它必须依赖输入张量的状态来决定是否开启“追踪模式”。

requires_grad=True究竟改变了什么?

设置了requires_grad=True后autograd 被迫把这条计算路径视为“可微路径”,从而允许(并在需要时)为中间结果创建 grad_fn


enable_input_require_grads
https://jyk-122.github.io/2026/02/26/enable_input_require_grads/
作者
Yikun Jiang
发布于
2026年2月26日
许可协议