Tutoriel de différentiation automatique - HappyR 9/01/2026

Auteur·rice·s
Affiliations

Mahendra Mariadassou

INRAE - MaIAGE

Hugo Gangloff

INRAE - MIA Paris Saclay

Arthur Leroy

INRAE - GABI & MIA Paris Saclay

Lucia Clarotto

AgroParisTech - MIA Paris Saclay

Date de publication

16 janvier 2026

Modifié

19 janvier 2026

Note

L’objectif de ce tutoriel est triple :

  • présenter les bases mathématiques du calcul de la Jacobienne d’une fonction arbitrairement complexe par différentiation automatique (modes forward et backward),
  • mettre ces concepts en pratique en Python,
  • comparer les performances computationnelles des différentes approches.
Installation package [en amont]

Le seul package nécessaire au tutoriel est JAX. Pour l’installation:

  • pip install jax

Théorie avec illustration sur un exemple simple [45 minutes]

Introduction à la Dérivation Automatique avec JAX

Ce tutoriel introduit les concepts de Jacobian-Vector Product (JVP) et Vector-Jacobian Product (VJP) à travers un exemple concret : une couche intermédiaire d’un réseau de neurones. L’objectif est de poser les notations mathématiques, puis de montrer comment JAX implémente ces opérations.

Contexte et Notations

On considère une couche intermédiaire d’un réseau de neurones :

\[ f: \mathbb{R}^4 \to \mathbb{R}^2, \quad \mathbf{x} \mapsto \tanh(W\mathbf{x} + \mathbf{b}) \]

où :

  • \(\mathbf{x} = (x_0, x_1, x_2, x_3) \in \mathbb{R}^4\) (entrée),
  • \(W \in \mathbb{R}^{2 \times 4} = \begin{pmatrix} \mathbf{w}_0 \\ \mathbf{w}_1 \end{pmatrix}\) (matrice des poids),
  • \(\mathbf{b} = (b_0, b_1)\in \mathbb{R}^2\) (biais),
  • \(\tanh\) est appliquée terme à terme.

Jacobienne de \(f\)

La jacobienne de \(f\) au point \(\mathbf{x}\) notée \(J_f(\mathbf{x})\) (ou parfois \(\partial_\mathbf{x} f\)) est une application de linéaire de \(\mathbb{R}^4\) dans \(\mathbb{R}^2\) de matrice: \[ J_f(\mathbf{x}) = \frac{\partial f}{\partial \mathbf{x}} = \begin{pmatrix} \frac{\partial f_i(\mathbf{x})}{\partial x_j} \end{pmatrix}_{i=1\dots2, j = 1 \dots4} = \begin{bmatrix} 1 - \tanh^2(\mathbf{w}_0^\top \mathbf{x} + b_0) & 0 \\ 0 & 1 - \tanh^2(\mathbf{w}_1^\top \mathbf{x} + b_1) \end{bmatrix} \cdot W \in \mathbb{R}^{2 \times 4} \]

Dérivation Forward (JVP)

Principe

Le JVP (pour Jacobian Vector Product) calcule la dérivée directionnelle de \(f\) dans la direction d’un vecteur \(\mathbf{v} \in \mathbb{R}^4\) : \[ J_f(\mathbf{x}) \cdot \mathbf{v} \in \mathbb{R}^{2} \] \(J_f(\mathbf{x}) \cdot \mathbf{v}\) est le vecteur tangent de \(f\) au point \(\mathbf{x}\) dans la direction \(\mathbf{v}\).

Si on ne considère pas un point \(\mathbf{x}\) en particulier, on peut voir \(J_f\) comme une application \(J_f: \mathbb{R}^4 \to (\mathbb{R}^4 \to \mathbb{R}^2)\). En particulier, si on se donne un point d’intérêt \(\mathbf{x}\) et une direction d’intérêt \(\mathbf{v}\), on peut définir une application \(\texttt{jvp}_f\) de \(\mathbb{R}^4 \times \mathbb{R}^4 \to \mathbb{R}^2\) via:

\[ \texttt{jvp}_f(\mathbf{x}, \mathbf{v}) = J_f(\mathbf{x}). \mathbf{v} \in \mathbb{R}^{2} \]

Dans le langage de l’autodiff:

  • \(\mathbf{x}\) est le vecteur primal (point d’évaluation)
  • \(\mathbf{v}\) est le vecteur tangent (direction d’évaluation de la jacobienne)

Cette écriture et ce vocabulaire sont très utiles pour calculer une dérivée directionnelle de façon séquentielle.

Important

Si \(h = f \circ g\), pour calculer \(\texttt{jvp}_h(\cdot, \cdot)\), il suffit de calculer \((\mathbf{x}, \mathbf{v}) \to (f(\mathbf{x}), \texttt{jvp}_f(\mathbf{x}, \mathbf{v}))\) et \((\mathbf{x}, \mathbf{v}) \to (g(\mathbf{x}), \texttt{jvp}_g(\mathbf{x}, \mathbf{v}))\) en des points arbitraires. En d’autres termes, il suffit de propager les vecteurs primaux et tangents le long de la chaîne (approche forward). On a en effet:

\[ \texttt{jvp}_h(\mathbf{x}, \mathbf{v}) = \texttt{jvp}_{f}(g(\mathbf{x}), \texttt{jvp}_{g}(\mathbf{x}, \mathbf{v})) \]

Application sur notre exemple.

