Задачи классификации
Коротко
Definition
Задача классификации — это задача машинного обучения, в которой модель должна отнести объект к одному из заранее заданных классов.
В классификации целевая переменная является категориальной. Модель не предсказывает произвольное число, как в регрессии, а выбирает класс: например, spam / not spam, cat / dog, disease / healthy, fraud / normal.
Классификация используется, когда нужно ответить на вопрос: к какой категории относится объект.
Что требуется предсказать
В задаче классификации модель получает объект и должна предсказать его класс .
Формально:
где:
- — входной объект;
- — предсказанный класс;
- — истинный класс.
Если классов два, задача называется бинарной классификацией. Если классов больше двух — многоклассовой классификацией. Если один объект может одновременно принадлежать нескольким классам, говорят о multilabel classification.
Примеры задач
Бинарная классификация:
- письмо является спамом или не спамом;
- транзакция мошенническая или нормальная;
- пациент болен или здоров;
- материал подходит или не подходит под заданный критерий.
Многоклассовая классификация:
- распознать цифру от 0 до 9;
- определить породу животного на изображении;
- классифицировать тип дефекта;
- определить тему текста.
Multilabel classification:
- присвоить статье несколько тематических тегов;
- определить несколько объектов на изображении;
- отметить несколько токсичных свойств комментария.
Типы входов и выходов
Входы могут быть разными:
- Табличные данные — признаки клиента, пациента, материала или объекта;
- Изображения — фотографии, медицинские снимки, микроскопия;
- Последовательности — текст, временные ряды, сигналы;
- Графовые данные — молекулы, кристаллы, социальные связи.
Выход модели обычно представлен одним из двух способов:
- метка класса, например
cat; - вероятности по классам, например
P(cat) = 0.82,P(dog) = 0.18.
На практике вероятности часто важнее самой метки, потому что позволяют управлять порогом принятия решения.
Подходящие модели
Для табличных данных часто используют:
Для изображений часто используют:
Для текстов и последовательностей часто используют:
Выбор модели зависит от типа данных, размера датасета, требований к интерпретируемости и цены ошибки.
Функции потерь
Самая типичная функция потерь для классификации — Кросс-энтропия.
Для бинарной классификации часто используется binary cross-entropy. Для многоклассовой классификации — categorical cross-entropy или softmax cross-entropy.
Интуиция простая: функция потерь штрафует модель, если она присваивает низкую вероятность правильному классу.
Например, если истинный класс — cat, то предсказание:
P(cat) = 0.95— хорошее;P(cat) = 0.55— неуверенное;P(cat) = 0.01— плохое.
Метрики оценки
Классификацию нельзя оценивать одной универсальной метрикой. Нужно учитывать баланс классов, цену ложных срабатываний и цену пропусков.
Основные метрики:
- accuracy — доля правильных ответов;
- precision — насколько часто положительные предсказания действительно верны;
- recall — насколько хорошо модель находит все положительные объекты;
- F1-score — баланс precision и recall;
- ROC-AUC — качество ранжирования объектов по вероятности положительного класса.
Подробно эти метрики разобраны в Метрики качества классификаторов и ROC-кривая.
Типичные ошибки
Смотреть только на accuracy
Accuracy может быть обманчивой при дисбалансе классов.
Если 99% объектов относятся к классу normal, модель может всегда предсказывать normal и получать accuracy 99%, хотя она вообще не находит редкий важный класс.
Путать precision и recall
Precision отвечает на вопрос: насколько можно доверять положительным предсказаниям модели.
Recall отвечает на вопрос: сколько реальных положительных объектов модель смогла найти.
Например, в медицинской диагностике часто важен высокий recall, чтобы не пропустить больных пациентов. В антиспам-системе может быть важен precision, чтобы не отправлять нормальные письма в спам.
Не настраивать threshold
Многие модели возвращают вероятность, а класс получается после выбора порога.
Например:
Но порог 0.5 не всегда оптимален. Если цена ошибки разная, порог нужно подбирать под задачу.
Игнорировать калибровку вероятностей
Модель может хорошо разделять классы, но плохо оценивать вероятности.
Например, если модель говорит P = 0.9, это не всегда значит, что событие действительно происходит примерно в 90% таких случаев. Для задач, где важны вероятности, нужна проверка калибровки.
Допускать data leakage
Data leakage возникает, когда в обучающие признаки случайно попадает информация, которая в реальном применении модели будет недоступна.
Это может дать очень высокое качество на тесте, но плохую работу в реальности.
Минимальный пример
Допустим, нужно классифицировать письма на два класса:
| Письмо | Признаки | Истинный класс | Предсказанная вероятность spam |
|---|---|---|---|
| 1 | много ссылок, подозрительная тема | spam | 0.97 |
| 2 | обычная переписка | not spam | 0.08 |
| 3 | рекламное письмо | spam | 0.71 |
| 4 | письмо от коллеги | not spam | 0.42 |
Если выбрать threshold 0.5, модель предскажет:
| Письмо | Вероятность spam | Предсказанный класс |
|---|---|---|
| 1 | 0.97 | spam |
| 2 | 0.08 | not spam |
| 3 | 0.71 | spam |
| 4 | 0.42 | not spam |
В этом примере все ответы правильные. Но если для четвёртого письма вероятность была бы 0.55, модель отправила бы нормальное письмо в спам. Поэтому в реальных задачах важно не только обучить модель, но и выбрать подходящий порог.
Практические замечания
Хороший workflow для классификации:
- Проверить баланс классов.
- Сделать корректное разделение на train, validation и test.
- Построить простой baseline.
- Обучить модель.
- Оценить не только accuracy, но и precision, recall, F1-score, ROC-AUC.
- Посмотреть confusion matrix.
- Подобрать threshold под практическую задачу.
- Проверить ошибки на важных подгруппах данных.
Во многих задачах важнее не максимальная средняя метрика, а контроль конкретного типа ошибок.