jaxpolylog quickstart#
This notebook is a condensed tour of
jaxpolylog, the JAX-native
polylogarithm package used internally by jaxvacua for
worldsheet-instanton sums.
It covers:
evaluating \(\mathrm{Li}_s(z)\) with each of the four
approxstrategies and benchmarking againstmpmath,the auto-patched
"patch"strategy at the analytical crossover,arbitrary-order automatic differentiation, verified against the closed-form identity \(\tfrac{\mathrm{d}}{\mathrm{d}z}\mathrm{Li}_s(z) = \mathrm{Li}_{s-1}(z)/z\).
For the full theory and stability story, see the jaxpolylog introduction.
Setup#
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import mpmath
import jaxpolylog
from jaxpolylog import jax_polylog, jax_polylog_vmap
1. Direct evaluation against mpmath#
We parametrise points by \(z = \exp(2\pi i\, t)\): small \(|t|\) is the
regime where the "zero" Laurent expansion is fastest, large \(|t|\)
the regime where the "inf" series is fastest, and "patch"
switches between them.
def li_table(s, t, p_range=100):
z = np.exp(2 * np.pi * 1j * t)
return {
"z": z,
"Li_inf": complex(jax_polylog(z, s, p_range, "inf")),
"Li_zero": complex(jax_polylog(z, s, p_range, "zero")),
"Li_patch": complex(jax_polylog(z, s, p_range, "patch")),
"Li_mpmath": complex(mpmath.polylog(s, z)),
}
for t in (2j, 1j, 1e-5j):
print(f"t = {t!r}")
for k, v in li_table(s=3, t=t).items():
print(f" {k:<10s} = {v}")
print()
Reading the output.
At \(t = 2\mathrm{i}\) (well inside the unit disk) every strategy agrees with
mpmathto all displayed digits.At \(t = \mathrm{i}\) (on the unit circle) the
"inf"series is at its precision edge;"zero"and"patch"remain accurate.At \(t = 10^{-5}\mathrm{i}\) (extremely close to \(z = 1\)) the
"inf"series stalls completely;"zero"and"patch"reproduce thempmathvalue.
This is exactly what the "patch" dispatch handles automatically.
2. The analytical crossover point#
jaxpolylog.polylogs._PVAL_OPTIMAL is the unique positive solution of
\(e^{-2\pi t} = t\), where the truncation errors of the "inf" and
"zero" series are equal. It is pre-computed once at import time
and used as the default crossover.
t_star = jaxpolylog.polylogs._PVAL_OPTIMAL
print(f"t_star = {t_star:.6f}")
print(f"check e^(-2*pi*t_star) = {np.exp(-2 * np.pi * t_star):.6f}")
3. Vectorisation#
zs = jnp.linspace(0.1, 0.95, 32) + 0.0j
vals = jax_polylog_vmap(zs, s=2, p_range=200, approx="patch")
print(f"input shape: {zs.shape}")
print(f"output shape: {vals.shape}")
print(f"first 4 values: {vals[:4]}")
4. Automatic differentiation#
jax_polylog carries a custom JVP rule that implements the analytic
identity \(\tfrac{\mathrm{d}}{\mathrm{d}z}\mathrm{Li}_s(z) =
\mathrm{Li}_{s-1}(z)/z\) via a numerically-stable helper that never
divides by \(z\) in the series regime. This is what makes high-order
derivatives well-defined down to \(|z| \le 10^{-30}\).
def f(z_real):
z = z_real + 0.0j
return jax_polylog(z, 3, 200, "patch").real
z_val = 0.5
df_dz = float(jax.grad(f)(z_val))
ana = float((jax_polylog(z_val + 0.0j, 2, 200, "patch") / z_val).real)
print(f"jax.grad(Li_3)(z) = {df_dz:+.12e}")
print(f"Li_2(z) / z = {ana:+.12e}")
print(f"|deviation| = {abs(df_dz - ana):.2e}")
# Second derivative
d2 = float(jax.grad(jax.grad(f))(z_val))
print(f"jax.grad(jax.grad(Li_3))(z) = {d2:+.12e}")
What next#
The jaxpolylog introduction covers the mathematics, the four strategies, exact-Bernoulli arithmetic, branch safety, and the deep-LCS stress test.
The jaxpolylog API reference documents
jax_polylog,jax_polylog_vmap, and the internal helpers.See the jaxvacua quickstart for the worldsheet-instanton sums where this primitive is actually used.