Reprenons notre exemple simple et décomposons le en écrivant \(\mathbf{z} = (z_0, z_1) = z(\mathbf{x}) = W\mathbf{x} + \mathbf{b} = (\mathbf{w}_0^\top \mathbf{x} + b_0, \mathbf{w}_1^\top \mathbf{x} + b_1)\). On peut calculer:

\[ \begin{align} f(\mathbf{x}) & = \tanh(\mathbf{z}) = \tanh(z(\mathbf{x})) \\ J_z(\mathbf{x}) & = W \\ J_{\tanh}(\mathbf{z}) & = \begin{bmatrix} 1 - \tanh^2(z_0) & 0 \\ 0 & 1 - \tanh^2(z_1) \end{bmatrix} \\ \texttt{jvp}_z(\mathbf{x}, \mathbf{v}) & = W \cdot \mathbf{v} \\ \texttt{jvp}_{\tanh}(\mathbf{z}, \mathbf{u}) & = \begin{bmatrix} 1 - \tanh^2(z_0) & 0 \\ 0 & 1 - \tanh^2(z_1) \end{bmatrix} \mathbf{u} \\ \end{align} \] On voit bien que, avec \(\mathbf{z}\) précédemment défini, on a:

\[ J_f(\mathbf{x}) \cdot \mathbf{v} = J_{\tanh}(\mathbf{z}) \cdot J_z(\mathbf{x})\cdot \mathbf{v} \]

qu’on peut réécrire en terme de \(\texttt{jvp}\) comme suit \[ \texttt{jvp}_f(\mathbf{x}, \mathbf{v}) = \texttt{jvp}_{\tanh}(z(\mathbf{x}), \texttt{jvp}_{z}(\mathbf{x}, \mathbf{v})) \]

Note

L’approche JVP est très frugale en mémoire et conceptuellement simple, les vecteurs primaux et tangent sont calculés et propagés à la volée (approche forward) et il n’est pas nécessaire de stocker quoi que ce soit. Le coût d’une évaluation de \((f(x), \texttt{jvp}_f(x, v))\) est à peu près 3 fois celui d’une évaluation de \(f(x)\).

L’approche JVP est utile pour évaluer les colonnes de \(J_f(\mathbf{x})\): il suffit de prendre \(\mathbf{v}\) de la forme \(\mathbf{v} = (0, \dots, 0, 1, 0, \dots, 0)\). Elle fonctionne donc très bien quand \(J_f(\mathbf{x})\) est long (\(f: \mathbb{R}^n \to \mathbb{R}^m\) avec \(n \leq m\)).

En pratique, la fonction jax.jvp est définie par \(\texttt{jax.jvp}: (f, \mathbf{x}, \mathbf{v}) \to (f(\mathbf{x}), J_f(\mathbf{x})\cdot \mathbf{v})\) (elle renvoie à la fois le nouveau vecteur primal et le nouveau vecteur tangent).

Dérivation Backward (VJP)

Principe

Le VJP (pour Vector Jacobian Product) calcule le produit d’un vecteur \(\mathbf{u} \in \mathbb{R}^2\) avec la transposée de la jacobienne : \[ \mathbf{u}^\top \cdot J_f(\mathbf{x}) \]

Si on ne considère pas un point \(\mathbf{x}\) en particulier, on peut voir \(J_f\) comme une application \(J_f: \mathbb{R}^4 \to (\mathbb{R}^2 \to \mathbb{R}^4)\). En particulier, si on se donne un point d’intérêt \(\mathbf{x}\) et une co-direction d’intérêt \(\mathbf{u}\), on peut définir une application \(\texttt{vjp}\) de \(\mathbb{R}^4 \times \mathbb{R}^2 \to \mathbb{R}^4\) via:

\[ \texttt{vjp}_f(\mathbf{x}, \mathbf{u}) = \mathbf{u}^\top J_f(\mathbf{x}) \in \mathbb{R}^{4} \]

Son intérêt est de calculer facilement les lignes de \(J_f(\mathbf{x})\), en prenant \(\mathbf{u}\) de la forme \(\mathbf{u} = (0, \dots, 0, 1, 0, \dots, 0)\), ce qui est plus efficace pour les matrices larges.

Soit une fonction composée \(h = f \circ g\), où :

  • \(g: \mathbb{R}^n \to \mathbb{R}^m\) (fonction interne),
  • \(f: \mathbb{R}^m \to \mathbb{R}^p\) (fonction externe).

On cherche à calculer la dérivée de \(f \circ g\) par rapport à \(\mathbf{x} \in \mathbb{R}^n\), c’est-à-dire la jacobienne de la composition : \[ J_{f \circ g}(\mathbf{x}) = J_f(g(\mathbf{x})) \cdot J_g(\mathbf{x}) = J_f(\mathbf{y}) \cdot J_g(\mathbf{x}), \quad \text{où } \mathbf{y} = g(\mathbf{x}). \]

