Аналітичне обчислення похідних на шаблонах C++

Тут днями писали про аналітичне знаходження похідних, що нагадало мені про моєї маленької бібліотечці на C++, яка робить майже те ж, але під час компіляції.



У чому профіт? Відповідь проста: мені потрібно було запрогать знаходження мінімуму досить складної функції, вважати похідні цієї функції за її параметрами ручкою на папірці було лінь, перевіряти потім, що я не опечатані при написанні коду, і підтримувати цей самий код — лінь подвійно, тому було вирішено написати штуковину, яка це зробить за мене. Ну, щоб в коді можна було написати щось таке:
using Formula_t = decltype (k * (_1 - r0) / (_1 + r0) * (g0 / (alpha0 - logr0 / Num<300>) - _1)); // формула
const auto residual = Formula_t::Eval (datapoint) - knownValue; // регресійний залишок

// похідні за параметрами:
const auto dg0 = VarDerivative_t<Formula_t, decltype (g0)>::Eval (datapoint);
const auto dalpha0 = VarDerivative_t<Formula_t, decltype (alpha0)>::Eval (datapoint);
const auto dk = VarDerivative_t<Formula_t, decltype (k)>::Eval (datapoint);

замість крокодилів, які вийдуть, якщо брати приватні похідні функції на картинці спочатку (вірніше, деякого її спрощеного варіанта, але він виглядає не так страшно).

Ще непогано бути достатньо впевненим, що компілятор це соптімізірует так, як якщо б відповідні похідні і функції були написані руками. А бути впевненим б хотілося — знаходити мінімум потрібно було дуже багато разів (дійсно багато, десь від сотні мільйонів до мільярда, в цьому була суть якогось обчислювального експерименту), тому обчислення похідних було б пляшковим горлечком, відбувається воно під час виконання через яку-небудь рекурсію по древообразной структурі. Якщо ж змусити компілятор обчислювати похідну, власне, під час компіляції, то є шанс, що він отриманого коду ще пройдеться оптимізатором, і ми не втратимо порівняно з ручним випискою всіх похідних. Шанс реалізувався, до речі.

Під катом — невеликий опис, як воно там все працює.



Почнемо з подання функції у програмі. Чомусь так вийшло, що кожна функція — це тип. Функція — це ще й дерево виразів, і вузол дерева подається типом
Node
:
template < typename NodeClass, typename... Args>
struct Node;


Тут
NodeClass
  тип вузла (мінлива, число, унарная функція, бінарна функція),
Args
  параметри цього вузла (індекс змінної, значення числа, дочірні вузли).

Вузли уміють диференціювати себе, друкувати і обчислювати для даних значень вільних змінних і параметрів. Так, якщо визначений тип для представлення вузла з звичайним числом:
using NumberType_t = long long;

template<NumberType_t N>
struct Number {};

то спеціалізація вузла для чисел тривіальна:
template<NumberType_t N>
struct Node<Number<N>>
{
template<char FPrime, int IPrime>
using Derivative_t = Node<Number<0>>;

static std::string Print ()
{
return std::to_string (N);
}

template < typename Vec>
static typename Vec::value_type Eval (const Vec&)
{
return N;
}

constexpr Node () {}
};

Похідна будь-якого числа з будь-якої змінної — нуль (за це відповідає тип
Derivative_t
, залишимо поки його шаблонні параметри). Роздрукувати число — теж просто (див.
Print()
). Обчислити вузол з числом — повернути це число (див.
Eval()
, шаблонний параметр
Vec
обговоримо пізніше).

Змінна видається схожим чином:
template<char Family, int Index>
struct Variable {};

Тут
Family
та
Index
— «сімейство» та індекс змінної. Так, для вони будуть дорівнювати
'w'
та
1
, а  
'x'
та
2
відповідно.

Вузол для змінної визначається трохи цікавіше, ніж для числа:
template<char Family, int Index>
struct Node<Variable<Family, Index>>
{
template<char FPrime, int IPrime>
using Derivative_t = std::conditional_t<FPrime == Family && IPrime == Index,
Node<Number<1>>,
Node<Number<0>>>;

static std::string Print ()
{
return std::string { Family, '_' } + std::to_string (Index);
}

template < typename Vec>
static typename Vec::value_type Eval (const Vec& values)
{
return values (Node {});
}

constexpr Node () {}
};

Так, похідна змінної по їй самій дорівнює одиниці, а з будь-якої іншої — нулю. Власне, параметри
FPrime
та
IPrime
для типу
Derivative_t
— це сімейство і індекс змінної, по якій потрібно взяти похідну.

