Skip to content

Commit

Permalink
WIP: Refactoring the JointDistribution
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Jun 28, 2024
1 parent f4fa4f4 commit 507e2f4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 12 deletions.
18 changes: 9 additions & 9 deletions src/bmi/samplers/_tfp/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ class JointDistribution:
$P_X$ and $P_Y$.
Attributes:
dist: $P_{XY}$
dist_x: $P_X$
dist_y: $P_Y$
dist_joint: $P_{XY}$. Each sample is a *tuple* `(xs, ys)`
where `xs` is of shape `(n_samples, dim_x)` and
`ys` is of shape `(n_samples, dim_y)`.
dist_x: $P_X$. Samples are of shape `(n_samples, dim_x)`
dist_y: $P_Y$. Samples are of shape `(n_samples, dim_y,)`
dim_x: dimension of the support of $X$
dim_y: dimension of the support of $Y$
analytic_mi: analytical mutual information.
Expand All @@ -43,8 +45,7 @@ def sample(self, n_points: int, key: jax.Array) -> tuple[jnp.ndarray, jnp.ndarra
if n_points < 1:
raise ValueError("n must be positive")

xy = self.dist_joint.sample(seed=key, sample_shape=(n_points,))
return xy[..., : self.dim_x], xy[..., self.dim_x :] # noqa: E203 (formatting discrepancy)
return self.dist_joint.sample(n_points, key)

def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Calculates pointwise mutual information at specified points.
Expand All @@ -60,7 +61,7 @@ def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
Note:
This function is vectorized, i.e. it can calculate PMI for multiple points at once.
"""
log_pxy = self.dist_joint.log_prob(jnp.hstack([x, y]))
log_pxy = self.dist_joint.log_prob((x, y))
log_px = self.dist_x.log_prob(x)
log_py = self.dist_y.log_prob(y)

Expand Down Expand Up @@ -136,9 +137,8 @@ def transform(
if y_transform is None:
y_transform = tfb.Identity()

product_bijector = tfb.Blockwise(
bijectors=[x_transform, y_transform], block_sizes=[dist.dim_x, dist.dim_y]
)
product_bijector = tfb.JointMap((x_transform, y_transform))

return JointDistribution(
dim_x=dist.dim_x,
dim_y=dist.dim_y,
Expand Down
8 changes: 7 additions & 1 deletion src/bmi/samplers/_tfp/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bmi.samplers._tfp._core import JointDistribution

jtf = tfp.tf2jax
tfb = tfp.bijectors
tfd = tfp.distributions


Expand Down Expand Up @@ -55,7 +56,12 @@ def __init__(
# Now we need to define the TensorFlow Probability distributions
# using the information provided

dist_joint = construct_multivariate_normal_distribution(mean=mean, covariance=covariance)
_dist_joint = construct_multivariate_normal_distribution(mean=mean, covariance=covariance)
dist_joint = tfd.TransformedDistribution(
distribution=_dist_joint,
bijector=tfb.Split((dim_x, dim_y)),
)

dist_x = construct_multivariate_normal_distribution(
mean=mean[:dim_x], covariance=covariance[:dim_x, :dim_x]
)
Expand Down
2 changes: 1 addition & 1 deletion src/bmi/samplers/_tfp/_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, dist_x: tfd.Distribution, dist_y: tfd.Distribution) -> None:
dim_x = int(dims_x[0])
dim_y = int(dims_y[0])

dist_joint = tfd.Blockwise([dist_x, dist_y])
dist_joint = tfd.JointDistributionSequential((dist_x, dist_y))

super().__init__(
dim_x=dim_x,
Expand Down
8 changes: 7 additions & 1 deletion src/bmi/samplers/_tfp/_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bmi.samplers._tfp._core import JointDistribution

jtf = tfp.tf2jax
tfb = tfp.bijectors
tfd = tfp.distributions


Expand Down Expand Up @@ -78,9 +79,14 @@ def __init__(
# Now we need to define the TensorFlow Probability distributions
# using the information provided

dist_joint = construct_multivariate_student_distribution(
_dist_joint = construct_multivariate_student_distribution(
mean=mean, dispersion=dispersion, df=df
)
dist_joint = tfd.TransformedDistribution(
distribution=_dist_joint,
bijector=tfb.Split((dim_x, dim_y)),
)

dist_x = construct_multivariate_student_distribution(
mean=mean[:dim_x], dispersion=dispersion[:dim_x, :dim_x], df=df
)
Expand Down

0 comments on commit 507e2f4

Please sign in to comment.