# Эффективное неявное дифференцирование: как JAX и Google Research меняют правила игры в ML

Источник: https://www.youtube.com/watch?v=8Oy7o3Yu-Xo
Канал: Yannic Kilcher
Опубликовано: 11.06.2021

---

Исследователи из Google Research представили новый подход к автоматическому дифференцированию, который может избавить разработчиков от необходимости вручную прописывать сложные математические выводы для оптимизационных задач. Янник Кильчер (Yannic Kilcher) разобрал статью, посвященную эффективному и модульному неявному дифференцированию — методу, который позволяет «пробрасывать» градиенты через внутренние циклы оптимизации, не разворачивая их и не требуя, чтобы сам алгоритм решения задачи был дифференцируемым.

## 🧠 Новая эра автоматического дифференцирования
[[JUMP:0:00]]

Автоматическое дифференцирование (autodiff) совершило революцию в машинном обучении [2:13]. По словам Янника Кильчера, в старых работах по глубокому обучению добрая половина статьи могла быть посвящена выводу градиентов предложенной архитектуры, чтобы её вообще можно было реализовать [2:26]. Современные фреймворки, такие как TensorFlow, PyTorch и JAX, сняли это бремя: теперь достаточно просто скомпоновать функции, а система сама вычислит производные [2:39].

Однако до сих пор существовала проблема с многоуровневыми задачами оптимизации, такими как:

*   Подбор гиперпараметров (hyperparameter optimization).
*   Мета-обучение (metalearning).
*   Задачи «оптимизации как слоя» (optimization as a layer).

В этих сценариях необходимо дифференцировать результат работы внутреннего алгоритма оптимизации. Как утверждает Янник Кильчер, предложенный Google Research фреймворк расширяет возможности классического autodiff на этот огромный класс приложений [3:20]. Главное преимущество здесь — отсутствие необходимости «разворачивать» (unroll) внутренний цикл оптимизации [0:25].

## 🔄 Проблема «развертывания» градиентов
[[JUMP:7:37]]

Традиционно, чтобы получить градиент функции, внутри которой сидит другой процесс обучения, фреймворкам приходится отслеживать каждый шаг этого внутреннего процесса. Например, в градиентном спуске каждый шаг обновления весов $w_{t+1} = w_t - \eta \nabla f(w_t)$ должен быть записан в граф вычислений [8:05].

Янник Кильчер выделяет два критических недостатка такого подхода:

1.  **Вычислительная сложность:** Если нейросеть обучается тысячи или миллионы шагов, «развернутое» выражение становится гигантским и крайне медленным [9:16].
2.  **Ограничение реализации:** Сам алгоритм оптимизации (например, решатель линейных систем или специфический оптимизатор в TensorFlow/PyTorch) должен быть написан на дифференцируемом языке фреймворка, что часто не соблюдается по соображениям производительности [9:30].

В качестве примера Янник Кильчер приводит мета-обучение (iMAML) [4:30]. Цель здесь — найти такую инициализацию весов, которая позволит нейросети максимально быстро адаптироваться к любой новой задаче [5:34]. Чтобы найти градиент для этой инициализации, нужно «пройти» через весь процесс обучения на конкретных задачах. Если делать это через развертывание, память быстро закончится [12:02].

## 🛠️ Решение от Google: Неявное дифференцирование
[[JUMP:13:06]]

Вместо того чтобы следить за каждым шагом оптимизатора, исследователи предложили использовать теорему о неявной функции [21:20]. По мнению автора обзора, это превращает сложный математический вывод в модульную инженерную задачу [13:18].

Ключевые этапы работы с новым фреймворком:

*   **Определение решателя (Solver):** Пользователь предоставляет функцию, которая находит решение внутренней задачи (например, `ridge_solver` для гребневой регрессии) [16:27].
*   **Условие оптимальности (Optimality Condition):** Пользователь определяет функцию $f$, которая равна нулю, когда решение оптимально [13:58]. Для задач минимизации потерь такой функцией будет градиент функции потерь (он равен нулю в точке минимума) [19:01].
*   **Декоратор:** С помощью специальной аннотации в коде (например, в JAX) решатель связывается с условием оптимальности [19:29].

Янник Кильчер подчеркивает: теперь не сам оптимизатор должен быть дифференцируемым, а только спецификация условий оптимальности [14:12]. Это колоссальный выигрыш, так как условия оптимальности (например, градиент функции потерь) обычно просты и легко дифференцируемы дважды [24:42].

## 📊 Математика «под капотом»
[[JUMP:21:20]]

В основе метода лежит идея, что если у нас есть корень функции (точка, где $f(x, \theta) = 0$), то градиент этого корня по параметру $\theta$ можно вычислить локально, не зная, как именно мы этот корень искали [22:15].

Процесс вычисления сводится к решению линейной системы вида $Ax = B$ [25:09]:

1.  Матрицы $A$ и $B$ получаются с помощью автоматического дифференцирования функции условий оптимальности [25:21].
2.  Используется стандартный линейный решатель.
3.  Результат и есть искомый градиент через всю внутреннюю процедуру [25:09].

Фреймворк поддерживает два типа условий:

1.  **Custom Root:** Поиск корня функции [26:02].
2.  **Custom Fixed Point:** Когда оптимальное решение является неподвижной точкой функции (например, в проксимальных методах) [26:02].

## 🧪 Практические примеры и применение
[[JUMP:28:41]]

Янник Кильчер перечисляет несколько сценариев, где этот модульный подход уже показал свою эффективность:

### 1. Подбор гиперпараметров в SVM
[[JUMP:28:54]]
В многоклассовых опорных векторах (SVM) есть параметры регуляризации, которые сложно настраивать градиентным спуском, так как внутренняя задача ограничена вероятностным симплексом. Новый метод позволяет эффективно находить градиент гиперпараметра через это сложное решение [29:25].

### 2. Дистилляция датасета (Dataset Distillation)
[[JUMP:30:59]]
Это амбициозная задача: найти, например, 10 идеальных изображений (по одному на класс), обучаясь на которых нейросеть покажет лучший результат на полном тестовом наборе [31:14]. Это двухуровневая оптимизация: мы обновляем пиксели этих 10 картинок, проходя через весь процесс обучения классификатора [31:39].

### 3. Сложные задачи физики и биологии
[[JUMP:30:08]]

*   **Dictionary Learning:** Оптимизация словаря признаков одновременно с функциями отображения [30:08].
*   **Молекулярная динамика:** Расчет того, как изменение размеров молекул влияет на состояние системы в равновесии [32:07].
*   **Состязательные примеры (Adversarial Examples):** Возможность обратного распространения ошибки через процедуру проекции градиента при поиске уязвимостей в моделях [28:15].

## 💻 Интеграция с JAX
[[JUMP:27:07]]

Реализация от Google плотно интегрирована в библиотеку JAX. Она позволяет переопределять стандартное поведение autodiff [27:21]. Вместо того чтобы JAX «прозрачно» дифференцировал итерации решателя, он использует аналитический неявный градиент, что значительно быстрее и точнее [27:35]. Янник Кильчер призывает всех, кто сталкивается с вложенными оптимизациями, попробовать этот инструмент, так как он «разблокирует» пласт исследований, который раньше был слишком трудоемким [1:19].