import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import optax
import matplotlib.pyplot as plt
Premiers pas avec jax
Différentiation automatique avec Jax
Le but de cette vignette est d’implémenter une régression logistique et/ou une régression multivariée avec JAX.
Préliminaires
Installation
Conda
Jax est disponible dans le channel conda-forge et peut donc s’installer dans un environnement conda
#| eval: true
conda create -n jax
conda activate jax
## Install the CPU version
conda install jax -c conda-forge
## Install packages necessary for the render
conda install nbclient nbformat ipykernel
Pour des instruction détaillées pour l’installation en mode GPU ou TPU, se référer à la documentation officielle.
Il suffit alors d’activer l’environnement jax pour produire le html à partir du qmd
#| eval: false
conda activate jax
quarto render my_document.qmd --to html
Pip
Si vous préférez une installation via pip (pour une version cpu),
#| eval: false
pip3 install jax jaxlib
Pour une installation GPU (avec cuda 12 par exemple, on vous laisse gérer la compatibilité de vos driver Nvidia and Cie),
#| eval: false
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
L’utilisation de venv
est recommandable (mais non obligatoire).
On installe également optax pour avoir accès à des optimiseurs.
#| eval: false
pip3 install optax
Important Pour utiliser optax
, il vaut mieux utiliser pip pour installer jax
et jaxlib
, les versions disponibles dans les dépôts conda sont en effet trop anciennes pour optax.
Premiers pas
Philosophie
En quelques mots, JAX est une bibliothèque Python développée par Google et initialement utilisée dans TensorFLow. Elle permet de faire de l’algèbre linéaire à la numpy
, avec 2 propriétés clés la rendant extrêmement performante:
- un autograd permettant la différenciation automatique de calcul Python/Numpy
- un compileur pour GPU et autres (XLA), dédié à l’algèbre linéaire qui permet d’optimiser les temps d’exécution grâce à une approche JIT (Just-in Time, c’est-à-dire une optimisation du code à l’exécution et non pas avant l’appel comme avec un compileur classique).
L’objectif de la bibliothèque est de proposer une expérience utilisateur aussi proche que possible de calculs à la Numpy
, notamment à l’aide de décorateurs Python. Néanmoins, pour accéder pleinement aux capacités de JAX, un certain nombre de contraintes d’écriture des programmes s’appliquent, que nous allons essayer de présenter pas à pas.
Import de la bibliothèque
L’import complet/standard est le suivant:
On peut détailler les fonctionnalités des modules comme suit:
- le module
jax.numpy
, akajnp
, porte les opérations matricielles usuelles de manière quasi transparente - le module
random
définit les outils de génération de nombres aléatoires, propres à JAX et très différents de Numpy - le module
grad
gère l’autodifférenciation - le module
jit
gère la “just-in time” compilation (accélération du code) - le module
vmap
permet de vectoriser automatiquement certaines opérations
Jax.numpy: interface Algèbre linéaire haut-niveau
On commence par simuler des données aléatoires via les outils de jax. Attention la gestion de la clé aléatoire est explicite. Après avoir créé une clé et avant chaque appel à une fonction aléatoire, il faut faire évoluer la graine à la main
= random.split(key, 2) key,subkey
et utiliser subkey
(les sous-clés) dans l’appel à la fonction aléatoire (ou aux fonctions aléatoire) comme écrit ci-dessous.
= 10000
n = 100
p = random.PRNGKey(0)
key = random.split(key, 3)
key,sub1,sub2 = jnp.ones((n, 1))
ones = random.normal(sub1, (n, p-1))
x = jnp.concatenate([ones, x], axis = 1)
x = random.normal(sub2, (p,1)) beta_true
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Avant de les multiplier. On utilise ici la fonction block_until_ready()
uniquement pour mesurer le temps effectif de calcul. En effet, JAX fait de l’évaluation asynchrone (comme {future}
en R) pour rendre la main à l’utilisateur après l’envoi de la commande.
%timeit odds = jnp.dot(x, beta_true).block_until_ready() # runs on the CPU
172 µs ± 240 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
On échantillonne ensuite des variables suivant une loi de Bernoulli.
= jnp.dot(x, beta_true)
odds = random.split(key, 2)
key,subkey = random.bernoulli(subkey, odds) y
et une perte logistique
\[\ell(y, x, \theta) = -\log p(y; \sigma(x^{\top}\theta)) = -y (x^\top \theta) + \log(1 + e^{x^\top \theta})\]
def logistic_loss(y, x, theta):
= jnp.dot(x, theta)
odds return -jnp.vdot(y, odds) + jnp.sum(jnp.log(1.0 + jnp.exp(odds)))
Qu’on peut tester sur un exemple simple
## Should be log(2)
True, 1.0, 0) logistic_loss(
Array(0.6931472, dtype=float32)
Just-in-time compilation
La version normale de notre fonction logistique est déjà rapide.
%timeit logistic_loss(y, x, beta_true).block_until_ready()
277 µs ± 895 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
mais on peut l’accélerer en compilant la fonction via le décorateur @jit
ou la fonction jit()
de façon complètement transparente pour l’utilisateur.
## Utilisation du décorateur @jit
@jit
def logistic_loss(y, x, theta):
= jnp.dot(x, theta)
odds return -jnp.vdot(y, odds) + jnp.sum(jnp.log(1.0 + jnp.exp(odds)))
## Utilisation de jit()
= jit(logistic_loss)
logistic_loss_jit %timeit logistic_loss_jit(y, x, beta_true).block_until_ready()
222 µs ± 3.66 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
La différence n’est pas très importante dans cet exemple. jit()
permet des gains d’autant plus importants qu’on travaille sur des fonctions complexes.
Attention, il n’est pas toujours possible de jitter une fonction, en particulier, si cette fonction implique un branchement conditionnel:
def f(x):
if x > 5:
return x
else:
return 2*x
L’erreur provient du fait que la définition de la fonction dépend de la valeur des entrées.
= jit(f)
f_jit ## Renvoie une erreur
1) f_jit(
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_1050/2514275982.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
Comme l’indique le message d’erreur
This concrete value was not available in Python because it depends on the value of the argument x.
grad: auto-différentiation
JAX permet de calculer le gradient d’une fonction via grad()
. La syntaxe est différente de torch et plus proche de ce qu’on ferait dans une fonction mathématique.
def loss(theta):
return logistic_loss(y, x, theta)
## random start for theta
= random.split(key, 2)
key,subkey = random.normal(key, (p, 1))
theta = grad(loss) grad_loss
= grad(loss)
grad_loss %timeit grad_loss(theta)
7.92 ms ± 644 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
grad()
peut-être combiné à jit()
dans tous les sens à condition que les fonctions s’y prêtent.
= jit(grad(loss))
grad_loss ## Warmup to cache grad loss
grad_loss(theta).shape## Actual time recording
%timeit grad_loss(theta)
368 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Mais ce n’est pas toujours intéressant.
= jit(grad(jit(loss)))
grad_loss ## Warmup to cache grad loss
grad_loss(theta).shape## Actual time recording
%timeit grad_loss(theta)
346 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Vectorisation
JAX permet enfin de vectoriser automatiquement des opérations de façon efficace (en faisant descendre la boucle à l’intérieur de la fonction, au niveau des primitives utilisées pour le calcul).
Considérons un exemple simple où on veut calculer des logs-odds sur mesures répétées.
## Matrice de covariables, répétées en temps
## [temps, individu, variable]
= random.split(key, 2)
key,subkey = random.normal(key, (10, n, p))
X def compute_odds(x, theta):
return jnp.dot(x, theta)
def compute_odds_batched(X, theta):
return jnp.stack([compute_odds(x, theta) for x in X])
Et testons ce qui se passe. On appelle la fonction sur une tranche de X
%timeit compute_odds(X[:1,:, :], beta_true)
2.59 ms ± 9.95 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Puis sur toutes les tranches de X avec notre fonction vectorisée manuellement.
%timeit compute_odds_batched(X, beta_true)
14.3 ms ± 148 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Puis sur toutes les tranches de X avec notre fonction vectorisée via vmap()
.
def compute_odds_batched_vmap(X, theta):
def f(x):
return compute_odds(x, theta)
return vmap(f)(X)
compute_odds_batched_vmap(X, beta_true).shape
(10, 10000, 1)
%timeit compute_odds_batched_vmap(X, beta_true)
12.5 ms ± 33.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
À comparer à la version native jax qui est déjà nativement vectorisée pour cette opération
%timeit compute_odds(X, beta_true)
11.5 ms ± 25.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Le gain n’est pas très important dans cette exemple précis mais on se rapproche quand même de la performance de la version native, par rapport à notre vectorisation manuelle.
Optimisation de la fonction objective
À la main
Contrairement à torch
, on n’a pas d’optimiseur défini clé en main dans JAX. La dérivée est néanmoins une fonction comme les autres et on peut donc écrire très simplement un algorithme simple de descente de gradient.
%%time
= 50
num_iterations = []
loss_vector ## Learning rate
= 0.001
lr ## Initialisation de theta
= jnp.zeros(p)
theta ## Fonction de perte, en mode jit
@jit
def loss(theta):
return logistic_loss(y, x, theta)
## Gradient de la fonction de perte, en mode jit
= jit(grad(loss))
grad_loss
## Descente de gradient
for i in range(num_iterations):
# Suivi de la fonction de perte
loss_vector.append(loss(theta))# Mise à jour du paramètre
= theta - lr * grad_loss(theta) theta
CPU times: user 284 ms, sys: 12 ms, total: 296 ms
Wall time: 294 ms
Et on peut vérifier que la fonction de perte décroit au cours du temps.
range(1, num_iterations + 1), loss_vector)
plt.plot('Iterations')
plt.xlabel('Loss')
plt.ylabel('Loss vs. Iterations')
plt.title( plt.show()
Et que les paramètres de régression estimés se rapprochent des vraies valeurs
='o', linestyle="none")
plt.plot(beta_true, theta, marker-3, 3], [-3, 3], color='r', linestyle='-', linewidth=2)
plt.plot(['True parameter')
plt.xlabel('Estimated parameter')
plt.ylabel('Estimated versus true parameter')
plt.title( plt.show()
Avec Optax
On peut néanmoins utiliser la librairie optax pour définir des optimiseurs comme en torch
. On va utiliser ici Adam.
Fonction objective
On commence par définir la fonction objective avec un ordre précis pour les arguments:
- paramètres à optimiser (typiquement \(\theta\), les coefficients de régression)
- paramètres additionels (pénalités, etc)
- données (avec le mot clé
data
)
def logistic_loss(y, x, theta):
= jnp.dot(x, theta)
odds return -jnp.vdot(y, odds) + jnp.sum(jnp.log(1.0 + jnp.exp(odds)))
def objective_and_grad(params, penalty, data):
= data[:, :-1]
x = data[:, -1]
y def loss(params):
return logistic_loss(y, x, params)
= loss(params)
loss_value = grad(loss)(params)
loss_grad return [loss_value, loss_grad]
Itérateur de données
Adam est un algorithme d’optimisation stochastique. On définit donc un itérateur qui va échantillonner les données.
= 100
batch_size = 1000
n_iter # key, subkey = random.split(key, 2)
def data_iterator(key, data):
return random.choice(key, data, (batch_size, ), replace = False)
Optimisation
On définit enfin une fonction de fit qui travaille sur des batchs.
def fit(data, params, optimizer, key):
= optimizer.init(params)
opt_state = []
loss_vector
@jit
def step(params, opt_state, batch):
= objective_and_grad(params, 0, batch)
loss_value, grads = optimizer.update(grads, opt_state, params)
updates, opt_state = optax.apply_updates(params, updates)
params return params, opt_state, loss_value
for i in range(n_iter):
= random.split(key, 2)
key, subkey = data_iterator(subkey, data)
batch = step(params, opt_state, batch)
params, opt_state, loss_value
loss_vector.append(loss_value.item())if i % 100 == 0:
print(f'step {i}, loss: {loss_value}')
return [params, loss_vector]
Finalement, on peut ajuster notre fonction paramétrée en utilisant l’optimiseur Adam fourni par optax.
%%time
= jnp.zeros((x.shape[1], ))
initial_params = optax.adam(learning_rate=1e-2)
optimizer = jnp.concatenate([x, y], axis = 1)
data = fit(data, initial_params, optimizer, key) params,loss_vector
step 0, loss: 69.3147201538086
step 100, loss: 18.580368041992188
step 200, loss: 12.382568359375
step 300, loss: 11.707901000976562
step 400, loss: 8.840866088867188
step 500, loss: 10.6444091796875
step 600, loss: 8.130218505859375
step 700, loss: 8.6331787109375
step 800, loss: 8.014801025390625
step 900, loss: 5.246734619140625
CPU times: user 9.02 s, sys: 200 ms, total: 9.22 s
Wall time: 8.56 s
On peut vérifier que la fonction objective converge sans décroître systématiquement,
range(1, n_iter+1), loss_vector)
plt.plot('Iterations')
plt.xlabel('Loss')
plt.ylabel('Loss vs. Iterations')
plt.title( plt.show()
et que les paramètres sont proches des bonnes valeurs.
='o', linestyle="none")
plt.plot(beta_true, params, marker-3, 3], [-3, 3], color='r', linestyle='-', linewidth=2)
plt.plot(['True parameter')
plt.xlabel('Estimated parameter')
plt.ylabel('Estimated versus true parameter')
plt.title( plt.show()