Skip to content

API Reference

smartclip: Adaptive gradient clipping algorithms for deep learning frameworks.

This package provides a framework-agnostic core with optional thin integrations for PyTorch, TensorFlow/Keras, and JAX/Flax. Public APIs are typed and designed for fast import times and production use.

AutoClip

Bases: ClipperBase

Adaptive clipping of gradients.

Modes: - "auto" (default): hyperparameter-free threshold using P² median (p=0.5) and Welford variance: T = median + 3 * std. - "percentile": target percentile of recent gradient norms using either EMA quantile estimator (history="ema") or rolling window (history="window").

observe(value, key=None)

Observe a gradient norm for a grouping key.

Backends should call this once per measured norm (global/layer/param) before applying clipping. Values that are non-finite are ignored when guard_nans is True.

Parameters:

Name Type Description Default
value float

Gradient norm (L2 norm of gradients for this group).

required
key Optional[Key]

Grouping key tuple. Examples: - ("global",) for global scope - ("layer", "conv1") for per-layer scope - ("param", "0") for per-parameter scope Defaults to ("global",) if None.

None

threshold(key=None)

Return current threshold for a key (default: global).

This does not enforce warmup/min-history gates. Callers should check can_clip() to decide whether clipping should be applied.

Parameters:

Name Type Description Default
key Optional[Key]

Grouping key tuple (e.g., ("global",), ("layer", "conv1")). Defaults to ("global",) if None.

None

Returns:

Type Description
float

Current percentile threshold, lower-bounded by eps.

threshold_any()

Convenience method for callers that do not track keys.

Returns the global threshold if available, otherwise the threshold for the single key if exactly one exists, otherwise eps.

AGC

Bases: ClipperBase

Adaptive Gradient Clipping (NFNets-style).

Scales gradients based on the ratio between gradient norm and parameter (weight) norm per group. Given gradient norm g and weight norm w, the target maximum gradient norm is T = clipping * (w + eps) and the applied scale is::

scale = min(1.0, T / (g + eps))

When exclude_bias_bn=True a simple framework-agnostic heuristic is used to skip parameters with dimensionality <= 1 (bias vectors and affine scale parameters such as BatchNorm/LN gammas).

observe(grad_norm, weight_norm, key=None)

Record one AGC observation for warmup/min-history gating.

Backends should call this once per group measurement prior to applying scaling. Non-finite values are ignored when guard_nans is True.

scale(grad_norm, weight_norm)

Compute scale factor in [0, 1] for given gradient and weight norms.

scale = min(1, target_norm(weight_norm) / (grad_norm + eps)) Non-finite inputs return 1.0 (no scaling) when guard_nans is True.

should_exclude_param(param)

Return True if a parameter should be excluded from clipping.

Heuristic: when exclude_bias_bn is enabled, exclude parameters whose data has dimensionality <= 1 (bias vectors and affine scales). If shape cannot be determined, do not exclude.

target_norm(weight_norm)

Return the allowed gradient norm for a given weight norm.

Computes clipping * (weight_norm + eps) and lower-bounds the result by eps.

ZScoreClip

Bases: ClipperBase

Z-score based adaptive clipping using EMA mean/variance.

Tracks exponentially-weighted moving averages of the observed gradient norm (m) and squared norm (m2) per grouping key. The standard deviation is computed as sqrt(max(0, m2 - m^2)). For a new observation with norm g, the Z-score is z = (g - m) / (std + eps) and clipping is recommended when z > zmax. Backends typically implement clipping by scaling gradients by min(1, T / (g + eps)) where the threshold T = m + zmax * std.

observe(value, key=None)

Observe a gradient norm for a grouping key.

Backends should call this once per measured norm (global/layer/param) before applying clipping. Values that are non-finite are ignored when guard_nans is True.

Parameters:

Name Type Description Default
value float

Gradient norm (L2 norm of gradients for this group).

required
key Optional[Key]

Grouping key tuple. Examples: - ("global",) for global scope - ("layer", "conv1") for per-layer scope - ("param", "0") for per-parameter scope Defaults to ("global",) if None.

None

stats(key=None)

Return the current (mean, std) estimates for a key.

If the key has not been observed, or estimates are uninitialized, returns (0.0, 0.0).

threshold(key=None)

Return current z-score threshold m + zmax * std for a key.

This does not enforce warmup/min-history gates. Callers should check can_clip() to decide whether clipping should be applied.

Parameters:

Name Type Description Default
key Optional[Key]

Grouping key tuple (e.g., ("global",), ("layer", "conv1")). Defaults to ("global",) if None.

None

Returns:

Type Description
float

Current threshold, lower-bounded by eps.

threshold_any()

Convenience method for callers that do not track keys.

Returns the global threshold if available, otherwise the threshold for the single key if exactly one exists, otherwise eps.

apply(model, clipper, on_metrics=None)

Apply adaptive clipping to model parameters.

Delegates to the active backend determined from the model instance.

step(model, optimizer, clipper, on_metrics=None)

Clip gradients on the model and then call optimizer.step().

clip_context(model, optimizer=None, clipper=None, on_metrics=None)

