Skip to content

Commit

Permalink
Merge pull request #588 from alex-rakowski/fcu-net
Browse files Browse the repository at this point in the history
Making FCU-Net compatible with 14.9
  • Loading branch information
sezelt authored Jan 4, 2024
2 parents 4243c33 + 300aaa3 commit 3dd902e
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 15 deletions.
22 changes: 15 additions & 7 deletions py4DSTEM/braggvectors/diskdetection_aiml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import json
import shutil
import numpy as np
from pathlib import Path


from scipy.ndimage import gaussian_filter
from time import time
Expand Down Expand Up @@ -437,9 +439,9 @@ def find_Bragg_disks_aiml_serial(
raise ImportError("Import Error: Please install crystal4D before proceeding")

# Make the peaks PointListArray
# dtype = [('qx',float),('qy',float),('intensity',float)]
peaks = BraggVectors(datacube.Rshape, datacube.Qshape)

dtype = [("qx", float), ("qy", float), ("intensity", float)]
# peaks = BraggVectors(datacube.Rshape, datacube.Qshape)
peaks = PointListArray(dtype=dtype, shape=(datacube.R_Nx, datacube.R_Ny))
# check that the filtered DP is the right size for the probe kernel:
if filter_function:
assert callable(filter_function), "filter_function must be callable"
Expand Down Expand Up @@ -518,7 +520,7 @@ def find_Bragg_disks_aiml_serial(
subpixel=subpixel,
upsample_factor=upsample_factor,
filter_function=filter_function,
peaks=peaks.vectors_uncal.get_pointlist(Rx, Ry),
peaks=peaks.get_pointlist(Rx, Ry),
model_path=model_path,
)
t2 = time() - t0
Expand Down Expand Up @@ -884,7 +886,7 @@ def _get_latest_model(model_path=None):
+ "https://www.tensorflow.org/install"
+ "for more information"
)
from py4DSTEM.io.google_drive_downloader import download_file_from_google_drive
from py4DSTEM.io.google_drive_downloader import gdrive_download

tf.keras.backend.clear_session()

Expand All @@ -894,7 +896,12 @@ def _get_latest_model(model_path=None):
except:
pass
# download the json file with the meta data
download_file_from_google_drive("FCU-Net", "./tmp/model_metadata.json")
gdrive_download(
"FCU-Net",
destination="./tmp/",
filename="model_metadata.json",
overwrite=True,
)
with open("./tmp/model_metadata.json") as f:
metadata = json.load(f)
file_id = metadata["file_id"]
Expand All @@ -918,7 +925,8 @@ def _get_latest_model(model_path=None):
else:
print("Checking the latest model on the cloud... \n")
filename = file_path + file_type
download_file_from_google_drive(file_id, filename)
filename = Path(filename)
gdrive_download(file_id, destination="./tmp", filename=filename.name)
try:
shutil.unpack_archive(filename, "./tmp", format="zip")
except:
Expand Down
7 changes: 4 additions & 3 deletions py4DSTEM/braggvectors/diskdetection_aiml_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ def find_Bragg_disks_aiml_CUDA(
"""

# Make the peaks PointListArray
# dtype = [('qx',float),('qy',float),('intensity',float)]
peaks = BraggVectors(datacube.Rshape, datacube.Qshape)
dtype = [("qx", float), ("qy", float), ("intensity", float)]
# peaks = BraggVectors(datacube.Rshape, datacube.Qshape)
peaks = PointListArray(dtype=dtype, shape=(datacube.R_Nx, datacube.R_Ny))

# check that the filtered DP is the right size for the probe kernel:
if filter_function:
Expand Down Expand Up @@ -221,7 +222,7 @@ def find_Bragg_disks_aiml_CUDA(
subpixel=subpixel,
upsample_factor=upsample_factor,
filter_function=filter_function,
peaks=peaks.vectors_uncal.get_pointlist(Rx, Ry),
peaks=peaks.get_pointlist(Rx, Ry),
get_maximal_points=get_maximal_points,
blocks=blocks,
threads=threads,
Expand Down
5 changes: 3 additions & 2 deletions py4DSTEM/io/google_drive_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
),
"small_dm3_3Dstack": ("small_dm3_3Dstack.dm3", "1B-xX3F65JcWzAg0v7f1aVwnawPIfb5_o"),
"FCU-Net": (
"filename.name",
"model_metadata.json",
"1-KX0saEYfhZ9IJAOwabH38PCVtfXidJi",
),
"small_datacube": (
Expand Down Expand Up @@ -221,7 +221,8 @@ def gdrive_download(
kwargs = {"fuzzy": True}
if id_ in file_ids:
f = file_ids[id_]
filename = f[0]
# Use the name in the collection filename passed
filename = filename if filename is not None else f[0]
kwargs["id"] = f[1]

# if its not in the list of files we expect
Expand Down
12 changes: 9 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,18 @@
"ipyparallel": ["ipyparallel >= 6.2.4", "dill >= 0.3.3"],
"cuda": ["cupy >= 10.0.0"],
"acom": ["pymatgen >= 2022", "mp-api == 0.24.1"],
"aiml": ["tensorflow == 2.4.1", "tensorflow-addons <= 0.14.0", "crystal4D"],
"aiml": [
"tensorflow <= 2.10.0",
"tensorflow-addons <= 0.16.1",
"crystal4D",
"typeguard == 2.7",
],
"aiml-cuda": [
"tensorflow == 2.4.1",
"tensorflow-addons <= 0.14.0",
"tensorflow <= 2.10.0",
"tensorflow-addons <= 0.16.1",
"crystal4D",
"cupy >= 10.0.0",
"typeguard == 2.7",
],
"numba": ["numba >= 0.49.1"],
},
Expand Down

0 comments on commit 3dd902e

Please sign in to comment.