Cormack-Jolly-Seber

Estimating survival with Numpyro
Author
Affiliations

Philip T. Patton

Marine Mammal Research Program

Hawaiʻi Institute of Marine Biology

Published

October 25, 2025

In this notebook, I demonstrate how to estimate survival with Cormack-Jolly-Seber models in NumPyro. This notebook is a near carbon copy of the CJS notebook on the NumPyro documentation page. Nevertheless, I hope that this notebook will be useful in its own right, primarily for folks who are more familiar with CJS models and less familiar with NumPyro.

from jax import random
from jax.scipy.special import expit
from numpyro import handlers
from numpyro.contrib.control_flow import scan
from numpyro.infer import NUTS, MCMC
import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import seaborn as sns

# plotting defaults
plt.style.use('fivethirtyeight')
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.spines.left'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.bottom'] = False
sns.set_palette("tab10")

# make the labels on arviz plots nicer
labeller = az.labels.MapLabeller(
    var_name_map={"psi": r"$\psi$", 'gamma':  r"$\gamma$", 'alpha': r'$\alpha$',
                  'epsilon': r"$\epsilon$", 'p':  r"$p$" , 'beta': r'$\beta$',
                  'phi': r'$\phi$', 'alpha_t': r'$\alpha_t$',}
)

# hyperparameters
RANDOM_SEED = 1792
CHAIN_COUNT = 4
WARMUP_COUNT = 500
SAMPLE_COUNT = 1000

Model definition

This model is very similar to the one defined in the NumPyro Jolly-Seber notebook. In some ways, this model is actually much simpler because we do not have to worry about recruitment into the population. This model does, however, introduce a new character from the NumPyro-verse: handlers.mask().

handlers.mask() tells NumPyro: do not include this the sample of z in the log probability computation when the mask is False. In this case, the mask, has_been_captured, indicates whether the animal has been captured by time t. As such, we essentially ignore the z state until the animal is captured. We include the mask in the carry, and update it with new data after we transition the z state (i.e., after the animal has survived or died between intervals).

The model requires one more trick to run. Essentially, we need to ensure that z[t]=1 during the occasion of the animal’s first capture. To do so, we employ the mask: mu_z_t = has_been_captured * phi * z + (1 - has_been_captured). This forces z[t]=1 until the animals first capture. After the animal’s first capture, the z state is included in the log probability calculation, and mu_z_t simplifies to phi * z.

def p_dot_phi_dot(capture_history):

    capture_count, _ = capture_history.shape
    phi = numpyro.sample("phi", dist.Uniform(0.0, 1.0))
    p = numpyro.sample("p", dist.Uniform(0.0, 1.0))

    def transition_fn(carry, y):

        has_been_captured, z = carry

        with numpyro.plate("animals", capture_count, dim=-1):

            # only compute log probs for animals where has_been_captured is True
            with handlers.mask(mask=has_been_captured):

                # force mu_z_t=1 during the occasion of the animal's first capture
                mu_z_t = has_been_captured * phi * z + (1 - has_been_captured)
                z = numpyro.sample(
                    "z",
                    dist.Bernoulli(dist.util.clamp_probs(mu_z_t)),
                    infer={"enumerate": "parallel"},
                )

                mu_y_t = p * z
                numpyro.sample(
                    "y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)), obs=y
                )

        has_been_captured = has_been_captured | y.astype(bool)
        return (has_been_captured, z), None

    z = jnp.ones(capture_count, dtype=jnp.int32)
    has_been_captured = capture_history[:, 0].astype(bool)
    scan(
        transition_fn,
        (has_been_captured, z),
        jnp.swapaxes(capture_history[:, 1:], 0, 1),
    )
# data
dipper = np.loadtxt('dipper.csv', delimiter=',', dtype=np.int32)

rng_key = random.PRNGKey(RANDOM_SEED)

# specify which sampler you want to use
nuts_kernel = NUTS(p_dot_phi_dot) # 11 seconds

# configure the MCMC run
mcmc = MCMC(nuts_kernel, num_warmup=WARMUP_COUNT, num_samples=SAMPLE_COUNT,
            num_chains=CHAIN_COUNT)

