关注

Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题

Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题


一、问题场景:合成噪声训练很好,真实图片效果却很差

前面我们训练的模型大多基于一个假设:

噪声是高斯噪声。

也就是训练时这样造数据:

noise = torch.randn_like(clean) * sigma / 255.0

这在论文实验里很常见,也方便复现。
但在真实工程里,我遇到一个非常现实的问题:

模型在合成噪声测试集上PSNR很高,但处理真实手机照片、截图、扫描件时效果明显变差。

真实噪声往往不是简单高斯噪声,它可能包含:

  • 传感器噪声
  • JPEG压缩噪声
  • 低光噪声
  • 颜色偏移
  • 局部噪声不均匀
  • 锐化产生的伪影

因此,只靠合成高斯噪声训练出来的模型,很容易出现泛化不足。

这一篇我们参考 CBDNet 的思路,做一个更接近真实噪声场景的去噪模型。


二、CBDNet解决什么问题?

CBDNet的核心思想可以概括为:

先估计噪声,再根据噪声分布进行去噪。

它不是假设整张图噪声强度一样,而是认为不同区域噪声可能不同。

这非常符合真实情况。

比如一张夜景照片:

  • 暗部噪声很重
  • 亮部噪声较轻
  • 边缘区域可能有压缩伪影

普通模型无法区分这些区域,而 CBDNet 会先预测一张 noise map。


三、整体架构设计

我们实现一个简化版 CBDNet,分成两个子网络:

1. Noise Estimation Network

输入 noisy image,输出 noise map。

2. Denoising Network

输入 noisy image + noise map,输出 clean image。

整体流程:

noisy -> noise estimation -> noise map
noisy + noise map -> denoising network -> clean

四、工程目录结构

cbdnet_denoise/
├── data/
│   ├── train/
│   └── val/
├── models/
│   └── cbdnet.py
├── dataset.py
├── train.py
├── eval.py
└── utils.py

五、数据构建:模拟更真实的噪声

真实噪声比高斯噪声复杂。
这里我们用一个工程上常见的简化方式:

  • 随机高斯噪声
  • 随机JPEG压缩
  • 随机噪声强度
  • 局部噪声变化

dataset.py

import os
import random
import io
import torch
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class RealisticNoiseDataset(Dataset):
    def __init__(self, root_dir, patch_size=128):
        self.paths = [
            os.path.join(root_dir, name)
            for name in os.listdir(root_dir)
            if name.lower().endswith((".jpg", ".png", ".jpeg"))
        ]

        self.patch_size = patch_size
        self.to_tensor = transforms.ToTensor()

    def jpeg_compress(self, img):
        quality = random.randint(30, 95)
        buffer = io.BytesIO()
        img.save(buffer, format="JPEG", quality=quality)
        buffer.seek(0)
        return Image.open(buffer).convert("L")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")

        w, h = img.size
        if w >= self.patch_size and h >= self.patch_size:
            x = random.randint(0, w - self.patch_size)
            y = random.randint(0, h - self.patch_size)
            img = img.crop((x, y, x + self.patch_size, y + self.patch_size))
        else:
            img = img.resize((self.patch_size, self.patch_size))

        clean = self.to_tensor(img)

        if random.random() < 0.5:
            img = self.jpeg_compress(img)

        base = self.to_tensor(img)

        sigma = random.uniform(5, 50) / 255.0
        noise = torch.randn_like(base) * sigma

        noisy = torch.clamp(base + noise, 0.0, 1.0)

        noise_map = torch.ones_like(clean) * sigma

        return noisy, noise_map, clean

六、CBDNet模型实现

models/cbdnet.py

import torch
import torch.nn as nn


class NoiseEstimationNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 1, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)


class DenoisingNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(2, 64, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 1, 3, padding=1)
        )

    def forward(self, noisy, noise_map):
        x = torch.cat([noisy, noise_map], dim=1)
        residual = self.net(x)
        return noisy - residual


class CBDNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.noise_estimator = NoiseEstimationNet()
        self.denoiser = DenoisingNet()

    def forward(self, noisy):
        noise_map = self.noise_estimator(noisy)
        clean = self.denoiser(noisy, noise_map)
        return clean, noise_map

七、训练代码

import torch
from torch.utils.data import DataLoader
from dataset import RealisticNoiseDataset
from models.cbdnet import CBDNet


