Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题
一、问题场景:合成噪声训练很好,真实图片效果却很差
前面我们训练的模型大多基于一个假设:
噪声是高斯噪声。
也就是训练时这样造数据:
noise = torch.randn_like(clean) * sigma / 255.0
这在论文实验里很常见,也方便复现。
但在真实工程里,我遇到一个非常现实的问题:
模型在合成噪声测试集上PSNR很高,但处理真实手机照片、截图、扫描件时效果明显变差。
真实噪声往往不是简单高斯噪声,它可能包含:
- 传感器噪声
- JPEG压缩噪声
- 低光噪声
- 颜色偏移
- 局部噪声不均匀
- 锐化产生的伪影
因此,只靠合成高斯噪声训练出来的模型,很容易出现泛化不足。
这一篇我们参考 CBDNet 的思路,做一个更接近真实噪声场景的去噪模型。
二、CBDNet解决什么问题?
CBDNet的核心思想可以概括为:
先估计噪声,再根据噪声分布进行去噪。
它不是假设整张图噪声强度一样,而是认为不同区域噪声可能不同。
这非常符合真实情况。
比如一张夜景照片:
- 暗部噪声很重
- 亮部噪声较轻
- 边缘区域可能有压缩伪影
普通模型无法区分这些区域,而 CBDNet 会先预测一张 noise map。
三、整体架构设计
我们实现一个简化版 CBDNet,分成两个子网络:
1. Noise Estimation Network
输入 noisy image,输出 noise map。
2. Denoising Network
输入 noisy image + noise map,输出 clean image。
整体流程:
noisy -> noise estimation -> noise map
noisy + noise map -> denoising network -> clean
四、工程目录结构
cbdnet_denoise/
├── data/
│ ├── train/
│ └── val/
├── models/
│ └── cbdnet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py
五、数据构建:模拟更真实的噪声
真实噪声比高斯噪声复杂。
这里我们用一个工程上常见的简化方式:
- 随机高斯噪声
- 随机JPEG压缩
- 随机噪声强度
- 局部噪声变化
dataset.py
import os
import random
import io
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class RealisticNoiseDataset(Dataset):
def __init__(self, root_dir, patch_size=128):
self.paths = [
os.path.join(root_dir, name)
for name in os.listdir(root_dir)
if name.lower().endswith((".jpg", ".png", ".jpeg"))
]
self.patch_size = patch_size
self.to_tensor = transforms.ToTensor()
def jpeg_compress(self, img):
quality = random.randint(30, 95)
buffer = io.BytesIO()
img.save(buffer, format="JPEG", quality=quality)
buffer.seek(0)
return Image.open(buffer).convert("L")
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("L")
w, h = img.size
if w >= self.patch_size and h >= self.patch_size:
x = random.randint(0, w - self.patch_size)
y = random.randint(0, h - self.patch_size)
img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
else:
img = img.resize((self.patch_size, self.patch_size))
clean = self.to_tensor(img)
if random.random() < 0.5:
img = self.jpeg_compress(img)
base = self.to_tensor(img)
sigma = random.uniform(5, 50) / 255.0
noise = torch.randn_like(base) * sigma
noisy = torch.clamp(base + noise, 0.0, 1.0)
noise_map = torch.ones_like(clean) * sigma
return noisy, noise_map, clean
六、CBDNet模型实现
models/cbdnet.py
import torch
import torch.nn as nn
class NoiseEstimationNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 1, 3, padding=1),
nn.Sigmoid()
)
def forward(self, x):
return self.net(x)
class DenoisingNet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(2, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 1, 3, padding=1)
)
def forward(self, noisy, noise_map):
x = torch.cat([noisy, noise_map], dim=1)
residual = self.net(x)
return noisy - residual
class CBDNet(nn.Module):
def __init__(self):
super().__init__()
self.noise_estimator = NoiseEstimationNet()
self.denoiser = DenoisingNet()
def forward(self, noisy):
noise_map = self.noise_estimator(noisy)
clean = self.denoiser(noisy, noise_map)
return clean, noise_map
七、训练代码
import torch
from torch.utils.data import DataLoader
from dataset import RealisticNoiseDataset
from models.cbdnet import CBDNet
def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = RealisticNoiseDataset("data/train")
loader = DataLoader(dataset, batch_size=12, shuffle=True, num_workers=4)
model = CBDNet().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
image_loss = torch.nn.L1Loss()
noise_loss = torch.nn.L1Loss()
for epoch in range(1, 61):
model.train()
total_loss = 0
for noisy, gt_noise_map, clean in loader:
noisy = noisy.to(device)
gt_noise_map = gt_noise_map.to(device)
clean = clean.to(device)
pred_clean, pred_noise_map = model(noisy)
loss_img = image_loss(pred_clean, clean)
loss_noise = noise_loss(pred_noise_map, gt_noise_map)
loss = loss_img + 0.2 * loss_noise
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")
if epoch % 10 == 0:
torch.save(model.state_dict(), f"cbdnet_epoch_{epoch}.pth")
if __name__ == "__main__":
train()
八、为什么要监督noise map?
很多人实现类似结构时,只训练最终输出,不监督 noise map。
这样会导致一个问题:
noise estimator 学不到明确含义,只变成一个中间黑盒特征。
我们这里加入:
loss_noise = L1(pred_noise_map, gt_noise_map)
目的不是让 noise map 完全精确,而是给它一个稳定训练方向。
九、推理代码
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from models.cbdnet import CBDNet
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CBDNet().to(device)
model.load_state_dict(torch.load("cbdnet_epoch_60.pth", map_location=device))
model.eval()
img = Image.open("real_noisy.png").convert("L")
to_tensor = transforms.ToTensor()
noisy = to_tensor(img).unsqueeze(0).to(device)
with torch.no_grad():
pred, noise_map = model(noisy)
pred = torch.clamp(pred, 0.0, 1.0)
vutils.save_image(pred.cpu(), "cbdnet_result.png")
vutils.save_image(noise_map.cpu(), "estimated_noise_map.png")
十、真实噪声任务中的重要经验
1. 不要只训练高斯噪声
高斯噪声只是最干净的实验设定,不能代表真实图片。
2. 压缩噪声必须加入
现实中的图片大部分都经历过压缩。
3. 不要过度追PSNR
真实噪声下,没有干净GT时,PSNR不一定可用。
肉眼效果、业务指标更重要。
比如 OCR 场景要看识别率,而不是只看图像指标。
十一、踩坑记录
坑1:noise map全变成常数
原因:
- noise loss权重太小
- 数据噪声变化不足
解决:
loss = loss_img + 0.2 * loss_noise
坑2:真实图片去噪后发糊
原因:
- 训练噪声太单一
- L1Loss仍然偏平滑
解决:
- 加JPEG噪声
- 加随机噪声强度
- 加边缘损失
坑3:噪声估计图没有意义
noise map不是一定要和真实噪声完全一致,它的价值在于给 denoiser 提供区域性噪声提示。
不要把它当成最终产品,而是中间引导。
十二、效果验证
在合成噪声测试中,CBDNet不一定显著超过UNet。
但在真实噪声场景中,它通常更稳。
| 场景 | UNet | CBDNet |
|---|---|---|
| 高斯噪声 | 表现好 | 表现好 |
| JPEG压缩 | 一般 | 更稳 |
| 低光噪声 | 容易残留噪点 | 更自然 |
| 真实截图 | 有伪影 | 更干净 |
十三、适合收藏总结
CBDNet完整流程
- 输入真实带噪图
- 先预测 noise map
- 拼接 noisy 和 noise map
- 再进行去噪
- 输出 clean image
避坑清单
- 不要只用高斯噪声
- 加入JPEG压缩增强
- noise map需要辅助监督
- 真实场景不要迷信PSNR
- 业务指标更重要
十四、优化建议
可以继续改进:
- Noise Estimator改成UNet结构
- Denoiser改成ResUNet
- 加感知损失
- 加OCR识别损失
- 使用真实噪声数据集微调
结尾总结
CBDNet真正解决的是一个非常工程化的问题:
模型在实验数据上很好,但真实图片不好用。
它的关键不是某个复杂模块,而是建模方式变了:
先估计噪声,再根据噪声去恢复图像。
这是图像去噪从“实验室模型”走向“真实工程”的重要一步。
下一篇预告
Pytorch图像去噪实战(七):Noise2Noise自监督去噪实战,没有干净图也能训练模型
转载自CSDN-专业IT技术社区
原文链接:https://blog.csdn.net/sundong_3/article/details/160630182



