From cfa264701e96e724100d453dc2032c542124bf69 Mon Sep 17 00:00:00 2001 From: Mark Messner Date: Thu, 14 Sep 2023 11:18:33 -0500 Subject: [PATCH] Fix nans with nan_to_num --- pyoptmat/optimize.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pyoptmat/optimize.py b/pyoptmat/optimize.py index 72b79b3..c4f587c 100644 --- a/pyoptmat/optimize.py +++ b/pyoptmat/optimize.py @@ -189,7 +189,7 @@ class StatisticalModel(PyroModule): entry i represents the noise in test type i """ - def __init__(self, maker, names, locs, scales, eps): + def __init__(self, maker, names, locs, scales, eps, nan_num = False): super().__init__() self.maker = maker @@ -203,6 +203,8 @@ def __init__(self, maker, names, locs, scales, eps): self.type_noise = self.eps.dim() > 0 + self.nan_num = nan_num + def get_params(self): """ Return the sampled parameters for input to the model @@ -235,9 +237,12 @@ def forward(self, exp_data, exp_cycles, exp_types, exp_control, exp_results=None predictions[:, :, 0], exp_cycles, exp_types ) + if self.nan_num: + results = torch.nan_to_num(results) + # Setup the full noise, which can be type specific if self.type_noise: - full_noise = torch.empty(exp_data.shape[-1]) + full_noise = torch.empty(exp_data.shape[-1], device = self.eps.device) for i in experiments.exp_map.values(): full_noise[exp_types == i] = self.eps[i] else: