Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ml classifier #1963

Merged
merged 9 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ctapipe/ml/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class ParticleIdClassifier(ClassificationReconstructor):
def __call__(self, event: ArrayEventContainer) -> None:
for tel_id in event.trigger.tels_with_trigger:
features = self._collect_features(event, tel_id)
prediction, valid = self.model.predict(
prediction, valid = self.model.predict_score(
self.subarray.tel[tel_id],
features,
)
Expand Down
26 changes: 24 additions & 2 deletions ctapipe/ml/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from ctapipe.utils.datasets import resource_file
from ctapipe.core import run_tool

@pytest.fixture(scope='session')

@pytest.fixture(scope="session")
def model_tmp_path(tmp_path_factory):
return tmp_path_factory.mktemp("models")


@pytest.fixture(scope='session')
@pytest.fixture(scope="session")
def energy_regressor_path(model_tmp_path):
from ctapipe.ml.tools.train_energy_regressor import TrainEnergyRegressor

Expand All @@ -25,3 +26,24 @@ def energy_regressor_path(model_tmp_path):
)
assert ret == 0
return out_file


@pytest.fixture(scope="session")
def particle_classifier_path(model_tmp_path):
from ctapipe.ml.tools.train_particle_classifier import TrainParticleIdClassifier

tool = TrainParticleIdClassifier()
config = resource_file("ml-config.yaml")
out_file = model_tmp_path / "particle_classifier.pkl"
ret = run_tool(
tool,
argv=[
"--input-background=dataset://proton_dl2_train_small.dl2.h5",
"--input-signal=dataset://gamma_diffuse_dl2_train_small.dl2.h5",
f"--output={out_file}",
f"--config={config}",
"--log-level=INFO",
],
)
assert ret == 0
return out_file
3 changes: 2 additions & 1 deletion ctapipe/ml/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .train_energy_regressor import TrainEnergyRegressor
from .train_particle_classifier import TrainParticleIdClassifier

__all__ = ["TrainEnergyRegressor"]
__all__ = ["TrainEnergyRegressor", "TrainParticleIdClassifier"]
145 changes: 145 additions & 0 deletions ctapipe/ml/tools/apply_particle_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from astropy.table.operations import vstack
import tables
from ctapipe.core.tool import Tool
from ctapipe.core.traits import Bool, Path, flag, create_class_enum_trait
from ctapipe.io import TableLoader, write_table
from tqdm.auto import tqdm
from ctapipe.io.tableio import TelListToMaskTransform
import numpy as np

from ..sklearn import Classifier
from ..apply import ParticleIdClassifier
from ..stereo_combination import StereoCombiner


class ApplyParticleIdClassifier(Tool):

overwrite = Bool(default_value=False).tag(config=True)

input_url = Path(
default_value=None,
allow_none=False,
directory_ok=False,
exists=True,
).tag(config=True)

model_path = Path(
default_value=None, allow_none=False, exists=True, directory_ok=False
).tag(config=True)

stereo_combiner_type = create_class_enum_trait(
base_class=StereoCombiner, default_value="StereoMeanCombiner"
).tag(config=True)

aliases = {
("i", "input"): "ApplyParticleIdClassifier.input_url",
("m", "model"): "ApplyParticleIdClassifier.model_path",
}

flags = {
**flag(
"overwrite",
"ApplyParticleIdClassifier.overwrite",
"Overwrite tables in output file if it exists",
"Don't overwrite tables in output file if it exists",
),
"f": (
{"ApplyParticleIdClassifier": {"overwrite": True}},
"Overwrite output file if it exists",
),
}

classes = [
TableLoader,
Classifier,
StereoCombiner,
]

def setup(self):
""""""
self.h5file = tables.open_file(self.input_url, mode="r+")
self.loader = TableLoader(
parent=self,
h5file=self.h5file,
load_dl1_images=False,
load_dl1_parameters=True,
load_dl2=True,
load_simulated=True,
load_instrument=True,
)
self.estimator = ParticleIdClassifier.read(
self.model_path,
self.loader.subarray,
parent=self,
)
self.combine = StereoCombiner.from_name(
self.stereo_combiner_type,
combine_property="classification",
algorithm=self.estimator.model.model_cls,
parent=self,
)

def start(self):
self.log.info("Applying model")

tables = []
for tel_id, tel in tqdm(self.loader.subarray.tel.items()):
if tel not in self.estimator.model.models:
self.log.warning(
"No model for telescope type %s, skipping tel %d",
tel,
tel_id,
)
continue

table = self.loader.read_telescope_events([tel_id])
if len(table) == 0:
self.log.warning("No events for telescope %d", tel_id)
continue

prediction, valid = self.estimator.predict(tel, table)
prefix = self.estimator.model.model_cls

class_col = f"{prefix}_prediction"
valid_col = f"{prefix}_is_valid"
table[class_col] = prediction
table[valid_col] = valid

write_table(
table[["obs_id", "event_id", "tel_id", class_col, valid_col]],
self.loader.input_url,
f"/dl2/event/telescope/classification/{prefix}/tel_{tel_id:03d}",
mode="a",
overwrite=self.overwrite,
)
tables.append(table)

if len(tables) == 0:
raise ValueError("No predictions made for any telescope")

