如何用 EWC(Elastic Weight Consolidation)计算重要权重矩阵并加入损失?
解读
在国内大模型落地场景中,持续学习(Continual Learning) 是面试高频考点。面试官问“怎么用 EWC 算重要权重并加到损失”时,真正想验证三件事:
- 你是否理解大模型微调后灾难性遗忘的业务痛点;
- 能否把费雪信息矩阵(FIM)算对、存对、用对,不爆显存;
- 能否把 EWC 正则项无缝接入 transformers + DeepSpeed 的实战管线,而不是只背公式。
回答必须给出可落地的 PyTorch 伪代码,并说明在百亿模型分布式训练里如何只保存对角费雪,避免 O(d²) 存储。
知识点
- 费雪信息矩阵对角化:
对预训练权重 θ*,用任务 A 的样本估计 Fᵢ = 𝔼[∂log p(y|x,θ)/∂θᵢ]²,只保留对角线 F_diag,存储量 O(d)。 - 重要权重矩阵:
计算 Ω = F_diag + α(α=1e-7 防止除零),得到每个参数的重要性得分。 - EWC 损失:
L_EWC = L_taskB + λ/2 · Σᵢ Ωᵢ·(θᵢ − θ*ᵢ)²,其中 λ 是业务调参核心,一般 1e3~1e5。 - 大模型工程技巧:
- 用 activation checkpoint + ZeRO-3 把 Ω 拆到各 GPU,不复制完整参数。
- 只计算 一层 LayerNorm 或 Attention 输出层 的 FIM,降低采样成本。
- 把 Ω 存为 FP16 稀疏张量,裁剪尾部 5% 小值,再压缩为 safetensors 文件,方便后续热更新。
答案
步骤一:采样估计 FIM
# 伪代码,兼容 DeepSpeed ZeRO-3
model.eval()
fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters()}
for x, y in data_taskA_loader:
model.zero_grad()
loss = model(x, labels=y).loss
loss.backward()
for n, p in model.named_parameters():
if p.grad is not None:
fisher[n] += p.grad.pow(2) * len(x)
# 平均
num_samples = len(data_taskA_loader.dataset)
for n in fisher:
fisher[n] /= num_samples
fisher[n] = fisher[n].clamp_min(1e-7) # 得到 Ω
步骤二:保存锚点权重
anchor = {n: p.detach().clone() for n, p in model.named_parameters()}
torch.save({"anchor": anchor, "fisher": fisher}, "ewc_taskA.pt")
步骤三:加入损失继续微调任务 B
lambda_ewc = 5e4 # 业务调参
ewc_state = torch.load("ewc_taskA.pt")
anchor, fisher = ewc_state["anchor"], ewc_state["fisher"]
def ewc_loss(model):
loss = 0
for n, p in model.named_parameters():
if n in fisher:
loss += (fisher[n] * (p - anchor[n]) ** 2).sum()
return lambda_ewc * loss
# 训练循环
for x, y in data_taskB_loader:
loss_task = model(x, labels=y).loss
loss_ewc = ewc_loss(model)
(loss_task + loss_ewc).backward()
optimizer.step()
注意:在 百亿模型 场景下,只给 每层前 20% 参数 计算 fisher,可把显存从 800 GB 降到 80 GB,满足国内 A100 80G 八卡机预算。
拓展思考
- 在线 EWC:业务数据每天新增,可用 指数移动平均 更新 Ω,避免全量重算。
- 任务身份路由:当业务同时跑 N 个领域,可把 Ω 拆成 N 个掩码,用 task-id 动态选择正则子集,实现单模型多租户。
- 与 LoRA 结合:只给 LoRA 的 A、B 矩阵 算 fisher,存储量从 100 GB 降到 200 MB,方便 K8s 热更新;此时 λ 需放大 10 倍,防止低秩矩阵过度漂移。
- 国产化适配:在 昇腾 910B 上,fisher 计算需换成 NPU 自定义算子,并用 MindSpeed 的并行 API 替代 DeepSpeed,避免 CUDA 依赖。