Кросс-энтропия
Коротко
Definition
Кросс-энтропия — это функция потерь, которая штрафует модель за низкую вероятность, присвоенную правильному классу.
В задачах классификации модель обычно выдаёт вероятности классов. Кросс-энтропия смотрит только на вероятность правильного класса и применяет к ней отрицательный логарифм:
Если модель уверенно дала правильному классу высокую вероятность, loss маленький. Если правильному классу дана низкая вероятность, loss большой.
Кросс-энтропия тесно связана с максимальным правдоподобием: минимизация cross-entropy эквивалентна максимизации log-likelihood на обучающих данных.
Интуиция
Представим, что правильный класс — cat.
Модель может сказать:
| Вероятность правильного класса | Loss |
|---|---|
| Очень маленький | |
| Небольшой | |
| Большой | |
| Очень большой |
Кросс-энтропия не просто проверяет, угадала модель класс или нет. Она учитывает уверенность модели.
Если модель дала правильный ответ, но была неуверенной, loss будет больше, чем у уверенной правильной модели.
Если модель уверенно ошиблась, loss будет очень большим.
Именно поэтому cross-entropy хорошо подходит для обучения классификаторов: она заставляет модель не только выбирать правильный класс, но и распределять вероятности разумно.
Основные идеи
Общая формула
Для независимых обучающих объектов:
общая cross-entropy записывается так:
где:
- — вероятность правильного класса по мнению модели;
- — параметры модели;
- — количество объектов.
Бинарная кросс-энтропия
Для бинарной классификации:
модель обычно предсказывает вероятность положительного класса:
Тогда binary cross-entropy:
Эта формула просто записывает два случая без if.
Если :
Если :
Многоклассовая кросс-энтропия
Для многоклассовой классификации с классами модель выдаёт распределение вероятностей:
Если истинная метка записана в one-hot формате, то cross-entropy:
Так как в one-hot векторе только один элемент равен 1, сумма фактически выбирает вероятность правильного класса:
где — индекс правильного класса.
Связь с максимальным правдоподобием
Правдоподобие обучающей выборки:
Логарифм правдоподобия:
Максимизация log-likelihood эквивалентна минимизации отрицательного среднего log-likelihood:
Это и есть cross-entropy loss для классификации.
Почему используют логарифм:
- произведение многих вероятностей быстро становится численно очень маленьким;
- логарифм превращает произведение в сумму;
- отрицательный знак превращает задачу максимизации в задачу минимизации;
- градиенты становятся удобнее для оптимизации.
Крайние случаи
Идеальный случай:
Тогда:
Худший случай:
Тогда:
Промежуточные примеры:
Чем меньше вероятность правильного класса, тем сильнее штраф.
Когда использовать
Кросс-энтропию используют, когда:
- задача является классификацией;
- модель выдаёт вероятности классов или logits;
- нужно обучить логистическую регрессию;
- нужно обучить нейросетевой классификатор;
- используется softmax для многоклассовой классификации;
- используется sigmoid для бинарной или multilabel-классификации;
- обучается языковая модель предсказывать следующий токен.
Типичные случаи:
- Логистическая регрессия;
- image classification;
- text classification;
- token classification;
- language modeling;
- transformer-based models;
- multiclass и multilabel классификация.
Когда не использовать
Кросс-энтропия не подходит напрямую, если:
- задача является обычной регрессией;
- целевая переменная непрерывная;
- модель не выдаёт вероятностную интерпретацию;
- классы размечены очень шумно;
- требуется оптимизировать не вероятность класса, а ranking-метрику;
- важна специальная стоимость разных типов ошибок.
В регрессии обычно используют MSE, MAE, Huber loss или другие функции потерь.
При сильном дисбалансе классов обычная cross-entropy может быть недостаточной. Тогда используют:
- class weights;
- focal loss;
- resampling;
- threshold tuning;
- специальные метрики вроде PR-AUC.
Минимальный пример
import math
prob_correct_class = 0.7
loss = -math.log(prob_correct_class)
print(loss)Пример для PyTorch:
import torch
from torch import nn
logits = torch.tensor([
[2.0, 0.5, 0.1],
[0.2, 1.5, 0.3],
])
target = torch.tensor([0, 1])
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, target)
print(loss.item())Важно: в PyTorch CrossEntropyLoss ожидает logits, а не вероятности после softmax. Softmax уже встроен внутрь функции потерь в численно стабильном виде.