mono_predictions = vstack(tables)
stereo_predictions = self.combine.predict(mono_predictions)
trafo = TelListToMaskTransform(self.loader.subarray)
for c in filter(
lambda c: c.name.endswith("tel_ids"), stereo_predictions.columns.values()
):
stereo_predictions[c.name] = np.array([trafo(r) for r in c])

write_table(
stereo_predictions,
self.loader.input_url,
f"/dl2/event/subarray/classification/{self.estimator.model.model_cls}",
mode="a",
overwrite=self.overwrite,
)

def finish(self):
self.h5file.close()


def main():
ApplyParticleIdClassifier().run()


if __name__ == "__main__":
main()
45 changes: 40 additions & 5 deletions ctapipe/ml/tools/tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import shutil


def test_apply_energy_regressor(
energy_regressor_path, dl2_shower_geometry_file, tmp_path
):
def test_apply_energy_regressor(energy_regressor_path, dl1_parameters_file, tmp_path):
from ctapipe.ml.tools.apply_energy_regressor import ApplyEnergyRegressor

input_path = tmp_path / dl2_shower_geometry_file.name
input_path = tmp_path / dl1_parameters_file.name

# create copy to not mutate common test file
shutil.copy2(dl2_shower_geometry_file, input_path)
shutil.copy2(dl1_parameters_file, input_path)

ret = run_tool(
ApplyEnergyRegressor(),
Expand All @@ -38,3 +36,40 @@ def test_apply_energy_regressor(

assert "ExtraTreesRegressor_energy_mono" in events.colnames
assert "ExtraTreesRegressor_is_valid_mono" in events.colnames


def test_apply_particle_classifier(
particle_classifier_path, dl1_parameters_file, tmp_path
):
from ctapipe.ml.tools.apply_particle_classifier import ApplyParticleIdClassifier

input_path = tmp_path / dl1_parameters_file.name

# create copy to not mutate common test file
shutil.copy2(dl1_parameters_file, input_path)

ret = run_tool(
ApplyParticleIdClassifier(),
argv=[
f"--input={input_path}",
f"--model={particle_classifier_path}",
"--ApplyParticleIdClassifier.StereoMeanCombiner.weights=konrad",
],
)
assert ret == 0

loader = TableLoader(input_path, load_dl2=True)
events = loader.read_subarray_events()
assert "ExtraTreesClassifier_prediction" in events.colnames
assert "ExtraTreesClassifier_tel_ids" in events.colnames
assert "ExtraTreesClassifier_is_valid" in events.colnames
assert "ExtraTreesClassifier_goodness_of_fit" in events.colnames

events = loader.read_telescope_events()
assert "ExtraTreesClassifier_prediction" in events.colnames
assert "ExtraTreesClassifier_tel_ids" in events.colnames
assert "ExtraTreesClassifier_is_valid" in events.colnames
assert "ExtraTreesClassifier_goodness_of_fit" in events.colnames

assert "ExtraTreesClassifier_prediction_mono" in events.colnames
assert "ExtraTreesClassifier_is_valid_mono" in events.colnames
56 changes: 54 additions & 2 deletions ctapipe/ml/tools/tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,64 @@ def test_process_apply_energy(tmp_path, energy_regressor_path):
f"--input={input_url}",
f"--output={output}",
"--write-images",
"--write-stereo-shower",
"--write-mono-shower",
"--write-showers",
f"--energy-regressor={energy_regressor_path}",
f"--config={config_path}",
]
assert run_tool(ProcessorTool(), argv=argv, cwd=tmp_path) == 0

print(read_table(output, "/dl2/event/telescope/energy/ExtraTreesRegressor/tel_004"))
print(read_table(output, "/dl2/event/subarray/energy/ExtraTreesRegressor"))


def test_process_apply_classification(tmp_path, particle_classifier_path):
from ctapipe.tools.process import ProcessorTool
from ctapipe.io import SimTelEventSource

output = tmp_path / "gamma_prod5.dl2_energy.h5"

config_path = tmp_path / "config.json"

input_url = "dataset://gamma_prod5.simtel.zst"

with SimTelEventSource(input_url) as s:
subarray = s.subarray

allowed_tels = subarray.get_tel_ids_for_type(
"LST_LST_LSTCam"
) + subarray.get_tel_ids_for_type("MST_MST_NectarCam")

config = {
"ProcessorTool": {
"EventSource": {
"allowed_tels": allowed_tels,
},
"stereo_combiner_configs": [
{
"type": "StereoMeanCombiner",
"combine_property": "classification",
"algorithm": "ExtraTreesClassifier",
}
],
}
}

with config_path.open("w") as f:
json.dump(config, f)

argv = [
f"--input={input_url}",
f"--output={output}",
"--write-images",
"--write-showers",
f"--particle-classifier={particle_classifier_path}",
f"--config={config_path}",
]
assert run_tool(ProcessorTool(), argv=argv, cwd=tmp_path) == 0

print(
read_table(
output, "/dl2/event/telescope/classification/ExtraTreesClassifier/tel_004"
)
)
print(read_table(output, "/dl2/event/subarray/classification/ExtraTreesClassifier"))
7 changes: 7 additions & 0 deletions ctapipe/ml/tools/tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
def test_train_energy_regressor(energy_regressor_path):
from ctapipe.ml.sklearn import Regressor

Regressor.load(energy_regressor_path)


def test_train_particle_classifier(particle_classifier_path):
from ctapipe.ml.sklearn import Classifier

Classifier.load(particle_classifier_path)
Loading