JAX Guide
Clip gradients explicitly with the backend, or use the clip context to wrap your Optax update.
Clip context (Optax update wrapped)
import jax
import optax
import smartclip as sc
tx = optax.adam(3e-4)
with sc.clip_context(model, optimizer=tx): # AutoClip auto mode by default
grads = jax.grad(loss_fn)(params, batch)
updates, opt_state = tx.update(grads, opt_state, params) # clipped automatically
params = optax.apply_updates(params, updates)
Training step with Optax (explicit apply_grads)
AutoClip (auto mode)
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
import smartclip as sc
from smartclip.backends import jax as sc_jax
clipper = sc.AutoClip()
def loss_fn(params, batch):
logits = model.apply(params, batch["x"], train=True)
return jnp.mean(cross_entropy(logits, batch["y"]))
@jax.jit
def train_step(state: TrainState, batch):
grads = jax.grad(loss_fn)(state.params, batch)
def on_metrics(rec: dict) -> None:
pass # log to W&B/TensorBoard if desired
clipped = sc_jax.apply_grads(grads, state.params, clipper, on_metrics=on_metrics)
updates, new_opt_state = state.tx.update(clipped, state.opt_state, state.params)
new_params = optax.apply_updates(state.params, updates)
return state.replace(step=state.step + 1, params=new_params, opt_state=new_opt_state)
AGC
clipper = sc.AGC(clipping=0.01)
grads = jax.grad(loss_fn)(params, batch)
clipped = sc_jax.apply_grads(grads, params, clipper, on_metrics=lambda rec: None)
updates, opt_state = tx.update(clipped, opt_state, params)
params = optax.apply_updates(params, updates)
Z-Score
clipper = sc.ZScoreClip(zmax=3.0)
grads = jax.grad(loss_fn)(params, batch)
clipped = sc_jax.apply_grads(grads, params, clipper, on_metrics=lambda rec: None)
updates, opt_state = tx.update(clipped, opt_state, params)
params = optax.apply_updates(params, updates)
Notes:
- Global and per-leaf clipping are supported; AGC uses weight norms per leaf.
- Clip context wraps the optimizer's
updatefor convenience; alternatively useapply_gradsexplicitly.