paper:https://arxiv.org/pdf/2105.05537
GitHub:https://github.com/HuCaoFighting/Swin-Unet
目录
一、摘要
近年来,卷积神经网络(CNN)在医学图像分析领域取得了重要进展。尤其是基于U形架构和跳跃连接的深度神经网络广泛应用于各种医学图像任务。然而,尽管CNN表现优异,但由于卷积操作的局部性,其难以很好地学习全局和长距离语义信息的交互。在本文中,我们提出了Swin-Unet,一种用于医学图像分割的类Unet纯Transformer模型。被切分为Token的图像块输入到基于Transformer的U形编码器-解码器架构中,通过跳跃连接实现局部和全局语义特征的学习。具体来说,我们采用了带有移动窗口的分层Swin Transformer作为编码器以提取上下文特征;同时设计了一个基于对称Swin Transformer的解码器,结合Patch扩展层实现上采样操作,以恢复特征图的空间分辨率。在输入和输出直接下采样和上采样4倍的情况下,针对多器官和心脏分割任务的实验表明,这种纯基于Transformer的U形编码器-解码器网络优于完全卷积方法或Transformer与卷积结合的方法。
二、模型
1、架构概述
Swin-Unet的整体架构如图所示。Swin-Unet由编码器、瓶颈层、解码器和跳跃连接组成。Swin-Unet的基本单元是Swin Transformer块。
编码器:
为了将输入转换为序列嵌入(将二维图像问题转换为序列问题),医学图像被分割为大小为的不重叠块。通过这种分割方式,每个Patch的特征维度变为
。
接着,一个线性嵌入层被应用于将特征维度投射到任意维度(表示为)。
转换后的Patch Token通过多个Swin Transformer块和Patch合并层,生成分层的特征表示。其中:
- Patch合并层负责下采样和特征维度扩展,实现更高效的多层次特征提取。
- Swin Transformer块负责特征表示的学习,捕获局部和全局上下文。
瓶颈层:
连接编码器与解码器,进一步处理从编码器提取的深层次特征。
解码器:
受U-Net的启发,设计了对称的基于Transformer的解码器。解码器由Swin Transformer块和Patch扩展层组成。解码器中提取的上下文特征通过跳跃连接与编码器的多尺度特征融合,以补充因下采样导致的空间信息丢失。
与Patch合并层相对,Patch扩展层专为上采样设计:
- 它将相邻维度的特征图重新排列为更大的特征图,实现分辨率的
上采样。
- 在最后一个Patch扩展层中,执行
上采样以将特征图分辨率恢复到输入图像的分辨率 (
)。
- 最后,对这些上采样后的特征应用一个线性投影层,以输出像素级的分割预测。
2、Swin Transformer块
与传统的多头自注意力(MSA)模块不同,Swin Transformer块基于移动窗口(shifted windows)构建。如图所示,展示了两个连续的Swin Transformer块。每个Swin Transformer块由以下组件组成:
(1)LayerNorm(LN)层:是一种归一化技术,通过对每一层的特征进行标准化(零均值和单位方差),加速模型训练,提高模型的稳定性。
(2)多头自注意力模块(W-MSA):自注意力机制通过计算输入特征中每对元素之间的相关性(注意力权重),实现特征的动态加权组合,捕获全局依赖关系。多头机制(Multi-Head Attention)是自注意力的扩展版本,通过并行多个注意力头,更好地捕获不同子空间中的特征。
(3)残差连接(Residual Connection):通过直接跳过非线性变换,将输入特征与输出特征相加,解决了深层网络中梯度消失的问题。
(4)多层感知机(MLP):是一种简单的前馈神经网络,用于特征转换和非线性映射。在Swin Transformer块中,MLP包含两层全连接层,中间插入GELU激活函数。
总结:
(1)LayerNorm 提供了稳定的特征归一化。
(2)多头自注意力模块 捕获全局上下文。
(3)残差连接 保持梯度稳定,避免信息丢失。
(4)带有GELU的两层MLP 增强了非线性表达能力,补充了注意力机制后的特征变换。
在两个连续的Transformer块中,分别应用了基于窗口的多头自注意力(W-MSA)模块和基于移动窗口的多头自注意力(SW-MSA)模块。基于这样的窗口划分机制,连续的Swin Transformer块可以被公式化为:
其中,和
分别表示第
个块中 (S)W-MSA 模块和 MLP 模块的输出。
自注意力机制的计算公式如下:
这里 分别表示查询矩阵、键矩阵和值矩阵。
表示窗口中的Patch数量,
是查询或键的维度。矩阵
的值取自偏置矩阵
。
3、编码器
在编码器中,分割为 维Token且分辨率为
的输入数据被输入到两个连续的Swin Transformer块中进行表示学习。在这一过程中,特征维度和分辨率保持不变。同时,Patch合并层会将Token数量减少一半(即2倍下采样),并将特征维度增加到原始维度的2倍。编码器中会重复三次这一过程。
Patch合并层:
- 输入的特征被划分为4个子部分,这些子部分通过连接操作整合为一个新的特征表示。
- 连接操作将特征的分辨率缩小了一半(2倍下采样),减少了空间上的计算复杂度。
- 原始连接后的特征维度增加了4倍。为了保持特征表示的一致性,后续通过一个线性层将维度调整为原始特征维度的2倍。
4、瓶颈层
由于Transformer网络过深时难以收敛 ,因此瓶颈层仅使用了两个连续的Swin Transformer块构建,用于学习深层特征表示。在瓶颈层中,特征维度和分辨率保持不变。
作为编码器和解码器之间的连接部分,瓶颈层提取和处理全局特征表示。通过减少网络深度(仅两个Swin Transformer块),避免过深Transformer模型带来的收敛问题,同时保持特征的高质量表达。
5、解码器
与编码器相对应,对称的解码器基于Swin Transformer块构建。为此,与编码器中使用的Patch合并层相对,解码器中使用Patch扩展层对提取的深层特征进行上采样。Patch扩展层将相邻维度的特征图重新排列为具有更高分辨率的特征图(2倍上采样),并将特征维度减少到原始维度的一半。
Patch扩展层:
在上采样之前,对输入特征()应用一个线性层,将特征维度扩展到原始维度的两倍(
)。
使用重排(rearrange)操作将输入特征的分辨率扩展为输入分辨率的两倍,同时将特征维度减少到原始维度的四分之一:
从 扩展到
。
解码器通过逐步上采样恢复特征分辨率,同时保持或减少特征维度,为最终的分割结果提供高分辨率上下文信息。
6、跳跃连接
与U-Net类似,跳跃连接用于将编码器的多尺度特征与解码器的上采样特征融合。通过将浅层特征与深层特征连接在一起,减少了下采样导致的空间信息丢失。随后,使用一个线性层对连接后的特征进行处理,使得其维度与上采样特征的维度保持一致。
作用:
- 编码器在多次下采样过程中可能丢失空间分辨率信息,通过跳跃连接补充高分辨率特征。
- 将编码器的浅层特征(细节信息)与解码器的深层特征(全局语义信息)进行融合,提高分割精度。
三、实验
1、数据集
Synapse多器官分割数据集 (Synapse):
该数据集包含30例病例,共计3779张轴向腹部临床CT图像。其中,18个样本用于训练集,12个样本用于测试集。使用平均Dice相似系数(DSC)和平均Hausdorff距离(HD)作为评估指标,在8个腹部器官(主动脉、胆囊、脾脏、左肾、右肾、肝脏、胰腺、胃)上评估我们的方法。
自动化心脏诊断挑战数据集 (ACDC):
ACDC数据集来自不同患者的MRI扫描仪采集。对于每位患者的MR图像,标注了左心室(LV)、右心室(RV)和心肌(MYO)。数据集分为70个训练样本、10个验证样本和20个测试样本。仅使用平均DSC作为评估指标来评估我们的方法在该数据集上的表现。
Dice相似系数(DSC):衡量分割结果与真实标签之间的重叠程度,范围为0到1,值越大表示分割性能越好。
其中, A 和B 分别是预测分割结果和真实标签。
Hausdorff距离(HD):衡量分割结果边界与真实边界之间的最大距离,用于评估边界细节的精确性。
2、实现细节
Swin-Unet 基于 Python 3.6 和 Pytorch 1.7.0 实现。在所有训练案例中,为了增加数据多样性,使用了数据增强技术,如翻转和旋转。输入图像大小和Patch大小分别设置为 224×224和 4。我们在具有32GB内存的Nvidia V100 GPU上训练模型。模型参数使用在 ImageNet 上预训练的权重进行初始化。
在训练过程中,批量大小设置为24,优化模型使用了具有动量(momentum)0.9和权重衰减(weight decay) 的流行SGD优化器来执行反向传播。
3、 Synapse数据集上的实验结果
表中展示了所提出的Swin-Unet与以往的最新方法在Synapse多器官CT数据集上的对比结果。与TransUnet 不同,我们添加了自己实现的U-Net 和 Att-UNet在Synapse数据集上的测试结果。实验结果表明,我们基于纯Transformer的类Unet方法取得了最优的分割性能,达到了79.13%的分割准确率(DSC)和21.55%的Hausdorff距离(HD)。与Att-Unet 和最新的TransUnet方法相比,尽管我们的方法在DSC评价指标上没有显著提升,但在HD评价指标上分别取得了约4%和10%的精度提升,这表明我们的方法在边界预测方面表现更优。
图中展示了不同方法在Synapse多器官CT数据集上的分割结果。从图中可以看出,基于CNN的方法容易出现过分割问题,这可能是由于卷积操作的局部性导致的。在本研究中,我们通过将Transformer集成到具有跳跃连接的U形架构中,证明了无卷积的纯Transformer方法可以更好地学习全局和长距离语义信息交互,从而实现更优的分割结果。
4、ACDC 数据集上的实验结果
与 Synapse 数据集类似,我们在 ACDC 数据集上训练了所提出的 Swin-Unet 以执行医学图像分割任务。实验结果总结在表中。使用 MRI 模态的图像数据作为输入,Swin-Unet 依然能够取得 90.00% 的优秀分割准确率,这表明我们的方法具有良好的泛化能力和鲁棒性。
5、消融实验
为了探讨不同因素对模型性能的影响,我们在 Synapse 数据集上进行了消融研究,具体讨论了以下因素:上采样方式、跳跃连接数量、输入大小以及模型规模。
(1)上采样方式的影响:
为对应编码器中的 Patch 合并层,我们在解码器中专门设计了一个 Patch 扩展层,用于上采样和特征维度扩展。
为了验证所提出的 Patch 扩展层的有效性,我们分别在 Synapse 数据集上使用双线性插值、反卷积和 Patch 扩展层进行实验。
表中的实验结果表明,结合 Patch 扩展层的 Swin-Unet 可以获得更高的分割准确率。
(2)跳跃连接数量的影响:
Swin-Unet 的跳跃连接添加在分辨率为 1/4、1/8和 1/16的位置。
通过分别将跳跃连接数量设置为 0、1、2 和 3,我们探讨了不同跳跃连接数量对模型分割性能的影响。
表中显示,模型性能随着跳跃连接数量的增加而提高。因此,为了使模型更加鲁棒,本研究中跳跃连接数量设置为 3。
(3) 输入大小的影响:
表展示了 Swin-Unet 使用 224×22和 384×384 输入分辨率时的测试结果。
随着输入大小从 224×224增加到 384×384,且 Patch 大小保持为 4,Transformer 的输入 Token 序列长度增大,从而提高了模型的分割性能。
然而,尽管模型的分割精度略有提升,但整个网络的计算负载显著增加。为了确保算法的运行效率,本研究中实验基于 224×224的输入分辨率。
(4) 模型规模的影响:
从表中可以看出,模型规模的增加几乎未能显著提高模型性能,但却增加了整个网络的计算成本。
综合考虑准确性和速度的平衡,我们选择基于 Tiny 的模型执行医学图像分割任务。
转载自CSDN-专业IT技术社区
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/weixin_56848903/article/details/144796853