diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 713760a4..c101af4c 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -419,7 +419,19 @@ def _get_index( Index = AnnotatedSpectrumIndex if annotated else SpectrumIndex valid_charge = np.arange(1, self.config.max_charge + 1) - return Index(index_fname, filenames, valid_charge=valid_charge) + + try: + return Index(index_fname, filenames, valid_charge=valid_charge) + except TypeError as e: + if Index == AnnotatedSpectrumIndex: + raise TypeError( + "Error creating annotated spectrum index. " + "This may be the result of having an unannotated MGF file " + "present in the validation peak file path list.\n" + f"Original error message: {e}" + ) + + raise e def _get_strategy(self) -> Union[str, DDPStrategy]: """Get the strategy for the Trainer. diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index c10f035f..db21725a 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -198,20 +198,21 @@ def test_evaluate( assert result_file.is_file() result_file.unlink() - # Test evaluation with unannotated peak files - # NOTE: Depth Charge raises a TypeError exception when an unannotated - # peak file is in the peak file list when initializing a - # AnnotatedSpectrumIndex - # TODO: Reexamine after Depth Charge v0.4 release + exception_string = ( + "Error creating annotated spectrum index. " + "This may be the result of having an unannotated MGF file " + "present in the validation peak file path list.\n" + ) + with pytest.raises(FileNotFoundError): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict([mzml_small], result_file, evaluate=True) - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=exception_string): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict([mgf_small_unannotated], result_file, evaluate=True) - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=exception_string): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict( [mgf_small_unannotated, mzml_small], result_file, evaluate=True @@ -230,7 +231,7 @@ def test_evaluate( assert result_file.is_file() result_file.unlink() - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=exception_string): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict( [mgf_small, mgf_small_unannotated], result_file, evaluate=True @@ -239,7 +240,7 @@ def test_evaluate( assert result_file.is_file() result_file.unlink() - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=exception_string): with ModelRunner(config, model_filename=str(model_file)) as runner: runner.predict( [mgf_small, mgf_small_unannotated, mzml_small],