Le VJP (Vector-Jacobian Product) permet de calculer le produit d’un vecteur cotangent \(\mathbf{u} \in \mathbb{R}^p\) avec \(J_{f \circ g}(\mathbf{x})\) de façon séquentielle : \[ \mathbf{u}^\top \cdot J_{f \circ g}(\mathbf{x}) = \mathbf{u}^\top \cdot \left( J_f(\mathbf{y}) \cdot J_g(\mathbf{x}) \right). \]

  1. Évaluer les vecteurs primaux : Calculer \(\mathbf{y} = g(\mathbf{x})\).
  2. Calculer le VJP de \(f\) : Calculer \(\mathbf{v}^\top = \mathbf{u}^\top \cdot J_f(\mathbf{y})\). Ce vecteur \(\mathbf{v} \in \mathbb{R}^m\) est le gradient adjoint de \(f\) au point \(\mathbf{y}\) pondéré par \(\mathbf{u}\).
  3. Calculer le VJP de \(g\) : Calculer \(\mathbf{v}^\top \cdot J_g(\mathbf{x})\). Ce produit donne le le gradient adjoint de \(g\) au point \(\mathbf{x}\) pondéré par \(\mathbf{v}\) qui correspond au gradient final \(\mathbf{u}^\top \cdot J_{f \circ g}(\mathbf{x})\).

Application sur notre exemple.

On rappelle que

  • \(z(\mathbf{x}) = W\mathbf{x} + \mathbf{b}\)
  • \(f(\mathbf{x}) = \tanh(\mathbf{z})\)

Pour un vecteur cotangent \(\mathbf{u} \in \mathbb{R}^2\):

  • Calcul de la couche linéaire \(z(\mathbf{x}) = W\mathbf{x} + \mathbf{b}\)
  • Première étape de la rétropropagation: \[ \mathbf{v}^\top = \texttt{vjp}_{\tanh}(\mathbf{z}, \mathbf{u}) = \mathbf{u}^\top J_{\tanh}(\mathbf{z}) = \mathbf{u}^\top \begin{bmatrix} 1 - \tanh^2(z_0) & 0 \\ 0 & 1 - \tanh^2(z_1) \end{bmatrix} \in \mathbb{R}^2 \]
  • Deuxième (et dernière) étape de la rétropropagation: \[ \mathbf{u}^T J_f(\mathbf{x}) = \texttt{vjp}_{z}(\mathbf{x}, \mathbf{v}) = \mathbf{v}^T J_z(\mathbf{x}) = \mathbf{v}^\top W \in \mathbb{R}^4 \]
Différence avec le mode forward
  • Le VJP permet de propager les gradients depuis la sortie de \(f\) jusqu’à son entrée.
  • C’est la base du mode reverse de la différentiation automatique où l’on calcule les gradients depuis la sortie vers l’entrée.
  • Contrairement au mode forward, il nécessite de stocker les valeurs des vecteurs primaux lors d’une première passe forward avant de rétropropager et de mettre à jour le gradient lors d’une passe reverse.

Cette approche est beaucoup plus efficace que l’approche forward pour les jacobiennes larges (peu de lignes, beaucoup de colonnes) mais conceptuellement plus sophistiquée et plus difficile à mettre en oeuvre: la profondeur de la pile mémoire augmente en effet linéairement avec le nombre de fonctions composées.

En pratique, la fonction jax.vjp est définie par \(\texttt{jax.jvp}: (f, \mathbf{x}, \mathbf{u}) \to (f(\mathbf{x}), \mathbf{u}^\top J_f(\mathbf{x}))\).

Implémentation de l’exemple en JAX [75 minutes]

Rappel des dérivées analytiques pour l’implémentation

On note

  • \(\mathbf{z} = W\mathbf{x} + \mathbf{b} \in \mathbb{R}^2\)
  • \(f_i(\mathbf{x}) = \tanh(z_i)\)

avec \(\mathbf{x} \in \mathbb{R}^4\), \(W \in \mathbb{R}^{2\times 4}\), \(\mathbf{b} \in \mathbb{R}^2\).

La dérivée de \(\tanh\) est : \[ \frac{d}{dz} \tanh(z) = 1 - \tanh^2(z). \]

La jacobienne de \(f\) en \(\mathbf{x}\) est donc \[ J_f(\mathbf{x}) = \begin{bmatrix} 1 - \tanh^2(z)) & 0 \\ 0 & 1 - \tanh^2(z)) \end{bmatrix} \cdot W \]

Imports et définition de la fonction

import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)

dim_entree = 4
dim_sortie = 2
key, subkey = jax.random.split(key)
W = jax.random.normal(subkey, (dim_sortie, dim_entree))
key, subkey = jax.random.split(key)
b = jax.random.normal(subkey, (dim_sortie,))

def f(x, W, b):
  return jnp.tanh(W @ x + b)

jnp est un module de JAX qui reprend l’essentiel des fonctions numpy.

Aléa dans JAX

La gestion de l’aléa dans JAX est assez particulière. L’aléa est fonctionnel et contrôlé par une clé :

  • On crée une clé : key = jax.random.PRNGKey(seed)
  • Chaque tirage consomme la clé \(\rightarrow\) il faut la scinder : key, subkey = jax.random.split(key)
  • Les fonctions aléatoires prennent toujours une clé en argument :
  • x = jax.random.normal(subkey, shape)

On testera les dérivées en un point donné :

x = jnp.array([0.2, -0.1, 0.5, 1.0])
Important

