Skip to content

Commit

Permalink
Autoformat with black
Browse files Browse the repository at this point in the history
  • Loading branch information
brainless-bot[bot] committed Aug 20, 2024
1 parent c56f2ee commit 6aaf04a
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 17 deletions.
12 changes: 9 additions & 3 deletions examples/example_spine_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@

def main():
with cProfile.Profile() as pr:
result, intermediate_steps_data = evaluator.evaluate(prediction_mask, reference_mask)["ungrouped"]
result, intermediate_steps_data = evaluator.evaluate(
prediction_mask, reference_mask
)["ungrouped"]

# To print the results, just call print
print(result)
Expand All @@ -35,8 +37,12 @@ def main():
intermediate_steps_data.original_prediction_arr # Input prediction array, untouched
intermediate_steps_data.original_reference_arr # Input reference array, untouched

intermediate_steps_data.prediction_arr(InputType.MATCHED_INSTANCE) # Prediction array after instances have been matched
intermediate_steps_data.reference_arr(InputType.MATCHED_INSTANCE) # Reference array after instances have been matched
intermediate_steps_data.prediction_arr(
InputType.MATCHED_INSTANCE
) # Prediction array after instances have been matched
intermediate_steps_data.reference_arr(
InputType.MATCHED_INSTANCE
) # Reference array after instances have been matched

pr.dump_stats(directory + "/semantic_example.log")
return result, intermediate_steps_data
Expand Down
20 changes: 16 additions & 4 deletions panoptica/panoptica_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,19 @@ def panoptic_evaluate(
# Crops away unecessary space of zeroes
input_pair.crop_data()

processing_pair: SemanticPair | UnmatchedInstancePair | MatchedInstancePair | EvaluateInstancePair | PanopticaResult = input_pair.copy()
processing_pair: (
SemanticPair
| UnmatchedInstancePair
| MatchedInstancePair
| EvaluateInstancePair
| PanopticaResult
) = input_pair.copy()

# First Phase: Instance Approximation
if isinstance(processing_pair, SemanticPair):
intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.SEMANTIC)
intermediate_steps_data.add_intermediate_arr_data(
processing_pair.copy(), InputType.SEMANTIC
)
assert (
instance_approximator is not None
), "Got SemanticPair but not InstanceApproximator"
Expand All @@ -243,7 +251,9 @@ def panoptic_evaluate(

# Second Phase: Instance Matching
if isinstance(processing_pair, UnmatchedInstancePair):
intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.UNMATCHED_INSTANCE)
intermediate_steps_data.add_intermediate_arr_data(
processing_pair.copy(), InputType.UNMATCHED_INSTANCE
)
processing_pair = _handle_zero_instances_cases(
processing_pair,
eval_metrics=instance_metrics,
Expand All @@ -266,7 +276,9 @@ def panoptic_evaluate(

# Third Phase: Instance Evaluation
if isinstance(processing_pair, MatchedInstancePair):
intermediate_steps_data.add_intermediate_arr_data(processing_pair.copy(), InputType.MATCHED_INSTANCE)
intermediate_steps_data.add_intermediate_arr_data(
processing_pair.copy(), InputType.MATCHED_INSTANCE
)
processing_pair = _handle_zero_instances_cases(
processing_pair,
eval_metrics=instance_metrics,
Expand Down
24 changes: 18 additions & 6 deletions panoptica/utils/processing_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ def __init__(self, original_input: _ProcessingPair | None):
self._original_input = original_input
self._intermediatesteps: dict[str, _ProcessingPair] = {}

def add_intermediate_arr_data(self, processing_pair: _ProcessingPair, inputtype: InputType):
def add_intermediate_arr_data(
self, processing_pair: _ProcessingPair, inputtype: InputType
):
type_name = inputtype.name
self.add_intermediate_data(type_name, processing_pair)

Expand All @@ -353,26 +355,36 @@ def add_intermediate_data(self, key, value):

@property
def original_prediction_arr(self):
assert self._original_input is not None, "Original prediction_arr is None, there are no intermediate steps"
assert (
self._original_input is not None
), "Original prediction_arr is None, there are no intermediate steps"
return self._original_input.prediction_arr

@property
def original_reference_arr(self):
assert self._original_input is not None, "Original reference_arr is None, there are no intermediate steps"
assert (
self._original_input is not None
), "Original reference_arr is None, there are no intermediate steps"
return self._original_input.reference_arr

def prediction_arr(self, inputtype: InputType):
type_name = inputtype.name
procpair = self[type_name]
assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error"
assert isinstance(
procpair, _ProcessingPair
), f"step {type_name} is not a processing pair, error"
return procpair.prediction_arr

def reference_arr(self, inputtype: InputType):
type_name = inputtype.name
procpair = self[type_name]
assert isinstance(procpair, _ProcessingPair), f"step {type_name} is not a processing pair, error"
assert isinstance(
procpair, _ProcessingPair
), f"step {type_name} is not a processing pair, error"
return procpair.reference_arr

def __getitem__(self, key):
assert key in self._intermediatesteps, f"key {key} not in intermediate steps, maybe the step was skipped?"
assert (
key in self._intermediatesteps
), f"key {key} not in intermediate steps, maybe the step was skipped?"
return self._intermediatesteps[key]
10 changes: 8 additions & 2 deletions unit_tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
MetricZeroTPEdgeCaseHandling,
EdgeCaseHandler,
)
from panoptica import ConnectedComponentsInstanceApproximator, NaiveThresholdMatching, Panoptica_Evaluator
from panoptica import (
ConnectedComponentsInstanceApproximator,
NaiveThresholdMatching,
Panoptica_Evaluator,
)
from pathlib import Path
import numpy as np
import random
Expand Down Expand Up @@ -105,7 +109,9 @@ def test_SegmentationClassGroups_config_by_name(self):

configname = "test_file.yaml"
t.save_to_config_by_name(configname)
d: SegmentationClassGroups = SegmentationClassGroups.load_from_config_name(configname)
d: SegmentationClassGroups = SegmentationClassGroups.load_from_config_name(
configname
)

testfile_d = config_by_name(configname)
os.remove(testfile_d)
Expand Down
10 changes: 8 additions & 2 deletions unit_tests/test_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
MetricMode,
MetricCouldNotBeComputedException,
)
from panoptica.utils.edge_case_handling import EdgeCaseResult, EdgeCaseHandler, MetricZeroTPEdgeCaseHandling
from panoptica.utils.edge_case_handling import (
EdgeCaseResult,
EdgeCaseHandler,
MetricZeroTPEdgeCaseHandling,
)


class Test_EdgeCaseHandler(unittest.TestCase):
Expand All @@ -24,7 +28,9 @@ def test_edgecasehandler_simple(self):

print()
# print(handler.get_metric_zero_tp_handle(ListMetric.IOU))
r = handler.handle_zero_tp(Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1)
r = handler.handle_zero_tp(
Metric.IOU, tp=0, num_pred_instances=1, num_ref_instances=1
)
print(r)

iou_test = MetricZeroTPEdgeCaseHandling(
Expand Down

0 comments on commit 6aaf04a

Please sign in to comment.