Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve examples #3091

Open
wants to merge 15 commits into
base: develop
Choose a base branch
from
3 changes: 1 addition & 2 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ jobs:
cache: pip
- name: cpuinfo
run: cat /proc/cpuinfo
- name: Install NNCF and test requirements
- name: Install test requirements
run: |
pip install -e .
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved
pip install -r tests/cross_fw/examples/requirements.txt
- name: Print installed modules
run: pip list
Expand Down
24 changes: 10 additions & 14 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -119,23 +119,19 @@ nncf_debug/

# NNCF examples
examples/torch/object_detection/eval/
examples/post_training_quantization/onnx/mobilenet_v2/mobilenet_v2_*
examples/post_training_quantization/openvino/mobilenet_v2/mobilenet_v2_*
examples/post_training_quantization/tensorflow/mobilenet_v2/mobilenet_v2_*
examples/post_training_quantization/torch/mobilenet_v2/mobilenet_v2_*
examples/post_training_quantization/torch/ssd300_vgg16/ssd300_vgg16_*
examples/post_training_quantization/openvino/anomaly_stfpm_quantize_with_accuracy_control/stfpm_*
examples/post_training_quantization/openvino/yolov8/yolov8n*
examples/post_training_quantization/openvino/yolov8_quantize_with_accuracy_control/yolov8n*
examples/**/runs/**
examples/**/results/**
examples/llm_compression/openvino/tiny_llama_find_hyperparams/statistics
compressed_graph.dot
original_graph.dot
datasets/**
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why datasets/** was removed?
In case of the running tests locally, datasets directory may appear and be unintentionally committed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

datasets directory created by ultralitics, it's only happens if you have yolo config .config/Ultralytics/settings.json file that was generated some time ago or changes manually.

On current version of yolo, default path for datasets_dir generated as nncf/../datasest if run example from NNCF in case of .config/Ultralytics/settings.json is not exist.
If datasets directory created on your host you need to modify config by yolo settings dataset_dir=new_path or by text editor, or remove config file after it will be generated this new path.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not agree.
I've run the following command from the nncf root (as usual, for testing purposes):
pytest -s tests/cross_fw/examples/test_examples.py -k "post_training_quantization_openvino_yolo8_quantize"
And after the passing, I observe the following directory:
image
Thus, removing datasets/** from the .gitignore might cause issues.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a high perspective, it would be more helpful if you would update examples to create a datasets folder during the run in a corresponding or temporary directory, to exclude incidents in the future. In this case, we would be able to remove this path from .gitignore and safely run tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you updated the ultralytics config file as described above?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I did nothing with the ultralytics package. I've reproduced simple scenario with the following steps:

  1. Create clean virtual env;
  2. Clone nncf from GitHub;
  3. Install nncf & pytest;
  4. Run test;

From the contributor's perspective, I would not even know anything about ultralytics managing at all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have config file ~/.config/Ultralytics/settings.json on your host that was generated when you run any code with ultralitics for the first time and creating or removing any virtual environment does not affect on the config in any way.

Your config collect line like "datasets_dir": "/home/<username>/nncf/datasets",.
This line only generated if run NNCF examples with old version of ultralitics, current default value is like /home/<username>/datasets - GIT_ROOT.parent / "datasets".

There is only one way conditions to get datasets_dir like .../nncf/datasets:

  • You did not run any code with ultralitics before
  • And for the first time you have run NNCF yolo example with old version of ultralitics

In all other cases datasets_dir will not set as nncf/datasets

My point, If you have config file of any third party package with not default values in your ~/.config directory, that used NNCF folder as storage for anything, it's not responsibility of NNCF, it's incorrect configuration that should be resolved by user.

Moreover, the datasets directory in the NNCF git root folder breaks isort, by changing the import order of datasets module.
So datasets in .gitignore hides this problem by making the existence of this directory expected but it is not.

@alexsu52 as you have added datasets to .gitignore, i think last word is your

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not responsibility of NNCF, it's incorrect configuration that should be resolved by user.

If I run NNCF tests from the root and NNCF creates a datasets folder in the root, this is NNCF's responsibility, not the user, since it is not possible to set up the folder for this via tests.
As you said,

datasets directory in the NNCF git root folder breaks isort

So, this is another reason to change the placing of the datasets folder on the NNCF's side, since tests were designed by the NNCF team.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read your discussion and did the following experiment:

  • remove ~/.config/Ultralytics/settings.json
  • create a clear env using
pip install -r examples/post_training_quantization/onnx/yolov8_quantize_with_accuracy_control/requirements.txt

Ultralytics version:
image

  • run example
python examples/post_training_quantization/onnx/yolov8_quantize_with_accuracy_control/main.py

dataset was downloaded into nncf/datasets
image

Where did I go wrong?

examples/**/*.xml
examples/**/*.bin
examples/**/*.pt
examples/**/*.onnx
examples/**/statistics
examples/**/runs
examples/**/results
examples/**/metrics.json

