-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from pmhalvor/add-classifier
Add classifier stage
- Loading branch information
Showing
35 changed files
with
1,679 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# Repo specific ignores | ||
audio/ | ||
plots/ | ||
model/ | ||
|
||
# Python basic ignores | ||
# Byte-compiled / optimized / DLL files | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
start end encounter_ids classifications | ||
2016-12-21T00:49:30 2016-12-21T00:50:30 ['9182'] [[0.8753612041473389], [0.746759295463562], [0.26265254616737366], [0.45787951350212097], [0.35406064987182617], [0.42348742485046387], [0.4947870969772339], [0.7287474274635315], [0.7099379897117615], [0.2122703194618225], [0.044488538056612015], [0.00849922839552164], [0.024390267208218575], [0.33750119805336], [0.6530888080596924], [0.3057247996330261], [0.1243574470281601], [0.027093390002846718], [0.011367958970367908], [0.004032353404909372], [0.026372192427515984], [0.021978065371513367], [0.006407670211046934], [0.5405446887016296], [0.34207114577293396], [0.6080849766731262], [0.5394770503044128], [0.3662146031856537], [0.16772609949111938], [0.3641503155231476], [0.060217034071683884], [0.008764371275901794], [0.012523961253464222], [0.009186000563204288], [0.022050702944397926], [0.3908870816230774], [0.15179167687892914], [0.3454047441482544], [0.4770602285861969], [0.07589100301265717], [0.5439115166664124], [0.8634722232818604], [0.985602617263794], [0.3311924636363983], [0.8832067847251892], [0.6166273951530457], [0.42301759123802185], [0.03573732450604439], [0.09752023965120316], [0.01426385436207056], [0.022987568750977516], [0.012294118292629719], [0.010207954794168472], [0.00296270614489913]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Desiging the classifier module | ||
|
||
During local development, we ran into issues using our pretrained TensorFlow model inside a `beam.DoFn`. | ||
Running the model in an isolated script worked fine, with the capability to handle large inputs, | ||
but for some reason, running through Beam was problmeatic. | ||
Research tells me this is either due to a memory allocation issue or model serialization issue. | ||
|
||
Either way, a work around is needed to enable local development (for debugging purposes) that's closely coupled to our expected cloud-based production environment. | ||
|
||
## Options | ||
|
||
### Option 1: Use a smaller model | ||
I found a quantized model the seemingly condenses the [google/humpback_whale model](https://tfhub.dev/google/humpback_whale/1) size enough to run in Beam, made by Oleg A. Golev (oleggolev) at https://github.com/oleggolev/COS598D-Whale/. | ||
The original model is converted to a tflite model with slightly adapted input and output layers. | ||
Example code for handling this model can be found at [examples/quantized_model.py](../../examples/quantized_model.py) and [examples/quantized_inference.py](../../examples/quantized_inference.py). | ||
|
||
#### Pros | ||
- actually works in Beam (on my local machine) | ||
- could speed up inference time and potentially reduce overall costs | ||
- originally quantized to be deployed on small edge devices, should be portable to most environments | ||
- model files easily downloadable (present in GitHub repo) | ||
- keeps all our processing in one single unit -> cleaner project structure on our end | ||
|
||
#### Cons | ||
- initial findings found classifications on most random arrays of dummy data -> too many false positives (I could be wrong here. Track issue: https://github.com/oleggolev/COS598D-Whale/issues/1) | ||
- committing to this set-up restricts us to a fixed model size | ||
- not easily swapped out for new models or architectures -> requires quantization of each new model used (high maintaince) | ||
- expected input size correlates to 1.5 seconds of audio, which feels too short to correctly classify a whale call (I may be mistaken here though) | ||
- outputs have to be aggregated for every 1.5 seconds of audio -> more post-process compute than original model | ||
- poorly documented repository, doesn't feel easy to trust right off the bat | ||
|
||
|
||
### Option 2: Model as a service | ||
Host the model on an external resource, and call it via an API. | ||
|
||
#### Pros | ||
- model easily be swapped out, updated, monitored, and maintained | ||
- with an autoscaler, the model server can handle larger inputs or even multiple requests at once | ||
- endpoint can be easily accesible to other developers (if desired) | ||
- error handling and retries won't initially break the processing pipeline (ex. 4 retries w/ exponential backoff then return no classifications found) | ||
- build personal exprience with exposing models as services | ||
- external compute allows the ML framework (TF, ONNX, Torch, etc) to manage memory how it wants to, instead of constraints enforced by Beam | ||
- reduces pipeline dependencies (though project dependencies remain same) | ||
|
||
#### Cons | ||
- fragments the codebase -> pipeline not easily packaged as a single unit which makes portability and deployment more difficult | ||
- requires to be running on two resources instead of one | ||
- likely more expensive (though some research around model hosting/serving options may find a cost-effective solution) | ||
- requires integration with more cloud services (doubled-edged sword, since this also gives me more experience with other cloud tools) | ||
|
||
### Option 3: Continue w/o ability for local development | ||
Since the model is intended to run in the cloud anyway, we can use this motivation to push toward cloud-only development. | ||
|
||
#### Pros | ||
- can continue development as already written, following same structure as rest of pipeline | ||
- keeps all processing in one single unit | ||
|
||
#### Cons | ||
- debugging is more difficult | ||
- lack of local testing makes development more time-consuming (waiting for deploys etc) | ||
- feels very "brute-force" to just throw more resources at the problem instead of reevaluating | ||
- restricts development to high-resource environments -> expensive development | ||
|
||
## Decision | ||
I'm going to go with Option 2: Model as a service. | ||
This is by far the best choice, though I wanted to give a far chance to exploring other options. | ||
More ideas can be added underway, but option 2 is the most flexible and scalable option. | ||
Any additional costs can be mitigated by optimizing the model server or implementing an efficient teardown strategy. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,5 +76,5 @@ def run(): | |
|
||
plt.show() | ||
|
||
|
||
run() | ||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
""" | ||
Beam PTransform for classifying whale audio. | ||
Gets stuck on classification, either due to memroy issues or model serialization. | ||
Kept for reference, but replaced by InferenceClient in classify.py. | ||
""" | ||
from apache_beam.io import filesystems | ||
from datetime import datetime | ||
|
||
import apache_beam as beam | ||
import io | ||
import librosa | ||
import logging | ||
import math | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import os | ||
import scipy | ||
import time | ||
import tensorflow_hub as hub | ||
import tensorflow as tf | ||
|
||
from config import load_pipeline_config | ||
|
||
|
||
config = load_pipeline_config() | ||
|
||
|
||
class BaseClassifier(beam.PTransform): | ||
name = "BaseClassifier" | ||
|
||
|
||
def __init__(self): | ||
self.source_sample_rate = config.audio.source_sample_rate | ||
self.model_sample_rate = config.classify.model_sample_rate | ||
|
||
self.model_path = config.classify.model_path | ||
|
||
def _preprocess(self, pcoll): | ||
signal, start, end, encounter_ids = pcoll | ||
key = self._build_key(start, end, encounter_ids) | ||
|
||
# Resample | ||
signal = self._resample(signal) | ||
|
||
batch_samples = self.batch_duration * self.sample_rate | ||
|
||
if signal.shape[0] > batch_samples: | ||
logging.debug(f"Signal size exceeds max sample size {batch_samples}.") | ||
|
||
split_indices = [batch_samples*(i+1) for i in range(math.floor(signal.shape[0] / batch_samples))] | ||
signal_batches = np.array_split(signal, split_indices) | ||
logging.debug(f"Split signal into {len(signal_batches)} batches of size {batch_samples}.") | ||
logging.debug(f"Size fo final batch {len(signal_batches[1])}") | ||
|
||
for batch in signal_batches: | ||
yield (key, batch) | ||
else: | ||
yield (key, signal) | ||
|
||
def _build_key( | ||
self, | ||
start_time: datetime, | ||
end_time: datetime, | ||
encounter_ids: list, | ||
): | ||
start_str = start_time.strftime('%Y%m%dT%H%M%S') | ||
end_str = end_time.strftime('%H%M%S') | ||
encounter_str = "_".join(encounter_ids) | ||
return f"{start_str}-{end_str}_{encounter_str}" | ||
|
||
def _postprocess(self, pcoll): | ||
return pcoll | ||
|
||
def _get_model(self): | ||
model = hub.load(self.model_path) | ||
return model | ||
|
||
def _resample(self, signal): | ||
logging.info( | ||
f"Resampling signal from {self.source_sample_rate} to {self.model_sample_rate}") | ||
return librosa.resample( | ||
signal, | ||
orig_sr=self.source_sample_rate, | ||
target_sr=self.model_sample_rate | ||
) | ||
|
||
|
||
class GoogleHumpbackWhaleClassifier(BaseClassifier): | ||
""" | ||
Model docs: https://tfhub.dev/google/humpback_whale/1 | ||
""" | ||
name = "GoogleHumpbackWhaleClassifier" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.model = self._get_model() | ||
self.score_fn = self.model.signatures['score'] | ||
self.metadata_fn = self.model.signatures['metadata'] | ||
|
||
def expand(self, pcoll): | ||
return ( | ||
pcoll | ||
| "Preprocess" >> beam.Map(self._preprocess) | ||
| "Classify" >> beam.Map(self._classify) | ||
| "Postprocess" >> beam.Map(self._postprocess) | ||
) | ||
|
||
def _classify(self, pcoll, ): | ||
key, signal = pcoll | ||
|
||
start_classify = time.time() | ||
|
||
# We specify a 1-sec score resolution: | ||
context_step_samples = tf.cast(self.model_sample_rate, tf.int64) | ||
|
||
logging.info(f'\n==> Applying model ...') | ||
logging.debug(f' inital input: len(signal_10kHz) = {len(signal)}') | ||
|
||
waveform1 = np.expand_dims(signal, axis=1) | ||
waveform_exp = tf.expand_dims(waveform1, 0) # makes a batch of size 1 | ||
logging.debug(f" final input: waveform_exp.shape = {waveform_exp.shape}") | ||
|
||
signal_scores = self.score_fn( | ||
waveform=waveform_exp, | ||
context_step_samples=context_step_samples | ||
) | ||
score_values = signal_scores['scores'].numpy()[0] | ||
logging.info(f'==> Model applied. Obtained {len(score_values)} score_values') | ||
logging.info(f'==> Elapsed time: {time.time() - start_classify} seconds') | ||
|
||
return (key, score_values) | ||
|
||
def _plot_spectrogram_scipy(self, signal, epsilon = 1e-15): | ||
# Compute spectrogram: | ||
w = scipy.signal.get_window('hann', self.sample_rate) | ||
f, t, psd = scipy.signal.spectrogram( | ||
signal, # TODO make sure this is resampled signal | ||
self.model_sample_rate, | ||
nperseg=self.model_sample_rate, | ||
noverlap=0, | ||
window=w, | ||
nfft=self.model_sample_rate, | ||
) | ||
psd = 10*np.log10(psd+epsilon) - self.hydrophone_sensitivity | ||
|
||
# Plot spectrogram: | ||
fig = plt.figure(figsize=(20, round(20/3))) # 3:1 aspect ratio | ||
plt.imshow( | ||
psd, | ||
aspect='auto', | ||
origin='lower', | ||
vmin=30, | ||
vmax=90, | ||
cmap='Blues', | ||
) | ||
plt.yscale('log') | ||
y_max = self.model_sample_rate / 2 | ||
plt.ylim(10, y_max) | ||
|
||
plt.colorbar() | ||
|
||
plt.xlabel('Seconds') | ||
plt.ylabel('Frequency (Hz)') | ||
plt.title(f'Calibrated spectrum levels, 16 {self.sample_rate / 1000.0} kHz data') | ||
|
||
def _plot_scores(self, pcoll, scores, med_filt_size=None): | ||
audio, start, end, encounter_ids = pcoll | ||
key = self._build_key(start, end, encounter_ids) | ||
|
||
# repeat last value to also see a step at the end: | ||
scores = np.concatenate((scores, scores[-1:])) | ||
x = range(len(scores)) | ||
plt.step(x, scores, where='post') | ||
plt.plot(x, scores, 'o', color='lightgrey', markersize=9) | ||
|
||
plt.grid(axis='x', color='0.95') | ||
plt.xlim(xmin=0, xmax=len(scores) - 1) | ||
plt.ylabel('Model Score') | ||
plt.xlabel('Seconds') | ||
|
||
if med_filt_size is not None: | ||
scores_int = [int(s[0]*1000) for s in scores] | ||
meds_int = scipy.signal.medfilt(scores_int, kernel_size=med_filt_size) | ||
meds = [m/1000. for m in meds_int] | ||
plt.plot(x, meds, 'p', color='black', markersize=9) | ||
|
||
plot_path = config.classify.plot_path_template.format( | ||
year=start.year, | ||
month=start.month, | ||
day=start.day, | ||
plot_name=key | ||
) | ||
plt.savefig(plot_path) | ||
plt.show() |
Oops, something went wrong.