import torch
import numpy as np
Tutoriel de différentiation automatique
L’obectif de ce tutoriel est de montrer comment utiliser JAX et PyTorch pour calculer le JVP (Jacobian Vector Product) et le VJP (Vector Jacobian Product) d’une fonction qui n’est disponible dans les primitives fournies par JAX/Torch. Les deux cas d’usage envisagées sont:
- l’utilisation d’une fonction non différentiable pour lesquels on veut écrire une dérivée “non-standard” afin de pouvoir l’utiliser dans JAX/Torch
- l’utilisation d’une fonction donc une approximation analytique de la dérivée est disponible mais qui n’est pas implémentée dans JAX/Torch
Dans ce tutoriel, on considère une fonction jouet \(f\) qui dépend d’une entrée \(x\) et de paramètres \(a, b\).
\[ f: (x, a, b) \in \mathbb{R}^p \times \mathbb{R}^p \times \mathbb{R} \mapsto \tanh(a^\top x + b) \in \mathbb{R} \]
On rappelle que \(tanh'(x) = 1 - \tanh^2(x)\) et que \[ \frac{\partial f}{\partial x} = a.(1 - \tanh^2(a^\top x + b)) \qquad \frac{\partial d}{\partial f} = x.(1 - \tanh^2(a^\top x + b)) \qquad \frac{\partial f}{\partial b} = (1 - \tanh^2(a^\top x + b)) \] ou en mode matriciel \[ \nabla f(x, a, b) = \left(\frac{\partial f}{\partial x}, \frac{\partial f}{\partial a}, \frac{\partial f}{\partial b}\right)^\top \]
On rappelle que la différentiation automatique fait appel à la chain-rule. Pour une fonction \(g\) à valeurs dans \(\mathbb{R}^p \times \mathbb{R}^p \times \mathbb{R}\), et en notant \(h = f \circ g\), on a \[ \begin{align} \nabla h(z) & = \frac{\partial (f \circ g)(z)}{\partial z} = \nabla f(g(z))^\top \nabla g(z) \\ & = \frac{\partial (f \circ g)(z)}{\partial z} = \frac{\partial h(z)}{\partial g(z)} \frac{\partial g(z)}{\partial z} \end{align} \]
On peut calculer \(\nabla h(z)\) de deux façons:
- en mode forward (ou
jvp
): on commence par calculer \(v = \nabla g(z)\) (aussi appelétangents
) et \(g(z)\) et le gradient \(J = \nabla f(g(z))\) avant de calculer le produit scalaire \(J v\). En pratique on écrit une fonctionjvp
\((x, v) \mapsto \nabla f(x) v (\symeq f(x + v) - f(x))\) qui calcule directement le produit scalaire pour éviter d’avoir à matérialiser \(\nabla f\). - en mode reverse (ou
vjp
): on rétro-propage le gradient en calculant \(J = \frac{\partial h(z)}{\partial g(z)} = \nabla f(g(z))\) (⚠️ il faut avoir calculé et stocké \(g(z)\) au préalable) et \(v = \frac{\partial g(z)}{\partial z}\) et on calcule le produit scalaire \(v^\top J\). En pratique, on écrit une fonctionvjp
(oubackward
) \((x, v) \mapsto v^\top \nabla f(x)\) qui calcule directement le produit scalaire pour éviter d’avoir à matérialiser \(\nabla f\).
En pratique il faut écrire jvp
et vjp
pour chaque fonction utilisée dans la composition.
Théorie [XX]
Pour un rapide résumé de ce qu’est l’auto différentiation, et de son intérêt par rapport à d’autres stratégies de calcul numérique ou d’approximation des gradients d’une fonction, voici une vidéo assez complète en 14 min :
Exemple en Torch [MM, LC]
En utilisant les primitives de torch
On definit notre fonction \(f\) en torch.
def f_torch(x, a, b):
return torch.tanh(torch.dot(x, a) + b)
On définit des valeurs pour lesquelles on sait calculer facilement le gradient.
= torch.tensor([2., 3.], requires_grad = True)
x = torch.ones(2, requires_grad = True)
a = torch.tensor(-2., requires_grad = True) b
Et on calcule les dérivées partielles (avec la convention \(\partial f / \partial x =\) x.grad
).
## Définit y par rapport à x
= f_torch(x, a, b)
y
y## Calcule et évalue le graphe de différentiation automatique de y par rapport à x
y.backward()## Renvoie dy/dx
x.grad, a.grad, b.grad
(tensor([0.0099, 0.0099]), tensor([0.0197, 0.0296]), tensor(0.0099))
On peut être plus concis pour calculer notre gradient (ici par rapport à \(x\)) en définissant directement la fonction \((x, a, b) \mapsto \frac{\partial f}{\partial x}(x, a, b)\) dans df_torch_dx
Le paramètre argnums=0
précise qu’on calcule la dérivée par rapport au premier argument de \(f\), en l’occurence \(x\).
= torch.func.grad(f_torch, argnums=0) df_torch_dx
On vérifie que les deux façons de faire donnent le même résultat.
assert torch.allclose(df_torch_dx(x, a, b), x.grad)
En utilisant notre propre fonction
Le code qui suit correspond à l’application des informations disponibles dans la documentation de torch sur notre fonction example. Un autre tutoriel intéressant est le suivant.
On doit définir 4 méthodes: - forward
qui reçoit les entrées et calcule la sortie - setup_context
qui stocke dans un objet ctx
des tenseurs qui peuvent être réutilisés au moment du calcul de la dérivée (dans notre exemple, on a juste besoin de \(x\), \(a\) et \(1 - \tanh^2(a^\top x + b)\). - backward
(ou vjp
) qui reçoit le gradient calculé en aval et renvoie le gradient, pour faire de la différentiation automatique en mode reverse. - jvp
qui reçoit une différentielle calculée en amont et la multiplie en amont avant de la renvoyer, pour faire de la différentiation automatique en mode forward.
Définition de la fonction
class f_torch_manual(torch.autograd.Function):
"""
We can implement our own custom autograd Functions by subclassing
torch.autograd.Function and implementing the forward and backward passes
which operate on Tensors.
"""
@staticmethod
def forward(x, a, b):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output.
"""
= torch.tanh(torch.dot(a, x) + b)
output return output
@staticmethod
def setup_context(ctx, inputs, output):
"""
ctx is a context object that can be used
to stash information for backward computation. You can cache tensors for
use in the backward pass using the ``ctx.save_for_backward`` method. Other
objects can be stored directly as attributes on the ctx object, such as
``ctx.my_object = my_object``.
"""
= inputs
x, a, b ## save output to cut computation time
= 1. - output.pow(2) # tanh' = 1 - tanh^2
scaling
ctx.save_for_backward(x, a, scaling)
@staticmethod
def backward(ctx, grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
It corresponds to a Vector Jacobian Product (vjp), used for reverse auto-differentiation
"""
= ctx.saved_tensors
x, a, scaling = grad_output * a * scaling
grad_x = grad_output * x * scaling
grad_a = grad_output * scaling
grad_b return grad_x, grad_a, grad_b # on doit calculer les grad par rapport à tous les arguments rajouter grad par rapport à a et b
@staticmethod
def jvp(x, a, b, tangents):
"""
It corresponds to a Jacobian Vector Product (jvp), used for forward auto-differentiation
"""
## Vector v of small perturbations
= tangents
tx, ta, tb ## Matrix (in this case vector) of first order gradient
= torch.tanh(torch.dot(a, x) + b)
result = (1. - result.pow(2))
scaling = a * scaling
Jx = x * scaling
Ja = scaling
Jb ## Return J(x, a, b)v
return torch.dot(Jx, tx) + torch.dot(Ja, ta) + Jb * tb
Vérification des dérivées
## Définit f
= f_torch_manual.apply
f = f(x, a, b)
z ## Calcule et évalue le graphe de différentiation automatique de y par rapport à x
## Réinitialise les gradients à zéro avant tout calcul
x.grad.zero_(), a.grad.zero_(), b.grad.zero_()
z.backward()## Renvoie dy/dx
x.grad, a.grad, b.grad
(tensor([0.0099, 0.0099]), tensor([0.0197, 0.0296]), tensor(0.0099))
On vérifie qu’on obtient bien le même résultat qu’en laissant torch
faire le calcul :party:.
On aurait aussi pu utiliser les opérateurs fonctionnels pour calculer la fonction dérivée (en utilisant le mode reverse)
= torch.func.jacrev(func=f, argnums=(0, 1, 2)) f_grad_rev
et vérifier que le résultat coincide avec le calcul fait à la main.
assert all(torch.allclose(f_grad_rev(x, a, b)[i], (x.grad, a.grad, b.grad)[i]) for i in range(2))
On calcule la dérivée par rapport à la première coordonnée de \(x\)
= (torch.tensor([1., 0.]), torch.tensor([0., 0.]), torch.tensor(0.))
tangents = x, a = a, b = b, tangents = tangents) f_torch_manual.jvp(x
tensor(0.0099, grad_fn=<AddBackward0>)
puis par rapport à la deuxième coordonnée de \(a\)
= (torch.tensor([0., 0.]), torch.tensor([0., 1.]), torch.tensor(0.))
tangents = x, a = a, b = b, tangents = tangents) f_torch_manual.jvp(x
tensor(0.0296, grad_fn=<AddBackward0>)
Et on valide que les résultats obtenus coïncident avec ceux obtenus en mode reverse et directement en utilisant torch
😁
En théorie, on pourrait utiliser les opérateurs fonctionnels pour calculer la fonction dérivée (en utilisant le mode forward)
= torch.func.jacfwd(func=f, argnums=(0, 1, 2)) f_grad_fwd
mais il faut définir une méthode statique vmap
et je n’ai pas compris comment faire 😢
Comparaison des temps de calculs
On compare ici les temps de calculs du calcul du gradient en mode forward, reverse pour la version qui utilise les primitives de torch et en mode reverse pour notre version.
= 100
dim = torch.tensor(np.arange(dim)/ (dim*10), requires_grad = True)
a = torch.tensor(0.5, requires_grad = True)
b = torch.tensor( (np.arange(dim) - 37) / (dim*10), requires_grad = True) x
= torch.func.jacrev(func=f_torch, argnums=(0, 1, 2))
f_torch_grad_rev %timeit f_torch_grad_rev(x, a, b)
= torch.func.jacfwd(func=f_torch, argnums=(0, 1, 2))
f_torch_grad_fwd %timeit f_torch_grad_fwd(x, a, b)
%timeit f_grad_rev(x, a, b)
434 μs ± 1.27 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
502 μs ± 1.75 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
788 μs ± 3.72 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Impact de JIT
On essaie de jitter nos fonctions pour vérifier si cela accélère le calcul des gradients.
Plus d’info sur le JIT dans pytorch
sont disponibles dans cette documentation.
= torch.jit.trace(f_torch_grad_rev, (x, a, b))
f_torch_grad_rev_jit = torch.jit.trace(f_torch_grad_fwd, (x, a, b))
f_torch_grad_fwd_jit # f_grad_rev_jit = torch.jit.trace(f_grad_rev, (x, a, b))
%timeit f_torch_grad_rev_jit(x, a, b)
%timeit f_torch_grad_fwd_jit(x, a, b)
# %timeit f_grad_rev_jit(x, a, b)
58.4 μs ± 73.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
125 μs ± 139 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Exemple en JAX [HG, AL]
Pour illustrer la différentiation automatique, nous allons utiliser JAX, sur une function simple dont nous connaissons les gradients analytiques.
Dérivée par rapport à \(x\)
import jax
import jax.numpy as jnp
from jax import grad, jit
from jax import random
## On définit une dimension arbitraire pour nos inputs
= 100
dim
## On initialise les paramètres et le vecteur d'input de la fonction
= jnp.arange(dim)/ (dim*10)
a = 0.5
b = (jnp.arange(dim) - 37) / (dim*10)
x
## On définit une fonction simple dont on connaît les gradients analytiques
def f(x, a, b):
return jnp.tanh(jnp.dot(a, x) + b)
## On affiche la valeur de la fonction pour vérifier que tout est ok
f(x, a, b)
Array(0.56842977, dtype=float32)
Dans un premier temps, on peut définir les gradients exacts de cette fonction à partir d’une formule analytique.
def df_dx(x, a, b):
return a * (1 - jnp.tanh(jnp.dot(a, x) + b) **2)
Puis on définit les gradients via autograd et on vérifie que les résultats sont identiques.
## jax.grad calcule la formule backward par défaut
= jax.grad(lambda x: f(x, a, b), argnums=0)
grad_df_dx ## On peut aussi calculer la formule forward via jax.jacfwd
= jax.jacfwd(lambda x: f(x, a, b), argnums=0)
fwdgrad_df_dx
## On vérifie que les gradients retournent des valeurs identiques
assert jnp.allclose(grad_df_dx(x), df_dx(x, a, b))
assert jnp.allclose(grad_df_dx(x), fwdgrad_df_dx(x))
## Si les assertions ne retournent pas d'erreur, les gradients sont corrects
print('All good, we are ready to go!')
All good, we are ready to go!
On définit également les gradients exacts par rapport aux paramètres \(a\) et \(b\) pour vérifier que l’on pourrait les optimiser dans un algorithme d’apprentissage.
def df_dab(x, a, b):
return x * (1 - jnp.tanh(jnp.dot(a, x) + b) **2), 1 - jnp.tanh(jnp.dot(a, x) + b) **2
Puis on définit ces gradients via autograd et on vérifie que les résultats sont identiques.
## jax.grad calcule la formule backward par défaut
= jax.grad(lambda a_b: f(x, *a_b), argnums=0)
grad_df_dab ## On peut aussi calculer la formule forward via jax.jacfwd
= jax.jacfwd(lambda a_b: f(x, *a_b), argnums=0)
fwdgrad_df_dab
## On vérifie que les gradients retournent des valeurs identiques
assert all(jnp.allclose(grad_df_dab((a,b))[i], df_dab(x, a, b)[i]) for i in range(2))
assert all(jnp.allclose(grad_df_dab((a,b))[i], fwdgrad_df_dab((a,b))[i]) for i in range(2))
## Si les assertions ne retournent pas d'erreur, les gradients sont corrects
print('All good, we are ready to go!')
All good, we are ready to go!
Nous avons donc bien vérifié que les gradients calculés avec JAX sont identiques aux gradients analytiques.
Comparaison des temps de calcul en JAX entre autograd (backward et forward) et le calcul explicite des gradients
Pour mettre en lumière les différences entre l’autograd backward et forward, nous allons définir une nouvelle fonction \(g\) dont les sorties sont de plus grandes dimensions que les entrées.
Comparaison des temps de calcul :
## On lance une quantification du temps de calcul pour les trois méthodes (la commande .block_until_ready() est spécifique à JAX pour forcer l'évaluation des opérations asynchrones)
%timeit df_dx(x, a, b).block_until_ready()
%timeit grad_df_dx(x).block_until_ready()
%timeit fwdgrad_df_dx(x).block_until_ready()
94.7 μs ± 404 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
1.39 ms ± 2.52 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.76 ms ± 3.92 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
On définit une nouvelle fonction \(g\) qui retourne un vecteur de dimension \(d\):
## On définit la dimension du vecteur de sortie
= 100
dim_g ## On initialise un point d'évaluation pour la fonction $g$
= 0.5
x_g
## On définit la nouvelle fonction $g$ qui retourne un vecteur de dimension $dim_g$
def g(x):
return jnp.array([jnp.tanh(x + i/dim_g) for i in range(dim_g)])
## On définit le gradient analytique de la fonction $g$
def dg_dx(x):
return jnp.array([1- jnp.tanh(x + i/dim_g)**2 for i in range(dim_g)])
## jax.grad calcule la formule backward par défaut
= jax.jacrev(lambda x: g(x), argnums=0)
grad_dg_dx ## On peut aussi calculer la formule forward via $jax.jacfwd$
= jax.jacfwd(lambda x: g(x), argnums=0) fwdgrad_dg_dx
Comparaison des temps de calcul :
%timeit dg_dx(x_g).block_until_ready()
%timeit grad_dg_dx(x_g).block_until_ready()
%timeit fwdgrad_dg_dx(x_g).block_until_ready()
13.6 ms ± 40.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
254 ms ± 1.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
106 ms ± 512 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
On onbserve que le calcul explicite des gradients est plus rapide que les deux autres méthodes. Il y a donc un prix à payer pour la différentiation automatique, bien qu’elle soit très utile pour des fonctions complexes. Cependant, nous allons voir qu’en combinaison avec une autre stratégie clef en JAX, le jitting, les choses changent.
JIT compilation avec JAX
La compilation JIT (Just-In-Time) permet d’optimiser les performances des fonctions en les compilant à la volée. JAX fournit la fonction jit
pour cela.
## On redéfinit la fonction $f$ et ses gradients avec JIT
@jit
def f_jit(x, a, b):
return jnp.tanh(jnp.dot(a, x) + b)
@jit
def df_dx_jit(x, a, b):
return a * (1 - jnp.tanh(jnp.dot(a, x) + b) **2)
## On compile les gradients de la fonction $f$ avec JIT
= jax.jit(jax.grad(lambda x: f_jit(x, a, b), argnums=0))
grad_f_jit = jax.jit(jax.jacfwd(lambda x: f_jit(x, a, b), argnums=0))
fwdgrad_f_jit
## On peut maintenant mesurer le temps de calcul des gradients avec JIT
%timeit df_dx_jit(x, a, b).block_until_ready()
%timeit grad_f_jit(x).block_until_ready()
%timeit fwdgrad_f_jit(x).block_until_ready()
8.4 μs ± 11.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
5.79 μs ± 13.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
9.78 μs ± 73.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
## On redéfinit la dimension du vecteur de sortie pour la fonction $g$ (les gradients sont sensibles à la dimension du vecteur de sortie, même avec JIT, notamment lorsque l'on compare l'autograd backward et forward)
= 10
dim_g
## On redéfinit la fonction $f$ et ses gradients avec JIT
@jit
def g_jit(x):
return jnp.array([jnp.tanh(x + i/dim_g) for i in range(dim_g)])
@jit
def dg_dx_jit(x):
return jnp.array([1- jnp.tanh(x + i/dim_g)**2 for i in range(dim_g)])
## Pour que le calcul soit rapide aussi, il faut ajouter un 'vmap' avant le 'jit' pour que la fonction soit vectorisée et puisse tirer parti de la compilation JIT.
= jax.jit(jax.vmap(jax.jacrev(lambda x: g_jit(x), argnums=0)))
grad_g_jit = jax.jit(jax.vmap(jax.jacfwd(lambda x: g_jit(x), argnums=0)))
fwdgrad_g_jit
## On peut maintenant mesurer le temps de calcul des gradients avec JIT
%timeit dg_dx_jit(x).block_until_ready()
%timeit grad_g_jit(x).block_until_ready()
%timeit fwdgrad_g_jit(x).block_until_ready()
16.9 μs ± 909 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
24.9 μs ± 1.09 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
23.4 μs ± 2.8 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
On peut constater dans les deux cas que la compilation JIT permet d’accélérer considérablement le calcul des gradients, et rendre l’autograd négligeable d’un point de vue computationel. Dans le cas de la fonction \(g\) qui retourne un vecteur, la vectorisation permet aussi d’accélérer le calcul. Ainsi la combinaison de l’autograd, du jitting et de la vectorisation permet d’obtenir des performances optimales et rendre la différentiation automatique aussi efficace que l’analytique, même pour des fonctions complexes.
Avec VJP et JVP
Dans les sections précédentes, jax.jacrev
et jax.jacfwd
utilisent, respectivement, les VJP et les JVP des opérations élémentaires de la fonction f
. Dans certains cas, nous pouvons avoir besoin de définir nous-mêmes les VJP et JVP comme vu en introduction.
Nous allons alors définir les VJP et JVP pour f
à l’aide de gradients que nous connaissons analytiquement. Les VJP et JVP des fonctions élémentaires sous-jacentes ne seront alors plus utilisés.
Lien vers la documentation de JAX
Un JVP est capable de dévoiler une colonne de la jacobienne à la fois. Ce n’est pas adapté pour cette fonction dont la jacobienne est large. Une passe JVP ne peut dévoiler qu’une seule dérivée partielle : si l’on veut la dérivée par rapport à chaque dimension de \(x\), chaque dimension de \(a\) et \(b\), il nous faut faire \(dim + dim + 1\) fois des JVPs ce qui n’est pas du tout efficace. Nous l’avons vu dans la section précédente où en fait, jax.jacfwd
doit en fait appeler tous ces JVPs (ce qui est fait de manière cachée à l’utilisateur).
Pour définir un custom_jvp
en JAX, il faut attacher à f
, une fonction f_jvp
, qui prend deux entrées primals
le point où l’on calcule le gradient et tangents
le vecteur tangent (à voir aussi comme les gradients en amont du graphe que l’on parcourt en descendant). f_jvp
retourne un tuple de deux vecteurs, f(primals)
et le JVP df_dx @ tangents
, où, bien sûr, df_dx
contient l’expression analytique de la dérivée (c’est une matrice jacobienne mais elle n’est jamais stockée en mémoire car tout de suite réduite par le produit matriciel).
Nous avons vu que si tangents
est un vecteur one-hot encoded nous dévoilons une colonne de la matrice jacobienne (celle où se situe le \(1\)). Dans l’exemple ci-dessous nous calculons de manière forward \(\frac{\partial f}{\partial x_0}\).
@jax.custom_jvp
def f(x, a, b):
return jnp.tanh(jnp.dot(a, x) + b)
@f.defjvp
def f_jvp(primals, tangents):
= primals
x, a, b = tangents
x_dot, a_dot, b_dot = f(x, a, b)
primal_out return primal_out, (jnp.dot(df_dx(x, a, b), x_dot) + jnp.dot((x * (1 - primal_out ** 2)), a_dot) + jnp.dot((1 - primal_out ** 2), b_dot))
= jnp.zeros(dim)
x0_tangents = x0_tangents.at[0].set(1)
x0_tangents = jax.jvp(f, (x, a, b), (x0_tangents, jnp.zeros(dim), 0.))
_, x_dot
assert jnp.allclose(grad_df_dx(x)[0], x_dot)
print("All right!")
All right!
On note que, s’il est défini, le custom_jvp
sera utilisé par JAX, en mode forward et en mode backward. Notons aussi la syntaxe particulière émanant du fait que f
prend trois arguments en entrée.
Lien vers la documentation de JAX
Un VJP est capable de dévoiler une ligne de la jacobienne à la fois. Cela va donc nous permettre de calculer toute la matrice jacobienne de \(f\) en un seul appel à JVP car c’est une matrice à une seule ligne. Nous l’avons vu dans la section précédente où en fait, jax.jacrec
doit en fait appeler tous ces VJPs (ce qui est fait de manière cachée à l’utilisateur).
Si nous souhaitons explicitement définir le VJP, nous devons d’abord écrire une fonction qui décrit la passe forward. C’est ici f_fwd
qui retourne f(primal)
et des valeurs stockées pour le moment de la passe backward (à la manière de save_for_backward
vu dans la section pytorch
!). Il faut ici bien réfléchir à ce qui est nécessaire de stocker et ce qui est superflu, afin d’optimiser au mieux le code. Ici nous stockons f(x,a,b)
, x
et a
car ces valeurs sont réutilisées dans la passe backward où nous calculons \(g. \frac{\partial \mathrm{tanh}(f(x,a,b))}{\partial x}\), \(g.\frac{\partial \mathrm{tanh}(f(x,a,b))}{\partial a}\) et \(g.\frac{\partial \mathrm{tanh}(f(x,a,b))}{\partial b}\). Avec \(g\) le gradient provenant de l’aval du graphe pour calculer les VJPs (rappelons que nous les calculons de manière backward en remontant le graphe).
Nous comprenons à nouveau que si g
est un vecteur one-hot encoded nous dévoilons une ligne de la matrice jacobienne (celle où se situe le \(1\)).
Nous devons également écrire une fonction f_bwd
qui prend en argument les valeurs stockées dans la passe forward ainsi que g
défini dans le paragraphe précédent. Ici g
est scalaire f
a valeurs dans \(\mathbb{R}\). f_bwd
retourne autant de sorties que f
compte d’entrées.
@jax.custom_vjp
def f(x, a, b):
return jnp.tanh(jnp.dot(a, x) + b)
def f_fwd(x, a, b):
= f(x, a, b)
primal_out return primal_out, (x, a, primal_out)
def f_bwd(res, g):
= res
x, a, primal_out return (a * (1 - primal_out **2) * g, x * (1 - primal_out **2) * g, (1 - primal_out **2) * g)
f.defvjp(f_fwd, f_bwd)= jax.vjp(f, x, a, b) # renvoie f(primal) et f_vjp qui est une fonction qui doit être évaluée en `g`
_, f_vjp
assert jnp.allclose(grad_df_dx(x), f_vjp(1.)[0])
assert all(jnp.allclose(grad_df_dab((a,b))[i], f_vjp(1.)[i + 1]) for i in range(2))
print("All right!")
All right!
Notons que la définition d’un custom_vjp
redéfinit la fonction grad
qui utilise donc aussi f_fwd
. Ainsi, nous avons l’équivalence :
assert jnp.allclose(f_vjp(1.)[0], jax.grad(f)(x, a, b))
print("All right!")
All right!