Разбор FixMatch: как обучить нейросеть на 250 примерах вместо 50 тысяч

Yannic Kilcher 20,3 тыс. 20 мин 6 мин 15.04.2020
Главное

В новом видеоролике популярный ИИ-блогер Янник Кильхер (Yannic Kilcher) разбирает научную статью исследователей из Google Research, посвященную алгоритму FixMatch. Этот метод значительно упрощает полуконтролируемое обучение (semi-supervised learning), эффективно сочетая в себе две классические концепции: регуляризацию согласованности и псевдоразметку. Ведущий подробно анализирует архитектуру данного подхода, приводит впечатляющие результаты тестов на известном датасете CIFAR-10, а также делится долей здорового скепсиса относительно реальных источников высокой эффективности алгоритма.

🧩 Суть полуконтролируемого обучения и проблема разметки данных 0:00

В классических задачах машинного обучения специалисты часто сталкиваются с нехваткой размеченных данных. В полуконтролируемом обучении у исследователя есть два набора данных: очень маленький массив размеченных примеров (где для каждого объекта $X$ известна метка $Y$) и огромный массив неразмеченных данных (содержащий только объекты $X$). Главная цель этого подхода — использовать гигантский объем неразмеченной информации, чтобы помочь модели лучше уловить скрытые взаимосвязи и закономерности.

Для иллюстрации проблемы Янник Кильхер приводит понятный пример из медицинской сферы — классификацию снимков легких на предмет наличия опухолей.

Сбор размеченных медицинских данных сопряжен со следующими трудностями:

При этом в базах данных любой больницы можно найти огромное количество депонированных, анонимизированных, но недиагностированных снимков. Аналогичная ситуация наблюдается и в обычном интернете: ручная разметка картинок людьми требует колоссальных трудозатрат, в то время как сеть переполнена неразмещенным графическим контентом. Задача FixMatch — заставить эти «сырые» данные работать на благо точности финальной модели.

⚙️ Архитектура FixMatch: синергия двух подходов 2:21

Алгоритм FixMatch не изобретает компоненты с нуля, а элегантно объединяет два уже известных в индустрии метода:

  1. Регуляризация согласованности (consistency regularization): концепция, согласно которой модель должна выдавать похожие предсказания, если ей на вход подаются слегка измененные (пертурбированные) версии одного и того же изображения.
  2. Псевдоразметка (pseudo-labeling): идея использования самой нейросети для генерации искусственных меток класса для неразмеченных объектов.

Общая функция потерь ($Loss$), которую оптимизирует FixMatch, математически выглядит как сумма двух составляющих:

$$Loss = Loss_{supervised} + \lambda Loss_{unsupervised}$$

Здесь $Loss_{supervised}$ представляет собой стандартную кросс-энтропию для размеченной выборки, а $Loss_{unsupervised}$ — вспомогательную потерю для неразмеченных данных, перед которой стоит коэффициент баланса $\lambda$. Вся магия алгоритма, как подчеркивает Янник Кильхер, скрыта именно внутри вычисления несупервизируемой потери.

🔄 Две параллельные линии: слабая и сильная аугментация 3:52

Процесс обработки неразмеченного изображения в FixMatch разделяется на два параллельных конвейера (пайплайна). Один и тот же исходный снимок подвергается двум разным типам трансформации (аугментации).

Конвейер слабой аугментации (Weak Augmentation)

В этой ветке изображение изменяется незначительно. Исторически аугментация использовалась как легальный «чит» для искусственного расширения тренировочного датасета. Например, если взять фотографию кошки, случайно отрезать край (random crop) и растянуть его до исходного размера, картинка все равно останется кошкой. Нейросеть же воспринимает это как совершенно новый пиксельный паттерн, что помогает ей лучше обобщать данные на тестах. В пайплайне слабой аугментации FixMatch применяются:

Обработанное таким образом изображение пропускается через текущую версию модели, и на выходе получается распределение вероятностей классов. Исследователи берут класс с максимальным значением вероятности и превращают его в жесткую метку — так называемый «псевдолейбл» ($\hat{y}$). При этом в FixMatch внедрено важное условие: псевдометка принимается в расчет только тогда, когда уверенность модели ($P(Y)$) превышает жестко заданный порог. Если сеть сомневается, этот пример просто отбрасывается.

Конвейер сильной аугментации (Strong Augmentation)

Параллельно исходное изображение направляется во вторую ветку, где его буквально «изувечивают». Цель сильной аугментации — исказить картинку до предела, но так, чтобы человек все еще мог распознать исходный объект (например, лошадь). В этом пайплайне алгоритмы (такие как CTAugment и Cutout) ведут себя агрессивно:

