Skip to content

Commit

Permalink
merge: 132 add a new dataset creator processing
Browse files Browse the repository at this point in the history
  • Loading branch information
picsalex authored Jul 5, 2024
1 parent 284cd0d commit 4742453
Show file tree
Hide file tree
Showing 216 changed files with 14,967 additions and 1,574 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,9 @@ yolox-detection-test/*

# MacOS files
.DS_Store

# Engine logs folder
/logs

# Ruff
/.ruff_cache
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ repos:
hooks:
- id: mypy
args:
- --non-interactive
- --install-types
- --check-untyped-defs
- --ignore-missing-imports
exclude: ^tests/
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2020 Picsell.ia
Copyright (c) 2020 Picsell.ia

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
10 changes: 7 additions & 3 deletions ViT-classification/experiment/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import os
from transformers import TrainerCallback

from trainer import VitClassificationTrainer
from transformers import TrainerCallback

os.environ["PICSELLIA_SDK_CUSTOM_LOGGING"] = "True"
os.environ["PICSELLIA_SDK_DOWNLOAD_BAR_MODE"] = "2"
Expand All @@ -11,8 +12,11 @@
class LogMetricsCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if state.is_local_process_zero:
filtered_logs = {metric_name: float(value) for metric_name, value in logs.items() if
metric_name != "total_flos"}
filtered_logs = {
metric_name: float(value)
for metric_name, value in logs.items()
if metric_name != "total_flos"
}
for metric_name, value in filtered_logs.items():
if metric_name in [
"train_loss",
Expand Down
20 changes: 10 additions & 10 deletions ViT-classification/experiment/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@
from picsellia.exceptions import ResourceNotFoundError
from picsellia.sdk.dataset_version import DatasetVersion
from picsellia.types.enums import InferenceType
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from torchvision.transforms import Compose, Normalize, RandomResizedCrop, ToTensor
from transformers import (
pipeline,
AutoImageProcessor,
DefaultDataCollator,
AutoModelForImageClassification,
TrainingArguments,
DefaultDataCollator,
Trainer,
TrainingArguments,
pipeline,
)

from abstract_trainer.trainer import AbstractTrainer
from utils import (
get_train_test_eval_datasets_from_experiment,
prepare_datasets_with_annotation,
split_single_dataset,
_move_all_files_in_class_directories,
compute_metrics,
get_asset_filename_from_path,
find_asset_by_filename,
get_asset_filename_from_path,
get_predicted_label_confidence,
get_train_test_eval_datasets_from_experiment,
log_labelmap,
prepare_datasets_with_annotation,
split_single_dataset,
)

from abstract_trainer.trainer import AbstractTrainer


class VitClassificationTrainer(AbstractTrainer):
checkpoint = "google/vit-base-patch16-224-in21k"
Expand Down
2 changes: 1 addition & 1 deletion ViT-classification/experiment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def move_image(filename: str, old_location_path: str, new_location_path: str) ->
new_path = os.path.join(new_location_path, filename)
try:
shutil.move(old_path, new_path)
except Exception as e:
except Exception:
logging.info(f"{filename} skipped.")


Expand Down
2 changes: 1 addition & 1 deletion ViT-classification/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pycocotools==2.0.6
torchvision==0.15.2
scikit-learn==1.3.0
optimum==1.14.1
onnxruntime==1.16.3
onnxruntime==1.16.3
55 changes: 24 additions & 31 deletions ViT-detection/experiment/helpers.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
import os
import picsellia

from picsellia.sdk.experiment import Experiment
import picsellia
import transformers
from datasets import DatasetDict, load_dataset
from picsellia.sdk.dataset import DatasetVersion
from utils.picsellia import download_data, evaluate_asset
from picsellia.sdk.experiment import Experiment
from picsellia.types.enums import InferenceType
from datasets import load_dataset, DatasetDict

import transformers
from transformers import (
AutoModelForObjectDetection,
TrainingArguments,
AutoImageProcessor,
AutoModelForObjectDetection,
Trainer,
TrainingArguments,
)

from utils.picsellia import download_data, evaluate_asset
from utils.vit import (
CocoDetection,
run_evaluation,
get_filenames_by_ids,
transform_images_and_annotations,
custom_train_test_eval_split,
get_id2label_mapping,
collate_fn,
save_annotation_file_images,
custom_train_test_eval_split,
format_and_write_annotations,
format_evaluation_results,
get_dataset_image_ids,
format_and_write_annotations,
get_filenames_by_ids,
get_id2label_mapping,
log_labelmap,
run_evaluation,
save_annotation_file_images,
transform_images_and_annotations,
)


Expand Down Expand Up @@ -66,9 +64,7 @@ def get_experiment() -> Experiment:

experiment = client.get_experiment_by_id(experiment_id)
else:
raise Exception(
"You must set the experiment_id"
)
raise Exception("You must set the experiment_id")
return experiment

def prepare_data_for_training(
Expand Down Expand Up @@ -131,18 +127,15 @@ def test(
self,
train_test_valid_dataset: DatasetDict,
) -> tuple[transformers.models, transformers.models]:
image_processor = AutoImageProcessor.from_pretrained(
self.output_model_dir)
model = AutoModelForObjectDetection.from_pretrained(
self.output_model_dir)
image_processor = AutoImageProcessor.from_pretrained(self.output_model_dir)
model = AutoModelForObjectDetection.from_pretrained(self.output_model_dir)

path_output, path_anno = save_annotation_file_images(
dataset=train_test_valid_dataset["test"],
experiment=self.experiment,
id2label=self.id2label,
)
test_ds_coco_format = CocoDetection(
path_output, image_processor, path_anno)
test_ds_coco_format = CocoDetection(path_output, image_processor, path_anno)

results = run_evaluation(
test_ds_coco_format=test_ds_coco_format,
Expand All @@ -161,13 +154,13 @@ def evaluate(
train_test_valid_dataset: DatasetDict,
model: transformers.models,
):
eval_image_ids = get_dataset_image_ids(
train_test_valid_dataset, "eval")
eval_image_ids = get_dataset_image_ids(train_test_valid_dataset, "eval")
id2filename_eval = get_filenames_by_ids(
image_ids=eval_image_ids, annotations=self.annotations, id_list=eval_image_ids
image_ids=eval_image_ids,
annotations=self.annotations,
id_list=eval_image_ids,
)
image_processor = AutoImageProcessor.from_pretrained(
self.output_model_dir)
image_processor = AutoImageProcessor.from_pretrained(self.output_model_dir)

for file_path in list(id2filename_eval.values()):
evaluate_asset(
Expand All @@ -176,7 +169,7 @@ def evaluate(
experiment=self.experiment,
dataset=dataset,
model=model,
image_processor=image_processor
image_processor=image_processor,
)

self.experiment.compute_evaluations_metrics(
Expand Down
3 changes: 1 addition & 2 deletions ViT-detection/experiment/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
import os

from transformers import TrainerCallback

from helpers import TrainingPipeline
from transformers import TrainerCallback

os.environ["PICSELLIA_SDK_CUSTOM_LOGGING"] = "True"
os.environ["PICSELLIA_SDK_DOWNLOAD_BAR_MODE"] = "2"
Expand Down
19 changes: 7 additions & 12 deletions ViT-detection/experiment/tests.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
import unittest
import uuid
from datasets import DatasetDict

from picsellia.sdk.label import Label
from torch import tensor
from utils.picsellia import (
get_filename_from_fullpath,
create_rectangle_list,
get_filename_from_fullpath,
reformat_box_to_coco,
)
from utils.vit import (
read_annotation_file,
format_coco_annot_to_jsonlines_format,
format_evaluation_results,
get_id2label_mapping,
get_category_names,
write_metadata_file,
formatted_annotations,
get_category_names,
get_id2label_mapping,
read_annotation_file,
)

from torch import tensor
from picsellia.sdk.label import Label
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from picsellia.sdk.connexion import Connexion


class TestDetectionVit(unittest.TestCase):
checkpoint = "facebook/detr-resnet-50"
Expand Down Expand Up @@ -245,5 +240,5 @@ def test_formatted_annotations(self):
self.assertEqual(formatted_annotation_list, expected_results)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
20 changes: 12 additions & 8 deletions ViT-detection/experiment/utils/picsellia.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from picsellia.sdk.dataset import DatasetVersion
import os

import torch
import transformers
from picsellia.sdk.asset import Asset
from picsellia.sdk.dataset import DatasetVersion
from picsellia.sdk.experiment import Experiment
from picsellia.sdk.label import Label

import os
import torch
from PIL import Image
import transformers


def download_data(experiment: Experiment) -> DatasetVersion:
Expand All @@ -23,13 +23,17 @@ def evaluate_asset(
experiment: Experiment,
model: transformers.models,
image_processor: transformers.models,
dataset: DatasetVersion
dataset: DatasetVersion,
):
dataset_labels = {label.name: label for label in dataset.list_labels()}
image_path = os.path.join(data_dir, file_path)
asset = find_asset_from_path(image_path=image_path, dataset=dataset)
results = predict_image(image_path=image_path, threshold=0.4,
model=model, image_processor=image_processor)
results = predict_image(
image_path=image_path,
threshold=0.4,
model=model,
image_processor=image_processor,
)
rectangle_list = create_rectangle_list(
results, dataset_labels, model.config.id2label
)
Expand Down
Loading

0 comments on commit 4742453

Please sign in to comment.