Skip to content

Commit

Permalink
Hm/all deterministic (#1914)
Browse files Browse the repository at this point in the history
* Skip printing summary if empty.

* Post-process when no sample sites present.

Current post-processing behaviour skips models with only deterministic variables. Applying this change will return consistent samples regardless of whether `sample` sites are present.
  • Loading branch information
hessammehr authored Nov 22, 2024
1 parent 0e7bd20 commit f87f40e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
4 changes: 4 additions & 0 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def summary(

summary_dict = {}
for name, value in samples.items():
if len(value) == 0:
continue
value = device_get(value)
value_flat = np.reshape(value, (-1,) + value.shape[2:])
mean = value_flat.mean(axis=0)
Expand Down Expand Up @@ -307,6 +309,8 @@ def print_summary(
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
}
summary_dict = summary(samples, prob, group_by_chain=True)
if not summary_dict:
return

row_names = {
k: k + "[" + ",".join(map(lambda x: str(x - 1), v.shape[2:])) + "]"
Expand Down
20 changes: 16 additions & 4 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def collect_and_postprocess(x):
if collect_fields:
fields = nested_attrgetter(*collect_fields)(x[0])
fields = [fields] if len(collect_fields) == 1 else list(fields)
site_values = jax.tree.flatten(fields[0])[0]
if len(site_values) > 0:
fields[0] = postprocess_fn(fields[0], *x[1:])
fields[0] = postprocess_fn(fields[0], *x[1:])

if remove_sites != ():
assert isinstance(fields[0], dict)
Expand Down Expand Up @@ -400,13 +398,27 @@ def _get_cached_fns(self):
fns, key = None, None
if fns is None:

def ensure_vmap(fn, batch_size=None):
def wrapper(x):
x_arrays = jax.tree.flatten(x)[0]
if len(x_arrays) > 0:
return vmap(fn)(x)
else:
assert batch_size is not None
return jax.tree.map(
lambda x: jnp.broadcast_to(x, (batch_size,) + jnp.shape(x)),
fn(x),
)

return wrapper

def _postprocess_fn(state, args, kwargs):
if self.postprocess_fn is None:
body_fn = self.sampler.postprocess_fn(args, kwargs)
else:
body_fn = self.postprocess_fn
if self.chain_method == "vectorized" and self.num_chains > 1:
body_fn = vmap(body_fn)
body_fn = ensure_vmap(body_fn, batch_size=self.num_chains)

return body_fn(state)

Expand Down
26 changes: 26 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,3 +1208,29 @@ def model():
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), extra_fields=("z.x",))
assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"]))


def test_all_deterministic():
def model1():
numpyro.deterministic("x", 1.0)

def model2():
numpyro.deterministic("x", jnp.array([1.0, 2.0]))

num_samples = 10
shapes = {model1: (), model2: (2,)}

for model, shape in shapes.items():
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=num_samples)
mcmc.run(random.PRNGKey(0))
assert mcmc.get_samples()["x"].shape == (num_samples,) + shape


def test_empty_summary():
def model():
pass

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))

mcmc.print_summary()

0 comments on commit f87f40e

Please sign in to comment.