Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Merge pull request #20 from asfadmin/download-models-command
Browse files Browse the repository at this point in the history
add functionality for downloading models
  • Loading branch information
Id405 authored Jun 17, 2023
2 parents 15f6072 + e963f73 commit d8f9d84
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
9 changes: 6 additions & 3 deletions docs/user/library.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,14 @@ Generating Masks from Wrapped Interferograms

The :func:`insar_eventnet.inference:mask` can be used to infer masks and presence values.

The following example uses the mask model located in ``models/masking_model`` and the presence model located in ``classification_model`` to infer and plot masks and presence values from the prompted path of a wrapped interferogram.
The following example downloads models and uses them to infer and plot masks and presence values from the prompted path of a wrapped interferogram.

.. code-block:: python
from tensorflow.keras.models import load_model
from insar_eventnet.inference import mask, plot_results
from insar_eventnet.io import initialize
tile_size = 512
crop_size = 512
Expand All @@ -109,6 +110,7 @@ The following example uses the mask model located in ``models/masking_model`` an
image_name = image_path.split('/')[-1].split('.')[0]
output_path = f'masks_inferred/{image_name}_mask.tif'
initialize()
mask_model = load_model(mask_model_path)
pres_model = load_model(pres_model_path)
Expand All @@ -126,4 +128,5 @@ The following example uses the mask model located in ``models/masking_model`` an
print("Negative")
plot_results(wrapped, mask, presence)
.. note::
The initialize function both creates the directory structure and downloads models for the user.
20 changes: 18 additions & 2 deletions insar_eventnet/aievents.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def cli():
@cli.command("setup")
def setup():
"""
Create data directory subtree. This should be run before make-dataset.
Create data directory subtree and download models. This should be run before
make-dataset.
\b
data/
Expand All @@ -80,7 +81,7 @@ def setup():
└──tensorboard/
"""

from insar_eventnet.io import create_directories
from insar_eventnet.io import create_directories, download_models

print("")
create_directories()
Expand All @@ -89,6 +90,21 @@ def setup():
click.echo("Data directory created")
print("")

print("Downloading models... this may take a second")
download_models("data/output")


@cli.command("download-models")
def download_models():
"""
Download models to data/output/models
"""

from insar_eventnet.io import download_models

print("Downloading... this may take a second")
download_models("data/output")


@cli.command("make-simulated-dataset")
@click.argument("name", type=str)
Expand Down
24 changes: 24 additions & 0 deletions insar_eventnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from pathlib import Path
from typing import Tuple
from datetime import datetime
from urllib.request import urlopen
from io import BytesIO
from zipfile import ZipFile

from insar_eventnet.config import (
AOI_DIR,
Expand Down Expand Up @@ -111,6 +114,11 @@ def load_dataset(load_path: Path) -> Tuple[np.ndarray, np.ndarray]:
return dataset_file["mask"], dataset_file["wrapped"], dataset_file["presence"]


def initialize() -> None:
create_directories()
download_models("data/output")


def create_directories() -> None:
"""
Creates the directories for storing our data.
Expand All @@ -132,6 +140,22 @@ def create_directories() -> None:
print(directory.__str__() + " already exists.")


def download_models(path: str) -> None:
"""
Downloads pretrained UNet masking model and EvetNet presence prediction model
Parameters
----------
model_path: str
"""

with urlopen(
"https://eventnetmodels.s3.us-west-2.amazonaws.com/models.zip"
) as response:
with ZipFile(BytesIO(response.read())) as file:
file.extractall(path)


def get_image_array(image_path: str) -> np.ndarray:
"""
Load a interferogram .tif from storage into an array.
Expand Down

0 comments on commit d8f9d84

Please sign in to comment.