# run the MCMC then inspect the output
mcmc.run(rng_key, dipper)
mcmc.print_summary()
/var/folders/7b/nb0vyhy90mdf30_65xwqzl300000gn/T/ipykernel_40148/284679213.py:10: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc = MCMC(nuts_kernel, num_warmup=WARMUP_COUNT, num_samples=SAMPLE_COUNT,
  0%|          | 0/1500 [00:00<?, ?it/s]warmup:   0%|          | 1/1500 [00:01<26:43,  1.07s/it, 1 steps of size 2.34e+00. acc. prob=0.00]warmup:   7%|▋         | 103/1500 [00:01<00:11, 120.90it/s, 15 steps of size 5.70e-01. acc. prob=0.77]warmup:  15%|█▍        | 220/1500 [00:01<00:04, 271.84it/s, 3 steps of size 1.72e+00. acc. prob=0.78] warmup:  23%|██▎       | 342/1500 [00:01<00:02, 433.36it/s, 3 steps of size 7.71e-01. acc. prob=0.78]warmup:  32%|███▏      | 485/1500 [00:01<00:01, 627.90it/s, 3 steps of size 8.04e-01. acc. prob=0.79]sample:  41%|████      | 610/1500 [00:01<00:01, 763.41it/s, 1 steps of size 7.48e-01. acc. prob=0.93]sample:  49%|████▉     | 735/1500 [00:01<00:00, 877.93it/s, 7 steps of size 7.48e-01. acc. prob=0.94]sample:  57%|█████▋    | 858/1500 [00:01<00:00, 964.39it/s, 7 steps of size 7.48e-01. acc. prob=0.93]sample:  66%|██████▌   | 987/1500 [00:01<00:00, 1049.50it/s, 7 steps of size 7.48e-01. acc. prob=0.94]sample:  75%|███████▍  | 1122/1500 [00:01<00:00, 1131.47it/s, 3 steps of size 7.48e-01. acc. prob=0.94]sample:  84%|████████▎ | 1253/1500 [00:02<00:00, 1179.89it/s, 3 steps of size 7.48e-01. acc. prob=0.94]sample:  92%|█████████▏| 1381/1500 [00:02<00:00, 1184.92it/s, 3 steps of size 7.48e-01. acc. prob=0.94]sample: 100%|██████████| 1500/1500 [00:02<00:00, 659.15it/s, 7 steps of size 7.48e-01. acc. prob=0.94] 
  0%|          | 0/1500 [00:00<?, ?it/s]warmup:   6%|▋         | 94/1500 [00:00<00:01, 927.52it/s, 11 steps of size 1.17e-01. acc. prob=0.77]warmup:  13%|█▎        | 199/1500 [00:00<00:01, 997.09it/s, 7 steps of size 1.02e+00. acc. prob=0.78]warmup:  22%|██▏       | 325/1500 [00:00<00:01, 1111.40it/s, 7 steps of size 8.64e-01. acc. prob=0.78]warmup:  31%|███       | 466/1500 [00:00<00:00, 1225.78it/s, 15 steps of size 4.54e-01. acc. prob=0.79]sample:  39%|███▉      | 589/1500 [00:00<00:00, 1225.31it/s, 3 steps of size 8.10e-01. acc. prob=0.91] sample:  48%|████▊     | 727/1500 [00:00<00:00, 1275.22it/s, 3 steps of size 8.10e-01. acc. prob=0.92]sample:  57%|█████▋    | 858/1500 [00:00<00:00, 1285.94it/s, 3 steps of size 8.10e-01. acc. prob=0.92]sample:  67%|██████▋   | 998/1500 [00:00<00:00, 1320.66it/s, 3 steps of size 8.10e-01. acc. prob=0.92]sample:  75%|███████▌  | 1131/1500 [00:00<00:00, 1292.90it/s, 1 steps of size 8.10e-01. acc. prob=0.92]sample:  84%|████████▍ | 1261/1500 [00:01<00:00, 1278.73it/s, 7 steps of size 8.10e-01. acc. prob=0.92]sample:  93%|█████████▎| 1390/1500 [00:01<00:00, 1279.73it/s, 7 steps of size 8.10e-01. acc. prob=0.92]sample: 100%|██████████| 1500/1500 [00:01<00:00, 1242.75it/s, 7 steps of size 8.10e-01. acc. prob=0.92]
  0%|          | 0/1500 [00:00<?, ?it/s]warmup:   6%|▌         | 93/1500 [00:00<00:01, 929.34it/s, 3 steps of size 1.45e-01. acc. prob=0.78]warmup:  14%|█▎        | 204/1500 [00:00<00:01, 1030.89it/s, 15 steps of size 7.41e-01. acc. prob=0.78]warmup:  22%|██▏       | 324/1500 [00:00<00:01, 1106.58it/s, 15 steps of size 5.29e-01. acc. prob=0.78]warmup:  31%|███       | 463/1500 [00:00<00:00, 1202.10it/s, 23 steps of size 2.62e-01. acc. prob=0.79]sample:  40%|████      | 604/1500 [00:00<00:00, 1275.16it/s, 3 steps of size 9.15e-01. acc. prob=0.90] sample:  50%|████▉     | 749/1500 [00:00<00:00, 1331.58it/s, 3 steps of size 9.15e-01. acc. prob=0.90]sample:  60%|█████▉    | 899/1500 [00:00<00:00, 1383.94it/s, 3 steps of size 9.15e-01. acc. prob=0.90]sample:  69%|██████▉   | 1040/1500 [00:00<00:00, 1390.89it/s, 7 steps of size 9.15e-01. acc. prob=0.90]sample:  79%|███████▉  | 1182/1500 [00:00<00:00, 1397.82it/s, 3 steps of size 9.15e-01. acc. prob=0.90]sample:  88%|████████▊ | 1324/1500 [00:01<00:00, 1402.69it/s, 7 steps of size 9.15e-01. acc. prob=0.90]sample:  98%|█████████▊| 1474/1500 [00:01<00:00, 1431.41it/s, 3 steps of size 9.15e-01. acc. prob=0.90]sample: 100%|██████████| 1500/1500 [00:01<00:00, 1327.84it/s, 7 steps of size 9.15e-01. acc. prob=0.90]
  0%|          | 0/1500 [00:00<?, ?it/s]warmup:   6%|▋         | 94/1500 [00:00<00:01, 928.45it/s, 11 steps of size 1.71e-01. acc. prob=0.78]warmup:  14%|█▍        | 207/1500 [00:00<00:01, 1046.17it/s, 3 steps of size 1.05e+00. acc. prob=0.78]warmup:  23%|██▎       | 344/1500 [00:00<00:00, 1188.46it/s, 7 steps of size 8.39e-01. acc. prob=0.79]warmup:  32%|███▏      | 478/1500 [00:00<00:00, 1246.84it/s, 1 steps of size 2.77e+00. acc. prob=0.79]sample:  40%|████      | 605/1500 [00:00<00:00, 1252.91it/s, 3 steps of size 8.55e-01. acc. prob=0.91]sample:  49%|████▉     | 734/1500 [00:00<00:00, 1264.57it/s, 3 steps of size 8.55e-01. acc. prob=0.91]sample:  58%|█████▊    | 869/1500 [00:00<00:00, 1292.30it/s, 3 steps of size 8.55e-01. acc. prob=0.91]sample:  67%|██████▋   | 999/1500 [00:00<00:00, 1279.34it/s, 3 steps of size 8.55e-01. acc. prob=0.91]sample:  76%|███████▋  | 1146/1500 [00:00<00:00, 1336.90it/s, 5 steps of size 8.55e-01. acc. prob=0.91]sample:  85%|████████▌ | 1282/1500 [00:01<00:00, 1342.56it/s, 1 steps of size 8.55e-01. acc. prob=0.91]sample:  94%|█████████▍| 1417/1500 [00:01<00:00, 1327.80it/s, 3 steps of size 8.55e-01. acc. prob=0.91]sample: 100%|██████████| 1500/1500 [00:01<00:00, 1283.24it/s, 3 steps of size 8.55e-01. acc. prob=0.91]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
         p      0.91      0.02      0.92      0.87      0.95   2539.04      1.00
       phi      0.69      0.03      0.69      0.64      0.73   3249.77      1.00

Number of divergences: 0
samples = mcmc.get_samples(group_by_chain=True)
idata = az.from_dict(samples)

az.plot_trace(idata, figsize=(8,4), var_names=['p', 'phi'], labeller=labeller)
plt.subplots_adjust(hspace=0.4)