Skip to content

Commit

Permalink
feat: generate fishial masks
Browse files Browse the repository at this point in the history
  • Loading branch information
ccrutchf committed Nov 12, 2024
1 parent 9655568 commit 531ab0c
Show file tree
Hide file tree
Showing 5 changed files with 488 additions and 129 deletions.
137 changes: 76 additions & 61 deletions fishsense_lite/commands/label_studio.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Module which represents the FishSense Lite Label Studio CLI."""

import importlib
import importlib.metadata
import json
import random
import string
from glob import glob
from pathlib import Path
from typing import List, Tuple
Expand All @@ -17,60 +16,21 @@
from pyfishsensedev.image.image_processors import RawProcessor
from pyfishsensedev.image.image_rectifier import ImageRectifier
from pyfishsensedev.laser.nn_laser_detector import NNLaserDetector

from pyfishsensedev.segmentation.fish.fish_segmentation_fishial_pytorch import (
FishSegmentationFishialPyTorch,
)

from fishsense_lite.commands.label_studio_models.laser_label_studio_json import (
LaserLabelStudioJSON,
)
from fishsense_lite.commands.label_studio_models.segmentation_label_studio_json import (
SegmentationLabelStudioJSON,
)
from fishsense_lite.utils import get_output_file, get_root, uint16_2_uint8


class Data:
def __init__(self, img: str):
self.img = img


class LaserValue:
def __init__(self, x: float, y: float, width: int, height: int):
self.x = x / float(width) * 100
self.y = y / float(height) * 100
self.width = 0.25
self.keypointlabels = ["Red Laser"]


class LaserResult:
def __init__(self, laser_image_coord: np.ndarray, width: int, height: int):
self.original_width = width
self.original_height = height
self.image_rotation = 0
self.value = LaserValue(
laser_image_coord[0], laser_image_coord[1], width, height
)

letters_and_numbers = string.ascii_letters + string.digits

self.id = "".join(random.choice(letters_and_numbers) for i in range(10))
self.from_name = "kp-1"
self.to_name = "img-1"
self.type = "keypointlabels"


class LaserPrediction:
def __init__(self, laser_image_coord: np.ndarray, width: int, height: int):
self.model_version = importlib.metadata.version("fishsense_lite")
self.result = [LaserResult(laser_image_coord, width, height)]


class LaserLabelStudioJSON:
def __init__(
self, img: str, laser_image_coord: np.ndarray, width: int, height: int
):
self.data = Data(img)
self.predictions = (
[LaserPrediction(laser_image_coord, width, height)]
if laser_image_coord is not None
else []
)


@ray.remote(vram_mb=1536)
def execute_laser(
@ray.remote(vram_mb=768)
def execute_nn_laser(
input_file: Path,
lens_calibration: LensCalibration,
estimated_laser_calibration: LaserCalibration,
Expand Down Expand Up @@ -110,6 +70,40 @@ def execute_laser(
laser_image_coord,
width,
height,
laser_detector.name,
)

with open(json_file, "w") as f:
f.write(json.dumps(json_objects, default=vars))


@ray.remote(vram_mb=768)
def execute_fishial(
input_file: Path, root: Path, output: Path, prefix: str, overwrite: bool
):
device = "cuda" if torch.cuda.is_available() else "cpu"
output_file = get_output_file(input_file, root, output, "jpg")
json_file = output_file.with_suffix(".json")

if output_file.exists() and json_file.exists() and not overwrite:
return

raw_processor = RawProcessor(enable_histogram_equalization=True)
try:
image = uint16_2_uint8(raw_processor.load_and_process(input_file))
except:
return

fish_segmentation_inference = FishSegmentationFishialPyTorch(device)
segmentations: np.ndarray = fish_segmentation_inference.inference(image)

output_file.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(output_file.absolute().as_posix(), image)

json_objects = SegmentationLabelStudioJSON(
f"{prefix}{output_file.relative_to(output.absolute()).as_posix()}",
segmentations,
fish_segmentation_inference.name,
)

with open(json_file, "w") as f:
Expand Down Expand Up @@ -242,27 +236,32 @@ def __call__(self):

output = Path(self.output_path)

self.__build_laser_json(
self.__build_nn_laser_json(
files, lens_calibration, estimated_laser_calibration, root, output
)

def __build_laser_json(
self.__build_fishial_json(files, output)

def __build_nn_laser_json(
self,
files: List[Path],
lens_calibration: LensCalibration,
estimated_laser_calibration: LaserCalibration,
root: Path,
output: Path,
):
output = output / "laser"
output.mkdir(parents=True, exist_ok=True)
laser_detector = NNLaserDetector(
lens_calibration, estimated_laser_calibration, "cpu"
)

laser_json_path = output / "label_studio.json"
if laser_json_path.exists() and not self.overwrite:
return
output = (
output
/ f"{laser_detector.name}.{importlib.metadata.version("pyfishsensedev")}"
)
output.mkdir(parents=True, exist_ok=True)

futures = [
execute_laser.remote(
execute_nn_laser.remote(
f,
lens_calibration,
estimated_laser_calibration,
Expand All @@ -275,3 +274,19 @@ def __build_laser_json(
]

list(self.tqdm(futures, total=len(files)))

def __build_fishial_json(self, files: List[Path], output: Path, root: Path):
fish_segmentation = FishSegmentationFishialPyTorch("cpu")

output = (
output
/ f"{fish_segmentation.name}.{importlib.metadata.version("pyfishsensedev")}"
)
output.mkdir(parents=True, exist_ok=True)

futures = [
execute_fishial.remote(f, root, output, self.prefix, self.overwrite)
for f in files
]

list(self.tqdm(futures, total=len(files)))
3 changes: 3 additions & 0 deletions fishsense_lite/commands/label_studio_models/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Data:
def __init__(self, img: str):
self.img = img
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import importlib
import random
import string

import numpy as np

from fishsense_lite.commands.label_studio_models.data import Data


class LaserValue:
def __init__(self, x: float, y: float, width: int, height: int):
self.x = x / float(width) * 100
self.y = y / float(height) * 100
self.width = 0.25
self.keypointlabels = ["Red Laser"]


class LaserResult:
def __init__(self, laser_image_coord: np.ndarray, width: int, height: int):
self.original_width = width
self.original_height = height
self.image_rotation = 0
self.value = LaserValue(
laser_image_coord[0], laser_image_coord[1], width, height
)

letters_and_numbers = string.ascii_letters + string.digits

self.id = "".join(random.choice(letters_and_numbers) for _ in range(10))
self.from_name = "kp-1"
self.to_name = "img-1"
self.type = "keypointlabels"


class LaserPrediction:
def __init__(
self, laser_image_coord: np.ndarray, width: int, height: int, model_name: str
):
self.model_version = (
f"{model_name}.{importlib.metadata.version("pyfishsensedev")}"
)
self.result = [LaserResult(laser_image_coord, width, height)]


class LaserLabelStudioJSON:
def __init__(
self,
img: str,
laser_image_coord: np.ndarray,
width: int,
height: int,
model_name: str,
):
self.data = Data(img)
self.predictions = (
[LaserPrediction(laser_image_coord, width, height, model_name)]
if laser_image_coord is not None
else []
)
Loading

0 comments on commit 531ab0c

Please sign in to comment.