Batch normalization

Batch Normalization: Kỹ thuật chuẩn hóa tăng tốc và ổn định huấn luyện mạng neural

Giới thiệu

Batch Normalization (BN) là một trong những kỹ thuật quan trọng nhất trong lĩnh vực deep learning hiện đại. Được giới thiệu bởi Sergey Ioffe và Christian Szegedy trong bài báo năm 2015 "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift", kỹ thuật này đã trở thành thành phần thiết yếu trong hầu hết các kiến trúc mạng neural hiện đại, từ CNN (ResNet, EfficientNet) đến Transformer.

Bài viết này sẽ đi sâu vào nguyên lý, công thức toán học, cách triển khai và các biến thể của Batch Normalization, cũng như cách kỹ thuật này tăng tốc và ổn định quá trình huấn luyện mạng neural.

Vấn đề mà Batch Normalization giải quyết

Trước khi tìm hiểu BN hoạt động như thế nào, hãy hiểu tại sao chúng ta cần nó.

Internal Covariate Shift

Trong quá trình huấn luyện mạng neural sâu, khi trọng số của các lớp đầu thay đổi, phân phối của đầu vào cho các lớp sau cũng thay đổi. Hiện tượng này được gọi là "Internal Covariate Shift".

Ví dụ: Giả sử lớp 1 tiếp nhận dữ liệu và tính toán các đặc trưng. Nếu trọng số của lớp 1 thay đổi trong quá trình huấn luyện, đầu ra của lớp 1 (chính là đầu vào của lớp 2) cũng thay đổi. Điều này buộc lớp 2 phải liên tục thích nghi với phân phối đầu vào mới, làm chậm quá trình học.

Vanishing/Exploding Gradients

Trong mạng neural sâu, gradient có thể trở nên rất nhỏ (vanishing) hoặc rất lớn (exploding) khi lan truyền ngược qua nhiều lớp. Điều này gây khó khăn cho quá trình học, đặc biệt ở các lớp đầu tiên của mạng.

Phụ thuộc vào Learning Rate

Mạng neural sâu thường rất nhạy cảm với learning rate. Learning rate quá lớn có thể khiến mô hình không hội tụ, trong khi learning rate quá nhỏ làm quá trình học rất chậm.

Batch Normalization giải quyết tất cả các vấn đề trên bằng cách chuẩn hóa các đặc trưng trong quá trình huấn luyện.

Nguyên lý hoạt động của Batch Normalization

Ý tưởng cốt lõi

Ý tưởng chính của Batch Normalization là chuẩn hóa đầu vào của mỗi lớp sao cho có trung bình bằng 0 và phương sai bằng 1. Điều này tương tự như cách chúng ta chuẩn hóa dữ liệu đầu vào ban đầu, nhưng được thực hiện ở mỗi lớp và trong quá trình huấn luyện.

Công thức toán học

Giả sử chúng ta có một mini-batch đầu vào x = {x₁, x₂, ..., xₘ} cho một lớp trong mạng neural. Batch Normalization thực hiện các bước sau:

  1. Tính trung bình mini-batch:

    μᵦ = (1/m) * Σ(i=1 to m) xᵢ
  2. Tính phương sai mini-batch:

    σ²ᵦ = (1/m) * Σ(i=1 to m) (xᵢ - μᵦ)²
  3. Chuẩn hóa:

    x̂ᵢ = (xᵢ - μᵦ) / √(σ²ᵦ + ε)

    Trong đó ε là một hằng số nhỏ (thường 10⁻⁵) để tránh chia cho 0.

  4. Scale và shift (phép biến đổi affine):

    yᵢ = γ * x̂ᵢ + β

    Trong đó γ và β là các tham số học được.

Bước 4 cho phép mạng học lại phân phối tối ưu cho mỗi lớp, thay vì bắt buộc tất cả các lớp phải có đầu ra trung bình 0, phương sai 1.

Vị trí đặt Batch Normalization

Trong mạng neural fully connected, BN thường được đặt sau phép nhân ma trận và trước hàm kích hoạt:

z = Wx + b
z_norm = BN(z)
a = activation(z_norm)

