"""NumPyro-specific conversion code."""
import logging
from typing import Callable, Optional
import numpy as np
from .. import utils
from ..rcparams import rcParams
from .base import dict_to_dataset, requires
from .inference_data import InferenceData
_log = logging.getLogger(__name__)
class NestedToMCMCAdapter:
"""
Adapter to convert a NestedSampler object into an MCMC-compatible interface.
This class reshapes posterior samples from a NestedSampler into a chain-and-draw
structure expected by MCMC workflows, providing compatibility with downstream
tools like ArviZ for posterior analysis.
Parameters
----------
nested_sampler : numpyro.contrib.nested_sampling.NestedSampler
The NestedSampler object containing posterior samples.
rng_key : jax.random.PRNGKey
The random key used for sampling.
num_samples : int
The total number of posterior samples to draw.
num_chains : int, optional
The number of artificial chains to create for MCMC compatibility (default is 1).
*args : tuple
Additional positional arguments required by the model (e.g., data, labels).
**kwargs : dict
Additional keyword arguments required by the model.
Attributes
----------
samples : dict
Reshaped posterior samples organized by variable name.
thinning : int
Dummy thinning attribute for compatibility with MCMC.
sampler : NestedToMCMCAdapter
Mimics the sampler attribute of an MCMC object.
model : callable
The probabilistic model used in the NestedSampler.
_args : tuple
Positional arguments passed to the model.
_kwargs : dict
Keyword arguments passed to the model.
Methods
-------
get_samples(group_by_chain=True)
Returns posterior samples reshaped by chain or flattened if `group_by_chain` is False.
get_extra_fields(group_by_chain=True)
Provides dummy sampling statistics like accept probabilities, step sizes, and num_steps.
"""
def __init__(self, nested_sampler, rng_key, num_samples, *args, num_chains=1, **kwargs):
self.nested_sampler = nested_sampler
self.rng_key = rng_key
self.num_samples = num_samples
self.num_chains = num_chains
self.samples = self._reshape_samples()
self.thinning = 1
self.sampler = self
self.model = nested_sampler.model
self._args = args
self._kwargs = kwargs
def _reshape_samples(self):
raw_samples = self.nested_sampler.get_samples(self.rng_key, self.num_samples)
samples_per_chain = self.num_samples // self.num_chains
return {
k: np.reshape(
v[: samples_per_chain * self.num_chains],
(self.num_chains, samples_per_chain, *v.shape[1:]),
)
for k, v in raw_samples.items()
}
def get_samples(self, group_by_chain=True):
if group_by_chain:
return self.samples
else:
# Flatten chains into a single dimension
return {k: v.reshape(-1, *v.shape[2:]) for k, v in self.samples.items()}
def get_extra_fields(self, group_by_chain=True):
# Generate dummy fields since NestedSampler does not produce these
n_chains = self.num_chains
n_samples = self.num_samples // self.num_chains
# Create dummy values for extra fields
extra_fields = {
"accept_prob": np.full((n_chains, n_samples), 1.0), # Assume all proposals are accepted
"step_size": np.full((n_chains, n_samples), 0.1), # Dummy step size
"num_steps": np.full((n_chains, n_samples), 10), # Dummy number of steps
}
if not group_by_chain:
# Flatten the chains into a single dimension
extra_fields = {k: v.reshape(-1, *v.shape[2:]) for k, v in extra_fields.items()}
return extra_fields
class NumPyroConverter:
"""Encapsulate NumPyro specific logic."""
# pylint: disable=too-many-instance-attributes
model = None # type: Optional[Callable]
nchains = None # type: int
ndraws = None # type: int
def __init__(
self,
*,
posterior=None,
prior=None,
posterior_predictive=None,
predictions=None,
constant_data=None,
predictions_constant_data=None,
log_likelihood=None,
index_origin=None,
coords=None,
dims=None,
pred_dims=None,
num_chains=1,
rng_key=None,
num_samples=1000,
data=None,
labels=None,
):
"""Convert NumPyro data into an InferenceData object.
Parameters
----------
posterior : numpyro.mcmc.MCMC
Fitted MCMC object from NumPyro
prior: dict
Prior samples from a NumPyro model
posterior_predictive : dict
Posterior predictive samples for the posterior
predictions: dict
Out of sample predictions
constant_data: dict
Dictionary containing constant data variables mapped to their values.
predictions_constant_data: dict
Constant data used for out-of-sample predictions.
index_origin : int, optional
coords : dict[str] -> list[str]
Map of dimensions to coordinates
dims : dict[str] -> list[str]
Map variable names to their coordinates
pred_dims: dict
Dims for predictions data. Map variable names to their coordinates.
num_chains: int
Number of chains used for sampling. Ignored if posterior is present.
"""
import jax
import numpyro
self.posterior = posterior
self.rng_key = rng_key
self.num_samples = num_samples
if isinstance(posterior, numpyro.contrib.nested_sampling.NestedSampler):
posterior = NestedToMCMCAdapter(
posterior, rng_key, num_samples, num_chains=num_chains, data=data, labels=labels
)
self.posterior = posterior
self.prior = jax.device_get(prior)
self.posterior_predictive = jax.device_get(posterior_predictive)
self.predictions = predictions
self.constant_data = constant_data
self.predictions_constant_data = predictions_constant_data
self.log_likelihood = (
rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
)
self.index_origin = rcParams["data.index_origin"] if index_origin is None else index_origin
self.coords = coords
self.dims = dims
self.pred_dims = pred_dims
self.numpyro = numpyro
def arbitrary_element(dct):
return next(iter(dct.values()))
if posterior is not None:
samples = jax.device_get(self.posterior.get_samples(group_by_chain=True))
if hasattr(samples, "_asdict"):
# In case it is easy to convert to a dictionary, as in the case of namedtuples
samples = samples._asdict()
if not isinstance(samples, dict):
# handle the case we run MCMC with a general potential_fn
# (instead of a NumPyro model) whose args is not a dictionary
# (e.g. f(x) = x ** 2)
tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0]
samples = {
f"Param:{i}": jax.device_get(v) for i, v in enumerate(tree_flatten_samples)
}
self._samples = samples
self.nchains, self.ndraws = (
posterior.num_chains,
posterior.num_samples // posterior.thinning,
)
self.model = self.posterior.sampler.model
# model arguments and keyword arguments
self._args = self.posterior._args # pylint: disable=protected-access
self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
else:
self.nchains = num_chains
get_from = None
if predictions is not None:
get_from = predictions
elif posterior_predictive is not None:
get_from = posterior_predictive
elif prior is not None:
get_from = prior
if get_from is None and constant_data is None and predictions_constant_data is None:
raise ValueError(
"When constructing InferenceData must have at least"
" one of posterior, prior, posterior_predictive or predictions."
)
if get_from is not None:
aelem = arbitrary_element(get_from)
self.ndraws = aelem.shape[0] // self.nchains
observations = {}
if self.model is not None:
# we need to use an init strategy to generate random samples for ImproperUniform sites
seeded_model = numpyro.handlers.substitute(
numpyro.handlers.seed(self.model, jax.random.PRNGKey(0)),
substitute_fn=numpyro.infer.init_to_sample,
)
trace = numpyro.handlers.trace(seeded_model).get_trace(*self._args, **self._kwargs)
observations = {
name: site["value"]
for name, site in trace.items()
if site["type"] == "sample" and site["is_observed"]
}
self.observations = observations if observations else None
@requires("posterior")
def posterior_to_xarray(self):
"""Convert the posterior to an xarray dataset."""
data = self._samples
return dict_to_dataset(
data,
library=self.numpyro,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
)
@requires("posterior")
def sample_stats_to_xarray(self):
"""Extract sample_stats from NumPyro posterior."""
rename_key = {
"potential_energy": "lp",
"adapt_state.step_size": "step_size",
"num_steps": "n_steps",
"accept_prob": "acceptance_rate",
}
data = {}
for stat, value in self.posterior.get_extra_fields(group_by_chain=True).items():
if isinstance(value, (dict, tuple)):
continue
name = rename_key.get(stat, stat)
value = value.copy()
data[name] = value
if stat == "num_steps":
data["tree_depth"] = np.log2(value).astype(int) + 1
return dict_to_dataset(
data,
library=self.numpyro,
dims=None,
coords=self.coords,
index_origin=self.index_origin,
)
@requires("posterior")
@requires("model")
def log_likelihood_to_xarray(self):
"""Extract log likelihood from NumPyro posterior."""
if not self.log_likelihood:
return None
data = {}
if self.observations is not None:
samples = self.posterior.get_samples(group_by_chain=False)
if hasattr(samples, "_asdict"):
samples = samples._asdict()
log_likelihood_dict = self.numpyro.infer.log_likelihood(
self.model, samples, *self._args, **self._kwargs
)
for obs_name, log_like in log_likelihood_dict.items():
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
data[obs_name] = np.reshape(np.asarray(log_like), shape)
return dict_to_dataset(
data,
library=self.numpyro,
dims=self.dims,
coords=self.coords,
index_origin=self.index_origin,
skip_event_dims=True,
)
def translate_posterior_predictive_dict_to_xarray(self, dct, dims):
"""Convert posterior_predictive or prediction samples to xarray."""
data = {}
for k, ary in dct.items():
shape = ary.shape
if shape[0] == self.nchains and shape[1] == self.ndraws:
data[k] = ary
elif shape[0] == self.nchains * self.ndraws:
data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
else:
data[k] = utils.expand_dims(ary)
_log.warning(
"posterior predictive shape not compatible with number of chains and draws. "
"This can mean that some draws or even whole chains are not represented."
)
return dict_to_dataset(
data,
library=self.numpyro,
coords=self.coords,
dims=dims,
index_origin=self.index_origin,
)
@requires("posterior_predictive")
def posterior_predictive_to_xarray(self):
"""Convert posterior_predictive samples to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(
self.posterior_predictive, self.dims
)
@requires("predictions")
def predictions_to_xarray(self):
"""Convert predictions to xarray."""
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, self.pred_dims)
def priors_to_xarray(self):
"""Convert prior samples (and if possible prior predictive too) to xarray."""
if self.prior is None:
return {"prior": None, "prior_predictive": None}
if self.posterior is not None:
prior_vars = list(self._samples.keys())
prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
else:
prior_vars = self.prior.keys()
prior_predictive_vars = None
priors_dict = {
group: (
None
if var_names is None
else dict_to_dataset(
{k: utils.expand_dims(self.prior[k]) for k in var_names},
library=self.numpyro,
coords=self.coords,
dims=self.dims,
index_origin=self.index_origin,
)
)
for group, var_names in zip(
("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
)
}
return priors_dict
@requires("observations")
@requires("model")
def observed_data_to_xarray(self):
"""Convert observed data to xarray."""
return dict_to_dataset(
self.observations,
library=self.numpyro,
dims=self.dims,
coords=self.coords,
default_dims=[],
index_origin=self.index_origin,
)
@requires("constant_data")
def constant_data_to_xarray(self):
"""Convert constant_data to xarray."""
return dict_to_dataset(
self.constant_data,
library=self.numpyro,
dims=self.dims,
coords=self.coords,
default_dims=[],
index_origin=self.index_origin,
)
@requires("predictions_constant_data")
def predictions_constant_data_to_xarray(self):
"""Convert predictions_constant_data to xarray."""
return dict_to_dataset(
self.predictions_constant_data,
library=self.numpyro,
dims=self.pred_dims,
coords=self.coords,
default_dims=[],
index_origin=self.index_origin,
)
def to_inference_data(self):
"""Convert all available data to an InferenceData object.
Note that if groups can not be created (i.e., there is no `trace`, so
the `posterior` and `sample_stats` can not be extracted), then the InferenceData
will not have those groups.
"""
return InferenceData(
**{
"posterior": self.posterior_to_xarray(),
"sample_stats": self.sample_stats_to_xarray(),
"log_likelihood": self.log_likelihood_to_xarray(),
"posterior_predictive": self.posterior_predictive_to_xarray(),
"predictions": self.predictions_to_xarray(),
**self.priors_to_xarray(),
"observed_data": self.observed_data_to_xarray(),
"constant_data": self.constant_data_to_xarray(),
"predictions_constant_data": self.predictions_constant_data_to_xarray(),
}
)
[docs]
def from_numpyro(
posterior=None,
*,
prior=None,
posterior_predictive=None,
predictions=None,
constant_data=None,
predictions_constant_data=None,
log_likelihood=None,
index_origin=None,
coords=None,
dims=None,
pred_dims=None,
num_chains=1,
rng_key=None,
num_samples=1000,
data=None,
labels=None,
):
"""Convert NumPyro data into an InferenceData object.
For a usage example read the
:ref:`Creating InferenceData section on from_numpyro <creating_InferenceData>`
Parameters
----------
posterior : numpyro.mcmc.MCMC
Fitted MCMC object from NumPyro
prior: dict
Prior samples from a NumPyro model
posterior_predictive : dict
Posterior predictive samples for the posterior
predictions: dict
Out of sample predictions
constant_data: dict
Dictionary containing constant data variables mapped to their values.
predictions_constant_data: dict
Constant data used for out-of-sample predictions.
index_origin : int, optional
coords : dict[str] -> list[str]
Map of dimensions to coordinates
dims : dict[str] -> list[str]
Map variable names to their coordinates
pred_dims: dict
Dims for predictions data. Map variable names to their coordinates.
num_chains: int
Number of chains used for sampling. Ignored if posterior is present.
"""
return NumPyroConverter(
posterior=posterior,
prior=prior,
posterior_predictive=posterior_predictive,
predictions=predictions,
constant_data=constant_data,
predictions_constant_data=predictions_constant_data,
log_likelihood=log_likelihood,
index_origin=index_origin,
coords=coords,
dims=dims,
pred_dims=pred_dims,
num_chains=num_chains,
rng_key=rng_key,
num_samples=num_samples,
data=data,
labels=labels,
).to_inference_data()