Обчислення значення функції, що складається з однієї змінної зводиться до її знаходження в словнику значень
values
, який передається у функцію
Eval()
. Словник сам вміє знаходити значення потрібної змінної за її типом, тому йому ми просто передамо тип нашої змінної і повернемо відповідне значення. Як словник це робить, ми розглянемо пізніше.

З унарными функціями все стає ще цікавіше.
enum class UnaryFunction
{
Sin,
Cos,
Ln,
Neg
};

template<UnaryFunction UF>
struct UnaryFunctionWrapper;


В спеціалізації
UnaryFunctionWrapper
ми віддамо логіку щодо взяття похідних кожної конкретної унарной функції. Щоб мінімально дублювати код, будемо брати похідну унарной функції за її аргументу, за подальше диференціювання аргументу по цільової змінної через chain rule буде відповідати викликає код:
template<>
struct UnaryFunctionWrapper<UnaryFunction::Sin>
{
template < typename Child>
using Derivative_t = Node<Cos, Child>;
};

template<>
struct UnaryFunctionWrapper<UnaryFunction::Cos>
{
template < typename Child>
using Derivative_t = Node<Neg, Node<Sin, Child>>;
};

template<>
struct UnaryFunctionWrapper<UnaryFunction::Ln>
{
template < typename Child>
using Derivative_t = Node<Div, Node<Number<1>>, Child>;
};

template<>
struct UnaryFunctionWrapper<UnaryFunction::Neg>
{
template < typename>
using Derivative_t = Node<Number<-1>>;
};


Тоді сам вузол виглядає наступним чином:
template<UnaryFunction UF, typename... ChildArgs>
struct Node<UnaryFunctionWrapper<UF> Node<ChildArgs...>>
{
using Child_t = Node<ChildArgs...>;

template<char FPrime, int IPrime>
using Derivative_t = Node<Mul,
typename UnaryFunctionWrapper<UF>::template Derivative_t<Child_t>,
typename Node<ChildArgs...>::template Derivative_t<FPrime, IPrime>>;

static std::string Print ()
{
return FunctionName (UF) + "(" + Node<ChildArgs...>::Print () + ")";
}

template < typename Vec>
static typename Vec::value_type Eval (const Vec& values)
{
const auto child = Child_t::Eval (values);
return EvalUnary (UnaryFunctionWrapper<UF> {}, child);
}
};

Вважаємо похідну через chain rule — виглядає страшно, ідея проста. Обчислюємо теж просто: вважаємо значення дочірнього вузла, потім обчислюємо значення нашої унарной функції на цьому значенні за допомогою функції
EvalUnary()
. Вірніше, сімейства функцій: першим аргументом функції йде тип, визначає нашу унарную функцію, щоб гарантувати вибір потрібної перевантаження під час компіляції. Так, можна було б передавати саме значення
UF
, і розумний компілятор майже напевно зробив би всі потрібні constant propagation passes, але тут простіше перестрахуватися.

До речі, окрему унарную операцію заперечення можна було б і не писати, замінивши її на множення на мінус одиницю.

З бінарними вузлами все аналогічно, тільки похідні виглядають зовсім страшно. Для розподілу, наприклад:
template<>
struct BinaryFunctionWrapper<BinaryFunction::Div>
{
template<char Family, int Index, typename U, typename V>
using Derivative_t = Node<Div,
Node<Add,
Node<Mul,
typename U::template Derivative_t<Family, Index>,
V
>,
Node<Neg,
Node<Mul,
U,
typename V::template Derivative_t<Family, Index>
>
>
>,
Node<Mul,
V,
V
>
>;
};


Тоді шукана метафункция
VarDerivative_t
визначається досить просто, бо за фактом лише викликає
Derivative_t
в переданого їй сайту:
template < typename Node, typename Var>
struct VarDerivative;

template < typename Expr, char Family, int Index>
struct VarDerivative<Expr, Node<Variable<Family, Index>>>
{
using Result_t = typename Expr::template Derivative_t<Family, Index>;
};

template < typename Node, typename Var>
using VarDerivative_t = typename VarDerivative<Node, std::decay_t<Var>>::Result_t;


Якщо тепер визначити допоміжні змінні і типи, наприклад:
// аліаси для типів унарных і бінарних функцій:
using Sin = UnaryFunctionWrapper<UnaryFunction::Sin>;
using Cos = UnaryFunctionWrapper<UnaryFunction::Cos>;
using Neg = UnaryFunctionWrapper<UnaryFunction::Neg>;
using Ln = UnaryFunctionWrapper<UnaryFunction::Ln>;

