Батч-нормализация
Коротко
Definition
Батч-нормализация — это слой нейронной сети, который нормализует активации внутри mini-batch, а затем обучаемо масштабирует и сдвигает их с помощью параметров и .
Batch Normalization, или BatchNorm, обычно используется для стабилизации и ускорения обучения глубоких нейронных сетей.
Идея: привести активации слоя к более стабильному распределению, чтобы следующие слои не получали входы с резко меняющимся масштабом.
Упрощённо:
где:
- — исходные активации;
- — нормализованные активации;
- — результат после обучаемого scale и shift.
Зачем нужен
Батч-нормализация помогает сделать обучение более устойчивым.
Она может:
- уменьшать чувствительность к инициализации весов;
- ускорять обучение;
- стабилизировать распределения активаций;
- позволять использовать более высокий learning rate;
- частично работать как регуляризация;
- улучшать прохождение градиентов в глубоких сетях.
Исторически BatchNorm объясняли как способ борьбы с Internal Covariate Shift: входы слоёв во время обучения постоянно меняются, потому что параметры предыдущих слоёв обновляются.
Современное понимание осторожнее: BatchNorm действительно стабилизирует обучение, но его польза не сводится только к Internal Covariate Shift. Практически важно, что слой делает оптимизацию более удобной и менее нестабильной.
Как работает
Пусть есть mini-batch:
где — размер batch.
BatchNorm выполняет четыре шага.
1. Среднее batch
Это среднее значение активаций внутри текущего batch.
2. Дисперсия batch
Это дисперсия активаций внутри текущего batch.
3. Нормализация
После этого активации внутри batch примерно имеют:
Параметр нужен для численной устойчивости, чтобы не делить на ноль.
4. Scale и shift
где:
- — обучаемый параметр масштаба;
- — обучаемый параметр сдвига.
Итоговая запись:
Обучаемые параметры
У BatchNorm есть два обучаемых параметра:
- — scale;
- — shift.
Они нужны, потому что не всегда оптимально оставлять активации строго с нулевым средним и единичной дисперсией.
Important
Если задаче нужен специфический масштаб или сдвиг, модель может восстановить их через и . BatchNorm делает параметры распределения активаций управляемыми, а не случайными.
Во время обратного распространения ошибки обновляются:
- параметры обычных слоёв;
- ;
- .
Статистики batch — среднее и дисперсия — сами по себе не являются обучаемыми параметрами, но во время обучения слой также накапливает running mean и running variance для инференса.
Режим обучения и режим инференса
BatchNorm ведёт себя по-разному во время обучения и во время инференса.
Training mode
Во время обучения BatchNorm считает среднее и дисперсию по текущему mini-batch:
Также он обновляет скользящие оценки:
- running mean;
- running variance.
Inference mode
Во время инференса BatchNorm обычно не считает статистики по текущему batch. Вместо этого используются running mean и running variance, накопленные во время обучения.
Tip
При инференсе мы больше не рассчитываем среднее и дисперсию batch, поданного в модель. Вместо этого применяются скользящие оценки статистик обучающего датасета.
Это важно: если забыть перевести модель в inference/eval mode, качество может стать нестабильным.
Что происходит между batch
Внутри одного batch нормализация сохраняет относительную структуру объектов, но между batch статистики могут немного плавать.
Чем меньше batch, тем выше шум в оценках среднего и дисперсии. Иногда этот шум полезен: он работает как регуляризация и заставляет модель искать более устойчивые признаки.
Но слишком маленький batch может сделать BatchNorm нестабильным. В таких случаях часто используют другие нормализации:
- LayerNorm;
- GroupNorm;
- InstanceNorm;
- running statistics с осторожной настройкой.
Где используется
BatchNorm часто используется:
- в CNN;
- в глубоких полносвязных сетях;
- в residual networks со Skip connection;
- в некоторых генеративных моделях;
- в классификаторах изображений;
- в segmentation-моделях;
- в старых или классических deep learning архитектурах.
В трансформерах чаще используют LayerNorm, а не BatchNorm, потому что трансформеры обычно работают с последовательностями, переменной длиной и другими режимами batching.
Связанные архитектуры
BatchNorm часто встречается рядом с:
- свёрточными слоями;
- функциями активации;
- pooling;
- Skip connection;
- CNN;
- residual blocks.
Типичный блок в CNN может выглядеть так:
Иногда порядок может отличаться, но такая схема очень распространена.
Связь с регуляризацией
BatchNorm не является регуляризацией в таком же прямом смысле, как L2 или dropout, но может давать регуляризующий эффект.
Причина: статистики mini-batch немного шумные. Один и тот же объект может нормализоваться чуть по-разному в зависимости от других объектов в batch.
Этот шум мешает модели слишком сильно полагаться на точные абсолютные значения активаций.
Тем не менее BatchNorm не стоит считать полной заменой регуляризации. В сложных моделях всё равно могут понадобиться:
- weight decay;
- dropout;
- аугментации;
- early stopping;
- правильная validation-схема.
Типичные ошибки понимания
Думать, что BatchNorm просто нормализует входные данные
BatchNorm нормализует не сырые признаки датасета, а активации внутри нейронной сети.
Для нормализации исходных признаков см. Предобработка данных.
Забывать про training/eval mode
Во время обучения BatchNorm использует статистики mini-batch. Во время инференса — running statistics.
Если модель оставить в training mode на инференсе, предсказания могут зависеть от состава batch.
Использовать BatchNorm при слишком маленьком batch
При маленьком batch оценки среднего и дисперсии становятся шумными. Это может ухудшить качество и стабильность обучения.
Считать и лишними
Без и BatchNorm всегда заставлял бы активации иметь фиксированный масштаб и сдвиг. Обучаемые scale и shift позволяют модели восстановить нужное распределение.
Считать BatchNorm универсально лучшим выбором
BatchNorm хорошо работает во многих CNN, но не всегда подходит для трансформеров, маленьких batch, autoregressive inference и некоторых последовательных задач.
Минимальный пример
Пусть в batch есть четыре значения активации одного нейрона:
| Объект | Активация |
|---|---|
| 1 | 2 |
| 2 | 4 |
| 3 | 6 |
| 4 | 8 |
Среднее:
Дисперсия:
Нормализованное значение для первого объекта:
После этого применяется scale и shift:
Если и , выход остаётся просто нормализованным. Если модель обучит другие значения и , она изменит масштаб и сдвиг нормализованных активаций.
Практические замечания
При использовании BatchNorm важно помнить:
- Для train и inference используются разные статистики.
- Маленький batch может ухудшить стабильность.
- BatchNorm часто ставят после linear/conv слоя и перед activation.
- В CNN BatchNorm обычно работает хорошо.
- В transformer-based моделях чаще используют LayerNorm.
- BatchNorm может уменьшать потребность в некоторых видах регуляризации, но не заменяет их полностью.
- При transfer learning нужно аккуратно решать, обновлять ли BatchNorm-статистики.
В задачах с маленькими batch или сильно разными объектами иногда лучше использовать GroupNorm или LayerNorm.