Задачи классификации

Коротко

Definition

Задача классификации — это задача машинного обучения, в которой модель должна отнести объект к одному из заранее заданных классов.

В классификации целевая переменная является категориальной. Модель не предсказывает произвольное число, как в регрессии, а выбирает класс: например, spam / not spam, cat / dog, disease / healthy, fraud / normal.

Классификация используется, когда нужно ответить на вопрос: к какой категории относится объект.

Что требуется предсказать

В задаче классификации модель получает объект и должна предсказать его класс .

Формально:

где:

  • — входной объект;
  • — предсказанный класс;
  • — истинный класс.

Если классов два, задача называется бинарной классификацией. Если классов больше двух — многоклассовой классификацией. Если один объект может одновременно принадлежать нескольким классам, говорят о multilabel classification.

Примеры задач

Бинарная классификация:

  • письмо является спамом или не спамом;
  • транзакция мошенническая или нормальная;
  • пациент болен или здоров;
  • материал подходит или не подходит под заданный критерий.

Многоклассовая классификация:

  • распознать цифру от 0 до 9;
  • определить породу животного на изображении;
  • классифицировать тип дефекта;
  • определить тему текста.

Multilabel classification:

  • присвоить статье несколько тематических тегов;
  • определить несколько объектов на изображении;
  • отметить несколько токсичных свойств комментария.

Типы входов и выходов

Входы могут быть разными:

Выход модели обычно представлен одним из двух способов:

  • метка класса, например cat;
  • вероятности по классам, например P(cat) = 0.82, P(dog) = 0.18.

На практике вероятности часто важнее самой метки, потому что позволяют управлять порогом принятия решения.

Подходящие модели

Для табличных данных часто используют:

Для изображений часто используют:

Для текстов и последовательностей часто используют:

Выбор модели зависит от типа данных, размера датасета, требований к интерпретируемости и цены ошибки.

Функции потерь

Самая типичная функция потерь для классификации — Кросс-энтропия.

Для бинарной классификации часто используется binary cross-entropy. Для многоклассовой классификации — categorical cross-entropy или softmax cross-entropy.

Интуиция простая: функция потерь штрафует модель, если она присваивает низкую вероятность правильному классу.

Например, если истинный класс — cat, то предсказание:

  • P(cat) = 0.95 — хорошее;
  • P(cat) = 0.55 — неуверенное;
  • P(cat) = 0.01 — плохое.

Метрики оценки

Классификацию нельзя оценивать одной универсальной метрикой. Нужно учитывать баланс классов, цену ложных срабатываний и цену пропусков.

Основные метрики:

  • accuracy — доля правильных ответов;
  • precision — насколько часто положительные предсказания действительно верны;
  • recall — насколько хорошо модель находит все положительные объекты;
  • F1-score — баланс precision и recall;
  • ROC-AUC — качество ранжирования объектов по вероятности положительного класса.

Подробно эти метрики разобраны в Метрики качества классификаторов и ROC-кривая.

Типичные ошибки

Смотреть только на accuracy

Accuracy может быть обманчивой при дисбалансе классов.

Если 99% объектов относятся к классу normal, модель может всегда предсказывать normal и получать accuracy 99%, хотя она вообще не находит редкий важный класс.

Путать precision и recall

Precision отвечает на вопрос: насколько можно доверять положительным предсказаниям модели.

Recall отвечает на вопрос: сколько реальных положительных объектов модель смогла найти.

Например, в медицинской диагностике часто важен высокий recall, чтобы не пропустить больных пациентов. В антиспам-системе может быть важен precision, чтобы не отправлять нормальные письма в спам.

Не настраивать threshold

Многие модели возвращают вероятность, а класс получается после выбора порога.

Например:

Но порог 0.5 не всегда оптимален. Если цена ошибки разная, порог нужно подбирать под задачу.

Игнорировать калибровку вероятностей

Модель может хорошо разделять классы, но плохо оценивать вероятности.

Например, если модель говорит P = 0.9, это не всегда значит, что событие действительно происходит примерно в 90% таких случаев. Для задач, где важны вероятности, нужна проверка калибровки.

Допускать data leakage

Data leakage возникает, когда в обучающие признаки случайно попадает информация, которая в реальном применении модели будет недоступна.

Это может дать очень высокое качество на тесте, но плохую работу в реальности.

Минимальный пример

Допустим, нужно классифицировать письма на два класса:

ПисьмоПризнакиИстинный классПредсказанная вероятность spam
1много ссылок, подозрительная темаspam0.97
2обычная перепискаnot spam0.08
3рекламное письмоspam0.71
4письмо от коллегиnot spam0.42

Если выбрать threshold 0.5, модель предскажет:

ПисьмоВероятность spamПредсказанный класс
10.97spam
20.08not spam
30.71spam
40.42not spam

В этом примере все ответы правильные. Но если для четвёртого письма вероятность была бы 0.55, модель отправила бы нормальное письмо в спам. Поэтому в реальных задачах важно не только обучить модель, но и выбрать подходящий порог.

Практические замечания

Хороший workflow для классификации:

  1. Проверить баланс классов.
  2. Сделать корректное разделение на train, validation и test.
  3. Построить простой baseline.
  4. Обучить модель.
  5. Оценить не только accuracy, но и precision, recall, F1-score, ROC-AUC.
  6. Посмотреть confusion matrix.
  7. Подобрать threshold под практическую задачу.
  8. Проверить ошибки на важных подгруппах данных.

Во многих задачах важнее не максимальная средняя метрика, а контроль конкретного типа ошибок.

Связанные понятия