Open spatial capture-recapture

Estimating survival and density with NumPyro
Author
Affiliations

Philip T. Patton

Marine Mammal Research Program

Hawaiʻi Institute of Marine Biology

Published

January 14, 2026

In this notebook, I introduce how to estimate survival and abundance with open spatial capture-recapture (SCR) models. Open SCR models are a combination of closed population SCR models and Jolly-Seber models. The basic open SCR model assumes that animals enter and leave the population according to a Jolly-Seber process, each animal has an activity center, and the probability of encountering the animal declines with distance from its activity center. In this example, we will assume that each animal’s activity center is constant across sampling occasions, and that activity centers are uniformly distributed across the state space.

In this example, I will use simulated data to ensure that we’re getting reasonable results. That said, the ovenbird dataset from the PyMC SCR notebook would be a great example here!

%config InlineBackend.figure_format = 'retina'

import os

# # this suppresses a warning about the cuda backend
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# settings for cpu
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=4'
os.environ['JAX_ENABLE_X64'] = 'False'
import numpyro
numpyro.set_host_device_count(4)

from itertools import product

from jax import random, debug
from jax.scipy.special import logit
from numpyro.contrib.control_flow import scan
from numpyro.infer import NUTS, MCMC, Predictive
import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
import seaborn as sns

# plotting defaults
plt.style.use('fivethirtyeight')
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.spines.left'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.bottom'] = False
sns.set_palette("tab10")

# mcmc hyperparameters
WARMUP_COUNT = 500
SAMPLE_COUNT = 1000
CHAIN_COUNT = 4
RANDOM_SEED = 1792

# simulation hyperparameters
SESSION_COUNT = 7
SUPERPOPULATION_SIZE = 400
APPARENT_SURVIVAL = 0.7
INITIAL_PI = 0.34
BASELINE_RATE = 0.25
SIGMA = 0.5
M = 1000
GRID_WIDTH = 10
OCCASION_COUNT = 20
BUFFER = 2
U_SIGMA = 5

# make the labels on arviz plots nicer
labeller = az.labels.MapLabeller(
    var_name_map={"psi": r"$\psi$", 'gamma':  r"$\gamma$", 'alpha': r'$\alpha$',
                  'epsilon': r"$\epsilon$", 'p':  r"$p$" , 'beta': r'$\beta$',
                  'phi': r'$\phi$', 'alpha_t': r'$\alpha_t$',}
)

def sim_js():
    """Simulate capture histories with a spatial Jolly-Seber model.

    Simulation code adapted from Kery and Schaub (2012), Chapter 10 and Gardner
    et al (2017) via Olivier Gimenez:
    https://oliviergimenez.github.io/basics_spatial_capturerecapture/
    """

    rng = np.random.default_rng(RANDOM_SEED)
    interval_count = SESSION_COUNT - 1

    # simulate entry into the population
    pi_rest = (1 - INITIAL_PI) / interval_count
    pi = np.concatenate([[INITIAL_PI], np.full(interval_count, pi_rest)])

    # which occasion did the animal enter in?
    entry_matrix = rng.multinomial(n=1, pvals=pi, size=SUPERPOPULATION_SIZE)
    entry_occasion = entry_matrix.nonzero()[1]
    _, entrant_count = np.unique(entry_occasion, return_counts=True)

    # zero if the animal has not yet entered and one after it enters
    entry_trajectory = np.maximum.accumulate(entry_matrix, axis=1)

    # flip coins for survival between occasions
    survival_draws = rng.binomial(
        1, APPARENT_SURVIVAL, (SUPERPOPULATION_SIZE, interval_count)
    )

    # add column such that survival between t and t+1 implies alive at t+1
    survival_draws = np.column_stack([np.ones(SUPERPOPULATION_SIZE), survival_draws])

    # ensure that the animal survives until it enters
    is_yet_to_enter = np.arange(SESSION_COUNT) <= entry_occasion[:, None]
    survival_draws[is_yet_to_enter] = 1

    # once the survival_draws flips to zero the remaining row stays 0
    survival_trajectory = np.cumprod(survival_draws, axis=1)

    # animal has entered AND is still alive
    state = entry_trajectory * survival_trajectory
    state = state.astype(np.int8)

    # create the trapping grid
    grid_height = GRID_WIDTH
    grid = product(range(GRID_WIDTH), range(grid_height))
    grid = np.array(list(grid))
    trap_count = len(grid)

    # define the state space
    minima = grid.min(axis=0) - BUFFER
    maxima = grid.max(axis=0) + BUFFER

    # simulate activity centers for every individual in the superpopulation
    center_coords = rng.uniform(minima, maxima, size=(SUPERPOPULATION_SIZE, 2))

    # true distance between the trap and the activity centers
    distance = jnp.linalg.norm(
        center_coords[:, None, :] - grid[None, :, :],
        axis=2
    )

    # simulate potential captures
    p = BASELINE_RATE * np.exp(-(distance**2) / (2 * SIGMA**2))
    capture = rng.binomial(OCCASION_COUNT, p[..., None],
                           size=(SUPERPOPULATION_SIZE, trap_count, SESSION_COUNT))

    # calculate actual captures, shape: (super_size, trap_count, session_count)
    capture_history = capture * state[:, np.newaxis, :]

    # remove the non-detected individuals
    was_captured = capture_history.sum(axis=(1, 2)) > 0
    capture_history = capture_history[was_captured]

    # augment the history with nz animals
    n = was_captured.sum()
    nz = M - n
    all_zero_history = np.zeros((nz, trap_count, SESSION_COUNT), dtype=np.int8)
    capture_history = np.vstack([capture_history, all_zero_history]).astype(np.int8)

    # return a dict with relevant summary stats
    N_t = state.sum(axis=0)

    # deaths at time t
    D_t = (state[:, :-1] * (1 - state[:, 1:])).sum(axis=0)

    # Print simulation summary
    print(f"Simulation complete: {n}/{SUPERPOPULATION_SIZE} individuals detected")
    print(f"Population trajectory: N = {N_t}")
    print(f"Recruits per session: B = {entrant_count}")
    print(f"Deaths per session: D = {D_t}")

    capture_history = np.moveaxis(capture_history, 2, 0)
    return {
        'capture_history': capture_history,
        'N_t': N_t,
        'B': entrant_count,
        'trapxy': grid,
        'trap_count': trap_count,
        'center_coords': center_coords,
        'z': state,
        'minima': minima,
        'maxima': maxima,
        'super_size': M,
        'occasion_count': OCCASION_COUNT,
        'session_count': SESSION_COUNT
    }