using Add = BinaryFunctionWrapper<BinaryFunction::Add>;
using Mul = BinaryFunctionWrapper<BinaryFunction::Mul>;
using Div = BinaryFunctionWrapper<BinaryFunction::Div>;
using Pow = BinaryFunctionWrapper<BinaryFunction::Pow>;

// variable template з C++14 для визначення змінної у загальному вигляді:
template<char Family, int Index = 0>
constexpr Node<Variable<Family, Index>> Var {};

// визначимо змінну x0 для зручності, либонь, їй часто користуватися будуть:
using X0 = Node<Variable<'x', 0>>;
constexpr X0 x0;
// і так далі для інших змінних

// константа для одиниці, одиниця часто зустрічається в формулах:
constexpr Node<Number<1>> _1;

// перевантаження операторів, їм навіть не потрібно тіло, достатньо типу:
template < typename T1, typename T2>
Node<Add, std::decay_t<T1> std::decay_t<T2>> operator+ (T1, T2);

template < typename T1, typename T2>
Node<Mul, std::decay_t<T1> std::decay_t<T2>> operator* (T1, T2);

template < typename T1, typename T2>
Node<Div, std::decay_t<T1> std::decay_t<T2>> operator/ (T1, T2);

template < typename T1, typename T2>
Node<Add, std::decay_t<T1> Node<Neg, std::decay_t<T2>>> operator- (T1, T2);

// не зовсім оператори, але теж щоб зручно було писати:
template < typename T>
Node<Sin, std::decay_t<T>> Sin (T);

template < typename T>
Node<Cos, std::decay_t<T>> Cos (T);

template < typename T>
Node<Ln, std::decay_t<T>> Ln (T);

то можна буде писати код прямо як в самому початку посту.

Що залишилося?

По-перше, розібратися з тим типом, який передається у функцію
Eval()
. По-друге, згадати про можливість перетворень шуканого вирази з заміною одного піддерева на інше. Почнемо з другого, воно простіше.

Мотивація (можна пропустити): якщо трохи попрофилировать код, який вийде з поточної версії, то в очі кинеться, що досить багато часу йде на обчислення , який, взагалі кажучи, один і той же для кожної експериментальної точки. Не біда! Введемо окрему змінну, яку вважатимемо за один раз перед розрахунком значень нашої формули на кожній з експериментальних точок, і замінити всі входження на цю змінну (власне, в мотиваційному коді на самому початку це вже зроблено). Однак, коли ми будемо брати похідну за , нам доведеться згадати, що , взагалі кажучи, не вільний параметр, а функція від . Згадати дуже просто: замінимо (для цього використовується метафункция
ApplyDependency_t
, хоча правильніше було б її назвати
Rewrite_t
або типу того), продиференціюємо, повернемо назад:
using Unwrapped_t = ApplyDependency_t<decltype (logr0), decltype (Ln (r0)), Formula_t>;
using Derivative_t = VarDerivative_t<Unwrapped_t, decltype (r0)>;
using CacheLog_t = ApplyDependency_t<decltype (Ln (r0)), decltype (logr0), Derivative_t>;


Реалізація багатослівна, але ідейно проста. Рекурсивно спускаємося по дереву формули, підміняючи елемент дерева, якщо він в точності збігається з шаблоном, інакше нічого не міняємо. Разом три спеціалізації: для спуску по підсайті унарной функції, для спуску по дочірніх сайтів бінарної функції, і власне для заміни, при цьому спеціалізації для спуску по дочірнім вузлів повинні перевіряти, що шаблон не збігається з поддеревом, відповідним розглянутої підфункції:
template < typename Var, typename Expr, typename Formula, typename Enable = void>
struct ApplyDependency
{
using Result_t = Formula;
};

template < typename Var, typename Expr, typename Formula>
using ApplyDependency_t = typename ApplyDependency<std::decay_t<Var> std::decay_t<Expr>, Formula>::Result_t;

template < typename Var, typename Expr, UnaryFunction UF, typename Child>
struct ApplyDependency<Var, Expr, Node<UnaryFunctionWrapper<UF>, Child>,
std::enable_if_t<!std::is_same<Var, Node<UnaryFunctionWrapper<UF>, Child>>::value>>
{
using Result_t = Node<
UnaryFunctionWrapper<UF>,
ApplyDependency_t<Var, Expr, Child>
>;
};