Ici on calcule la dérivée par rapport à \(\mathbf{x}\) en gardant \(W\) et \(\mathbf{b}\) comme des paramètres. Dans l’apprentissage d’un réseau de neurones, on calcule plutôt les dérivées par rapport à \(W\) et \(\mathbf{b}\), mais le principe sera exactement le même!

  1. Sans passer par les fonctions forward et backward de JAX, implémenter la dérivée dans une fonction jacobian_manual()
  2. En utilisant jax.jacfwd et jax.jacbwd (le principe de ces fonctions est vu en Section 1)
  3. En utilisant jax.grad + jax.vmap
  4. Contrôler que toutes les versions donnent le même résultats.
  1. Les fonctions jax.jacrev et jax.jacbwd prennent en entrée au moins une fonction et retourne un callable, donc une autre fonction. Il faut créer une fonction anonmye lambda x: f(x, W, b) pour considérer f uniquement comme fonction de x L’argument argnums de jax.jacrev et jax.jacbwd dit par rapport à quelle variable on veut dériver

  2. jax.grad calcule la dérivée d’une fonction par formule backward, mais elle ne s’applique qu’aux fonctions à valeurs dans \(\mathbb{R}\). Il faut utiliser jax.vmap pour appliquer la même fonction à toutes les composantes du vecteur de sortie.

  3. Vous pouvez utiliser la fonction jnp.allclose().

1. Jacobienne à la main

def jacobian_manual(x, W, b):
    z = W @ x + b              # (2,)
    D = 1.0 - jnp.tanh(z)**2   # (2,)
    return jnp.diag(D) @ W     # (2, 4)
J_manual = jacobian_manual(x, W, b)
J_manual
Array([[-0.2562241 , -0.21355163,  0.02156247, -0.03708893],
       [-0.4760124 , -0.7362525 , -0.7173037 ,  0.18564229]],      dtype=float32)

2. Jacobienne avec Jax autodiff

JAX permet de calculer directement la jacobienne complète.

Différentiation reverse-mode (jax.jacrev)

On peut calculer la formule backward via jax.jacrev (ou son alias jax.jacobian).

Dès le départ on va fixer l’index de l’argument de la fonction f par rapport auquel on veut dériver. Dans notre exemple, on dérive par rapport à x qui correspond à l’index 0 des arguments. Par la suite, on utilisera cet index pour pouvoir dire par rapport à quelle variable on veut dériver, avec l’argument argnum.

idx_jac = 0
J_bwd = jax.jacrev(lambda x: f(x, W, b), argnums=idx_jac) 

J_bwd(x)
Array([[-0.2562241 , -0.21355163,  0.02156247, -0.03708893],
       [-0.47601238, -0.7362524 , -0.7173036 ,  0.18564227]],      dtype=float32)
  • La fonction jax.jacrev prend en entrée au moins une fonction et retourne un callable, donc une autre fonction.

  • lambda x: f(x, W, b) est une fonction anonyme qui :

    • prend un seul argument x
    • appelle f en gardant W et b constants (fermés dans la closure)
  • l’argument argnums dit par rapport à quelle variable on veut dériver, ici x

L’appel du callable J_bwd sur x nous donne la Jacobienne calculée par différentiation automatique.

Différentiation forward-mode (jax.jacfwd)

On peut aussi calculer la formule forward via jax.jacfwd avec le même principe que la précédente.

J_fwd = jax.jacfwd(lambda x: f(x, W, b), argnums=idx_jac)

J_fwd(x)
Array([[-0.25622407, -0.21355158,  0.02156247, -0.03708893],
       [-0.47601235, -0.7362524 , -0.71730363,  0.18564227]],      dtype=float32)

3. Jacobienne avec gradient et vmap (jax.grad + jax.vmap)

jax.grad est une autre fonction qui calcule la dérivée d’une fonction par formule backward, mais elle ne s’applique qu’aux fonctions à valeurs dans \(\mathbb{R}\). Pour pouvoir l’utiliser, on doit donc calculer le gradient de chaque composante de f séparément et ensuite utiliser jax.vmap pour appliquer la même fonction à toutes les valeurs du vecteur de sortie.

L’idée de vmap est la suivante : vectoriser une fonction qui agit sur un seul élément pour qu’elle agisse sur tout un lot d’éléments.

Si on a une fonction \(g\) telle que \[g(i) = \nabla_x f_i(x,W,b),\]

la fonction jax.vmap(g)([0, 1]) construit automatiquement le vecteur

\[\begin{pmatrix} \nabla_x f_0(x,W,b)\\ \nabla_x f_1(x,W,b) \end{pmatrix},\]

qui est la Jacobienne de \(g\).

grad_f_i = jax.grad(lambda i, x, W, b: f(x, W, b)[i], argnums=1+idx_jac)

J_grad_vmap = jax.vmap(grad_f_i, in_axes=(0,None,None,None))(jnp.arange(dim_sortie),x,W,b)
J_grad_vmap
Array([[-0.2562241 , -0.21355163,  0.02156247, -0.03708893],
       [-0.47601238, -0.7362524 , -0.7173036 ,  0.18564227]],      dtype=float32)
  • jax.grad calcule une ligne de la Jacobienne (gradient d’une sortie scalaire), c’est un callable
  • jax.vmap calcule ces lignes de manière vectorisée.

Dans la fonction jax.grad, l’argnum devient 1 + idx_jac car il faut considérer le nouveau premier argument i.

Dans la fonction vmap, on observe un argument en plus de la fonction à mapper :

  • in_axes=(0, None, None, None) est le mapping pattern de vmap :
    • 0 : on vectorise sur la dimension \(0\) de l’argument i
    • None : x, W, b sont partagés (pas vectorisés)
  • jnp.arange(dim_sortie) fournit les indices \(i = 0, 1, ..., \texttt{dim\_sortie}−1\) pour lesquels on veut calculer le gradient.

Au final, la sortie est bien la matrice Jacobienne \((2 \times 4)\).


4. Vérification de l’égalité des résultats

