Skip to content

Commit

Permalink
Merge pull request #224 from bacetiner/master
Browse files Browse the repository at this point in the history
Cleaned up ConsTypeClassifier and added the new model URL
  • Loading branch information
bacetiner authored Aug 29, 2024
2 parents d0af236 + d7d787b commit 627e3db
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
"""Class object to use or create construction type classification models."""
#
# Copyright (c) 2022 The Regents of the University of California
#
Expand Down Expand Up @@ -35,41 +35,91 @@
#
# Contributors:
# Barbaros Cetiner
#
# Last updated:
# 08-29-2024


import os
from brails.modules.ImageClassifier.ImageClassifier import ImageClassifier

import torch
import os


class ConsTypeClassifier(ImageClassifier):
"""
Class for facilitating classification of different construction types.
ConsTypeClassifier is a specialized class for classifying different
construction types (e.g., MAB, MAS, RCC, STL, WOD). It allows for making
predictions using a pre-trained model and retraining the pre-trained model
on new data. Inherits from the ImageClassifier class.
This class loads a default pre-trained model if a custom model path is not
provided during initialization. It can also be retrained using new
datasets.
"""

def __init__(self, modelPath=None):

if modelPath == None:
os.makedirs('tmp/models',exist_ok=True)
modelPath = 'tmp/models/consTypeClassifier_v1.pth'
if not os.path.isfile(modelPath):
print('Loading default construction type classifier model file to tmp/models folder...')
torch.hub.download_url_to_file('https://zenodo.org/record/7271554/files/trained_model_constype.pth',
modelPath, progress=False)
def __init__(self, model_path: str = None) -> None:
"""
Initialize the ConsTypeClassifier.
Args_
model_path (str, optional): Path to the model file. If None, it
will load the default construction type classifier model.
"""
if model_path is None:
os.makedirs('tmp/models', exist_ok=True)
model_path = 'tmp/models/consTypeClassifier_v1.pth'
if not os.path.isfile(model_path):
print('Loading default construction type classifier model ' +
'file to tmp/models folder...')
torch.hub.download_url_to_file('https://zenodo.org/record/' +
'13525814/files/' +
'constype_classifier_v1.pth',
model_path, progress=False)
print('Default construction type classifier model loaded')
else:
print(f"Default construction type classifier model at {modelPath} loaded")
else:
print('Default construction type classifier model at ' +
f"{model_path} loaded")
else:
print(f'Inferences will be performed using the custom model at {modelPath}')

self.modelPath = modelPath
self.classes = ['MAB','MAS','RCC','STL','WOD']

def predict(self, dataDir):
imageClassifier = ImageClassifier()
imageClassifier.predict(self.modelPath,dataDir,self.classes)
self.preds = imageClassifier.preds

def retrain(self, dataDir, batchSize=8, nepochs=100, plotLoss=True):
imageClassifier = ImageClassifier()
imageClassifier.retrain(self.modelPath,dataDir,batchSize,nepochs,plotLoss)

print('Inferences will be performed using the custom model at ' +
f'{model_path}')

self.model_path: str = model_path
self.classes: list[str] = ['MAB', 'MAS',
'RCC', 'STL', 'WOD'] # Construction types

def predict(self, data_dir: str) -> None:
"""
Perform construction type predictions on images in the specified path.
Args__
data_dir (str): Path to the directory containing images to be
classified.
"""
image_classifier = ImageClassifier()
image_classifier.predict(self.model_path, data_dir, self.classes)
self.preds = image_classifier.preds # Store predictions

def retrain(self,
data_dir: str,
batch_size: int = 8,
nepochs: int = 100,
plot_loss: bool = True) -> None:
"""
Retrain the construction type classifier on new data.
Args__
data_dir (str): Path to the directory containing training data.
batch_size (int, optional): Batch size for training. Default is 8.
nepochs (int, optional): Number of epochs for training. Default is
100.
plot_loss (bool, optional): Whether to plot the loss during
training. Default is True.
"""
image_classifier = ImageClassifier()
image_classifier.retrain(self.model_path, data_dir,
batch_size, nepochs, plot_loss)


if __name__ == '__main__':
pass
pass
3 changes: 2 additions & 1 deletion brails/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@
#
# Contributors:
# Barbaros Cetiner
# Yunhui Guo
# Yunhui Guo
# Sascha Hornauer

from brails.modules.ImageClassifier.ImageClassifier import ImageClassifier
from brails.modules.ImageSegmenter.ImageSegmenter import ImageSegmenter
from brails.modules.ConstructionTypeClassifier.ConstructionTypeClassifier import ConsTypeClassifier
from brails.modules.RoofTypeClassifier.RoofTypeClassifier import RoofClassifier
from brails.modules.OccupancyClassifier.OccupancyClassifier import OccupancyClassifier
from brails.modules.ChimneyDetector.ChimneyDetector import ChimneyDetector
Expand Down

0 comments on commit 627e3db

Please sign in to comment.