关注

100个YOLOv8改进模块 |即插即用

目录

🔥 一、注意力机制类 (1-25)

通道注意力 (1-8)

空间注意力 (9-16)

混合注意力 (17-25)

🏗️ 二、卷积模块改进 (26-45)

深度可分离卷积 (26-32)

特征提取增强 (33-40)

轻量化设计 (41-45)

🔄 三、特征融合模块 (46-65)

Neck结构改进 (46-55)

跨尺度融合 (56-65)

🎯 四、检测头改进 (66-80)

⚡ 五、损失函数与标签分配 (81-90)

🔧 六、训练策略与辅助模块 (91-100)

📋 使用说明

在YOLOv12中插入模块的示例:

快速替换指南:


🔥 一、注意力机制类 (1-25)

通道注意力 (1-8)

Python

# 1. ECA-Net (Efficient Channel Attention)
class ECA(nn.Module):
    def __init__(self, channel, gamma=2, b=1):
        super().__init__()
        kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
        kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        return x * self.sigmoid(y)

# 2. SE-Net (Squeeze-and-Excitation)
class SE(nn.Module):
    def __init__(self, c1, reduction=16):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(c1, c1 // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c1 // reduction, c1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

# 3. CBAM (Convolutional Block Attention Module) - Channel
class CBAM_Channel(nn.Module):
    def __init__(self, c1, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(c1, c1//reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(c1//reduction, c1, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        return x * self.sigmoid(avg_out + max_out)

# 4. GCT (Gated Context Transformation)
class GCT(nn.Module):
    def __init__(self, c1, epsilon=1e-5):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1, c1, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, c1, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, c1, 1, 1))
        self.epsilon = epsilon

    def forward(self, x):
        embedding = torch.norm(x, p=2, dim=(2,3), keepdim=True)
        norm = self.gamma / (embedding + self.epsilon).pow(0.5)
        gate = 1. + torch.tanh(norm * (embedding - self.alpha))
        return x * gate

# 5. EMA (Exponential Moving Average) Attention
class EMA(nn.Module):
    def __init__(self, c1, factor=32):
        super().__init__()
        self.groups = factor
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(c1 // self.groups, c1 // self.groups)
        self.conv1x1 = nn.Conv2d(c1 // self.groups, c1 // self.groups, kernel_size=1)
        self.conv3x3 = nn.Conv2d(c1 // self.groups, c1 // self.groups, kernel_size=3, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        out = self.gn(group_x * x_h.sigmoid() * x_w.sigmoid())
        out = self.conv3x3(out)
        return out.reshape(b, c, h, w)

# 6. SimAM (Simple Attention Module)
class SimAM(nn.Module):
    def __init__(self, e_lambda=1e-4):
        super().__init__()
        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda

    def forward(self, x):
        b, c, h, w = x.size()
        n = w * h - 1
        x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
        return x * self.activaton(y)

# 7. SRM (Style-based Recalibration Module)
class SRM(nn.Module):
    def __init__(self, c1):
        super().__init__()
        self.c1 = c1
        self.style_pool = nn.AdaptiveAvgPool2d(1)
        self.style_conv = nn.Conv2d(c1, c1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        style = self.style_pool(x)
        style = self.style_conv(style)
        return x * self.sigmoid(style)

# 8. FcaNet (Frequency Channel Attention)
class FcaNet(nn.Module):
    def __init__(self, c1, reduction=16):
        super().__init__()
        self.register_buffer('pre_computed_dct_weights', self._get_dct_weights(c1))
        self.fc = nn.Sequential(
            nn.Linear(c1, c1 // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c1 // reduction, c1, bias=False),
            nn.Sigmoid()
        )

    def _get_dct_weights(self, channels):
        c_part = channels // 4
        dct_weights = torch.zeros(channels)
        for i in range(channels):
            freq = i // c_part
            dct_weights[i] = freq + 1
        return dct_weights.view(1, -1, 1, 1)

    def forward(self, x):
        b, c, _, _ = x.shape
        y = torch.sum(x * self.pre_computed_dct_weights, dim=[2,3])
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

空间注意力 (9-16)

Python

# 9. Spatial Attention (SAM)
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        return x * self.sigmoid(self.conv(x_cat))

# 10. Coordinate Attention (CA)
class CoordAtt(nn.Module):
    def __init__(self, inp, reduction=32):
        super().__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // reduction)
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = nn.Hardswish()
        self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        return identity * a_w * a_h

# 11. Triplet Attention
class TripletAttention(nn.Module):
    def __init__(self, no_spatial=False):
        super().__init__()
        self.cw = nn.Conv2d(1, 1, kernel_size=(7, 3), padding=(3, 1), bias=False)
        self.hc = nn.Conv2d(1, 1, kernel_size=(3, 7), padding=(1, 3), bias=False)
        self.no_spatial = no_spatial
        if not no_spatial:
            self.hw = nn.Conv2d(1, 1, 7, padding=3, bias=False)

    def forward(self, x):
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()
        x_out1 = self.cw(x_perm1)
        x_out11 = x_out1.sigmoid().permute(0, 2, 1, 3).contiguous()
        
        x_perm2 = x.permute(0, 3, 2, 1).contiguous()
        x_out2 = self.hc(x_perm2)
        x_out21 = x_out2.sigmoid().permute(0, 3, 2, 1).contiguous()
        
        if not self.no_spatial:
            x_out = self.hw(x)
            x_out = x_out * x_out11 * x_out21
        else:
            x_out = x * x_out11 * x_out21
        return x_out

# 12. Shuffle Attention
class ShuffleAttention(nn.Module):
    def __init__(self, c1, G=8):
        super().__init__()
        self.G = G
        self.c1 = c1
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cweight = nn.Parameter(torch.zeros(1, c1 // (2 * G), 1, 1))
        self.cbias = nn.Parameter(torch.ones(1, c1 // (2 * G), 1, 1))
        self.sweight = nn.Parameter(torch.zeros(1, c1 // (2 * G), 1, 1))
        self.sbias = nn.Parameter(torch.ones(1, c1 // (2 * G), 1, 1))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        x = x.view(b * self.G, -1, h, w)
        x_0, x_1 = x.chunk(2, dim=1)
        
        # channel attention
        xn = self.avg_pool(x_0)
        xn = self.cweight * xn + self.cbias
        xn = x_0 * self.sigmoid(xn)
        
        # spatial attention
        xs = torch.sigmoid(self.sweight * x_1 + self.sbias)
        xs = x_1 * xs
        
        out = torch.cat([xn, xs], dim=1)
        out = out.contiguous().view(b, -1, h, w)
        return out

# 13. Residual Spatial Attention (RSA)
class RSA(nn.Module):
    def __init__(self, c1):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c1//2, 1),
            nn.BatchNorm2d(c1//2),
            nn.ReLU(),
            nn.Conv2d(c1//2, 1, 3, padding=1),
            nn.BatchNorm2d(1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attn = self.conv(x)
        return x * self.sigmoid(attn)

# 14. Polarized Self-Attention (PSA)
class PSA(nn.Module):
    def __init__(self, c1):
        super().__init__()
        self.ch_wv = nn.Conv2d(c1, c1 // 2, kernel_size=(1, 1))
        self.ch_wq = nn.Conv2d(c1, 1, kernel_size=(1, 1))
        self.ch_wz = nn.Conv2d(c1 // 2, c1, kernel_size=(1, 1))
        self.ch_b = nn.Parameter(torch.zeros(1, c1, 1, 1))
        
        self.sp_wv = nn.Conv2d(c1, c1 // 2, kernel_size=(1, 1))
        self.sp_wq = nn.Conv2d(c1, c1 // 2, kernel_size=(1, 1))
        self.sp_wz = nn.Conv2d(c1 // 2, 1, kernel_size=(1, 1))
        self.sp_b = nn.Parameter(torch.zeros(1, 1, 1, 1))
        self.softmax = nn.Softmax(-1)

    def forward(self, x):
        # channel attention
        b, c, h, w = x.size()
        channel_wv = self.ch_wv(x)
        channel_wq = self.ch_wq(x).view(b, 1, -1)
        channel_wv = channel_wv.view(b, c//2, -1)
        channel_wq = self.softmax(channel_wq)
        channel_wz = torch.matmul(channel_wv, channel_wq.permute(0, 2, 1)).unsqueeze(-1)
        channel_out = self.ch_wz(channel_wz) + self.ch_b
        
        # spatial attention
        spatial_wv = self.sp_wv(x)
        spatial_wq = self.sp_wq(x)
        spatial_wv = spatial_wv.view(b, c//2, -1)
        spatial_wq = spatial_wq.view(b, c//2, -1).permute(0, 2, 1)
        spatial_wz = torch.matmul(spatial_wv, spatial_wq)
        spatial_out = self.sp_wz(spatial_wz.unsqueeze(-1)) + self.sp_b
        
        return x * torch.sigmoid(channel_out + spatial_out)

# 15. A2 (Area Attention) - YOLOv12原生,可改进版本
class A2_Improved(nn.Module):
    def __init__(self, dim, num_heads=8, area=1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.area = area
        
        self.qkv = nn.Conv2d(dim, dim * 3, 1)
        self.proj = nn.Conv2d(dim, dim, 1)
        self.pe = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)

    def forward(self, x):
        B, C, H, W = x.shape
        qkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, H * W).permute(1, 0, 2, 4, 3)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Area attention: partition into local regions
        if self.area > 1:
            q = q.reshape(B, self.num_heads, H // self.area, self.area, W // self.area, self.area, self.head_dim)
            q = q.permute(0, 1, 2, 4, 3, 5, 6).reshape(B, self.num_heads, -1, self.area * self.area, self.head_dim)
            k = k.reshape(B, self.num_heads, H // self.area, self.area, W // self.area, self.area, self.head_dim)
            k = k.permute(0, 1, 2, 4, 3, 5, 6).reshape(B, self.num_heads, -1, self.area * self.area, self.head_dim)
            v = v.reshape(B, self.num_heads, H // self.area, self.area, W // self.area, self.area, self.head_dim)
            v = v.permute(0, 1, 2, 4, 3, 5, 6).reshape(B, self.num_heads, -1, self.area * self.area, self.head_dim)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, H, W, C).permute(0, 3, 1, 2)
        x = self.proj(x) + self.pe(x)
        return x

# 16. Criss-Cross Attention (CCA)
class CrissCrossAttention(nn.Module):
    def __init__(self, c1):
        super().__init__()
        self.query_conv = nn.Conv2d(c1, c1 // 8, 1)
        self.key_conv = nn.Conv2d(c1, c1 // 8, 1)
        self.value_conv = nn.Conv2d(c1, c1, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=3)

    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_key = self.key_conv(x)
        proj_value = self.value_conv(x)
        
        proj_query_H = proj_query.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width).permute(0, 2, 1)
        proj_key_H = proj_key.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)
        proj_key_W = proj_key.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)
        
        proj_value_H = proj_value.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)
        proj_value_W = proj_value.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)
        
        energy_H = (torch.bmm(proj_query_H, proj_key_H) + self.INF(m_batchsize, height, width)).view(m_batchsize, width, height, height).permute(0, 2, 1, 3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize, height, width, width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))
        
        return x + self.gamma * out

混合注意力 (17-25)

Python

# 17. CBAM (Channel + Spatial)
class CBAM(nn.Module):
    def __init__(self, c1, reduction=16, kernel_size=7):
        super().__init__()
        self.channel = CBAM_Channel(c1, reduction)
        self.spatial = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.channel(x)
        x = self.spatial(x)
        return x

# 18. BAM (Bottleneck Attention)
class BAM(nn.Module):
    def __init__(self, c1, reduction=16, dilation_val=4):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(c1, c1 // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c1 // reduction, c1, bias=False)
        )
        self.conv1 = nn.Conv2d(c1, c1 // reduction, 1)
        self.conv2 = nn.Conv2d(c1 // reduction, c1 // reduction, 3, padding=dilation_val, dilation=dilation_val)
        self.conv3 = nn.Conv2d(c1 // reduction, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        y1 = self.avg_pool(x).view(b, c)
        y1 = self.fc(y1).view(b, c, 1, 1)
        
        y2 = self.conv1(x)
        y2 = F.relu(self.conv2(y2))
        y2 = self.conv3(y2)
        
        y = self.sigmoid(y1.expand_as(y2) + y2)
        return x * y

# 19. DA-Net (Dual Attention)
class DANet(nn.Module):
    def __init__(self, c1):
        super().__init__()
        self.sa = PAM(c1)
        self.sc = CAM(c1)

    def forward(self, x):
        sa_feat = self.sa(x)
        sc_feat = self.sc(x)
        return sa_feat + sc_feat

# 20. CC-Net (Criss-Cross)
class CCNet(nn.Module):
    def __init__(self, c1):
        super().__init__()
        self.cca = CrissCrossAttention(c1)
        self.cca2 = CrissCrossAttention(c1)

    def forward(self, x):
        x = self.cca(x)
        x = self.cca2(x)
        return x

# 21. ANN (Asymmetric Non-local)
class ANN(nn.Module):
    def __init__(self, c1):
        super().__init__()
        self.query_conv = nn.Conv2d(c1, c1 // 8, 1)
        self.key_conv = nn.Conv2d(c1, c1 // 8, 1)
        self.value_conv = nn.Conv2d(c1, c1, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(batchsize, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batchsize, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batchsize, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batchsize, C, height, width)
        return self.gamma * out + x

# 22. GC-Net (Global Context)
class GCNet(nn.Module):
    def __init__(self, c1, ratio=16):
        super().__init__()
        self.c1 = c1
        self.channel = 1
        self.conv_mask = nn.Conv2d(c1, 1, kernel_size=1)
        self.softmax = nn.Softmax(dim=2)
        self.channel_add_conv = nn.Sequential(
            nn.Conv2d(self.c1, self.c1 // ratio, kernel_size=1),
            nn.LayerNorm([self.c1 // ratio, 1, 1]),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.c1 // ratio, self.c1, kernel_size=1)
        )

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        input_x = x
        input_x = input_x.view(batch, channel, height * width)
        input_x = input_x.unsqueeze(1)
        context_mask = self.conv_mask(x)
        context_mask = context_mask.view(batch, 1, height * width)
        context_mask = self.softmax(context_mask)
        context = torch.matmul(input_x, context_mask.transpose(1, 2))
        return context

    def forward(self, x):
        context = self.spatial_pool(x)
        channel_add_term = self.channel_add_conv(context)
        x = x + channel_add_term
        return x

# 23. Double Attention
class DoubleAttention(nn.Module):
    def __init__(self, c1, reconstruct=True):
        super().__init__()
        self.c1 = c1
        self.conv1 = nn.Conv2d(c1, c1 // 8, 1)
        self.conv2 = nn.Conv2d(c1, c1 // 8, 1)
        self.conv3 = nn.Conv2d(c1, c1 // 8, 1)
        self.softmax = nn.Softmax(dim=-1)
        self.scale = (c1 // 8) ** -0.5
        self.reconstruct = reconstruct
        if reconstruct:
            self.conv_reconstruct = nn.Conv2d(c1 // 8, c1, 1)

    def forward(self, x):
        b, c, h, w = x.size()
        f = self.conv1(x).view(b, c // 8, h * w)
        g = self.conv2(x).view(b, c // 8, h * w)
        h_feat = self.conv3(x).view(b, c // 8, h * w)
        
        attention = torch.bmm(f.permute(0, 2, 1), g)
        attention = self.softmax(attention)
        
        out = torch.bmm(h_feat, attention.permute(0, 2, 1))
        out = out.view(b, c // 8, h, w)
        
        if self.reconstruct:
            out = self.conv_reconstruct(out)
        return out

# 24. SK-Net (Selective Kernel)
class SKNet(nn.Module):
    def __init__(self, c1, reduction=16):
        super().__init__()
        self.d = max(c1 // reduction, 32)
        self.conv1 = nn.Conv2d(c1, self.d, 1)
        self.conv2 = nn.Conv2d(c1, self.d, 1)
        self.fc = nn.Sequential(
            nn.Linear(self.d, c1 // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(c1 // reduction, c1 * 2)
        )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size = x.size(0)
        u1 = self.conv1(x)
        u2 = self.conv2(x)
        u = u1 + u2
        s = u.mean(-1).mean(-1)
        z = self.fc(s)
        z = z.view(batch_size, 2, -1)
        a_b = self.softmax(z)
        a, b = a_b[:, 0, :], a_b[:, 1, :]
        a, b = a.view(batch_size, -1, 1, 1), b.view(batch_size, -1, 1, 1)
        v = a.expand_as(u1) * u1 + b.expand_as(u2) * u2
        return v

# 25. DyNet (Dynamic Convolution)
class DyNet(nn.Module):
    def __init__(self, c1, c2, k=3, stride=1, M=4):
        super().__init__()
        self.M = M
        self.conv = nn.ModuleList()
        for i in range(M):
            self.conv.append(nn.Conv2d(c1, c2, k, stride, k//2))
        self.fc = nn.Sequential(
            nn.Linear(c1, M),
            nn.Sigmoid()
        )
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        attention = self.fc(self.avg_pool(x).squeeze(-1).squeeze(-1))
        out = 0
        for i in range(self.M):
            out += attention[:, i:i+1].unsqueeze(-1).unsqueeze(-1) * self.conv[i](x)
        return out

🏗️ 二、卷积模块改进 (26-45)

深度可分离卷积 (26-32)

Python

# 26. Ghost Convolution
class GhostConv(nn.Module):
    def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
        super().__init__()
        c_ = c2 // 2
        self.cv1 = Conv(c1, c_, k, s, None, g, act)
        self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)

    def forward(self, x):
        y = self.cv1(x)
        return torch.cat([y, self.cv2(y)], 1)

# 27. PConv (Partial Convolution)
class PConv(nn.Module):
    def __init__(self, c1, c2, k=3, s=1):
        super().__init__()
        self.cv1 = nn.Conv2d(c1 // 4, c1 // 4, k, s, k//2, groups=c1//4)
        self.cv2 = nn.Conv2d(c1, c2, 1)

    def forward(self, x):
        y = torch.zeros_like(x)
        y[:, :x.size(1)//4, :, :] = self.cv1(x[:, :x.size(1)//4, :, :])
        return self.cv2(y)

# 28. DSConv (Depthwise Separable)
class DSConv(nn.Module):
    def __init__(self, c1, c2, k=3, s=1, act=True):
        super().__init__()
        self.dw = nn.Conv2d(c1, c1, k, s, k//2, groups=c1, bias=False)
        self.pw = nn.Conv2d(c1, c2, 1, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act else nn.Identity()

    def forward(self, x):
        return self.act(self.bn(self.pw(self.dw(x))))

# 29. MixConv (Mixed Depthwise)
class MixConv(nn.Module):
    def __init__(self, c1, c2, k=(3, 5, 7), stride=1):
        super().__init__()
        self.groups = len(k)
        self.split_channel = c2 // self.groups
        self.m = nn.ModuleList([
            nn.Conv2d(c1, self.split_channel, kernel_size=ki, stride=stride, padding=ki//2, groups=c1)
            for ki in k
        ])
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))

# 30. CondConv (Conditional Convolution)
class CondConv(nn.Module):
    def __init__(self, c1, c2, k=3, s=1, num_experts=3):
        super().__init__()
        self.num_experts = num_experts
        self.conv = nn.ModuleList([nn.Conv2d(c1, c2, k, s, k//2) for _ in range(num_experts)])
        self.routing = nn.Linear(c1, num_experts)

    def forward(self, x):
        b, c, h, w = x.size()
        routing_weights = F.softmax(self.routing(F.adaptive_avg_pool2d(x, 1).view(b, c)), dim=1)
        out = sum(w.view(b, 1, 1, 1) * conv(x) for w, conv in zip(routing_weights.T, self.conv))
        return out

# 31. ODConv (Omni-Dimensional)
class ODConv(nn.Module):
    def __init__(self, c1, c2, k=3, s=1):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, k//2)
        self.attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(c1, c1//16, 1),
            nn.ReLU(),
            nn.Conv2d(c1//16, c1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.conv(x) * self.attn(x)

# 32. Involution
class Involution(nn.Module):
    def __init__(self, c1, k=7, stride=1):
        super().__init__()
        self.kernel_size = k
        self.stride = stride
        self.channel_in = c1 // 2
        self.channel_out = c1
        self.span = k // 2
        self.o_proj = nn.Conv2d(c1, k * k * self.channel_in, 1)
        self.reduce = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
        self.dilation = 1

    def forward(self, x):
        b, c, h, w = x.shape
        h_out, w_out = h // self.stride, w // self.stride
        
        weight = self.o_proj(self.reduce(x))
        weight = weight.view(b, self.channel_in, self.kernel_size * self.kernel_size, h_out, w_out)
        
        x_unfold = F.unfold(x, self.kernel_size, padding=self.span, stride=self.stride)
        x_unfold = x_unfold.view(b, self.channel_in, self.kernel_size * self.kernel_size, h_out, w_out)
        
        out = (x_unfold * weight).sum(dim=2).view(b, self.channel_in, h_out, w_out)
        return out

特征提取增强 (33-40)

Python

# 33. Res2Net Module
class Res2Net(nn.Module):
    def __init__(self, c1, c2, s=1, scale=4):
        super().__init__()
        width = c2 // scale
        self.nums = scale - 1
        convs = []
        bns = []
        for i in range(self.nums):
            convs.append(nn.Conv2d(width, width, 3, s, 1, bias=False))
            bns.append(nn.BatchNorm2d(width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)
        self.act = nn.ReLU(inplace=True)
        self.downsample = nn.Conv2d(c1, c2, 1, s, bias=False) if c1 != c2 or s != 1 else None

    def forward(self, x):
        residual = x
        out = self.act(x)
        spx = torch.split(out, out.size(1) // (self.nums + 1), dim=1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            sp = self.act(self.bns[i](sp))
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        out = torch.cat((out, spx[self.nums]), 1)
        if self.downsample is not None:
            residual = self.downsample(residual)
        return out + residual

# 34. Dilated Convolution Block
class DilatedConv(nn.Module):
    def __init__(self, c1, c2, k=3, dilations=[1, 2, 4]):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv2d(c1, c2 // len(dilations), k, padding=d, dilation=d)
            for d in dilations
        ])
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(self.bn(torch.cat([conv(x) for conv in self.convs], dim=1)))

# 35. ASPP (Atrous Spatial Pyramid Pooling)
class ASPP(nn.Module):
    def __init__(self, c1, c2, rates=[6, 12, 18]):
        super().__init__()
        self.conv1 = nn.Conv2d(c1, c2 // 4, 1)
        self.conv2 = nn.Conv2d(c1, c2 // 4, 3, padding=rates[0], dilation=rates[0])
        self.conv3 = nn.Conv2d(c1, c2 // 4, 3, padding=rates[1], dilation=rates[1])
        self.conv4 = nn.Conv2d(c1, c2 // 4, 3, padding=rates[2], dilation=rates[2])
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(c1, c2 // 4, 1)
        )
        self.project = nn.Sequential(
            nn.Conv2d(c2, c2, 1),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x4 = self.conv4(x)
        x5 = F.interpolate(self.global_avg_pool(x), size=x.shape[2:], mode='bilinear', align_corners=True)
        x = torch.cat([x1, x2, x3, x4, x5], dim=1)
        return self.project(x)

# 36. RFB (Receptive Field Block)
class RFB(nn.Module):
    def __init__(self, c1, c2, stride=1, map_reduce=8):
        super().__init__()
        inter_planes = c1 // map_reduce
        self.branch0 = nn.Sequential(
            nn.Conv2d(c1, inter_planes, 1, stride, 0),
            nn.BatchNorm2d(inter_planes),
            nn.ReLU(inplace=True)
        )
        self.branch1 = nn.Sequential(
            nn.Conv2d(c1, inter_planes, 1, 1, 0),
            nn.BatchNorm2d(inter_planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_planes, inter_planes, 3, stride, 1, dilation=1),
            nn.BatchNorm2d(inter_planes),
            nn.ReLU(inplace=True)
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(c1, inter_planes, 1, 1, 0),
            nn.BatchNorm2d(inter_planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_planes, inter_planes, 3, 1, 1),
            nn.BatchNorm2d(inter_planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(inter_planes, inter_planes, 3, stride, 3, dilation=3),
            nn.BatchNorm2d(inter_planes),
            nn.ReLU(inplace=True)
        )
        self.ConvLinear = nn.Sequential(
            nn.Conv2d(inter_planes * 3, c2, 1, 1, 0),
            nn.BatchNorm2d(c2)
        )
        self.shortcut = nn.Sequential(
            nn.Conv2d(c1, c2, 1, stride, 0),
            nn.BatchNorm2d(c2)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        out = torch.cat((x0, x1, x2), 1)
        out = self.ConvLinear(out)
        short = self.shortcut(x)
        out = out * 0.1 + short
        return self.relu(out)

# 37. DCN (Deformable Convolution)
class DCN(nn.Module):
    def __init__(self, c1, c2, k=3, s=1, p=1):
        super().__init__()
        self.conv_offset = nn.Conv2d(c1, 2 * k * k, k, s, p)
        self.conv_mask = nn.Conv2d(c1, k * k, k, s, p)
        self.conv = nn.Conv2d(c1, c2, k, s, p)
        
    def forward(self, x):
        offset = self.conv_offset(x)
        mask = torch.sigmoid(self.conv_mask(x))
        return torchvision.ops.deform_conv2d(x, offset, self.conv.weight, self.conv.bias, mask=mask)

# 38. CARAFE (Content-Aware ReAssembly)
class CARAFE(nn.Module):
    def __init__(self, c1, c2, scale_factor=2, up_kernel=5, encoder_kernel=3):
        super().__init__()
        self.scale_factor = scale_factor
        self.up_kernel = up_kernel
        self.encoder = nn.Conv2d(c1, up_kernel * up_kernel * scale_factor * scale_factor, encoder_kernel, 1, encoder_kernel//2)
        self.unfold = nn.Unfold(kernel_size=up_kernel, padding=up_kernel//2)

    def forward(self, x):
        N, C, H, W = x.size()
        kernel = self.encoder(x)
        kernel = kernel.view(N, self.up_kernel * self.up_kernel, self.scale_factor * self.scale_factor, H, W)
        kernel = F.softmax(kernel, dim=1)
        
        x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
        x = self.unfold(x).view(N, C, self.up_kernel * self.up_kernel, self.scale_factor * H, self.scale_factor * W)
        
        out = (x * kernel.unsqueeze(1)).sum(dim=2)
        return out

# 39. DySample (Dynamic Sampling)
class DySample(nn.Module):
    def __init__(self, c1, scale=2, style='lp'):
        super().__init__()
        self.scale = scale
        self.style = style
        self.offset_conv = nn.Sequential(
            nn.Conv2d(c1, 2 * scale * scale, 1),
            nn.PixelShuffle(scale)
        )

    def forward(self, x):
        offset = self.offset_conv(x)
        return self.sample(x, offset)

    def sample(self, x, offset):
        b, c, h, w = x.shape
        offset = offset.permute(0, 2, 3, 1).contiguous()
        grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
        grid = torch.stack([grid_x, grid_y], dim=-1).float().to(x.device)
        grid = grid.unsqueeze(0).repeat(b, 1, 1, 1)
        grid = grid + offset
        grid[:, :, :, 0] = grid[:, :, :, 0] / w * 2 - 1
        grid[:, :, :, 1] = grid[:, :, :, 1] / h * 2 - 1
        return F.grid_sample(x, grid, mode='bilinear', padding_mode='zeros')

# 40. SPD-Conv (Space-to-Depth)
class SPDConv(nn.Module):
    def __init__(self, c1, c2, k=3, s=1):
        super().__init__()
        self.conv = nn.Conv2d(c1 * 4, c2, k, s, k//2)

    def forward(self, x):
        return self.conv(torch.cat([
            x[..., ::2, ::2],
            x[..., 1::2, ::2],
            x[..., ::2, 1::2],
            x[..., 1::2, 1::2]
        ], dim=1))

轻量化设计 (41-45)

Python

# 41. MobileNetV3 Block
class MobileNetV3Block(nn.Module):
    def __init__(self, c1, c2, exp_ratio=4, k=3, se=True):
        super().__init__()
        exp = c1 * exp_ratio
        self.conv = nn.Sequential(
            nn.Conv2d(c1, exp, 1, bias=False),
            nn.BatchNorm2d(exp),
            nn.Hardswish(),
            nn.Conv2d(exp, exp, k, 1, k//2, groups=exp, bias=False),
            nn.BatchNorm2d(exp),
            nn.Hardswish() if se else nn.Identity(),
            SE(exp) if se else nn.Identity(),
            nn.Conv2d(exp, c2, 1, bias=False),
            nn.BatchNorm2d(c2)
        )
        self.shortcut = nn.Conv2d(c1, c2, 1) if c1 != c2 else nn.Identity()

    def forward(self, x):
        return self.conv(x) + self.shortcut(x)

# 42. ShuffleNetV2 Block
class ShuffleV2Block(nn.Module):
    def __init__(self, c1, c2, stride):
        super().__init__()
        self.stride = stride
        branch_features = c2 // 2
        if self.stride > 1:
            self.branch1 = nn.Sequential(
                nn.Conv2d(c1, c1, 3, stride, 1, groups=c1, bias=False),
                nn.BatchNorm2d(c1),
                nn.Conv2d(c1, branch_features, 1, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True)
            )
        else:
            self.branch1 = nn.Sequential()
        
        self.branch2 = nn.Sequential(
            nn.Conv2d(c1 if stride > 1 else branch_features, branch_features, 1, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            nn.Conv2d(branch_features, branch_features, 3, stride, 1, groups=branch_features, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, 1, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        return channel_shuffle(out, 2)

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.size()
    channels_per_group = num_channels // groups
    x = x.view(batchsize, groups, channels_per_group, height, width)
    x = x.transpose(1, 2).contiguous()
    x = x.view(batchsize, num_channels, height, width)
    return x

# 43. EfficientNet Block
class EfficientNetBlock(nn.Module):
    def __init__(self, c1, c2, expand_ratio, k=3, stride=1, se_ratio=0.25):
        super().__init__()
        hidden_dim = round(c1 * expand_ratio)
        self.conv = nn.Sequential(
            nn.Conv2d(c1, hidden_dim, 1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim, hidden_dim, k, stride, k//2, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.SiLU(),
            SE(hidden_dim, int(1/se_ratio)) if se_ratio else nn.Identity(),
            nn.Conv2d(hidden_dim, c2, 1, bias=False),
            nn.BatchNorm2d(c2)
        )
        self.shortcut = nn.Sequential() if stride == 1 and c1 == c2 else None

    def forward(self, x):
        if self.shortcut is not None:
            return x + self.conv(x)
        else:
            return self.conv(x)

# 44. RepVGG Block
class RepVGGBlock(nn.Module):
    def __init__(self, c1, c2, k=3, s=1, padding=1, deploy=False):
        super().__init__()
        self.deploy = deploy
        self.nonlinearity = nn.ReLU()
        
        if deploy:
            self.rbr_reparam = nn.Conv2d(c1, c2, k, s, padding, bias=True)
        else:
            self.rbr_identity = nn.BatchNorm2d(c1) if c2 == c1 and s == 1 else None
            self.rbr_dense = nn.Sequential(
                nn.Conv2d(c1, c2, k, s, padding, bias=False),
                nn.BatchNorm2d(c2)
            )
            self.rbr_1x1 = nn.Sequential(
                nn.Conv2d(c1, c2, 1, s, 0, bias=False),
                nn.BatchNorm2d(c2)
            )

    def forward(self, x):
        if self.deploy:
            return self.nonlinearity(self.rbr_reparam(x))
        
        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(x)
        
        return self.nonlinearity(self.rbr_dense(x) + self.rbr_1x1(x) + id_out)

# 45. FasterNet Block
class FasterNetBlock(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        super().__init__()
        c_ = int(c2 * e)
        self.cv1 = PConv(c1, c_, 3, 1)
        self.cv2 = Conv(c_, c2, 1)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

🔄 三、特征融合模块 (46-65)

Neck结构改进 (46-55)

Python

# 46. BiFPN (Bidirectional FPN)
class BiFPN(nn.Module):
    def __init__(self, channels_list, num_outs=5):
        super().__init__()
        self.num_outs = num_outs
        self.bifpn_convs = nn.ModuleList()
        
        for i in range(num_outs - 1):
            self.bifpn_convs.append(
                nn.Sequential(
                    Conv(channels_list[i] + channels_list[i+1], channels_list[i+1], 1),
                    Conv(channels_list[i+1], channels_list[i+1], 3)
                )
            )

    def forward(self, inputs):
        outs = [inputs[0]]
        for i in range(len(inputs) - 1):
            up = F.interpolate(outs[-1], size=inputs[i+1].shape[2:], mode='nearest')
            out = torch.cat([up, inputs[i+1]], dim=1)
            out = self.bifpn_convs[i](out)
            outs.append(out)
        return outs

# 47. NAS-FPN
class NASFPN(nn.Module):
    def __init__(self, c1, c2, num_layers=7):
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(self._build_merge_module(c1, c2))

    def _build_merge_module(self, c1, c2):
        return nn.Sequential(
            Conv(c1 * 2, c2, 1),
            Conv(c2, c2, 3),
            nn.BatchNorm2d(c2),
            nn.ReLU()
        )

    def forward(self, features):
        # Simplified NAS-FPN structure
        for layer in self.layers:
            new_feat = layer(torch.cat([features[-2], features[-1]], dim=1))
            features.append(new_feat)
        return features[-3:]

# 48. ASFF (Adaptively Spatial Feature Fusion)
class ASFF(nn.Module):
    def __init__(self, level, rfb=False, vis=False):
        super().__init__()
        self.level = level
        self.dim = [512, 256, 128]  # P3, P4, P5 channels
        self.inter_dim = self.dim[self.level]
        
        if level == 0:
            self.stride_level_1 = Conv(256, self.inter_dim, 3, 2)
            self.stride_level_2 = Conv(128, self.inter_dim, 3, 2)
        elif level == 1:
            self.compress_level_0 = Conv(512, self.inter_dim, 1, 1)
            self.stride_level_2 = Conv(128, self.inter_dim, 3, 2)
        elif level == 2:
            self.compress_level_0 = Conv(512, self.inter_dim, 1, 1)
            self.compress_level_1 = Conv(256, self.inter_dim, 1, 1)
        
        self.weight_level_0 = Conv(self.inter_dim, 2, 1, 1)
        self.weight_level_1 = Conv(self.inter_dim, 2, 1, 1)
        self.weight_level_2 = Conv(self.inter_dim, 2, 1, 1)
        
        self.weights_levels = Conv(6, 3, 1, 1)

    def forward(self, x_level_0, x_level_1, x_level_2):
        if self.level == 0:
            level_0_resized = x_level_0
            level_1_resized = self.stride_level_1(x_level_1)
            level_2_downsampled_inter = F.max_pool2d(x_level_2, 3, stride=2, padding=1)
            level_2_resized = self.stride_level_2(level_2_downsampled_inter)
        elif self.level == 1:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')
            level_1_resized = x_level_1
            level_2_resized = self.stride_level_2(x_level_2)
        elif self.level == 2:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')
            level_1_compressed = self.compress_level_1(x_level_1)
            level_1_resized = F.interpolate(level_1_compressed, scale_factor=2, mode='nearest')
            level_2_resized = x_level_2
        
        level_0_weight_v = self.weight_level_0(level_0_resized)
        level_1_weight_v = self.weight_level_1(level_1_resized)
        level_2_weight_v = self.weight_level_2(level_2_resized)
        
        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
        levels_weight = self.weights_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)
        
        fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
                          level_1_resized * levels_weight[:, 1:2, :, :] + \
                          level_2_resized * levels_weight[:, 2:, :, :]
        return fused_out_reduced

# 49. PAFPN (Path Aggregation FPN)
class PAFPN(nn.Module):
    def __init__(self, in_channels, out_channels, num_outs=3):
        super().__init__()
        self.lateral_convs = nn.ModuleList()
        self.fpn_convs = nn.ModuleList()
        self.downsample_convs = nn.ModuleList()
        self.pafpn_convs = nn.ModuleList()
        
        for i in range(num_outs):
            l_conv = Conv(in_channels[i], out_channels, 1)
            fpn_conv = Conv(out_channels, out_channels, 3)
            self.lateral_convs.append(l_conv)
            self.fpn_convs.append(fpn_conv)
            
        for i in range(num_outs - 1):
            d_conv = Conv(out_channels, out_channels, 3, stride=2)
            pafpn_conv = Conv(out_channels, out_channels, 3)
            self.downsample_convs.append(d_conv)
            self.pafpn_convs.append(pafpn_conv)

    def forward(self, inputs):
        laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
        
        # Top-down pathway
        for i in range(len(laterals) - 1, 0, -1):
            prev_shape = laterals[i - 1].shape[2:]
            laterals[i - 1] = laterals[i - 1] + F.interpolate(laterals[i], size=prev_shape, mode='nearest')
        
        # Bottom-up pathway
        inter_outs = [self.fpn_convs[i](laterals[i]) for i in range(len(laterals))]
        for i in range(len(inter_outs) - 1):
            inter_outs[i + 1] = inter_outs[i + 1] + self.downsample_convs[i](inter_outs[i])
        
        outs = [self.pafpn_convs[i](inter_outs[i]) for i in range(len(inter_outs))]
        return outs

# 50. HRFPN (High Resolution FPN)
class HRFPN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1x1 = nn.ModuleList([Conv(c, out_channels, 1) for c in in_channels])
        self.conv3x3 = nn.ModuleList([Conv(out_channels, out_channels, 3) for _ in in_channels])

    def forward(self, x):
        outs = []
        for i, (conv1, conv3) in enumerate(zip(self.conv1x1, self.conv3x3)):
            if i == 0:
                out = conv3(conv1(x[i]))
            else:
                out = conv3(conv1(x[i]) + F.interpolate(outs[-1], size=x[i].shape[2:], mode='nearest'))
            outs.append(out)
        return outs

# 51. FPT (Feature Pyramid Transformer)
class FPT(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.non_local = NonLocal2d(channels)
        self.scaled_dot_product = nn.MultiheadAttention(channels, 8)
        
    def forward(self, feats):
        # Self-Transformer
        self_trans = self.non_local(feats[-1])
        
        # Grounding Transformer
        ground_trans = []
        for i in range(len(feats) - 1):
            q = feats[i].flatten(2).permute(2, 0, 1)
            k = feats[-1].flatten(2).permute(2, 0, 1)
            v = feats[-1].flatten(2).permute(2, 0, 1)
            attn_out, _ = self.scaled_dot_product(q, k, v)
            ground_trans.append(attn_out.permute(1, 2, 0).view_as(feats[i]))
        
        return ground_trans + [self_trans]

# 52. CFP (Centralized Feature Pyramid)
class CFP(nn.Module):
    def __init__(self, channels, num_levels=3):
        super().__init__()
        self.lateral_convs = nn.ModuleList([Conv(c, channels, 1) for c in channels])
        self.fpn_convs = nn.ModuleList([Conv(channels, channels, 3) for _ in range(num_levels)])
        self.global_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 4, 1),
            nn.ReLU(),
            nn.Conv2d(channels // 4, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        laterals = [conv(inp) for conv, inp in zip(self.lateral_convs, inputs)]
        
        # Centralized fusion
        global_feat = sum([F.adaptive_avg_pool2d(l, 1) for l in laterals])
        global_attn = self.global_attn(global_feat)
        
        outs = []
        for i, lateral in enumerate(laterals):
            if i < len(laterals) - 1:
                upsampled = F.interpolate(laterals[i + 1], size=lateral.shape[2:], mode='nearest')
                lateral = lateral + upsampled
            out = self.fpn_convs[i](lateral * global_attn)
            outs.append(out)
        return outs

# 53. AFPN (Asymptotic FPN)
class AFPN(nn.Module):
    def __init__(self, in_channels, out_channels, num_outs=3):
        super().__init__()
        self.asymptotic_fusion = nn.ModuleList()
        for i in range(num_outs):
            self.asymptotic_fusion.append(
                nn.Sequential(
                    Conv(in_channels[i], out_channels, 1),
                    Conv(out_channels, out_channels, 3)
                )
            )

    def forward(self, inputs):
        outs = []
        for i, (inp, fusion) in enumerate(zip(inputs, self.asymptotic_fusion)):
            if i > 0:
                inp = inp + F.interpolate(outs[-1], size=inp.shape[2:], mode='nearest')
            out = fusion(inp)
            outs.append(out)
        return outs

# 54. GFPN (Generalized FPN)
class GFPN(nn.Module):
    def __init__(self, channels_list):
        super().__init__()
        self.nodes = nn.ModuleList()
        for i in range(len(channels_list) - 1):
            self.nodes.append(
                nn.Sequential(
                    Conv(channels_list[i] + channels_list[i+1], channels_list[i+1], 1),
                    Conv(channels_list[i+1], channels_list[i+1], 3)
                )
            )

    def forward(self, features):
        for i, node in enumerate(self.nodes):
            up = F.interpolate(features[i], size=features[i+1].shape[2:], mode='nearest')
            features[i+1] = node(torch.cat([up, features[i+1]], dim=1))
        return features

# 55. C3k2_Improved (改进版C3k2)
class C3k2_Improved(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
        super().__init__()
        self.c = int(c2 * e)
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=(k, k), e=1.0) for _ in range(n))

    def forward(self, x):
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

跨尺度融合 (56-65)

Python

# 56. FRM (Feature Refinement Module)
class FRM(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.conv_high = Conv(c1, c2, 1)
        self.conv_low = Conv(c1, c2, 1)
        self.conv_out = Conv(c2, c2, 3)

    def forward(self, x_high, x_low):
        x_high = F.interpolate(x_high, size=x_low.shape[2:], mode='nearest')
        x_high = self.conv_high(x_high)
        x_low = self.conv_low(x_low)
        out = x_high + x_low
        return self.conv_out(out)

# 57. FFM (Feature Fusion Module)
class FFM(nn.Module):
    def __init__(self, c1, c2, num_inputs=2):
        super().__init__()
        self.conv = Conv(c1 * num_inputs, c2, 1)
        self.attn = ECA(c2)

    def forward(self, *inputs):
        # Resize all inputs to the same size
        target_size = inputs[0].shape[2:]
        resized = [F.interpolate(inp, size=target_size, mode='nearest') if inp.shape[2:] != target_size else inp 
                  for inp in inputs]
        concat = torch.cat(resized, dim=1)
        return self.attn(self.conv(concat))

# 58. CCFM (Cross-Scale Feature Fusion)
class CCFM(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.convs = nn.ModuleList([Conv(c, channels, 1) for c in channels])
        self.fusion = Conv(channels * len(channels), channels, 1)

    def forward(self, features):
        # Resize all to medium scale
        target = features[len(features)//2]
        resized = [F.interpolate(self.convs[i](f), size=target.shape[2:], mode='nearest') 
                  for i, f in enumerate(features)]
        return self.fusion(torch.cat(resized, dim=1))

# 59. SFF (Selective Feature Fusion)
class SFF(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.conv1 = Conv(c1, c2, 1)
        self.conv2 = Conv(c1, c2, 1)
        self.selector = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(c2 * 2, 2, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x2 = F.interpolate(self.conv2(x2), size=x1.shape[2:], mode='nearest')
        weights = self.selector(torch.cat([x1, x2], dim=1))
        return x1 * weights[:, 0:1] + x2 * weights[:, 1:2]

# 60. MSFM (Multi-Scale Fusion Module)
class MSFM(nn.Module):
    def __init__(self, c1, c2, scales=[1, 2, 4]):
        super().__init__()
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.AvgPool2d(scale, stride=scale) if scale > 1 else nn.Identity(),
                Conv(c1, c2 // len(scales), 1),
                nn.Upsample(scale_factor=scale, mode='nearest') if scale > 1 else nn.Identity()
            ) for scale in scales
        ])
        self.fusion = Conv(c2, c2, 1)

    def forward(self, x):
        outs = [branch(x) for branch in self.branches]
        # Resize all to original size
        outs = [F.interpolate(o, size=x.shape[2:], mode='nearest') for o in outs]
        return self.fusion(torch.cat(outs, dim=1))

# 61. A2C2f_Improved (改进版A2C2f)
class A2C2f_Improved(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        super().__init__()
        self.c = int(c2 * e)
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)
        self.m = nn.ModuleList(A2(self.c, self.c) for _ in range(n))

    def forward(self, x):
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

# 62. SimpleFusion (极简融合)
class SimpleFusion(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv = Conv(c1 * 2, c2, 3)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        return self.conv(torch.cat([x1, x2], dim=1))

# 63. DenseFusion (密集连接融合)
class DenseFusion(nn.Module):
    def __init__(self, channels_list):
        super().__init__()
        self.convs = nn.ModuleList()
        for i in range(len(channels_list)):
            self.convs.append(Conv(sum(channels_list[:i+1]), channels_list[i], 1))

    def forward(self, features):
        outs = [features[0]]
        for i in range(1, len(features)):
            upsampled = [F.interpolate(f, size=features[i].shape[2:], mode='nearest') for f in outs]
            fused = torch.cat(upsampled + [features[i]], dim=1)
            outs.append(self.convs[i](fused))
        return outs

# 64. AttentionFusion (注意力融合)
class AttentionFusion(nn.Module):
    def __init__(self, c1, c2):
        super().__init__()
        self.conv1 = Conv(c1, c2, 1)
        self.conv2 = Conv(c1, c2, 1)
        self.attn = nn.Sequential(
            nn.Conv2d(c2 * 2, c2, 1),
            nn.BatchNorm2d(c2),
            nn.Sigmoid()
        )

    def forward(self, x1, x2):
        x1 = self.conv1(x1)
        x2 = F.interpolate(self.conv2(x2), size=x1.shape[2:], mode='nearest')
        attn = self.attn(torch.cat([x1, x2], dim=1))
        return x1 * attn + x2 * (1 - attn)

Python

# 65. ScaleAdaptiveFusion (尺度自适应融合) - 续
class ScaleAdaptiveFusion(nn.Module):
    def __init__(self, c1, c2, num_scales=3):
        super().__init__()
        self.scale_convs = nn.ModuleList([Conv(c1, c2, 3, stride=2**i) for i in range(num_scales)])
        self.fusion = Conv(c2 * num_scales, c2, 1)
        self.scale_weights = nn.Parameter(torch.ones(num_scales))

    def forward(self, x):
        multi_scale_feats = []
        for conv in self.scale_convs:
            feat = conv(x)
            multi_scale_feats.append(F.interpolate(feat, size=x.shape[2:], mode='nearest'))
        
        weights = F.softmax(self.scale_weights, dim=0)
        weighted_feats = [f * w.view(1, 1, 1, 1) for f, w in zip(multi_scale_feats, weights)]
        return self.fusion(torch.cat(weighted_feats, dim=1))

🎯 四、检测头改进 (66-80)

Python

# 66. Decoupled Head (解耦头)
class DecoupledHead(nn.Module):
    def __init__(self, c1, num_classes=80, anchors=1):
        super().__init__()
        self.stems = nn.Conv2d(c1, c1, 1)
        self.cls_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.reg_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.cls_preds = nn.Conv2d(c1, num_classes * anchors, 1)
        self.reg_preds = nn.Conv2d(c1, 4 * anchors, 1)
        self.obj_preds = nn.Conv2d(c1, anchors, 1)

    def forward(self, x):
        x = self.stems(x)
        cls_feat = self.cls_convs(x)
        reg_feat = self.reg_convs(x)
        return torch.cat([
            self.reg_preds(reg_feat),
            self.obj_preds(reg_feat),
            self.cls_preds(cls_feat)
        ], 1)

# 67. Anchor-Free Head
class AnchorFreeHead(nn.Module):
    def __init__(self, c1, num_classes=80):
        super().__init__()
        self.cls_conv = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3),
            nn.Conv2d(c1, num_classes, 1)
        )
        self.reg_conv = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3),
            nn.Conv2d(c1, 4, 1)
        )

    def forward(self, x):
        cls_score = self.cls_conv(x)
        bbox_pred = self.reg_conv(x).exp()
        return cls_score, bbox_pred

# 68. DyHead (Dynamic Head)
class DyHead(nn.Module):
    def __init__(self, c1, num_classes=80, num_anchors=3):
        super().__init__()
        self.scale_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(c1, 1, 1),
            nn.Sigmoid()
        )
        self.spatial_attn = nn.Sequential(
            nn.Conv2d(c1, 1, 7, padding=3),
            nn.Sigmoid()
        )
        self.task_attn = nn.Sequential(
            nn.Conv2d(c1, c1, 1),
            nn.Sigmoid()
        )
        self.cls_pred = nn.Conv2d(c1, num_classes * num_anchors, 1)
        self.reg_pred = nn.Conv2d(c1, 4 * num_anchors, 1)

    def forward(self, x):
        # Scale-aware attention
        scale = self.scale_attn(x)
        # Spatial-aware attention
        spatial = self.spatial_attn(x)
        # Task-aware attention
        feat = x * scale * spatial
        task = self.task_attn(feat)
        
        cls_feat = feat * task
        reg_feat = feat * (1 - task)
        
        return torch.cat([self.reg_pred(reg_feat), self.cls_pred(cls_feat)], 1)

# 69. TOOD Head (Task-aligned One-stage Object Detection)
class TOODHead(nn.Module):
    def __init__(self, c1, num_classes=80, num_anchors=1):
        super().__init__()
        self.num_classes = num_classes
        self.inter_convs = nn.ModuleList([
            Conv(c1, c1, 3) for _ in range(4)
        ])
        self.task_decomposition = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(c1, c1, 1),
            nn.ReLU(),
            nn.Conv2d(c1, c1, 1),
            nn.Sigmoid()
        )
        self.cls_prob_module = nn.Sequential(
            nn.Conv2d(c1, c1, 3, padding=1),
            nn.BatchNorm2d(c1),
            nn.ReLU(),
            nn.Conv2d(c1, num_classes * num_anchors, 1)
        )
        self.reg_offset_module = nn.Sequential(
            nn.Conv2d(c1, c1, 3, padding=1),
            nn.BatchNorm2d(c1),
            nn.ReLU(),
            nn.Conv2d(c1, 4 * num_anchors, 1)
        )

    def forward(self, x):
        inter_feat = x
        for inter_conv in self.inter_convs:
            inter_feat = inter_conv(inter_feat)
        
        task_feat = self.task_decomposition(inter_feat)
        cls_feat = inter_feat + task_feat * inter_feat
        reg_feat = inter_feat + (1 - task_feat) * inter_feat
        
        cls_score = self.cls_prob_module(cls_feat)
        bbox_pred = self.reg_offset_module(reg_feat)
        return torch.cat([bbox_pred, cls_score], 1)

# 70. GFL Head (Generalized Focal Loss)
class GFLHead(nn.Module):
    def __init__(self, c1, num_classes=80):
        super().__init__()
        self.cls_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.reg_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.gfl_cls = nn.Conv2d(c1, num_classes, 1)
        self.gfl_reg = nn.Conv2d(c1, 4 * (16 + 1), 1)  # 16 for DFL + 1 for bbox

    def forward(self, x):
        cls_feat = self.cls_convs(x)
        reg_feat = self.reg_convs(x)
        cls_score = self.gfl_cls(cls_feat)
        bbox_pred = self.gfl_reg(reg_feat)
        return cls_score, bbox_pred

# 71. PPYOLOE Head
class PPYOLOEHead(nn.Module):
    def __init__(self, c1, num_classes=80, num_anchors=3):
        super().__init__()
        self.stem_cls = Conv(c1, c1, 1)
        self.stem_reg = Conv(c1, c1, 1)
        
        self.pred_cls = nn.Sequential(
            Conv(c1, c1, 3),
            nn.Conv2d(c1, num_classes * num_anchors, 1)
        )
        self.pred_reg = nn.Sequential(
            Conv(c1, c1, 3),
            nn.Conv2d(c1, 4 * num_anchors, 1)
        )
        self.pred_iou = nn.Sequential(
            Conv(c1, c1, 3),
            nn.Conv2d(c1, num_anchors, 1)
        )

    def forward(self, x):
        cls_feat = self.stem_cls(x)
        reg_feat = self.stem_reg(x)
        
        cls_logit = self.pred_cls(cls_feat)
        bbox_pred = self.pred_reg(reg_feat)
        iou_mat = self.pred_iou(reg_feat)
        
        return torch.cat([bbox_pred, iou_mat, cls_logit], 1)

# 72. RT-DETR Head
class RTDETRHead(nn.Module):
    def __init__(self, c1, num_classes=80, hidden_dim=256, num_queries=300):
        super().__init__()
        self.input_proj = nn.Conv2d(c1, hidden_dim, 1)
        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(hidden_dim, 8, batch_first=True),
            num_layers=6
        )
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.bbox_embed = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 4)
        )

    def forward(self, x):
        # Simplified version
        bs, c, h, w = x.shape
        x = self.input_proj(x).flatten(2).permute(0, 2, 1)
        query = self.query_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
        tgt = self.transformer(query, x)
        outputs_class = self.class_embed(tgt)
        outputs_coord = self.bbox_embed(tgt).sigmoid()
        return outputs_class, outputs_coord

# 73. YOLOX Head
class YOLOXHead(nn.Module):
    def __init__(self, c1, num_classes=80):
        super().__init__()
        self.stem = Conv(c1, c1, 1)
        self.cls_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.reg_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.cls_pred = nn.Conv2d(c1, num_classes, 1)
        self.reg_pred = nn.Conv2d(c1, 4, 1)
        self.obj_pred = nn.Conv2d(c1, 1, 1)

    def forward(self, x):
        x = self.stem(x)
        cls_feat = self.cls_convs(x)
        reg_feat = self.reg_convs(x)
        
        cls_output = self.cls_pred(cls_feat)
        reg_output = self.reg_pred(reg_feat)
        obj_output = self.obj_pred(reg_feat)
        
        return torch.cat([reg_output, obj_output, cls_output], 1)

# 74. NanoDet Head
class NanoDetHead(nn.Module):
    def __init__(self, c1, num_classes=80, reg_max=7):
        super().__init__()
        self.cls_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.reg_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.gfl_cls = nn.Conv2d(c1, num_classes, 1)
        self.gfl_reg = nn.Conv2d(c1, 4 * (reg_max + 1), 1)

    def forward(self, x):
        cls_feat = self.cls_convs(x)
        reg_feat = self.reg_convs(x)
        cls_score = self.gfl_cls(cls_feat)
        bbox_pred = self.gfl_reg(reg_feat)
        return cls_score, bbox_pred

# 75. CenterNet Head
class CenterNetHead(nn.Module):
    def __init__(self, c1, num_classes=80):
        super().__init__()
        self.heatmap = nn.Sequential(
            Conv(c1, c1, 3),
            nn.Conv2d(c1, num_classes, 1)
        )
        self.wh = nn.Sequential(
            Conv(c1, c1, 3),
            nn.Conv2d(c1, 2, 1)
        )
        self.offset = nn.Sequential(
            Conv(c1, c1, 3),
            nn.Conv2d(c1, 2, 1)
        )

    def forward(self, x):
        heatmap = self.heatmap(x).sigmoid()
        wh = self.wh(x).relu()
        offset = self.offset(x)
        return heatmap, wh, offset

# 76. FCOS Head
class FCOSHead(nn.Module):
    def __init__(self, c1, num_classes=80):
        super().__init__()
        self.cls_tower = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3),
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.bbox_tower = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3),
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.cls_logits = nn.Conv2d(c1, num_classes, 3, padding=1)
        self.bbox_pred = nn.Conv2d(c1, 4, 3, padding=1)
        self.centerness = nn.Conv2d(c1, 1, 3, padding=1)

    def forward(self, x):
        cls_feat = self.cls_tower(x)
        reg_feat = self.bbox_tower(x)
        
        cls_score = self.cls_logits(cls_feat)
        bbox_pred = F.relu(self.bbox_pred(reg_feat))
        centerness = self.centerness(reg_feat)
        
        return cls_score, bbox_pred, centerness

# 77. ATSS Head
class ATSSHead(nn.Module):
    def __init__(self, c1, num_classes=80, num_anchors=9):
        super().__init__()
        self.num_anchors = num_anchors
        self.cls_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.reg_convs = nn.Sequential(
            Conv(c1, c1, 3),
            Conv(c1, c1, 3)
        )
        self.atss_cls = nn.Conv2d(c1, num_classes * num_anchors, 3, padding=1)
        self.atss_reg = nn.Conv2d(c1, 4 * num_anchors, 3, padding=1)
        self.atss_centerness = nn.Conv2d(c1, num_anchors, 3, padding=1)

    def forward(self, x):
        cls_feat = self.cls_convs(x)
        reg_feat = self.reg_convs(x)
        
        cls_score = self.atss_cls(cls_feat)
        bbox_pred = self.atss_reg(reg_feat)
        centerness = self.atss_centerness(reg_feat)
        
        return cls_score, bbox_pred, centerness

# 78. PAFHead (Parallel Attention Fusion Head)
class PAFHead(nn.Module):
    def __init__(self, c1, num_classes=80):
        super().__init__()
        self.local_conv = nn.Sequential(
            Conv(c1, c1 // 2, 3),
            Conv(c1 // 2, c1 // 2, 3)
        )
        self.global_conv = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Conv(c1, c1 // 2, 1),
            nn.Upsample(scale_factor=64, mode='nearest')
        )
        self.fusion = Conv(c1, c1, 1)
        self.pred = nn.Conv2d(c1, num_classes + 4 + 1, 1)

    def forward(self, x):
        local_feat = self.local_conv(x)
        global_feat = self.global_conv(x)
        # Resize global to match local
        global_feat = F.interpolate(global_feat, size=local_feat.shape[2:], mode='nearest')
        fused = self.fusion(torch.cat([local_feat, global_feat], dim=1))
        return self.pred(fused)

# 79. DoubleBranchHead
class DoubleBranchHead(nn.Module):
    def __init__(self, c1, num_classes=80):
        super().__init__()
        self.context_branch = nn.Sequential(
            ASPP(c1, c1),
            Conv(c1, c1, 3)
        )
        self.detail_branch = nn.Sequential(
            nn.Conv2d(c1, c1, 3, padding=1, groups=c1),
            nn.BatchNorm2d(c1),
            nn.Conv2d(c1, c1, 1),
            nn.BatchNorm2d(c1),
            nn.ReLU()
        )
        self.fusion = Conv(c1 * 2, c1, 1)
        self.pred = nn.Conv2d(c1, num_classes + 5, 1)

    def forward(self, x):
        context = self.context_branch(x)
        detail = self.detail_branch(x)
        fused = self.fusion(torch.cat([context, detail], dim=1))
        return self.pred(fused)

# 80. LiteHead (超轻量检测头)
class LiteHead(nn.Module):
    def __init__(self, c1, num_classes=80):
        super().__init__()
        self.shared_conv = Conv(c1, c1, 3)
        self.cls_pred = nn.Conv2d(c1, num_classes, 1)
        self.reg_pred = nn.Conv2d(c1, 4, 1)

    def forward(self, x):
        feat = self.shared_conv(x)
        return torch.cat([self.reg_pred(feat), self.cls_pred(feat)], 1)

⚡ 五、损失函数与标签分配 (81-90)

Python

# 81. TaskAlignedAssigner (任务对齐分配器)
class TaskAlignedAssigner(nn.Module):
    def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0):
        super().__init__()
        self.topk = topk
        self.num_classes = num_classes
        self.alpha = alpha
        self.beta = beta

    def forward(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
        # Simplified implementation
        # Align metric = cls_score^alpha * iou^beta
        na = pd_bboxes.shape[-2]
        num_gt = gt_bboxes.shape[0]
        
        if num_gt == 0:
            return torch.zeros(na, dtype=torch.long), torch.zeros(na, self.num_classes)
        
        # Calculate IoU
        overlaps = self.iou_calculation(gt_bboxes, pd_bboxes)
        
        # Calculate alignment metric
        align_metric = pd_scores[:, gt_labels].pow(self.alpha) * overlaps.pow(self.beta)
        
        # Select top-k candidates
        topk_metrics, topk_idxs = torch.topk(align_metric, self.topk, dim=1)
        
        return topk_idxs, align_metric

    def iou_calculation(self, gt_bboxes, pd_bboxes):
        # Simplified IoU calculation
        return torch.rand(gt_bboxes.shape[0], pd_bboxes.shape[0])

# 82. SimOTA (Simplified Optimal Transport Assignment)
class SimOTA:
    def __init__(self, num_classes=80, center_radius=2.5):
        self.num_classes = num_classes
        self.center_radius = center_radius

    def __call__(self, pred_scores, pred_bboxes, gt_labels, gt_bboxes):
        # Dynamic k estimation based on IoU
        num_gt = gt_bboxes.shape[0]
        num_pred = pred_bboxes.shape[0]
        
        if num_gt == 0:
            return torch.zeros(num_pred, dtype=torch.bool), torch.zeros(num_pred, dtype=torch.long)
        
        # Cost matrix: cls_loss + 3.0 * reg_loss
        ious = self.bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0))
        cost = (1 - ious) * 3.0 + (1 - pred_scores[:, gt_labels])
        
        # Dynamic k
        ious_in_boxes_matrix = ious
        topk_ious, _ = torch.topk(ious_in_boxes_matrix, min(10, num_pred), dim=0)
        dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
        
        # Optimal assignment
        matching_matrix = torch.zeros(num_pred, num_gt)
        for gt_idx in range(num_gt):
            _, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
            matching_matrix[pos_idx, gt_idx] = 1.0
        
        return matching_matrix.sum(1) > 0, matching_matrix.argmax(1)

    def bbox_iou(self, box1, box2):
        # Simplified IoU
        return torch.rand(box1.shape[0], box2.shape[0])

# 83. TAL (Task Alignment Learning)
class TAL(nn.Module):
    def __init__(self, topk=13, alpha=1.0, beta=6.0):
        super().__init__()
        self.topk = topk
        self.alpha = alpha
        self.beta = beta

    def forward(self, cls_scores, bbox_preds, gt_bboxes, gt_labels):
        # Task alignment metric
        num_gt = gt_bboxes.shape[0]
        num_pred = cls_scores.shape[0]
        
        # Classification score target
        cls_targets = torch.zeros(num_pred, cls_scores.shape[1])
        
        # IoU calculation
        ious = self.compute_iou(bbox_preds, gt_bboxes)
        
        # Alignment metric
        alignment_metrics = cls_scores.gather(1, gt_labels.unsqueeze(0).expand(num_pred, -1))
        alignment_metrics = alignment_metrics.pow(self.alpha) * ious.pow(self.beta)
        
        # Positive mask
        topk_metrics, topk_indices = torch.topk(alignment_metrics, self.topk, dim=0)
        positive_mask = torch.zeros_like(alignment_metrics)
        positive_mask.scatter_(0, topk_indices, 1)
        
        return positive_mask, alignment_metrics

    def compute_iou(self, bboxes1, bboxes2):
        return torch.rand(bboxes1.shape[0], bboxes2.shape[0])

# 84. ATSS (Adaptive Training Sample Selection)
class ATSS:
    def __init__(self, topk=9):
        self.topk = topk

    def __call__(self, anchors, gt_bboxes, num_level_anchors):
        num_gt = gt_bboxes.shape[0]
        num_anchors = anchors.shape[0]
        
        if num_gt == 0:
            return torch.zeros(num_anchors, dtype=torch.bool), torch.zeros(num_anchors, dtype=torch.long)
        
        # Compute distances
        gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
        gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
        gt_points = torch.stack([gt_cx, gt_cy], dim=1)
        
        anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2.0
        anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2.0
        anchor_points = torch.stack([anchor_cx, anchor_cy], dim=1)
        
        distances = (anchor_points.unsqueeze(0) - gt_points.unsqueeze(1)).pow(2).sum(-1).sqrt()
        
        # Select candidates
        candidate_idxs = []
        start_idx = 0
        for num in num_level_anchors:
            end_idx = start_idx + num
            _, topk_idxs = torch.topk(distances[:, start_idx:end_idx], self.topk, dim=1, largest=False)
            candidate_idxs.append(topk_idxs + start_idx)
            start_idx = end_idx
        
        candidate_idxs = torch.cat(candidate_idxs, dim=1)
        
        # Compute IoU threshold
        candidate_ious = torch.rand(num_gt, candidate_idxs.shape[1])
        iou_threshold = candidate_ious.mean(dim=1, keepdim=True) + candidate_ious.std(dim=1, keepdim=True)
        
        # Final positive mask
        is_pos = candidate_ious >= iou_threshold
        
        return is_pos.any(dim=0), is_pos.float().argmax(dim=0)

# 85. HungarianMatcher (DETR风格)
class HungarianMatcher(nn.Module):
    def __init__(self, cost_class=1.0, cost_bbox=5.0, cost_giou=2.0):
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou

    @torch.no_grad()
    def forward(self, outputs, targets):
        bs, num_queries = outputs["pred_logits"].shape[:2]
        
        # Flatten to compute cost matrix
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)
        out_bbox = outputs["pred_boxes"].flatten(0, 1)
        
        # Concat target labels and boxes
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_bbox = torch.cat([v["boxes"] for v in targets])
        
        # Classification cost
        cost_class = -out_prob[:, tgt_ids]
        
        # L1 cost
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
        
        # GIoU cost
        cost_giou = -self.generalized_box_iou(out_bbox, tgt_bbox)
        
        # Final cost matrix
        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.view(bs, num_queries, -1).cpu()
        
        # Hungarian algorithm
        from scipy.optimize import linear_sum_assignment
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split([len(v["labels"]) for v in targets], -1))]
        
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

    def generalized_box_iou(self, boxes1, boxes2):
        return torch.rand(boxes1.shape[0], boxes2.shape[0])

# 86. QualityFocalLoss
class QualityFocalLoss(nn.Module):
    def __init__(self, beta=2.0):
        super().__init__()
        self.beta = beta

    def forward(self, pred, target, score):
        # pred: [N, C], target: [N], score: [N] (IoU score)
        pred_sigmoid = pred.sigmoid()
        scale_factor = pred_sigmoid
        zerolabel = scale_factor.new_zeros(pred.shape)
        
        loss = F.binary_cross_entropy_with_logits(pred, zerolabel, reduction='none') * scale_factor.pow(self.beta)
        
        pos = target != 0
        if pos.any():
            scale_factor = score[pos] - pred_sigmoid[pos]
            loss[pos] = F.binary_cross_entropy_with_logits(pred[pos], score[pos], reduction='none') * scale_factor.abs().pow(self.beta)
        
        return loss.sum()

# 87. DistributionFocalLoss (DFL)
class DistributionFocalLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        # pred: [N, 4, reg_max+1], target: [N, 4] (discretized)
        n, c, reg_max = pred.shape
        target = target.long()
        
        # Get left and right targets
        target_left = target.clamp(min=0, max=reg_max - 1)
        target_right = (target_left + 1).clamp(max=reg_max - 1)
        
        # Weights
        weight_left = (target_right.float() - target.float())
        weight_right = (target.float() - target_left.float())
        
        # Cross entropy
        loss = F.cross_entropy(pred.view(-1, reg_max), target_left.view(-1), reduction='none') * weight_left.view(-1) + \
               F.cross_entropy(pred.view(-1, reg_max), target_right.view(-1), reduction='none') * weight_right.view(-1)
        
        return loss.mean()

# 88. GIoULoss
class GIoULoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target, weight=None):
        # Simplified GIoU
        pred_x1, pred_y1, pred_x2, pred_y2 = pred.T
        target_x1, target_y1, target_x2, target_y2 = target.T
        
        pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
        target_area = (target_x2 - target_x1) * (target_y2 - target_y1)
        
        inter_x1 = torch.max(pred_x1, target_x1)
        inter_y1 = torch.max(pred_y1, target_y1)
        inter_x2 = torch.min(pred_x2, target_x2)
        inter_y2 = torch.min(pred_y2, target_y2)
        
        inter_area = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
        union_area = pred_area + target_area - inter_area
        
        iou = inter_area / (union_area + 1e-7)
        
        # Convex hull
        convex_x1 = torch.min(pred_x1, target_x1)
        convex_y1 = torch.min(pred_y1, target_y1)
        convex_x2 = torch.max(pred_x2, target_x2)
        convex_y2 = torch.max(pred_y2, target_y2)
        convex_area = (convex_x2 - convex_x1) * (convex_y2 - convex_y1)
        
        giou = iou - (convex_area - union_area) / (convex_area + 1e-7)
        loss = 1 - giou
        
        if weight is not None:
            loss = loss * weight
        
        return loss.mean()

# 89. FocalLoss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, target):
        bce_loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        pred_prob = pred.sigmoid()
        p_t = (target * pred_prob) + ((1 - target) * (1 - pred_prob))
        alpha_t = target * self.alpha + (1 - target) * (1 - self.alpha)
        
        loss = alpha_t * (1.0 - p_t).pow(self.gamma) * bce_loss
        return loss.mean()

# 90. VarifocalLoss
class VarifocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, pred, target, score):
        pred_sigmoid = pred.sigmoid()
        weight = self.alpha * pred_sigmoid.pow(self.gamma) * (target > 0).float() + \
                 (1 - self.alpha) * (1 - score) * (target == 0).float()
        
        loss = weight * F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        return loss.sum() / (weight.sum() + 1e-6)

🔧 六、训练策略与辅助模块 (91-100)

Python

# 91. EMA (Exponential Moving Average)
class ModelEMA:
    def __init__(self, model, decay=0.9999):
        self.ema = deepcopy(model).eval()
        self.updates = 0
        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)
            msd = model.state_dict()
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1.0 - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        for k, v in model.__dict__.items():
            if len(include) and k not in include or k.startswith('_') or k in exclude:
                continue
            else:
                setattr(self.ema, k, v)

# 92. SWA (Stochastic Weight Averaging)
class SWA:
    def __init__(self, model, swa_start=10):
        self.model = model
        self.swa_model = deepcopy(model).eval()
        self.swa_start = swa_start
        self.n_averaged = 0

    def update(self, model, epoch):
        if epoch < self.swa_start:
            return
        
        for swa_p, p in zip(self.swa_model.parameters(), model.parameters()):
            swa_p.data = (swa_p.data * self.n_averaged + p.data) / (self.n_averaged + 1)
        self.n_averaged += 1

# 93. DropBlock
class DropBlock(nn.Module):
    def __init__(self, block_size=7, keep_prob=0.9):
        super().__init__()
        self.block_size = block_size
        self.keep_prob = keep_prob

    def forward(self, x):
        if not self.training or self.keep_prob == 1.0:
            return x
        
        gamma = (1.0 - self.keep_prob) * (x.shape[-1] ** 2) / (self.block_size ** 2) / \
                ((x.shape[-1] - self.block_size + 1) ** 2)
        
        mask = torch.bernoulli(torch.ones_like(x) * gamma)
        mask = F.max_pool2d(mask, self.block_size, stride=1, padding=self.block_size // 2)
        mask = 1 - mask
        
        return x * mask * mask.numel() / mask.sum()

# 94. DropPath (Stochastic Depth)
class DropPath(nn.Module):
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0.0 or not self.training:
            return x
        
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor

# 95. LabelSmoothing
class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        log_probs = F.log_softmax(x, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

# 96. CutMix
class CutMix:
    def __init__(self, alpha=1.0):
        self.alpha = alpha

    def __call__(self, images, labels):
        lam = np.random.beta(self.alpha, self.alpha)
        rand_index = torch.randperm(images.size()[0]).to(images.device)
        
        target_a = labels
        target_b = labels[rand_index]
        
        bbx1, bby1, bbx2, bby2 = self.rand_bbox(images.size(), lam)
        images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
        
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
        return images, target_a, target_b, lam

    def rand_bbox(self, size, lam):
        W = size[2]
        H = size[3]
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        return bbx1, bby1, bbx2, bby2

# 97. Mosaic Augmentation
class Mosaic:
    def __init__(self, size=640):
        self.size = size

    def __call__(self, images, labels):
        # Simplified mosaic: 4 images in one
        s = self.size
        mosaic_img = torch.zeros(3, s * 2, s * 2)
        mosaic_labels = []
        
        positions = [(0, 0), (s, 0), (0, s), (s, s)]
        
        for i, (img, label) in enumerate(zip(images[:4], labels[:4])):
            x1, y1 = positions[i]
            mosaic_img[:, y1:y1+s, x1:x1+s] = img
            # Adjust labels
            if len(label) > 0:
                label[:, [0, 2]] += x1
                label[:, [1, 3]] += y1
                mosaic_labels.append(label)
        
        if len(mosaic_labels) > 0:
            mosaic_labels = torch.cat(mosaic_labels, dim=0)
        
        return mosaic_img, mosaic_labels

# 98. MixUp
class MixUp:
    def __init__(self, alpha=0.5):
        self.alpha = alpha

    def __call__(self, images, labels):
        lam = np.random.beta(self.alpha, self.alpha)
        batch_size = images.size(0)
        index = torch.randperm(batch_size).to(images.device)
        
        mixed_images = lam * images + (1 - lam) * images[index]
        return mixed_images, labels, labels[index], lam

# 99. Test Time Augmentation (TTA)
class TTA:
    def __init__(self, scales=[1.0, 1.25, 0.75], flips=[False, True]):
        self.scales = scales
        self.flips = flips

    def __call__(self, model, image):
        results = []
        original_size = image.shape[2:]
        
        for scale in self.scales:
            for flip in self.flips:
                # Resize
                if scale != 1.0:
                    new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
                    aug_img = F.interpolate(image, size=new_size, mode='bilinear', align_corners=False)
                else:
                    aug_img = image
                
                # Flip
                if flip:
                    aug_img = torch.flip(aug_img, dims=[3])
                
                # Inference
                with torch.no_grad():
                    pred = model(aug_img)
                
                # Reverse flip
                if flip:
                    pred[0][..., 0] = aug_img.shape[3] - pred[0][..., 0]  # x coordinate
                
                results.append(pred)
        
        # Merge results (simplified)
        return results[0]  # Return first for simplicity

# 100. MultiScaleTraining
class MultiScaleTraining:
    def __init__(self, sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640]):
        self.sizes = sizes
        self.current_size = 640

    def __call__(self, images, labels):
        # Randomly select size
        self.current_size = random.choice(self.sizes)
        
        # Resize images and adjust labels
        resized_images = F.interpolate(images, size=(self.current_size, self.current_size), 
                                      mode='bilinear', align_corners=False)
        
        # Scale labels
        scale_factor = self.current_size / images.shape[2]
        scaled_labels = labels * scale_factor if labels.numel() > 0 else labels
        
        return resized_images, scaled_labels

📋 使用说明

在YOLOv12中插入模块的示例:

Python

# 修改 backbone 插入注意力
class ImprovedBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv(3, 64, 3, 2)
        self.attn1 = ECA(64)  # 插入ECA注意力
        self.conv2 = Conv(64, 128, 3, 2)
        self.attn2 = CBAM(128)  # 插入CBAM

# 修改 neck 使用改进的FPN
class ImprovedNeck(nn.Module):
    def __init__(self):
        super().__init__()
        self.fpn = BiFPN([256, 512, 1024])  # 使用BiFPN

# 修改 head 使用解耦头
class ImprovedHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.head = DecoupledHead(256, num_classes=80)  # 使用解耦头

快速替换指南:

表格

原模块改进模块效果
ConvGhostConv减少50%计算量
C3C3k2_Improved保持精度,更快
SPPFASPP多尺度感受野
ConcatASFF自适应特征融合
DetectDecoupledHead分类回归解耦

这100个模块覆盖了注意力机制、卷积改进、特征融合、检测头、损失函数、数据增强等各个方面,可以灵活组合使用!

转载自CSDN-专业IT技术社区

原文链接:https://blog.csdn.net/VectorShift/article/details/158586215

评论

赞0

评论列表

微信小程序
QQ小程序

关于作者

点赞数:0
关注数:0
粉丝:0
文章:0
关注标签:0
加入于:--