template < typename Var, typename Expr, BinaryFunction BF, typename FirstNode, typename SecondNode>
struct ApplyDependency<Var, Expr, Node<BinaryFunctionWrapper<BF>, FirstNode, SecondNode>,
std::enable_if_t<!std::is_same<Var, Node<BinaryFunctionWrapper<BF>, FirstNode, SecondNode>>::value>>
{
using Result_t = Node<
BinaryFunctionWrapper<BF>,
ApplyDependency_t<Var, Expr, FirstNode>,
ApplyDependency_t<Var, Expr, SecondNode>
>;
};

template < typename Var, typename Expr>
struct ApplyDependency<Var, Expr, Var>
{
using Result_t = Expr;
};


Ффух. Залишилося розібратися з передачею значень параметрів.

Згадаймо, що кожен параметр має свій власний тип, тому якщо ми побудуємо сімейство функцій, перевантажених за типом параметрів, кожна з яких повертає відповідне значення, то знову (прямо як з обчисленням унарных функцій трохи раніше) є шанс, що компілятор це справа скрутить і соптімізірует (а він, до речі, і соптімізірует, розумник такий). Ну, щось на зразок
auto GetValue (Variable<'x', 0>)
{
return value_for_x0;
}

auto GetValue (Variable<'x', 1>)
{
return value_for_x1;
}

...


Тільки ми хочемо зробити це красиво, щоб можна було написати, наприклад:
BuildFunctor (g0, someValue,
alpha0, anotherValue,
k, yetOneMoreValue,
r0, independentVariable,
logr0, logOfTheIndependentVariable);

де
g0
,
alpha0
і компанія — об'єкти, які мають типи відповідних змінних, а слідом за ними йдуть відповідні значення.

Як ми можемо схрестити вужа і їжака, зробивши в загальному вигляді функцію, тип значення якої задається в компил-таймі, а значення — в рантайме? Лямбды поспішають на допомогу!
template < typename ValueType, typename NodeType>
auto BuildFunctor (NodeType, ValueType val)
{
return [val] (NodeType) { return val; };
}


Нехай у нас є дві такі функції, як ми можемо отримати сімейство функцій в одному просторі імен, щоб потрібна вибиралася перевантаженням? Спадкування поспішає на допомогу!
template < typename F, typename S>
struct Map : F, S
{
using F::operator();
using S::operator();

Map (F f, S)
: F { std::forward<F> (f) }
, S { std::forward<S> (s) }
{
}
};

Ми наследуемся від обох лямбд (адже лямбда розгортається в структуру зі згенерованим компілятором ім'ям, а значить, від неї можна успадковуватись) і приносимо в скоуп їх оператори-круглі-дужки.

Більш того, можна успадковуватись не тільки від лямбд, але й від довільних структур, що мають які-небудь оператори-круглі-дужки. Опа, отримали алгебру. Таким чином, якщо є N лямбд, можна отнаследовать першу
Map
від перших двох лямбд, наступну
Map
— від першої
Map
та наступної лямбды, і так далі. Оформимо це у вигляді коду:
template < typename F>
auto Augment (F&& f)
{
return f;
}

template < typename F, typename S>
auto Augment (F&& f, S&& s)
{
return Map<std::decay_t<F> std::decay_t<S>> { f, s };
}

template < typename ValueType>
auto BuildFunctor ()
{
struct
{
ValueType operator() () const
{
return {};
}

using value_type = ValueType;
} dummy;
return dummy;
}

template < typename ValueType, typename NodeType, typename... Tail>
auto BuildFunctor (NodeType, ValueType val, Tail&&... tail)
{
return detail::Augment ([val] (NodeType) { return val; },
BuildFunctor<ValueType> (std::forward<Tail> (tail)...));
}


Автоматом отримуємо повноту та єдиність: якщо якісь аргументи не будуть задані, це буде помилкою компіляції, так само як і якщо якісь аргументи будуть задані двічі.

Власне, все.

Хіба що, один мій приятель, з яким я показував у свій час, запропонував, на мій погляд, більш елегантне рішення на constexpr-функції, але у мене до нього вже 9 місяців не доходять руки.

Ну і лінк на бібліотеку: I Am Mad. До продакшену не готове, пуллреквесты приймаються, і все таке.

Ну і ще можна поудивляться, наскільки розумні сучасні компілятори, які можуть продертися крізь ось ці всі верстви шаблонів поверх шаблонів поверх лямбд поверх шаблонів і згенерувати досить оптимальний код.
Джерело: Хабрахабр

0 коментарів

Тільки зареєстровані та авторизовані користувачі можуть залишати коментарі.