Skip to content

Commit

Permalink
Merge pull request #675 from weecology/crop_classifier
Browse files Browse the repository at this point in the history
Create a crop classifier workflow
  • Loading branch information
henrykironde authored Jun 24, 2024
2 parents a16b146 + 8e0c111 commit 1237965
Show file tree
Hide file tree
Showing 9 changed files with 587 additions and 7 deletions.
64 changes: 64 additions & 0 deletions deepforest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from PIL import Image
import rasterio as rio
from deepforest import preprocess
from rasterio.windows import Window
from torchvision import transforms


def get_transform(augment):
Expand Down Expand Up @@ -178,3 +180,65 @@ def __getitem__(self, idx):
crop = preprocess.preprocess_image(crop)

return crop


def bounding_box_transform(augment=False):
data_transforms = []
data_transforms.append(transforms.ToTensor())
data_transforms.append(resnet_normalize)
data_transforms.append(transforms.Resize([224, 224]))
if augment:
data_transforms.append(transforms.RandomHorizontalFlip(0.5))
return transforms.Compose(data_transforms)


resnet_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])


class BoundingBoxDataset(Dataset):
"""An in memory dataset for bounding box predictions
Args:
df: a pandas dataframe with image_path and xmin,xmax,ymin,ymax columns
transform: a function to apply to the image
root_dir: the directory where the image is stored
Returns:
rgb: a tensor of shape (3, height, width)
"""

def __init__(self, df, root_dir, transform=None, augment=False):
self.df = df

if transform is None:
self.transform = bounding_box_transform(augment=augment)
else:
self.transform = transform

unique_image = self.df['image_path'].unique()
assert len(unique_image
) == 1, "There should be only one unique image for this class object"

# Open the image using rasterio
self.src = rio.open(os.path.join(root_dir, unique_image[0]))

def __len__(self):
return len(self.df)

def __getitem__(self, idx):
row = self.df.iloc[idx]
xmin = row['xmin']
xmax = row['xmax']
ymin = row['ymin']
ymax = row['ymax']

# Read the RGB data
box = self.src.read(window=Window(xmin, ymin, xmax - xmin, ymax - ymin))
box = np.rollaxis(box, 0, 3)

if self.transform:
image = self.transform(box)
else:
image = box

