diff --git a/deepforest/dataset.py b/deepforest/dataset.py index e723c154..fe3f3900 100644 --- a/deepforest/dataset.py +++ b/deepforest/dataset.py @@ -24,6 +24,8 @@ from PIL import Image import rasterio as rio from deepforest import preprocess +from rasterio.windows import Window +from torchvision import transforms def get_transform(augment): @@ -178,3 +180,65 @@ def __getitem__(self, idx): crop = preprocess.preprocess_image(crop) return crop + + +def bounding_box_transform(augment=False): + data_transforms = [] + data_transforms.append(transforms.ToTensor()) + data_transforms.append(resnet_normalize) + data_transforms.append(transforms.Resize([224, 224])) + if augment: + data_transforms.append(transforms.RandomHorizontalFlip(0.5)) + return transforms.Compose(data_transforms) + + +resnet_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + +class BoundingBoxDataset(Dataset): + """An in memory dataset for bounding box predictions + + Args: + df: a pandas dataframe with image_path and xmin,xmax,ymin,ymax columns + transform: a function to apply to the image + root_dir: the directory where the image is stored + Returns: + rgb: a tensor of shape (3, height, width) + """ + + def __init__(self, df, root_dir, transform=None, augment=False): + self.df = df + + if transform is None: + self.transform = bounding_box_transform(augment=augment) + else: + self.transform = transform + + unique_image = self.df['image_path'].unique() + assert len(unique_image + ) == 1, "There should be only one unique image for this class object" + + # Open the image using rasterio + self.src = rio.open(os.path.join(root_dir, unique_image[0])) + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + row = self.df.iloc[idx] + xmin = row['xmin'] + xmax = row['xmax'] + ymin = row['ymin'] + ymax = row['ymax'] + + # Read the RGB data + box = self.src.read(window=Window(xmin, ymin, xmax - xmin, ymax - ymin)) + box = np.rollaxis(box, 0, 3) + + if self.transform: + image = self.transform(box) + else: + image = box + + return image diff --git a/deepforest/main.py b/deepforest/main.py index a7238f38..c86d4ace 100644 --- a/deepforest/main.py +++ b/deepforest/main.py @@ -429,7 +429,10 @@ def predict_tile(self, sigma=0.5, thresh=0.001, color=None, - thickness=1): + thickness=1, + crop_model=None, + crop_transform=None, + crop_augment=False): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and reassambles into a single array. @@ -449,6 +452,9 @@ def predict_tile(self, thresh: the score thresh used to filter bboxes after soft-nms performed color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255) thickness: thickness of the rectangle border line in px + cropModel: a deepforest.model.CropModel object to predict on crops + crop_transform: a torchvision.transforms object to apply to crops + crop_augment: a boolean to apply augmentations to crops Returns: boxes (array): if return_plot, an image. @@ -521,6 +527,15 @@ def predict_tile(self, return list(zip(results, self.crops)) + if crop_model: + # If a crop model is provided, predict on each crop + results = predict._predict_crop_model_(crop_model=crop_model, + results=results, + raster_path=raster_path, + trainer=self.trainer, + transform=crop_transform, + augment=crop_augment) + return results def training_step(self, batch, batch_idx): diff --git a/deepforest/model.py b/deepforest/model.py index bde76c0a..3faa0582 100644 --- a/deepforest/model.py +++ b/deepforest/model.py @@ -1,6 +1,19 @@ # Model - common class from deepforest.models import * import torch +from pytorch_lightning import LightningModule, Trainer +import os +import torch +import torchmetrics +from torchvision import models, transforms +from torchvision.datasets import ImageFolder +import numpy as np +import rasterio +import numpy as np +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import cv2 class Model(): @@ -49,3 +62,197 @@ def check_model(self): model_keys = list(predictions[1].keys()) model_keys.sort() assert model_keys == ['boxes', 'labels', 'scores'] + + +def simple_resnet_50(num_classes=2): + m = models.resnet50(pretrained=True) + num_ftrs = m.fc.in_features + m.fc = torch.nn.Linear(num_ftrs, num_classes) + + return m + + +class CropModel(LightningModule): + + def __init__(self, num_classes=2, batch_size=4, num_workers=0, lr=0.0001, model=None): + super().__init__() + + # Model + self.num_classes = num_classes + if model == None: + self.model = simple_resnet_50(num_classes=num_classes) + else: + self.model = model + + # Metrics + self.accuracy = torchmetrics.Accuracy(average='none', + num_classes=num_classes, + task="multiclass") + self.total_accuracy = torchmetrics.Accuracy(num_classes=num_classes, + task="multiclass") + self.precision_metric = torchmetrics.Precision(num_classes=num_classes, + task="multiclass") + self.metrics = torchmetrics.MetricCollection({ + "Class Accuracy": self.accuracy, + "Accuracy": self.total_accuracy, + "Precision": self.precision_metric + }) + + # Training Hyperparameters + self.batch_size = batch_size + self.num_workers = num_workers + self.lr = lr + + def create_trainer(self, **kwargs): + """Create a pytorch lightning trainer object""" + self.trainer = Trainer(**kwargs) + + def load_from_disk(self, train_dir, val_dir): + self.train_ds = ImageFolder(root=train_dir, + transform=self.get_transform(augment=True)) + self.val_ds = ImageFolder(root=val_dir, + transform=self.get_transform(augment=False)) + + def get_transform(self, augment): + """ + Returns the data transformation pipeline for the model. + + Parameters: + augment (bool): Flag indicating whether to apply data augmentation. + + Returns: + torchvision.transforms.Compose: The composed data transformation pipeline. + """ + data_transforms = [] + data_transforms.append(transforms.ToTensor()) + data_transforms.append(self.normalize()) + data_transforms.append(transforms.Resize([224, 224])) + if augment: + data_transforms.append(transforms.RandomHorizontalFlip(0.5)) + return transforms.Compose(data_transforms) + + def write_crops(self, root_dir, images, boxes, labels, savedir): + """ + Write crops to disk. + + Args: + root_dir (str): The root directory where the images are located. + images (list): A list of image filenames. + boxes (list): A list of bounding box coordinates in the format [xmin, ymin, xmax, ymax]. + labels (list): A list of labels corresponding to each bounding box. + savedir (str): The directory where the cropped images will be saved. + + Returns: + None + """ + + # Create a directory for each label + for label in labels: + os.makedirs(os.path.join(savedir, label), exist_ok=True) + + # Use rasterio to read the image + for index, box in enumerate(boxes): + xmin, ymin, xmax, ymax = box + label = labels[index] + image = images[index] + with rasterio.open(os.path.join(root_dir, image)) as src: + # Crop the image using the bounding box coordinates + img = src.read(window=((ymin, ymax), (xmin, xmax))) + # Save the cropped image as a PNG file using opencv + img_path = os.path.join(savedir, label, f"crop_{index}.png") + img = np.rollaxis(img, 0, 3) + cv2.imwrite(img_path, img) + + def normalize(self): + return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + def forward(self, x): + output = self.model(x) + output = F.sigmoid(output) + + return output + + def train_dataloader(self): + train_loader = torch.utils.data.DataLoader(self.train_ds, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers) + + return train_loader + + def predict_dataloader(self, ds): + loader = torch.utils.data.DataLoader(ds, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers) + + return loader + + def val_dataloader(self): + val_loader = torch.utils.data.DataLoader(self.val_ds, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers) + + return val_loader + + def training_step(self, batch, batch_idx): + x, y = batch + outputs = self.forward(x) + loss = F.cross_entropy(outputs, y) + self.log("train_loss", loss) + + return loss + + def predict_step(self, batch, batch_idx): + outputs = self.forward(batch) + yhat = F.softmax(outputs, 1) + + return yhat + + def validation_step(self, batch, batch_idx): + x, y = batch + outputs = self(x) + loss = F.cross_entropy(outputs, y) + self.log("val_loss", loss) + metric_dict = self.metrics(outputs, y) + for key, value in metric_dict.items(): + for key, value in metric_dict.items(): + if isinstance(value, torch.Tensor) and value.numel() > 1: + for i, v in enumerate(value): + self.log(f"{key}_{i}", v, on_step=False, on_epoch=True) + else: + self.log(key, value, on_step=False, on_epoch=True) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, + mode='min', + factor=0.5, + patience=10, + verbose=True, + threshold=0.0001, + threshold_mode='rel', + cooldown=0, + min_lr=0, + eps=1e-08) + + #Monitor rate is val data is used + return {'optimizer': optimizer, 'lr_scheduler': scheduler, "monitor": 'val_loss'} + + def dataset_confusion(self, loader): + """Create a confusion matrix from a data loader""" + true_class = [] + predicted_class = [] + self.eval() + for batch in loader: + x, y = batch + true_class.append(F.one_hot(y, num_classes=self.num_classes).detach().numpy()) + prediction = self(x) + predicted_class.append(prediction.detach().numpy()) + + true_class = np.concatenate(true_class) + predicted_class = np.concatenate(predicted_class) + + return true_class, predicted_class diff --git a/deepforest/predict.py b/deepforest/predict.py index fb92308e..1f486766 100644 --- a/deepforest/predict.py +++ b/deepforest/predict.py @@ -10,7 +10,8 @@ from torchvision.ops import nms import typing -from deepforest import visualize +from deepforest import visualize, dataset +import rasterio def _predict_image_(model, @@ -188,3 +189,38 @@ def _dataloader_wrapper_(model, thickness=thickness) return results + + +def _predict_crop_model_(crop_model, + trainer, + results, + raster_path, + transform=None, + augment=False): + """ + Predicts crop model on a raster file. + + Args: + crop_model: The crop model to be used for prediction. + trainer: The pytorch lightning trainer object for prediction. + results: The results dataframe to store the predicted labels and scores. + raster_path: The path to the raster file. + + Returns: + The updated results dataframe with predicted labels and scores. + """ + bounding_box_dataset = dataset.BoundingBoxDataset( + results, + root_dir=os.path.dirname(raster_path), + transform=transform, + augment=augment) + crop_dataloader = crop_model.predict_dataloader(bounding_box_dataset) + crop_results = trainer.predict(crop_model, crop_dataloader) + stacked_outputs = np.vstack(np.concatenate(crop_results)) + label = np.argmax(stacked_outputs, 1) + score = np.max(stacked_outputs, 1) + + results["cropmodel_label"] = label + results["cropmodel_score"] = score + + return results diff --git a/docs/CropModel.md b/docs/CropModel.md new file mode 100644 index 00000000..70284a14 --- /dev/null +++ b/docs/CropModel.md @@ -0,0 +1,118 @@ +# The CropModel: Classifying objects after object detection + +One of the most requested features since the early days of DeepForest was the ability to apply a follow-up model to predicted bounding boxes. For example, if we use the 'tree' or 'bird' backbone, but then we want to classify each of those detections with our model without retraining the upstream detector. Beginning in version 1.4.0, the CropModel class can be used in conjunction with predict_tile and predict_image methods. The general workflow is the object detection model is first applied, the prediction locations are extracted into images, optionally saved to disk, and a second model is applied on each crop image. A new column 'cropmodel_label' and 'cropmodel_score' will appear alongside the object detection model label and score. + +## Benefits + +Why would you want to apply a model directly on each crop? Why not train a multi-class object detection model? This is certainly a reasonable approach, but there are a few benefits in particular common use-cases. + +* Object detection models require that all objects of a particular class are annotated within an image. This is often impossible for detailed category labels. For example, you might have bounding boxes for all 'trees' in an image, but only have species or health labels for a small portion of them based on ground surveys. Training a multi-class object detection model would invariably mean training on only a portion of your available data. + +* CropModels are simpler and more extendable. By decoupling the detection and classification workflows, you can seperately handle challenges like class imbalance and incomplete labels, without reducing the quality of the detections. We have found that training two stage object detection models to be finicky for many similar classes and involve reasonable knowledge on managing learning rates. + +* New data and multi-sensor learning. For many applications the data needed for detection and classification may be different. The CropModel concept allows an extendable piece that can allow others to make more advanced pipelines. + +## Considerations + +* Using a CropModel will be slower, since for each detection, the sensor data needs to be cropped and passed to the detector. This is definitely less efficient than using a combined classification/detection system like the multi-class detection models. With modern GPUs, this ofter matters less, but its something to be mindful of. + +* The model knows only about the pixels that exist inside the crop, and cannot use features outside the bounding box. The lack of spatial awareness is a major limitation. It is possible, but untested that the multi-class detection model is better at this kind of task. Desgining a genuine box attention mechansim is probably better (https://arxiv.org/abs/2111.13087). + +## Use + +Consider a testfile with tree boxes and a 'Alive/Dead' label that comes with all DeepForest installations + +``` +df = pd.read_csv(get_data("testfile_multi.csv")) +crop_model = model.CropModel(num_classes=2) +``` + +This is a pytorch-lightning object and can be used like any other model. + +``` +# Test forward pass +x = torch.rand(4, 3, 224, 224) +output = crop_model.forward(x) +assert output.shape == (4, 2) +``` + +The only difference is now we don't have boxes, we are classifier for entire crops. We can do this within memory, or by writing a set of crops to disk. Let's start by writing to disk. + +``` +boxes = df[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist() +image_path = os.path.join(os.path.dirname(get_data("SOAP_061.png")),df["image_path"].iloc[0]) +crop_model.write_crops(boxes=boxes,labels=df.label.values,image_path=image_path, savedir=tmpdir) +``` + +This crops each box location and saves them in a folder with the label name. Now we have two folders in the savedir location, a 'Alive' and a 'Dead' folder. + +## Training + +We could train a new model from here in typical pytorch-lightning syntax. + +``` +crop_model.create_trainer(fast_dev_run=True) +# Get the data stored from the write_crops step above. +crop_model.load_from_disk(train_dir=tmpdir, val_dir=tmpdir) +crop_model.trainer.fit(crop_model) +crop_model.trainer.validate(crop_model) +``` + +## Customizing + +The CropModel makes very few assumptions about the architecture and simply provides a container to make predictions at each detection. To specify a custom cropmodel, use the model argument. + +``` +from deepforest.model import CropModel +from torchvision.models import resnet101 +backbone = resnet101(weights='DEFAULT') +crop_model = CropModel(num_classes=2, model=backbone) +``` + +One detail is that the preprocessing transform will differ for backbones, make sure to check the final lines + +``` +print(crop_model.get_transform(augment=True)) +... +... +)> + Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None) + RandomHorizontalFlip(p=0.5) +) +``` +To see the torchvision transform.Compose statement. You can overwrite this if needed for the torchvision.ImageFolder reader when reading existing images. + +``` +def custom_transform(self, augment): + data_transforms = [] + data_transforms.append(transforms.ToTensor()) + data_transforms.append(self.normalize) + + data_transforms.append(transforms.Resize([,])) + if augment: + data_transforms.append(transforms.RandomHorizontalFlip(0.5)) + return transforms.Compose(data_transforms) +crop_model.get_transform = custom_transform +``` + +Or if running from within memory crops during prediction, you can pass the transform and augment flag to predict methods + +``` +m.predict_tile(...,crop_transform=custom_transform, augment=False) +``` + +This allows full flexibility over the preprocessing steps. For further customization, you can subclass the CropModel object and change methods such as learning rate optimzation, evaluation steps and all other pytorch lightning hooks. + +``` +class CustomCropModel(CropModel): + def training_step(self, batch, batch_idx): + # Custom training step implementation + # Add your code here + return loss + +# Create an instance of the custom CropModel +model = CustomCropModel() +``` + + + diff --git a/docs/prebuilt.md b/docs/prebuilt.md index 89354fda..4b561045 100644 --- a/docs/prebuilt.md +++ b/docs/prebuilt.md @@ -38,6 +38,23 @@ We have created a [GPU colab tutorial](https://colab.research.google.com/drive/1 For more information, or specific questions about the bird detection, please create issues on the [BirdDetector repo](https://github.com/weecology/BirdDetector) +# Crop Classifiers + +## Alive/Dead trees +To provide a simple filter for trees that appear dead in the RGB data we collected 6,342 image crops from the prediction landscape, as well as other NEON sites, and hand annotated them as either alive or dead. We finetuned a resnet-50 pre-trained on ImageNet to classify alive or dead trees before passing them to the species classification model. The model was trained with an ADAM optimizer with a learning rate of 0.001 and batch size of 128 for 40 epochs, and was evaluated on a randomly held out of 10% of the crops. The evaluation accuracy of the alive-dead model was 95.8% (Table S1). + +Table S1 Confusion matrix for the Alive/Dead model in Weinstein et al. 2023 + + | Predicted | Alive | Dead | + |-----------------|-------|------| + | Observed | 527 | 9 | + | | 10 | 89 | + +* Note *, due to the smaller training sizes, the confidence scores are overfit and not smooth. We do not recommend using the confidence scores from this model until it is trained on more diverse data. + + +Citation: Weinstein, Ben G., et al. "Capturing long‐tailed individual tree diversity using an airborne imaging and a multi‐temporal hierarchical model." Remote Sensing in Ecology and Conservation 9.5 (2023): 656-670. + ## Want more pretrained models? Please consider contributing your data to open source repositories, such as zenodo or lila.science. The more data we gather, the more we can combine the annotation and data collection efforts of hundreds of researchers to built models available to everyone. We welcome suggestions on what models and data are most urgently [needed](https://github.com/weecology/DeepForest/discussions). \ No newline at end of file diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 1bfdd524..2495f80a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -8,7 +8,8 @@ import pandas as pd import numpy as np import tempfile -import rasterio as rio +import rasterio as rio +from deepforest.dataset import BoundingBoxDataset def single_class(): csv_file = get_data("example.csv") @@ -155,4 +156,20 @@ def test_TileDataset(preload_images): #assert crop shape assert ds[1].shape == (3, 100, 100) - + +def test_BoundingBoxDataset(): + # Create a sample dataframe + df = pd.read_csv(get_data("OSBS_029.csv")) + + # Create the BoundingBoxDataset object + ds = BoundingBoxDataset(df, root_dir=os.path.dirname(get_data("OSBS_029.png"))) + + # Check the length of the dataset + assert len(ds) == df.shape[0] + + # Get an item from the dataset + item = ds[0] + + # Check the shape of the RGB tensor + assert item.shape == (3, 224,224) + diff --git a/tests/test_main.py b/tests/test_main.py index e25f20f7..bca6b2fc 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -13,7 +13,7 @@ import albumentations as A from albumentations.pytorch import ToTensorV2 -from deepforest import main, get_data, dataset +from deepforest import main, get_data, dataset, model from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import TensorBoardLogger @@ -555,6 +555,7 @@ def test_existing_predict_dataloader(m, tmpdir): batches = m.trainer.predict(m, existing_loader) len(batches[0]) == m.config["batch_size"] + 1 + # Test train with each scheduler @pytest.mark.parametrize("scheduler,expected",[("cosine","CosineAnnealingLR"), ("lambdaLR","LambdaLR"), @@ -613,4 +614,35 @@ def test_configure_optimizers(scheduler, expected): m.trainer.fit(m) # Assert the scheduler type - assert type(m.trainer.lr_scheduler_configs[0].scheduler).__name__ == scheduler_config["expected"], f"Scheduler type mismatch for {scheduler_config['type']}" \ No newline at end of file + assert type(m.trainer.lr_scheduler_configs[0].scheduler).__name__ == scheduler_config["expected"], f"Scheduler type mismatch for {scheduler_config['type']}" + +@pytest.fixture() +def crop_model(): + return model.CropModel() + +def test_predict_tile_with_crop_model(m, config): + raster_path = get_data("SOAP_061.png") + patch_size = 400 + patch_overlap = 0.05 + iou_threshold = 0.15 + return_plot = False + mosaic = True + + + # Set up the crop model + crop_model = model.CropModel() + + # Call the predict_tile method with the crop_model + m.config["train"]["fast_dev_run"] = False + m.create_trainer() + result = m.predict_tile(raster_path=raster_path, + patch_size=patch_size, + patch_overlap=patch_overlap, + iou_threshold=iou_threshold, + return_plot=return_plot, + mosaic=mosaic, + crop_model=crop_model) + + # Assert the result + assert isinstance(result, pd.DataFrame) + assert set(result.columns) == {"xmin", "ymin", "xmax", "ymax", "label", "score", "cropmodel_label","cropmodel_score","image_path"} \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py index d9251879..52b0fa0c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,10 +1,84 @@ import pytest import torch from deepforest import model +from deepforest import get_data +import pandas as pd +import os +import numpy as np +from torchvision import transforms +import cv2 # The model object is achitecture agnostic container. def test_model_no_args(config): with pytest.raises(ValueError): model.Model(config) - \ No newline at end of file +# The model object is achitecture agnostic container. +def test_model_no_args(config): + with pytest.raises(ValueError): + model.Model(config) + +@pytest.fixture() +def crop_model(): + crop_model = model.CropModel(num_classes=2) + + return crop_model + +def test_crop_model(crop_model): # Use pytest tempdir fixture to create a temporary directory + # Test forward pass + x = torch.rand(4, 3, 224, 224) + output = crop_model.forward(x) + assert output.shape == (4, 2) + + # Test training step + batch = (x, torch.tensor([0, 1, 0, 1])) + loss = crop_model.training_step(batch, batch_idx=0) + assert isinstance(loss, torch.Tensor) + + # Test validation step + val_batch = (x, torch.tensor([0, 1, 0, 1])) + val_loss = crop_model.validation_step(val_batch, batch_idx=0) + assert isinstance(val_loss, torch.Tensor) + +def test_crop_model_train(crop_model, tmpdir): + df = pd.read_csv(get_data("testfile_multi.csv")) + boxes = df[['xmin', 'ymin', 'xmax', 'ymax']].values.tolist() + root_dir = os.path.dirname(get_data("SOAP_061.png")) + images = df.image_path.values + crop_model.write_crops(boxes=boxes,labels=df.label.values, root_dir=root_dir, images=images, savedir=tmpdir) + + #Create a trainer + crop_model.create_trainer(fast_dev_run=True) + crop_model.load_from_disk(train_dir=tmpdir, val_dir=tmpdir) + + # Test training dataloader + train_loader = crop_model.train_dataloader() + assert isinstance(train_loader, torch.utils.data.DataLoader) + + # Test validation dataloader + val_loader = crop_model.val_dataloader() + assert isinstance(val_loader, torch.utils.data.DataLoader) + + crop_model.trainer.fit(crop_model) + crop_model.trainer.validate(crop_model) + +def test_crop_model_custom_transform(): + # Create a dummy instance of CropModel + crop_model = model.CropModel(num_classes=2) + + def custom_transform(self, augment): + data_transforms = [] + data_transforms.append(transforms.ToTensor()) + data_transforms.append(self.normalize) + # Add transforms here + data_transforms.append(transforms.Resize([300, 300])) + if augment: + data_transforms.append(transforms.RandomHorizontalFlip(0.5)) + return transforms.Compose(data_transforms) + + # Test custom transform + x = torch.rand(4, 3, 300, 300) + crop_model.get_transform = custom_transform + output = crop_model.forward(x) + assert output.shape == (4, 2) +