print(jnp.allclose(J_manual, J_bwd(x)))
print(jnp.allclose(J_manual, J_fwd(x)))
print(jnp.allclose(J_manual, J_grad_vmap))
True
True
True
print(jnp.max(jnp.abs(J_manual - J_fwd(x))))
5.9604645e-08

COFFEE BREAK + SOCIAL TIME [15 minutes]

Implémentation des JVP et VJP pour une fonction personnalisée

  1. Calculer la dérivée de notre fonction en utilisant les fonctions jax.jvp et jax.vjp fournies en JAX.
  2. Redéfinissez manuellement les opérateurs jvp et vjp associés à notre exemple, comme si l’on chercher à créer un brique élémentaire pour la chaine de dérivation.

Vérifier à chaque fois que les différentes approches donnent le même résultat que la version analytique.

  1. Pour cela nous pouvons utiliser jax.jvp(f, (x,), (v,)) et jax.vjp(f, x) (puis appliquer la fonction retournée par jax.vjp sur un vecteur cotangent u).
  2. Vous pouvez utiliser les fonctions jax.custom_jvp et jax.custom_vjp pour cette tâche. Plus spécifiquement, vous devez définir une fonction f_custom_jvp et f_custom_vjp qui encapsule la fonction f, puis définir les règles jvp et vjp associées en utilisant les dérivées analytiques que nous avons calculées précédemment.

5. Jacobienne avec jax.jvp et jax.vjp

Tout d’abord, rappelons que le jvp et le vjp permettent de calculer des produits entre la jacobienne et un vecteur, respectivement dans le sens forward (jvp) et backward (vjp). Nous allons donc définir des vecteurs tangent et cotangent pour tester ces fonctions.

# Vecteur tangent pour jvp (dimension d'entrée de f)
v = jnp.array([1.0, 0.0, 0.0, 0.0])  # Exemple de vecteur tangent
# Vecteur cotangent pour vjp (dimension de sortie de f)
u = jnp.array([1.0, 0.0])  # Exemple de vecteur cotangent

# Calcul du JVP
f_jvp_out, jvp_result = jax.jvp(lambda x: f(x, W, b), (x,), (v,))
print("JVP Result:", jvp_result)

# Calcul du VJP
f_vjp_out, vjp_fun = jax.vjp(lambda x: f(x, W, b), x)
vjp_result = vjp_fun(u)[0]
print("VJP Result:", vjp_result)

# Vérification que les résultats sont cohérents avec la jacobienne manuelle
J_manual = jacobian_manual(x, W, b)
jvp_check = J_manual @ v
vjp_check = u @ J_manual
print("JVP Check:", jvp_check)
print("VJP Check:", vjp_check)

print("JVP matches manual:", jnp.allclose(jvp_result, jvp_check))
print("VJP matches manual:", jnp.allclose(vjp_result, vjp_check))
JVP Result: [-0.25622407 -0.47601235]
VJP Result: [-0.2562241  -0.21355163  0.02156247 -0.03708893]
JVP Check: [-0.2562241 -0.4760124]
VJP Check: [-0.2562241  -0.21355163  0.02156247 -0.03708893]
JVP matches manual: True
VJP matches manual: True

6. Redéfinition manuelle de jvp et vjp

Nous allons redéfinir le jvp et le vjp de notre fonction f comme s’il s’agissait d’une brique élémentaire pour la chaîne de différentiation automatique. Pour cela, nous utiliserons les décorateurs @jax.custom_jvp et @jax.custom_vjp.

@jax.custom_jvp
def f_custom_jvp(x, W, b):
    return f(x, W, b)

@f_custom_jvp.defjvp
def f_jvp(primals, tangents):
    x, W, b = primals
    v, _, _ = tangents
    z = W @ x + b
    D = 1.0 - jnp.tanh(z)**2
    jvp_result = jnp.diag(D) @ W @ v
    return f(x, W, b), jvp_result

@jax.custom_vjp
def f_custom_vjp(x, W, b):
  return f(x, W, b)

def f_custom_vjp_fwd(x, W, b):
  y = f(x, W, b)
  z = W @ x + b
  D = 1.0 - jnp.tanh(z)**2
  # residuals needed for backward: x, W, D
  return y, (x, W, D)

def f_custom_vjp_bwd(residuals, g):
  x, W, D = residuals
  # g is the cotangent (u) with shape (dim_sortie,)
  v = g * D  # element-wise
  dx = v @ W
  dW = jnp.outer(v, x)
  db = v
  return dx, dW, db

# Liaison des fonctions forward et backward
f_custom_vjp.defvjp(f_custom_vjp_fwd, f_custom_vjp_bwd)

# Test de la redéfinition
f_jvp_out_custom, jvp_result_custom = jax.jvp(lambda x: f_custom_jvp(x, W, b), (x,), (v,))
f_vjp_out_custom, vjp_fun_custom = jax.vjp(lambda x: f_custom_vjp(x, W, b), x)
vjp_result_custom = vjp_fun_custom(u)[0]
print("Custom JVP Result:", jvp_result_custom)
print("Custom VJP Result:", vjp_result_custom)

print("Custom JVP matches manual:", jnp.allclose(jvp_result_custom, jvp_check))
print("Custom VJP matches manual:", jnp.allclose(vjp_result_custom, vjp_check))
Custom JVP Result: [-0.2562241 -0.4760124]
Custom VJP Result: [-0.2562241  -0.21355163  0.02156247 -0.03708893]
Custom JVP matches manual: True
Custom VJP matches manual: True

