diff --git a/cascade/auditor.py b/cascade/auditor.py index 64228ba..8379434 100644 --- a/cascade/auditor.py +++ b/cascade/auditor.py @@ -55,16 +55,16 @@ def audit(self, n_audits: int, sort_audits: bool = False, ) -> tuple[float, list[int]]: - """ - Args: - atoms: list of ase atoms. Should have atoms.info['forces_ens'] - set by the cascade EnsembleCalculator - n_audits: number of frames to return - sort_audits: in this case does nothing, since there is no UQ order - Returns: - p_any: drawn from Unif(0,1) - audit_frames: indices of the frams that were sampled - """ + """ + Args: + atoms: list of ase atoms. Should have atoms.info['forces_ens'] + set by the cascade EnsembleCalculator + n_audits: number of frames to return + sort_audits: in this case does nothing, since there is no UQ order + Returns: + p_any: drawn from Unif(0,1) + audit_frames: indices of the frams that were sampled + """ score = self.rng.uniform(0, 1) ix = self.rng.choice(a=len(atoms), @@ -72,6 +72,7 @@ def audit(self, replace=False) return score, ix + class ForceThresholdAuditor(BaseAuditor): """Determines the likelihood all calculations have error below the threshold, based on ensemble variance