目录
🔥 一、注意力机制类 (1-25)
通道注意力 (1-8)
Python
# 1. ECA-Net (Efficient Channel Attention)
class ECA(nn.Module):
def __init__(self, channel, gamma=2, b=1):
super().__init__()
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
return x * self.sigmoid(y)
# 2. SE-Net (Squeeze-and-Excitation)
class SE(nn.Module):
def __init__(self, c1, reduction=16):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(c1, c1 // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(c1 // reduction, c1, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avgpool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
# 3. CBAM (Convolutional Block Attention Module) - Channel
class CBAM_Channel(nn.Module):
def __init__(self, c1, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(c1, c1//reduction, 1, bias=False),
nn.ReLU(),
nn.Conv2d(c1//reduction, c1, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
return x * self.sigmoid(avg_out + max_out)
# 4. GCT (Gated Context Transformation)
class GCT(nn.Module):
def __init__(self, c1, epsilon=1e-5):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, c1, 1, 1))
self.gamma = nn.Parameter(torch.zeros(1, c1, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, c1, 1, 1))
self.epsilon = epsilon
def forward(self, x):
embedding = torch.norm(x, p=2, dim=(2,3), keepdim=True)
norm = self.gamma / (embedding + self.epsilon).pow(0.5)
gate = 1. + torch.tanh(norm * (embedding - self.alpha))
return x * gate
# 5. EMA (Exponential Moving Average) Attention
class EMA(nn.Module):
def __init__(self, c1, factor=32):
super().__init__()
self.groups = factor
self.softmax = nn.Softmax(-1)
self.agp = nn.AdaptiveAvgPool2d((1, 1))
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
self.gn = nn.GroupNorm(c1 // self.groups, c1 // self.groups)
self.conv1x1 = nn.Conv2d(c1 // self.groups, c1 // self.groups, kernel_size=1)
self.conv3x3 = nn.Conv2d(c1 // self.groups, c1 // self.groups, kernel_size=3, padding=1)
def forward(self, x):
b, c, h, w = x.size()
group_x = x.reshape(b * self.groups, -1, h, w)
x_h = self.pool_h(group_x)
x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
x_h, x_w = torch.split(hw, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
out = self.gn(group_x * x_h.sigmoid() * x_w.sigmoid())
out = self.conv3x3(out)
return out.reshape(b, c, h, w)
# 6. SimAM (Simple Attention Module)
class SimAM(nn.Module):
def __init__(self, e_lambda=1e-4):
super().__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
# 7. SRM (Style-based Recalibration Module)
class SRM(nn.Module):
def __init__(self, c1):
super().__init__()
self.c1 = c1
self.style_pool = nn.AdaptiveAvgPool2d(1)
self.style_conv = nn.Conv2d(c1, c1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
style = self.style_pool(x)
style = self.style_conv(style)
return x * self.sigmoid(style)
# 8. FcaNet (Frequency Channel Attention)
class FcaNet(nn.Module):
def __init__(self, c1, reduction=16):
super().__init__()
self.register_buffer('pre_computed_dct_weights', self._get_dct_weights(c1))
self.fc = nn.Sequential(
nn.Linear(c1, c1 // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(c1 // reduction, c1, bias=False),
nn.Sigmoid()
)
def _get_dct_weights(self, channels):
c_part = channels // 4
dct_weights = torch.zeros(channels)
for i in range(channels):
freq = i // c_part
dct_weights[i] = freq + 1
return dct_weights.view(1, -1, 1, 1)
def forward(self, x):
b, c, _, _ = x.shape
y = torch.sum(x * self.pre_computed_dct_weights, dim=[2,3])
y = self.fc(y).view(b, c, 1, 1)
return x * y
空间注意力 (9-16)
Python
# 9. Spatial Attention (SAM)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x_cat = torch.cat([avg_out, max_out], dim=1)
return x * self.sigmoid(self.conv(x_cat))
# 10. Coordinate Attention (CA)
class CoordAtt(nn.Module):
def __init__(self, inp, reduction=32):
super().__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
mip = max(8, inp // reduction)
self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = nn.Hardswish()
self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
def forward(self, x):
identity = x
n, c, h, w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)
y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
return identity * a_w * a_h
# 11. Triplet Attention
class TripletAttention(nn.Module):
def __init__(self, no_spatial=False):
super().__init__()
self.cw = nn.Conv2d(1, 1, kernel_size=(7, 3), padding=(3, 1), bias=False)
self.hc = nn.Conv2d(1, 1, kernel_size=(3, 7), padding=(1, 3), bias=False)
self.no_spatial = no_spatial
if not no_spatial:
self.hw = nn.Conv2d(1, 1, 7, padding=3, bias=False)
def forward(self, x):
x_perm1 = x.permute(0, 2, 1, 3).contiguous()
x_out1 = self.cw(x_perm1)
x_out11 = x_out1.sigmoid().permute(0, 2, 1, 3).contiguous()
x_perm2 = x.permute(0, 3, 2, 1).contiguous()
x_out2 = self.hc(x_perm2)
x_out21 = x_out2.sigmoid().permute(0, 3, 2, 1).contiguous()
if not self.no_spatial:
x_out = self.hw(x)
x_out = x_out * x_out11 * x_out21
else:
x_out = x * x_out11 * x_out21
return x_out
# 12. Shuffle Attention
class ShuffleAttention(nn.Module):
def __init__(self, c1, G=8):
super().__init__()
self.G = G
self.c1 = c1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.cweight = nn.Parameter(torch.zeros(1, c1 // (2 * G), 1, 1))
self.cbias = nn.Parameter(torch.ones(1, c1 // (2 * G), 1, 1))
self.sweight = nn.Parameter(torch.zeros(1, c1 // (2 * G), 1, 1))
self.sbias = nn.Parameter(torch.ones(1, c1 // (2 * G), 1, 1))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, h, w = x.size()
x = x.view(b * self.G, -1, h, w)
x_0, x_1 = x.chunk(2, dim=1)
# channel attention
xn = self.avg_pool(x_0)
xn = self.cweight * xn + self.cbias
xn = x_0 * self.sigmoid(xn)
# spatial attention
xs = torch.sigmoid(self.sweight * x_1 + self.sbias)
xs = x_1 * xs
out = torch.cat([xn, xs], dim=1)
out = out.contiguous().view(b, -1, h, w)
return out
# 13. Residual Spatial Attention (RSA)
class RSA(nn.Module):
def __init__(self, c1):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(c1, c1//2, 1),
nn.BatchNorm2d(c1//2),
nn.ReLU(),
nn.Conv2d(c1//2, 1, 3, padding=1),
nn.BatchNorm2d(1)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
attn = self.conv(x)
return x * self.sigmoid(attn)
# 14. Polarized Self-Attention (PSA)
class PSA(nn.Module):
def __init__(self, c1):
super().__init__()
self.ch_wv = nn.Conv2d(c1, c1 // 2, kernel_size=(1, 1))
self.ch_wq = nn.Conv2d(c1, 1, kernel_size=(1, 1))
self.ch_wz = nn.Conv2d(c1 // 2, c1, kernel_size=(1, 1))
self.ch_b = nn.Parameter(torch.zeros(1, c1, 1, 1))
self.sp_wv = nn.Conv2d(c1, c1 // 2, kernel_size=(1, 1))
self.sp_wq = nn.Conv2d(c1, c1 // 2, kernel_size=(1, 1))
self.sp_wz = nn.Conv2d(c1 // 2, 1, kernel_size=(1, 1))
self.sp_b = nn.Parameter(torch.zeros(1, 1, 1, 1))
self.softmax = nn.Softmax(-1)
def forward(self, x):
# channel attention
b, c, h, w = x.size()
channel_wv = self.ch_wv(x)
channel_wq = self.ch_wq(x).view(b, 1, -1)
channel_wv = channel_wv.view(b, c//2, -1)
channel_wq = self.softmax(channel_wq)
channel_wz = torch.matmul(channel_wv, channel_wq.permute(0, 2, 1)).unsqueeze(-1)
channel_out = self.ch_wz(channel_wz) + self.ch_b
# spatial attention
spatial_wv = self.sp_wv(x)
spatial_wq = self.sp_wq(x)
spatial_wv = spatial_wv.view(b, c//2, -1)
spatial_wq = spatial_wq.view(b, c//2, -1).permute(0, 2, 1)
spatial_wz = torch.matmul(spatial_wv, spatial_wq)
spatial_out = self.sp_wz(spatial_wz.unsqueeze(-1)) + self.sp_b
return x * torch.sigmoid(channel_out + spatial_out)
# 15. A2 (Area Attention) - YOLOv12原生,可改进版本
class A2_Improved(nn.Module):
def __init__(self, dim, num_heads=8, area=1):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.area = area
self.qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
self.pe = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
def forward(self, x):
B, C, H, W = x.shape
qkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, H * W).permute(1, 0, 2, 4, 3)
q, k, v = qkv[0], qkv[1], qkv[2]
# Area attention: partition into local regions
if self.area > 1:
q = q.reshape(B, self.num_heads, H // self.area, self.area, W // self.area, self.area, self.head_dim)
q = q.permute(0, 1, 2, 4, 3, 5, 6).reshape(B, self.num_heads, -1, self.area * self.area, self.head_dim)
k = k.reshape(B, self.num_heads, H // self.area, self.area, W // self.area, self.area, self.head_dim)
k = k.permute(0, 1, 2, 4, 3, 5, 6).reshape(B, self.num_heads, -1, self.area * self.area, self.head_dim)
v = v.reshape(B, self.num_heads, H // self.area, self.area, W // self.area, self.area, self.head_dim)
v = v.permute(0, 1, 2, 4, 3, 5, 6).reshape(B, self.num_heads, -1, self.area * self.area, self.head_dim)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, H, W, C).permute(0, 3, 1, 2)
x = self.proj(x) + self.pe(x)
return x
# 16. Criss-Cross Attention (CCA)
class CrissCrossAttention(nn.Module):
def __init__(self, c1):
super().__init__()
self.query_conv = nn.Conv2d(c1, c1 // 8, 1)
self.key_conv = nn.Conv2d(c1, c1 // 8, 1)
self.value_conv = nn.Conv2d(c1, c1, 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=3)
def forward(self, x):
m_batchsize, _, height, width = x.size()
proj_query = self.query_conv(x)
proj_key = self.key_conv(x)
proj_value = self.value_conv(x)
proj_query_H = proj_query.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height).permute(0, 2, 1)
proj_query_W = proj_query.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width).permute(0, 2, 1)
proj_key_H = proj_key.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)
proj_key_W = proj_key.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)
proj_value_H = proj_value.permute(0, 3, 1, 2).contiguous().view(m_batchsize * width, -1, height)
proj_value_W = proj_value.permute(0, 2, 1, 3).contiguous().view(m_batchsize * height, -1, width)
energy_H = (torch.bmm(proj_query_H, proj_key_H) + self.INF(m_batchsize, height, width)).view(m_batchsize, width, height, height).permute(0, 2, 1, 3)
energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize, height, width, width)
concate = self.softmax(torch.cat([energy_H, energy_W], 3))
return x + self.gamma * out
混合注意力 (17-25)
Python
# 17. CBAM (Channel + Spatial)
class CBAM(nn.Module):
def __init__(self, c1, reduction=16, kernel_size=7):
super().__init__()
self.channel = CBAM_Channel(c1, reduction)
self.spatial = SpatialAttention(kernel_size)
def forward(self, x):
x = self.channel(x)
x = self.spatial(x)
return x
# 18. BAM (Bottleneck Attention)
class BAM(nn.Module):
def __init__(self, c1, reduction=16, dilation_val=4):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(c1, c1 // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(c1 // reduction, c1, bias=False)
)
self.conv1 = nn.Conv2d(c1, c1 // reduction, 1)
self.conv2 = nn.Conv2d(c1 // reduction, c1 // reduction, 3, padding=dilation_val, dilation=dilation_val)
self.conv3 = nn.Conv2d(c1 // reduction, 1, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, h, w = x.size()
y1 = self.avg_pool(x).view(b, c)
y1 = self.fc(y1).view(b, c, 1, 1)
y2 = self.conv1(x)
y2 = F.relu(self.conv2(y2))
y2 = self.conv3(y2)
y = self.sigmoid(y1.expand_as(y2) + y2)
return x * y
# 19. DA-Net (Dual Attention)
class DANet(nn.Module):
def __init__(self, c1):
super().__init__()
self.sa = PAM(c1)
self.sc = CAM(c1)
def forward(self, x):
sa_feat = self.sa(x)
sc_feat = self.sc(x)
return sa_feat + sc_feat
# 20. CC-Net (Criss-Cross)
class CCNet(nn.Module):
def __init__(self, c1):
super().__init__()
self.cca = CrissCrossAttention(c1)
self.cca2 = CrissCrossAttention(c1)
def forward(self, x):
x = self.cca(x)
x = self.cca2(x)
return x
# 21. ANN (Asymmetric Non-local)
class ANN(nn.Module):
def __init__(self, c1):
super().__init__()
self.query_conv = nn.Conv2d(c1, c1 // 8, 1)
self.key_conv = nn.Conv2d(c1, c1 // 8, 1)
self.value_conv = nn.Conv2d(c1, c1, 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batchsize, C, height, width = x.size()
proj_query = self.query_conv(x).view(batchsize, -1, width * height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(batchsize, -1, width * height)
energy = torch.bmm(proj_query, proj_key)
attention = F.softmax(energy, dim=-1)
proj_value = self.value_conv(x).view(batchsize, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(batchsize, C, height, width)
return self.gamma * out + x
# 22. GC-Net (Global Context)
class GCNet(nn.Module):
def __init__(self, c1, ratio=16):
super().__init__()
self.c1 = c1
self.channel = 1
self.conv_mask = nn.Conv2d(c1, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.c1, self.c1 // ratio, kernel_size=1),
nn.LayerNorm([self.c1 // ratio, 1, 1]),
nn.ReLU(inplace=True),
nn.Conv2d(self.c1 // ratio, self.c1, kernel_size=1)
)
def spatial_pool(self, x):
batch, channel, height, width = x.size()
input_x = x
input_x = input_x.view(batch, channel, height * width)
input_x = input_x.unsqueeze(1)
context_mask = self.conv_mask(x)
context_mask = context_mask.view(batch, 1, height * width)
context_mask = self.softmax(context_mask)
context = torch.matmul(input_x, context_mask.transpose(1, 2))
return context
def forward(self, x):
context = self.spatial_pool(x)
channel_add_term = self.channel_add_conv(context)
x = x + channel_add_term
return x
# 23. Double Attention
class DoubleAttention(nn.Module):
def __init__(self, c1, reconstruct=True):
super().__init__()
self.c1 = c1
self.conv1 = nn.Conv2d(c1, c1 // 8, 1)
self.conv2 = nn.Conv2d(c1, c1 // 8, 1)
self.conv3 = nn.Conv2d(c1, c1 // 8, 1)
self.softmax = nn.Softmax(dim=-1)
self.scale = (c1 // 8) ** -0.5
self.reconstruct = reconstruct
if reconstruct:
self.conv_reconstruct = nn.Conv2d(c1 // 8, c1, 1)
def forward(self, x):
b, c, h, w = x.size()
f = self.conv1(x).view(b, c // 8, h * w)
g = self.conv2(x).view(b, c // 8, h * w)
h_feat = self.conv3(x).view(b, c // 8, h * w)
attention = torch.bmm(f.permute(0, 2, 1), g)
attention = self.softmax(attention)
out = torch.bmm(h_feat, attention.permute(0, 2, 1))
out = out.view(b, c // 8, h, w)
if self.reconstruct:
out = self.conv_reconstruct(out)
return out
# 24. SK-Net (Selective Kernel)
class SKNet(nn.Module):
def __init__(self, c1, reduction=16):
super().__init__()
self.d = max(c1 // reduction, 32)
self.conv1 = nn.Conv2d(c1, self.d, 1)
self.conv2 = nn.Conv2d(c1, self.d, 1)
self.fc = nn.Sequential(
nn.Linear(self.d, c1 // reduction),
nn.ReLU(inplace=True),
nn.Linear(c1 // reduction, c1 * 2)
)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
batch_size = x.size(0)
u1 = self.conv1(x)
u2 = self.conv2(x)
u = u1 + u2
s = u.mean(-1).mean(-1)
z = self.fc(s)
z = z.view(batch_size, 2, -1)
a_b = self.softmax(z)
a, b = a_b[:, 0, :], a_b[:, 1, :]
a, b = a.view(batch_size, -1, 1, 1), b.view(batch_size, -1, 1, 1)
v = a.expand_as(u1) * u1 + b.expand_as(u2) * u2
return v
# 25. DyNet (Dynamic Convolution)
class DyNet(nn.Module):
def __init__(self, c1, c2, k=3, stride=1, M=4):
super().__init__()
self.M = M
self.conv = nn.ModuleList()
for i in range(M):
self.conv.append(nn.Conv2d(c1, c2, k, stride, k//2))
self.fc = nn.Sequential(
nn.Linear(c1, M),
nn.Sigmoid()
)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
attention = self.fc(self.avg_pool(x).squeeze(-1).squeeze(-1))
out = 0
for i in range(self.M):
out += attention[:, i:i+1].unsqueeze(-1).unsqueeze(-1) * self.conv[i](x)
return out
🏗️ 二、卷积模块改进 (26-45)
深度可分离卷积 (26-32)
Python
# 26. Ghost Convolution
class GhostConv(nn.Module):
def __init__(self, c1, c2, k=1, s=1, g=1, act=True):
super().__init__()
c_ = c2 // 2
self.cv1 = Conv(c1, c_, k, s, None, g, act)
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
def forward(self, x):
y = self.cv1(x)
return torch.cat([y, self.cv2(y)], 1)
# 27. PConv (Partial Convolution)
class PConv(nn.Module):
def __init__(self, c1, c2, k=3, s=1):
super().__init__()
self.cv1 = nn.Conv2d(c1 // 4, c1 // 4, k, s, k//2, groups=c1//4)
self.cv2 = nn.Conv2d(c1, c2, 1)
def forward(self, x):
y = torch.zeros_like(x)
y[:, :x.size(1)//4, :, :] = self.cv1(x[:, :x.size(1)//4, :, :])
return self.cv2(y)
# 28. DSConv (Depthwise Separable)
class DSConv(nn.Module):
def __init__(self, c1, c2, k=3, s=1, act=True):
super().__init__()
self.dw = nn.Conv2d(c1, c1, k, s, k//2, groups=c1, bias=False)
self.pw = nn.Conv2d(c1, c2, 1, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.pw(self.dw(x))))
# 29. MixConv (Mixed Depthwise)
class MixConv(nn.Module):
def __init__(self, c1, c2, k=(3, 5, 7), stride=1):
super().__init__()
self.groups = len(k)
self.split_channel = c2 // self.groups
self.m = nn.ModuleList([
nn.Conv2d(c1, self.split_channel, kernel_size=ki, stride=stride, padding=ki//2, groups=c1)
for ki in k
])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
# 30. CondConv (Conditional Convolution)
class CondConv(nn.Module):
def __init__(self, c1, c2, k=3, s=1, num_experts=3):
super().__init__()
self.num_experts = num_experts
self.conv = nn.ModuleList([nn.Conv2d(c1, c2, k, s, k//2) for _ in range(num_experts)])
self.routing = nn.Linear(c1, num_experts)
def forward(self, x):
b, c, h, w = x.size()
routing_weights = F.softmax(self.routing(F.adaptive_avg_pool2d(x, 1).view(b, c)), dim=1)
out = sum(w.view(b, 1, 1, 1) * conv(x) for w, conv in zip(routing_weights.T, self.conv))
return out
# 31. ODConv (Omni-Dimensional)
class ODConv(nn.Module):
def __init__(self, c1, c2, k=3, s=1):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, k//2)
self.attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(c1, c1//16, 1),
nn.ReLU(),
nn.Conv2d(c1//16, c1, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.conv(x) * self.attn(x)
# 32. Involution
class Involution(nn.Module):
def __init__(self, c1, k=7, stride=1):
super().__init__()
self.kernel_size = k
self.stride = stride
self.channel_in = c1 // 2
self.channel_out = c1
self.span = k // 2
self.o_proj = nn.Conv2d(c1, k * k * self.channel_in, 1)
self.reduce = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.dilation = 1
def forward(self, x):
b, c, h, w = x.shape
h_out, w_out = h // self.stride, w // self.stride
weight = self.o_proj(self.reduce(x))
weight = weight.view(b, self.channel_in, self.kernel_size * self.kernel_size, h_out, w_out)
x_unfold = F.unfold(x, self.kernel_size, padding=self.span, stride=self.stride)
x_unfold = x_unfold.view(b, self.channel_in, self.kernel_size * self.kernel_size, h_out, w_out)
out = (x_unfold * weight).sum(dim=2).view(b, self.channel_in, h_out, w_out)
return out
特征提取增强 (33-40)
Python
# 33. Res2Net Module
class Res2Net(nn.Module):
def __init__(self, c1, c2, s=1, scale=4):
super().__init__()
width = c2 // scale
self.nums = scale - 1
convs = []
bns = []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, 3, s, 1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.act = nn.ReLU(inplace=True)
self.downsample = nn.Conv2d(c1, c2, 1, s, bias=False) if c1 != c2 or s != 1 else None
def forward(self, x):
residual = x
out = self.act(x)
spx = torch.split(out, out.size(1) // (self.nums + 1), dim=1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.act(self.bns[i](sp))
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
out = torch.cat((out, spx[self.nums]), 1)
if self.downsample is not None:
residual = self.downsample(residual)
return out + residual
# 34. Dilated Convolution Block
class DilatedConv(nn.Module):
def __init__(self, c1, c2, k=3, dilations=[1, 2, 4]):
super().__init__()
self.convs = nn.ModuleList([
nn.Conv2d(c1, c2 // len(dilations), k, padding=d, dilation=d)
for d in dilations
])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.bn(torch.cat([conv(x) for conv in self.convs], dim=1)))
# 35. ASPP (Atrous Spatial Pyramid Pooling)
class ASPP(nn.Module):
def __init__(self, c1, c2, rates=[6, 12, 18]):
super().__init__()
self.conv1 = nn.Conv2d(c1, c2 // 4, 1)
self.conv2 = nn.Conv2d(c1, c2 // 4, 3, padding=rates[0], dilation=rates[0])
self.conv3 = nn.Conv2d(c1, c2 // 4, 3, padding=rates[1], dilation=rates[1])
self.conv4 = nn.Conv2d(c1, c2 // 4, 3, padding=rates[2], dilation=rates[2])
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(c1, c2 // 4, 1)
)
self.project = nn.Sequential(
nn.Conv2d(c2, c2, 1),
nn.BatchNorm2d(c2),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x3 = self.conv3(x)
x4 = self.conv4(x)
x5 = F.interpolate(self.global_avg_pool(x), size=x.shape[2:], mode='bilinear', align_corners=True)
x = torch.cat([x1, x2, x3, x4, x5], dim=1)
return self.project(x)
# 36. RFB (Receptive Field Block)
class RFB(nn.Module):
def __init__(self, c1, c2, stride=1, map_reduce=8):
super().__init__()
inter_planes = c1 // map_reduce
self.branch0 = nn.Sequential(
nn.Conv2d(c1, inter_planes, 1, stride, 0),
nn.BatchNorm2d(inter_planes),
nn.ReLU(inplace=True)
)
self.branch1 = nn.Sequential(
nn.Conv2d(c1, inter_planes, 1, 1, 0),
nn.BatchNorm2d(inter_planes),
nn.ReLU(inplace=True),
nn.Conv2d(inter_planes, inter_planes, 3, stride, 1, dilation=1),
nn.BatchNorm2d(inter_planes),
nn.ReLU(inplace=True)
)
self.branch2 = nn.Sequential(
nn.Conv2d(c1, inter_planes, 1, 1, 0),
nn.BatchNorm2d(inter_planes),
nn.ReLU(inplace=True),
nn.Conv2d(inter_planes, inter_planes, 3, 1, 1),
nn.BatchNorm2d(inter_planes),
nn.ReLU(inplace=True),
nn.Conv2d(inter_planes, inter_planes, 3, stride, 3, dilation=3),
nn.BatchNorm2d(inter_planes),
nn.ReLU(inplace=True)
)
self.ConvLinear = nn.Sequential(
nn.Conv2d(inter_planes * 3, c2, 1, 1, 0),
nn.BatchNorm2d(c2)
)
self.shortcut = nn.Sequential(
nn.Conv2d(c1, c2, 1, stride, 0),
nn.BatchNorm2d(c2)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
out = torch.cat((x0, x1, x2), 1)
out = self.ConvLinear(out)
short = self.shortcut(x)
out = out * 0.1 + short
return self.relu(out)
# 37. DCN (Deformable Convolution)
class DCN(nn.Module):
def __init__(self, c1, c2, k=3, s=1, p=1):
super().__init__()
self.conv_offset = nn.Conv2d(c1, 2 * k * k, k, s, p)
self.conv_mask = nn.Conv2d(c1, k * k, k, s, p)
self.conv = nn.Conv2d(c1, c2, k, s, p)
def forward(self, x):
offset = self.conv_offset(x)
mask = torch.sigmoid(self.conv_mask(x))
return torchvision.ops.deform_conv2d(x, offset, self.conv.weight, self.conv.bias, mask=mask)
# 38. CARAFE (Content-Aware ReAssembly)
class CARAFE(nn.Module):
def __init__(self, c1, c2, scale_factor=2, up_kernel=5, encoder_kernel=3):
super().__init__()
self.scale_factor = scale_factor
self.up_kernel = up_kernel
self.encoder = nn.Conv2d(c1, up_kernel * up_kernel * scale_factor * scale_factor, encoder_kernel, 1, encoder_kernel//2)
self.unfold = nn.Unfold(kernel_size=up_kernel, padding=up_kernel//2)
def forward(self, x):
N, C, H, W = x.size()
kernel = self.encoder(x)
kernel = kernel.view(N, self.up_kernel * self.up_kernel, self.scale_factor * self.scale_factor, H, W)
kernel = F.softmax(kernel, dim=1)
x = F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
x = self.unfold(x).view(N, C, self.up_kernel * self.up_kernel, self.scale_factor * H, self.scale_factor * W)
out = (x * kernel.unsqueeze(1)).sum(dim=2)
return out
# 39. DySample (Dynamic Sampling)
class DySample(nn.Module):
def __init__(self, c1, scale=2, style='lp'):
super().__init__()
self.scale = scale
self.style = style
self.offset_conv = nn.Sequential(
nn.Conv2d(c1, 2 * scale * scale, 1),
nn.PixelShuffle(scale)
)
def forward(self, x):
offset = self.offset_conv(x)
return self.sample(x, offset)
def sample(self, x, offset):
b, c, h, w = x.shape
offset = offset.permute(0, 2, 3, 1).contiguous()
grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
grid = torch.stack([grid_x, grid_y], dim=-1).float().to(x.device)
grid = grid.unsqueeze(0).repeat(b, 1, 1, 1)
grid = grid + offset
grid[:, :, :, 0] = grid[:, :, :, 0] / w * 2 - 1
grid[:, :, :, 1] = grid[:, :, :, 1] / h * 2 - 1
return F.grid_sample(x, grid, mode='bilinear', padding_mode='zeros')
# 40. SPD-Conv (Space-to-Depth)
class SPDConv(nn.Module):
def __init__(self, c1, c2, k=3, s=1):
super().__init__()
self.conv = nn.Conv2d(c1 * 4, c2, k, s, k//2)
def forward(self, x):
return self.conv(torch.cat([
x[..., ::2, ::2],
x[..., 1::2, ::2],
x[..., ::2, 1::2],
x[..., 1::2, 1::2]
], dim=1))
轻量化设计 (41-45)
Python
# 41. MobileNetV3 Block
class MobileNetV3Block(nn.Module):
def __init__(self, c1, c2, exp_ratio=4, k=3, se=True):
super().__init__()
exp = c1 * exp_ratio
self.conv = nn.Sequential(
nn.Conv2d(c1, exp, 1, bias=False),
nn.BatchNorm2d(exp),
nn.Hardswish(),
nn.Conv2d(exp, exp, k, 1, k//2, groups=exp, bias=False),
nn.BatchNorm2d(exp),
nn.Hardswish() if se else nn.Identity(),
SE(exp) if se else nn.Identity(),
nn.Conv2d(exp, c2, 1, bias=False),
nn.BatchNorm2d(c2)
)
self.shortcut = nn.Conv2d(c1, c2, 1) if c1 != c2 else nn.Identity()
def forward(self, x):
return self.conv(x) + self.shortcut(x)
# 42. ShuffleNetV2 Block
class ShuffleV2Block(nn.Module):
def __init__(self, c1, c2, stride):
super().__init__()
self.stride = stride
branch_features = c2 // 2
if self.stride > 1:
self.branch1 = nn.Sequential(
nn.Conv2d(c1, c1, 3, stride, 1, groups=c1, bias=False),
nn.BatchNorm2d(c1),
nn.Conv2d(c1, branch_features, 1, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(c1 if stride > 1 else branch_features, branch_features, 1, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
nn.Conv2d(branch_features, branch_features, 3, stride, 1, groups=branch_features, bias=False),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, 1, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
return channel_shuffle(out, 2)
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
x = x.view(batchsize, groups, channels_per_group, height, width)
x = x.transpose(1, 2).contiguous()
x = x.view(batchsize, num_channels, height, width)
return x
# 43. EfficientNet Block
class EfficientNetBlock(nn.Module):
def __init__(self, c1, c2, expand_ratio, k=3, stride=1, se_ratio=0.25):
super().__init__()
hidden_dim = round(c1 * expand_ratio)
self.conv = nn.Sequential(
nn.Conv2d(c1, hidden_dim, 1, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
nn.Conv2d(hidden_dim, hidden_dim, k, stride, k//2, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.SiLU(),
SE(hidden_dim, int(1/se_ratio)) if se_ratio else nn.Identity(),
nn.Conv2d(hidden_dim, c2, 1, bias=False),
nn.BatchNorm2d(c2)
)
self.shortcut = nn.Sequential() if stride == 1 and c1 == c2 else None
def forward(self, x):
if self.shortcut is not None:
return x + self.conv(x)
else:
return self.conv(x)
# 44. RepVGG Block
class RepVGGBlock(nn.Module):
def __init__(self, c1, c2, k=3, s=1, padding=1, deploy=False):
super().__init__()
self.deploy = deploy
self.nonlinearity = nn.ReLU()
if deploy:
self.rbr_reparam = nn.Conv2d(c1, c2, k, s, padding, bias=True)
else:
self.rbr_identity = nn.BatchNorm2d(c1) if c2 == c1 and s == 1 else None
self.rbr_dense = nn.Sequential(
nn.Conv2d(c1, c2, k, s, padding, bias=False),
nn.BatchNorm2d(c2)
)
self.rbr_1x1 = nn.Sequential(
nn.Conv2d(c1, c2, 1, s, 0, bias=False),
nn.BatchNorm2d(c2)
)
def forward(self, x):
if self.deploy:
return self.nonlinearity(self.rbr_reparam(x))
if self.rbr_identity is None:
id_out = 0
else:
id_out = self.rbr_identity(x)
return self.nonlinearity(self.rbr_dense(x) + self.rbr_1x1(x) + id_out)
# 45. FasterNet Block
class FasterNetBlock(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__()
c_ = int(c2 * e)
self.cv1 = PConv(c1, c_, 3, 1)
self.cv2 = Conv(c_, c2, 1)
self.add = shortcut and c1 == c2
def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
🔄 三、特征融合模块 (46-65)
Neck结构改进 (46-55)
Python
# 46. BiFPN (Bidirectional FPN)
class BiFPN(nn.Module):
def __init__(self, channels_list, num_outs=5):
super().__init__()
self.num_outs = num_outs
self.bifpn_convs = nn.ModuleList()
for i in range(num_outs - 1):
self.bifpn_convs.append(
nn.Sequential(
Conv(channels_list[i] + channels_list[i+1], channels_list[i+1], 1),
Conv(channels_list[i+1], channels_list[i+1], 3)
)
)
def forward(self, inputs):
outs = [inputs[0]]
for i in range(len(inputs) - 1):
up = F.interpolate(outs[-1], size=inputs[i+1].shape[2:], mode='nearest')
out = torch.cat([up, inputs[i+1]], dim=1)
out = self.bifpn_convs[i](out)
outs.append(out)
return outs
# 47. NAS-FPN
class NASFPN(nn.Module):
def __init__(self, c1, c2, num_layers=7):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(self._build_merge_module(c1, c2))
def _build_merge_module(self, c1, c2):
return nn.Sequential(
Conv(c1 * 2, c2, 1),
Conv(c2, c2, 3),
nn.BatchNorm2d(c2),
nn.ReLU()
)
def forward(self, features):
# Simplified NAS-FPN structure
for layer in self.layers:
new_feat = layer(torch.cat([features[-2], features[-1]], dim=1))
features.append(new_feat)
return features[-3:]
# 48. ASFF (Adaptively Spatial Feature Fusion)
class ASFF(nn.Module):
def __init__(self, level, rfb=False, vis=False):
super().__init__()
self.level = level
self.dim = [512, 256, 128] # P3, P4, P5 channels
self.inter_dim = self.dim[self.level]
if level == 0:
self.stride_level_1 = Conv(256, self.inter_dim, 3, 2)
self.stride_level_2 = Conv(128, self.inter_dim, 3, 2)
elif level == 1:
self.compress_level_0 = Conv(512, self.inter_dim, 1, 1)
self.stride_level_2 = Conv(128, self.inter_dim, 3, 2)
elif level == 2:
self.compress_level_0 = Conv(512, self.inter_dim, 1, 1)
self.compress_level_1 = Conv(256, self.inter_dim, 1, 1)
self.weight_level_0 = Conv(self.inter_dim, 2, 1, 1)
self.weight_level_1 = Conv(self.inter_dim, 2, 1, 1)
self.weight_level_2 = Conv(self.inter_dim, 2, 1, 1)
self.weights_levels = Conv(6, 3, 1, 1)
def forward(self, x_level_0, x_level_1, x_level_2):
if self.level == 0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(x_level_2, 3, stride=2, padding=1)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
elif self.level == 1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
elif self.level == 2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')
level_1_compressed = self.compress_level_1(x_level_1)
level_1_resized = F.interpolate(level_1_compressed, scale_factor=2, mode='nearest')
level_2_resized = x_level_2
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v), 1)
levels_weight = self.weights_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] + \
level_1_resized * levels_weight[:, 1:2, :, :] + \
level_2_resized * levels_weight[:, 2:, :, :]
return fused_out_reduced
# 49. PAFPN (Path Aggregation FPN)
class PAFPN(nn.Module):
def __init__(self, in_channels, out_channels, num_outs=3):
super().__init__()
self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()
self.downsample_convs = nn.ModuleList()
self.pafpn_convs = nn.ModuleList()
for i in range(num_outs):
l_conv = Conv(in_channels[i], out_channels, 1)
fpn_conv = Conv(out_channels, out_channels, 3)
self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)
for i in range(num_outs - 1):
d_conv = Conv(out_channels, out_channels, 3, stride=2)
pafpn_conv = Conv(out_channels, out_channels, 3)
self.downsample_convs.append(d_conv)
self.pafpn_convs.append(pafpn_conv)
def forward(self, inputs):
laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
# Top-down pathway
for i in range(len(laterals) - 1, 0, -1):
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] = laterals[i - 1] + F.interpolate(laterals[i], size=prev_shape, mode='nearest')
# Bottom-up pathway
inter_outs = [self.fpn_convs[i](laterals[i]) for i in range(len(laterals))]
for i in range(len(inter_outs) - 1):
inter_outs[i + 1] = inter_outs[i + 1] + self.downsample_convs[i](inter_outs[i])
outs = [self.pafpn_convs[i](inter_outs[i]) for i in range(len(inter_outs))]
return outs
# 50. HRFPN (High Resolution FPN)
class HRFPN(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1x1 = nn.ModuleList([Conv(c, out_channels, 1) for c in in_channels])
self.conv3x3 = nn.ModuleList([Conv(out_channels, out_channels, 3) for _ in in_channels])
def forward(self, x):
outs = []
for i, (conv1, conv3) in enumerate(zip(self.conv1x1, self.conv3x3)):
if i == 0:
out = conv3(conv1(x[i]))
else:
out = conv3(conv1(x[i]) + F.interpolate(outs[-1], size=x[i].shape[2:], mode='nearest'))
outs.append(out)
return outs
# 51. FPT (Feature Pyramid Transformer)
class FPT(nn.Module):
def __init__(self, channels):
super().__init__()
self.non_local = NonLocal2d(channels)
self.scaled_dot_product = nn.MultiheadAttention(channels, 8)
def forward(self, feats):
# Self-Transformer
self_trans = self.non_local(feats[-1])
# Grounding Transformer
ground_trans = []
for i in range(len(feats) - 1):
q = feats[i].flatten(2).permute(2, 0, 1)
k = feats[-1].flatten(2).permute(2, 0, 1)
v = feats[-1].flatten(2).permute(2, 0, 1)
attn_out, _ = self.scaled_dot_product(q, k, v)
ground_trans.append(attn_out.permute(1, 2, 0).view_as(feats[i]))
return ground_trans + [self_trans]
# 52. CFP (Centralized Feature Pyramid)
class CFP(nn.Module):
def __init__(self, channels, num_levels=3):
super().__init__()
self.lateral_convs = nn.ModuleList([Conv(c, channels, 1) for c in channels])
self.fpn_convs = nn.ModuleList([Conv(channels, channels, 3) for _ in range(num_levels)])
self.global_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels // 4, 1),
nn.ReLU(),
nn.Conv2d(channels // 4, channels, 1),
nn.Sigmoid()
)
def forward(self, inputs):
laterals = [conv(inp) for conv, inp in zip(self.lateral_convs, inputs)]
# Centralized fusion
global_feat = sum([F.adaptive_avg_pool2d(l, 1) for l in laterals])
global_attn = self.global_attn(global_feat)
outs = []
for i, lateral in enumerate(laterals):
if i < len(laterals) - 1:
upsampled = F.interpolate(laterals[i + 1], size=lateral.shape[2:], mode='nearest')
lateral = lateral + upsampled
out = self.fpn_convs[i](lateral * global_attn)
outs.append(out)
return outs
# 53. AFPN (Asymptotic FPN)
class AFPN(nn.Module):
def __init__(self, in_channels, out_channels, num_outs=3):
super().__init__()
self.asymptotic_fusion = nn.ModuleList()
for i in range(num_outs):
self.asymptotic_fusion.append(
nn.Sequential(
Conv(in_channels[i], out_channels, 1),
Conv(out_channels, out_channels, 3)
)
)
def forward(self, inputs):
outs = []
for i, (inp, fusion) in enumerate(zip(inputs, self.asymptotic_fusion)):
if i > 0:
inp = inp + F.interpolate(outs[-1], size=inp.shape[2:], mode='nearest')
out = fusion(inp)
outs.append(out)
return outs
# 54. GFPN (Generalized FPN)
class GFPN(nn.Module):
def __init__(self, channels_list):
super().__init__()
self.nodes = nn.ModuleList()
for i in range(len(channels_list) - 1):
self.nodes.append(
nn.Sequential(
Conv(channels_list[i] + channels_list[i+1], channels_list[i+1], 1),
Conv(channels_list[i+1], channels_list[i+1], 3)
)
)
def forward(self, features):
for i, node in enumerate(self.nodes):
up = F.interpolate(features[i], size=features[i+1].shape[2:], mode='nearest')
features[i+1] = node(torch.cat([up, features[i+1]], dim=1))
return features
# 55. C3k2_Improved (改进版C3k2)
class C3k2_Improved(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
super().__init__()
self.c = int(c2 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1)
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=(k, k), e=1.0) for _ in range(n))
def forward(self, x):
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
跨尺度融合 (56-65)
Python
# 56. FRM (Feature Refinement Module)
class FRM(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.conv_high = Conv(c1, c2, 1)
self.conv_low = Conv(c1, c2, 1)
self.conv_out = Conv(c2, c2, 3)
def forward(self, x_high, x_low):
x_high = F.interpolate(x_high, size=x_low.shape[2:], mode='nearest')
x_high = self.conv_high(x_high)
x_low = self.conv_low(x_low)
out = x_high + x_low
return self.conv_out(out)
# 57. FFM (Feature Fusion Module)
class FFM(nn.Module):
def __init__(self, c1, c2, num_inputs=2):
super().__init__()
self.conv = Conv(c1 * num_inputs, c2, 1)
self.attn = ECA(c2)
def forward(self, *inputs):
# Resize all inputs to the same size
target_size = inputs[0].shape[2:]
resized = [F.interpolate(inp, size=target_size, mode='nearest') if inp.shape[2:] != target_size else inp
for inp in inputs]
concat = torch.cat(resized, dim=1)
return self.attn(self.conv(concat))
# 58. CCFM (Cross-Scale Feature Fusion)
class CCFM(nn.Module):
def __init__(self, channels):
super().__init__()
self.convs = nn.ModuleList([Conv(c, channels, 1) for c in channels])
self.fusion = Conv(channels * len(channels), channels, 1)
def forward(self, features):
# Resize all to medium scale
target = features[len(features)//2]
resized = [F.interpolate(self.convs[i](f), size=target.shape[2:], mode='nearest')
for i, f in enumerate(features)]
return self.fusion(torch.cat(resized, dim=1))
# 59. SFF (Selective Feature Fusion)
class SFF(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.conv1 = Conv(c1, c2, 1)
self.conv2 = Conv(c1, c2, 1)
self.selector = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(c2 * 2, 2, 1),
nn.Softmax(dim=1)
)
def forward(self, x1, x2):
x1 = self.conv1(x1)
x2 = F.interpolate(self.conv2(x2), size=x1.shape[2:], mode='nearest')
weights = self.selector(torch.cat([x1, x2], dim=1))
return x1 * weights[:, 0:1] + x2 * weights[:, 1:2]
# 60. MSFM (Multi-Scale Fusion Module)
class MSFM(nn.Module):
def __init__(self, c1, c2, scales=[1, 2, 4]):
super().__init__()
self.branches = nn.ModuleList([
nn.Sequential(
nn.AvgPool2d(scale, stride=scale) if scale > 1 else nn.Identity(),
Conv(c1, c2 // len(scales), 1),
nn.Upsample(scale_factor=scale, mode='nearest') if scale > 1 else nn.Identity()
) for scale in scales
])
self.fusion = Conv(c2, c2, 1)
def forward(self, x):
outs = [branch(x) for branch in self.branches]
# Resize all to original size
outs = [F.interpolate(o, size=x.shape[2:], mode='nearest') for o in outs]
return self.fusion(torch.cat(outs, dim=1))
# 61. A2C2f_Improved (改进版A2C2f)
class A2C2f_Improved(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__()
self.c = int(c2 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv((2 + n) * self.c, c2, 1)
self.m = nn.ModuleList(A2(self.c, self.c) for _ in range(n))
def forward(self, x):
y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
# 62. SimpleFusion (极简融合)
class SimpleFusion(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = Conv(c1 * 2, c2, 3)
def forward(self, x1, x2):
x1 = self.up(x1)
return self.conv(torch.cat([x1, x2], dim=1))
# 63. DenseFusion (密集连接融合)
class DenseFusion(nn.Module):
def __init__(self, channels_list):
super().__init__()
self.convs = nn.ModuleList()
for i in range(len(channels_list)):
self.convs.append(Conv(sum(channels_list[:i+1]), channels_list[i], 1))
def forward(self, features):
outs = [features[0]]
for i in range(1, len(features)):
upsampled = [F.interpolate(f, size=features[i].shape[2:], mode='nearest') for f in outs]
fused = torch.cat(upsampled + [features[i]], dim=1)
outs.append(self.convs[i](fused))
return outs
# 64. AttentionFusion (注意力融合)
class AttentionFusion(nn.Module):
def __init__(self, c1, c2):
super().__init__()
self.conv1 = Conv(c1, c2, 1)
self.conv2 = Conv(c1, c2, 1)
self.attn = nn.Sequential(
nn.Conv2d(c2 * 2, c2, 1),
nn.BatchNorm2d(c2),
nn.Sigmoid()
)
def forward(self, x1, x2):
x1 = self.conv1(x1)
x2 = F.interpolate(self.conv2(x2), size=x1.shape[2:], mode='nearest')
attn = self.attn(torch.cat([x1, x2], dim=1))
return x1 * attn + x2 * (1 - attn)
Python
# 65. ScaleAdaptiveFusion (尺度自适应融合) - 续
class ScaleAdaptiveFusion(nn.Module):
def __init__(self, c1, c2, num_scales=3):
super().__init__()
self.scale_convs = nn.ModuleList([Conv(c1, c2, 3, stride=2**i) for i in range(num_scales)])
self.fusion = Conv(c2 * num_scales, c2, 1)
self.scale_weights = nn.Parameter(torch.ones(num_scales))
def forward(self, x):
multi_scale_feats = []
for conv in self.scale_convs:
feat = conv(x)
multi_scale_feats.append(F.interpolate(feat, size=x.shape[2:], mode='nearest'))
weights = F.softmax(self.scale_weights, dim=0)
weighted_feats = [f * w.view(1, 1, 1, 1) for f, w in zip(multi_scale_feats, weights)]
return self.fusion(torch.cat(weighted_feats, dim=1))
🎯 四、检测头改进 (66-80)
Python
# 66. Decoupled Head (解耦头)
class DecoupledHead(nn.Module):
def __init__(self, c1, num_classes=80, anchors=1):
super().__init__()
self.stems = nn.Conv2d(c1, c1, 1)
self.cls_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.reg_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.cls_preds = nn.Conv2d(c1, num_classes * anchors, 1)
self.reg_preds = nn.Conv2d(c1, 4 * anchors, 1)
self.obj_preds = nn.Conv2d(c1, anchors, 1)
def forward(self, x):
x = self.stems(x)
cls_feat = self.cls_convs(x)
reg_feat = self.reg_convs(x)
return torch.cat([
self.reg_preds(reg_feat),
self.obj_preds(reg_feat),
self.cls_preds(cls_feat)
], 1)
# 67. Anchor-Free Head
class AnchorFreeHead(nn.Module):
def __init__(self, c1, num_classes=80):
super().__init__()
self.cls_conv = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3),
nn.Conv2d(c1, num_classes, 1)
)
self.reg_conv = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3),
nn.Conv2d(c1, 4, 1)
)
def forward(self, x):
cls_score = self.cls_conv(x)
bbox_pred = self.reg_conv(x).exp()
return cls_score, bbox_pred
# 68. DyHead (Dynamic Head)
class DyHead(nn.Module):
def __init__(self, c1, num_classes=80, num_anchors=3):
super().__init__()
self.scale_attn = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(c1, 1, 1),
nn.Sigmoid()
)
self.spatial_attn = nn.Sequential(
nn.Conv2d(c1, 1, 7, padding=3),
nn.Sigmoid()
)
self.task_attn = nn.Sequential(
nn.Conv2d(c1, c1, 1),
nn.Sigmoid()
)
self.cls_pred = nn.Conv2d(c1, num_classes * num_anchors, 1)
self.reg_pred = nn.Conv2d(c1, 4 * num_anchors, 1)
def forward(self, x):
# Scale-aware attention
scale = self.scale_attn(x)
# Spatial-aware attention
spatial = self.spatial_attn(x)
# Task-aware attention
feat = x * scale * spatial
task = self.task_attn(feat)
cls_feat = feat * task
reg_feat = feat * (1 - task)
return torch.cat([self.reg_pred(reg_feat), self.cls_pred(cls_feat)], 1)
# 69. TOOD Head (Task-aligned One-stage Object Detection)
class TOODHead(nn.Module):
def __init__(self, c1, num_classes=80, num_anchors=1):
super().__init__()
self.num_classes = num_classes
self.inter_convs = nn.ModuleList([
Conv(c1, c1, 3) for _ in range(4)
])
self.task_decomposition = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(c1, c1, 1),
nn.ReLU(),
nn.Conv2d(c1, c1, 1),
nn.Sigmoid()
)
self.cls_prob_module = nn.Sequential(
nn.Conv2d(c1, c1, 3, padding=1),
nn.BatchNorm2d(c1),
nn.ReLU(),
nn.Conv2d(c1, num_classes * num_anchors, 1)
)
self.reg_offset_module = nn.Sequential(
nn.Conv2d(c1, c1, 3, padding=1),
nn.BatchNorm2d(c1),
nn.ReLU(),
nn.Conv2d(c1, 4 * num_anchors, 1)
)
def forward(self, x):
inter_feat = x
for inter_conv in self.inter_convs:
inter_feat = inter_conv(inter_feat)
task_feat = self.task_decomposition(inter_feat)
cls_feat = inter_feat + task_feat * inter_feat
reg_feat = inter_feat + (1 - task_feat) * inter_feat
cls_score = self.cls_prob_module(cls_feat)
bbox_pred = self.reg_offset_module(reg_feat)
return torch.cat([bbox_pred, cls_score], 1)
# 70. GFL Head (Generalized Focal Loss)
class GFLHead(nn.Module):
def __init__(self, c1, num_classes=80):
super().__init__()
self.cls_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.reg_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.gfl_cls = nn.Conv2d(c1, num_classes, 1)
self.gfl_reg = nn.Conv2d(c1, 4 * (16 + 1), 1) # 16 for DFL + 1 for bbox
def forward(self, x):
cls_feat = self.cls_convs(x)
reg_feat = self.reg_convs(x)
cls_score = self.gfl_cls(cls_feat)
bbox_pred = self.gfl_reg(reg_feat)
return cls_score, bbox_pred
# 71. PPYOLOE Head
class PPYOLOEHead(nn.Module):
def __init__(self, c1, num_classes=80, num_anchors=3):
super().__init__()
self.stem_cls = Conv(c1, c1, 1)
self.stem_reg = Conv(c1, c1, 1)
self.pred_cls = nn.Sequential(
Conv(c1, c1, 3),
nn.Conv2d(c1, num_classes * num_anchors, 1)
)
self.pred_reg = nn.Sequential(
Conv(c1, c1, 3),
nn.Conv2d(c1, 4 * num_anchors, 1)
)
self.pred_iou = nn.Sequential(
Conv(c1, c1, 3),
nn.Conv2d(c1, num_anchors, 1)
)
def forward(self, x):
cls_feat = self.stem_cls(x)
reg_feat = self.stem_reg(x)
cls_logit = self.pred_cls(cls_feat)
bbox_pred = self.pred_reg(reg_feat)
iou_mat = self.pred_iou(reg_feat)
return torch.cat([bbox_pred, iou_mat, cls_logit], 1)
# 72. RT-DETR Head
class RTDETRHead(nn.Module):
def __init__(self, c1, num_classes=80, hidden_dim=256, num_queries=300):
super().__init__()
self.input_proj = nn.Conv2d(c1, hidden_dim, 1)
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(hidden_dim, 8, batch_first=True),
num_layers=6
)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.class_embed = nn.Linear(hidden_dim, num_classes)
self.bbox_embed = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 4)
)
def forward(self, x):
# Simplified version
bs, c, h, w = x.shape
x = self.input_proj(x).flatten(2).permute(0, 2, 1)
query = self.query_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
tgt = self.transformer(query, x)
outputs_class = self.class_embed(tgt)
outputs_coord = self.bbox_embed(tgt).sigmoid()
return outputs_class, outputs_coord
# 73. YOLOX Head
class YOLOXHead(nn.Module):
def __init__(self, c1, num_classes=80):
super().__init__()
self.stem = Conv(c1, c1, 1)
self.cls_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.reg_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.cls_pred = nn.Conv2d(c1, num_classes, 1)
self.reg_pred = nn.Conv2d(c1, 4, 1)
self.obj_pred = nn.Conv2d(c1, 1, 1)
def forward(self, x):
x = self.stem(x)
cls_feat = self.cls_convs(x)
reg_feat = self.reg_convs(x)
cls_output = self.cls_pred(cls_feat)
reg_output = self.reg_pred(reg_feat)
obj_output = self.obj_pred(reg_feat)
return torch.cat([reg_output, obj_output, cls_output], 1)
# 74. NanoDet Head
class NanoDetHead(nn.Module):
def __init__(self, c1, num_classes=80, reg_max=7):
super().__init__()
self.cls_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.reg_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.gfl_cls = nn.Conv2d(c1, num_classes, 1)
self.gfl_reg = nn.Conv2d(c1, 4 * (reg_max + 1), 1)
def forward(self, x):
cls_feat = self.cls_convs(x)
reg_feat = self.reg_convs(x)
cls_score = self.gfl_cls(cls_feat)
bbox_pred = self.gfl_reg(reg_feat)
return cls_score, bbox_pred
# 75. CenterNet Head
class CenterNetHead(nn.Module):
def __init__(self, c1, num_classes=80):
super().__init__()
self.heatmap = nn.Sequential(
Conv(c1, c1, 3),
nn.Conv2d(c1, num_classes, 1)
)
self.wh = nn.Sequential(
Conv(c1, c1, 3),
nn.Conv2d(c1, 2, 1)
)
self.offset = nn.Sequential(
Conv(c1, c1, 3),
nn.Conv2d(c1, 2, 1)
)
def forward(self, x):
heatmap = self.heatmap(x).sigmoid()
wh = self.wh(x).relu()
offset = self.offset(x)
return heatmap, wh, offset
# 76. FCOS Head
class FCOSHead(nn.Module):
def __init__(self, c1, num_classes=80):
super().__init__()
self.cls_tower = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3),
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.bbox_tower = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3),
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.cls_logits = nn.Conv2d(c1, num_classes, 3, padding=1)
self.bbox_pred = nn.Conv2d(c1, 4, 3, padding=1)
self.centerness = nn.Conv2d(c1, 1, 3, padding=1)
def forward(self, x):
cls_feat = self.cls_tower(x)
reg_feat = self.bbox_tower(x)
cls_score = self.cls_logits(cls_feat)
bbox_pred = F.relu(self.bbox_pred(reg_feat))
centerness = self.centerness(reg_feat)
return cls_score, bbox_pred, centerness
# 77. ATSS Head
class ATSSHead(nn.Module):
def __init__(self, c1, num_classes=80, num_anchors=9):
super().__init__()
self.num_anchors = num_anchors
self.cls_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.reg_convs = nn.Sequential(
Conv(c1, c1, 3),
Conv(c1, c1, 3)
)
self.atss_cls = nn.Conv2d(c1, num_classes * num_anchors, 3, padding=1)
self.atss_reg = nn.Conv2d(c1, 4 * num_anchors, 3, padding=1)
self.atss_centerness = nn.Conv2d(c1, num_anchors, 3, padding=1)
def forward(self, x):
cls_feat = self.cls_convs(x)
reg_feat = self.reg_convs(x)
cls_score = self.atss_cls(cls_feat)
bbox_pred = self.atss_reg(reg_feat)
centerness = self.atss_centerness(reg_feat)
return cls_score, bbox_pred, centerness
# 78. PAFHead (Parallel Attention Fusion Head)
class PAFHead(nn.Module):
def __init__(self, c1, num_classes=80):
super().__init__()
self.local_conv = nn.Sequential(
Conv(c1, c1 // 2, 3),
Conv(c1 // 2, c1 // 2, 3)
)
self.global_conv = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Conv(c1, c1 // 2, 1),
nn.Upsample(scale_factor=64, mode='nearest')
)
self.fusion = Conv(c1, c1, 1)
self.pred = nn.Conv2d(c1, num_classes + 4 + 1, 1)
def forward(self, x):
local_feat = self.local_conv(x)
global_feat = self.global_conv(x)
# Resize global to match local
global_feat = F.interpolate(global_feat, size=local_feat.shape[2:], mode='nearest')
fused = self.fusion(torch.cat([local_feat, global_feat], dim=1))
return self.pred(fused)
# 79. DoubleBranchHead
class DoubleBranchHead(nn.Module):
def __init__(self, c1, num_classes=80):
super().__init__()
self.context_branch = nn.Sequential(
ASPP(c1, c1),
Conv(c1, c1, 3)
)
self.detail_branch = nn.Sequential(
nn.Conv2d(c1, c1, 3, padding=1, groups=c1),
nn.BatchNorm2d(c1),
nn.Conv2d(c1, c1, 1),
nn.BatchNorm2d(c1),
nn.ReLU()
)
self.fusion = Conv(c1 * 2, c1, 1)
self.pred = nn.Conv2d(c1, num_classes + 5, 1)
def forward(self, x):
context = self.context_branch(x)
detail = self.detail_branch(x)
fused = self.fusion(torch.cat([context, detail], dim=1))
return self.pred(fused)
# 80. LiteHead (超轻量检测头)
class LiteHead(nn.Module):
def __init__(self, c1, num_classes=80):
super().__init__()
self.shared_conv = Conv(c1, c1, 3)
self.cls_pred = nn.Conv2d(c1, num_classes, 1)
self.reg_pred = nn.Conv2d(c1, 4, 1)
def forward(self, x):
feat = self.shared_conv(x)
return torch.cat([self.reg_pred(feat), self.cls_pred(feat)], 1)
⚡ 五、损失函数与标签分配 (81-90)
Python
# 81. TaskAlignedAssigner (任务对齐分配器)
class TaskAlignedAssigner(nn.Module):
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0):
super().__init__()
self.topk = topk
self.num_classes = num_classes
self.alpha = alpha
self.beta = beta
def forward(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
# Simplified implementation
# Align metric = cls_score^alpha * iou^beta
na = pd_bboxes.shape[-2]
num_gt = gt_bboxes.shape[0]
if num_gt == 0:
return torch.zeros(na, dtype=torch.long), torch.zeros(na, self.num_classes)
# Calculate IoU
overlaps = self.iou_calculation(gt_bboxes, pd_bboxes)
# Calculate alignment metric
align_metric = pd_scores[:, gt_labels].pow(self.alpha) * overlaps.pow(self.beta)
# Select top-k candidates
topk_metrics, topk_idxs = torch.topk(align_metric, self.topk, dim=1)
return topk_idxs, align_metric
def iou_calculation(self, gt_bboxes, pd_bboxes):
# Simplified IoU calculation
return torch.rand(gt_bboxes.shape[0], pd_bboxes.shape[0])
# 82. SimOTA (Simplified Optimal Transport Assignment)
class SimOTA:
def __init__(self, num_classes=80, center_radius=2.5):
self.num_classes = num_classes
self.center_radius = center_radius
def __call__(self, pred_scores, pred_bboxes, gt_labels, gt_bboxes):
# Dynamic k estimation based on IoU
num_gt = gt_bboxes.shape[0]
num_pred = pred_bboxes.shape[0]
if num_gt == 0:
return torch.zeros(num_pred, dtype=torch.bool), torch.zeros(num_pred, dtype=torch.long)
# Cost matrix: cls_loss + 3.0 * reg_loss
ious = self.bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0))
cost = (1 - ious) * 3.0 + (1 - pred_scores[:, gt_labels])
# Dynamic k
ious_in_boxes_matrix = ious
topk_ious, _ = torch.topk(ious_in_boxes_matrix, min(10, num_pred), dim=0)
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
# Optimal assignment
matching_matrix = torch.zeros(num_pred, num_gt)
for gt_idx in range(num_gt):
_, pos_idx = torch.topk(cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
matching_matrix[pos_idx, gt_idx] = 1.0
return matching_matrix.sum(1) > 0, matching_matrix.argmax(1)
def bbox_iou(self, box1, box2):
# Simplified IoU
return torch.rand(box1.shape[0], box2.shape[0])
# 83. TAL (Task Alignment Learning)
class TAL(nn.Module):
def __init__(self, topk=13, alpha=1.0, beta=6.0):
super().__init__()
self.topk = topk
self.alpha = alpha
self.beta = beta
def forward(self, cls_scores, bbox_preds, gt_bboxes, gt_labels):
# Task alignment metric
num_gt = gt_bboxes.shape[0]
num_pred = cls_scores.shape[0]
# Classification score target
cls_targets = torch.zeros(num_pred, cls_scores.shape[1])
# IoU calculation
ious = self.compute_iou(bbox_preds, gt_bboxes)
# Alignment metric
alignment_metrics = cls_scores.gather(1, gt_labels.unsqueeze(0).expand(num_pred, -1))
alignment_metrics = alignment_metrics.pow(self.alpha) * ious.pow(self.beta)
# Positive mask
topk_metrics, topk_indices = torch.topk(alignment_metrics, self.topk, dim=0)
positive_mask = torch.zeros_like(alignment_metrics)
positive_mask.scatter_(0, topk_indices, 1)
return positive_mask, alignment_metrics
def compute_iou(self, bboxes1, bboxes2):
return torch.rand(bboxes1.shape[0], bboxes2.shape[0])
# 84. ATSS (Adaptive Training Sample Selection)
class ATSS:
def __init__(self, topk=9):
self.topk = topk
def __call__(self, anchors, gt_bboxes, num_level_anchors):
num_gt = gt_bboxes.shape[0]
num_anchors = anchors.shape[0]
if num_gt == 0:
return torch.zeros(num_anchors, dtype=torch.bool), torch.zeros(num_anchors, dtype=torch.long)
# Compute distances
gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
gt_points = torch.stack([gt_cx, gt_cy], dim=1)
anchor_cx = (anchors[:, 0] + anchors[:, 2]) / 2.0
anchor_cy = (anchors[:, 1] + anchors[:, 3]) / 2.0
anchor_points = torch.stack([anchor_cx, anchor_cy], dim=1)
distances = (anchor_points.unsqueeze(0) - gt_points.unsqueeze(1)).pow(2).sum(-1).sqrt()
# Select candidates
candidate_idxs = []
start_idx = 0
for num in num_level_anchors:
end_idx = start_idx + num
_, topk_idxs = torch.topk(distances[:, start_idx:end_idx], self.topk, dim=1, largest=False)
candidate_idxs.append(topk_idxs + start_idx)
start_idx = end_idx
candidate_idxs = torch.cat(candidate_idxs, dim=1)
# Compute IoU threshold
candidate_ious = torch.rand(num_gt, candidate_idxs.shape[1])
iou_threshold = candidate_ious.mean(dim=1, keepdim=True) + candidate_ious.std(dim=1, keepdim=True)
# Final positive mask
is_pos = candidate_ious >= iou_threshold
return is_pos.any(dim=0), is_pos.float().argmax(dim=0)
# 85. HungarianMatcher (DETR风格)
class HungarianMatcher(nn.Module):
def __init__(self, cost_class=1.0, cost_bbox=5.0, cost_giou=2.0):
super().__init__()
self.cost_class = cost_class
self.cost_bbox = cost_bbox
self.cost_giou = cost_giou
@torch.no_grad()
def forward(self, outputs, targets):
bs, num_queries = outputs["pred_logits"].shape[:2]
# Flatten to compute cost matrix
out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)
out_bbox = outputs["pred_boxes"].flatten(0, 1)
# Concat target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
# Classification cost
cost_class = -out_prob[:, tgt_ids]
# L1 cost
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# GIoU cost
cost_giou = -self.generalized_box_iou(out_bbox, tgt_bbox)
# Final cost matrix
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()
# Hungarian algorithm
from scipy.optimize import linear_sum_assignment
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split([len(v["labels"]) for v in targets], -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
def generalized_box_iou(self, boxes1, boxes2):
return torch.rand(boxes1.shape[0], boxes2.shape[0])
# 86. QualityFocalLoss
class QualityFocalLoss(nn.Module):
def __init__(self, beta=2.0):
super().__init__()
self.beta = beta
def forward(self, pred, target, score):
# pred: [N, C], target: [N], score: [N] (IoU score)
pred_sigmoid = pred.sigmoid()
scale_factor = pred_sigmoid
zerolabel = scale_factor.new_zeros(pred.shape)
loss = F.binary_cross_entropy_with_logits(pred, zerolabel, reduction='none') * scale_factor.pow(self.beta)
pos = target != 0
if pos.any():
scale_factor = score[pos] - pred_sigmoid[pos]
loss[pos] = F.binary_cross_entropy_with_logits(pred[pos], score[pos], reduction='none') * scale_factor.abs().pow(self.beta)
return loss.sum()
# 87. DistributionFocalLoss (DFL)
class DistributionFocalLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, target):
# pred: [N, 4, reg_max+1], target: [N, 4] (discretized)
n, c, reg_max = pred.shape
target = target.long()
# Get left and right targets
target_left = target.clamp(min=0, max=reg_max - 1)
target_right = (target_left + 1).clamp(max=reg_max - 1)
# Weights
weight_left = (target_right.float() - target.float())
weight_right = (target.float() - target_left.float())
# Cross entropy
loss = F.cross_entropy(pred.view(-1, reg_max), target_left.view(-1), reduction='none') * weight_left.view(-1) + \
F.cross_entropy(pred.view(-1, reg_max), target_right.view(-1), reduction='none') * weight_right.view(-1)
return loss.mean()
# 88. GIoULoss
class GIoULoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, pred, target, weight=None):
# Simplified GIoU
pred_x1, pred_y1, pred_x2, pred_y2 = pred.T
target_x1, target_y1, target_x2, target_y2 = target.T
pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
target_area = (target_x2 - target_x1) * (target_y2 - target_y1)
inter_x1 = torch.max(pred_x1, target_x1)
inter_y1 = torch.max(pred_y1, target_y1)
inter_x2 = torch.min(pred_x2, target_x2)
inter_y2 = torch.min(pred_y2, target_y2)
inter_area = (inter_x2 - inter_x1).clamp(min=0) * (inter_y2 - inter_y1).clamp(min=0)
union_area = pred_area + target_area - inter_area
iou = inter_area / (union_area + 1e-7)
# Convex hull
convex_x1 = torch.min(pred_x1, target_x1)
convex_y1 = torch.min(pred_y1, target_y1)
convex_x2 = torch.max(pred_x2, target_x2)
convex_y2 = torch.max(pred_y2, target_y2)
convex_area = (convex_x2 - convex_x1) * (convex_y2 - convex_y1)
giou = iou - (convex_area - union_area) / (convex_area + 1e-7)
loss = 1 - giou
if weight is not None:
loss = loss * weight
return loss.mean()
# 89. FocalLoss
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, pred, target):
bce_loss = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
pred_prob = pred.sigmoid()
p_t = (target * pred_prob) + ((1 - target) * (1 - pred_prob))
alpha_t = target * self.alpha + (1 - target) * (1 - self.alpha)
loss = alpha_t * (1.0 - p_t).pow(self.gamma) * bce_loss
return loss.mean()
# 90. VarifocalLoss
class VarifocalLoss(nn.Module):
def __init__(self, alpha=0.75, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, pred, target, score):
pred_sigmoid = pred.sigmoid()
weight = self.alpha * pred_sigmoid.pow(self.gamma) * (target > 0).float() + \
(1 - self.alpha) * (1 - score) * (target == 0).float()
loss = weight * F.binary_cross_entropy_with_logits(pred, target, reduction='none')
return loss.sum() / (weight.sum() + 1e-6)
🔧 六、训练策略与辅助模块 (91-100)
Python
# 91. EMA (Exponential Moving Average)
class ModelEMA:
def __init__(self, model, decay=0.9999):
self.ema = deepcopy(model).eval()
self.updates = 0
self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = model.state_dict()
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1.0 - d) * msd[k].detach()
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
for k, v in model.__dict__.items():
if len(include) and k not in include or k.startswith('_') or k in exclude:
continue
else:
setattr(self.ema, k, v)
# 92. SWA (Stochastic Weight Averaging)
class SWA:
def __init__(self, model, swa_start=10):
self.model = model
self.swa_model = deepcopy(model).eval()
self.swa_start = swa_start
self.n_averaged = 0
def update(self, model, epoch):
if epoch < self.swa_start:
return
for swa_p, p in zip(self.swa_model.parameters(), model.parameters()):
swa_p.data = (swa_p.data * self.n_averaged + p.data) / (self.n_averaged + 1)
self.n_averaged += 1
# 93. DropBlock
class DropBlock(nn.Module):
def __init__(self, block_size=7, keep_prob=0.9):
super().__init__()
self.block_size = block_size
self.keep_prob = keep_prob
def forward(self, x):
if not self.training or self.keep_prob == 1.0:
return x
gamma = (1.0 - self.keep_prob) * (x.shape[-1] ** 2) / (self.block_size ** 2) / \
((x.shape[-1] - self.block_size + 1) ** 2)
mask = torch.bernoulli(torch.ones_like(x) * gamma)
mask = F.max_pool2d(mask, self.block_size, stride=1, padding=self.block_size // 2)
mask = 1 - mask
return x * mask * mask.numel() / mask.sum()
# 94. DropPath (Stochastic Depth)
class DropPath(nn.Module):
def __init__(self, drop_prob=0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
# 95. LabelSmoothing
class LabelSmoothing(nn.Module):
def __init__(self, smoothing=0.1):
super().__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, x, target):
log_probs = F.log_softmax(x, dim=-1)
nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -log_probs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
# 96. CutMix
class CutMix:
def __init__(self, alpha=1.0):
self.alpha = alpha
def __call__(self, images, labels):
lam = np.random.beta(self.alpha, self.alpha)
rand_index = torch.randperm(images.size()[0]).to(images.device)
target_a = labels
target_b = labels[rand_index]
bbx1, bby1, bbx2, bby2 = self.rand_bbox(images.size(), lam)
images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
return images, target_a, target_b, lam
def rand_bbox(self, size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = int(W * cut_rat)
cut_h = int(H * cut_rat)
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
# 97. Mosaic Augmentation
class Mosaic:
def __init__(self, size=640):
self.size = size
def __call__(self, images, labels):
# Simplified mosaic: 4 images in one
s = self.size
mosaic_img = torch.zeros(3, s * 2, s * 2)
mosaic_labels = []
positions = [(0, 0), (s, 0), (0, s), (s, s)]
for i, (img, label) in enumerate(zip(images[:4], labels[:4])):
x1, y1 = positions[i]
mosaic_img[:, y1:y1+s, x1:x1+s] = img
# Adjust labels
if len(label) > 0:
label[:, [0, 2]] += x1
label[:, [1, 3]] += y1
mosaic_labels.append(label)
if len(mosaic_labels) > 0:
mosaic_labels = torch.cat(mosaic_labels, dim=0)
return mosaic_img, mosaic_labels
# 98. MixUp
class MixUp:
def __init__(self, alpha=0.5):
self.alpha = alpha
def __call__(self, images, labels):
lam = np.random.beta(self.alpha, self.alpha)
batch_size = images.size(0)
index = torch.randperm(batch_size).to(images.device)
mixed_images = lam * images + (1 - lam) * images[index]
return mixed_images, labels, labels[index], lam
# 99. Test Time Augmentation (TTA)
class TTA:
def __init__(self, scales=[1.0, 1.25, 0.75], flips=[False, True]):
self.scales = scales
self.flips = flips
def __call__(self, model, image):
results = []
original_size = image.shape[2:]
for scale in self.scales:
for flip in self.flips:
# Resize
if scale != 1.0:
new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
aug_img = F.interpolate(image, size=new_size, mode='bilinear', align_corners=False)
else:
aug_img = image
# Flip
if flip:
aug_img = torch.flip(aug_img, dims=[3])
# Inference
with torch.no_grad():
pred = model(aug_img)
# Reverse flip
if flip:
pred[0][..., 0] = aug_img.shape[3] - pred[0][..., 0] # x coordinate
results.append(pred)
# Merge results (simplified)
return results[0] # Return first for simplicity
# 100. MultiScaleTraining
class MultiScaleTraining:
def __init__(self, sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640]):
self.sizes = sizes
self.current_size = 640
def __call__(self, images, labels):
# Randomly select size
self.current_size = random.choice(self.sizes)
# Resize images and adjust labels
resized_images = F.interpolate(images, size=(self.current_size, self.current_size),
mode='bilinear', align_corners=False)
# Scale labels
scale_factor = self.current_size / images.shape[2]
scaled_labels = labels * scale_factor if labels.numel() > 0 else labels
return resized_images, scaled_labels
📋 使用说明
在YOLOv12中插入模块的示例:
Python
# 修改 backbone 插入注意力
class ImprovedBackbone(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Conv(3, 64, 3, 2)
self.attn1 = ECA(64) # 插入ECA注意力
self.conv2 = Conv(64, 128, 3, 2)
self.attn2 = CBAM(128) # 插入CBAM
# 修改 neck 使用改进的FPN
class ImprovedNeck(nn.Module):
def __init__(self):
super().__init__()
self.fpn = BiFPN([256, 512, 1024]) # 使用BiFPN
# 修改 head 使用解耦头
class ImprovedHead(nn.Module):
def __init__(self):
super().__init__()
self.head = DecoupledHead(256, num_classes=80) # 使用解耦头
快速替换指南:
表格
| 原模块 | 改进模块 | 效果 |
|---|---|---|
| Conv | GhostConv | 减少50%计算量 |
| C3 | C3k2_Improved | 保持精度,更快 |
| SPPF | ASPP | 多尺度感受野 |
| Concat | ASFF | 自适应特征融合 |
| Detect | DecoupledHead | 分类回归解耦 |
这100个模块覆盖了注意力机制、卷积改进、特征融合、检测头、损失函数、数据增强等各个方面,可以灵活组合使用!
转载自CSDN-专业IT技术社区
原文链接:https://blog.csdn.net/VectorShift/article/details/158586215



