想二次开发?FFT NPainting LaMa项目结构先了解
本文面向希望基于
fft npainting lama镜像做定制化开发的工程师,不讲原理、不堆参数,只带你一层层拆开项目骨架,看清每个目录、每个文件的真实职责——让你改得明白、调得顺手、扩得安心。
1. 项目定位与二次开发前提
1.1 这不是个“黑盒WebUI”,而是一套可插拔的修复流水线
很多用户把fft npainting lama当成普通AI工具:上传→画笔→点击→下载。但如果你打开终端执行ls -R /root/cv_fft_inpainting_lama/,会发现它本质是一个前后端分离、模型即服务、逻辑高度解耦的工程系统:
- 前端是轻量级Gradio WebUI(非React/Vue单页应用)
- 后端是Python推理服务(非Flask/FastAPI全栈框架)
- 核心修复能力封装在独立模块中(非胶水代码拼接)
这意味着:你不需要重写整个界面,也不必动模型训练代码,只需找准接口缝合点,就能快速注入新功能。
1.2 二次开发前必须确认的三件事
| 事项 | 检查方式 | 不满足后果 |
|---|---|---|
| Python环境纯净性 | python3 -c "import torch; print(torch.__version__)"; pip list | grep lama | 模型加载失败、CUDA报错、依赖冲突 |
| 模型权重已就位 | ls /root/cv_fft_inpainting_lama/models/应含big-lama或places2目录 | 启动时报FileNotFoundError: models/big-lama/ |
| WebUI入口清晰可定位 | grep -r "gradio.App" /root/cv_fft_inpainting_lama/应返回app.py路径 | 修改后无法生效、热重载失效 |
提示:所有检查命令均可直接粘贴到服务器终端执行,无需额外安装工具。
2. 项目根目录结构全景图
执行tree -L 2 /root/cv_fft_inpainting_lama/得到如下精简视图(已过滤.git、__pycache__等无关项):
/root/cv_fft_inpainting_lama/ ├── app.py # WebUI主入口:定义界面布局、事件绑定、调用链路 ├── inference.py # 核心推理门面:统一接收输入、调度模型、返回结果 ├── models/ # 模型权重与配置 │ ├── big-lama/ # 主力修复模型(LaMa + FFT增强) │ └── places2/ # 备用模型(传统LaMa) ├── src/ # 功能模块化核心(重点!) │ ├── base/ # 基础工具:图像预处理、mask生成、后处理 │ ├── fft/ # FFT频域增强模块(项目特色) │ ├── lama/ # LaMa模型封装(加载、推理、输出解析) │ └── utils/ # 通用函数:路径管理、日志、配置读取 ├── outputs/ # 自动保存修复结果(无需手动创建) ├── assets/ # 前端静态资源(图标、CSS、JS片段) ├── start_app.sh # 启动脚本:环境检查 + 服务拉起 └── requirements.txt # 依赖声明(注意:含torch版本锁定)这个结构不是随意组织的——它严格遵循关注点分离原则:app.py只管“怎么展示”,inference.py只管“怎么调用”,src/下各子模块只管“自己那块事”。
3. 关键文件逐行解析:从启动到修复的完整链路
3.1start_app.sh:服务启动的“总开关”
#!/bin/bash cd /root/cv_fft_inpainting_lama # 1. 环境自检(关键!) if ! command -v python3 &> /dev/null; then echo " Python3未安装,请先配置环境" exit 1 fi # 2. 检查模型是否存在(避免运行时崩溃) if [ ! -d "models/big-lama" ]; then echo " 模型目录缺失:models/big-lama" echo " 请从官方仓库下载并解压至该路径" exit 1 fi # 3. 拉起WebUI(核心命令) echo " 正在启动WebUI..." python3 app.py --server-port 7860 --server-name 0.0.0.0二次开发提示:
- 若需修改端口,直接改
--server-port参数即可,无需动app.py - 若需添加GPU设备控制,可在
python3 app.py前加CUDA_VISIBLE_DEVICES=0 - 所有环境检查逻辑都集中在此脚本,新增依赖检查也应加在这里
3.2app.py:WebUI的“神经中枢”
这是你最常修改的文件,但改动必须克制。其核心结构如下:
import gradio as gr from inference import run_inpainting # ← 关键:所有业务逻辑都在inference.py里 def process_image(image, mask): # 1. 输入校验(尺寸、格式) # 2. 调用inference.run_inpainting()执行修复 # 3. 返回结果图像+状态文本 return result_img, f" 完成!已保存至: {output_path}" # 定义Gradio界面 with gr.Blocks(title=" 图像修复系统") as demo: gr.Markdown("## 图像修复系统 | webUI二次开发 by 科哥") with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label=" 图像编辑区") # ← 前端输入组件 mask_img = gr.Image(type="pil", label="🖌 Mask标注", visible=False) # ← 隐藏的mask通道 run_btn = gr.Button(" 开始修复") with gr.Column(): output_img = gr.Image(type="pil", label="📷 修复结果") # ← 前端输出组件 status_text = gr.Textbox(label=" 处理状态") # 绑定事件(核心!) run_btn.click( fn=process_image, inputs=[input_img, mask_img], # ← 输入来自前端组件 outputs=[output_img, status_text] # ← 输出传给前端组件 ) if __name__ == "__main__": demo.launch(server_port=7860, server_name="0.0.0.0")二次开发安全区:
可安全修改:gr.Markdown()中的标题文案、按钮文字、组件label值
可安全扩展:在with gr.Row():内新增列,添加自定义控件(如风格选择下拉框)
绝对禁止:修改fn=process_image的函数签名或调用逻辑——这会破坏与inference.py的契约
3.3inference.py:修复能力的“唯一出口”
这是整个项目最核心的文件,也是你二次开发的主战场。其设计哲学是:一个函数,完成全部修复流程。
import os import numpy as np from PIL import Image from src.lama.lama_model import LaMaModel from src.fft.fft_enhancer import apply_fft_enhance from src.base.preprocess import prepare_input from src.base.postprocess import save_result def run_inpainting(input_pil: Image.Image, mask_pil: Image.Image) -> tuple[Image.Image, str]: """ 执行完整修复流程 :param input_pil: 原图(PIL.Image) :param mask_pil: 修复区域mask(白色为修复区,PIL.Image) :return: (修复后图像, 状态描述) """ # 步骤1:预处理(统一尺寸、归一化、转tensor) input_tensor, mask_tensor = prepare_input(input_pil, mask_pil) # 步骤2:加载模型(单例模式,避免重复加载) model = LaMaModel.get_instance(model_path="models/big-lama") # 步骤3:基础修复(LaMa原生推理) pred_tensor = model.inpaint(input_tensor, mask_tensor) # ← 核心推理调用 # 步骤4:FFT频域增强(项目特色!) enhanced_tensor = apply_fft_enhance(pred_tensor) # ← 关键增强点 # 步骤5:后处理(转PIL、保存、生成状态文本) result_pil = save_result(enhanced_tensor) status = f" 完成!已保存至: {os.path.abspath('outputs/')}" return result_pil, status二次开发黄金位置:
🔹 在步骤3后插入自定义后处理(如:添加水印、调整色温)
🔹 在步骤4替换apply_fft_enhance为你自己的增强函数(保持输入/输出tensor格式一致)
🔹 在步骤5修改save_result()的保存逻辑(如:按时间分目录、生成JSON元数据)
提示:所有模型加载、设备分配(CPU/GPU)、精度控制(FP16/FP32)都封装在
LaMaModel类中,不要在run_inpainting里硬编码torch.cuda.set_device()。
4.src/模块深度拆解:你的能力扩展包
4.1src/base/:稳定可靠的地基
| 文件 | 职责 | 二次开发建议 |
|---|---|---|
preprocess.py | 输入标准化:尺寸裁剪/填充、RGB/BGR转换、mask二值化 | 如需支持透明通道(PNG alpha),在此添加if img.mode == 'RGBA': ...分支 |
postprocess.py | 输出规范化:tensor→PIL、自动命名、路径创建、日志记录 | 如需导出多格式(JPG+WEBP),扩展save_result()的保存逻辑 |
config.py | 全局配置:默认尺寸、最大内存限制、超时阈值 | 修改MAX_IMAGE_SIZE = 2048可提升大图支持能力 |
4.2src/fft/:项目技术亮点所在
fft_enhancer.py是区别于标准LaMa的关键模块:
import torch import torch.fft as fft def apply_fft_enhance(pred_tensor: torch.Tensor) -> torch.Tensor: """ 对预测结果进行FFT频域增强 1. 将图像转至频域 2. 对低频分量进行幅度提升(强化主体结构) 3. 对高频分量进行相位微调(优化纹理细节) 4. 逆变换回空域 """ # pred_tensor: [1, 3, H, W] → 转频域 freq = fft.fftn(pred_tensor, dim=(-2, -1)) # 构建增强掩膜(中心低频强,边缘高频弱) h, w = freq.shape[-2:] y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') center_y, center_x = h // 2, w // 2 dist_sq = (y - center_y) ** 2 + (x - center_x) ** 2 mask = torch.exp(-dist_sq / (2 * (min(h, w) // 8) ** 2)) # 高斯低通 # 增强:低频放大 + 高频相位扰动 enhanced_freq = freq * (1.2 * mask + 0.8 * (1 - mask)) # 幅度调制 phase_noise = torch.randn_like(freq).to(freq.device) * 0.05 enhanced_freq = enhanced_freq * torch.exp(1j * phase_noise) # 相位扰动 # 逆变换 enhanced_tensor = fft.ifftn(enhanced_freq, dim=(-2, -1)).real return torch.clamp(enhanced_tensor, 0, 1)可扩展方向:
- 添加
mode参数支持不同增强策略('structure','texture','balance') - 将
0.05噪声系数改为可配置参数,通过inference.py传入 - 实现
apply_fft_enhance_batch()支持批量处理,提升吞吐量
4.3src/lama/:模型能力的封装体
lama_model.py采用懒加载+单例模式,确保模型只加载一次:
class LaMaModel: _instance = None _model = None _device = None @classmethod def get_instance(cls, model_path: str): if cls._instance is None: cls._instance = cls() cls._instance._load_model(model_path) return cls._instance def _load_model(self, model_path: str): # 1. 加载配置 config = load_config(os.path.join(model_path, "config.yaml")) # 2. 构建模型 self._model = build_model(config) # 3. 加载权重 weights = torch.load(os.path.join(model_path, "best.ckpt"), map_location='cpu') self._model.load_state_dict(weights) # 4. 设备迁移(自动选择GPU/CPU) self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._model.to(self._device) self._model.eval() def inpaint(self, image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # 核心推理:image/mask → pred with torch.no_grad(): image = image.to(self._device) mask = mask.to(self._device) pred = self._model(image, mask) return pred.cpu()二次开发安全操作:
可继承LaMaModel创建新类(如MyEnhancedLaMaModel),覆盖inpaint()方法
可在_load_model()中添加自定义权重初始化逻辑(如LoRA适配器加载)
禁止修改_device判断逻辑——它已自动适配环境
5. 二次开发实战:3个高频需求的改造方案
5.1 需求:支持批量图像修复(非单张)
改造点:inference.py+ 新增batch_inference.py
- 在
inference.py中新增函数:
def run_batch_inpainting(image_paths: list[str], mask_paths: list[str]) -> list[str]: """批量修复入口,返回结果路径列表""" results = [] for img_path, mask_path in zip(image_paths, mask_paths): input_pil = Image.open(img_path) mask_pil = Image.open(mask_path) _, status = run_inpainting(input_pil, mask_pil) results.append(status.split("已保存至: ")[-1]) return results- 创建
batch_inference.py(独立脚本):
import argparse from inference import run_batch_inpainting parser = argparse.ArgumentParser() parser.add_argument("--images", nargs="+", required=True) parser.add_argument("--masks", nargs="+", required=True) args = parser.parse_args() results = run_batch_inpainting(args.images, args.masks) for r in results: print(r)- 使用方式:
python batch_inference.py \ --images img1.jpg img2.png \ --masks mask1.png mask2.png5.2 需求:增加“修复强度”滑块控制
改造点:app.py+inference.py
- 修改
app.py的Gradio界面:
with gr.Row(): strength_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="🔧 修复强度") run_btn = gr.Button(" 开始修复") # 修改事件绑定 run_btn.click( fn=process_image, inputs=[input_img, mask_img, strength_slider], # ← 新增输入 outputs=[output_img, status_text] )- 修改
inference.py的run_inpainting函数签名:
def run_inpainting(input_pil: Image.Image, mask_pil: Image.Image, strength: float = 0.8) -> tuple[Image.Image, str]: # ... 原有逻辑 pred_tensor = model.inpaint(input_tensor, mask_tensor) # 关键:将strength传给FFT增强 enhanced_tensor = apply_fft_enhance(pred_tensor, strength=strength) # ...- 修改
fft_enhancer.py:
def apply_fft_enhance(pred_tensor: torch.Tensor, strength: float = 0.8) -> torch.Tensor: # 将原代码中的 1.2/0.8 等系数替换为 strength 相关计算 low_boost = 1.0 + 0.5 * strength # strength越大,低频增强越强 high_phase = 0.01 + 0.09 * strength # strength越大,高频扰动越强 # ... 后续计算使用 new_boost, new_phase5.3 需求:修复结果自动同步到云存储
改造点:src/base/postprocess.py
- 在
save_result()函数末尾添加:
def save_result(tensor: torch.Tensor) -> Image.Image: # ... 原有保存逻辑 local_path = os.path.join("outputs", filename) # 新增:云同步(以阿里云OSS为例) try: from oss2 import Auth, Bucket auth = Auth('your-access-key', 'your-secret-key') bucket = Bucket(auth, 'https://oss-cn-wlcb.aliyuncs.com', 'your-bucket-name') with open(local_path, 'rb') as f: bucket.put_object(f'webui/{filename}', f) print(f"☁ 已同步至OSS: webui/{filename}") except ImportError: print(" oss2未安装,跳过云同步") except Exception as e: print(f" 云同步失败: {e}") return result_pil- 在
requirements.txt中追加:
oss2>=2.15.06. 调试与验证:让修改立刻可见
6.1 快速验证修改是否生效
| 场景 | 操作 | 预期现象 |
|---|---|---|
修改了app.py文案 | 重启服务bash start_app.sh | 浏览器刷新后标题/按钮文字立即更新 |
修改了inference.py逻辑 | 无需重启,Gradio自动热重载 | 点击修复后,状态栏显示新文案或结果变化 |
修改了src/模块 | 重启服务(因模块被import缓存) | ps aux | grep app.py查看进程PID变化 |
6.2 日志定位问题的黄金路径
所有关键操作均输出日志到终端,重点关注三类信息:
- `` 开头:成功节点(如
模型加载完成) - `` 开头:警告但可继续(如
输入尺寸过大,已自动缩放) - `` 开头:致命错误(如
CUDA out of memory)
日志文件化(可选):
修改start_app.sh中的启动命令:
python3 app.py --server-port 7860 --server-name 0.0.0.0 2>&1 | tee -a /var/log/fft-lama.log7. 总结:二次开发的四条铁律
7.1 铁律一:永远优先修改配置,而非硬编码
- 想改端口?改
start_app.sh - 想改模型路径?改
inference.py中的字符串,或提取为config.py变量 - 想调参?加
strength参数,而非在fft_enhancer.py里写死0.05
7.2 铁律二:WebUI只是“皮肤”,核心在inference.py
- 90%的功能扩展,只需改
inference.py的输入/输出和中间处理 - Gradio界面只是调用它的“遥控器”,不要在界面里写业务逻辑
7.3 铁律三:src/是你的乐高积木库
base/提供稳定接口(别动它,用它)fft/是特色模块(可替换、可增强)lama/是能力底座(可继承、可包装)
7.4 铁律四:每次修改,只动一个点,立即验证
- 改完一行代码,就点一次“ 开始修复”
- 看不到效果?先检查终端日志有没有 ``
- 日志没报错?用
print()在关键位置打点(记得删掉)
你不是在维护一个“AI项目”,而是在调试一条精密的图像修复流水线。理解每个齿轮的咬合位置,比记住所有参数更重要。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。