Like the simulation code, the NumPyro model is essentially the same as the NumPyro Jolly-Seber model. The difference is that the detection process is more complex, in that we need to simulate activity centers, compute the distance pairwise distance matrix from the activity centers and the traps, then compute the detection probability for each individual at each trap. Again, we will marginalize out the latent existential state, \(z\), and scan over the sampling occasions.

def spatial_js(data):

    # unpack everything in the data dictionary
    history = data['capture_history']
    trapxy = jnp.array(data['trapxy'])
    trap_count = data['trap_count']
    xmin, ymin = jnp.array(data['minima'])
    xmax, ymax = jnp.array(data['maxima'])
    super_size = data['super_size']
    occasion_count = data['occasion_count']
    session_count = data['session_count']

    ### Priors ###

    # specify gamma prior as a function of pi and psi
    psi = numpyro.sample('psi', dist.Beta(0.001, 1))
    pi = numpyro.sample('pi', dist.Dirichlet(jnp.ones(session_count)))

    # compute the removal probabilities as a function of psi and pi
    gamma = jnp.zeros(session_count)

    # the `vector.at[0].set(1)` notation is jax for `vector[0] = 1`
    gamma = gamma.at[0].set(psi * pi[0])
    for t in range(1, session_count):
        denominator = jnp.prod(1 - gamma[:t])
        gamma = gamma.at[t].set(psi * pi[t] / denominator)
    gamma = numpyro.deterministic('gamma', gamma)

    # apparent survival
    phi = numpyro.sample('phi', dist.Beta(2, 2))

    # baseline encounter rate
    g0 = numpyro.sample('alpha0', dist.Beta(4, 1))

    # detection range
    sigma = numpyro.sample('sigma', dist.Uniform(0, U_SIGMA))

    # sample activity center locations and compute capture probability at traps
    with numpyro.plate("activity_centers", super_size):
        sx = numpyro.sample('sx', dist.Uniform(xmin, xmax))
        sy = numpyro.sample('sy', dist.Uniform(ymin, ymax))
        center_coords = jnp.stack([sx, sy], axis=1)

        # pairwise distance between activity centers and the traps
        x_norm = jnp.sum(center_coords**2, axis=1, keepdims=True)
        y_norm = jnp.sum(trapxy**2, axis=1, keepdims=True)
        distance2 = x_norm + y_norm.T - 2 * center_coords @ trapxy.T

        # linear model for scr detection model
        p = g0 * jnp.exp( - distance2 / (2 * sigma ** 2))


    #### Transition existential states and compute scr likelihood ####


    def transition_and_capture(carry, y_current):

        # this is what the previous iteration of transition_and_capture returned
        z_previous, t = carry

        # transition probability matrix
        trans_probs = jnp.array([
            [1 - gamma[t], gamma[t],     0.0],  # From not yet entered
            [         0.0,      phi, 1 - phi],  # From alive
            [         0.0,      0.0,     1.0]   # From dead
        ])

        with numpyro.plate("animals", super_size, dim=-1):

            # probability of transitioning depends on the previous state
            mu_z_current = trans_probs[z_previous]
            z_current = numpyro.sample(
                "state",
                dist.Categorical(dist.util.clamp_probs(mu_z_current)),
                infer={"enumerate": "parallel"}
            )

            # probability of encountering the individual based on the state
            mu_y_current = jnp.where(z_current == 1, 1, 0) * p.T
            mu_y_current = dist.util.clamp_probs(mu_y_current)

            # likelihood of the scr data
            with numpyro.plate('traps', trap_count, dim=-2):
                numpyro.sample(
                    "photo_id_data",
                    dist.Binomial(occasion_count, mu_y_current),
                    obs=y_current.T
                )

        return (z_current, t + 1), None

    # start everyone in the not yet entered state
    state_init = jnp.zeros(super_size, dtype=jnp.int32)
    scan(
        transition_and_capture,
        (state_init, 0),
        history
    )

