jaxpolylog quickstart#

This notebook is a condensed tour of jaxpolylog, the JAX-native polylogarithm package used internally by jaxvacua for worldsheet-instanton sums.

It covers:

  1. evaluating \(\mathrm{Li}_s(z)\) with each of the four approx strategies and benchmarking against mpmath,

  2. the auto-patched "patch" strategy at the analytical crossover,

  3. arbitrary-order automatic differentiation, verified against the closed-form identity \(\tfrac{\mathrm{d}}{\mathrm{d}z}\mathrm{Li}_s(z) = \mathrm{Li}_{s-1}(z)/z\).

For the full theory and stability story, see the jaxpolylog introduction.

Setup#

import warnings

warnings.filterwarnings("ignore")

import numpy as np
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

import mpmath
import jaxpolylog
from jaxpolylog import jax_polylog, jax_polylog_vmap

1. Direct evaluation against mpmath#

We parametrise points by \(z = \exp(2\pi i\, t)\): small \(|t|\) is the regime where the "zero" Laurent expansion is fastest, large \(|t|\) the regime where the "inf" series is fastest, and "patch" switches between them.

def li_table(s, t, p_range=100):
    z = np.exp(2 * np.pi * 1j * t)
    return {
        "z": z,
        "Li_inf": complex(jax_polylog(z, s, p_range, "inf")),
        "Li_zero": complex(jax_polylog(z, s, p_range, "zero")),
        "Li_patch": complex(jax_polylog(z, s, p_range, "patch")),
        "Li_mpmath": complex(mpmath.polylog(s, z)),
    }


for t in (2j, 1j, 1e-5j):
    print(f"t = {t!r}")
    for k, v in li_table(s=3, t=t).items():
        print(f"  {k:<10s} = {v}")
    print()

Reading the output.

  • At \(t = 2\mathrm{i}\) (well inside the unit disk) every strategy agrees with mpmath to all displayed digits.

  • At \(t = \mathrm{i}\) (on the unit circle) the "inf" series is at its precision edge; "zero" and "patch" remain accurate.

  • At \(t = 10^{-5}\mathrm{i}\) (extremely close to \(z = 1\)) the "inf" series stalls completely; "zero" and "patch" reproduce the mpmath value.

This is exactly what the "patch" dispatch handles automatically.

2. The analytical crossover point#

jaxpolylog.polylogs._PVAL_OPTIMAL is the unique positive solution of \(e^{-2\pi t} = t\), where the truncation errors of the "inf" and "zero" series are equal. It is pre-computed once at import time and used as the default crossover.

t_star = jaxpolylog.polylogs._PVAL_OPTIMAL
print(f"t_star = {t_star:.6f}")
print(f"check  e^(-2*pi*t_star)  = {np.exp(-2 * np.pi * t_star):.6f}")

3. Vectorisation#

zs = jnp.linspace(0.1, 0.95, 32) + 0.0j
vals = jax_polylog_vmap(zs, s=2, p_range=200, approx="patch")
print(f"input  shape: {zs.shape}")
print(f"output shape: {vals.shape}")
print(f"first 4 values: {vals[:4]}")

4. Automatic differentiation#

jax_polylog carries a custom JVP rule that implements the analytic identity \(\tfrac{\mathrm{d}}{\mathrm{d}z}\mathrm{Li}_s(z) = \mathrm{Li}_{s-1}(z)/z\) via a numerically-stable helper that never divides by \(z\) in the series regime. This is what makes high-order derivatives well-defined down to \(|z| \le 10^{-30}\).

def f(z_real):
    z = z_real + 0.0j
    return jax_polylog(z, 3, 200, "patch").real


z_val = 0.5
df_dz = float(jax.grad(f)(z_val))
ana = float((jax_polylog(z_val + 0.0j, 2, 200, "patch") / z_val).real)
print(f"jax.grad(Li_3)(z) = {df_dz:+.12e}")
print(f"Li_2(z) / z       = {ana:+.12e}")
print(f"|deviation|       = {abs(df_dz - ana):.2e}")

# Second derivative
d2 = float(jax.grad(jax.grad(f))(z_val))
print(f"jax.grad(jax.grad(Li_3))(z) = {d2:+.12e}")

What next#

  • The jaxpolylog introduction covers the mathematics, the four strategies, exact-Bernoulli arithmetic, branch safety, and the deep-LCS stress test.

  • The jaxpolylog API reference documents jax_polylog, jax_polylog_vmap, and the internal helpers.

  • See the jaxvacua quickstart for the worldsheet-instanton sums where this primitive is actually used.