Skip to content
This repository has been archived by the owner on Mar 19, 2023. It is now read-only.

Commit

Permalink
Add support for multiple targets.
Browse files Browse the repository at this point in the history
Adds support for multiple targets provided in a list format in the config. Will match all possible targets above the confidence level and return the total count as the state.
  • Loading branch information
shbatm committed Feb 10, 2020
1 parent 9f81cf7 commit 8c54e5d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Configuration variables:
- **save_file_folder**: (Optional) The folder to save processed images to. Note that folder path should be added to [whitelist_external_dirs](https://www.home-assistant.io/docs/configuration/basic/)
- **save_timestamped_file**: (Optional, default `False`, requires `save_file_folder` to be configured) Save the processed image with the time of detection in the filename.
- **source**: Must be a camera.
- **target**: The target object class, default `person`.
- **target**: The target object class, default `person`. Can also be a list of targets.
- **confidence**: (Optional) The confidence (in %) above which detected targets are counted in the sensor state. Default value: 80
- **name**: (Optional) A custom name for the the entity.

Expand Down
57 changes: 31 additions & 26 deletions custom_components/deepstack_object/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
CONF_SAVE_TIMESTAMPTED_FILE = "save_timestamped_file"
DATETIME_FORMAT = "%Y-%m-%d_%H:%M:%S"
DEFAULT_API_KEY = ""
DEFAULT_TARGET = "person"
DEFAULT_TARGET = ["person"]
DEFAULT_TIMEOUT = 10
EVENT_OBJECT_DETECTED = "image_processing.object_detected"
EVENT_FILE_SAVED = "image_processing.file_saved"
Expand All @@ -67,7 +67,9 @@
vol.Required(CONF_PORT): cv.port,
vol.Optional(CONF_API_KEY, default=DEFAULT_API_KEY): cv.string,
vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int,
vol.Optional(CONF_TARGET, default=DEFAULT_TARGET): cv.string,
vol.Optional(CONF_TARGET, default=DEFAULT_TARGET): vol.All(
cv.ensure_list, [cv.string]
),
vol.Optional(CONF_SAVE_FILE_FOLDER): cv.isdir,
vol.Optional(CONF_SAVE_TIMESTAMPTED_FILE, default=False): cv.boolean,
}
Expand All @@ -76,8 +78,9 @@

def get_box(prediction: dict, img_width: int, img_height: int):
"""
Return the relative bounxing box coordinates
defined by the tuple (y_min, x_min, y_max, x_max)
Return the relative bounxing box coordinates.
Defined by the tuple (y_min, x_min, y_max, x_max)
where the coordinates are floats in the range [0.0, 1.0] and
relative to the width and height of the image.
"""
Expand Down Expand Up @@ -145,7 +148,8 @@ def __init__(
camera_name = split_entity_id(camera_entity)[1]
self._name = "deepstack_object_{}".format(camera_name)
self._state = None
self._targets_confidences = []
self._targets_confidences = [None] * len(self._target)
self._targets_found = [0] * len(self._target)
self._predictions = {}
self._summary = {}
self._last_detection = None
Expand All @@ -161,27 +165,30 @@ def process_image(self, image):
io.BytesIO(bytearray(image))
).size
self._state = None
self._targets_confidences = []
self._targets_confidences = [None] * len(self._target)
self._targets_found = [0] * len(self._target)
self._predictions = {}
self._summary = {}
try:
self._dsobject.detect(image)
except ds.DeepstackException as exc:
_LOGGER.error("Depstack error : %s", exc)
_LOGGER.error("Deepstack error : %s", exc)
return

self._predictions = self._dsobject.predictions.copy()

if len(self._predictions) > 0:
raw_confidences = ds.get_object_confidences(self._predictions, self._target)
self._targets_confidences = [
ds.format_confidence(confidence) for confidence in raw_confidences
]
self._state = len(
ds.get_confidences_above_threshold(
self._targets_confidences, self._confidence
if self._predictions:
for i, target in enumerate(self._target):
raw_confidences = ds.get_object_confidences(self._predictions, target)
self._targets_confidences[i] = [
ds.format_confidence(confidence) for confidence in raw_confidences
]
self._targets_found[i] = len(
ds.get_confidences_above_threshold(
self._targets_confidences[i], self._confidence
)
)
)
self._state = sum(self._targets_found)
if self._state > 0:
self._last_detection = dt_util.now().strftime(DATETIME_FORMAT)
self._summary = ds.get_objects_summary(self._predictions)
Expand All @@ -193,14 +200,13 @@ def process_image(self, image):

def save_image(self, image, predictions, target, directory):
"""Save a timestamped image with bounding boxes around targets."""

img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
draw = ImageDraw.Draw(img)

for prediction in predictions:
prediction_confidence = ds.format_confidence(prediction["confidence"])
if (
prediction["label"] == target
prediction["label"] in target
and prediction_confidence >= self._confidence
):
box = get_box(prediction, self._image_width, self._image_height)
Expand All @@ -213,12 +219,12 @@ def save_image(self, image, predictions, target, directory):
color=RED,
)

latest_save_path = directory + "{}_latest_{}.jpg".format(self._name, target)
latest_save_path = directory + "{}_latest_{}.jpg".format(self._name, target[0])
img.save(latest_save_path)

if self._save_timestamped_file:
timestamp_save_path = directory + "{}_{}_{}.jpg".format(
self._name, target, self._last_detection
self._name, target[0], self._last_detection
)

out_file = open(timestamp_save_path, "wb")
Expand All @@ -231,7 +237,6 @@ def save_image(self, image, predictions, target, directory):

def fire_prediction_events(self, predictions, confidence):
"""Fire events based on predictions if above confidence threshold."""

for prediction in predictions:
if ds.format_confidence(prediction["confidence"]) > confidence:
box = get_box(prediction, self._image_width, self._image_height)
Expand All @@ -246,7 +251,7 @@ def fire_prediction_events(self, predictions, confidence):
)

def fire_saved_file_event(self, save_path):
"""Fire event when saving a file"""
"""Fire event when saving a file."""
self.hass.bus.fire(
EVENT_FILE_SAVED, {ATTR_ENTITY_ID: self.entity_id, FILE: save_path}
)
Expand All @@ -269,16 +274,16 @@ def name(self):
@property
def unit_of_measurement(self):
"""Return the unit of measurement."""
target = self._target
if self._state != None and self._state > 1:
target += "s"
target = self._target if len(self._target) == 1 else "target"
if self._state is not None and self._state > 1:
return target + "s"
return target

@property
def device_state_attributes(self):
"""Return device specific state attributes."""
attr = {}
if self._last_detection:
attr["last_{}_detection".format(self._target)] = self._last_detection
attr["last_detection"] = self._last_detection
attr["summary"] = self._summary
return attr

0 comments on commit 8c54e5d

Please sign in to comment.