Сильно искаженное изображение также отправляется в нейросеть для получения прогноза.

В чем главный трюк?

Главная хитрость FixMatch заключается в том, что псевдометка, полученная из «слабой» ветки, начинает принудительно считаться истинной меткой (ground truth) для «сильной» ветки. Модель штрафуют, если ее предсказание на сильно искаженном кадре не совпадает с псевдометкой, выданной на слабо измененном кадре.

По мнению Янника Кильхера, в базовом понимании этот процесс можно свести к сравнению абсолютно чистого изображения с зашумленным. В процессе обучения нейросеть фактически учится игнорировать сильные визуальные искажения. Она понимает, что какие бы безумные цветовые фильтры ни накладывались на картинку, ее финальный класс должен оставаться неизменным.

📈 Фантастические результаты на CIFAR-10 12:00

Эксперименты авторов статьи на популярном бенчмарке CIFAR-10 продемонстрировали результаты, которые Янник Кильхер называет «невероятными» и «сумасшедшими». Для контекста: стандартный объем обучающей выборки CIFAR-10 составляет 50 000 изображений, а самый передовой уровень точности (State of the Art) на полной выборке колеблется в районе 96–97%.

FixMatch показывает следующие результаты при сокращении числа размеченных данных:

Ведущий делает ремарку, что для теста с одним изображением на класс авторы, скорее всего, специально отобрали наиболее репрезентативные, «идеальные» образцы категорий. Тем не менее, удержание точности на уровне 78% всего по одной картинке — это потрясающий показатель.

🔬 Абляционное исследование и «подводные камни» гиперпараметров 15:49

Чтобы докопаться до истины и понять, за счет чего FixMatch работает так хорошо, разработчики провели абляционное исследование (ablation study), поочередно отключая различные функции. Выяснилось, что дьявол кроется в деталях и тонких настройках.

В ходе тестов были определены критически важные факторы успеха:

🤔 Критический взгляд: прорыв или победа грубой силы? 18:11

В финальной части обзора Янник Кильхер переходит к анализу скрытых проблем представленной научной работы. По мнению ведущего, сильная чувствительность алгоритма к гиперпараметрам делает подобные исследования немного сомнительными («sketchy»). Авторы статьи открыто заявляют, что подбор параметра затухания весов (weight decay) в условиях жесткого дефицита меток имеет колоссальное значение. Ошибка в выборе оптимального значения weight decay всего на один порядок в большую или меньшую сторону может мгновенно обрушить точность нейросети на 10 и более процентных пунктов.

Янник Кильхер подчеркивает парадокс: академическое сообщество увлеченно борется за прирост точности в 0,5% или 1%, изобретая новые «изящные архитектуры». При этом один неверный шаг в настройке базового гиперпараметра способен отнять у модели в 10 раз больше, чем дает этот теоретический прорыв.

В итоге ведущий заявляет, что хотя авторы FixMatch и получили беспрецедентные цифры точности, до конца неясно: является ли этот успех следствием действительно гениального архитектурного решения или же это просто результат «закидывания задачи деньгами» — то есть использования огромных вычислительных мощностей Google для идеального перебора и подгонки скрытых параметров.

💬 Цитаты

«Это невероятно — получить почти 95% точности всего со всеми 250 размеченными примерами, когда обычный датасет требует 50 тысяч.»

Янник Кильхер 12:53

«Такого рода исследования, где вы бьетесь за доли процента точности, в то время как один неверный шаг в выборе гиперпараметра стоит вам 10%, выглядят немного сомнительно.»

Янник Кильхер 19:29
👥 Спикер
🔗 Упомянутые сайты и проекты
📖 Термины
Полуконтролируемое обучение
Метод машинного обучения, использующий для тренировки моделей комбинацию из малого количества размеченных и большого количества неразмеченных данных.
Аугментация данных
Процесс создания новых тренировочных примеров путем модификации (поворотов, обрезки, изменения цвета) уже существующих объектов.
Псевдоразметка
Подход, при котором сама обучаемая модель генерирует метки классов для неразмеченных данных, которые затем используются в качестве истинных.
Weight decay
Параметр регуляризации в оптимизаторах нейросетей, который штрафует модель за слишком большие веса, предотвращая переобучение.
📊 Цифры
⚖️ Другая сторона
Искусственный интеллект FixMatch Янник Кильхер полуконтролируемое обучение Google Research аугментация данных