Context manager that clips before each optimizer step for the active backend.

Defaults to AutoClip() when clipper is None.

Core

AGC

Bases: ClipperBase

Adaptive Gradient Clipping (NFNets-style).

Scales gradients based on the ratio between gradient norm and parameter (weight) norm per group. Given gradient norm g and weight norm w, the target maximum gradient norm is T = clipping * (w + eps) and the applied scale is::

scale = min(1.0, T / (g + eps))

When exclude_bias_bn=True a simple framework-agnostic heuristic is used to skip parameters with dimensionality <= 1 (bias vectors and affine scale parameters such as BatchNorm/LN gammas).

observe(grad_norm, weight_norm, key=None)

Record one AGC observation for warmup/min-history gating.

Backends should call this once per group measurement prior to applying scaling. Non-finite values are ignored when guard_nans is True.

scale(grad_norm, weight_norm)

Compute scale factor in [0, 1] for given gradient and weight norms.

scale = min(1, target_norm(weight_norm) / (grad_norm + eps)) Non-finite inputs return 1.0 (no scaling) when guard_nans is True.

should_exclude_param(param)

Return True if a parameter should be excluded from clipping.

Heuristic: when exclude_bias_bn is enabled, exclude parameters whose data has dimensionality <= 1 (bias vectors and affine scales). If shape cannot be determined, do not exclude.

target_norm(weight_norm)

Return the allowed gradient norm for a given weight norm.

Computes clipping * (weight_norm + eps) and lower-bounds the result by eps.

AutoClip

Bases: ClipperBase

Adaptive clipping of gradients.

Modes: - "auto" (default): hyperparameter-free threshold using P² median (p=0.5) and Welford variance: T = median + 3 * std. - "percentile": target percentile of recent gradient norms using either EMA quantile estimator (history="ema") or rolling window (history="window").

observe(value, key=None)

Observe a gradient norm for a grouping key.

Backends should call this once per measured norm (global/layer/param) before applying clipping. Values that are non-finite are ignored when guard_nans is True.

Parameters:

Name Type Description Default
value float

Gradient norm (L2 norm of gradients for this group).

required
key Optional[Key]

Grouping key tuple. Examples: - ("global",) for global scope - ("layer", "conv1") for per-layer scope - ("param", "0") for per-parameter scope Defaults to ("global",) if None.

None

threshold(key=None)

Return current threshold for a key (default: global).

This does not enforce warmup/min-history gates. Callers should check can_clip() to decide whether clipping should be applied.

Parameters:

Name Type Description Default
key Optional[Key]

Grouping key tuple (e.g., ("global",), ("layer", "conv1")). Defaults to ("global",) if None.

None

Returns:

Type Description
float

Current percentile threshold, lower-bounded by eps.

threshold_any()

Convenience method for callers that do not track keys.

Returns the global threshold if available, otherwise the threshold for the single key if exactly one exists, otherwise eps.

ClipperBase

Base class for adaptive gradient clippers.

This class manages configuration, numeric stability constants, and minimal state serialization. Subclasses implement algorithm-specific logic.

ParamLike

Bases: Protocol

A minimal protocol representing a trainable parameter.

The parameter stores its data (tensor/array) and an optional gradient.

TensorLike

Bases: Protocol

A minimal protocol representing a tensor/array from any framework.

Intentionally small to avoid importing optional frameworks at type-check time.

ZScoreClip

Bases: ClipperBase

Z-score based adaptive clipping using EMA mean/variance.

Tracks exponentially-weighted moving averages of the observed gradient norm (m) and squared norm (m2) per grouping key. The standard deviation is computed as sqrt(max(0, m2 - m^2)). For a new observation with norm g, the Z-score is z = (g - m) / (std + eps) and clipping is recommended when z > zmax. Backends typically implement clipping by scaling gradients by min(1, T / (g + eps)) where the threshold T = m + zmax * std.

observe(value, key=None)

Observe a gradient norm for a grouping key.

Backends should call this once per measured norm (global/layer/param) before applying clipping. Values that are non-finite are ignored when guard_nans is True.

Parameters:

Name Type Description Default
value float

Gradient norm (L2 norm of gradients for this group).

required
key Optional[Key]

Grouping key tuple. Examples: - ("global",) for global scope - ("layer", "conv1") for per-layer scope - ("param", "0") for per-parameter scope Defaults to ("global",) if None.

None

stats(key=None)

Return the current (mean, std) estimates for a key.

If the key has not been observed, or estimates are uninitialized, returns (0.0, 0.0).

threshold(key=None)

Return current z-score threshold m + zmax * std for a key.

This does not enforce warmup/min-history gates. Callers should check can_clip() to decide whether clipping should be applied.

Parameters:

Name Type Description Default
key Optional[Key]

Grouping key tuple (e.g., ("global",), ("layer", "conv1")). Defaults to ("global",) if None.

None

Returns:

Type Description
float

Current threshold, lower-bounded by eps.

threshold_any()

Convenience method for callers that do not track keys.

Returns the global threshold if available, otherwise the threshold for the single key if exactly one exists, otherwise eps.