Кросс-энтропия

Коротко

Definition

Кросс-энтропия — это функция потерь, которая штрафует модель за низкую вероятность, присвоенную правильному классу.

В задачах классификации модель обычно выдаёт вероятности классов. Кросс-энтропия смотрит только на вероятность правильного класса и применяет к ней отрицательный логарифм:

Если модель уверенно дала правильному классу высокую вероятность, loss маленький. Если правильному классу дана низкая вероятность, loss большой.

Кросс-энтропия тесно связана с максимальным правдоподобием: минимизация cross-entropy эквивалентна максимизации log-likelihood на обучающих данных.

Интуиция

Представим, что правильный класс — cat.

Модель может сказать:

Вероятность правильного классаLoss
Очень маленький
Небольшой
Большой
Очень большой

Кросс-энтропия не просто проверяет, угадала модель класс или нет. Она учитывает уверенность модели.

Если модель дала правильный ответ, но была неуверенной, loss будет больше, чем у уверенной правильной модели.

Если модель уверенно ошиблась, loss будет очень большим.

Именно поэтому cross-entropy хорошо подходит для обучения классификаторов: она заставляет модель не только выбирать правильный класс, но и распределять вероятности разумно.

Основные идеи

Общая формула

Для независимых обучающих объектов:

общая cross-entropy записывается так:

где:

  • — вероятность правильного класса по мнению модели;
  • — параметры модели;
  • — количество объектов.

Бинарная кросс-энтропия

Для бинарной классификации:

модель обычно предсказывает вероятность положительного класса:

Тогда binary cross-entropy:

Эта формула просто записывает два случая без if.

Если :

Если :

Многоклассовая кросс-энтропия

Для многоклассовой классификации с классами модель выдаёт распределение вероятностей:

Если истинная метка записана в one-hot формате, то cross-entropy:

Так как в one-hot векторе только один элемент равен 1, сумма фактически выбирает вероятность правильного класса:

где — индекс правильного класса.

Связь с максимальным правдоподобием

Правдоподобие обучающей выборки:

Логарифм правдоподобия:

Максимизация log-likelihood эквивалентна минимизации отрицательного среднего log-likelihood:

Это и есть cross-entropy loss для классификации.

Почему используют логарифм:

  • произведение многих вероятностей быстро становится численно очень маленьким;
  • логарифм превращает произведение в сумму;
  • отрицательный знак превращает задачу максимизации в задачу минимизации;
  • градиенты становятся удобнее для оптимизации.

Крайние случаи

Идеальный случай:

Тогда:

Худший случай:

Тогда:

Промежуточные примеры:

Чем меньше вероятность правильного класса, тем сильнее штраф.

Когда использовать

Кросс-энтропию используют, когда:

  • задача является классификацией;
  • модель выдаёт вероятности классов или logits;
  • нужно обучить логистическую регрессию;
  • нужно обучить нейросетевой классификатор;
  • используется softmax для многоклассовой классификации;
  • используется sigmoid для бинарной или multilabel-классификации;
  • обучается языковая модель предсказывать следующий токен.

Типичные случаи:

Когда не использовать

Кросс-энтропия не подходит напрямую, если:

  • задача является обычной регрессией;
  • целевая переменная непрерывная;
  • модель не выдаёт вероятностную интерпретацию;
  • классы размечены очень шумно;
  • требуется оптимизировать не вероятность класса, а ranking-метрику;
  • важна специальная стоимость разных типов ошибок.

В регрессии обычно используют MSE, MAE, Huber loss или другие функции потерь.

При сильном дисбалансе классов обычная cross-entropy может быть недостаточной. Тогда используют:

  • class weights;
  • focal loss;
  • resampling;
  • threshold tuning;
  • специальные метрики вроде PR-AUC.

Минимальный пример

import math
 
prob_correct_class = 0.7
loss = -math.log(prob_correct_class)
 
print(loss)

Пример для PyTorch:

import torch
from torch import nn
 
logits = torch.tensor([
    [2.0, 0.5, 0.1],
    [0.2, 1.5, 0.3],
])
 
target = torch.tensor([0, 1])
 
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, target)
 
print(loss.item())

Важно: в PyTorch CrossEntropyLoss ожидает logits, а не вероятности после softmax. Softmax уже встроен внутрь функции потерь в численно стабильном виде.

Связанные понятия

Что знать перед этим

Связанные заметки