Quelques explications sur le code ci-dessus:

  • Nous avons du redéfinir f_custom_jvp et f_custom_vjp en utilisant les décorateurs @jax.custom_jvp et @jax.custom_vjp. Il s’agit de fonctions qui encapsulent notre fonction f originale. On ne peut pas directement décorer f car elle ne suit pas la signature attendue par JAX pour les fonctions différentiables.
  • Nous avons utilisé les décorateurs @jax.custom_jvp et @jax.custom_vjp pour définir des versions personnalisées de jvp et vjp pour notre fonction f. Cela permet de spécifier explicitement comment calculer ces produits en utilisant les dérivées analytiques.
  • Dans le cas du ‘vjp’, nous avons défini une fonction de passage avant (f_custom_vjp_fwd) qui calcule la sortie de la fonction ainsi que les résidus nécessaires pour le calcul du backward. La fonction de passage arrière (f_custom_vjp_bwd) utilise ces résidus pour calculer le produit vectoriel avec la transposée de la jacobienne. Ces résidus sont essentiels pour le calcul correct du gradient.

Cette fonctionnalité est utile lorsque nous voulons contrôler précisément le comportement de la différentiation automatique pour des fonctions spécifiques (potentiellement non implémentées en JAX), notamment lorsque les dérivées analytiques sont disponibles ou lorsque nous souhaitons optimiser les performances.

Performance de l’autodifférentiation

  1. Comparer, sur un grand nombre d’évaluations, les temps de calcul des dérivées analytiques, de l’autodifférentiation en mode forward et backward.
  2. Essayez d’appliquer une compilation JIT (Just In Time) pour comparer à nouveau les différences entre chaque approche.
  1. Vous pouvez utiliser la commande % timeit pour calculer le temps d’execution.

  2. Pour ‘jitter’ une fonction, il suffit soit de la décorer avec @jax.jit ou d’utiliser jax.jit(your_function).

7. Comparaison des temps de calcul

## On utilise %timeit pour mesurer les temps d'exécution
##%timeit jacobian_manual(x, W, b)
##%timeit jax.jacfwd(lambda x: f(x, W, b), argnums=idx_jac)(x)
##%timeit jax.jacrev(lambda x: f(x, W, b), argnums=idx_jac)(x)

On peut constater que la dérivée manuelle est la plus rapide, tandis que jax.jacrev et jax.jacfwd sont comparables dans notre exemple. De manière générale, les différences de temps vont varier en fonction de la taille des entrées et des sorties. Il faut privilégier jacrev lorsque la dimension de sortie est petite ( \(d_\text{output} \ll d_\text{input}\)), et jacfwd lorsque la dimension d’entrée est petite (\(d_\text{output} \gg d_\text{input}\)).

8. Comparaison des temps de calcul avec JIT

Le jitting est une technique d’optimisation qui compile une fonction pour une exécution (parfois beaucoup) plus rapide. En JAX, c’est fait très simplement avec jax.jit.

# On redéfinit la fonction f et ses gradients avec JIT
f_jit = jax.jit(lambda x: f(x, W, b))
jac_manual_jit = jax.jit(lambda x: jacobian_manual(x, W, b))
jac_fwd_jit = jax.jit(jax.jacfwd(lambda x: f_jit(x), argnums=idx_jac))
jac_bwd_jit = jax.jit(jax.jacrev(lambda x: f_jit(x), argnums=idx_jac))

Il est important de faire un “warm-up” des fonctions JIT avant de mesurer les temps d’exécution, car la première exécution inclut le temps de compilation.

# Warm-up from module import names
jac_manual_jit(x)
jac_fwd_jit(x)
jac_bwd_jit(x)
Array([[-0.25622404, -0.21355158,  0.02156247, -0.03708893],
       [-0.47601238, -0.7362524 , -0.7173036 ,  0.18564227]],      dtype=float32)

Puis le benchmarking:

##%timeit jac_manual_jit(x)
##%timeit jac_fwd_jit(x)
##%timeit jac_bwd_jit(x)

On peut constater que le jitting réduit considérablement les temps d’exécution pour toutes les méthodes. La dérivée manuelle reste la plus rapide, mais la différence avec jax.jacrev devient négligeable. Le jitting rend l’autodifférentiation très compétitive en termes de performance, ce qui est un avantage majeur de JAX pour les applications de machine learning et d’optimisation numérique.

Application au calcul de la hessienne [30 minutes]

Calculs de la hessienne

Ainsi si la matrice jacobienne représente l’application linéaire qui donne les variations de \(f\) pour un changement infinitésimal des entrées, alors on peut dire que la matrice hessienne représente l’opérateur qui donne les variations de la matrice jacobienne de \(f\) pour un changement infinitésimal en entrée. C’est donc un opérateur qui prend en entrée un vecteur de l’espace tangent et donne une matrice.

La hessienne de \(f\) (\(f\) telle que \(f\colon\mathbb{R}^4\to\mathbb{R}^2\)) au point \(\mathbf{x}\) notée \(H_f(\mathbf{x})\) est un opérateur linéaire (qui produit une matrice) tel que \(H_f\colon\mathbb{R}^4 \times (\mathbb{R}^4\to \mathbb{R}^2)\). On peut donc lui associer une matrice à trois dimensions, que l’on nommera, peut-être avec un petit abus de langage, matrice hessienne. On peut aussi voir \(H_f\) comme une application bilinéaire : \(H_f\colon(\mathbb{R}^4\times\mathbb{R}^4)\to\mathbb{R}^2\). On peut montrer que les deux premiers vecteurs tangents sont interchangeables, ce qui en fait une forme bilinéaire symétrique.

