PyTorch Batch Normalization 1 - 四維張量是如何進行計算?

Chinese Deep Learning

written by LiaoWC on 2021-06-01


本篇會提一些 PyTorch 的 BatchNorm2d 的一些細節。首先先問一個問題:你覺得下列哪個尺寸的張量在使用BatchNorm2d(以下以"BN"簡稱)時會報錯?(設 BN 的特徵數為 3)

  1. (1, 3, 1, 1)
  2. (2, 3, 1, 1)
  3. (1, 3, 2, 2)

答案是:(1, 3, 1, 1)。

from math import prod
from torch import nn

# Shape 1
bn1 = nn.BatchNorm2d(num_features=3)
shape1 = (1, 3, 1, 1)
tensor1 = torch.arange(float(prod(shape1))).reshape(shape1)
print(tensor1)
print(bn1(tensor1))
# Shape 2
bn2 = nn.BatchNorm2d(num_features=3)
shape2 = (2, 3, 1, 1)
tensor2 = torch.arange(float(prod(shape2))).reshape(shape2)
print(tensor2)
print(bn2(tensor2))
# Shape 3
bn3 = nn.BatchNorm2d(num_features=3)
shape3 = (1, 3, 2, 2)
tensor3 = torch.arange(float(prod(shape3))).reshape(shape3)
print(tensor3)
print(bn3(tensor3))

跑上面程式碼在第一個張量形狀會報錯:ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 3, 1, 1]),其餘則不會。

一般圖像處理有關的張量會表示為四維,四個維度分別對應:N × C × H × W。N 為 batch size,C 為 channel 數量,H 為高,W 為寬。

回到剛剛的問題,(1, 3, 1, 1) 會報錯是因為它的 batch size 只有 1 嗎?那 (1, 3, 2, 2) 為什麼可以呢?還是因為它的 H × W 只有 1?那 (1, 3, 1, 1) 為什麼不會報錯呢?

參照下圖左上角。進行 Batch Norm 時,N × (H,W flatten) 的平面會在 channel 的維度上移動計算。

comparing-bn.png

圖片來源: arXiv:1903.10520v2 [cs.CV] (https://arxiv.org/abs/1903.10520v2)

所以有幾個 channel 就會得到幾個 mean 和 var,每個 channel 計算所有 N x (H, W flatten) 個值的 mean 和 var。

回到上面的例子,我們逐個來看

  • (1, 3, 1, 1) ⇒ 3個 channel 各計算 1 × 1 × 1 = 1 個值
  • (2, 3, 1, 1) ⇒ 3個 channel 各計算 2 × 1 × 1 = 2 個值
  • (1, 3, 2, 2) ⇒ 3個 channel 各計算 1 × 2 × 2 = 4 個值

(1, 3, 1, 1) 會報錯是因為算 BN 時不能一次只有一個值,

bn2dformula.png

圖片來源:https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html

上面這是 BatchNorm2d 計算的式子,如果只有一個值可以算,那麼 $x - E[x] = 0$ 。Training 時這樣的資料形狀還做 BN 不知有何意義?如果是 testing,應該要 model.eval()

Reference

https://zhuanlan.zhihu.com/p/69431151 https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html