4-残差连接
约 606 字大约 2 分钟
2026-03-27
经过之前的训练我们的方法已经稳定得到结果,在这一节,我们对模型进行改进,为解码器加入跳跃连接,让编码器的每次分辨率信息与解码器相互联通,接下来将详细介绍,也就是老师提供的3-skipConnect.py
我们现有的方法在前向传播的过程如下:
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.decoder(x)
return x 我们可以发现,对于解码器而言,其获取到的信息只有编码器提取的高维特征,并没有其他尺度的特征,针对此,我们使用跳跃连接的方法,首先我们先定义解码器:
self.up4 = UpBlock(512, 256, 256) # x4 + x3
self.up3 = UpBlock(256, 128, 128) # x + x2
self.up2 = UpBlock(128, 64, 64) # x + x1
self.up1 = UpBlock(64, 64, 64) # x + x0
self.final_up = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) # 1/2 -> 1
self.final_conv = nn.Sequential(
DoubleConv(32, 32),
nn.Conv2d(32, 1, kernel_size=1)
) 在随后的前向传播中,我们连接相同尺度的信息:
def forward(self, x):
# 原图尺寸
input_h, input_w = x.shape[-2:]
# ===== Encoder =====
x0 = self.conv1(x) # [B, 64, H/2, W/2]
x0 = self.bn1(x0)
x0 = self.relu(x0)
x = self.maxpool(x0) # [B, 64, H/4, W/4]
x1 = self.layer1(x) # [B, 64, H/4, W/4]
x2 = self.layer2(x1) # [B,128, H/8, W/8]
x3 = self.layer3(x2) # [B,256, H/16,W/16]
x4 = self.layer4(x3) # [B,512, H/32,W/32]
# ===== Decoder =====
x = self.up4(x4, x3) # -> 256, 1/16
x = self.up3(x, x2) # -> 128, 1/8
x = self.up2(x, x1) # -> 64, 1/4
x = self.up1(x, x0) # -> 64, 1/2
x = self.final_up(x) # -> 32, 1/1
x = self.final_conv(x) # -> 1, 1/1
if x.shape[-2:] != (input_h, input_w):
x = torch.nn.functional.interpolate(
x, size=(input_h, input_w), mode='bilinear', align_corners=False
)
return x 这样在解码器解码的过程中,可以通过跳跃连接获取相同尺度下的编码器信息。这也就形成了经典网络Unet的基本架构。
其主要结果如下:
Epoch 18/20
Train Loss: 0.3359 | Val Loss: 0.3749
Train F1: 0.8514 | Val F1: 0.7924
Train P/R: 0.8028/0.9063
Val P/R: 0.7614/0.8260
提示:轻度过拟合
学习率保持: 0.00005000
Epoch 19/20
Train Loss: 0.3239 | Val Loss: 0.3420
Train F1: 0.8769 | Val F1: 0.8186
Train P/R: 0.8736/0.8802
Val P/R: 0.8327/0.8050
提示:轻度过拟合
已保存最佳模型:best_model_resnet18skipconnect_简单数据集.pth
学习率保持: 0.00005000
Epoch 20/20
Train Loss: 0.3166 | Val Loss: 0.4180
Train F1: 0.8271 | Val F1: 0.7741
Train P/R: 0.7379/0.9410
Val P/R: 0.7006/0.8648
提示:轻度过拟合
学习率保持: 0.00005000
============================================================
训练完成
============================================================
模型:resnet18skipconnect
最佳验证 F1:0.8186 其与2-augmentation.py比较的结果如图:


