Skip to content

Commit

Permalink
fixed shape calls to tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
janisshin committed Jun 6, 2024
1 parent f46cd46 commit 57177f1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/emll/data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def create_noisy_observations_of_computed_values(name:str, computed_tensor:T.ten
raise ValueError("Data shape does not match standard deviation shape!")

# check shape of data == computed_tensor
if any(data.shape != computed_tensor.shape.eval()):
if any(data.shape != computed_tensor.eval().shape):
raise ValueError(f"Data shape {data.shape} does not match computed tensor shape {pytensor.tensor.shape(computed_tensor)}!")

# check that standard deviations are all positive
Expand Down
2 changes: 1 addition & 1 deletion src/emll/linlog_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def steady_state_pytensor(self, Ex, Ey=None, en=None, yn=None, method="scan"):
en = np.atleast_2d(en)
n_exp = en.shape[0]
else:
n_exp = en.shape.eval()[0]
n_exp = en.eval().shape[0]

if isinstance(yn, np.ndarray):
yn = np.atleast_2d(yn)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ def test_create_pytensor_from_data():
data_tensor = create_pytensor_from_data_naive(input_string, input_dataframe_mixed, input_stdev_dataframe_mixed, input_laplace_dataframe_mixed)

# check that the n-dim and shape are correct
assert np.size(data_tensor.shape.eval()) == np.size(input_dataframe_mixed.shape), f"Expected tensor n dim {np.size(input_dataframe_mixed.shape)}, found {np.size(data_tensor.shape.eval())}"
assert tuple(data_tensor.shape.eval()) == input_dataframe_mixed.shape, f"Expected tensor shape {input_dataframe_mixed.shape}, found {data_tensor.shape.eval()}"
assert np.size(data_tensor.eval().shape) == np.size(input_dataframe_mixed.shape), f"Expected tensor n dim {np.size(input_dataframe_mixed.shape)}, found {np.size(data_tensor.eval().shape)}"
assert tuple(data_tensor.eval().shape) == input_dataframe_mixed.shape, f"Expected tensor shape {input_dataframe_mixed.shape}, found {data_tensor.eval().shape}"

# Check that the steady_state_pytensor method runs
error_occurred = False
Expand All @@ -422,8 +422,8 @@ def test_create_pytensor_from_data():
y_n = pm.Normal("yn_t",
mu=0,
sigma=10,
shape=(data_tensor.shape.eval()[0], lin_log.ny),
initval=0.1 * np.random.randn(data_tensor.shape.eval()[0], lin_log.ny)
shape=(data_tensor.eval().shape[0], lin_log.ny),
initval=0.1 * np.random.randn(data_tensor.eval().shape[0], lin_log.ny)
)

chi_ss, vn_ss = lin_log.steady_state_pytensor(
Expand Down

0 comments on commit 57177f1

Please sign in to comment.