def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = RealisticNoiseDataset("data/train")
    loader = DataLoader(dataset, batch_size=12, shuffle=True, num_workers=4)

    model = CBDNet().to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    image_loss = torch.nn.L1Loss()
    noise_loss = torch.nn.L1Loss()

    for epoch in range(1, 61):
        model.train()
        total_loss = 0

        for noisy, gt_noise_map, clean in loader:
            noisy = noisy.to(device)
            gt_noise_map = gt_noise_map.to(device)
            clean = clean.to(device)

            pred_clean, pred_noise_map = model(noisy)

            loss_img = image_loss(pred_clean, clean)
            loss_noise = noise_loss(pred_noise_map, gt_noise_map)

            loss = loss_img + 0.2 * loss_noise

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch}, Loss: {total_loss / len(loader):.6f}")

        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"cbdnet_epoch_{epoch}.pth")


if __name__ == "__main__":
    train()

八、为什么要监督noise map?

很多人实现类似结构时,只训练最终输出,不监督 noise map。

这样会导致一个问题:

noise estimator 学不到明确含义,只变成一个中间黑盒特征。

我们这里加入:

loss_noise = L1(pred_noise_map, gt_noise_map)

目的不是让 noise map 完全精确,而是给它一个稳定训练方向。


九、推理代码

import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from models.cbdnet import CBDNet


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CBDNet().to(device)
model.load_state_dict(torch.load("cbdnet_epoch_60.pth", map_location=device))
model.eval()

img = Image.open("real_noisy.png").convert("L")
to_tensor = transforms.ToTensor()

noisy = to_tensor(img).unsqueeze(0).to(device)

with torch.no_grad():
    pred, noise_map = model(noisy)
    pred = torch.clamp(pred, 0.0, 1.0)

vutils.save_image(pred.cpu(), "cbdnet_result.png")
vutils.save_image(noise_map.cpu(), "estimated_noise_map.png")

十、真实噪声任务中的重要经验

1. 不要只训练高斯噪声

高斯噪声只是最干净的实验设定,不能代表真实图片。

2. 压缩噪声必须加入

现实中的图片大部分都经历过压缩。

3. 不要过度追PSNR

真实噪声下,没有干净GT时,PSNR不一定可用。
肉眼效果、业务指标更重要。

比如 OCR 场景要看识别率,而不是只看图像指标。


十一、踩坑记录

坑1:noise map全变成常数

原因:

  • noise loss权重太小
  • 数据噪声变化不足

解决:

loss = loss_img + 0.2 * loss_noise

坑2:真实图片去噪后发糊

原因:

  • 训练噪声太单一
  • L1Loss仍然偏平滑

解决:

  • 加JPEG噪声
  • 加随机噪声强度
  • 加边缘损失

坑3:噪声估计图没有意义

noise map不是一定要和真实噪声完全一致,它的价值在于给 denoiser 提供区域性噪声提示。

不要把它当成最终产品,而是中间引导。


十二、效果验证

在合成噪声测试中,CBDNet不一定显著超过UNet。
但在真实噪声场景中,它通常更稳。

场景UNetCBDNet
高斯噪声表现好表现好
JPEG压缩一般更稳
低光噪声容易残留噪点更自然
真实截图有伪影更干净

十三、适合收藏总结

CBDNet完整流程

  1. 输入真实带噪图
  2. 先预测 noise map
  3. 拼接 noisy 和 noise map
  4. 再进行去噪
  5. 输出 clean image

避坑清单

  • 不要只用高斯噪声
  • 加入JPEG压缩增强
  • noise map需要辅助监督
  • 真实场景不要迷信PSNR
  • 业务指标更重要

十四、优化建议

可以继续改进:

  • Noise Estimator改成UNet结构
  • Denoiser改成ResUNet
  • 加感知损失
  • 加OCR识别损失
  • 使用真实噪声数据集微调

结尾总结

CBDNet真正解决的是一个非常工程化的问题:

模型在实验数据上很好,但真实图片不好用。

它的关键不是某个复杂模块,而是建模方式变了:

先估计噪声,再根据噪声去恢复图像。

这是图像去噪从“实验室模型”走向“真实工程”的重要一步。


下一篇预告

Pytorch图像去噪实战(七):Noise2Noise自监督去噪实战,没有干净图也能训练模型

转载自CSDN-专业IT技术社区

原文链接:https://blog.csdn.net/sundong_3/article/details/160630182

评论

赞0

评论列表

微信小程序
QQ小程序

关于作者

点赞数:0
关注数:0
粉丝:0
文章:0
关注标签:0
加入于:--