DiffSynth-Studio训练踩坑记录:PyTorch 2.5.1 + Meta Tensor + 新增模块 + strict=True 导致的加载失败
环境:
- PyTorch 2.5.1
- DiffSynth-Studio / Wan2.1-T2V-1.3B
任务:在官方 WanVideo 模型基础上增加模块,继续训练 LoRA
这篇文章记录一次在 WanVideo 训练过程中遇到的模型加载问题,涉及到:
- PyTorch 2.5.1 的 meta tensor 机制;
- 使用
strict=True加载权重导致的结构不匹配; - 给模型增加新模块后如何正确加载旧 checkpoint。
最后训练已经成功跑起来,这里把整个排查和修复过程整理一下。
1. 场景简介:在 WanVideo 上加模块继续训练
训练命令大致如下:
nohupbashexamples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh\>wan2.1-1.3B.log2>&1&项目中使用:
self.pipe=WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16,device=device,model_configs=model_configs,tokenizer_config=tokenizer_config,audio_processor_config=audio_processor_config,)来下载并加载 Wan2.1-T2V-1.3B 模型。
我对模型做了一个改动:在原有模型的基础上增加了一些新的模块(新的层),希望用原模型的权重作为初始化,继续训练 LoRA。
自然地,旧 checkpoint 中没有新模块对应的权重,这为后面的报错埋下了伏笔。
2. 第一层坑:strict=True 导致结构不匹配直接报错
最开始,代码里加载权重用的是默认的strict=True,类似:
model.load_state_dict(state_dict)# 默认 strict=True当我在模型结构上增加模块之后,这些新增层的参数在 checkpoint 中不存在。
在strict=True的情况下,load_state_dict的行为是:
- 模型里有,但
state_dict里没有 → 归为missing_keys,直接报错; state_dict里有,但模型里没有 → 归为unexpected_keys,也会报错。
也就是说,只要你对模型结构进行了增删改,strict=True 会让加载必然失败。
正确做法应该是:
- 改成
strict=False:load_info=model.load_state_dict(state_dict,strict=False)print("missing:",load_info.missing_keys)print("unexpected:",load_info.unexpected_keys) - 允许结构不完全一致;
- 用
missing_keys/unexpected_keys明确看到哪些参数没加载上、哪些是 checkpoint 里多余的。
后面在修复时,我把这一步和 meta 问题一起处理了,最终使用:
load_info=model.load_state_dict(state_dict,assign=True,strict=False)下面先解释第二层坑:meta tensor。
3. 什么是 meta tensor?先理解再排错
在这次问题里,最关键的一条报错是:
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.这里的 “meta tensor” 其实是 PyTorch 提供的一种特殊机制:meta device。
3.1 meta 是什么?
可以简单理解为:
“只有形状和 dtype 信息、没有真实数据、不占显存的假张量,用来先搭模型结构、后装权重。”
比如:
withtorch.device("meta"):w=torch.empty(16,16)print(w.device)# device(type='meta')print(w.shape)# torch.Size([16, 16])print(w.is_meta)# True特点:
- 有 shape、有 dtype;
w.is_meta == True;- 不占 GPU/CPU 实际内存;
- 但不能做任何需要真实数据的操作,例如:
w.to("cuda")(要拷贝数据);w + 1(要访问数据);- 卷积、matmul 等计算。
3.2 为什么要有 meta?
大模型构建时,如果立刻在 GPU 上分配全部参数,很容易 OOM。
因此很多框架采用“空权重 / meta device”技术:
withtorch.device("meta"):model=MyBigModel()# 所有参数都是 meta tensor这一步只搭结构,不占真实显存。
后面再按策略,把参数真正“实例化”到 GPU / CPU 并加载权重。
3.3 和这次报错的关系
在 DiffSynth-Studio 里,构建 WanVideo 模型时就用了 meta 技术。
也就是说,模型刚被创建时,参数是is_meta=True,还没有真实数据。
而在diffsynth/core/loader/model.py中,当时的代码类似:
model=model.to(dtype=torch_dtype,device=device)这里直接对 meta 模型调用.to(),PyTorch 会尝试:
从 meta 上“拷贝”数据到目标设备。
但 meta 上根本没有数据,所以就抛出了:
Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() ...官方的意思是:
对 meta 模型,不要直接
.to(),而要用to_empty(device=...)先在目标设备上创建“空参数容器”,再用load_state_dict填入真正的 checkpoint 数据。
4. 正确的迁移方式:to_empty+load_state_dict
PyTorch 2.5.1 下的推荐做法如下:
在 meta 设备上构建模型:
withtorch.device("meta"):model=MyModel()不要直接
model.to("cuda"),而是:model=model.to_empty(device="cuda")# 只分配空参数容器然后用 checkpoint 填充:
state_dict=torch.load("xx.pth",map_location="cpu")model.load_state_dict(state_dict)
5. 把 meta + strict + 新增模块 这三件事串起来
结合前面的分析,这次问题本质上是三件事叠加:
- 模型是通过 meta 方式构建的;
- 我对模型结构做了修改(增加了模块);
- 一开始用的是
strict=True,随后在 meta 状态下直接调用了.to()。
想要一劳永逸地解决,需要做到:
- 判断是否有 meta 参数;
- 对 meta 模型用
to_empty(device=...)而不是.to(); - 用
load_state_dict(..., strict=False); - 打印
missing_keys/unexpected_keys; - 初始化新增模块。
最终,我在load_model中整理出的核心逻辑大概是这样:
# 1. 判断是否有 meta 参数has_meta_param=any(p.is_metaforpinmodel.parameters())ifhas_meta_param:# 2. meta 模型:使用 to_empty 迁移到目标设备# 注意:PyTorch 2.5.1 要用关键字参数 device=model=model.to_empty(device=device)# 3. 之后再统一 dtype(此时已经不是 meta 了,可以正常 .to)iftorch_dtypeisnotNone:model=model.to(dtype=torch_dtype)else:# 非 meta 模型:直接 .to 即可model=model.to(dtype=torch_dtype,device=device)# 4. 加载 checkpoint(非严格模式 + assign=True)load_info=model.load_state_dict(state_dict,assign=True,strict=False)missing=load_info.missing_keys unexpected=load_info.unexpected_keysifmissing:print("未加载到的参数 (missing_keys):")forkinmissing:print(" ",k)ifunexpected:print("多余的权重 (unexpected_keys)(state_dict 中有,但模型中没有):")forkinunexpected:print(" ",k)ifnotmissingandnotunexpected:print("所有参数均已加载")几点注意:
to_empty在 2.5.1 中不支持dtype=参数,所以用:model=model.to_empty(device=device)model=model.to(dtype=torch_dtype)strict=False是为了允许新增模块不在 checkpoint 中,避免硬报错。assign=True用的是 PyTorch 2.x 的新行为:直接让参数引用指向state_dict中的 tensor,减少一次拷贝。