Skip to content

Commit

Permalink
feat: better raw processing and seathru
Browse files Browse the repository at this point in the history
  • Loading branch information
ccrutchf committed Nov 14, 2024
1 parent 8c6414f commit e4b9c89
Show file tree
Hide file tree
Showing 4 changed files with 1,064 additions and 177 deletions.
258 changes: 180 additions & 78 deletions demo/pipeline.ipynb

Large diffs are not rendered by default.

77 changes: 67 additions & 10 deletions fishsense_lite/commands/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import torch
from fishsense_common.pluggable_cli import Command, argument
from pyfishsensedev.calibration import LaserCalibration, LensCalibration
from pyfishsensedev.depth_map import DepthAnythingDepthMap, LaserDepthMap
from pyfishsensedev.image import ColorCorrection, ImageRectifier, RawProcessor
from pyfishsensedev.image.image_processors import RawProcessor
from pyfishsensedev.image.image_processors.raw_processor_old import RawProcessorOld
from pyfishsensedev.image.image_rectifier import ImageRectifier
from pyfishsensedev.laser.nn_laser_detector import NNLaserDetector
from pyfishsensedev.segmentation.fish.fish_segmentation_fishial_pytorch import (
Expand Down Expand Up @@ -46,9 +49,11 @@ def execute_nn_laser(
if output_file.exists() and json_file.exists() and not overwrite:
return

dark_raw_processor = RawProcessor(enable_histogram_equalization=False)
dark_raw_processor = RawProcessorOld(
input_file, enable_histogram_equalization=False
)
try:
image_dark = uint16_2_uint8(dark_raw_processor.load_and_process(input_file))
image_dark = uint16_2_uint8(next(dark_raw_processor.__iter__()))
except:
return

Expand Down Expand Up @@ -80,7 +85,13 @@ def execute_nn_laser(

@ray.remote(vram_mb=768)
def execute_fishial(
input_file: Path, root: Path, output: Path, prefix: str, overwrite: bool
input_file: Path,
lens_calibration: LensCalibration,
estimated_laser_calibration: LaserCalibration,
root: Path,
output: Path,
prefix: str,
overwrite: bool,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
output_file = get_output_file(input_file, root, output, "jpg")
Expand All @@ -93,17 +104,46 @@ def execute_fishial(
if output_file.exists() and json_file.exists() and not overwrite:
return

raw_processor = RawProcessor(enable_histogram_equalization=True)
raw_processor = RawProcessor(input_file)
dark_raw_processor = RawProcessorOld(
input_file, enable_histogram_equalization=False
)
try:
image = uint16_2_uint8(raw_processor.load_and_process(input_file))
img = next(raw_processor.__iter__())
img_dark = uint16_2_uint8(next(dark_raw_processor.__iter__()))
except:
return

image_rectifier = ImageRectifier(lens_calibration)
img = image_rectifier.rectify(img)
img_dark = image_rectifier.rectify(img_dark)

laser_detector = NNLaserDetector(
lens_calibration, estimated_laser_calibration, device
)
laser_coords = laser_detector.find_laser(img_dark)

ml_depth_map = DepthAnythingDepthMap(img, device)

if laser_coords:
laser_coords_int = np.round(laser_coords).astype(int)
depth_map = LaserDepthMap(
laser_coords, lens_calibration, estimated_laser_calibration
)
scale = (
depth_map.depth_map[0, 0]
/ ml_depth_map.depth_map[laser_coords_int[1], laser_coords_int[0]]
)
ml_depth_map.rescale(scale)

color_correction = ColorCorrection()
img = color_correction.correct_color(img, ml_depth_map)

fish_segmentation_inference = FishSegmentationFishialPyTorch(device)
segmentations: np.ndarray = fish_segmentation_inference.inference(image)
segmentations: np.ndarray = fish_segmentation_inference.inference(img)

output_file.parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(output_file.absolute().as_posix(), image)
cv2.imwrite(output_file.absolute().as_posix(), img)

debug_output = (segmentations.astype(float) / segmentations.max() * 255).astype(
np.uint8
Expand Down Expand Up @@ -251,7 +291,9 @@ def __call__(self):
files, lens_calibration, estimated_laser_calibration, root, output
)

self.__build_fishial_json(files, output, root)
self.__build_fishial_json(
files, lens_calibration, estimated_laser_calibration, output, root
)

def __build_nn_laser_json(
self,
Expand Down Expand Up @@ -286,7 +328,14 @@ def __build_nn_laser_json(

list(self.tqdm(futures, total=len(files)))

def __build_fishial_json(self, files: List[Path], output: Path, root: Path):
def __build_fishial_json(
self,
files: List[Path],
lens_calibration: LensCalibration,
estimated_laser_calibration: LaserCalibration,
output: Path,
root: Path,
):
fish_segmentation = FishSegmentationFishialPyTorch("cpu")

output = (
Expand All @@ -296,7 +345,15 @@ def __build_fishial_json(self, files: List[Path], output: Path, root: Path):
output.mkdir(parents=True, exist_ok=True)

futures = [
execute_fishial.remote(f, root, output, self.prefix, self.overwrite)
execute_fishial.remote(
f,
lens_calibration,
estimated_laser_calibration,
root,
output,
self.prefix,
self.overwrite,
)
for f in files
]

Expand Down
Loading

0 comments on commit e4b9c89

Please sign in to comment.