return image
17 changes: 16 additions & 1 deletion deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,10 @@ def predict_tile(self,
sigma=0.5,
thresh=0.001,
color=None,
thickness=1):
thickness=1,
crop_model=None,
crop_transform=None,
crop_augment=False):
"""For images too large to input into the model, predict_tile cuts the
image into overlapping windows, predicts trees on each window and
reassambles into a single array.
Expand All @@ -449,6 +452,9 @@ def predict_tile(self,
thresh: the score thresh used to filter bboxes after soft-nms performed
color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
thickness: thickness of the rectangle border line in px
cropModel: a deepforest.model.CropModel object to predict on crops
crop_transform: a torchvision.transforms object to apply to crops
crop_augment: a boolean to apply augmentations to crops
Returns:
boxes (array): if return_plot, an image.
Expand Down Expand Up @@ -521,6 +527,15 @@ def predict_tile(self,

return list(zip(results, self.crops))

if crop_model:
# If a crop model is provided, predict on each crop
results = predict._predict_crop_model_(crop_model=crop_model,
results=results,
raster_path=raster_path,
trainer=self.trainer,
transform=crop_transform,
augment=crop_augment)

return results

def training_step(self, batch, batch_idx):
Expand Down
207 changes: 207 additions & 0 deletions deepforest/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
# Model - common class
from deepforest.models import *
import torch
from pytorch_lightning import LightningModule, Trainer
import os
import torch
import torchmetrics
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
import numpy as np
import rasterio
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import cv2


class Model():
Expand Down Expand Up @@ -49,3 +62,197 @@ def check_model(self):
model_keys = list(predictions[1].keys())
model_keys.sort()
assert model_keys == ['boxes', 'labels', 'scores']


def simple_resnet_50(num_classes=2):
m = models.resnet50(pretrained=True)
num_ftrs = m.fc.in_features
m.fc = torch.nn.Linear(num_ftrs, num_classes)

return m


class CropModel(LightningModule):

def __init__(self, num_classes=2, batch_size=4, num_workers=0, lr=0.0001, model=None):
super().__init__()

# Model
self.num_classes = num_classes
if model == None:
self.model = simple_resnet_50(num_classes=num_classes)
else:
self.model = model

# Metrics
self.accuracy = torchmetrics.Accuracy(average='none',
num_classes=num_classes,
task="multiclass")
self.total_accuracy = torchmetrics.Accuracy(num_classes=num_classes,
task="multiclass")
self.precision_metric = torchmetrics.Precision(num_classes=num_classes,
task="multiclass")
self.metrics = torchmetrics.MetricCollection({
"Class Accuracy": self.accuracy,
"Accuracy": self.total_accuracy,
"Precision": self.precision_metric
})

# Training Hyperparameters
self.batch_size = batch_size
self.num_workers = num_workers
self.lr = lr

def create_trainer(self, **kwargs):
"""Create a pytorch lightning trainer object"""
self.trainer = Trainer(**kwargs)

def load_from_disk(self, train_dir, val_dir):
self.train_ds = ImageFolder(root=train_dir,
transform=self.get_transform(augment=True))
self.val_ds = ImageFolder(root=val_dir,
transform=self.get_transform(augment=False))

def get_transform(self, augment):
"""
Returns the data transformation pipeline for the model.
Parameters:
augment (bool): Flag indicating whether to apply data augmentation.
Returns:
torchvision.transforms.Compose: The composed data transformation pipeline.
"""
data_transforms = []
data_transforms.append(transforms.ToTensor())
data_transforms.append(self.normalize())
data_transforms.append(transforms.Resize([224, 224]))
if augment:
data_transforms.append(transforms.RandomHorizontalFlip(0.5))
return transforms.Compose(data_transforms)

def write_crops(self, root_dir, images, boxes, labels, savedir):
"""
Write crops to disk.
Args:
root_dir (str): The root directory where the images are located.
images (list): A list of image filenames.
boxes (list): A list of bounding box coordinates in the format [xmin, ymin, xmax, ymax].
labels (list): A list of labels corresponding to each bounding box.
savedir (str): The directory where the cropped images will be saved.
Returns:
None
"""

# Create a directory for each label
for label in labels:
os.makedirs(os.path.join(savedir, label), exist_ok=True)

# Use rasterio to read the image
for index, box in enumerate(boxes):
xmin, ymin, xmax, ymax = box
label = labels[index]
image = images[index]
with rasterio.open(os.path.join(root_dir, image)) as src:
# Crop the image using the bounding box coordinates
img = src.read(window=((ymin, ymax), (xmin, xmax)))
# Save the cropped image as a PNG file using opencv
img_path = os.path.join(savedir, label, f"crop_{index}.png")
img = np.rollaxis(img, 0, 3)
cv2.imwrite(img_path, img)

def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def forward(self, x):
output = self.model(x)
output = F.sigmoid(output)

return output

def train_dataloader(self):
train_loader = torch.utils.data.DataLoader(self.train_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers)

return train_loader

def predict_dataloader(self, ds):
loader = torch.utils.data.DataLoader(ds,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)

return loader

def val_dataloader(self):
val_loader = torch.utils.data.DataLoader(self.val_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers)

return val_loader

def training_step(self, batch, batch_idx):
x, y = batch
outputs = self.forward(x)
loss = F.cross_entropy(outputs, y)
self.log("train_loss", loss)

return loss

def predict_step(self, batch, batch_idx):
outputs = self.forward(batch)
yhat = F.softmax(outputs, 1)

return yhat

def validation_step(self, batch, batch_idx):
x, y = batch
outputs = self(x)
loss = F.cross_entropy(outputs, y)
self.log("val_loss", loss)
metric_dict = self.metrics(outputs, y)
for key, value in metric_dict.items():
for key, value in metric_dict.items():
if isinstance(value, torch.Tensor) and value.numel() > 1:
for i, v in enumerate(value):
self.log(f"{key}_{i}", v, on_step=False, on_epoch=True)
else:
self.log(key, value, on_step=False, on_epoch=True)
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
mode='min',
factor=0.5,
patience=10,
verbose=True,
threshold=0.0001,
threshold_mode='rel',
cooldown=0,
min_lr=0,
eps=1e-08)

#Monitor rate is val data is used
return {'optimizer': optimizer, 'lr_scheduler': scheduler, "monitor": 'val_loss'}

def dataset_confusion(self, loader):
"""Create a confusion matrix from a data loader"""
true_class = []
predicted_class = []
self.eval()
for batch in loader:
x, y = batch
true_class.append(F.one_hot(y, num_classes=self.num_classes).detach().numpy())
prediction = self(x)
predicted_class.append(prediction.detach().numpy())

true_class = np.concatenate(true_class)
predicted_class = np.concatenate(predicted_class)

return true_class, predicted_class
38 changes: 37 additions & 1 deletion deepforest/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from torchvision.ops import nms
import typing

from deepforest import visualize
from deepforest import visualize, dataset
import rasterio


def _predict_image_(model,
Expand Down Expand Up @@ -188,3 +189,38 @@ def _dataloader_wrapper_(model,
thickness=thickness)

return results


def _predict_crop_model_(crop_model,
trainer,
results,
raster_path,
transform=None,
augment=False):
"""
Predicts crop model on a raster file.
Args:
crop_model: The crop model to be used for prediction.
trainer: The pytorch lightning trainer object for prediction.
results: The results dataframe to store the predicted labels and scores.
raster_path: The path to the raster file.
Returns:
The updated results dataframe with predicted labels and scores.
"""
bounding_box_dataset = dataset.BoundingBoxDataset(
results,
root_dir=os.path.dirname(raster_path),
transform=transform,
augment=augment)
crop_dataloader = crop_model.predict_dataloader(bounding_box_dataset)
crop_results = trainer.predict(crop_model, crop_dataloader)
stacked_outputs = np.vstack(np.concatenate(crop_results))
label = np.argmax(stacked_outputs, 1)
score = np.max(stacked_outputs, 1)

results["cropmodel_label"] = label
results["cropmodel_score"] = score

return results
Loading

0 comments on commit 1237965

Please sign in to comment.