本篇會提一些 PyTorch 的 BatchNorm2d
的一些細節。首先先問一個問題:你覺得下列哪個尺寸的張量在使用BatchNorm2d
(以下以"BN"簡稱)時會報錯?(設 BN 的特徵數為 3)
答案是:(1, 3, 1, 1)。
from math import prod
from torch import nn
# Shape 1
bn1 = nn.BatchNorm2d(num_features=3)
shape1 = (1, 3, 1, 1)
tensor1 = torch.arange(float(prod(shape1))).reshape(shape1)
print(tensor1)
print(bn1(tensor1))
# Shape 2
bn2 = nn.BatchNorm2d(num_features=3)
shape2 = (2, 3, 1, 1)
tensor2 = torch.arange(float(prod(shape2))).reshape(shape2)
print(tensor2)
print(bn2(tensor2))
# Shape 3
bn3 = nn.BatchNorm2d(num_features=3)
shape3 = (1, 3, 2, 2)
tensor3 = torch.arange(float(prod(shape3))).reshape(shape3)
print(tensor3)
print(bn3(tensor3))
跑上面程式碼在第一個張量形狀會報錯:ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 3, 1, 1])
,其餘則不會。
一般圖像處理有關的張量會表示為四維,四個維度分別對應:N × C × H × W。N 為 batch size,C 為 channel 數量,H 為高,W 為寬。
回到剛剛的問題,(1, 3, 1, 1) 會報錯是因為它的 batch size 只有 1 嗎?那 (1, 3, 2, 2) 為什麼可以呢?還是因為它的 H × W 只有 1?那 (1, 3, 1, 1) 為什麼不會報錯呢?
參照下圖左上角。進行 Batch Norm 時,N × (H,W flatten) 的平面會在 channel 的維度上移動計算。
圖片來源: arXiv:1903.10520v2 [cs.CV] (https://arxiv.org/abs/1903.10520v2)
所以有幾個 channel 就會得到幾個 mean 和 var,每個 channel 計算所有 N x (H, W flatten) 個值的 mean 和 var。
回到上面的例子,我們逐個來看
(1, 3, 1, 1) 會報錯是因為算 BN 時不能一次只有一個值,
圖片來源:https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html
上面這是 BatchNorm2d
計算的式子,如果只有一個值可以算,那麼 $x - E[x] = 0$ 。Training 時這樣的資料形狀還做 BN 不知有何意義?如果是 testing,應該要 model.eval()
。
https://zhuanlan.zhihu.com/p/69431151 https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html