De même, si on ne considère pas un point \(\mathbf{x}\) en particulier, on peut voir \(H_f\) comme une application \(H_f: \mathbb{R}^4 \to \mathbb{R}^4 \to \mathbb{R}^4 \to \mathbb{R}^2\).

Calculer la matrice hessienne associée \(H_f(\mathbf{x}), \mathbf{x}=(0.2, -0.1, 0.5, 1.0),\) de la manière la plus simple possible avec JAX. Afficher les dimensions de l’objet rétourné.

H_f_x = jax.hessian(f)(x, W, b)

print(H_f_x, H_f_x.shape)
[[[ 1.1841627   0.9869479  -0.0996529   0.17140983]
  [ 0.9869481   0.8225781  -0.08305635  0.14286263]
  [-0.09965291 -0.08305635  0.00838626 -0.01442495]
  [ 0.17140985  0.14286262 -0.01442495  0.0248119 ]]

 [[ 0.44439813  0.6873544   0.66966414 -0.17331289]
  [ 0.68735445  1.0631369   1.0357753  -0.26806453]
  [ 0.6696641   1.0357752   1.0091177  -0.2611654 ]
  [-0.17331289 -0.26806453 -0.2611654   0.06759109]]] (2, 4, 4)

Nous avons dit que la hessienne s’interprétait comme les variations des variations (la dérivée de la jacobienne), essayons alors de reconstruire jax.hessian avec les fonctions plus bas-niveau vues plus tôt.

Calculer la matrice hessienne associée \(H_f(\mathbf{x}), \mathbf{x}=(0.2, -0.1, 0.5, 1.0),\) en utilisant jax.jacfwd et jax.jacrev.

\(\circ\)

H_f_x_0 = jax.jacfwd(jax.jacrev(f))(x, W, b)
H_f_x_1 = jax.jacfwd(jax.jacfwd(f))(x, W, b)
H_f_x_2 = jax.jacrev(jax.jacrev(f))(x, W, b)
H_f_x_3 = jax.jacrev(jax.jacfwd(f))(x, W, b)

assert jnp.allclose(H_f_x, H_f_x_0)
assert jnp.allclose(H_f_x, H_f_x_1)
assert jnp.allclose(H_f_x, H_f_x_2)
assert jnp.allclose(H_f_x, H_f_x_3)

print("All good!")
All good!

En théorie (question papier/crayon), parmi toutes les solutions possibles à la question précédente, quelle devrait être la plus efficace en terme de coût calculatoire ? La moins efficace ?

Se rappeler quand il est préférable d’utiliser les JVP (mécanique derrière jacfwd) et quand il est préférable d’utiliser les VJP (mécanique derrière jacrev).

Un JVP (resp. VJP) avec \(\mathbf{v}\) de la forme \((0, ..., 0, 1, 0, ..., 0)\) est capable de dévoiler une colonne (resp. ligne) de la matrice jacobienne. Il vaut donc mieux vectoriser les appels aux JVP (resp. VJP) lorsque l’on veut calculer une matrice jacobienne dans \(\mathbb{R}^{m\times n}\) avec \(n\leq m\) (resp. \(m\leq n\)).

Souvenons-nous aussi que jacfwd (resp. jacrev) vectorise des appels de JVP (resp. VJP) avec des \(\mathbf{v}\) tels que ci-dessus.

Lors du calcul de la matrice jacobienne, il vaut mieux privilégier une approche reverse car \(J_f(\mathbf{x})\colon\mathbb{R}^4\to\mathbb{R}^2\), c’est une matrice plus large que grande. À l’inverse, lors du calcul de la dérivée de la jacobienne la dimension de sortie est plus grande \(H_f(\mathbf{x})\colon\mathbb{R}^4\to(\mathbb{R}^4\to\mathbb{R}^2)\), il vaut mieux privilégier une approche forward.

Ainsi théoriquement, pour \(f\), on préfère jacfwd\(\circ\)jacrev et on évite jacrev\(\circ\)jacfwd.

NB : Ici les dimensions sont tellement petites que la différence entre les modes de calculs n’est peut-être pas significative. Il faudrait profiler le code.

Regardons ce qu’il se passe pour jax.jacfwd\(\circ\)jax.jacfwd !

Calculer la matrice hessienne associée \(H_f(\mathbf{x}), \mathbf{x}=(0.2, -0.1, 0.5, 1.0),\) en utilisant jax.jvp.

Les lignes suivantes calculent la matrice jacobienne en vectorisant les JVP sur la base de l’espace tangent (\(\mathbb{R}^4\)). Nous proposons de suivre la même logique en composant deux JVP.

# Fix W and b for clarity
f_fixed = lambda x: f(x, W, b)

Jac_f_x_fun = jax.vmap(lambda fun, p, t: jax.jvp(fun, p, t)[1], (None, None, 0))
Jac_f_x = Jac_f_x_fun(f_fixed, (x,), (jnp.eye(4),)).T
print(Jac_f_x.shape)
(2, 4)

Plutôt que deux composer 2 appels à jax.vmap ce qui paraît syntaxiquement complexe, nous créons le tableau du produit cartésien des indices des lignes de 2 matrices identités (2 bases de l’espace tangent). Nous parcourons ensuite ces paires d’indices avec un seul jax.vmap et composons les 2 JVP, où chacun prend un vecteur tangent indicé par un des deux indice de la paire.

