Skip to content

Commit

Permalink
t
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Nov 15, 2024
1 parent 63fdf19 commit c93b019
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import nncf
from nncf.common.logging.track_progress import track

MODEL_NAME = "yolov8n-seg"

ROOT = Path(__file__).parent.resolve()
DATASET_PATH = Path().home() / ".cache" / "nncf" / "datasets"

Expand Down Expand Up @@ -110,6 +112,10 @@ def prepare_validation(model: YOLO, args: Any) -> Tuple[SegmentationValidator, t
validator: SegmentationValidator = model.task_map[model.task]["validator"](args=args)
validator.data = check_det_dataset(args.data)
validator.stride = 32

coco_data_path = DATASET_PATH / "coco128-seg"
data_loader = validator.get_dataloader(coco_data_path.as_posix(), 1)

validator.is_coco = True
validator.class_map = coco80_to_coco91_class()
validator.names = model.model.names
Expand All @@ -118,9 +124,6 @@ def prepare_validation(model: YOLO, args: Any) -> Tuple[SegmentationValidator, t
validator.process = ops.process_mask
validator.plot_masks = []

coco_data_path = DATASET_PATH / "coco128-seg"
data_loader = validator.get_dataloader(coco_data_path.as_posix(), 1)

return validator, data_loader


Expand Down Expand Up @@ -218,8 +221,6 @@ def validation_ac(


def run_example():
MODEL_NAME = "yolov8n-seg"

model = YOLO(ROOT / f"{MODEL_NAME}.pt")
args = get_cfg(cfg=DEFAULT_CFG)
args.data = "coco128-seg.yaml"
Expand Down
7 changes: 4 additions & 3 deletions examples/post_training_quantization/openvino/yolov8/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,16 @@ def prepare_validation(model: YOLO, args: Any) -> Tuple[DetectionValidator, torc
validator: DetectionValidator = model.task_map[model.task]["validator"](args=args)
validator.data = check_det_dataset(args.data)
validator.stride = 32

coco_data_path = DATASET_PATH / "coco128"
data_loader = validator.get_dataloader(coco_data_path.as_posix(), 1)

validator.is_coco = True
validator.class_map = coco80_to_coco91_class()
validator.names = model.model.names
validator.metrics.names = validator.names
validator.nc = model.model.model[-1].nc

coco_data_path = DATASET_PATH / "coco128"
data_loader = validator.get_dataloader(coco_data_path.as_posix(), 1)

return validator, data_loader


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def prepare_validation(model: YOLO, args: Any) -> Tuple[SegmentationValidator, t
validator: SegmentationValidator = model.task_map[model.task]["validator"](args=args)
validator.data = check_det_dataset(args.data)
validator.stride = 32

coco_data_path = DATASET_PATH / "coco128-seg"
data_loader = validator.get_dataloader(coco_data_path.as_posix(), 1)

validator.is_coco = True
validator.class_map = coco80_to_coco91_class()
validator.names = model.model.names
Expand All @@ -108,9 +112,6 @@ def prepare_validation(model: YOLO, args: Any) -> Tuple[SegmentationValidator, t
validator.process = ops.process_mask
validator.plot_masks = []

coco_data_path = DATASET_PATH / "coco128-seg"
data_loader = validator.get_dataloader(coco_data_path.as_posix(), 1)

return validator, data_loader


Expand Down

0 comments on commit c93b019

Please sign in to comment.