Trong mạng CNN, BN được áp dụng cho mỗi kênh đặc trưng sau phép tích chập:

z = Conv2D(x)
z_norm = BN(z)  # Chuẩn hóa theo từng kênh
a = activation(z_norm)

Lợi ích của Batch Normalization

1. Tăng tốc quá trình huấn luyện

Batch Normalization cho phép sử dụng learning rate lớn hơn vì nó giảm sự thay đổi mạnh của gradient. Điều này có thể giúp tăng tốc huấn luyện lên 14 lần so với mạng không sử dụng BN (theo bài báo gốc).

2. Giảm sự phụ thuộc vào khởi tạo trọng số

BN giúp mạng ít nhạy cảm hơn với việc khởi tạo trọng số, vì nó chuẩn hóa đầu ra của mỗi lớp.

3. Tác dụng regularization

BN có tác dụng regularization nhẹ do cách tính toán thống kê trên mini-batch, giúp giảm overfitting. Tuy nhiên, tác dụng này không mạnh bằng các kỹ thuật regularization chuyên dụng như Dropout.

4. Giảm vanishing gradient

Bằng cách đảm bảo phân phối đầu vào của mỗi lớp không quá lệch, BN giúp giảm vấn đề vanishing gradient trong mạng sâu.

5. Giảm sự phụ thuộc vào các kỹ thuật regularization khác

Mạng sử dụng BN thường ít phụ thuộc vào các kỹ thuật như Dropout, vì BN đã có một số tác dụng regularization.

Triển khai Batch Normalization

PyTorch

import torch.nn as nn

