Skip to content

Commit

Permalink
fix rng seeds and add nans
Browse files Browse the repository at this point in the history
  • Loading branch information
phinate committed Sep 20, 2024
1 parent a1d6ec9 commit c32d074
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Test if metrics match the legacy metrics"""

import inspect
from functools import partial
from typing import cast

import jax.numpy as jnp
import jax.random as jr
import numpy as np
import pytest
from jaxtyping import Array, Float32
Expand All @@ -29,6 +32,10 @@ def apply_pix_metric(metric_func, y_hat, y) -> Float32[Array, "batch channels ti
y_jax = jnp.array(y).reshape(-1, *y.shape[-2:])[..., np.newaxis]
y_hat_jax = jnp.array(y_hat).reshape(-1, *y_hat.shape[-2:])[..., np.newaxis]

sig = inspect.signature(metric_func)
if "ignore_nans" in sig.parameters:
metric_func = partial(metric_func, ignore_nans=True)

# we reshape the result back into [batch, channels, time],
# then take the mean over the batch
return cast(Float32[Array, "batch channels time"], metric_func(y_hat_jax, y_jax)).reshape(
Expand All @@ -47,8 +54,14 @@ def apply_pix_metric(metric_func, y_hat, y) -> Float32[Array, "batch channels ti
def test_metrics(metric_func, legacy_func):
"""Test if metrics match the legacy metrics"""
# Create a sample input batch
y_hat = np.random.rand(1, 3, 10, 100, 100)
y = np.random.rand(1, 3, 10, 100, 100)
shape = (1, 3, 10, 100, 100)
key = jr.key(321)
key, k1, k2 = jr.split(key, 3)
y_hat = jr.uniform(k1, shape, minval=0, maxval=1)
y = jr.uniform(k2, shape, minval=0, maxval=1)

# Add NaNs to the input
y = y.at[:, :, :, 0, 0].set(np.nan)

# Call the metric function
metric = apply_pix_metric(metric_func, y_hat, y).mean(axis=0)
Expand All @@ -59,7 +72,7 @@ def test_metrics(metric_func, legacy_func):
# Check the values of the output
legacy_res = legacy_func(y_hat, y)

# Lower tolerance for ssim
# Lower tolerance for ssim (differences in implementation)
rtol = 0.001 if metric_func == ssim else 1e-5

assert np.allclose(
Expand Down

0 comments on commit c32d074

Please sign in to comment.