Ce qui est intéressant ici c’est que nous voyons assez bien la hessienne sous l’angle application bilinéaire s’appliquant à 2 vecteurs de l’espace tangent (rappel : \(H_f(\mathbf{x})\colon \mathbb{R}^4\to\mathbb{R}^4\to\mathbb{R}^2\)).

# Fix W and b for clarity
f_fixed = lambda x: f(x, W, b)

cartesian_prod = jnp.indices((4, 4)).reshape(2,-1).T

H_f_x_4 = jax.vmap(
    lambda pair: jax.jvp(
        lambda x: jax.jvp(
            f_fixed,
            (x,),
            (jnp.eye(4)[pair[1]],)
        )[1],
        (x,),
        (jnp.eye(4)[pair[0]],)
    )[1], (0,)
)(cartesian_prod).T.reshape(2, 4, 4)

assert jnp.allclose(H_f_x, H_f_x_4)
print(H_f_x_4.shape)
print("All good!")
(2, 4, 4)
All good!

Hessian-vector product et Vector-Hessian product

De manière similaire à la définition du JVP et du VJP, les opérations HVP et VHP sont définies par :

\[ \begin{align*} \texttt{hvp}_f\colon\mathbb{R}^4\times\mathbb{R}^4&\to\mathbb{R}^4\times\mathbb{R}^2\\ (\mathbf{x},\mathbf{v})&\mapsto H_f(\mathbf{x})\mathbf{v} \end{align*} \]

\[ \begin{align*} \texttt{vhp}_f\colon\mathbb{R}^4\times\mathbb{R}^2&\to\mathbb{R}^4\times\mathbb{R}^4\\ (\mathbf{x},\mathbf{u})&\mapsto \mathbf{u}^TH_f(\mathbf{x}) \end{align*} \]

(où il faut définir rigoureuseument le produit \(H_f(\mathbf{x})\mathbf{v}\) comme dans cet article).

Leur intérêt réside dans le fait qu’on n’instancie pas la matrice hessienne dans sa globalité. On a néanmoins besoin d’instancier les matrices jacobiennes des calculs intermédiaires.

HVP

La documentation de JAX donne de exemples concrets d’utilisation du HVP.

  1. Écrire la fonction \(\texttt{hvp}_f\) en composant les dérivations forward.
  2. Calculer \(\texttt{hvp}_f(\mathbf{x},(1, 0, 0, 0))\), avec le même \(\mathbf{x}\) que précédemment.
  3. Quelles parties de la matrice hessienne révèle-t-on avec le calcul précédent ?

Voir les exemples de la documentation de JAX.

f_fixed = lambda x: f(x, W, b)

# 1.
def hvp(f, primals, tangents):
    # on a besoin de calculer une matrice jacobienne
    return jax.jvp(jax.jacrev(f), primals, tangents)[1]

# 2.
H_f_x_5_partial = hvp(f_fixed, (x,), (jnp.array([1., 0., 0., 0.]),))

# 3.
assert jnp.allclose(H_f_x_5_partial[0, :], H_f_x[0, :, 0])
print("On découvre la première colonne du première bloc")
assert jnp.allclose(H_f_x_5_partial[1, :], H_f_x[1, :, 0])
print("On découvre la première colonne du deuxième bloc")
On découvre la première colonne du première bloc
On découvre la première colonne du deuxième bloc

Compatibilité avec la chain rule ? (bonus)

Peut-on composer les HVP et / ou VHP comme on l’a fait pour les VJP et JVP ? Si oui, est-ce avantageux de le faire ?

Notons que avec \(f=f_2\circ f_1\), on a \(J_{{f_2}\circ {f_1}}(\mathbf{x})=J_{f_2}({f_1}(\mathbf{x})) + J_{f_1}(\mathbf{x})\), et on peut voir que :

\[ H_{{f_2}\circ {f_1}}({f_2}{\mathbf{x}}) = (H_{f_2}({f_1}(\mathbf{x})) J_{f_1}(\mathbf{x})) J_{f_1}(\mathbf{x}) + J_{f_2}({f_1}(\mathbf{x}))H_{f_1}(\mathbf{x}) \]

Attention, les produits de l’équation précédente sont définis au sens de la proposition 1 de l’article cité précédemment. De plus on remarque que dans le cas de \(f_1\) application linéaire (c’est notre cas ici), le deuxième terme est nul.


Conclusion [5 minutes]

  • La dérivation manuelle est la plus rapide et la plus contrôlable, mais peu scalable. Ici elle est aussi longue que jacrev.
  • jacrev est souvent préférable quand dim_sortie\(=\dim(f(x))\) est petite (notre cas ici). Ce cadre correspond exactement à ce qui se passe dans une couche de réseau de neurones, et se généralise naturellement à des architectures plus profondes.
  • jacfwd est souvent préférable quand dim_entree\(=\dim(x)\) est petite.

La rédifinition de JVP (Jacobian Vector Product) et VJP (Vector Jacobian Product) d’une fonction qui n’est pas disponible dans les primitives fournies par JAX pourraient être utiles dans les cas suivants :

  • l’utilisation d’une fonction non différentiable pour laquelle on veut écrire une dérivée “non-standard” afin de pouvoir l’utiliser dans JAX
  • l’utilisation d’une fonction dont une approximation analytique de la dérivée est disponible mais qui n’est pas implémentée dans JAX