# Trong mạng fully connected
class MyNetwork(nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.bn1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        return x

# Trong mạng CNN
class MyCNN(nn.Module):
    def __init__(self):
        super(MyCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.functional.relu(x)
        return x

TensorFlow/Keras

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation

# Mạng fully connected
model = Sequential([
    Dense(256, input_shape=(784,)),
    BatchNormalization(),
    Activation('relu'),
    Dense(10)
])

# Mạng CNN
model = Sequential([
    Conv2D(64, kernel_size=3, padding='same', input_shape=(32, 32, 3)),
    BatchNormalization(),
    Activation('relu')
])

Batch Normalization trong quá trình Inference

Trong quá trình huấn luyện, BN sử dụng thống kê từ mini-batch hiện tại. Nhưng trong quá trình inference (dự đoán), chúng ta có thể chỉ có một mẫu đầu vào hoặc batch size khác. Điều này được giải quyết bằng cách:

  1. Trong quá trình huấn luyện, BN tính toán và cập nhật trung bình cộng (moving average) của trung bình và phương sai của mỗi đặc trưng.

  2. Trong quá trình inference, BN sử dụng các giá trị trung bình cộng này thay vì tính toán thống kê trên batch hiện tại.

# PyTorch
model.eval()  # Chuyển sang chế độ evaluation, sử dụng running statistics

# TensorFlow/Keras
# Tự động xử lý, không cần thao tác đặc biệt

Các biến thể của Batch Normalization

1. Layer Normalization

Layer Normalization chuẩn hóa tất cả các đơn vị trong một lớp cho mỗi mẫu đầu vào. Điều này rất hữu ích trong RNN và Transformer, nơi độ dài chuỗi có thể thay đổi.

# PyTorch
nn.LayerNorm(normalized_shape)

# TensorFlow
tf.keras.layers.LayerNormalization()

2. Instance Normalization

Instance Normalization chuẩn hóa mỗi kênh đặc trưng cho mỗi mẫu riêng lẻ. Thường được sử dụng trong các tác vụ chuyển đổi hình ảnh như style transfer.

# PyTorch
nn.InstanceNorm2d(num_features)

# TensorFlow
tfa.layers.InstanceNormalization()  # Từ TensorFlow Addons

3. Group Normalization

Group Normalization chia các kênh thành các nhóm và chuẩn hóa trong mỗi nhóm. Phương pháp này hoạt động tốt với batch size nhỏ.

# PyTorch
nn.GroupNorm(num_groups, num_channels)

# TensorFlow
tfa.layers.GroupNormalization(groups=32)  # Từ TensorFlow Addons

Batch Normalization trong các kiến trúc hiện đại

ResNet

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = nn.functional.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = nn.functional.relu(out)
        return out

EfficientNet/MobileNet

BN rất quan trọng trong các mạng nhỏ gọn như EfficientNet và MobileNet, giúp chúng học hiệu quả dù có ít tham số.

Hiểu sâu hơn về Batch Normalization

Tranh luận về Internal Covariate Shift

Bài báo gốc cho rằng BN giải quyết vấn đề Internal Covariate Shift. Tuy nhiên, một nghiên cứu năm 2018 "How Does Batch Normalization Help Optimization?" lập luận rằng lợi ích chính của BN không phải là giảm covariate shift mà là làm mịn bề mặt tối ưu hóa, làm cho các gradient có độ lớn hợp lý hơn và ổn định hơn.

Tác động đến bề mặt mất mát

BN làm cho bề mặt mất mát mượt mà hơn và ít biến đổi hơn đối với các thay đổi tham số. Điều này cho phép tối ưu hóa hiệu quả hơn và học với learning rate cao hơn.

Quan hệ với các kỹ thuật regularization

Mặc dù BN có tác dụng regularization, nhưng nó không thay thế hoàn toàn các kỹ thuật như Dropout. Trên thực tế, việc kết hợp BN với Dropout cần được thực hiện cẩn thận vì chúng có thể tương tác không mong muốn.

Các vấn đề và thách thức với Batch Normalization

1. Phụ thuộc vào batch size

BN hoạt động tốt nhất với batch size lớn (32, 64 hoặc lớn hơn). Với batch size nhỏ, thống kê trên mini-batch không đáng tin cậy, dẫn đến hiệu suất kém. Trong trường hợp này, Layer Normalization hoặc Group Normalization có thể là lựa chọn tốt hơn.

2. Tính toán phức tạp

BN tăng chi phí tính toán và bộ nhớ, đặc biệt là trong mạng rất sâu.

3. Không phù hợp với một số tác vụ

BN có thể không phù hợp cho các tác vụ như online learning (học trực tuyến) nơi dữ liệu đến từng mẫu một.

Hướng dẫn thực hành

1. Vị trí đặt Batch Normalization

# Thông thường: Conv/FC -> BN -> ReLU
x = self.conv(x)
x = self.bn(x)
x = F.relu(x)

# ResNet: Conv -> BN -> ReLU -> Conv -> BN -> Add -> ReLU
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = F.relu(out)

2. Batch Normalization với Transfer Learning

Khi fine-tuning mô hình pre-trained với BN:

# Đóng băng các lớp BN trong backbone
for module in backbone.modules():
    if isinstance(module, nn.BatchNorm2d):
        module.eval()  # Sử dụng running statistics
        module.weight.requires_grad = False
        module.bias.requires_grad = False

3. Xử lý Batch Normalization trong Recurrent Neural Networks

BN có thể gây khó khăn trong RNN do độ dài chuỗi thay đổi. Layer Normalization thường là lựa chọn tốt hơn:

# PyTorch
class LSTM_LN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTM_LN, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.ln = nn.LayerNorm(hidden_size)
        
    def forward(self, x, hx=None):
        output, (h_n, c_n) = self.lstm(x, hx)
        output = self.ln(output)
        return output, (h_n, c_n)

Kết luận

Batch Normalization là một trong những đột phá quan trọng nhất trong lĩnh vực deep learning, giúp huấn luyện mạng neural nhanh hơn và ổn định hơn. Mặc dù cơ chế chính xác giải thích tại sao BN hiệu quả vẫn đang được nghiên cứu, nhưng lợi ích thực tế của nó đã được chứng minh rõ ràng qua vô số ứng dụng.

Các biến thể như Layer Normalization, Instance Normalization và Group Normalization mở rộng ý tưởng chuẩn hóa để phù hợp với các kiến trúc và tác vụ khác nhau. Hiểu rõ về các kỹ thuật chuẩn hóa này là điều cần thiết để thiết kế và huấn luyện hiệu quả các mạng neural hiện đại.

Last updated