060、超分数据集构建:从 DIV2K 到 REDS 的数据预处理与增强方法
上周帮师弟调一个EDSR的复现,他拿着DIV2K训练了三天,PSNR死活上不去。我一看他的数据加载代码,好家伙,直接把HR图片resize成LR,再用双三次插值放大回来当输入——这哪是超分,这是让网络学双三次插值啊。这种坑我当年也踩过,今天就把数据预处理这块的实战经验掰开揉碎讲清楚。
数据集的“脏活累活”才是决定上限的关键
很多人觉得超分就是堆网络结构,其实数据预处理做不好,再好的模型也白搭。DIV2K和REDS是目前单图像超分和视频超分最常用的两个基准数据集,但它们的原始数据格式、退化方式、增强策略完全不同。
DIV2K:单图超分的“标准试卷”
DIV2K包含800张训练图、100张验证图、100张测试图,全是2K分辨率。拿到手第一件事不是直接切patch,而是检查图片的位深度——DIV2K原始是PNG格式,但有些图是16位深度的,直接读成uint8会丢失细节。我习惯用OpenCV的imread加IMREAD_UNCHANGED标志:
importcv2importnumpyasnp# 这里踩过坑:imread默认转成8位,16位图会截断img=cv2.imread('path/to/image.png',cv2.IMREAD_UNCHANGED)ifimg.dtype==np.uint16:img=(img/256).astype(np.uint8)# 别用//256,会丢失低8位信息DIV2K官方提供了LR版本,但那是用MATLAB的双三次插值生成的,和PyTorch的F.interpolate结果有细微差异。如果你要复现论文,建议直接用官方LR,别自己生成——我见过有人自己生成LR导致和论文结果对不上的惨案。
REDS:视频超分的“时间刺客”
REDS有240个训练视频片段,每个片段100帧,分辨率1280x720。视频超分的数据预处理比单图复杂得多,核心在于时间维度的对齐和光流计算。
REDS的原始帧是PNG序列,文件名格式是00000000.png到00000099.png。这里有个坑:有些帧是纯黑或纯白(相机标定帧),需要提前过滤掉。我写了个简单的帧质量检查:
defis_valid_frame(img_path,threshold=10):img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)# 别这样写:if np.mean(img) < 10 or np.mean(img) > 245# 应该检查方差,纯色帧方差接近0returnnp.std(img)>threshold数据增强:别让网络学会“偷懒”
数据增强不是越多越好,关键是让网络学到不变性。超分任务里,旋转和翻转是安全的,但颜色抖动要谨慎——超分本质是恢复高频细节,颜色变化会干扰网络对纹理的学习。
单图超分的增强策略
我常用的增强组合是:随机旋转(0/90/180/270度)、随机水平翻转、随机裁剪。裁剪时有个细节:LR和HR的裁剪位置必须对齐。很多人用random_crop分别裁LR和HR,结果位置对不上,网络学了个寂寞。
defpaired_random_crop(lr,hr,patch_size,scale):# 这里踩过坑:lr和hr的坐标要对应scale倍数h_lr,w_lr=lr.shape[:2]h_hr,w_hr=hr.shape[:2]# 别这样写:直接随机裁lr再放大# 应该先确定hr的裁剪位置,再映射到lrx_hr=np.random.randint(0,w_hr-patch_size+1)y_hr=np.random.randint(0,h_hr-patch_size+1)x_lr=x_hr//scale y_lr=y_hr//scale lr_patch=lr[y_lr:y_lr+patch_size//scale,x_lr:x_lr+patch_size//scale]hr_patch=hr[y_hr:y_hr+patch_size,x_hr:x_hr+patch_size]returnlr_patch,hr_patch视频超分的时序增强
视频超分里,时间反转(倒放)是个很有效的增强手段,能让网络学到双向运动信息。但要注意:光流也要跟着反转。我见过有人只反转帧序列,光流还是正向的,结果训练时梯度乱飞。
deftemporal_reverse(frames,flows):# 别这样写:只反转frames,flows不变# 应该同时反转flows的方向reversed_frames=frames[::-1].copy()reversed_flows=[-fforfinflows[::-1]]# 光流取反returnreversed_frames,reversed_flows退化模型:从“简单粗暴”到“以假乱真”
DIV2K和REDS的官方退化都是双三次下采样,但真实场景的退化复杂得多。如果你想做真实世界超分,需要模拟更复杂的退化:模糊、噪声、下采样、压缩伪影的组合。
经典退化流水线
我常用的退化模型是:先高斯模糊(核大小随机),再加噪声(高斯或泊松),然后双三次下采样,最后JPEG压缩(质量因子随机)。这个流水线在RCAN和SwinIR的论文里都有提到。
defdegrade_hr(hr,scale,blur_kernel=5,noise_std=0.01,jpeg_q=90):# 这里踩过坑:模糊核大小必须是奇数ifblur_kernel%2==0:blur_kernel+=1# 强制奇数blurred=cv2.GaussianBlur(hr,(blur_kernel,blur_kernel),0)# 别这样写:先下采样再加噪声# 应该先加噪声再下采样,更符合物理模型noise=np.random.randn(*hr.shape)*noise_std*255noisy=np.clip(blurred+noise,0,255).astype(np.uint8)lr=cv2.resize(noisy,(noisy.shape[1]//scale,noisy.shape[0]//scale),interpolation=cv2.INTER_CUBIC)# JPEG压缩模拟encode_param=[int(cv2.IMWRITE_JPEG_QUALITY),jpeg_q]_,enc_img=cv2.imencode('.jpg',lr,encode_param)lr=cv2.imdecode(enc_img,1)returnlr视频超分的退化特殊性
视频超分还要考虑帧间的退化一致性。比如模糊核在连续帧之间应该是平滑变化的,不能每帧随机一个模糊核——否则网络会学到帧间闪烁。我通常让模糊核参数在时间轴上做线性插值:
deftemporal_smooth_degrade(frames,scale,kernel_range=(3,7)):n_frames=len(frames)# 首尾帧的模糊核大小k_start=np.random.randint(*kernel_range)k_end=np.random.randint(*kernel_range)# 中间帧线性插值kernels=np.linspace(k_start,k_end,n_frames).astype(int)# 确保奇数kernels=kernels+1-kernels%2return[degrade_hr(f,scale,blur_kernel=k)forf,kinzip(frames,kernels)]数据加载的工程化技巧
内存管理
DIV2K的HR图每张约10MB,800张就是8GB,全加载到内存会爆。我习惯用lmdb或h5py做持久化存储,训练时随机读取。REDS更夸张,240个片段×100帧,全加载要上百GB。
importlmdbdefcreate_lmdb(dataset_path,output_path):# 这里踩过坑:map_size要设大,默认1MB不够env=lmdb.open(output_path,map_size=1099511627776)# 1TBwithenv.begin(write=True)astxn:forimg_pathinsorted(glob(dataset_path+'/*.png')):img=cv2.imread(img_path)key=img_path.split('/')[-1].encode()txn.put(key,cv2.imencode('.png',img)[1].tobytes())多尺度训练
很多超分模型(如EDSR、RCAN)支持多尺度训练,即同一个模型同时学习×2、×3、×4。数据加载时,需要为每张HR图生成多个尺度的LR。这里有个技巧:先下采样到最大尺度(比如×4),再上采样到其他尺度,避免重复计算。
defmulti_scale_lr(hr,scales=[2,3,4]):# 别这样写:对每个scale都从HR下采样# 应该先下采样到最大scale,再上采样max_scale=max(scales)lr_max=cv2.resize(hr,(hr.shape[1]//max_scale,hr.shape[0]//max_scale))lrs={}forsinscales:ifs==max_scale:lrs[s]=lr_maxelse:# 从lr_max上采样到目标尺度target_size=(lr_max.shape[1]*max_scale//s,lr_max.shape[0]*max_scale//s)lrs[s]=cv2.resize(lr_max,target_size,interpolation=cv2.INTER_CUBIC)returnlrs个人经验性建议
别迷信官方数据集:DIV2K和REDS的退化方式太理想化,如果你的应用场景是监控视频或手机拍照,建议自己构建退化模型。我做过一个实验:用DIV2K训练的模型在真实监控视频上PSNR掉了3dB,后来加了运动模糊和噪声模拟才追回来。
数据增强要“对症下药”:如果你的超分模型要处理老照片修复,多加点JPEG压缩和划痕模拟;如果是卫星图像超分,重点加高斯模糊和大气湍流模拟。别一股脑把所有增强都用上,网络会学成“万金油”,什么场景都做不好。
视频超分的数据加载是性能瓶颈:我见过有人用PyTorch的DataLoader直接读PNG序列,训练速度被IO拖慢5倍。建议用
torchvision.io.read_video直接读视频文件,或者用decord库做高效解码。REDS的帧序列可以提前打包成.mp4,读取速度提升明显。验证集和测试集的退化方式必须和训练集一致:这个看似废话,但很多人犯过。比如训练时用了随机模糊核,验证时却用固定模糊核,导致验证结果虚高。我习惯在训练脚本里固定随机种子,确保每次生成的退化参数可复现。
数据预处理代码要版本控制:超分实验的预处理逻辑经常调整,今天加个噪声,明天改个裁剪策略。如果不做版本控制,过两个月你自己都搞不清当时用的什么参数。我每个实验都会在代码里写一个
preprocess_config.yaml,记录所有预处理参数。
最后说句实在话:超分领域现在卷得厉害,模型结构越来越复杂,但很多顶会论文的数据预处理其实很粗糙。你把数据预处理做到极致,哪怕用个简单的SRCNN,效果也能超过那些花里胡哨的模型。数据才是王道,这句话在超分领域尤其适用。