sim_data = sim_js()
Simulation complete: 300/400 individuals detected
Population trajectory: N = [133 124 128 145 138 162 154]
Recruits per session: B = [133  33  44  54  37  57  42]
Deaths per session: D = [42 40 37 44 33 50]

This is the first NumPyro model where the difference in performance between CPU and GPU is vast. As such, I will use the GPU to sample the models. To use GPU in NumPyro, your computer will have to have a CUDA compatible (i.e., NVIDIA) GPU. By default, these come with laptops and PCs designed for gaming. Alternatively, almost all university or governmental high performance computing clusters (HPCs) will have GPUs. Another option would be Google Colab, which allows people to temporarily rent GPU resources.

This model takes about 2 minutes to sample on GPU. While I am not quite patient enough to run it on CPU, it likely takes over an hour to sample on CPU.

from numpyro.infer import init_to_feasible

rng_key = random.PRNGKey(RANDOM_SEED)

# specify which sampler you want to use
nuts_kernel = NUTS(spatial_js)

# configure the MCMC run
mcmc = MCMC(nuts_kernel, num_warmup=WARMUP_COUNT, num_samples=SAMPLE_COUNT,
            num_chains=CHAIN_COUNT)

# run the MCMC
mcmc.run(rng_key, sim_data)
samples = mcmc.get_samples(group_by_chain=True)
idata = az.from_dict(samples)
idata.posterior['N'] = idata.posterior['psi'] * M

pi_rest = (1 - INITIAL_PI) / (SESSION_COUNT - 1)
pi = np.concat([[INITIAL_PI], np.full(SESSION_COUNT - 1, pi_rest)])

parameters = ['N', 'phi', 'alpha0', 'pi', 'sigma']
axes = az.plot_trace(
    idata,
    figsize=(8,10),
    var_names=parameters,
    labeller=labeller,
    lines=[("phi", {}, [APPARENT_SURVIVAL]), ("alpha0", {}, [BASELINE_RATE]),
           ("pi", {}, pi), ("N", {}, [SUPERPOPULATION_SIZE]),
           ('sigma', {}, SIGMA)],
)
plt.subplots_adjust(hspace=0.4)
[axes[i, 0].grid(False) for i in range(len(parameters))]
plt.show()

As we can see, the model returns parameter estimates well, although some of the entry probabilities, \(pi_t\) are clearly biased. Again, we can recover discrete latent states with Predictive.

def sample_z(model, posterior_samples, data):
    '''Samples the posterior predictive distribution for z given the histories'''

    # initialize the posterior predictive distribution
    predictive = Predictive(
        model,
        posterior_samples=posterior_samples,
        return_sites=["state"]
    )

    # sample z
    rng_key = random.PRNGKey(RANDOM_SEED)
    latent_samples = predictive(rng_key, data)
    return latent_samples["state"]

# generate the posterior predictive distribution for N
samples = mcmc.get_samples()
z = sample_z(spatial_js, samples, sim_data)
ever_alive = z.max(axis=1) > 0
samples['N'] = ever_alive.sum(axis=1)
fig, ax = plt.subplots(figsize=(6,4))

t = np.arange(SESSION_COUNT)

N_t_samps = (z == 1).sum(axis=2)

ax.scatter(t, sim_data['N_t'], color='C0', marker='d', label='Truth', s=60)
ax.violinplot(N_t_samps, t, showmedians=False, showextrema=False)

ax.legend()
ax.set_xlabel(r'Season')
ax.set_ylabel(r'Abundance')
ax.set_title(r'Population trajectory')

plt.show()

The model is able to recreate the population trajectory over time. While some estimates are slightly off, they are all well within the estimates of uncertainty.