# Tests
tests/**/runs/**
tests/**/tmp*/**
open_model_zoo/
nncf-tests.xml
compressed_graph.dot
original_graph.dot
37 changes: 21 additions & 16 deletions examples/post_training_quantization/onnx/mobilenet_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
import subprocess
from pathlib import Path
from typing import List, Optional
from typing import List

import numpy as np
import onnx
Expand All @@ -23,25 +23,26 @@
from sklearn.metrics import accuracy_score
from torchvision import datasets
from torchvision import transforms
from tqdm import tqdm

import nncf
from nncf.common.logging.track_progress import track

ROOT = Path(__file__).parent.resolve()
MODEL_URL = "https://huggingface.co/alexsu52/mobilenet_v2_imagenette/resolve/main/mobilenet_v2_imagenette.onnx"
DATASET_URL = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz"
DATASET_PATH = "~/.cache/nncf/datasets"
MODEL_PATH = "~/.cache/nncf/models"
DATASET_PATH = Path().home() / ".cache" / "nncf" / "datasets"
MODEL_PATH = Path().home() / ".cache" / "nncf" / "models"
DATASET_CLASSES = 10


def download_dataset() -> Path:
downloader = FastDownload(base=DATASET_PATH, archive="downloaded", data="extracted")
downloader = FastDownload(base=DATASET_PATH.as_posix(), archive="downloaded", data="extracted")
return downloader.get(DATASET_URL)


def download_model() -> Path:
return download_url(MODEL_URL, Path(MODEL_PATH).resolve())
MODEL_PATH.mkdir(exist_ok=True, parents=True)
return download_url(MODEL_URL, MODEL_PATH.resolve())


def validate(path_to_model: Path, validation_loader: torch.utils.data.DataLoader) -> float:
Expand All @@ -51,7 +52,7 @@ def validate(path_to_model: Path, validation_loader: torch.utils.data.DataLoader
compiled_model = ov.compile_model(path_to_model, device_name="CPU")
output = compiled_model.outputs[0]

for images, target in tqdm(validation_loader):
for images, target in track(validation_loader, description="Validating"):
pred = compiled_model(images)[output]
predictions.append(np.argmax(pred, axis=1))
references.append(target)
Expand All @@ -61,13 +62,17 @@ def validate(path_to_model: Path, validation_loader: torch.utils.data.DataLoader
return accuracy_score(predictions, references)


def run_benchmark(path_to_model: Path, shape: Optional[List[int]] = None, verbose: bool = True) -> float:
command = f"benchmark_app -m {path_to_model} -d CPU -api async -t 15"
if shape is not None:
command += f' -shape [{",".join(str(x) for x in shape)}]'
cmd_output = subprocess.check_output(command, shell=True) # nosec
if verbose:
print(*str(cmd_output).split("\\n")[-9:-1], sep="\n")
def run_benchmark(path_to_model: Path, shape: List[int]) -> float:
command = [
"benchmark_app",
"-m", path_to_model.as_posix(),
"-d", "CPU",
"-api", "async",
"-t", "15",
"-shape", str(shape),
] # fmt: skip
cmd_output = subprocess.check_output(command, text=True)
print(*cmd_output.splitlines()[-8:], sep="\n")
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output))
return float(match.group(1))

Expand Down Expand Up @@ -136,9 +141,9 @@ def transform_fn(data_item):
print(f"[2/7] Save INT8 model: {int8_model_path}")

print("[3/7] Benchmark FP32 model:")
fp32_fps = run_benchmark(fp32_model_path, shape=[1, 3, 224, 224], verbose=True)
fp32_fps = run_benchmark(fp32_model_path, shape=[1, 3, 224, 224])
print("[4/7] Benchmark INT8 model:")
int8_fps = run_benchmark(int8_model_path, shape=[1, 3, 224, 224], verbose=True)
int8_fps = run_benchmark(int8_model_path, shape=[1, 3, 224, 224])

print("[5/7] Validate ONNX FP32 model in OpenVINO:")
fp32_top1 = validate(fp32_model_path, val_loader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

import openvino as ov
import torch
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.engine.validator import BaseValidator as Validator
from ultralytics.models.yolo import YOLO
from ultralytics.models.yolo.segment.val import SegmentationValidator
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils.metrics import ConfusionMatrix

from examples.post_training_quantization.onnx.yolov8_quantize_with_accuracy_control.main import prepare_validation
from examples.post_training_quantization.onnx.yolov8_quantize_with_accuracy_control.main import print_statistics
from nncf.common.logging.track_progress import track

ROOT = Path(__file__).parent.resolve()
MODEL_NAME = "yolov8n-seg"
Expand All @@ -37,7 +37,7 @@
def validate_ov_model(
ov_model: ov.Model,
data_loader: torch.utils.data.DataLoader,
validator: Validator,
validator: SegmentationValidator,
num_samples: Optional[int] = None,
) -> Tuple[Dict, int, int]:
validator.seen = 0
Expand All @@ -47,7 +47,7 @@ def validate_ov_model(
validator.confusion_matrix = ConfusionMatrix(nc=validator.nc)
compiled_model = ov.compile_model(ov_model, device_name="CPU")
num_outputs = len(compiled_model.outputs)
for batch_i, batch in enumerate(data_loader):
for batch_i, batch in enumerate(track(data_loader, description="Validating")):
if num_samples is not None and batch_i == num_samples:
break
batch = validator.preprocess(batch)
Expand All @@ -65,12 +65,17 @@ def validate_ov_model(
return stats, validator.seen, validator.nt_per_class.sum()


def run_benchmark(model_path: str, config) -> float:
command = f"benchmark_app -m {model_path} -d CPU -api async -t 30"
command += f' -shape "[1,3,{config.imgsz},{config.imgsz}]"'
cmd_output = subprocess.check_output(command, shell=True) # nosec

match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output))
def run_benchmark(model_path: Path, config) -> float:
command = [
"benchmark_app",
"-m", model_path.as_posix(),
"-d", "CPU",
"-api", "async",
"-t", "30",
"-shape", str([1, 3, config.imgsz, config.imgsz]),
] # fmt: skip
cmd_output = subprocess.check_output(command, text=True)
match = re.search(r"Throughput\: (.+?) FPS", cmd_output)
return float(match.group(1))


Expand All @@ -96,11 +101,11 @@ def run_benchmark(model_path: str, config) -> float:
validator, data_loader = prepare_validation(YOLO(ROOT / f"{MODEL_NAME}.pt"), args)

print("[5/7] Validate OpenVINO FP32 model:")
fp32_stats, total_images, total_objects = validate_ov_model(fp32_ov_model, tqdm(data_loader), validator)
fp32_stats, total_images, total_objects = validate_ov_model(fp32_ov_model, data_loader, validator)
print_statistics(fp32_stats, total_images, total_objects)

print("[6/7] Validate OpenVINO INT8 model:")
int8_stats, total_images, total_objects = validate_ov_model(int8_ov_model, tqdm(data_loader), validator)
int8_stats, total_images, total_objects = validate_ov_model(int8_ov_model, data_loader, validator)
print_statistics(int8_stats, total_images, total_objects)

print("[7/7] Report:")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,32 @@
from pathlib import Path
from typing import Any, Dict, Tuple

import numpy as np
import onnx
import onnxruntime
import torch
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.data.converter import coco80_to_coco91_class
from ultralytics.data.utils import check_det_dataset
from ultralytics.engine.validator import BaseValidator as Validator
from ultralytics.models.yolo import YOLO
from ultralytics.models.yolo.segment.val import SegmentationValidator
from ultralytics.utils import DATASETS_DIR
from ultralytics.utils import DEFAULT_CFG
from ultralytics.utils import ops
from ultralytics.utils.metrics import ConfusionMatrix

import nncf
from nncf.common.logging.track_progress import track

MODEL_NAME = "yolov8n-seg"

ROOT = Path(__file__).parent.resolve()


def validate(
model: onnx.ModelProto, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None
model: onnx.ModelProto,
data_loader: torch.utils.data.DataLoader,
validator: SegmentationValidator,
num_samples: int = None,
) -> Tuple[Dict, int, int]:
validator.seen = 0
validator.jdict = []
Expand All @@ -49,7 +53,7 @@ def validate(
output_names = [output.name for output in session.get_outputs()]
num_outputs = len(output_names)

for batch_i, batch in enumerate(data_loader):
for batch_i, batch in enumerate(track(data_loader, description="Validating")):
if num_samples is not None and batch_i == num_samples:
break
batch = validator.preprocess(batch)
Expand All @@ -71,7 +75,7 @@ def validate(
return stats, validator.seen, validator.nt_per_class.sum()


def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -> None:
def print_statistics(stats: Dict[str, float], total_images: int, total_objects: int) -> None:
print("Metrics(Box):")
mp, mr, map50, mean_ap = (
stats["metrics/precision(B)"],
Expand All @@ -84,38 +88,35 @@ def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -
pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format
print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap))

# print the mask metrics for segmentation
if "metrics/precision(M)" in stats:
print("Metrics(Mask):")
s_mp, s_mr, s_map50, s_mean_ap = (
stats["metrics/precision(M)"],
stats["metrics/recall(M)"],
stats["metrics/mAP50(M)"],
stats["metrics/mAP50-95(M)"],
)
# Print results
s = ("%20s" + "%12s" * 6) % ("Class", "Images", "Labels", "Precision", "Recall", "[email protected]", "[email protected]:.95")
print(s)
pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format
print(pf % ("all", total_images, total_objects, s_mp, s_mr, s_map50, s_mean_ap))


def prepare_validation(model: YOLO, args: Any) -> Tuple[Validator, torch.utils.data.DataLoader]:
validator = model.task_map[model.task]["validator"](args=args)
validator.data = check_det_dataset(args.data)
validator.stride = 32
print("Metrics(Mask):")
s_mp, s_mr, s_map50, s_mean_ap = (
stats["metrics/precision(M)"],
stats["metrics/recall(M)"],
stats["metrics/mAP50(M)"],
stats["metrics/mAP50-95(M)"],
)
# Print results
s = ("%20s" + "%12s" * 6) % ("Class", "Images", "Labels", "Precision", "Recall", "[email protected]", "[email protected]:.95")
print(s)
pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format
print(pf % ("all", total_images, total_objects, s_mp, s_mr, s_map50, s_mean_ap))

data_loader = validator.get_dataloader(f"{DATASETS_DIR}/coco128-seg", 1)

def prepare_validation(model: YOLO, args: Any) -> Tuple[SegmentationValidator, torch.utils.data.DataLoader]:
validator: SegmentationValidator = model.task_map[model.task]["validator"](args=args)
validator.data = check_det_dataset(args.data)
validator.stride = 32
validator.is_coco = True
validator.class_map = coco80_to_coco91_class()
validator.names = model.model.names
validator.metrics.names = validator.names
validator.nc = model.model.model[-1].nc
validator.nm = 32
validator.process = ops.process_mask
validator.plot_masks = []

coco_data_path = DATASETS_DIR / "coco128-seg"
data_loader = validator.get_dataloader(coco_data_path.as_posix(), 1)

return validator, data_loader


Expand All @@ -129,7 +130,7 @@ def prepare_onnx_model(model: YOLO, model_name: str) -> Tuple[onnx.ModelProto, P


def quantize_ac(
model: onnx.ModelProto, data_loader: torch.utils.data.DataLoader, validator_ac: Validator
model: onnx.ModelProto, data_loader: torch.utils.data.DataLoader, validator_ac: SegmentationValidator
) -> onnx.ModelProto:
input_name = model.graph.input[0].name

Expand All @@ -140,7 +141,7 @@ def transform_fn(data_item: Dict):
def validation_ac(
val_model: onnx.ModelProto,
validation_loader: torch.utils.data.DataLoader,
validator: Validator,
validator: SegmentationValidator,
num_samples: int = None,
) -> float:
validator.seen = 0
Expand All @@ -155,7 +156,6 @@ def validation_ac(
output_names = [output.name for output in session.get_outputs()]
num_outputs = len(output_names)

counter = 0
for batch_i, batch in enumerate(validation_loader):
KodiaqQ marked this conversation as resolved.
Show resolved Hide resolved
if num_samples is not None and batch_i == num_samples:
break
Expand All @@ -172,13 +172,12 @@ def validation_ac(
]
preds = validator.postprocess(preds)
validator.update_metrics(preds, batch)
counter += 1

stats = validator.get_stats()
if num_outputs == 1:
stats_metrics = stats["metrics/mAP50-95(B)"]
else:
stats_metrics = stats["metrics/mAP50-95(M)"]
print(f"Validate: dataset length = {counter}, metric value = {stats_metrics:.3f}")
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
return stats_metrics, None

quantization_dataset = nncf.Dataset(data_loader, transform_fn)
Expand Down Expand Up @@ -213,8 +212,6 @@ def validation_ac(


def run_example():
MODEL_NAME = "yolov8n-seg"

model = YOLO(ROOT / f"{MODEL_NAME}.pt")
args = get_cfg(cfg=DEFAULT_CFG)
args.data = "coco128-seg.yaml"
Expand All @@ -231,11 +228,11 @@ def run_example():
print(f"[2/5] Save INT8 model: {int8_model_path}")

print("[3/5] Validate ONNX FP32 model:")
fp_stats, total_images, total_objects = validate(fp32_model, tqdm(data_loader), validator)
fp_stats, total_images, total_objects = validate(fp32_model, data_loader, validator)
print_statistics(fp_stats, total_images, total_objects)

print("[4/5] Validate ONNX INT8 model:")
q_stats, total_images, total_objects = validate(int8_model, tqdm(data_loader), validator)
q_stats, total_images, total_objects = validate(int8_model, data_loader, validator)
print_statistics(q_stats, total_images, total_objects)

print("[5/5] Report:")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
ultralytics==8.3.22
onnx==1.17.0
onnxruntime==1.19.2
openvino==2024.5
Loading
Loading