diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 000000000..682203bf3
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,61 @@
+// For format details, see https://aka.ms/devcontainer.json. For config options, see the
+// README at: https://github.com/devcontainers/templates/tree/main/src/miniconda
+{
+ "name": "py4dstem-dev",
+ "image": "mcr.microsoft.com/vscode/devcontainers/miniconda:0-3",
+ // "build": {
+ // "context": "..",
+ // "dockerfile": "Dockerfile"
+ // },
+
+ // Features to add to the dev container. More info: https://containers.dev/features.
+ // "features": {},
+
+ // Use 'forwardPorts' to make a list of ports inside the container available locally.
+ // "forwardPorts": []
+
+ // Use 'postCreateCommand' to run commands after the container is created.
+ "postCreateCommand": "/opt/conda/bin/conda init && /opt/conda/bin/pip install -e /workspaces/py4DSTEM/ && /opt/conda/bin/pip install ipython ipykernel jupyter",
+
+ // Configure tool-specific properties.
+ "customizations": {
+ "vscode": {
+ "settings": {
+ "python.defaultInterpreterPath": "/opt/conda/bin/python",
+ "python.analysis.autoFormatStrings": true,
+ "python.analysis.completeFunctionParens": true,
+ "ruff.showNotifications": "onWarning",
+ "workbench.colorTheme": "Monokai",
+ // "editor.defaultFormatter": "ms-python.black-formatter",
+ "editor.fontFamily": "Menlo, Monaco, 'Courier New', monospace",
+ "editor.bracketPairColorization.enabled": true,
+ "editor.guides.bracketPairs": "active",
+ "editor.minimap.renderCharacters": false,
+ "editor.minimap.autohide": true,
+ "editor.minimap.scale": 2,
+ "[python]": {
+ "editor.defaultFormatter": "ms-python.black-formatter",
+ "editor.codeActionsOnSave": {
+ "source.organizeImports": false
+ }
+ }
+ },
+ "extensions": [
+ "ms-python.python",
+ "donjayamanne.python-extension-pack",
+ "ms-python.vscode-pylance",
+ "ms-toolsai.jupyter",
+ "GitHub.codespaces",
+ "ms-azuretools.vscode-docker",
+ "DavidAnson.vscode-markdownlint",
+ "ms-vsliveshare.vsliveshare",
+ "charliermarsh.ruff",
+ "eamodio.gitlens",
+ "ms-python.black-formatter"
+ ]
+ }
+ }
+
+ // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
+ // "remoteUser": "root"
+}
\ No newline at end of file
diff --git a/.flake8 b/.flake8
new file mode 100644
index 000000000..4fc1beac0
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,17 @@
+[flake8]
+extend-ignore:
+ E114,
+ E115,
+ E116,
+ E201,
+ E202,
+ E203,
+ E204,
+ E231,
+ E265,
+ E266,
+ E303,
+ E402,
+ E501,
+
+
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 000000000..9054d12c5
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,37 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: ''
+labels: ''
+assignees: ''
+
+---
+
+**Describe the bug**
+A clear and concise description of what the bug is,
+
+**To Reproduce**
+Steps to reproduce the behavior, please be as general as possible, and ideally recreate a minimal reproducible example:
+
+**Expected behavior**
+A clear and concise description of what you expected to happen.
+
+
+**py4DSTEM version**
+It can be accessed by running `py4DSTEM.__version__`
+**Python version**
+It can be accessed using `sys.version`
+**Operating system**
+Windows, Mac (Intel or ARM), Linux (Distro)
+
+**GPU**
+If GPU related please provide:
+- CUDA driver - It can be accessed by: `nvidia-smi`
+- Cupy Version - It can be accessed by `cupy.__version__`
+
+**Screenshots**
+If applicable, could you add screenshots to help explain your problem?
+
+
+**Additional context**
+Please feel free to add any other context about the problem here.
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 000000000..26f72eadb
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,20 @@
+---
+name: Feature request
+about: Suggest an idea for this project
+title: ''
+labels: ''
+assignees: ''
+
+---
+
+**Is your feature request related to a problem? Please describe.**
+A clear and concise description of what the problem is.
+
+**Describe the solution you'd like**
+A clear and concise description of what you want to happen.
+
+**Describe alternatives you've considered**
+A clear and concise description of any alternative solutions or features you've considered.
+
+**Additional context**
+Add any other context or screenshots about the feature request here.
diff --git a/.github/scripts/update_version.py b/.github/scripts/update_version.py
new file mode 100644
index 000000000..635cf8268
--- /dev/null
+++ b/.github/scripts/update_version.py
@@ -0,0 +1,20 @@
+"""
+Script to update the patch version number of the py4DSTEM package.
+"""
+
+version_file_path = "py4DSTEM/version.py"
+
+with open(version_file_path, "r") as f:
+ lines = f.readlines()
+
+line_split = lines[0].split(".")
+patch_number = line_split[2].split("'")[0].split('"')[0]
+
+# Increment patch number
+patch_number = str(int(patch_number) + 1) + "'"
+
+
+new_line = line_split[0] + "." + line_split[1] + "." + patch_number
+
+with open(version_file_path, "w") as f:
+ f.write(new_line)
diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml
new file mode 100644
index 000000000..09b2a0fba
--- /dev/null
+++ b/.github/workflows/black.yml
@@ -0,0 +1,14 @@
+name: Check code style
+
+on:
+ push:
+ branches: [ "dev" ]
+ pull_request:
+ branches: [ "dev" ]
+
+jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: psf/black@stable
\ No newline at end of file
diff --git a/.github/workflows/check_install_dev.yml b/.github/workflows/check_install_dev.yml
new file mode 100644
index 000000000..4e9d16f77
--- /dev/null
+++ b/.github/workflows/check_install_dev.yml
@@ -0,0 +1,45 @@
+name: Install Checker Dev
+on:
+ push:
+ branches: [ "dev" ]
+ pull_request:
+ branches: [ "dev" ]
+jobs:
+
+ test-python-os-versions:
+ name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} on ${{ matrix.architecture }}
+ continue-on-error: ${{ matrix.allow_failure }}
+ runs-on: ${{ matrix.runs-on }}
+ strategy:
+ fail-fast: false
+ matrix:
+ allow_failure: [false]
+ runs-on: [ubuntu-latest]
+ architecture: [x86_64]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ # include:
+ # - python-version: "3.12.0-beta.4"
+ # runs-on: ubuntu-latest
+ # allow_failure: true
+ # Currently no public runners available for this but this or arm64 should work next time
+ # include:
+ # - python-version: "3.10"
+ # architecture: [aarch64]
+ # runs-on: macos-latest
+ # allow_failure: true
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install repo
+ run: |
+ python -m pip install .
+ - name: Check installation
+ run: |
+ python -c "import py4DSTEM; print(py4DSTEM.__version__)"
+ # - name: Check machine arch
+ # run: |
+ # python -c "import platform; print(platform.machine())"
diff --git a/.github/workflows/check_install_main.yml b/.github/workflows/check_install_main.yml
new file mode 100644
index 000000000..a276cab17
--- /dev/null
+++ b/.github/workflows/check_install_main.yml
@@ -0,0 +1,45 @@
+name: Install Checker Main
+on:
+ push:
+ branches: [ "main" ]
+ pull_request:
+ branches: [ "main" ]
+jobs:
+
+ test-python-os-versions:
+ name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} on ${{ matrix.architecture }}
+ continue-on-error: ${{ matrix.allow_failure }}
+ runs-on: ${{ matrix.runs-on }}
+ strategy:
+ fail-fast: false
+ matrix:
+ allow_failure: [false]
+ runs-on: [ubuntu-latest, windows-latest, macos-latest]
+ architecture: [x86_64]
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
+ #include:
+ # - python-version: "3.12.0-beta.4"
+ # runs-on: ubuntu-latest
+ # allow_failure: true
+ # Currently no public runners available for this but this or arm64 should work next time
+ # include:
+ # - python-version: "3.10"
+ # architecture: [aarch64]
+ # runs-on: macos-latest
+ # allow_failure: true
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install repo
+ run: |
+ python -m pip install .
+ - name: Check installation
+ run: |
+ python -c "import py4DSTEM; print(py4DSTEM.__version__)"
+ - name: Check machine arch
+ run: |
+ python -c "import platform; print(platform.machine())"
diff --git a/.github/workflows/check_install_quick.yml b/.github/workflows/check_install_quick.yml
new file mode 100644
index 000000000..f83ee0b73
--- /dev/null
+++ b/.github/workflows/check_install_quick.yml
@@ -0,0 +1,45 @@
+name: Install Checker Quick
+on:
+ push:
+ branches-ignore:
+ - main
+ - dev
+ pull_request:
+ branches-ignore:
+ - main
+ - dev
+jobs:
+
+ test-python-os-versions:
+ name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} on ${{ matrix.architecture }}
+ continue-on-error: ${{ matrix.allow_failure }}
+ runs-on: ${{ matrix.runs-on }}
+ strategy:
+ fail-fast: false
+ matrix:
+ allow_failure: [false]
+ runs-on: [ubuntu-latest]
+ architecture: [x86_64]
+ python-version: ["3.9", "3.12"]
+ # Currently no public runners available for this but this or arm64 should work next time
+ # include:
+ # - python-version: "3.10"
+ # architecture: [aarch64]
+ # runs-on: macos-latest
+ # allow_failure: true
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Setup Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install repo
+ run: |
+ python -m pip install .
+ - name: Check installation
+ run: |
+ python -c "import py4DSTEM; print(py4DSTEM.__version__)"
+ # - name: Check machine arch
+ # run: |
+ # python -c "import platform; print(platform.machine())"
\ No newline at end of file
diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml
new file mode 100644
index 000000000..2537fc983
--- /dev/null
+++ b/.github/workflows/linter.yml
@@ -0,0 +1,30 @@
+name: Lint with super-linter@v5-slim
+
+on:
+ push:
+ branches: [ "dev" ]
+ pull_request:
+ branches: [ "dev" ]
+
+jobs:
+ run-lint:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+ with:
+ # Full git history is needed to get a proper list of changed files within `super-linter`
+ fetch-depth: 0
+
+ - name: Lint Code Base
+ uses: super-linter/super-linter/slim@v5 # updated to latest slim as quicker to download
+ env:
+ VALIDATE_ALL_CODEBASE: false # only check changes
+ VALIDATE_PYTHON_FLAKE8: true # lint with flake8
+ DEFAULT_BRANCH: "dev" # set default branch to dev
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # for github things
+ # FILTER_REGEX_EXCLUDE: .*test/.* # exclude test dirs
+ FILTER_REGEX_EXCLUDE: .*__init__.py/.* # exclude test dirs
+ FILTER_REGEX_INCLUDE: .*py4DSTEM/.* # only look for py4DSTEM
+ LINTER_RULES_PATH: / # set toplevel dir as the path to look for rules
+ PYTHON_FLAKE8_CONFIG_FILE: .flake8 # set specific config file
diff --git a/.github/workflows/pypi_upload.yml b/.github/workflows/pypi_upload.yml
new file mode 100644
index 000000000..264c69030
--- /dev/null
+++ b/.github/workflows/pypi_upload.yml
@@ -0,0 +1,85 @@
+# Action to check the version of the package and upload it to PyPI
+# if the version is higher than the one on PyPI
+name: PyPI Upload
+
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+ branches:
+ - main
+
+jobs:
+ update_version:
+ runs-on: ubuntu-latest
+ name: Check if version.py is changed and update if the version.py is not changed
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ fetch-depth: 0
+ token: ${{ secrets.GH_ACTION_VERSION_UPDATE }}
+ - name: Get changed files
+ id: changed-files-specific
+ uses: tj-actions/changed-files@v39
+ with:
+ files: |
+ py4DSTEM/version.py
+ - name: Debug version file change checker
+ run: |
+ echo "Checking variable..."
+ echo ${{ steps.changed-files-specific.outputs.any_changed }}
+ echo "Done"
+ - name: Running if py4DSTEM/version.py file is not changed
+ if: steps.changed-files-specific.outputs.any_changed == 'false'
+ run: |
+ echo "Version file not changed, running script to change the version file."
+ #git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }}
+ python .github/scripts/update_version.py
+ git config --global user.email "ben.savitzky@gmail.com"
+ git config --global user.name "bsavitzky"
+ git commit -a -m "Auto-update version number (GH Action)"
+ git push origin main
+ sync_with_dev:
+ needs: update_version
+ runs-on: ubuntu-latest
+ name: Sync main with dev
+ steps:
+ - name: Sync main with dev
+ uses: actions/checkout@v3
+ with:
+ ref: dev
+ fetch-depth: 0
+ token: ${{ secrets.GH_ACTION_VERSION_UPDATE }}
+ - run: |
+ # set strategy to default merge
+ git config pull.rebase false
+ git config --global user.email "ben.savitzky@gmail.com"
+ git config --global user.name "bsavitzky"
+ git pull origin main --commit --no-edit
+ git push origin dev
+ deploy:
+ needs: sync_with_dev
+ runs-on: ubuntu-latest
+ name: Deploy to PyPI
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ ref: dev
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.8
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install build
+ - name: Build package
+ run: python -m build
+ - name: Publish package
+ uses: pypa/gh-action-pypi-publish@release/v1
+ with:
+ user: __token__
+ password: ${{ secrets.PYPI_API_TOKEN }}
+
+
diff --git a/.gitignore b/.gitignore
index 3fc678c9a..24587a3b3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,88 @@
*.pyc
.project
.pydevproject
+*.swp
+*.ipynb_checkpoints*
+.vscode/
+pyrightconfig.json
+
+# Folders #
+.idea/
+__pycache__/
+*.egg-info/
+build/
+dist/
+test/unit_test_data/
+sample_code/data/*h5
+sample_code/data/*dm3
+.vscode
+# Data Specific #
+#################
+#*.png
+#*.gif
+*.emd
+*.emf
+*.dm3
+*.dm4
+*.tiff
+*.tif
+*.jpg
+*.avi
+*.mp4
+*.emd
+*.pdf
+*.bin
+*.mat
+*.doc
+*.docx
+*.xlsx
+*.asv
+*.eps
+*.fig
+*.ai
+*.avi
+*.m~
+*.db
+*.eps
+*.asv
+#*.svg
+*.emf
+
+# Compiled source #
+###################
+*.com
+*.class
+*.dll
+*.exe
+*.o
+*.so
+*.mex*
+
+# Packages #
+############
+# it's better to unpack these files and commit the raw source
+# git has its own built in compression methods
+*.7z
+*.dmg
+*.gz
+*.iso
+*.jar
+*.rar
+*.tar
+*.zip
+
+# Logs and databases #
+######################
+*.log
+*.sql
+*.sqlite
+
+# OS generated files #
+######################
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+ehthumbs.db
+Thumbs.db
\ No newline at end of file
diff --git a/.readthedocs.yaml b/.readthedocs.yaml
new file mode 100644
index 000000000..5a18e5dd1
--- /dev/null
+++ b/.readthedocs.yaml
@@ -0,0 +1,17 @@
+version: 2
+
+build:
+ os: "ubuntu-20.04"
+ tools:
+ python: "3.10"
+
+sphinx:
+ configuration: docs/source/conf.py
+
+python:
+ install:
+ - method: pip
+ path: .
+ - requirements: docs/requirements.txt
+
+formats: all
diff --git a/CITATION.cff b/CITATION.cff
new file mode 100644
index 000000000..89d614c5f
--- /dev/null
+++ b/CITATION.cff
@@ -0,0 +1,104 @@
+cff-version: 1.1.0
+message: "If you use this software, please cite the accompanying paper."
+abstract: "Scanning transmission electron microscopy (STEM) allows for imaging, diffraction, and spectroscopy of materials on length scales ranging from microns to atoms. By using a high-speed, direct electron detector, it is now possible to record a full two-dimensional (2D) image of the diffracted electron beam at each probe position, typically a 2D grid of probe positions. These 4D-STEM datasets are rich in information, including signatures of the local structure, orientation, deformation, electromagnetic fields, and other sample-dependent properties. However, extracting this information requires complex analysis pipelines that include data wrangling, calibration, analysis, and visualization, all while maintaining robustness against imaging distortions and artifacts. In this paper, we present py4DSTEM, an analysis toolkit for measuring material properties from 4D-STEM datasets, written in the Python language and released with an open-source license. We describe the algorithmic steps for dataset calibration and various 4D-STEM property measurements in detail and present results from several experimental datasets. We also implement a simple and universal file format appropriate for electron microscopy data in py4DSTEM, which uses the open-source HDF5 standard. We hope this tool will benefit the research community and help improve the standards for data and computational methods in electron microscopy, and we invite the community to contribute to this ongoing project."
+authors:
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Savitzky
+ given-names: "Benjamin H."
+ -
+ affiliation: "Department of Materials Science and Engineering, University of California, Berkeley, CA 94720, USA"
+ family-names: Zeltmann
+ given-names: "Steven E. "
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Hughes
+ given-names: "Lauren A."
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Brown
+ given-names: "Hamish G."
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA and Department of Materials Science and Engineering, University of California, Berkeley, CA 94720, USA"
+ family-names: Zhao
+ given-names: Shiteng
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA and Department of Materials Science and Engineering, University of California, Berkeley, CA 94720, USA"
+ family-names: Pelz
+ given-names: "Philipp M."
+ -
+ affiliation: "Institut für Physik, Humboldt-Universität zu Berlin, Newtonstraße 15, 12489 Berlin, Germany"
+ family-names: Pekin
+ given-names: "Thomas C."
+ -
+ affiliation: "Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Barnard
+ given-names: "Edward S."
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA Department of Materials Science and Engineering, University of California, Berkeley, CA 94720, USA"
+ family-names: Donohue
+ given-names: Jennifer
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA Department of Materials Science and Engineering, University of California, Berkeley, CA 94720, USA"
+ family-names: "Rangel DaCosta"
+ given-names: Luis
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA Department of Materials Science and Engineering, University of California, Berkeley, CA 94720, USA"
+ family-names: Kennedy
+ given-names: Ellis
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Xie
+ given-names: Yujun
+ -
+ affiliation: "Los Alamos National Laboratory, Los Alamos, NM 87545, USA"
+ family-names: Janish
+ given-names: "Matthew T."
+ -
+ affiliation: "Los Alamos National Laboratory, Los Alamos, NM 87545, USA"
+ family-names: Schneider
+ given-names: "Matthew M."
+ -
+ affiliation: "Toyota Research Institute, Los Altos, CA 94022, USA"
+ family-names: Herring
+ given-names: Patrick
+ -
+ affiliation: "Toyota Research Institute, Los Altos, CA 94022, USA"
+ family-names: Gopal
+ given-names: Chirranjeevi
+ -
+ affiliation: "Toyota Research Institute, Los Altos, CA 94022, USA"
+ family-names: Anapolsky
+ given-names: Abraham
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Dhall
+ given-names: Rohan
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Bustillo
+ given-names: "Karen C."
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Ercius
+ given-names: Peter
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA Department of Materials Science and Engineering, University of California, Berkeley, CA 94720, USA"
+ family-names: Scott
+ given-names: "Mary C."
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Ciston
+ given-names: Jim
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA Department of Materials Science and Engineering, University of California, Berkeley, CA 94720, USA"
+ family-names: Minor
+ given-names: "Andrew M."
+ -
+ affiliation: "National Center for Electron Microscopy, Molecular Foundry, Lawrence Berkeley National Laboratory, 1 Cyclotron Road, Berkeley, CA 94720, USA"
+ family-names: Ophus
+ given-names: Colin
+title: "py4DSTEM: A Software Package for Four-Dimensional Scanning Transmission Electron Microscopy Data Analysis"
+version: 0.12.6
+doi: 10.1017/S1431927621000477
+date-released: 2021-05-21
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 000000000..72f3b9d31
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,675 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
+
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 000000000..42eb4101e
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+include LICENSE.txt
diff --git a/README.md b/README.md
new file mode 100644
index 000000000..aa102542a
--- /dev/null
+++ b/README.md
@@ -0,0 +1,225 @@
+
+> :warning: **py4DSTEM version 0.14 update** :warning: Warning: this is a major update and we expect some workflows to break. You can still install previous versions of py4DSTEM [as discussed here](#legacyinstall)
+
+
+
+![py4DSTEM logo](/images/py4DSTEM_logo.png)
+
+**py4DSTEM** is an open source set of python tools for processing and analysis of four-dimensional scanning transmission electron microscopy (4D-STEM) data.
+Additional information:
+
+- [Installation instructions](#install)
+- [The py4DSTEM documentation pages](https://py4dstem.readthedocs.io/en/latest/index.html).
+- [Tutorials and example code](https://github.com/py4dstem/py4DSTEM_tutorials)
+- [Our open access py4DSTEM publication in Microscopy and Microanalysis](https://doi.org/10.1017/S1431927621000477) describing this project and demonstrating a variety of applications.
+- [Our open access 4D-STEM review in Microscopy and Microanalysis](https://doi.org/10.1017/S1431927619000497) describing this project and demonstrating a variety of applications.
+
+
+
+# What is 4D-STEM?
+
+In a traditional STEM experiment, a beam of high energy electrons is focused to a very fine probe - on the order of, or even smaller than, the spacing between atoms - and rastered across the surface of the sample. A conventional two-dimensional STEM image is formed by populating the value of each pixel with the electron flux through a detector at the corresponding beam position. In 4D-STEM, a pixelated detector is used instead, where a 2D image of the diffracted probe is recorded at every rastered probe position. A 4D-STEM scan thus results in a 4D data array.
+
+
+4D-STEM data is information rich.
+A datacube can be collapsed in real space to yield information comparable to nanobeam electron diffraction experiment, or in diffraction space to yield a variety of virtual images, corresponding to both traditional STEM imaging modes as well as more exotic virtual imaging modalities.
+The structure, symmetries, and spacings of Bragg disks can be used to extract spatially resolved maps of crystallinity, grain orientations, and lattice strain.
+Redundant information in overlapping Bragg disks can be leveraged to calculate the sample potential.
+Structure in the diffracted halos of amorphous systems can be used to describe the short and medium range order.
+
+**py4DSTEM** supports many different modes of 4DSTEM analysis.
+The tutorials, sample code, module, and function documentation all provide more detailed discussion on some of the analytical methods possible with this code.
+
+
+
+
+# py4DSTEM Installation
+
+[![PyPI version](https://badge.fury.io/py/py4dstem.svg)](https://badge.fury.io/py/py4dstem)
+[![Anaconda-Server Badge](https://anaconda.org/conda-forge/py4dstem/badges/version.svg)](https://anaconda.org/conda-forge/py4dstem)
+[![Anaconda-Server Badge](https://anaconda.org/conda-forge/py4dstem/badges/latest_release_date.svg)](https://anaconda.org/conda-forge/py4dstem)
+[![Anaconda-Server Badge](https://anaconda.org/conda-forge/py4dstem/badges/platforms.svg)](https://anaconda.org/conda-forge/py4dstem)
+[![Anaconda-Server Badge](https://anaconda.org/conda-forge/py4dstem/badges/downloads.svg)](https://anaconda.org/conda-forge/py4dstem)
+
+The recommended installation for **py4DSTEM** uses the Anaconda python distribution.
+First, download and install Anaconda: www.anaconda.com/download.
+If you prefer a more lightweight conda client, you can instead install Miniconda: https://docs.conda.io/en/latest/miniconda.html.
+Then open a conda terminal and run one of the following sets of commands to ensure everything is up-to-date and create a new environment for your py4DSTEM installation:
+
+```
+conda update conda
+conda create -n py4dstem
+conda activate py4dstem
+conda install -c conda-forge py4dstem pymatgen jupyterlab
+```
+
+In order, these commands
+- ensure your installation of anaconda is up-to-date
+- make a virtual environment (see below)
+- enter the environment
+- install py4DSTEM, as well as pymatgen (used for crystal structure calculations) and JupyterLab (an interface for running Python notebooks like those in the [py4DSTEM tutorials repository](https://github.com/py4dstem/py4DSTEM_tutorials))
+
+
+We've had some recent reports install of `conda` getting stuck trying to solve the environment using the above installation. If you run into this problem, you can install py4DSTEM using `pip` instead of `conda` by running:
+
+```
+conda update conda
+conda create -n py4dstem python=3.10
+conda activate py4dstem
+pip install py4dstem pymatgen
+```
+
+Both `conda` and `pip` are programs which manage package installations, i.e. make sure different codes you're installing which depend on one another are using mutually compatible versions. Each has advantages and disadvantages; `pip` is a little more bare-bones, and we've seen this install work when `conda` doesn't. If you also want to use Jupyterlab you can then use either `pip install jupyterlab` or `conda install jupyterlab`.
+
+If you would prefer to install only the base modules of **py4DSTEM**, and skip pymategen and Jupterlab, you can instead run:
+
+```
+conda install -c conda-forge py4dstem
+```
+
+Finally, regardless of which of the above approaches you used, in Windows you should then also run:
+
+```
+conda install pywin32
+```
+
+which enables Python to talk to the Windows API.
+
+Please note that virtual environments are used in the instructions above in order to make sure packages that have different dependencies don't conflict with one another.
+Because these directions install py4DSTEM to its own virtual environment, each time you want to use py4DSTEM you'll need to activate this environment.
+You can do this in the command line by running `conda activate py4dstem`, or, if you're using the Anaconda Navigator, by clicking on the Environments tab and then clicking on `py4dstem`.
+
+Last - as of the version 0.14.4 update, we've had a few reports of problems upgrading to the newest version. We're not sure what's causing the issue yet, but have found the new version can be installed successfully in these cases using a fresh Anaconda installation.
+
+
+
+## Legacy installations (version <0.14)
+
+The latest version of py4DSTEM (v0.14) makes changes to the classes and functions which may not be compatible with code written for prior versions.
+We are working to ensure better backwards-compatibility in the future.
+For now, if you have code from earlier versions, you can either (1) install the legacy version of your choice, or (2) update legacy code to use the version 0.14 methods.
+To update your code to the new syntax, check out the examples in the [py4DSTEM tutorials repository](https://github.com/py4dstem/py4DSTEM_tutorials) and the docstrings for the classes and functions you're using.
+To install the legacy version of py4DSTEM of your choice, you can call
+
+```
+pip install py4dstem==0.XX.XX
+```
+
+substituting the desired version for `XX.XX`. For instance, you can install the last version 13 release with
+
+```
+pip install py4dstem==0.13.17
+```
+
+or the last version 12 release with
+
+```
+pip install py4dstem==0.12.24
+```
+
+
+
+
+## Advanced installations - GPU acceleration and ML functionality
+
+To install the py4dstem with AI/ML functionality, follow the steps below.
+If you're using a machine with an Nvidia GPU and CUDA capability, run:
+
+```
+conda update conda
+conda create -n py4dstem-aiml
+conda activate py4dstem-aiml
+conda install -c conda-forge cudatoolkit=11.0 cudnn=8.1 cupy
+pip install "py4dstem[aiml-cuda]"
+```
+
+If your machine does not have a CUDA capable device, run
+```
+conda update conda
+conda create -n py4dstem
+conda activate py4dstem
+conda install pip
+pip install "py4dstem[aiml]"
+```
+
+
+
+# The py4DSTEM GUI
+
+The py4DSTEM GUI data browser has been moved to a separate repository.
+You can [find that repository here](https://github.com/py4dstem/py4D-browser).
+You can install the GUI from the command line with:
+
+```
+pip install py4D-browser
+```
+
+The py4D-browser can then be launched from the command line by calling:
+
+```
+py4DGUI
+```
+
+
+
+
+# Running the code
+
+The anaconda navigator can be used to launch various Python interfaces, including Jupyter Notebooks, JupyterLab, PyCharm, and others.
+
+Once you're inside the conda environment where you installed py4DSTEM and you've launched an interface to the Python interpreter, you can import **py4DSTEM** to access all its modules and functions using `import py4DSTEM`.
+
+
+## Example code and tutorials
+
+At this point you'll need code, and data!
+Sample code demonstrating a variety of workflows can be found in [the py4DSTEM tutorials repository](https://github.com/py4dstem/py4DSTEM_tutorials) in the `/notebooks` directory.
+These sample files are provided as Jupyter notebooks.
+Links to the data used in each notebook are provided in the intro cell of each notebook.
+
+
+
+
+# More information
+
+## Documentation
+
+Our documentation pages are [available here](https://py4dstem.readthedocs.io/en/latest/index.html).
+
+
+
+## For contributors
+
+Please see [here](https://gist.github.com/bsavitzky/8b1ee4c1244814940e7cff4500535dba).
+
+
+## Scientific papers which use **py4DSTEM**
+
+See a list [here](docs/papers.md).
+
+
+
+
+# Acknowledgements
+
+If you use py4DSTEM for a scientific study, please cite [our open access py4DSTEM publication in Microscopy and Microanalysis](https://doi.org/10.1017/S1431927621000477). You are also free to use the py4DSTEM [logo in PDF format](images/py4DSTEM_logo_54.pdf) or [logo in PNG format](images/py4DSTEM_logo_54_export.png) for presentations or posters.
+
+
+[![TRI logo](/images/toyota_research_institute.png)](https://www.tri.global/)
+
+
+The developers gratefully acknowledge the financial support of the Toyota Research Institute for the research and development time which made this project possible.
+
+[![DOE logo](/images/DOE_logo.png)](https://www.energy.gov/science/bes/basic-energy-sciences/)
+
+Additional funding has been provided by the US Department of Energy, Office of Science, Basic Energy Sciences.
+
+
+
+# License
+
+GNU GPLv3
+
+**py4DSTEM** is open source software distributed under a GPLv3 license.
+It is free to use, alter, or build on, provided that any work derived from **py4DSTEM** is also kept free and open under a GPLv3 license.
+
diff --git a/ScopeFoundry/__init__.py b/ScopeFoundry/__init__.py
deleted file mode 100644
index 2017aa413..000000000
--- a/ScopeFoundry/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .base_gui import BaseMicroscopeApp, BaseApp
-from .measurement import Measurement
-from .hardware import HardwareComponent
-from .logged_quantity import LoggedQuantity, LQRange, LQCollection
diff --git a/ScopeFoundry/base_gui.py b/ScopeFoundry/base_gui.py
deleted file mode 100644
index dcfe26ab2..000000000
--- a/ScopeFoundry/base_gui.py
+++ /dev/null
@@ -1,374 +0,0 @@
-'''
-Created on Jul 23, 2014
-
-'''
-
-import sys, os
-import time
-import datetime
-import numpy as np
-import collections
-from collections import OrderedDict
-import configparser
-
-from PySide2 import QtCore, QtGui, QtUiTools
-import pyqtgraph as pg
-#import pyqtgraph.console
-import IPython
-if IPython.version_info[0] < 4:
- from IPython.qt.console.rich_ipython_widget import RichIPythonWidget as RichJupyterWidget
- from IPython.qt.inprocess import QtInProcessKernelManager
-else:
- from qtconsole.rich_jupyter_widget import RichJupyterWidget
- from qtconsole.inprocess import QtInProcessKernelManager
-
-import matplotlib
-matplotlib.rcParams['backend.qt4'] = 'PySide'
-from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg as FigureCanvas
-from matplotlib.backends.backend_qt4agg import NavigationToolbar2QT as NavigationToolbar2
-
-from matplotlib.figure import Figure
-
-from .logged_quantity import LoggedQuantity, LQCollection
-
-from .helper_funcs import confirm_on_close, load_qt_ui_file, OrderedAttrDict
-
-#from equipment.image_display import ImageDisplay
-
-from .h5_io import h5_base_file, h5_create_measurement_group
-
-
-class BaseApp(QtCore.QObject):
-
- def __init__(self, argv):
-
- self.this_dir, self.this_filename = os.path.split(__file__)
-
-
- self.qtapp = QtGui.QApplication.instance()
- if not self.qtapp:
- self.qtapp = QtGui.QApplication(argv)
-
-
-
- self.settings = LQCollection()
-
- self.setup_console_widget()
- self.setup()
-
- if not hasattr(self, 'name'):
- self.name = "ScopeFoundry"
- self.qtapp.setApplicationName(self.name)
-
-
- def exec_(self):
- return self.qtapp.exec_()
-
- def setup_console_widget(self):
- # Console
- #self.console_widget = pyqtgraph.console.ConsoleWidget(namespace={'gui':self, 'pg':pg, 'np':np}, text="ScopeFoundry GUI console")
- # https://github.com/ipython/ipython-in-depth/blob/master/examples/Embedding/inprocess_qtconsole.py
- self.kernel_manager = QtInProcessKernelManager()
- self.kernel_manager.start_kernel()
- self.kernel = self.kernel_manager.kernel
- self.kernel.gui = 'qt4'
- self.kernel.shell.push({'np': np, 'app': self})
- self.kernel_client = self.kernel_manager.client()
- self.kernel_client.start_channels()
-
- #self.console_widget = RichIPythonWidget()
- self.console_widget = RichJupyterWidget()
- self.console_widget.setWindowTitle("ScopeFoundry IPython Console")
- self.console_widget.kernel_manager = self.kernel_manager
- self.console_widget.kernel_client = self.kernel_client
-
- return self.console_widget
-
- def setup(self):
- pass
-
-
- def settings_save_ini(self, fname, save_ro=True):
- config = ConfigParser.ConfigParser()
- config.optionxform = str
- config.add_section('app')
- config.set('app', 'name', self.name)
- for lqname, lq in self.settings.as_dict().items():
- if not lq.ro or save_ro:
- config.set('app', lqname, lq.ini_string_value())
-
- with open(fname, 'wb') as configfile:
- config.write(configfile)
-
- print("ini settings saved to", fname, config.optionxform)
-
- def settings_load_ini(self, fname):
- print("ini settings loading from", fname)
-
- def str2bool(v):
- return v.lower() in ("yes", "true", "t", "1")
-
- config = ConfigParser.ConfigParser()
- config.optionxform = str
- config.read(fname)
-
- if 'app' in config.sections():
- for lqname, new_val in config.items('app'):
- print(lqname)
- lq = self.settings.as_dict().get(lqname)
- if lq:
- if lq.dtype == bool:
- new_val = str2bool(new_val)
- lq.update_value(new_val)
-
- def settings_save_ini_ask(self, dir=None, save_ro=True):
- # TODO add default directory, etc
- fname, _ = QtGui.QFileDialog.getSaveFileName(self.ui, caption=u'Save Settings', dir=u"", filter=u"Settings (*.ini)")
- print(repr(fname))
- if fname:
- self.settings_save_ini(fname, save_ro=save_ro)
- return fname
-
- def settings_load_ini_ask(self, dir=None):
- # TODO add default directory, etc
- fname, _ = QtGui.QFileDialog.getOpenFileName(None, "Settings (*.ini)")
- print(repr(fname))
- if fname:
- self.settings_load_ini(fname)
- return fname
-
-class BaseMicroscopeApp(BaseApp):
-
-
-
- name = "ScopeFoundry"
-
- def __del__ ( self ):
- self.ui = None
-
- def show(self):
- #self.ui.exec_()
- self.ui.show()
-
- def __init__(self, argv):
- BaseApp.__init__(self, argv)
- ui_filename = os.path.join(self.this_dir,"base_microscope_app.ui")
-
- self.hardware = OrderedAttrDict()
- self.measurements = OrderedAttrDict()
-
- # Load Qt UI from .ui file
- self.ui = load_qt_ui_file(self.ui_filename)
-
- confirm_on_close(self.ui, title="Close %s?" % self.name, message="Do you wish to shut down %s?" % self.name)
-
- # Run the subclass setup function
- self.setup()
-
- self.setup_default_ui()
-
-
- def setup_default_ui(self):
- self.ui.hardware_treeWidget.setColumnWidth(0,175)
- self.ui.measurements_treeWidget.setColumnWidth(0,175)
-
- # Setup the figures
- for name, measure in self.measurement_components.items():
- print("setting up figures for", name, "measurement", measure.name)
- measure.setup_figure()
-
- if hasattr(self.ui, 'console_pushButton'):
- self.ui.console_pushButton.clicked.connect(self.console_widget.show)
- self.ui.console_pushButton.clicked.connect(self.console_widget.activateWindow)
-
- #settings events
- if hasattr(self.ui, "settings_autosave_pushButton"):
- self.ui.settings_autosave_pushButton.clicked.connect(self.settings_auto_save)
- if hasattr(self.ui, "settings_load_last_pushButton"):
- self.ui.settings_load_last_pushButton.clicked.connect(self.settings_load_last)
- if hasattr(self.ui, "settings_save_pushButton"):
- self.ui.settings_save_pushButton.clicked.connect(self.settings_save_dialog)
- if hasattr(self.ui, "settings_load_pushButton"):
- self.ui.settings_load_pushButton.clicked.connect(self.settings_load_dialog)
-
-
- def setup(self):
- """ Override to add Hardware and Measurement Components"""
- #raise NotImplementedError()
- pass
-
-
- """def add_image_display(self,name,widget):
- print "---adding figure", name, widget
- if name in self.figs:
- return self.figs[name]
- else:
- disp=ImageDisplay(name,widget)
- self.figs[name]=disp
- return disp
- """
-
- def add_pg_graphics_layout(self, name, widget):
- print("---adding pg GraphicsLayout figure", name, widget)
- if name in self.figs:
- return self.figs[name]
- else:
- disp=pg.GraphicsLayoutWidget(border=(100,100,100))
- widget.layout().addWidget(disp)
- self.figs[name]=disp
- return disp
-
- # IDEA: write an abstract function to add pg.imageItem() for maps,
- # which haddels, pixelscale, ROI ....
- # could also be implemented in the base_2d class?
-
-
-
- def add_figure_mpl(self,name, widget):
- """creates a matplotlib figure attaches it to the qwidget specified
- (widget needs to have a layout set (preferably verticalLayout)
- adds a figure to self.figs"""
- print("---adding figure", name, widget)
- if name in self.figs:
- return self.figs[name]
- else:
- fig = Figure()
- fig.patch.set_facecolor('w')
- canvas = FigureCanvas(fig)
- nav = NavigationToolbar2(canvas, self.ui)
- widget.layout().addWidget(canvas)
- widget.layout().addWidget(nav)
- canvas.setFocusPolicy( QtCore.Qt.ClickFocus )
- canvas.setFocus()
- self.figs[name] = fig
- return fig
-
- def add_figure(self,name,widget):
- return self.add_figure_mpl(name,widget)
-
-
- def add_hardware_component(self,hc):
- self.hardware_components[hc.name] = hc
- return hc
-
- def add_measurement_component(self, measure):
- assert not measure.name in self.measurement_components.keys()
- self.measurement_components[measure.name] = measure
- return measure
-
- def settings_save_h5(self, fname):
- with h5_base_file(self, fname) as h5_file:
- for measurement in self.measurements.values():
- h5_create_measurement_group(measurement, h5_file)
- print("settings saved to", h5_file.filename)
-
- def settings_save_ini(self, fname, save_ro=True, save_gui=True, save_hardware=True, save_measurements=True):
- import ConfigParser
- config = ConfigParser.ConfigParser()
- config.optionxform = str
- if save_gui:
- config.add_section('app')
- for lqname, lq in self.settings.items():
- config.set('app', lqname, lq.val)
- if save_hardware:
- for hc_name, hc in self.hardware.items():
- section_name = 'hardware/'+hc_name
- config.add_section(section_name)
- for lqname, lq in hc.settings.items():
- if not lq.ro or save_ro:
- config.set(section_name, lqname, lq.val)
- if save_measurements:
- for meas_name, measurement in self.measurements.items():
- section_name = 'measurement/'+meas_name
- config.add_section(section_name)
- for lqname, lq in measurement.settings.items():
- if not lq.ro or save_ro:
- config.set(section_name, lqname, lq.val)
- with open(fname, 'wb') as configfile:
- config.write(configfile)
-
- print("ini settings saved to", fname, config.optionxform)
-
-
-
- def settings_load_ini(self, fname):
- print("ini settings loading from", fname)
-
- def str2bool(v):
- return v.lower() in ("yes", "true", "t", "1")
-
-
- import ConfigParser
- config = ConfigParser.ConfigParser()
- config.optionxform = str
- config.read(fname)
-
- if 'app' in config.sections():
- for lqname, new_val in config.items('app'):
- lq = self.settings[lqname]
- if lq.dtype == bool:
- new_val = str2bool(new_val)
- lq.update_value(new_val)
-
- for hc_name, hc in self.hardware_components.items():
- section_name = 'hardware/'+hc_name
- print(section_name)
- if section_name in config.sections():
- for lqname, new_val in config.items(section_name):
- try:
- lq = hc.settings[lqname]
- if lq.dtype == bool:
- new_val = str2bool(new_val)
- if not lq.ro:
- lq.update_value(new_val)
- except Exception as err:
- print("-->Failed to load config for {}/{}, new val {}: {}".format(section_name, lqname, new_val, repr(err)))
-
- for meas_name, measurement in self.measurement_components.items():
- section_name = 'measurement/'+meas_name
- if section_name in config.sections():
- for lqname, new_val in config.items(section_name):
- lq = measurement.logged_quantities[lqname]
- if lq.dtype == bool:
- new_val = str2bool(new_val)
- if not lq.ro:
- lq.update_value(new_val)
-
- print("ini settings loaded from", fname)
-
- def settings_load_h5(self, fname):
- import h5py
- with h5py.File(fname) as h5_file:
- pass
-
- def settings_auto_save(self):
- #fname = "%i_settings.h5" % time.time()
- #self.settings_save_h5(fname)
- self.settings_save_ini("%i_settings.ini" % time.time())
-
- def settings_load_last(self):
- import glob
- #fname = sorted(glob.glob("*_settings.h5"))[-1]
- #self.settings_load_h5(fname)
- fname = sorted(glob.glob("*_settings.ini"))[-1]
- self.settings_load_ini(fname)
-
-
- def settings_save_dialog(self):
- fname, selectedFilter = QtGui.QFileDialog.getSaveFileName(self.ui, "Save Settings file", "", "Settings File (*.ini)")
- if fname:
- self.settings_save_ini(fname)
-
- def settings_load_dialog(self):
- fname, selectedFilter = QtGui.QFileDialog.getOpenFileName(self.ui,"Open Settings file", "", "Settings File (*.ini *.h5)")
- self.settings_load_ini(fname)
-
-
-
-
-
-if __name__ == '__main__':
-
- app = BaseMicroscopeApp(sys.argv)
-
- sys.exit(app.exec_())
diff --git a/ScopeFoundry/base_microscope_app.ui b/ScopeFoundry/base_microscope_app.ui
deleted file mode 100644
index 3a92c7653..000000000
--- a/ScopeFoundry/base_microscope_app.ui
+++ /dev/null
@@ -1,111 +0,0 @@
-
-
- MainWindow
-
-
-
- 0
- 0
- 438
- 600
-
-
-
- ScopeFoundry
-
-
- false
-
-
-
-
-
-
- Logged Quantities
-
-
-
-
-
- Qt::Vertical
-
-
-
-
- Hardware
-
-
-
-
-
-
- Measurements
-
-
-
-
-
-
-
-
- Settings
-
-
-
-
-
- Load Last
-
-
-
-
-
-
- Auto Save
-
-
-
-
-
-
- Save...
-
-
-
-
-
-
- Load...
-
-
-
-
-
-
- Console...
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 0
- 0
- 438
- 22
-
-
-
-
-
-
-
-
diff --git a/ScopeFoundry/examples/__init__.py b/ScopeFoundry/examples/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/ScopeFoundry/examples/example_gui.py b/ScopeFoundry/examples/example_gui.py
deleted file mode 100644
index 228543490..000000000
--- a/ScopeFoundry/examples/example_gui.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import sys
-from PySide import QtGui
-
-from base_gui import BaseMicroscopeGUI
-
-# Import Hardware Components
-#from hardware_components.picoharp import PicoHarpHardwareComponent
-
-# Import Measurement Components
-#from measurement_components.powermeter_optimizer import PowerMeterOptimizerMeasurement
-
-class ExampleMicroscopeGUI(BaseMicroscopeGUI):
-
- ui_filename = "example_gui.ui"
-
- def setup(self):
- #Add hardware components
- print "Adding Hardware Components"
-
- #self.picoharp_hc = self.add_hardware_component(PicoHarpHardwareComponent(self))
-
- #Add measurement components
- print "Create Measurement objects"
- #self.apd_optimizer_measure = self.add_measurement_component(APDOptimizerMeasurement(self))
-
-
- #Add additional logged quantities
-
- # Connect to custom gui
-
-
-
-if __name__ == '__main__':
- app = QtGui.QApplication(sys.argv)
- app.setApplicationName("Example Foundry Scope App")
-
- gui = ExampleMicroscopeGUI()
- gui.show()
-
- sys.exit(app.exec_())
\ No newline at end of file
diff --git a/ScopeFoundry/examples/example_gui.ui b/ScopeFoundry/examples/example_gui.ui
deleted file mode 100644
index d03ebaa32..000000000
--- a/ScopeFoundry/examples/example_gui.ui
+++ /dev/null
@@ -1,338 +0,0 @@
-
-
- MainWindow
-
-
-
- 0
- 0
- 1395
- 1028
-
-
-
- Microscope
-
-
- false
-
-
-
-
-
-
- true
-
-
-
- 271
- 601
-
-
-
- QFrame::NoFrame
-
-
- QFrame::Plain
-
-
-
-
-
- 2D Scan Area
-
-
-
-
-
- µm
-
-
-
-
-
-
- 3
-
-
- 25.000000000000000
-
-
-
-
-
-
- 3
-
-
- 45.000000000000000
-
-
-
-
-
-
- 4
-
-
-
-
-
-
- dV:
-
-
-
-
-
-
- um
-
-
-
-
-
-
- dH:
-
-
-
-
-
-
- 4
-
-
-
-
-
-
- um
-
-
-
-
-
-
- H:
-
-
-
-
-
-
- 3
-
-
- 45.000000000000000
-
-
-
-
-
-
- 3
-
-
- 25.000000000000000
-
-
-
-
-
-
- µm
-
-
-
-
-
-
- V:
-
-
-
-
-
-
- Start
-
-
-
-
-
-
- Stop
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 0
- 0
-
-
-
- QTabWidget::Rounded
-
-
- 0
-
-
-
- Example Scan
-
-
-
-
-
- Controls
-
-
-
-
-
- Start Example Scan
-
-
-
-
-
-
- Interrupt Example Scan
-
-
-
-
-
-
- false
-
-
- Clear Figure
-
-
-
-
-
-
-
-
-
- Plot
-
-
-
-
-
-
-
-
- Hardware
-
-
-
-
-
- Qt::ScrollBarAlwaysOn
-
-
- Qt::ScrollBarAlwaysOff
-
-
- QAbstractScrollArea::AdjustToContentsOnFirstShow
-
-
- true
-
-
-
-
- 0
- 0
- 59
- 24
-
-
-
-
-
-
-
-
-
-
- Measurement
-
-
-
-
-
- Qt::ScrollBarAlwaysOn
-
-
- Qt::ScrollBarAlwaysOff
-
-
- true
-
-
-
-
- 0
- 0
- 59
- 24
-
-
-
-
-
-
-
-
-
-
-
- tabWidget
- frame
-
-
-
-
- 0
- 0
- 1395
- 22
-
-
-
-
-
-
- h0_doubleSpinBox
- h1_doubleSpinBox
- v0_doubleSpinBox
- v1_doubleSpinBox
- dh_doubleSpinBox
- dv_doubleSpinBox
- tabWidget
- hardware_tab_scrollArea
- scrollArea
-
-
-
-
diff --git a/ScopeFoundry/examples/example_xy_slowscan.py b/ScopeFoundry/examples/example_xy_slowscan.py
deleted file mode 100644
index cfb68dde3..000000000
--- a/ScopeFoundry/examples/example_xy_slowscan.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import sys
-from PySide import QtGui
-
-from ScopeFoundry import BaseMicroscopeGUI
-
-# Import Hardware Components
-from hardware_components.apd_counter import APDCounterHardwareComponent
-from ScopeFoundry.examples.hardware.dummy_xy_stage import DummyXYStage
-
-# Import Measurement Components
-from measurement_components.apd_optimizer_simple import APDOptimizerMeasurement
-from measurement_components.simple_xy_scan import SimpleXYScan
-
-
-class ExampleXYSlowscanGUI(BaseMicroscopeGUI):
-
- ui_filename = "../../ScopeFoundry/base_gui.ui"
-
- def setup(self):
- #Add hardware components
- print "Adding Hardware Components"
- self.add_hardware_component(APDCounterHardwareComponent(self))
- self.add_hardware_component(DummyXYStage(self))
-
- #Add measurement components
- print "Create Measurement objects"
- self.add_measurement_component(APDOptimizerMeasurement(self))
- self.add_measurement_component(SimpleXYScan(self))
-
- #set some default logged quantities
- self.hardware_components['apd_counter'].debug_mode.update_value(True)
- self.hardware_components['apd_counter'].dummy_mode.update_value(True)
- self.hardware_components['apd_counter'].connected.update_value(True)
-
- #Add additional logged quantities
-
- # Connect to custom gui
-
-
-
-if __name__ == '__main__':
- app = QtGui.QApplication(sys.argv)
- app.setApplicationName("Example XY slowscan App")
-
- gui = ExampleXYSlowscanGUI(app)
- gui.show()
-
- sys.exit(app.exec_())
diff --git a/ScopeFoundry/examples/hardware/__init__.py b/ScopeFoundry/examples/hardware/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/ScopeFoundry/examples/hardware/dummy_xy_stage.py b/ScopeFoundry/examples/hardware/dummy_xy_stage.py
deleted file mode 100644
index 73abc635c..000000000
--- a/ScopeFoundry/examples/hardware/dummy_xy_stage.py
+++ /dev/null
@@ -1,85 +0,0 @@
-from ScopeFoundry import HardwareComponent
-import random
-
-class DummmyXYStageEquipment(object):
-
- def __init__(self, debug=False):
- self.x = 0
- self.y = 0
- self.debug=debug
- # communicate with hardware here
-
- def read_x(self):
- self.x = self.x + self.noise()
- if self.debug: print "read_x", self.x
- return self.x
-
- def read_y(self):
- self.y = self.y + self.noise()
- if self.debug: print "read_y", self.y
- return self.y
-
- def write_x(self, x):
- self.x = x
- if self.debug: print "write_x", self.x
-
-
- def write_y(self, y):
- self.y = y
- if self.debug: print "write_y", self.y
-
- def close(self):
- print "dummy_xy_stage_equipment close"
-
- def noise(self):
- return (random.random()-0.5)*10e-3
-
-class DummyXYStage(HardwareComponent):
-
- name = "dummy_xy_stage"
-
- def setup(self):
- lq_params = dict( dtype=float, ro=False,
- initial = -1,
- vmin=-1,
- vmax=100,
- si = False,
- unit='um')
- self.x_position = self.add_logged_quantity("x_position", **lq_params)
- self.y_position = self.add_logged_quantity("y_position", **lq_params)
-
- self.x_position.reread_from_hardware_after_write = True
- self.x_position.spinbox_decimals = 3
-
- self.y_position.reread_from_hardware_after_write = True
- self.y_position.spinbox_decimals = 3
-
- def connect(self):
- if self.debug_mode.val: print "connecting to dummy_xy_stage"
-
- # Open connection to hardware
- self.stage_equip = DummmyXYStageEquipment(debug=True)
-
- # connect logged quantities
- self.x_position.hardware_read_func = self.stage_equip.read_x
- self.y_position.hardware_read_func = self.stage_equip.read_y
-
- self.x_position.hardware_set_func = self.stage_equip.write_x
- self.y_position.hardware_set_func = self.stage_equip.write_y
-
- def disconnect(self):
- if self.debug_mode.val: print "disconnecting to dummy_xy_stage"
-
- #disconnect logged quantities from hardware
- for lq in self.logged_quantities.values():
- lq.hardware_read_func = None
- lq.hardware_set_func = None
-
- #disconnect hardware
- self.stage_equip.close()
-
- # clean up hardware object
- del self.stage_equip
-
-
-
diff --git a/ScopeFoundry/h5_io.py b/ScopeFoundry/h5_io.py
deleted file mode 100644
index c9309130d..000000000
--- a/ScopeFoundry/h5_io.py
+++ /dev/null
@@ -1,211 +0,0 @@
-import h5py
-import time
-
-"""
-recommended HDF5 file format for ScopeFoundry
-* = group
-- = attr
-D = data_set
-
-* /
- - scope_foundry_version = 100
- - emd_version = 102
- * gui
- - log_quant_1
- - log_quant_1_unit
- - ...
- * hardware
- * hardware_component_1
- - ScopeFoundry_Type = Hardware
- - name = hardware_component_1
- - log_quant_1
- - log_quant_1_unit
- - ...
- * units
- - log_quant_1 = '[n_m]'
- * ...
- * measurement_1
- - ScopeFoundry_Type = Measurement
- - name = measurement_1
- - log_quant_1
- - ...
- * units
- - log_quant_1 = '[n_m]'
- * image_like_data_set_1
- - emd_group_type = 1
- D data
- D dim0
- - name = 'x'
- - unit = '[n_m]'
- D ...
- D dimN
- D simple_data_set_2
- D ...
-
-other thoughts:
- store git revision of code
- store git revision of ScopeFoundry
-
-"""
-
-def h5_base_file(gui, fname):
- h5_file = h5py.File(fname)
- root = h5_file['/']
- root.attrs["ScopeFoundry_version"] = 100
- t0 = time.time()
- root.attrs['time_id'] = t0
-
- h5_save_gui_lq(gui, root)
- h5_save_hardware_lq(gui, root)
- return h5_file
-
-def h5_save_gui_lq(gui, h5group):
- h5_gui_group = h5group.create_group('gui/')
- h5_gui_group.attrs['ScopeFoundry_type'] = "Gui"
- settings_group = h5_gui_group.create_group('settings')
- h5_save_lqs_to_attrs(gui.logged_quantities, settings_group)
-
-
-def h5_save_hardware_lq(gui, h5group):
- h5_hardware_group = h5group.create_group('hardware/')
- h5_hardware_group.attrs['ScopeFoundry_type'] = "HardwareList"
- for hc_name, hc in gui.hardware_components.items():
- h5_hc_group = h5_hardware_group.create_group(hc_name)
- h5_hc_group.attrs['name'] = hc.name
- h5_hc_group.attrs['ScopeFoundry_type'] = "Hardware"
- h5_hc_settings_group = h5_hc_group.create_group("settings")
- h5_save_lqs_to_attrs(hc.logged_quantities, h5_hc_settings_group)
- return h5_hardware_group
-
-def h5_save_lqs_to_attrs(logged_quantities, h5group):
- """
- Take a dictionary of logged_quantities
- and create attributes inside h5group
-
- :param logged_quantities:
- :param h5group:
- :return: None
- """
- unit_group = h5group.create_group('units')
- # TODO decide if we should specify h5 attr data type based on LQ dtype
- for lqname, lq in logged_quantities.items():
- h5group.attrs[lqname] = lq.val
- if lq.unit:
- unit_group.attrs[lqname] = lq.unit
-
-
-def h5_create_measurement_group(measurement, h5group):
- h5_meas_group = h5group.create_group('measurement/' + measurement.name)
- h5_save_measurement_settings(measurement, h5_meas_group)
- return h5_meas_group
-
-def h5_save_measurement_settings(measurement, h5_meas_group):
- h5_meas_group.attrs['name'] = measurement.name
- h5_meas_group.attrs['ScopeFoundry_type'] = "Measurement"
- settings_group = h5_meas_group.create_group("settings")
- h5_save_lqs_to_attrs(measurement.logged_quantities, settings_group)
-
-
-def h5_create_emd_dataset(name, h5parent, shape=None, data = None, maxshape = None,
- dim_arrays = None, dim_names= None, dim_units = None, **kwargs):
- """
- create an EMD dataset v0.2 inside h5parent
- returns an h5 group emd_grp
-
- to access N-dim dataset: emd_grp['data']
- to access a specific dimension array: emd_grp['dim1']
-
- HDF5 Hierarchy:
- ---------------
- * h5parent
- * name [emd_grp] (<--returned)
- - emd_group_type = 1
- D data [shape = shape]
- D dim1 [shape = shape[0]]
- - name
- - units
- ...
- D dimN [shape = shape[-1]]
-
- Parameters
- ----------
-
- h5parent : parent HDF5 group
-
- shape : Dataset shape of N dimensions. Required if "data" isn't provided.
-
- data : Provide data to initialize the dataset. If used, you can omit
- shape and dtype arguments.
-
- Keyword Args:
-
- dtype : Numpy dtype or string. If omitted, dtype('f') will be used.
- Required if "data" isn't provided; otherwise, overrides data
- array's dtype.
-
- dim_arrays : optional, a list of N dimension arrays
-
- dim_names : optional, a list of N strings naming the dataset dimensions
-
- dim_units : optional, a list of N strings specifying units of dataset dimensions
-
- Other keyword arguments follow from h5py.File.create_dataset
-
- Returns
- -------
- emd_grp : h5 group containing dataset and dimension arrays, see hierarchy below
-
- """
- #set the emd version tag at root of h5 file
- h5parent.file['/'].attrs['version_major'] = 0
- h5parent.file['/'].attrs['version_minor'] = 2
-
- from matplotlib import pyplot
- pyplot.acorr
-
- # create the EMD data group
- emd_grp = h5parent.create_group(name)
- emd_grp.attrs['emd_group_type'] = 1
-
- if data is not None:
- shape = data.shape
-
- # data set where the N-dim data is stored
- data_dset = emd_grp.create_dataset("data", shape=shape, maxshape=maxshape, data=data, **kwargs)
-
- if dim_arrays is not None: assert len(dim_arrays) == len(shape)
- if dim_names is not None: assert len(dim_names) == len(shape)
- if dim_units is not None: assert len(dim_units) == len(shape)
- if maxshape is not None: assert len(maxshape) == len(shape)
-
- # Create the dimension array datasets
- for ii in range(len(shape)):
- if dim_arrays is not None:
- dim_array = dim_arrays[ii]
- dim_dtype = dim_array.dtype
- else:
- dim_array = None
- dim_dtype = float
- if dim_names is not None:
- dim_name = dim_names[ii]
- else:
- dim_name = "dim" + str(ii+1)
- if dim_units is not None:
- dim_unit = dim_units[ii]
- else:
- dim_unit = None
- if maxshape is not None:
- dim_maxshape = (maxshape[ii],)
- else:
- dim_maxshape = None
-
- # create dimension array dataset
- dim_dset = emd_grp.create_dataset("dim" + str(ii+1), shape=(shape[ii],),
- dtype=dim_dtype, data=dim_array,
- maxshape=dim_maxshape)
- dim_dset.attrs['name'] = dim_name
- if dim_unit is not None:
- dim_dset.attrs['unit'] = dim_unit
-
- return emd_grp
-
diff --git a/ScopeFoundry/hardware.py b/ScopeFoundry/hardware.py
deleted file mode 100644
index bd8b1da38..000000000
--- a/ScopeFoundry/hardware.py
+++ /dev/null
@@ -1,177 +0,0 @@
-from PySide2 import QtCore, QtGui
-from .logged_quantity import LQCollection#, LoggedQuantity
-from collections import OrderedDict
-import pyqtgraph as pg
-
-class HardwareComponent(QtCore.QObject):
-
- def add_logged_quantity(self, name, **kwargs):
- #lq = LoggedQuantity(name=name, **kwargs)
- #self.logged_quantities[name] = lq
- #return lq
- return self.settings.New(name, **kwargs)
-
- def add_operation(self, name, op_func):
- """type name: str
- type op_func: QtCore.Slot
- """
- self.operations[name] = op_func
-
- def __init__(self, app, debug=False):
- QtCore.QObject.__init__(self)
-
- self.app = app
-
- #self.logged_quantities = OrderedDict()
- self.settings = LQCollection()
- self.operations = OrderedDict()
-
- self.connected = self.add_logged_quantity("connected", dtype=bool)
- self.connected.updated_value.connect(self.enable_connection)
-
- self.debug_mode = self.add_logged_quantity("debug_mode", dtype=bool, initial=debug)
-
- self.setup()
-
- try:
- self._add_control_widgets_to_hardware_tab()
- except Exception as err:
- print("HardwareComponent: could not add to hardware tab", self.name, err)
- try:
- self._add_control_widgets_to_hardware_tree()
- except Exception as err:
- print("HardwareComponent: could not add to hardware tree", self.name, err)
-
- self.has_been_connected_once = False
-
- self.is_connected = False
-
- def setup(self):
- """
- Runs during __init__, before the hardware connection is established
- Should generate desired LoggedQuantities, operations
- """
- raise NotImplementedError()
-
- def _add_control_widgets_to_hardware_tab(self):
- cwidget = self.app.ui.hardware_tab_scrollArea_content_widget
-
- self.controls_groupBox = QtGui.QGroupBox(self.name)
- self.controls_formLayout = QtGui.QFormLayout()
- self.controls_groupBox.setLayout(self.controls_formLayout)
-
- cwidget.layout().addWidget(self.controls_groupBox)
-
- #self.connect_hardware_checkBox = QtGui.QCheckBox("Connect to Hardware")
- #self.controls_formLayout.addRow("Connect", self.connect_hardware_checkBox)
- #self.connect_hardware_checkBox.stateChanged.connect(self.enable_connection)
-
-
- self.control_widgets = OrderedDict()
- for lqname, lq in self.settings.as_dict().items():
- #: :type lq: LoggedQuantity
- if lq.choices is not None:
- widget = QtGui.QComboBox()
- elif lq.dtype in [int, float]:
- if lq.si:
- widget = pg.SpinBox()
- else:
- widget = QtGui.QDoubleSpinBox()
- elif lq.dtype in [bool]:
- widget = QtGui.QCheckBox()
- elif lq.dtype in [str]:
- widget = QtGui.QLineEdit()
- lq.connect_bidir_to_widget(widget)
-
- # Add to formlayout
- self.controls_formLayout.addRow(lqname, widget)
- self.control_widgets[lqname] = widget
-
-
- self.op_buttons = OrderedDict()
- for op_name, op_func in self.operations.items():
- op_button = QtGui.QPushButton(op_name)
- op_button.clicked.connect(op_func)
- self.controls_formLayout.addRow(op_name, op_button)
-
- self.read_from_hardware_button = QtGui.QPushButton("Read From Hardware")
- self.read_from_hardware_button.clicked.connect(self.read_from_hardware)
- self.controls_formLayout.addRow("Logged Quantities:", self.read_from_hardware_button)
-
- def _add_control_widgets_to_hardware_tree(self):
- tree = self.app.ui.hardware_treeWidget
- #tree = QTreeWidget()
- tree.setColumnCount(2)
- tree.setHeaderLabels(["Hardware", "Value"])
-
- self.tree_item = QtGui.QTreeWidgetItem(tree, [self.name, "="*10])
- tree.insertTopLevelItem(0, self.tree_item)
- self.tree_item.setFirstColumnSpanned(True)
-
- for lqname, lq in self.settings.as_dict().items():
- #: :type lq: LoggedQuantity
- if lq.choices is not None:
- widget = QtGui.QComboBox()
- elif lq.dtype in [int, float]:
- if lq.si:
- widget = pg.SpinBox()
- else:
- widget = QtGui.QDoubleSpinBox()
- elif lq.dtype in [bool]:
- widget = QtGui.QCheckBox()
- elif lq.dtype in [str]:
- widget = QtGui.QLineEdit()
- lq.connect_bidir_to_widget(widget)
-
- # Add to formlayout
- #self.controls_formLayout.addRow(lqname, widget)
- lq_tree_item = QtGui.QTreeWidgetItem(self.tree_item, [lqname, ""])
- self.tree_item.addChild(lq_tree_item)
- lq.hardware_tree_widget = widget
- tree.setItemWidget(lq_tree_item, 1, lq.hardware_tree_widget)
- #self.control_widgets[lqname] = widget
-
- self.op_buttons = OrderedDict()
- for op_name, op_func in self.operations.items():
- op_button = QtGui.QPushButton(op_name)
- op_button.clicked.connect(op_func)
- self.op_buttons[op_name] = op_button
- #self.controls_formLayout.addRow(op_name, op_button)
- op_tree_item = QtGui.QTreeWidgetItem(self.tree_item, [op_name, ""])
- tree.setItemWidget(op_tree_item, 1, op_button)
-
- self.tree_read_from_hardware_button = QtGui.QPushButton("Read From\nHardware")
- self.tree_read_from_hardware_button.clicked.connect(self.read_from_hardware)
- #self.controls_formLayout.addRow("Logged Quantities:", self.read_from_hardware_button)
- self.read_from_hardware_button_tree_item = QtGui.QTreeWidgetItem(self.tree_item, ["Logged Quantities:", ""])
- self.tree_item.addChild(self.read_from_hardware_button_tree_item)
- tree.setItemWidget(self.read_from_hardware_button_tree_item, 1, self.tree_read_from_hardware_button)
-
- @QtCore.Slot()
- def read_from_hardware(self):
- for name, lq in self.settings.as_dict().items():
- if self.debug_mode.val: print("read_from_hardware", name)
- lq.read_from_hardware()
-
-
- def connect(self):
- """
- Opens a connection to hardware
- and connects hardware to associated LoggedQuantities
- """
- raise NotImplementedError()
-
-
- def disconnect(self):
- """
- Disconnects the hardware and severs hardware--LoggedQuantities link
- """
-
- raise NotImplementedError()
-
- @QtCore.Slot(bool)
- def enable_connection(self, enable=True):
- if enable:
- self.connect()
- else:
- self.disconnect()
diff --git a/ScopeFoundry/helper_funcs.py b/ScopeFoundry/helper_funcs.py
deleted file mode 100644
index 870fff5eb..000000000
--- a/ScopeFoundry/helper_funcs.py
+++ /dev/null
@@ -1,64 +0,0 @@
-from PySide2 import QtCore, QtGui, QtUiTools
-from collections import OrderedDict
-import os
-
-class OrderedAttrDict(object):
-
- def __init__(self):
- self._odict = OrderedDict()
-
- def add(self, name, obj):
- self._odict[name] = obj
- self.__dict__[name] = obj
- return obj
-
- def keys(self):
- return self._odict.keys()
- def values(self):
- return self._dict.values()
- def items(self):
- return self._odict.items()
-
- def __len__(self):
- return len(self._odict)
-
-def sibling_path(a, b):
- return os.path.join(os.path.dirname(a), b)
-
-
-def load_qt_ui_file(ui_filename):
- ui_loader = QtUiTools.QUiLoader()
- ui_file = QtCore.QFile(ui_filename)
- ui_file.open(QtCore.QFile.ReadOnly)
- ui = ui_loader.load(ui_file)
- ui_file.close()
- return ui
-
-def confirm_on_close(widget, title="Close ScopeFoundry?", message="Do you wish to shut down ScopeFoundry?"):
- widget.closeEventEater = CloseEventEater(title, message)
- widget.installEventFilter(widget.closeEventEater)
-
-class CloseEventEater(QtCore.QObject):
-
- def __init__(self, title="Close ScopeFoundry?", message="Do you wish to shut down ScopeFoundry?"):
- QtCore.QObject.__init__(self)
- self.title = title
- self.message = message
-
- def eventFilter(self, obj, event):
- if event.type() == QtCore.QEvent.Close:
- # eat close event
- print("close")
- reply = QtGui.QMessageBox.question(None,
- self.title,
- self.message,
- QtGui.QMessageBox.Yes, QtGui.QMessageBox.No)
- if reply == QtGui.QMessageBox.Yes:
- QtGui.QApplication.quit()
- event.accept()
- else:
- event.ignore()
- return True
- else:
- # standard event processing
- return QtCore.QObject.eventFilter(self,obj, event)
diff --git a/ScopeFoundry/logged_quantity.py b/ScopeFoundry/logged_quantity.py
deleted file mode 100644
index ccfd1c53b..000000000
--- a/ScopeFoundry/logged_quantity.py
+++ /dev/null
@@ -1,525 +0,0 @@
-from PySide2 import QtCore, QtWidgets, QtGui
-import pyqtgraph
-import numpy as np
-from collections import OrderedDict
-import json
-
-class LoggedQuantity(QtCore.QObject):
-
- updated_value = QtCore.Signal((float,),(int,),(bool,), (), (str,),) # signal sent when value has been updated
- updated_text_value = QtCore.Signal(str)
- updated_choice_index_value = QtCore.Signal(int) # emits the index of the value in self.choices
-
- updated_min_max = QtCore.Signal((float,float),(int,int), (),) # signal sent when min max range updated
- updated_readonly = QtCore.Signal((bool,), (),)
-
- def __init__(self, name, dtype=float,
- hardware_read_func=None, hardware_set_func=None,
- initial=0, fmt="%g", si=True,
- ro = False,
- unit = None,
- spinbox_decimals = 2,
- spinbox_step=0.1,
- vmin=-1e12, vmax=+1e12, choices=None):
- QtCore.QObject.__init__(self)
-
- self.name = name
- self.dtype = dtype
- self.val = dtype(initial)
- self.hardware_read_func = hardware_read_func
- self.hardware_set_func = hardware_set_func
- self.fmt = fmt # string formatting string. This is ignored if dtype==str
- self.si = si # will use pyqtgraph SI Spinbox if True
- self.unit = unit
- self.vmin = vmin
- self.vmax = vmax
- self.choices = choices # must be tuple [ ('name', val) ... ]
- self.ro = ro # Read-Only?
-
- if self.dtype == int:
- self.spinbox_decimals = 0
- else:
- self.spinbox_decimals = spinbox_decimals
- self.reread_from_hardware_after_write = False
-
- if self.dtype == int:
- self.spinbox_step = 1
- else:
- self.spinbox_step = spinbox_step
-
- self.oldval = None
-
- self._in_reread_loop = False # flag to prevent reread from hardware loops
-
- self.widget_list = []
-
- def coerce_to_type(self, x):
- return self.dtype(x)
-
- def read_from_hardware(self, send_signal=True):
- if self.hardware_read_func:
- self.oldval = self.val
- val = self.hardware_read_func()
- #print "read_from_hardware", self.name, val
- self.val = self.coerce_to_type(val)
- if send_signal:
- self.send_display_updates()
- return self.val
-
- @QtCore.Slot(str)
- @QtCore.Slot(float)
- @QtCore.Slot(int)
- @QtCore.Slot(bool)
- @QtCore.Slot()
- def update_value(self, new_val=None, update_hardware=True, send_signal=True, reread_hardware=None):
- #print "LQ update_value", self.name, self.val, "-->", new_val
- if new_val is None:
- #print "update_value {} new_val is None. From Sender {}".format(self.name, self.sender())
- new_val = self.sender().text()
-
- self.oldval = self.coerce_to_type(self.val)
- new_val = self.coerce_to_type(new_val)
-
- #print "LQ update_value1", self.name
-
- if self.same_values(self.oldval, new_val):
- #print "same_value so returning", self.oldval, new_val
- self._in_reread_loop = False #once value has settled in the event loop, re-enable reading from hardware
- return
-
- self.val = new_val
-
- #print "LQ update_value2", self.name
-
- if reread_hardware is None:
- reread_hardware = self.reread_from_hardware_after_write
-
- #print "called update_value", self.name, new_val, reread_hardware
- if update_hardware and self.hardware_set_func and not self._in_reread_loop:
- self.hardware_set_func(self.val)
- if reread_hardware:
- # re-reading from hardware can set off a loop of setting
- # and re-reading from hardware if hardware readout is not
- # exactly the requested value. temporarily disable rereading
- # from hardware until value in LoggedQuantity has settled
- self._in_reread_loop = True
- self.read_from_hardware(send_signal=False) # changed send_signal to false (ESB 2015-08-05)
- if send_signal:
- self.send_display_updates()
-
- def send_display_updates(self, force=False):
- #print "send_display_updates: {} force={}".format(self.name, force)
- if (not self.same_values(self.oldval, self.val)) or (force):
-
- #print "send display updates", self.name, self.val, self.oldval
- str_val = self.string_value()
- self.updated_value[str].emit(str_val)
- self.updated_text_value.emit(str_val)
-
- self.updated_value[float].emit(self.val)
- if self.dtype != float:
- self.updated_value[int].emit(self.val)
- self.updated_value[bool].emit(self.val)
- self.updated_value[()].emit()
-
- if self.choices is not None:
- choice_vals = [c[1] for c in self.choices]
- if self.val in choice_vals:
- self.updated_choice_index_value.emit(choice_vals.index(self.val) )
- self.oldval = self.val
- else:
- pass
- #print "\t no updates sent", (self.oldval != self.val) , (force), self.oldval, self.val
-
- def same_values(self, v1, v2):
- return v1 == v2
-
- def string_value(self):
- if self.dtype == str:
- return self.val
- else:
- return self.fmt % self.val
-
- def ini_string_value(self):
- return str(self.val)
-
-
- def update_choice_index_value(self, new_choice_index, **kwargs):
- self.update_value(self.choices[new_choice_index][1], **kwargs)
-
-
- def connect_bidir_to_widget(self, widget):
- print(type(widget))
- if type(widget) == QtWidgets.QDoubleSpinBox:
- #self.updated_value[float].connect(widget.setValue )
- #widget.valueChanged[float].connect(self.update_value)
- widget.setKeyboardTracking(False)
- if self.vmin is not None:
- widget.setMinimum(self.vmin)
- if self.vmax is not None:
- widget.setMaximum(self.vmax)
- if self.unit is not None:
- widget.setSuffix(" "+self.unit)
- widget.setDecimals(self.spinbox_decimals)
- widget.setSingleStep(self.spinbox_step)
- widget.setValue(self.val)
- #events
- self.updated_value[float].connect(widget.setValue)
- #if not self.ro:
- widget.valueChanged[float].connect(self.update_value)
-
- elif type(widget) == QtWidgets.QSpinBox:
- #self.updated_value[float].connect(widget.setValue )
- #widget.valueChanged[float].connect(self.update_value)
- widget.setKeyboardTracking(False)
- #if self.vmin is not None:
- # widget.setMinimum(self.vmin)
- #if self.vmax is not None:
- # widget.setMaximum(self.vmax)
- #if self.unit is not None:
- # widget.setSuffix(" "+self.unit)
- #widget.setDecimals(self.spinbox_decimals)
- widget.setSingleStep(self.spinbox_step)
- widget.setValue(self.val)
- #events
- self.updated_value[int].connect(widget.setValue)
- #if not self.ro:
- widget.valueChanged[int].connect(self.update_value)
-
- elif type(widget) == QtWidgets.QCheckBox:
- print(self.name)
- #self.updated_value[bool].connect(widget.checkStateSet)
- #widget.stateChanged[int].connect(self.update_value)
- # Ed's version
- print("connecting checkbox widget")
- self.updated_value[bool].connect(widget.setChecked)
- widget.toggled[bool].connect(self.update_value)
- if self.ro:
- #widget.setReadOnly(True)
- widget.setEnabled(False)
- elif type(widget) == QtWidgets.QLineEdit:
- self.updated_text_value[str].connect(widget.setText)
- if self.ro:
- widget.setReadOnly(True) # FIXME
- def on_edit_finished():
- print("on_edit_finished", self.name)
- self.update_value(widget.text())
- widget.editingFinished.connect(on_edit_finished)
- elif type(widget) == QtWidgets.QPlainTextEdit:
- # FIXME doesn't quite work right: a signal character resets cursor position
- self.updated_text_value[str].connect(widget.setPlainText)
- # TODO Read only
- def set_from_plaintext():
- self.update_value(widget.toPlainText())
- widget.textChanged.connect(set_from_plaintext)
-
- elif type(widget) == QtWidgets.QComboBox:
- # need to have a choice list to connect to a QComboBox
- assert self.choices is not None
- widget.clear() # removes all old choices
- for choice_name, choice_value in self.choices:
- widget.addItem(choice_name, choice_value)
- self.updated_choice_index_value[int].connect(widget.setCurrentIndex)
- widget.currentIndexChanged.connect(self.update_choice_index_value)
-
- elif type(widget) == pyqtgraph.widgets.SpinBox.SpinBox:
- #widget.setFocusPolicy(QtCore.Qt.StrongFocus)
- suffix = self.unit
- if self.unit is None:
- suffix = ""
- if self.dtype == int:
- integer = True
- minStep=1
- step=1
- else:
- integer = False
- minStep=.1
- step=.1
- widget.setOpts(
- suffix=suffix,
- siPrefix=True,
- dec=True,
- step=step,
- minStep=minStep,
- bounds=[self.vmin, self.vmax],
- int=integer)
- if self.ro:
- widget.setEnabled(False)
- widget.setButtonSymbols(QtWidgets.QAbstractSpinBox.NoButtons)
- widget.setReadOnly(True)
- widget.setDecimals(self.spinbox_decimals)
- widget.setSingleStep(self.spinbox_step)
- self.updated_value[float].connect(widget.setValue)
- #if not self.ro:
- #widget.valueChanged[float].connect(self.update_value)
- widget.valueChanged.connect(self.update_value)
- elif type(widget) == QtWidgets.QLabel:
- self.updated_text_value.connect(widget.setText)
- else:
- raise ValueError("Unknown widget type")
-
- self.send_display_updates(force=True)
- #self.widget = widget
- self.widget_list.append(widget)
- self.change_readonly(self.ro)
-
- def change_choice_list(self, choices):
- #widget = self.widget
- self.choices = choices
-
- for widget in self.widget_list:
- if type(widget) == QtWidgets.QComboBox:
- # need to have a choice list to connect to a QComboBox
- assert self.choices is not None
- widget.clear() # removes all old choices
- for choice_name, choice_value in self.choices:
- widget.addItem(choice_name, choice_value)
- else:
- raise RuntimeError("Invalid widget type.")
-
- def change_min_max(self, vmin=-1e12, vmax=+1e12):
- self.vmin = vmin
- self.vmax = vmax
- for widget in self.widget_list: # may not work for certain widget types
- widget.setRange(vmin, vmax)
- self.updated_min_max.emit(vmin,vmax)
-
- def change_readonly(self, ro=True):
- self.ro = ro
- for widget in self.widget_list:
- if type(widget) in [QtWidgets.QDoubleSpinBox, pyqtgraph.widgets.SpinBox.SpinBox]:
- widget.setReadOnly(self.ro)
- #elif
- self.updated_readonly.emit(self.ro)
-
-
-
-
-class FileLQ(LoggedQuantity):
-
- def __init__(self, name, default_dir=None, **kwargs):
- kwargs.pop('dtype', None)
-
- LoggedQuantity.__init__(self, name, dtype=str, **kwargs)
-
- self.default_dir = default_dir
-
- def connect_to_browse_widgets(self, lineEdit, pushButton):
- assert type(lineEdit) == QtWidgets.QLineEdit
- self.connect_bidir_to_widget(lineEdit)
-
- assert type(pushButton) == QtWidgets.QPushButton
- pushButton.clicked.connect(self.file_browser)
-
- def file_browser(self):
- # TODO add default directory, etc
- fname, _ = QtWidgets.QFileDialog.getOpenFileName(None)
- print(repr(fname))
- if fname:
- self.update_value(fname)
-
-class ArrayLQ(LoggedQuantity):
- updated_shape = QtCore.Signal(str)
-
- def __init__(self, name, dtype=float,
- hardware_read_func=None, hardware_set_func=None,
- initial=[], fmt="%g", si=True,
- ro = False,
- unit = None,
- vmin=-1e12, vmax=+1e12, choices=None):
- QtCore.QObject.__init__(self)
-
- self.name = name
- self.dtype = dtype
- self.val = np.array(initial, dtype=dtype)
- self.hardware_read_func = hardware_read_func
- self.hardware_set_func = hardware_set_func
- self.fmt = fmt # % string formatting string. This is ignored if dtype==str
- self.unit = unit
- self.vmin = vmin
- self.vmax = vmax
- self.ro = ro # Read-Only
-
- if self.dtype == int:
- self.spinbox_decimals = 0
- else:
- self.spinbox_decimals = 2
- self.reread_from_hardware_after_write = False
-
- self.oldval = None
-
- self._in_reread_loop = False # flag to prevent reread from hardware loops
-
- self.widget_list = []
-
- def same_values(self, v1, v2):
- if v1.shape == v2.shape:
- return np.all(v1 == v2)
- print("same_values", v2-v1, np.all(v1 == v2))
- else:
- return False
-
-
-
-
- def change_shape(self, newshape):
- #TODO
- pass
-
- def string_value (self):
- return json.dumps(self.val.tolist())
-
- def ini_string_value(self):
- return json.dumps(self.val.tolist())
-
- def coerce_to_type(self, x):
- #print type(x)
- if type(x) in (unicode, str):
- x = json.loads(x)
- #print repr(x)
- return np.array(x, dtype=self.dtype)
-
- def send_display_updates(self, force=False):
- print(self.name, 'send_display_updates')
- #print "send_display_updates: {} force={}".format(self.name, force)
- if force or np.any(self.oldval != self.val):
-
- #print "send display updates", self.name, self.val, self.oldval
- str_val = self.string_value()
- self.updated_value[str].emit(str_val)
- self.updated_text_value.emit(str_val)
-
- #self.updated_value[float].emit(self.val)
- #if self.dtype != float:
- # self.updated_value[int].emit(self.val)
- #self.updated_value[bool].emit(self.val)
- self.updated_value[()].emit()
-
- self.oldval = self.val
- else:
- pass
- #print "\t no updates sent", (self.oldval != self.val) , (force), self.oldval, self.val
-
-
-class LQRange(QtCore.QObject):
- """
- LQRange is a collection of logged quantities that describe a
- numpy.linspace array inputs
- Four LQ's are defined, min, max, num, step
- and are connected by signals/slots that keep the quantities
- in sync.
- LQRange.array is the linspace array and is kept upto date
- with changes to the 4 LQ's
- """
- updated_range = QtCore.Signal((),)# (float,),(int,),(bool,), (), (str,),) # signal sent when value has been updated
-
- def __init__(self, min_lq,max_lq,step_lq, num_lq):
- QtCore.QObject.__init__(self)
-
- self.min = min_lq
- self.max = max_lq
- self.num = num_lq
- self.step = step_lq
-
- assert self.num.dtype == int
-
- self.array = np.linspace(self.min.val, self.max.val, self.num.val)
- step = self.array[1]-self.array[0]
- self.step.update_value(step)
-
- self.num.updated_value[int].connect(self.recalc_with_new_num)
- self.min.updated_value.connect(self.recalc_with_new_min_max)
- self.max.updated_value.connect(self.recalc_with_new_min_max)
- self.step.updated_value.connect(self.recalc_with_new_step)
-
- def recalc_with_new_num(self, new_num):
- print("recalc_with_new_num", new_num)
- self.array = np.linspace(self.min.val, self.max.val, int(new_num))
- if len(self.array) > 1:
- new_step = self.array[1]-self.array[0]
- print(" new_step inside new_num", new_step)
- self.step.update_value(new_step)#, send_signal=True, update_hardware=False)
- self.step.send_display_updates(force=True)
- self.updated_range.emit()
-
- def recalc_with_new_min_max(self, x):
- self.array = np.linspace(self.min.val, self.max.val, self.num.val)
- step = self.array[1]-self.array[0]
- self.step.update_value(step)#, send_signal=True, update_hardware=False)
- self.updated_range.emit()
-
- def recalc_with_new_step(self,new_step):
- print("-->recalc_with_new_step")
- if len(self.array) > 1:
- old_step = self.array[1]-self.array[0]
- else:
- old_step = np.nan
- diff = np.abs(old_step - new_step)
- print("step diff", diff)
- if diff < 10**(-self.step.spinbox_decimals):
- print("steps close enough, no more recalc")
- return
- else:
- new_num = int((((self.max.val - self.min.val)/new_step)+1))
- self.array = np.linspace(self.min.val, self.max.val, new_num)
- new_step1 = self.array[1]-self.array[0]
- print("recalc_with_new_step", new_step, new_num, new_step1)
- #self.step.val = new_step1
- #self.num.val = new_num
- #self.step.update_value(new_step1, send_signal=False)
- #if np.abs(self.step.val - new_step1)/self.step.val > 1e-2:
- self.step.val = new_step1
- self.num.update_value(new_num)
- #self.num.send_display_updates(force=True)
- #self.step.update_value(new_step1)
-
- #print "sending step display Updates"
- #self.step.send_display_updates(force=True)
- self.updated_range.emit()
-
-class LQCollection(object):
-
- def __init__(self):
- self._logged_quantities = OrderedDict()
-
- def New(self, name, dtype=float, **kwargs):
- is_array = kwargs.pop('array', False)
- print(name, 'is_array', is_array)
- if is_array:
- lq = ArrayLQ(name=name, dtype=dtype, **kwargs)
- else:
- if dtype == 'file':
- lq = FileLQ(name=name, **kwargs)
- else:
- lq = LoggedQuantity(name=name, dtype=dtype, **kwargs)
- self._logged_quantities[name] = lq
- self.__dict__[name] = lq
- return lq
-
- def as_list(self):
- return self._logged_quantities.values()
-
- def as_dict(self):
- return self._logged_quantities
-
- """def __getattr__(self, name):
- return self.logged_quantities[name]
-
- def __getitem__(self, key):
- return self.logged_quantities[key]
-
- def __getattribute__(self,name):
- if name in self.logged_quantities.keys():
- return self.logged_quantities[name]
- else:
- return object.__getattribute__(self, name)
- """
-
-
-def print_signals_and_slots(obj):
- for i in xrange(obj.metaObject().methodCount()):
- m = obj.metaObject().method(i)
- if m.methodType() == QtCore.QMetaMethod.MethodType.Signal:
- print("SIGNAL: sig=", m.signature(), "hooked to nslots=",obj.receivers(QtCore.SIGNAL(m.signature())))
- elif m.methodType() == QtCore.QMetaMethod.MethodType.Slot:
- print("SLOT: sig=", m.signature())
diff --git a/ScopeFoundry/measurement.py b/ScopeFoundry/measurement.py
deleted file mode 100644
index fc5f66638..000000000
--- a/ScopeFoundry/measurement.py
+++ /dev/null
@@ -1,264 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Tue Apr 1 09:25:48 2014
-
-@author: esbarnard
-"""
-
-from PySide2 import QtCore, QtGui
-import threading
-import time
-from ScopeFoundry.logged_quantity import LQCollection
-from ScopeFoundry.helper_funcs import load_qt_ui_file
-from collections import OrderedDict
-import pyqtgraph as pg
-
-class Measurement(QtCore.QObject):
-
- measurement_sucessfully_completed = QtCore.Signal(()) # signal sent when full measurement is complete
- measurement_interrupted = QtCore.Signal(()) # signal sent when measurement is complete due to an interruption
- #measurement_state_changed = QtCore.Signal(bool) # signal sent when measurement started or stopped
-
- def __init__(self, app):
- """type app: MicroscopeApp
- """
-
- QtCore.QObject.__init__(self)
-
- self.app = app
-
- self.display_update_period = 0.1 # seconds
- self.display_update_timer = QtCore.QTimer(self)
- self.display_update_timer.timeout.connect(self.on_display_update_timer)
- self.acq_thread = None
-
- self.interrupt_measurement_called = False
-
- #self.logged_quantities = OrderedDict()
- self.settings = LQCollection()
- self.operations = OrderedDict()
-
-
- self.activation = self.settings.New('activation', dtype=bool, ro=False) # does the user want to the thread to be running
- self.running = self.settings.New('running', dtype=bool, ro=True) # is the thread actually running?
- self.progress = self.settings.New('progress', dtype=float, unit="%", si=False, ro=True)
-
- self.activation.updated_value.connect(self.start_stop)
-
- self.add_operation("start", self.start)
- self.add_operation("interrupt", self.interrupt)
- self.add_operation("setup", self.setup)
- self.add_operation("setup_figure", self.setup_figure)
- self.add_operation("update_display", self.update_display)
- self.add_operation('show_ui', self.show_ui)
-
- if hasattr(self, 'ui_filename'):
- self.load_ui()
-
- self.setup()
-
- try:
- self._add_control_widgets_to_measurements_tab()
- except Exception as err:
- print("MeasurementComponent: could not add to measurement tab", self.name, err)
- try:
- self._add_control_widgets_to_measurements_tree()
- except Exception as err:
- print("MeasurementComponent: could not add to measurement tree", self.name, err)
-
-
- def setup(self):
- """Override this to set up logged quantities and gui connections
- Runs during __init__, before the hardware connection is established
- Should generate desired LoggedQuantities"""
- pass
- #raise NotImplementedError()
-
- def setup_figure(self):
- print("Empty setup_figure called")
- pass
-
- @QtCore.Slot()
- def start(self):
- print("measurement", self.name, "start")
- self.interrupt_measurement_called = False
- if (self.acq_thread is not None) and self.is_measuring():
- raise RuntimeError("Cannot start a new measurement while still measuring")
- self.acq_thread = threading.Thread(target=self._thread_run)
- #self.measurement_state_changed.emit(True)
- self.running.update_value(True)
- self.acq_thread.start()
- self.t_start = time.time()
- self.display_update_timer.start(self.display_update_period*1000)
-
- def _run(self):
- raise NotImplementedError("Measurement {}._run() not defined".format(self.name))
-
- def _thread_run(self):
- #self.progress_updated.emit(50) # set progress bars to default run position at 50%
- self.set_progress(50)
- try:
- self._run()
-
- #except Exception as err:
- # self.interrupt_measurement_called = True
- # raise err
- finally:
- self.running.update_value(False)
- self.set_progress(0) # set progress bars back to zero
- #self.measurement_state_changed.emit(False)
- if self.interrupt_measurement_called:
- self.measurement_interrupted.emit()
- else:
- self.measurement_sucessfully_completed.emit()
-
- def set_progress(self, pct):
- self.progress.update_value(pct)
-
- @QtCore.Slot()
- def interrupt(self):
- print("measurement", self.name, "interrupt")
- self.interrupt_measurement_called = True
- self.activation.update_value(False)
- #Make sure display is up to date
- #self.on_display_update_timer()
-
-
- def start_stop(self, start):
- print(self.name, "start_stop", start)
- if start:
- self.start()
- else:
- self.interrupt()
-
-
- def is_measuring(self):
- if self.acq_thread is None:
- self.running.update_value(False)
- return False
- else:
- resp = self.acq_thread.is_alive()
- self.running.update_value(resp)
- return resp
-
- def update_display(self):
- "Override this function to provide figure updates when the display timer runs"
- pass
-
-
- @QtCore.Slot()
- def on_display_update_timer(self):
- try:
- self.update_display()
- except Exception as err:
- pass
- print(self.name, "Failed to update figure:", err)
- finally:
- if not self.is_measuring():
- self.display_update_timer.stop()
-
- def add_logged_quantity(self, name, **kwargs):
- lq = self.settings.New(name=name, **kwargs)
- return lq
-
- def add_operation(self, name, op_func):
- """type name: str
- type op_func: QtCore.Slot
- """
- self.operations[name] = op_func
-
- def load_ui(self, ui_fname=None):
- # TODO destroy and rebuild UI if it already exists
- if ui_fname is not None:
- self.ui_filename = ui_fname
- # Load Qt UI from .ui file
- self.ui = load_qt_ui_file(self.ui_filename)
- self.show_ui()
-
- def show_ui(self):
- self.ui.show()
- self.ui.activateWindow()
- #self.ui.raise() #just to be sure it's on top
-
- def _add_control_widgets_to_measurements_tab(self):
- cwidget = self.app.ui.measurements_tab_scrollArea_content_widget
-
- self.controls_groupBox = QtGui.QGroupBox(self.name)
- self.controls_formLayout = QtGui.QFormLayout()
- self.controls_groupBox.setLayout(self.controls_formLayout)
-
- cwidget.layout().addWidget(self.controls_groupBox)
-
- self.control_widgets = OrderedDict()
- for lqname, lq in self.logged_quantities.items():
- #: :type lq: LoggedQuantity
- if lq.choices is not None:
- widget = QtGui.QComboBox()
- elif lq.dtype in [int, float]:
- if lq.si:
- widget = pg.SpinBox()
- else:
- widget = QtGui.QDoubleSpinBox()
- elif lq.dtype in [bool]:
- widget = QtGui.QCheckBox()
- elif lq.dtype in [str]:
- widget = QtGui.QLineEdit()
- lq.connect_bidir_to_widget(widget)
-
- # Add to formlayout
- self.controls_formLayout.addRow(lqname, widget)
- self.control_widgets[lqname] = widget
-
-
- self.op_buttons = OrderedDict()
- for op_name, op_func in self.operations.items():
- op_button = QtGui.QPushButton(op_name)
- op_button.clicked.connect(op_func)
- self.controls_formLayout.addRow(op_name, op_button)
-
-
- def _add_control_widgets_to_measurements_tree(self, tree=None):
- if tree is None:
- tree = self.app.ui.measurements_treeWidget
-
- tree.setColumnCount(2)
- tree.setHeaderLabels(["Measurements", "Value"])
-
- self.tree_item = QtGui.QTreeWidgetItem(tree, [self.name, ""])
- tree.insertTopLevelItem(0, self.tree_item)
- #self.tree_item.setFirstColumnSpanned(True)
- self.tree_progressBar = QtGui.QProgressBar()
- tree.setItemWidget(self.tree_item, 1, self.tree_progressBar)
- self.progress.updated_value.connect(self.tree_progressBar.setValue)
-
- # Add logged quantities to tree
- for lqname, lq in self.logged_quantities.items():
- #: :type lq: LoggedQuantity
- if lq.choices is not None:
- widget = QtGui.QComboBox()
- elif lq.dtype in [int, float]:
- if lq.si:
- widget = pg.SpinBox()
- else:
- widget = QtGui.QDoubleSpinBox()
- elif lq.dtype in [bool]:
- widget = QtGui.QCheckBox()
- elif lq.dtype in [str]:
- widget = QtGui.QLineEdit()
- lq.connect_bidir_to_widget(widget)
-
- lq_tree_item = QtGui.QTreeWidgetItem(self.tree_item, [lqname, ""])
- self.tree_item.addChild(lq_tree_item)
- lq.hardware_tree_widget = widget
- tree.setItemWidget(lq_tree_item, 1, lq.hardware_tree_widget)
- #self.control_widgets[lqname] = widget
-
- # Add operation buttons to tree
- self.op_buttons = OrderedDict()
- for op_name, op_func in self.operations.items():
- op_button = QtGui.QPushButton(op_name)
- op_button.clicked.connect(op_func)
- self.op_buttons[op_name] = op_button
- #self.controls_formLayout.addRow(op_name, op_button)
- op_tree_item = QtGui.QTreeWidgetItem(self.tree_item, [op_name, ""])
- tree.setItemWidget(op_tree_item, 1, op_button)
diff --git a/ScopeFoundry/ndarray_interactive.py b/ScopeFoundry/ndarray_interactive.py
deleted file mode 100644
index 71e9f740b..000000000
--- a/ScopeFoundry/ndarray_interactive.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from PySide import QtCore, QtGui
-from PySide.QtCore import Qt
-import numpy as np
-
-
-#https://www.mail-archive.com/pyqt@riverbankcomputing.com/msg17575.html
-# plus more
-class NumpyQTableModel(QtCore.QAbstractTableModel):
- def __init__(self, narray, col_names=None, row_names=None, fmt="%g", copy=True, parent=None):
- QtCore.QAbstractTableModel.__init__(self, parent)
- self.copy = copy
- if self.copy:
- self._array = narray.copy()
- else:
- self._array = narray
- self.col_names = col_names
- self.row_names = row_names
- self.fmt=fmt
-
- def rowCount(self, parent=None):
- return self._array.shape[0]
-
- def columnCount(self, parent=None):
- return self._array.shape[1]
-
- def data(self, index, role=Qt.DisplayRole):
- if index.isValid():
- if role == Qt.DisplayRole or role==Qt.EditRole:
- row = index.row()
- col = index.column()
- return self.fmt % self._array[row, col]
- return None
-
- def setData(self, index, value, role=Qt.EditRole):
- print index,value, role
- jj,ii = index.row(), index.column()
-
- print 'setData', ii,jj
- #return QtCore.QAbstractTableModel.setData(self, *args, **kwargs)
-
- try:
- self._array[jj,ii] = value
- self.dataChanged.emit(index, index) # topLeft, bottomRight indexes of change
- return True
- except Exception as err:
- print "setData err:", err
- return False
-
- def set_array(self, narray):
- #print "set_array"
- if self.copy:
- self._array = narray.copy()
- else:
- self._array = narray
- self.layoutChanged.emit()
- self.dataChanged.emit((0,0), (self.rowCount(), self.columnCount()))
-
-
- def flags(self, *args, **kwargs):
- #return QtCore.QAbstractTableModel.flags(self, *args, **kwargs)
- return Qt.ItemIsEditable | Qt.ItemIsEnabled | Qt.ItemIsSelectable
-
- def headerData(self, section, orientation, role=Qt.DisplayRole):
- if role == Qt.DisplayRole and orientation == Qt.Horizontal and self.col_names:
- return self.col_names[section]
- if role == Qt.DisplayRole and orientation == Qt.Vertical and self.row_names:
- return self.row_names[section]
- return QtCore.QAbstractTableModel.headerData(self, section-1, orientation, role)
-
-class ArrayLQ_QTableModel(NumpyQTableModel):
- def __init__(self, lq, col_names=None, row_names=None, parent=None):
- print lq.val
- NumpyQTableModel.__init__(self, lq.val, col_names=col_names, row_names=row_names, parent=parent)
- self.lq = lq
- self.lq.updated_value[()].connect(self.on_lq_updated_value)
- self.dataChanged.connect(self.on_dataChanged)
-
- def on_lq_updated_value(self):
- #print "ArrayLQ_QTableModel", self.lq.name, 'on_lq_updated_value'
- self.set_array(self.lq.val)
-
- def on_dataChanged(self,topLeft=None, bottomRight=None):
- #print "ArrayLQ_QTableModel", self.lq.name, 'on_dataChanged'
- self.lq.update_value(np.array(self._array))
- #self.lq.send_display_updates(force=True)
-
-
-
-if __name__ == '__main__':
- qtapp = QtGui.QApplication([])
-
- import numpy as np
-
- A = np.random.rand(10,5)
- B = np.random.rand(12,4)
-
- table_view = QtGui.QTableView()
- table_view_model = NumpyQTableModel(narray=A, col_names=['Peak', 'FWHM','center', 'asdf', '!__!'])
- table_view.setModel(table_view_model)
- table_view.show()
- table_view.raise_()
-
- table_view_model.set_array(B)
-
- qtapp.exec_()
diff --git a/_run.py b/_run.py
deleted file mode 100644
index d27e3e721..000000000
--- a/_run.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from viewer import DataViewer
-import sys
-
-if __name__ == '__main__':
- app = DataViewer(sys.argv)
-
- sys.exit(app.exec_())
diff --git a/control_panel.py b/control_panel.py
deleted file mode 100644
index 3cb91707c..000000000
--- a/control_panel.py
+++ /dev/null
@@ -1,168 +0,0 @@
-#!/Users/Ben/Code/anaconda2/envs/py3/bin/python
-
-import sys
-from PySide2 import QtCore, QtWidgets, QtGui
-
-
-class ControlPanel(QtWidgets.QWidget):
- def __init__(self):
- QtWidgets.QWidget.__init__(self)
-
- # Container widget
- scrollableWidget = QtWidgets.QWidget()
- layout = QtWidgets.QVBoxLayout(self)
-
- ##### Make sub-widgets #####
- # For each, provide handles to connect to their widgets
-
- # File loading
- dataLoader = DataLoadingWidget()
- self.lineEdit_LoadFile = dataLoader.lineEdit_LoadFile
- self.pushButton_BrowseFiles = dataLoader.pushButton_BrowseFiles
-
- # Data cube size and shape
- sizeAndShapeEditor = DataCubeSizeAndShapeWidget()
- self.spinBox_Nx = sizeAndShapeEditor.spinBox_Nx
- self.spinBox_Ny = sizeAndShapeEditor.spinBox_Ny
- self.lineEdit_Binning = sizeAndShapeEditor.lineEdit_Binning
- self.pushButton_Binning = sizeAndShapeEditor.pushButton_Binning
- self.pushButton_SetCropWindow = sizeAndShapeEditor.pushButton_SetCropWindow
- self.pushButton_CropData = sizeAndShapeEditor.pushButton_CropData
-
-
-
- # Create and set layout
- layout.addWidget(dataLoader)
- layout.addWidget(sizeAndShapeEditor)
- scrollableWidget.setLayout(layout)
-
- # Scroll Area Properties
- scrollArea = QtWidgets.QScrollArea()
- scrollArea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
- scrollArea.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
- scrollArea.setWidgetResizable(True)
- scrollArea.setWidget(scrollableWidget)
- scrollArea.setFrameStyle(QtWidgets.QFrame.NoFrame)
-
- # Set the scroll area container to fill the layout of the entire ControlPanel widget
- vLayout = QtWidgets.QVBoxLayout(self)
- vLayout.addWidget(scrollArea)
- self.setLayout(vLayout)
-
- # Set geometry
- #self.setFixedHeight(600)
- #self.setFixedWidth(300)
-
-
-
-class DataLoadingWidget(QtWidgets.QWidget):
- def __init__(self):
- QtWidgets.QWidget.__init__(self)
-
- # Label, Line Edit, Browse Button
- self.label_Filename = QtWidgets.QLabel("Filename")
- self.lineEdit_LoadFile = QtWidgets.QLineEdit("")
- self.pushButton_BrowseFiles = QtWidgets.QPushButton("Browse")
-
- # Title
- title_row = QtWidgets.QLabel("Load Data")
- titleFont = QtGui.QFont()
- titleFont.setBold(True)
- title_row.setFont(titleFont)
-
- # Layout
- top_row = QtWidgets.QHBoxLayout()
- top_row.addWidget(self.label_Filename, stretch=0)
- top_row.addWidget(self.lineEdit_LoadFile, stretch=5)
-
- layout = QtWidgets.QVBoxLayout()
- layout.addWidget(title_row,0,QtCore.Qt.AlignCenter)
- #verticalSpacer = QtWidgets.QSpacerItem(0, 10, QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed)
- #layout.addItem(verticalSpacer)
- layout.addLayout(top_row)
- layout.addWidget(self.pushButton_BrowseFiles,0,QtCore.Qt.AlignRight)
-
- self.setLayout(layout)
- self.setFixedWidth(260)
- self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed,QtWidgets.QSizePolicy.Fixed))
-
-class DataCubeSizeAndShapeWidget(QtWidgets.QWidget):
- def __init__(self):
- QtWidgets.QWidget.__init__(self)
-
- # Reshaping - Nx and Ny
- self.spinBox_Nx = QtWidgets.QSpinBox()
- self.spinBox_Ny = QtWidgets.QSpinBox()
- self.spinBox_Nx.setMaximum(100000)
- self.spinBox_Ny.setMaximum(100000)
-
- layout_spinBoxRow = QtWidgets.QHBoxLayout()
- layout_spinBoxRow.addWidget(QtWidgets.QLabel("Nx"),0,QtCore.Qt.AlignRight)
- layout_spinBoxRow.addWidget(self.spinBox_Nx)
- layout_spinBoxRow.addWidget(QtWidgets.QLabel("Ny"),0,QtCore.Qt.AlignRight)
- layout_spinBoxRow.addWidget(self.spinBox_Ny)
-
- layout_Reshaping = QtWidgets.QVBoxLayout()
- layout_Reshaping.addWidget(QtWidgets.QLabel("Scan shape"),0,QtCore.Qt.AlignCenter)
- layout_Reshaping.addLayout(layout_spinBoxRow)
-
- # Binning
- self.lineEdit_Binning = QtWidgets.QLineEdit("")
- self.pushButton_Binning = QtWidgets.QPushButton("Bin Data")
-
- layout_binningRow = QtWidgets.QHBoxLayout()
- layout_binningRow.addWidget(QtWidgets.QLabel("Bin by:"))
- layout_binningRow.addWidget(self.lineEdit_Binning)
- layout_binningRow.addWidget(self.pushButton_Binning)
-
- layout_Binning = QtWidgets.QVBoxLayout()
- layout_Binning.addWidget(QtWidgets.QLabel("Binning"),0,QtCore.Qt.AlignCenter)
- layout_Binning.addLayout(layout_binningRow)
-
- # Cropping
- self.pushButton_SetCropWindow = QtWidgets.QPushButton("Set Crop Window")
- self.pushButton_CropData = QtWidgets.QPushButton("Crop Data")
-
- layout_croppingRow = QtWidgets.QHBoxLayout()
- layout_croppingRow.addWidget(self.pushButton_SetCropWindow)
- layout_croppingRow.addWidget(self.pushButton_CropData)
-
- layout_Cropping = QtWidgets.QVBoxLayout()
- layout_Cropping.addWidget(QtWidgets.QLabel("Cropping"),0,QtCore.Qt.AlignCenter)
- layout_Cropping.addLayout(layout_croppingRow)
-
- # Title
- title_row = QtWidgets.QLabel("Reshape, bin, and crop")
- titleFont = QtGui.QFont()
- titleFont.setBold(True)
- title_row.setFont(titleFont)
-
- # Layout
- layout = QtWidgets.QVBoxLayout()
- layout.addWidget(title_row,0,QtCore.Qt.AlignCenter)
- layout.addLayout(layout_Reshaping)
- layout.addLayout(layout_Binning)
- layout.addLayout(layout_Cropping)
-
- self.setLayout(layout)
- self.setFixedWidth(260)
- self.setSizePolicy(QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed,QtWidgets.QSizePolicy.Fixed))
-
-if __name__ == '__main__':
- app = QtWidgets.QApplication(sys.argv)
-
- controlPanel = ControlPanel()
- controlPanel.show()
-
- app.exec_()
-
-
-
-
-
-#app = QtWidgets.QApplication(sys.argv)
-#controlPanel = ControlPanel()
-#controlPanel.show()
-#sys.exit(app.exec_())
-
-
diff --git a/datacube.py b/datacube.py
deleted file mode 100644
index 851044cd9..000000000
--- a/datacube.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Defines a class - DataCube - for storing / accessing / manipulating the 4D-STEM data
-
-# For now, let's assume the data we're loading...
-# -is 3D data, with the real space dimensions flattened.
-# -does not have the scan shape stored in metadata
-#
-# Once we have other kinds of data, we can implement more complex loading functions which
-# catch all the possibilities.
-
-
-import hyperspy.api as hs
-import numpy as np
-
-class DataCube(object):
-
- def __init__(self, filename):
- self.read_data(filename)
- self.set_scan_shape()
-
- def read_data(self,filename):
- #Load data
- try:
- hyperspy_file = hs.load(filename)
- self.raw_data = hyperspy_file.data
- self.metadata = hyperspy_file.metadata
- self.original_metadata = hyperspy_file.original_metadata
- except Exception as err:
- print("Failed to load", err)
- self.raw_data = np.random.rand(100,512,512)
- # Get shape of raw data
- if len(self.raw_data.shape)==3:
- self.R_N, self.Q_Ny, self.Q_Nx = self.raw_data.shape
- self.R_Nx, self.R_Ny = 1, self.R_N
- elif len(self.raw_data.shape)==4:
- self.R_Ny, self.R_Nx, self.Q_Ny, self.Q_Nx = self.raw_data.shape
- self.R_N = self.R_Ny*self.R_Nx
- else:
- print("Error: unexpected raw data shape of {}".format(self.raw_data.shape))
-
- def set_scan_shape(self,R_Ny,R_Nx):
- """
- Reshape the data give the real space scan shape.
- TODO: insert catch for 4D data being reshaped. Presently only 3D data supported.
- """
- try:
- self.data4D = self.raw_data.reshape(R_Ny,R_Nx,self.Q_Ny,self.Q_Nx)
- self.R_Ny,self.R_Nx = R_Ny, R_Nx
- except ValueError:
- pass
-
- def set_diffraction_space_view(self,R_Ny,R_Nx):
- """
- Set the image in diffraction space
- """
- try:
- self.data4D = self.raw_data.reshape(R_Ny,R_Nx,self.Q_Ny,self.Q_Nx)
- self.R_Ny,self.R_Nx = R_Ny, R_Nx
- except ValueError:
- pass
-
-
diff --git a/dm3_lib/.DS_Store b/dm3_lib/.DS_Store
deleted file mode 100644
index fa99ebca4..000000000
Binary files a/dm3_lib/.DS_Store and /dev/null differ
diff --git a/dm3_lib/__init__.py b/dm3_lib/__init__.py
deleted file mode 100644
index c34330ae5..000000000
--- a/dm3_lib/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from ._dm3_lib import VERSION
-from ._dm3_lib import DM3
-from ._dm3_lib import SUPPORTED_DATA_TYPES
diff --git a/dm3_lib/_dm3_lib.py b/dm3_lib/_dm3_lib.py
deleted file mode 100755
index eb0719d44..000000000
--- a/dm3_lib/_dm3_lib.py
+++ /dev/null
@@ -1,855 +0,0 @@
-#!/usr/bin/python
-"""Python module for parsing GATAN DM3 files"""
-
-################################################################################
-## Python script for parsing GATAN DM3 (DigitalMicrograph) files
-## --
-## based on the DM3_Reader plug-in (v 1.3.4) for ImageJ
-## by Greg Jefferis
-## http://rsb.info.nih.gov/ij/plugins/DM3_Reader.html
-## --
-## Python adaptation: Pierre-Ivan Raynal
-## http://microscopies.med.univ-tours.fr/
-################################################################################
-
-from __future__ import print_function
-
-import sys
-import os
-import struct
-
-from PIL import Image
-
-import numpy
-
-__all__ = ["DM3", "VERSION", "SUPPORTED_DATA_TYPES"]
-
-VERSION = '1.2'
-
-debugLevel = 0 # 0=none, 1-3=basic, 4-5=simple, 6-10 verbose
-
-## check for Python version
-PY3 = (sys.version_info[0] == 3)
-
-## - adjust for Python3
-if PY3:
- # unicode() deprecated in Python 3
- unicode_str = str
-else:
- unicode_str = unicode
-
-### utility fuctions ###
-
-### binary data reading functions ###
-
-def readLong(f):
- """Read 4 bytes as integer in file f"""
- read_bytes = f.read(4)
- return struct.unpack('>l', read_bytes)[0]
-
-def readShort(f):
- """Read 2 bytes as integer in file f"""
- read_bytes = f.read(2)
- return struct.unpack('>h', read_bytes)[0]
-
-def readByte(f):
- """Read 1 byte as integer in file f"""
- read_bytes = f.read(1)
- return struct.unpack('>b', read_bytes)[0]
-
-def readBool(f):
- """Read 1 byte as boolean in file f"""
- read_val = readByte(f)
- return (read_val!=0)
-
-def readChar(f):
- """Read 1 byte as char in file f"""
- read_bytes = f.read(1)
- return struct.unpack('c', read_bytes)[0]
-
-def readString(f, len_=1):
- """Read len_ bytes as a string in file f"""
- read_bytes = f.read(len_)
- str_fmt = '>'+str(len_)+'s'
- return struct.unpack( str_fmt, read_bytes )[0]
-
-def readLEShort(f):
- """Read 2 bytes as *little endian* integer in file f"""
- read_bytes = f.read(2)
- return struct.unpack(' reading function
-readFunc = {
- SHORT: readLEShort,
- LONG: readLELong,
- USHORT: readLEUShort,
- ULONG: readLEULong,
- FLOAT: readLEFloat,
- DOUBLE: readLEDouble,
- BOOLEAN: readBool,
- CHAR: readChar,
- OCTET: readChar, # difference with char???
-}
-
-## list of image DataTypes ##
-dataTypes = {
- 0: 'NULL_DATA',
- 1: 'SIGNED_INT16_DATA',
- 2: 'REAL4_DATA',
- 3: 'COMPLEX8_DATA',
- 4: 'OBSELETE_DATA',
- 5: 'PACKED_DATA',
- 6: 'UNSIGNED_INT8_DATA',
- 7: 'SIGNED_INT32_DATA',
- 8: 'RGB_DATA',
- 9: 'SIGNED_INT8_DATA',
- 10: 'UNSIGNED_INT16_DATA',
- 11: 'UNSIGNED_INT32_DATA',
- 12: 'REAL8_DATA',
- 13: 'COMPLEX16_DATA',
- 14: 'BINARY_DATA',
- 15: 'RGB_UINT8_0_DATA',
- 16: 'RGB_UINT8_1_DATA',
- 17: 'RGB_UINT16_DATA',
- 18: 'RGB_FLOAT32_DATA',
- 19: 'RGB_FLOAT64_DATA',
- 20: 'RGBA_UINT8_0_DATA',
- 21: 'RGBA_UINT8_1_DATA',
- 22: 'RGBA_UINT8_2_DATA',
- 23: 'RGBA_UINT8_3_DATA',
- 24: 'RGBA_UINT16_DATA',
- 25: 'RGBA_FLOAT32_DATA',
- 26: 'RGBA_FLOAT64_DATA',
- 27: 'POINT2_SINT16_0_DATA',
- 28: 'POINT2_SINT16_1_DATA',
- 29: 'POINT2_SINT32_0_DATA',
- 30: 'POINT2_FLOAT32_0_DATA',
- 31: 'RECT_SINT16_1_DATA',
- 32: 'RECT_SINT32_1_DATA',
- 33: 'RECT_FLOAT32_1_DATA',
- 34: 'RECT_FLOAT32_0_DATA',
- 35: 'SIGNED_INT64_DATA',
- 36: 'UNSIGNED_INT64_DATA',
- 37: 'LAST_DATA',
- }
-
-## supported Data Types
-dT_supported = [1, 2, 6, 7, 9, 10, 11, 14]
-SUPPORTED_DATA_TYPES = {i: dataTypes[i] for i in dT_supported}
-
-## other constants ##
-IMGLIST = "root.ImageList."
-OBJLIST = "root.DocumentObjectList."
-MAXDEPTH = 64
-
-DEFAULTCHARSET = 'utf-8'
-## END constants ##
-
-
-class DM3(object):
- """DM3 object. """
-
- ## utility functions
- def _makeGroupString(self):
- tString = str(self._curGroupAtLevelX[0])
- for i in range( 1, self._curGroupLevel+1 ):
- tString += '.{}'.format(self._curGroupAtLevelX[i])
- return tString
-
- def _makeGroupNameString(self):
- tString = self._curGroupNameAtLevelX[0]
- for i in range( 1, self._curGroupLevel+1 ):
- tString += '.' + str( self._curGroupNameAtLevelX[i] )
- return tString
-
- def _readTagGroup(self):
- # go down a level
- self._curGroupLevel += 1
- # increment group counter
- self._curGroupAtLevelX[self._curGroupLevel] += 1
- # set number of current tag to -1
- # --- readTagEntry() pre-increments => first gets 0
- self._curTagAtLevelX[self._curGroupLevel] = -1
- if ( debugLevel > 5):
- print("rTG: Current Group Level:", self._curGroupLevel)
- # is the group sorted?
- sorted_ = readByte(self._f)
- isSorted = (sorted_ == 1)
- # is the group open?
- opened = readByte(self._f)
- isOpen = (opened == 1)
- # number of Tags
- nTags = readLong(self._f)
- if ( debugLevel > 5):
- print("rTG: Iterating over the", nTags, "tag entries in this group")
- # read Tags
- for i in range( nTags ):
- self._readTagEntry()
- # go back up one level as reading group is finished
- self._curGroupLevel += -1
- return 1
-
- def _readTagEntry(self):
- # is data or a new group?
- data = readByte(self._f)
- isData = (data == 21)
- self._curTagAtLevelX[self._curGroupLevel] += 1
- # get tag label if exists
- lenTagLabel = readShort(self._f)
- if ( lenTagLabel != 0 ):
- tagLabel = readString(self._f, lenTagLabel).decode('latin-1')
- else:
- tagLabel = str( self._curTagAtLevelX[self._curGroupLevel] )
- if ( debugLevel > 5):
- print("{}|{}:".format(self._curGroupLevel, self._makeGroupString()),
- end=' ')
- print("Tag label = "+tagLabel)
- elif ( debugLevel > 1 ):
- print(str(self._curGroupLevel)+": Tag label = "+tagLabel)
- if isData:
- # give it a name
- self._curTagName = self._makeGroupNameString()+"."+tagLabel
- # read it
- self._readTagType()
- else:
- # it is a tag group
- self._curGroupNameAtLevelX[self._curGroupLevel+1] = tagLabel
- self._readTagGroup() # increments curGroupLevel
- return 1
-
- def _readTagType(self):
- delim = readString(self._f, 4).decode('latin-1')
- if ( delim != '%%%%' ):
- raise Exception(hex( self._f.tell() )
- + ": Tag Type delimiter not %%%%")
- nInTag = readLong(self._f)
- self._readAnyData()
- return 1
-
- def _encodedTypeSize(self, eT):
- # returns the size in bytes of the data type
- if eT == 0:
- width = 0
- elif eT in (BOOLEAN, CHAR, OCTET):
- width = 1
- elif eT in (SHORT, USHORT):
- width = 2
- elif eT in (LONG, ULONG, FLOAT):
- width = 4
- elif eT == DOUBLE:
- width = 8
- else:
- # returns -1 for unrecognised types
- width = -1
- return width
-
- def _readAnyData(self):
- ## higher level function dispatching to handling data types
- ## to other functions
- # - get Type category (short, long, array...)
- encodedType = readLong(self._f)
- # - calc size of encodedType
- etSize = self._encodedTypeSize(encodedType)
- if ( debugLevel > 5):
- print("rAnD, " + hex( self._f.tell() ) + ":", end=' ')
- print("Tag Type = " + str(encodedType) + ",", end=' ')
- print("Tag Size = " + str(etSize))
- if ( etSize > 0 ):
- self._storeTag( self._curTagName,
- self._readNativeData(encodedType, etSize) )
- elif ( encodedType == STRING ):
- stringSize = readLong(self._f)
- self._readStringData(stringSize)
- elif ( encodedType == STRUCT ):
- # does not store tags yet
- structTypes = self._readStructTypes()
- self._readStructData(structTypes)
- elif ( encodedType == ARRAY ):
- # does not store tags yet
- # indicates size of skipped data blocks
- arrayTypes = self._readArrayTypes()
- self._readArrayData(arrayTypes)
- else:
- raise Exception("rAnD, " + hex(self._f.tell())
- + ": Can't understand encoded type")
- return 1
-
- def _readNativeData(self, encodedType, etSize):
- # reads ordinary data types
- if encodedType in readFunc:
- val = readFunc[encodedType](self._f)
- else:
- raise Exception("rND, " + hex(self._f.tell())
- + ": Unknown data type " + str(encodedType))
- if ( debugLevel > 3 ):
- print("rND, " + hex(self._f.tell()) + ": " + str(val))
- elif ( debugLevel > 1 ):
- print(val)
- return val
-
- def _readStringData(self, stringSize):
- # reads string data
- if ( stringSize <= 0 ):
- rString = ""
- else:
- if ( debugLevel > 3 ):
- print("rSD @ " + str(self._f.tell()) + "/" + hex(self._f.tell()) +" :", end=' ')
- rString = readString(self._f, stringSize)
- # /!\ UTF-16 unicode string => convert to Python unicode str
- rString = rString.decode('utf-16-le')
- if ( debugLevel > 3 ):
- print(rString + " <" + repr( rString ) + ">")
- if ( debugLevel > 1 ):
- print("StringVal:", rString)
- self._storeTag( self._curTagName, rString )
- return rString
-
- def _readArrayTypes(self):
- # determines the data types in an array data type
- arrayType = readLong(self._f)
- itemTypes = []
- if ( arrayType == STRUCT ):
- itemTypes = self._readStructTypes()
- elif ( arrayType == ARRAY ):
- itemTypes = self._readArrayTypes()
- else:
- itemTypes.append( arrayType )
- return itemTypes
-
- def _readArrayData(self, arrayTypes):
- # reads array data
-
- arraySize = readLong(self._f)
-
- if ( debugLevel > 3 ):
- print("rArD, " + hex( self._f.tell() ) + ":", end=' ')
- print("Reading array of size = " + str(arraySize))
-
- itemSize = 0
- encodedType = 0
-
- for i in range( len(arrayTypes) ):
- encodedType = int( arrayTypes[i] )
- etSize = self._encodedTypeSize(encodedType)
- itemSize += etSize
- if ( debugLevel > 5 ):
- print("rArD: Tag Type = " + str(encodedType) + ",", end=' ')
- print("Tag Size = " + str(etSize))
- ##! readNativeData( encodedType, etSize ) !##
-
- if ( debugLevel > 5 ):
- print("rArD: Array Item Size = " + str(itemSize))
-
- bufSize = arraySize * itemSize
-
- if ( (not self._curTagName.endswith("ImageData.Data"))
- and ( len(arrayTypes) == 1 )
- and ( encodedType == USHORT )
- and ( arraySize < 256 ) ):
- # treat as string
- val = self._readStringData( bufSize )
- else:
- # treat as binary data
- # - store data size and offset as tags
- self._storeTag( self._curTagName + ".Size", bufSize )
- self._storeTag( self._curTagName + ".Offset", self._f.tell() )
- # - skip data w/o reading
- self._f.seek( self._f.tell() + bufSize )
-
- return 1
-
- def _readStructTypes(self):
- # analyses data types in a struct
-
- if ( debugLevel > 3 ):
- print("Reading Struct Types at Pos = " + hex(self._f.tell()))
-
- structNameLength = readLong(self._f)
- nFields = readLong(self._f)
-
- if ( debugLevel > 5 ):
- print("nFields = ", nFields)
-
- if ( nFields > 100 ):
- raise Exception(hex(self._f.tell())+": Too many fields")
-
- fieldTypes = []
- nameLength = 0
- for i in range( nFields ):
- nameLength = readLong(self._f)
- if ( debugLevel > 9 ):
- print("{}th nameLength = {}".format(i, nameLength))
- fieldType = readLong(self._f)
- fieldTypes.append( fieldType )
-
- return fieldTypes
-
- def _readStructData(self, structTypes):
- # reads struct data based on type info in structType
- for i in range( len(structTypes) ):
- encodedType = structTypes[i]
- etSize = self._encodedTypeSize(encodedType)
-
- if ( debugLevel > 5 ):
- print("Tag Type = " + str(encodedType) + ",", end=' ')
- print("Tag Size = " + str(etSize))
-
- # get data
- self._readNativeData(encodedType, etSize)
-
- return 1
-
- def _storeTag(self, tagName, tagValue):
- # store Tags as list and dict
- # NB: all tag values (and names) stored as unicode objects;
- # => can then be easily converted to any encoding
- if ( debugLevel == 1 ):
- print(" - storing Tag:")
- print(" -- name: ", tagName)
- print(" -- value: ", tagValue, type(tagValue))
- # - convert tag value to unicode if not already unicode object
- self._storedTags.append( tagName + " = " + unicode_str(tagValue) )
- self._tagDict[tagName] = unicode_str(tagValue)
-
- ### END utility functions ###
-
- def __init__(self, filename, debug=0):
- """DM3 object: parses DM3 file."""
-
- ## initialize variables ##
- self._debug = debug
- self._outputcharset = DEFAULTCHARSET
- self._filename = filename
- self._chosenImage = 1
- # - track currently read group
- self._curGroupLevel = -1
- self._curGroupAtLevelX = [ 0 for x in range(MAXDEPTH) ]
- self._curGroupNameAtLevelX = [ '' for x in range(MAXDEPTH) ]
- # - track current tag
- self._curTagAtLevelX = [ '' for x in range(MAXDEPTH) ]
- self._curTagName = ''
- # - open file for reading
- self._f = open( self._filename, 'rb' )
- # - create Tags repositories
- self._storedTags = []
- self._tagDict = {}
-
- isDM3 = True
- ## read header (first 3 4-byte int)
- # get version
- fileVersion = readLong(self._f)
- if ( fileVersion != 3 ):
- isDM3 = False
- # get indicated file size
- fileSize = readLong(self._f)
- # get byte-ordering
- lE = readLong(self._f)
- littleEndian = (lE == 1)
- if not littleEndian:
- isDM3 = False
- # check file header, raise Exception if not DM3
- if not isDM3:
- raise Exception("%s does not appear to be a DM3 file."
- % os.path.split(self._filename)[1])
- elif self._debug > 0:
- print("%s appears to be a DM3 file" % (self._filename))
-
- if ( debugLevel > 5 or self._debug > 1):
- print("Header info.:")
- print("- file version:", fileVersion)
- print("- lE:", lE)
- print("- file size:", fileSize, "bytes")
-
- # set name of root group (contains all data)...
- self._curGroupNameAtLevelX[0] = "root"
- # ... then read it
- self._readTagGroup()
- if self._debug > 0:
- print("-- %s Tags read --" % len(self._storedTags))
-
- # fetch image characteristics
- tag_root = 'root.ImageList.1'
- self._data_type = int( self.tags["%s.ImageData.DataType" % tag_root] )
- self._im_width = int( self.tags["%s.ImageData.Dimensions.0" % tag_root] )
- self._im_height = int( self.tags["%s.ImageData.Dimensions.1" % tag_root] )
- try:
- self._im_depth = int( self.tags['root.ImageList.1.ImageData.Dimensions.2'] )
- except KeyError:
- self._im_depth = 1
-
- if self._debug > 0:
- print("Notice: image size: %sx%s px" % (self._im_width, self._im_height))
- if self._im_depth>1:
- print("Notice: %s image stack" % (self._im_depth))
-
- @property
- def data_type(self):
- """Returns image DataType."""
- return self._data_type
-
- @property
- def data_type_str(self):
- """Returns image DataType string."""
- return dataTypes[self._data_type]
-
- @property
- def width(self):
- """Returns image width (px)."""
- return self._im_width
-
- @property
- def height(self):
- """Returns image height (px)."""
- return self._im_height
-
- @property
- def depth(self):
- """Returns image depth (i.e. number of images in stack)."""
- return self._im_depth
-
- @property
- def size(self):
- """Returns image size (width,height[,depth])."""
- if self._im_depth > 1:
- return (self._im_width, self._im_height, self._im_depth)
- else:
- return (self._im_width, self._im_height)
-
- @property
- def outputcharset(self):
- """Returns Tag dump/output charset."""
- return self._outputcharset
-
- @outputcharset.setter
- def outputcharset(self, value):
- """Set Tag dump/output charset."""
- self._outputcharset = value
-
- @property
- def filename(self):
- """Returns full file path."""
- return self._filename
-
- @property
- def tags(self):
- """Returns all image Tags."""
- return self._tagDict
-
- def dumpTags(self, dump_dir='/tmp'):
- """Dumps image Tags in a txt file."""
- dump_file = os.path.join(dump_dir,
- os.path.split(self._filename)[1]
- + ".tagdump.txt")
- try:
- dumpf = open( dump_file, 'w' )
- except:
- print("Warning: cannot generate dump file.")
- else:
- for tag in self._storedTags:
- dumpf.write( "{}\n".format(tag.encode(self._outputcharset)))
- dumpf.close
-
- @property
- def info(self):
- """Extracts useful experiment info from DM3 file."""
- # define useful information
- tag_root = 'root.ImageList.1'
- info_keys = {
- 'descrip': "%s.Description" % tag_root,
- 'acq_date': "%s.ImageTags.DataBar.Acquisition Date" % tag_root,
- 'acq_time': "%s.ImageTags.DataBar.Acquisition Time" % tag_root,
- 'name': "%s.ImageTags.Microscope Info.Name" % tag_root,
- 'micro': "%s.ImageTags.Microscope Info.Microscope" % tag_root,
- 'hv': "%s.ImageTags.Microscope Info.Voltage" % tag_root,
- 'mag': "%s.ImageTags.Microscope Info.Indicated Magnification" % tag_root,
- 'mode': "%s.ImageTags.Microscope Info.Operation Mode" % tag_root,
- 'operator': "%s.ImageTags.Microscope Info.Operator" % tag_root,
- 'specimen': "%s.ImageTags.Microscope Info.Specimen" % tag_root,
- # 'image_notes': "root.DocumentObjectList.10.Text' # = Image Notes
- }
- # get experiment information
- infoDict = {}
- for key, tag_name in info_keys.items():
- if tag_name in self.tags:
- # tags supplied as Python unicode str; convert to chosen charset
- # (typically latin-1 or utf-8)
- infoDict[key] = self.tags[tag_name].encode(self._outputcharset)
- # return experiment information
- return infoDict
-
- @property
- def imagedata(self):
- """Extracts image data as numpy.array"""
-
- # numpy dtype strings associated to the various image dataTypes
- dT_str = {
- 1: ' 0:
- print("Notice: image data in %s starts at %s" % (
- os.path.split(self._filename)[1], hex(data_offset)
- ))
-
- # check if image DataType is implemented, then read
- if data_type in dT_str:
- np_dt = numpy.dtype( dT_str[data_type] )
- if self._debug > 0:
- print("Notice: image data type: %s ('%s'), read as %s" % (
- data_type, dataTypes[data_type], np_dt
- ))
- self._f.seek( data_offset )
- # - fetch image data
- rawdata = self._f.read(data_size)
- # - convert raw to numpy array w/ correct dtype
- ima = numpy.fromstring(rawdata, dtype=np_dt)
- # - reshape to matrix or stack
- if im_depth > 1:
- ima = ima.reshape(im_depth, im_height, im_width)
- else:
- ima = ima.reshape(im_height, im_width)
- else:
- raise Exception(
- "Cannot extract image data from %s: unimplemented DataType (%s:%s)." %
- (os.path.split(self._filename)[1], data_type, dataTypes[data_type])
- )
-
- # if image dataType is BINARY, binarize image
- # (i.e., px_value>0 is True)
- if data_type == 14:
- ima[ima>0] = 1
-
- return ima
-
-
- @property
- def Image(self):
- """Returns image data as PIL Image"""
-
- # define PIL Image mode for the various (supported) image dataTypes,
- # among:
- # - '1': 1-bit pixels, black and white, stored as 8-bit pixels
- # - 'L': 8-bit pixels, gray levels
- # - 'I': 32-bit integer pixels
- # - 'F': 32-bit floating point pixels
- dT_modes = {
- 1: 'I', # 16-bit LE signed integer
- 2: 'F', # 32-bit LE floating point
- 6: 'L', # 8-bit unsigned integer
- 7: 'I', # 32-bit LE signed integer
- 9: 'I', # 8-bit signed integer
- 10: 'I', # 16-bit LE unsigned integer
- 11: 'I', # 32-bit LE unsigned integer
- 14: 'L', # "binary"
- }
-
- # define loaded array dtype if has to be fixed to match Image mode
- dT_newdtypes = {
- 1: 'int32', # 16-bit LE integer to 32-bit int
- 2: 'float32', # 32-bit LE float to 32-bit float
- 9: 'int32', # 8-bit signed integer to 32-bit int
- 10: 'int32', # 16-bit LE u. integer to 32-bit int
- }
-
- # get relevant Tags
- data_type = self._data_type
- im_width = self._im_width
- im_height = self._im_height
- im_depth = self._im_depth
-
- # fetch image data array
- ima = self.imagedata
- # assign Image mode
- mode_ = dT_modes[data_type]
-
- # reshape array if image stack
- if im_depth > 1:
- ima = ima.reshape(im_height*im_depth, im_width)
-
- # load image data array into Image object (recast array if necessary)
- if data_type in dT_newdtypes:
- im = Image.fromarray(ima.astype(dT_newdtypes[data_type]),mode_)
- else:
- im = Image.fromarray(ima,mode_)
-
- return im
-
-
- @property
- def contrastlimits(self):
- """Returns display range (cuts)."""
- tag_root = 'root.DocumentObjectList.0'
- low = int(float(self.tags["%s.ImageDisplayInfo.LowLimit" % tag_root]))
- high = int(float(self.tags["%s.ImageDisplayInfo.HighLimit" % tag_root]))
- cuts = (low, high)
- return cuts
-
- @property
- def cuts(self):
- """Returns display range (cuts)."""
- return self.contrastlimits
-
- @property
- def pxsize(self):
- """Returns pixel size and unit."""
- tag_root = 'root.ImageList.1'
- pixel_size = float(
- self.tags["%s.ImageData.Calibrations.Dimension.0.Scale" % tag_root])
- unit = self.tags["%s.ImageData.Calibrations.Dimension.0.Units" %
- tag_root]
- if unit == u'\xb5m':
- unit = 'micron'
- else:
- unit = unit.encode('ascii')
- if self._debug > 0:
- print("pixel size = %s %s" % (pixel_size, unit))
- return (pixel_size, unit)
-
-
- @property
- def tnImage(self):
- """Returns thumbnail as PIL Image."""
- # get thumbnail
- tag_root = 'root.ImageList.0'
- tn_size = int( self.tags["%s.ImageData.Data.Size" % tag_root] )
- tn_offset = int( self.tags["%s.ImageData.Data.Offset" % tag_root] )
- tn_width = int( self.tags["%s.ImageData.Dimensions.0" % tag_root] )
- tn_height = int( self.tags["%s.ImageData.Dimensions.1" % tag_root] )
-
- if self._debug > 0:
- print("Notice: tn data in %s starts at %s" % (
- os.path.split(self._filename)[1], hex(tn_offset)
- ))
- print("Notice: tn size: %sx%s px" % (tn_width, tn_height))
-
- if (tn_width*tn_height*4) != tn_size:
- raise Exception("Cannot extract thumbnail from %s"
- % os.path.split(self._filename)[1])
- else:
- self._f.seek( tn_offset )
- rawdata = self._f.read(tn_size)
- # - read as 32-bit LE unsigned integer
- tn = Image.frombytes( 'F', (tn_width, tn_height), rawdata,
- 'raw', 'F;32' )
- # - rescale and convert px data
- tn = tn.point(lambda x: x * (1./65536) + 0)
- tn = tn.convert('L')
- # - return image
- return tn
-
- @property
- def thumbnaildata(self):
- """Fetch thumbnail image data as numpy.array"""
-
- # get useful thumbnail Tags
- tag_root = 'root.ImageList.0'
- tn_size = int( self.tags["%s.ImageData.Data.Size" % tag_root] )
- tn_offset = int( self.tags["%s.ImageData.Data.Offset" % tag_root] )
- tn_width = int( self.tags["%s.ImageData.Dimensions.0" % tag_root] )
- tn_height = int( self.tags["%s.ImageData.Dimensions.1" % tag_root] )
-
- if self._debug > 0:
- print("Notice: tn data in %s starts at %s" % (
- os.path.split(self._filename)[1], hex(tn_offset)
- ))
- print("Notice: tn size: %sx%s px" % (tn_width, tn_height))
-
- # get thumbnail data
- if (tn_width*tn_height*4) == tn_size:
- self._f.seek(tn_offset)
- rawtndata = self._f.read(tn_size)
- print('## rawdata:', len(rawtndata))
- # - read as 32-bit LE unsigned integer
- np_dt_tn = numpy.dtype(' 0:
- print("Thumbnail saved as '%s'." % tn_path)
- except:
- print("Warning: could not save thumbnail.")
-
-
-## MAIN ##
-if __name__ == '__main__':
- print("dm3_lib %s" % VERSION)
-
diff --git a/dm3_lib/demo/demo.py b/dm3_lib/demo/demo.py
deleted file mode 100755
index 1d15c98e1..000000000
--- a/dm3_lib/demo/demo.py
+++ /dev/null
@@ -1,122 +0,0 @@
-#!/usr/bin/python
-
-from __future__ import print_function, division
-
-import os.path
-import argparse
-
-import numpy as np
-import matplotlib.pyplot as plt
-
-from PIL import Image
-
-import dm3_lib as dm3
-
-from utilities import calcHistogram, calcDisplayRange
-
-# CONSTANTS
-
-savedir = os.path.expanduser("~/Desktop")
-debug = 0
-
-# define command line arguments
-parser = argparse.ArgumentParser()
-
-parser.add_argument("file", help="path to DM3 file to parse")
-parser.add_argument("-v", "--verbose", help="increase output verbosity",
- action="store_true")
-parser.add_argument("--dump", help="dump DM3 tags in text file",
- action="store_true")
-parser.add_argument("--convert", help="save image in various formats",
- action="store_true")
-
-# parse command line arguments
-args = parser.parse_args()
-if args.verbose:
- debug = 1
-filepath = args.file
-
-# get filename
-filename = os.path.split(filepath)[1]
-fileref = os.path.splitext(filename)[0]
-
-# pyplot interactive mode
-plt.ion()
-plt.close('all')
-
-# parse DM3 file
-dm3f = dm3.DM3(filepath, debug=debug)
-
-# get some useful tag data and print
-print("file:", dm3f.filename)
-print("file info.:")
-print(dm3f.info)
-print("scale: %.3g %s/px"%dm3f.pxsize)
-cuts = dm3f.cuts
-print("cuts:",cuts)
-
-# dump image Tags in txt file
-if args.dump:
- dm3f.dumpTags(savedir)
-
-# get image data
-aa = dm3f.imagedata
-
-# display image
-# - w/o cuts
-if args.verbose:
- plt.matshow(aa, cmap=plt.cm.pink)
- plt.title("%s (w/o cuts)"%filename)
- plt.colorbar(shrink=.8)
-# - w/ cuts (if cut values different)
-if cuts[0] != cuts[1]:
- plt.matshow(aa, cmap=plt.cm.pink, vmin=cuts[0], vmax=cuts[1])
- plt.title("%s"%filename)
- plt.colorbar(shrink=.8)
-
-# - display image histogram
-if args.verbose:
- hh,bb = calcHistogram(aa)
- plt.figure('Image histogram')
- plt.plot(bb[:-1],hh,drawstyle='steps')
- plt.xlim(bb[0],bb[-1])
- plt.xlabel('Intensity')
- plt.ylabel('Number')
-
-# convert image to various formats
-if args.convert:
- # save image as TIFF
- tif_file = os.path.join(savedir,fileref+'.tif')
- im = Image.fromarray(aa)
- im.save(tif_file)
- # check TIFF dynamic range
- tim = Image.open(tif_file)
- if tim.mode == 'L':
- tif_range = "8-bit"
- else:
- tif_range = "32-bit"
- print("Image saved as %s TIFF."%tif_range)
-
- # save image as PNG and JPG files
- # - normalize image for conversion to 8-bit
- aa_norm = aa.copy()
- # -- apply cuts (optional)
- if cuts[0] != cuts[1]:
- aa_norm[ (aa <= min(cuts)) ] = float(min(cuts))
- aa_norm[ (aa >= max(cuts)) ] = float(max(cuts))
- # -- normalize
- aa_norm = (aa_norm - np.min(aa_norm)) / (np.max(aa_norm) - np.min(aa_norm))
- # -- scale to 0--255, convert to (8-bit) integer
- aa_norm = np.uint8(np.round( aa_norm * 255 ))
-
- if args.verbose:
- # - display normalized image
- plt.matshow(aa_norm, cmap=plt.cm.Greys_r)
- plt.title("%s [8-bit display]"%filename)
- plt.colorbar(shrink=.8)
-
- # - save as PNG and JPG
- im_dsp = Image.fromarray(aa_norm)
- im_dsp.save(os.path.join(savedir,fileref+'.png'))
- im_dsp.save(os.path.join(savedir,fileref+'.jpg'))
-
diff --git a/dm3_lib/demo/utilities.py b/dm3_lib/demo/utilities.py
deleted file mode 100644
index 50bc80efb..000000000
--- a/dm3_lib/demo/utilities.py
+++ /dev/null
@@ -1,35 +0,0 @@
-#!/usr/bin/python
-
-import numpy as np
-
-# histogram, re-compute cuts
-
-def calcHistogram(imdata, bins_=256):
- '''Compute image histogram.'''
- im_values = np.ravel(imdata)
- hh, bins_ = np.histogram( im_values, bins=bins_ )
- return hh, bins_
-
-def calcDisplayRange(imdata, cutoff=.1, bins_=512):
- '''Compute display range, i.e., cuts.
- (ignore the 'cutoff'% lowest/highest value pixels)'''
- # compute image histogram
- hh, bins_ = calcHistogram(imdata, bins_)
- # check histogram format
- if len(bins_)==len(hh):
- bb = bins_
- else:
- bb = bins_[:-1] # 'bins' == bin_edges
- # number of pixels
- Npx = np.sum(hh)
- # calc. lower limit :
- i = 1
- while np.sum( hh[:i] ) < Npx*cutoff/100.:
- i += 1
- cut0 = round( bb[i] )
- # calc. higher limit
- j = 1
- while np.sum( hh[-j:] ) < Npx*cutoff/100.:
- j += 1
- cut1 = round( bb[-j] )
- return cut0,cut1
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 000000000..d0c3cbf10
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/_static/Advanced_plots.png b/docs/_static/Advanced_plots.png
new file mode 100644
index 000000000..407dbf678
Binary files /dev/null and b/docs/_static/Advanced_plots.png differ
diff --git a/docs/_static/DOE_logo.png b/docs/_static/DOE_logo.png
new file mode 100644
index 000000000..33681cbc0
Binary files /dev/null and b/docs/_static/DOE_logo.png differ
diff --git a/docs/_static/DOI-BADGE-978-3-319-76207-4_15.svg b/docs/_static/DOI-BADGE-978-3-319-76207-4_15.svg
new file mode 100644
index 000000000..49ac8fcd5
--- /dev/null
+++ b/docs/_static/DOI-BADGE-978-3-319-76207-4_15.svg
@@ -0,0 +1,35 @@
+
\ No newline at end of file
diff --git a/docs/_static/demo.gif b/docs/_static/demo.gif
new file mode 100644
index 000000000..97f4e5a34
Binary files /dev/null and b/docs/_static/demo.gif differ
diff --git a/docs/_static/dp.png b/docs/_static/dp.png
new file mode 100644
index 000000000..6ed560714
Binary files /dev/null and b/docs/_static/dp.png differ
diff --git a/docs/_static/py4DSTEM_logo.png b/docs/_static/py4DSTEM_logo.png
new file mode 100644
index 000000000..dc749a225
Binary files /dev/null and b/docs/_static/py4DSTEM_logo.png differ
diff --git a/docs/_static/py4DSTEM_logo_small.ico b/docs/_static/py4DSTEM_logo_small.ico
new file mode 100644
index 000000000..f9719d08c
Binary files /dev/null and b/docs/_static/py4DSTEM_logo_small.ico differ
diff --git a/docs/_static/py4DSTEM_logo_vsmall.ico b/docs/_static/py4DSTEM_logo_vsmall.ico
new file mode 100644
index 000000000..f9719d08c
Binary files /dev/null and b/docs/_static/py4DSTEM_logo_vsmall.ico differ
diff --git a/docs/_static/toyota_research_institute.png b/docs/_static/toyota_research_institute.png
new file mode 100644
index 000000000..d06126fd6
Binary files /dev/null and b/docs/_static/toyota_research_institute.png differ
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 000000000..6247f7e23
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/docs/papers.md b/docs/papers.md
new file mode 100644
index 000000000..4b1bc6b8d
--- /dev/null
+++ b/docs/papers.md
@@ -0,0 +1,61 @@
+![py4DSTEM logo](/images/py4DSTEM_logo.png)
+
+
+## Papers which have used py4DSTEM
+
+Please email clophus@lbl.gov if you have used py4DSTEM for analysis and your paper is not listed below!
+
+### 2022 (9)
+
+[Correlative image learning of chemo-mechanics in phase-transforming solids](https://www.nature.com/articles/s41563-021-01191-0), Nature Materials (2022)
+
+[Correlative analysis of structure and chemistry of LixFePO4 platelets using 4D-STEM and X-ray ptychography](https://doi.org/10.1016/j.mattod.2021.10.031), Materials Today 52, 102 (2022).
+
+[Visualizing Grain Statistics in MOCVD WSe2 through Four-Dimensional Scanning Transmission Electron Microscopy](https://doi.org/10.1021/acs.nanolett.1c04315), Nano Letters 22, 2578 (2022).
+
+[Electric field control of chirality](https://doi.org/10.1126/sciadv.abj8030), Science Advances 8 (2022).
+
+[Real-Time Interactive 4D-STEM Phase-Contrast Imaging From Electron Event Representation Data: Less computation with the right representation](https://doi.org/10.1109/MSP.2021.3120981), IEEE Signal Processing Magazine 39, 25 (2022).
+
+[Microstructural dependence of defect formation in iron-oxide thin films](https://doi.org/10.1016/j.apsusc.2022.152844), Applied Surface Science 589, 152844 (2022).
+
+[Chemical and Structural Alterations in the Amorphous Structure of Obsidian due to Nanolites](https://doi.org/10.1017/S1431927621013957), Microscopy and Microanalysis 28, 289 (2022).
+
+[Nanoscale characterization of crystalline and amorphous phases in silicon oxycarbide ceramics using 4D-STEM](https://doi.org/10.1016/j.matchar.2021.111512), Materials Characterization 181, 111512 (2021).
+
+[Disentangling multiple scattering with deep learning: application to strain mapping from electron diffraction patterns](https://arxiv.org/abs/2202.00204), arXiv:2202.00204 (2022).
+
+
+
+### 2021 (10)
+
+[Cryoforged nanotwinned titanium with ultrahigh strength and ductility](https://doi.org/10.1126/science.abe7252), Science 16, 373, 1363 (2021).
+
+[Strain fields in twisted bilayer graphene](https://doi.org/10.1038/s41563-021-00973-w), Nature Materials 20, 956 (2021).
+
+[Determination of Grain-Boundary Structure and Electrostatic Characteristics in a SrTiO3 Bicrystal by Four-Dimensional Electron Microscopy](https://doi.org/10.1021/acs.nanolett.1c02960), Nanoletters 21, 9138 (2021).
+
+[Local Lattice Deformation of Tellurene Grain Boundaries by Four-Dimensional Electron Microscopy](https://pubs.acs.org/doi/10.1021/acs.jpcc.1c00308), Journal of Physical Chemistry C 125, 3396 (2021).
+
+[Extreme mixing in nanoscale transition metal alloys](https://doi.org/10.1016/j.matt.2021.04.014), Matter 4, 2340 (2021).
+
+[Multibeam Electron Diffraction](https://doi.org/10.1017/S1431927620024770), Microscopy and Microanalysis 27, 129 (2021).
+
+[A Fast Algorithm for Scanning Transmission Electron Microscopy Imaging and 4D-STEM Diffraction Simulations](https://doi.org/10.1017/S1431927621012083), Microscopy and Microanalysis 27, 835 (2021).
+
+[Fast Grain Mapping with Sub-Nanometer Resolution Using 4D-STEM with Grain Classification by Principal Component Analysis and Non-Negative Matrix Factorization](https://doi.org/10.1017/S1431927621011946), Microscopy and Microanalysis 27, 794
+
+[Prismatic 2.0 – Simulation software for scanning and high resolution transmission electron microscopy (STEM and HRTEM)](https://doi.org/10.1016/j.micron.2021.103141), Micron 151, 103141 (2021).
+
+[4D-STEM of Beam-Sensitive Materials](https://doi.org/10.1021/acs.accounts.1c00073), Accounts of Chemical Research 54, 2543 (2021).
+
+
+### 2020 (3)
+
+[Patterned probes for high precision 4D-STEM bragg measurements](https://doi.org/10.1063/5.0015532), Ultramicroscopy 209, 112890 (2020).
+
+
+[Tilted fluctuation electron microscopy](https://doi.org/10.1063/5.0015532), Applied Physics Letters 117, 091903 (2020).
+
+[4D-STEM elastic stress state characterisation of a TWIP steel nanotwin](https://arxiv.org/abs/2004.03982), arXiv:2004.03982
+
diff --git a/docs/requirements.txt b/docs/requirements.txt
new file mode 100644
index 000000000..03ecc7e26
--- /dev/null
+++ b/docs/requirements.txt
@@ -0,0 +1,3 @@
+emdfile
+sphinx_rtd_theme
+# py4dstem
diff --git a/docs/source/4dstem.rst b/docs/source/4dstem.rst
new file mode 100644
index 000000000..d6c927979
--- /dev/null
+++ b/docs/source/4dstem.rst
@@ -0,0 +1,27 @@
+.. _4dstem:
+
+What is 4D-STEM?
+================
+
+**S**\ canning **T**\ ransmission **E**\ lectron **M**\ icropscopy (STEM) is a powerful tool for materials characterization.
+In a traditional STEM experiment, a beam of high energy electrons is focused to a very fine probe - on the order of, or even smaller than, the spacing between atoms - and rastered across the surface of the sample.
+A conventional two-dimensional STEM image is formed by populating the value of each pixel with the electron flux through a detector at the corresponding beam position.
+In a high resolution tool, this enables imaging at the level of atoms.
+
+
+Four-dimensional scanning transmission electron microscopy (4D-STEM) uses a fast, pixelated electron detector to collect far more information than a traditional STEM experiment.
+In 4D-STEM, a pixelated detector is used to record a 2D diffraction image at every raster position of the beam.
+A 4D-STEM scan thus results in a 4D data array: two dimensions in diffraction space (i.e. the detector pixels), and two dimensions in real space (i.e. the rastering of the beam).
+
+
+4D-STEM data is information rich.
+A 4D datacube can be collapsed in real space to yield information comparable to nanobeam electron diffraction experiment, or in diffraction space to yield a variety of virtual images, corresponding to both traditional STEM imaging modes as well as more exotic virtual imaging modalities.
+The structure, symmetries, and spacings of Bragg disks can be used to extract spatially resolved maps of crystallinity, grain orientations, and lattice strain.
+Redundant information in overlapping Bragg disks can be leveraged to calculate the sample potential.
+Structure in the diffracted halos of amorphous systems can be used to describe the short and medium range order.
+
+
+py4DSTEM supports many different modes of 4D-STEM analysis.
+
+
+
diff --git a/docs/source/acknowledgements.rst b/docs/source/acknowledgements.rst
new file mode 100644
index 000000000..efc1c589a
--- /dev/null
+++ b/docs/source/acknowledgements.rst
@@ -0,0 +1,31 @@
+Acknowledgements
+================
+
+
+* If you use py4DSTEM for a scientific study, please cite our open access `py4DSTEM publication`_ in Microscopy and Microanalysis.
+
+
+ * py4DSTEM: A Software Package for Four-Dimensional Scanning Transmission Electron Microscopy Data Analysis
+
+ .. image:: ../_static/DOI-BADGE-978-3-319-76207-4_15.svg
+ :target: https://doi.org/10.1017/S1431927621000477
+
+* Check out the `Py4DSTEM Github`_
+
+* We'd like to thank The developers gratefully acknowledge the financial support of the Toyota Research Institute for the research and development time which made this project possible.
+
+.. image:: ../_static/toyota_research_institute.png
+
+* Additional funding has been provided by the US Department of Energy, Office of Science, Basic Energy Sciences.
+
+.. image:: ../_static/DOE_logo.png
+
+* You are also free to use the py4DSTEM logo in PDF format or logo in PNG format for presentations or posters.
+
+**********
+References
+**********
+
+.. target-notes::
+.. _`py4DSTEM publication`: https://doi.org/10.1017/S1431927621000477
+.. _`Py4DSTEM Github`: http://github.com/py4DSTEM/py4DSTEM
diff --git a/docs/source/api.rst b/docs/source/api.rst
new file mode 100644
index 000000000..764aa5e1a
--- /dev/null
+++ b/docs/source/api.rst
@@ -0,0 +1,17 @@
+API
+===
+
+For a full index of py4DSTEM functions and classes check out :ref:`apiIndex`
+
+.. toctree::
+ :maxdepth: 2
+
+ api/py4DSTEM
+ api/classes
+ api/io
+ api/preprocess
+ api/process
+ api/utils
+ api/version
+ api/visualize
+ api/emd
diff --git a/docs/source/api/classes.rst b/docs/source/api/classes.rst
new file mode 100644
index 000000000..9370eaa64
--- /dev/null
+++ b/docs/source/api/classes.rst
@@ -0,0 +1,84 @@
+Classes
+=======
+.. contents:: Table of Contents
+ :depth: 2
+Array
+-----
+.. autoclass:: py4DSTEM.Array
+ :members:
+ :inherited-members:
+BraggVectors
+------------
+.. autoclass:: py4DSTEM.BraggVectors
+ :members:
+ :inherited-members:
+Calibration
+-----------
+.. autoclass:: py4DSTEM.Calibration
+ :members:
+ :inherited-members:
+Custom
+------
+.. autoclass:: py4DSTEM.Custom
+ :members:
+ :inherited-members:
+Data
+----
+.. autoclass:: py4DSTEM.Data
+ :members:
+ :inherited-members:
+DataCube
+--------
+.. autoclass:: py4DSTEM.DataCube
+ :members:
+ :inherited-members:
+DiffractionSlice
+----------------
+.. autoclass:: py4DSTEM.DiffractionSlice
+ :members:
+ :inherited-members:
+Metadata
+--------
+.. autoclass:: py4DSTEM.Metadata
+ :members:
+ :inherited-members:
+Node
+----
+.. autoclass:: py4DSTEM.Node
+ :members:
+ :inherited-members:
+PointList
+---------
+.. autoclass:: py4DSTEM.PointList
+ :members:
+ :inherited-members:
+PointListArray
+--------------
+.. autoclass:: py4DSTEM.PointListArray
+ :members:
+ :inherited-members:
+Probe
+-----
+.. autoclass:: py4DSTEM.Probe
+ :members:
+ :inherited-members:
+QPoints
+-------
+.. autoclass:: py4DSTEM.QPoints
+ :members:
+ :inherited-members:
+RealSlice
+---------
+.. autoclass:: py4DSTEM.RealSlice
+ :members:
+ :inherited-members:
+VirtualDiffraction
+------------------
+.. autoclass:: py4DSTEM.VirtualDiffraction
+ :members:
+ :inherited-members:
+VirtualImage
+------------
+.. autoclass:: py4DSTEM.VirtualImage
+ :members:
+ :inherited-members:
diff --git a/docs/source/api/emd.rst b/docs/source/api/emd.rst
new file mode 100644
index 000000000..8e4223496
--- /dev/null
+++ b/docs/source/api/emd.rst
@@ -0,0 +1,28 @@
+emd
+=======
+.. contents:: Table of Contents
+ :depth: 2
+
+Classes
+-------
+.. autoclass:: emdfile.Array
+.. autoclass:: emdfile.Custom
+.. autoclass:: emdfile.Metadata
+.. autoclass:: emdfile.Node
+.. autoclass:: emdfile.PointList
+.. autoclass:: emdfile.PointListArray
+.. autoclass:: emdfile.Root
+
+
+Functions
+---------
+.. autofunction:: emdfile._get_EMD_version
+.. autofunction:: emdfile._is_EMD_file
+.. autofunction:: emdfile._version_is_geq
+.. autofunction:: emdfile.dirname
+.. autofunction:: emdfile.join
+.. autofunction:: emdfile.print_h5_tree
+.. autofunction:: emdfile.read
+.. autofunction:: emdfile.save
+.. autofunction:: emdfile.set_author
+.. autofunction:: emdfile.tqdmnd
diff --git a/docs/source/api/io.rst b/docs/source/api/io.rst
new file mode 100644
index 000000000..031025f7a
--- /dev/null
+++ b/docs/source/api/io.rst
@@ -0,0 +1,48 @@
+io
+===
+
+.. contents:: Table of Contents
+ :depth: 2
+
+.. automodule:: py4DSTEM.io
+ :members:
+filereaders
+-----------
+.. automodule:: py4DSTEM.io.filereaders
+ :members:
+.. automodule:: py4DSTEM.io.filereaders.empad
+ :members:
+.. automodule:: py4DSTEM.io.filereaders.read_K2
+ :members:
+.. automodule:: py4DSTEM.io.filereaders.read_mib
+ :members:
+google_drive_downloader
+-----------------------
+.. automodule:: py4DSTEM.io.google_drive_downloader
+ :members:
+.. automodule:: py4DSTEM.io.google_drive_downloader.gdown
+ :members:
+importfile
+----------
+.. automodule:: py4DSTEM.io.importfile
+ :members:
+legacy
+------
+.. automodule:: py4DSTEM.io.legacy
+ :members:
+.. automodule:: py4DSTEM.io.legacy.h5py
+ :members:
+.. automodule:: py4DSTEM.io.legacy.legacy12
+ :members:
+.. automodule:: py4DSTEM.io.legacy.legacy13
+ :members:
+.. automodule:: py4DSTEM.io.legacy.read_legacy_12
+ :members:
+.. automodule:: py4DSTEM.io.legacy.read_legacy_13
+ :members:
+.. automodule:: py4DSTEM.io.legacy.read_utils
+ :members:
+parsefiletype
+-------------
+.. automodule:: py4DSTEM.io.parsefiletype
+ :members:
\ No newline at end of file
diff --git a/docs/source/api/io.rst.bk b/docs/source/api/io.rst.bk
new file mode 100644
index 000000000..da7d45a32
--- /dev/null
+++ b/docs/source/api/io.rst.bk
@@ -0,0 +1,22 @@
+io
+===
+
+
+.. contents:: Table of Contents
+ :depth: 2
+
+Reading Native Files
+------
+.. autofunction:: py4DSTEM.io.read
+Reading External Files (e.g. dm4)
+-------------------------
+.. autofunction:: py4DSTEM.io.import_file
+Saving Files
+------------
+.. autofunction:: py4DSTEM.io.save
+Downloading Files
+-----------------
+.. autofunction:: py4DSTEM.io.download_file_from_google_drive
+Get Available Files
+-------------------
+.. autofunction:: py4DSTEM.io.get_sample_data_ids
\ No newline at end of file
diff --git a/docs/source/api/preprocess.rst b/docs/source/api/preprocess.rst
new file mode 100644
index 000000000..06f483a8f
--- /dev/null
+++ b/docs/source/api/preprocess.rst
@@ -0,0 +1,25 @@
+preprocess
+==========
+
+.. contents:: Table of Contents
+ :depth: 2
+darkreference
+-------------
+.. automodule:: py4DSTEM.preprocess.darkreference
+ :members:
+electroncount
+-------------
+.. automodule:: py4DSTEM.preprocess.electroncount
+ :members:
+preprocess
+----------
+.. automodule:: py4DSTEM.preprocess.preprocess
+ :members:
+radialbkgrd
+-----------
+.. automodule:: py4DSTEM.preprocess.radialbkgrd
+ :members:
+utils
+-----
+.. automodule:: py4DSTEM.preprocess.utils
+ :members:
\ No newline at end of file
diff --git a/docs/source/api/process.rst b/docs/source/api/process.rst
new file mode 100644
index 000000000..87b00523b
--- /dev/null
+++ b/docs/source/api/process.rst
@@ -0,0 +1,159 @@
+process
+=======
+.. contents:: Table of Contents
+ :depth: 2
+
+.. automodule:: py4DSTEM.process
+ :members:
+calibration
+-----------
+.. automodule:: py4DSTEM.process.calibration
+ :members:
+.. automodule:: py4DSTEM.process.calibration.braggvectors
+ :members:
+.. automodule:: py4DSTEM.process.calibration.ellipse
+ :members:
+.. automodule:: py4DSTEM.process.calibration.origin
+ :members:
+.. automodule:: py4DSTEM.process.calibration.probe
+ :members:
+.. automodule:: py4DSTEM.process.calibration.qpixelsize
+ :members:
+.. automodule:: py4DSTEM.process.calibration.rotation
+ :members:
+classification
+--------------
+.. automodule:: py4DSTEM.process.classification
+ :members:
+.. automodule:: py4DSTEM.process.classification.braggvectorclassification
+ :members:
+.. automodule:: py4DSTEM.process.classification.classutils
+ :members:
+.. automodule:: py4DSTEM.process.classification.featurization
+ :members:
+diffraction
+-----------
+.. automodule:: py4DSTEM.process.diffraction
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.WK_scattering_factors
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.crystal
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.crystal_ACOM
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.crystal_bloch
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.crystal_calibrate
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.crystal_phase
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.crystal_viz
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.flowlines
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.sys
+ :members:
+.. automodule:: py4DSTEM.process.diffraction.utils
+ :members:
+diskdetection
+-------------
+.. automodule:: py4DSTEM.process.diskdetection
+ :members:
+.. automodule:: py4DSTEM.process.diskdetection.braggvectormap
+ :members:
+.. automodule:: py4DSTEM.process.diskdetection.diskdetection
+ :members:
+.. automodule:: py4DSTEM.process.diskdetection.diskdetection_aiml
+ :members:
+.. automodule:: py4DSTEM.process.diskdetection.threshold
+ :members:
+fit
+---
+.. automodule:: py4DSTEM.process.fit
+ :members:
+.. automodule:: py4DSTEM.process.fit.fit
+ :members:
+latticevectors
+--------------
+.. automodule:: py4DSTEM.process.latticevectors
+ :members:
+.. automodule:: py4DSTEM.process.latticevectors.fit
+ :members:
+.. automodule:: py4DSTEM.process.latticevectors.index
+ :members:
+.. automodule:: py4DSTEM.process.latticevectors.initialguess
+ :members:
+.. automodule:: py4DSTEM.process.latticevectors.strain
+ :members:
+phase
+-----
+.. automodule:: py4DSTEM.process.phase
+ :members:
+.. automodule:: py4DSTEM.process.phase.iterative_base_class
+ :members:
+.. automodule:: py4DSTEM.process.phase.iterative_constraints
+ :members:
+.. automodule:: py4DSTEM.process.phase.iterative_dpc
+ :members:
+.. automodule:: py4DSTEM.process.phase.iterative_mixedstate_ptychography
+ :members:
+.. automodule:: py4DSTEM.process.phase.iterative_multislice_ptychography
+ :members:
+.. automodule:: py4DSTEM.process.phase.iterative_overlap_tomography
+ :members:
+.. automodule:: py4DSTEM.process.phase.iterative_parallax
+ :members:
+.. automodule:: py4DSTEM.process.phase.iterative_ptychography
+ :members:
+.. automodule:: py4DSTEM.process.phase.iterative_simultaneous_ptychography
+ :members:
+.. automodule:: py4DSTEM.process.phase.utils
+ :members:
+probe
+-----
+.. automodule:: py4DSTEM.process.probe
+ :members:
+.. automodule:: py4DSTEM.process.probe.kernel
+ :members:
+.. automodule:: py4DSTEM.process.probe.probe
+ :members:
+
+rdf
+---
+.. automodule:: py4DSTEM.process.rdf.amorph
+ :members:
+.. automodule:: py4DSTEM.process.rdf.rdf
+ :members:
+
+utils
+-----
+.. automodule:: py4DSTEM.process.utils
+ :members:
+.. automodule:: py4DSTEM.process.utils.cross_correlate
+ :members:
+.. automodule:: py4DSTEM.process.utils.elliptical_coords
+ :members:
+.. automodule:: py4DSTEM.process.utils.masks
+ :members:
+.. automodule:: py4DSTEM.process.utils.multicorr
+ :members:
+.. automodule:: py4DSTEM.process.utils.utils
+ :members:
+virtualdiffraction
+------------------
+.. automodule:: py4DSTEM.process.virtualdiffraction
+ :members:
+virtualimage
+------------
+.. automodule:: py4DSTEM.process.virtualimage
+ :members:
+wholepatternfit
+---------------
+.. automodule:: py4DSTEM.process.wholepatternfit
+ :members:
+.. automodule:: py4DSTEM.process.wholepatternfit.wp_models
+ :members:
+.. automodule:: py4DSTEM.process.wholepatternfit.wpf
+ :members:
+.. automodule:: py4DSTEM.process.wholepatternfit.wpf_viz
+ :members:
\ No newline at end of file
diff --git a/docs/source/api/py4DSTEM.rst b/docs/source/api/py4DSTEM.rst
new file mode 100644
index 000000000..1584d87dc
--- /dev/null
+++ b/docs/source/api/py4DSTEM.rst
@@ -0,0 +1,24 @@
+py4DSTEM
+========
+
+There are some shortcuts available for regularly used functions and utilities
+
+.. contents:: Table of Contents
+ :depth: 2
+
+IO
+----
+.. autofunction:: py4DSTEM.read
+.. autofunction:: py4DSTEM.import_file
+.. autofunction:: py4DSTEM.save
+.. autofunction:: py4DSTEM.print_h5_tree
+
+Plotting
+--------
+.. autofunction:: py4DSTEM.show
+
+Utilities
+---------
+.. autofunction:: py4DSTEM.check_config
+.. autofunction:: py4DSTEM.join
+.. autofunction:: py4DSTEM.tqdmnd
\ No newline at end of file
diff --git a/docs/source/api/visualize.rst b/docs/source/api/visualize.rst
new file mode 100644
index 000000000..b9a1d6b47
--- /dev/null
+++ b/docs/source/api/visualize.rst
@@ -0,0 +1,101 @@
+visualize
+=========
+
+.. contents:: Table of Contents
+ :depth: 2
+
+show
+----
+.. autofunction:: py4DSTEM.visualize.show
+.. autofunction:: py4DSTEM.visualize.show_hist
+.. autofunction:: py4DSTEM.visualize.show_Q
+.. autofunction:: py4DSTEM.visualize.show_rectangles
+.. autofunction:: py4DSTEM.visualize.show_circles
+.. autofunction:: py4DSTEM.visualize.show_ellipses
+.. autofunction:: py4DSTEM.visualize.show_annuli
+.. autofunction:: py4DSTEM.visualize.show_points
+
+overlay
+-------
+.. autofunction:: py4DSTEM.visualize.overlay.add_annuli
+.. autofunction:: py4DSTEM.visualize.overlay.add_bragg_index_labels
+.. autofunction:: py4DSTEM.visualize.overlay.add_cartesian_grid
+.. autofunction:: py4DSTEM.visualize.overlay.add_circles
+.. autofunction:: py4DSTEM.visualize.overlay.add_ellipses
+.. autofunction:: py4DSTEM.visualize.overlay.add_grid_overlay
+.. autofunction:: py4DSTEM.visualize.overlay.add_pointlabels
+.. autofunction:: py4DSTEM.visualize.overlay.add_points
+.. autofunction:: py4DSTEM.visualize.overlay.add_polarelliptical_grid
+.. autofunction:: py4DSTEM.visualize.overlay.add_rectangles
+.. autofunction:: py4DSTEM.visualize.overlay.add_rtheta_grid
+.. autofunction:: py4DSTEM.visualize.overlay.add_scalebar
+.. autofunction:: py4DSTEM.visualize.overlay.add_vector
+.. autofunction:: py4DSTEM.visualize.overlay.get_nice_spacing
+.. autofunction:: py4DSTEM.visualize.overlay.is_color_like
+virtualimage
+------------
+.. autofunction:: py4DSTEM.visualize.virtualimage.position_detector
+.. autofunction:: py4DSTEM.visualize.virtualimage.show
+vis_RQ
+------
+.. autofunction:: py4DSTEM.visualize.vis_RQ.ax_addaxes
+.. autofunction:: py4DSTEM.visualize.vis_RQ.ax_addaxes_QtoR
+.. autofunction:: py4DSTEM.visualize.vis_RQ.ax_addaxes_RtoQ
+.. autofunction:: py4DSTEM.visualize.vis_RQ.ax_addvector
+.. autofunction:: py4DSTEM.visualize.vis_RQ.ax_addvector_QtoR
+.. autofunction:: py4DSTEM.visualize.vis_RQ.ax_addvector_RtoQ
+.. autofunction:: py4DSTEM.visualize.vis_RQ.get_Qvector_from_Rvector
+.. autofunction:: py4DSTEM.visualize.vis_RQ.get_Rvector_from_Qvector
+.. autofunction:: py4DSTEM.visualize.vis_RQ.show
+.. autofunction:: py4DSTEM.visualize.vis_RQ.show_RQ
+.. autofunction:: py4DSTEM.visualize.vis_RQ.show_RQ_axes
+.. autofunction:: py4DSTEM.visualize.vis_RQ.show_RQ_vector
+.. autofunction:: py4DSTEM.visualize.vis_RQ.show_RQ_vectors
+.. autofunction:: py4DSTEM.visualize.vis_RQ.show_points
+.. autofunction:: py4DSTEM.visualize.vis_RQ.show_selected_dp
+vis_grid
+--------
+.. autofunction:: py4DSTEM.visualize.vis_grid._show_grid_overlay
+.. autofunction:: py4DSTEM.visualize.vis_grid.add_grid_overlay
+.. autofunction:: py4DSTEM.visualize.vis_grid.show
+.. autofunction:: py4DSTEM.visualize.vis_grid.show_DP_grid
+.. autofunction:: py4DSTEM.visualize.vis_grid.show_grid_overlay
+.. autofunction:: py4DSTEM.visualize.vis_grid.show_image_grid
+.. autofunction:: py4DSTEM.visualize.vis_grid.show_points
+vis_special
+-----------
+.. autofunction:: py4DSTEM.visualize.vis_special.Complex2RGB
+.. autofunction:: py4DSTEM.visualize.vis_special.add_bragg_index_labels
+.. autofunction:: py4DSTEM.visualize.vis_special.add_ellipses
+.. autofunction:: py4DSTEM.visualize.vis_special.add_pointlabels
+.. autofunction:: py4DSTEM.visualize.vis_special.add_points
+.. autofunction:: py4DSTEM.visualize.vis_special.add_scalebar
+.. autofunction:: py4DSTEM.visualize.vis_special.add_vector
+.. autofunction:: py4DSTEM.visualize.vis_special.ax_addaxes
+.. autofunction:: py4DSTEM.visualize.vis_special.ax_addaxes_QtoR
+.. autofunction:: py4DSTEM.visualize.vis_special.convert_ellipse_params
+.. autofunction:: py4DSTEM.visualize.vis_special.double_sided_gaussian
+.. autofunction:: py4DSTEM.visualize.vis_special.get_selected_lattice_vectors
+.. autofunction:: py4DSTEM.visualize.vis_special.get_voronoi_vertices
+.. autofunction:: py4DSTEM.visualize.vis_special.hsv_to_rgb
+.. autofunction:: py4DSTEM.visualize.vis_special.make_axes_locatable
+.. autofunction:: py4DSTEM.visualize.vis_special.select_lattice_vectors
+.. autofunction:: py4DSTEM.visualize.vis_special.select_point
+.. autofunction:: py4DSTEM.visualize.vis_special.show
+.. autofunction:: py4DSTEM.visualize.vis_special.show_amorphous_ring_fit
+.. autofunction:: py4DSTEM.visualize.vis_special.show_bragg_indexing
+.. autofunction:: py4DSTEM.visualize.vis_special.show_class_BPs
+.. autofunction:: py4DSTEM.visualize.vis_special.show_class_BPs_grid
+.. autofunction:: py4DSTEM.visualize.vis_special.show_complex
+.. autofunction:: py4DSTEM.visualize.vis_special.show_elliptical_fit
+.. autofunction:: py4DSTEM.visualize.vis_special.show_image_grid
+.. autofunction:: py4DSTEM.visualize.vis_special.show_kernel
+.. autofunction:: py4DSTEM.visualize.vis_special.show_lattice_vectors
+.. autofunction:: py4DSTEM.visualize.vis_special.show_max_peak_spacing
+.. autofunction:: py4DSTEM.visualize.vis_special.show_origin_fit
+.. autofunction:: py4DSTEM.visualize.vis_special.show_origin_meas
+.. autofunction:: py4DSTEM.visualize.vis_special.show_pointlabels
+.. autofunction:: py4DSTEM.visualize.vis_special.show_qprofile
+.. autofunction:: py4DSTEM.visualize.vis_special.show_selected_dps
+.. autofunction:: py4DSTEM.visualize.vis_special.show_strain
+.. autofunction:: py4DSTEM.visualize.vis_special.show_voronoi
\ No newline at end of file
diff --git a/docs/source/apiindex.rst b/docs/source/apiindex.rst
new file mode 100644
index 000000000..40d39d5d9
--- /dev/null
+++ b/docs/source/apiindex.rst
@@ -0,0 +1,17 @@
+.. _apiindex:
+
+API Index
+=========
+
+.. toctree::
+ :maxdepth: 5
+
+ api/py4DSTEM
+ api/classes
+ api/io
+ api/preprocess
+ api/process
+ api/utils
+ api/version
+ api/visualize
+ api/emd
\ No newline at end of file
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 000000000..6da66611e
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,102 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+# import os
+# sys.path.insert(0, os.path.abspath('.'))
+
+import os
+import sys
+
+sys.path.insert(0, os.path.dirname(os.getcwd()))
+from py4DSTEM import __version__
+from datetime import datetime
+
+# -- Project information -----------------------------------------------------
+
+project = "py4dstem"
+copyright = f"{datetime.today().year}, py4DSTEM Development Team"
+author = "Ben Savitsky & Alex Rakowski"
+
+# The full version, including alpha/beta/rc tags
+# release = '0.14.0'
+release = f"{__version__}"
+
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "sphinx.ext.autodoc",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.intersphinx",
+ "sphinx_rtd_theme",
+]
+
+# Other useful extensions
+# sphinx_copybutton
+# sphinx_toggleprompt
+# sphinx.ext.mathjax
+
+# Specify a standard user agent, as Sphinx default is blocked on some sites
+user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36 Edg/108.0.1462.54"
+
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = []
+
+# Set autodoc defaults
+autodoc_default_options = {
+ "members": True,
+ "member-order": "bysource",
+ "special-members": "__init__",
+}
+
+# Include todo items/lists
+todo_include_todos = True
+
+# autodoc_member_order = 'bysource'
+
+
+# intersphinx options
+
+# intersphinx_mapping = {
+# 'emdfile': ('https://pypi.org/project/emdfile/0.0.4/', None)
+# }
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = "sphinx_rtd_theme"
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ["../_static"]
+
+
+# The name of an image file (relative to this directory) to place at the top
+# of the sidebar.
+html_logo = "../_static/py4DSTEM_logo.png"
+
+# The name of an image file (within the static path) to use as favicon of the
+# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
+# pixels large.
+html_favicon = "../_static/py4DSTEM_logo_vsmall.ico"
diff --git a/docs/source/examples.rst b/docs/source/examples.rst
new file mode 100644
index 000000000..7501545e4
--- /dev/null
+++ b/docs/source/examples.rst
@@ -0,0 +1,69 @@
+.. _examples::
+
+Examples
+========
+
+First Steps
+-----------
+
+Once py4DSTEM has been succsesfully installed, you can start using it in Python the usual way. The most popular way is using Jupyter Notebooks, but py4DSTEM can be run in python scripts, iPython, spyder, etc.
+
+
+.. code-block:: python
+ :linenos:
+ :caption: Your first py4DSTEM script
+
+ # Import the needed packages
+ import py4DSTEM
+
+ # This line displays the current version of py4DSTEM:
+ py4DSTEM.__version__
+
+ # download the dataset
+ py4DSTEM.io.download_file_from_google_drive(
+ '1PmbCYosA1eYydWmmZebvf6uon9k_5g_S',
+ 'simulatedAuNanoplatelet_binned.h5'
+ )
+ file_data = "simulatedAuNanoplatelet_binned.h5"
+
+ # Load the data
+ datacube = py4DSTEM.io.read(
+ file_data,
+ data_id = 'polyAu_4DSTEM' # The file above has several blocks of data inside
+ )
+
+ # plot a diffraction pattern
+ py4DSTEM.show(
+ datacube[10,30],
+ intensity_range='absolute',
+ vmin=20,
+ vmax=200,
+ cmap='viridis',
+ )
+
+.. image:: ../_static/dp.png
+ :width: 400
+ :height: 400
+ :align: center
+
+|
+| Congratulations you've just plotted your first diffraction pattern.
+| If you run into trouble, refer back to the installation instructions :ref:`installation`. Remember to make sure you've activated the right :ref:`Python environment`.
+
+
+Next Steps
+----------
+
+.. image:: https://img.shields.io/badge/launch-binder-579aca.svg?logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAFkAAABZCAMAAABi1XidAAAB8lBMVEX///9XmsrmZYH1olJXmsr1olJXmsrmZYH1olJXmsr1olJXmsrmZYH1olL1olJXmsr1olJXmsrmZYH1olL1olJXmsrmZYH1olJXmsr1olL1olJXmsrmZYH1olL1olJXmsrmZYH1olL1olL0nFf1olJXmsrmZYH1olJXmsq8dZb1olJXmsrmZYH1olJXmspXmspXmsr1olL1olJXmsrmZYH1olJXmsr1olL1olJXmsrmZYH1olL1olLeaIVXmsrmZYH1olL1olL1olJXmsrmZYH1olLna31Xmsr1olJXmsr1olJXmsrmZYH1olLqoVr1olJXmsr1olJXmsrmZYH1olL1olKkfaPobXvviGabgadXmsqThKuofKHmZ4Dobnr1olJXmsr1olJXmspXmsr1olJXmsrfZ4TuhWn1olL1olJXmsqBi7X1olJXmspZmslbmMhbmsdemsVfl8ZgmsNim8Jpk8F0m7R4m7F5nLB6jbh7jbiDirOEibOGnKaMhq+PnaCVg6qWg6qegKaff6WhnpKofKGtnomxeZy3noG6dZi+n3vCcpPDcpPGn3bLb4/Mb47UbIrVa4rYoGjdaIbeaIXhoWHmZYHobXvpcHjqdHXreHLroVrsfG/uhGnuh2bwj2Hxk17yl1vzmljzm1j0nlX1olL3AJXWAAAAbXRSTlMAEBAQHx8gICAuLjAwMDw9PUBAQEpQUFBXV1hgYGBkcHBwcXl8gICAgoiIkJCQlJicnJ2goKCmqK+wsLC4usDAwMjP0NDQ1NbW3Nzg4ODi5+3v8PDw8/T09PX29vb39/f5+fr7+/z8/Pz9/v7+zczCxgAABC5JREFUeAHN1ul3k0UUBvCb1CTVpmpaitAGSLSpSuKCLWpbTKNJFGlcSMAFF63iUmRccNG6gLbuxkXU66JAUef/9LSpmXnyLr3T5AO/rzl5zj137p136BISy44fKJXuGN/d19PUfYeO67Znqtf2KH33Id1psXoFdW30sPZ1sMvs2D060AHqws4FHeJojLZqnw53cmfvg+XR8mC0OEjuxrXEkX5ydeVJLVIlV0e10PXk5k7dYeHu7Cj1j+49uKg7uLU61tGLw1lq27ugQYlclHC4bgv7VQ+TAyj5Zc/UjsPvs1sd5cWryWObtvWT2EPa4rtnWW3JkpjggEpbOsPr7F7EyNewtpBIslA7p43HCsnwooXTEc3UmPmCNn5lrqTJxy6nRmcavGZVt/3Da2pD5NHvsOHJCrdc1G2r3DITpU7yic7w/7Rxnjc0kt5GC4djiv2Sz3Fb2iEZg41/ddsFDoyuYrIkmFehz0HR2thPgQqMyQYb2OtB0WxsZ3BeG3+wpRb1vzl2UYBog8FfGhttFKjtAclnZYrRo9ryG9uG/FZQU4AEg8ZE9LjGMzTmqKXPLnlWVnIlQQTvxJf8ip7VgjZjyVPrjw1te5otM7RmP7xm+sK2Gv9I8Gi++BRbEkR9EBw8zRUcKxwp73xkaLiqQb+kGduJTNHG72zcW9LoJgqQxpP3/Tj//c3yB0tqzaml05/+orHLksVO+95kX7/7qgJvnjlrfr2Ggsyx0eoy9uPzN5SPd86aXggOsEKW2Prz7du3VID3/tzs/sSRs2w7ovVHKtjrX2pd7ZMlTxAYfBAL9jiDwfLkq55Tm7ifhMlTGPyCAs7RFRhn47JnlcB9RM5T97ASuZXIcVNuUDIndpDbdsfrqsOppeXl5Y+XVKdjFCTh+zGaVuj0d9zy05PPK3QzBamxdwtTCrzyg/2Rvf2EstUjordGwa/kx9mSJLr8mLLtCW8HHGJc2R5hS219IiF6PnTusOqcMl57gm0Z8kanKMAQg0qSyuZfn7zItsbGyO9QlnxY0eCuD1XL2ys/MsrQhltE7Ug0uFOzufJFE2PxBo/YAx8XPPdDwWN0MrDRYIZF0mSMKCNHgaIVFoBbNoLJ7tEQDKxGF0kcLQimojCZopv0OkNOyWCCg9XMVAi7ARJzQdM2QUh0gmBozjc3Skg6dSBRqDGYSUOu66Zg+I2fNZs/M3/f/Grl/XnyF1Gw3VKCez0PN5IUfFLqvgUN4C0qNqYs5YhPL+aVZYDE4IpUk57oSFnJm4FyCqqOE0jhY2SMyLFoo56zyo6becOS5UVDdj7Vih0zp+tcMhwRpBeLyqtIjlJKAIZSbI8SGSF3k0pA3mR5tHuwPFoa7N7reoq2bqCsAk1HqCu5uvI1n6JuRXI+S1Mco54YmYTwcn6Aeic+kssXi8XpXC4V3t7/ADuTNKaQJdScAAAAAElFTkSuQmCC
+ :target: https://mybinder.org/v2/gh/py4dstem/py4DSTEM_tutorials/main
+
+For a more extensive overview checkout the `tutorial github repository `_ to see example notebooks demonstraing the features of py4DSTEM. These can be downloaded and run locally or run through the browser using `binder `_. Here are some example plots from different anaylses you'll learn running the tutorials.
+
+
+
+.. image:: ../_static/Advanced_plots.png
+ :width: 600
+ :align: center
+
+|
+|
\ No newline at end of file
diff --git a/docs/source/gplv3.txt b/docs/source/gplv3.txt
new file mode 100644
index 000000000..f288702d2
--- /dev/null
+++ b/docs/source/gplv3.txt
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/docs/source/gui.rst b/docs/source/gui.rst
new file mode 100644
index 000000000..df4f89d1b
--- /dev/null
+++ b/docs/source/gui.rst
@@ -0,0 +1,28 @@
+.. _gui:
+
+Graphical User Interface
+========================
+
+Overview
+--------
+
+There is a GUI for viewing and performing some basic analysis of your 4D-STEM dataset.
+This feature is currently in development and must be installed separately. For more details you can checkout the git repositoary `here `_
+
+.. image:: ../_static/demo.gif
+ :width: 800
+ :alt: py4DSTEM-Browser
+
+
+Installation
+------------
+
+Currently there are no pip or conda packages and it must be install in one of two ways: ::
+
+ git clone https://github.com/sezelt/py4D-browser.git
+ cd py4D-browser
+ python setupy.py
+
+Alternatively, ::
+
+ pip install git+https://github.com/sezelt/py4D-browser
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 000000000..778cee2cd
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,39 @@
+Welcome to the py4DSTEM documentation
+=====================================
+
+``py4DSTEM`` is an open source set of python tools for processing and analysis of :ref:`four-dimensional scanning transmission electron microscopy (4D-STEM)<4dstem>` data.
+
+
+
+Contents
+^^^^^^^^
+
+.. toctree::
+ :maxdepth: 2
+
+ 4dstem
+ installation
+ examples
+ api
+ apiindex
+ gui
+ supportAndContributions
+ license
+ acknowledgements
+
+
+Indices and tables
+^^^^^^^^^^^^^^^^^^
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
+
+
+
+.. image:: ../../images/py4DSTEM_logo.png
+ :width: 250
+
+.. image:: ../../images/toyota_research_institute.png
+
+
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
new file mode 100644
index 000000000..f3b3e31de
--- /dev/null
+++ b/docs/source/installation.rst
@@ -0,0 +1,487 @@
+.. _installation:
+
+Installation
+============
+
+.. contents:: Table of Contents
+ :depth: 4
+
+
+
+Setting up Python
+-----------------
+
+The recommended installation for py4DSTEM uses the `Anaconda `_ Python distribution. Alternatives such as `Miniconda `_, `Mamba `_, `pip virtualenv `_, and `poetry `_ will work, but here we assume the use of Anaconda. See :ref:`virtualenvironments`, for more details.
+The instructions to download and install Anaconda can be found `here `_.
+
+
+
+
+.. The overview of installation process is:
+
+.. * make a virtual environment (see below)
+.. * enter the environment
+.. * install py4DSTEM
+
+Recommended Installation
+------------------
+
+There are three ways to install py4DSTEM:
+
+#. Anaconda (miniconda / mamba)
+#. Pip
+#. Installing from Source
+
+The easiest way to install py4DSTEM is to use the pre packaged anaconda version. This is an overview of what the installation process looks like, for OS specific instructions see below.
+
+Anaconda
+********
+
+Windows
+^^^^^^^
+
+.. code-block:: shell
+ :linenos:
+ :caption: Windows base install
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem
+ conda install -c conda-forge pywin32
+ # optional but recomended
+ conda install jupyterlab pymatgen
+
+Linux
+^^^^^
+
+.. code-block:: shell
+ :linenos:
+ :caption: Linux base install
+
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem
+ # optional but recomended
+ conda install jupyterlab pymatgen
+
+Mac (Intel)
+^^^^^^^^^^^
+.. code-block:: shell
+ :linenos:
+ :caption: Intel Mac base install
+
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem
+ # optional but recomended
+ conda install jupyterlab pymatgen
+
+Mac (Apple Silicon M1/M2)
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. code-block:: shell
+ :linenos:
+ :caption: Apple Silicon Mac base install
+
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install pyqt hdf5
+ conda install -c conda-forge py4dstem
+ # optional but recomended
+ conda install jupyterlab pymatgen
+
+
+Advanced Installation
+---------------------
+
+Installing optional dependencies:
+*********************************
+
+Some of the features and modules require extra dependencies which can easily be installed using either Anaconda or Pip.
+
+Anaconda
+********
+
+Windows
+^^^^^^^
+
+.. code-block:: shell
+ :linenos:
+ :caption: Windows Anaconda install ACOM
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem pymatgen
+ conda install -c conda-forge pywin32
+
+Running py4DSTEM code with GPU acceleration requires an NVIDIA GPU (AMD has beta support but hasn't been tested) and Nvidia Drivers installed on the system.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Windows Anaconda install GPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem cupy cudatoolkit
+ conda install -c conda-forge pywin32
+
+
+If you are looking to run the ML-AI features you are required to install tensorflow, this can be done with CPU only and GPU support.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Windows Anaconda install ML-AI CPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem
+ pip install tensorflow==2.4.1 tensorflow-addons<=0.14 crystal4D
+ conda install -c conda-forge pywin32
+
+.. code-block:: shell
+ :linenos:
+ :caption: Windows Anaconda install ML-AI GPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem
+ conda install -c conda-forge cupy cudatoolkit=11.0
+ pip install tensorflow==2.4.1 tensorflow-addons<=0.14 crystal4D
+ conda install -c conda-forge pywin32
+
+
+
+Linux
+^^^^^
+
+.. code-block:: shell
+ :linenos:
+ :caption: Linux Anaconda install ACOM
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem pymatgen
+
+Running py4DSTEM code with GPU acceleration requires an NVIDIA GPU (AMD has beta support but hasn't been tested) and Nvidia Drivers installed on the system.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Linux Anaconda install GPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem cupy cudatoolkit
+
+
+If you are looking to run the ML-AI features you are required to install tensorflow, this can be done with CPU only and GPU support.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Linux Anaconda install ML-AI CPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem
+ pip install tensorflow==2.4.1 tensorflow-addons<=0.14 crystal4D
+
+.. code-block:: shell
+ :linenos:
+ :caption: Linux Anaconda install ML-AI GPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem
+ conda install -c conda-forge cupy cudatoolkit=11.0
+ pip install tensorflow==2.4.1 tensorflow-addons<=0.14 crystal4D
+
+
+
+Mac (Intel)
+^^^^^^^^^^^
+.. code-block:: shell
+ :linenos:
+ :caption: Intel Mac Anaconda install ACOM
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem pymatgen
+
+
+Tensorflow does not support AMD GPUs so while ML-AI features can be run on an Intel Mac they are not GPU accelerated
+
+.. code-block:: shell
+ :linenos:
+ :caption: Intel Mac Anaconda install ML-AI CPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem
+ pip install tensorflow==2.4.1 tensorflow-addons<=0.14 crystal4D
+
+Mac (Apple Silicon M1/M2)
+^^^^^^^^^^^^^^^^^^^^^^^^^
+.. code-block:: shell
+ :linenos:
+ :caption: Apple Silicon Mac Anaconda install ACOM
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem pymatgen
+
+
+
+Tensorflow's support of Apple silicon GPUs is limited, and while there are steps that should enable GPU acceleration they have not been tested, but CPU only has been tested.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Apple Silicon Mac Anaconda install ML-AI CPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge py4dstem
+ pip install tensorflow==2.4.1 tensorflow-addons<=0.14 crystal4D
+
+.. Attention:: **GPU Accelerated Tensorflow on Apple Silicon**
+
+ This is an untested install method and it may not work. If you try and face issues please post an issue on `github `_.
+
+
+.. code-block:: shell
+ :linenos:
+ :caption: Apple Silicon Mac Anaconda install ML-AI GPU
+
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c apple tensorflow-deps
+ pip install tensorflow-macos==2.5.0 tensorflow-addons<=0.14 crystal4D tensorflow-metal
+ conda install -c conda-forge py4dstem
+
+
+
+Pip
+***
+
+Windows
+^^^^^^^
+
+.. code-block:: shell
+ :linenos:
+ :caption: Windows pip install ACOM
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[acom]
+ conda install -c conda-forge pywin32
+
+Running py4DSTEM code with GPU acceleration requires an NVIDIA GPU (AMD has beta support but hasn't been tested) and Nvidia Drivers installed on the system.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Windows pip install GPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[cuda]
+ conda install -c conda-forge pywin32
+
+
+If you are looking to run the ML-AI features you are required to install tensorflow, this can be done with CPU only and GPU support.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Windows pip install ML-AI CPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[aiml]
+ conda install -c conda-forge pywin32
+
+.. code-block:: shell
+ :linenos:
+ :caption: Windows pip install ML-AI GPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge cudatoolkit=11.0
+ pip install py4dstem[aiml-cuda]
+ conda install -c conda-forge pywin32
+
+Linux
+^^^^^
+
+.. code-block:: shell
+ :linenos:
+ :caption: Linux pip install ACOM
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[acom]
+
+Running py4DSTEM code with GPU acceleration requires an NVIDIA GPU (AMD has beta support but hasn't been tested) and Nvidia Drivers installed on the system.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Linux pip install GPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[cuda]
+
+
+If you are looking to run the ML-AI features you are required to install tensorflow, this can be done with CPU only and GPU support.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Linux pip install ML-AI CPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[aiml]
+
+.. code-block:: shell
+ :linenos:
+ :caption: Linux pip install ML-AI GPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c conda-forge cudatoolkit=11.0
+ pip install py4dstem[aiml-cuda]
+
+Mac (Intel)
+^^^^^^^^^^^
+.. code-block:: shell
+ :linenos:
+ :caption: Intel Mac pip install ACOM
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[acom]
+
+
+Tensorflow does not support AMD GPUs so while ML-AI features can be run on an Intel Mac they are not GPU accelerated
+
+.. code-block:: shell
+ :linenos:
+ :caption: Intel Mac pip install ML-AI CPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[aiml]
+
+Mac (Apple Silicon M1/M2)
+^^^^^^^^^^^^^^^^^^^^^^^^^
+.. code-block:: shell
+ :linenos:
+ :caption: Apple Silicon Mac pip install ACOM
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[acom]
+ conda install -c conda-forge py4dstem pymatgen
+
+
+
+Tensorflow's support of Apple silicon GPUs is limited, and while there are steps that should enable GPU acceleration they have not been tested, but CPU only has been tested.
+
+.. code-block:: shell
+ :linenos:
+ :caption: Apple Silicon Mac Anaconda install ML-AI CPU
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ pip install py4dstem[aiml]
+
+.. Attention:: **GPU Accelerated Tensorflow on Apple Silicon**
+
+ This is an untested install method and it may not work. If you try and face issues please post an issue on `github `_.
+
+
+.. code-block:: shell
+ :linenos:
+ :caption: Apple Silicon Mac Anaconda install ML-AI GPU
+
+
+ conda create -n py4dstem python=3.9
+ conda activate py4dstem
+ conda install -c apple tensorflow-deps
+ pip install tensorflow-macos==2.5.0 tensorflow-addons<=0.14 crystal4D tensorflow-metal py4dstem
+
+
+Installing from Source
+******************
+
+To checkout the latest bleeding edge features, or contriubte your own features you'll need to install py4DSTEM from source. Luckily this is easy and can be done by simply running:
+
+.. code-block:: shell
+ :linenos:
+
+ git clone
+ git checkout # e.g. git checkout dev
+ pip install -e .
+
+Alternatively, you can try single step method:
+
+.. code-block:: shell
+ :linenos:
+
+ pip install git+https://github.com/py4DSTEM/py4DSTEM.git@dev # install the dev branch
+
+
+Docker
+******
+
+Overview
+^^^^^^^^
+ "Docker is an open platform for developing, shipping, and running applications. Docker enables you to separate your applications from your infrastructure so you can deliver software quickly. With Docker, you can manage your infrastructure in the same ways you manage your applications. By taking advantage of Docker’s methodologies for shipping, testing, and deploying code quickly, you can significantly reduce the delay between writing code and running it in production."
+ c.f. `Docker website `_
+
+Installation
+^^^^^^^^^^^^
+
+There are py4DSTEM Docker images available on dockerhub, which can be pulled and run or built upon. Checkout the dockerhub repository to see all the versions aviale or simply run the below to get the latest version.
+While Docker is extremely powerful and aims to greatly simplify depolying software, it is also a complex and nuanced topic. If you are interested in using it, and are having troubles getting it to work please file an issue on the github.
+To use Docker you'll first need to `install Docker `_. After which you can run the images with the following commands.
+
+.. code-block:: shell
+ :linenos:
+
+ docker pull arakowsk/py4dstem:latest
+ docker run py4dstem:latest
+
+Alternatively, you can use `Docker Desktop `_ which is a GUI interface for Docker and may be an easier method for running the images for less experienced users.
+
+
+Troubleshooting
+---------------
+
+If you face any issues, see the common errors below, and if there's no solution please file an issue on the `git repository `_.
+
+Some common errors:
+
+- make sure you've activated the right environment
+- when installing subsections sometimes the quotation marks can be tricky dpeending on os, terminal etc.
+- GPU drivers - tricky to explain
+
+
+
+.. _virtualenvironments:
+
+Virtual Environments
+--------------------
+
+.. Attention:: **Virtual environments**
+
+ A Python virtual environment is its own siloed version of Python, with its own set of packages and modules, kept separate from any other Python installations on your system.
+ In the instructions above, we created a virtual environment to make sure packages that have different dependencies don't conflict with one another.
+ For instance, as of this writing, some of the scientific Python packages don't work well with Python 3.9 - but you might have some other applications on your computer that *need* Python 3.9.
+ Using virtual environments solves this problem.
+ In this example, we're creating and navigating virtual environments using Anaconda.
+
+ Because these directions install py4DSTEM to its own virtual environment, each time you want to use py4DSTEM, you'll need to activate this environment.
+
+ * In the command line, you can do this with ``conda activate py4dstem``.
+ * In the Anaconda Navigator, you can do this by clicking on the Environments tab and then clicking on ``py4dstem``.
\ No newline at end of file
diff --git a/docs/source/license.rst b/docs/source/license.rst
new file mode 100644
index 000000000..c640aaa3e
--- /dev/null
+++ b/docs/source/license.rst
@@ -0,0 +1,16 @@
+.. _license:
+
+License
+=======
+
+py4DSTEM is released under the GNU GPV version 3 license.
+
+
+GPLv3
+^^^^^
+
+.. include:: gplv3.txt
+ :literal:
+
+
+
diff --git a/docs/source/references.rst b/docs/source/references.rst
new file mode 100644
index 000000000..e02d5c128
--- /dev/null
+++ b/docs/source/references.rst
@@ -0,0 +1,37 @@
+References
+==========
+
+.. * If you use py4DSTEM for a scientific study, please cite our open access `py4DSTEM publication`_ in Microscopy and Microanalysis.
+
+
+.. * py4DSTEM: A Software Package for Four-Dimensional Scanning Transmission Electron Microscopy Data Analysis
+
+.. .. image:: ../_static/DOI-BADGE-978-3-319-76207-4_15.svg
+.. :target: https://doi.org/10.1017/S1431927621000477
+
+.. * Check out the `Py4DSTEM Github`_
+
+.. * We'd like to thank The developers gratefully acknowledge the financial support of the Toyota Research Institute for the research and development time which made this project possible.
+
+.. .. image:: ../_static/toyota_research_institute.png
+
+.. * Additional funding has been provided by the US Department of Energy, Office of Science, Basic Energy Sciences.
+
+.. .. image:: ../_static/DOE_logo.png
+
+.. * You are also free to use the py4DSTEM logo in PDF format or logo in PNG format for presentations or posters.
+
+.. **********
+.. References
+.. **********
+
+.. .. target-notes::
+.. .. _`py4DSTEM publication`: https://doi.org/10.1017/S1431927621000477
+.. .. _`Py4DSTEM Github`: http://github.com/py4DSTEM/py4DSTEM
+
+
+.. rubric:: Footnotes
+
+.. [#f1] `Text of the first footnote`
+.. [#f2] `Py4DSTEM Github `_
+.. [#f3]
\ No newline at end of file
diff --git a/docs/source/supportAndContributions.rst b/docs/source/supportAndContributions.rst
new file mode 100644
index 000000000..f78c788d7
--- /dev/null
+++ b/docs/source/supportAndContributions.rst
@@ -0,0 +1,57 @@
+.. _supportAndContributions:
+
+Support & Contributions
+=======
+
+
+Support
+-------
+
+Think you've found a bug or are facing issues using a feature? Please let us know by creating an issue on `github `_
+
+
+Contributions
+-------------
+Looking to contirbute? Awesome we love people contributing, and it's a simple process.
+
+#. Submit feature request on `github `_
+#. Follow the developer install instructions
+#. Make any change alterations and document all functions (All code should be readable, so clarity beats cleverness)
+#. Submit a PR on github.
+
+
+.. .. raw:: html
+
+..
+..
+..
+..
P
+..
R
+..
O
+..
F
+..
I
+..
T
+..
+..
\ No newline at end of file
diff --git a/images/DOE_logo.png b/images/DOE_logo.png
new file mode 100644
index 000000000..33681cbc0
Binary files /dev/null and b/images/DOE_logo.png differ
diff --git a/images/py4DSTEM_logo.png b/images/py4DSTEM_logo.png
new file mode 100644
index 000000000..1f9263f1e
Binary files /dev/null and b/images/py4DSTEM_logo.png differ
diff --git a/images/py4DSTEM_logo_54_export.png b/images/py4DSTEM_logo_54_export.png
new file mode 100644
index 000000000..b2ee84715
Binary files /dev/null and b/images/py4DSTEM_logo_54_export.png differ
diff --git a/images/toyota_research_institute.png b/images/toyota_research_institute.png
new file mode 100644
index 000000000..d06126fd6
Binary files /dev/null and b/images/toyota_research_institute.png differ
diff --git a/ncem_4D_stem_quickview.py b/ncem_4D_stem_quickview.py
deleted file mode 100644
index b25fb9fda..000000000
--- a/ncem_4D_stem_quickview.py
+++ /dev/null
@@ -1,170 +0,0 @@
-######## Viewer for 4D STEM data ########
-#
-# Defines a class -- Interactive4DSTEMDataViewer - enabling a simple GUI for
-# interacting with 4D STEM datasets.
-#
-# Relevant documentation for lower level code:
-# ScopeFoundry is a flexible package for both scientific data visualization and control of labrotory experiments. See http://www.scopefoundry.org/.
-# Qt is being run through Pyside/PySide2/PyQt/Qt for Python. See https://www.qt.io/qt-for-python.
-# pyqtgraph is a library which facilitates fast-running scientific visualization. See http://pyqtgraph.org/.
-
-
-from __future__ import division, print_function
-import numpy as np
-import sys
-from ScopeFoundry import BaseApp
-from ScopeFoundry.helper_funcs import load_qt_ui_file, sibling_path
-import pyqtgraph as pg
-import dm3_lib as dm3
-
-class Interactive4DSTEMDataViewer(BaseApp):
- """
- Interactive4DSTEMDataViewer objects inherit from the ScopeFoundry.BaseApp class.
- ScopeFoundry.BaseApp objects inherit from the QtCore.QObject class.
- Additional functionality is provided by pyqtgraph widgets.
-
- The class is used by instantiating and then entering the main Qt loop with, e.g.:
- app = Interactive4DSTEMDataViewer(sys.argv)
- app.exec_()
- """
- def setup(self):
-
- """
- Sets up the interface.
-
- Includes three primary windows:
- -Diffraction space view (detector space)
- -Real space view (scan positions + virtual detectors)
- -iPython Console
-
- Note that the diffraction space window also contains dialogs for basic user inputs.
- (i.e. file loading, etc.)
- """
-
- # Load the main user interface window
- self.ui = load_qt_ui_file(sibling_path(__file__, "quick_view_gui.ui"))
- self.ui.show()
- self.ui.raise_()
-
- # Create new self.settings fields
- self.settings.New('data_filename',dtype='file')
- self.settings.New('stem_Nx', dtype=int, initial=1)
- self.settings.New('stem_Ny', dtype=int, initial=1)
-
- # Methods to be run when UI widgets are changed
- self.settings.data_filename.updated_value.connect(self.on_change_data_filename)
- self.settings.stem_Nx.updated_value.connect(self.on_change_stem_nx)
- self.settings.stem_Ny.updated_value.connect(self.on_change_stem_ny)
-
- # Connect UI changes to updates in self.settings
- self.settings.data_filename.connect_to_browse_widgets(self.ui.data_filename_lineEdit, self.ui.data_filename_browse_pushButton)
- self.settings.stem_Nx.connect_bidir_to_widget(self.ui.stem_Nx_doubleSpinBox)
- self.settings.stem_Ny.connect_bidir_to_widget(self.ui.stem_Ny_doubleSpinBox)
-
- # Create and set up display of diffraction patterns
- self.stack_imv = pg.ImageView()
- self.stack_imv.setImage(self.stack_data.swapaxes(1,2))
- self.ui.stack_groupBox.layout().addWidget(self.stack_imv)
-
- # Create and set up display in real space
- self.stem_imv = pg.ImageView()
- self.stem_imv.setImage(self.data4D.sum(axis=(2,3)).T)
- self.stem_pt_roi = pg_point_roi(self.stem_imv.getView())
- self.stem_pt_roi.sigRegionChanged.connect(self.on_stem_pt_roi_change)
- self.virtual_aperture_roi = pg.RectROI([self.ccd_Nx/2, self.ccd_Ny/2], [50,50], pen=(3,9))
- self.stack_imv.getView().addItem(self.virtual_aperture_roi)
- self.virtual_aperture_roi.sigRegionChanged.connect(self.on_virtual_aperture_roi_change)
- self.stem_imv.setWindowTitle('STEM image')
- self.stem_imv.show()
-
- # Make a iPython Console widget
- self.console_widget.show()
-
- # Arrange windows and set their geometries
- px = 600
- self.ui.setGeometry(0,0,px,2*px)
- self.stem_imv.setGeometry(px,0,px,px)
- self.console_widget.setGeometry(px,1.11*px,px,px)
- self.stack_imv.activateWindow()
- self.stack_imv.raise_()
- self.stem_imv.raise_()
- self.console_widget.raise_()
-
-
- #### Methods controlling responses to user inputs ####
-
- def on_change_data_filename(self):
- fname = self.settings.data_filename.val
- print("Loading file",fname)
-
- try:
- self.dm3f = dm3.DM3(fname, debug=True)
- self.stack_data = self.dm3f.imagedata
- except Exception as err:
- print("Failed to load", err)
- self.stack_data = np.random.rand(100,512,512)
- self.stem_N, self.ccd_Ny, self.ccd_Nx = self.stack_data.shape
- if hasattr(self, 'stem_pt_roi'):
- self.on_stem_pt_roi_change()
-
- self.stack_imv.setImage(self.stack_data.swapaxes(1,2))
-
- self.settings.stem_Nx.update_value(1)
- self.settings.stem_Ny.update_value(self.stem_N)
-
- def on_change_stem_nx(self):
- stem_Nx = self.settings.stem_Nx.val
- self.settings.stem_Ny.update_value(int(self.stem_N/stem_Nx))
- stem_Ny = self. settings.stem_Ny.val
- self.data4D = self.stack_data.reshape(stem_Ny,stem_Nx,self.ccd_Ny,self.ccd_Nx)
- if hasattr(self, "virtual_aperture_roi"):
- self.on_virtual_aperture_roi_change()
-
- def on_change_stem_ny(self):
- stem_Ny = self.settings.stem_Ny.val
- self.settings.stem_Nx.update_value(int(self.stem_N/stem_Ny))
- stem_Nx = self.settings.stem_Nx.val
- self.data4D = self.stack_data.reshape(stem_Ny,stem_Nx,self.ccd_Ny,self.ccd_Nx)
- if hasattr(self, "virtual_aperture_roi"):
- self.on_virtual_aperture_roi_change()
-
- def on_stem_pt_roi_change(self):
- roi_state = self.stem_pt_roi.saveState()
- x0,y0 = roi_state['pos']
- xc,yc = x0+1, y0+1
- stack_num = self.settings.stem_Nx.val*int(yc)+int(xc)
- self.stack_imv.setCurrentIndex(stack_num)
-
- def on_virtual_aperture_roi_change(self):
- roi_state = self.virtual_aperture_roi.saveState()
- x0,y0 = roi_state['pos']
- slices, transforms = self.virtual_aperture_roi.getArraySlice(self.stack_data, self.stack_imv.getImageItem())
- slice_x, slice_y, slice_z = slices
- self.stem_imv.setImage(self.data4D[:,:,slice_y, slice_x].sum(axis=(2,3)).T)
-
-############### End of class ###############
-
-
-def pg_point_roi(view_box):
- """
- Utility function for point selection.
- Based in pyqtgraph, and returns a pyqtgraph CircleROI object.
- This object has a sigRegionChanged.connect() signal method to connect to other functions.
- """
- circ_roi = pg.CircleROI( (0,0), (2,2), movable=True, pen=(0,9))
- h = circ_roi.addTranslateHandle((0.5,0.5))
- h.pen = pg.mkPen('r')
- h.update()
- view_box.addItem(circ_roi)
- circ_roi.removeHandle(0)
- return circ_roi
-
-
-
-if __name__=="__main__":
- app = Interactive4DSTEMDataViewer(sys.argv)
-
- sys.exit(app.exec_())
-
-
-
diff --git a/py4DSTEM/__init__.py b/py4DSTEM/__init__.py
new file mode 100644
index 000000000..d5df63f5e
--- /dev/null
+++ b/py4DSTEM/__init__.py
@@ -0,0 +1,93 @@
+from py4DSTEM.version import __version__
+from emdfile import tqdmnd
+
+
+### io
+
+# substructure
+from emdfile import (
+ Node,
+ Root,
+ Metadata,
+ Array,
+ PointList,
+ PointListArray,
+ Custom,
+ print_h5_tree,
+)
+
+_emd_hook = True
+
+# structure
+from py4DSTEM import io
+from py4DSTEM.io import import_file, read, save
+
+
+### basic data classes
+
+# data
+from py4DSTEM.data import (
+ Data,
+ Calibration,
+ DiffractionSlice,
+ RealSlice,
+ QPoints,
+)
+
+# datacube
+from py4DSTEM.datacube import DataCube, VirtualImage, VirtualDiffraction
+
+
+### visualization
+
+from py4DSTEM import visualize
+from py4DSTEM.visualize import show, show_complex
+
+### analysis classes
+
+# braggvectors
+from py4DSTEM.braggvectors import (
+ Probe,
+ BraggVectors,
+ BraggVectorMap,
+)
+
+from py4DSTEM.process import classification
+
+
+# diffraction
+from py4DSTEM.process.diffraction import Crystal, Orientation
+
+
+# ptycho
+from py4DSTEM.process import phase
+
+
+# polar
+from py4DSTEM.process.polar import PolarDatacube
+
+
+# strain
+from py4DSTEM.process.strain.strain import StrainMap
+
+from py4DSTEM.process import wholepatternfit
+
+
+### more submodules
+# TODO
+
+from py4DSTEM import preprocess
+from py4DSTEM import process
+
+
+### utilities
+
+# config
+from py4DSTEM.utils.configuration_checker import check_config
+
+# TODO - config .toml
+
+# testing
+from os.path import dirname, join
+
+_TESTPATH = join(dirname(__file__), "../test/unit_test_data")
diff --git a/py4DSTEM/braggvectors/__init__.py b/py4DSTEM/braggvectors/__init__.py
new file mode 100644
index 000000000..482b1f31e
--- /dev/null
+++ b/py4DSTEM/braggvectors/__init__.py
@@ -0,0 +1,8 @@
+from py4DSTEM.braggvectors.probe import Probe
+from py4DSTEM.braggvectors.braggvectors import BraggVectors
+from py4DSTEM.braggvectors.braggvector_methods import BraggVectorMap
+from py4DSTEM.braggvectors.diskdetection import *
+from py4DSTEM.braggvectors.probe import *
+
+# from .diskdetection_aiml import *
+# from .diskdetection_parallel_new import *
diff --git a/py4DSTEM/braggvectors/braggvector_methods.py b/py4DSTEM/braggvectors/braggvector_methods.py
new file mode 100644
index 000000000..70a36dec1
--- /dev/null
+++ b/py4DSTEM/braggvectors/braggvector_methods.py
@@ -0,0 +1,867 @@
+# BraggVectors methods
+
+import inspect
+from warnings import warn
+
+import matplotlib.pyplot as plt
+import numpy as np
+from emdfile import Array, Metadata, _read_metadata, tqdmnd
+from py4DSTEM import show
+from py4DSTEM.datacube import VirtualImage
+from scipy.ndimage import gaussian_filter
+
+
+class BraggVectorMethods:
+ """
+ A container for BraggVector object instance methods
+ """
+
+ # 2D histogram
+
+ def histogram(
+ self,
+ mode="cal",
+ sampling=1,
+ weights=None,
+ weights_thresh=0.005,
+ ):
+ """
+ Returns a 2D histogram of Bragg vector intensities in diffraction space,
+ aka a Bragg vector map.
+
+ Parameters
+ ----------
+ mode : str
+ Must be 'cal' or 'raw'. Use the calibrated or raw vector positions.
+ sampling : number
+ The sampling rate of the histogram, in units of the camera's sampling.
+ `sampling = 2` upsamples and `sampling = 0.5` downsamples, each by a
+ factor of 2.
+ weights : None or array
+ If None, use all real space scan positions. Otherwise must be a real
+ space shaped array representing a weighting factor applied to vector
+ intensities from each scan position. If weights is boolean uses beam
+ positions where weights is True. If weights is number-like, scales
+ by the values, and skips positions where wieghtsweight_thresh
+ are skipped.
+
+ Returns
+ -------
+ BraggVectorHistogram
+ An Array with .data representing the data, and .dim[0] and .dim[1]
+ representing the sampling grid.
+ """
+ # get vectors
+ assert mode in ("cal", "raw"), f"Invalid mode {mode}!"
+ if mode == "cal":
+ v = self.cal
+ else:
+ v = self.raw
+
+ # condense vectors into a single array for speed,
+ # handling any weight factors
+ if weights is None:
+ vects = np.concatenate(
+ [
+ v[i, j].data
+ for i in range(self.Rshape[0])
+ for j in range(self.Rshape[1])
+ ]
+ )
+ elif weights.dtype == bool:
+ x, y = np.nonzero(weights)
+ vects = np.concatenate([v[i, j].data for i, j in zip(x, y)])
+ else:
+ l = []
+ x, y = np.nonzero(weights > weights_thresh)
+ for i, j in zip(x, y):
+ d = v[i, j].data
+ d["intensity"] *= weights[i, j]
+ l.append(d)
+ vects = np.concatenate(l)
+ # get the vectors
+ qx = vects["qx"]
+ qy = vects["qy"]
+ I = vects["intensity"]
+
+ # Set up bin grid
+ Q_Nx = np.round(self.Qshape[0] * sampling).astype(int)
+ Q_Ny = np.round(self.Qshape[1] * sampling).astype(int)
+
+ # transform vects onto bin grid
+ if mode == "raw":
+ qx *= sampling
+ qy *= sampling
+ # calibrated vects
+ # to tranform to the bingrid we ~undo the calibrations,
+ # then scale by the sampling factor
+ else:
+ # get pixel calibration
+ if self.calstate["pixel"] is True:
+ qpix = self.calibration.get_Q_pixel_size()
+ qx /= qpix
+ qy /= qpix
+ # origin calibration
+ if self.calstate["center"] is True:
+ origin = self.calibration.get_origin_mean()
+ qx += origin[0]
+ qy += origin[1]
+ # resample
+ qx *= sampling
+ qy *= sampling
+
+ # round to nearest integer
+ floorx = np.floor(qx).astype(np.int64)
+ ceilx = np.ceil(qx).astype(np.int64)
+ floory = np.floor(qy).astype(np.int64)
+ ceily = np.ceil(qy).astype(np.int64)
+
+ # Remove any points outside the bin grid
+ mask = np.logical_and.reduce(
+ ((floorx >= 0), (floory >= 0), (ceilx < Q_Nx), (ceily < Q_Ny))
+ )
+ qx = qx[mask]
+ qy = qy[mask]
+ I = I[mask]
+ floorx = floorx[mask]
+ floory = floory[mask]
+ ceilx = ceilx[mask]
+ ceily = ceily[mask]
+
+ # Interpolate values
+ dx = qx - floorx
+ dy = qy - floory
+ # Compute indices of the 4 neighbors to (qx,qy)
+ # floor x, floor y
+ inds00 = np.ravel_multi_index([floorx, floory], (Q_Nx, Q_Ny))
+ # floor x, ceil y
+ inds01 = np.ravel_multi_index([floorx, ceily], (Q_Nx, Q_Ny))
+ # ceil x, floor y
+ inds10 = np.ravel_multi_index([ceilx, floory], (Q_Nx, Q_Ny))
+ # ceil x, ceil y
+ inds11 = np.ravel_multi_index([ceilx, ceily], (Q_Nx, Q_Ny))
+
+ # Compute the histogram by accumulating intensity in each
+ # neighbor weighted by linear interpolation
+ hist = (
+ np.bincount(inds00, I * (1.0 - dx) * (1.0 - dy), minlength=Q_Nx * Q_Ny)
+ + np.bincount(inds01, I * (1.0 - dx) * dy, minlength=Q_Nx * Q_Ny)
+ + np.bincount(inds10, I * dx * (1.0 - dy), minlength=Q_Nx * Q_Ny)
+ + np.bincount(inds11, I * dx * dy, minlength=Q_Nx * Q_Ny)
+ ).reshape(Q_Nx, Q_Ny)
+
+ # determine the resampled grid center and pixel size
+ if mode == "cal" and self.calstate["center"] is True:
+ x0 = sampling * origin[0]
+ y0 = sampling * origin[1]
+ else:
+ x0, y0 = 0, 0
+ if mode == "cal" and self.calstate["pixel"] is True:
+ pixelsize = qpix / sampling
+ else:
+ pixelsize = 1 / sampling
+ # find the dim vectors
+ dimx = (np.arange(Q_Nx) - x0) * pixelsize
+ dimy = (np.arange(Q_Ny) - y0) * pixelsize
+ dim_units = self.calibration.get_Q_pixel_units()
+
+ # wrap in a class
+ ans = BraggVectorMap(
+ name=f"2Dhist_{self.name}_{mode}_s={sampling}",
+ data=hist,
+ weights=weights,
+ dims=[dimx, dimy],
+ dim_units=dim_units,
+ origin=(x0, y0),
+ pixelsize=pixelsize,
+ )
+
+ # return
+ return ans
+
+ # aliases
+ get_bvm = get_bragg_vector_map = histogram
+
+ # bragg virtual imaging
+
+ def get_virtual_image(
+ self,
+ mode=None,
+ geometry=None,
+ name="bragg_virtual_image",
+ returncalc=True,
+ center=True,
+ ellipse=True,
+ pixel=True,
+ rotate=True,
+ ):
+ """
+ Calculates a virtual image based on the values of the Braggvectors
+ integrated over some detector function determined by `mode` and
+ `geometry`.
+
+ Parameters
+ ----------
+ mode : str
+ defines the type of detector used. Options:
+ - 'circular', 'circle': uses round detector, like bright field
+ - 'annular', 'annulus': uses annular detector, like dark field
+ geometry : variable
+ expected value depends on the value of `mode`, as follows:
+ - 'circle', 'circular': nested 2-tuple, ((qx,qy),radius)
+ - 'annular' or 'annulus': nested 2-tuple,
+ ((qx,qy),(radius_i,radius_o))
+ Values can be in pixels or calibrated units. Note that (qx,qy)
+ can be skipped, which assumes peaks centered at (0,0).
+ center: bool
+ Apply calibration - center coordinate.
+ ellipse: bool
+ Apply calibration - elliptical correction.
+ pixel: bool
+ Apply calibration - pixel size.
+ rotate: bool
+ Apply calibration - QR rotation.
+
+ Returns
+ -------
+ virtual_im : VirtualImage
+ """
+
+ # parse inputs
+ circle_modes = ["circular", "circle"]
+ annulus_modes = ["annular", "annulus"]
+ modes = circle_modes + annulus_modes + [None]
+ assert mode in modes, f"Unrecognized mode {mode}"
+
+ # set geometry
+ if mode is None:
+ if geometry is None:
+ qxy_center = None
+ radial_range = np.array((0, np.inf))
+ else:
+ if len(geometry[0]) == 0:
+ qxy_center = None
+ else:
+ qxy_center = np.array(geometry[0])
+ if isinstance(geometry[1], int) or isinstance(geometry[1], float):
+ radial_range = np.array((0, geometry[1]))
+ elif len(geometry[1]) == 0:
+ radial_range = None
+ else:
+ radial_range = np.array(geometry[1])
+ elif mode == "circular" or mode == "circle":
+ radial_range = np.array((0, geometry[1]))
+ if len(geometry[0]) == 0:
+ qxy_center = None
+ else:
+ qxy_center = np.array(geometry[0])
+ elif mode == "annular" or mode == "annulus":
+ radial_range = np.array(geometry[1])
+ if len(geometry[0]) == 0:
+ qxy_center = None
+ else:
+ qxy_center = np.array(geometry[0])
+
+ # allocate space
+ im_virtual = np.zeros(self.shape)
+
+ # generate image
+ for rx, ry in tqdmnd(
+ self.shape[0],
+ self.shape[1],
+ ):
+ # Get user-specified Bragg vectors
+ p = self.get_vectors(
+ rx,
+ ry,
+ center=center,
+ ellipse=ellipse,
+ pixel=pixel,
+ rotate=rotate,
+ )
+
+ if p.data.shape[0] > 0:
+ if radial_range is None:
+ im_virtual[rx, ry] = np.sum(p.I)
+ else:
+ if qxy_center is None:
+ qr = np.hypot(p.qx, p.qy)
+ else:
+ qr = np.hypot(p.qx - qxy_center[0], p.qy - qxy_center[1])
+ sub = np.logical_and(qr >= radial_range[0], qr < radial_range[1])
+ if np.sum(sub) > 0:
+ im_virtual[rx, ry] = np.sum(p.I[sub])
+
+ # wrap in Virtual Image class
+ ans = VirtualImage(data=im_virtual, name=name)
+ # add generating params as metadta
+ ans.metadata = Metadata(
+ name="gen_params",
+ data={
+ "_calling_method": inspect.stack()[0][3],
+ "_calling_class": __class__.__name__,
+ "mode": mode,
+ "geometry": geometry,
+ "name": name,
+ "returncalc": returncalc,
+ },
+ )
+ # attach to the tree
+ self.attach(ans)
+
+ # return
+ if returncalc:
+ return ans
+
+ # calibration measurements
+
+ def measure_origin(
+ self,
+ center_guess=None,
+ score_method=None,
+ findcenter="max",
+ ):
+ """
+ Finds the diffraction shifts of the center beam using the raw Bragg
+ vector measurements.
+
+ If a center guess is not specified, first, a guess at the unscattered
+ beam position is determined, either by taking the CoM of the Bragg vector
+ map, or by taking its maximal pixel. Once a unscattered beam position is
+ determined, the Bragg peak closest to this position is identified. The
+ shifts in these peaks positions from their average are returned as the
+ diffraction shifts.
+
+ Parameters
+ ----------
+ center_guess : 2-tuple
+ initial guess for the center
+ score_method : str
+ Method used to find center peak
+ - 'intensity': finds the most intense Bragg peak near the center
+ - 'distance': finds the closest Bragg peak to the center
+ - 'intensity weighted distance': determines center through a
+ combination of weighting distance and intensity
+ findcenter (str): specifies the method for determining the unscattered beam
+ position options: 'CoM', or 'max.' Only used if center_guess is None.
+ CoM finds the center of mass of bragg ector map, 'max' uses its
+ brightest pixel.
+
+ Returns:
+ (3-tuple): A 3-tuple comprised of:
+
+ * **qx0** *((R_Nx,R_Ny)-shaped array)*: the origin x-coord
+ * **qy0** *((R_Nx,R_Ny)-shaped array)*: the origin y-coord
+ * **braggvectormap** *((R_Nx,R_Ny)-shaped array)*: the Bragg vector map of only
+ the Bragg peaks identified with the unscattered beam. Useful for diagnostic
+ purposes.
+ """
+ assert findcenter in ["CoM", "max"], "center must be either 'CoM' or 'max'"
+ assert score_method in [
+ "distance",
+ "intensity",
+ "intensity weighted distance",
+ None,
+ ], "center must be either 'distance' or 'intensity weighted distance'"
+
+ R_Nx, R_Ny = self.Rshape
+ Q_Nx, Q_Ny = self.Qshape
+
+ # Default scoring method
+ if score_method is None:
+ if center_guess is None:
+ score_method = "intensity"
+ else:
+ score_method = "distance"
+
+ # Get guess at position of unscattered beam (x0,y0)
+ if center_guess is None:
+ bvm = self.histogram(mode="raw")
+ if findcenter == "max":
+ x0, y0 = np.unravel_index(
+ np.argmax(gaussian_filter(bvm, 10)), (Q_Nx, Q_Ny)
+ )
+ else:
+ from py4DSTEM.process.utils import get_CoM
+
+ x0, y0 = get_CoM(bvm)
+ else:
+ x0, y0 = center_guess
+
+ # Get Bragg peak closest to unscattered beam at each scan position
+ qx0 = np.zeros(self.Rshape)
+ qy0 = np.zeros(self.Rshape)
+ mask = np.ones(self.Rshape, dtype=bool)
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ vects = self.raw[Rx, Ry].data
+ if len(vects) > 0:
+ if score_method == "distance":
+ r2 = (vects["qx"] - x0) ** 2 + (vects["qy"] - y0) ** 2
+ index = np.argmin(r2)
+ elif score_method == "intensity":
+ index = np.argmax(vects["intensity"])
+ elif score_method == "intensity weighted distance":
+ r2 = vects["intensity"] / (
+ 1 + ((vects["qx"] - x0) ** 2 + (vects["qy"] - y0) ** 2)
+ )
+ index = np.argmax(r2)
+ qx0[Rx, Ry] = vects["qx"][index]
+ qy0[Rx, Ry] = vects["qy"][index]
+ else:
+ mask = False
+ qx0[Rx, Ry] = x0
+ qy0[Rx, Ry] = y0
+
+ # set calibration metadata
+ self.calibration.set_origin_meas((qx0, qy0))
+ self.calibration.set_origin_meas_mask(mask)
+
+ # return
+ return qx0, qy0, mask
+
+ def measure_origin_beamstop(
+ self, center_guess, radii, max_dist=None, max_iter=1, **kwargs
+ ):
+ """
+ Find the origin from a set of braggpeaks assuming there is a beamstop, by identifying
+ pairs of conjugate peaks inside an annular region and finding their centers of mass.
+
+ Args:
+ center_guess (2-tuple): qx0,qy0
+ radii (2-tuple): the inner and outer radii of the specified annular region
+ max_dist (number): the maximum allowed distance between the reflection of two
+ peaks to consider them conjugate pairs
+ max_iter (integer): for values >1, repeats the algorithm after updating center_guess
+
+ Returns:
+ (2d masked array): the origins
+ """
+ R_Nx, R_Ny = self.Rshape
+ braggpeaks = self._v_uncal
+
+ if max_dist is None:
+ max_dist = radii[1]
+
+ # remove peaks outside the annulus
+ braggpeaks_masked = braggpeaks.copy()
+ for rx in range(R_Nx):
+ for ry in range(R_Ny):
+ pl = braggpeaks_masked[rx, ry]
+ qr = np.hypot(
+ pl.data["qx"] - center_guess[0], pl.data["qy"] - center_guess[1]
+ )
+ rm = np.logical_not(np.logical_and(qr >= radii[0], qr <= radii[1]))
+ pl.remove(rm)
+
+ # Find all matching conjugate pairs of peaks
+ center_curr = center_guess
+ for ii in range(max_iter):
+ centers = np.zeros((R_Nx, R_Ny, 2))
+ found_center = np.zeros((R_Nx, R_Ny), dtype=bool)
+ for rx in range(R_Nx):
+ for ry in range(R_Ny):
+ # Get data
+ pl = braggpeaks_masked[rx, ry]
+ is_paired = np.zeros(len(pl.data), dtype=bool)
+ matches = []
+
+ # Find matching pairs
+ for i in range(len(pl.data)):
+ if not is_paired[i]:
+ x, y = pl.data["qx"][i], pl.data["qy"][i]
+ x_r = -x + 2 * center_curr[0]
+ y_r = -y + 2 * center_curr[1]
+ dists = np.hypot(x_r - pl.data["qx"], y_r - pl.data["qy"])
+ dists[is_paired] = max_dist
+ matched = dists <= max_dist
+ if any(matched):
+ match = np.argmin(dists)
+ matches.append((i, match))
+ is_paired[i], is_paired[match] = True, True
+
+ # Find the center
+ if len(matches) > 0:
+ x0, y0 = [], []
+ for i in range(len(matches)):
+ x0.append(np.mean(pl.data["qx"][list(matches[i])]))
+ y0.append(np.mean(pl.data["qy"][list(matches[i])]))
+ x0, y0 = np.mean(x0), np.mean(y0)
+ centers[rx, ry, :] = x0, y0
+ found_center[rx, ry] = True
+ else:
+ found_center[rx, ry] = False
+
+ # Update current center guess
+ x0_curr = np.mean(centers[found_center, 0])
+ y0_curr = np.mean(centers[found_center, 1])
+ center_curr = x0_curr, y0_curr
+
+ # collect answers
+ mask = found_center
+ qx0, qy0 = centers[:, :, 0], centers[:, :, 1]
+
+ # set calibration metadata
+ self.calibration.set_origin_meas((qx0, qy0))
+ self.calibration.set_origin_meas_mask(mask)
+
+ # return
+ return qx0, qy0, mask
+
+ def fit_origin(
+ self,
+ mask=None,
+ fitfunction="plane",
+ robust=False,
+ robust_steps=3,
+ robust_thresh=2,
+ mask_check_data=True,
+ plot=True,
+ plot_range=None,
+ cmap="RdBu_r",
+ returncalc=True,
+ **kwargs,
+ ):
+ """
+ Fit origin of bragg vectors.
+
+ Args:
+ mask (2b boolean array, optional): ignore points where mask=True
+ fitfunction (str, optional): must be 'plane' or 'parabola' or 'bezier_two'
+ robust (bool, optional): If set to True, fit will be repeated with outliers
+ removed.
+ robust_steps (int, optional): Optional parameter. Number of robust iterations
+ performed after initial fit.
+ robust_thresh (int, optional): Threshold for including points, in units of
+ root-mean-square (standard deviations) error of the predicted values after
+ fitting.
+ mask_check_data (bool): Get mask from origin measurements equal to zero. (TODO - replace)
+ plot (bool, optional): plot results
+ plot_range (float): min and max color range for plot (pixels)
+ cmap (colormap): plotting colormap
+
+ Returns:
+ (variable): Return value depends on returnfitp. If ``returnfitp==False``
+ (default), returns a 4-tuple containing:
+
+ * **qx0_fit**: *(ndarray)* the fit origin x-position
+ * **qy0_fit**: *(ndarray)* the fit origin y-position
+ * **qx0_residuals**: *(ndarray)* the x-position fit residuals
+ * **qy0_residuals**: *(ndarray)* the y-position fit residuals
+ """
+ q_meas = self.calibration.get_origin_meas()
+
+ from py4DSTEM.process.calibration import fit_origin
+
+ if mask_check_data is True:
+ data_mask = np.logical_not(q_meas[0] == 0)
+ if mask is None:
+ mask = data_mask
+ else:
+ mask = np.logical_and(mask, data_mask)
+
+ qx0_fit, qy0_fit, qx0_residuals, qy0_residuals = fit_origin(
+ tuple(q_meas),
+ mask=mask,
+ fitfunction=fitfunction,
+ robust=robust,
+ robust_steps=robust_steps,
+ robust_thresh=robust_thresh,
+ )
+
+ # try to add update calibration metadata
+ try:
+ self.calibration.set_origin((qx0_fit, qy0_fit))
+ self.setcal()
+ except AttributeError:
+ warn(
+ "No calibration found on this datacube - fit values are not being stored"
+ )
+ pass
+
+ # show
+ if plot:
+ self.show_origin_fit(
+ q_meas[0],
+ q_meas[1],
+ qx0_fit,
+ qy0_fit,
+ qx0_residuals,
+ qy0_residuals,
+ mask=mask,
+ plot_range=plot_range,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ # return
+ if returncalc:
+ return qx0_fit, qy0_fit, qx0_residuals, qy0_residuals
+
+ def show_origin_fit(
+ self,
+ qx0_meas,
+ qy0_meas,
+ qx0_fit,
+ qy0_fit,
+ qx0_residuals,
+ qy0_residuals,
+ mask=None,
+ plot_range=None,
+ cmap="RdBu_r",
+ **kwargs,
+ ):
+ # apply mask
+ if mask is not None:
+ qx0_meas = np.ma.masked_array(qx0_meas, mask=np.logical_not(mask))
+ qy0_meas = np.ma.masked_array(qy0_meas, mask=np.logical_not(mask))
+ qx0_residuals = np.ma.masked_array(qx0_residuals, mask=np.logical_not(mask))
+ qy0_residuals = np.ma.masked_array(qy0_residuals, mask=np.logical_not(mask))
+ qx0_mean = np.mean(qx0_fit)
+ qy0_mean = np.mean(qy0_fit)
+
+ # set range
+ if plot_range is None:
+ plot_range = max(
+ (
+ 1.5 * np.max(np.abs(qx0_fit - qx0_mean)),
+ 1.5 * np.max(np.abs(qy0_fit - qy0_mean)),
+ )
+ )
+
+ # set figsize
+ imsize_ratio = np.sqrt(qx0_meas.shape[1] / qx0_meas.shape[0])
+ axsize = (3 * imsize_ratio, 3 / imsize_ratio)
+ axsize = kwargs.pop("axsize", axsize)
+
+ # plot
+ fig, ax = show(
+ [
+ [qx0_meas - qx0_mean, qx0_fit - qx0_mean, qx0_residuals],
+ [qy0_meas - qy0_mean, qy0_fit - qy0_mean, qy0_residuals],
+ ],
+ cmap=cmap,
+ axsize=axsize,
+ title=[
+ "measured origin, x",
+ "fitorigin, x",
+ "residuals, x",
+ "measured origin, y",
+ "fitorigin, y",
+ "residuals, y",
+ ],
+ vmin=-1 * plot_range,
+ vmax=1 * plot_range,
+ intensity_range="absolute",
+ show_cbar=True,
+ returnfig=True,
+ **kwargs,
+ )
+ plt.tight_layout()
+
+ return
+
+ def fit_p_ellipse(
+ self, bvm, center, fitradii, mask=None, returncalc=False, **kwargs
+ ):
+ """
+ Args:
+ bvm (BraggVectorMap): a 2D array used for ellipse fitting
+ center (2-tuple of floats): the center (x0,y0) of the annular fitting region
+ fitradii (2-tuple of floats): inner and outer radii (ri,ro) of the fit region
+ mask (ar-shaped ndarray of bools): ignore data wherever mask==True
+
+ Returns:
+ p_ellipse if returncal is True
+ """
+ from py4DSTEM.process.calibration import fit_ellipse_1D
+
+ p_ellipse = fit_ellipse_1D(bvm, center, fitradii, mask)
+
+ scaling = kwargs.get("scaling", "log")
+ kwargs.pop("scaling", None)
+ from py4DSTEM.visualize import show_elliptical_fit
+
+ show_elliptical_fit(bvm, fitradii, p_ellipse, scaling=scaling, **kwargs)
+
+ self.calibration.set_p_ellipse(p_ellipse)
+ self.setcal()
+
+ if returncalc:
+ return p_ellipse
+
+ def mask_in_Q(self, mask, update_inplace=False, returncalc=True):
+ """
+ Remove peaks which fall inside the diffraction shaped boolean array
+ `mask`, in raw (uncalibrated) peak positions.
+
+ Parameters
+ ----------
+ mask : 2d boolean array
+ The mask. Must be diffraction space shaped
+ update_inplace : bool
+ If False (default) copies this BraggVectors instance and
+ removes peaks from the copied instance. If True, removes
+ peaks from this instance.
+ returncalc : bool
+ Toggles returning the answer
+
+ Returns
+ -------
+ bvects : BraggVectors
+ """
+ # Copy peaks, if requested
+ if update_inplace:
+ v = self._v_uncal
+ else:
+ v = self._v_uncal.copy(name="_v_uncal")
+
+ # Loop and remove masked peaks
+ for rx in range(v.shape[0]):
+ for ry in range(v.shape[1]):
+ p = v[rx, ry]
+ xs = np.round(p.data["qx"]).astype(int)
+ ys = np.round(p.data["qy"]).astype(int)
+ sub = mask[xs, ys]
+ p.remove(sub)
+
+ # assign the return value
+ if update_inplace:
+ ans = self
+ else:
+ ans = self.copy(name=self.name + "_masked")
+ ans.set_raw_vectors(v)
+
+ # return
+ if returncalc:
+ return ans
+ else:
+ return
+
+ # alias
+ def get_masked_peaks(self, mask, update_inplace=False, returncalc=True):
+ """
+ Alias for `mask_in_Q`.
+ """
+ warn(
+ "`.get_masked_peaks` is deprecated and will be removed in a future version. Use `.mask_in_Q`"
+ )
+ return self.mask_in_Q(
+ mask=mask, update_inplace=update_inplace, returncalc=returncalc
+ )
+
+ def mask_in_R(self, mask, update_inplace=False, returncalc=True):
+ """
+ Remove peaks which fall inside the real space shaped boolean array
+ `mask`.
+
+ Parameters
+ ----------
+ mask : 2d boolean array
+ The mask. Must be real space shaped
+ update_inplace : bool
+ If False (default) copies this BraggVectors instance and
+ removes peaks from the copied instance. If True, removes
+ peaks from this instance.
+ returncalc : bool
+ Toggles returning the answer
+
+ Returns
+ -------
+ bvects : BraggVectors
+ """
+ # Copy peaks, if requested
+ if update_inplace:
+ v = self._v_uncal
+ else:
+ v = self._v_uncal.copy(name="_v_uncal")
+
+ # Loop and remove masked peaks
+ for rx in range(v.shape[0]):
+ for ry in range(v.shape[1]):
+ if mask[rx, ry]:
+ p = v[rx, ry]
+ p.remove(np.ones(len(p), dtype=bool))
+
+ # assign the return value
+ if update_inplace:
+ ans = self
+ else:
+ ans = self.copy(name=self.name + "_masked")
+ ans.set_raw_vectors(v)
+
+ # return
+ if returncalc:
+ return ans
+ else:
+ return
+
+ def to_strainmap(self, name: str = None):
+ """
+ Generate a StrainMap object from the BraggVectors
+ equivalent to py4DSTEM.StrainMap(braggvectors=braggvectors)
+
+ Args:
+ name (str, optional): The name of the strainmap. Defaults to None which reverts to default name 'strainmap'.
+
+ Returns:
+ py4DSTEM.StrainMap: A py4DSTEM StrainMap object generated from the BraggVectors
+ """
+ from py4DSTEM.process.strain import StrainMap
+
+ return StrainMap(self, name) if name else StrainMap(self)
+
+
+######### END BraggVectorMethods CLASS ########
+
+
+class BraggVectorMap(Array):
+ def __init__(self, name, data, weights, dims, dim_units, origin, pixelsize):
+ Array.__init__(
+ self,
+ name=name,
+ data=data,
+ dims=dims,
+ dim_units=[dim_units, dim_units],
+ )
+ self.metadata = Metadata(
+ name="grid",
+ data={"origin": origin, "pixelsize": pixelsize, "weights": weights},
+ )
+
+ @property
+ def origin(self):
+ return self.metadata["grid"]["origin"]
+
+ @property
+ def pixelsize(self):
+ return self.metadata["grid"]["pixelsize"]
+
+ @property
+ def pixelunits(self):
+ return self.dim_units[0]
+
+ @property
+ def weights(self):
+ return self.metadata["grid"]["weights"]
+
+ # read
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of args/values to pass to the class constructor
+ """
+ constr_args = Array._get_constructor_args(group)
+ metadata = _read_metadata(group, "grid")
+ args = {
+ "name": constr_args["name"],
+ "data": constr_args["data"],
+ "weights": metadata["weights"],
+ "dims": constr_args["dims"],
+ "dim_units": constr_args["dim_units"],
+ "origin": metadata["origin"],
+ "pixelsize": metadata["pixelsize"],
+ }
+ return args
diff --git a/py4DSTEM/braggvectors/braggvectors.py b/py4DSTEM/braggvectors/braggvectors.py
new file mode 100644
index 000000000..e81eeb62f
--- /dev/null
+++ b/py4DSTEM/braggvectors/braggvectors.py
@@ -0,0 +1,500 @@
+# Defines the BraggVectors class
+
+from py4DSTEM.data import Data
+from emdfile import Custom, PointListArray, PointList, Metadata
+from py4DSTEM.braggvectors.braggvector_methods import BraggVectorMethods
+from os.path import basename
+import numpy as np
+from warnings import warn
+
+
+class BraggVectors(Custom, BraggVectorMethods, Data):
+ """
+ Stores localized bragg scattering positions and intensities
+ for a 4D-STEM datacube.
+
+ Raw (detector coordinate) vectors are accessible as
+
+ >>> braggvectors.raw[ scan_x, scan_y ]
+
+ and calibrated vectors as
+
+ >>> braggvectors.cal[ scan_x, scan_y ]
+
+ To set which calibrations are being applied, call
+
+ >>> braggvectors.setcal(
+ >>> center = bool,
+ >>> ellipse = bool,
+ >>> pixel = bool,
+ >>> rotate = bool
+ >>> )
+
+ If .setcal is not called, calibrations will be automatically selected based
+ based on the contents of the instance's `calibrations` property. The
+ calibrations performed in the last call to `braggvectors.cal` are exposed as
+
+ >>> braggvectors.calstate
+
+ After grabbing some vectors
+
+ >>> vects = braggvectors.raw[ scan_x,scan_y ]
+
+ the values themselves are accessible as
+
+ >>> vects.qx,vects.qy,vects.I
+ >>> vects['qx'],vects['qy'],vects['intensity']
+
+ Alternatively, you can access the centered vectors in pixel units with
+
+ >>> vects.get_vectors(
+ >>> scan_x,
+ >>> scan_y,
+ >>> center = bool,
+ >>> ellipse = bool,
+ >>> pixel = bool,
+ >>> rotate = bool
+ >>> )
+
+ which will return the vectors at scan position (scan_x,scan_y) with the
+ requested calibrations applied.
+ """
+
+ def __init__(
+ self, Rshape, Qshape, name="braggvectors", verbose=False, calibration=None
+ ):
+ Custom.__init__(self, name=name)
+ Data.__init__(self, calibration=calibration)
+
+ self.Rshape = Rshape
+ self.Qshape = Qshape
+ self.verbose = verbose
+
+ self._v_uncal = PointListArray(
+ dtype=[("qx", np.float64), ("qy", np.float64), ("intensity", np.float64)],
+ shape=Rshape,
+ name="_v_uncal",
+ )
+
+ # initial calibration state
+ self._calstate = {
+ "center": False,
+ "ellipse": False,
+ "pixel": False,
+ "rotate": False,
+ }
+
+ # register with calibrations
+ self.calibration.register_target(self)
+
+ # setup vector getters
+ self._set_raw_vector_getter()
+ self._set_cal_vector_getter()
+
+ # set new raw vectors
+ def set_raw_vectors(self, x):
+ """Given some PointListArray x of the correct shape, sets this to the raw vectors"""
+ assert isinstance(
+ x, PointListArray
+ ), f"Raw vectors must be set to a PointListArray, not type {type(x)}"
+ assert x.shape == self.Rshape, "Shapes don't match!"
+ self._v_uncal = x
+ self._set_raw_vector_getter()
+ self._set_cal_vector_getter()
+
+ # calibration state, vector getters
+
+ @property
+ def calstate(self):
+ return self._calstate
+
+ def _set_raw_vector_getter(self):
+ self._raw_vector_getter = RawVectorGetter(braggvects=self)
+
+ def _set_cal_vector_getter(self):
+ self._cal_vector_getter = CalibratedVectorGetter(braggvects=self)
+
+ # shape
+ @property
+ def shape(self):
+ return self.Rshape
+
+ # raw vectors
+
+ @property
+ def raw(self):
+ """
+ Calling
+
+ >>> raw[ scan_x, scan_y ]
+
+ returns those bragg vectors.
+ """
+ # use the vector getter to grab the vector
+ return self._raw_vector_getter
+
+ # calibrated vectors
+
+ @property
+ def cal(self):
+ """
+ Calling
+
+ >>> cal[ scan_x, scan_y ]
+
+ retrieves data. Use `.setcal` to set the calibrations to be applied, or
+ `.calstate` to see which calibrations are currently set. Calibrations
+ are initially all set to False. Call `.setcal()` (with no arguments)
+ to automatically detect which calibrations are present and apply those.
+ """
+ # retrieve the getter and return
+ return self._cal_vector_getter
+
+ # set calibration state
+
+ def setcal(
+ self,
+ center=None,
+ ellipse=None,
+ pixel=None,
+ rotate=None,
+ ):
+ """
+ Calling
+
+ >>> braggvectors.setcal(
+ >>> center = bool,
+ >>> ellipse = bool,
+ >>> pixel = bool,
+ >>> rotate = bool,
+ >>> )
+
+ sets the calibrations that will be applied to vectors subsequently
+ retrieved with
+
+ >>> braggvectors.cal[ scan_x, scan_y ]
+
+ Any arguments left as `None` will be automatically set based on
+ the calibration measurements available.
+ """
+
+ # check for calibrations
+ try:
+ c = self.calibration
+ # if no calibrations are found, print a warning and set all to False
+ except Exception:
+ warn("No calibrations found at .calibration; setting all cals to False")
+ self._calstate = {
+ "center": False,
+ "ellipse": False,
+ "pixel": False,
+ "rotate": False,
+ }
+ return
+
+ # autodetect
+ if center is None:
+ center = False if c.get_origin() is None else True
+ if ellipse is None:
+ ellipse = False if c.get_ellipse() is None else True
+ if pixel is None:
+ pixel = False if c.get_Q_pixel_size() == 1 else True
+ if rotate is None:
+ rotate = False if c.get_QR_rotation() is None else True
+
+ # validate requested state
+ if center:
+ assert c.get_origin() is not None, "Requested calibration not found"
+ if ellipse:
+ assert c.get_ellipse() is not None, "Requested calibration not found"
+ if pixel:
+ assert c.get_Q_pixel_size() is not None, "Requested calibration not found"
+ if rotate:
+ assert c.get_QR_rotation() is not None, "Requested calibration not found"
+
+ # set the calibrations
+ self._calstate = {
+ "center": center,
+ "ellipse": ellipse,
+ "pixel": pixel,
+ "rotate": rotate,
+ }
+ if self.verbose:
+ print("current calibration state: ", self.calstate)
+ pass
+
+ def calibrate(self):
+ """
+ Autoupdate the calstate when relevant calibrations are set
+ """
+ self.setcal()
+
+ # vector getter method
+
+ def get_vectors(self, scan_x, scan_y, center, ellipse, pixel, rotate):
+ """
+ Returns the bragg vectors at the specified scan position with
+ the specified calibration state.
+
+ Parameters
+ ----------
+ scan_x : int
+ scan_y : int
+ center : bool
+ ellipse : bool
+ pixel : bool
+ rotate : bool
+
+ Returns
+ -------
+ vectors : BVects
+ """
+
+ ans = self._v_uncal[scan_x, scan_y].data
+ ans = self.cal._transform(
+ data=ans,
+ cal=self.calibration,
+ scanxy=(scan_x, scan_y),
+ center=center,
+ ellipse=ellipse,
+ pixel=pixel,
+ rotate=rotate,
+ )
+ return BVects(ans)
+
+ # copy
+ def copy(self, name=None):
+ name = name if name is not None else self.name + "_copy"
+ braggvector_copy = BraggVectors(
+ self.Rshape, self.Qshape, name=name, calibration=self.calibration.copy()
+ )
+
+ braggvector_copy.set_raw_vectors(self._v_uncal.copy())
+ for k in self.metadata.keys():
+ braggvector_copy.metadata = self.metadata[k].copy()
+ braggvector_copy.setcal()
+ return braggvector_copy
+
+ # write
+
+ def to_h5(self, group):
+ """Constructs the group, adds the bragg vector pointlists,
+ and adds metadata describing the shape
+ """
+ md = Metadata(name="_braggvectors_shape")
+ md["Rshape"] = self.Rshape
+ md["Qshape"] = self.Qshape
+ self.metadata = md
+ grp = Custom.to_h5(self, group)
+ return grp
+
+ # read
+
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """ """
+ # Get shape metadata from the metadatabundle group
+ assert (
+ "metadatabundle" in group.keys()
+ ), "No metadata found, can't get Rshape and Qshape"
+ grp_metadata = group["metadatabundle"]
+ assert (
+ "_braggvectors_shape" in grp_metadata.keys()
+ ), "No _braggvectors_shape metadata found"
+ md = Metadata.from_h5(grp_metadata["_braggvectors_shape"])
+ # Populate args and return
+ kwargs = {
+ "name": basename(group.name),
+ "Rshape": md["Rshape"],
+ "Qshape": md["Qshape"],
+ }
+ return kwargs
+
+ def _populate_instance(self, group):
+ """ """
+ # Get the vectors
+ dic = self._get_emd_attr_data(group)
+ assert "_v_uncal" in dic.keys(), "Uncalibrated bragg vectors not found!"
+ self._v_uncal = dic["_v_uncal"]
+ # Point the vector getters to the vectors
+ self._set_raw_vector_getter()
+ self._set_cal_vector_getter()
+
+ # standard output display
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( "
+ string += f"A {self.shape}-shaped array of lists of bragg vectors )"
+ return string
+
+
+# Vector access classes
+
+
+class BVects:
+ """
+ Enables
+
+ >>> v.qx,v.qy,v.I
+
+ -like access to a collection of Bragg vector.
+ """
+
+ def __init__(self, data):
+ """pointlist must have fields 'qx', 'qy', and 'intensity'"""
+ self._data = data
+
+ @property
+ def qx(self):
+ return self._data["qx"]
+
+ @property
+ def qy(self):
+ return self._data["qy"]
+
+ @property
+ def I(self):
+ return self._data["intensity"]
+
+ @property
+ def data(self):
+ return self._data
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( "
+ string += f"A set of {len(self.data)} bragg vectors."
+ string += " Access data with .qx, .qy, .I, or .data.)"
+ return string
+
+
+class RawVectorGetter:
+ def __init__(
+ self,
+ braggvects,
+ ):
+ self._bvects = braggvects
+ self._data = braggvects._v_uncal
+
+ def __getitem__(self, pos):
+ x, y = pos
+ ans = self._data[x, y].data
+ return BVects(ans)
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( "
+ string += "Retrieves raw bragg vectors. Get vectors for scan position x,y with [x,y]. )"
+ return string
+
+
+class CalibratedVectorGetter:
+ def __init__(
+ self,
+ braggvects,
+ ):
+ self._bvects = braggvects
+ self._data = braggvects._v_uncal
+
+ def __getitem__(self, pos):
+ x, y = pos
+ ans = self._data[x, y].data
+ ans = self._transform(
+ data=ans,
+ cal=self._bvects.calibration,
+ scanxy=(x, y),
+ center=self._bvects.calstate["center"],
+ ellipse=self._bvects.calstate["ellipse"],
+ pixel=self._bvects.calstate["pixel"],
+ rotate=self._bvects.calstate["rotate"],
+ )
+ return BVects(ans)
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( "
+ string += "Retrieves calibrated Bragg vectors. Get vectors for scan position x,y with [x,y]."
+ string += (
+ "\n"
+ + space
+ + "Set which calibrations to apply with braggvectors.setcal(...). )"
+ )
+ return string
+
+ def _transform(
+ self,
+ data,
+ cal,
+ scanxy,
+ center,
+ ellipse,
+ pixel,
+ rotate,
+ ):
+ """
+ Return a transformed copy of stractured data `data` with fields
+ with fields 'qx','qy','intensity', applying calibrating transforms
+ according to the values of center, ellipse, pixel, using the
+ measurements found in Calibration instance cal for scan position scanxy.
+ """
+
+ ans = data.copy()
+ x, y = scanxy
+
+ # origin
+
+ if center:
+ origin = cal.get_origin(x, y)
+ assert origin is not None, "Requested calibration was not found!"
+ ans["qx"] -= origin[0]
+ ans["qy"] -= origin[1]
+
+ # ellipse
+ if ellipse:
+ ell = cal.get_ellipse(x, y)
+ assert ell is not None, "Requested calibration was not found!"
+ a, b, theta = ell
+ # Get the transformation matrix
+ e = b / a
+ sint, cost = np.sin(theta - np.pi / 2.0), np.cos(theta - np.pi / 2.0)
+ T = np.array(
+ [
+ [e * sint**2 + cost**2, sint * cost * (1 - e)],
+ [sint * cost * (1 - e), sint**2 + e * cost**2],
+ ]
+ )
+ # apply it
+ xyar_i = np.vstack([ans["qx"], ans["qy"]])
+ xyar_f = np.matmul(T, xyar_i)
+ ans["qx"] = xyar_f[0, :]
+ ans["qy"] = xyar_f[1, :]
+
+ # pixel size
+ if pixel:
+ qpix = cal.get_Q_pixel_size()
+ assert qpix is not None, "Requested calibration was not found!"
+ ans["qx"] *= qpix
+ ans["qy"] *= qpix
+
+ # Q/R rotation
+ if rotate:
+ theta = cal.get_QR_rotation()
+ assert theta is not None, "Requested calibration was not found!"
+ flip = cal.get_QR_flip()
+ flip = False if flip is None else flip
+ # rotation matrix
+ R = np.array(
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
+ )
+ # rotate and flip
+ if flip:
+ positions = R @ np.vstack((ans["qy"], ans["qx"]))
+ else:
+ positions = R @ np.vstack((ans["qx"], ans["qy"]))
+ # update
+ ans["qx"] = positions[0, :]
+ ans["qy"] = positions[1, :]
+
+ # return
+ return ans
diff --git a/py4DSTEM/braggvectors/diskdetection.py b/py4DSTEM/braggvectors/diskdetection.py
new file mode 100644
index 000000000..99818b75e
--- /dev/null
+++ b/py4DSTEM/braggvectors/diskdetection.py
@@ -0,0 +1,794 @@
+# Functions for finding Bragg scattering by cross correlative template matching
+# with a vacuum probe.
+
+import numpy as np
+from scipy.ndimage import gaussian_filter
+
+from emdfile import tqdmnd
+from py4DSTEM.braggvectors.braggvectors import BraggVectors
+from py4DSTEM.data import QPoints
+from py4DSTEM.datacube import DataCube
+from py4DSTEM.preprocess.utils import get_maxima_2D
+from py4DSTEM.process.utils.cross_correlate import get_cross_correlation_FT
+from py4DSTEM.braggvectors.diskdetection_aiml import find_Bragg_disks_aiml
+
+
+def find_Bragg_disks(
+ data,
+ template,
+ radial_bksb=False,
+ filter_function=None,
+ corrPower=1,
+ sigma=None,
+ sigma_dp=0,
+ sigma_cc=2,
+ subpixel="multicorr",
+ upsample_factor=16,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0.005,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ edgeBoundary=20,
+ maxNumPeaks=70,
+ CUDA=False,
+ CUDA_batched=True,
+ distributed=None,
+ ML=False,
+ ml_model_path=None,
+ ml_num_attempts=1,
+ ml_batch_size=8,
+):
+ """
+ Finds the Bragg disks in the diffraction patterns represented by `data` by
+ cross/phase correlatin with `template`.
+
+ Behavior depends on `data`. If it is
+
+ - a DataCube: runs on all its diffraction patterns, and returns a
+ BraggVectors instance
+ - a 2D array: runs on this array, and returns a QPoints instance
+ - a 3D array: runs slice the ar[i,:,:] slices of this array, and returns
+ a len(ar.shape[0]) list of QPoints instances.
+ - a 3-tuple (DataCube, rx, ry), for numbers or length-N arrays (rx,ry):
+ runs on the diffraction patterns in DataCube at positions (rx,ry),
+ and returns a instance or length N list of instances of QPoints
+
+ For disk detection on a full DataCube, the calculation can be performed
+ on the CPU, GPU or a cluster. By default the CPU is used. If `CUDA` is set
+ to True, tries to use the GPU. If `CUDA_batched` is also set to True,
+ batches the FFT/IFFT computations on the GPU. For distribution to a cluster,
+ distributed must be set to a dictionary, with contents describing how
+ distributed processing should be performed - see below for details.
+
+
+ For each diffraction pattern, the algorithm works in 4 steps:
+
+ (1) any pre-processing is performed to the diffraction image. This is
+ accomplished by passing a callable function to the argument
+ `filter_function`, a bool to the argument `radial_bksb`, or a value >0
+ to `sigma_dp`. If none of these are passed, this step is skipped.
+ (2) the diffraction image is cross correlated with the template.
+ Phase/hybrid correlations can be used instead by setting the
+ `corrPower` argument. Cross correlation can be skipped entirely,
+ and the subsequent steps performed directly on the diffraction
+ image instead of the cross correlation, by passing None to
+ `template`.
+ (3) the maxima of the cross correlation are located and their
+ positions and intensities stored. The cross correlation may be
+ passed through a gaussian filter first by passing the `sigma_cc`
+ argument. The method for maximum detection can be set with
+ the `subpixel` parameter. Options, from something like fastest/least
+ precise to slowest/most precise are 'pixel', 'poly', and 'multicorr'.
+ (4) filtering is applied to remove untrusted or undesired positive counts,
+ based on their intensity (`minRelativeIntensity`,`relativeToPeak`,
+ `minAbsoluteIntensity`) their proximity to one another or the
+ image edge (`minPeakSpacing`, `edgeBoundary`), and the total
+ number of peaks per pattern (`maxNumPeaks`).
+
+
+ Parameters
+ ----------
+ data : variable
+ see above
+ template : 2D array
+ the vacuum probe template, in real space. For Probe instances,
+ this is `probe.kernel`. If None, does not perform a cross
+ correlation.
+ radial_bksb : bool
+ if True, computes a radial background given by the median of the
+ (circular) polar transform of each each diffraction pattern, and
+ subtracts this background from the pattern before applying any
+ filter function and computing the cross correlation. The origin
+ position must be set in the datacube's calibrations. Currently
+ only supported for full datacubes on the CPU.
+ filter_function : callable
+ filtering function to apply to each diffraction pattern before
+ peak finding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern.
+ The shape of the returned DP must match the shape of the probe
+ kernel (but does not need to match the shape of the input
+ diffraction pattern, e.g. the filter can be used to bin the
+ diffraction pattern). If using distributed disk detection, the
+ function must be able to be pickled with by dill.
+ corrPower : float between 0 and 1, inclusive
+ the cross correlation power. A value of 1 corresponds to a cross
+ correlation, 0 corresponds to a phase correlation, and intermediate
+ values correspond to hybrid correlations.
+ sigma : float
+ alias for `sigma_cc`
+ sigma_dp : float
+ if >0, a gaussian smoothing filter with this standard deviation
+ is applied to the diffraction pattern before maxima are detected
+ sigma_cc : float
+ if >0, a gaussian smoothing filter with this standard deviation
+ is applied to the cross correlation before maxima are detected
+ subpixel : str
+ Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor : int
+ upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ minAbsoluteIntensity : float
+ the minimum acceptable correlation peak intensity, on an absolute scale
+ minRelativeIntensity : float
+ the minimum acceptable correlation peak intensity, relative to the
+ intensity of the brightest peak
+ relativeToPeak : int
+ specifies the peak against which the minimum relative intensity is
+ measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing : float
+ the minimum acceptable spacing between detected peaks
+ edgeBoundary (int): minimum acceptable distance for detected peaks from
+ the diffraction image edge, in pixels.
+ maxNumPeaks : int
+ the maximum number of peaks to return
+ CUDA : bool
+ If True, import cupy and use an NVIDIA GPU to perform disk detection
+ CUDA_batched : bool
+ If True, and CUDA is selected, the FFT and IFFT steps of disk detection
+ are performed in batches to better utilize GPU resources.
+ distributed : dict
+ contains information for parallel processing using an IPyParallel or
+ Dask distributed cluster. Valid keys are:
+ * ipyparallel (dict):
+ * client_file (str): path to client json for connecting to your
+ existing IPyParallel cluster
+ * dask (dict): client (object): a dask client that connects to
+ your existing Dask cluster
+ * data_file (str): the absolute path to your original data
+ file containing the datacube
+ * cluster_path (str): defaults to the working directory during
+ processing
+ if distributed is None, which is the default, processing will be in
+ serial
+
+ Returns
+ -------
+ variable
+ the Bragg peak positions and correlation intensities. If `data` is:
+ * a DataCube, returns a BraggVectors instance
+ * a 2D array, returns a QPoints instance
+ * a 3D array, returns a list of QPoints instances
+ * a (DataCube,rx,ry) 3-tuple, returns a list of QPoints
+ instances
+ """
+
+ # parse args
+ sigma_cc = sigma if sigma is not None else sigma_cc
+
+ # `data` type
+ if isinstance(data, DataCube):
+ mode = "datacube"
+ elif isinstance(data, np.ndarray):
+ if data.ndim == 2:
+ mode = "dp"
+ elif data.ndim == 3:
+ mode = "dp_stack"
+ else:
+ er = f"if `data` is an array, must be 2- or 3-D, not {data.ndim}-D"
+ raise Exception(er)
+ else:
+ try:
+ # when a position (rx,ry) is passed, get those patterns
+ # and put them in a stack
+ dc, rx, ry = data[0], data[1], data[2]
+
+ # h5py datasets have different rules for slicing than
+ # numpy arrays, so we have to do this manually
+ if "h5py" in str(type(dc.data)):
+ data = np.zeros((len(rx), dc.Q_Nx, dc.Q_Ny))
+ # no background subtraction
+ if not radial_bksb:
+ for i, (x, y) in enumerate(zip(rx, ry)):
+ data[i] = dc.data[x, y]
+ # with bksubtr
+ else:
+ for i, (x, y) in enumerate(zip(rx, ry)):
+ data[i] = dc.get_radial_bksb_dp(rx, ry)
+ else:
+ # no background subtraction
+ if not radial_bksb:
+ data = dc.data[np.array(rx), np.array(ry), :, :]
+ # with bksubtr
+ else:
+ data = np.zeros((len(rx), dc.Q_Nx, dc.Q_Ny))
+ for i, (x, y) in enumerate(zip(rx, ry)):
+ data[i] = dc.get_radial_bksb_dp(x, y)
+ if data.ndim == 2:
+ mode = "dp"
+ elif data.ndim == 3:
+ mode = "dp_stack"
+ except:
+ er = f"entry {data} for `data` could not be parsed"
+ raise Exception(er)
+
+ # CPU/GPU/cluster/ML-AI
+
+ if ML:
+ mode = "dc_ml"
+
+ elif mode == "datacube":
+ if distributed is None and CUDA is False:
+ mode = "dc_CPU"
+ elif distributed is None and CUDA is True:
+ if CUDA_batched is False:
+ mode = "dc_GPU"
+ else:
+ mode = "dc_GPU_batched"
+ else:
+ x = _parse_distributed(distributed)
+ connect, data_file, cluster_path, distributed_mode = x
+ if distributed_mode == "dask":
+ mode = "dc_dask"
+ elif distributed_mode == "ipyparallel":
+ mode = "dc_ipyparallel"
+ else:
+ er = f"unrecognized distributed mode {distributed_mode}"
+ raise Exception(er)
+ # overwrite if ML selected
+
+ # select a function
+ fn_dict = {
+ "dp": _find_Bragg_disks_single,
+ "dp_stack": _find_Bragg_disks_stack,
+ "dc_CPU": _find_Bragg_disks_CPU,
+ "dc_GPU": _find_Bragg_disks_CUDA_unbatched,
+ "dc_GPU_batched": _find_Bragg_disks_CUDA_batched,
+ "dc_dask": _find_Bragg_disks_dask,
+ "dc_ipyparallel": _find_Bragg_disks_ipp,
+ "dc_ml": find_Bragg_disks_aiml,
+ }
+ fn = fn_dict[mode]
+
+ # prepare kwargs
+ kws = {}
+ # distributed kwargs
+ if distributed is not None:
+ kws["connect"] = connect
+ kws["data_file"] = data_file
+ kws["cluster_path"] = cluster_path
+ # ML arguments
+ if ML is True:
+ kws["CUDA"] = CUDA
+ kws["model_path"] = ml_model_path
+ kws["num_attempts"] = ml_num_attempts
+ kws["batch_size"] = ml_batch_size
+
+ # if radial background subtraction is requested, add to args
+ if radial_bksb and mode == "dc_CPU":
+ kws["radial_bksb"] = radial_bksb
+
+ # run and return
+ ans = fn(
+ data,
+ template,
+ filter_function=filter_function,
+ corrPower=corrPower,
+ sigma_dp=sigma_dp,
+ sigma_cc=sigma_cc,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ **kws,
+ )
+ return ans
+
+
+# Single diffraction pattern
+
+
+def _find_Bragg_disks_single(
+ DP,
+ template,
+ filter_function=None,
+ corrPower=1,
+ sigma_dp=0,
+ sigma_cc=2,
+ subpixel="poly",
+ upsample_factor=16,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=0,
+ edgeBoundary=1,
+ maxNumPeaks=100,
+ _return_cc=False,
+ _template_space="real",
+):
+ # apply filter function
+ er = "filter_function must be callable"
+ if filter_function:
+ assert callable(filter_function), er
+ DP = DP if filter_function is None else filter_function(DP)
+
+ # check for a template
+ if template is None:
+ cc = DP
+ else:
+ # fourier transform the template
+ assert _template_space in ("real", "fourier")
+ if _template_space == "real":
+ template_FT = np.conj(np.fft.fft2(template))
+ else:
+ template_FT = template
+
+ # apply any smoothing to the data
+ if sigma_dp > 0:
+ DP = gaussian_filter(DP, sigma_dp)
+
+ # Compute cross correlation
+ # _returnval = 'fourier' if subpixel == 'multicorr' else 'real'
+ cc = get_cross_correlation_FT(
+ DP,
+ template_FT,
+ corrPower,
+ "fourier",
+ )
+
+ # Get maxima
+ maxima = get_maxima_2D(
+ np.maximum(np.real(np.fft.ifft2(cc)), 0),
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ sigma=sigma_cc,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minSpacing=minPeakSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ _ar_FT=cc,
+ )
+
+ # Wrap as QPoints instance
+ maxima = QPoints(maxima)
+
+ # Return
+ if _return_cc is True:
+ return maxima, cc
+ return maxima
+
+
+# def _get_cross_correlation_FT(
+# DP,
+# template_FT,
+# corrPower = 1,
+# _returnval = 'real'
+# ):
+# """
+# if _returnval is 'real', returns the real-valued cross-correlation.
+# otherwise, returns the complex valued result.
+# """
+#
+# m = np.fft.fft2(DP) * template_FT
+# cc = np.abs(m)**(corrPower) * np.exp(1j*np.angle(m))
+# if _returnval == 'real':
+# cc = np.maximum(np.real(np.fft.ifft2(cc)),0)
+# return cc
+
+
+# 3D stack of DPs
+
+
+def _find_Bragg_disks_stack(
+ dp_stack,
+ template,
+ filter_function=None,
+ corrPower=1,
+ sigma_dp=0,
+ sigma_cc=2,
+ subpixel="poly",
+ upsample_factor=16,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=0,
+ edgeBoundary=1,
+ maxNumPeaks=100,
+ _template_space="real",
+):
+ ans = []
+
+ for idx in range(dp_stack.shape[0]):
+ dp = dp_stack[idx, :, :]
+ peaks = _find_Bragg_disks_single(
+ dp,
+ template,
+ filter_function=filter_function,
+ corrPower=corrPower,
+ sigma_dp=sigma_dp,
+ sigma_cc=sigma_cc,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ _template_space=_template_space,
+ _return_cc=False,
+ )
+ ans.append(peaks)
+
+ return ans
+
+
+# Whole datacube, CPU
+
+
+def _find_Bragg_disks_CPU(
+ datacube,
+ probe,
+ filter_function=None,
+ corrPower=1,
+ sigma_dp=0,
+ sigma_cc=2,
+ subpixel="multicorr",
+ upsample_factor=16,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0.005,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ edgeBoundary=20,
+ maxNumPeaks=70,
+ radial_bksb=False,
+):
+ # Make the BraggVectors instance
+ braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape)
+
+ # Get the template's Fourier Transform
+ probe_kernel_FT = np.conj(np.fft.fft2(probe)) if probe is not None else None
+
+ # Loop over all diffraction patterns
+ # Compute and populate BraggVectors data
+ for rx, ry in tqdmnd(
+ datacube.R_Nx,
+ datacube.R_Ny,
+ desc="Finding Bragg Disks",
+ unit="DP",
+ unit_scale=True,
+ ):
+ # Get a diffraction pattern
+
+ # without background subtraction
+ if not radial_bksb:
+ dp = datacube.data[rx, ry, :, :]
+ # and with
+ else:
+ dp = datacube.get_radial_bksb_dp(rx, ry)
+
+ # Compute
+ peaks = _find_Bragg_disks_single(
+ dp,
+ template=probe_kernel_FT,
+ filter_function=filter_function,
+ corrPower=corrPower,
+ sigma_dp=sigma_dp,
+ sigma_cc=sigma_cc,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ _return_cc=False,
+ _template_space="fourier",
+ )
+
+ # Populate data
+ braggvectors._v_uncal[rx, ry] = peaks
+
+ # Return
+ return braggvectors
+
+
+# CUDA - unbatched
+
+
+def _find_Bragg_disks_CUDA_unbatched(
+ datacube,
+ probe,
+ filter_function=None,
+ corrPower=1,
+ sigma_dp=0,
+ sigma_cc=2,
+ subpixel="multicorr",
+ upsample_factor=16,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0.005,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ edgeBoundary=20,
+ maxNumPeaks=70,
+):
+ # compute
+ from py4DSTEM.braggvectors.diskdetection_cuda import find_Bragg_disks_CUDA
+
+ peaks = find_Bragg_disks_CUDA(
+ datacube,
+ probe,
+ filter_function=filter_function,
+ corrPower=corrPower,
+ sigma=sigma_cc,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ batching=False,
+ )
+
+ # Populate a BraggVectors instance and return
+ braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape)
+ braggvectors._v_uncal = peaks
+ braggvectors._set_raw_vector_getter()
+ braggvectors._set_cal_vector_getter()
+ return braggvectors
+
+
+# CUDA - batched
+
+
+def _find_Bragg_disks_CUDA_batched(
+ datacube,
+ probe,
+ filter_function=None,
+ corrPower=1,
+ sigma_dp=0,
+ sigma_cc=2,
+ subpixel="multicorr",
+ upsample_factor=16,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0.005,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ edgeBoundary=20,
+ maxNumPeaks=70,
+):
+ # compute
+ from py4DSTEM.braggvectors.diskdetection_cuda import find_Bragg_disks_CUDA
+
+ peaks = find_Bragg_disks_CUDA(
+ datacube,
+ probe,
+ filter_function=filter_function,
+ corrPower=corrPower,
+ sigma=sigma_cc,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ batching=True,
+ )
+
+ # Populate a BraggVectors instance and return
+ braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape)
+ braggvectors._v_uncal = peaks
+ braggvectors._set_raw_vector_getter()
+ braggvectors._set_cal_vector_getter()
+ return braggvectors
+
+
+# Distributed - ipyparallel
+
+
+def _find_Bragg_disks_ipp(
+ datacube,
+ probe,
+ connect,
+ data_file,
+ cluster_path,
+ filter_function=None,
+ corrPower=1,
+ sigma_dp=0,
+ sigma_cc=2,
+ subpixel="multicorr",
+ upsample_factor=16,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0.005,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ edgeBoundary=20,
+ maxNumPeaks=70,
+):
+ # compute
+ from py4DSTEM.braggvectors.diskdetection_parallel import find_Bragg_disks_ipp
+
+ peaks = find_Bragg_disks_ipp(
+ datacube,
+ probe,
+ filter_function=filter_function,
+ corrPower=corrPower,
+ sigma=sigma_cc,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ ipyparallel_client_file=connect,
+ data_file=data_file,
+ cluster_path=cluster_path,
+ )
+
+ # Populate a BraggVectors instance and return
+ braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape)
+ braggvectors._v_uncal = peaks
+ braggvectors._set_raw_vector_getter()
+ braggvectors._set_cal_vector_getter()
+ return braggvectors
+
+
+# Distributed - dask
+
+
+def _find_Bragg_disks_dask(
+ datacube,
+ probe,
+ connect,
+ data_file,
+ cluster_path,
+ filter_function=None,
+ corrPower=1,
+ sigma_dp=0,
+ sigma_cc=2,
+ subpixel="multicorr",
+ upsample_factor=16,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0.005,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ edgeBoundary=20,
+ maxNumPeaks=70,
+):
+ # compute
+ from py4DSTEM.braggvectors.diskdetection_parallel import find_Bragg_disks_dask
+
+ peaks = find_Bragg_disks_dask(
+ datacube,
+ probe,
+ filter_function=filter_function,
+ corrPower=corrPower,
+ sigma=sigma_cc,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ dask_client_file=connect,
+ data_file=data_file,
+ cluster_path=cluster_path,
+ )
+
+ # Populate a BraggVectors instance and return
+ braggvectors = BraggVectors(datacube.Rshape, datacube.Qshape)
+ braggvectors._v_uncal = peaks
+ braggvectors._set_raw_vector_getter()
+ braggvectors._set_cal_vector_getter()
+ return braggvectors
+
+
+def _parse_distributed(distributed):
+ """
+ Parse the `distributed` dict argument to determine distribution behavior
+ """
+ import os
+
+ # parse mode (ipyparallel or dask)
+ if "ipyparallel" in distributed:
+ mode = "ipyparallel"
+ if "client_file" in distributed["ipyparallel"]:
+ connect = distributed["ipyparallel"]["client_file"]
+ else:
+ er = 'Within distributed["ipyparallel"], '
+ er += 'missing key for "client_file"'
+ raise KeyError(er)
+
+ try:
+ import ipyparallel as ipp
+
+ c = ipp.Client(url_file=connect, timeout=30)
+
+ if len(c.ids) == 0:
+ er = "No IPyParallel engines attached to cluster!"
+ raise RuntimeError(er)
+ except ImportError:
+ raise ImportError("Unable to import module ipyparallel!")
+
+ elif "dask" in distributed:
+ mode = "dask"
+ if "client" in distributed["dask"]:
+ connect = distributed["dask"]["client"]
+ else:
+ er = 'Within distributed["dask"], missing key for "client"'
+ raise KeyError(er)
+
+ else:
+ er = "Within distributed, you must specify 'ipyparallel' or 'dask'!"
+ raise KeyError(er)
+
+ # parse data file
+ if "data_file" not in distributed:
+ er = "Missing input data file path to distributed! "
+ er += "Required key 'data_file'"
+ raise KeyError(er)
+
+ data_file = distributed["data_file"]
+
+ if not isinstance(data_file, str):
+ er = "Expected string for distributed key 'data_file', "
+ er += f"received {type(data_file)}"
+ raise TypeError(er)
+ if len(data_file.strip()) == 0:
+ er = "Empty data file path from distributed key 'data_file'"
+ raise ValueError(er)
+ elif not os.path.exists(data_file):
+ raise FileNotFoundError("File not found")
+
+ # parse cluster path
+ if "cluster_path" in distributed:
+ cluster_path = distributed["cluster_path"]
+
+ if not isinstance(cluster_path, str):
+ er = "distributed key 'cluster_path' must be of type str, "
+ er += f"received {type(cluster_path)}"
+ raise TypeError(er)
+
+ if len(cluster_path.strip()) == 0:
+ er = "distributed key 'cluster_path' cannot be an empty string!"
+ raise ValueError(er)
+ elif not os.path.exists(cluster_path):
+ er = f"distributed key 'cluster_path' does not exist: {cluster_path}"
+ raise FileNotFoundError(er)
+ elif not os.path.isdir(cluster_path):
+ er = "distributed key 'cluster_path' is not a directory: "
+ er += f"{cluster_path}"
+ raise NotADirectoryError(er)
+ else:
+ cluster_path = None
+
+ # return
+ return connect, data_file, cluster_path, mode
diff --git a/py4DSTEM/braggvectors/diskdetection_aiml.py b/py4DSTEM/braggvectors/diskdetection_aiml.py
new file mode 100644
index 000000000..4d23ebf6c
--- /dev/null
+++ b/py4DSTEM/braggvectors/diskdetection_aiml.py
@@ -0,0 +1,941 @@
+# Functions for finding Bragg disks using AI/ML pipeline
+"""
+Functions for finding Braggdisks using AI/ML method using tensorflow
+"""
+
+import os
+import glob
+import json
+import shutil
+import numpy as np
+
+from scipy.ndimage import gaussian_filter
+from time import time
+from numbers import Number
+
+from emdfile import tqdmnd, PointList, PointListArray
+from py4DSTEM.braggvectors.braggvectors import BraggVectors
+from py4DSTEM.data import QPoints
+from py4DSTEM.process.utils import get_maxima_2D
+
+# from py4DSTEM.braggvectors import universal_threshold
+
+
+def find_Bragg_disks_aiml_single_DP(
+ DP,
+ probe,
+ num_attempts=5,
+ int_window_radius=1,
+ predict=True,
+ sigma=0,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="multicorr",
+ upsample_factor=16,
+ filter_function=None,
+ peaks=None,
+ model_path=None,
+):
+ """
+ Finds the Bragg disks in single DP by AI/ML method. This method utilizes FCU-Net
+ to predict Bragg disks from diffraction images.
+
+ The input DP and Probes need to be aligned before the prediction. Detected peaks within
+ edgeBoundary pixels of the diffraction plane edges are then discarded. Next, peaks
+ with intensities less than minRelativeIntensity of the brightest peak in the
+ correlation are discarded. Then peaks which are within a distance of minPeakSpacing
+ of their nearest neighbor peak are found, and in each such pair the peak with the
+ lesser correlation intensities is removed. Finally, if the number of peaks remaining
+ exceeds maxNumPeaks, only the maxNumPeaks peaks with the highest correlation
+ intensity are retained.
+
+ Args:
+ DP (ndarray): a diffraction pattern
+ probe (ndarray): the vacuum probe template
+ num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5
+ Ideally, the more num_attempts the better (confident) the prediction will be
+ as the ML prediction utilizes Monte Carlo Dropout technique to estimate model
+ uncertainty using Bayesian approach. Note: increasing num_attempts will increase
+ the compute time significantly and it is advised to use GPU (CUDA) enabled environment
+ for fast prediction with num_attempts > 1
+ int_window_radius (int): window radius (in pixels) for disk intensity integration over the
+ predicted atomic potentials array
+ predict (bool): Flag to determine if ML prediction is opted.
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the relativeToPeak'th peak
+ minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity,
+ on an absolute scale
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The shape
+ of the returned DP must match the shape of the probe kernel (but does not
+ need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ peaks (PointList): For internal use. If peaks is None, the PointList of peak
+ positions is created here. If peaks is not None, it is the PointList that
+ detected peaks are added to, and must have the appropriate coords
+ ('qx','qy','intensity').
+ model_path (str): filepath for the model weights (Tensorflow model) to load from.
+ By default, if the model_path is not provided, py4DSTEM will search for the
+ latest model stored on cloud using metadata json file. It is not recommeded to
+ keep track of the model path and advised to keep this argument unchanged (None)
+ to always search for the latest updated training model weights.
+
+ Returns:
+ (PointList): the Bragg peak positions and correlation intensities
+ """
+ try:
+ import crystal4D
+ except:
+ raise ImportError("Import Error: Please install crystal4D before proceeding")
+ try:
+ import tensorflow as tf
+ except:
+ raise ImportError(
+ "Please install tensorflow before proceeding - please check "
+ + "https://www.tensorflow.org/install"
+ + "for more information"
+ )
+
+ assert subpixel in [
+ "none",
+ "poly",
+ "multicorr",
+ ], "Unrecognized subpixel option {}, subpixel must be 'none', 'poly', or 'multicorr'".format(
+ subpixel
+ )
+
+ # Perform any prefiltering
+ if filter_function:
+ assert callable(filter_function), "filter_function must be callable"
+ DP = DP if filter_function is None else filter_function(DP)
+
+ if predict:
+ assert (
+ len(DP.shape) == 2
+ ), "Dimension of single diffraction should be 2 (Qx, Qy)"
+ assert len(probe.shape) == 2, "Dimension of probe should be 2 (Qx, Qy)"
+ model = _get_latest_model(model_path=model_path)
+ DP = tf.expand_dims(tf.expand_dims(DP, axis=0), axis=-1)
+ probe = tf.expand_dims(tf.expand_dims(probe, axis=0), axis=-1)
+ prediction = np.zeros(shape=(1, DP.shape[1], DP.shape[2], 1))
+
+ for i in tqdmnd(
+ num_attempts,
+ desc="Neural network is predicting atomic potential",
+ unit="ATTEMPTS",
+ unit_scale=True,
+ ):
+ prediction += model.predict([DP, probe])
+ print("Averaging over {} attempts \n".format(num_attempts))
+ pred = prediction[0, :, :, 0] / num_attempts
+ else:
+ assert (
+ len(DP.shape) == 2
+ ), "Dimension of single diffraction should be 2 (Qx, Qy)"
+ pred = DP
+
+ maxima = get_maxima_2D(
+ pred,
+ sigma=sigma,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ edgeBoundary=edgeBoundary,
+ relativeToPeak=relativeToPeak,
+ maxNumPeaks=maxNumPeaks,
+ minSpacing=minPeakSpacing,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ )
+
+ # maxima_x, maxima_y, maxima_int = _integrate_disks(pred, maxima_x,maxima_y,maxima_int,int_window_radius=int_window_radius)
+
+ # # Make peaks PointList
+ # if peaks is None:
+ # coords = [('qx',float),('qy',float),('intensity',float)]
+ # peaks = PointList(coordinates=coords)
+ # else:
+ # assert(isinstance(peaks,PointList))
+ # peaks.add_tuple_of_nparrays((maxima_x,maxima_y,maxima_int))
+ maxima = QPoints(maxima)
+ return maxima
+
+
+def find_Bragg_disks_aiml_selected(
+ datacube,
+ probe,
+ Rx,
+ Ry,
+ num_attempts=5,
+ int_window_radius=1,
+ batch_size=1,
+ predict=True,
+ sigma=0,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="multicorr",
+ upsample_factor=16,
+ filter_function=None,
+ model_path=None,
+):
+ """
+ Finds the Bragg disks in the diffraction patterns of datacube at scan positions
+ (Rx,Ry) by AI/ML method. This method utilizes FCU-Net to predict Bragg
+ disks from diffraction images.
+
+ Args:
+ datacube (datacube): a diffraction datacube
+ probe (ndarray): the vacuum probe template
+ num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5
+ Ideally, the more num_attempts the better (confident) the prediction will be
+ as the ML prediction utilizes Monte Carlo Dropout technique to estimate model
+ uncertainty using Bayesian approach. Note: increasing num_attempts will increase
+ the compute time significantly and it is advised to use GPU (CUDA) enabled environment
+ for fast prediction with num_attempts > 1
+ int_window_radius (int): window radius (in pixels) for disk intensity integration over the
+ predicted atomic potentials array
+ predict (bool): Flag to determine if ML prediction is opted.
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the relativeToPeak'th peak
+ minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity,
+ on an absolute scale
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The shape
+ of the returned DP must match the shape of the probe kernel (but does not
+ need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ peaks (PointList): For internal use. If peaks is None, the PointList of peak
+ positions is created here. If peaks is not None, it is the PointList that
+ detected peaks are added to, and must have the appropriate coords
+ ('qx','qy','intensity').
+ model_path (str): filepath for the model weights (Tensorflow model) to load from.
+ By default, if the model_path is not provided, py4DSTEM will search for the
+ latest model stored on cloud using metadata json file. It is not recommended to
+ keep track of the model path and advised to keep this argument unchanged (None)
+ to always search for the latest updated training model weights.
+
+ Returns:
+ (n-tuple of PointLists, n=len(Rx)): the Bragg peak positions and
+ correlation intensities at each scan position (Rx,Ry).
+ """
+
+ try:
+ import crystal4D
+ except:
+ raise ImportError("Import Error: Please install crystal4D before proceeding")
+
+ assert len(Rx) == len(Ry)
+ peaks = []
+
+ if predict:
+ model = _get_latest_model(model_path=model_path)
+ t0 = time()
+ probe = np.expand_dims(
+ np.repeat(np.expand_dims(probe, axis=0), len(Rx), axis=0), axis=-1
+ )
+ DP = np.expand_dims(
+ np.expand_dims(datacube.data[Rx[0], Ry[0], :, :], axis=0), axis=-1
+ )
+ total_DP = len(Rx)
+ for i in range(1, len(Rx)):
+ DP_ = np.expand_dims(
+ np.expand_dims(datacube.data[Rx[i], Ry[i], :, :], axis=0), axis=-1
+ )
+ DP = np.concatenate([DP, DP_], axis=0)
+
+ prediction = np.zeros(shape=(total_DP, datacube.Q_Nx, datacube.Q_Ny, 1))
+
+ image_num = len(Rx)
+ batch_num = int(image_num // batch_size)
+
+ for att in tqdmnd(
+ num_attempts,
+ desc="Neural network is predicting structure factors",
+ unit="ATTEMPTS",
+ unit_scale=True,
+ ):
+ for i in range(batch_num):
+ prediction[i * batch_size : (i + 1) * batch_size] += model.predict(
+ [
+ DP[i * batch_size : (i + 1) * batch_size],
+ probe[i * batch_size : (i + 1) * batch_size],
+ ],
+ verbose=0,
+ )
+ if (i + 1) * batch_size < image_num:
+ prediction[(i + 1) * batch_size :] += model.predict(
+ [DP[(i + 1) * batch_size :], probe[(i + 1) * batch_size :]],
+ verbose=0,
+ )
+
+ prediction = prediction / num_attempts
+
+ # Loop over selected diffraction patterns
+ for Rx in tqdmnd(
+ image_num, desc="Finding Bragg Disks using AI/ML", unit="DP", unit_scale=True
+ ):
+ DP = prediction[Rx, :, :, 0]
+ _peaks = find_Bragg_disks_aiml_single_DP(
+ DP,
+ probe,
+ int_window_radius=int_window_radius,
+ predict=False,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ filter_function=filter_function,
+ model_path=model_path,
+ )
+ peaks.append(_peaks)
+ t2 = time() - t0
+ print(
+ "Analyzed {} diffraction patterns in {}h {}m {}s".format(
+ image_num, int(t2 / 3600), int(t2 / 60), int(t2 % 60)
+ )
+ )
+
+ peaks = tuple(peaks)
+ return peaks
+
+
+def find_Bragg_disks_aiml_serial(
+ datacube,
+ probe,
+ num_attempts=5,
+ int_window_radius=1,
+ predict=True,
+ batch_size=2,
+ sigma=0,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="multicorr",
+ upsample_factor=16,
+ global_threshold=False,
+ minGlobalIntensity=0.005,
+ metric="mean",
+ filter_function=None,
+ name="braggpeaks_raw",
+ model_path=None,
+):
+ """
+ Finds the Bragg disks in all diffraction patterns of datacube from AI/ML method.
+ When hist = True, returns histogram of intensities in the entire datacube.
+
+ Args:
+ datacube (datacube): a diffraction datacube
+ probe (ndarray): the vacuum probe template
+ num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5.
+ Ideally, the more num_attempts the better (confident) the prediction will be
+ as the ML prediction utilizes Monte Carlo Dropout technique to estimate model
+ uncertainty using Bayesian approach. Note: increasing num_attempts will increase
+ the compute time significantly and it is advised to use GPU (CUDA) enabled environment
+ for fast prediction with num_attempts > 1
+ int_window_radius (int): window radius (in pixels) for disk intensity integration over the
+ predicted atomic potentials array
+ predict (bool): Flag to determine if ML prediction is opted.
+ batch_size (int): batch size for Tensorflow model.predict() function, by default batch_size = 2,
+ Note: if you are using CPU for model.predict(), please use batch_size < 2. Future version
+ will implement Dask parrlelization implementation of the serial function to boost up the
+ performance of Tensorflow CPU predictions. Keep in mind that this funciton will take
+ significant amount of time to predict for all the DPs in a datacube.
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the relativeToPeak'th peak
+ minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity,
+ on an absolute scale
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ global_threshold (bool): if True, applies global threshold based on
+ minGlobalIntensity and metric
+ minGlobalThreshold (float): the minimum allowed peak intensity, relative to the
+ selected metric (0-1), except in the case of 'manual' metric, in which the
+ threshold value based on the minimum intensity that you want thresholder
+ out should be set.
+ metric (string): the metric used to compare intensities. 'average' compares peak
+ intensity relative to the average of the maximum intensity in each
+ diffraction pattern. 'max' compares peak intensity relative to the maximum
+ intensity value out of all the diffraction patterns. 'median' compares peak
+ intensity relative to the median of the maximum intensity peaks in each
+ diffraction pattern. 'manual' Allows the user to threshold based on a
+ predetermined intensity value manually determined. In this case,
+ minIntensity should be an int.
+ name (str): name for the returned PointListArray
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The
+ shape of the returned DP must match the shape of the probe kernel (but does
+ not need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ model_path (str): filepath for the model weights (Tensorflow model) to load from.
+ By default, if the model_path is not provided, py4DSTEM will search for the
+ latest model stored on cloud using metadata json file. It is not recommended to
+ keep track of the model path and advised to keep this argument unchanged (None)
+ to always search for the latest updated training model weights.
+
+ Returns:
+ (PointListArray): the Bragg peak positions and correlation intensities
+ """
+
+ try:
+ import crystal4D
+ except:
+ raise ImportError("Import Error: Please install crystal4D before proceeding")
+
+ # Make the peaks PointListArray
+ # dtype = [('qx',float),('qy',float),('intensity',float)]
+ peaks = BraggVectors(datacube.Rshape, datacube.Qshape)
+
+ # check that the filtered DP is the right size for the probe kernel:
+ if filter_function:
+ assert callable(filter_function), "filter_function must be callable"
+ DP = (
+ datacube.data[0, 0, :, :]
+ if filter_function is None
+ else filter_function(datacube.data[0, 0, :, :])
+ )
+ # assert np.all(DP.shape == probe.shape), 'Probe kernel shape must match filtered DP shape'
+
+ if predict:
+ t0 = time()
+ model = _get_latest_model(model_path=model_path)
+ probe = np.expand_dims(
+ np.repeat(np.expand_dims(probe, axis=0), datacube.R_N, axis=0), axis=-1
+ )
+ DP = np.expand_dims(
+ np.reshape(datacube.data, (datacube.R_N, datacube.Q_Nx, datacube.Q_Ny)),
+ axis=-1,
+ )
+
+ prediction = np.zeros(shape=(datacube.R_N, datacube.Q_Nx, datacube.Q_Ny, 1))
+
+ image_num = datacube.R_N
+ batch_num = int(image_num // batch_size)
+
+ for att in tqdmnd(
+ num_attempts,
+ desc="Neural network is predicting structure factors",
+ unit="ATTEMPTS",
+ unit_scale=True,
+ ):
+ for i in range(batch_num):
+ prediction[i * batch_size : (i + 1) * batch_size] += model.predict(
+ [
+ DP[i * batch_size : (i + 1) * batch_size],
+ probe[i * batch_size : (i + 1) * batch_size],
+ ],
+ verbose=0,
+ )
+ if (i + 1) * batch_size < image_num:
+ prediction[(i + 1) * batch_size :] += model.predict(
+ [DP[(i + 1) * batch_size :], probe[(i + 1) * batch_size :]],
+ verbose=0,
+ )
+
+ prediction = prediction / num_attempts
+
+ prediction = np.reshape(
+ np.transpose(prediction, (0, 3, 1, 2)),
+ (datacube.R_Nx, datacube.R_Ny, datacube.Q_Nx, datacube.Q_Ny),
+ )
+
+ # Loop over all diffraction patterns
+ for Rx, Ry in tqdmnd(
+ datacube.R_Nx,
+ datacube.R_Ny,
+ desc="Finding Bragg Disks using AI/ML",
+ unit="DP",
+ unit_scale=True,
+ ):
+ DP_ = prediction[Rx, Ry, :, :]
+ find_Bragg_disks_aiml_single_DP(
+ DP_,
+ probe,
+ num_attempts=num_attempts,
+ int_window_radius=int_window_radius,
+ predict=False,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ filter_function=filter_function,
+ peaks=peaks.vectors_uncal.get_pointlist(Rx, Ry),
+ model_path=model_path,
+ )
+ t2 = time() - t0
+ print(
+ "Analyzed {} diffraction patterns in {}h {}m {}s".format(
+ datacube.R_N, int(t2 / 3600), int(t2 / 60), int(t2 % 60)
+ )
+ )
+
+ if global_threshold is True:
+ from py4DSTEM.braggvectors import universal_threshold
+
+ peaks = universal_threshold(
+ peaks, minGlobalIntensity, metric, minPeakSpacing, maxNumPeaks
+ )
+ peaks.name = name
+ return peaks
+
+
+def find_Bragg_disks_aiml(
+ datacube,
+ probe,
+ num_attempts=5,
+ int_window_radius=1,
+ predict=True,
+ batch_size=8,
+ sigma=0,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="multicorr",
+ upsample_factor=16,
+ name="braggpeaks_raw",
+ filter_function=None,
+ model_path=None,
+ distributed=None,
+ CUDA=True,
+ **kwargs,
+):
+ """
+ Finds the Bragg disks in all diffraction patterns of datacube by AI/ML method. This method
+ utilizes FCU-Net to predict Bragg disks from diffraction images.
+
+ datacube (datacube): a diffraction datacube
+ probe (ndarray): the vacuum probe template
+ num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5.
+ Ideally, the more num_attempts the better (confident) the prediction will be
+ as the ML prediction utilizes Monte Carlo Dropout technique to estimate model
+ uncertainty using Bayesian approach. Note: increasing num_attempts will increase
+ the compute time significantly and it is advised to use GPU (CUDA) enabled environment
+ for fast prediction with num_attempts > 1
+ int_window_radius (int): window radius (in pixels) for disk intensity integration over the
+ predicted atomic potentials array
+ predict (bool): Flag to determine if ML prediction is opted.
+ batch_size (int): batch size for Tensorflow model.predict() function, by default batch_size = 2,
+ Note: if you are using CPU for model.predict(), please use batch_size < 2. Future version
+ will implement Dask parrlelization implementation of the serial function to boost up the
+ performance of Tensorflow CPU predictions. Keep in mind that this funciton will take
+ significant amount of time to predict for all the DPs in a datacube.
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the relativeToPeak'th peak
+ minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity,
+ on an absolute scale
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ global_threshold (bool): if True, applies global threshold based on
+ minGlobalIntensity and metric
+ minGlobalThreshold (float): the minimum allowed peak intensity, relative to the
+ selected metric (0-1), except in the case of 'manual' metric, in which the
+ threshold value based on the minimum intensity that you want thresholder
+ out should be set.
+ metric (string): the metric used to compare intensities. 'average' compares peak
+ intensity relative to the average of the maximum intensity in each
+ diffraction pattern. 'max' compares peak intensity relative to the maximum
+ intensity value out of all the diffraction patterns. 'median' compares peak
+ intensity relative to the median of the maximum intensity peaks in each
+ diffraction pattern. 'manual' Allows the user to threshold based on a
+ predetermined intensity value manually determined. In this case,
+ minIntensity should be an int.
+ name (str): name for the returned PointListArray
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The
+ shape of the returned DP must match the shape of the probe kernel (but does
+ not need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ model_path (str): filepath for the model weights (Tensorflow model) to load from.
+ By default, if the model_path is not provided, py4DSTEM will search for the
+ latest model stored on cloud using metadata json file. It is not recommended to
+ keep track of the model path and advised to keep this argument unchanged (None)
+ to always search for the latest updated training model weights.
+ distributed (dict): contains information for parallelprocessing using an
+ IPyParallel or Dask distributed cluster. Valid keys are:
+ * ipyparallel (dict):
+ * client_file (str): path to client json for connecting to your
+ existing IPyParallel cluster
+ * dask (dict):
+ client (object): a dask client that connects to your
+ existing Dask cluster
+ * data_file (str): the absolute path to your original data
+ file containing the datacube
+ * cluster_path (str): defaults to the working directory during processing
+ if distributed is None, which is the default, processing will be in serial
+ CUDA (bool): When True, py4DSTEM will use CUDA-enabled disk_detection_aiml function
+
+ Returns:
+ (PointListArray): the Bragg peak positions and correlation intensities
+ """
+ try:
+ import crystal4D
+ except:
+ raise ImportError("Please install crystal4D before proceeding")
+
+ def _parse_distributed(distributed):
+ import os
+
+ if "ipyparallel" in distributed:
+ if "client_file" in distributed["ipyparallel"]:
+ connect = distributed["ipyparallel"]["client_file"]
+ else:
+ raise KeyError(
+ 'Within distributed["ipyparallel"], missing key for "client_file"'
+ )
+
+ try:
+ import ipyparallel as ipp
+
+ c = ipp.Client(url_file=connect, timeout=30)
+
+ if len(c.ids) == 0:
+ raise RuntimeError("No IPyParallel engines attached to cluster!")
+ except ImportError:
+ raise ImportError("Unable to import module ipyparallel!")
+ elif "dask" in distributed:
+ if "client" in distributed["dask"]:
+ connect = distributed["dask"]["client"]
+ else:
+ raise KeyError('Within distributed["dask"], missing key for "client"')
+ else:
+ raise KeyError(
+ "Within distributed, you must specify 'ipyparallel' or 'dask'!"
+ )
+
+ if "data_file" not in distributed:
+ raise KeyError(
+ "Missing input data file path to distributed! Required key 'data_file'"
+ )
+
+ data_file = distributed["data_file"]
+
+ if not isinstance(data_file, str):
+ raise TypeError(
+ "Expected string for distributed key 'data_file', received {}".format(
+ type(data_file)
+ )
+ )
+ if len(data_file.strip()) == 0:
+ raise ValueError("Empty data file path from distributed key 'data_file'")
+ elif not os.path.exists(data_file):
+ raise FileNotFoundError("File not found")
+
+ if "cluster_path" in distributed:
+ cluster_path = distributed["cluster_path"]
+
+ if not isinstance(cluster_path, str):
+ raise TypeError(
+ "distributed key 'cluster_path' must be of type str, received {}".format(
+ type(cluster_path)
+ )
+ )
+
+ if len(cluster_path.strip()) == 0:
+ raise ValueError(
+ "distributed key 'cluster_path' cannot be an empty string!"
+ )
+ elif not os.path.exists(cluster_path):
+ raise FileNotFoundError(
+ "distributed key 'cluster_path' does not exist: {}".format(
+ cluster_path
+ )
+ )
+ elif not os.path.isdir(cluster_path):
+ raise NotADirectoryError(
+ "distributed key 'cluster_path' is not a directory: {}".format(
+ cluster_path
+ )
+ )
+ else:
+ cluster_path = None
+
+ return connect, data_file, cluster_path
+
+ if distributed is None:
+ import warnings
+
+ if not CUDA:
+ if _check_cuda_device_available():
+ warnings.warn(
+ "WARNING: CUDA = False is selected but py4DSTEM found available CUDA device to speed up. Going ahead anyway with non-CUDA mode (CPU only). You may want to abort and switch to CUDA = True to speed things up... \n"
+ )
+ if num_attempts > 1:
+ warnings.warn(
+ "WARNING: num_attempts > 1 will take significant amount of time with Non-CUDA mode ..."
+ )
+ return find_Bragg_disks_aiml_serial(
+ datacube,
+ probe,
+ num_attempts=num_attempts,
+ int_window_radius=int_window_radius,
+ predict=predict,
+ batch_size=batch_size,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ model_path=model_path,
+ name=name,
+ filter_function=filter_function,
+ )
+ elif _check_cuda_device_available():
+ from py4DSTEM.braggvectors.diskdetection_aiml_cuda import (
+ find_Bragg_disks_aiml_CUDA,
+ )
+
+ return find_Bragg_disks_aiml_CUDA(
+ datacube,
+ probe,
+ num_attempts=num_attempts,
+ int_window_radius=int_window_radius,
+ predict=predict,
+ batch_size=batch_size,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ model_path=model_path,
+ name=name,
+ filter_function=filter_function,
+ )
+ else:
+ import warnings
+
+ warnings.warn(
+ "WARNING: py4DSTEM attempted to speed up the process using GPUs but no CUDA enabled devices are found. Switching back to Non-CUDA (CPU only) mode (Note it will take significant amount of time to get AIML predictions for disk detection using CPUs!!!!) \n"
+ )
+ if num_attempts > 1:
+ warnings.warn(
+ "WARNING: num_attempts > 1 will take significant amount of time with Non-CUDA mode ..."
+ )
+ return find_Bragg_disks_aiml_serial(
+ datacube,
+ probe,
+ num_attempts=num_attempts,
+ int_window_radius=int_window_radius,
+ predict=predict,
+ batch_size=batch_size,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ model_path=model_path,
+ name=name,
+ filter_function=filter_function,
+ )
+
+ elif isinstance(distributed, dict):
+ raise Exception(
+ "{} is not yet implemented for aiml pipeline".format(type(distributed))
+ )
+ else:
+ raise Exception(
+ "Expected type dict or None for distributed, instead found : {}".format(
+ type(distributed)
+ )
+ )
+
+
+def _integrate_disks(DP, maxima_x, maxima_y, maxima_int, int_window_radius=1):
+ """
+ Integrate DP over the circular patch of pixel with radius
+ """
+ disks = []
+ img_size = DP.shape[0]
+ for x, y, i in zip(maxima_x, maxima_y, maxima_int):
+ r1, r2 = np.ogrid[-x : img_size - x, -y : img_size - y]
+ mask = r1**2 + r2**2 <= int_window_radius**2
+ mask_arr = np.zeros((img_size, img_size))
+ mask_arr[mask] = 1
+ disk = DP * mask_arr
+ disks.append(np.average(disk))
+ try:
+ disks = disks / max(disks)
+ except:
+ pass
+ return (maxima_x, maxima_y, disks)
+
+
+def _check_cuda_device_available():
+ """
+ Check if GPU is available to use by python/tensorflow.
+ """
+
+ import tensorflow as tf
+
+ tf_recog_gpus = tf.config.experimental.list_physical_devices("GPU")
+
+ if len(tf_recog_gpus) > 0:
+ return True
+ else:
+ return False
+
+
+def _get_latest_model(model_path=None):
+ """
+ get the latest tensorflow model and model weights for disk detection
+
+ Args:
+ model_path (filepath string): File path for the tensorflow models stored in local system,
+ if provided, disk detection will be performed loading the model provided by user.
+ By default, there is no need to provide any file path unless specifically required for
+ development/debug purpose. If None, _get_latest_model() will look up the latest model
+ from cloud and download and load them.
+
+ Returns:
+ model: Trained tensorflow model for disk detection
+ """
+ import crystal4D
+
+ try:
+ import tensorflow as tf
+ except:
+ raise ImportError(
+ "Please install tensorflow before proceeding - please check "
+ + "https://www.tensorflow.org/install"
+ + "for more information"
+ )
+ from py4DSTEM.io.google_drive_downloader import download_file_from_google_drive
+
+ tf.keras.backend.clear_session()
+
+ if model_path is None:
+ try:
+ os.mkdir("./tmp")
+ except:
+ pass
+ # download the json file with the meta data
+ download_file_from_google_drive("FCU-Net", "./tmp/model_metadata.json")
+ with open("./tmp/model_metadata.json") as f:
+ metadata = json.load(f)
+ file_id = metadata["file_id"]
+ file_path = metadata["file_path"]
+ file_type = metadata["file_type"]
+
+ try:
+ with open("./tmp/model_metadata_old.json") as f_old:
+ metaold = json.load(f_old)
+ file_id_old = metaold["file_id"]
+ except:
+ file_id_old = file_id
+
+ if os.path.exists(file_path) and file_id == file_id_old:
+ print(
+ "Latest model weight is already available in the local system. Loading the model... \n"
+ )
+ model_path = file_path
+ os.remove("./tmp/model_metadata_old.json")
+ os.rename("./tmp/model_metadata.json", "./tmp/model_metadata_old.json")
+ else:
+ print("Checking the latest model on the cloud... \n")
+ filename = file_path + file_type
+ download_file_from_google_drive(file_id, filename)
+ try:
+ shutil.unpack_archive(filename, "./tmp", format="zip")
+ except:
+ pass
+ model_path = file_path
+ os.rename("./tmp/model_metadata.json", "./tmp/model_metadata_old.json")
+ print("Loading the model... \n")
+
+ model = tf.keras.models.load_model(
+ model_path,
+ custom_objects={"lrScheduler": crystal4D.utils.utils.lrScheduler(128)},
+ )
+ else:
+ print("Loading the user provided model... \n")
+ model = tf.keras.models.load_model(
+ model_path,
+ custom_objects={"lrScheduler": crystal4D.utils.utils.lrScheduler(128)},
+ )
+
+ return model
diff --git a/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py
new file mode 100644
index 000000000..d23770586
--- /dev/null
+++ b/py4DSTEM/braggvectors/diskdetection_aiml_cuda.py
@@ -0,0 +1,738 @@
+# Functions for finding Bragg disks using AI/ML pipeline (CUDA version)
+"""
+Functions for finding Braggdisks (AI/ML) using cupy and tensorflow-gpu
+"""
+
+import numpy as np
+from time import time
+
+from emdfile import tqdmnd
+from py4DSTEM.braggvectors.braggvectors import BraggVectors
+from emdfile import PointList, PointListArray
+from py4DSTEM.data import QPoints
+from py4DSTEM.braggvectors.kernels import kernels
+from py4DSTEM.braggvectors.diskdetection_aiml import _get_latest_model
+
+# from py4DSTEM.braggvectors.diskdetection import universal_threshold
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ raise ImportError("AIML CUDA Requires cupy")
+
+try:
+ import tensorflow as tf
+except:
+ raise ImportError(
+ "Please install tensorflow before proceeding - please check "
+ + "https://www.tensorflow.org/install"
+ + "for more information"
+ )
+
+from cupyx.scipy.ndimage import gaussian_filter
+
+
+def find_Bragg_disks_aiml_CUDA(
+ datacube,
+ probe,
+ num_attempts=5,
+ int_window_radius=1,
+ predict=True,
+ batch_size=8,
+ sigma=0,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="multicorr",
+ upsample_factor=16,
+ global_threshold=False,
+ minGlobalIntensity=0.005,
+ metric="mean",
+ filter_function=None,
+ name="braggpeaks_raw",
+ model_path=None,
+):
+ """
+ Finds the Bragg disks in all diffraction patterns of datacube by AI/ML method (CUDA version)
+ This method utilizes FCU-Net to predict Bragg disks from diffraction images.
+
+ Args:
+ datacube (datacube): a diffraction datacube
+ probe (ndarray): the vacuum probe template
+ num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5.
+ Ideally, the more num_attempts the better (confident) the prediction will be
+ as the ML prediction utilizes Monte Carlo Dropout technique to estimate model
+ uncertainty using Bayesian approach. Note: increasing num_attempts will increase
+ the compute time significantly and it is advised to use GPU (CUDA) enabled environment
+ for fast prediction with num_attempts > 1
+ int_window_radius (int): window radius (in pixels) for disk intensity integration over the
+ predicted atomic potentials array
+ predict (bool): Flag to determine if ML prediction is opted.
+ batch_size (int): batch size for Tensorflow model.predict() function, by default batch_size = 2,
+ Note: if you are using CPU for model.predict(), please use batch_size < 2. Future version
+ will implement Dask parrlelization implementation of the serial function to boost up the
+ performance of Tensorflow CPU predictions. Keep in mind that this funciton will take
+ significant amount of time to predict for all the DPs in a datacube.
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the relativeToPeak'th peak
+ minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity,
+ on an absolute scale
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ global_threshold (bool): if True, applies global threshold based on
+ minGlobalIntensity and metric
+ minGlobalThreshold (float): the minimum allowed peak intensity, relative to the
+ selected metric (0-1), except in the case of 'manual' metric, in which the
+ threshold value based on the minimum intensity that you want thresholder
+ out should be set.
+ metric (string): the metric used to compare intensities. 'average' compares peak
+ intensity relative to the average of the maximum intensity in each
+ diffraction pattern. 'max' compares peak intensity relative to the maximum
+ intensity value out of all the diffraction patterns. 'median' compares peak
+ intensity relative to the median of the maximum intensity peaks in each
+ diffraction pattern. 'manual' Allows the user to threshold based on a
+ predetermined intensity value manually determined. In this case,
+ minIntensity should be an int.
+ name (str): name for the returned PointListArray
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The
+ shape of the returned DP must match the shape of the probe kernel (but does
+ not need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ model_path (str): filepath for the model weights (Tensorflow model) to load from.
+ By default, if the model_path is not provided, py4DSTEM will search for the
+ latest model stored on cloud using metadata json file. It is not recommended to
+ keep track of the model path and advised to keep this argument unchanged (None)
+ to always search for the latest updated training model weights.
+
+ Returns:
+ (PointListArray): the Bragg peak positions and correlation intensities
+ """
+
+ # Make the peaks PointListArray
+ # dtype = [('qx',float),('qy',float),('intensity',float)]
+ peaks = BraggVectors(datacube.Rshape, datacube.Qshape)
+
+ # check that the filtered DP is the right size for the probe kernel:
+ if filter_function:
+ assert callable(filter_function), "filter_function must be callable"
+ DP = (
+ datacube.data[0, 0, :, :]
+ if filter_function is None
+ else filter_function(datacube.data[0, 0, :, :])
+ )
+ assert np.all(
+ DP.shape == probe.shape
+ ), "Probe kernel shape must match filtered DP shape"
+
+ get_maximal_points = kernels["maximal_pts_float64"]
+
+ if get_maximal_points.max_threads_per_block < DP.shape[1]:
+ blocks = ((np.prod(DP.shape) // get_maximal_points.max_threads_per_block + 1),)
+ threads = get_maximal_points.max_threads_per_block
+ else:
+ blocks = (DP.shape[0],)
+ threads = (DP.shape[1],)
+
+ if predict:
+ t0 = time()
+ model = _get_latest_model(model_path=model_path)
+ prediction = np.zeros(shape=(datacube.R_N, datacube.Q_Nx, datacube.Q_Ny, 1))
+
+ image_num = datacube.R_N
+ batch_num = int(image_num // batch_size)
+
+ datacube_flattened = datacube.data.view()
+ datacube_flattened = datacube_flattened.reshape(
+ datacube.R_N, datacube.Q_Nx, datacube.Q_Ny
+ )
+
+ for att in tqdmnd(
+ num_attempts,
+ desc="Neural network is predicting structure factors",
+ unit="ATTEMPTS",
+ unit_scale=True,
+ ):
+ for batch_idx in range(batch_num):
+ # the final batch may be smaller than the other ones:
+ probes_remaining = datacube.R_N - (batch_idx * batch_size)
+ this_batch_size = (
+ probes_remaining if probes_remaining < batch_size else batch_size
+ )
+ DP = tf.expand_dims(
+ datacube_flattened[
+ batch_idx * batch_size : batch_idx * batch_size
+ + this_batch_size
+ ],
+ axis=-1,
+ )
+ _probe = tf.expand_dims(
+ tf.repeat(tf.expand_dims(probe, axis=0), this_batch_size, axis=0),
+ axis=-1,
+ )
+ prediction[
+ batch_idx * batch_size : batch_idx * batch_size + this_batch_size
+ ] += model.predict([DP, _probe])
+
+ print("Averaging over {} attempts \n".format(num_attempts))
+ prediction = prediction / num_attempts
+
+ prediction = np.reshape(
+ np.transpose(prediction, (0, 3, 1, 2)),
+ (datacube.R_Nx, datacube.R_Ny, datacube.Q_Nx, datacube.Q_Ny),
+ )
+
+ # Loop over all diffraction patterns
+ for Rx, Ry in tqdmnd(
+ datacube.R_Nx,
+ datacube.R_Ny,
+ desc="Finding Bragg Disks using AI/ML CUDA",
+ unit="DP",
+ unit_scale=True,
+ ):
+ DP = prediction[Rx, Ry, :, :]
+ _find_Bragg_disks_aiml_single_DP_CUDA(
+ DP,
+ probe,
+ num_attempts=num_attempts,
+ int_window_radius=int_window_radius,
+ predict=False,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ filter_function=filter_function,
+ peaks=peaks.vectors_uncal.get_pointlist(Rx, Ry),
+ get_maximal_points=get_maximal_points,
+ blocks=blocks,
+ threads=threads,
+ )
+ t2 = time() - t0
+ print(
+ "Analyzed {} diffraction patterns in {}h {}m {}s".format(
+ datacube.R_N, int(t2 / 3600), int(t2 / 60), int(t2 % 60)
+ )
+ )
+ if global_threshold is True:
+ from py4DSTEM.braggvectors import universal_threshold
+
+ peaks = universal_threshold(
+ peaks, minGlobalIntensity, metric, minPeakSpacing, maxNumPeaks
+ )
+ peaks.name = name
+ return peaks
+
+
+def _find_Bragg_disks_aiml_single_DP_CUDA(
+ DP,
+ probe,
+ num_attempts=5,
+ int_window_radius=1,
+ predict=True,
+ sigma=0,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="multicorr",
+ upsample_factor=16,
+ filter_function=None,
+ return_cc=False,
+ peaks=None,
+ get_maximal_points=None,
+ blocks=None,
+ threads=None,
+ model_path=None,
+ **kwargs,
+):
+ """
+ Finds the Bragg disks in single DP by AI/ML method. This method utilizes FCU-Net
+ to predict Bragg disks from diffraction images.
+
+ The input DP and Probes need to be aligned before the prediction. Detected peaks within
+ edgeBoundary pixels of the diffraction plane edges are then discarded. Next, peaks
+ with intensities less than minRelativeIntensity of the brightest peak in the
+ correlation are discarded. Then peaks which are within a distance of minPeakSpacing
+ of their nearest neighbor peak are found, and in each such pair the peak with the
+ lesser correlation intensities is removed. Finally, if the number of peaks remaining
+ exceeds maxNumPeaks, only the maxNumPeaks peaks with the highest correlation
+ intensity are retained.
+
+ Args:
+ DP (ndarray): a diffraction pattern
+ probe (ndarray): the vacuum probe template
+ num_attempts (int): Number of attempts to predict the Bragg disks. Recommended: 5
+ Ideally, the more num_attempts the better (confident) the prediction will be
+ as the ML prediction utilizes Monte Carlo Dropout technique to estimate model
+ uncertainty using Bayesian approach. Note: increasing num_attempts will increase
+ the compute time significantly and it is advised to use GPU (CUDA) enabled environment
+ for fast prediction with num_attempts > 1
+ int_window_radius (int): window radius (in pixels) for disk intensity integration over the
+ predicted atomic potentials array
+ predict (bool): Flag to determine if ML prediction is opted.
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the relativeToPeak'th peak
+ minAbsoluteIntensity (float): the minimum acceptable correlation peak intensity,
+ on an absolute scale
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The shape
+ of the returned DP must match the shape of the probe kernel (but does not
+ need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ peaks (PointList): For internal use. If peaks is None, the PointList of peak
+ positions is created here. If peaks is not None, it is the PointList that
+ detected peaks are added to, and must have the appropriate coords
+ ('qx','qy','intensity').
+ model_path (str): filepath for the model weights (Tensorflow model) to load from.
+ By default, if the model_path is not provided, py4DSTEM will search for the
+ latest model stored on cloud using metadata json file. It is not recommeded to
+ keep track of the model path and advised to keep this argument unchanged (None)
+ to always search for the latest updated training model weights.
+
+ Returns:
+ peaks (PointList) the Bragg peak positions and correlation intensities
+ """
+ assert subpixel in [
+ "none",
+ "poly",
+ "multicorr",
+ ], "Unrecognized subpixel option {}, subpixel must be 'none', 'poly', or 'multicorr'".format(
+ subpixel
+ )
+
+ if predict:
+ assert (
+ len(DP.shape) == 2
+ ), "Dimension of single diffraction should be 2 (Qx, Qy)"
+ assert len(probe.shape) == 2, "Dimension of Probe should be 2 (Qx, Qy)"
+
+ model = _get_latest_model(model_path=model_path)
+ DP = tf.expand_dims(tf.expand_dims(DP, axis=0), axis=-1)
+ probe = tf.expand_dims(tf.expand_dims(probe, axis=0), axis=-1)
+ prediction = np.zeros(shape=(1, DP.shape[1], DP.shape[2], 1))
+
+ for att in tqdmnd(
+ num_attempts,
+ desc="Neural network is predicting structure factors",
+ unit="ATTEMPTS",
+ unit_scale=True,
+ ):
+ print("attempt {} \n".format(att + 1))
+ prediction += model.predict([DP, probe])
+ print("Averaging over {} attempts \n".format(num_attempts))
+ pred = cp.array(prediction[0, :, :, 0] / num_attempts, dtype="float64")
+ else:
+ assert (
+ len(DP.shape) == 2
+ ), "Dimension of single diffraction should be 2 (Qx, Qy)"
+ pred = cp.array(
+ DP if filter_function is None else filter_function(DP), dtype="float64"
+ )
+
+ # Find the maxima
+ maxima_x, maxima_y, maxima_int = get_maxima_2D_cp(
+ pred,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ get_maximal_points=get_maximal_points,
+ blocks=blocks,
+ threads=threads,
+ )
+
+ maxima_x, maxima_y, maxima_int = _integrate_disks_cp(
+ pred, maxima_x, maxima_y, maxima_int, int_window_radius=int_window_radius
+ )
+
+ # Make peaks PointList
+ if peaks is None:
+ coords = [("qx", float), ("qy", float), ("intensity", float)]
+ peaks = PointList(coordinates=coords)
+ else:
+ assert isinstance(peaks, PointList)
+ peaks.add_data_by_field((maxima_x, maxima_y, maxima_int))
+
+ return peaks
+
+
+def get_maxima_2D_cp(
+ ar,
+ sigma=0,
+ edgeBoundary=0,
+ minSpacing=0,
+ minRelativeIntensity=0,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ maxNumPeaks=0,
+ subpixel="poly",
+ ar_FT=None,
+ upsample_factor=16,
+ get_maximal_points=None,
+ blocks=None,
+ threads=None,
+):
+ """
+ Finds the indices where the 2D array ar is a local maximum.
+ Optional parameters allow blurring of the array and filtering of the output;
+ setting each of these to 0 (default) turns off these functions.
+
+ Accepts:
+ ar (ndarray) a 2D array
+ sigma (float) guassian blur std to apply to ar before finding the maxima
+ edgeBoundary (int) ignore maxima within edgeBoundary of the array edge
+ minSpacing (float) if two maxima are found within minSpacing, the dimmer one
+ is removed
+ minRelativeIntensity (float) maxima dimmer than minRelativeIntensity compared to the
+ relativeToPeak'th brightest maximum are removed
+ minAbsoluteIntensity (float) the minimum acceptable correlation peak intensity,
+ on an absolute scale
+ relativeToPeak (int) 0=brightest maximum. 1=next brightest, etc.
+ maxNumPeaks (int) return only the first maxNumPeaks maxima
+ subpixel (str) 'none': no subpixel fitting
+ (default) 'poly': polynomial interpolation of correlogram peaks
+ (fairly fast but not very accurate)
+ 'multicorr': uses the multicorr algorithm with
+ DFT upsampling
+ ar_FT (None or complex array) if subpixel=='multicorr' the
+ fourier transform of the image is required. It may be
+ passed here as a complex array. Otherwise, if ar_FT is None,
+ it is computed
+ upsample_factor (int) required iff subpixel=='multicorr'
+
+ Returns
+ maxima_x (ndarray) x-coords of the local maximum, sorted by intensity.
+ maxima_y (ndarray) y-coords of the local maximum, sorted by intensity.
+ maxima_intensity (ndarray) intensity of the local maxima
+ """
+ assert subpixel in [
+ "none",
+ "poly",
+ "multicorr",
+ ], "Unrecognized subpixel option {}, subpixel must be 'none', 'poly', or 'multicorr'".format(
+ subpixel
+ )
+
+ # Get maxima
+ ar = gaussian_filter(ar, sigma)
+ maxima_bool = cp.zeros_like(ar, dtype=bool)
+ sizex = ar.shape[0]
+ sizey = ar.shape[1]
+ N = sizex * sizey
+ get_maximal_points(
+ blocks, threads, (ar, maxima_bool, minAbsoluteIntensity, sizex, sizey, N)
+ )
+ # get_maximal_points(blocks,threads,(ar,maxima_bool,sizex,sizey,N))
+
+ # Remove edges
+ if edgeBoundary > 0:
+ assert isinstance(edgeBoundary, (int, np.integer))
+ maxima_bool[:edgeBoundary, :] = False
+ maxima_bool[-edgeBoundary:, :] = False
+ maxima_bool[:, :edgeBoundary] = False
+ maxima_bool[:, -edgeBoundary:] = False
+ elif subpixel is True:
+ maxima_bool[:1, :] = False
+ maxima_bool[-1:, :] = False
+ maxima_bool[:, :1] = False
+ maxima_bool[:, -1:] = False
+
+ # Get indices, sorted by intensity
+ maxima_x, maxima_y = cp.nonzero(maxima_bool)
+ maxima_x = maxima_x.get()
+ maxima_y = maxima_y.get()
+ dtype = np.dtype([("x", float), ("y", float), ("intensity", float)])
+ maxima = np.zeros(len(maxima_x), dtype=dtype)
+ maxima["x"] = maxima_x
+ maxima["y"] = maxima_y
+
+ ar = ar.get()
+ maxima["intensity"] = ar[maxima_x, maxima_y]
+ maxima = np.sort(maxima, order="intensity")[::-1]
+
+ if len(maxima) > 0:
+ # Remove maxima which are too close
+ if minSpacing > 0:
+ deletemask = np.zeros(len(maxima), dtype=bool)
+ for i in range(len(maxima)):
+ if deletemask[i] is False:
+ tooClose = (
+ (maxima["x"] - maxima["x"][i]) ** 2
+ + (maxima["y"] - maxima["y"][i]) ** 2
+ ) < minSpacing**2
+ tooClose[: i + 1] = False
+ deletemask[tooClose] = True
+ maxima = np.delete(maxima, np.nonzero(deletemask)[0])
+
+ # Remove maxima which are too dim
+ if (minRelativeIntensity > 0) & (len(maxima) > relativeToPeak):
+ assert isinstance(relativeToPeak, (int, np.integer))
+ deletemask = (
+ maxima["intensity"] / maxima["intensity"][relativeToPeak]
+ < minRelativeIntensity
+ )
+ maxima = np.delete(maxima, np.nonzero(deletemask)[0])
+
+ # Remove maxima which are too dim, absolute scale
+ if minAbsoluteIntensity > 0:
+ deletemask = maxima["intensity"] < minAbsoluteIntensity
+ maxima = np.delete(maxima, np.nonzero(deletemask)[0])
+
+ # Remove maxima in excess of maxNumPeaks
+ if maxNumPeaks > 0:
+ assert isinstance(maxNumPeaks, (int, np.integer))
+ if len(maxima) > maxNumPeaks:
+ maxima = maxima[:maxNumPeaks]
+
+ # Subpixel fitting
+ # For all subpixel fitting, first fit 1D parabolas in x and y to 3 points (maximum, +/- 1 pixel)
+ if subpixel != "none":
+ for i in range(len(maxima)):
+ Ix1_ = ar[int(maxima["x"][i]) - 1, int(maxima["y"][i])]
+ Ix0 = ar[int(maxima["x"][i]), int(maxima["y"][i])]
+ Ix1 = ar[int(maxima["x"][i]) + 1, int(maxima["y"][i])]
+ Iy1_ = ar[int(maxima["x"][i]), int(maxima["y"][i]) - 1]
+ Iy0 = ar[int(maxima["x"][i]), int(maxima["y"][i])]
+ Iy1 = ar[int(maxima["x"][i]), int(maxima["y"][i]) + 1]
+ deltax = (Ix1 - Ix1_) / (4 * Ix0 - 2 * Ix1 - 2 * Ix1_)
+ deltay = (Iy1 - Iy1_) / (4 * Iy0 - 2 * Iy1 - 2 * Iy1_)
+ maxima["x"][i] += deltax
+ maxima["y"][i] += deltay
+ maxima["intensity"][i] = linear_interpolation_2D_cp(
+ ar, maxima["x"][i], maxima["y"][i]
+ )
+ # Further refinement with fourier upsampling
+ if subpixel == "multicorr":
+ if ar_FT is None:
+ ar_FT = cp.conj(cp.fft.fft2(cp.array(ar)))
+ else:
+ ar_FT = cp.conj(ar_FT)
+ for ipeak in range(len(maxima["x"])):
+ xyShift = np.array((maxima["x"][ipeak], maxima["y"][ipeak]))
+ # we actually have to lose some precision and go down to half-pixel
+ # accuracy. this could also be done by a single upsampling at factor 2
+ # instead of get_maxima_2D_cp.
+ xyShift[0] = np.round(xyShift[0] * 2) / 2
+ xyShift[1] = np.round(xyShift[1] * 2) / 2
+
+ subShift = upsampled_correlation_cp(ar_FT, upsample_factor, xyShift)
+ maxima["x"][ipeak] = subShift[0]
+ maxima["y"][ipeak] = subShift[1]
+
+ return maxima["x"], maxima["y"], maxima["intensity"]
+
+
+def upsampled_correlation_cp(imageCorr, upsampleFactor, xyShift):
+ """
+ Refine the correlation peak of imageCorr around xyShift by DFT upsampling using cupy.
+
+ Args:
+ imageCorr (complex valued ndarray):
+ Complex product of the FFTs of the two images to be registered
+ i.e. m = np.fft.fft2(DP) * probe_kernel_FT;
+ imageCorr = np.abs(m)**(corrPower) * np.exp(1j*np.angle(m))
+ upsampleFactor (int):
+ Upsampling factor. Must be greater than 2. (To do upsampling
+ with factor 2, use upsampleFFT, which is faster.)
+ xyShift:
+ Location in original image coordinates around which to upsample the
+ FT. This should be given to exactly half-pixel precision to
+ replicate the initial FFT step that this implementation skips
+
+ Returns:
+ (2-element np array): Refined location of the peak in image coordinates.
+ """
+
+ # -------------------------------------------------------------------------------------
+ # There are two approaches to Fourier upsampling for subpixel refinement: (a) one
+ # can pad an (appropriately shifted) FFT with zeros and take the inverse transform,
+ # or (b) one can compute the DFT by matrix multiplication using modified
+ # transformation matrices. The former approach is straightforward but requires
+ # performing the FFT algorithm (which is fast) on very large data. The latter method
+ # trades one speedup for a slowdown elsewhere: the matrix multiply steps are expensive
+ # but we operate on smaller matrices. Since we are only interested in a very small
+ # region of the FT around a peak of interest, we use the latter method to get
+ # a substantial speedup and enormous decrease in memory requirement. This
+ # "DFT upsampling" approach computes the transformation matrices for the matrix-
+ # multiply DFT around a small 1.5px wide region in the original `imageCorr`.
+
+ # Following the matrix multiply DFT we use parabolic subpixel fitting to
+ # get even more precision! (below 1/upsampleFactor pixels)
+
+ # NOTE: previous versions of multiCorr operated in two steps: using the zero-
+ # padding upsample method for a first-pass factor-2 upsampling, followed by the
+ # DFT upsampling (at whatever user-specified factor). I have implemented it
+ # differently, to better support iterating over multiple peaks. **The DFT is always
+ # upsampled around xyShift, which MUST be specified to HALF-PIXEL precision
+ # (no more, no less) to replicate the behavior of the factor-2 step.**
+ # (It is possible to refactor this so that peak detection is done on a Fourier
+ # upsampled image rather than using the parabolic subpixel and rounding as now...
+ # I like keeping it this way because all of the parameters and logic will be identical
+ # to the other subpixel methods.)
+ # -------------------------------------------------------------------------------------
+
+ assert upsampleFactor > 2
+
+ xyShift[0] = np.round(xyShift[0] * upsampleFactor) / upsampleFactor
+ xyShift[1] = np.round(xyShift[1] * upsampleFactor) / upsampleFactor
+
+ globalShift = np.fix(np.ceil(upsampleFactor * 1.5) / 2)
+
+ upsampleCenter = globalShift - upsampleFactor * xyShift
+
+ imageCorrUpsample = cp.conj(
+ dftUpsample_cp(imageCorr, upsampleFactor, upsampleCenter)
+ ).get()
+
+ xySubShift = np.unravel_index(imageCorrUpsample.argmax(), imageCorrUpsample.shape)
+
+ # add a subpixel shift via parabolic fitting
+ try:
+ icc = np.real(
+ imageCorrUpsample[
+ xySubShift[0] - 1 : xySubShift[0] + 2,
+ xySubShift[1] - 1 : xySubShift[1] + 2,
+ ]
+ )
+ dx = (icc[2, 1] - icc[0, 1]) / (4 * icc[1, 1] - 2 * icc[2, 1] - 2 * icc[0, 1])
+ dy = (icc[1, 2] - icc[1, 0]) / (4 * icc[1, 1] - 2 * icc[1, 2] - 2 * icc[1, 0])
+ except:
+ dx, dy = (
+ 0,
+ 0,
+ ) # this is the case when the peak is near the edge and one of the above values does not exist
+
+ xySubShift = xySubShift - globalShift
+
+ xyShift = xyShift + (xySubShift + np.array([dx, dy])) / upsampleFactor
+
+ return xyShift
+
+
+def dftUpsample_cp(imageCorr, upsampleFactor, xyShift):
+ """
+ This performs a matrix multiply DFT around a small neighboring region of the inital
+ correlation peak. By using the matrix multiply DFT to do the Fourier upsampling, the
+ efficiency is greatly improved. This is adapted from the subfuction dftups found in
+ the dftregistration function on the Matlab File Exchange.
+
+ https://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation
+
+ The matrix multiplication DFT is from:
+
+ Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, "Efficient subpixel
+ image registration algorithms," Opt. Lett. 33, 156-158 (2008).
+ http://www.sciencedirect.com/science/article/pii/S0045790612000778
+
+ Args:
+ imageCorr (complex valued ndarray):
+ Correlation image between two images in Fourier space.
+ upsampleFactor (int):
+ Scalar integer of how much to upsample.
+ xyShift (list of 2 floats):
+ Coordinates in the UPSAMPLED GRID around which to upsample.
+ These must be single-pixel IN THE UPSAMPLED GRID
+
+ Returns:
+ (ndarray):
+ Upsampled image from region around correlation peak.
+ """
+ imageSize = imageCorr.shape
+ pixelRadius = 1.5
+ numRow = np.ceil(pixelRadius * upsampleFactor)
+ numCol = numRow
+
+ colKern = cp.exp(
+ (-1j * 2 * cp.pi / (imageSize[1] * upsampleFactor))
+ * cp.outer(
+ (cp.fft.ifftshift((cp.arange(imageSize[1]))) - cp.floor(imageSize[1] / 2)),
+ (cp.arange(numCol) - xyShift[1]),
+ )
+ )
+
+ rowKern = cp.exp(
+ (-1j * 2 * cp.pi / (imageSize[0] * upsampleFactor))
+ * cp.outer(
+ (cp.arange(numRow) - xyShift[0]),
+ (cp.fft.ifftshift(cp.arange(imageSize[0])) - cp.floor(imageSize[0] / 2)),
+ )
+ )
+
+ imageUpsample = cp.real(rowKern @ imageCorr @ colKern)
+ return imageUpsample
+
+
+def linear_interpolation_2D_cp(ar, x, y):
+ """
+ Calculates the 2D linear interpolation of array ar at position x,y using the four
+ nearest array elements.
+ """
+ x0, x1 = int(np.floor(x)), int(np.ceil(x))
+ y0, y1 = int(np.floor(y)), int(np.ceil(y))
+ dx = x - x0
+ dy = y - y0
+ return (
+ (1 - dx) * (1 - dy) * ar[x0, y0]
+ + (1 - dx) * dy * ar[x0, y1]
+ + dx * (1 - dy) * ar[x1, y0]
+ + dx * dy * ar[x1, y1]
+ )
+
+
+def _integrate_disks_cp(DP, maxima_x, maxima_y, maxima_int, int_window_radius=1):
+ disks = []
+ DP = cp.asnumpy(DP)
+ img_size = DP.shape[0]
+ for x, y, i in zip(maxima_x, maxima_y, maxima_int):
+ r1, r2 = np.ogrid[-x : img_size - x, -y : img_size - y]
+ mask = r1**2 + r2**2 <= int_window_radius**2
+ mask_arr = np.zeros((img_size, img_size))
+ mask_arr[mask] = 1
+ disk = DP * mask_arr
+ disks.append(np.average(disk))
+ try:
+ disks = disks / max(disks)
+ except:
+ pass
+ return (maxima_x, maxima_y, disks)
diff --git a/py4DSTEM/braggvectors/diskdetection_cuda.py b/py4DSTEM/braggvectors/diskdetection_cuda.py
new file mode 100644
index 000000000..670361b3b
--- /dev/null
+++ b/py4DSTEM/braggvectors/diskdetection_cuda.py
@@ -0,0 +1,715 @@
+"""
+Functions for finding Braggdisks using cupy
+
+"""
+
+import numpy as np
+import cupy as cp
+from cupyx.scipy.ndimage import gaussian_filter
+import cupyx.scipy.fft as cufft
+from time import time
+import numba
+
+from emdfile import tqdmnd
+from py4DSTEM import PointList, PointListArray
+from py4DSTEM.braggvectors.kernels import kernels
+
+
+def find_Bragg_disks_CUDA(
+ datacube,
+ probe,
+ corrPower=1,
+ sigma=2,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0.0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="multicorr",
+ upsample_factor=16,
+ filter_function=None,
+ name="braggpeaks_raw",
+ batching=True,
+):
+ """
+ Finds the Bragg disks in all diffraction patterns of datacube by cross, hybrid, or
+ phase correlation with probe. When hist = True, returns histogram of intensities in
+ the entire datacube.
+
+ Args:
+ DP (ndarray): a diffraction pattern
+ probe (ndarray): the vacuum probe template, in real space.
+ corrPower (float between 0 and 1, inclusive): the cross correlation power. A
+ value of 1 corresponds to a cross correaltion, and 0 corresponds to a
+ phase correlation, with intermediate values giving various hybrids.
+ sigma (float): the standard deviation for the gaussian smoothing applied to
+ the cross correlation
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the brightest peak
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ global_threshold (bool): if True, applies global threshold based on
+ minGlobalIntensity and metric
+ minGlobalThreshold (float): the minimum allowed peak intensity, relative to the
+ selected metric (0-1), except in the case of 'manual' metric, in which the
+ threshold value based on the minimum intensity that you want thresholder
+ out should be set.
+ metric (string): the metric used to compare intensities. 'average' compares peak
+ intensity relative to the average of the maximum intensity in each
+ diffraction pattern. 'max' compares peak intensity relative to the maximum
+ intensity value out of all the diffraction patterns. 'median' compares peak
+ intensity relative to the median of the maximum intensity peaks in each
+ diffraction pattern. 'manual' Allows the user to threshold based on a
+ predetermined intensity value manually determined. In this case,
+ minIntensity should be an int.
+ name (str): name for the returned PointListArray
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The
+ shape of the returned DP must match the shape of the probe kernel (but does
+ not need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ batching (bool): Whether to batch the FFT cross correlation steps.
+
+ Returns:
+ (PointListArray): the Bragg peak positions and correlation intensities
+ """
+
+ # Make the peaks PointListArray
+ coords = [("qx", float), ("qy", float), ("intensity", float)]
+ peaks = PointListArray(dtype=coords, shape=(datacube.R_Nx, datacube.R_Ny))
+
+ # check that the filtered DP is the right size for the probe kernel:
+ if filter_function:
+ assert callable(filter_function), "filter_function must be callable"
+ DP = (
+ datacube.data[0, 0, :, :]
+ if filter_function is None
+ else filter_function(datacube.data[0, 0, :, :])
+ )
+ assert np.all(
+ DP.shape == probe.shape
+ ), "Probe kernel shape must match filtered DP shape"
+
+ # Get the probe kernel FT as a cupy array
+ probe_kernel_FT = cp.conj(cp.fft.fft2(cp.array(probe))).astype(cp.complex64)
+ bytes_per_pattern = probe_kernel_FT.nbytes
+
+ # get the maximal array kernel
+ # if probe_kernel_FT.dtype == 'float64':
+ # get_maximal_points = kernels['maximal_pts_float64']
+ # elif probe_kernel_FT.dtype == 'float32':
+ # get_maximal_points = kernels['maximal_pts_float32']
+ # else:
+ # raise TypeError("Maximal kernel only valid for float32 and float64 types...")
+ get_maximal_points = kernels["maximal_pts_float32"]
+
+ if get_maximal_points.max_threads_per_block < DP.shape[1]:
+ # naive blocks/threads will not work, figure out an OK distribution
+ blocks = ((np.prod(DP.shape) // get_maximal_points.max_threads_per_block + 1),)
+ threads = (get_maximal_points.max_threads_per_block,)
+ else:
+ blocks = (DP.shape[0],)
+ threads = (DP.shape[1],)
+
+ t0 = time()
+ if batching:
+ # compute the batch size based on available VRAM:
+ max_num_bytes = cp.cuda.Device().mem_info[0]
+ # use a fudge factor to leave room for the fourier transformed data
+ # I have set this at 10, which results in underutilization of
+ # VRAM, because this yielded better performance in my testing
+ batch_size = max_num_bytes // (bytes_per_pattern * 10)
+ num_batches = datacube.R_N // batch_size + 1
+
+ print(f"Using {num_batches} batches of {batch_size} patterns each...")
+
+ # allocate array for batch of DPs
+ batched_subcube = cp.zeros(
+ (batch_size, datacube.Q_Nx, datacube.Q_Ny), dtype=cp.float32
+ )
+
+ for batch_idx in tqdmnd(
+ range(num_batches), desc="Finding Bragg disks in batches", unit="batch"
+ ):
+ # the final batch may be smaller than the other ones:
+ probes_remaining = datacube.R_N - (batch_idx * batch_size)
+ this_batch_size = (
+ probes_remaining if probes_remaining < batch_size else batch_size
+ )
+
+ # fill in diffraction patterns, with filtering
+ for subbatch_idx in range(this_batch_size):
+ patt_idx = batch_idx * batch_size + subbatch_idx
+ rx, ry = np.unravel_index(patt_idx, (datacube.R_Nx, datacube.R_Ny))
+ batched_subcube[subbatch_idx, :, :] = cp.array(
+ datacube.data[rx, ry, :, :]
+ if filter_function is None
+ else filter_function(datacube.data[rx, ry, :, :]),
+ dtype=cp.float32,
+ )
+
+ # Perform the FFT and multiplication by probe_kernel on the batched array
+ batched_crosscorr = (
+ cufft.fft2(batched_subcube, overwrite_x=True)
+ * probe_kernel_FT[None, :, :]
+ )
+
+ # Iterate over the patterns in the batch and do the Bragg disk stuff
+ for subbatch_idx in range(this_batch_size):
+ patt_idx = batch_idx * batch_size + subbatch_idx
+ rx, ry = np.unravel_index(patt_idx, (datacube.R_Nx, datacube.R_Ny))
+
+ subFFT = batched_crosscorr[subbatch_idx]
+ ccc = cp.abs(subFFT) ** corrPower * cp.exp(1j * cp.angle(subFFT))
+ cc = cp.maximum(cp.real(cp.fft.ifft2(ccc)), 0)
+
+ _find_Bragg_disks_single_DP_FK_CUDA(
+ None,
+ None,
+ ccc=ccc,
+ cc=cc,
+ corrPower=corrPower,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ filter_function=filter_function,
+ peaks=peaks.get_pointlist(rx, ry),
+ get_maximal_points=get_maximal_points,
+ blocks=blocks,
+ threads=threads,
+ )
+
+ # clean up
+ del batched_subcube, batched_crosscorr, subFFT, cc, ccc
+ cp.get_default_memory_pool().free_all_blocks()
+
+ else:
+ # Loop over all diffraction patterns
+ for Rx, Ry in tqdmnd(
+ datacube.R_Nx,
+ datacube.R_Ny,
+ desc="Finding Bragg Disks",
+ unit="DP",
+ unit_scale=True,
+ ):
+ DP = datacube.data[Rx, Ry, :, :]
+ _find_Bragg_disks_single_DP_FK_CUDA(
+ DP,
+ probe_kernel_FT,
+ corrPower=corrPower,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ filter_function=filter_function,
+ peaks=peaks.get_pointlist(Rx, Ry),
+ get_maximal_points=get_maximal_points,
+ blocks=blocks,
+ threads=threads,
+ )
+ t = time() - t0
+ print(
+ f"Analyzed {datacube.R_N} diffraction patterns in {t//3600}h {t % 3600 // 60}m {t % 60:.2f}s\n(avg. speed {datacube.R_N/t:0.4f} patterns per second)".format()
+ )
+ peaks.name = name
+ return peaks
+
+
+def _find_Bragg_disks_single_DP_FK_CUDA(
+ DP,
+ probe_kernel_FT,
+ corrPower=1,
+ sigma=2,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0.0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="multicorr",
+ upsample_factor=16,
+ filter_function=None,
+ return_cc=False,
+ peaks=None,
+ get_maximal_points=None,
+ blocks=None,
+ threads=None,
+ ccc=None,
+ cc=None,
+):
+ """
+ Finds the Bragg disks in DP by cross, hybrid, or phase correlation with probe_kernel_FT.
+
+ After taking the cross/hybrid/phase correlation, a gaussian smoothing is applied
+ with standard deviation sigma, and all local maxima are found. Detected peaks within
+ edgeBoundary pixels of the diffraction plane edges are then discarded. Next, peaks with
+ intensities less than minRelativeIntensity of the brightest peak in the correaltion are
+ discarded. Then peaks which are within a distance of minPeakSpacing of their nearest neighbor
+ peak are found, and in each such pair the peak with the lesser correlation intensities is
+ removed. Finally, if the number of peaks remaining exceeds maxNumPeaks, only the maxNumPeaks
+ peaks with the highest correlation intensity are retained.
+
+ IMPORTANT NOTE: the argument probe_kernel_FT is related to the probe kernels generated by
+ functions like get_probe_kernel() by:
+
+ probe_kernel_FT = np.conj(np.fft.fft2(probe_kernel))
+
+ if this function is simply passed a probe kernel, the results will not be meaningful! To run
+ on a single DP while passing the real space probe kernel as an argument, use
+ find_Bragg_disks_single_DP().
+
+ Accepts:
+ DP (ndarray) a diffraction pattern
+ probe_kernel_FT (cparray) the vacuum probe template, in Fourier space. Related to the
+ real space probe kernel by probe_kernel_FT = F(probe_kernel)*, where F
+ indicates a Fourier Transform and * indicates complex conjugation.
+ corrPower (float between 0 and 1, inclusive) the cross correlation power. A
+ value of 1 corresponds to a cross correaltion, and 0 corresponds to a
+ phase correlation, with intermediate values giving various hybrids.
+ sigma (float) the standard deviation for the gaussian smoothing applied to
+ the cross correlation
+ edgeBoundary (int) minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float) the minimum acceptable correlation peak intensity, relative to
+ the intensity of the relativeToPeak'th peak
+ relativeToPeak (int) specifies the peak against which the minimum relative intensity
+ is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float) the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int) the maximum number of peaks to return
+ subpixel (str) 'none': no subpixel fitting
+ (default) 'poly': polynomial interpolation of correlogram peaks
+ (fairly fast but not very accurate)
+ 'multicorr': uses the multicorr algorithm with
+ DFT upsampling
+ upsample_factor (int) upsampling factor for subpixel fitting (only used when subpixel='multicorr')
+ filter_function (callable) filtering function to apply to each diffraction pattern before peakfinding.
+ Must be a function of only one argument (the diffraction pattern) and return
+ the filtered diffraction pattern.
+ The shape of the returned DP must match the shape of the probe kernel (but does
+ not need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk detection,
+ the function must be able to be pickled with by dill.
+ return_cc (bool) if True, return the cross correlation
+ peaks (PointList) For internal use.
+ If peaks is None, the PointList of peak positions is created here.
+ If peaks is not None, it is the PointList that detected peaks are added
+ to, and must have the appropriate coords ('qx','qy','intensity').
+ ccc and cc: Precomputed complex and real-IFFT cross correlations. Used when called
+ in batched mode only, causing local calculation of those to be skipped
+
+ Returns:
+ peaks (PointList) the Bragg peak positions and correlation intensities
+ """
+
+ # if we are in batching mode, cc and ccc will be provided. else, compute it
+ if ccc is None:
+ # Perform any prefiltering
+ DP = cp.array(
+ DP if filter_function is None else filter_function(DP), dtype=cp.float32
+ )
+
+ # Get the cross correlation
+ if subpixel in ("none", "poly"):
+ cc = get_cross_correlation_fk(DP, probe_kernel_FT, corrPower)
+ ccc = None
+ # for multicorr subpixel fitting, we need both the real and complex cross correlation
+ else:
+ ccc = get_cross_correlation_fk(
+ DP, probe_kernel_FT, corrPower, returnval="fourier"
+ )
+ cc = cp.maximum(cp.real(cp.fft.ifft2(ccc)), 0)
+
+ # Find the maxima
+ maxima_x, maxima_y, maxima_int = get_maxima_2D(
+ cc,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=subpixel,
+ ar_FT=ccc,
+ upsample_factor=upsample_factor,
+ get_maximal_points=get_maximal_points,
+ blocks=blocks,
+ threads=threads,
+ )
+
+ # Make peaks PointList
+ if peaks is None:
+ coords = [("qx", float), ("qy", float), ("intensity", float)]
+ peaks = PointList(coordinates=coords)
+ peaks.add_data_by_field((maxima_x, maxima_y, maxima_int))
+
+ if return_cc:
+ return peaks, gaussian_filter(cc, sigma)
+ else:
+ return peaks
+
+
+def get_cross_correlation_fk(ar, fourierkernel, corrPower=1, returnval="cc"):
+ """
+ Calculates the cross correlation of ar with fourierkernel.
+ Here, fourierkernel = np.conj(np.fft.fft2(kernel)); speeds up computation when the same
+ kernel is to be used for multiple cross correlations.
+ corrPower specifies the correlation type, where 1 is a cross correlation, 0 is a phase
+ correlation, and values in between are hybrids.
+
+ The return value depends on the argument `returnval`:
+ if return=='cc' (default), returns the real part of the cross correlation in real
+ space.
+ if return=='fourier', returns the output in Fourier space, before taking the
+ inverse transform.
+ """
+ m = cp.fft.fft2(ar) * fourierkernel
+ ccc = cp.abs(m) ** (corrPower) * cp.exp(1j * cp.angle(m))
+ if returnval == "fourier":
+ return ccc
+ else:
+ return cp.real(cp.fft.ifft2(ccc))
+
+
+def get_maxima_2D(
+ ar,
+ sigma=0,
+ edgeBoundary=0,
+ minSpacing=0,
+ minRelativeIntensity=0,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ maxNumPeaks=0,
+ subpixel="poly",
+ ar_FT=None,
+ upsample_factor=16,
+ get_maximal_points=None,
+ blocks=None,
+ threads=None,
+):
+ """
+ Finds the indices where the 2D array ar is a local maximum.
+ Optional parameters allow blurring of the array and filtering of the output;
+ setting each of these to 0 (default) turns off these functions.
+
+ Accepts:
+ ar (ndarray) a 2D array
+ sigma (float) guassian blur std to applyu to ar before finding the maxima
+ edgeBoundary (int) ignore maxima within edgeBoundary of the array edge
+ minSpacing (float) if two maxima are found within minSpacing, the dimmer one
+ is removed
+ minRelativeIntensity (float) maxima dimmer than minRelativeIntensity compared to the
+ relativeToPeak'th brightest maximum are removed
+ relativeToPeak (int) 0=brightest maximum. 1=next brightest, etc.
+ maxNumPeaks (int) return only the first maxNumPeaks maxima
+ subpixel (str) 'none': no subpixel fitting
+ (default) 'poly': polynomial interpolation of correlogram peaks
+ (fairly fast but not very accurate)
+ 'multicorr': uses the multicorr algorithm with
+ DFT upsampling
+ ar_FT (None or complex array) if subpixel=='multicorr' the
+ fourier transform of the image is required. It may be
+ passed here as a complex array. Otherwise, if ar_FT is None,
+ it is computed
+ upsample_factor (int) required iff subpixel=='multicorr'
+
+ Returns
+ maxima_x (ndarray) x-coords of the local maximum, sorted by intensity.
+ maxima_y (ndarray) y-coords of the local maximum, sorted by intensity.
+ maxima_intensity (ndarray) intensity of the local maxima
+ """
+
+ # Get maxima
+ ar = gaussian_filter(ar, sigma)
+ maxima_bool = cp.zeros_like(ar, dtype=bool)
+ sizex = ar.shape[0]
+ sizey = ar.shape[1]
+ N = sizex * sizey
+ get_maximal_points(
+ blocks, threads, (ar, maxima_bool, minAbsoluteIntensity, sizex, sizey, N)
+ )
+
+ # Remove edges
+ if edgeBoundary > 0:
+ maxima_bool[:edgeBoundary, :] = False
+ maxima_bool[-edgeBoundary:, :] = False
+ maxima_bool[:, :edgeBoundary] = False
+ maxima_bool[:, -edgeBoundary:] = False
+ elif subpixel is True:
+ maxima_bool[:1, :] = False
+ maxima_bool[-1:, :] = False
+ maxima_bool[:, :1] = False
+ maxima_bool[:, -1:] = False
+
+ # Get indices, sorted by intensity
+ maxima_x, maxima_y = cp.nonzero(maxima_bool)
+ maxima_x = maxima_x.get()
+ maxima_y = maxima_y.get()
+ dtype = np.dtype([("x", float), ("y", float), ("intensity", float)])
+ maxima = np.zeros(len(maxima_x), dtype=dtype)
+ maxima["x"] = maxima_x
+ maxima["y"] = maxima_y
+
+ ar = ar.get()
+ maxima["intensity"] = ar[maxima_x, maxima_y]
+ maxima = np.sort(maxima, order="intensity")[::-1]
+
+ if len(maxima) > 0:
+ # Remove maxima which are too close
+ if minSpacing > 0:
+ deletemask = np.zeros(len(maxima), dtype=bool)
+ for i in range(len(maxima)):
+ if deletemask[i] is False:
+ tooClose = (
+ (maxima["x"] - maxima["x"][i]) ** 2
+ + (maxima["y"] - maxima["y"][i]) ** 2
+ ) < minSpacing**2
+ tooClose[: i + 1] = False
+ deletemask[tooClose] = True
+ maxima = np.delete(maxima, np.nonzero(deletemask)[0])
+
+ # Remove maxima which are too dim
+ if (minRelativeIntensity > 0) & (len(maxima) > relativeToPeak):
+ deletemask = (
+ maxima["intensity"] / maxima["intensity"][relativeToPeak]
+ < minRelativeIntensity
+ )
+ maxima = np.delete(maxima, np.nonzero(deletemask)[0])
+
+ # Remove maxima which are too dim, absolute scale
+ if minAbsoluteIntensity > 0:
+ deletemask = maxima["intensity"] < minAbsoluteIntensity
+ maxima = np.delete(maxima, np.nonzero(deletemask)[0])
+
+ # Remove maxima in excess of maxNumPeaks
+ if maxNumPeaks is not None and maxNumPeaks > 0:
+ if len(maxima) > maxNumPeaks:
+ maxima = maxima[:maxNumPeaks]
+
+ # Subpixel fitting
+ # For all subpixel fitting, first fit 1D parabolas in x and y to 3 points (maximum, +/- 1 pixel)
+ if subpixel != "none":
+ for i in range(len(maxima)):
+ Ix1_ = ar[int(maxima["x"][i]) - 1, int(maxima["y"][i])]
+ Ix0 = ar[int(maxima["x"][i]), int(maxima["y"][i])]
+ Ix1 = ar[int(maxima["x"][i]) + 1, int(maxima["y"][i])]
+ Iy1_ = ar[int(maxima["x"][i]), int(maxima["y"][i]) - 1]
+ Iy0 = ar[int(maxima["x"][i]), int(maxima["y"][i])]
+ Iy1 = ar[int(maxima["x"][i]), int(maxima["y"][i]) + 1]
+ deltax = (Ix1 - Ix1_) / (4 * Ix0 - 2 * Ix1 - 2 * Ix1_)
+ deltay = (Iy1 - Iy1_) / (4 * Iy0 - 2 * Iy1 - 2 * Iy1_)
+ maxima["x"][i] += deltax if np.abs(deltax) <= 1.0 else 0.0
+ maxima["y"][i] += deltay if np.abs(deltay) <= 1.0 else 0.0
+ maxima["intensity"][i] = linear_interpolation_2D(
+ ar, maxima["x"][i], maxima["y"][i]
+ )
+ # Further refinement with fourier upsampling
+ if subpixel == "multicorr":
+ ar_FT = cp.conj(ar_FT)
+
+ xyShift = np.vstack((maxima["x"], maxima["y"])).T
+ # we actually have to lose some precision and go down to half-pixel
+ # accuracy. this could also be done by a single upsampling at factor 2
+ # instead of get_maxima_2D.
+ xyShift = cp.array(np.round(xyShift * 2.0) / 2)
+
+ subShift = upsampled_correlation(ar_FT, upsample_factor, xyShift).get()
+ maxima["x"] = subShift[:, 0]
+ maxima["y"] = subShift[:, 1]
+
+ return maxima["x"], maxima["y"], maxima["intensity"]
+
+
+def upsampled_correlation(imageCorr, upsampleFactor, xyShift):
+ """
+ Refine the correlation peak of imageCorr around xyShift by DFT upsampling.
+
+ There are two approaches to Fourier upsampling for subpixel refinement: (a) one
+ can pad an (appropriately shifted) FFT with zeros and take the inverse transform,
+ or (b) one can compute the DFT by matrix multiplication using modified
+ transformation matrices. The former approach is straightforward but requires
+ performing the FFT algorithm (which is fast) on very large data. The latter method
+ trades one speedup for a slowdown elsewhere: the matrix multiply steps are expensive
+ but we operate on smaller matrices. Since we are only interested in a very small
+ region of the FT around a peak of interest, we use the latter method to get
+ a substantial speedup and enormous decrease in memory requirement. This
+ "DFT upsampling" approach computes the transformation matrices for the matrix-
+ multiply DFT around a small 1.5px wide region in the original `imageCorr`.
+
+ Following the matrix multiply DFT we use parabolic subpixel fitting to
+ get even more precision! (below 1/upsampleFactor pixels)
+
+ NOTE: previous versions of multiCorr operated in two steps: using the zero-
+ padding upsample method for a first-pass factor-2 upsampling, followed by the
+ DFT upsampling (at whatever user-specified factor). I have implemented it
+ differently, to better support iterating over multiple peaks. **The DFT is always
+ upsampled around xyShift, which MUST be specified to HALF-PIXEL precision
+ (no more, no less) to replicate the behavior of the factor-2 step.**
+ (It is possible to refactor this so that peak detection is done on a Fourier
+ upsampled image rather than using the parabolic subpixel and rounding as now...
+ I like keeping it this way because all of the parameters and logic will be identical
+ to the other subpixel methods.)
+
+
+ Args:
+ imageCorr (complex valued ndarray):
+ Complex product of the FFTs of the two images to be registered
+ i.e. m = np.fft.fft2(DP) * probe_kernel_FT;
+ imageCorr = np.abs(m)**(corrPower) * np.exp(1j*np.angle(m))
+ upsampleFactor (int):
+ Upsampling factor. Must be greater than 2. (To do upsampling
+ with factor 2, use upsampleFFT, which is faster.)
+ xyShift:
+ Array of points around which to upsample, with shape [N-points, 2]
+
+ Returns:
+ (N_points, 2) cupy ndarray: Refined locations of the peaks in image coordinates.
+ """
+
+ xyShift = (cp.round(xyShift * upsampleFactor) / upsampleFactor).astype(cp.float32)
+
+ globalShift = np.fix(np.ceil(upsampleFactor * 1.5) / 2)
+
+ upsampleCenter = globalShift - upsampleFactor * xyShift
+
+ imageCorrUpsample = dftUpsample(imageCorr, upsampleFactor, upsampleCenter).get()
+
+ xSubShift, ySubShift = np.unravel_index(
+ imageCorrUpsample.reshape(imageCorrUpsample.shape[0], -1).argmax(axis=1),
+ imageCorrUpsample.shape[1:3],
+ )
+
+ # add a subpixel shift via parabolic fitting, serially for each peak
+ for idx in range(xSubShift.shape[0]):
+ try:
+ icc = np.real(
+ imageCorrUpsample[
+ idx,
+ xSubShift[idx] - 1 : xSubShift[idx] + 2,
+ ySubShift[idx] - 1 : ySubShift[idx] + 2,
+ ]
+ )
+ dx = (icc[2, 1] - icc[0, 1]) / (
+ 4 * icc[1, 1] - 2 * icc[2, 1] - 2 * icc[0, 1]
+ )
+ dy = (icc[1, 2] - icc[1, 0]) / (
+ 4 * icc[1, 1] - 2 * icc[1, 2] - 2 * icc[1, 0]
+ )
+ except:
+ dx, dy = (
+ 0,
+ 0,
+ ) # this is the case when the peak is near the edge and one of the above values does not exist
+
+ xyShift[idx] = (
+ xyShift[idx]
+ + (cp.array([xSubShift[idx] + dx, ySubShift[idx] + dy]) - globalShift)
+ / upsampleFactor
+ )
+
+ return xyShift
+
+
+def dftUpsample(imageCorr, upsampleFactor, xyShift):
+ """
+ This performs a matrix multiply DFT around a small neighboring region of the inital
+ correlation peak. By using the matrix multiply DFT to do the Fourier upsampling, the
+ efficiency is greatly improved. This is adapted from the subfuction dftups found in
+ the dftregistration function on the Matlab File Exchange.
+
+ https://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation
+
+ The matrix multiplication DFT is from:
+
+ Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, "Efficient subpixel
+ image registration algorithms," Opt. Lett. 33, 156-158 (2008).
+ http://www.sciencedirect.com/science/article/pii/S0045790612000778
+
+ Args:
+ imageCorr (complex valued ndarray):
+ Correlation image between two images in Fourier space.
+ upsampleFactor (int):
+ Scalar integer of how much to upsample.
+ xyShift (N_points,2) cp.ndarray, locations to upsample around:
+ Coordinates in the UPSAMPLED GRID around which to upsample.
+ These must be single-pixel IN THE UPSAMPLED GRID
+
+ Returns:
+ (ndarray):
+ Stack of upsampled images from region around correlation peak.
+ """
+ N_pts = xyShift.shape[0]
+ imageSize = imageCorr.shape
+ pixelRadius = 1.5
+ kernel_size = int(np.ceil(pixelRadius * upsampleFactor))
+
+ colKern = cp.zeros(
+ (N_pts, imageSize[1], kernel_size), dtype=cp.complex64
+ ) # N_pts * image_size[1] * kernel_size
+ rowKern = cp.zeros(
+ (N_pts, kernel_size, imageSize[0]), dtype=cp.complex64
+ ) # N_pts * kernel_size * image_size[0]
+
+ # Fill in the DFT arrays using the CUDA kernels
+ multicorr_col_kernel = kernels["multicorr_col_kernel"]
+ blocks = (
+ (np.prod(colKern.shape) // multicorr_col_kernel.max_threads_per_block + 1),
+ )
+ threads = (multicorr_col_kernel.max_threads_per_block,)
+ multicorr_col_kernel(
+ blocks, threads, (colKern, xyShift, N_pts, *imageSize, upsampleFactor)
+ )
+
+ multicorr_row_kernel = kernels["multicorr_row_kernel"]
+ blocks = (
+ (np.prod(rowKern.shape) // multicorr_row_kernel.max_threads_per_block + 1),
+ )
+ threads = (multicorr_row_kernel.max_threads_per_block,)
+ multicorr_row_kernel(
+ blocks, threads, (rowKern, xyShift, N_pts, *imageSize, upsampleFactor)
+ )
+
+ # Apply the DFT arrays to the correlation image
+ imageUpsample = cp.real(rowKern @ imageCorr @ colKern)
+ return imageUpsample
+
+
+@numba.jit(nopython=True)
+def linear_interpolation_2D(ar, x, y):
+ """
+ Calculates the 2D linear interpolation of array ar at position x,y using the four
+ nearest array elements.
+ """
+ x0, x1 = int(np.floor(x)), int(np.ceil(x))
+ y0, y1 = int(np.floor(y)), int(np.ceil(y))
+ dx = x - x0
+ dy = y - y0
+ return (
+ (1 - dx) * (1 - dy) * ar[x0, y0]
+ + (1 - dx) * dy * ar[x0, y1]
+ + dx * (1 - dy) * ar[x1, y0]
+ + dx * dy * ar[x1, y1]
+ )
diff --git a/py4DSTEM/braggvectors/diskdetection_parallel.py b/py4DSTEM/braggvectors/diskdetection_parallel.py
new file mode 100644
index 000000000..a1c5dc6f4
--- /dev/null
+++ b/py4DSTEM/braggvectors/diskdetection_parallel.py
@@ -0,0 +1,577 @@
+# stdlib
+import os
+import tempfile
+from time import time
+
+# 3rd party
+import numpy as np
+import dill
+
+# local
+import py4DSTEM
+from emdfile import PointListArray
+
+
+def _find_Bragg_disks_single_DP_FK(
+ DP,
+ probe_kernel_FT,
+ corrPower=1,
+ sigma=2,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="multicorr",
+ upsample_factor=16,
+ filter_function=None,
+ return_cc=False,
+ peaks=None,
+):
+ """
+ Mirror of diskdetection.find_Bragg_disks_single_DP_FK with explicit imports for
+ remote execution.
+
+ Finds the Bragg disks in DP by cross, hybrid, or phase correlation with
+ probe_kernel_FT.
+
+ After taking the cross/hybrid/phase correlation, a gaussian smoothing is applied
+ with standard deviation sigma, and all local maxima are found. Detected peaks within
+ edgeBoundary pixels of the diffraction plane edges are then discarded. Next, peaks
+ with intensities less than minRelativeIntensity of the brightest peak in the
+ correaltion are discarded. Then peaks which are within a distance of minPeakSpacing
+ of their nearest neighbor peak are found, and in each such pair the peak with the
+ lesser correlation intensities is removed. Finally, if the number of peaks remaining
+ exceeds maxNumPeaks, only the maxNumPeaks peaks with the highest correlation
+ intensity are retained.
+
+ IMPORTANT NOTE: the argument probe_kernel_FT is related to the probe kernels
+ generated by functions like get_probe_kernel() by:
+
+ >>> probe_kernel_FT = np.conj(np.fft.fft2(probe_kernel))
+
+ if this function is simply passed a probe kernel, the results will not be meaningful!
+ To run on a single DP while passing the real space probe kernel as an argument, use
+ find_Bragg_disks_single_DP().
+
+ Args:
+ DP (ndarray): a diffraction pattern
+ probe_kernel_FT (ndarray): the vacuum probe template, in Fourier space. Related
+ to the real space probe kernel by probe_kernel_FT = F(probe_kernel)*, where
+ F indicates a Fourier Transform and * indicates complex conjugation.
+ corrPower (float between 0 and 1, inclusive): the cross correlation power. A
+ value of 1 corresponds to a cross correaltion, and 0 corresponds to a
+ phase correlation, with intermediate values giving various hybrids.
+ sigma (float): the standard deviation for the gaussian smoothing applied to
+ the cross correlation
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the relativeToPeak'th peak
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The shape
+ of the returned DP must match the shape of the probe kernel (but does not
+ need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ return_cc (bool): if True, return the cross correlation
+ peaks (PointList): For internal use. If peaks is None, the PointList of peak
+ positions is created here. If peaks is not None, it is the PointList that
+ detected peaks are added to, and must have the appropriate coords
+ ('qx','qy','intensity').
+
+ Returns:
+ (PointList) the Bragg peak positions and correlation intensities
+ """
+ assert subpixel in [
+ "none",
+ "poly",
+ "multicorr",
+ ], "Unrecognized subpixel option {}, subpixel must be 'none', 'poly', or 'multicorr'".format(
+ subpixel
+ )
+
+ import numpy
+ import scipy.ndimage.filters
+ import py4DSTEM.process.utils.multicorr
+
+ # apply filter function:
+ DP = DP if filter_function is None else filter_function(DP)
+
+ if subpixel == "none":
+ cc = py4DSTEM.process.utils.get_cross_correlation_fk(
+ DP, probe_kernel_FT, corrPower
+ )
+ cc = numpy.maximum(cc, 0)
+ maxima_x, maxima_y, maxima_int = py4DSTEM.process.utils.get_maxima_2D(
+ cc,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=False,
+ )
+ elif subpixel == "poly":
+ cc = py4DSTEM.process.utils.get_cross_correlation_fk(
+ DP, probe_kernel_FT, corrPower
+ )
+ cc = numpy.maximum(cc, 0)
+ maxima_x, maxima_y, maxima_int = py4DSTEM.process.utils.get_maxima_2D(
+ cc,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=True,
+ )
+ else:
+ # Multicorr subpixel:
+ m = numpy.fft.fft2(DP) * probe_kernel_FT
+ ccc = numpy.abs(m) ** corrPower * numpy.exp(1j * numpy.angle(m))
+
+ cc = numpy.maximum(numpy.real(numpy.fft.ifft2(ccc)), 0)
+
+ maxima_x, maxima_y, maxima_int = py4DSTEM.process.utils.get_maxima_2D(
+ cc,
+ sigma=sigma,
+ edgeBoundary=edgeBoundary,
+ minRelativeIntensity=minRelativeIntensity,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ relativeToPeak=relativeToPeak,
+ minSpacing=minPeakSpacing,
+ maxNumPeaks=maxNumPeaks,
+ subpixel=True,
+ )
+
+ # Use the DFT upsample to refine the detected peaks (but not the intensity)
+ for ipeak in range(len(maxima_x)):
+ xyShift = numpy.array((maxima_x[ipeak], maxima_y[ipeak]))
+ # we actually have to lose some precision and go down to half-pixel
+ # accuracy. this could also be done by a single upsampling at factor 2
+ # instead of get_maxima_2D.
+ xyShift[0] = numpy.round(xyShift[0] * 2) / 2
+ xyShift[1] = numpy.round(xyShift[1] * 2) / 2
+
+ subShift = py4DSTEM.process.utils.multicorr.upsampled_correlation(
+ ccc, upsample_factor, xyShift
+ )
+ maxima_x[ipeak] = subShift[0]
+ maxima_y[ipeak] = subShift[1]
+
+ # Make peaks PointList
+ if peaks is None:
+ coords = [("qx", float), ("qy", float), ("intensity", float)]
+ peaks = py4DSTEM.PointList(coordinates=coords)
+ else:
+ assert isinstance(peaks, py4DSTEM.PointList)
+ peaks.add_tuple_of_nparrays((maxima_x, maxima_y, maxima_int))
+
+ if return_cc:
+ return peaks, scipy.ndimage.filters.gaussian_filter(cc, sigma)
+ else:
+ return peaks
+
+
+def _process_chunk(_f, start, end, path_to_static, coords, path_to_data, cluster_path):
+ import os
+ import dill
+
+ with open(path_to_static, "rb") as infile:
+ inputs = dill.load(infile)
+
+ # Always try to memory map the data file, if possible
+ if path_to_data.rsplit(".", 1)[-1].startswith("dm"):
+ datacube = py4DSTEM.read(path_to_data, load="dmmmap")
+ elif path_to_data.rsplit(".", 1)[-1].startswith("gt"):
+ datacube = py4DSTEM.read(path_to_data, load="gatan_bin")
+ else:
+ datacube = py4DSTEM.read(path_to_data)
+
+ results = []
+ for x in coords:
+ results.append((x[0], x[1], _f(datacube.data[x[0], x[1], :, :], *inputs).data))
+
+ # release memory
+ datacube = None
+
+ path_to_output = os.path.join(cluster_path, "{}_{}.data".format(start, end))
+ with open(path_to_output, "wb") as data_file:
+ dill.dump(results, data_file)
+
+ return path_to_output
+
+
+def find_Bragg_disks_ipp(
+ DP,
+ probe,
+ corrPower=1,
+ sigma=2,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="poly",
+ upsample_factor=4,
+ filter_function=None,
+ ipyparallel_client_file=None,
+ data_file=None,
+ cluster_path=None,
+):
+ """
+ Distributed compute using IPyParallel.
+
+ Finds the Bragg disks in all diffraction patterns of datacube by cross, hybrid, or
+ phase correlation with probe.
+
+ Args:
+ DP (ndarray): a diffraction pattern
+ probe (ndarray): the vacuum probe template, in real space.
+ corrPower (float between 0 and 1, inclusive): the cross correlation power. A
+ value of 1 corresponds to a cross correaltion, and 0 corresponds to a
+ phase correlation, with intermediate values giving various hybrids.
+ sigma (float): the standard deviation for the gaussian smoothing applied to
+ the cross correlation
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the brightest peak
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The shape
+ of the returned DP must match the shape of the probe kernel (but does not
+ need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ ipyparallel_client_file (str): absolute path to ipyparallel client JSON file for
+ connecting to a cluster
+ data_file (str): absolute path to the data file containing the datacube for
+ processing remotely
+ cluster_path (str): working directory for cluster processing, defaults to current
+ directory
+
+ Returns:
+ (PointListArray): the Bragg peak positions and correlation intensities
+ """
+ import ipyparallel as ipp
+
+ R_Nx = DP.R_Nx
+ R_Ny = DP.R_Ny
+ R_N = DP.R_N
+ DP = None
+
+ # Make the peaks PointListArray
+ coords = [("qx", float), ("qy", float), ("intensity", float)]
+ peaks = PointListArray(coordinates=coords, shape=(R_Nx, R_Ny))
+
+ # Get the probe kernel FT
+ probe_kernel_FT = np.conj(np.fft.fft2(probe))
+
+ if ipyparallel_client_file is None:
+ raise RuntimeError("ipyparallel_client_file is None, no IPyParallel cluster")
+ elif data_file is None:
+ raise RuntimeError("data_file is None, needs path to datacube")
+
+ t0 = time()
+ c = ipp.Client(url_file=ipyparallel_client_file, timeout=30)
+
+ inputs_list = [
+ probe_kernel_FT,
+ corrPower,
+ sigma,
+ edgeBoundary,
+ minRelativeIntensity,
+ minAbsoluteIntensity,
+ relativeToPeak,
+ minPeakSpacing,
+ maxNumPeaks,
+ subpixel,
+ upsample_factor,
+ filter_function,
+ ]
+
+ if cluster_path is None:
+ cluster_path = os.getcwd()
+
+ tmpdir = tempfile.TemporaryDirectory(dir=cluster_path)
+
+ t_00 = time()
+ # write out static inputs
+ path_to_inputs = os.path.join(tmpdir.name, "inputs")
+ with open(path_to_inputs, "wb") as inputs_file:
+ dill.dump(inputs_list, inputs_file)
+ t_inputs_save = time() - t_00
+ print("Serialize input values : {}".format(t_inputs_save))
+
+ results = []
+ t1 = time()
+ total = int(R_Nx * R_Ny)
+ chunkSize = int(total / len(c.ids))
+
+ while chunkSize * len(c.ids) < total:
+ chunkSize += 1
+
+ indices = [(Rx, Ry) for Rx in range(R_Nx) for Ry in range(R_Ny)]
+
+ start = 0
+ for engine in c.ids:
+ if start + chunkSize < total - 1:
+ end = start + chunkSize
+ else:
+ end = total
+
+ results.append(
+ c[engine].apply(
+ _process_chunk,
+ _find_Bragg_disks_single_DP_FK,
+ start,
+ end,
+ path_to_inputs,
+ indices[start:end],
+ data_file,
+ tmpdir.name,
+ )
+ )
+
+ if end == total:
+ break
+ else:
+ start = end
+ t_submit = time() - t1
+ print("Submit phase : {}".format(t_submit))
+
+ t2 = time()
+ c.wait(jobs=results)
+ t_wait = time() - t2
+ print("Gather phase : {}".format(t_wait))
+
+ t3 = time()
+ for i in range(len(results)):
+ with open(results[i].get(), "rb") as f:
+ data_chunk = dill.load(f)
+
+ for Rx, Ry, data in data_chunk:
+ peaks.get_pointlist(Rx, Ry).add_dataarray(data)
+ t_copy = time() - t3
+ print("Copy results : {}".format(t_copy))
+
+ # clean up temp files
+ try:
+ tmpdir.cleanup()
+ except OSError as e:
+ print("Error when cleaning up temporary files: {}".format(e))
+
+ t = time() - t0
+ print(
+ "Analyzed {} diffraction patterns in {}h {}m {}s".format(
+ R_N, int(t / 3600), int(t / 60), int(t % 60)
+ )
+ )
+
+ return peaks
+
+
+def find_Bragg_disks_dask(
+ DP,
+ probe,
+ corrPower=1,
+ sigma=2,
+ edgeBoundary=20,
+ minRelativeIntensity=0.005,
+ minAbsoluteIntensity=0,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ maxNumPeaks=70,
+ subpixel="poly",
+ upsample_factor=4,
+ filter_function=None,
+ dask_client=None,
+ data_file=None,
+ cluster_path=None,
+):
+ """
+ Distributed compute using Dask.
+
+ Finds the Bragg disks in all diffraction patterns of datacube by cross, hybrid, or
+ phase correlation with probe.
+
+ Args:
+ DP (ndarray): a diffraction pattern
+ probe (darray): the vacuum probe template, in real space.
+ corrPower (float between 0 and 1, inclusive): the cross correlation power. A
+ value of 1 corresponds to a cross correaltion, and 0 corresponds to a
+ phase correlation, with intermediate values giving various hybrids.
+ sigma (float): the standard deviation for the gaussian smoothing applied to
+ the cross correlation
+ edgeBoundary (int): minimum acceptable distance from the DP edge, in pixels
+ minRelativeIntensity (float): the minimum acceptable correlation peak intensity,
+ relative to the intensity of the brightest peak
+ relativeToPeak (int): specifies the peak against which the minimum relative
+ intensity is measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing (float): the minimum acceptable spacing between detected peaks
+ maxNumPeaks (int): the maximum number of peaks to return
+ subpixel (str): Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor (int): upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ filter_function (callable): filtering function to apply to each diffraction
+ pattern before peakfinding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern. The shape
+ of the returned DP must match the shape of the probe kernel (but does not
+ need to match the shape of the input diffraction pattern, e.g. the filter
+ can be used to bin the diffraction pattern). If using distributed disk
+ detection, the function must be able to be pickled with by dill.
+ dask_client (obj): dask client for connecting to a cluster
+ data_file (str): absolute path to the data file containing the datacube for
+ processing remotely
+ cluster_path (str): working directory for cluster processing, defaults to current
+ directory
+
+ Returns:
+ (PointListArray) the Bragg peak positions and correlation intensities
+ """
+ import distributed
+
+ R_Nx = DP.R_Nx
+ R_Ny = DP.R_Ny
+ R_N = DP.R_N
+ DP = None
+
+ # Make the peaks PointListArray
+ coords = [("qx", float), ("qy", float), ("intensity", float)]
+ peaks = PointListArray(coordinates=coords, shape=(R_Nx, R_Ny))
+
+ # Get the probe kernel FT
+ probe_kernel_FT = np.conj(np.fft.fft2(probe))
+
+ if dask_client is None:
+ raise RuntimeError("dask_client is None, no Dask cluster!")
+ elif data_file is None:
+ raise RuntimeError("data_file is None, needs path to datacube")
+
+ t0 = time()
+
+ inputs_list = [
+ probe_kernel_FT,
+ corrPower,
+ sigma,
+ edgeBoundary,
+ minRelativeIntensity,
+ minAbsoluteIntensity,
+ relativeToPeak,
+ minPeakSpacing,
+ maxNumPeaks,
+ subpixel,
+ upsample_factor,
+ filter_function,
+ ]
+
+ if cluster_path is None:
+ cluster_path = os.getcwd()
+
+ tmpdir = tempfile.TemporaryDirectory(dir=cluster_path)
+
+ # write out static inputs
+ path_to_inputs = os.path.join(tmpdir.name, "{}.inputs".format(dask_client.id))
+ with open(path_to_inputs, "wb") as inputs_file:
+ dill.dump(inputs_list, inputs_file)
+ t_inputs_save = time() - t0
+ print("Serialize input values : {}".format(t_inputs_save))
+
+ cores = len(dask_client.ncores())
+
+ submits = []
+ t1 = time()
+ total = int(R_Nx * R_Ny)
+ chunkSize = int(total / cores)
+
+ while (chunkSize * cores) < total:
+ chunkSize += 1
+
+ indices = [(Rx, Ry) for Rx in range(R_Nx) for Ry in range(R_Ny)]
+
+ start = 0
+ for engine in range(cores):
+ if start + chunkSize < total - 1:
+ end = start + chunkSize
+ else:
+ end = total
+
+ submits.append(
+ dask_client.submit(
+ _process_chunk,
+ _find_Bragg_disks_single_DP_FK,
+ start,
+ end,
+ path_to_inputs,
+ indices[start:end],
+ data_file,
+ tmpdir.name,
+ )
+ )
+
+ if end == total:
+ break
+ else:
+ start = end
+ t_submit = time() - t1
+ print("Submit phase : {}".format(t_submit))
+
+ t2 = time()
+ # collect results
+ for batch in distributed.as_completed(submits, with_results=True).batches():
+ for future, result in batch:
+ with open(result, "rb") as f:
+ data_chunk = dill.load(f)
+
+ for Rx, Ry, data in data_chunk:
+ peaks.get_pointlist(Rx, Ry).add_dataarray(data)
+ t_copy = time() - t2
+ print("Gather phase : {}".format(t_copy))
+
+ # clean up temp files
+ try:
+ tmpdir.cleanup()
+ except OSError as e:
+ print("Error when cleaning up temporary files: {}".format(e))
+
+ t = time() - t0
+ print(
+ "Analyzed {} diffraction patterns in {}h {}m {}s".format(
+ R_N, int(t / 3600), int(t / 60), int(t % 60)
+ )
+ )
+
+ return peaks
diff --git a/py4DSTEM/braggvectors/diskdetection_parallel_new.py b/py4DSTEM/braggvectors/diskdetection_parallel_new.py
new file mode 100644
index 000000000..dccc0dd4b
--- /dev/null
+++ b/py4DSTEM/braggvectors/diskdetection_parallel_new.py
@@ -0,0 +1,272 @@
+import numpy as np
+import matplotlib.pyplot as plt
+import h5py
+import time
+import dill
+
+import dask
+import dask.array as da
+import dask.config
+from dask import delayed
+from dask.distributed import Client, LocalCluster
+from dask.diagnostics import ProgressBar
+
+# import dask.bag as db
+
+# import distributed
+from distributed.protocol.serialize import register_serialization_family
+import distributed
+
+import py4DSTEM
+from emdfile import PointListArray, PointList
+from py4DSTEM.braggvectors.diskdetection import _find_Bragg_disks_single_DP_FK
+
+
+#### SERIALISERS ####
+# Define Serialiser
+# these are functions which allow the hdf5 objects to be passed. May not be required anymore
+
+
+def dill_dumps(x):
+ header = {"serializer": "dill"}
+ frames = [dill.dumps(x)]
+ return header, frames
+
+
+def dill_loads(header, frames):
+ if len(frames) > 1:
+ frame = "".join(frames)
+ else:
+ frame = frames[0]
+
+ return dill.loads(frame)
+
+
+# register the serialization method
+# register_serialization_family('dill', dill_dumps, dill_loads)
+
+
+def register_dill_serializer():
+ """
+ This function registers the dill serializer allowing dask to work on h5py objects.
+ Not sure if this needs to be run and how often this need to be run. Keeping this in for now.
+ Args:
+ None
+ Returns:
+ None
+ """
+ register_serialization_family("dill", dill_dumps, dill_loads)
+ return None
+
+
+#### END OF SERAILISERS ####
+
+
+#### DASK WRAPPER FUNCTION ####
+
+
+# Each delayed objected is passed a 4D array, currently implementing only on 2D slices.
+# TODO add batching with fancy indexing - needs to run a for loop over the batch of arrays
+# TODO add cuda accelerated version
+# TODO add ML-AI version
+def _find_Bragg_disks_single_DP_FK_dask_wrapper(arr, *args, **kwargs):
+ # THis is needed as _find_Bragg_disks_single_DP_FK takes 2D array these arrays have the wrong shape
+ return _find_Bragg_disks_single_DP_FK(arr[0, 0], *args, **kwargs)
+
+
+#### END OF DASK WRAPPER FUNCTIONS ####
+
+
+#### MAIN FUNCTION
+# TODO add batching with fancy indexing - needs batch size, fancy indexing method
+# TODO add cuda accelerated function - needs dask GPU cluster.
+
+
+def beta_parallel_disk_detection(
+ dataset,
+ probe,
+ # rxmin=None, # these would allow selecting a sub section
+ # rxmax=None,
+ # rymin=None,
+ # rymax=None,
+ # qxmin=None,
+ # qxmax=None,
+ # qymin=None,
+ # qymax=None,
+ probe_type="FT",
+ dask_client=None,
+ dask_client_params: dict = None,
+ restart_dask_client=True,
+ close_dask_client=False,
+ return_dask_client=True,
+ *args,
+ **kwargs,
+):
+ """
+ This is not fully validated currently so may not work, please report bugs on the py4DSTEM github page.
+
+ This parallellises the disk detetection for all probe posistions. This can operate on either in memory or out of memory datasets
+
+ There is an asumption that unless specifying otherwise you are parallelising on a single Local Machine.
+ If this is not the case its probably best to pass the dask_client into the function, although you can just pass the required arguments to dask_client_params.
+ If no dask_client arguments are passed it will create a dask_client for a local machine
+
+ Note:
+ Do not pass "peaks" argument as a kwarg, like you might in "_find_Bragg_disks_single_DP_FK", as the results will be unreliable and may cause the calculation to crash.
+ Args:
+ dataset (py4dSTEM datacube): 4DSTEM dataset
+ probe (ndarray): can be regular probe kernel or fourier transormed
+ probe_type (str): "FT" or None
+ dask_client (distributed.client.Client): dask client
+ dask_client_params (dict): parameters to pass to dask client or dask cluster
+ restart_dask_client (bool): if True, function will attempt to restart the dask_client.
+ close_dask_client (bool): if True, function will attempt to close the dask_client.
+ return_dask_client (bool): if True, function will return the dask_client.
+ *args,kwargs will be passed to "_find_Bragg_disks_single_DP_FK" e.g. corrPower, sigma, edgeboundary...
+
+ Returns:
+ peaks (PointListArray): the Bragg peak positions and the correlenation intensities
+ dask_client(optional) (distributed.client.Client): dask_client for use later.
+ """
+ # TODO add asserts abotu peaks not being passed
+ # Dask Client stuff
+ # TODO how to guess at default params for client, sqrt no.cores. Something to do with the size of the diffraction patterm
+ # write a function which can do this.
+ # TODO replace dask part with a with statement for easier clean up e.g.
+ # with LocalCluser(params) as cluster, Client(cluster) as client:
+ # ... dask stuff.
+ # TODO add assert statements and other checks. Think about reordering opperations
+
+ if dask_client is None:
+ if dask_client_params is not None:
+ dask.config.set(
+ {
+ "distributed.worker.memory.spill": False,
+ "distributed.worker.memory.target": False,
+ }
+ )
+ cluster = LocalCluster(**dask_client_params)
+ dask_client = Client(cluster, **dask_client_params)
+ else:
+ # AUTO MAGICALLY SET?
+ # LET DASK SET?
+ # HAVE A FUNCTION WHICH RUNS ON A SUBSET OF THE DATA TO PICK OPTIMIAL VALUE?
+ # psutil could be used to count cores.
+ dask.config.set(
+ {
+ "distributed.worker.memory.spill": False, # stops spilling to disk
+ "distributed.worker.memory.target": False,
+ }
+ ) # stops spilling to disk and erroring out
+ cluster = LocalCluster()
+ dask_client = Client(cluster)
+
+ else:
+ assert type(dask_client) == distributed.client.Client
+ if restart_dask_client:
+ try:
+ dask_client.restart()
+ except Exception as e:
+ print(
+ 'Could not restart dask client. Try manually restarting outside or passing "restart_dask_client=False"'
+ ) # WARNING STATEMENT
+ return e
+ else:
+ pass
+
+ # Probe stuff
+ assert (
+ probe.shape == dataset.data.shape[2:]
+ ), "Probe and Diffraction Pattern Shapes are Mismatched"
+ if probe_type != "FT":
+ # TODO clean up and pull out redudant parts
+ # if probe.dtype != (np.complex128 or np.complex64 or np.complex256):
+ # DO FFT SHIFT THING
+ probe_kernel_FT = np.conj(np.fft.fft2(probe))
+ dask_probe_array = da.from_array(
+ probe_kernel_FT, chunks=(dataset.Q_Nx, dataset.Q_Ny)
+ )
+ dask_probe_delayed = dask_probe_array.to_delayed()
+ # delayed_probe_kernel_FT = delayed(probe_kernel_FT)
+ else:
+ probe_kernel_FT = probe
+ dask_probe_array = da.from_array(
+ probe_kernel_FT, chunks=(dataset.Q_Nx, dataset.Q_Ny)
+ )
+ dask_probe_delayed = dask_probe_array.to_delayed()
+
+ # GET DATA
+ # TODO add another elif if it is a dask array then pass
+ if type(dataset.data) == np.ndarray:
+ dask_data = da.from_array(
+ dataset.data, chunks=(1, 1, dataset.Q_Nx, dataset.Q_Ny)
+ )
+ elif dataset.stack_pointer is not None:
+ dask_data = da.from_array(
+ dataset.stack_pointer, chunks=(1, 1, dataset.Q_Nx, dataset.Q_Ny)
+ )
+ else:
+ print("Couldn't access the data")
+ return None
+
+ # Convert the data to delayed
+ dataset_delayed = dask_data.to_delayed()
+ # TODO Trim data e.g. rx,ry,qx,qy
+ # I can pass the index values in here I should trim the probe and diffraction pattern first
+
+ # Into the meat of the function
+
+ # create an empty list to which we will append the dealyed functions to.
+ res = []
+ # loop over the dataset_delayed and create a delayed function of
+ for x in np.ndindex(dataset_delayed.shape):
+ temp = delayed(_find_Bragg_disks_single_DP_FK_dask_wrapper)(
+ dataset_delayed[x],
+ probe_kernel_FT=dask_probe_delayed[0, 0],
+ # probe_kernel_FT=delayed_probe_kernel_FT,
+ *args,
+ **kwargs,
+ ) # passing through args from earlier or should I use
+ # corrPower=corrPower,
+ # sigma=sigma_gaussianFilter,
+ # edgeBoundary=edgeBoundary,
+ # minRelativeIntensity=minRelativeIntensity,
+ # minPeakSpacing=minPeakSpacing,
+ # maxNumPeaks=maxNumPeaks,
+ # subpixel='poly')
+ res.append(temp)
+ _temp_peaks = dask_client.compute(
+ res, optimize_graph=True
+ ) # creates futures and starts computing
+
+ output = dask_client.gather(_temp_peaks) # gather the future objects
+
+ coords = [("qx", float), ("qy", float), ("intensity", float)]
+ peaks = PointListArray(coordinates=coords, shape=dataset.data.shape[:-2])
+
+ # temp_peaks[0][0]
+
+ # operating over a list so we need the size (0->count) and re-create the probe positions (0->rx,0->ry),
+ for count, (rx, ry) in zip(
+ [i for i in range(dataset.data[..., 0, 0].size)],
+ np.ndindex(dataset.data.shape[:-2]),
+ ):
+ # peaks.get_pointlist(rx, ry).add_pointlist(temp_peaks[0][count])
+ # peaks.get_pointlist(rx, ry).add_pointlist(output[count][0])
+ peaks.get_pointlist(rx, ry).add_pointlist(output[count])
+
+ # Clean up
+ dask_client.cancel(_temp_peaks) # removes from the dask workers
+ del _temp_peaks # deletes the object
+ if close_dask_client:
+ dask_client.close()
+ return peaks
+ elif close_dask_client is False and return_dask_client is True:
+ return peaks, dask_client
+ elif close_dask_client and return_dask_client is False:
+ return peaks
+ else:
+ print(
+ "Dask Client in unknown state, this may result in unpredicitable behaviour later"
+ )
+ return peaks
diff --git a/py4DSTEM/braggvectors/kernels.py b/py4DSTEM/braggvectors/kernels.py
new file mode 100644
index 000000000..d36ae172b
--- /dev/null
+++ b/py4DSTEM/braggvectors/kernels.py
@@ -0,0 +1,97 @@
+import cupy as cp
+
+__all__ = ["kernels"]
+
+kernels = {}
+
+############################# multicorr kernels #################################
+
+import os
+
+with open(os.path.join(os.path.dirname(__file__), "multicorr_row_kernel.cu"), "r") as f:
+ kernels["multicorr_row_kernel"] = cp.RawKernel(f.read(), "multicorr_row_kernel")
+
+with open(os.path.join(os.path.dirname(__file__), "multicorr_col_kernel.cu"), "r") as f:
+ kernels["multicorr_col_kernel"] = cp.RawKernel(f.read(), "multicorr_col_kernel")
+
+
+############################# get_maximal_points ################################
+
+"""
+These kernels are approximately 50x faster than the np.roll approach used in the CPU version,
+per my testing with 1024x1024 pixels and float64 on a Jetson Xavier NX.
+The boundary conditions are slightly different in this version, in that pixels on the edge
+of the frame are always false. This simplifies the indexing, and since in the Braggdisk
+detection application an edgeBoundary is always applied in the case of subpixel detection,
+this is not considered a problem.
+"""
+
+maximal_pts_float32 = r"""
+extern "C" __global__
+void maximal_pts(const float *ar, bool *out, const double minAbsoluteIntensity, const long long sizex, const long long sizey, const long long N){
+ int tid = blockDim.x * blockIdx.x + threadIdx.x;
+ int x = tid / sizey;
+ int y = tid % sizey;
+ bool res = false;
+ if (tid < N && x>0 && x<(sizex-1) && y>0 && y<(sizey-1)) {
+ float val = ar[tid];
+
+ out[tid] = ( val > ar[tid + sizey]) &&
+ (val > ar[tid - sizey]) &&
+ (val > ar[tid + 1]) &&
+ (val > ar[tid - 1]) &&
+ (val > ar[tid - sizey - 1]) &&
+ (val > ar[tid - sizey + 1]) &&
+ (val > ar[tid + sizey - 1]) &&
+ (val > ar[tid+sizey + 1] &&
+ (val >= minAbsoluteIntensity));
+ }
+}
+"""
+
+kernels["maximal_pts_float32"] = cp.RawKernel(maximal_pts_float32, "maximal_pts")
+
+maximal_pts_float64 = r"""
+extern "C" __global__
+void maximal_pts(const double *ar, bool *out, const double minAbsoluteIntensity, const long long sizex, const long long sizey, const long long N){
+ int tid = blockDim.x * blockIdx.x + threadIdx.x;
+ int x = tid / sizey;
+ int y = tid % sizey;
+ bool res = false;
+ if (tid < N && x>0 && x<(sizex-1) && y>0 && y<(sizey-1)) {
+ double val = ar[tid];
+
+ out[tid] = ( val > ar[tid + sizey]) &&
+ (val > ar[tid - sizey]) &&
+ (val > ar[tid + 1]) &&
+ (val > ar[tid - 1]) &&
+ (val > ar[tid - sizey - 1]) &&
+ (val > ar[tid - sizey + 1]) &&
+ (val > ar[tid + sizey - 1]) &&
+ (val > ar[tid+sizey + 1] &&
+ (val >= minAbsoluteIntensity));
+ }
+}
+"""
+
+kernels["maximal_pts_float64"] = cp.RawKernel(maximal_pts_float64, "maximal_pts")
+
+
+################################ edge_boundary ######################################
+
+edge_boundary = r"""
+extern "C" __global__
+void edge_boundary(bool *ar, const long long edgeBoundary,
+ const long long sizex, const long long sizey, const long long N){
+ int tid = blockDim.x * blockIdx.x + threadIdx.x;
+ int x = tid % sizex;
+ int y = tid / sizey; // Floor divide
+ if (tid < N) {
+ if (x(sizex-1-edgeBoundary) || y(sizey-1-edgeBoundary)){
+ ar[tid] = false;
+ }
+ }
+}
+"""
+
+kernels["edge_boundary"] = cp.RawKernel(edge_boundary, "edge_boundary")
diff --git a/py4DSTEM/braggvectors/multicorr_col_kernel.cu b/py4DSTEM/braggvectors/multicorr_col_kernel.cu
new file mode 100644
index 000000000..d2c9cc4ca
--- /dev/null
+++ b/py4DSTEM/braggvectors/multicorr_col_kernel.cu
@@ -0,0 +1,61 @@
+#include
+#define PI 3.14159265359
+extern "C" __global__
+void multicorr_col_kernel(
+ complex *ar,
+ const float *xyShifts,
+ const long long N_pts,
+ const long long image_size_x,
+ const long long image_size_y,
+ const long long upsample_factor) {
+ /*
+ Fill in the entries of the multicorr row kernel.
+ Inputs (C++ type/Python type):
+ ar (complex* / cp.complex64): Array of size N_pts x image_size[1] x kernel_size
+ to hold the row kernels
+ xyShifts (const float* / cp.float32): (N_pts x 2) array of center points to build kernels for
+ N_pts (const long long/int) number of center points we are
+ building kernels for
+ image_size_x (const long long/int): x size of the correlation image
+ image_size_y (const long long/int): y size of correlation image
+ upsample_factor (const long long/int): note, kernel_width = ceil(1.5*upsample_factor)
+ */
+ int kernel_size = ceil(1.5 * upsample_factor);
+
+ int tid = blockDim.x * blockIdx.x + threadIdx.x;
+
+ // Using strides to compute indices:
+ int stride_0 = image_size_y * kernel_size; // Stride along 0-th dimension of ar
+ int stride_1 = kernel_size; // Stride along 1-th dimension of ar
+
+ // Which kernel in the stack (first index of ar)
+ int kernel_idx = tid / stride_0;
+ // Which row in the kernel (second index of ar)
+ int row_idx = (tid % stride_0) / stride_1;
+ // Which column in the kernel (last index of ar)
+ int col_idx = (tid % stride_0) % stride_1;
+
+ complex prefactor = complex(0,-2.0 * PI) / float(image_size_y * upsample_factor);
+
+ // Now do the actual calculation
+ if (tid < N_pts * image_size_y * kernel_size) {
+ // np.fft.ifftshift(np.arange(imageSize[1])) - np.floor(imageSize[1]/2)
+ // modresult is necessary to get the Pythonic behavior of mod of negative numbers
+ int modresult = int(row_idx - ceil((float)image_size_y / 2.)) % image_size_y;
+ modresult = modresult < 0 ? modresult + image_size_y : modresult;
+ float columnEntry = float(modresult) - floor((float)image_size_y/2.) ;
+
+
+ // np.arange(numColumns) - xyShift[idx,0]
+ float rowEntry = (float)col_idx - xyShifts[kernel_idx*2 + 1];
+
+ ar[tid] = exp(prefactor * columnEntry * rowEntry);
+
+ // Use these for testing the indexing:
+ // ar[tid] = complex(0,(float)tid);
+ // ar[tid] = complex(0,(float)kernel_idx);
+ // ar[tid] = complex(0,(float)row_idx);
+ // ar[tid] = complex(0,(float)col_idx);
+ }
+
+}
diff --git a/py4DSTEM/braggvectors/multicorr_row_kernel.cu b/py4DSTEM/braggvectors/multicorr_row_kernel.cu
new file mode 100644
index 000000000..a5a0f352f
--- /dev/null
+++ b/py4DSTEM/braggvectors/multicorr_row_kernel.cu
@@ -0,0 +1,60 @@
+#include
+#define PI 3.14159265359
+extern "C" __global__
+void multicorr_row_kernel(
+ complex *ar,
+ const float *xyShifts,
+ const long long N_pts,
+ const long long image_size_x,
+ const long long image_size_y,
+ const long long upsample_factor) {
+ /*
+ Fill in the entries of the multicorr row kernel.
+ Inputs (C++ type/Python type):
+ ar (complex* / cp.complex64): Array of size N_pts x kernel_size x image_size[0]
+ to hold the row kernels
+ xyShifts (const float* / cp.float32): (N_pts x 2) array of center points to build kernels for
+ N_pts (const long long/int) number of center points we are
+ building kernels for
+ image_size_x (const long long/int): x size of the correlation image
+ image_size_y (const long long/int): y size of correlation image
+ upsample_factor (const long long/int): note, kernel_width = ceil(1.5*upsample_factor)
+ */
+ int kernel_size = ceil(1.5 * upsample_factor);
+
+ int tid = blockDim.x * blockIdx.x + threadIdx.x;
+
+ // Using strides to compute indices:
+ int stride_0 = image_size_x * kernel_size; // Stride along 0-th dimension of ar
+ int stride_1 = image_size_x; // Stride along 1-th dimension of ar
+
+ // Which kernel in the stack (first index of ar)
+ int kernel_idx = tid / stride_0;
+ // Which row in the kernel (second index of ar)
+ int row_idx = (tid % stride_0) / stride_1;
+ // Which column in the kernel (last index of ar)
+ int col_idx = (tid % stride_0) % stride_1;
+
+ complex prefactor = complex(0,-2.0 * PI) / float(image_size_x * upsample_factor);
+
+ // Now do the actual calculation
+ if (tid < N_pts * image_size_x * kernel_size) {
+ // np.arange(numColumns) - xyShift[idx,0]
+ float columnEntry = (float)row_idx - xyShifts[kernel_idx*2];
+
+ // np.fft.ifftshift(np.arange(imageSize[0])) - np.floor(imageSize[0]/2)
+ // modresult is necessary to get the Pythonic behavior of mod of negative numbers
+ int modresult = int(col_idx - ceil((float)image_size_x / 2.)) % image_size_x;
+ modresult = modresult < 0 ? modresult + image_size_x : modresult;
+ float rowEntry = float(modresult) - floor((float)image_size_x/2.) ;
+
+ ar[tid] = exp(prefactor * columnEntry * rowEntry);
+
+ // Use these for testing the indexing:
+ //ar[tid] = complex(0,(float)tid);
+ //ar[tid] = complex(0,(float)kernel_idx);
+ //ar[tid] = complex(0,(float)row_idx);
+ //ar[tid] = complex(0,(float)col_idx);
+ }
+
+}
\ No newline at end of file
diff --git a/py4DSTEM/braggvectors/probe.py b/py4DSTEM/braggvectors/probe.py
new file mode 100644
index 000000000..464c2f2a4
--- /dev/null
+++ b/py4DSTEM/braggvectors/probe.py
@@ -0,0 +1,559 @@
+# Defines the Probe class
+
+import numpy as np
+from typing import Optional
+from warnings import warn
+
+from py4DSTEM.data import DiffractionSlice, Data
+from scipy.ndimage import binary_opening, binary_dilation, distance_transform_edt
+
+
+class Probe(DiffractionSlice, Data):
+ """
+ Stores a vacuum probe.
+
+ Both a vacuum probe and a kernel for cross-correlative template matching
+ derived from that probe are stored and can be accessed at
+
+ >>> p.probe
+ >>> p.kernel
+
+ respectively, for some Probe instance `p`. If a kernel has not been computed
+ the latter expression returns None.
+
+
+ """
+
+ def __init__(self, data: np.ndarray, name: Optional[str] = "probe"):
+ """
+ Accepts:
+ data (2D or 3D np.ndarray): the vacuum probe, or
+ the vacuum probe + kernel
+ name (str): a name
+
+ Returns:
+ (Probe)
+ """
+ # if only the probe is passed, make space for the kernel
+ if data.ndim == 2:
+ data = np.stack([data, np.zeros_like(data)])
+
+ # initialize as a DiffractionSlice
+ DiffractionSlice.__init__(
+ self, name=name, data=data, slicelabels=["probe", "kernel"]
+ )
+
+ ## properties
+
+ @property
+ def probe(self):
+ return self.get_slice("probe").data
+
+ @probe.setter
+ def probe(self, x):
+ assert x.shape == (self.data.shape[1:])
+ self.data[0, :, :] = x
+
+ @property
+ def kernel(self):
+ return self.get_slice("kernel").data
+
+ @kernel.setter
+ def kernel(self, x):
+ assert x.shape == (self.data.shape[1:])
+ self.data[1, :, :] = x
+
+ # read
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of args/values to pass to the class constructor
+ """
+ ar_constr_args = DiffractionSlice._get_constructor_args(group)
+ args = {
+ "data": ar_constr_args["data"],
+ "name": ar_constr_args["name"],
+ }
+ return args
+
+ # generation methods
+
+ @classmethod
+ def from_vacuum_data(cls, data, mask=None, threshold=0.2, expansion=12, opening=3):
+ """
+ Generates and returns a vacuum probe Probe instance from either a
+ 2D vacuum image or a 3D stack of vacuum diffraction patterns.
+
+ The probe is multiplied by `mask`, if it's passed. An additional
+ masking step zeros values outside of a mask determined by `threshold`,
+ `expansion`, and `opening`, generated by first computing the binary image
+ probe < max(probe)*threshold, then applying a binary expansion and
+ then opening to this image. No alignment is performed - i.e. it is assumed
+ that the beam was stationary during acquisition of the stack. To align
+ the images, use the DataCube .get_vacuum_probe method.
+
+ Parameters
+ ----------
+ data : 2D or 3D array
+ the vacuum diffraction data. For 3D stacks, use shape (N,Q_Nx,Q_Ny)
+ mask : boolean array, optional
+ mask applied to the probe
+ threshold : float
+ threshold determining mask which zeros values outside of probe
+ expansion : int
+ number of pixels by which the zeroing mask is expanded to capture
+ the full probe
+ opening : int
+ size of binary opening used to eliminate stray bright pixels
+
+ Returns
+ -------
+ probe : Probe
+ the vacuum probe
+ """
+ assert isinstance(data, np.ndarray)
+ if data.ndim == 3:
+ probe = np.average(data, axis=0)
+ elif data.ndim == 2:
+ probe = data
+ else:
+ raise Exception(f"data must be 2- or 3-D, not {data.ndim}-D")
+
+ if mask is not None:
+ probe *= mask
+
+ mask = probe > np.max(probe) * threshold
+ mask = binary_opening(mask, iterations=opening)
+ mask = binary_dilation(mask, iterations=1)
+ mask = (
+ np.cos(
+ (np.pi / 2)
+ * np.minimum(
+ distance_transform_edt(np.logical_not(mask)) / expansion, 1
+ )
+ )
+ ** 2
+ )
+
+ probe = cls(probe * mask)
+ return probe
+
+ @classmethod
+ def generate_synthetic_probe(cls, radius, width, Qshape):
+ """
+ Makes a synthetic probe, with the functional form of a disk blurred by a
+ sigmoid (a logistic function).
+
+ Parameters
+ ----------
+ radius : float
+ the probe radius
+ width : float
+ the blurring of the probe edge. width represents the
+ full width of the blur, with x=-w/2 to x=+w/2 about the edge
+ spanning values of ~0.12 to 0.88
+ Qshape : 2 tuple
+ the diffraction plane dimensions
+
+ Returns
+ -------
+ probe : Probe
+ the probe
+ """
+ # Make coords
+ Q_Nx, Q_Ny = Qshape
+ qy, qx = np.meshgrid(np.arange(Q_Ny), np.arange(Q_Nx))
+ qy, qx = qy - Q_Ny / 2.0, qx - Q_Nx / 2.0
+ qr = np.sqrt(qx**2 + qy**2)
+
+ # Shift zero to disk edge
+ qr = qr - radius
+
+ # Calculate logistic function
+ probe = 1 / (1 + np.exp(4 * qr / width))
+
+ return cls(probe)
+
+ # calibration methods
+
+ def measure_disk(
+ self,
+ thresh_lower=0.01,
+ thresh_upper=0.99,
+ N=100,
+ returncalc=True,
+ data=None,
+ ):
+ """
+ Finds the center and radius of an average probe image.
+
+ A naive algorithm. Creates a series of N binary masks by thresholding
+ the probe image a linspace of N thresholds from thresh_lower to
+ thresh_upper, relative to the image max/min. For each mask, we find the
+ square root of the number of True valued pixels divided by pi to
+ estimate a radius. Because the central disk is intense relative to the
+ remainder of the image, the computed radii are expected to vary very
+ little over a wider range threshold values. A range of r values
+ considered trustworthy is estimated by taking the derivative
+ r(thresh)/dthresh identifying where it is small, and the mean of this
+ range is returned as the radius. A center is estimated using a binary
+ thresholded image in combination with the center of mass operator.
+
+ Parameters
+ ----------
+ thresh_lower : float, 0 to 1
+ the lower limit of threshold values
+ thresh_upper : float, 0 to 1)
+ the upper limit of threshold values
+ N : int
+ the number of thresholds / masks to use
+ returncalc : True
+ toggles returning the answer
+ data : 2d array, optional
+ if passed, uses this 2D array in place of the probe image when
+ performing the computation. This also supresses storing the
+ results in the Probe's calibration metadata
+
+ Returns
+ -------
+ r, x0, y0 : (3-tuple)
+ the radius and origin
+ """
+ from py4DSTEM.process.utils import get_CoM
+
+ # set the image
+ im = self.probe if data is None else data
+
+ # define the thresholds
+ thresh_vals = np.linspace(thresh_lower, thresh_upper, N)
+ r_vals = np.zeros(N)
+
+ # get binary images and compute a radius for each
+ immax = np.max(im)
+ for i, val in enumerate(thresh_vals):
+ mask = im > immax * val
+ r_vals[i] = np.sqrt(np.sum(mask) / np.pi)
+
+ # Get derivative and determine trustworthy r-values
+ dr_dtheta = np.gradient(r_vals)
+ mask = (dr_dtheta <= 0) * (dr_dtheta >= 2 * np.median(dr_dtheta))
+ r = np.mean(r_vals[mask])
+
+ # Get origin
+ thresh = np.mean(thresh_vals[mask])
+ mask = im > immax * thresh
+ x0, y0 = get_CoM(im * mask)
+
+ # Store metadata and return
+ ans = r, x0, y0
+ if data is None:
+ try:
+ self.calibration.set_probe_param(ans)
+ except AttributeError:
+ warn(
+ f"Couldn't store the probe parameters in metadata as no calibration was found for this Probe instance, {self}"
+ )
+ pass
+ if returncalc:
+ return ans
+
+ # Kernel generation methods
+
+ def get_kernel(
+ self, mode="flat", origin=None, data=None, returncalc=True, **kwargs
+ ):
+ """
+ Creates a cross-correlation kernel from the vacuum probe.
+
+ Specific behavior and valid keyword arguments depend on the `mode`
+ specified. In each case, the center of the probe is shifted to the
+ origin and the kernel normalized such that it sums to 1. This is the
+ only processing performed if mode is 'flat'. Otherwise, a centrosymmetric
+ region of negative intensity is added around the probe intended to promote
+ edge-filtering-like behavior during cross correlation, with the
+ functional form of the subtracted region defined by `mode` and the
+ relevant **kwargs. For normalization, flat probes integrate to 1, and the
+ remaining probes integrate to 1 before subtraction and 0 after. Required
+ keyword arguments are:
+
+ - 'flat': No required arguments. This mode is recommended for bullseye
+ or other structured probes
+ - 'gaussian': Required arg `sigma` (number), the width (standard
+ deviation) of a centered gaussian to be subtracted.
+ - 'sigmoid': Required arg `radii` (2-tuple), the inner and outer radii
+ (ri,ro) of an annular region with a sine-squared sigmoidal radial
+ profile to be subtracted.
+ - 'sigmoid_log': Required arg `radii` (2-tuple), the inner and outer radii
+ (ri,ro) of an annular region with a logistic sigmoidal radial
+ profile to be subtracted.
+
+ Parameters
+ ----------
+ mode : str
+ must be in 'flat','gaussian','sigmoid','sigmoid_log'
+ origin : 2-tuple, optional
+ specify the origin. If not passed, looks for a value for the probe
+ origin in metadata. If not found there, calls .measure_disk.
+ data : 2d array, optional
+ if specified, uses this array instead of the probe image to compute
+ the kernel
+ **kwargs
+ see descriptions above
+
+ Returns
+ -------
+ kernel : 2D array
+ """
+
+ modes = ["flat", "gaussian", "sigmoid", "sigmoid_log"]
+
+ # parse args
+ assert mode in modes, f"mode must be in {modes}. Received {mode}"
+
+ # get function
+ function_dict = {
+ "flat": self.get_probe_kernel_flat,
+ "gaussian": self.get_probe_kernel_edge_gaussian,
+ "sigmoid": self._get_probe_kernel_edge_sigmoid_sine_squared,
+ "sigmoid_log": self._get_probe_kernel_edge_sigmoid_sine_squared,
+ }
+ fn = function_dict[mode]
+
+ # check for the origin
+ if origin is None:
+ try:
+ x = self.calibration.get_probe_params()
+ except AttributeError:
+ x = None
+ finally:
+ if x is None:
+ origin = None
+ else:
+ r, x, y = x
+ origin = (x, y)
+
+ # get the data
+ probe = data if data is not None else self.probe
+
+ # compute
+ kern = fn(probe, origin=origin, **kwargs)
+
+ # add to the Probe
+ self.kernel = kern
+
+ # return
+ if returncalc:
+ return kern
+
+ @staticmethod
+ def get_probe_kernel_flat(probe, origin=None, bilinear=False):
+ """
+ Creates a cross-correlation kernel from the vacuum probe by normalizing
+ and shifting the center.
+
+ Parameters
+ ----------
+ probe : 2d array
+ the vacuum probe
+ origin : 2-tuple (optional)
+ the origin of diffraction space. If not specified, finds the origin
+ using get_probe_radius.
+ bilinear : bool (optional)
+ By default probe is shifted via a Fourier transform. Setting this to
+ True overrides it and uses bilinear shifting. Not recommended!
+
+ Returns
+ -------
+ kernel : ndarray
+ the cross-correlation kernel corresponding to the probe, in real
+ space
+ """
+ from py4DSTEM.process.utils import get_shifted_ar
+
+ Q_Nx, Q_Ny = probe.shape
+
+ # Get CoM
+ if origin is None:
+ from py4DSTEM.process.calibration import get_probe_size
+
+ _, xCoM, yCoM = get_probe_size(probe)
+ else:
+ xCoM, yCoM = origin
+
+ # Normalize
+ probe = probe / np.sum(probe)
+
+ # Shift center to corners of array
+ probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear)
+
+ # Return
+ return probe_kernel
+
+ @staticmethod
+ def get_probe_kernel_edge_gaussian(
+ probe,
+ sigma,
+ origin=None,
+ bilinear=True,
+ ):
+ """
+ Creates a cross-correlation kernel from the probe, subtracting a
+ gaussian from the normalized probe such that the kernel integrates to
+ zero, then shifting the center of the probe to the array corners.
+
+ Parameters
+ ----------
+ probe : ndarray
+ the diffraction pattern corresponding to the probe over vacuum
+ sigma : float
+ the width of the gaussian to subtract, relative to the standard
+ deviation of the probe
+ origin : 2-tuple (optional)
+ the origin of diffraction space. If not specified, finds the origin
+ using get_probe_radius.
+ bilinear : bool
+ By default probe is shifted via a Fourier transform. Setting this to
+ True overrides it and uses bilinear shifting. Not recommended!
+
+ Returns
+ -------
+ kernel : ndarray
+ the cross-correlation kernel
+ """
+ from py4DSTEM.process.utils import get_shifted_ar
+
+ Q_Nx, Q_Ny = probe.shape
+
+ # Get CoM
+ if origin is None:
+ from py4DSTEM.process.calibration import get_probe_size
+
+ _, xCoM, yCoM = get_probe_size(probe)
+ else:
+ xCoM, yCoM = origin
+
+ # Shift probe to origin
+ probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear)
+
+ # Generate normalization kernel
+ # Coordinates
+ qy, qx = np.meshgrid(
+ np.mod(np.arange(Q_Ny) + Q_Ny // 2, Q_Ny) - Q_Ny // 2,
+ np.mod(np.arange(Q_Nx) + Q_Nx // 2, Q_Nx) - Q_Nx // 2,
+ )
+ qr2 = qx**2 + qy**2
+ # Calculate Gaussian normalization kernel
+ qstd2 = np.sum(qr2 * probe_kernel) / np.sum(probe_kernel)
+ kernel_norm = np.exp(-qr2 / (2 * qstd2 * sigma**2))
+
+ # Output normalized kernel
+ probe_kernel = probe_kernel / np.sum(probe_kernel) - kernel_norm / np.sum(
+ kernel_norm
+ )
+
+ return probe_kernel
+
+ @staticmethod
+ def get_probe_kernel_edge_sigmoid(
+ probe,
+ radii,
+ origin=None,
+ type="sine_squared",
+ bilinear=True,
+ ):
+ """
+ Creates a convolution kernel from an average probe, subtracting an annular
+ trench about the probe such that the kernel integrates to zero, then
+ shifting the center of the probe to the array corners.
+
+ Parameters
+ ----------
+ probe : ndarray
+ the diffraction pattern corresponding to the probe over vacuum
+ radii : 2-tuple
+ the sigmoid inner and outer radii
+ origin : 2-tuple (optional)
+ the origin of diffraction space. If not specified, finds the origin
+ using get_probe_radius.
+ type : string
+ must be 'logistic' or 'sine_squared'
+ bilinear : bool
+ By default probe is shifted via a Fourier transform. Setting this to
+ True overrides it and uses bilinear shifting. Not recommended!
+
+ Returns
+ -------
+ kernel : 2d array
+ the cross-correlation kernel
+ """
+ from py4DSTEM.process.utils import get_shifted_ar
+
+ # parse inputs
+ if isinstance(probe, Probe):
+ probe = probe.probe
+
+ valid_types = ("logistic", "sine_squared")
+ assert type in valid_types, "type must be in {}".format(valid_types)
+ Q_Nx, Q_Ny = probe.shape
+ ri, ro = radii
+
+ # Get CoM
+ if origin is None:
+ from py4DSTEM.process.calibration import get_probe_size
+
+ _, xCoM, yCoM = get_probe_size(probe)
+ else:
+ xCoM, yCoM = origin
+
+ # Shift probe to origin
+ probe_kernel = get_shifted_ar(probe, -xCoM, -yCoM, bilinear=bilinear)
+
+ # Generate normalization kernel
+ # Coordinates
+ qy, qx = np.meshgrid(
+ np.mod(np.arange(Q_Ny) + Q_Ny // 2, Q_Ny) - Q_Ny // 2,
+ np.mod(np.arange(Q_Nx) + Q_Nx // 2, Q_Nx) - Q_Nx // 2,
+ )
+ qr = np.sqrt(qx**2 + qy**2)
+ # Calculate sigmoid
+ if type == "logistic":
+ r0 = 0.5 * (ro + ri)
+ sigma = 0.25 * (ro - ri)
+ sigmoid = 1 / (1 + np.exp((qr - r0) / sigma))
+ elif type == "sine_squared":
+ sigmoid = (qr - ri) / (ro - ri)
+ sigmoid = np.minimum(np.maximum(sigmoid, 0.0), 1.0)
+ sigmoid = np.cos((np.pi / 2) * sigmoid) ** 2
+ else:
+ raise Exception("type must be in {}".format(valid_types))
+
+ # Output normalized kernel
+ probe_kernel = probe_kernel / np.sum(probe_kernel) - sigmoid / np.sum(sigmoid)
+
+ return probe_kernel
+
+ def _get_probe_kernel_edge_sigmoid_sine_squared(
+ self,
+ probe,
+ radii,
+ origin=None,
+ **kwargs,
+ ):
+ return self.get_probe_kernel_edge_sigmoid(
+ probe,
+ radii,
+ origin=origin,
+ type="sine_squared",
+ **kwargs,
+ )
+
+ def _get_probe_kernel_edge_sigmoid_logistic(
+ self,
+ probe,
+ radii,
+ origin=None,
+ **kwargs,
+ ):
+ return self.get_probe_kernel_edge_sigmoid(
+ probe, radii, origin=origin, type="logistic", **kwargs
+ )
diff --git a/py4DSTEM/braggvectors/threshold.py b/py4DSTEM/braggvectors/threshold.py
new file mode 100644
index 000000000..c6c1c4afc
--- /dev/null
+++ b/py4DSTEM/braggvectors/threshold.py
@@ -0,0 +1,225 @@
+# Bragg peaks thresholding fns
+
+import numpy as np
+
+from emdfile import tqdmnd, PointListArray
+
+
+def threshold_Braggpeaks(
+ pointlistarray, minRelativeIntensity, relativeToPeak, minPeakSpacing, maxNumPeaks
+):
+ """
+ Takes a PointListArray of detected Bragg peaks and applies additional
+ thresholding, returning the thresholded PointListArray. To skip a threshold,
+ set that parameter to False.
+
+ Args:
+ pointlistarray (PointListArray): The Bragg peaks. Must have
+ coords=('qx','qy','intensity')
+ minRelativeIntensity (float): the minimum allowed peak intensity,
+ relative to the brightest peak in each diffraction pattern
+ relativeToPeak (int): specifies the peak against which the minimum
+ relative intensity is measured -- 0=brightest maximum. 1=next
+ brightest, etc.
+ minPeakSpacing (int): the minimum allowed spacing between adjacent peaks
+ maxNumPeaks (int): maximum number of allowed peaks per diffraction
+ pattern
+ """
+ assert all(
+ [item in pointlistarray.dtype.fields for item in ["qx", "qy", "intensity"]]
+ ), "pointlistarray must include the coordinates 'qx', 'qy', and 'intensity'."
+ for Rx, Ry in tqdmnd(
+ pointlistarray.shape[0],
+ pointlistarray.shape[1],
+ desc="Thresholding Bragg disks",
+ unit="DP",
+ unit_scale=True,
+ ):
+ pointlist = pointlistarray.get_pointlist(Rx, Ry)
+ pointlist.sort(coordinate="intensity", order="descending")
+
+ # Remove peaks below minRelativeIntensity threshold
+ if minRelativeIntensity is not False:
+ deletemask = (
+ pointlist.data["intensity"]
+ / pointlist.data["intensity"][relativeToPeak]
+ < minRelativeIntensity
+ )
+ pointlist.remove_points(deletemask)
+
+ # Remove peaks that are too close together
+ if maxNumPeaks is not False:
+ r2 = minPeakSpacing**2
+ deletemask = np.zeros(pointlist.length, dtype=bool)
+ for i in range(pointlist.length):
+ if deletemask[i] is False:
+ tooClose = (
+ (pointlist.data["qx"] - pointlist.data["qx"][i]) ** 2
+ + (pointlist.data["qy"] - pointlist.data["qy"][i]) ** 2
+ ) < r2
+ tooClose[: i + 1] = False
+ deletemask[tooClose] = True
+ pointlist.remove_points(deletemask)
+
+ # Keep only up to maxNumPeaks
+ if maxNumPeaks is not False:
+ if maxNumPeaks < pointlist.length:
+ deletemask = np.zeros(pointlist.length, dtype=bool)
+ deletemask[maxNumPeaks:] = True
+ pointlist.remove_points(deletemask)
+
+ return pointlistarray
+
+
+def universal_threshold(
+ pointlistarray,
+ thresh,
+ metric="maximum",
+ minPeakSpacing=False,
+ maxNumPeaks=False,
+ name=None,
+):
+ """
+ Takes a PointListArray of detected Bragg peaks and applies universal
+ thresholding, returning the thresholded PointListArray. To skip a threshold,
+ set that parameter to False.
+
+ Args:
+ pointlistarray (PointListArray): The Bragg peaks. Must have
+ coords=('qx','qy','intensity')
+ thresh (float): the minimum allowed peak intensity. The meaning of this
+ threshold value is determined by the value of the 'metric' argument,
+ below
+ metric (string): the metric used to compare intensities. Must be in
+ ('maximum','average','median','manual'). In each case aside from
+ 'manual', the intensity threshold is set to Val*thresh, where Val is
+ given by
+ * 'maximum' - the maximum intensity in the entire pointlistarray
+ * 'average' - the average of the maximum intensities of each
+ scan position in the pointlistarray
+ * 'median' - the medain of the maximum intensities of each
+ scan position in the entire pointlistarray
+ If metric is 'manual', the threshold is exactly minIntensity
+ minPeakSpacing (int): the minimum allowed spacing between adjacent peaks.
+ optional, default is false
+ maxNumPeaks (int): maximum number of allowed peaks per diffraction pattern.
+ optional, default is false
+ name (str, optional): a name for the returned PointListArray. If
+ unspecified, takes the old PLA name and appends '_unithresh'.
+
+ Returns:
+ (PointListArray): Bragg peaks thresholded by intensity.
+ """
+ assert isinstance(pointlistarray, PointListArray)
+ assert metric in ("maximum", "average", "median", "manual")
+ assert all(
+ [item in pointlistarray.dtype.fields for item in ["qx", "qy", "intensity"]]
+ ), "pointlistarray must include the coordinates 'qx', 'qy', and 'intensity'."
+ _pointlistarray = pointlistarray.copy()
+ if name is None:
+ _pointlistarray.name = pointlistarray.name + "_unithresh"
+
+ HI_array = np.zeros((_pointlistarray.shape[0], _pointlistarray.shape[1]))
+ for Rx, Ry in tqdmnd(
+ _pointlistarray.shape[0],
+ _pointlistarray.shape[1],
+ desc="Thresholding Bragg disks",
+ unit="DP",
+ unit_scale=True,
+ ):
+ pointlist = _pointlistarray.get_pointlist(Rx, Ry)
+ if pointlist.data.shape[0] == 0:
+ top_value = np.nan
+ else:
+ HI_array[Rx, Ry] = np.max(pointlist.data["intensity"])
+
+ if metric == "maximum":
+ _thresh = np.max(HI_array) * thresh
+ elif metric == "average":
+ _thresh = np.nanmean(HI_array) * thresh
+ elif metric == "median":
+ _thresh = np.median(HI_array) * thresh
+ else:
+ _thresh = thresh
+
+ for Rx, Ry in tqdmnd(
+ _pointlistarray.shape[0],
+ _pointlistarray.shape[1],
+ desc="Thresholding Bragg disks",
+ unit="DP",
+ unit_scale=True,
+ ):
+ pointlist = _pointlistarray.get_pointlist(Rx, Ry)
+
+ # Remove peaks below minRelativeIntensity threshold
+ deletemask = pointlist.data["intensity"] < _thresh
+ pointlist.remove(deletemask)
+
+ # Remove peaks that are too close together
+ if maxNumPeaks is not False:
+ r2 = minPeakSpacing**2
+ deletemask = np.zeros(pointlist.length, dtype=bool)
+ for i in range(pointlist.length):
+ if deletemask[i] is False:
+ tooClose = (
+ (pointlist.data["qx"] - pointlist.data["qx"][i]) ** 2
+ + (pointlist.data["qy"] - pointlist.data["qy"][i]) ** 2
+ ) < r2
+ tooClose[: i + 1] = False
+ deletemask[tooClose] = True
+ pointlist.remove_points(deletemask)
+
+ # Keep only up to maxNumPeaks
+ if maxNumPeaks is not False:
+ if maxNumPeaks < pointlist.length:
+ deletemask = np.zeros(pointlist.length, dtype=bool)
+ deletemask[maxNumPeaks:] = True
+ pointlist.remove_points(deletemask)
+ return _pointlistarray
+
+
+def get_pointlistarray_intensities(pointlistarray):
+ """
+ Concatecates the Bragg peak intensities from a PointListArray of Bragg peak
+ positions into one array and returns the intensities. This output can be used
+ for understanding the distribution of intensities in your dataset for
+ universal thresholding.
+
+ Args:
+ pointlistarray (PointListArray):
+
+ Returns:
+ (ndarray): all detected peak intensities
+ """
+ assert np.all(
+ [name in pointlistarray.dtype.names for name in ["qx", "qy", "intensity"]]
+ ), "pointlistarray coords must include coordinates: 'qx', 'qy', 'intensity'."
+ assert (
+ "qx" in pointlistarray.dtype.names
+ ), "pointlistarray coords must include 'qx' and 'qy'"
+ assert (
+ "qy" in pointlistarray.dtype.names
+ ), "pointlistarray coords must include 'qx' and 'qy'"
+ assert (
+ "intensity" in pointlistarray.dtype.names
+ ), "pointlistarray coords must include 'intensity'"
+
+ first_pass = True
+ for Rx, Ry in tqdmnd(
+ pointlistarray.shape[0],
+ pointlistarray.shape[1],
+ desc="Getting disk intensities",
+ unit="DP",
+ unit_scale=True,
+ ):
+ pointlist = pointlistarray.get_pointlist(Rx, Ry)
+ for i in range(pointlist.length):
+ if first_pass:
+ peak_intensities = np.array(pointlist.data[i][2])
+ peak_intensities = np.reshape(peak_intensities, 1)
+ first_pass = False
+ else:
+ temp_array = np.array(pointlist.data[i][2])
+ temp_array = np.reshape(temp_array, 1)
+ peak_intensities = np.append(peak_intensities, temp_array)
+ return peak_intensities
diff --git a/py4DSTEM/data/__init__.py b/py4DSTEM/data/__init__.py
new file mode 100644
index 000000000..ac697918d
--- /dev/null
+++ b/py4DSTEM/data/__init__.py
@@ -0,0 +1,7 @@
+_emd_hook = True
+
+from py4DSTEM.data.calibration import Calibration
+from py4DSTEM.data.data import Data
+from py4DSTEM.data.diffractionslice import DiffractionSlice
+from py4DSTEM.data.realslice import RealSlice
+from py4DSTEM.data.qpoints import QPoints
diff --git a/py4DSTEM/data/calibration.py b/py4DSTEM/data/calibration.py
new file mode 100644
index 000000000..408f977cc
--- /dev/null
+++ b/py4DSTEM/data/calibration.py
@@ -0,0 +1,920 @@
+# Defines the Calibration class, which stores calibration metadata
+
+import numpy as np
+from numbers import Number
+from typing import Optional
+from warnings import warn
+
+from emdfile import Metadata, Root
+from py4DSTEM.data.propagating_calibration import call_calibrate
+
+
+class Calibration(Metadata):
+ """
+ Stores calibration measurements.
+
+ Usage
+ -----
+ For some calibration instance `c`
+
+ >>> c['x'] = y
+
+ will set the value of some calibration item called 'x' to y, and
+
+ >>> _y = c['x']
+
+ will return the value currently stored as 'x' and assign it to _y.
+ Additionally, for calibration items in the list `l` given below,
+ the syntax
+
+ >>> c.set_p(p)
+ >>> p = c.get_p()
+
+ is equivalent to
+
+ >>> c.p = p
+ >>> p = c.p
+
+ is equivalent to
+
+ >>> c['p'] = p
+ >>> p = c['p']
+
+ where in the first line of each couplet the parameter `p` is set and in
+ the second it's retrieved, for parameters p in the list
+
+ calibrate
+ ---------
+ l = [
+ Q_pixel_size, *
+ R_pixel_size, *
+ Q_pixel_units, *
+ R_pixel_units, *
+ qx0,
+ qy0,
+ qx0_mean,
+ qy0_mean,
+ qx0shift,
+ qy0shift,
+ origin, *
+ origin_meas,
+ origin_meas_mask,
+ origin_shift,
+ a, *
+ b, *
+ theta, *
+ p_ellipse, *
+ ellipse, *
+ QR_rotation_degrees, *
+ QR_flip, *
+ QR_rotflip, *
+ probe_semiangle,
+ probe_param,
+ probe_center,
+ probe_convergence_semiangle_pixels,
+ probe_convergence_semiangle_mrad,
+ ]
+
+ There are two advantages to using the getter/setter syntax for parameters
+ in `l` (e.g. either c.set_p or c.p) instead of the normal dictionary-like
+ getter/setter syntax (i.e. c['p']). These are (1) enabling retrieving
+ parameters by beam scan position, and (2) enabling propagation of any
+ calibration changes to downstream data objects which are affected by the
+ altered calibrations. See below.
+
+ Get a parameter by beam scan position
+ -------------------------------------
+ Some parameters support retrieval by beam scan position. In these cases,
+ calling
+
+ >>> c.get_p(rx,ry)
+
+ will return the value of parameter p at beam position (rx,ry). This works
+ only for the above syntax. Using either of
+
+ >>> c.p
+ >>> c['p']
+
+ will return an R-space shaped array.
+
+ Trigger downstream calibrations
+ -------------------------------
+ Some objects store their own internal calibration state, which depends on
+ the calibrations stored here. For example, a DataCube stores dimension
+ vectors which calibrate its 4 dimensions, and which depend on the pixel
+ sizes and the origin position.
+
+ Modifying certain parameters therefore can trigger other objects which
+ depend on these parameters to re-calibrate themselves by calling their
+ .calibrate() method, if the object has one. Methods marked with a * in the
+ list `l` above have this property. Only objects registered with the
+ Calibration instance will have their .calibrate method triggered by changing
+ these parameters. An object `data` can be registered by calling
+
+ >>> c.register_target( data )
+
+ and deregistered with
+
+ >>> c.deregister_target( data )
+
+ If an object without a .calibrate method is registerd when a * method is
+ called, nothing happens.
+
+ The .calibrate methods are triggered by setting some parameter `p` using
+ either
+
+ >>> c.set_p( val )
+
+ or
+
+ >>> c.p = val
+
+ syntax. Setting the parameter with
+
+ >>> c['p'] = val
+
+ will not trigger re-calibrations.
+
+ Calibration + Data
+ ------------------
+ Data in py4DSTEM is stored in filetree like representations, and
+ Calibration instances are the top-level objects in these trees,
+ in that they live here:
+
+ Root
+ |--metadata
+ | |-- *****---> calibration <---*****
+ |
+ |--some_object(e.g.datacube)
+ | |--another_object(e.g.max_dp)
+ | |--etc.
+ |--etc.
+ :
+
+ Every py4DSTEM Data object has a tree with a calibration, and calling
+
+ >>> data.calibration
+
+ will return the that Calibration instance. See also the docstring
+ for the `Data` class.
+
+ Attaching an object to a different Calibration
+ ----------------------------------------------
+ To modify the calibration associated with some object `data`, use
+
+ >>> c.attach( data )
+
+ where `c` is the new calibration instance. This (1) moves `data` into the
+ top level of `c`'s data tree, which means the new calibration will now be
+ accessible normally at
+
+ >>> data.calibration
+
+ and (2) if and only if `data` was registered with its old calibration,
+ de-registers it there and registers it with the new calibration. If
+ `data` was not registered with the old calibration and it should be
+ registered with the new one, `c.register_target( data )` should be
+ called.
+
+ To attach `data` to a different location in the calibration instance's
+ tree, use `node.attach( data )`. See the Data.attach docstring.
+ """
+
+ def __init__(
+ self,
+ name: Optional[str] = "calibration",
+ root: Optional[Root] = None,
+ ):
+ """
+ Args:
+ name (optional, str):
+ """
+ Metadata.__init__(self, name=name)
+
+ # Set the root
+ if root is None:
+ root = Root(name="py4DSTEM_root")
+ self.set_root(root)
+
+ # List to hold objects that will re-`calibrate` when
+ # certain properties are changed
+ self._targets = []
+
+ # set initial pixel values
+ self["Q_pixel_size"] = 1
+ self["R_pixel_size"] = 1
+ self["Q_pixel_units"] = "pixels"
+ self["R_pixel_units"] = "pixels"
+ self["QR_flip"] = False
+
+ # EMD root property
+ @property
+ def root(self):
+ return self._root
+
+ @root.setter
+ def root(self):
+ raise Exception(
+ "Calibration.root does not support assignment; to change the root, use self.set_root"
+ )
+
+ def set_root(self, root):
+ assert isinstance(root, Root), f"root must be a Root, not type {type(root)}"
+ self._root = root
+
+ # Attach data to the calibration instance
+ def attach(self, data):
+ """
+ Attach `data` to this calibration instance, placing it in the top
+ level of the Calibration instance's tree. If `data` was in a
+ different data tree, remove it. If `data` was registered with
+ a different calibration instance, de-register it there and
+ register it here. If `data` was not previously registerd and it
+ should be, after attaching it run `self.register_target(data)`.
+ """
+ from py4DSTEM.data import Data
+
+ assert isinstance(data, Data), "data must be a Data instance"
+ self.root.attach(data)
+
+ # Register for auto-calibration
+ def register_target(self, new_target):
+ """
+ Register an object to recieve calls to it `calibrate`
+ method when certain calibrations get updated
+ """
+ if new_target not in self._targets:
+ self._targets.append(new_target)
+
+ def unregister_target(self, target):
+ """
+ Unlink an object from recieving calls to `calibrate` when
+ certain calibration values are changed
+ """
+ if target in self._targets:
+ self._targets.remove(target)
+
+ @property
+ def targets(self):
+ return tuple(self._targets)
+
+ ######### Begin Calibration Metadata Params #########
+
+ # pixel size/units
+
+ @call_calibrate
+ def set_Q_pixel_size(self, x):
+ self._params["Q_pixel_size"] = x
+
+ def get_Q_pixel_size(self):
+ return self._get_value("Q_pixel_size")
+
+ # aliases
+ @property
+ def Q_pixel_size(self):
+ return self.get_Q_pixel_size()
+
+ @Q_pixel_size.setter
+ def Q_pixel_size(self, x):
+ self.set_Q_pixel_size(x)
+
+ @property
+ def qpixsize(self):
+ return self.get_Q_pixel_size()
+
+ @qpixsize.setter
+ def qpixsize(self, x):
+ self.set_Q_pixel_size(x)
+
+ @call_calibrate
+ def set_R_pixel_size(self, x):
+ self._params["R_pixel_size"] = x
+
+ def get_R_pixel_size(self):
+ return self._get_value("R_pixel_size")
+
+ # aliases
+ @property
+ def R_pixel_size(self):
+ return self.get_R_pixel_size()
+
+ @R_pixel_size.setter
+ def R_pixel_size(self, x):
+ self.set_R_pixel_size(x)
+
+ @property
+ def qpixsize(self):
+ return self.get_R_pixel_size()
+
+ @qpixsize.setter
+ def qpixsize(self, x):
+ self.set_R_pixel_size(x)
+
+ @call_calibrate
+ def set_Q_pixel_units(self, x):
+ assert x in (
+ "pixels",
+ "A^-1",
+ "mrad",
+ ), "Q pixel units must be 'A^-1', 'mrad' or 'pixels'."
+ self._params["Q_pixel_units"] = x
+
+ def get_Q_pixel_units(self):
+ return self._get_value("Q_pixel_units")
+
+ # aliases
+ @property
+ def Q_pixel_units(self):
+ return self.get_Q_pixel_units()
+
+ @Q_pixel_units.setter
+ def Q_pixel_units(self, x):
+ self.set_Q_pixel_units(x)
+
+ @property
+ def qpixunits(self):
+ return self.get_Q_pixel_units()
+
+ @qpixunits.setter
+ def qpixunits(self, x):
+ self.set_Q_pixel_units(x)
+
+ @call_calibrate
+ def set_R_pixel_units(self, x):
+ self._params["R_pixel_units"] = x
+
+ def get_R_pixel_units(self):
+ return self._get_value("R_pixel_units")
+
+ # aliases
+ @property
+ def R_pixel_units(self):
+ return self.get_R_pixel_units()
+
+ @R_pixel_units.setter
+ def R_pixel_units(self, x):
+ self.set_R_pixel_units(x)
+
+ @property
+ def rpixunits(self):
+ return self.get_R_pixel_units()
+
+ @rpixunits.setter
+ def rpixunits(self, x):
+ self.set_R_pixel_units(x)
+
+ # origin
+
+ # qx0,qy0
+ def set_qx0(self, x):
+ self._params["qx0"] = x
+ x = np.asarray(x)
+ qx0_mean = np.mean(x)
+ qx0_shift = x - qx0_mean
+ self._params["qx0_mean"] = qx0_mean
+ self._params["qx0_shift"] = qx0_shift
+
+ def set_qx0_mean(self, x):
+ self._params["qx0_mean"] = x
+
+ def get_qx0(self, rx=None, ry=None):
+ return self._get_value("qx0", rx, ry)
+
+ def get_qx0_mean(self):
+ return self._get_value("qx0_mean")
+
+ def get_qx0shift(self, rx=None, ry=None):
+ return self._get_value("qx0_shift", rx, ry)
+
+ def set_qy0(self, x):
+ self._params["qy0"] = x
+ x = np.asarray(x)
+ qy0_mean = np.mean(x)
+ qy0_shift = x - qy0_mean
+ self._params["qy0_mean"] = qy0_mean
+ self._params["qy0_shift"] = qy0_shift
+
+ def set_qy0_mean(self, x):
+ self._params["qy0_mean"] = x
+
+ def get_qy0(self, rx=None, ry=None):
+ return self._get_value("qy0", rx, ry)
+
+ def get_qy0_mean(self):
+ return self._get_value("qy0_mean")
+
+ def get_qy0shift(self, rx=None, ry=None):
+ return self._get_value("qy0_shift", rx, ry)
+
+ def set_qx0_meas(self, x):
+ self._params["qx0_meas"] = x
+
+ def get_qx0_meas(self, rx=None, ry=None):
+ return self._get_value("qx0_meas", rx, ry)
+
+ def set_qy0_meas(self, x):
+ self._params["qy0_meas"] = x
+
+ def get_qy0_meas(self, rx=None, ry=None):
+ return self._get_value("qy0_meas", rx, ry)
+
+ def set_origin_meas_mask(self, x):
+ self._params["origin_meas_mask"] = x
+
+ def get_origin_meas_mask(self, rx=None, ry=None):
+ return self._get_value("origin_meas_mask", rx, ry)
+
+ # aliases
+ @property
+ def qx0(self):
+ return self.get_qx0()
+
+ @qx0.setter
+ def qx0(self, x):
+ self.set_qx0(x)
+
+ @property
+ def qx0_mean(self):
+ return self.get_qx0_mean()
+
+ @qx0_mean.setter
+ def qx0_mean(self, x):
+ self.set_qx0_mean(x)
+
+ @property
+ def qx0shift(self):
+ return self.get_qx0shift()
+
+ @property
+ def qy0(self):
+ return self.get_qy0()
+
+ @qy0.setter
+ def qy0(self, x):
+ self.set_qy0(x)
+
+ @property
+ def qy0_mean(self):
+ return self.get_qy0_mean()
+
+ @qy0_mean.setter
+ def qy0_mean(self, x):
+ self.set_qy0_mean(x)
+
+ @property
+ def qy0_shift(self):
+ return self.get_qy0_shift()
+
+ @property
+ def qx0_meas(self):
+ return self.get_qx0_meas()
+
+ @qx0_meas.setter
+ def qx0_meas(self, x):
+ self.set_qx0_meas(x)
+
+ @property
+ def qy0_meas(self):
+ return self.get_qy0_meas()
+
+ @qy0_meas.setter
+ def qy0_meas(self, x):
+ self.set_qy0_meas(x)
+
+ @property
+ def origin_meas_mask(self):
+ return self.get_origin_meas_mask()
+
+ @origin_meas_mask.setter
+ def origin_meas_mask(self, x):
+ self.set_origin_meas_mask(x)
+
+ # origin = (qx0,qy0)
+
+ @call_calibrate
+ def set_origin(self, x):
+ """
+ Args:
+ x (2-tuple of numbers or of 2D, R-shaped arrays): the origin
+ """
+ qx0, qy0 = x
+ self.set_qx0(qx0)
+ self.set_qy0(qy0)
+
+ def get_origin(self, rx=None, ry=None):
+ qx0 = self._get_value("qx0", rx, ry)
+ qy0 = self._get_value("qy0", rx, ry)
+ ans = (qx0, qy0)
+ if any([x is None for x in ans]):
+ ans = None
+ return ans
+
+ def get_origin_mean(self):
+ qx0 = self._get_value("qx0_mean")
+ qy0 = self._get_value("qy0_mean")
+ return qx0, qy0
+
+ def get_origin_shift(self, rx=None, ry=None):
+ qx0 = self._get_value("qx0_shift", rx, ry)
+ qy0 = self._get_value("qy0_shift", rx, ry)
+ ans = (qx0, qy0)
+ if any([x is None for x in ans]):
+ ans = None
+ return ans
+
+ def set_origin_meas(self, x):
+ """
+ Args:
+ x (2-tuple or 3 uple of 2D R-shaped arrays): qx0,qy0,[mask]
+ """
+ qx0, qy0 = x[0], x[1]
+ self.set_qx0_meas(qx0)
+ self.set_qy0_meas(qy0)
+ try:
+ m = x[2]
+ self.set_origin_meas_mask(m)
+ except IndexError:
+ pass
+
+ def get_origin_meas(self, rx=None, ry=None):
+ qx0 = self._get_value("qx0_meas", rx, ry)
+ qy0 = self._get_value("qy0_meas", rx, ry)
+ ans = (qx0, qy0)
+ if any([x is None for x in ans]):
+ ans = None
+ return ans
+
+ # aliases
+ @property
+ def origin(self):
+ return self.get_origin()
+
+ @origin.setter
+ def origin(self, x):
+ self.set_origin(x)
+
+ @property
+ def origin_meas(self):
+ return self.get_origin_meas()
+
+ @origin_meas.setter
+ def origin_meas(self, x):
+ self.set_origin_meas(x)
+
+ @property
+ def origin_shift(self):
+ return self.get_origin_shift()
+
+ # ellipse
+
+ @call_calibrate
+ def set_a(self, x):
+ self._params["a"] = x
+
+ def get_a(self, rx=None, ry=None):
+ return self._get_value("a", rx, ry)
+
+ @call_calibrate
+ def set_b(self, x):
+ self._params["b"] = x
+
+ def get_b(self, rx=None, ry=None):
+ return self._get_value("b", rx, ry)
+
+ @call_calibrate
+ def set_theta(self, x):
+ self._params["theta"] = x
+
+ def get_theta(self, rx=None, ry=None):
+ return self._get_value("theta", rx, ry)
+
+ @call_calibrate
+ def set_ellipse(self, x):
+ """
+ Args:
+ x (3-tuple): (a,b,theta)
+ """
+ a, b, theta = x
+ self._params["a"] = a
+ self._params["b"] = b
+ self._params["theta"] = theta
+
+ @call_calibrate
+ def set_p_ellipse(self, x):
+ """
+ Args:
+ x (5-tuple): (qx0,qy0,a,b,theta) NOTE: does *not* change qx0,qy0!
+ """
+ _, _, a, b, theta = x
+ self._params["a"] = a
+ self._params["b"] = b
+ self._params["theta"] = theta
+
+ def get_ellipse(self, rx=None, ry=None):
+ a = self.get_a(rx, ry)
+ b = self.get_b(rx, ry)
+ theta = self.get_theta(rx, ry)
+ ans = (a, b, theta)
+ if any([x is None for x in ans]):
+ ans = None
+ return ans
+
+ def get_p_ellipse(self, rx=None, ry=None):
+ qx0, qy0 = self.get_origin(rx, ry)
+ a, b, theta = self.get_ellipse(rx, ry)
+ return (qx0, qy0, a, b, theta)
+
+ # aliases
+ @property
+ def a(self):
+ return self.get_a()
+
+ @a.setter
+ def a(self, x):
+ self.set_a(x)
+
+ @property
+ def b(self):
+ return self.get_b()
+
+ @b.setter
+ def b(self, x):
+ self.set_b(x)
+
+ @property
+ def theta(self):
+ return self.get_theta()
+
+ @theta.setter
+ def theta(self, x):
+ self.set_theta(x)
+
+ @property
+ def p_ellipse(self):
+ return self.get_p_ellipse()
+
+ @p_ellipse.setter
+ def p_ellipse(self, x):
+ self.set_p_ellipse(x)
+
+ @property
+ def ellipse(self):
+ return self.get_ellipse()
+
+ @ellipse.setter
+ def ellipse(self, x):
+ self.set_ellipse(x)
+
+ # Q/R-space rotation and flip
+
+ @call_calibrate
+ def set_QR_rotation(self, x):
+ self._params["QR_rotation"] = x
+ self._params["QR_rotation_degrees"] = np.degrees(x)
+
+ def get_QR_rotation(self):
+ return self._get_value("QR_rotation")
+
+ @call_calibrate
+ def set_QR_rotation_degrees(self, x):
+ self._params["QR_rotation"] = np.radians(x)
+ self._params["QR_rotation_degrees"] = x
+
+ def get_QR_rotation_degrees(self):
+ return self._get_value("QR_rotation_degrees")
+
+ @call_calibrate
+ def set_QR_flip(self, x):
+ self._params["QR_flip"] = x
+
+ def get_QR_flip(self):
+ return self._get_value("QR_flip")
+
+ @call_calibrate
+ def set_QR_rotflip(self, rot_flip):
+ """
+ Args:
+ rot_flip (tuple), (rot, flip) where:
+ rot (number): rotation in degrees
+ flip (bool): True indicates a Q/R axes flip
+ """
+ rot, flip = rot_flip
+ self._params["QR_rotation"] = rot
+ self._params["QR_rotation_degrees"] = np.degrees(rot)
+ self._params["QR_flip"] = flip
+
+ @call_calibrate
+ def set_QR_rotflip_degrees(self, rot_flip):
+ """
+ Args:
+ rot_flip (tuple), (rot, flip) where:
+ rot (number): rotation in degrees
+ flip (bool): True indicates a Q/R axes flip
+ """
+ rot, flip = rot_flip
+ self._params["QR_rotation"] = np.radians(rot)
+ self._params["QR_rotation_degrees"] = rot
+ self._params["QR_flip"] = flip
+
+ def get_QR_rotflip(self):
+ rot = self.get_QR_rotation()
+ flip = self.get_QR_flip()
+ if rot is None or flip is None:
+ return None
+ return (rot, flip)
+
+ def get_QR_rotflip_degrees(self):
+ rot = self.get_QR_rotation_degrees()
+ flip = self.get_QR_flip()
+ if rot is None or flip is None:
+ return None
+ return (rot, flip)
+
+ # aliases
+ @property
+ def QR_rotation_degrees(self):
+ return self.get_QR_rotation_degrees()
+
+ @QR_rotation_degrees.setter
+ def QR_rotation_degrees(self, x):
+ self.set_QR_rotation_degrees(x)
+
+ @property
+ def QR_flip(self):
+ return self.get_QR_flip()
+
+ @QR_flip.setter
+ def QR_flip(self, x):
+ self.set_QR_flip(x)
+
+ @property
+ def QR_rotflip(self):
+ return self.get_QR_rotflip()
+
+ @QR_rotflip.setter
+ def QR_rotflip(self, x):
+ self.set_QR_rotflip(x)
+
+ # probe
+
+ def set_probe_semiangle(self, x):
+ self._params["probe_semiangle"] = x
+
+ def get_probe_semiangle(self):
+ return self._get_value("probe_semiangle")
+
+ def set_probe_param(self, x):
+ """
+ Args:
+ x (3-tuple): (probe size, x0, y0)
+ """
+ probe_semiangle, qx0, qy0 = x
+ self.set_probe_semiangle(probe_semiangle)
+ self.set_qx0_mean(qx0)
+ self.set_qy0_mean(qy0)
+
+ def get_probe_param(self):
+ probe_semiangle = self._get_value("probe_semiangle")
+ qx0 = self._get_value("qx0")
+ qy0 = self._get_value("qy0")
+ ans = (probe_semiangle, qx0, qy0)
+ if any([x is None for x in ans]):
+ ans = None
+ return ans
+
+ def set_convergence_semiangle_pixels(self, x):
+ self._params["convergence_semiangle_pixels"] = x
+
+ def get_convergence_semiangle_pixels(self):
+ return self._get_value("convergence_semiangle_pixels")
+
+ def set_convergence_semiangle_mrad(self, x):
+ self._params["convergence_semiangle_mrad"] = x
+
+ def get_convergence_semiangle_mrad(self):
+ return self._get_value("convergence_semiangle_mrad")
+
+ def set_probe_center(self, x):
+ self._params["probe_center"] = x
+
+ def get_probe_center(self):
+ return self._get_value("probe_center")
+
+ # aliases
+ @property
+ def probe_semiangle(self):
+ return self.get_probe_semiangle()
+
+ @probe_semiangle.setter
+ def probe_semiangle(self, x):
+ self.set_probe_semiangle(x)
+
+ @property
+ def probe_param(self):
+ return self.get_probe_param()
+
+ @probe_param.setter
+ def probe_param(self, x):
+ self.set_probe_param(x)
+
+ @property
+ def probe_center(self):
+ return self.get_probe_center()
+
+ @probe_center.setter
+ def probe_center(self, x):
+ self.set_probe_center(x)
+
+ @property
+ def probe_convergence_semiangle_pixels(self):
+ return self.get_probe_convergence_semiangle_pixels()
+
+ @probe_convergence_semiangle_pixels.setter
+ def probe_convergence_semiangle_pixels(self, x):
+ self.set_probe_convergence_semiangle_pixels(x)
+
+ @property
+ def probe_convergence_semiangle_mrad(self):
+ return self.get_probe_convergence_semiangle_mrad()
+
+ @probe_convergence_semiangle_mrad.setter
+ def probe_convergence_semiangle_mrad(self, x):
+ self.set_probe_convergence_semiangle_mrad(x)
+
+ ######## End Calibration Metadata Params ########
+
+ # calibrate targets
+ @call_calibrate
+ def calibrate(self):
+ pass
+
+ # For parameters which can have 2D or (2+n)D array values,
+ # this function enables returning the value(s) at a 2D position,
+ # rather than the whole array
+ def _get_value(self, p, rx=None, ry=None):
+ """Enables returning the value of a pixel (rx,ry),
+ if these are passed and `p` is an appropriate array
+ """
+ v = self._params.get(p)
+
+ if v is None:
+ return v
+
+ if (rx is None) or (ry is None) or (not isinstance(v, np.ndarray)):
+ return v
+
+ else:
+ er = f"`rx` and `ry` must be ints; got values {rx} and {ry}"
+ assert np.all([isinstance(i, (int, np.integer)) for i in (rx, ry)]), er
+ return v[rx, ry]
+
+ def copy(self, name=None):
+ """ """
+ if name is None:
+ name = self.name + "_copy"
+ cal = Calibration(name=name)
+ cal._params.update(self._params)
+ return cal
+
+ # HDF5 i/o
+
+ # write is inherited from Metadata
+ def to_h5(self, group):
+ """
+ Saves the metadata dictionary _params to group, then adds the
+ calibration's target's list
+ """
+ # Add targets list to metadata
+ targets = [x._treepath for x in self.targets]
+ self["_target_paths"] = targets
+ # Save the metadata
+ Metadata.to_h5(self, group)
+ del self._params["_target_paths"]
+
+ # read
+ @classmethod
+ def from_h5(cls, group):
+ """
+ Takes a valid group for an HDF5 file object which is open in
+ read mode. Determines if it's a valid Metadata representation, and
+ if so loads and returns it as a Calibration instance. Otherwise,
+ raises an exception.
+
+ Accepts:
+ group (HDF5 group)
+
+ Returns:
+ A Calibration instance
+ """
+ # load the group as a Metadata instance
+ metadata = Metadata.from_h5(group)
+
+ # convert it to a Calibration instance
+ cal = Calibration(name=metadata.name)
+ cal._params.update(metadata._params)
+
+ # return
+ return cal
+
+
+########## End of class ##########
diff --git a/py4DSTEM/data/data.py b/py4DSTEM/data/data.py
new file mode 100644
index 000000000..ed5db2852
--- /dev/null
+++ b/py4DSTEM/data/data.py
@@ -0,0 +1,161 @@
+# Base class for all py4DSTEM data
+# which adds an EMD root and a pointer to 'calibration' metadata
+
+import warnings
+
+from emdfile import Node, Root
+from py4DSTEM.data import Calibration
+
+
+class Data:
+ """
+ The purpose of the `Data` class is to ensure calibrations are linked
+ to data containing class instances, while allowing multiple objects
+ to share a single Calibration. The calibrations of a Data instance
+ `data` is accessible as
+
+ >>> data.calibration
+
+ In py4DSTEM, Data containing objects are stored internally in filetree
+ like representations, defined by the EMD1.0 and `emdfile` specifications,
+ e.g.
+
+ Root
+ |--metadata
+ | |--calibration
+ |
+ |--some_object(e.g.datacube)
+ | |--another_object(e.g.max_dp)
+ | |--etc.
+ |
+ |--one_more_object(e.g.crystal)
+ | |--etc.
+ :
+
+ Calibrations are metadata which always live in the root of such a tree.
+ Running `data.calibration` returns the calibrations from the tree root,
+ and therefore the same calibration instance is referred to be all objects
+ in the same tree. The root itself is accessible from any Data instance
+ as
+
+ >>> data.root
+
+ To examine the tree of a Data instance, in a Python interpreter do
+
+ >>> data.tree(True)
+
+ to display the whole data tree, and
+
+ >>> data.tree()
+
+ to display the tree of from the current node on, i.e. the branch
+ downstream of `data`.
+
+ Calling
+
+ >>> data.calibration
+
+ will raise a warning and return None if no root calibrations are found.
+
+ Some objects should be modified when the calibrations change - these
+ objects must have .calibrate() method, which is called any time relevant
+ calibration parameters change if the object has been registered with
+ the calibrations.
+
+ To transfer `data` from it's current tree to another existing tree, use
+
+ >>> data.attach(some_other_data)
+
+ which will move the data to the new tree. If the data was registered with
+ it's old calibrations, this will also de-register it there and register
+ it with the new calibrations such that .calibrate() is called when it
+ should be.
+
+ See also the Calibration docstring.
+ """
+
+ def __init__(self, calibration=None):
+ assert isinstance(self, Node), "Data instances must inherit from Node"
+ assert calibration is None or isinstance(
+ calibration, Calibration
+ ), f"calibration must be None or a Calibration instance, not type {type(calibration)}"
+
+ # set up calibration + EMD tree
+ if calibration is None:
+ if self.root is None:
+ root = Root(name=self.name + "_root")
+ root.tree(self)
+ self.calibration = Calibration()
+ elif "calibration" not in self.root.metadata:
+ self.calibration = Calibration()
+ else:
+ pass
+ elif calibration.root is None:
+ if self.root is None:
+ root = Root(name=self.name + "_root")
+ root.tree(self)
+ self.calibration = calibration
+ elif "calibration" not in self.root.metadata:
+ self.calibration = calibration
+ else:
+ warnings.warn(
+ "A calibration was passed to instantiate a new Data instance, but the instance already has a calibration. The passed calibration *WAS NOT* attached. To attach the new calibration and overwrite the existing calibration, use `data.calibration = new_calibration`"
+ )
+ pass
+ else:
+ if self.root is None:
+ calibration.root.tree(self)
+ self.calibration = calibration
+ elif "calibration" not in self.root.metadata:
+ self.calibration = calibration
+ warnings.warn(
+ "A calibration was passed to instantiate a new Data instance. The Data already had a root but no calibration, and the calibration already exists in a different root. The calibration has been added and now lives in both roots, and can therefore be modified from either place!"
+ )
+ else:
+ warnings.warn(
+ "A calibration was passed to instantiate a new Data instance, however the Data already has a root and calibration, and the calibration already has a root!! The passed calibration *WAS NOT* attached. To attach the new calibration and overwrite the existing calibration, use `data.calibration = new_calibration."
+ )
+
+ # calibration property
+
+ @property
+ def calibration(self):
+ try:
+ return self.root.metadata["calibration"]
+ except KeyError:
+ warnings.warn("No calibration metadata found in root, returning None")
+ return None
+ except AttributeError:
+ warnings.warn("No root or root metadata found, returning None")
+ return None
+
+ @calibration.setter
+ def calibration(self, x):
+ assert isinstance(x, Calibration)
+ if "calibration" in self.root.metadata.keys():
+ warnings.warn(
+ "A 'calibration' key already exists in root.metadata - overwriting..."
+ )
+ x.name = "calibration"
+ self.root.metadata["calibration"] = x
+
+ # transfer trees
+
+ def attach(self, node):
+ """
+ Attach `node` to the current object's tree, attaching calibration and detaching
+ calibrations as needed.
+ """
+ assert isinstance(node, Node), f"node must be a Node, not type {type(node)}"
+ register = False
+ if hasattr(node, "calibration"):
+ if node.calibration is not None:
+ if node in node.calibration._targets:
+ register = True
+ node.calibration.unregister_target(node)
+ if node.root is None:
+ self.tree(node)
+ else:
+ self.graft(node)
+ if register:
+ self.calibration.register_target(node)
diff --git a/py4DSTEM/data/diffractionslice.py b/py4DSTEM/data/diffractionslice.py
new file mode 100644
index 000000000..4a6d1b9c2
--- /dev/null
+++ b/py4DSTEM/data/diffractionslice.py
@@ -0,0 +1,53 @@
+# Defines the DiffractionSlice class, which stores 2(+1)D,
+# diffraction-shaped data
+
+from emdfile import Array
+from py4DSTEM.data import Data
+
+from typing import Optional, Union
+import numpy as np
+
+
+class DiffractionSlice(Array, Data):
+ """
+ Stores a diffraction-space shaped 2D data array.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "diffractionslice",
+ units: Optional[str] = "intensity",
+ slicelabels: Optional[Union[bool, list]] = None,
+ calibration=None,
+ ):
+ """
+ Accepts:
+ data (np.ndarray): the data
+ name (str): the name of the diffslice
+ units (str): units of the pixel values
+ slicelabels(None or list): names for slices if this is a 3D stack
+
+ Returns:
+ (DiffractionSlice instance)
+ """
+
+ # initialize as an Array
+ Array.__init__(self, data=data, name=name, units=units, slicelabels=slicelabels)
+ # initialize as Data
+ Data.__init__(self, calibration)
+
+ # read
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of args/values to pass to the class constructor
+ """
+ ar_constr_args = Array._get_constructor_args(group)
+ args = {
+ "data": ar_constr_args["data"],
+ "name": ar_constr_args["name"],
+ "units": ar_constr_args["units"],
+ "slicelabels": ar_constr_args["slicelabels"],
+ }
+ return args
diff --git a/py4DSTEM/data/propagating_calibration.py b/py4DSTEM/data/propagating_calibration.py
new file mode 100644
index 000000000..4de0c8d96
--- /dev/null
+++ b/py4DSTEM/data/propagating_calibration.py
@@ -0,0 +1,98 @@
+# Define decorators call_* which, when used to decorate class methods,
+# calls all objects in a list _targets? to call some method *.
+
+import warnings
+
+
+# This is the abstract pattern:
+
+
+class call_method(object):
+ """
+ A decorator which, when attached to a method of SomeClass,
+ causes `method` to be called on any objects in the
+ instance's `_targets` list, following execution of
+ the decorated function.
+ """
+
+ def __init__(self, func):
+ self.func = func
+
+ def __call__(self, *args, **kwargs):
+ """
+ Update the parameters the caller wanted by calling the wrapped
+ method, then loop through the list of targets and call their
+ `calibrate` methods.
+ """
+ self.func(*args, **kwargs)
+ some_object = args[0]
+ assert hasattr(
+ some_object, "_targets"
+ ), "SomeObject object appears to be in an invalid state. _targets attribute is missing."
+ for target in some_object._targets:
+ if hasattr(target, "method") and callable(target.method):
+ try:
+ target.method()
+ except Exception as err:
+ print(
+ f"Attempted to call .method(), but this raised an error: {err}"
+ )
+ else:
+ # warn or pass or error out here, as needs be
+ # pass
+ warnings.warn(
+ f"{target} is registered as a target but does not appear to have a .method() callable"
+ )
+
+ def __get__(self, instance, owner):
+ """
+ This is some magic to make sure that the Calibration instance
+ on which the decorator was called gets passed through and
+ everything dispatches correctly (by making sure `instance`,
+ the Calibration instance to which the call was directed, gets
+ placed in the `self` slot of the wrapped method (which is *not*
+ actually bound to the instance due to this decoration.) using
+ partial application of the method.)
+ """
+ from functools import partial
+
+ return partial(self.__call__, instance)
+
+
+# This is a functional decorator, @call_calibrate:
+
+# calls: calibrate()
+# targets: _targets
+
+
+class call_calibrate(object):
+ """
+ Decorated methods cause all targets in _targets to call .calibrate().
+ """
+
+ def __init__(self, func):
+ self.func = func
+
+ def __call__(self, *args, **kwargs):
+ """ """
+ self.func(*args, **kwargs)
+ calibration = args[0]
+ assert hasattr(
+ calibration, "_targets"
+ ), "Calibration object appears to be in an invalid state. _targets attribute is missing."
+ for target in calibration._targets:
+ if hasattr(target, "calibrate") and callable(target.calibrate):
+ try:
+ target.calibrate()
+ except Exception as err:
+ print(
+ f"Attempted to calibrate object {target} but this raised an error: {err}"
+ )
+ else:
+ pass
+
+ def __get__(self, instance, owner):
+ """ """
+ from functools import partial
+
+ return partial(self.__call__, instance)
diff --git a/py4DSTEM/data/qpoints.py b/py4DSTEM/data/qpoints.py
new file mode 100644
index 000000000..8eabd3eb4
--- /dev/null
+++ b/py4DSTEM/data/qpoints.py
@@ -0,0 +1,70 @@
+# Defines the QPoints class, which stores PointLists with fields 'qx','qy','intensity'
+
+from emdfile import PointList
+from py4DSTEM.data import Data
+
+from typing import Optional
+import numpy as np
+
+
+class QPoints(PointList, Data):
+ """
+ Stores a set of diffraction space points,
+ with fields 'qx', 'qy' and 'intensity'
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "qpoints",
+ ):
+ """
+ Accepts:
+ data (structured numpy ndarray): should have three fields, which
+ will be renamed 'qx','qy','intensity'
+ name (str): the name of the QPoints instance
+
+ Returns:
+ A new QPoints instance
+ """
+
+ # initialize as a PointList
+ PointList.__init__(
+ self,
+ data=data,
+ name=name,
+ )
+
+ # rename fields
+ self.fields = "qx", "qy", "intensity"
+
+ # properties
+
+ @property
+ def qx(self):
+ return self.data["qx"]
+
+ @property
+ def qy(self):
+ return self.data["qy"]
+
+ @property
+ def intensity(self):
+ return self.data["intensity"]
+
+ # aliases
+ I = intensity
+
+ # read
+ # this method is not necessary but is kept for consistency of structure!
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of args/values to pass to the class constructor
+ """
+ pl_constr_args = PointList._get_constructor_args(group)
+ args = {
+ "data": pl_constr_args["data"],
+ "name": pl_constr_args["name"],
+ }
+ return args
diff --git a/py4DSTEM/data/realslice.py b/py4DSTEM/data/realslice.py
new file mode 100644
index 000000000..2c834df4d
--- /dev/null
+++ b/py4DSTEM/data/realslice.py
@@ -0,0 +1,53 @@
+# Defines the RealSlice class, which stores 2(+1)D real-space shaped data
+
+from emdfile import Array
+from py4DSTEM.data import Data
+
+from typing import Optional, Union
+import numpy as np
+
+
+class RealSlice(Array, Data):
+ """
+ Stores a real-space shaped 2D data array.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "realslice",
+ units: Optional[str] = "intensity",
+ slicelabels: Optional[Union[bool, list]] = None,
+ calibration=None,
+ ):
+ """
+ Accepts:
+ data (np.ndarray): the data
+ name (str): the name of the realslice
+ slicelabels(None or list): names for slices if this is a stack of
+ realslices
+
+ Returns:
+ A new RealSlice instance
+ """
+ # initialize as an Array
+ Array.__init__(
+ self, data=data, name=name, units="intensity", slicelabels=slicelabels
+ )
+ # initialize as Data
+ Data.__init__(self, calibration)
+
+ # read
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of args/values to pass to the class constructor
+ """
+ ar_constr_args = Array._get_constructor_args(group)
+ args = {
+ "data": ar_constr_args["data"],
+ "name": ar_constr_args["name"],
+ "units": ar_constr_args["units"],
+ "slicelabels": ar_constr_args["slicelabels"],
+ }
+ return args
diff --git a/py4DSTEM/datacube/__init__.py b/py4DSTEM/datacube/__init__.py
new file mode 100644
index 000000000..883961fcb
--- /dev/null
+++ b/py4DSTEM/datacube/__init__.py
@@ -0,0 +1,5 @@
+_emd_hook = True
+
+from py4DSTEM.datacube.datacube import DataCube
+from py4DSTEM.datacube.virtualimage import VirtualImage
+from py4DSTEM.datacube.virtualdiffraction import VirtualDiffraction
diff --git a/py4DSTEM/datacube/datacube.py b/py4DSTEM/datacube/datacube.py
new file mode 100644
index 000000000..4d87afdd5
--- /dev/null
+++ b/py4DSTEM/datacube/datacube.py
@@ -0,0 +1,1328 @@
+# Defines the DataCube class, which stores 4D-STEM datacubes
+
+import numpy as np
+from scipy.interpolate import interp1d
+from scipy.ndimage import (
+ binary_opening,
+ binary_dilation,
+ distance_transform_edt,
+ binary_fill_holes,
+ gaussian_filter1d,
+ gaussian_filter,
+)
+from typing import Optional, Union
+
+from emdfile import Array, Metadata, Node, Root, tqdmnd
+from py4DSTEM.data import Data, Calibration
+from py4DSTEM.datacube.virtualimage import DataCubeVirtualImager
+from py4DSTEM.datacube.virtualdiffraction import DataCubeVirtualDiffraction
+
+
+class DataCube(
+ Array,
+ Data,
+ DataCubeVirtualImager,
+ DataCubeVirtualDiffraction,
+):
+ """
+ Storage and processing methods for 4D-STEM datasets.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "datacube",
+ slicelabels: Optional[Union[bool, list]] = None,
+ calibration: Optional[Union[Calibration, None]] = None,
+ ):
+ """
+ Accepts:
+ data (np.ndarray): the data
+ name (str): the name of the datacube
+ calibration (None or Calibration or 'pass'): default (None)
+ creates and attaches a new Calibration instance to root
+ metadata, or, passing a Calibration instance uses this instead.
+ slicelabels (None or list): names for slices if this is a
+ stack of datacubes
+
+ Returns:
+ A new DataCube instance.
+ """
+ # initialize as an Array
+ Array.__init__(
+ self,
+ data=data,
+ name=name,
+ units="pixel intensity",
+ dim_names=["Rx", "Ry", "Qx", "Qy"],
+ slicelabels=slicelabels,
+ )
+
+ # initialize as Data
+ Data.__init__(self, calibration)
+
+ # register with calibration
+ self.calibration.register_target(self)
+
+ # cartesian coords
+ self.calibrate()
+
+ # polar coords
+ self.polar = None
+
+ def calibrate(self):
+ """
+ Calibrate the coordinate axes of the datacube. Using the calibrations
+ at self.calibration, sets the 4 dim vectors (Qx,Qy,Rx,Ry) according
+ to the pixel size, units and origin positions, then updates the
+ meshgrids representing Q and R space.
+ """
+ assert self.calibration is not None, "No calibration found!"
+
+ # Get calibration values
+ rpixsize = self.calibration.get_R_pixel_size()
+ rpixunits = self.calibration.get_R_pixel_units()
+ qpixsize = self.calibration.get_Q_pixel_size()
+ qpixunits = self.calibration.get_Q_pixel_units()
+ origin = self.calibration.get_origin_mean()
+ if origin is None or origin == (None, None):
+ origin = (0, 0)
+
+ # Calc dim vectors
+ dim_rx = np.arange(self.R_Nx) * rpixsize
+ dim_ry = np.arange(self.R_Ny) * rpixsize
+ dim_qx = -origin[0] + np.arange(self.Q_Nx) * qpixsize
+ dim_qy = -origin[1] + np.arange(self.Q_Ny) * qpixsize
+
+ # Set dim vectors
+ self.set_dim(0, dim_rx, units=rpixunits)
+ self.set_dim(1, dim_ry, units=rpixunits)
+ self.set_dim(2, dim_qx, units=qpixunits)
+ self.set_dim(3, dim_qy, units=qpixunits)
+
+ # Set meshgrids
+ self._qxx, self._qyy = np.meshgrid(dim_qx, dim_qy)
+ self._rxx, self._ryy = np.meshgrid(dim_rx, dim_ry)
+
+ self._qyy_raw, self._qxx_raw = np.meshgrid(
+ np.arange(self.Q_Ny), np.arange(self.Q_Nx)
+ )
+ self._ryy_raw, self._rxx_raw = np.meshgrid(
+ np.arange(self.R_Ny), np.arange(self.R_Nx)
+ )
+
+ # coordinate meshgrids
+ @property
+ def rxx(self):
+ return self._rxx
+
+ @property
+ def ryy(self):
+ return self._ryy
+
+ @property
+ def qxx(self):
+ return self._qxx
+
+ @property
+ def qyy(self):
+ return self._qyy
+
+ @property
+ def rxx_raw(self):
+ return self._rxx_raw
+
+ @property
+ def ryy_raw(self):
+ return self._ryy_raw
+
+ @property
+ def qxx_raw(self):
+ return self._qxx_raw
+
+ @property
+ def qyy_raw(self):
+ return self._qyy_raw
+
+ # coordinate meshgrids with shifted origin
+ def qxxs(self, rx, ry):
+ qx0_shift = self.calibration.get_qx0shift(rx, ry)
+ if qx0_shift is None:
+ raise Exception(
+ "Can't compute shifted meshgrid - origin shift is not defined"
+ )
+ return self.qxx - qx0_shift
+
+ def qyys(self, rx, ry):
+ qy0_shift = self.calibration.get_qy0shift(rx, ry)
+ if qy0_shift is None:
+ raise Exception(
+ "Can't compute shifted meshgrid - origin shift is not defined"
+ )
+ return self.qyy - qy0_shift
+
+ # shape properties
+
+ ## shape
+
+ # FOV
+ @property
+ def R_Nx(self):
+ return self.data.shape[0]
+
+ @property
+ def R_Ny(self):
+ return self.data.shape[1]
+
+ @property
+ def Q_Nx(self):
+ return self.data.shape[2]
+
+ @property
+ def Q_Ny(self):
+ return self.data.shape[3]
+
+ @property
+ def Rshape(self):
+ return (self.data.shape[0], self.data.shape[1])
+
+ @property
+ def Qshape(self):
+ return (self.data.shape[2], self.data.shape[3])
+
+ @property
+ def R_N(self):
+ return self.R_Nx * self.R_Ny
+
+ # aliases
+ qnx = Q_Nx
+ qny = Q_Ny
+ rnx = R_Nx
+ rny = R_Ny
+ rshape = Rshape
+ qshape = Qshape
+ rn = R_N
+
+ ## pixel size / units
+
+ # Q
+ @property
+ def Q_pixel_size(self):
+ return self.calibration.get_Q_pixel_size()
+
+ @property
+ def Q_pixel_units(self):
+ return self.calibration.get_Q_pixel_units()
+
+ # R
+ @property
+ def R_pixel_size(self):
+ return self.calibration.get_R_pixel_size()
+
+ @property
+ def R_pixel_units(self):
+ return self.calibration.get_R_pixel_units()
+
+ # aliases
+ qpixsize = Q_pixel_size
+ qpixunit = Q_pixel_units
+ rpixsize = R_pixel_size
+ rpixunit = R_pixel_units
+
+ def copy(self):
+ """
+ Copys datacube
+ """
+ from py4DSTEM import DataCube
+
+ new_datacube = DataCube(
+ data=self.data.copy(),
+ name=self.name,
+ calibration=self.calibration.copy(),
+ slicelabels=self.slicelabels,
+ )
+
+ Qpixsize = new_datacube.calibration.get_Q_pixel_size()
+ Qpixunits = new_datacube.calibration.get_Q_pixel_units()
+ Rpixsize = new_datacube.calibration.get_R_pixel_size()
+ Rpixunits = new_datacube.calibration.get_R_pixel_units()
+
+ new_datacube.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx")
+ new_datacube.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry")
+
+ new_datacube.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx")
+ new_datacube.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy")
+
+ return new_datacube
+
+ # I/O
+
+ # to_h5 is inherited from Array
+
+ # read
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """Construct a datacube with no calibration / metadata"""
+ # We only need some of the Array constructors;
+ # dim vector/units are passed through when Calibration
+ # is loaded, and the runtim dim vectors are then set
+ # in _add_root_links
+ ar_args = Array._get_constructor_args(group)
+
+ args = {
+ "data": ar_args["data"],
+ "name": ar_args["name"],
+ "slicelabels": ar_args["slicelabels"],
+ "calibration": None,
+ }
+
+ return args
+
+ def _add_root_links(self, group):
+ """When reading from file, link to calibration metadata,
+ then use it to populate the datacube dim vectors
+ """
+ # Link to the datacube
+ self.calibration._datacube = self
+
+ # Populate dim vectors
+ self.calibration.set_Q_pixel_size(self.calibration.get_Q_pixel_size())
+ self.calibration.set_R_pixel_size(self.calibration.get_R_pixel_size())
+ self.calibration.set_Q_pixel_units(self.calibration.get_Q_pixel_units())
+ self.calibration.set_R_pixel_units(self.calibration.get_R_pixel_units())
+
+ return
+
+ # Class methods
+
+ def add(self, data, name=""):
+ """
+ Adds a block of data to the DataCube's tree. If `data` is an instance of
+ an EMD/py4DSTEM class, add it to the tree. If it's a numpy array,
+ turn it into an Array instance, then save to the tree.
+ """
+ if isinstance(data, np.ndarray):
+ data = Array(data=data, name=name)
+ self.attach(data)
+
+ def set_scan_shape(self, Rshape):
+ """
+ Reshape the data given the real space scan shape.
+
+ Accepts:
+ Rshape (2-tuple)
+ """
+ from py4DSTEM.preprocess import set_scan_shape
+
+ assert len(Rshape) == 2, "Rshape must have a length of 2"
+ d = set_scan_shape(self, Rshape[0], Rshape[1])
+ return d
+
+ def swap_RQ(self):
+ """
+ Swaps the first and last two dimensions of the 4D datacube.
+ """
+ from py4DSTEM.preprocess import swap_RQ
+
+ d = swap_RQ(self)
+ return d
+
+ def swap_Rxy(self):
+ """
+ Swaps the real space x and y coordinates.
+ """
+ from py4DSTEM.preprocess import swap_Rxy
+
+ d = swap_Rxy(self)
+ return d
+
+ def swap_Qxy(self):
+ """
+ Swaps the diffraction space x and y coordinates.
+ """
+ from py4DSTEM.preprocess import swap_Qxy
+
+ d = swap_Qxy(self)
+ return d
+
+ def crop_Q(self, ROI):
+ """
+ Crops the data in diffraction space about the region specified by ROI.
+
+ Accepts:
+ ROI (4-tuple): Specifies (Qx_min,Qx_max,Qy_min,Qy_max)
+ """
+ from py4DSTEM.preprocess import crop_data_diffraction
+
+ assert len(ROI) == 4, "Crop region `ROI` must have length 4"
+ d = crop_data_diffraction(self, ROI[0], ROI[1], ROI[2], ROI[3])
+ return d
+
+ def crop_R(self, ROI):
+ """
+ Crops the data in real space about the region specified by ROI.
+
+ Accepts:
+ ROI (4-tuple): Specifies (Rx_min,Rx_max,Ry_min,Ry_max)
+ """
+ from py4DSTEM.preprocess import crop_data_real
+
+ assert len(ROI) == 4, "Crop region `ROI` must have length 4"
+ d = crop_data_real(self, ROI[0], ROI[1], ROI[2], ROI[3])
+ return d
+
+ def bin_Q(self, N, dtype=None):
+ """
+ Bins the data in diffraction space by bin factor N
+
+ Parameters
+ ----------
+ N : int
+ The binning factor
+ dtype : a datatype (optional)
+ Specify the datatype for the output. If not passed, the datatype
+ is left unchanged
+
+ Returns
+ ------
+ datacube : DataCube
+ """
+ from py4DSTEM.preprocess import bin_data_diffraction
+
+ d = bin_data_diffraction(self, N, dtype)
+ return d
+
+ def pad_Q(self, N=None, output_size=None):
+ """
+ Pads the data in diffraction space by pad factor N, or to match output_size.
+
+ Accepts:
+ N (float, or Sequence[float]): the padding factor
+ output_size ((int,int)): the padded output size
+ """
+ from py4DSTEM.preprocess import pad_data_diffraction
+
+ d = pad_data_diffraction(self, pad_factor=N, output_size=output_size)
+ return d
+
+ def resample_Q(self, N=None, output_size=None, method="bilinear"):
+ """
+ Resamples the data in diffraction space by resampling factor N, or to match output_size,
+ using either 'fourier' or 'bilinear' interpolation.
+
+ Accepts:
+ N (float, or Sequence[float]): the resampling factor
+ output_size ((int,int)): the resampled output size
+ method (str): 'fourier' or 'bilinear' (default)
+ """
+ from py4DSTEM.preprocess import resample_data_diffraction
+
+ d = resample_data_diffraction(
+ self, resampling_factor=N, output_size=output_size, method=method
+ )
+ return d
+
+ def bin_Q_mmap(self, N, dtype=np.float32):
+ """
+ Bins the data in diffraction space by bin factor N for memory mapped data
+
+ Accepts:
+ N (int): the binning factor
+ dtype: the data type
+ """
+ from py4DSTEM.preprocess import bin_data_mmap
+
+ d = bin_data_mmap(self, N)
+ return d
+
+ def bin_R(self, N):
+ """
+ Bins the data in real space by bin factor N
+
+ Accepts:
+ N (int): the binning factor
+ """
+ from py4DSTEM.preprocess import bin_data_real
+
+ d = bin_data_real(self, N)
+ return d
+
+ def thin_R(self, N):
+ """
+ Reduces the data in real space by skipping every N patterns in the x and y directions.
+
+ Accepts:
+ N (int): the thinning factor
+ """
+ from py4DSTEM.preprocess import thin_data_real
+
+ d = thin_data_real(self, N)
+ return d
+
+ def filter_hot_pixels(self, thresh, ind_compare=1, return_mask=False):
+ """
+ This function performs pixel filtering to remove hot / bright pixels. We first compute a moving local ordering filter,
+ applied to the mean diffraction image. This ordering filter will return a single value from the local sorted intensity
+ values, given by ind_compare. ind_compare=0 would be the highest intensity, =1 would be the second hightest, etc.
+ Next, a mask is generated for all pixels which are least a value thresh higher than the local ordering filter output.
+ Finally, we loop through all diffraction images, and any pixels defined by mask are replaced by their 3x3 local median.
+
+ Args:
+ datacube (DataCube):
+ thresh (float): threshold for replacing hot pixels, if pixel value minus local ordering filter exceeds it.
+ ind_compare (int): which median filter value to compare against. 0 = brightest pixel, 1 = next brightest, etc.
+ return_mask (bool): if True, returns the filter mask
+
+ Returns:
+ datacube (DataCube)
+ mask (optional, boolean Array) the bad pixel mask
+ """
+ from py4DSTEM.preprocess import filter_hot_pixels
+
+ d = filter_hot_pixels(
+ self,
+ thresh,
+ ind_compare,
+ return_mask,
+ )
+ return d
+
+ # Probe
+
+ def get_vacuum_probe(
+ self,
+ ROI=None,
+ align=True,
+ mask=None,
+ threshold=0.2,
+ expansion=12,
+ opening=3,
+ verbose=False,
+ returncalc=True,
+ ):
+ """
+ Computes a vacuum probe.
+
+ Which diffraction patterns are included in the calculation is specified
+ by the `ROI` parameter. Diffraction patterns are aligned before averaging
+ if `align` is True (default). A global mask is applied to each diffraction
+ pattern before aligning/averaging if `mask` is specified. After averaging,
+ a final masking step is applied according to the parameters `threshold`,
+ `expansion`, and `opening`.
+
+ Parameters
+ ----------
+ ROI : optional, boolean array or len 4 list/tuple
+ If unspecified, uses the whole datacube. If a boolean array is
+ passed must be real-space shaped, and True pixels are used. If a
+ 4-tuple is passed, uses the region inside the limits
+ (rx_min,rx_max,ry_min,ry_max)
+ align : optional, bool
+ if True, aligns the probes before averaging
+ mask : optional, array
+ mask applied to each diffraction pattern before alignment and
+ averaging
+ threshold : float
+ in the final masking step, values less than max(probe)*threshold
+ are considered outside the probe
+ expansion : int
+ number of pixels by which the final mask is expanded after
+ thresholding
+ opening : int
+ size of binary opening applied to the final mask to eliminate stray
+ bright pixels
+ verbose : bool
+ toggles verbose output
+ returncalc : bool
+ if True, returns the answer
+
+ Returns
+ -------
+ probe : Probe, optional
+ the vacuum probe
+ """
+ from py4DSTEM.process.utils import get_shifted_ar, get_shift
+ from py4DSTEM.braggvectors import Probe
+
+ # parse region to use
+ if ROI is None:
+ ROI = np.ones(self.Rshape, dtype=bool)
+ elif isinstance(ROI, tuple):
+ assert len(ROI) == 4, "if ROI is a tuple must be length 4"
+ _ROI = np.ones(self.Rshape, dtype=bool)
+ ROI = _ROI[ROI[0] : ROI[1], ROI[2] : ROI[3]]
+ else:
+ assert isinstance(ROI, np.ndarray)
+ assert ROI.shape == self.Rshape
+ xy = np.vstack(np.nonzero(ROI))
+ length = xy.shape[1]
+
+ # setup global mask
+ if mask is None:
+ mask = 1
+ else:
+ assert mask.shape == self.Qshape
+
+ # compute average probe
+ probe = self.data[xy[0, 0], xy[1, 0], :, :]
+ for n in tqdmnd(range(1, length)):
+ curr_DP = self.data[xy[0, n], xy[1, n], :, :] * mask
+ if align:
+ xshift, yshift = get_shift(probe, curr_DP)
+ curr_DP = get_shifted_ar(curr_DP, xshift, yshift)
+ probe = probe * (n - 1) / n + curr_DP / n
+
+ # mask
+ mask = probe > np.max(probe) * threshold
+ mask = binary_opening(mask, iterations=opening)
+ mask = binary_dilation(mask, iterations=1)
+ mask = (
+ np.cos(
+ (np.pi / 2)
+ * np.minimum(
+ distance_transform_edt(np.logical_not(mask)) / expansion, 1
+ )
+ )
+ ** 2
+ )
+ probe *= mask
+
+ # make a probe, add to tree, and return
+ probe = Probe(probe)
+ self.attach(probe)
+ if returncalc:
+ return probe
+
+ def get_probe_size(
+ self,
+ dp=None,
+ thresh_lower=0.01,
+ thresh_upper=0.99,
+ N=100,
+ plot=False,
+ returncal=True,
+ write_to_cal=True,
+ **kwargs,
+ ):
+ """
+ Gets the center and radius of the probe in the diffraction plane.
+
+ The algorithm is as follows:
+ First, create a series of N binary masks, by thresholding the diffraction
+ pattern DP with a linspace of N thresholds from thresh_lower to
+ thresh_upper, measured relative to the maximum intensity in DP.
+ Using the area of each binary mask, calculate the radius r of a circular
+ probe. Because the central disk is typically very intense relative to
+ the rest of the DP, r should change very little over a wide range of
+ intermediate values of the threshold. The range in which r is trustworthy
+ is found by taking the derivative of r(thresh) and finding identifying
+ where it is small. The radius is taken to be the mean of these r values.
+ Using the threshold corresponding to this r, a mask is created and the
+ CoM of the DP times this mask it taken. This is taken to be the origin
+ x0,y0.
+
+ Args:
+ dp (str or array): specifies the diffraction pattern in which to
+ find the central disk. A position averaged, or shift-corrected
+ and averaged, DP works best. If mode is None, the diffraction
+ pattern stored in the tree from 'get_dp_mean' is used. If mode
+ is a string it specifies the name of another virtual diffraction
+ pattern in the tree. If mode is an array, the array is used to
+ calculate probe size.
+ thresh_lower (float, 0 to 1): the lower limit of threshold values
+ thresh_upper (float, 0 to 1): the upper limit of threshold values
+ N (int): the number of thresholds / masks to use
+ plot (bool): if True plots results
+ plot_params(dict): dictionary to modify defaults in plot
+ return_calc (bool): if True returns 3-tuple described below
+ write_to_cal (bool): if True, looks for a Calibration instance
+ and writes the measured probe radius there
+
+ Returns:
+ (3-tuple): A 3-tuple containing:
+
+ * **r**: *(float)* the central disk radius, in pixels
+ * **x0**: *(float)* the x position of the central disk center
+ * **y0**: *(float)* the y position of the central disk center
+ """
+ # perform computation
+ from py4DSTEM.process.calibration import get_probe_size
+
+ if dp is None:
+ assert (
+ "dp_mean" in self.treekeys
+ ), "calculate .get_dp_mean() or pass a `dp` arg"
+ DP = self.tree("dp_mean").data
+ elif type(dp) == str:
+ assert dp in self.treekeys, f"mode {dp} not found in the tree"
+ DP = self.tree(dp)
+ elif type(dp) == np.ndarray:
+ assert dp.shape == self.Qshape, "must be a diffraction space shape 2D array"
+ DP = dp
+
+ x = get_probe_size(
+ DP,
+ thresh_lower=thresh_lower,
+ thresh_upper=thresh_upper,
+ N=N,
+ )
+
+ # try to add to calibration
+ if write_to_cal:
+ try:
+ self.calibration.set_probe_param(x)
+ except AttributeError:
+ raise Exception(
+ "writing to calibrations were requested, but could not be completed"
+ )
+
+ # plot results
+ if plot:
+ from py4DSTEM.visualize import show_circles
+
+ show_circles(DP, (x[1], x[2]), x[0], vmin=0, vmax=1, **kwargs)
+
+ # return
+ if returncal:
+ return x
+
+ # Bragg disks
+
+ def find_Bragg_disks(
+ self,
+ template,
+ data=None,
+ radial_bksb=False,
+ filter_function=None,
+ corrPower=1,
+ sigma=None,
+ sigma_dp=0,
+ sigma_cc=2,
+ subpixel="multicorr",
+ upsample_factor=16,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0.005,
+ relativeToPeak=0,
+ minPeakSpacing=60,
+ edgeBoundary=20,
+ maxNumPeaks=70,
+ CUDA=False,
+ CUDA_batched=True,
+ distributed=None,
+ ML=False,
+ ml_model_path=None,
+ ml_num_attempts=1,
+ ml_batch_size=8,
+ name="braggvectors",
+ returncalc=True,
+ ):
+ """
+ Finds the Bragg disks in the diffraction patterns represented by `data` by
+ cross/phase correlatin with `template`.
+
+ Behavior depends on `data`. If it is None (default), runs on the whole DataCube,
+ and stores the output in its tree. Otherwise, nothing is stored in tree,
+ but some value is returned. Valid entries are:
+
+ - a 2-tuple of numbers (rx,ry): run on this diffraction image,
+ and return a QPoints instance
+ - a 2-tuple of arrays (rx,ry): run on these diffraction images,
+ and return a list of QPoints instances
+ - an Rspace shapped 2D boolean array: run on the diffraction images
+ specified by the True counts and return a list of QPoints
+ instances
+
+ For disk detection on a full DataCube, the calculation can be performed
+ on the CPU, GPU or a cluster. By default the CPU is used. If `CUDA` is set
+ to True, tries to use the GPU. If `CUDA_batched` is also set to True,
+ batches the FFT/IFFT computations on the GPU. For distribution to a cluster,
+ distributed must be set to a dictionary, with contents describing how
+ distributed processing should be performed - see below for details.
+
+
+ For each diffraction pattern, the algorithm works in 4 steps:
+
+ (1) any pre-processing is performed to the diffraction image. This is
+ accomplished by passing a callable function to the argument
+ `filter_function`, a bool to the argument `radial_bksb`, or a value >0
+ to `sigma_dp`. If none of these are passed, this step is skipped.
+ (2) the diffraction image is cross correlated with the template.
+ Phase/hybrid correlations can be used instead by setting the
+ `corrPower` argument. Cross correlation can be skipped entirely,
+ and the subsequent steps performed directly on the diffraction
+ image instead of the cross correlation, by passing None to
+ `template`.
+ (3) the maxima of the cross correlation are located and their
+ positions and intensities stored. The cross correlation may be
+ passed through a gaussian filter first by passing the `sigma_cc`
+ argument. The method for maximum detection can be set with
+ the `subpixel` parameter. Options, from something like fastest/least
+ precise to slowest/most precise are 'pixel', 'poly', and 'multicorr'.
+ (4) filtering is applied to remove untrusted or undesired positive counts,
+ based on their intensity (`minRelativeIntensity`,`relativeToPeak`,
+ `minAbsoluteIntensity`) their proximity to one another or the
+ image edge (`minPeakSpacing`, `edgeBoundary`), and the total
+ number of peaks per pattern (`maxNumPeaks`).
+
+
+ Parameters
+ ----------
+ template : 2D array
+ the vacuum probe template, in real space. For Probe instances,
+ this is `probe.kernel`. If None, does not perform a cross
+ correlation.
+ data : variable
+ see above
+ radial_bksb : bool
+ if True, computes a radial background given by the median of the
+ (circular) polar transform of each each diffraction pattern, and
+ subtracts this background from the pattern before applying any
+ filter function and computing the cross correlation. The origin
+ position must be set in the datacube's calibrations. Currently
+ only supported for full datacubes on the CPU.
+ filter_function : callable
+ filtering function to apply to each diffraction pattern before
+ peak finding. Must be a function of only one argument (the
+ diffraction pattern) and return the filtered diffraction pattern.
+ The shape of the returned DP must match the shape of the probe
+ kernel (but does not need to match the shape of the input
+ diffraction pattern, e.g. the filter can be used to bin the
+ diffraction pattern). If using distributed disk detection, the
+ function must be able to be pickled with by dill.
+ corrPower : float between 0 and 1, inclusive
+ the cross correlation power. A value of 1 corresponds to a cross
+ correlation, 0 corresponds to a phase correlation, and intermediate
+ values correspond to hybrid correlations.
+ sigma : float
+ alias for `sigma_cc`
+ sigma_dp : float
+ if >0, a gaussian smoothing filter with this standard deviation
+ is applied to the diffraction pattern before maxima are detected
+ sigma_cc : float
+ if >0, a gaussian smoothing filter with this standard deviation
+ is applied to the cross correlation before maxima are detected
+ subpixel : str
+ Whether to use subpixel fitting, and which algorithm to use.
+ Must be in ('none','poly','multicorr').
+ * 'none': performs no subpixel fitting
+ * 'poly': polynomial interpolation of correlogram peaks (default)
+ * 'multicorr': uses the multicorr algorithm with DFT upsampling
+ upsample_factor : int
+ upsampling factor for subpixel fitting (only used when
+ subpixel='multicorr')
+ minAbsoluteIntensity : float
+ the minimum acceptable correlation peak intensity, on an absolute scale
+ minRelativeIntensity : float
+ the minimum acceptable correlation peak intensity, relative to the
+ intensity of the brightest peak
+ relativeToPeak : int
+ specifies the peak against which the minimum relative intensity is
+ measured -- 0=brightest maximum. 1=next brightest, etc.
+ minPeakSpacing : float
+ the minimum acceptable spacing between detected peaks
+ edgeBoundary (int): minimum acceptable distance for detected peaks from
+ the diffraction image edge, in pixels.
+ maxNumPeaks : int
+ the maximum number of peaks to return
+ CUDA : bool
+ If True, import cupy and use an NVIDIA GPU to perform disk detection
+ CUDA_batched : bool
+ If True, and CUDA is selected, the FFT and IFFT steps of disk detection
+ are performed in batches to better utilize GPU resources.
+ distributed : dict
+ contains information for parallel processing using an IPyParallel or
+ Dask distributed cluster. Valid keys are:
+ * ipyparallel (dict):
+ * client_file (str): path to client json for connecting to your
+ existing IPyParallel cluster
+ * dask (dict): client (object): a dask client that connects to
+ your existing Dask cluster
+ * data_file (str): the absolute path to your original data
+ file containing the datacube
+ * cluster_path (str): defaults to the working directory during
+ processing
+ if distributed is None, which is the default, processing will be in
+ serial
+ name : str
+ name for the output BraggVectors
+ returncalc : bool
+ if True, returns the answer
+
+ Returns
+ -------
+ variable
+ See above.
+ """
+ from py4DSTEM.braggvectors import find_Bragg_disks
+
+ sigma_cc = sigma if sigma is not None else sigma_cc
+
+ # parse args
+ if data is None:
+ x = self
+ elif isinstance(data, tuple):
+ x = self, data[0], data[1]
+ elif isinstance(data, np.ndarray):
+ assert data.dtype == bool, "array must be boolean"
+ assert data.shape == self.Rshape, "array must be Rspace shaped"
+ x = self.data[data, :, :]
+ else:
+ raise Exception(f"unexpected type for `data` {type(data)}")
+
+ # compute
+ peaks = find_Bragg_disks(
+ data=x,
+ template=template,
+ radial_bksb=radial_bksb,
+ filter_function=filter_function,
+ corrPower=corrPower,
+ sigma_dp=sigma_dp,
+ sigma_cc=sigma_cc,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minPeakSpacing=minPeakSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ CUDA=CUDA,
+ CUDA_batched=CUDA_batched,
+ distributed=distributed,
+ ML=ML,
+ ml_model_path=ml_model_path,
+ ml_num_attempts=ml_num_attempts,
+ ml_batch_size=ml_batch_size,
+ )
+
+ if isinstance(peaks, Node):
+ # add metadata
+ peaks.name = name
+ peaks.metadata = Metadata(
+ name="gen_params",
+ data={
+ #'gen_func' :
+ "template": template,
+ "filter_function": filter_function,
+ "corrPower": corrPower,
+ "sigma_dp": sigma_dp,
+ "sigma_cc": sigma_cc,
+ "subpixel": subpixel,
+ "upsample_factor": upsample_factor,
+ "minAbsoluteIntensity": minAbsoluteIntensity,
+ "minRelativeIntensity": minRelativeIntensity,
+ "relativeToPeak": relativeToPeak,
+ "minPeakSpacing": minPeakSpacing,
+ "edgeBoundary": edgeBoundary,
+ "maxNumPeaks": maxNumPeaks,
+ "CUDA": CUDA,
+ "CUDA_batched": CUDA_batched,
+ "distributed": distributed,
+ "ML": ML,
+ "ml_model_path": ml_model_path,
+ "ml_num_attempts": ml_num_attempts,
+ "ml_batch_size": ml_batch_size,
+ },
+ )
+
+ # add to tree
+ if data is None:
+ self.attach(peaks)
+
+ # return
+ if returncalc:
+ return peaks
+
+ def get_beamstop_mask(
+ self,
+ threshold=0.25,
+ distance_edge=2.0,
+ include_edges=True,
+ sigma=0,
+ use_max_dp=False,
+ scale_radial=None,
+ name="mask_beamstop",
+ returncalc=True,
+ ):
+ """
+ This function uses the mean diffraction pattern plus a threshold to
+ create a beamstop mask.
+
+ Args:
+ threshold (float): Value from 0 to 1 defining initial threshold for
+ beamstop mask, taken from the sorted intensity values - 0 is the
+ dimmest pixel, while 1 uses the brighted pixels.
+ distance_edge (float): How many pixels to expand the mask.
+ include_edges (bool): If set to True, edge pixels will be included
+ in the mask.
+ sigma (float):
+ Gaussain blur std to apply to image before thresholding.
+ use_max_dp (bool):
+ Use the max DP instead of the mean DP.
+ scale_radial (float):
+ Scale from center of image by this factor (can help with edge)
+ name (string): Name of the output array.
+ returncalc (bool): Set to true to return the result.
+
+ Returns:
+ (Optional): if returncalc is True, returns the beamstop mask
+
+ """
+
+ if scale_radial is not None:
+ x = np.arange(self.data.shape[2]) * 2.0 / self.data.shape[2]
+ y = np.arange(self.data.shape[3]) * 2.0 / self.data.shape[3]
+ ya, xa = np.meshgrid(y - np.mean(y), x - np.mean(x))
+ im_scale = 1.0 + np.sqrt(xa**2 + ya**2) * scale_radial
+
+ # Get image for beamstop mask
+ if use_max_dp:
+ # if not "dp_mean" in self.tree.keys():
+ # self.get_dp_max();
+ # im = self.tree["dp_max"].data.astype('float')
+ if not "dp_max" in self._branch.keys():
+ self.get_dp_max()
+ im = self.tree("dp_max").data.copy().astype("float")
+ else:
+ if not "dp_mean" in self._branch.keys():
+ self.get_dp_mean()
+ im = self.tree("dp_mean").data.copy()
+
+ # if not "dp_mean" in self.tree.keys():
+ # self.get_dp_mean();
+ # im = self.tree["dp_mean"].data.astype('float')
+
+ # smooth and scale if needed
+ if sigma > 0.0:
+ im = gaussian_filter(im, sigma, mode="nearest")
+ if scale_radial is not None:
+ im *= im_scale
+
+ # Calculate beamstop mask
+ int_sort = np.sort(im.ravel())
+ ind = np.round(
+ np.clip(int_sort.shape[0] * threshold, 0, int_sort.shape[0])
+ ).astype("int")
+ intensity_threshold = int_sort[ind]
+ mask_beamstop = im >= intensity_threshold
+
+ # clean up mask
+ mask_beamstop = np.logical_not(binary_fill_holes(np.logical_not(mask_beamstop)))
+ mask_beamstop = binary_fill_holes(mask_beamstop)
+
+ # Edges
+ if include_edges:
+ mask_beamstop[0, :] = False
+ mask_beamstop[:, 0] = False
+ mask_beamstop[-1, :] = False
+ mask_beamstop[:, -1] = False
+
+ # Expand mask
+ mask_beamstop = distance_transform_edt(mask_beamstop) < distance_edge
+
+ # Wrap beamstop mask in a class
+ x = Array(data=mask_beamstop, name=name)
+
+ # Add metadata
+ x.metadata = Metadata(
+ name="gen_params",
+ data={
+ #'gen_func' :
+ "threshold": threshold,
+ "distance_edge": distance_edge,
+ "include_edges": include_edges,
+ "name": "mask_beamstop",
+ "returncalc": returncalc,
+ },
+ )
+
+ # Add to tree
+ self.tree(x)
+
+ # return
+ if returncalc:
+ return mask_beamstop
+
+ def get_radial_bkgrnd(self, rx, ry, sigma=2):
+ """
+ Computes and returns a background image for the diffraction
+ pattern at (rx,ry), populated by radial rings of constant intensity
+ about the origin, with the value of each ring given by the median
+ value of the diffraction pattern at that radial distance.
+
+ Parameters
+ ----------
+ rx : int
+ The x-coord of the beam position
+ ry : int
+ The y-coord of the beam position
+ sigma : number
+ If >0, applying a gaussian smoothing in the radial direction
+ before returning
+
+ Returns
+ -------
+ background : ndarray
+ The radial background
+ """
+ # ensure a polar cube and origin exist
+ assert self.polar is not None, "No polar datacube found!"
+ assert self.calibration.get_origin() is not None, "No origin found!"
+
+ # get the 1D median background
+ bkgrd_ma_1d = np.ma.median(self.polar.data[rx, ry], axis=0)
+ bkgrd_1d = bkgrd_ma_1d.data
+ bkgrd_1d[bkgrd_ma_1d.mask] = 0
+
+ # smooth
+ if sigma > 0:
+ bkgrd_1d = gaussian_filter1d(bkgrd_1d, sigma)
+
+ # define the 2D cartesian coordinate system
+ origin = self.calibration.get_origin()
+ origin = origin[0][rx, ry], origin[1][rx, ry]
+ qxx, qyy = self.qxx_raw - origin[0], self.qyy_raw - origin[1]
+
+ # get distance qr in polar-elliptical coords
+ ellipse = self.calibration.get_ellipse()
+ ellipse = (1, 1, 0) if ellipse is None else ellipse
+ a, b, theta = ellipse
+
+ qrr = np.sqrt(
+ ((qxx * np.cos(theta)) + (qyy * np.sin(theta))) ** 2
+ + ((qxx * np.sin(theta)) - (qyy * np.cos(theta))) ** 2 / (b / a) ** 2
+ )
+
+ # make an interpolation function and get the 2D background
+ f = interp1d(self.polar.radial_bins, bkgrd_1d, fill_value="extrapolate")
+ background = f(qrr)
+
+ # return
+ return background
+
+ def get_radial_bksb_dp(self, rx, ry, sigma=2):
+ """
+ Computes and returns the diffraction pattern at beam position (rx,ry)
+ with a radial background subtracted. See the docstring for
+ datacube.get_radial_background for more info.
+
+ Parameters
+ ----------
+ rx : int
+ The x-coord of the beam position
+ ry : int
+ The y-coord of the beam position
+ sigma : number
+ If >0, applying a gaussian smoothing in the radial direction
+ before returning
+
+ Returns
+ -------
+ data : ndarray
+ The radial background subtracted diffraction image
+ """
+ # get 2D background
+ background = self.get_radial_bkgrnd(rx, ry, sigma)
+
+ # subtract, zero negative values, return
+ ans = self.data[rx, ry] - background
+ ans[ans < 0] = 0
+ return ans
+
+ def get_local_ave_dp(
+ self,
+ rx,
+ ry,
+ radial_bksb=False,
+ sigma=2,
+ braggmask=False,
+ braggvectors=None,
+ braggmask_radius=None,
+ ):
+ """
+ Computes and returns the diffraction pattern at beam position (rx,ry)
+ after weighted local averaging with its nearest-neighbor patterns,
+ using a 3x3 gaussian kernel for the weightings.
+
+ Parameters
+ ----------
+ rx : int
+ The x-coord of the beam position
+ ry : int
+ The y-coord of the beam position
+ radial_bksb : bool
+ It True, apply a radial background subtraction to each pattern
+ before averaging
+ sigma : number
+ If radial_bksb is True, use this sigma for radial smoothing of
+ the background
+ braggmask : bool
+ If True, masks bragg scattering at each scan position before
+ averaging. `braggvectors` and `braggmask_radius` must be
+ specified.
+ braggvectors : BraggVectors
+ The Bragg vectors to use for masking
+ braggmask_radius : number
+ The radius about each Bragg point to mask
+
+ Returns
+ -------
+ data : ndarray
+ The radial background subtracted diffraction image
+ """
+ # define the kernel
+ kernel = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]]) / 16.0
+
+ # get shape and check for valid inputs
+ nx, ny = self.data.shape[:2]
+ assert rx >= 0 and rx < nx, "rx outside of scan range"
+ assert ry >= 0 and ry < ny, "ry outside of scan range"
+
+ # get the subcube, checking for edge patterns
+ # and modifying the kernel as needed
+ if rx != 0 and rx != (nx - 1) and ry != 0 and ry != (ny - 1):
+ subcube = self.data[rx - 1 : rx + 2, ry - 1 : ry + 2, :, :]
+ elif rx == 0 and ry == 0:
+ subcube = self.data[:2, :2, :, :]
+ kernel = kernel[1:, 1:]
+ elif rx == 0 and ry == (ny - 1):
+ subcube = self.data[:2, -2:, :, :]
+ kernel = kernel[1:, :-1]
+ elif rx == (nx - 1) and ry == 0:
+ subcube = self.data[-2:, :2, :, :]
+ kernel = kernel[:-1, 1:]
+ elif rx == (nx - 1) and ry == (ny - 1):
+ subcube = self.data[-2:, -2:, :, :]
+ kernel = kernel[:-1, :-1]
+ elif rx == 0:
+ subcube = self.data[:2, ry - 1 : ry + 2, :, :]
+ kernel = kernel[1:, :]
+ elif rx == (nx - 1):
+ subcube = self.data[-2:, ry - 1 : ry + 2, :, :]
+ kernel = kernel[:-1, :]
+ elif ry == 0:
+ subcube = self.data[rx - 1 : rx + 2, :2, :, :]
+ kernel = kernel[:, 1:]
+ elif ry == (ny - 1):
+ subcube = self.data[rx - 1 : rx + 2, -2:, :, :]
+ kernel = kernel[:, :-1]
+ else:
+ raise Exception(f"Invalid (rx,ry) = ({rx},{ry})...")
+
+ # normalize the kernel
+ kernel /= np.sum(kernel)
+
+ # compute...
+
+ # ...in the simple case
+ if not (radial_bksb) and not (braggmask):
+ ans = np.tensordot(subcube, kernel, axes=((0, 1), (0, 1)))
+
+ # ...with radial background subtration
+ elif radial_bksb and not (braggmask):
+ # get position of (rx,ry) relative to kernel
+ _xs = 1 if rx != 0 else 0
+ _ys = 1 if ry != 0 else 0
+ x0 = rx - _xs
+ y0 = ry - _ys
+ # compute
+ ans = np.zeros(self.Qshape)
+ for (i, j), w in np.ndenumerate(kernel):
+ x = x0 + i
+ y = y0 + j
+ ans += self.get_radial_bksb_dp(x, y, sigma) * w
+
+ # ...with bragg masking
+ elif not (radial_bksb) and braggmask:
+ assert (
+ braggvectors is not None
+ ), "`braggvectors` must be specified or `braggmask` must be turned off!"
+ assert (
+ braggmask_radius is not None
+ ), "`braggmask_radius` must be specified or `braggmask` must be turned off!"
+ # get position of (rx,ry) relative to kernel
+ _xs = 1 if rx != 0 else 0
+ _ys = 1 if ry != 0 else 0
+ x0 = rx - _xs
+ y0 = ry - _ys
+ # compute
+ ans = np.zeros(self.Qshape)
+ weights = np.zeros(self.Qshape)
+ for (i, j), w in np.ndenumerate(kernel):
+ x = x0 + i
+ y = y0 + j
+ mask = self.get_braggmask(braggvectors, x, y, braggmask_radius)
+ weights_curr = mask * w
+ ans += self.data[x, y] * weights_curr
+ weights += weights_curr
+ # normalize
+ out = np.full_like(ans, np.nan)
+ ans_mask = weights > 0
+ ans = np.divide(ans, weights, out=out, where=ans_mask)
+ # make masked array
+ ans = np.ma.array(data=ans, mask=np.logical_not(ans_mask))
+ pass
+
+ # ...with both radial background subtraction and bragg masking
+ else:
+ assert (
+ braggvectors is not None
+ ), "`braggvectors` must be specified or `braggmask` must be turned off!"
+ assert (
+ braggmask_radius is not None
+ ), "`braggmask_radius` must be specified or `braggmask` must be turned off!"
+ # get position of (rx,ry) relative to kernel
+ _xs = 1 if rx != 0 else 0
+ _ys = 1 if ry != 0 else 0
+ x0 = rx - _xs
+ y0 = ry - _ys
+ # compute
+ ans = np.zeros(self.Qshape)
+ weights = np.zeros(self.Qshape)
+ for (i, j), w in np.ndenumerate(kernel):
+ x = x0 + i
+ y = y0 + j
+ mask = self.get_braggmask(braggvectors, x, y, braggmask_radius)
+ weights_curr = mask * w
+ ans += self.get_radial_bksb_dp(x, y, sigma) * weights_curr
+ weights += weights_curr
+ # normalize
+ out = np.full_like(ans, np.nan)
+ ans_mask = weights > 0
+ ans = np.divide(ans, weights, out=out, where=ans_mask)
+ # make masked array
+ ans = np.ma.array(data=ans, mask=np.logical_not(ans_mask))
+ pass
+
+ # return
+ return ans
+
+ def get_braggmask(self, braggvectors, rx, ry, radius):
+ """
+ Returns a boolean mask which is False in a radius of `radius` around
+ each bragg scattering vector at scan position (rx,ry).
+
+ Parameters
+ ----------
+ braggvectors : BraggVectors
+ The bragg vectors
+ rx : int
+ The x-coord of the beam position
+ ry : int
+ The y-coord of the beam position
+ radius : number
+ mask pixels about each bragg vector to this radial distance
+
+ Returns
+ -------
+ mask : boolean ndarray
+ """
+ # allocate space
+ mask = np.ones(self.Qshape, dtype=bool)
+ # get the vectors
+ vects = braggvectors.raw[rx, ry]
+ # loop
+ for idx in range(len(vects.data)):
+ qr = np.hypot(self.qxx_raw - vects.qx[idx], self.qyy_raw - vects.qy[idx])
+ mask = np.logical_and(mask, qr > radius)
+ return mask
diff --git a/py4DSTEM/datacube/virtualdiffraction.py b/py4DSTEM/datacube/virtualdiffraction.py
new file mode 100644
index 000000000..23b151d58
--- /dev/null
+++ b/py4DSTEM/datacube/virtualdiffraction.py
@@ -0,0 +1,393 @@
+# Virtual diffraction from a self. Includes:
+# * VirtualDiffraction - a container for virtual diffraction data + metadata
+# * DataCubeVirtualDiffraction - methods inherited by DataCube for virt diffraction
+
+import numpy as np
+from typing import Optional
+import inspect
+
+from emdfile import tqdmnd, Metadata
+from py4DSTEM.data import DiffractionSlice, Data
+from py4DSTEM.preprocess import get_shifted_ar
+
+# Virtual diffraction container class
+
+
+class VirtualDiffraction(DiffractionSlice, Data):
+ """
+ Stores a diffraction-space shaped 2D image with metadata
+ indicating how this image was generated from a self.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "virtualdiffraction",
+ ):
+ """
+ Args:
+ data (np.ndarray) : the 2D data
+ name (str) : the name
+
+ Returns:
+ A new VirtualDiffraction instance
+ """
+ # initialize as a DiffractionSlice
+ DiffractionSlice.__init__(
+ self,
+ data=data,
+ name=name,
+ )
+
+ # read
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of args/values to pass to the class constructor
+ """
+ ar_constr_args = DiffractionSlice._get_constructor_args(group)
+ args = {
+ "data": ar_constr_args["data"],
+ "name": ar_constr_args["name"],
+ }
+ return args
+
+
+# DataCube virtual diffraction methods
+
+
+class DataCubeVirtualDiffraction:
+ def __init__(self):
+ pass
+
+ def get_virtual_diffraction(
+ self,
+ method,
+ mask=None,
+ shift_center=False,
+ subpixel=False,
+ verbose=True,
+ name="virtual_diffraction",
+ returncalc=True,
+ ):
+ """
+ Function to calculate virtual diffraction images.
+
+ Parameters
+ ----------
+ method : str
+ defines method used for averaging/combining diffraction patterns.
+ Options are ('mean', 'median', 'max')
+ mask : None or 2D array
+ if None (default), all pixels are used. Otherwise, must be a boolean
+ or floating point or complex array with the same shape as real space.
+ For bool arrays, only True pixels are used in the computation.
+ Otherwise a weighted average is performed.
+ shift_center : bool
+ toggles shifting the diffraction patterns to account for beam shift.
+ Currently only supported for 'max' and 'mean' modes. Default is
+ False.
+ subpixel : bool
+ if shift_center is True, toggles subpixel shifts via Fourier
+ interpolation. Ignored if shift_center is False.
+ verbose : bool
+ toggles progress bar
+ name : string
+ name for the output DiffractionImage instance
+ returncalc : bool
+ toggles returning the output
+
+ Returns
+ -------
+ diff_im : DiffractionImage
+ """
+ # parse inputs
+ assert method in (
+ "max",
+ "median",
+ "mean",
+ ), "check doc strings for supported types"
+ assert (
+ mask is None or mask.shape == self.Rshape
+ ), "mask must be None or real-space shaped"
+
+ # Calculate
+
+ # ...with no center shifting
+ if shift_center is False:
+ # ...for the whole pattern
+ if mask is None:
+ if method == "mean":
+ virtual_diffraction = np.mean(self.data, axis=(0, 1))
+ elif method == "max":
+ virtual_diffraction = np.max(self.data, axis=(0, 1))
+ else:
+ virtual_diffraction = np.median(self.data, axis=(0, 1))
+
+ # ...for boolean masks
+ elif mask.dtype == bool:
+ mask_indices = np.nonzero(mask)
+ if method == "mean":
+ virtual_diffraction = np.mean(
+ self.data[mask_indices[0], mask_indices[1], :, :], axis=0
+ )
+ elif method == "max":
+ virtual_diffraction = np.max(
+ self.data[mask_indices[0], mask_indices[1], :, :], axis=0
+ )
+ else:
+ virtual_diffraction = np.median(
+ self.data[mask_indices[0], mask_indices[1], :, :], axis=0
+ )
+
+ # ...for complex and floating point masks
+ else:
+ # allocate space
+ if mask.dtype == "complex":
+ virtual_diffraction = np.zeros(self.Qshape, dtype="complex")
+ else:
+ virtual_diffraction = np.zeros(self.Qshape)
+ # set computation method
+ if method == "mean":
+ fn = np.sum
+ elif method == "max":
+ fn = np.max
+ else:
+ fn = np.median
+ # loop
+ for qx, qy in tqdmnd(
+ self.Q_Nx,
+ self.Q_Ny,
+ disable=not verbose,
+ ):
+ virtual_diffraction[qx, qy] = fn(
+ np.squeeze(self.data[:, :, qx, qy]) * mask
+ )
+ # normalize weighted means
+ if method == "mean":
+ virtual_diffraction /= np.sum(mask)
+
+ # ...with center shifting
+ else:
+ assert method in (
+ "max",
+ "mean",
+ ), "only 'mean' and 'max' are supported for center-shifted virtual diffraction"
+
+ # Get calibration metadata
+ assert self.calibration.get_origin() is not None, "origin is not calibrated"
+ x0, y0 = self.calibration.get_origin()
+ x0_mean, y0_mean = self.calibration.get_origin_mean()
+
+ # get shifts
+ qx_shift = x0_mean - x0
+ qy_shift = y0_mean - y0
+
+ if subpixel is False:
+ # round shifts -> int
+ qx_shift = qx_shift.round().astype(int)
+ qy_shift = qy_shift.round().astype(int)
+
+ # ...for boolean masks and unmasked
+ if mask is None or mask.dtype == bool:
+ # get scan points
+ mask = np.ones(self.Rshape, dtype=bool) if mask is None else mask
+ mask_indices = np.nonzero(mask)
+ # allocate space
+ virtual_diffraction = np.zeros(self.Qshape)
+ # loop
+ for rx, ry in zip(mask_indices[0], mask_indices[1]):
+ # get shifted DP
+ if subpixel:
+ DP = get_shifted_ar(
+ self.data[
+ rx,
+ ry,
+ :,
+ :,
+ ],
+ qx_shift[rx, ry],
+ qy_shift[rx, ry],
+ )
+ else:
+ DP = np.roll(
+ self.data[
+ rx,
+ ry,
+ :,
+ :,
+ ],
+ (qx_shift[rx, ry], qy_shift[rx, ry]),
+ axis=(0, 1),
+ )
+ # compute
+ if method == "mean":
+ virtual_diffraction += DP
+ elif method == "max":
+ virtual_diffraction = np.maximum(virtual_diffraction, DP)
+ # normalize means
+ if method == "mean":
+ virtual_diffraction /= len(mask_indices[0])
+
+ # ...for floating point and complex masks
+ else:
+ # allocate space
+ if mask.dtype == "complex":
+ virtual_diffraction = np.zeros(self.Qshape, dtype="complex")
+ else:
+ virtual_diffraction = np.zeros(self.Qshape)
+ # loop
+ for rx, ry in tqdmnd(
+ self.R_Nx,
+ self.R_Ny,
+ disable=not verbose,
+ ):
+ # get shifted DP
+ if subpixel:
+ DP = get_shifted_ar(
+ self.data[
+ rx,
+ ry,
+ :,
+ :,
+ ],
+ qx_shift[rx, ry],
+ qy_shift[rx, ry],
+ )
+ else:
+ DP = np.roll(
+ self.data[
+ rx,
+ ry,
+ :,
+ :,
+ ],
+ (qx_shift[rx, ry], qy_shift[rx, ry]),
+ axis=(0, 1),
+ )
+
+ # compute
+ w = mask[rx, ry]
+ if method == "mean":
+ virtual_diffraction += DP * w
+ elif method == "max":
+ virtual_diffraction = np.maximum(virtual_diffraction, DP * w)
+ if method == "mean":
+ virtual_diffraction /= np.sum(mask)
+
+ # wrap, add to tree, and return
+
+ # wrap in DiffractionImage
+ ans = VirtualDiffraction(data=virtual_diffraction, name=name)
+
+ # add the args used to gen this dp as metadata
+ ans.metadata = Metadata(
+ name="gen_params",
+ data={
+ "_calling_method": inspect.stack()[0][3],
+ "_calling_class": __class__.__name__,
+ "method": method,
+ "mask": mask,
+ "shift_center": shift_center,
+ "subpixel": subpixel,
+ "verbose": verbose,
+ "name": name,
+ "returncalc": returncalc,
+ },
+ )
+
+ # add to the tree
+ self.attach(ans)
+
+ # return
+ if returncalc:
+ return ans
+
+ # additional interfaces
+
+ def get_dp_max(
+ self,
+ returncalc=True,
+ ):
+ """
+ Calculates the max diffraction pattern.
+
+ Calls `DataCube.get_virtual_diffraction` - see that method's docstring
+ for more custimizable virtual diffraction.
+
+ Parameters
+ ----------
+ returncalc : bool
+ toggles returning the answer
+
+ Returns
+ -------
+ max_dp : VirtualDiffraction
+ """
+ return self.get_virtual_diffraction(
+ method="max",
+ mask=None,
+ shift_center=False,
+ subpixel=False,
+ verbose=True,
+ name="dp_max",
+ returncalc=True,
+ )
+
+ def get_dp_mean(
+ self,
+ returncalc=True,
+ ):
+ """
+ Calculates the mean diffraction pattern.
+
+ Calls `DataCube.get_virtual_diffraction` - see that method's docstring
+ for more custimizable virtual diffraction.
+
+ Parameters
+ ----------
+ returncalc : bool
+ toggles returning the answer
+
+ Returns
+ -------
+ mean_dp : VirtualDiffraction
+ """
+ return self.get_virtual_diffraction(
+ method="mean",
+ mask=None,
+ shift_center=False,
+ subpixel=False,
+ verbose=True,
+ name="dp_mean",
+ returncalc=True,
+ )
+
+ def get_dp_median(
+ self,
+ returncalc=True,
+ ):
+ """
+ Calculates the max diffraction pattern.
+
+ Calls `DataCube.get_virtual_diffraction` - see that method's docstring
+ for more custimizable virtual diffraction.
+
+ Parameters
+ ----------
+ returncalc : bool
+ toggles returning the answer
+
+ Returns
+ -------
+ max_dp : VirtualDiffraction
+ """
+ return self.get_virtual_diffraction(
+ method="median",
+ mask=None,
+ shift_center=False,
+ subpixel=False,
+ verbose=True,
+ name="dp_median",
+ returncalc=True,
+ )
diff --git a/py4DSTEM/datacube/virtualimage.py b/py4DSTEM/datacube/virtualimage.py
new file mode 100644
index 000000000..87aeae8b1
--- /dev/null
+++ b/py4DSTEM/datacube/virtualimage.py
@@ -0,0 +1,737 @@
+# Virtual imaging from a datacube. Includes:
+# * VirtualImage - a container for virtual image data + metadata
+# * DataCubeVirtualImager - methods inherited by DataCube for virt imaging
+#
+# for bragg virtual imaging methods, goto diskdetection.virtualimage.py
+
+import numpy as np
+import dask.array as da
+from typing import Optional
+import inspect
+
+from emdfile import tqdmnd, Metadata
+from py4DSTEM.data import Calibration, RealSlice, Data, DiffractionSlice
+from py4DSTEM.preprocess import get_shifted_ar
+from py4DSTEM.visualize import show
+
+
+# Virtual image container class
+
+
+class VirtualImage(RealSlice, Data):
+ """
+ A container for storing virtual image data and metadata,
+ including the real-space shaped 2D image and metadata
+ indicating how this image was generated from a datacube.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "virtualimage",
+ ):
+ """
+ Parameters
+ ----------
+ data : np.ndarray
+ the 2D data
+ name : str
+ the name
+ """
+ # initialize as a RealSlice
+ RealSlice.__init__(
+ self,
+ data=data,
+ name=name,
+ )
+
+ # read
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of args/values to pass to the class constructor
+ """
+ ar_constr_args = RealSlice._get_constructor_args(group)
+ args = {
+ "data": ar_constr_args["data"],
+ "name": ar_constr_args["name"],
+ }
+ return args
+
+
+# DataCube virtual imaging methods
+
+
+class DataCubeVirtualImager:
+ def __init__(self):
+ pass
+
+ def get_virtual_image(
+ self,
+ mode,
+ geometry,
+ centered=False,
+ calibrated=False,
+ shift_center=False,
+ subpixel=False,
+ verbose=True,
+ dask=False,
+ return_mask=False,
+ name="virtual_image",
+ returncalc=True,
+ test_config=False,
+ ):
+ """
+ Calculate a virtual image.
+
+ The detector is determined by the combination of the `mode` and
+ `geometry` arguments, supporting point, circular, rectangular,
+ annular, and custom mask detectors. The values passed to geometry
+ may be given with respect to an origin at the corner of the detector
+ array or with respect to the calibrated center position, and in units of
+ pixels or real calibrated units, depending on the values of the
+ `centered` and `calibrated` arguments, respectively. The mask may be
+ shifted pattern-by-pattern to account for diffraction scan shifts using
+ the `shift_center` argument.
+
+ The computed virtual image is stored in the datacube's tree, and is
+ also returned by default.
+
+ Parameters
+ ----------
+ mode : str
+ defines geometry mode for calculating virtual image, and the
+ expected input for the `geometry` argument. options:
+ - 'point': uses a single pixel detector
+ - 'circle', 'circular': uses a round detector, like bright
+ field
+ - 'annular', 'annulus': uses an annular detector, like dark
+ field
+ - 'rectangle', 'square', 'rectangular': uses rectangular
+ detector
+ - 'mask': any diffraction-space shaped 2D array, representing
+ a flexible detector
+ geometry : variable
+ the expected value of this argument is determined by `mode` as
+ follows:
+ - 'point': 2-tuple, (qx,qy), ints
+ - 'circle', 'circular': nested 2-tuple, ((qx,qy),radius),
+ - 'annular', 'annulus': nested 2-tuple,
+ ((qx,qy),(radius_i,radius_o)),
+ - 'rectangle', 'square', 'rectangular': 4-tuple,
+ (xmin,xmax,ymin,ymax)
+ - `mask`: any boolean or floating point 2D array with the same
+ size as datacube.Qshape
+ centered : bool
+ if False, the origin is in the upper left corner. If True, the origin
+ is set to the mean origin in the datacube calibrations, so that a
+ bright-field image could be specified with, e.g., geometry=((0,0),R).
+ The origin can set with datacube.calibration.set_origin(). For
+ `mode="mask"`, has no effect. Default is False.
+ calibrated : bool
+ if True, geometry is specified in units of 'A^-1' instead of pixels.
+ The datacube's calibrations must have its `"Q_pixel_units"` parameter
+ set to "A^-1". For `mode="mask"`, has no effect. Default is False.
+ shift_center : bool
+ if True, the mask is shifted at each real space position to account
+ for any shifting of the origin of the diffraction images. The
+ datacube's calibration['origin'] parameter must be set. The shift
+ applied to each pattern is the difference between the local origin
+ position and the mean origin position over all patterns, rounded to
+ the nearest integer for speed. Default is False. If `shift_center` is
+ True, `centered` is automatically set to True.
+ subpixel : bool
+ if True, applies subpixel shifts to virtual image
+ verbose : bool
+ toggles a progress bar
+ dask : bool
+ if True, use dask to distribute the calculation
+ return_mask : bool
+ if False (default) returns a virtual image as usual. Otherwise does
+ *not* compute or return a virtual image, instead finding and
+ returning the mask that will be used in subsequent calls to this
+ function using these same parameters. In this case, must be either
+ `True` or a 2-tuple of integers corresponding to `(rx,ry)`. If True
+ is passed, returns the mask used if `shift_center` is set to False.
+ If a 2-tuple is passed, returns the mask used at scan position
+ (rx,ry) if `shift_center` is set to True. Nothing is added to the
+ datacube's tree.
+ name : str
+ the output object's name
+ returncalc : bool
+ if True, returns the output
+ test_config : bool
+ if True, prints the Boolean values of
+ (`centered`,`calibrated`,`shift_center`). Does not compute the
+ virtual image.
+
+ Returns
+ -------
+ virt_im : VirtualImage (optional, if returncalc is True)
+ """
+ # parse inputs
+ assert mode in (
+ "point",
+ "circle",
+ "circular",
+ "annulus",
+ "annular",
+ "rectangle",
+ "square",
+ "rectangular",
+ "mask",
+ ), "check doc strings for supported modes"
+
+ if test_config:
+ for x, y in zip(
+ ["centered", "calibrated", "shift_center"],
+ [centered, calibrated, shift_center],
+ ):
+ print(f"{x} = {y}")
+
+ # Get geometry
+ g = self.get_calibrated_detector_geometry(
+ self.calibration, mode, geometry, centered, calibrated
+ )
+
+ # Get mask
+ mask = self.make_detector(self.Qshape, mode, g)
+ # if return_mask is True, skip computation
+ if return_mask is True and shift_center is False:
+ return mask
+
+ # Calculate virtual image
+
+ # no center shifting
+ if shift_center is False:
+ # single CPU
+ if not dask:
+ # allocate space
+ if mask.dtype == "complex":
+ virtual_image = np.zeros(self.Rshape, dtype="complex")
+ else:
+ virtual_image = np.zeros(self.Rshape)
+ # compute
+ for rx, ry in tqdmnd(
+ self.R_Nx,
+ self.R_Ny,
+ disable=not verbose,
+ ):
+ virtual_image[rx, ry] = np.sum(self.data[rx, ry] * mask)
+
+ # dask
+ if dask is True:
+ # set up a generalized universal function for dask distribution
+ def _apply_mask_dask(self, mask):
+ virtual_image = np.sum(
+ np.multiply(self.data, mask), dtype=np.float64
+ )
+
+ apply_mask_dask = da.as_gufunc(
+ _apply_mask_dask,
+ signature="(i,j),(i,j)->()",
+ output_dtypes=np.float64,
+ axes=[(2, 3), (0, 1), ()],
+ vectorize=True,
+ )
+
+ # compute
+ virtual_image = apply_mask_dask(self.data, mask)
+
+ # with center shifting
+ else:
+ # get shifts
+ assert (
+ self.calibration.get_origin_shift() is not None
+ ), "origin need to be calibrated"
+ qx_shift, qy_shift = self.calibration.get_origin_shift()
+ if subpixel is False:
+ qx_shift = qx_shift.round().astype(int)
+ qy_shift = qy_shift.round().astype(int)
+
+ # if return_mask is True, get+return the mask and skip the computation
+ if return_mask is not False:
+ try:
+ rx, ry = return_mask
+ except TypeError:
+ raise Exception(
+ f"if `shift_center=True`, return_mask must be a 2-tuple of \
+ ints or False, but revieced inpute value of {return_mask}"
+ )
+ if subpixel:
+ _mask = get_shifted_ar(
+ mask, qx_shift[rx, ry], qy_shift[rx, ry], bilinear=True
+ )
+ else:
+ _mask = np.roll(
+ mask, (qx_shift[rx, ry], qy_shift[rx, ry]), axis=(0, 1)
+ )
+ return _mask
+
+ # allocate space
+ if mask.dtype == "complex":
+ virtual_image = np.zeros(self.Rshape, dtype="complex")
+ else:
+ virtual_image = np.zeros(self.Rshape)
+
+ # loop
+ for rx, ry in tqdmnd(
+ self.R_Nx,
+ self.R_Ny,
+ disable=not verbose,
+ ):
+ # get shifted mask
+ if subpixel:
+ _mask = get_shifted_ar(
+ mask, qx_shift[rx, ry], qy_shift[rx, ry], bilinear=True
+ )
+ else:
+ _mask = np.roll(
+ mask, (qx_shift[rx, ry], qy_shift[rx, ry]), axis=(0, 1)
+ )
+ # add to output array
+ virtual_image[rx, ry] = np.sum(self.data[rx, ry] * _mask)
+
+ # data handling
+
+ # wrap with a py4dstem class
+ ans = VirtualImage(
+ data=virtual_image,
+ name=name,
+ )
+
+ # add generating params as metadata
+ ans.metadata = Metadata(
+ name="gen_params",
+ data={
+ "_calling_method": inspect.stack()[0][3],
+ "_calling_class": __class__.__name__,
+ "mode": mode,
+ "geometry": geometry,
+ "centered": centered,
+ "calibrated": calibrated,
+ "shift_center": shift_center,
+ "subpixel": subpixel,
+ "verbose": verbose,
+ "dask": dask,
+ "return_mask": return_mask,
+ "name": name,
+ "returncalc": True,
+ "test_config": test_config,
+ },
+ )
+
+ # add to the tree
+ self.attach(ans)
+
+ # return
+ if returncalc:
+ return ans
+
+ # Position detector
+
+ def position_detector(
+ self,
+ mode,
+ geometry,
+ data=None,
+ centered=None,
+ calibrated=None,
+ shift_center=False,
+ subpixel=True,
+ scan_position=None,
+ invert=False,
+ color="r",
+ alpha=0.7,
+ **kwargs,
+ ):
+ """
+ Position a virtual detector by displaying a mask over a diffraction
+ space image. Calling `.get_virtual_image()` using the same `mode`
+ and `geometry` parameters will compute a virtual image using this
+ detector.
+
+ Parameters
+ ----------
+ mode : str
+ see the DataCube.get_virtual_image docstring
+ geometry : variable
+ see the DataCube.get_virtual_image docstring
+ data : None or 2d-array or 2-tuple of ints
+ The diffraction image to overlay the mask on. If `None` (default),
+ looks for a max or mean or median diffraction image in this order
+ and if found, uses it, otherwise, uses the diffraction pattern at
+ scan position (0,0). If a 2d array is passed, must be diffraction
+ space shaped array. If a 2-tuple is passed, uses the diffraction
+ pattern at scan position (rx,ry).
+ centered : bool
+ see the DataCube.get_virtual_image docstring
+ calibrated : bool
+ see the DataCube.get_virtual_image docstring
+ shift_center : None or bool or 2-tuple of ints
+ If `None` (default) and `data` is either None or an array, the mask
+ is not shifted. If `None` and `data` is a 2-tuple, shifts the mask
+ according to the origin at the scan position (rx,ry) specified in
+ `data`. If False, does not shift the mask. If True and `data` is
+ a 2-tuple, shifts the mask accordingly, and if True and `data` is
+ any other value, raises an error. If `shift_center` is a 2-tuple,
+ shifts the mask according to the origin value at this 2-tuple
+ regardless of the value of `data` (enabling e.g. overlaying the
+ mask for a specific scan position on a max or mean diffraction
+ image.)
+ subpixel : bool
+ if True, applies subpixel shifts to virtual image
+ invert : bool
+ if True, invert the masked pixel (i.e. pixels *outside* the detector
+ are overlaid with a mask)
+ color : any matplotlib color specification
+ the mask color
+ alpha : number
+ the mask transparency
+ kwargs : dict
+ Any additional arguments are passed on to the show() function
+ """
+ # parse inputs
+
+ # mode
+ assert mode in (
+ "point",
+ "circle",
+ "circular",
+ "annulus",
+ "annular",
+ "rectangle",
+ "square",
+ "rectangular",
+ "mask",
+ ), "check doc strings for supported modes"
+
+ # data
+ if data is None:
+ image = None
+ keys = ["dp_mean", "dp_max", "dp_median"]
+ for k in keys:
+ try:
+ image = self.tree(k)
+ break
+ except:
+ pass
+ if image is None:
+ image = self[0, 0]
+ elif isinstance(data, np.ndarray):
+ assert (
+ data.shape == self.Qshape
+ ), f"Can't position a detector over an image with a shape that is different \
+ from diffraction space. Diffraction space in this dataset has shape {self.Qshape} \
+ but the image passed has shape {data.shape}"
+ image = data
+ elif isinstance(data, DiffractionSlice):
+ assert (
+ data.shape == self.Qshape
+ ), f"Can't position a detector over an image with a shape that is different \
+ from diffraction space. Diffraction space in this dataset has shape {self.Qshape} \
+ but the image passed has shape {data.shape}"
+ image = data.data
+ elif isinstance(data, tuple):
+ rx, ry = data[:2]
+ image = self[rx, ry]
+ else:
+ raise Exception(
+ f"Invalid argument passed to `data`. Expected None or np.ndarray or \
+ tuple, not type {type(data)}"
+ )
+
+ # shift center
+ if shift_center is None:
+ shift_center = False
+ elif shift_center is True:
+ assert isinstance(
+ data, tuple
+ ), "If shift_center is set to True, `data` should be a 2-tuple (rx,ry). \
+ To shift the detector mask while using some other input for `data`, \
+ set `shift_center` to a 2-tuple (rx,ry)"
+ elif isinstance(shift_center, tuple):
+ rx, ry = shift_center[:2]
+ shift_center = True
+ else:
+ shift_center = False
+
+ # Get the mask
+
+ # Get geometry
+ g = self.get_calibrated_detector_geometry(
+ calibration=self.calibration,
+ mode=mode,
+ geometry=geometry,
+ centered=centered,
+ calibrated=calibrated,
+ )
+
+ # Get mask
+ mask = self.make_detector(image.shape, mode, g)
+ if not (invert):
+ mask = np.logical_not(mask)
+
+ # Shift center
+ if shift_center:
+ try:
+ rx, ry
+ except NameError:
+ raise Exception(
+ "if `shift_center` is True then `data` must be the 3-tuple (DataCube,rx,ry)"
+ )
+ # get shifts
+ assert (
+ self.calibration.get_origin_shift() is not None
+ ), "origin shifts need to be calibrated"
+ qx_shift, qy_shift = self.calibration.get_origin_shift()
+ if subpixel:
+ mask = get_shifted_ar(
+ mask, qx_shift[rx, ry], qy_shift[rx, ry], bilinear=True
+ )
+ else:
+ qx_shift = int(np.round(qx_shift[rx, ry]))
+ qy_shift = int(np.round(qy_shift[rx, ry]))
+ mask = np.roll(mask, (qx_shift, qy_shift), axis=(0, 1))
+
+ # Show
+ show(image, mask=mask, mask_color=color, mask_alpha=alpha, **kwargs)
+ return
+
+ @staticmethod
+ def get_calibrated_detector_geometry(
+ calibration, mode, geometry, centered, calibrated
+ ):
+ """
+ Determine the detector geometry in pixels, given some mode and geometry
+ in calibrated units, where the calibration state is specified by {
+ centered, calibrated}
+
+ Parameters
+ ----------
+ calibration : Calibration
+ Used to retrieve the center positions. If `None`, confirms that
+ centered and calibrated are False then passes, otherwise raises
+ an exception
+ mode : str
+ see the DataCube.get_virtual_image docstring
+ geometry : variable
+ see the DataCube.get_virtual_image docstring
+ centered : bool
+ see the DataCube.get_virtual_image docstring
+ calibrated : bool
+ see the DataCube.get_virtual_image docstring
+
+ Returns
+ -------
+ geo : tuple
+ the geometry in detector pixels
+ """
+ # Parse inputs
+ g = geometry
+ if calibration is None:
+ assert (
+ calibrated is False and centered is False
+ ), "No calibration found - set a calibration or set `centered` and `calibrated` to False"
+ return g
+ else:
+ assert isinstance(calibration, Calibration)
+ cal = calibration
+
+ # Get calibration metadata
+ if centered:
+ assert cal.get_qx0_mean() is not None, "origin needs to be calibrated"
+ x0_mean, y0_mean = cal.get_origin_mean()
+
+ if calibrated:
+ assert (
+ cal["Q_pixel_units"] == "A^-1"
+ ), "check calibration - must be calibrated in A^-1 to use `calibrated=True`"
+ unit_conversion = cal.get_Q_pixel_size()
+
+ # Convert units into detector pixels
+
+ # Shift center
+ if centered is True:
+ if mode == "point":
+ g = (g[0] + x0_mean, g[1] + y0_mean)
+ if mode in ("circle", "circular", "annulus", "annular"):
+ g = ((g[0][0] + x0_mean, g[0][1] + y0_mean), g[1])
+ if mode in ("rectangle", "square", "rectangular"):
+ g = (g[0] + x0_mean, g[1] + x0_mean, g[2] + y0_mean, g[3] + y0_mean)
+
+ # Scale by the detector pixel size
+ if calibrated is True:
+ if mode == "point":
+ g = (g[0] / unit_conversion, g[1] / unit_conversion)
+ if mode in ("circle", "circular"):
+ g = (
+ (g[0][0] / unit_conversion, g[0][1] / unit_conversion),
+ (g[1] / unit_conversion),
+ )
+ if mode in ("annulus", "annular"):
+ g = (
+ (g[0][0] / unit_conversion, g[0][1] / unit_conversion),
+ (g[1][0] / unit_conversion, g[1][1] / unit_conversion),
+ )
+ if mode in ("rectangle", "square", "rectangular"):
+ g = (
+ g[0] / unit_conversion,
+ g[1] / unit_conversion,
+ g[2] / unit_conversion,
+ g[3] / unit_conversion,
+ )
+
+ return g
+
+ @staticmethod
+ def make_detector(
+ shape,
+ mode,
+ geometry,
+ ):
+ """
+ Generate a 2D mask representing a detector function.
+
+ Parameters
+ ----------
+ shape : 2-tuple
+ defines shape of mask. Should be the shape of diffraction space.
+ mode : str
+ defines geometry mode for calculating virtual image. See the
+ docstring for DataCube.get_virtual_image
+ geometry : variable
+ defines geometry for calculating virtual image. See the
+ docstring for DataCube.get_virtual_image
+
+ Returns
+ -------
+ detector_mask : 2d array
+ """
+ g = geometry
+
+ # point mask
+ if mode == "point":
+ assert (
+ isinstance(g, tuple) and len(g) == 2
+ ), "specify qx and qy as tuple (qx, qy)"
+ mask = np.zeros(shape, dtype=bool)
+
+ qx = int(g[0])
+ qy = int(g[1])
+
+ mask[qx, qy] = 1
+
+ # circular mask
+ if mode in ("circle", "circular"):
+ assert (
+ isinstance(g, tuple)
+ and len(g) == 2
+ and len(g[0]) == 2
+ and isinstance(g[1], (float, int))
+ ), "specify qx, qy, radius_i as ((qx, qy), radius)"
+
+ qxa, qya = np.indices(shape)
+ mask = (qxa - g[0][0]) ** 2 + (qya - g[0][1]) ** 2 < g[1] ** 2
+
+ # annular mask
+ if mode in ("annulus", "annular"):
+ assert (
+ isinstance(g, tuple)
+ and len(g) == 2
+ and len(g[0]) == 2
+ and len(g[1]) == 2
+ ), "specify qx, qy, radius_i, radius_0 as ((qx, qy), (radius_i, radius_o))"
+
+ assert g[1][1] > g[1][0], "Inner radius must be smaller than outer radius"
+
+ qxa, qya = np.indices(shape)
+ mask1 = (qxa - g[0][0]) ** 2 + (qya - g[0][1]) ** 2 > g[1][0] ** 2
+ mask2 = (qxa - g[0][0]) ** 2 + (qya - g[0][1]) ** 2 < g[1][1] ** 2
+ mask = np.logical_and(mask1, mask2)
+
+ # rectangle mask
+ if mode in ("rectangle", "square", "rectangular"):
+ assert (
+ isinstance(g, tuple) and len(g) == 4
+ ), "specify x_min, x_max, y_min, y_max as (x_min, x_max, y_min, y_max)"
+ mask = np.zeros(shape, dtype=bool)
+
+ xmin = int(np.round(g[0]))
+ xmax = int(np.round(g[1]))
+ ymin = int(np.round(g[2]))
+ ymax = int(np.round(g[3]))
+
+ mask[xmin:xmax, ymin:ymax] = 1
+
+ # flexible mask
+ if mode == "mask":
+ assert type(g) == np.ndarray, "`geometry` type should be `np.ndarray`"
+ assert g.shape == shape, "mask and diffraction pattern shapes do not match"
+ mask = g
+ return mask
+
+ # TODO where should this go?
+ def make_bragg_mask(
+ self,
+ Qshape,
+ g1,
+ g2,
+ radius,
+ origin,
+ max_q,
+ return_sum=True,
+ **kwargs,
+ ):
+ """
+ Creates and returns a mask consisting of circular disks
+ about the points of a 2D lattice.
+
+ Args:
+ Qshape (2 tuple): the shape of diffraction space
+ g1,g2 (len 2 array or tuple): the lattice vectors
+ radius (number): the disk radius
+ origin (len 2 array or tuple): the origin
+ max_q (nuumber): the maxima distance to tile to
+ return_sum (bool): if False, return a 3D array, where each
+ slice contains a single disk; if False, return a single
+ 2D masks of all disks
+
+ Returns:
+ (2 or 3D array) the mask
+ """
+ nas = np.asarray
+ g1, g2, origin = nas(g1), nas(g2), nas(origin)
+
+ # Get N,M, the maximum indices to tile out to
+ L1 = np.sqrt(np.sum(g1**2))
+ H = int(max_q / L1) + 1
+ L2 = np.hypot(-g2[0] * g1[1], g2[1] * g1[0]) / np.sqrt(np.sum(g1**2))
+ K = int(max_q / L2) + 1
+
+ # Compute number of points
+ N = 0
+ for h in range(-H, H + 1):
+ for k in range(-K, K + 1):
+ v = h * g1 + k * g2
+ if np.sqrt(v.dot(v)) < max_q:
+ N += 1
+
+ # create mask
+ mask = np.zeros((Qshape[0], Qshape[1], N), dtype=bool)
+ N = 0
+ for h in range(-H, H + 1):
+ for k in range(-K, K + 1):
+ v = h * g1 + k * g2
+ if np.sqrt(v.dot(v)) < max_q:
+ center = origin + v
+ mask[:, :, N] = self.make_detector(
+ Qshape,
+ mode="circle",
+ geometry=(center, radius),
+ )
+ N += 1
+
+ if return_sum:
+ mask = np.sum(mask, axis=2)
+ return mask
diff --git a/py4DSTEM/io/__init__.py b/py4DSTEM/io/__init__.py
new file mode 100644
index 000000000..fa7cd099e
--- /dev/null
+++ b/py4DSTEM/io/__init__.py
@@ -0,0 +1,8 @@
+# read / write
+from py4DSTEM.io.importfile import import_file
+from py4DSTEM.io.read import read
+from py4DSTEM.io.save import save
+
+
+# google downloader
+from py4DSTEM.io.google_drive_downloader import gdrive_download, get_sample_file_ids
diff --git a/py4DSTEM/io/filereaders/README.md b/py4DSTEM/io/filereaders/README.md
new file mode 100644
index 000000000..4aa58a05c
--- /dev/null
+++ b/py4DSTEM/io/filereaders/README.md
@@ -0,0 +1,13 @@
+# `io.nonnative`: reading non-native filetypes
+
+Implemented / tested filetypes:
+- .dm3/.dm4
+
+
+Todo:
+- empad
+- gatan K2 binary
+- kitware's e- counted data
+- mrc relativity
+
+
diff --git a/py4DSTEM/io/filereaders/__init__.py b/py4DSTEM/io/filereaders/__init__.py
new file mode 100644
index 000000000..b6f4eb0a2
--- /dev/null
+++ b/py4DSTEM/io/filereaders/__init__.py
@@ -0,0 +1,6 @@
+from py4DSTEM.io.filereaders.read_dm import read_dm
+from py4DSTEM.io.filereaders.read_K2 import read_gatan_K2_bin
+from py4DSTEM.io.filereaders.empad import read_empad
+from py4DSTEM.io.filereaders.read_mib import load_mib
+from py4DSTEM.io.filereaders.read_arina import read_arina
+from py4DSTEM.io.filereaders.read_abTEM import read_abTEM
diff --git a/py4DSTEM/io/filereaders/empad.py b/py4DSTEM/io/filereaders/empad.py
new file mode 100644
index 000000000..25c0a113b
--- /dev/null
+++ b/py4DSTEM/io/filereaders/empad.py
@@ -0,0 +1,107 @@
+# Reads an EMPAD file
+#
+# Created on Tue Jan 15 13:06:03 2019
+# @author: percius
+# Edited on 20190409 by bsavitzky and rdhall
+# Edited on 20210628 by sez
+
+import numpy as np
+from pathlib import Path
+from emdfile import tqdmnd
+from py4DSTEM.datacube import DataCube
+from py4DSTEM.preprocess.utils import bin2D
+
+
+def read_empad(filename, mem="RAM", binfactor=1, metadata=False, **kwargs):
+ """
+ Reads the EMPAD file at filename, returning a DataCube.
+
+ EMPAD files are shaped as 130x128 arrays, consisting of 128x128 arrays of data followed by
+ two rows of metadata. For each frame, its position in the scan is embedded in the metadata.
+ By extracting the scan position of the first and last frames, the function determines the scan
+ size. Then, the full dataset is loaded and cropped to the 128x128 valid region.
+
+ Accepts:
+ filename (str) path to the EMPAD file
+ EMPAD_shape (kwarg, tuple) Manually specify the shape of the data for files that do not
+ contain metadata in the .raw file. This will typically be:
+ (# scan pixels x, # scan pixels y, 130, 128)
+
+ Returns:
+ data (DataCube) the 4D datacube, excluding the metadata rows.
+ """
+ assert isinstance(
+ filename, (str, Path)
+ ), "Error: filepath fp must be a string or pathlib.Path"
+ assert mem in [
+ "RAM",
+ "MEMMAP",
+ ], 'Error: argument mem must be either "RAM" or "MEMMAP"'
+ assert isinstance(binfactor, int), "Error: argument binfactor must be an integer"
+ assert binfactor >= 1, "Error: binfactor must be >= 1"
+ assert metadata is False, "Error: EMPAD Reader does not support metadata."
+
+ row = 130
+ col = 128
+ fPath = Path(filename)
+
+ if "EMPAD_shape" in kwargs.keys():
+ data_shape = kwargs["EMPAD_shape"]
+ else:
+ import os
+
+ filesize = os.path.getsize(fPath)
+ pattern_size = row * col * 4 # 4 bytes per pixel
+ N_patterns = filesize / pattern_size
+ Nxy = np.sqrt(N_patterns)
+
+ # Check that it's reasonably square
+ assert (
+ np.abs(Nxy - np.round(Nxy)) <= 1e-10
+ ), "Automatically detected shape seems wrong... Try specifying it manually with the EMPAD_shape keyword argument"
+
+ data_shape = (int(Nxy), int(Nxy), row, col)
+
+ # # Parse the EMPAD metadata for first and last images
+ # empadDTYPE = np.dtype([("data", "16384float32"), ("metadata", "256float32")])
+ # with open(fPath, "rb") as fid:
+ # imFirst = np.fromfile(fid, dtype=empadDTYPE, count=1)
+ # fid.seek(-128 * 130 * 4, 2)
+ # imLast = np.fromfile(fid, dtype=empadDTYPE, count=1)
+
+ # # Get the scan shape
+ # shape0 = imFirst["metadata"][0][128 + 12 : 128 + 16]
+ # shape1 = imLast["metadata"][0][128 + 12 : 128 + 16]
+ # rShape = 1 + shape1[0:2] - shape0[0:2] # scan shape
+ # data_shape = (int(rShape[0]), int(rShape[1]), row, col)
+
+ # Load the data
+ if (mem, binfactor) == ("RAM", 1):
+ with open(fPath, "rb") as fid:
+ data = np.fromfile(fid, np.float32).reshape(data_shape)[:, :, :128, :]
+ elif (mem, binfactor) == ("MEMMAP", 1):
+ data = np.memmap(fPath, dtype=np.float32, mode="r", shape=data_shape)[
+ :, :, :128, :
+ ]
+ elif (mem) == ("RAM"):
+ # binned read into RAM
+ memmap = np.memmap(fPath, dtype=np.float32, mode="r", shape=data_shape)[
+ :, :, :128, :
+ ]
+ R_Nx, R_Ny, Q_Nx, Q_Ny = memmap.shape
+ Q_Nx, Q_Ny = Q_Nx // binfactor, Q_Ny // binfactor
+ data = np.zeros((R_Nx, R_Ny, Q_Nx, Q_Ny), dtype=np.float32)
+ for Rx, Ry in tqdmnd(
+ R_Nx, R_Ny, desc="Binning data", unit="DP", unit_scale=True
+ ):
+ data[Rx, Ry, :, :] = bin2D(
+ memmap[Rx, Ry, :, :], binfactor, dtype=np.float32
+ )
+ else:
+ # memory mapping + bin-on-load is not supported
+ raise Exception(
+ "Memory mapping and on-load binning together is not supported. Either set binfactor=1 or mem='RAM'."
+ )
+ return
+
+ return DataCube(data=data)
diff --git a/py4DSTEM/io/filereaders/read_K2.py b/py4DSTEM/io/filereaders/read_K2.py
new file mode 100644
index 000000000..e0a5dae1f
--- /dev/null
+++ b/py4DSTEM/io/filereaders/read_K2.py
@@ -0,0 +1,578 @@
+# Open an interface to a Gatan K2 binary fileset, loading frames from disk as called.
+# While slicing (i.e. calling dc.data4D[__,__,__,__]) returns a numpy ndarray, the
+# object is not itself a numpy array, so most numpy functions do not operate on this.
+
+from collections.abc import Sequence
+import numpy as np
+
+try:
+ import numba as nb
+except ImportError:
+ pass
+from emdfile import tqdmnd
+from py4DSTEM.datacube import DataCube
+
+
+def read_gatan_K2_bin(fp, mem="MEMMAP", binfactor=1, metadata=False, **kwargs):
+ """
+ Read a K2 binary 4D-STEM file.
+
+ Args:
+ fp: str Path to the file
+ mem (str, optional): Specifies how the data should be stored; must be "RAM"
+ or "MEMMAP". See docstring for py4DSTEM.file.io.read. Default is "MEMMAP".
+ binfactor: (int, optional): Bin the data, in diffraction space, as it's loaded.
+ See docstring for py4DSTEM.file.io.read. Must be 1, retained only for
+ compatibility.
+ metadata (bool, optional): if True, returns the file metadata as a Metadata
+ instance.
+
+ Returns:
+ (variable): The return value depends on usage:
+
+
+ * if metadata==False, returns the 4D-STEM dataset as a DataCube
+ * if metadata==True, returns the metadata as a Metadata instance
+
+ Note that metadata is read either way - in the latter case ONLY
+ metadata is read and returned, in the former case a DataCube
+ is returned with the metadata attached at datacube.metadata
+ """
+ assert mem == "MEMMAP", "K2 files can only be memory-mapped, sorry."
+ assert binfactor == 1, "K2 files can only be read at full resolution, sorry."
+
+ if metadata is True:
+ return None
+
+ block_sync = kwargs.get("K2_sync_block_IDs", True)
+ NR = kwargs.get("K2_hidden_stripe_noise_reduction", True)
+ return DataCube(
+ data=K2DataArray(
+ fp, sync_block_IDs=block_sync, hidden_stripe_noise_reduction=NR
+ )
+ )
+
+
+class K2DataArray(Sequence):
+ """
+ K2DataArray provides an interface to a set of Gatan K2IS binary output files.
+ This object behaves *similar* to a numpy memmap into the data, and supports 4-D indexing
+ and slicing. Slices into this object return np.ndarray objects.
+
+ The object is created by passing the path to any of: (i) the folder containing the
+ raw data, (ii) the *.gtg metadata file, or (iii) one of the raw data *.bin files.
+ In any case, there should be only one dataset (8 *.bin's and a *.gtg) in the folder.
+
+ ===== Filtering and Noise Reduction =====
+ This object is read-only---you cannot edit the data on disk, which means that some
+ DataCube functions like swap_RQ() will not work.
+
+ The K2IS has a "resolution" of 1920x1792, but actually saves hidden stripes in the raw data.
+ By setting the hidden_stripe_noise_reduction flag to True, the electronic noise in these
+ stripes is used to reduce the readout noise. (This is on by default.)
+
+ If you want to take a separate background to subtract, set `dark_reference` to specify this
+ background. This is then subtracted from the frames as they are called out (no matter where
+ the object is referenced! So, for instance, Bragg disk detection will operate on the background-
+ subtracted diffraction patterns!). However, mixing the auto-background and specified background
+ is potentially dangerous and (currently!) not allowed. To switch back from user-background to
+ auto-background, just delete the user background, i.e. `del(dc.data4D.dark_reference)`
+
+ Note:
+ If you call dc.data4D[:,:,:,:] on a DataCube with a K2DataArray this will read the entire stack
+ into memory. To reduce RAM pressure, only call small slices or loop over each diffraction pattern.
+ """
+
+ def __init__(
+ self, filepath, sync_block_IDs=True, hidden_stripe_noise_reduction=True
+ ):
+ from ncempy.io import dm
+ import os
+ import glob
+
+ # first parse the input and get the path to the *.gtg
+ if not os.path.isdir(filepath):
+ filepath = os.path.dirname(filepath)
+
+ assert (
+ len(glob.glob(os.path.join(filepath, "*.bin"))) == 8
+ ), "Wrong path, or wrong number of bin files."
+ assert (
+ len(glob.glob(os.path.join(filepath, "*.gtg"))) == 1
+ ), "Wrong path, or wrong number of gtg files."
+
+ gtgpath = os.path.join(filepath, glob.glob(os.path.join(filepath, "*.gtg"))[0])
+ binprefix = gtgpath[:-4]
+
+ self._gtg_file = gtgpath
+ self._bin_prefix = binprefix
+
+ # open the *.gtg and read the metadata
+ gtg = dm.fileDM(gtgpath)
+ gtg.parseHeader()
+
+ # get the important metadata
+ try:
+ R_Ny = gtg.allTags[".SI Dimensions.Size Y"]
+ R_Nx = gtg.allTags[".SI Dimensions.Size X"]
+ except ValueError:
+ print("Warning: scan shape not detected. Please check/set manually.")
+ R_Nx = self._guess_number_frames() // 32
+ R_Ny = 1
+
+ try:
+ # this may be wrong for binned data... in which case the reader doesn't work anyway!
+ Q_Nx = gtg.allTags[".SI Image Tags.Acquisition.Parameters.Detector.height"]
+ Q_Ny = gtg.allTags[".SI Image Tags.Acquisition.Parameters.Detector.width"]
+ except:
+ print("Warning: diffraction pattern shape not detected!")
+ print("Assuming 1920x1792 as the diffraction pattern size!")
+ Q_Nx = 1792
+ Q_Ny = 1920
+
+ self.shape = (int(R_Nx), int(R_Ny), int(Q_Nx), int(Q_Ny))
+ self._hidden_stripe_noise_reduction = hidden_stripe_noise_reduction
+ self.sync_block_IDs = sync_block_IDs
+
+ self._stripe_dtype = np.dtype(
+ [
+ ("sync", ">u4"),
+ ("pad1", np.void, 5),
+ ("shutter", ">u1"),
+ ("pad2", np.void, 6),
+ (
+ "block",
+ ">u4",
+ ),
+ ("pad4", np.void, 4),
+ ("frame", ">u4"),
+ ("coords", ">u2", (4,)),
+ ("pad3", np.void, 4),
+ ("data", ">u1", (22320,)),
+ ]
+ )
+
+ self._attach_to_files()
+
+ self._shutter_offsets = np.zeros((8,), dtype=np.uint32)
+ self._find_offsets()
+ print("Shutter flags are:", self._shutter_offsets)
+
+ self._gtg_meta = gtg.allTags
+
+ self._user_noise_reduction = False
+
+ self._temp = np.zeros((32,), dtype=self._stripe_dtype)
+ self._Qx, self._Qy = self._parse_slices(
+ (slice(None), slice(None)), "diffraction"
+ )
+
+ # needed for Dask support:
+ self.ndim = 4
+ self.ndims = 4
+ self.dtype = np.int16
+
+ super().__init__()
+
+ # ======== HANDLE SLICING AND len CALLS =========#
+ def __getitem__(self, i):
+ # first check that the slicing is valid:
+ assert (
+ len(i) == 4
+ ), f"Incorrect number of indices given. {len(i)} given, 4 required."
+ # take the input and parse it into coordinate arrays
+ if isinstance(i[0], slice) | isinstance(i[1], slice):
+ Rx, Ry = self._parse_slices(i[:2], "real")
+ R_Nx = Rx.shape[0]
+ R_Ny = Rx.shape[1]
+ assert Rx.max() < self.shape[0], "index out of range"
+ assert Ry.max() < self.shape[1], "index out of range"
+ else: # skip _parse_slices for single input
+ Rx = np.array([i[0]], ndmin=2)
+ Ry = np.array([i[1]], ndmin=2)
+ R_Nx = 1
+ R_Ny = 1
+ assert Rx < self.shape[0], "index out of range"
+ assert Ry < self.shape[1], "index out of range"
+ if (i[2] == slice(None)) & (i[3] == slice(None)):
+ Qx, Qy = self._Qx, self._Qy
+ else:
+ Qx, Qy = self._parse_slices(i[2:], "diffraction")
+
+ assert Qx.max() < self.shape[2], "index out of range"
+ assert Qy.max() < self.shape[3], "index out of range"
+
+ # preallocate the output data array
+ outdata = np.zeros((R_Nx, R_Ny, Qx.shape[0], Qx.shape[1]), dtype=np.int16)
+
+ # loop over all the requested frames
+ for sy in range(R_Ny):
+ for sx in range(R_Nx):
+ scanx = Rx[sx, sy]
+ scany = Ry[sx, sy]
+
+ frame = np.ravel_multi_index(
+ (scanx, scany), (self.shape[0], self.shape[1]), order="F"
+ )
+ DP = self._grab_frame(frame)
+ if self._hidden_stripe_noise_reduction:
+ self._subtract_readout_noise(DP)
+ elif self._user_noise_reduction:
+ DP = DP - self._user_dark_reference
+
+ outdata[sx, sy, :, :] = DP[Qx, Qy].reshape([Qx.shape[0], Qx.shape[1]])
+
+ return np.squeeze(outdata)
+
+ def __len__(self):
+ return np.prod(self.shape)
+
+ # ====== DUCK-TYPED NUMPY FUNCTIONS ======#
+
+ def mean(self, axis=None, dtype=None, out=None, keepdims=False):
+ assert axis in [(0, 1), (2, 3)], "Only average DP and average image supported."
+
+ # handle average DP
+ if axis == (0, 1):
+ avgDP = np.zeros((self.shape[2], self.shape[3]))
+ for Ry, Rx in tqdmnd(self.shape[1], self.shape[0]):
+ avgDP += self[Rx, Ry, :, :]
+
+ return avgDP / (self.shape[0] * self.shape[1])
+
+ # handle average image
+ if axis == (2, 3):
+ avgImg = np.zeros((self.shape[0], self.shape[1]))
+ for Ry, Rx in tqdmnd(self.shape[1], self.shape[0]):
+ avgImg[Rx, Ry] = np.mean(self[Rx, Ry, :, :])
+ return avgImg
+
+ def sum(self, axis=None, dtype=None, out=None, keepdims=False):
+ assert axis in [(0, 1), (2, 3)], "Only sum DP and sum image supported."
+
+ # handle average DP
+ if axis == (0, 1):
+ sumDP = np.zeros((self.shape[2], self.shape[3]))
+ for Ry, Rx in tqdmnd(self.shape[1], self.shape[0]):
+ sumDP += self[Rx, Ry, :, :]
+
+ return sumDP
+
+ # handle average image
+ if axis == (2, 3):
+ sumImg = np.zeros((self.shape[0], self.shape[1]))
+ for Ry, Rx in tqdmnd(self.shape[1], self.shape[0]):
+ sumImg[Rx, Ry] = np.sum(self[Rx, Ry, :, :])
+ return sumImg
+
+ def max(self, axis=None, out=None):
+ assert axis in [(0, 1), (2, 3)], "Only max DP and max image supported."
+
+ # handle average DP
+ if axis == (0, 1):
+ maxDP = np.zeros((self.shape[2], self.shape[3]))
+ for Ry, Rx in tqdmnd(self.shape[1], self.shape[0]):
+ maxDP = np.maximum(maxDP, self[Rx, Ry, :, :])
+
+ return maxDP
+
+ # handle average image
+ if axis == (2, 3):
+ maxImg = np.zeros((self.shape[0], self.shape[1]))
+ for Ry, Rx in tqdmnd(self.shape[1], self.shape[0]):
+ maxImg[Rx, Ry] = np.max(self[Rx, Ry, :, :])
+ return maxImg
+
+ # ====== READING FROM BINARY AND NOISE REDUCTION ======#
+ def _attach_to_files(self):
+ self._bin_files = np.empty(8, dtype=object)
+ for i in range(8):
+ binName = self._bin_prefix + str(i + 1) + ".bin"
+
+ # Synchronize to the magic sync word
+ # First, open the file in binary mode and read ~1 MB
+ with open(binName, "rb") as f:
+ s = f.read(1_000_000)
+
+ # Scan the chunk and find everywhere the sync word appears
+ sync = [
+ s.find(b"\xff\xff\x00\x55"),
+ ]
+ while sync[-1] >= 0:
+ sync.append(s.find(b"\xff\xff\x00\x55", sync[-1] + 1))
+
+ # Since the sync word can conceivably occur within the data region,
+ # check that there is another sync word 22360 bytes away
+ sync_idx = 0
+ while 0 not in [s - sync[sync_idx] - 22360 for s in sync]:
+ sync_idx += 1
+
+ if sync_idx > 0:
+ print(
+ f"Beginning file {i} at offset {sync[sync_idx]} due to incomplete data block!"
+ )
+
+ # Start the memmap at the offset of the sync byte
+ self._bin_files[i] = np.memmap(
+ binName,
+ dtype=self._stripe_dtype,
+ mode="r",
+ shape=(self._guess_number_frames(),),
+ offset=sync[sync_idx],
+ )
+
+ def _find_offsets(self):
+ # first, line up the block counts (LiberTEM calls this sync_sectors)
+ if self.sync_block_IDs:
+ print("Synchronizing block IDs.")
+ first_blocks = np.zeros((8,), dtype=np.uint32)
+ for i in range(8):
+ binfile = self._bin_files[i]
+ first_blocks[i] = binfile[0]["block"]
+
+ # find the first frame in each with the starting block
+ block_id = np.max(first_blocks)
+ print("First block syncs to block #", block_id)
+ for i in range(8):
+ sync = False
+ frame = 0
+ while sync is False:
+ sync = self._bin_files[i][frame]["block"] == block_id
+ if sync is False:
+ frame += 1
+ self._shutter_offsets[i] += frame
+ print("Offsets are currently ", self._shutter_offsets)
+ else:
+ print("Skipping block ID synchronization step...")
+
+ first_frame = self._bin_files[0][self._shutter_offsets[0]]["frame"]
+ # next, check if the frames are complete (the next 32 blocks should have the same block #)
+ print("Checking if first frame is complete...")
+ sync = True
+ for i in range(8):
+ stripe = self._bin_files[i][
+ self._shutter_offsets[i] : self._shutter_offsets[i] + 32
+ ]
+ for j in range(32):
+ if stripe[j]["frame"] != first_frame:
+ sync = False
+ next_frame = stripe[j]["frame"]
+
+ if sync is False:
+ # the first frame is incomplete, so we need to seek the next one
+ print(
+ f"First frame ({first_frame}) incomplete, seeking frame {next_frame}..."
+ )
+ for i in range(8):
+ sync = False
+ frame = 0
+ while sync is False:
+ sync = (
+ self._bin_files[i][self._shutter_offsets[i] + frame]["frame"]
+ == next_frame
+ )
+ if sync is False:
+ frame += 1
+ self._shutter_offsets[i] += frame
+ print("Offsets are now ", self._shutter_offsets)
+
+ # JUST TO BE SAFE, CHECK AGAIN THAT FRAME IS COMPLETE
+ print("Checking if new frame is complete...")
+ first_frame = np.max(stripe["frame"])
+ # check if the frames are complete (the next 32 blocks should have the same block #)
+ sync = True
+ for i in range(8):
+ stripe = self._bin_files[i][
+ self._shutter_offsets[i] : self._shutter_offsets[i] + 32
+ ]
+ if np.any(stripe[:]["frame"] != first_frame):
+ sync = False
+ if sync is True:
+ print("New frame is complete!")
+ else:
+ print("Next frame also incomplete!!!! Data may be corrupt?")
+
+ # in each file, find the first frame with open shutter (LiberTEM calls this sync_to_first_frame)
+ print("Synchronizing to shutter open...")
+ for i in range(8):
+ shutter = False
+ frame = 0
+ while shutter is False:
+ offset = self._shutter_offsets[i] + (frame * 32)
+ stripe = self._bin_files[i][offset : offset + 32]
+ shutter = stripe[0]["shutter"]
+ if shutter == 0:
+ frame += 1
+
+ self._shutter_offsets[i] += frame * 32
+
+ def _grab_frame(self, frame):
+ fullImage = np.zeros([1860, 2048], dtype=np.int16)
+ for ii in range(8):
+ xOffset = ii * 256 # the x location of the sector for each BIN file
+ # read a set of stripes:
+ start = self._shutter_offsets[ii] + (frame * 32)
+ np.copyto(self._temp, self._bin_files[ii][start : start + 32])
+
+ if np.any(self._temp[:]["sync"] != 0xFFFF0055):
+ print(
+ "The binary file is unsynchronized and cannot be read. You must use Digital Micrograph to extract to *.dm4."
+ )
+ break # stop reading if the sync byte is not correct. Ideally, this would read the next byte, etc... until this exact value is found
+ # parse the stripes
+ for jj in range(0, 32):
+ coords = self._temp[jj][
+ "coords"
+ ] # first x, first y, last x, last y; ref to 0;inclusive;should indicate 16x930 pixels
+ # place the data in the image
+ fullImage[
+ coords[1] : coords[3] + 1,
+ coords[0] + xOffset : coords[2] + xOffset + 1,
+ ] = _convert_uint12(self._temp[jj]["data"]).reshape([930, 16])
+ return fullImage
+
+ @staticmethod
+ def _subtract_readout_noise(DP):
+ # subtract readout noise using the hidden stripes
+ darkref = np.floor_divide(
+ np.sum(DP[1792:, :], axis=0, dtype=np.int16),
+ np.int16(1860 - 1792),
+ dtype=np.int16,
+ )
+ DP -= darkref[np.newaxis, :]
+
+ # Handle the user specifying a dark reference (fix the size and make sure auto gets turned off)
+ @property
+ def dark_reference(self):
+ return self._user_dark_reference[:1792, :1920]
+
+ @dark_reference.setter
+ def dark_reference(self, dr):
+ assert dr.shape == (
+ 1792,
+ 1920,
+ ), "Dark reference must be the size of an active frame"
+ # assert dr.dtype == np.uint16, "Dark reference must be 16 bit unsigned"
+ self._user_dark_reference = np.zeros((1860, 2048), dtype=np.int16)
+ self._user_dark_reference[:1792, :1920] = dr
+
+ # disable auto noise reduction
+ self._hidden_stripe_noise_reduction = False
+ self._user_noise_reduction = True
+
+ @dark_reference.deleter
+ def dark_reference(self):
+ del self._user_dark_reference
+ self._user_noise_reduction = False
+
+ # ======== UTILITY FUNCTIONS ========#
+ def _parse_slices(self, i, mode):
+ assert len(i) == 2, "Wrong size input"
+
+ if mode == "real":
+ xMax = self.shape[0] # R_Nx
+ yMax = self.shape[1] # R_Ny
+ elif mode == "diffraction":
+ xMax = self.shape[2] # Q_Nx
+ yMax = self.shape[3] # Q_Ny
+ else:
+ raise ValueError("incorrect slice mode")
+
+ if isinstance(i[0], slice):
+ xInds = np.arange(i[0].start or 0, (i[0].stop or xMax), (i[0].step or 1))
+ else:
+ xInds = i[0]
+
+ if isinstance(i[1], slice):
+ yInds = np.arange(i[1].start or 0, (i[1].stop or yMax), (i[1].step or 1))
+ else:
+ yInds = i[1]
+
+ if mode == "diffraction":
+ x, y = np.meshgrid(xInds, yInds, indexing="ij")
+ elif mode == "real":
+ x, y = np.meshgrid(xInds, yInds, indexing="xy")
+ return x, y
+
+ def _guess_number_frames(self):
+ import os
+
+ nbytes = np.array(
+ [os.path.getsize(self._bin_prefix + f"{n}.bin") for n in range(1, 9)]
+ )
+ return np.min(nbytes) // 0x5758
+
+ def _write_to_hdf5(self, group):
+ """
+ Write the entire dataset to an HDF5 file.
+ group should be an HDF5 Group object.
+ ( This function is normally called via py4DSTEM.file.io.save() )
+ """
+ dset = group.create_dataset("data", (self.shape), "i2")
+
+ for sy in range(self.shape[1]):
+ for sx in range(self.shape[0]):
+ dset[sx, sy, :, :] = self[sx, sy, :, :]
+
+ return dset
+
+
+# ======= UTILITIES OUTSIDE THE CLASS ======#
+import sys
+
+if "numba" in sys.modules:
+
+ @nb.njit(nb.int16[::1](nb.uint8[::1]), fastmath=False, parallel=False)
+ def _convert_uint12(data_chunk):
+ """
+ data_chunk is a contigous 1D array of uint8 data)
+ eg.data_chunk = np.frombuffer(data_chunk, dtype=np.uint8)
+ """
+
+ # ensure that the data_chunk has the right length
+ assert np.mod(data_chunk.shape[0], 3) == 0
+
+ out = np.empty(data_chunk.shape[0] // 3 * 2, dtype=np.uint16)
+
+ for i in nb.prange(data_chunk.shape[0] // 3):
+ fst_uint8 = np.uint16(data_chunk[i * 3])
+ mid_uint8 = np.uint16(data_chunk[i * 3 + 1])
+ lst_uint8 = np.uint16(data_chunk[i * 3 + 2])
+
+ out[i * 2] = (
+ fst_uint8 | (mid_uint8 & 0x0F) << 8
+ ) # (fst_uint8 << 4) + (mid_uint8 >> 4)
+ out[i * 2 + 1] = (
+ mid_uint8 & 0xF0
+ ) >> 4 | lst_uint8 << 4 # ((mid_uint8 % 16) << 8) + lst_uint8
+
+ DP = out.astype(np.int16)
+ return DP
+
+else:
+
+ def _convert_uint12(data_chunk):
+ """
+ data_chunk is a contigous 1D array of uint8 data)
+ eg.data_chunk = np.frombuffer(data_chunk, dtype=np.uint8)
+ """
+
+ # ensure that the data_chunk has the right length
+ assert np.mod(data_chunk.shape[0], 3) == 0
+
+ out = np.empty(data_chunk.shape[0] // 3 * 2, dtype=np.uint16)
+
+ for i in range(data_chunk.shape[0] // 3):
+ fst_uint8 = np.uint16(data_chunk[i * 3])
+ mid_uint8 = np.uint16(data_chunk[i * 3 + 1])
+ lst_uint8 = np.uint16(data_chunk[i * 3 + 2])
+
+ out[i * 2] = (
+ fst_uint8 | (mid_uint8 & 0x0F) << 8
+ ) # (fst_uint8 << 4) + (mid_uint8 >> 4)
+ out[i * 2 + 1] = (
+ mid_uint8 & 0xF0
+ ) >> 4 | lst_uint8 << 4 # ((mid_uint8 % 16) << 8) + lst_uint8
+
+ DP = out.astype(np.int16)
+ return DP
diff --git a/py4DSTEM/io/filereaders/read_abTEM.py b/py4DSTEM/io/filereaders/read_abTEM.py
new file mode 100644
index 000000000..805439023
--- /dev/null
+++ b/py4DSTEM/io/filereaders/read_abTEM.py
@@ -0,0 +1,81 @@
+import h5py
+from py4DSTEM.data import DiffractionSlice, RealSlice
+from py4DSTEM.datacube import DataCube
+
+
+def read_abTEM(
+ filename,
+ mem="RAM",
+ binfactor: int = 1,
+):
+ """
+ File reader for abTEM datasets
+ Args:
+ filename: str with path to file
+ mem (str): Must be "RAM" or "MEMMAP". Specifies how the data is
+ loaded; "RAM" transfer the data from storage to RAM, while "MEMMAP"
+ leaves the data in storage and creates a memory map which points to
+ the diffraction patterns, allowing them to be retrieved individually
+ from storage.
+ binfactor (int): Diffraction space binning factor for bin-on-load.
+
+ Returns:
+ DataCube
+ """
+ assert mem == "RAM", "read_abTEM does not support memory mapping"
+ assert binfactor == 1, "abTEM files can only be read at full resolution"
+
+ with h5py.File(filename, "r") as f:
+ datasets = {}
+ for key in f.keys():
+ datasets[key] = f.get(key)[()]
+
+ data = datasets["array"]
+
+ sampling = datasets["sampling"]
+ units = datasets["units"]
+
+ assert len(data.shape) in (2, 4), "abtem reader supports only 4D and 2D data"
+
+ if len(data.shape) == 4:
+ datacube = DataCube(data=data)
+
+ datacube.calibration.set_R_pixel_size(sampling[0])
+ if sampling[0] != sampling[1]:
+ print(
+ "Warning: py4DSTEM currently only handles uniform x,y sampling. Setting sampling with x calibration"
+ )
+ datacube.calibration.set_Q_pixel_size(sampling[2])
+ if sampling[2] != sampling[3]:
+ print(
+ "Warning: py4DSTEM currently only handles uniform qx,qy sampling. Setting sampling with qx calibration"
+ )
+
+ if units[0] == b"\xc3\x85":
+ datacube.calibration.set_R_pixel_units("A")
+ else:
+ datacube.calibration.set_R_pixel_units(units[0].decode("utf-8"))
+
+ datacube.calibration.set_Q_pixel_units(units[2].decode("utf-8"))
+
+ return datacube
+
+ else:
+ if units[0] == b"mrad":
+ diffraction = DiffractionSlice(data=data)
+ if sampling[0] != sampling[1]:
+ print(
+ "Warning: py4DSTEM currently only handles uniform qx,qy sampling. Setting sampling with x calibration"
+ )
+ diffraction.calibration.set_Q_pixel_units(units[0].decode("utf-8"))
+ diffraction.calibration.set_Q_pixel_size(sampling[0])
+ return diffraction
+ else:
+ image = RealSlice(data=data)
+ if sampling[0] != sampling[1]:
+ print(
+ "Warning: py4DSTEM currently only handles uniform x,y sampling. Setting sampling with x calibration"
+ )
+ image.calibration.set_Q_pixel_units("A")
+ image.calibration.set_Q_pixel_size(sampling[0])
+ return image
diff --git a/py4DSTEM/io/filereaders/read_arina.py b/py4DSTEM/io/filereaders/read_arina.py
new file mode 100644
index 000000000..6f7c463d2
--- /dev/null
+++ b/py4DSTEM/io/filereaders/read_arina.py
@@ -0,0 +1,114 @@
+import h5py
+import hdf5plugin
+import numpy as np
+from py4DSTEM.datacube import DataCube
+from py4DSTEM.preprocess.utils import bin2D
+
+
+def read_arina(
+ filename,
+ scan_width=1,
+ mem="RAM",
+ binfactor: int = 1,
+ dtype_bin: float = None,
+ flatfield: np.ndarray = None,
+):
+ """
+ File reader for arina 4D-STEM datasets
+ Args:
+ filename: str with path to master file
+ scan_width: x dimension of scan
+ mem (str): Must be "RAM" or "MEMMAP". Specifies how the data is
+ loaded; "RAM" transfer the data from storage to RAM, while "MEMMAP"
+ leaves the data in storage and creates a memory map which points to
+ the diffraction patterns, allowing them to be retrieved individually
+ from storage.
+ binfactor (int): Diffraction space binning factor for bin-on-load.
+ dtype_bin(float): specify datatype for bin on load if need something
+ other than uint16
+ flatfield (np.ndarray):
+ flatfield forcorrection factors
+
+ Returns:
+ DataCube
+ """
+ assert mem == "RAM", "read_arina does not support memory mapping"
+
+ f = h5py.File(filename, "r")
+ nimages = 0
+
+ # Count the number of images in all datasets
+ for dset in f["entry"]["data"]:
+ nimages = nimages + f["entry"]["data"][dset].shape[0]
+ height = f["entry"]["data"][dset].shape[1]
+ width = f["entry"]["data"][dset].shape[2]
+ dtype = f["entry"]["data"][dset].dtype
+
+ width = width // binfactor
+ height = height // binfactor
+
+ assert (
+ nimages % scan_width < 1e-6
+ ), "scan_width must be integer multiple of x*y size"
+
+ if dtype.type is np.uint32:
+ print("Dataset is uint32 but will be converted to uint16")
+ dtype = np.dtype(np.uint16)
+
+ if dtype_bin:
+ array_3D = np.empty((nimages, width, height), dtype=dtype_bin)
+ else:
+ array_3D = np.empty((nimages, width, height), dtype=dtype)
+
+ image_index = 0
+
+ if flatfield is None:
+ correction_factors = 1
+ else:
+ # Avoid div by 0 errors -> pixel with value 0 will be set to meadian
+ flatfield[flatfield == 0] = 1
+ correction_factors = np.median(flatfield) / flatfield
+
+ for dset in f["entry"]["data"]:
+ image_index = _processDataSet(
+ f["entry"]["data"][dset],
+ image_index,
+ array_3D,
+ binfactor,
+ correction_factors,
+ )
+
+ if f.__bool__():
+ f.close()
+
+ scan_height = int(nimages / scan_width)
+
+ datacube = DataCube(
+ np.flip(
+ array_3D.reshape(
+ scan_width, scan_height, array_3D.data.shape[1], array_3D.data.shape[2]
+ ),
+ 0,
+ )
+ )
+
+ return datacube
+
+
+def _processDataSet(dset, start_index, array_3D, binfactor, correction_factors):
+ image_index = start_index
+ nimages_dset = dset.shape[0]
+
+ for i in range(nimages_dset):
+ if binfactor == 1:
+ array_3D[image_index] = np.multiply(
+ dset[i].astype(array_3D.dtype), correction_factors
+ )
+ else:
+ array_3D[image_index] = bin2D(
+ np.multiply(dset[i].astype(array_3D.dtype), correction_factors),
+ binfactor,
+ )
+
+ image_index = image_index + 1
+ return image_index
diff --git a/py4DSTEM/io/filereaders/read_dm.py b/py4DSTEM/io/filereaders/read_dm.py
new file mode 100644
index 000000000..617529708
--- /dev/null
+++ b/py4DSTEM/io/filereaders/read_dm.py
@@ -0,0 +1,194 @@
+# Reads a digital micrograph 4D-STEM dataset
+
+import numpy as np
+from pathlib import Path
+from ncempy.io import dm
+from emdfile import tqdmnd, Array
+
+from py4DSTEM.datacube import DataCube
+from py4DSTEM.preprocess.utils import bin2D
+
+
+def read_dm(filepath, name="dm_dataset", mem="RAM", binfactor=1, **kwargs):
+ """
+ Read a digital micrograph 4D-STEM file.
+
+ Args:
+ filepath: str or Path Path to the file
+ mem (str, optional): Specifies how the data is stored. Must be
+ "RAM", or "MEMMAP". See docstring for py4DSTEM.file.io.read. Default
+ is "RAM".
+ binfactor (int, optional): Bin the data, in diffraction space, as it's
+ loaded. See docstring for py4DSTEM.file.io.read. Default is 1.
+ metadata (bool, optional): if True, returns the file metadata as a
+ Metadata instance.
+ kwargs:
+ "dtype": a numpy dtype specifier to use for data binned on load,
+ defaults to the data's current dtype
+
+ Returns:
+ DataCube if a 4D dataset is found, else an ND Array
+ """
+
+ # open the file
+ with dm.fileDM(filepath, on_memory=False) as dmFile:
+ # loop through datasets looking for one with more than 2D
+ # This is needed because:
+ # NCEM TitanX files store 4D data in a 3D array
+ # K3 sometimes stores 2D images alongside the 4D array
+ thumbanil_count = 1 if dmFile.thumbnail else 0
+ dataset_index = 0
+ for i in range(dmFile.numObjects - thumbanil_count):
+ temp_data = dmFile.getMemmap(i)
+ if len(np.squeeze(temp_data).shape) > 2:
+ dataset_index = i
+ break
+
+ # We will only try to read pixel sizes for 4D data for now
+ pixel_size_found = False
+ if dmFile.dataShape[dataset_index + thumbanil_count] > 2:
+ # The pixel sizes of all datasets are chained together, so
+ # we have to figure out the right offset
+ try:
+ scale_offset = (
+ sum(dmFile.dataShape[:dataset_index]) + 2 * thumbanil_count
+ )
+ pixelsize = dmFile.scale[scale_offset:]
+ pixelunits = dmFile.scaleUnit[scale_offset:]
+
+ # Get the calibration pixel sizes
+ Q_pixel_size = pixelsize[0]
+ Q_pixel_units = "pixels" if pixelunits[0] == "" else pixelunits[0]
+ R_pixel_size = pixelsize[2]
+ R_pixel_units = "pixels" if pixelunits[2] == "" else pixelunits[2]
+
+ # Check that the units are sensible
+ # On microscopes that do not have live communication with the detector
+ # the calibrations can be invalid
+ if Q_pixel_units in ("nm", "µm"):
+ Q_pixel_units = "pixels"
+ Q_pixel_size = 1
+ R_pixel_units = "pixels"
+ R_pixel_size = 1
+
+ # Convert mrad to Å^-1 if possible
+ if Q_pixel_units == "mrad":
+ voltage = [
+ v
+ for t, v in dmFile.allTags.items()
+ if "Microscope Info.Voltage" in t
+ ]
+ if len(voltage) >= 1:
+ from py4DSTEM.process.utils import electron_wavelength_angstrom
+
+ wavelength = electron_wavelength_angstrom(voltage[0])
+ Q_pixel_units = "A^-1"
+ Q_pixel_size = (
+ Q_pixel_size / wavelength / 1000.0
+ ) # convert mrad to 1/Å
+ elif Q_pixel_units == "1/nm":
+ Q_pixel_units = "A^-1"
+ Q_pixel_size /= 10
+
+ pixel_size_found = True
+ except Exception as err:
+ pass
+
+ # Handle 3D NCEM TitanX data
+ titan_shape = _process_NCEM_TitanX_Tags(dmFile)
+
+ if mem == "RAM":
+ if binfactor == 1:
+ _data = dmFile.getDataset(dataset_index)["data"]
+ else:
+ # get a memory map
+ _mmap = dmFile.getMemmap(dataset_index)
+
+ # get the dtype for the binned data
+ dtype = kwargs.get("dtype", _mmap[0, 0].dtype)
+
+ if titan_shape is not None:
+ # NCEM TitanX tags were found
+ _mmap = np.reshape(_mmap, titan_shape + _mmap.shape[-2:])
+ new_shape = (
+ *_mmap.shape[:2],
+ _mmap.shape[2] // binfactor,
+ _mmap.shape[3] // binfactor,
+ )
+ _data = np.zeros(new_shape, dtype=dtype)
+
+ for rx, ry in tqdmnd(*_data.shape[:2]):
+ _data[rx, ry] = bin2D(_mmap[rx, ry], binfactor, dtype=dtype)
+
+ elif (mem, binfactor) == ("MEMMAP", 1):
+ _data = dmFile.getMemmap(dataset_index)
+ else:
+ raise Exception(
+ "Memory mapping and on-load binning together is not supported. Either set binfactor=1 or mem='RAM'."
+ )
+ return
+
+ if titan_shape is not None:
+ # NCEM TitanX tags were found
+ _data = np.reshape(_data, titan_shape + _data.shape[-2:])
+
+ if len(_data.shape) == 4:
+ data = DataCube(_data, name=name)
+ if pixel_size_found:
+ try:
+ data.calibration.set_Q_pixel_size(Q_pixel_size * binfactor)
+ data.calibration.set_Q_pixel_units(Q_pixel_units)
+ data.calibration.set_R_pixel_size(R_pixel_size)
+ data.calibration.set_R_pixel_units(R_pixel_units)
+ except Exception as err:
+ print(
+ f"Setting pixel sizes of the datacube failed with error {err}"
+ )
+ else:
+ data = Array(_data, name=name)
+
+ return data
+
+
+def _process_NCEM_TitanX_Tags(dmFile, dc=None):
+ """
+ Check the metadata in the DM File for certain tags which are added by
+ the NCEM TitanX, and reshape the 3D datacube into 4D using these tags.
+ Also fixes the two-pixel roll issue present in TitanX data. If no datacube
+ is passed, return R_Nx and R_Ny
+ """
+ scanx = [v for k, v in dmFile.allTags.items() if "4D STEM Tags.Scan shape X" in k]
+ scany = [v for k, v in dmFile.allTags.items() if "4D STEM Tags.Scan shape Y" in k]
+ if len(scanx) >= 1 and len(scany) >= 1:
+ # TitanX tags found!
+ R_Nx = int(scany[0]) # need to flip x/y
+ R_Ny = int(scanx[0])
+
+ if dc is not None:
+ dc.set_scan_shape(R_Nx, R_Ny)
+ dc.data = np.roll(dc.data, shift=-2, axis=1)
+ else:
+ return R_Nx, R_Ny
+
+
+# def get_metadata_from_dmFile(fp):
+# """ Accepts a filepath to a dm file and returns a Metadata instance
+# """
+# metadata = Metadata()
+#
+# with dm.fileDM(fp, on_memory=False) as dmFile:
+# pixelSizes = dmFile.scale
+# pixelUnits = dmFile.scaleUnit
+# assert pixelSizes[0] == pixelSizes[1], "Rx and Ry pixel sizes don't match"
+# assert pixelSizes[2] == pixelSizes[3], "Qx and Qy pixel sizes don't match"
+# assert pixelUnits[0] == pixelUnits[1], "Rx and Ry pixel units don't match"
+# assert pixelUnits[2] == pixelUnits[3], "Qx and Qy pixel units don't match"
+# for i in range(len(pixelUnits)):
+# if pixelUnits[i] == "":
+# pixelUnits[i] = "pixels"
+# metadata.set_R_pixel_size__microscope(pixelSizes[0])
+# metadata.set_R_pixel_size_units__microscope(pixelUnits[0])
+# metadata.set_Q_pixel_size__microscope(pixelSizes[2])
+# metadata.set_Q_pixel_size_units__microscope(pixelUnits[2])
+#
+# return metadata
diff --git a/py4DSTEM/io/filereaders/read_mib.py b/py4DSTEM/io/filereaders/read_mib.py
new file mode 100644
index 000000000..7456bd594
--- /dev/null
+++ b/py4DSTEM/io/filereaders/read_mib.py
@@ -0,0 +1,355 @@
+# Read the mib file captured using the Merlin detector
+# Author: Tara Mishra, tara.matsci@gmail.
+# Based on the PyXEM load_mib module https://github.com/pyxem/pyxem/blob/563a3bb5f3233f46cd3e57f3cd6f9ddf7af55ad0/pyxem/utils/io_utils.py
+
+import numpy as np
+from py4DSTEM.datacube import DataCube
+import os
+
+
+def load_mib(
+ file_path,
+ mem="MEMMAP",
+ binfactor=1,
+ reshape=True,
+ flip=True,
+ scan=(256, 256),
+ **kwargs,
+):
+ """
+ Read a MIB file and return as py4DSTEM DataCube.
+
+ The scan size is not encoded in the MIB metadata - by default it is
+ set to (256,256), and can be modified by passing the keyword `scan`.
+ """
+
+ assert binfactor == 1, "MIB does not support bin-on-load... yet?"
+
+ # Get scan info from kwargs
+ header = parse_hdr(file_path)
+ width = header["width"]
+ height = header["height"]
+ width_height = width * height
+
+ data = get_mib_memmap(file_path)
+ depth = get_mib_depth(header, file_path)
+ hdr_bits = get_hdr_bits(header)
+
+ if header["Counter Depth (number)"] == 1:
+ # RAW 1 bit data: the header bits are written as uint8 but the frames
+ # are binary and need to be unpacked as such.
+ data = data.reshape(-1, int(width_height / 8 + hdr_bits))
+ data = data[:, hdr_bits:]
+ # get the shape axis 1 before unpackbit
+ s0 = data.shape[0]
+ s1 = data.shape[1]
+ data = np.unpackbits(data)
+ data.reshape(s0, s1 * 8)
+ else:
+ data = data.reshape(-1, int(width_height + hdr_bits))
+ data = data[:, hdr_bits:]
+
+ if header["raw"] == "MIB":
+ data = data.reshape(depth, width, height)
+ else:
+ print("Data type not supported as MIB reader")
+
+ if reshape:
+ data = data.reshape(scan[0], scan[1], width, height)
+
+ if mem == "RAM":
+ data = np.array(data) # Load entire dataset into RAM
+
+ py4dstem_data = DataCube(data=data)
+ return py4dstem_data
+
+
+def manageHeader(fname):
+ """Get necessary information from the header of the .mib file.
+ Parameters
+ ----------
+ fname : str
+ Filename for header file.
+ Returns
+ -------
+ hdr : tuple
+ (DataOffset,NChips,PixelDepthInFile,sensorLayout,Timestamp,shuttertime,bitdepth)
+ Examples
+ --------
+ #Output for 6bit 256*256 data:
+ #(768, 4, 'R64', '2x2', '2019-06-14 11:46:12.607836', 0.0002, 6)
+ #Output for 12bit single frame nor RAW:
+ #(768, 4, 'U16', '2x2', '2019-06-06 11:12:42.001309', 0.001, 12)
+ """
+ Header = str()
+ with open(fname, "rb") as input:
+ aByte = input.read(1)
+ Header += str(aByte.decode("ascii"))
+ # This gets rid of the header
+ while aByte and ord(aByte) != 0:
+ aByte = input.read(1)
+ Header += str(aByte.decode("ascii"))
+
+ elements_in_header = Header.split(",")
+
+ DataOffset = int(elements_in_header[2])
+
+ NChips = int(elements_in_header[3])
+
+ PixelDepthInFile = elements_in_header[6]
+ sensorLayout = elements_in_header[7].strip()
+ Timestamp = elements_in_header[9]
+ shuttertime = float(elements_in_header[10])
+
+ if PixelDepthInFile == "R64":
+ bitdepth = int(elements_in_header[18]) # RAW
+ elif PixelDepthInFile == "U16":
+ bitdepth = 12
+ elif PixelDepthInFile == "U08":
+ bitdepth = 6
+ elif PixelDepthInFile == "U32":
+ bitdepth = 24
+
+ hdr = (
+ DataOffset,
+ NChips,
+ PixelDepthInFile,
+ sensorLayout,
+ Timestamp,
+ shuttertime,
+ bitdepth,
+ )
+
+ return hdr
+
+
+def parse_hdr(fp):
+ """Parse information from mib file header info from _manageHeader function.
+ Parameters
+ ----------
+ fp : str
+ Filepath to .mib file.
+ Returns
+ -------
+ hdr_info : dict
+ Dictionary containing header info extracted from .mib file.
+ The entries of the dictionary are as follows:
+ 'width': int
+ pixels, detector number of pixels in x direction,
+ 'height': int
+ pixels detector number of pixels in y direction,
+ 'Assembly Size': str
+ configuration of the detector chips, e.g. '2x2' for quad,
+ 'offset': int
+ number of characters in the header before the first frame starts,
+ 'data-type': str
+ always 'unsigned',
+ 'data-length': str
+ identifying dtype,
+ 'Counter Depth (number)': int
+ counter bit depth,
+ 'raw': str
+ regular binary 'MIB' or raw binary 'R64',
+ 'byte-order': str
+ always 'dont-care',
+ 'record-by': str
+ 'image' or 'vector' - only 'image' encountered,
+ 'title': str
+ path of the mib file without extension, e.g. '/dls/e02/data/2020/cm26481-1/Merlin/testing/20200204 115306/test',
+ 'date': str
+ date created, e.g. '20200204',
+ 'time': str
+ time created, e.g. '11:53:32.295336',
+ 'data offset': int
+ number of characters at the header.
+ """
+ hdr_info = {}
+
+ read_hdr = manageHeader(fp)
+
+ # Set the array size of the chip
+
+ if read_hdr[3] == "1x1":
+ hdr_info["width"] = 256
+ hdr_info["height"] = 256
+ elif read_hdr[3] == "2x2":
+ hdr_info["width"] = 512
+ hdr_info["height"] = 512
+
+ hdr_info["Assembly Size"] = read_hdr[3]
+
+ # Set mib offset
+ hdr_info["offset"] = read_hdr[0]
+ # Set data-type
+ hdr_info["data-type"] = "unsigned"
+ # Set data-length
+ if read_hdr[6] == "1":
+ # Binary data recorded as 8 bit numbers
+ hdr_info["data-length"] = "8"
+ else:
+ # Changes 6 to 8 , 12 to 16 and 24 to 32 bit
+ cd_int = int(read_hdr[6])
+ hdr_info["data-length"] = str(int((cd_int + cd_int / 3)))
+
+ hdr_info["Counter Depth (number)"] = int(read_hdr[6])
+ if read_hdr[2] == "R64":
+ hdr_info["raw"] = "R64"
+ else:
+ hdr_info["raw"] = "MIB"
+ # Set byte order
+ hdr_info["byte-order"] = "dont-care"
+ # Set record by to stack of images
+ hdr_info["record-by"] = "image"
+
+ # Set title to file name
+ hdr_info["title"] = fp.split(".")[0]
+ # Set time and date
+ # Adding the try argument to accommodate the new hdr formatting as of April 2018
+ try:
+ year, month, day_time = read_hdr[4].split("-")
+ day, time = day_time.split(" ")
+ hdr_info["date"] = year + month + day
+ hdr_info["time"] = time
+ except BaseException:
+ day, month, year_time = read_hdr[4].split("/")
+ year, time = year_time.split(" ")
+ hdr_info["date"] = year + month + day
+ hdr_info["time"] = time
+
+ hdr_info["data offset"] = read_hdr[0]
+
+ return hdr_info
+
+
+def get_mib_memmap(fp, mmap_mode="r"):
+ """Reads the binary mib file into a numpy memmap object and returns as dask array object.
+ Parameters
+ ----------
+ fp: str
+ MIB file name / path
+ mmap_mode: str
+ memmpap read mode - default is 'r'
+ Returns
+ -------
+ data_da: dask array
+ data as a dask array object
+ """
+ hdr_info = parse_hdr(fp)
+ data_length = hdr_info["data-length"]
+ data_type = hdr_info["data-type"]
+ endian = hdr_info["byte-order"]
+ read_offset = 0
+
+ if data_type == "signed":
+ data_type = "int"
+ elif data_type == "unsigned":
+ data_type = "uint"
+ elif data_type == "float":
+ pass
+ else:
+ raise TypeError('Unknown "data-type" string.')
+
+ # mib data always big-endian
+ endian = ">"
+ data_type += str(int(data_length))
+ # uint1 not a valid dtype
+ if data_type == "uint1":
+ data_type = "uint8"
+ data_type = np.dtype(data_type)
+ else:
+ data_type = np.dtype(data_type)
+ data_type = data_type.newbyteorder(endian)
+
+ data_mem = np.memmap(fp, offset=read_offset, dtype=data_type, mode=mmap_mode)
+ return data_mem
+
+
+def get_mib_depth(hdr_info, fp):
+ """Determine the total number of frames based on .mib file size.
+ Parameters
+ ----------
+ hdr_info : dict
+ Dictionary containing header info extracted from .mib file.
+ fp : filepath
+ Path to .mib file.
+ Returns
+ -------
+ depth : int
+ Number of frames in the stack
+ """
+ # Define standard frame sizes for quad and single medipix chips
+ if hdr_info["Assembly Size"] == "2x2":
+ mib_file_size_dict = {
+ "1": 33536,
+ "6": 262912,
+ "12": 525056,
+ "24": 1049344,
+ }
+ if hdr_info["Assembly Size"] == "1x1":
+ mib_file_size_dict = {
+ "1": 8576,
+ "6": 65920,
+ "12": 131456,
+ "24": 262528,
+ }
+
+ file_size = os.path.getsize(fp[:-3] + "mib")
+ if hdr_info["raw"] == "R64":
+ single_frame = mib_file_size_dict.get(str(hdr_info["Counter Depth (number)"]))
+ depth = int(file_size / single_frame)
+ elif hdr_info["raw"] == "MIB":
+ if hdr_info["Counter Depth (number)"] == "1":
+ # 1 bit and 6 bit non-raw frames have the same size
+ single_frame = mib_file_size_dict.get("6")
+ depth = int(file_size / single_frame)
+ else:
+ single_frame = mib_file_size_dict.get(
+ str(hdr_info["Counter Depth (number)"])
+ )
+ depth = int(file_size / single_frame)
+
+ return depth
+
+
+def get_hdr_bits(hdr_info):
+ """Gets the number of character bits for the header for each frame given the data type.
+ Parameters
+ ----------
+ hdr_info: dict
+ output of the parse_hdr function
+ Returns
+ -------
+ hdr_bits: int
+ number of characters in the header
+ """
+ data_length = hdr_info["data-length"]
+ data_type = hdr_info["data-type"]
+
+ if data_type == "signed":
+ data_type = "int"
+ elif data_type == "unsigned":
+ data_type = "uint"
+ elif data_type == "float":
+ pass
+ else:
+ raise TypeError('Unknown "data-type" string.')
+
+ # mib data always big-endian
+ endian = ">"
+ data_type += str(int(data_length))
+ # uint1 not a valid dtype
+ if data_type == "uint1":
+ data_type = "uint8"
+ data_type = np.dtype(data_type)
+ else:
+ data_type = np.dtype(data_type)
+ data_type = data_type.newbyteorder(endian)
+
+ if data_length == "1":
+ hdr_multiplier = 1
+ else:
+ hdr_multiplier = (int(data_length) / 8) ** -1
+
+ hdr_bits = int(hdr_info["data offset"] * hdr_multiplier)
+
+ return hdr_bits
diff --git a/py4DSTEM/io/google_drive_downloader.py b/py4DSTEM/io/google_drive_downloader.py
new file mode 100644
index 000000000..5b53f19ae
--- /dev/null
+++ b/py4DSTEM/io/google_drive_downloader.py
@@ -0,0 +1,283 @@
+import gdown
+import os
+import warnings
+
+
+### File IDs
+
+# single files
+file_ids = {
+ "sample_diffraction_pattern": (
+ "a_diffraction_pattern.h5",
+ "1ymYMnuDC0KV6dqduxe2O1qafgSd0jjnU",
+ ),
+ "Au_sim": (
+ "Au_sim.h5",
+ "1PmbCYosA1eYydWmmZebvf6uon9k_5g_S",
+ ),
+ "carbon_nanotube": (
+ "carbon_nanotube.h5",
+ "1bHv3u61Cr-y_GkdWHrJGh1lw2VKmt3UM",
+ ),
+ "Si_SiGe_exp": (
+ "Si_SiGe_exp.h5",
+ "1fXNYSGpe6w6E9RBA-Ai_owZwoj3w8PNC",
+ ),
+ "Si_SiGe_probe": (
+ "Si_SiGe_probe.h5",
+ "141Tv0YF7c5a-MCrh3CkY_w4FgWtBih80",
+ ),
+ "Si_SiGe_EELS_strain": (
+ "Si_SiGe_EELS_strain.h5",
+ "1klkecq8IuEOYB-bXchO7RqOcgCl4bmDJ",
+ ),
+ "AuAgPd_wire": (
+ "AuAgPd_wire.h5",
+ "1OQYW0H6VELsmnLTcwicP88vo2V5E3Oyt",
+ ),
+ "AuAgPd_wire_probe": (
+ "AuAgPd_wire_probe.h5",
+ "17OduUKpxVBDumSK_VHtnc2XKkaFVN8kq",
+ ),
+ "polycrystal_2D_WS2": (
+ "polycrystal_2D_WS2.h5",
+ "1AWB3-UTPiTR9dgrEkNFD7EJYsKnbEy0y",
+ ),
+ "WS2cif": (
+ "WS2.cif",
+ "13zBl6aFExtsz_sew-L0-_ALYJfcgHKjo",
+ ),
+ "polymers": (
+ "polymers.h5",
+ "1lK-TAMXN1MpWG0Q3_4vss_uEZgW2_Xh7",
+ ),
+ "vac_probe": (
+ "vac_probe.h5",
+ "1QTcSKzZjHZd1fDimSI_q9_WsAU25NIXe",
+ ),
+ "small_dm3_3Dstack": ("small_dm3_3Dstack.dm3", "1B-xX3F65JcWzAg0v7f1aVwnawPIfb5_o"),
+ "FCU-Net": (
+ "filename.name",
+ "1-KX0saEYfhZ9IJAOwabH38PCVtfXidJi",
+ ),
+ "small_datacube": (
+ "small_datacube.dm4",
+ # TODO - change this file to something smaller - ideally e.g. shape (4,8,256,256) ~= 4.2MB'
+ "1QTcSKzZjHZd1fDimSI_q9_WsAU25NIXe",
+ ),
+ "legacy_v0.9": (
+ "legacy_v0.9_simAuNanoplatelet_bin.h5",
+ "1AIRwpcj87vK3ubLaKGj1UiYXZByD2lpu",
+ ),
+ "legacy_v0.13": ("legacy_v0.13.h5", "1VEqUy0Gthama7YAVkxwbjQwdciHpx8rA"),
+ "legacy_v0.14": (
+ "legacy_v0.14.h5",
+ "1eOTEJrpHnNv9_DPrWgZ4-NTN21UbH4aR",
+ ),
+ "test_realslice_io": ("test_realslice_io.h5", "1siH80-eRJwG5R6AnU4vkoqGWByrrEz1y"),
+ "test_arina_master": (
+ "STO_STEM_bench_20us_master.h5",
+ "1q_4IjFuWRkw5VM84NhxrNTdIq4563BOC",
+ ),
+ "test_arina_01": (
+ "STO_STEM_bench_20us_data_000001.h5",
+ "1_3Dbm22-hV58iffwK9x-3vqJUsEXZBFQ",
+ ),
+ "test_arina_02": (
+ "STO_STEM_bench_20us_data_000002.h5",
+ "1x29RzHLnCzP0qthLhA1kdlUQ09ENViR8",
+ ),
+ "test_arina_03": (
+ "STO_STEM_bench_20us_data_000003.h5",
+ "1qsbzdEVD8gt4DYKnpwjfoS_Mg4ggObAA",
+ ),
+ "test_arina_04": (
+ "STO_STEM_bench_20us_data_000004.h5",
+ "1Lcswld0Y9fNBk4-__C9iJbc854BuHq-h",
+ ),
+ "test_arina_05": (
+ "STO_STEM_bench_20us_data_000005.h5",
+ "13YTO2ABsTK5nObEr7RjOZYCV3sEk3gt9",
+ ),
+ "test_arina_06": (
+ "STO_STEM_bench_20us_data_000006.h5",
+ "1RywPXt6HRbCvjgjSuYFf60QHWlOPYXwy",
+ ),
+ "test_arina_07": (
+ "STO_STEM_bench_20us_data_000007.h5",
+ "1GRoBecCvAUeSIujzsPywv1vXKSIsNyoT",
+ ),
+ "test_arina_08": (
+ "STO_STEM_bench_20us_data_000008.h5",
+ "1sTFuuvgKbTjZz1lVUfkZbbTDTQmwqhuU",
+ ),
+ "test_arina_09": (
+ "STO_STEM_bench_20us_data_000009.h5",
+ "1JmBiMg16iMVfZ5wz8z_QqcNPVRym1Ezh",
+ ),
+ "test_arina_10": (
+ "STO_STEM_bench_20us_data_000010.h5",
+ "1_90xAfclNVwMWwQ-YKxNNwBbfR1nfHoB",
+ ),
+ "test_strain": (
+ "downsample_Si_SiGe_analysis_braggdisks_cal.h5",
+ "1bYgDdAlnWHyFmY-SwN3KVpMutWBI5MhP",
+ ),
+}
+
+# collections of files
+collection_ids = {
+ "tutorials": (
+ "Au_sim",
+ "carbon_nanotube",
+ "Si_SiGe_exp",
+ "Si_SiGe_probe",
+ "Si_SiGe_EELS_strain",
+ "AuAgPd_wire",
+ "AuAgPd_wire_probe",
+ "polycrystal_2D_WS2",
+ "WS2cif",
+ "polymers",
+ "vac_probe",
+ ),
+ "test_io": (
+ "small_dm3_3Dstack",
+ "vac_probe",
+ "legacy_v0.9",
+ "legacy_v0.13",
+ "legacy_v0.14",
+ "test_realslice_io",
+ ),
+ "test_arina": (
+ "test_arina_master",
+ "test_arina_01",
+ "test_arina_02",
+ "test_arina_03",
+ "test_arina_04",
+ "test_arina_05",
+ "test_arina_06",
+ "test_arina_07",
+ "test_arina_08",
+ "test_arina_09",
+ "test_arina_10",
+ ),
+ "test_braggvectors": ("Au_sim",),
+ "strain": ("test_strain",),
+}
+
+
+def get_sample_file_ids():
+ return {"files": file_ids.keys(), "collections": collection_ids.keys()}
+
+
+### Downloader
+
+
+def gdrive_download(
+ id_,
+ destination=None,
+ overwrite=False,
+ filename=None,
+ verbose=True,
+):
+ """
+ Downloads a file or collection of files from google drive.
+
+ Parameters
+ ----------
+ id_ : str
+ File ID for the desired file. May be either a key from the list
+ of files and collections of files accessible at get_sample_file_ids(),
+ or a complete url, or the portions of a google drive link specifying
+ it's google file ID, i.e. for the address
+ https://drive.google.com/file/d/1bHv3u61Cr-y_GkdWHrJGh1lw2VKmt3UM/,
+ the id string '1bHv3u61Cr-y_GkdWHrJGh1lw2VKmt3UM'.
+ destination : None or str
+ The location files are downloaded to. If a collection of files has been
+ specified, creates a new directory at the specified destination and
+ downloads the collection there. If None, downloads to the current
+ working directory. Otherwise must be a string or Path pointint to
+ a valid location on the filesystem.
+ overwrite : bool
+ Turns overwrite protection on/off.
+ filename : None or str
+ Used only if `id_` is a url or gdrive id. In these cases, specifies
+ the name of the output file. If left as None, saves to
+ 'gdrivedownload.file'. If `id_` is a key from the sample file id list,
+ this parameter is ignored.
+ verbose : bool
+ Toggles verbose output
+ """
+ # parse destination
+ if destination is None:
+ destination = os.getcwd()
+ assert os.path.exists(
+ destination
+ ), f"`destination` must exist on filesystem. Received {destination}"
+
+ # download single files
+ if id_ not in collection_ids:
+ # assign the name and id
+ kwargs = {"fuzzy": True}
+ if id_ in file_ids:
+ f = file_ids[id_]
+ filename = f[0]
+ kwargs["id"] = f[1]
+
+ # if its not in the list of files we expect
+
+ # TODO simplify the logic here
+ else:
+ filename = "gdrivedownload.file" if filename is None else filename
+ # check if its a url
+ if id_.startswith("http"):
+ # check the url is the correct format i.e. https://drive.google.com/uc?id=
+ # and not https://drive.google.com/file/d/
+ # if correct format
+ if "uc?id=" in id_:
+ kwargs["url"] = id_
+ # if incorrect format, strip the google ID from the URL
+ # making http/https agnostic
+ elif "drive.google.com/file/d/" in id_:
+ # warn the user the the url syntax was incorrect and this is making a guess
+ warnings.warn(
+ f"URL provided {id_} was not in the correct format https://drive.google.com/uc?id=, attempting to interpret link and download the file. Most likely a URL with this format was provided https://drive.google.com/file/d/"
+ )
+ # try stripping
+ stripped_id = id_.split("/")[-1]
+ # Currently the length of the google drive IDs appears to always be 33 characters
+ # check for length and warn if it appears malformed, if so raise warning and the ID it guessed
+ if len(stripped_id) != 33:
+ warnings.warn(
+ f"Guessed ID {stripped_id}: appears to be in the wrong length (not 33 characters), attempting download"
+ )
+ kwargs["id"] = stripped_id
+ # if its just a Google Drive string
+ else:
+ kwargs["id"] = id_
+
+ # download
+ kwargs["output"] = os.path.join(destination, filename)
+ if not (overwrite) and os.path.exists(kwargs["output"]):
+ if verbose:
+ print(f"A file already exists at {kwargs['output']}, skipping...")
+ else:
+ gdown.download(**kwargs)
+
+ # download a collections of files
+ else:
+ # set destination
+ destination = os.path.join(destination, id_)
+ if not os.path.exists(destination):
+ os.mkdir(destination)
+
+ # loop
+ for x in collection_ids[id_]:
+ file_name, file_id = file_ids[x]
+ output = os.path.join(destination, file_name)
+ # download
+ if not (overwrite) and os.path.exists(output):
+ if verbose:
+ print(f"A file already exists at {output}, skipping...")
+ else:
+ gdown.download(id=file_id, output=output, fuzzy=True)
diff --git a/py4DSTEM/io/importfile.py b/py4DSTEM/io/importfile.py
new file mode 100644
index 000000000..20a3759a2
--- /dev/null
+++ b/py4DSTEM/io/importfile.py
@@ -0,0 +1,99 @@
+# Reader functions for non-native file types
+
+import pathlib
+from os.path import exists
+from typing import Optional, Union
+
+from py4DSTEM.io.filereaders import (
+ load_mib,
+ read_abTEM,
+ read_arina,
+ read_dm,
+ read_empad,
+ read_gatan_K2_bin,
+)
+from py4DSTEM.io.parsefiletype import _parse_filetype
+
+
+def import_file(
+ filepath: Union[str, pathlib.Path],
+ mem: Optional[str] = "RAM",
+ binfactor: Optional[int] = 1,
+ filetype: Optional[str] = None,
+ **kwargs,
+):
+ """
+ Reader for non-native file formats.
+ Parses the filetype, and calls the appropriate reader.
+ Supports Gatan DM3/4, some EMPAD file versions, Gatan K2 bin/gtg, and mib
+ formats.
+
+ Args:
+ filepath (str or Path): Path to the file.
+ mem (str): Must be "RAM" or "MEMMAP". Specifies how the data is
+ loaded; "RAM" transfer the data from storage to RAM, while "MEMMAP"
+ leaves the data in storage and creates a memory map which points to
+ the diffraction patterns, allowing them to be retrieved individually
+ from storage.
+ binfactor (int): Diffraction space binning factor for bin-on-load.
+ filetype (str): Used to override automatic filetype detection.
+ options include "dm", "empad", "gatan_K2_bin", "mib", "arina", "abTEM"
+ **kwargs: any additional kwargs are passed to the downstream reader -
+ refer to the individual filetype reader function call signatures
+ and docstrings for more details.
+
+ Returns:
+ (DataCube or Array) returns a DataCube if 4D data is found, otherwise
+ returns an Array
+
+ """
+
+ assert isinstance(
+ filepath, (str, pathlib.Path)
+ ), f"filepath must be a string or Path, not {type(filepath)}"
+ assert exists(filepath), f"The given filepath: '{filepath}' \ndoes not exist"
+ assert mem in [
+ "RAM",
+ "MEMMAP",
+ ], 'Error: argument mem must be either "RAM" or "MEMMAP"'
+ assert isinstance(binfactor, int), "Error: argument binfactor must be an integer"
+ assert binfactor >= 1, "Error: binfactor must be >= 1"
+ if binfactor > 1:
+ assert (
+ mem != "MEMMAP"
+ ), "Error: binning is not supported for memory mapping. Either set binfactor=1 or set mem='RAM'"
+
+ filetype = _parse_filetype(filepath) if filetype is None else filetype
+
+ if filetype in ("emd", "legacy"):
+ raise Exception(
+ "EMD file or py4DSTEM detected - use py4DSTEM.read, not py4DSTEM.import_file!"
+ )
+ assert filetype in [
+ "dm",
+ "empad",
+ "gatan_K2_bin",
+ "mib",
+ "arina",
+ "abTEM"
+ # "kitware_counted",
+ ], "Error: filetype not recognized"
+
+ if filetype == "dm":
+ data = read_dm(filepath, mem=mem, binfactor=binfactor, **kwargs)
+ elif filetype == "empad":
+ data = read_empad(filepath, mem=mem, binfactor=binfactor, **kwargs)
+ elif filetype == "gatan_K2_bin":
+ data = read_gatan_K2_bin(filepath, mem=mem, binfactor=binfactor, **kwargs)
+ # elif filetype == "kitware_counted":
+ # data = read_kitware_counted(filepath, mem, binfactor, metadata=metadata, **kwargs)
+ elif filetype == "mib":
+ data = load_mib(filepath, mem=mem, binfactor=binfactor, **kwargs)
+ elif filetype == "arina":
+ data = read_arina(filepath, mem=mem, binfactor=binfactor, **kwargs)
+ elif filetype == "abTEM":
+ data = read_abTEM(filepath, mem=mem, binfactor=binfactor, **kwargs)
+ else:
+ raise Exception("Bad filetype!")
+
+ return data
diff --git a/py4DSTEM/io/legacy/__init__.py b/py4DSTEM/io/legacy/__init__.py
new file mode 100644
index 000000000..ee340a7d4
--- /dev/null
+++ b/py4DSTEM/io/legacy/__init__.py
@@ -0,0 +1,3 @@
+from py4DSTEM.io.legacy.read_legacy_13 import *
+from py4DSTEM.io.legacy.read_legacy_12 import *
+from py4DSTEM.io.legacy.read_utils import *
diff --git a/py4DSTEM/io/legacy/legacy12/__init__.py b/py4DSTEM/io/legacy/legacy12/__init__.py
new file mode 100644
index 000000000..2370b8ca6
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/__init__.py
@@ -0,0 +1,5 @@
+from .read_v0_5 import read_v0_5
+from .read_v0_6 import read_v0_6
+from .read_v0_7 import read_v0_7
+from .read_v0_9 import read_v0_9
+from .read_v0_12 import read_v0_12
diff --git a/py4DSTEM/io/legacy/legacy12/read_utils_v0_12.py b/py4DSTEM/io/legacy/legacy12/read_utils_v0_12.py
new file mode 100644
index 000000000..a8a646e8d
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_utils_v0_12.py
@@ -0,0 +1,80 @@
+# Utility functions
+
+import h5py
+import numpy as np
+from py4DSTEM.io.legacy.read_utils import is_py4DSTEM_file
+
+
+def get_py4DSTEM_dataobject_info(filepath, topgroup="4DSTEM_experiment"):
+ """Returns a numpy structured array with basic metadata for all contained dataobjects.
+ Keys for the info array are: 'index','type','shape','name'.
+ """
+ assert is_py4DSTEM_file(filepath), "Error: not recognized as a py4DSTEM file"
+ with h5py.File(filepath, "r") as f:
+ assert topgroup in f.keys(), "Error: unrecognized topgroup"
+ i = 0
+ l_md = []
+ with h5py.File(filepath, "r") as f:
+ grp_dc = f[topgroup + "/data/datacubes/"]
+ grp_cdc = f[topgroup + "/data/counted_datacubes/"]
+ grp_ds = f[topgroup + "/data/diffractionslices/"]
+ grp_rs = f[topgroup + "/data/realslices/"]
+ grp_pl = f[topgroup + "/data/pointlists/"]
+ grp_pla = f[topgroup + "/data/pointlistarrays/"]
+ grp_coords = f[topgroup + "/data/coordinates/"]
+ N = (
+ len(grp_dc)
+ + len(grp_cdc)
+ + len(grp_ds)
+ + len(grp_rs)
+ + len(grp_pl)
+ + len(grp_pla)
+ + len(grp_coords)
+ )
+ info = np.zeros(
+ N,
+ dtype=[("index", int), ("type", "U16"), ("shape", tuple), ("name", "U64")],
+ )
+ for name in sorted(grp_dc.keys()):
+ shape = grp_dc[name + "/data/"].shape
+ dtype = "DataCube"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_cdc.keys()):
+ # TODO
+ shape = grp_cdc[name + "/data/"].shape
+ dtype = "CountedDataCube"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_ds.keys()):
+ shape = grp_ds[name + "/data/"].shape
+ dtype = "DiffractionSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_rs.keys()):
+ shape = grp_rs[name + "/data/"].shape
+ dtype = "RealSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pl.keys()):
+ coordinates = list(grp_pl[name].keys())
+ length = grp_pl[name + "/" + coordinates[0] + "/data"].shape[0]
+ shape = (len(coordinates), length)
+ dtype = "PointList"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pla.keys()):
+ ar_shape = grp_pla[name + "/data"].shape
+ pla_dtype = h5py.check_vlen_dtype(grp_pla[name + "/data"].dtype)
+ N_coords = len(pla_dtype)
+ shape = (ar_shape[0], ar_shape[1], N_coords, -1)
+ dtype = "PointListArray"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_coords.keys()):
+ shape = 0 # TODO?
+ dtype = "Coordinates"
+ info[i] = i, dtype, shape, name
+ i += 1
+
+ return info
diff --git a/py4DSTEM/io/legacy/legacy12/read_utils_v0_5.py b/py4DSTEM/io/legacy/legacy12/read_utils_v0_5.py
new file mode 100644
index 000000000..1e868c4b7
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_utils_v0_5.py
@@ -0,0 +1,60 @@
+# Utility functions
+
+import h5py
+import numpy as np
+from py4DSTEM.io.legacy.read_utils import is_py4DSTEM_file
+
+
+def get_py4DSTEM_dataobject_info(fp, topgroup="4DSTEM_experiment"):
+ """Returns a numpy structured array with basic metadata for all contained dataobjects.
+ Keys for the info array are: 'index','type','shape','name'.
+ """
+ assert is_py4DSTEM_file(fp), "Error: not recognized as a py4DSTEM file"
+ with h5py.File(fp, "r") as f:
+ assert topgroup in f.keys(), "Error: unrecognized topgroup"
+ i = 0
+ l_md = []
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[topgroup + "/data/datacubes/"]
+ grp_ds = f[topgroup + "/data/diffractionslices/"]
+ grp_rs = f[topgroup + "/data/realslices/"]
+ grp_pl = f[topgroup + "/data/pointlists/"]
+ grp_pla = f[topgroup + "/data/pointlistarrays/"]
+ N = len(grp_dc) + len(grp_ds) + len(grp_rs) + len(grp_pl) + len(grp_pla)
+ info = np.zeros(
+ N,
+ dtype=[("index", int), ("type", "U16"), ("shape", tuple), ("name", "U64")],
+ )
+ for name in sorted(grp_dc.keys()):
+ shape = grp_dc[name + "/datacube/"].shape
+ dtype = "DataCube"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_ds.keys()):
+ shape = grp_ds[name + "/diffractionslice/"].shape
+ dtype = "DiffractionSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_rs.keys()):
+ shape = grp_rs[name + "/realslice/"].shape
+ dtype = "RealSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pl.keys()):
+ coordinates = list(grp_pl[name].keys())
+ length = grp_pl[name + "/" + coordinates[0] + "/pointlist"].shape[0]
+ shape = (len(coordinates), length)
+ dtype = "PointList"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pla.keys()):
+ l = list(grp_pla[name])
+ ar = np.array([l[j].split("_") for j in range(len(l))]).astype(int)
+ ar_shape = (np.max(ar[:, 0]) + 1, np.max(ar[:, 1]) + 1)
+ N_coords = len(list(grp_pla[name + "/0_0"]))
+ shape = (ar_shape[0], ar_shape[1], N_coords, -1)
+ dtype = "PointListArray"
+ info[i] = i, dtype, shape, name
+ i += 1
+
+ return info
diff --git a/py4DSTEM/io/legacy/legacy12/read_utils_v0_6.py b/py4DSTEM/io/legacy/legacy12/read_utils_v0_6.py
new file mode 100644
index 000000000..79cd7f048
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_utils_v0_6.py
@@ -0,0 +1,60 @@
+# Utility functions
+
+import h5py
+import numpy as np
+from py4DSTEM.io.legacy.read_utils import is_py4DSTEM_file
+
+
+def get_py4DSTEM_dataobject_info(fp, topgroup="4DSTEM_experiment"):
+ """Returns a numpy structured array with basic metadata for all contained dataobjects.
+ Keys for the info array are: 'index','type','shape','name'.
+ """
+ assert is_py4DSTEM_file(fp), "Error: not recognized as a py4DSTEM file"
+ with h5py.File(fp, "r") as f:
+ assert topgroup in f.keys(), "Error: unrecognized topgroup"
+ i = 0
+ l_md = []
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[topgroup + "/data/datacubes/"]
+ grp_ds = f[topgroup + "/data/diffractionslices/"]
+ grp_rs = f[topgroup + "/data/realslices/"]
+ grp_pl = f[topgroup + "/data/pointlists/"]
+ grp_pla = f[topgroup + "/data/pointlistarrays/"]
+ N = len(grp_dc) + len(grp_ds) + len(grp_rs) + len(grp_pl) + len(grp_pla)
+ info = np.zeros(
+ N,
+ dtype=[("index", int), ("type", "U16"), ("shape", tuple), ("name", "U64")],
+ )
+ for name in sorted(grp_dc.keys()):
+ shape = grp_dc[name + "/data/"].shape
+ dtype = "DataCube"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_ds.keys()):
+ shape = grp_ds[name + "/data/"].shape
+ dtype = "DiffractionSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_rs.keys()):
+ shape = grp_rs[name + "/data/"].shape
+ dtype = "RealSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pl.keys()):
+ coordinates = list(grp_pl[name].keys())
+ length = grp_pl[name + "/" + coordinates[0] + "/data"].shape[0]
+ shape = (len(coordinates), length)
+ dtype = "PointList"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pla.keys()):
+ l = list(grp_pla[name])
+ ar = np.array([l[j].split("_") for j in range(len(l))]).astype(int)
+ ar_shape = (np.max(ar[:, 0]) + 1, np.max(ar[:, 1]) + 1)
+ N_coords = len(list(grp_pla[name + "/0_0"]))
+ shape = (ar_shape[0], ar_shape[1], N_coords, -1)
+ dtype = "PointListArray"
+ info[i] = i, dtype, shape, name
+ i += 1
+
+ return info
diff --git a/py4DSTEM/io/legacy/legacy12/read_utils_v0_7.py b/py4DSTEM/io/legacy/legacy12/read_utils_v0_7.py
new file mode 100644
index 000000000..56b09059d
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_utils_v0_7.py
@@ -0,0 +1,59 @@
+# Utility functions
+
+import h5py
+import numpy as np
+from py4DSTEM.io.legacy.read_utils import is_py4DSTEM_file
+
+
+def get_py4DSTEM_dataobject_info(fp, topgroup="4DSTEM_experiment"):
+ """Returns a numpy structured array with basic metadata for all contained dataobjects.
+ Keys for the info array are: 'index','type','shape','name'.
+ """
+ assert is_py4DSTEM_file(fp), "Error: not recognized as a py4DSTEM file"
+ with h5py.File(fp, "r") as f:
+ assert topgroup in f.keys(), "Error: unrecognized topgroup"
+ i = 0
+ l_md = []
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[topgroup + "/data/datacubes/"]
+ grp_ds = f[topgroup + "/data/diffractionslices/"]
+ grp_rs = f[topgroup + "/data/realslices/"]
+ grp_pl = f[topgroup + "/data/pointlists/"]
+ grp_pla = f[topgroup + "/data/pointlistarrays/"]
+ N = len(grp_dc) + len(grp_ds) + len(grp_rs) + len(grp_pl) + len(grp_pla)
+ info = np.zeros(
+ N,
+ dtype=[("index", int), ("type", "U16"), ("shape", tuple), ("name", "U64")],
+ )
+ for name in sorted(grp_dc.keys()):
+ shape = grp_dc[name + "/data/"].shape
+ dtype = "DataCube"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_ds.keys()):
+ shape = grp_ds[name + "/data/"].shape
+ dtype = "DiffractionSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_rs.keys()):
+ shape = grp_rs[name + "/data/"].shape
+ dtype = "RealSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pl.keys()):
+ coordinates = list(grp_pl[name].keys())
+ length = grp_pl[name + "/" + coordinates[0] + "/data"].shape[0]
+ shape = (len(coordinates), length)
+ dtype = "PointList"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pla.keys()):
+ ar = np.array(grp_pla[name + "/data"])
+ ar_shape = ar.shape
+ N_coords = len(ar[0, 0].dtype)
+ shape = (ar_shape[0], ar_shape[1], N_coords, -1)
+ dtype = "PointListArray"
+ info[i] = i, dtype, shape, name
+ i += 1
+
+ return info
diff --git a/py4DSTEM/io/legacy/legacy12/read_utils_v0_9.py b/py4DSTEM/io/legacy/legacy12/read_utils_v0_9.py
new file mode 100644
index 000000000..56b09059d
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_utils_v0_9.py
@@ -0,0 +1,59 @@
+# Utility functions
+
+import h5py
+import numpy as np
+from py4DSTEM.io.legacy.read_utils import is_py4DSTEM_file
+
+
+def get_py4DSTEM_dataobject_info(fp, topgroup="4DSTEM_experiment"):
+ """Returns a numpy structured array with basic metadata for all contained dataobjects.
+ Keys for the info array are: 'index','type','shape','name'.
+ """
+ assert is_py4DSTEM_file(fp), "Error: not recognized as a py4DSTEM file"
+ with h5py.File(fp, "r") as f:
+ assert topgroup in f.keys(), "Error: unrecognized topgroup"
+ i = 0
+ l_md = []
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[topgroup + "/data/datacubes/"]
+ grp_ds = f[topgroup + "/data/diffractionslices/"]
+ grp_rs = f[topgroup + "/data/realslices/"]
+ grp_pl = f[topgroup + "/data/pointlists/"]
+ grp_pla = f[topgroup + "/data/pointlistarrays/"]
+ N = len(grp_dc) + len(grp_ds) + len(grp_rs) + len(grp_pl) + len(grp_pla)
+ info = np.zeros(
+ N,
+ dtype=[("index", int), ("type", "U16"), ("shape", tuple), ("name", "U64")],
+ )
+ for name in sorted(grp_dc.keys()):
+ shape = grp_dc[name + "/data/"].shape
+ dtype = "DataCube"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_ds.keys()):
+ shape = grp_ds[name + "/data/"].shape
+ dtype = "DiffractionSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_rs.keys()):
+ shape = grp_rs[name + "/data/"].shape
+ dtype = "RealSlice"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pl.keys()):
+ coordinates = list(grp_pl[name].keys())
+ length = grp_pl[name + "/" + coordinates[0] + "/data"].shape[0]
+ shape = (len(coordinates), length)
+ dtype = "PointList"
+ info[i] = i, dtype, shape, name
+ i += 1
+ for name in sorted(grp_pla.keys()):
+ ar = np.array(grp_pla[name + "/data"])
+ ar_shape = ar.shape
+ N_coords = len(ar[0, 0].dtype)
+ shape = (ar_shape[0], ar_shape[1], N_coords, -1)
+ dtype = "PointListArray"
+ info[i] = i, dtype, shape, name
+ i += 1
+
+ return info
diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_12.py b/py4DSTEM/io/legacy/legacy12/read_v0_12.py
new file mode 100644
index 000000000..44aa86b6a
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_v0_12.py
@@ -0,0 +1,428 @@
+# Reader for py4DSTEM v0.12 files
+
+import h5py
+import numpy as np
+from os.path import splitext, exists
+from py4DSTEM.io.legacy.read_utils import (
+ is_py4DSTEM_file,
+ get_py4DSTEM_topgroups,
+ get_py4DSTEM_version,
+ version_is_geq,
+)
+from py4DSTEM.io.legacy.legacy12.read_utils_v0_12 import get_py4DSTEM_dataobject_info
+from emdfile import PointList, PointListArray
+from py4DSTEM.data import (
+ DiffractionSlice,
+ RealSlice,
+)
+from py4DSTEM.datacube import DataCube
+from emdfile import tqdmnd
+
+
+def read_v0_12(fp, **kwargs):
+ """
+ File reader for files written by py4DSTEM v0.12. Precise behavior is detemined by which
+ arguments are passed -- see below.
+
+ Accepts:
+ filepath str or Path When passed a filepath only, this function checks if the path
+ points to a valid py4DSTEM file, then prints its contents to screen.
+ data_id int/str/list Specifies which data to load. Use integers to specify the
+ data index, or strings to specify data names. A list or
+ tuple returns a list of DataObjects. Returns the specified data.
+ topgroup str Stricty, a py4DSTEM file is considered to be
+ everything inside a toplevel subdirectory within the
+ HDF5 file, so that if desired one can place many py4DSTEM
+ files inside a single H5. In this case, when loading
+ data, the topgroup argument is passed to indicate which
+ py4DSTEM file to load. If an H5 containing multiple
+ py4DSTEM files is passed without a topgroup specified,
+ the topgroup names are printed to screen.
+ metadata bool If True, returns a dictionary with the file metadata.
+ log bool If True, writes the processing log to a plaintext file
+ called splitext(fp)[0]+'.log'.
+ mem str Only used if a single DataCube is loaded. In this case, mem
+ specifies how the data should be stored; must be "RAM"
+ or "MEMMAP" or "DASK". See docstring for py4DSTEM.file.io.read. Default
+ is "RAM".
+ binfactor int Only used if a single DataCube is loaded. In this case,
+ a binfactor of > 1 causes the data to be binned by this amount
+ as it's loaded.
+ dtype dtype Used when binning data, ignored otherwise. Defaults to whatever
+ the type of the raw data is, to avoid enlarging data size. May be
+ useful to avoid 'wraparound' errors.
+
+ Returns:
+ data,md The function always returns a length 2 tuple corresponding
+ to data and md. If no input arguments with return values (i.e.
+ data, metadata), these will return None. Otherwise, their return
+ values are as described above. E.f. passing data=[0,1,2],metadata=True
+ will return a length two tuple, the first element being a list of 3
+ DataObject instances and the second a MetaData instance.
+ """
+ assert exists(fp), "Error: specified filepath does not exist"
+ assert is_py4DSTEM_file(
+ fp
+ ), "Error: {} isn't recognized as a py4DSTEM file.".format(fp)
+
+ # For HDF5 files containing multiple valid EMD type 2 files, disambiguate desired data
+ tgs = get_py4DSTEM_topgroups(fp)
+ if "topgroup" in kwargs.keys():
+ tg = kwargs["topgroup"]
+ assert tg in tgs, "Error: specified topgroup, {}, not found.".format(tg)
+ else:
+ if len(tgs) == 1:
+ tg = tgs[0]
+ else:
+ print(
+ "Multiple topgroups detected. Please specify one by passing the 'topgroup' keyword argument."
+ )
+ print("")
+ print("Topgroups found:")
+ for tg in tgs:
+ print(tg)
+ return None, None
+
+ version = get_py4DSTEM_version(fp, tg)
+ assert version_is_geq(version, (0, 12, 0)), "File must be v0.12+"
+ _data_id = "data_id" in kwargs.keys() # Flag indicating if data was requested
+
+ # If metadata is requested
+ if "metadata" in kwargs.keys():
+ if kwargs["metadata"]:
+ raise NotImplementedError("Legacy metadata reader missing...")
+ # return metadata_from_h5(fp, tg)
+
+ # If data is requested
+ elif "data_id" in kwargs.keys():
+ data_id = kwargs["data_id"]
+ assert isinstance(
+ data_id, (int, np.int_, str, list, tuple)
+ ), "Error: data must be specified with strings or integers only."
+ if not isinstance(data_id, (int, np.int_, str)):
+ assert all(
+ [isinstance(d, (int, np.int_, str)) for d in data_id]
+ ), "Error: data must be specified with strings or integers only."
+
+ # Parse optional arguments
+ if "mem" in kwargs.keys():
+ mem = kwargs["mem"]
+ assert mem in ("RAM", "MEMMAP", "DASK")
+ else:
+ mem = "RAM"
+ if "binfactor" in kwargs.keys():
+ binfactor = kwargs["binfactor"]
+ assert isinstance(binfactor, (int, np.int_))
+ else:
+ binfactor = 1
+ if "dtype" in kwargs.keys():
+ bindtype = kwargs["dtype"]
+ assert isinstance(bindtype, type)
+ else:
+ bindtype = None
+
+ return get_data(fp, tg, data_id, mem, binfactor, bindtype)
+
+ # If no data is requested
+ else:
+ print_py4DSTEM_file(fp, tg)
+ return
+
+
+###### Get data ######
+
+
+def get_data(filepath, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a filepath to a valid py4DSTEM file and an int/str/list specifying data, and returns the data."""
+ if isinstance(data_id, (int, np.int_)):
+ return get_data_from_int(
+ filepath, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ elif isinstance(data_id, str):
+ return get_data_from_str(
+ filepath, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ else:
+ return get_data_from_list(filepath, tg, data_id)
+
+
+def get_data_from_int(filepath, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a filepath to a valid py4DSTEM file and an integer specifying data, and returns the data."""
+ assert isinstance(data_id, (int, np.int_))
+ with h5py.File(filepath, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_cdc = f[tg + "/data/counted_datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grp_coords = f[tg + "/data/coordinates/"]
+ grps = [grp_dc, grp_cdc, grp_ds, grp_rs, grp_pl, grp_pla, grp_coords]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i = np.nonzero(data_id < Ns)[0][0]
+ grp = grps[i]
+ N = data_id - Ns[i]
+ name = sorted(grp.keys())[N]
+
+ group_name = grp.name + "/" + name
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ if mem == "MEMMAP":
+ f = h5py.File(filepath, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ # ADDING STUFF IN HERE,
+ # I need to change datacube and counted datacube
+ elif mem == "DASK":
+ f = h5py.File(filepath, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_str(filepath, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a filepath to a valid py4DSTEM file and a string specifying data, and returns the data."""
+ assert isinstance(data_id, str)
+ with h5py.File(filepath, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_cdc = f[tg + "/data/counted_datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grp_coords = f[tg + "/data/coordinates/"]
+ grps = [grp_dc, grp_cdc, grp_ds, grp_rs, grp_pl, grp_pla, grp_coords]
+
+ l_dc = list(grp_dc.keys())
+ l_cdc = list(grp_cdc.keys())
+ l_ds = list(grp_ds.keys())
+ l_rs = list(grp_rs.keys())
+ l_pl = list(grp_pl.keys())
+ l_pla = list(grp_pla.keys())
+ l_coords = list(grp_coords.keys())
+ names = l_dc + l_cdc + l_ds + l_rs + l_pl + l_pla + l_coords
+
+ inds = [i for i, name in enumerate(names) if name == data_id]
+ assert len(inds) != 0, "Error: no data named {} found.".format(data_id)
+ assert len(inds) < 2, "Error: multiple data blocks named {} found.".format(
+ data_id
+ )
+ ind = inds[0]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i_grp = np.nonzero(ind < Ns)[0][0]
+ grp = grps[i_grp]
+ group_name = grp.name + "/" + data_id
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ if mem == "MEMMAP":
+ f = h5py.File(filepath, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ elif mem == "DASK":
+ f = h5py.File(filepath, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_list(filepath, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a filepath to a valid py4DSTEM file and a list or tuple specifying data, and returns the data."""
+ assert isinstance(data_id, (list, tuple))
+ assert all([isinstance(d, (int, np.int_, str)) for d in data_id])
+ data = []
+ for el in data_id:
+ if isinstance(el, (int, np.int_)):
+ data.append(
+ get_data_from_int(
+ filepath,
+ tg,
+ data_id=el,
+ mem=mem,
+ binfactor=binfactor,
+ bindtype=bindtype,
+ )
+ )
+ elif isinstance(el, str):
+ data.append(
+ get_data_from_str(
+ filepath,
+ tg,
+ data_id=el,
+ mem=mem,
+ binfactor=binfactor,
+ bindtype=bindtype,
+ )
+ )
+ else:
+ raise Exception("Data must be specified with strings or integers only.")
+ return data
+
+
+def get_data_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single dataobject in an open, correctly formatted H5 file,
+ and returns a py4DSTEM DataObject.
+ """
+ dtype = g.name.split("/")[-2]
+ if dtype == "datacubes":
+ return get_datacube_from_grp(g, mem, binfactor, bindtype)
+ elif dtype == "counted_datacubes":
+ raise NotImplementedError(
+ "CountedDataCube objects are not available in py4DSTEM v0.13"
+ )
+ # return get_counted_datacube_from_grp(g)
+ elif dtype == "diffractionslices":
+ return get_diffractionslice_from_grp(g)
+ elif dtype == "realslices":
+ return get_realslice_from_grp(g)
+ elif dtype == "pointlists":
+ return get_pointlist_from_grp(g)
+ elif dtype == "pointlistarrays":
+ return get_pointlistarray_from_grp(g)
+ elif dtype == "coordinates":
+ raise NotImplementedError(
+ "Conversion from legacy Coordinates object to v0.13 Calibration is not available..."
+ )
+ # return get_coordinates_from_grp(g)
+ else:
+ raise Exception("Unrecognized data object type {}".format(dtype))
+
+
+# Print to screen
+
+
+def print_py4DSTEM_file(filepath, tg):
+ """Accepts a filepath to a valid py4DSTEM file and prints to screen the file contents."""
+ info = get_py4DSTEM_dataobject_info(filepath, tg)
+
+ version = get_py4DSTEM_version(filepath, tg)
+ print(f"py4DSTEM file version {version[0]}.{version[1]}.{version[2]}")
+
+ print("{:10}{:18}{:24}{:54}".format("Index", "Type", "Shape", "Name"))
+ print("{:10}{:18}{:24}{:54}".format("-----", "----", "-----", "----"))
+ for el in info:
+ print(
+ " {:8}{:18}{:24}{:54}".format(
+ str(el["index"]), str(el["type"]), str(el["shape"]), str(el["name"])
+ )
+ )
+ return
+
+
+#####################################################
+# READERS SCRAPED FROM DATASTRUCTURE FILES IN v0.12 #
+#####################################################
+
+
+def get_datacube_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single datacube in an open, correctly formatted H5 file,
+ and returns a DataCube.
+ """
+ # TODO: add binning
+ assert binfactor == 1, "Bin on load is currently unsupported for EMD files."
+
+ if (mem, binfactor) == ("RAM", 1):
+ stack_pointer = g["data"]
+ data = np.array(g["data"])
+ elif (mem, binfactor) == ("MEMMAP", 1):
+ data = g["data"]
+ stack_pointer = None
+ name = g.name.split("/")[-1]
+ return DataCube(data=data, name=name)
+
+
+def get_diffractionslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a diffractionslice in an open, correctly formatted H5 file,
+ and returns a DiffractionSlice.
+ """
+ data = np.array(g["data"])
+ name = g.name.split("/")[-1]
+ Q_Nx, Q_Ny = data.shape[:2]
+ if len(data.shape) == 2:
+ return DiffractionSlice(data=data, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ with lbls.astype("S64"):
+ lbls = lbls[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return DiffractionSlice(data=data, name=name, slicelabels=lbls)
+
+
+def get_realslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a realslice in an open, correctly formatted H5 file,
+ and returns a RealSlice.
+ """
+ data = np.array(g["data"])
+ name = g.name.split("/")[-1]
+ R_Nx, R_Ny = data.shape[:2]
+ if len(data.shape) == 2:
+ return RealSlice(data=data, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ with lbls.astype("S64"):
+ lbls = lbls[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return RealSlice(data=data, name=name, slicelabels=lbls)
+
+
+def get_pointlist_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlist in an open, correctly formatted H5 file,
+ and returns a PointList.
+ """
+ name = g.name.split("/")[-1]
+ coordinates = []
+ coord_names = list(g.keys())
+ length = len(g[coord_names[0] + "/data"])
+ if length == 0:
+ for coord in coord_names:
+ coordinates.append((coord, None))
+ else:
+ for coord in coord_names:
+ dtype = type(g[coord + "/data"][0])
+ coordinates.append((coord, dtype))
+ data = np.zeros(length, dtype=coordinates)
+ for coord in coord_names:
+ data[coord] = np.array(g[coord + "/data"])
+ return PointList(data=data, name=name)
+
+
+def get_pointlistarray_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlistarray in an open, correctly formatted H5 file,
+ and returns a PointListArray.
+ """
+ name = g.name.split("/")[-1]
+ dset = g["data"]
+ shape = dset.shape
+ coordinates = h5py.check_vlen_dtype(dset.dtype)
+ pla = PointListArray(dtype=coordinates, shape=shape, name=name)
+ for i, j in tqdmnd(
+ shape[0], shape[1], desc="Reading PointListArray", unit="PointList"
+ ):
+ try:
+ pla.get_pointlist(i, j).data = dset[i, j]
+ except ValueError:
+ pass
+ return pla
diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_5.py b/py4DSTEM/io/legacy/legacy12/read_v0_5.py
new file mode 100644
index 000000000..de7108b02
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_v0_5.py
@@ -0,0 +1,415 @@
+# Reader for py4DSTEM v0.5 files
+
+import h5py
+import numpy as np
+from os.path import splitext
+from py4DSTEM.io.legacy.read_utils import (
+ is_py4DSTEM_file,
+ get_py4DSTEM_topgroups,
+ get_py4DSTEM_version,
+ version_is_geq,
+)
+from py4DSTEM.io.legacy.legacy12.read_utils_v0_5 import get_py4DSTEM_dataobject_info
+from emdfile import PointList, PointListArray
+from py4DSTEM.data import (
+ DiffractionSlice,
+ RealSlice,
+)
+from py4DSTEM.datacube import DataCube
+from emdfile import tqdmnd
+
+
+def read_v0_5(fp, **kwargs):
+ """
+ File reader for files written by py4DSTEM v0.5. Precise behavior is detemined by which
+ arguments are passed -- see below.
+
+ ***NOTE: this function has not yet been tested on all legacy py4DSTEM formats. Please report
+ any problems by filing an issue on our github!
+
+ Accepts:
+ filepath str or Path When passed a filepath only, this function checks if the path
+ points to a valid py4DSTEM file, then prints its contents to screen.
+ data_id int/str/list Specifies which data to load. Use integers to specify the
+ data index, or strings to specify data names. A list or
+ tuple returns a list of DataObjects. Returns the specified data.
+ topgroup str Stricty, a py4DSTEM file is considered to be
+ everything inside a toplevel subdirectory within the
+ HDF5 file, so that if desired one can place many py4DSTEM
+ files inside a single H5. In this case, when loading
+ data, the topgroup argument is passed to indicate which
+ py4DSTEM file to load. If an H5 containing multiple
+ py4DSTEM files is passed without a topgroup specified,
+ the topgroup names are printed to screen.
+ metadata bool If True, returns a dictionary with the file metadata.
+ log bool If True, writes the processing log to a plaintext file
+ called splitext(fp)[0]+'.log'.
+ mem str Only used if a single DataCube is loaded. In this case, mem
+ specifies how the data should be stored; must be "RAM"
+ or "MEMMAP". See docstring for py4DSTEM.file.io.read. Default
+ is "RAM".
+ binfactor int Only used if a single DataCube is loaded. In this case,
+ a binfactor of > 1 causes the data to be binned by this amount
+ as it's loaded.
+ dtype dtype Used when binning data, ignored otherwise. Defaults to whatever
+ the type of the raw data is, to avoid enlarging data size. May be
+ useful to avoid 'wraparound' errors.
+
+ Returns:
+ data,md The function always returns a length 2 tuple corresponding
+ to data and md. If no input arguments with return values (i.e.
+ data, metadata), these will return None. Otherwise, their return
+ values are as described above. E.f. passing data=[0,1,2],metadata=True
+ will return a length two tuple, the first element being a list of 3
+ DataObject instances and the second a MetaData instance.
+ """
+ assert is_py4DSTEM_file(
+ fp
+ ), "Error: {} isn't recognized as a py4DSTEM file.".format(fp)
+
+ # For HDF5 files containing multiple valid EMD type 2 files, disambiguate desired data
+ tgs = get_py4DSTEM_topgroups(fp)
+ if "topgroup" in kwargs.keys():
+ tg = kwargs["topgroup"]
+ assert tg in tgs, "Error: specified topgroup, {}, not found.".format(tg)
+ else:
+ if len(tgs) == 1:
+ tg = tgs[0]
+ else:
+ print(
+ "Multiple topgroups detected. Please specify one by passing the 'topgroup' keyword argument."
+ )
+ print("")
+ print("Topgroups found:")
+ for tg in tgs:
+ print(tg)
+ return None, None
+
+ version = get_py4DSTEM_version(fp, tg)
+ assert version == (0, 5, 0), "File must be v0.5.0."
+ _data_id = "data_id" in kwargs.keys() # Flag indicating if data was requested
+
+ # Validate inputs
+ if _data_id:
+ data_id = kwargs["data_id"]
+ assert isinstance(
+ data_id, (int, str, list, tuple)
+ ), "Error: data must be specified with strings or integers only."
+ if not isinstance(data_id, (int, str)):
+ assert all(
+ [isinstance(d, (int, str)) for d in data_id]
+ ), "Error: data must be specified with strings or integers only."
+
+ # Parse optional arguments
+ if "mem" in kwargs.keys():
+ mem = kwargs["mem"]
+ assert mem in ("RAM", "MEMMAP")
+ else:
+ mem = "RAM"
+ if "binfactor" in kwargs.keys():
+ binfactor = kwargs["binfactor"]
+ assert isinstance(binfactor, int)
+ else:
+ binfactor = 1
+ if "dtype" in kwargs.keys():
+ bindtype = kwargs["dtype"]
+ assert isinstance(bindtype, type)
+ else:
+ bindtype = None
+
+ # Perform requested operations
+ if not _data_id:
+ print_py4DSTEM_file(fp, tg)
+ return
+ else:
+ return get_data(fp, tg, data_id, mem, binfactor, bindtype)
+
+
+############ Helper functions ############
+
+
+def print_py4DSTEM_file(fp, tg):
+ """Accepts a fp to a valid py4DSTEM file and prints to screen the file contents."""
+ info = get_py4DSTEM_dataobject_info(fp, tg)
+
+ version = get_py4DSTEM_version(fp, tg)
+ print(f"py4DSTEM file version {version[0]}.{version[1]}.{version[2]}")
+
+ print("{:10}{:18}{:24}{:54}".format("Index", "Type", "Shape", "Name"))
+ print("{:10}{:18}{:24}{:54}".format("-----", "----", "-----", "----"))
+ for el in info:
+ print(
+ " {:8}{:18}{:24}{:54}".format(
+ str(el["index"]), str(el["type"]), str(el["shape"]), str(el["name"])
+ )
+ )
+
+ return
+
+
+def get_data(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and an int/str/list specifying data, and returns the data."""
+ if isinstance(data_id, int):
+ return get_data_from_int(
+ fp, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ elif isinstance(data_id, str):
+ return get_data_from_str(
+ fp, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ else:
+ return get_data_from_list(fp, tg, data_id)
+
+
+def get_data_from_int(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and an integer specifying data, and returns the data."""
+ assert isinstance(data_id, int)
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grps = [grp_dc, grp_ds, grp_rs, grp_pl, grp_pla]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i = np.nonzero(data_id < Ns)[0][0]
+ grp = grps[i]
+ N = data_id - Ns[i]
+ name = sorted(grp.keys())[N]
+
+ group_name = grp.name + "/" + name
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ if mem == "MEMMAP":
+ f = h5py.File(fp, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_str(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and a string specifying data, and returns the data."""
+ assert isinstance(data_id, str)
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grps = [grp_dc, grp_ds, grp_rs, grp_pl, grp_pla]
+
+ l_dc = list(grp_dc.keys())
+ l_ds = list(grp_ds.keys())
+ l_rs = list(grp_rs.keys())
+ l_pl = list(grp_pl.keys())
+ l_pla = list(grp_pla.keys())
+ names = l_dc + l_ds + l_rs + l_pl + l_pla
+
+ inds = [i for i, name in enumerate(names) if name == data_id]
+ assert len(inds) != 0, "Error: no data named {} found.".format(data_id)
+ assert len(inds) < 2, "Error: multiple data blocks named {} found.".format(
+ data_id
+ )
+ ind = inds[0]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i_grp = np.nonzero(ind < Ns)[0][0]
+ grp = grps[i_grp]
+
+ group_name = grp.name + "/" + data_id
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ # if using MEMMAP, file cannot be accessed from the context manager
+ # or else it will be closed before the data is accessed
+ if mem == "MEMMAP":
+ f = h5py.File(fp, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_list(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and a list or tuple specifying data, and returns the data."""
+ assert isinstance(data_id, (list, tuple))
+ assert all([isinstance(d, (int, str)) for d in data_id])
+ data = []
+ for el in data_id:
+ if isinstance(el, int):
+ data.append(
+ get_data_from_int(
+ fp, tg, data_id=el, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ )
+ elif isinstance(el, str):
+ data.append(
+ get_data_from_str(
+ fp, tg, data_id=el, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ )
+ else:
+ raise Exception("Data must be specified with strings or integers only.")
+ return data
+
+
+def get_data_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single dataobject in an open, correctly formatted H5 file,
+ and returns a py4DSTEM DataObject.
+ """
+ dtype = g.name.split("/")[-2]
+ if dtype == "datacubes":
+ return get_datacube_from_grp(g, mem, binfactor, bindtype)
+ elif dtype == "counted_datacubes":
+ return get_counted_datacube_from_grp(g)
+ elif dtype == "diffractionslices":
+ return get_diffractionslice_from_grp(g)
+ elif dtype == "realslices":
+ return get_realslice_from_grp(g)
+ elif dtype == "pointlists":
+ return get_pointlist_from_grp(g)
+ elif dtype == "pointlistarrays":
+ return get_pointlistarray_from_grp(g)
+ else:
+ raise Exception("Unrecognized data object type {}".format(dtype))
+
+
+def get_datacube_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single datacube in an open, correctly formatted H5 file,
+ and returns a DataCube.
+ """
+ assert binfactor == 1, "Bin on load is currently unsupported for EMD files."
+
+ if (mem, binfactor) == ("RAM", 1):
+ data = np.array(g["datacube"])
+ elif (mem, binfactor) == ("MEMMAP", 1):
+ data = g["datacube"]
+
+ name = g.name.split("/")[-1]
+ return DataCube(data=data, name=name)
+
+
+def get_counted_datacube_from_grp(g):
+ """Accepts an h5py Group corresponding to a counted datacube in an open, correctly formatted H5 file,
+ and returns a CountedDataCube.
+ """
+ return # TODO
+
+
+def get_diffractionslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a diffractionslice in an open, correctly formatted H5 file,
+ and returns a DiffractionSlice.
+ """
+ data = np.array(g["diffractionslice"])
+ name = g.name.split("/")[-1]
+ Q_Nx, Q_Ny = data.shape[:2]
+ if len(data.shape) == 2:
+ return DiffractionSlice(data=data, Q_Nx=Q_Nx, Q_Ny=Q_Ny, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ with lbls.astype("S64"):
+ lbls = lbls[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return DiffractionSlice(
+ data=data, Q_Nx=Q_Nx, Q_Ny=Q_Ny, name=name, slicelabels=lbls
+ )
+
+
+def get_realslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a realslice in an open, correctly formatted H5 file,
+ and returns a RealSlice.
+ """
+ data = np.array(g["realslice"])
+ name = g.name.split("/")[-1]
+ R_Nx, R_Ny = data.shape[:2]
+ if len(data.shape) == 2:
+ return RealSlice(data=data, R_Nx=R_Nx, R_Ny=R_Ny, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ with lbls.astype("S64"):
+ lbls = lbls[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return RealSlice(data=data, R_Nx=R_Nx, R_Ny=R_Ny, name=name, slicelabels=lbls)
+
+
+def get_pointlist_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlist in an open, correctly formatted H5 file,
+ and returns a PointList.
+ """
+ name = g.name.split("/")[-1]
+ coordinates = []
+ coord_names = list(g.keys())
+ length = len(g[coord_names[0] + "/pointlist"])
+ if length == 0:
+ for coord in coord_names:
+ coordinates.append((coord, None))
+ else:
+ for coord in coord_names:
+ dtype = type(g[coord + "/pointlist"][0])
+ coordinates.append((coord, dtype))
+ data = np.zeros(length, dtype=coordinates)
+ for coord in coord_names:
+ data[coord] = np.array(g[coord + "/pointlist"])
+ return PointList(data=data, name=name)
+
+
+def get_pointlistarray_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlistarray in an open, correctly formatted H5 file,
+ and returns a PointListArray.
+ """
+ name = g.name.split("/")[-1]
+ l = list(g)
+ ar = np.array([l[i].split("_") for i in range(len(l))]).astype(int)
+ shape = (np.max(ar[:, 0]) + 1, np.max(ar[:, 1]) + 1)
+ coord_names = list(g["0_0"])
+ N = len(coord_names)
+ coord_types = [
+ type(np.array(g["0_0/" + coord_names[i] + "/pointlistarray"])[0])
+ for i in range(N)
+ ]
+ coordinates = [(coord_names[i], coord_types[i]) for i in range(N)]
+ pla = PointListArray(dtype=coordinates, shape=shape, name=name)
+ for i, j in tqdmnd(
+ range(shape[0]),
+ range(shape[1]),
+ desc="Reading PointListArray",
+ unit="PointList",
+ ):
+ g_pl = g[str(i) + "_" + str(j)]
+ L = len(np.array(g_pl[coordinates[0][0] + "/pointlistarray"]))
+ data = np.zeros(L, dtype=coordinates)
+ for i in range(N):
+ coord = coordinates[i][0]
+ data[coord] = np.array(g_pl[coord + "/pointlistarray"])
+ pla.get_pointlist(i, j).data = data
+ return pla
+
+
+########### Metadata and log ############
+
+
+def get_metadata(fp, tg):
+ """Accepts a fp to a valid py4DSTEM file, and return a dictionary with its metadata."""
+ return # TODO
+
+
+def write_log(fp, tg):
+ """Accepts a fp to a valid py4DSTEM file, then prints its processing log to splitext(fp)[0]+'.log'."""
+ return # TODO
diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_6.py b/py4DSTEM/io/legacy/legacy12/read_v0_6.py
new file mode 100644
index 000000000..f746548ca
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_v0_6.py
@@ -0,0 +1,395 @@
+# Reader for py4DSTEM v0.6 files
+
+import h5py
+import numpy as np
+from os.path import splitext
+from py4DSTEM.io.legacy.read_utils import (
+ is_py4DSTEM_file,
+ get_py4DSTEM_topgroups,
+ get_py4DSTEM_version,
+ version_is_geq,
+)
+from py4DSTEM.io.legacy.legacy12.read_utils_v0_6 import get_py4DSTEM_dataobject_info
+from emdfile import PointList, PointListArray
+from py4DSTEM.data import (
+ DiffractionSlice,
+ RealSlice,
+)
+from py4DSTEM.datacube import DataCube
+from emdfile import tqdmnd
+
+
+def read_v0_6(fp, **kwargs):
+ """
+ File reader for files written by py4DSTEM v0.6. Precise behavior is detemined by which
+ arguments are passed -- see below.
+
+ ***NOTE: this function has not yet been tested on all legacy py4DSTEM formats. Please report
+ any problems by filing an issue on our github!
+
+ Accepts:
+ filepath str or Path When passed a filepath only, this function checks if the path
+ points to a valid py4DSTEM file, then prints its contents to screen.
+ data_id int/str/list Specifies which data to load. Use integers to specify the
+ data index, or strings to specify data names. A list or
+ tuple returns a list of DataObjects. Returns the specified data.
+ topgroup str Stricty, a py4DSTEM file is considered to be
+ everything inside a toplevel subdirectory within the
+ HDF5 file, so that if desired one can place many py4DSTEM
+ files inside a single H5. In this case, when loading
+ data, the topgroup argument is passed to indicate which
+ py4DSTEM file to load. If an H5 containing multiple
+ py4DSTEM files is passed without a topgroup specified,
+ the topgroup names are printed to screen.
+ metadata bool If True, returns a dictionary with the file metadata.
+ log bool If True, writes the processing log to a plaintext file
+ called splitext(fp)[0]+'.log'.
+ mem str Only used if a single DataCube is loaded. In this case, mem
+ specifies how the data should be stored; must be "RAM"
+ or "MEMMAP". See docstring for py4DSTEM.file.io.read. Default
+ is "RAM".
+ binfactor int Only used if a single DataCube is loaded. In this case,
+ a binfactor of > 1 causes the data to be binned by this amount
+ as it's loaded.
+ dtype dtype Used when binning data, ignored otherwise. Defaults to whatever
+ the type of the raw data is, to avoid enlarging data size. May be
+ useful to avoid 'wraparound' errors.
+
+ Returns:
+ data,md The function always returns a length 2 tuple corresponding
+ to data and md. If no input arguments with return values (i.e.
+ data, metadata), these will return None. Otherwise, their return
+ values are as described above. E.f. passing data=[0,1,2],metadata=True
+ will return a length two tuple, the first element being a list of 3
+ DataObject instances and the second a MetaData instance.
+ """
+ assert is_py4DSTEM_file(
+ fp
+ ), "Error: {} isn't recognized as a py4DSTEM file.".format(fp)
+
+ # For HDF5 files containing multiple valid EMD type 2 files, disambiguate desired data
+ tgs = get_py4DSTEM_topgroups(fp)
+ if "topgroup" in kwargs.keys():
+ tg = kwargs["topgroup"]
+ assert tg in tgs, "Error: specified topgroup, {}, not found.".format(tg)
+ else:
+ if len(tgs) == 1:
+ tg = tgs[0]
+ else:
+ print(
+ "Multiple topgroups detected. Please specify one by passing the 'topgroup' keyword argument."
+ )
+ print("")
+ print("Topgroups found:")
+ for tg in tgs:
+ print(tg)
+ return None, None
+
+ version = get_py4DSTEM_version(fp, tg)
+ assert version == (0, 6, 0), "File must be v0.6.0."
+ _data_id = "data_id" in kwargs.keys() # Flag indicating if data was requested
+
+ # Validate inputs
+ if _data_id:
+ data_id = kwargs["data_id"]
+ assert isinstance(
+ data_id, (int, str, list, tuple)
+ ), "Error: data must be specified with strings or integers only."
+ if not isinstance(data_id, (int, str)):
+ assert all(
+ [isinstance(d, (int, str)) for d in data_id]
+ ), "Error: data must be specified with strings or integers only."
+
+ # Parse optional arguments
+ if "mem" in kwargs.keys():
+ mem = kwargs["mem"]
+ assert mem in ("RAM", "MEMMAP")
+ else:
+ mem = "RAM"
+ if "binfactor" in kwargs.keys():
+ binfactor = kwargs["binfactor"]
+ assert isinstance(binfactor, int)
+ else:
+ binfactor = 1
+ if "dtype" in kwargs.keys():
+ bindtype = kwargs["dtype"]
+ assert isinstance(bindtype, type)
+ else:
+ bindtype = None
+
+ # Perform requested operations
+ if not _data_id:
+ print_py4DSTEM_file(fp, tg)
+ return
+ else:
+ return get_data(fp, tg, data_id, mem, binfactor, bindtype)
+
+
+############ Helper functions ############
+
+
+def print_py4DSTEM_file(fp, tg):
+ """Accepts a fp to a valid py4DSTEM file and prints to screen the file contents."""
+ info = get_py4DSTEM_dataobject_info(fp, tg)
+
+ version = get_py4DSTEM_version(fp, tg)
+ print(f"py4DSTEM file version {version[0]}.{version[1]}.{version[2]}")
+
+ print("{:10}{:18}{:24}{:54}".format("Index", "Type", "Shape", "Name"))
+ print("{:10}{:18}{:24}{:54}".format("-----", "----", "-----", "----"))
+ for el in info:
+ print(
+ " {:8}{:18}{:24}{:54}".format(
+ str(el["index"]), str(el["type"]), str(el["shape"]), str(el["name"])
+ )
+ )
+
+ return
+
+
+def get_data(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and an int/str/list specifying data, and returns the data."""
+ if isinstance(data_id, int):
+ return get_data_from_int(
+ fp, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ elif isinstance(data_id, str):
+ return get_data_from_str(
+ fp, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ else:
+ return get_data_from_list(fp, tg, data_id)
+
+
+def get_data_from_int(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and an integer specifying data, and returns the data."""
+ assert isinstance(data_id, int)
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grps = [grp_dc, grp_ds, grp_rs, grp_pl, grp_pla]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i = np.nonzero(data_id < Ns)[0][0]
+ grp = grps[i]
+ N = data_id - Ns[i]
+ name = sorted(grp.keys())[N]
+
+ group_name = grp.name + "/" + name
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ if mem == "MEMMAP":
+ f = h5py.File(fp, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_str(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and a string specifying data, and returns the data."""
+ assert isinstance(data_id, str)
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grps = [grp_dc, grp_ds, grp_rs, grp_pl, grp_pla]
+
+ l_dc = list(grp_dc.keys())
+ l_ds = list(grp_ds.keys())
+ l_rs = list(grp_rs.keys())
+ l_pl = list(grp_pl.keys())
+ l_pla = list(grp_pla.keys())
+ names = l_dc + l_ds + l_rs + l_pl + l_pla
+
+ inds = [i for i, name in enumerate(names) if name == data_id]
+ assert len(inds) != 0, "Error: no data named {} found.".format(data_id)
+ assert len(inds) < 2, "Error: multiple data blocks named {} found.".format(
+ data_id
+ )
+ ind = inds[0]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i_grp = np.nonzero(ind < Ns)[0][0]
+ grp = grps[i_grp]
+
+ group_name = grp.name + "/" + data_id
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ # if using MEMMAP, file cannot be accessed from the context manager
+ # or else it will be closed before the data is accessed
+ if mem == "MEMMAP":
+ f = h5py.File(fp, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_list(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and a list or tuple specifying data, and returns the data."""
+ assert isinstance(data_id, (list, tuple))
+ assert all([isinstance(d, (int, str)) for d in data_id])
+ data = []
+ for el in data_id:
+ if isinstance(el, int):
+ data.append(
+ get_data_from_int(
+ fp, tg, data_id=el, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ )
+ elif isinstance(el, str):
+ data.append(
+ get_data_from_str(
+ fp, tg, data_id=el, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ )
+ else:
+ raise Exception("Data must be specified with strings or integers only.")
+ return data
+
+
+def get_data_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single dataobject in an open, correctly formatted H5 file,
+ and returns a py4DSTEM DataObject.
+ """
+ dtype = g.name.split("/")[-2]
+ if dtype == "datacubes":
+ return get_datacube_from_grp(g, mem, binfactor, bindtype)
+ elif dtype == "counted_datacubes":
+ return get_counted_datacube_from_grp(g)
+ elif dtype == "diffractionslices":
+ return get_diffractionslice_from_grp(g)
+ elif dtype == "realslices":
+ return get_realslice_from_grp(g)
+ elif dtype == "pointlists":
+ return get_pointlist_from_grp(g)
+ elif dtype == "pointlistarrays":
+ return get_pointlistarray_from_grp(g)
+ else:
+ raise Exception("Unrecognized data object type {}".format(dtype))
+
+
+def get_datacube_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single datacube in an open, correctly formatted H5 file,
+ and returns a DataCube.
+ """
+ assert binfactor == 1, "Bin on load is currently unsupported for EMD files."
+
+ if (mem, binfactor) == ("RAM", 1):
+ data = np.array(g["data"])
+ elif (mem, binfactor) == ("MEMMAP", 1):
+ data = g["data"]
+
+ name = g.name.split("/")[-1]
+ return DataCube(data=data, name=name)
+
+
+def get_counted_datacube_from_grp(g):
+ """Accepts an h5py Group corresponding to a counted datacube in an open, correctly formatted H5 file,
+ and returns a CountedDataCube.
+ """
+ return # TODO
+
+
+def get_diffractionslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a diffractionslice in an open, correctly formatted H5 file,
+ and returns a DiffractionSlice.
+ """
+ data = np.array(g["data"])
+ name = g.name.split("/")[-1]
+ if len(data.shape) == 2:
+ return DiffractionSlice(data=data, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ lbls = lbls.astype("S64")[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return DiffractionSlice(data=data, name=name, slicelabels=lbls)
+
+
+def get_realslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a realslice in an open, correctly formatted H5 file,
+ and returns a RealSlice.
+ """
+ data = np.array(g["data"])
+ name = g.name.split("/")[-1]
+ if len(data.shape) == 2:
+ return RealSlice(data=data, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ lbls = lbls.astype("S64")[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return RealSlice(data=data, name=name, slicelabels=lbls)
+
+
+def get_pointlist_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlist in an open, correctly formatted H5 file,
+ and returns a PointList.
+ """
+ name = g.name.split("/")[-1]
+ coordinates = []
+ coord_names = list(g.keys())
+ length = len(g[coord_names[0] + "/data"])
+ if length == 0:
+ for coord in coord_names:
+ coordinates.append((coord, None))
+ else:
+ for coord in coord_names:
+ dtype = type(g[coord + "/data"][0])
+ coordinates.append((coord, dtype))
+ data = np.zeros(length, dtype=coordinates)
+ for coord in coord_names:
+ data[coord] = np.array(g[coord + "/data"])
+ return PointList(data=data, name=name)
+
+
+def get_pointlistarray_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlistarray in an open, correctly formatted H5 file,
+ and returns a PointListArray.
+ """
+ name = g.name.split("/")[-1]
+ l = list(g)
+ ar = np.array([l[i].split("_") for i in range(len(l))]).astype(int)
+ shape = (np.max(ar[:, 0]) + 1, np.max(ar[:, 1]) + 1)
+ coord_names = list(g["0_0"])
+ N = len(coord_names)
+ coord_types = [
+ type(np.array(g["0_0/" + coord_names[i] + "/data"])[0]) for i in range(N)
+ ]
+ coordinates = [(coord_names[i], coord_types[i]) for i in range(N)]
+ pla = PointListArray(dtype=coordinates, shape=shape, name=name)
+ for i, j in tqdmnd(
+ range(shape[0]),
+ range(shape[1]),
+ desc="Reading PointListArray",
+ unit="PointList",
+ ):
+ g_pl = g[str(i) + "_" + str(j)]
+ L = len(np.array(g_pl[coordinates[0][0] + "/data"]))
+ data = np.zeros(L, dtype=coordinates)
+ for i in range(N):
+ coord = coordinates[i][0]
+ data[coord] = np.array(g_pl[coord + "/data"])
+ pla.get_pointlist(i, j).data = data
+ return pla
diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_7.py b/py4DSTEM/io/legacy/legacy12/read_v0_7.py
new file mode 100644
index 000000000..fac779d64
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_v0_7.py
@@ -0,0 +1,394 @@
+# Reader for py4DSTEM v0.7 - v0.8 files
+
+import h5py
+import numpy as np
+from os.path import splitext
+from py4DSTEM.io.legacy.read_utils import (
+ is_py4DSTEM_file,
+ get_py4DSTEM_topgroups,
+ get_py4DSTEM_version,
+ version_is_geq,
+)
+from py4DSTEM.io.legacy.legacy12.read_utils_v0_7 import get_py4DSTEM_dataobject_info
+from emdfile import PointList, PointListArray
+from py4DSTEM.data import (
+ DiffractionSlice,
+ RealSlice,
+)
+from py4DSTEM.datacube import DataCube
+from emdfile import tqdmnd
+
+
+def read_v0_7(fp, **kwargs):
+ """
+ File reader for files written by py4DSTEM v0.7 or v0.8. Precise behavior is detemined by which
+ arguments are passed -- see below.
+
+ ***NOTE: this function has not yet been tested on all legacy py4DSTEM formats. Please report
+ any problems by filing an issue on our github!
+
+ Accepts:
+ filepath str or Path When passed a filepath only, this function checks if the path
+ points to a valid py4DSTEM file, then prints its contents to screen.
+ data_id int/str/list Specifies which data to load. Use integers to specify the
+ data index, or strings to specify data names. A list or
+ tuple returns a list of DataObjects. Returns the specified data.
+ topgroup str Stricty, a py4DSTEM file is considered to be
+ everything inside a toplevel subdirectory within the
+ HDF5 file, so that if desired one can place many py4DSTEM
+ files inside a single H5. In this case, when loading
+ data, the topgroup argument is passed to indicate which
+ py4DSTEM file to load. If an H5 containing multiple
+ py4DSTEM files is passed without a topgroup specified,
+ the topgroup names are printed to screen.
+ metadata bool If True, returns a dictionary with the file metadata.
+ log bool If True, writes the processing log to a plaintext file
+ called splitext(fp)[0]+'.log'.
+ mem str Only used if a single DataCube is loaded. In this case, mem
+ specifies how the data should be stored; must be "RAM"
+ or "MEMMAP". See docstring for py4DSTEM.file.io.read. Default
+ is "RAM".
+ binfactor int Only used if a single DataCube is loaded. In this case,
+ a binfactor of > 1 causes the data to be binned by this amount
+ as it's loaded.
+ dtype dtype Used when binning data, ignored otherwise. Defaults to whatever
+ the type of the raw data is, to avoid enlarging data size. May be
+ useful to avoid 'wraparound' errors.
+
+ Returns:
+ data,md The function always returns a length 2 tuple corresponding
+ to data and md. If no input arguments with return values (i.e.
+ data, metadata), these will return None. Otherwise, their return
+ values are as described above. E.f. passing data=[0,1,2],metadata=True
+ will return a length two tuple, the first element being a list of 3
+ DataObject instances and the second a MetaData instance.
+ """
+ assert is_py4DSTEM_file(
+ fp
+ ), "Error: {} isn't recognized as a py4DSTEM file.".format(fp)
+
+ # For HDF5 files containing multiple valid EMD type 2 files, disambiguate desired data
+ tgs = get_py4DSTEM_topgroups(fp)
+ if "topgroup" in kwargs.keys():
+ tg = kwargs["topgroup"]
+ assert tg in tgs, "Error: specified topgroup, {}, not found.".format(tg)
+ else:
+ if len(tgs) == 1:
+ tg = tgs[0]
+ else:
+ print(
+ "Multiple topgroups detected. Please specify one by passing the 'topgroup' keyword argument."
+ )
+ print("")
+ print("Topgroups found:")
+ for tg in tgs:
+ print(tg)
+ return None, None
+
+ version = get_py4DSTEM_version(fp, tg)
+ assert version == (0, 7, 0), "File must be v0.7.0."
+ _data_id = "data_id" in kwargs.keys() # Flag indicating if data was requested
+
+ # Validate inputs
+ if _data_id:
+ data_id = kwargs["data_id"]
+ assert isinstance(
+ data_id, (int, str, list, tuple)
+ ), "Error: data must be specified with strings or integers only."
+ if not isinstance(data_id, (int, str)):
+ assert all(
+ [isinstance(d, (int, str)) for d in data_id]
+ ), "Error: data must be specified with strings or integers only."
+
+ # Parse optional arguments
+ if "mem" in kwargs.keys():
+ mem = kwargs["mem"]
+ assert mem in ("RAM", "MEMMAP")
+ else:
+ mem = "RAM"
+ if "binfactor" in kwargs.keys():
+ binfactor = kwargs["binfactor"]
+ assert isinstance(binfactor, int)
+ else:
+ binfactor = 1
+ if "dtype" in kwargs.keys():
+ bindtype = kwargs["dtype"]
+ assert isinstance(bindtype, type)
+ else:
+ bindtype = None
+
+ # Perform requested operations
+ if not _data_id:
+ print_py4DSTEM_file(fp, tg)
+ return
+ else:
+ return get_data(fp, tg, data_id, mem, binfactor, bindtype)
+
+
+############ Helper functions ############
+
+
+def print_py4DSTEM_file(fp, tg):
+ """Accepts a fp to a valid py4DSTEM file and prints to screen the file contents."""
+ info = get_py4DSTEM_dataobject_info(fp, tg)
+
+ version = get_py4DSTEM_version(fp, tg)
+ print(f"py4DSTEM file version {version[0]}.{version[1]}.{version[2]}")
+
+ print("{:10}{:18}{:24}{:54}".format("Index", "Type", "Shape", "Name"))
+ print("{:10}{:18}{:24}{:54}".format("-----", "----", "-----", "----"))
+ for el in info:
+ print(
+ " {:8}{:18}{:24}{:54}".format(
+ str(el["index"]), str(el["type"]), str(el["shape"]), str(el["name"])
+ )
+ )
+
+ return
+
+
+def get_data(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and an int/str/list specifying data, and returns the data."""
+ if isinstance(data_id, int):
+ return get_data_from_int(
+ fp, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ elif isinstance(data_id, str):
+ return get_data_from_str(
+ fp, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ else:
+ return get_data_from_list(fp, tg, data_id)
+
+
+def get_data_from_int(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and an integer specifying data, and returns the data."""
+ assert isinstance(data_id, int)
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grps = [grp_dc, grp_ds, grp_rs, grp_pl, grp_pla]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i = np.nonzero(data_id < Ns)[0][0]
+ grp = grps[i]
+ N = data_id - Ns[i]
+ name = sorted(grp.keys())[N]
+
+ group_name = grp.name + "/" + name
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ if mem == "MEMMAP":
+ f = h5py.File(fp, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_str(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and a string specifying data, and returns the data."""
+ assert isinstance(data_id, str)
+ with h5py.File(fp, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grps = [grp_dc, grp_ds, grp_rs, grp_pl, grp_pla]
+
+ l_dc = list(grp_dc.keys())
+ l_ds = list(grp_ds.keys())
+ l_rs = list(grp_rs.keys())
+ l_pl = list(grp_pl.keys())
+ l_pla = list(grp_pla.keys())
+ names = l_dc + l_ds + l_rs + l_pl + l_pla
+
+ inds = [i for i, name in enumerate(names) if name == data_id]
+ assert len(inds) != 0, "Error: no data named {} found.".format(data_id)
+ assert len(inds) < 2, "Error: multiple data blocks named {} found.".format(
+ data_id
+ )
+ ind = inds[0]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i_grp = np.nonzero(ind < Ns)[0][0]
+ grp = grps[i_grp]
+
+ group_name = grp.name + "/" + data_id
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ # if using MEMMAP, file cannot be accessed from the context manager
+ # or else it will be closed before the data is accessed
+ if mem == "MEMMAP":
+ f = h5py.File(fp, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_list(fp, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a fp to a valid py4DSTEM file and a list or tuple specifying data, and returns the data."""
+ assert isinstance(data_id, (list, tuple))
+ assert all([isinstance(d, (int, str)) for d in data_id])
+ data = []
+ for el in data_id:
+ if isinstance(el, int):
+ data.append(
+ get_data_from_int(
+ fp, tg, data_id=el, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ )
+ elif isinstance(el, str):
+ data.append(
+ get_data_from_str(
+ fp, tg, data_id=el, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ )
+ else:
+ raise Exception("Data must be specified with strings or integers only.")
+ return data
+
+
+def get_data_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single dataobject in an open, correctly formatted H5 file,
+ and returns a py4DSTEM DataObject.
+ """
+ dtype = g.name.split("/")[-2]
+ if dtype == "datacubes":
+ return get_datacube_from_grp(g, mem, binfactor, bindtype)
+ elif dtype == "counted_datacubes":
+ return get_counted_datacube_from_grp(g)
+ elif dtype == "diffractionslices":
+ return get_diffractionslice_from_grp(g)
+ elif dtype == "realslices":
+ return get_realslice_from_grp(g)
+ elif dtype == "pointlists":
+ return get_pointlist_from_grp(g)
+ elif dtype == "pointlistarrays":
+ return get_pointlistarray_from_grp(g)
+ else:
+ raise Exception("Unrecognized data object type {}".format(dtype))
+
+
+def get_datacube_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single datacube in an open, correctly formatted H5 file,
+ and returns a DataCube.
+ """
+ assert binfactor == 1, "Bin on load is currently unsupported for EMD files."
+
+ if (mem, binfactor) == ("RAM", 1):
+ data = np.array(g["data"])
+ elif (mem, binfactor) == ("MEMMAP", 1):
+ data = g["data"]
+
+ name = g.name.split("/")[-1]
+ return DataCube(data=data, name=name)
+
+
+def get_counted_datacube_from_grp(g):
+ """Accepts an h5py Group corresponding to a counted datacube in an open, correctly formatted H5 file,
+ and returns a CountedDataCube.
+ """
+ return # TODO
+
+
+def get_diffractionslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a diffractionslice in an open, correctly formatted H5 file,
+ and returns a DiffractionSlice.
+ """
+ data = np.array(g["data"])
+ name = g.name.split("/")[-1]
+ Q_Nx, Q_Ny = data.shape[:2]
+ if len(data.shape) == 2:
+ return DiffractionSlice(data=data, Q_Nx=Q_Nx, Q_Ny=Q_Ny, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ with lbls.astype("S64"):
+ lbls = lbls[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return DiffractionSlice(
+ data=data, Q_Nx=Q_Nx, Q_Ny=Q_Ny, name=name, slicelabels=lbls
+ )
+
+
+def get_realslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a realslice in an open, correctly formatted H5 file,
+ and returns a RealSlice.
+ """
+ data = np.array(g["data"])
+ name = g.name.split("/")[-1]
+ R_Nx, R_Ny = data.shape[:2]
+ if len(data.shape) == 2:
+ return RealSlice(data=data, R_Nx=R_Nx, R_Ny=R_Ny, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ with lbls.astype("S64"):
+ lbls = lbls[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return RealSlice(data=data, R_Nx=R_Nx, R_Ny=R_Ny, name=name, slicelabels=lbls)
+
+
+def get_pointlist_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlist in an open, correctly formatted H5 file,
+ and returns a PointList.
+ """
+ name = g.name.split("/")[-1]
+ coordinates = []
+ coord_names = list(g.keys())
+ length = len(g[coord_names[0] + "/data"])
+ if length == 0:
+ for coord in coord_names:
+ coordinates.append((coord, None))
+ else:
+ for coord in coord_names:
+ dtype = type(g[coord + "/data"][0])
+ coordinates.append((coord, dtype))
+ data = np.zeros(length, dtype=coordinates)
+ for coord in coord_names:
+ data[coord] = np.array(g[coord + "/data"])
+ return PointList(data=data, name=name)
+
+
+def get_pointlistarray_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlistarray in an open, correctly formatted H5 file,
+ and returns a PointListArray.
+ """
+ name = g.name.split("/")[-1]
+ # l = list(g)
+ print(g)
+ ar = np.array(g["data"])
+ shape = ar.shape
+ coord_names = ar[0, 0].dtype.names
+ N = len(coord_names)
+ coord_types = [ar[0, 0].dtype.fields[i][0] for i in coord_names]
+ coordinates = [(coord_names[i], coord_types[i]) for i in range(N)]
+ pla = PointListArray(dtype=coordinates, shape=shape, name=name)
+ for i, j in tqdmnd(
+ range(shape[0]),
+ range(shape[1]),
+ desc="Reading PointListArray",
+ unit="PointList",
+ ):
+ pla.get_pointlist(i, j).data = ar[i, j]
+ return pla
diff --git a/py4DSTEM/io/legacy/legacy12/read_v0_9.py b/py4DSTEM/io/legacy/legacy12/read_v0_9.py
new file mode 100644
index 000000000..0cf186ffd
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy12/read_v0_9.py
@@ -0,0 +1,402 @@
+# Reader for py4DSTEM v0.9 - v0.11 files
+
+import h5py
+import numpy as np
+from os.path import splitext, exists
+from py4DSTEM.io.legacy.read_utils import (
+ is_py4DSTEM_file,
+ get_py4DSTEM_topgroups,
+ get_py4DSTEM_version,
+ version_is_geq,
+)
+from py4DSTEM.io.legacy.legacy12.read_utils_v0_9 import get_py4DSTEM_dataobject_info
+from emdfile import PointList, PointListArray
+from py4DSTEM.data import (
+ DiffractionSlice,
+ RealSlice,
+)
+from py4DSTEM.datacube import DataCube
+from emdfile import tqdmnd
+
+
+def read_v0_9(fp, **kwargs):
+ """
+ File reader for files written by py4DSTEM v0.9-0.11. Precise behavior is detemined by which
+ arguments are passed -- see below.
+
+ Accepts:
+ filepath str or Path When passed a filepath only, this function checks if the path
+ points to a valid py4DSTEM file, then prints its contents to screen.
+ data_id int/str/list Specifies which data to load. Use integers to specify the
+ data index, or strings to specify data names. A list or
+ tuple returns a list of DataObjects. Returns the specified data.
+ topgroup str Stricty, a py4DSTEM file is considered to be
+ everything inside a toplevel subdirectory within the
+ HDF5 file, so that if desired one can place many py4DSTEM
+ files inside a single H5. In this case, when loading
+ data, the topgroup argument is passed to indicate which
+ py4DSTEM file to load. If an H5 containing multiple
+ py4DSTEM files is passed without a topgroup specified,
+ the topgroup names are printed to screen.
+ metadata bool If True, returns a dictionary with the file metadata.
+ log bool If True, writes the processing log to a plaintext file
+ called splitext(fp)[0]+'.log'.
+ mem str Only used if a single DataCube is loaded. In this case, mem
+ specifies how the data should be stored; must be "RAM"
+ or "MEMMAP". See docstring for py4DSTEM.file.io.read. Default
+ is "RAM".
+ binfactor int Only used if a single DataCube is loaded. In this case,
+ a binfactor of > 1 causes the data to be binned by this amount
+ as it's loaded.
+ dtype dtype Used when binning data, ignored otherwise. Defaults to whatever
+ the type of the raw data is, to avoid enlarging data size. May be
+ useful to avoid 'wraparound' errors.
+
+ Returns:
+ data,md The function always returns a length 2 tuple corresponding
+ to data and md. If no input arguments with return values (i.e.
+ data, metadata), these will return None. Otherwise, their return
+ values are as described above. E.f. passing data=[0,1,2],metadata=True
+ will return a length two tuple, the first element being a list of 3
+ DataObject instances and the second a MetaData instance.
+ """
+ assert exists(fp), "Error: specified filepath does not exist"
+ assert is_py4DSTEM_file(
+ fp
+ ), "Error: {} isn't recognized as a py4DSTEM file.".format(fp)
+
+ # For HDF5 files containing multiple valid EMD type 2 files, disambiguate desired data
+ tgs = get_py4DSTEM_topgroups(fp)
+ if "topgroup" in kwargs.keys():
+ tg = kwargs["topgroup"]
+ assert tg in tgs, "Error: specified topgroup, {}, not found.".format(tg)
+ else:
+ if len(tgs) == 1:
+ tg = tgs[0]
+ else:
+ print(
+ "Multiple topgroups detected. Please specify one by passing the 'topgroup' keyword argument."
+ )
+ print("")
+ print("Topgroups found:")
+ for tg in tgs:
+ print(tg)
+ return None, None
+
+ version = get_py4DSTEM_version(fp, tg)
+ assert version_is_geq(version, (0, 9, 0)) and not version_is_geq(
+ version, (0, 12, 0)
+ ), "File must be v0.9-0.11"
+ _data_id = "data_id" in kwargs.keys() # Flag indicating if data was requested
+
+ # If metadata is requested
+ if "metadata" in kwargs.keys():
+ if kwargs["metadata"]:
+ raise NotImplementedError("Legacy metadata reader missing...")
+ # return metadata_from_h5(fp, tg)
+
+ # If data is requested
+ elif "data_id" in kwargs.keys():
+ data_id = kwargs["data_id"]
+ assert isinstance(
+ data_id, (int, np.int_, str, list, tuple)
+ ), "Error: data must be specified with strings or integers only."
+ if not isinstance(data_id, (int, np.int_, str)):
+ assert all(
+ [isinstance(d, (int, np.int_, str)) for d in data_id]
+ ), "Error: data must be specified with strings or integers only."
+
+ # Parse optional arguments
+ if "mem" in kwargs.keys():
+ mem = kwargs["mem"]
+ assert mem in ("RAM", "MEMMAP")
+ else:
+ mem = "RAM"
+ if "binfactor" in kwargs.keys():
+ binfactor = kwargs["binfactor"]
+ assert isinstance(binfactor, (int, np.int_))
+ else:
+ binfactor = 1
+ if "dtype" in kwargs.keys():
+ bindtype = kwargs["dtype"]
+ assert isinstance(bindtype, type)
+ else:
+ bindtype = None
+
+ return get_data(fp, tg, data_id, mem, binfactor, bindtype)
+
+ # If no data is requested
+ else:
+ print_py4DSTEM_file(fp, tg)
+ return
+
+
+############ Helper functions ############
+
+
+def print_py4DSTEM_file(filepath, tg):
+ """Accepts a filepath to a valid py4DSTEM file and prints to screen the file contents."""
+ info = get_py4DSTEM_dataobject_info(filepath, tg)
+
+ version = get_py4DSTEM_version(filepath, tg)
+ print(f"py4DSTEM file version {version[0]}.{version[1]}.{version[2]}")
+
+ print("{:10}{:18}{:24}{:54}".format("Index", "Type", "Shape", "Name"))
+ print("{:10}{:18}{:24}{:54}".format("-----", "----", "-----", "----"))
+ for el in info:
+ print(
+ " {:8}{:18}{:24}{:54}".format(
+ str(el["index"]), str(el["type"]), str(el["shape"]), str(el["name"])
+ )
+ )
+ return
+
+
+def get_data(filepath, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a filepath to a valid py4DSTEM file and an int/str/list specifying data, and returns the data."""
+ if isinstance(data_id, (int, np.int_)):
+ return get_data_from_int(
+ filepath, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ elif isinstance(data_id, str):
+ return get_data_from_str(
+ filepath, tg, data_id, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+ else:
+ return get_data_from_list(filepath, tg, data_id)
+
+
+def get_data_from_int(filepath, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a filepath to a valid py4DSTEM file and an integer specifying data, and returns the data."""
+ assert isinstance(data_id, (int, np.int_))
+ with h5py.File(filepath, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_cdc = f[tg + "/data/counted_datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grps = [grp_dc, grp_cdc, grp_ds, grp_rs, grp_pl, grp_pla]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i = np.nonzero(data_id < Ns)[0][0]
+ grp = grps[i]
+ N = data_id - Ns[i]
+ name = sorted(grp.keys())[N]
+
+ group_name = grp.name + "/" + name
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ if mem == "MEMMAP":
+ f = h5py.File(filepath, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_str(filepath, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a filepath to a valid py4DSTEM file and a string specifying data, and returns the data."""
+ assert isinstance(data_id, str)
+ with h5py.File(filepath, "r") as f:
+ grp_dc = f[tg + "/data/datacubes/"]
+ grp_cdc = f[tg + "/data/counted_datacubes/"]
+ grp_ds = f[tg + "/data/diffractionslices/"]
+ grp_rs = f[tg + "/data/realslices/"]
+ grp_pl = f[tg + "/data/pointlists/"]
+ grp_pla = f[tg + "/data/pointlistarrays/"]
+ grps = [grp_dc, grp_cdc, grp_ds, grp_rs, grp_pl, grp_pla]
+
+ l_dc = list(grp_dc.keys())
+ l_cdc = list(grp_cdc.keys())
+ l_ds = list(grp_ds.keys())
+ l_rs = list(grp_rs.keys())
+ l_pl = list(grp_pl.keys())
+ l_pla = list(grp_pla.keys())
+ names = l_dc + l_cdc + l_ds + l_rs + l_pl + l_pla
+
+ inds = [i for i, name in enumerate(names) if name == data_id]
+ assert len(inds) != 0, "Error: no data named {} found.".format(data_id)
+ assert len(inds) < 2, "Error: multiple data blocks named {} found.".format(
+ data_id
+ )
+ ind = inds[0]
+
+ Ns = np.cumsum([len(grp.keys()) for grp in grps])
+ i_grp = np.nonzero(ind < Ns)[0][0]
+ grp = grps[i_grp]
+ group_name = grp.name + "/" + data_id
+
+ if mem == "RAM":
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ # if using MEMMAP, file cannot be accessed from the context manager
+ # or else it will be closed before the data is accessed
+ if mem == "MEMMAP":
+ f = h5py.File(filepath, "r")
+ grp_data = f[group_name]
+ data = get_data_from_grp(
+ grp_data, mem=mem, binfactor=binfactor, bindtype=bindtype
+ )
+
+ return data
+
+
+def get_data_from_list(filepath, tg, data_id, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts a filepath to a valid py4DSTEM file and a list or tuple specifying data, and returns the data."""
+ assert isinstance(data_id, (list, tuple))
+ assert all([isinstance(d, (int, np.int_, str)) for d in data_id])
+ data = []
+ for el in data_id:
+ if isinstance(el, (int, np.int_)):
+ data.append(
+ get_data_from_int(
+ filepath,
+ tg,
+ data_id=el,
+ mem=mem,
+ binfactor=binfactor,
+ bindtype=bindtype,
+ )
+ )
+ elif isinstance(el, str):
+ data.append(
+ get_data_from_str(
+ filepath,
+ tg,
+ data_id=el,
+ mem=mem,
+ binfactor=binfactor,
+ bindtype=bindtype,
+ )
+ )
+ else:
+ raise Exception("Data must be specified with strings or integers only.")
+ return data
+
+
+def get_data_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single dataobject in an open, correctly formatted H5 file,
+ and returns a py4DSTEM DataObject.
+ """
+ dtype = g.name.split("/")[-2]
+ if dtype == "datacubes":
+ return get_datacube_from_grp(g, mem, binfactor, bindtype)
+ elif dtype == "counted_datacubes":
+ return get_counted_datacube_from_grp(g)
+ elif dtype == "diffractionslices":
+ return get_diffractionslice_from_grp(g)
+ elif dtype == "realslices":
+ return get_realslice_from_grp(g)
+ elif dtype == "pointlists":
+ return get_pointlist_from_grp(g)
+ elif dtype == "pointlistarrays":
+ return get_pointlistarray_from_grp(g)
+ else:
+ raise Exception("Unrecognized data object type {}".format(dtype))
+
+
+def get_datacube_from_grp(g, mem="RAM", binfactor=1, bindtype=None):
+ """Accepts an h5py Group corresponding to a single datacube in an open, correctly formatted H5 file,
+ and returns a DataCube.
+ """
+ assert binfactor == 1, "Bin on load is currently unsupported for EMD files."
+
+ if (mem, binfactor) == ("RAM", 1):
+ data = np.array(g["data"])
+ elif (mem, binfactor) == ("MEMMAP", 1):
+ data = g["data"]
+
+ name = g.name.split("/")[-1]
+ return DataCube(data=data, name=name)
+
+
+def get_counted_datacube_from_grp(g):
+ """Accepts an h5py Group corresponding to a counted datacube in an open, correctly formatted H5 file,
+ and returns a CountedDataCube.
+ """
+ raise NotImplementedError("CountedDataCubes are not available in py4DSTEM v0.13")
+ return # TODO
+
+
+def get_diffractionslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a diffractionslice in an open, correctly formatted H5 file,
+ and returns a DiffractionSlice.
+ """
+ data = np.array(g["data"])
+ name = g.name.split("/")[-1]
+ Q_Nx, Q_Ny = data.shape[:2]
+ if len(data.shape) == 2:
+ return DiffractionSlice(data=data, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ with lbls.astype("S64"):
+ lbls = lbls[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return DiffractionSlice(data=data, name=name, slicelabels=lbls)
+
+
+def get_realslice_from_grp(g):
+ """Accepts an h5py Group corresponding to a realslice in an open, correctly formatted H5 file,
+ and returns a RealSlice.
+ """
+ data = np.array(g["data"])
+ name = g.name.split("/")[-1]
+ R_Nx, R_Ny = data.shape[:2]
+ if len(data.shape) == 2:
+ return RealSlice(data=data, name=name)
+ else:
+ lbls = g["dim3"]
+ if "S" in lbls.dtype.str: # Checks if dim3 is composed of fixed width C strings
+ with lbls.astype("S64"):
+ lbls = lbls[:]
+ lbls = [lbl.decode("UTF-8") for lbl in lbls]
+ return RealSlice(data=data, name=name, slicelabels=lbls)
+
+
+def get_pointlist_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlist in an open, correctly formatted H5 file,
+ and returns a PointList.
+ """
+ name = g.name.split("/")[-1]
+ coordinates = []
+ coord_names = list(g.keys())
+ length = len(g[coord_names[0] + "/data"])
+ if length == 0:
+ for coord in coord_names:
+ coordinates.append((coord, None))
+ else:
+ for coord in coord_names:
+ dtype = type(g[coord + "/data"][0])
+ coordinates.append((coord, dtype))
+ data = np.zeros(length, dtype=coordinates)
+ for coord in coord_names:
+ data[coord] = np.array(g[coord + "/data"])
+ return PointList(data=data, name=name)
+
+
+def get_pointlistarray_from_grp(g):
+ """Accepts an h5py Group corresponding to a pointlistarray in an open, correctly formatted H5 file,
+ and returns a PointListArray.
+ """
+ name = g.name.split("/")[-1]
+ dset = g["data"]
+ shape = g["data"].shape
+ coordinates = g["data"][0, 0].dtype
+ pla = PointListArray(dtype=coordinates, shape=shape, name=name)
+ for i, j in tqdmnd(
+ shape[0], shape[1], desc="Reading PointListArray", unit="PointList"
+ ):
+ pla.get_pointlist(i, j).data = dset[i, j]
+ return pla
diff --git a/py4DSTEM/io/legacy/legacy13/__init__.py b/py4DSTEM/io/legacy/legacy13/__init__.py
new file mode 100644
index 000000000..b0f919cc0
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/__init__.py
@@ -0,0 +1,13 @@
+from .v13_emd_classes import Root, Metadata, Array, PointList, PointListArray
+from .v13_py4dstem_classes import (
+ Calibration,
+ DataCube,
+ DiffractionSlice,
+ VirtualDiffraction,
+ RealSlice,
+ VirtualImage,
+ Probe,
+ QPoints,
+ BraggVectors,
+)
+from .v13_to_14 import v13_to_14
diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/__init__.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/__init__.py
new file mode 100644
index 000000000..0c2b530e3
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/__init__.py
@@ -0,0 +1,7 @@
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.tree import *
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.root import *
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.metadata import *
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.array import *
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.pointlist import *
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.pointlistarray import *
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import *
diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py
new file mode 100644
index 000000000..a5192ffa6
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/array.py
@@ -0,0 +1,498 @@
+# Defines the Array class, which stores any N-dimensional array-like data.
+# Implements the EMD file standard - https://emdatasets.com/format
+
+from typing import Optional, Union
+import numpy as np
+import h5py
+from numbers import Number
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.tree import Tree
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.metadata import Metadata
+
+
+class Array:
+ """
+ A class which stores any N-dimensional array-like data, plus basic metadata:
+ a name and units, as well as calibrations for each axis of the array, and names
+ and units for those axis calibrations.
+ In the simplest usage, only a data array is passed:
+ >>> ar = Array(np.ones((20,20,256,256)))
+ will create an array instance whose data is the numpy array passed, and with
+ automatically populated dimension calibrations in units of pixels.
+ Additional arguments may be passed to populate the object metadata:
+ >>> ar = Array(
+ >>> np.ones((20,20,256,256)),
+ >>> name = 'test_array',
+ >>> units = 'intensity',
+ >>> dims = [
+ >>> [0,5],
+ >>> [0,5],
+ >>> [0,0.01],
+ >>> [0,0.01]
+ >>> ],
+ >>> dim_units = [
+ >>> 'nm',
+ >>> 'nm',
+ >>> 'A^-1',
+ >>> 'A^-1'
+ >>> ],
+ >>> dim_names = [
+ >>> 'rx',
+ >>> 'ry',
+ >>> 'qx',
+ >>> 'qy'
+ >>> ],
+ >>> )
+ will create an array with a name and units for its data, where its first two
+ dimensions are in units of nanometers, have pixel sizes of 5nm, and are
+ described by the handles 'rx' and 'ry', and where its last two dimensions
+ are in units of inverse Angstroms, have pixels sizes of 0.01A^-1, and are
+ described by the handles 'qx' and 'qy'.
+ Arrays in which the length of each pixel is non-constant are also
+ supported. For instance,
+ >>> x = np.logspace(0,1,100)
+ >>> y = np.sin(x)
+ >>> ar = Array(
+ >>> y,
+ >>> dims = [
+ >>> x
+ >>> ]
+ >>> )
+ generates an array representing the values of the sine function sampled
+ 100 times along a logarithmic interval from 1 to 10. In this example,
+ this data could then be plotted with, e.g.
+ >>> plt.scatter(ar.dims[0], ar.data)
+ If the `slicelabels` keyword is passed, the first N-1 dimensions of the
+ array are treated normally, while the final dimension is used to represent
+ distinct arrays which share a common shape and set of dim vectors. Thus
+ >>> ar = Array(
+ >>> np.ones((50,50,4)),
+ >>> name = 'test_array_stack',
+ >>> units = 'intensity',
+ >>> dims = [
+ >>> [0,2],
+ >>> [0,2]
+ >>> ],
+ >>> dim_units = [
+ >>> 'nm',
+ >>> 'nm'
+ >>> ],
+ >>> dim_names = [
+ >>> 'rx',
+ >>> 'ry'
+ >>> ],
+ >>> slicelabels = [
+ >>> 'a',
+ >>> 'b',
+ >>> 'c',
+ >>> 'd'
+ >>> ]
+ >>> )
+ will generate a single Array instance containing 4 arrays which each have
+ a shape (50,50) and a common set of dim vectors ['rx','ry'], and which
+ can be indexed into with the names assigned in `slicelabels` using
+ >>> ar.get_slice('a')
+ which will return a 2D (non-stack-like) Array instance with shape (50,50)
+ and the dims assigned above. The Array attribute .rank is equal to the
+ number of dimensions for a non-stack-like Array, and is equal to N-1
+ for stack-like arrays.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "array",
+ units: Optional[str] = "",
+ dims: Optional[list] = None,
+ dim_names: Optional[list] = None,
+ dim_units: Optional[list] = None,
+ slicelabels=None,
+ ):
+ """
+ Accepts:
+ data (np.ndarray): the data
+ name (str): the name of the Array
+ units (str): units for the pixel values
+ dims (variable): calibration vectors for each of the axes of the data
+ array. Valid values for each element of the list are None,
+ a number, a 2-element list/array, or an M-element list/array
+ where M is the data array. If None is passed, the dim will be
+ populated with integer values starting at 0 and its units will
+ be set to pixels. If a number is passed, the dim is populated
+ with a vector beginning at zero and increasing linearly by this
+ step size. If a 2-element list/array is passed, the dim is
+ populated with a linear vector with these two numbers as the first
+ two elements. If a list/array of length M is passed, this is used
+ as the dim vector, (and must therefore match this dimension's
+ length). If dims recieves a list of fewer than N arguments for an
+ N-dimensional data array, the extra dimensions are populated as if
+ None were passed, using integer pixel values. If the `dims`
+ parameter is not passed, all dim vectors are populated this way.
+ dim_units (list): the units for the calibration dim vectors. If
+ nothing is passed, dims vectors which have been populated
+ automatically with integers corresponding to pixel numbers
+ will be assigned units of 'pixels', and any other dim vectors
+ will be assigned units of 'unknown'. If a list with length <
+ the array dimensions, the passed values are assumed to apply
+ to the first N dimensions, and the remaining values are
+ populated with 'pixels' or 'unknown' as above.
+ dim_names (list): labels for each axis of the data array. Values
+ which are not passed, following the same logic as described
+ above, will be autopopulated with the name "dim#" where #
+ is the axis number.
+ slicelabels (None or True or list): if not None, must be True or a
+ list of strings, indicating a "stack-like" array. In this case,
+ the first N-1 dimensions of the array are treated normally, in
+ the sense of populating dims, dim_names, and dim_units, while the
+ final dimension is treated distinctly: it indexes into
+ distinct arrays which share a set of dimension attributes, and
+ can be sliced into using the string labels from the `slicelabels`
+ list, with the syntax array['label'] or array.get_slice('label').
+ If `slicelabels` is `True` or is a list with length less than the
+ final dimension length, unassigned dimensions are autopopulated
+ with labels `array{i}`. The flag array.is_stack is set to True
+ and the array.rank attribute is set to N-1.
+ Returns:
+ A new Array instance
+ """
+ self.data = data
+ self.name = name
+ self.units = units
+ self.dims = dims
+ self.dim_names = dim_names
+ self.dim_units = dim_units
+
+ self.tree = Tree()
+ if not hasattr(self, "_metadata"):
+ self._metadata = {}
+
+ ## Handle array stacks
+
+ if slicelabels is None:
+ self.is_stack = False
+
+ else:
+ self.is_stack = True
+
+ # Populate labels
+ if slicelabels is True:
+ slicelabels = [f"array{i}" for i in range(self.depth)]
+ elif len(slicelabels) < self.depth:
+ slicelabels = np.concatenate(
+ (
+ slicelabels,
+ [f"array{i}" for i in range(len(slicelabels), self.depth)],
+ )
+ )
+ else:
+ slicelabels = slicelabels[: self.depth]
+ slicelabels = Labels(slicelabels)
+
+ self.slicelabels = slicelabels
+
+ ## Set dim vectors
+
+ dim_in_pixels = np.zeros(
+ self.rank, dtype=bool
+ ) # flag to help assign dim names and units
+ # if none were passed
+ if self.dims is None:
+ self.dims = [self._unpack_dim(1, self.shape[n]) for n in range(self.rank)]
+ dim_in_pixels[:] = True
+
+ # if some but not all were passed
+ elif len(self.dims) < self.rank:
+ _dims = self.dims
+ N = len(_dims)
+ self.dims = []
+ for n in range(N):
+ dim = self._unpack_dim(_dims[n], self.shape[n])
+ self.dims.append(dim)
+ for n in range(N, self.rank):
+ self.dims.append(self._unpack_dim(1, self.shape[n]))
+ dim_in_pixels[n] = True
+
+ # if all were passed
+ elif len(self.dims) == self.rank:
+ _dims = self.dims
+ self.dims = []
+ for n in range(self.rank):
+ dim = self._unpack_dim(_dims[n], self.shape[n])
+ self.dims.append(dim)
+
+ # otherwise
+ else:
+ raise Exception(
+ f"too many dim vectors were passed - expected {self.rank}, received {len(self.dims)}"
+ )
+
+ ## set dim vector names
+
+ # if none were passed
+ if self.dim_names is None:
+ self.dim_names = [f"dim{n}" for n in range(self.rank)]
+
+ # if some but not all were passed
+ elif len(self.dim_names) < self.rank:
+ N = len(self.dim_names)
+ self.dim_names = [name for name in self.dim_names] + [
+ f"dim{n}" for n in range(N, self.rank)
+ ]
+
+ # if all were passed
+ elif len(self.dim_names) == self.rank:
+ pass
+
+ # otherwise
+ else:
+ raise Exception(
+ f"too many dim names were passed - expected {self.rank}, received {len(self.dim_names)}"
+ )
+
+ ## set dim vector units
+
+ # if none were passed
+ if self.dim_units is None:
+ self.dim_units = [["unknown", "pixels"][int(i)] for i in dim_in_pixels]
+
+ # if some but not all were passed
+ elif len(self.dim_units) < self.rank:
+ N = len(self.dim_units)
+ self.dim_units = [units for units in self.dim_units] + [
+ ["unknown", "pixels"][int(dim_in_pixels[i])]
+ for i in range(N, self.rank)
+ ]
+
+ # if all were passed
+ elif len(self.dim_units) == self.rank:
+ pass
+
+ # otherwise
+ else:
+ raise Exception(
+ f"too many dim units were passed - expected {self.rank}, received {len(self.dim_units)}"
+ )
+
+ # Shape properties
+
+ @property
+ def shape(self):
+ if not self.is_stack:
+ return self.data.shape
+ else:
+ return self.data.shape[:-1]
+
+ @property
+ def depth(self):
+ if not self.is_stack:
+ return 0
+ else:
+ return self.data.shape[-1]
+
+ @property
+ def rank(self):
+ if not self.is_stack:
+ return self.data.ndim
+ else:
+ return self.data.ndim - 1
+
+ ## Slicing
+
+ def get_slice(self, label, name=None):
+ idx = self.slicelabels._dict[label]
+ return Array(
+ data=self.data[..., idx],
+ name=name if name is not None else self.name + "_" + label,
+ units=self.units[:-1],
+ dims=self.dims[:-1],
+ dim_units=self.dim_units[:-1],
+ dim_names=self.dim_names[:-1],
+ )
+
+ def __getitem__(self, x):
+ if isinstance(x, str):
+ return self.get_slice(x)
+ elif isinstance(x, tuple) and isinstance(x[0], str):
+ return self.get_slice(x[0])[x[1:]]
+ else:
+ return self.data[x]
+
+ ## Dim vectors
+
+ def set_dim(
+ self,
+ n: int,
+ dim: Union[list, np.ndarray],
+ units: Optional[str] = None,
+ name: Optional[str] = None,
+ ):
+ """
+ Sets the n'th dim vector, using `dim` as described in the Array
+ documentation. If `units` and/or `name` are passed, sets these
+ values for the n'th dim vector.
+ Accepts:
+ n (int): specifies which dim vector
+ dim (list or array): length must be either 1 or 2, or equal to the
+ length of the n'th axis of the data array. If length is 1 specifies step
+ size of dim vector and starts at 0. If length is 2, specifies start
+ and step of dim vector.
+ units (Optional, str):
+ name: (Optional, str):
+ """
+ length = self.shape[n]
+ _dim = self._unpack_dim(dim, length)
+ self.dims[n] = _dim
+ if units is not None:
+ self.dim_units[n] = units
+ if name is not None:
+ self.dim_names[n] = name
+
+ @staticmethod
+ def _unpack_dim(dim, length):
+ """
+ Given a dim vector as passed at instantiation and the expected length
+ of this dimension of the array, this function checks the passed dim
+ vector length, and checks the dim vector type. For number-like dim-
+ vectors:
+ -if it is a number, turns it into the list [0,number] and proceeds
+ as below
+ -if it has length 2, linearly extends the vector to its full length
+ -if it has length `length`, returns the vector as is
+ -if it has any other length, raises an Exception.
+ For string-like dim vectors, the length must match the array dimension
+ length.
+ Accepts:
+ dim (list or array)
+ length (int)
+ Returns
+ the unpacked dim vector
+ """
+ # Expand single numbers
+ if isinstance(dim, Number):
+ dim = [0, dim]
+
+ N = len(dim)
+
+ # for string dimensions:
+ if not isinstance(dim[0], Number):
+ assert (
+ N == length
+ ), f"For non-numerical dims, the dim vector length must match the array dimension length. Recieved a dim vector of length {N} for an array dimension length of {length}."
+
+ # For number-like dimensions:
+ if N == length:
+ return dim
+ elif N == 2:
+ start, step = dim[0], dim[1] - dim[0]
+ stop = start + step * length
+ return np.arange(start, stop, step)
+ else:
+ raise Exception(
+ f"dim vector length must be either 2 or equal to the length of the corresponding array dimension; dim vector length was {dim} and the array dimension length was {length}"
+ )
+
+ def _dim_is_linear(self, dim, length):
+ """
+ Returns True if a dim is linear, else returns False
+ """
+ dim_expanded = self._unpack_dim(dim[:2], length)
+ return np.array_equal(dim, dim_expanded)
+
+ # set up metadata property
+
+ @property
+ def metadata(self):
+ return self._metadata
+
+ @metadata.setter
+ def metadata(self, x):
+ assert isinstance(x, Metadata)
+ self._metadata[x.name] = x
+
+ ## Representation to standard output
+
+ def __repr__(self):
+ if not self.is_stack:
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( A {self.rank}-dimensional array of shape {self.shape} called '{self.name}',"
+ string += "\n" + space + "with dimensions:"
+ string += "\n"
+ for n in range(self.rank):
+ # need to handle the edge case of only single value in dims i.e.line scans, 1,512,256,256
+ # check there is more than a single probe poistion
+ if self.dims[n].size < 2:
+ string += (
+ "\n"
+ + space
+ + f" {self.dim_names[n]} = [{self.dims[n][0]}] {self.dim_units[n]}"
+ )
+ else:
+ string += (
+ "\n"
+ + space
+ + f" {self.dim_names[n]} = [{self.dims[n][0]},{self.dims[n][1]},...] {self.dim_units[n]}"
+ )
+
+ else:
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( A stack of {self.depth} Arrays with {self.rank}-dimensions and shape {self.shape}, called '{self.name}'"
+ string += "\n"
+ string += "\n" + space + "The labels are:"
+ for label in self.slicelabels:
+ string += "\n" + space + f" {label}"
+ string += "\n"
+ string += "\n"
+ string += "\n" + space + "The Array dimensions are:"
+ for n in range(self.rank):
+ # need to handle the edge case of only single value in dims i.e.line scans, 1,512,256,256
+ # check there is more than a single probe poistion
+ if self.dims[n].size < 2:
+ string += (
+ "\n"
+ + space
+ + f" {self.dim_names[n]} = [{self.dims[n][0]}] {self.dim_units[n]}"
+ )
+ else:
+ string += (
+ "\n"
+ + space
+ + f" {self.dim_names[n]} = [{self.dims[n][0]},{self.dims[n][1]},...] {self.dim_units[n]}"
+ )
+ if not self._dim_is_linear(self.dims[n], self.shape[n]):
+ string += " (*non-linear*)"
+ string += "\n)"
+
+ return string
+
+ # HDF5 read/write
+
+ def to_h5(self, group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import Array_to_h5
+
+ Array_to_h5(self, group)
+
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import Array_from_h5
+
+ return Array_from_h5(group)
+
+
+########### END OF CLASS ###########
+
+
+# List subclass for accessing data slices with a dict
+class Labels(list):
+ def __init__(self, x=[]):
+ list.__init__(self, x)
+ self.setup_labels_dict()
+
+ def __setitem__(self, idx, label):
+ label_old = self[idx]
+ del self._dict[label_old]
+ list.__setitem__(self, idx, label)
+ self._dict[label] = idx
+
+ def setup_labels_dict(self):
+ self._dict = {}
+ for idx, label in enumerate(self):
+ self._dict[label] = idx
diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/io.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/io.py
new file mode 100644
index 000000000..e1b7ab241
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/io.py
@@ -0,0 +1,642 @@
+# Functions for reading and writing the base EMD types between
+# HDF5 and python classes
+
+import numpy as np
+import h5py
+from numbers import Number
+from emdfile import tqdmnd
+
+
+# Define the EMD group types
+
+EMD_group_types = {
+ "Root": "root",
+ "Metadata": 0,
+ "Array": 1,
+ "PointList": 2,
+ "PointListArray": 3,
+ "Custom": 4,
+}
+
+
+# Utility functions for finding and validating EMD groups
+
+
+def find_EMD_groups(group: h5py.Group, emd_group_type):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in read mode,
+ and finds all groups inside this group at its top level matching
+ `emd_group_type`. Does not do a nested search. Returns the names of all
+ groups found.
+ Accepts:
+ group (HDF5 group):
+ emd_group_type (int)
+ """
+ keys = [k for k in group.keys() if "emd_group_type" in group[k].attrs.keys()]
+ return [k for k in keys if group[k].attrs["emd_group_type"] == emd_group_type]
+
+
+def EMD_group_exists(group: h5py.Group, emd_group_type, name: str):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in read mode,
+ and a name. Determines if an object of this `emd_group_type` and name exists
+ inside this group, and returns a boolean.
+ Accepts:
+ group (HDF5 group):
+ emd_group_type (int):
+ name (string):
+ Returns:
+ bool
+ """
+ if name in group.keys():
+ if "emd_group_type" in group[name].attrs.keys():
+ if group[name].attrs["emd_group_type"] == emd_group_type:
+ return True
+ return False
+ return False
+ return False
+
+
+# Read and write for base EMD types
+
+
+## ROOT
+
+
+# write
+def Root_to_h5(root, group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open
+ in write or append mode. Writes a new group with a name given by
+ this Root instance's .name field nested inside the passed
+ group, and saves the data there.
+ Accepts:
+ group (HDF5 group)
+ """
+
+ ## Write
+ grp = group.create_group(root.name)
+ grp.attrs.create("emd_group_type", EMD_group_types["Root"])
+ grp.attrs.create("py4dstem_class", root.metadata.__class__.__name__)
+
+
+# read
+def Root_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in read mode,
+ and a name. Determines if a valid Root object of this name exists
+ inside this group, and if it does, loads and returns it. If it doesn't,
+ raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ A Root instance
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.root import Root
+ from os.path import basename
+
+ er = f"Group {group} is not a valid EMD Metadata group"
+ assert "emd_group_type" in group.attrs.keys(), er
+ assert group.attrs["emd_group_type"] == EMD_group_types["Root"], er
+
+ root = Root(basename(group.name))
+ return root
+
+
+## METADATA
+
+
+# write
+def Metadata_to_h5(metadata, group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open
+ in write or append mode. Writes a new group with a name given by
+ this Metadata instance's .name field nested inside the passed
+ group, and saves the data there.
+ Accepts:
+ group (HDF5 group)
+ """
+
+ ## Write
+ grp = group.create_group(metadata.name)
+ grp.attrs.create("emd_group_type", EMD_group_types["Metadata"])
+ grp.attrs.create("py4dstem_class", metadata.__class__.__name__)
+
+ # Save data
+ for k, v in metadata._params.items():
+ # None
+ if v is None:
+ v = "_None"
+ v = np.string_(v) # convert to byte string
+ dset = grp.create_dataset(k, data=v)
+ dset.attrs["type"] = np.string_("None")
+
+ # strings
+ elif isinstance(v, str):
+ v = np.string_(v) # convert to byte string
+ dset = grp.create_dataset(k, data=v)
+ dset.attrs["type"] = np.string_("string")
+
+ # bools
+ elif isinstance(v, bool):
+ dset = grp.create_dataset(k, data=v, dtype=bool)
+ dset.attrs["type"] = np.string_("bool")
+
+ # numbers
+ elif isinstance(v, Number):
+ dset = grp.create_dataset(k, data=v, dtype=type(v))
+ dset.attrs["type"] = np.string_("number")
+
+ # arrays
+ elif isinstance(v, np.ndarray):
+ dset = grp.create_dataset(k, data=v, dtype=v.dtype)
+ dset.attrs["type"] = np.string_("array")
+
+ # tuples
+ elif isinstance(v, tuple):
+ # of numbers
+ if isinstance(v[0], Number):
+ dset = grp.create_dataset(k, data=v)
+ dset.attrs["type"] = np.string_("tuple")
+
+ # of tuples
+ elif any([isinstance(v[i], tuple) for i in range(len(v))]):
+ dset_grp = grp.create_group(k)
+ dset_grp.attrs["type"] = np.string_("tuple_of_tuples")
+ dset_grp.attrs["length"] = len(v)
+ for i, x in enumerate(v):
+ dset_grp.create_dataset(str(i), data=x)
+
+ # of arrays
+ elif isinstance(v[0], np.ndarray):
+ dset_grp = grp.create_group(k)
+ dset_grp.attrs["type"] = np.string_("tuple_of_arrays")
+ dset_grp.attrs["length"] = len(v)
+ for i, ar in enumerate(v):
+ dset_grp.create_dataset(str(i), data=ar, dtype=ar.dtype)
+
+ # of strings
+ elif isinstance(v[0], str):
+ dset_grp = grp.create_group(k)
+ dset_grp.attrs["type"] = np.string_("tuple_of_strings")
+ dset_grp.attrs["length"] = len(v)
+ for i, s in enumerate(v):
+ dset_grp.create_dataset(str(i), data=np.string_(s))
+
+ else:
+ er = f"Metadata only supports writing tuples with numeric and array-like arguments; found type {type(v[0])}"
+ raise Exception(er)
+
+ # lists
+ elif isinstance(v, list):
+ # of numbers
+ if isinstance(v[0], Number):
+ dset = grp.create_dataset(k, data=v)
+ dset.attrs["type"] = np.string_("list")
+
+ # of arrays
+ elif isinstance(v[0], np.ndarray):
+ dset_grp = grp.create_group(k)
+ dset_grp.attrs["type"] = np.string_("list_of_arrays")
+ dset_grp.attrs["length"] = len(v)
+ for i, ar in enumerate(v):
+ dset_grp.create_dataset(str(i), data=ar, dtype=ar.dtype)
+
+ # of strings
+ elif isinstance(v[0], str):
+ dset_grp = grp.create_group(k)
+ dset_grp.attrs["type"] = np.string_("list_of_strings")
+ dset_grp.attrs["length"] = len(v)
+ for i, s in enumerate(v):
+ dset_grp.create_dataset(str(i), data=np.string_(s))
+
+ else:
+ er = f"Metadata only supports writing lists with numeric and array-like arguments; found type {type(v[0])}"
+ raise Exception(er)
+
+ else:
+ er = f"Metadata supports writing numbers, bools, strings, arrays, tuples of numbers or arrays, and lists of numbers or arrays. Found an unsupported type {type(v[0])}"
+ raise Exception(er)
+
+
+# read
+def Metadata_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in read mode,
+ and a name. Determines if a valid Metadata object of this name exists
+ inside this group, and if it does, loads and returns it. If it doesn't,
+ raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ A Metadata instance
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.metadata import Metadata
+ from os.path import basename
+
+ er = f"Group {group} is not a valid EMD Metadata group"
+ assert "emd_group_type" in group.attrs.keys(), er
+ assert group.attrs["emd_group_type"] == EMD_group_types["Metadata"], er
+
+ # Get data
+ data = {}
+ for k, v in group.items():
+ # get type
+ try:
+ t = group[k].attrs["type"].decode("utf-8")
+ except KeyError:
+ raise Exception(f"unrecognized Metadata value type {type(v)}")
+
+ # None
+ if t == "None":
+ v = None
+
+ # strings
+ elif t == "string":
+ v = v[...].item()
+ v = v.decode("utf-8")
+ v = v if v != "_None" else None
+
+ # numbers
+ elif t == "number":
+ v = v[...].item()
+
+ # bools
+ elif t == "bool":
+ v = v[...].item()
+
+ # array
+ elif t == "array":
+ v = np.array(v)
+
+ # tuples of numbers
+ elif t == "tuple":
+ v = tuple(v[...])
+
+ # tuples of arrays
+ elif t == "tuple_of_arrays":
+ L = group[k].attrs["length"]
+ tup = []
+ for l in range(L):
+ tup.append(np.array(v[str(l)]))
+ v = tuple(tup)
+
+ # tuples of tuples
+ elif t == "tuple_of_tuples":
+ L = group[k].attrs["length"]
+ tup = []
+ for l in range(L):
+ x = v[str(l)][...]
+ if x.ndim == 0:
+ x = x.item()
+ else:
+ x = tuple(x)
+ tup.append(x)
+ v = tuple(tup)
+
+ # tuples of strings
+ elif t == "tuple_of_strings":
+ L = group[k].attrs["length"]
+ tup = []
+ for l in range(L):
+ s = v[str(l)][...].item().decode("utf-8")
+ tup.append(s)
+ v = tuple(tup)
+
+ # lists of numbers
+ elif t == "list":
+ v = list(v[...])
+
+ # lists of arrays
+ elif t == "list_of_arrays":
+ L = group[k].attrs["length"]
+ _list = []
+ for l in range(L):
+ _list.append(np.array(v[str(l)]))
+ v = _list
+
+ # list of strings
+ elif t == "list_of_strings":
+ L = group[k].attrs["length"]
+ _list = []
+ for l in range(L):
+ s = v[str(l)][...].item().decode("utf-8")
+ _list.append(s)
+ v = _list
+
+ else:
+ raise Exception(f"unrecognized Metadata value type {t}")
+
+ # add data
+ data[k] = v
+
+ # make Metadata instance, add data, and return
+ md = Metadata(basename(group.name))
+ md._params.update(data)
+ return md
+
+
+## ARRAY
+
+# write
+
+
+def Array_to_h5(array, group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ write or append mode. Writes a new group with a name given by this
+ Array's .name field nested inside the passed group, and saves the
+ data there.
+ Accepts:
+ group (HDF5 group)
+ """
+
+ ## Write
+
+ grp = group.create_group(array.name)
+ grp.attrs.create("emd_group_type", 1) # this tag indicates an Array
+ grp.attrs.create("py4dstem_class", array.__class__.__name__)
+
+ # add the data
+ data = grp.create_dataset(
+ "data",
+ shape=array.data.shape,
+ data=array.data
+ # dtype = type(array.data)
+ )
+ data.attrs.create(
+ "units", array.units
+ ) # save 'units' but not 'name' - 'name' is the group name
+
+ # Add the normal dim vectors
+ for n in range(array.rank):
+ # unpack info
+ dim = array.dims[n]
+ name = array.dim_names[n]
+ units = array.dim_units[n]
+ is_linear = array._dim_is_linear(dim, array.shape[n])
+
+ # compress the dim vector if it's linear
+ if is_linear:
+ dim = dim[:2]
+
+ # write
+ dset = grp.create_dataset(f"dim{n}", data=dim)
+ dset.attrs.create("name", name)
+ dset.attrs.create("units", units)
+
+ # Add stack dim vector, if present
+ if array.is_stack:
+ n = array.rank
+ name = "_labels_"
+ dim = [s.encode("utf-8") for s in array.slicelabels]
+
+ # write
+ dset = grp.create_dataset(f"dim{n}", data=dim)
+ dset.attrs.create("name", name)
+
+ # Add metadata
+ _write_metadata(array, grp)
+
+
+## read
+
+
+def Array_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in read mode.
+ Determines if this group represents an Array object and if it does, loads
+ returns it. If it doesn't, raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ An Array instance
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.array import Array
+ from os.path import basename
+
+ er = f"Group {group} is not a valid EMD Array group"
+ assert "emd_group_type" in group.attrs.keys(), er
+ assert group.attrs["emd_group_type"] == EMD_group_types["Array"], er
+
+ # get data
+ dset = group["data"]
+ data = dset[:]
+ units = dset.attrs["units"]
+ rank = len(data.shape)
+
+ # determine if this is a stack array
+ last_dim = group[f"dim{rank-1}"]
+ if last_dim.attrs["name"] == "_labels_":
+ is_stack = True
+ normal_dims = rank - 1
+ else:
+ is_stack = False
+ normal_dims = rank
+
+ # get dim vectors
+ dims = []
+ dim_units = []
+ dim_names = []
+ for n in range(normal_dims):
+ dim_dset = group[f"dim{n}"]
+ dims.append(dim_dset[:])
+ dim_units.append(dim_dset.attrs["units"])
+ dim_names.append(dim_dset.attrs["name"])
+
+ # if it's a stack array, get the labels
+ if is_stack:
+ slicelabels = last_dim[:]
+ slicelabels = [s.decode("utf-8") for s in slicelabels]
+ else:
+ slicelabels = None
+
+ # make Array
+ ar = Array(
+ data=data,
+ name=basename(group.name),
+ units=units,
+ dims=dims,
+ dim_names=dim_names,
+ dim_units=dim_units,
+ slicelabels=slicelabels,
+ )
+
+ # add metadata
+ _read_metadata(ar, group)
+
+ return ar
+
+
+## POINTLIST
+
+
+# write
+def PointList_to_h5(pointlist, group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ write or append mode. Writes a new group with a name given by this
+ PointList's .name field nested inside the passed group, and saves
+ the data there.
+ Accepts:
+ group (HDF5 group)
+ """
+
+ ## Write
+ grp = group.create_group(pointlist.name)
+ grp.attrs.create("emd_group_type", 2) # this tag indicates a PointList
+ grp.attrs.create("py4dstem_class", pointlist.__class__.__name__)
+
+ # Add data
+ for f, t in zip(pointlist.fields, pointlist.types):
+ group_current_field = grp.create_dataset(f, data=pointlist.data[f])
+ group_current_field.attrs.create("dtype", np.string_(t))
+ # group_current_field.create_dataset(
+ # "data",
+ # data = pointlist.data[f]
+ # )
+
+ # Add metadata
+ _write_metadata(pointlist, grp)
+
+
+# read
+def PointList_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in read mode,
+ and a name. Determines if a valid PointList object of this name exists inside
+ this group, and if it does, loads and returns it. If it doesn't, raises
+ an exception.
+ Accepts:
+ group (HDF5 group)
+ name (string)
+ Returns:
+ A PointList instance
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.pointlist import PointList
+ from os.path import basename
+
+ er = f"Group {group} is not a valid EMD PointList group"
+ assert "emd_group_type" in group.attrs.keys(), er
+ assert group.attrs["emd_group_type"] == EMD_group_types["PointList"], er
+
+ # Get metadata
+ fields = list(group.keys())
+ if "_metadata" in fields:
+ fields.remove("_metadata")
+ dtype = []
+ for field in fields:
+ curr_dtype = group[field].attrs["dtype"].decode("utf-8")
+ dtype.append((field, curr_dtype))
+ length = len(group[fields[0]])
+
+ # Get data
+ data = np.zeros(length, dtype=dtype)
+ if length > 0:
+ for field in fields:
+ data[field] = np.array(group[field])
+
+ # Make the PointList
+ pl = PointList(data=data, name=basename(group.name))
+
+ # Add additional metadata
+ _read_metadata(pl, group)
+
+ return pl
+
+
+## POINTLISTARRAY
+
+
+# write
+def PointListArray_to_h5(pointlistarray, group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ write or append mode. Writes a new group with a name given by this
+ PointListArray's .name field nested inside the passed group, and
+ saves the data there.
+ Accepts:
+ group (HDF5 group)
+ """
+
+ ## Write
+ grp = group.create_group(pointlistarray.name)
+ grp.attrs.create("emd_group_type", 3) # this tag indicates a PointListArray
+ grp.attrs.create("py4dstem_class", pointlistarray.__class__.__name__)
+
+ # Add metadata
+ dtype = h5py.special_dtype(vlen=pointlistarray.dtype)
+ dset = grp.create_dataset("data", pointlistarray.shape, dtype)
+
+ # Add data
+ for i, j in tqdmnd(dset.shape[0], dset.shape[1]):
+ dset[i, j] = pointlistarray[i, j].data
+
+ # Add additional metadata
+ _write_metadata(pointlistarray, grp)
+
+
+# read
+def PointListArray_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in read mode,
+ and a name. Determines if a valid PointListArray object of this name exists
+ inside this group, and if it does, loads and returns it. If it doesn't,
+ raises an exception.
+ Accepts:
+ group (HDF5 group)
+ name (string)
+ Returns:
+ A PointListArray instance
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.pointlistarray import (
+ PointListArray,
+ )
+ from os.path import basename
+
+ er = f"Group {group} is not a valid EMD PointListArray group"
+ assert "emd_group_type" in group.attrs.keys(), er
+ assert group.attrs["emd_group_type"] == EMD_group_types["PointListArray"], er
+
+ # Get the DataSet
+ dset = group["data"]
+ dtype = h5py.check_vlen_dtype(dset.dtype)
+ shape = dset.shape
+
+ # Initialize a PointListArray
+ pla = PointListArray(dtype=dtype, shape=shape, name=basename(group.name))
+
+ # Add data
+ for i, j in tqdmnd(
+ shape[0], shape[1], desc="Reading PointListArray", unit="PointList"
+ ):
+ try:
+ pla[i, j].add(dset[i, j])
+ except ValueError:
+ pass
+
+ # Add metadata
+ _read_metadata(pla, group)
+
+ return pla
+
+
+# Metadata helper functions
+
+
+def _write_metadata(obj, grp):
+ items = obj._metadata.items()
+ if len(items) > 0:
+ grp_metadata = grp.create_group("_metadata")
+ for name, md in items:
+ obj._metadata[name].name = name
+ obj._metadata[name].to_h5(grp_metadata)
+
+
+def _read_metadata(obj, grp):
+ try:
+ grp_metadata = grp["_metadata"]
+ for key in grp_metadata.keys():
+ obj.metadata = Metadata_from_h5(grp_metadata[key])
+ except KeyError:
+ pass
diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/metadata.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/metadata.py
new file mode 100644
index 000000000..d430528e1
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/metadata.py
@@ -0,0 +1,80 @@
+import numpy as np
+from numbers import Number
+from typing import Optional
+import h5py
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.tree import Tree
+
+
+class Metadata:
+ """
+ Stores metadata in the form of a flat (non-nested) dictionary.
+ Keys are arbitrary strings. Values may be strings, numbers, arrays,
+ or lists of the above types.
+ Usage:
+ >>> meta = Metadata()
+ >>> meta['param'] = value
+ >>> val = meta['param']
+ If the parameter has not been set, the getter methods return None.
+ """
+
+ def __init__(self, name: Optional[str] = "metadata"):
+ """
+ Args:
+ name (Optional, string):
+ """
+ self.name = name
+ self.tree = Tree()
+
+ # create parameter dictionary
+ self._params = {}
+
+ ### __get/setitem__
+
+ def __getitem__(self, x):
+ return self._params[x]
+
+ def __setitem__(self, k, v):
+ self._params[k] = v
+
+ @property
+ def keys(self):
+ return self._params.keys()
+
+ def copy(self, name=None):
+ """ """
+ if name is None:
+ name = self.name + "_copy"
+ md = Metadata(name=name)
+ md._params.update(self._params)
+ return md
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( A Metadata instance called '{self.name}', containing the following fields:"
+ string += "\n"
+
+ maxlen = 0
+ for k in self._params.keys():
+ if len(k) > maxlen:
+ maxlen = len(k)
+
+ for k, v in self._params.items():
+ if isinstance(v, np.ndarray):
+ v = f"{v.ndim}D-array"
+ string += "\n" + space + f"{k}:{(maxlen-len(k)+3)*' '}{str(v)}"
+ string += "\n)"
+
+ return string
+
+ # HDF5 read/write
+
+ def to_h5(self, group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import Metadata_to_h5
+
+ Metadata_to_h5(self, group)
+
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import Metadata_from_h5
+
+ return Metadata_from_h5(group)
diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/pointlist.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/pointlist.py
new file mode 100644
index 000000000..c7f0c7fc1
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/pointlist.py
@@ -0,0 +1,193 @@
+# Defines a class, PointList, for storing / accessing / manipulating data
+# in the form of lists of vectors in named dimensions. Wraps numpy
+# structured arrays.
+
+import numpy as np
+import h5py
+from copy import copy
+from typing import Optional
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.tree import Tree
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.metadata import Metadata
+
+
+class PointList:
+ """
+ A wrapper around structured numpy arrays, with read/write functionality in/out of
+ py4DSTEM formatted HDF5 files.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "pointlist",
+ ):
+ """
+ Instantiate a PointList.
+ Args:
+ data (structured numpy ndarray): the data; the dtype of this array will
+ specify the fields of the PointList.
+ name (str): name for the PointList
+ Returns:
+ a PointList instance
+ """
+ self.data = data
+ self.name = name
+
+ self._dtype = self.data.dtype
+ self._fields = self.data.dtype.names
+ self._types = tuple([self.data.dtype.fields[f][0] for f in self.fields])
+
+ self.tree = Tree()
+ if not hasattr(self, "_metadata"):
+ self._metadata = {}
+
+ # properties
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @dtype.setter
+ def dtype(self, dtype):
+ self._dtype = dtype
+
+ @property
+ def fields(self):
+ return self._fields
+
+ @fields.setter
+ def fields(self, x):
+ self.data.dtype.names = x
+ self._fields = x
+
+ @property
+ def types(self):
+ return self._types
+
+ @property
+ def length(self):
+ return np.atleast_1d(self.data).shape[0]
+
+ ## Add, remove, sort data
+
+ def add(self, data):
+ """
+ Appends a numpy structured array. Its dtypes must agree with the existing data.
+ """
+ assert self.dtype == data.dtype, "Error: dtypes must agree"
+ self.data = np.append(self.data, data)
+
+ def remove(self, mask):
+ """Removes points wherever mask==True"""
+ assert (
+ np.atleast_1d(mask).shape[0] == self.length
+ ), "deletemask must be same length as the data"
+ inds = mask.nonzero()[0]
+ self.data = np.delete(self.data, inds)
+
+ def sort(self, field, order="descending"):
+ """
+ Sorts the point list according to field,
+ which must be a field in self.dtype.
+ order should be 'descending' or 'ascending'.
+ """
+ assert field in self.fields
+ assert (order == "descending") or (order == "ascending")
+ if order == "ascending":
+ self.data = np.sort(self.data, order=field)
+ else:
+ self.data = np.sort(self.data, order=field)[::-1]
+
+ ## Copy, copy+modify PointList
+
+ def copy(self, name=None):
+ """Returns a copy of the PointList. If name=None, sets to `{name}_copy`"""
+ name = name if name is not None else self.name + "_copy"
+
+ pl = PointList(data=np.copy(self.data), name=name)
+
+ for k, v in self.metadata.items():
+ pl.metadata = v.copy(name=k)
+
+ return pl
+
+ def add_fields(self, new_fields, name=""):
+ """
+ Creates a copy of the PointList, but with additional fields given by new_fields.
+ Args:
+ new_fields: a list of 2-tuples, ('name', dtype)
+ name: a name for the new pointlist
+ """
+ dtype = []
+ for f, t in zip(self.fields, self.types):
+ dtype.append((f, t))
+ for f, t in new_fields:
+ dtype.append((f, t))
+
+ data = np.zeros(self.length, dtype=dtype)
+ for f in self.fields:
+ data[f] = np.copy(self.data[f])
+
+ return PointList(data=data, name=name)
+
+ def add_data_by_field(self, data, fields=None):
+ """
+ Add a list of data arrays to the PointList, in the fields
+ given by `fields`. If `fields` is not specified, assumes the data
+ arrays are in the same order as self.fields
+ Args:
+ data (list): arrays of data to add to each field
+ """
+
+ if data[0].ndim == 0:
+ L = (1,)
+ else:
+ L = data[0].shape[0]
+ newdata = np.zeros(L, dtype=self.dtype)
+
+ _fields = self.fields if fields is None else fields
+
+ for d, f in zip(data, _fields):
+ newdata[f] = d
+
+ self.data = np.append(self.data, newdata)
+
+ # set up metadata property
+
+ @property
+ def metadata(self):
+ return self._metadata
+
+ @metadata.setter
+ def metadata(self, x):
+ assert isinstance(x, Metadata)
+ self._metadata[x.name] = x
+
+ ## Representation to standard output
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( A length {self.length} PointList called '{self.name}',"
+ string += "\n" + space + f"with {len(self.fields)} fields:"
+ string += "\n"
+ space2 = max([len(field) for field in self.fields]) + 3
+ for f, t in zip(self.fields, self.types):
+ string += "\n" + space + f"{f}{(space2-len(f))*' '}({str(t)})"
+ string += "\n)"
+
+ return string
+
+ # Slicing
+ def __getitem__(self, v):
+ return self.data[v]
+
+ # HDF5 read/write
+
+ def to_h5(self, group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import PointList_to_h5
+
+ PointList_to_h5(self, group)
+
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import PointList_from_h5
+
+ return PointList_from_h5(group)
diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/pointlistarray.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/pointlistarray.py
new file mode 100644
index 000000000..c246672bd
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/pointlistarray.py
@@ -0,0 +1,160 @@
+import numpy as np
+from copy import copy
+from typing import Optional
+import h5py
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.tree import Tree
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.metadata import Metadata
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.pointlist import PointList
+
+
+class PointListArray:
+ """
+ An 2D array of PointLists which share common coordinates.
+ """
+
+ def __init__(
+ self,
+ dtype,
+ shape,
+ name: Optional[str] = "pointlistarray",
+ ):
+ """
+ Creates an empty PointListArray.
+ Args:
+ dtype: the dtype of the numpy structured arrays which will comprise
+ the data of each PointList
+ shape (2-tuple of ints): the shape of the array of PointLists
+ name (str): a name for the PointListArray
+ Returns:
+ a PointListArray instance
+ """
+ assert len(shape) == 2, "Shape must have length 2."
+
+ self.name = name
+ self.shape = shape
+
+ self.dtype = np.dtype(dtype)
+ self.fields = self.dtype.names
+ self.types = tuple([self.dtype.fields[f][0] for f in self.fields])
+
+ self.tree = Tree()
+ if not hasattr(self, "_metadata"):
+ self._metadata = {}
+
+ # Populate with empty PointLists
+ self._pointlists = [
+ [
+ PointList(data=np.zeros(0, dtype=self.dtype), name=f"{i},{j}")
+ for j in range(self.shape[1])
+ ]
+ for i in range(self.shape[0])
+ ]
+
+ ## get/set pointlists
+
+ def get_pointlist(self, i, j, name=None):
+ """
+ Returns the pointlist at i,j
+ """
+ pl = self._pointlists[i][j]
+ if name is not None:
+ pl = pl.copy(name=name)
+ return pl
+
+ def __getitem__(self, tup):
+ l = len(tup) if isinstance(tup, tuple) else 1
+ assert l == 2, f"Expected 2 slice values, recieved {l}"
+ return self.get_pointlist(tup[0], tup[1])
+
+ def __setitem__(self, tup, pointlist):
+ l = len(tup) if isinstance(tup, tuple) else 1
+ assert l == 2, f"Expected 2 slice values, recieved {l}"
+ assert pointlist.fields == self.fields, "fields must match"
+ self._pointlists[tup[0]][tup[1]] = pointlist
+
+ ## Make copies
+
+ def copy(self, name=""):
+ """
+ Returns a copy of itself.
+ """
+ new_pla = PointListArray(dtype=self.dtype, shape=self.shape, name=name)
+
+ for i in range(new_pla.shape[0]):
+ for j in range(new_pla.shape[1]):
+ pl = new_pla.get_pointlist(i, j)
+ pl.add(np.copy(self.get_pointlist(i, j).data))
+
+ for k, v in self.metadata.items():
+ new_pla.metadata = v.copy(name=k)
+
+ return new_pla
+
+ def add_fields(self, new_fields, name=""):
+ """
+ Creates a copy of the PointListArray, but with additional fields given
+ by new_fields.
+ Args:
+ new_fields: a list of 2-tuples, ('name', dtype)
+ name: a name for the new pointlist
+ """
+ dtype = []
+ for f, t in zip(self.fields, self.types):
+ dtype.append((f, t))
+ for f, t in new_fields:
+ dtype.append((f, t))
+
+ new_pla = PointListArray(dtype=dtype, shape=self.shape, name=name)
+
+ for i in range(new_pla.shape[0]):
+ for j in range(new_pla.shape[1]):
+ # Copy old data into a new structured array
+ pl_old = self.get_pointlist(i, j)
+ data = np.zeros(pl_old.length, np.dtype(dtype))
+ for f in self.fields:
+ data[f] = np.copy(pl_old.data[f])
+
+ # Write into new pointlist
+ pl_new = new_pla.get_pointlist(i, j)
+ pl_new.add(data)
+
+ return new_pla
+
+ # set up metadata property
+
+ @property
+ def metadata(self):
+ return self._metadata
+
+ @metadata.setter
+ def metadata(self, x):
+ assert isinstance(x, Metadata)
+ self._metadata[x.name] = x
+
+ ## Representation to standard output
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( A shape {self.shape} PointListArray called '{self.name}',"
+ string += "\n" + space + f"with {len(self.fields)} fields:"
+ string += "\n"
+ space2 = max([len(field) for field in self.fields]) + 3
+ for f, t in zip(self.fields, self.types):
+ string += "\n" + space + f"{f}{(space2-len(f))*' '}({str(t)})"
+ string += "\n)"
+
+ return string
+
+ # HDF5 read/write
+
+ def to_h5(self, group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import PointListArray_to_h5
+
+ PointListArray_to_h5(self, group)
+
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import (
+ PointListArray_from_h5,
+ )
+
+ return PointListArray_from_h5(group)
diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/root.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/root.py
new file mode 100644
index 000000000..c5137d9f4
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/root.py
@@ -0,0 +1,53 @@
+import numpy as np
+from numbers import Number
+from typing import Optional
+import h5py
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.tree import Tree
+
+
+class Root:
+ """
+ A class serving as a container for Trees
+ """
+
+ def __init__(self, name: Optional[str] = "root"):
+ """
+ Args:
+ name (Optional, string):
+ """
+ self.name = name
+ self.tree = Tree()
+
+ ### __get/setitem__
+
+ def __getitem__(self, x):
+ return self.tree[x]
+
+ def __setitem__(self, k, v):
+ self.tree[k] = v
+
+ @property
+ def keys(self):
+ return self.tree.keys()
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( A Root instance called '{self.name}', containing the following top-level object instances:"
+ string += "\n"
+ for k, v in self.tree._tree.items():
+ string += "\n" + space + f" {k} \t\t ({v.__class__.__name__})"
+ string += "\n)"
+ return string
+
+ # HDF5 read/write
+
+ def to_h5(self, group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import Root_to_h5
+
+ Root_to_h5(self, group)
+
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import Root_from_h5
+
+ return Root_from_h5(group)
diff --git a/py4DSTEM/io/legacy/legacy13/v13_emd_classes/tree.py b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/tree.py
new file mode 100644
index 000000000..51f124122
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_emd_classes/tree.py
@@ -0,0 +1,79 @@
+# Defines the Tree class, which stores the tree of child data
+# class instances contained by a given class instance
+
+
+class Tree:
+ def __init__(self):
+ self._tree = {}
+
+ def __setitem__(self, key, value):
+ self._tree[key] = value
+
+ def __getitem__(self, x):
+ l = x.split("/")
+ try:
+ l.remove("")
+ l.remove("")
+ except ValueError:
+ pass
+ return self._getitem_from_list(l)
+
+ def _getitem_from_list(self, x):
+ if len(x) == 0:
+ raise Exception("invalid slice value to tree")
+
+ k = x.pop(0)
+ er = f"{k} not found in tree - check keys"
+ assert k in self._tree.keys(), er
+
+ if len(x) == 0:
+ return self._tree[k]
+ else:
+ tree = self._tree[k].tree
+ return tree._getitem_from_list(x)
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( An object tree containing the following top-level object instances:"
+ string += "\n"
+ for k, v in self._tree.items():
+ string += "\n" + space + f" {k} \t\t ({v.__class__.__name__})"
+ string += "\n)"
+ return string
+
+ def keys(self):
+ return self._tree.keys()
+
+ def print(self):
+ """
+ Prints the tree contents to screen.
+ """
+ print("/")
+ self._print_tree_to_screen(self)
+ print("\n")
+
+ def _print_tree_to_screen(self, tree, tablevel=0, linelevels=[]):
+ """ """
+ if tablevel not in linelevels:
+ linelevels.append(tablevel)
+ keys = [k for k in tree.keys()]
+ # keys = [k for k in keys if k != 'metadata']
+ N = len(keys)
+ for i, k in enumerate(keys):
+ string = ""
+ string += "|" if 0 in linelevels else ""
+ for idx in range(tablevel):
+ l = "|" if idx + 1 in linelevels else ""
+ string += "\t" + l
+ # print(string)
+ print(string + "--" + k)
+ if i == N - 1:
+ linelevels.remove(tablevel)
+ try:
+ self._print_tree_to_screen(
+ tree[k].tree, tablevel=tablevel + 1, linelevels=linelevels
+ )
+ except AttributeError:
+ pass
+
+ pass
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/__init__.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/__init__.py
new file mode 100644
index 000000000..19317f040
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/__init__.py
@@ -0,0 +1,10 @@
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.calibration import *
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.datacube import *
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.probe import *
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.diffractionslice import *
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.virtualdiffraction import *
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.realslice import *
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.virtualimage import *
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.qpoints import *
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.braggvectors import *
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import *
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/braggvectors.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/braggvectors.py
new file mode 100644
index 000000000..4e51bdebf
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/braggvectors.py
@@ -0,0 +1,95 @@
+# Defines the BraggVectors class
+
+
+from typing import Optional, Union
+import numpy as np
+import h5py
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes import PointListArray
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.tree import Tree
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.metadata import Metadata
+
+
+class BraggVectors:
+ """
+ Stores bragg scattering information for a 4D datacube.
+ >>> braggvectors = BraggVectors( datacube.Rshape, datacube.Qshape )
+ initializes an instance of the appropriate shape for a DataCube `datacube`.
+ >>> braggvectors.v[rx,ry]
+ >>> braggvectors.v_uncal[rx,ry]
+ retrieve, respectively, the calibrated and uncalibrated bragg vectors at
+ scan position [rx,ry], and
+ >>> braggvectors.v[rx,ry]['qx']
+ >>> braggvectors.v[rx,ry]['qy']
+ >>> braggvectors.v[rx,ry]['intensity']
+ retrieve the positiona and intensity of the scattering.
+ """
+
+ def __init__(self, Rshape, Qshape, name="braggvectors"):
+ self.name = name
+ self.Rshape = Rshape
+ self.shape = self.Rshape
+ self.Qshape = Qshape
+
+ self.tree = Tree()
+ if not hasattr(self, "_metadata"):
+ self._metadata = {}
+ if "braggvectors" not in self._metadata.keys():
+ self.metadata = Metadata(name="braggvectors")
+ self.metadata["braggvectors"]["Qshape"] = self.Qshape
+
+ self._v_uncal = PointListArray(
+ dtype=[("qx", np.float64), ("qy", np.float64), ("intensity", np.float64)],
+ shape=Rshape,
+ name="_v_uncal",
+ )
+
+ @property
+ def vectors(self):
+ try:
+ return self._v_cal
+ except AttributeError:
+ er = "No calibrated bragg vectors found. Try running .calibrate()!"
+ raise Exception(er)
+
+ @property
+ def vectors_uncal(self):
+ return self._v_uncal
+
+ @property
+ def metadata(self):
+ return self._metadata
+
+ @metadata.setter
+ def metadata(self, x):
+ assert isinstance(x, Metadata)
+ self._metadata[x.name] = x
+
+ ## Representation to standard output
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( "
+ string += f"A {self.shape}-shaped array of lists of bragg vectors )"
+ return string
+
+ # HDF5 read/write
+
+ # write
+ def to_h5(self, group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import (
+ BraggVectors_to_h5,
+ )
+
+ BraggVectors_to_h5(self, group)
+
+ # read
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import (
+ BraggVectors_from_h5,
+ )
+
+ return BraggVectors_from_h5(group)
+
+
+############ END OF CLASS ###########
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/calibration.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/calibration.py
new file mode 100644
index 000000000..cbf4cd1fe
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/calibration.py
@@ -0,0 +1,64 @@
+# Defines the Calibration class, which stores calibration metadata
+
+from typing import Optional
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.metadata import Metadata
+
+
+class Calibration(Metadata):
+ """ """
+
+ def __init__(
+ self,
+ name: Optional[str] = "calibration",
+ ):
+ """
+ Args:
+ name (optional, str):
+ """
+ Metadata.__init__(self, name=name)
+
+ self.set_Q_pixel_size(1)
+ self.set_R_pixel_size(1)
+ self.set_Q_pixel_units("pixels")
+ self.set_R_pixel_units("pixels")
+
+ def set_Q_pixel_size(self, x):
+ self._params["Q_pixel_size"] = x
+
+ def get_Q_pixel_size(self):
+ return self._get_value("Q_pixel_size")
+
+ def set_R_pixel_size(self, x):
+ self._params["R_pixel_size"] = x
+
+ def get_R_pixel_size(self):
+ return self._get_value("R_pixel_size")
+
+ def set_Q_pixel_units(self, x):
+ pix = ("pixels", "A^-1", "mrad")
+ assert x in pix, f"{x} must be in {pix}"
+ self._params["Q_pixel_units"] = x
+
+ def get_Q_pixel_units(self):
+ return self._get_value("Q_pixel_units")
+
+ def set_R_pixel_units(self, x):
+ self._params["R_pixel_units"] = x
+
+ def get_R_pixel_units(self):
+ return self._get_value("R_pixel_units")
+
+ # HDF5 read/write
+
+ # write inherited from Metadata
+
+ # read
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import (
+ Calibration_from_h5,
+ )
+
+ return Calibration_from_h5(group)
+
+
+########## End of class ##########
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/datacube.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/datacube.py
new file mode 100644
index 000000000..422d47bc6
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/datacube.py
@@ -0,0 +1,180 @@
+# Defines the DataCube class, which stores 4D-STEM datacubes
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.array import Array
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.calibration import Calibration
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.parenttree import ParentTree
+
+from typing import Optional, Union
+import numpy as np
+import h5py
+
+
+class DataCube(Array):
+ """
+ Stores 4D-STEM datasets.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "datacube",
+ R_pixel_size: Optional[Union[float, list]] = 1,
+ R_pixel_units: Optional[Union[str, list]] = "pixels",
+ Q_pixel_size: Optional[Union[float, list]] = 1,
+ Q_pixel_units: Optional[Union[str, list]] = "pixels",
+ slicelabels: Optional[Union[bool, list]] = None,
+ calibration: Optional = None,
+ ):
+ """
+ Accepts:
+ data (np.ndarray): the data
+ name (str): the name of the datacube
+ R_pixel_size (float or length 2 list of floats): the real space
+ pixel size
+ R_pixel_units (str or length 2 list of str): the real space
+ pixel units
+ Q_pixel_size (float or length 2 list of str): the diffraction space
+ pixel size
+ Q_pixel_units (str or length 2 list of str): the diffraction space
+ pixel units. Must be 'pixels' or 'A^-1'.
+ slicelabels (None or list): names for slices if this is a
+ stack of datacubes
+ calibration (Calibration):
+ Returns:
+ A new DataCube instance
+ """
+
+ # initialize as an Array
+ Array.__init__(
+ self,
+ data=data,
+ name=name,
+ units="pixel intensity",
+ dims=[R_pixel_size, R_pixel_size, Q_pixel_size, Q_pixel_size],
+ dim_units=[R_pixel_units, R_pixel_units, Q_pixel_units, Q_pixel_units],
+ dim_names=["Rx", "Ry", "Qx", "Qy"],
+ slicelabels=slicelabels,
+ )
+
+ # make a tree
+ # we're overwriting the emd Tree with the py4DSTEM Tree
+ # which knows how to track the parent datacube
+ # also adds calibration to the tree
+ self.tree = ParentTree(self, Calibration())
+
+ # set size/units
+ self.tree["calibration"].set_R_pixel_size(R_pixel_size)
+ self.tree["calibration"].set_R_pixel_units(R_pixel_units)
+ self.tree["calibration"].set_Q_pixel_size(Q_pixel_size)
+ self.tree["calibration"].set_Q_pixel_units(Q_pixel_units)
+
+ ## properties
+
+ # FOV
+ @property
+ def R_Nx(self):
+ return self.data.shape[0]
+
+ @property
+ def R_Ny(self):
+ return self.data.shape[1]
+
+ @property
+ def Q_Nx(self):
+ return self.data.shape[2]
+
+ @property
+ def Q_Ny(self):
+ return self.data.shape[3]
+
+ @property
+ def Rshape(self):
+ return (self.data.shape[0], self.data.shape[1])
+
+ @property
+ def Qshape(self):
+ return (self.data.shape[2], self.data.shape[3])
+
+ @property
+ def R_N(self):
+ return self.R_Nx * self.R_Ny
+
+ # pixel sizes/units
+
+ # R
+ @property
+ def R_pixel_size(self):
+ return self.calibration.get_R_pixel_size()
+
+ @R_pixel_size.setter
+ def R_pixel_size(self, x):
+ if type(x) is not list:
+ x = [x, x]
+ self.set_dim(0, [0, x[0]])
+ self.set_dim(1, [0, x[1]])
+ self.calibration.set_R_pixel_size(x)
+
+ @property
+ def R_pixel_units(self):
+ return self.calibration.get_R_pixel_units()
+
+ @R_pixel_units.setter
+ def R_pixel_units(self, x):
+ if type(x) is not list:
+ x = [x, x]
+ self.dim_units[0] = x[0]
+ self.dim_units[1] = x[1]
+ self.calibration.set_R_pixel_units(x)
+
+ # Q
+ @property
+ def Q_pixel_size(self):
+ return self.calibration.get_Q_pixel_size()
+
+ @Q_pixel_size.setter
+ def Q_pixel_size(self, x):
+ if type(x) is not list:
+ x = [x, x]
+ self.set_dim(2, [0, x[0]])
+ self.set_dim(3, [0, x[1]])
+ self.calibration.set_Q_pixel_size(x)
+
+ @property
+ def Q_pixel_units(self):
+ return self.calibration.get_Q_pixel_units()
+
+ @Q_pixel_units.setter
+ def Q_pixel_units(self, x):
+ if type(x) is not list:
+ x = [x, x]
+ self.dim_units[2] = x[0]
+ self.dim_units[3] = x[1]
+ self.calibration.set_Q_pixel_units(x)
+
+ # calibration
+ @property
+ def calibration(self):
+ return self.tree["calibration"]
+
+ @calibration.setter
+ def calibration(self, x):
+ assert isinstance(x, Calibration)
+ self.tree["calibration"] = x
+
+ # for parent datacube tracking
+ def track_parent(self, x):
+ x._parent = self
+ x.calibration = self.calibration
+
+ # HDF5 read/write
+
+ # write is inherited from Array
+
+ # read
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import DataCube_from_h5
+
+ return DataCube_from_h5(group)
+
+
+############ END OF CLASS ###########
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/diffractionslice.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/diffractionslice.py
new file mode 100644
index 000000000..b32877a4a
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/diffractionslice.py
@@ -0,0 +1,49 @@
+# Defines the DiffractionSlice class, which stores 2(+1)D,
+# diffraction-shaped data
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.array import Array
+
+from typing import Optional, Union
+import numpy as np
+import h5py
+
+
+class DiffractionSlice(Array):
+ """
+ Stores a diffraction-space shaped 2D data array.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "diffractionslice",
+ slicelabels: Optional[Union[bool, list]] = None,
+ ):
+ """
+ Accepts:
+ data (np.ndarray): the data
+ name (str): the name of the diffslice
+ slicelabels(None or list): names for slices if this is a 3D stack
+ Returns:
+ (DiffractionSlice instance)
+ """
+
+ # initialize as an Array
+ Array.__init__(
+ self, data=data, name=name, units="intensity", slicelabels=slicelabels
+ )
+
+ # HDF5 read/write
+
+ # write inherited from Array
+
+ # read
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import (
+ DiffractionSlice_from_h5,
+ )
+
+ return DiffractionSlice_from_h5(group)
+
+
+############ END OF CLASS ###########
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/io.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/io.py
new file mode 100644
index 000000000..2556ebe8f
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/io.py
@@ -0,0 +1,477 @@
+# Functions for reading and writing subclasses of the base EMD types
+
+import numpy as np
+import h5py
+from os.path import basename
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import (
+ Array_from_h5,
+ Metadata_from_h5,
+)
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import PointList_from_h5
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import (
+ PointListArray_from_h5,
+ PointListArray_to_h5,
+)
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.io import (
+ _write_metadata,
+ _read_metadata,
+)
+
+
+# Calibration
+
+
+# read
+def Calibration_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ read mode. Determines if it's a valid Metadata representation, and
+ if so loads and returns it as a Calibration instance. Otherwise,
+ raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ A Calibration instance
+ """
+ cal = Metadata_from_h5(group)
+ cal = Calibration_from_Metadata(cal)
+ return cal
+
+
+def Calibration_from_Metadata(metadata):
+ """
+ Constructs a Calibration object with the dict entries of a Metadata object
+ Accepts:
+ metadata (Metadata)
+ Returns:
+ (Calibration)
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.calibration import Calibration
+
+ cal = Calibration(name=metadata.name)
+ cal._params.update(metadata._params)
+
+ return cal
+
+
+# DataCube
+
+# read
+
+
+def DataCube_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in read
+ mode. Determines if an Array object of this name exists inside this group,
+ and if it does, loads and returns it as a DataCube. If it doesn't exist, or if
+ it exists but does not have a rank of 4, raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ A DataCube instance
+ """
+ datacube = Array_from_h5(group)
+ datacube = DataCube_from_Array(datacube)
+ return datacube
+
+
+def DataCube_from_Array(array):
+ """
+ Converts an Array to a DataCube.
+ Accepts:
+ array (Array)
+ Returns:
+ datacube (DataCube)
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.datacube import DataCube
+
+ assert array.rank == 4, "Array must have 4 dimensions"
+ array.__class__ = DataCube
+ try:
+ R_pixel_size = array.dims[0][1] - array.dims[0][0]
+ except IndexError:
+ R_pixel_size = 1
+ try:
+ Q_pixel_size = array.dims[2][1] - array.dims[2][0]
+ except IndexError:
+ Q_pixel_size = 1
+ array.__init__(
+ data=array.data,
+ name=array.name,
+ R_pixel_size=R_pixel_size,
+ R_pixel_units=array.dim_units[0],
+ Q_pixel_size=Q_pixel_size,
+ Q_pixel_units=array.dim_units[2],
+ slicelabels=array.slicelabels,
+ )
+ return array
+
+
+# DiffractionSlice
+
+# read
+
+
+def DiffractionSlice_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ read mode. Determines if it's a valid Array, and if so loads and
+ returns it as a DiffractionSlice. Otherwise, raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ A DiffractionSlice instance
+ """
+ diffractionslice = Array_from_h5(group)
+ diffractionslice = DiffractionSlice_from_Array(diffractionslice)
+ return diffractionslice
+
+
+def DiffractionSlice_from_Array(array):
+ """
+ Converts an Array to a DiffractionSlice.
+ Accepts:
+ array (Array)
+ Returns:
+ (DiffractionSlice)
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.diffractionslice import (
+ DiffractionSlice,
+ )
+
+ assert array.rank == 2, "Array must have 2 dimensions"
+ array.__class__ = DiffractionSlice
+ array.__init__(data=array.data, name=array.name, slicelabels=array.slicelabels)
+ return array
+
+
+# RealSlice
+
+# read
+
+
+def RealSlice_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ read mode. Determines if it's a valid Array, and if so loads and
+ returns it as a RealSlice. Otherwise, raises an exception.
+
+ Accepts:
+ group (HDF5 group)
+
+ Returns:
+ A RealSlice instance
+ """
+ realslice = Array_from_h5(group)
+ realslice = RealSlice_from_Array(realslice)
+ return realslice
+
+
+def RealSlice_from_Array(array):
+ """
+ Converts an Array to a RealSlice.
+
+ Accepts:
+ array (Array)
+
+ Returns:
+ (RealSlice)
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.realslice import RealSlice
+
+ assert array.rank == 2, "Array must have 2 dimensions"
+ array.__class__ = RealSlice
+ array.__init__(data=array.data, name=array.name, slicelabels=array.slicelabels)
+ return array
+
+
+# VirtualDiffraction
+
+# read
+
+
+def VirtualDiffraction_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ read mode. Determines if it's a valid Array, and if so loads and
+ returns it as a VirtualDiffraction. Otherwise, raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ A VirtualDiffraction instance
+ """
+ virtualdiffraction = Array_from_h5(group)
+ virtualdiffraction = VirtualDiffraction_from_Array(virtualdiffraction)
+ return virtualdiffraction
+
+
+def VirtualDiffraction_from_Array(array):
+ """
+ Converts an Array to a VirtualDiffraction.
+ Accepts:
+ array (Array)
+ Returns:
+ (VirtualDiffraction)
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.virtualdiffraction import (
+ VirtualDiffraction,
+ )
+
+ assert array.rank == 2, "Array must have 2 dimensions"
+
+ # get diffraction image metadata
+ try:
+ md = array.metadata["virtualdiffraction"]
+ method = md["method"]
+ mode = md["mode"]
+ geometry = md["geometry"]
+ shift_center = md["shift_center"]
+ except KeyError:
+ print("Warning: VirtualDiffraction metadata could not be found")
+ method = ""
+ mode = ""
+ geometry = ""
+ shift_center = ""
+
+ # instantiate as a DiffractionImage
+ array.__class__ = VirtualDiffraction
+ array.__init__(
+ data=array.data,
+ name=array.name,
+ method=method,
+ mode=mode,
+ geometry=geometry,
+ shift_center=shift_center,
+ )
+ return array
+
+
+# VirtualImage
+
+# read
+
+
+def VirtualImage_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ read mode. Determines if it's a valid Array, and if so loads and
+ returns it as a VirtualImage. Otherwise, raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ A VirtualImage instance
+ """
+ image = Array_from_h5(group)
+ image = VirtualImage_from_Array(image)
+ return image
+
+
+def VirtualImage_from_Array(array):
+ """
+ Converts an Array to a VirtualImage.
+ Accepts:
+ array (Array)
+ Returns:
+ (VirtualImage)
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.virtualimage import (
+ VirtualImage,
+ )
+
+ assert array.rank == 2, "Array must have 2 dimensions"
+
+ # get diffraction image metadata
+ try:
+ md = array.metadata["virtualimage"]
+ mode = md["mode"]
+ geo = md["geometry"]
+ centered = md._params.get("centered", None)
+ calibrated = md._params.get("calibrated", None)
+ shift_center = md._params.get("shift_center", None)
+ dask = md._params.get("dask", None)
+ except KeyError:
+ er = "VirtualImage metadata could not be found"
+ raise Exception(er)
+
+ # instantiate as a DiffractionImage
+ array.__class__ = VirtualImage
+ array.__init__(
+ data=array.data,
+ name=array.name,
+ mode=mode,
+ geometry=geo,
+ centered=centered,
+ calibrated=calibrated,
+ shift_center=shift_center,
+ dask=dask,
+ )
+ return array
+
+
+# Probe
+
+# read
+
+
+def Probe_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ read mode. Determines if it's a valid Array, and if so loads and
+ returns it as a Probe. Otherwise, raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ A Probe instance
+ """
+ probe = Array_from_h5(group)
+ probe = Probe_from_Array(probe)
+ return probe
+
+
+def Probe_from_Array(array):
+ """
+ Converts an Array to a Probe.
+ Accepts:
+ array (Array)
+ Returns:
+ (Probe)
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.probe import Probe
+
+ assert array.rank == 2, "Array must have 2 dimensions"
+ # get diffraction image metadata
+ try:
+ md = array.metadata["probe"]
+ kwargs = {}
+ for k in md.keys:
+ v = md[k]
+ kwargs[k] = v
+ except KeyError:
+ er = "Probe metadata could not be found"
+ raise Exception(er)
+
+ # instantiate as a DiffractionImage
+ array.__class__ = Probe
+ array.__init__(data=array.data, name=array.name, **kwargs)
+ return array
+
+
+# QPoints
+
+# Reading
+
+
+def QPoints_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ read mode. Determines if it's a valid QPoints instance, and if so
+ loads and returns it. Otherwise, raises an exception.
+ Accepts:
+ group (HDF5 group)
+ Returns:
+ A QPoints instance
+ """
+ qpoints = PointList_from_h5(group)
+ qpoints = QPoints_from_PointList(qpoints)
+ return qpoints
+
+
+def QPoints_from_PointList(pointlist):
+ """
+ Converts an PointList to QPoints.
+ Accepts:
+ pointlist (PointList)
+ Returns:
+ (QPoints)
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.qpoints import QPoints
+
+ pointlist.__class__ = QPoints
+ pointlist.__init__(
+ data=pointlist.data,
+ name=pointlist.name,
+ )
+ return pointlist
+
+
+# BraggVectors
+
+
+# write
+def BraggVectors_to_h5(braggvectors, group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in
+ write or append mode. Writes a new group with a name given by this
+ BraggVectors .name field nested inside the passed group, and saves
+ the data there.
+ Accepts:
+ group (HDF5 group)
+ """
+
+ ## Write
+ grp = group.create_group(braggvectors.name)
+ grp.attrs.create("emd_group_type", 4) # this tag indicates a Custom type
+ grp.attrs.create("py4dstem_class", braggvectors.__class__.__name__)
+
+ # Ensure that the PointListArrays have the appropriate names
+ braggvectors._v_uncal.name = "_v_uncal"
+
+ # Add vectors
+ PointListArray_to_h5(braggvectors._v_uncal, grp)
+ try:
+ braggvectors._v_cal.name = "_v_cal"
+ PointListArray_to_h5(braggvectors._v_cal, grp)
+ except AttributeError:
+ pass
+
+ # Add metadata
+ _write_metadata(braggvectors, grp)
+
+
+# read
+def BraggVectors_from_h5(group: h5py.Group):
+ """
+ Takes a valid HDF5 group for an HDF5 file object which is open in read mode,
+ and a name. Determines if a valid BraggVectors object of this name exists inside
+ this group, and if it does, loads and returns it. If it doesn't, raises
+ an exception.
+ Accepts:
+ group (HDF5 group)
+ name (string)
+ Returns:
+ A BraggVectors instance
+ """
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.braggvectors import (
+ BraggVectors,
+ )
+
+ er = f"Group {group} is not a valid BraggVectors group"
+ assert "emd_group_type" in group.attrs.keys(), er
+ assert group.attrs["emd_group_type"] == 4, er
+
+ # Get uncalibrated peak
+ v_uncal = PointListArray_from_h5(group["_v_uncal"])
+
+ # Get Qshape metadata
+ try:
+ grp_metadata = group["_metadata"]
+ Qshape = Metadata_from_h5(grp_metadata["braggvectors"])["Qshape"]
+ except KeyError:
+ raise Exception("could not read Qshape")
+
+ # Set up BraggVectors
+ braggvectors = BraggVectors(v_uncal.shape, Qshape=Qshape, name=basename(group.name))
+ braggvectors._v_uncal = v_uncal
+
+ # Add calibrated peaks, if they're there
+ try:
+ v_cal = PointListArray_from_h5(group["_v_cal"])
+ braggvectors._v_cal = v_cal
+ except KeyError:
+ pass
+
+ # Add remaining metadata
+ _read_metadata(braggvectors, group)
+
+ return braggvectors
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/parenttree.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/parenttree.py
new file mode 100644
index 000000000..cb3c7853e
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/parenttree.py
@@ -0,0 +1,98 @@
+# Defines the ParentTree class, which inherits from emd.Tree, and
+# adds the ability to track a parent datacube.
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.tree import Tree
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.array import Array
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.calibration import Calibration
+from numpy import ndarray
+
+
+class ParentTree(Tree):
+ def __init__(self, parent, calibration):
+ """
+ Creates a tree which is aware of and can point objects
+ added to it to it's parent and associated calibration.
+ `parent` is typically a DataCube, but need not be.
+ `calibration` should be a Calibration instance.
+ """
+ assert isinstance(calibration, Calibration)
+
+ Tree.__init__(self)
+ self._tree["calibration"] = calibration
+ self._parent = parent
+
+ def __setitem__(self, key, value):
+ if isinstance(value, ndarray):
+ value = Array(data=value, name=key)
+ self._tree[key] = value
+ value._parent = self._parent
+ value.calibration = self._parent.calibration
+
+ def __getitem__(self, x):
+ l = x.split("/")
+ try:
+ l.remove("")
+ l.remove("")
+ except ValueError:
+ pass
+ return self._getitem_from_list(l)
+
+ def _getitem_from_list(self, x):
+ if len(x) == 0:
+ raise Exception("invalid slice value to tree")
+
+ k = x.pop(0)
+ er = f"{k} not found in tree - check keys"
+ assert k in self._tree.keys(), er
+
+ if len(x) == 0:
+ return self._tree[k]
+ else:
+ tree = self._tree[k].tree
+ return tree._getitem_from_list(x)
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( An object tree containing the following top-level object instances:"
+ string += "\n"
+ for k, v in self._tree.items():
+ string += "\n" + space + f" {k} \t\t ({v.__class__.__name__})"
+ string += "\n)"
+ return string
+
+ def keys(self):
+ return self._tree.keys()
+
+ def print(self):
+ """
+ Prints the tree contents to screen.
+ """
+ print("/")
+ self._print_tree_to_screen(self)
+ print("\n")
+
+ def _print_tree_to_screen(self, tree, tablevel=0, linelevels=[]):
+ """ """
+ if tablevel not in linelevels:
+ linelevels.append(tablevel)
+ keys = [k for k in tree.keys()]
+ # keys = [k for k in keys if k != 'metadata']
+ N = len(keys)
+ for i, k in enumerate(keys):
+ string = ""
+ string += "|" if 0 in linelevels else ""
+ for idx in range(tablevel):
+ l = "|" if idx + 1 in linelevels else ""
+ string += "\t" + l
+ # print(string)
+ print(string + "--" + k)
+ if i == N - 1:
+ linelevels.remove(tablevel)
+ try:
+ self._print_tree_to_screen(
+ tree[k].tree, tablevel=tablevel + 1, linelevels=linelevels
+ )
+ except AttributeError:
+ pass
+
+ pass
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/probe.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/probe.py
new file mode 100644
index 000000000..cd1c7d9d9
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/probe.py
@@ -0,0 +1,74 @@
+# Defines the Probe class, which stores vacuum probes
+# and cross-correlation kernels derived from them
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.array import Array, Metadata
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.diffractionslice import (
+ DiffractionSlice,
+)
+
+from typing import Optional, Union
+import numpy as np
+import h5py
+
+
+class Probe(DiffractionSlice):
+ """
+ Stores a vacuum probe.
+ """
+
+ def __init__(self, data: np.ndarray, name: Optional[str] = "probe", **kwargs):
+ """
+ Accepts:
+ data (2D or 3D np.ndarray): the vacuum probe, or
+ the vacuum probe + kernel
+ name (str): a name
+ Returns:
+ (Probe)
+ """
+ # if only the probe is passed, make space for the kernel
+ if data.ndim == 2:
+ data = np.dstack([data, np.zeros_like(data)])
+
+ # initialize as a DiffractionSlice
+ DiffractionSlice.__init__(
+ self, name=name, data=data, slicelabels=["probe", "kernel"]
+ )
+
+ # Set metadata
+ md = Metadata(name="probe")
+ for k, v in kwargs.items():
+ md[k] = v
+ self.metadata = md
+
+ ## properties
+
+ @property
+ def probe(self):
+ return self.get_slice("probe").data
+
+ @probe.setter
+ def probe(self, x):
+ assert x.shape == (self.data.shape[:2])
+ self.data[:, :, 0] = x
+
+ @property
+ def kernel(self):
+ return self.get_slice("kernel").data
+
+ @kernel.setter
+ def kernel(self, x):
+ assert x.shape == (self.data.shape[:2])
+ self.data[:, :, 1] = x
+
+ # HDF5 read/write
+
+ # write inherited from Array
+
+ # read
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import Probe_from_h5
+
+ return Probe_from_h5(group)
+
+
+############ END OF CLASS ###########
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/qpoints.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/qpoints.py
new file mode 100644
index 000000000..3429c4c8d
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/qpoints.py
@@ -0,0 +1,63 @@
+# Defines the QPoints class, which stores PointLists with fields 'qx','qy','intensity'
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.pointlist import PointList
+
+from typing import Optional, Union
+import numpy as np
+import h5py
+
+
+class QPoints(PointList):
+ """
+ Stores a set of diffraction space points,
+ with fields 'qx', 'qy' and 'intensity'
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "qpoints",
+ ):
+ """
+ Accepts:
+ data (structured numpy ndarray): should have three fields, which
+ will be renamed 'qx','qy','intensity'
+ name (str): the name of the QPoints instance
+ Returns:
+ A new QPoints instance
+ """
+
+ # initialize as a PointList
+ PointList.__init__(
+ self,
+ data=data,
+ name=name,
+ )
+
+ # rename fields
+ self.fields = "qx", "qy", "intensity"
+
+ @property
+ def qx(self):
+ return self.data["qx"]
+
+ @property
+ def qy(self):
+ return self.data["qy"]
+
+ @property
+ def intensity(self):
+ return self.data["intensity"]
+
+ # HDF5 read/write
+
+ # write inherited from PointList
+
+ # read
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import QPoints_from_h5
+
+ return QPoints_from_h5(group)
+
+
+############ END OF CLASS ###########
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/realslice.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/realslice.py
new file mode 100644
index 000000000..367401055
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/realslice.py
@@ -0,0 +1,97 @@
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.array import Array
+
+from typing import Optional, Union
+import numpy as np
+import h5py
+
+
+class RealSlice(Array):
+ """
+ Stores a real-space shaped 2D data array.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "realslice",
+ pixel_size: Optional[Union[float, list]] = 1,
+ pixel_units: Optional[Union[str, list]] = "pixels",
+ slicelabels: Optional[Union[bool, list]] = None,
+ ):
+ """
+ Accepts:
+ data (np.ndarray): the data
+ name (str): the name of the realslice
+ pixel_size (float or length 2 list of floats): the pixel size
+ pixel_units (str length 2 list of str): the pixel units
+ slicelabels(None or list): names for slices if this is a stack of
+ realslices
+ Returns:
+ A new RealSlice instance
+ """
+ # expand pixel inputs to include 2 dimensions
+ if type(pixel_size) is not list:
+ pixel_size = [pixel_size, pixel_size]
+ if type(pixel_units) is not list:
+ pixel_units = [pixel_units, pixel_units]
+
+ # initialize as an Array
+ Array.__init__(
+ self,
+ data=data,
+ name=name,
+ units="intensity",
+ dims=[
+ pixel_size[0],
+ pixel_size[1],
+ ],
+ dim_units=[
+ pixel_units[0],
+ pixel_units[1],
+ ],
+ dim_names=["Rx", "Ry"],
+ slicelabels=slicelabels,
+ )
+
+ # setup the size/units with getter/setters
+ self._pixel_size = pixel_size
+ self._pixel_units = pixel_units
+
+ @property
+ def pixel_size(self):
+ return self._pixel_size
+
+ @pixel_size.setter
+ def pixel_size(self, x):
+ if type(x) is not list:
+ x = [x, x]
+ self.set_dim(0, [0, x[0]])
+ self.set_dim(1, [0, x[1]])
+ self._pixel_size = x
+
+ @property
+ def pixel_units(self):
+ return self._pixel_units
+
+ @pixel_units.setter
+ def pixel_units(self, x):
+ if type(x) is not list:
+ x = [x, x]
+ self.dim_units[0] = x[0]
+ self.dim_units[1] = x[1]
+ self._pixel_units = x
+
+ # HDF5 read/write
+
+ # write inherited from Array
+
+ # read
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import (
+ RealSlice_from_h5,
+ )
+
+ return RealSlice_from_h5(group)
+
+
+############ END OF CLASS ###########
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/virtualdiffraction.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/virtualdiffraction.py
new file mode 100644
index 000000000..188f1d646
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/virtualdiffraction.py
@@ -0,0 +1,95 @@
+# Defines the VirtualDiffraction class, which stores 2D, diffraction-shaped data
+# with metadata about how it was created
+
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.diffractionslice import (
+ DiffractionSlice,
+)
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.metadata import Metadata
+
+from typing import Optional, Union
+import numpy as np
+import h5py
+
+
+class VirtualDiffraction(DiffractionSlice):
+ """
+ Stores a diffraction-space shaped 2D image with metadata
+ indicating how this image was generated from a datacube.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "diffractionimage",
+ method: Optional[str] = None,
+ mode: Optional[str] = None,
+ geometry: Optional[Union[tuple, np.ndarray]] = None,
+ calibrated: Optional[bool] = False,
+ shift_center: bool = False,
+ ):
+ """
+ Args:
+ data (np.ndarray) : the 2D data
+ name (str) : the name
+ method (str) : defines method used for diffraction pattern, options are
+ 'mean', 'median', and 'max'
+ mode (str) : defines mode for selecting area in real space to use for
+ virtual diffraction. The default is None, which means no
+ geometry will be applied and the whole datacube will be used
+ for the calculation. Options:
+ - 'point' uses singular point as detector
+ - 'circle' or 'circular' uses round detector, like bright field
+ - 'annular' or 'annulus' uses annular detector, like dark field
+ - 'rectangle', 'square', 'rectangular', uses rectangular detector
+ - 'mask' flexible detector, any 2D array
+ geometry (variable) : valid entries are determined by the `mode`, values in pixels
+ argument, as follows. The default is None, which means no geometry will be applied
+ and the whole datacube will be used for the calculation. If mode is None the geometry
+ will not be applied.
+ - 'point': 2-tuple, (rx,ry),
+ qx and qy are each single float or int to define center
+ - 'circle' or 'circular': nested 2-tuple, ((rx,ry),radius),
+ qx, qy and radius, are each single float or int
+ - 'annular' or 'annulus': nested 2-tuple, ((rx,ry),(radius_i,radius_o)),
+ qx, qy, radius_i, and radius_o are each single float or integer
+ - 'rectangle', 'square', 'rectangular': 4-tuple, (xmin,xmax,ymin,ymax)
+ - `mask`: flexible detector, any boolean or floating point 2D array with
+ the same shape as datacube.Rshape
+ calibrated (bool) : if True, geometry is specified in units of 'A' instead of pixels.
+ The datacube's calibrations must have its `"R_pixel_units"` parameter set to "A".
+ If mode is None the geometry and calibration will not be applied.
+ shift_center (bool) : if True, the difraction pattern is shifted to account for beam shift
+ or the changing of the origin through the scan. The datacube's calibration['origin']
+ parameter must be set Only 'max' and 'mean' supported for this option.
+ Returns:
+ A new VirtualDiffraction instance
+ """
+ # initialize as a DiffractionSlice
+ DiffractionSlice.__init__(
+ self,
+ data=data,
+ name=name,
+ )
+
+ # Set metadata
+ md = Metadata(name="virtualdiffraction")
+ md["method"] = method
+ md["mode"] = mode
+ md["geometry"] = geometry
+ md["shift_center"] = shift_center
+ self.metadata = md
+
+ # HDF5 read/write
+
+ # write inherited from Array
+
+ # read
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import (
+ VirtualDiffraction_from_h5,
+ )
+
+ return VirtualDiffraction_from_h5(group)
+
+
+############ END OF CLASS ###########
diff --git a/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/virtualimage.py b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/virtualimage.py
new file mode 100644
index 000000000..4d6c38845
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_py4dstem_classes/virtualimage.py
@@ -0,0 +1,100 @@
+# Defines the VirtualImage class, which stores 2D, real-shaped data
+# with metadata about how it was created
+
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.realslice import RealSlice
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes.metadata import Metadata
+
+from typing import Optional, Union
+import numpy as np
+import h5py
+
+
+class VirtualImage(RealSlice):
+ """
+ Stores a real-space shaped 2D image with metadata
+ indicating how this image was generated from a datacube.
+ """
+
+ def __init__(
+ self,
+ data: np.ndarray,
+ name: Optional[str] = "virtualimage",
+ mode: Optional[str] = None,
+ geometry: Optional[Union[tuple, np.ndarray]] = None,
+ centered: Optional[bool] = False,
+ calibrated: Optional[bool] = False,
+ shift_center: Optional[bool] = False,
+ dask: Optional[bool] = False,
+ ):
+ """
+ Args:
+ data (np.ndarray) : the 2D data
+ name (str) : the name
+ mode (str) : defines geometry mode for calculating virtual image.
+ Options:
+ - 'point' uses singular point as detector
+ - 'circle' or 'circular' uses round detector, like bright field
+ - 'annular' or 'annulus' uses annular detector, like dark field
+ - 'rectangle', 'square', 'rectangular', uses rectangular detector
+ - 'mask' flexible detector, any 2D array
+ geometry (variable) : valid entries are determined by the `mode`, values in pixels
+ argument, as follows:
+ - 'point': 2-tuple, (qx,qy),
+ qx and qy are each single float or int to define center
+ - 'circle' or 'circular': nested 2-tuple, ((qx,qy),radius),
+ qx, qy and radius, are each single float or int
+ - 'annular' or 'annulus': nested 2-tuple, ((qx,qy),(radius_i,radius_o)),
+ qx, qy, radius_i, and radius_o are each single float or integer
+ - 'rectangle', 'square', 'rectangular': 4-tuple, (xmin,xmax,ymin,ymax)
+ - `mask`: flexible detector, any boolean or floating point 2D array with
+ the same shape as datacube.Qshape
+ centered (bool) : if False (default), the origin is in the upper left corner.
+ If True, the mean measured origin in the datacube calibrations
+ is set as center. The measured origin is set with datacube.calibration.set_origin()
+ In this case, for example, a centered bright field image could be defined
+ by geometry = ((0,0), R). For `mode="mask"`, has no effect.
+ calibrated (bool) : if True, geometry is specified in units of 'A^-1' instead of pixels.
+ The datacube's calibrations must have its `"Q_pixel_units"` parameter set to "A^-1".
+ For `mode="mask"`, has no effect.
+ shift_center (bool) : if True, the mask is shifted at each real space position to
+ account for any shifting of the origin of the diffraction images. The datacube's
+ calibration['origin'] parameter must be set. The shift applied to each pattern is
+ the difference between the local origin position and the mean origin position
+ over all patterns, rounded to the nearest integer for speed.
+ verbose (bool) : if True, show progress bar
+ dask (bool) : if True, use dask arrays
+
+ Returns:
+ A new VirtualImage instance
+ """
+ # initialize as a RealSlice
+ RealSlice.__init__(
+ self,
+ data=data,
+ name=name,
+ )
+
+ # Set metadata
+ md = Metadata(name="virtualimage")
+ md["mode"] = mode
+ md["geometry"] = geometry
+ md["centered"] = centered
+ md["calibrated"] = calibrated
+ md["shift_center"] = shift_center
+ md["dask"] = dask
+ self.metadata = md
+
+ # HDF5 read/write
+
+ # write inherited from Array
+
+ # read
+ def from_h5(group):
+ from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes.io import (
+ VirtualImage_from_h5,
+ )
+
+ return VirtualImage_from_h5(group)
+
+
+############ END OF CLASS ###########
diff --git a/py4DSTEM/io/legacy/legacy13/v13_to_14.py b/py4DSTEM/io/legacy/legacy13/v13_to_14.py
new file mode 100644
index 000000000..650529b22
--- /dev/null
+++ b/py4DSTEM/io/legacy/legacy13/v13_to_14.py
@@ -0,0 +1,221 @@
+# Convert v13 to v14 classes
+
+import numpy as np
+from emdfile import tqdmnd
+
+
+# v13 imports
+
+from py4DSTEM.io.legacy.legacy13.v13_emd_classes import (
+ Root as Root13,
+ Metadata as Metadata13,
+ Array as Array13,
+ PointList as PointList13,
+ PointListArray as PointListArray13,
+)
+from py4DSTEM.io.legacy.legacy13.v13_py4dstem_classes import (
+ Calibration as Calibration13,
+ DataCube as DataCube13,
+ DiffractionSlice as DiffractionSlice13,
+ VirtualDiffraction as VirtualDiffraction13,
+ RealSlice as RealSlice13,
+ VirtualImage as VirtualImage13,
+ Probe as Probe13,
+ QPoints as QPoints13,
+ BraggVectors as BraggVectors13,
+)
+
+
+# v14 imports
+
+from emdfile import Root, Metadata, Array, PointList, PointListArray
+
+from py4DSTEM.data import (
+ Calibration,
+ DiffractionSlice,
+ RealSlice,
+ QPoints,
+)
+from py4DSTEM.datacube import (
+ DataCube,
+ VirtualImage,
+ VirtualDiffraction,
+)
+
+
+def v13_to_14(v13tree, v13cal):
+ """
+ Converts a v13 data tree to a v14 data tree
+ """
+ # if a list of root names was returned, pass it through
+ if isinstance(v13tree, list):
+ return v13tree
+
+ # convert the selected node
+ node = _v13_to_14_cls(v13tree)
+
+ # handle the root
+ if isinstance(node, Root):
+ root = node
+ elif node.root is None:
+ root = Root(name=node.name)
+ root.tree(node)
+ else:
+ root = node.root
+
+ # populate tree
+ _populate_tree(v13tree, node, root)
+
+ # add calibration
+ if v13cal is not None:
+ cal = _v13_to_14_cls(v13cal)
+ root.metadata = cal
+
+ # return
+ return node
+
+
+def _populate_tree(node13, node14, root14):
+ for key in node13.tree.keys():
+ newnode13 = node13.tree[key]
+ newnode14 = _v13_to_14_cls(newnode13)
+ # skip calibrations and metadata
+ if isinstance(newnode14, Metadata):
+ pass
+ else:
+ node14.tree(newnode14, force=True)
+ _populate_tree(newnode13, newnode14, root14)
+
+
+def _v13_to_14_cls(obj):
+ """
+ Convert a single version 13 object instance to the equivalent version 14 object,
+ including metadata.
+ """
+
+ assert isinstance(
+ obj,
+ (
+ Root13,
+ Metadata13,
+ Array13,
+ PointList13,
+ PointListArray13,
+ Calibration13,
+ DataCube13,
+ DiffractionSlice13,
+ VirtualDiffraction13,
+ RealSlice13,
+ VirtualImage13,
+ Probe13,
+ QPoints13,
+ BraggVectors13,
+ ),
+ ), f"obj must be a v13 class instance, not type {type(obj)}"
+
+ if isinstance(obj, Root13):
+ x = Root(name=obj.name)
+
+ elif isinstance(obj, Calibration13):
+ x = Calibration(name=obj.name)
+ x._params.update(obj._params)
+
+ elif isinstance(obj, DataCube13):
+ x = DataCube(name=obj.name, data=obj.data, slicelabels=obj.slicelabels)
+
+ elif isinstance(obj, DiffractionSlice13):
+ if obj.is_stack:
+ data = np.rollaxis(obj.data, axis=2)
+ else:
+ data = obj.data
+ x = DiffractionSlice(
+ name=obj.name, data=data, units=obj.units, slicelabels=obj.slicelabels
+ )
+
+ elif isinstance(obj, VirtualDiffraction13):
+ x = VirtualDiffraction(name=obj.name, data=obj.data)
+
+ elif isinstance(obj, RealSlice13):
+ if obj.is_stack:
+ data = np.rollaxis(obj.data, axis=2)
+ else:
+ data = obj.data
+ x = RealSlice(
+ name=obj.name, data=data, units=obj.units, slicelabels=obj.slicelabels
+ )
+ pass
+
+ elif isinstance(obj, VirtualImage13):
+ x = VirtualImage(name=obj.name, data=obj.data)
+ pass
+
+ elif isinstance(obj, Probe13):
+ from py4DSTEM.braggvectors import Probe
+
+ x = Probe(name=obj.name, data=obj.data)
+
+ elif isinstance(obj, QPoints13):
+ x = PointList(name=obj.name, data=obj.data)
+
+ elif isinstance(obj, BraggVectors13):
+ from py4DSTEM.braggvectors import BraggVectors
+
+ x = BraggVectors(name=obj.name, Rshape=obj.Rshape, Qshape=obj.Qshape)
+ x._v_uncal = obj._v_uncal
+ if hasattr(obj, "_v_cal"):
+ x._v_cal = obj._v_cal
+
+ elif isinstance(obj, Metadata13):
+ x = Metadata(name=obj.name)
+ x._params.update(obj._params)
+
+ elif isinstance(obj, Array13):
+ # prepare arguments
+ if obj.is_stack:
+ data = np.rollaxis(obj.data, axis=2)
+ else:
+ data = obj.data
+ args = {"name": obj.name, "data": data}
+ if hasattr(obj, "units"):
+ args["units"] = obj.units
+ if hasattr(obj, "dim_names"):
+ args["dim_names"] = obj.dim_names
+ if hasattr(obj, "dim_units"):
+ args["dim_units"] = obj.dim_units
+ if hasattr(obj, "slicelabels"):
+ args["slicelabels"] = obj.slicelabels
+ if hasattr(obj, "dims"):
+ dims = []
+ for dim in obj.dims:
+ dims.append(dim)
+ args["dims"] = dims
+
+ # get the array
+ x = Array(**args)
+
+ elif isinstance(obj, PointList13):
+ x = PointList(name=obj.name, data=obj.data)
+
+ elif isinstance(obj, PointListArray13):
+ x = PointListArray(name=obj.name, dtype=obj.dtype, shape=obj.shape)
+ for idx, jdx in tqdmnd(
+ x.shape[0],
+ x.shape[1],
+ desc="transferring PointListArray v13->14",
+ unit="foolishness",
+ ):
+ x[idx, jdx] = obj[idx, jdx]
+
+ else:
+ raise Exception(f"Unexpected object type {type(obj)}")
+
+ # Handle metadata
+ if hasattr(obj, "metadata"):
+ for key in obj.metadata.keys():
+ md = obj.metadata[key]
+ dm = Metadata(name=md.name)
+ dm._params.update(md._params)
+ x.metadata = dm
+
+ # Return
+ return x
diff --git a/py4DSTEM/io/legacy/read_legacy_12.py b/py4DSTEM/io/legacy/read_legacy_12.py
new file mode 100644
index 000000000..40bfcfc94
--- /dev/null
+++ b/py4DSTEM/io/legacy/read_legacy_12.py
@@ -0,0 +1,93 @@
+# File reader for py4DSTEM files
+
+import h5py
+import numpy as np
+from os.path import splitext, exists
+from py4DSTEM.io.legacy.read_utils import is_py4DSTEM_file, get_py4DSTEM_topgroups
+from py4DSTEM.io.legacy.read_utils import get_py4DSTEM_version, version_is_geq
+from py4DSTEM.io.legacy.legacy12 import (
+ read_v0_12,
+ read_v0_9,
+ read_v0_7,
+ read_v0_6,
+ read_v0_5,
+)
+
+
+def read_legacy12(filepath, **kwargs):
+ """
+ File reader for older legacy py4DSTEM (v<0.13) formated HDF5 files.
+
+ Different file versions Precise behavior is
+ detemined by which arguments are passed -- see below.
+
+ Args:
+ filepath (str or pathlib.Path): When passed a filepath only, this function checks
+ if the path points to a valid py4DSTEM file, then prints its contents to
+ screen.
+ data_id (int/str/list, optional): Specifies which data to load. Use integers to
+ specify the data index, or strings to specify data names. A list or tuple
+ returns a list of DataObjects. Returns the specified data.
+ topgroup (str, optional): Stricty, a py4DSTEM file is considered to be everything
+ inside a toplevel subdirectory within the HDF5 file, so that if desired one
+ can place many py4DSTEM files inside a single H5. In this case, when loading
+ data, the topgroup argument is passed to indicate which py4DSTEM file to
+ load. If an H5 containing multiple py4DSTEM files is passed without a
+ topgroup specified, the topgroup names are printed to screen.
+ mem (str, optional): Only used if a single DataCube is loaded. In this case,
+ mem specifies how the data should be stored; must be "RAM" or "MEMMAP". See
+ docstring for py4DSTEM.file.io.read. Default is "RAM".
+ binfactor (int, optional): Only used if a single DataCube is loaded. In this
+ case, a binfactor of > 1 causes the data to be binned by this amount as it's
+ loaded.
+ dtype (dtype, optional): Used when binning data, ignored otherwise. Defaults to
+ whatever the type of the raw data is, to avoid enlarging data size. May be
+ useful to avoid 'wraparound' errors.
+
+ Returns:
+ (variable): The output depends on usage:
+
+ * If no input arguments with return values (i.e. data_id or metadata) are
+ passed, nothing is returned.
+ * Otherwise, a single DataObject or list of DataObjects are returned, based
+ on the value of the argument data_id.
+ """
+ assert exists(filepath), "Error: specified filepath does not exist"
+ assert is_py4DSTEM_file(
+ filepath
+ ), "Error: {} isn't recognized as a py4DSTEM file.".format(filepath)
+
+ # For HDF5 files containing multiple valid EMD type 2 files (i.e. py4DSTEM files),
+ # disambiguate desired data
+ tgs = get_py4DSTEM_topgroups(filepath)
+ if "topgroup" in kwargs.keys():
+ tg = kwargs["topgroup"]
+ # assert(tg in tgs), "Error: specified topgroup, {}, not found.".format(tg)
+ else:
+ if len(tgs) == 1:
+ tg = tgs[0]
+ else:
+ print("Multiple topgroups were found -- please specify one:")
+ print("")
+ for tg in tgs:
+ print(tg)
+ return
+
+ # Get py4DSTEM version and call the appropriate read function
+ version = get_py4DSTEM_version(filepath, tg)
+ if version_is_geq(version, (0, 12, 0)):
+ return read_v0_12(filepath, **kwargs)
+ elif version_is_geq(version, (0, 9, 0)):
+ return read_v0_9(filepath, **kwargs)
+ elif version_is_geq(version, (0, 7, 0)):
+ return read_v0_7(filepath, **kwargs)
+ elif version_is_geq(version, (0, 6, 0)):
+ return read_v0_6(filepath, **kwargs)
+ elif version_is_geq(version, (0, 5, 0)):
+ return read_v0_5(filepath, **kwargs)
+ else:
+ raise Exception(
+ "Support for legacy v{}.{}.{} files is no longer available.".format(
+ version[0], version[1], version[2]
+ )
+ )
diff --git a/py4DSTEM/io/legacy/read_legacy_13.py b/py4DSTEM/io/legacy/read_legacy_13.py
new file mode 100644
index 000000000..04da1e65a
--- /dev/null
+++ b/py4DSTEM/io/legacy/read_legacy_13.py
@@ -0,0 +1,255 @@
+# File reader for py4DSTEM v13 files
+
+import h5py
+import numpy as np
+import warnings
+from os.path import exists, basename, dirname, join
+from typing import Optional, Union
+
+from py4DSTEM.io.legacy.read_utils import is_py4DSTEM_version13
+from py4DSTEM.io.legacy.legacy13 import (
+ Calibration,
+ DataCube,
+ DiffractionSlice,
+ VirtualDiffraction,
+ RealSlice,
+ VirtualImage,
+ Probe,
+ QPoints,
+ BraggVectors,
+)
+from py4DSTEM.io.legacy.legacy13 import Root, Metadata, Array, PointList, PointListArray
+from py4DSTEM.io.legacy.legacy13 import v13_to_14
+
+
+def read_legacy13(
+ filepath,
+ root: Optional[str] = None,
+ tree: Optional[Union[bool, str]] = True,
+):
+ """
+ File reader for legacy py4DSTEM (v=0.13.x) formated HDF5 files.
+
+ Args:
+ filepath (str or Path): the file path
+ root (str): the path to the data group in the HDF5 file
+ to read from. To examine an HDF5 file written by py4DSTEM
+ in order to determine this path, call
+ `py4DSTEM.print_h5_tree(filepath)`. If left unspecified,
+ looks in the file and if it finds a single top-level
+ object, loads it. If it finds multiple top-level objects,
+ prints a warning and returns a list of root paths to the
+ top-level object found.
+ tree (bool or str): indicates what data should be loaded,
+ relative to the root group specified above. Must be in
+ (`True` or `False` or `noroot`). If set to `False`, the
+ only the data in the root group is loaded, plus any
+ associated calibrations. If set to `True`, loads the root
+ group, and all other data groups nested underneath it
+ in the file tree. If set to `'noroot'`, loads all other
+ data groups nested under the root group in the file tree,
+ but does *not* load the data inside the root group (allowing,
+ e.g., loading all the data nested under a DataCube13 without
+ loading the whole datacube).
+ Returns:
+ (the data)
+ """
+ # Check that filepath is valid
+ assert exists(filepath), "Error: specified filepath does not exist"
+ assert is_py4DSTEM_version13(
+ filepath
+ ), f"Error: {filepath} isn't recognized as a v13 py4DSTEM file."
+
+ if root is None:
+ # check if there is a single object in the file
+ # if so, set root to that file; otherwise raise an Exception or Warning
+
+ with h5py.File(filepath, "r") as f:
+ l1keys = list(f.keys())
+ if len(l1keys) == 0:
+ raise Exception("No top level groups found in this HDF5 file!")
+ elif len(l1keys) > 1:
+ warnings.warn(
+ "Multiple top level groups found; please specify. Returning group names."
+ )
+ return l1keys
+ else:
+ l2keys = list(f[l1keys[0]].keys())
+ if len(l2keys) == 0:
+ raise Exception("No top level data blocks found in this HDF5 file!")
+ elif len(l2keys) > 1:
+ warnings.warn(
+ "Multiple top level data blocks found; please specify. Returning h5 paths to top level data blocks."
+ )
+ return [join(l1keys[0], k) for k in l2keys]
+ else:
+ root = join(l1keys[0], l2keys[0])
+ # this is a windows fix
+ root = root.replace("\\", "/")
+
+ # Open file
+ with h5py.File(filepath, "r") as f:
+ # open the selected group
+ try:
+ group_data = f[root]
+ except KeyError:
+ raise Exception(
+ f"the provided root {root} is not a valid path to a recognized data group"
+ )
+
+ # Read data
+ if tree is True:
+ data = _read_with_tree(group_data)
+
+ elif tree is False:
+ data = _read_without_tree(group_data)
+
+ elif tree == "noroot":
+ data = _read_without_root(group_data)
+
+ else:
+ raise Exception(f"Unexpected value {tree} for `tree`")
+
+ # Read calibration
+ cal = _read_calibration(group_data)
+
+ # convert version 13 -> 14
+ data = v13_to_14(data, cal)
+ return data
+
+
+# utilities
+
+
+def _read_without_tree(grp):
+ # handle empty datasets
+ if grp.attrs["emd_group_type"] == "root":
+ data = Root(
+ name=basename(grp.name),
+ )
+ return data
+
+ # read data as v13 objects
+ __class__ = _get_v13_class(grp)
+ data = __class__.from_h5(grp)
+
+ return data
+
+
+def _read_with_tree(grp):
+ data = _read_without_tree(grp)
+ _populate_tree(data.tree, grp)
+ return data
+
+
+def _read_without_root(grp):
+ root = Root()
+ _populate_tree(root.tree, grp)
+ return root
+
+
+def _read_calibration(grp):
+ keys = [k for k in grp.keys() if isinstance(grp[k], h5py.Group)]
+ keys = [k for k in keys if (_get_v13_class(grp[k]) == Calibration)]
+ if len(keys) > 0:
+ k = keys[0]
+ cal = Calibration.from_h5(grp[k])
+ return cal
+ else:
+ name = dirname(grp.name)
+ if name != "/":
+ grp_upstream = grp.file[dirname(grp.name)]
+ return _read_calibration(grp_upstream)
+ else:
+ return None
+
+
+def _populate_tree(tree, grp):
+ keys = [k for k in grp.keys() if isinstance(grp[k], h5py.Group)]
+ keys = [
+ k for k in keys if (k[0] != "_" and not _get_v13_class(grp[k]) == Calibration)
+ ]
+ for key in keys:
+ tree[key] = _read_without_tree(grp[key])
+ _populate_tree(tree[key].tree, grp[key])
+ pass
+
+
+def print_v13h5_tree(filepath, show_metadata=False):
+ """
+ Prints the contents of an h5 file from a filepath.
+ """
+
+ with h5py.File(filepath, "r") as f:
+ print("/")
+ print_v13h5pyFile_tree(f, show_metadata=show_metadata)
+ print("\n")
+
+
+def print_v13h5pyFile_tree(f, tablevel=0, linelevels=[], show_metadata=False):
+ """
+ Prints the contents of an h5 file from an open h5py File instance.
+ """
+ if tablevel not in linelevels:
+ linelevels.append(tablevel)
+ keys = [k for k in f.keys() if isinstance(f[k], h5py.Group)]
+ if not show_metadata:
+ keys = [k for k in keys if k != "_metadata"]
+ N = len(keys)
+ for i, k in enumerate(keys):
+ string = ""
+ string += "|" if 0 in linelevels else ""
+ for idx in range(tablevel):
+ l = "|" if idx + 1 in linelevels else ""
+ string += "\t" + l
+ print(string + "--" + k)
+ if i == N - 1:
+ linelevels.remove(tablevel)
+ print_v13h5pyFile_tree(
+ f[k],
+ tablevel=tablevel + 1,
+ linelevels=linelevels,
+ show_metadata=show_metadata,
+ )
+
+ pass
+
+
+def _get_v13_class(grp):
+ lookup = {
+ "Metadata": Metadata,
+ "Array": Array,
+ "PointList": PointList,
+ "PointListArray": PointListArray,
+ "Calibration": Calibration,
+ "DataCube": DataCube,
+ "DiffractionSlice": DiffractionSlice,
+ "VirtualDiffraction": VirtualDiffraction,
+ "DiffractionImage": VirtualDiffraction,
+ "RealSlice": RealSlice,
+ "VirtualImage": VirtualImage,
+ "Probe": Probe,
+ "QPoints": QPoints,
+ "BraggVectors": BraggVectors,
+ }
+
+ if "py4dstem_class" in grp.attrs:
+ classname = grp.attrs["py4dstem_class"]
+ elif "emd_group_type" in grp.attrs:
+ emd_group_type = grp.attrs["emd_group_type"]
+ classname = {
+ "root": "root",
+ 0: Metadata,
+ 1: Array,
+ 2: PointList,
+ 3: PointListArray,
+ }[emd_group_type]
+ else:
+ warnings.warn(f"Can't determine class type of H5 group {grp}; skipping...")
+ return None
+ try:
+ __class__ = lookup[classname]
+ return __class__
+ except KeyError:
+ warnings.warn(f"Can't determine class type of H5 group {grp}; skipping...")
+ return None
diff --git a/py4DSTEM/io/legacy/read_utils.py b/py4DSTEM/io/legacy/read_utils.py
new file mode 100644
index 000000000..7cd48cde7
--- /dev/null
+++ b/py4DSTEM/io/legacy/read_utils.py
@@ -0,0 +1,106 @@
+# Utility functions
+
+import h5py
+import numpy as np
+
+
+def get_py4DSTEM_topgroups(filepath):
+ """Returns a list of toplevel groups in an HDF5 file which are valid py4DSTEM file trees."""
+ topgroups = []
+ with h5py.File(filepath, "r") as f:
+ for key in f.keys():
+ if "emd_group_type" in f[key].attrs:
+ topgroups.append(key)
+ return topgroups
+
+
+def is_py4DSTEM_version13(filepath):
+ """Returns True for data written by a py4DSTEM v0.13.x release."""
+ with h5py.File(filepath, "r") as f:
+ for k in f.keys():
+ if "emd_group_type" in f[k].attrs:
+ if f[k].attrs["emd_group_type"] == "root":
+ if all(
+ [x in f[k].attrs for x in ("version_major", "version_minor")]
+ ):
+ if (
+ int(f[k].attrs["version_major"]),
+ int(f[k].attrs["version_minor"]),
+ ) == (0, 13):
+ return True
+ return False
+
+
+def is_py4DSTEM_file(filepath):
+ """Returns True iff filepath points to a py4DSTEM formatted (EMD type 2) file."""
+ if is_py4DSTEM_version13(filepath):
+ return True
+ else:
+ try:
+ topgroups = get_py4DSTEM_topgroups(filepath)
+ if len(topgroups) > 0:
+ return True
+ else:
+ return False
+ except OSError:
+ return False
+
+
+def get_py4DSTEM_version(filepath, topgroup="4DSTEM_experiment"):
+ """Returns the version (major,minor,release) of a py4DSTEM file."""
+ assert is_py4DSTEM_file(filepath), "Error: not recognized as a py4DSTEM file"
+ with h5py.File(filepath, "r") as f:
+ version_major = int(f[topgroup].attrs["version_major"])
+ version_minor = int(f[topgroup].attrs["version_minor"])
+ if "version_release" in f[topgroup].attrs.keys():
+ version_release = int(f[topgroup].attrs["version_release"])
+ else:
+ version_release = 0
+ return version_major, version_minor, version_release
+
+
+def get_UUID(filepath, topgroup="4DSTEM_experiment"):
+ """Returns the UUID of a py4DSTEM file, or if unavailable returns -1."""
+ assert is_py4DSTEM_file(filepath), "Error: not recognized as a py4DSTEM file"
+ with h5py.File(filepath, "r") as f:
+ if topgroup in f.keys():
+ if "UUID" in f[topgroup].attrs:
+ return f[topgroup].attrs["UUID"]
+ return -1
+
+
+def version_is_geq(current, minimum):
+ """Returns True iff current version (major,minor,release) is greater than or equal to minimum." """
+ if current[0] > minimum[0]:
+ return True
+ elif current[0] == minimum[0]:
+ if current[1] > minimum[1]:
+ return True
+ elif current[1] == minimum[1]:
+ if current[2] >= minimum[2]:
+ return True
+ else:
+ return False
+ else:
+ return False
+
+
+def get_N_dataobjects(filepath, topgroup="4DSTEM_experiment"):
+ """Returns a 7-tuple of ints with the numbers of: DataCubes, CountedDataCubes,
+ DiffractionSlices, RealSlices, PointLists, PointListArrays, total DataObjects.
+ """
+ assert is_py4DSTEM_file(filepath), "Error: not recognized as a py4DSTEM file"
+ with h5py.File(filepath, "r") as f:
+ assert topgroup in f.keys(), "Error: unrecognized topgroup"
+ N_dc = len(f[topgroup]["data/datacubes"].keys())
+ N_cdc = len(f[topgroup]["data/counted_datacubes"].keys())
+ N_ds = len(f[topgroup]["data/diffractionslices"].keys())
+ N_rs = len(f[topgroup]["data/realslices"].keys())
+ N_pl = len(f[topgroup]["data/pointlists"].keys())
+ N_pla = len(f[topgroup]["data/pointlistarrays"].keys())
+ try:
+ N_coords = len(f[topgroup]["data/coordinates"].keys())
+ except:
+ N_coords = 0
+ N_do = N_dc + N_cdc + N_ds + N_rs + N_pl + N_pla + N_coords
+ return N_dc, N_cdc, N_ds, N_rs, N_pl, N_pla, N_coords, N_do
diff --git a/py4DSTEM/io/parsefiletype.py b/py4DSTEM/io/parsefiletype.py
new file mode 100644
index 000000000..84b53e4dc
--- /dev/null
+++ b/py4DSTEM/io/parsefiletype.py
@@ -0,0 +1,112 @@
+# File parser utility
+
+from os.path import splitext
+import py4DSTEM.io.legacy as legacy
+import emdfile as emd
+import h5py
+
+import emdfile as emd
+import h5py
+import py4DSTEM.io.legacy as legacy
+
+
+def _parse_filetype(fp):
+ """
+ Accepts a path to a data file, and returns the file type as a string.
+ """
+ _, fext = splitext(fp)
+ fext = fext.lower()
+ if fext in [
+ ".h5",
+ ".hdf5",
+ ".py4dstem",
+ ".emd",
+ ]:
+ if emd._is_EMD_file(fp):
+ return "emd"
+
+ elif legacy.is_py4DSTEM_file(fp):
+ return "legacy"
+
+ elif _is_arina(fp):
+ return "arina"
+
+ elif _is_abTEM(fp):
+ return "abTEM"
+ else:
+ raise Exception("not supported `h5` data type")
+
+ elif fext in [
+ ".dm",
+ ".dm3",
+ ".dm4",
+ ]:
+ return "dm"
+ elif fext in [".raw"]:
+ return "empad"
+ elif fext in [".mrc"]:
+ return "mrc_relativity"
+ elif fext in [".gtg", ".bin"]:
+ return "gatan_K2_bin"
+ elif fext in [".kitware_counted"]:
+ return "kitware_counted"
+ elif fext in [".mib", ".MIB"]:
+ return "mib"
+ else:
+ raise Exception(f"Unrecognized file extension {fext}.")
+
+
+def _is_arina(filepath):
+ """
+ Check if an h5 file is an Arina file.
+ """
+ with h5py.File(filepath, "r") as f:
+ try:
+ assert "entry" in f.keys()
+ except AssertionError:
+ return False
+ try:
+ assert "NX_class" in f["entry"].attrs.keys()
+ except AssertionError:
+ return False
+ return True
+
+
+def _is_abTEM(filepath):
+ """
+ Check if an h5 file is an abTEM file.
+ """
+ with h5py.File(filepath, "r") as f:
+ try:
+ assert "array" in f.keys()
+ except AssertionError:
+ return False
+ return True
+
+
+def _is_arina(filepath):
+ """
+ Check if an h5 file is an Arina file.
+ """
+ with h5py.File(filepath, "r") as f:
+ try:
+ assert "entry" in f.keys()
+ except AssertionError:
+ return False
+ try:
+ assert "NX_class" in f["entry"].attrs.keys()
+ except AssertionError:
+ return False
+ return True
+
+
+def _is_abTEM(filepath):
+ """
+ Check if an h5 file is an abTEM file.
+ """
+ with h5py.File(filepath, "r") as f:
+ try:
+ assert "array" in f.keys()
+ except AssertionError:
+ return False
+ return True
diff --git a/py4DSTEM/io/read.py b/py4DSTEM/io/read.py
new file mode 100644
index 000000000..6dc4ce37e
--- /dev/null
+++ b/py4DSTEM/io/read.py
@@ -0,0 +1,194 @@
+# Reader for native files
+
+import warnings
+from os.path import exists
+from pathlib import Path
+from typing import Optional, Union
+
+import emdfile as emd
+import py4DSTEM.io.legacy as legacy
+from py4DSTEM.data import Data
+from py4DSTEM.io.parsefiletype import _parse_filetype
+
+
+def read(
+ filepath: Union[str, Path],
+ datapath: Optional[str] = None,
+ tree: Optional[Union[bool, str]] = True,
+ verbose: Optional[bool] = False,
+ **kwargs,
+):
+ """
+ A file reader for native py4DSTEM / EMD files. To read non-native
+ formats, use `py4DSTEM.import_file`.
+
+ For files written by py4DSTEM version 0.14+, the function arguments
+ are those listed here - filepath, datapath, and tree. See below for
+ descriptions.
+
+ Files written by py4DSTEM v0.14+ are EMD 1.0 files, an HDF5 based
+ format. For a description and complete file specification, see
+ https://emdatasets.com/format/. For the Python implementation of
+ EMD 1.0 read-write routines which py4DSTEM is build on top of, see
+ https://github.com/py4dstem/emdfile.
+
+ To read file written by older verions of py4DSTEM, different keyword
+ arguments should be passed. See the docstring for
+ `py4DSTEM.io.native.legacy.read_py4DSTEM_legacy` for a complete list.
+ For example, `data_id` may need to be specified to select dataset.
+
+ Args:
+ filepath (str or Path): the file path
+ datapath (str or None): the path within the H5 file to the data
+ group to read from. If there is a single EMD data tree in the
+ file, `datapath` may be left as None, and the path will
+ be set to the root node of that tree. If `datapath` is None
+ and there are multiple EMD trees, this function will issue a
+ warning a return a list of paths to the root nodes of all
+ EMD trees it finds. Otherwise, should be a '/' delimited path
+ to the data node of interest, for example passing
+ 'rootnode/somedata/someotherdata' will set the node called
+ 'someotherdata' as the point to read from. To print the tree
+ of data nodes present in a file to the screen, use
+ `py4DSTEM.print_h5_tree(filepath)`.
+ tree (True or False or 'noroot'): indicates what data should be loaded,
+ relative to the target data group specified with `datapath`.
+ Enables reading the target data node only if `tree` is False,
+ reading the target node as well as recursively reading the tree
+ of data underneath it if `tree` is True, or recursively reading
+ the tree of data underneath the target node but excluding the
+ target node itself if `tree` is to 'noroot'.
+ Returns:
+ (the data)
+ """
+
+ # parse filetype
+ er1 = f"filepath must be a string or Path, not {type(filepath)}"
+ er2 = f"specified filepath '{filepath}' does not exist"
+ assert isinstance(filepath, (str, Path)), er1
+ assert exists(filepath), er2
+
+ filetype = _parse_filetype(filepath)
+ assert filetype in (
+ "emd",
+ "legacy",
+ ), f"`py4DSTEM.read` loads native HDF5 formatted files, but a file of type {filetype} was detected. Try loading it using py4DSTEM.import_file"
+
+ # support older `root` input
+ if datapath is None:
+ if "root" in kwargs:
+ datapath = kwargs["root"]
+
+ # EMD 1.0 formatted files (py4DSTEM v0.14+)
+ if filetype == "emd":
+ # check version
+ version = emd._get_EMD_version(filepath)
+ if verbose:
+ print(
+ f"EMD version {version[0]}.{version[1]}.{version[2]} detected. Reading..."
+ )
+ assert emd._version_is_geq(
+ version, (1, 0, 0)
+ ), f"EMD version {version} detected. Expected version >= 1.0.0"
+
+ # read
+ data = emd.read(filepath, emdpath=datapath, tree=tree)
+ if verbose:
+ print("Data was read from file. Adding calibration links...")
+
+ # add calibration links
+ if isinstance(data, Data):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ cal = data.calibration
+ elif isinstance(data, emd.Root):
+ try:
+ cal = data.metadata["calibration"]
+ except KeyError:
+ cal = None
+ else:
+ cal = None
+ if cal is not None:
+ try:
+ root_treepath = cal["_root_treepath"]
+ target_paths = cal["_target_paths"]
+ del cal._params["_target_paths"]
+ for p in target_paths:
+ try:
+ p = p.replace(root_treepath, "")
+ d = data.root.tree(p)
+ cal.register_target(d)
+ if hasattr(d, "setcal"):
+ d.setcal()
+ except AssertionError:
+ pass
+ except KeyError:
+ pass
+ cal.calibrate()
+
+ # return
+ if verbose:
+ print("Done.")
+ return data
+
+ # legacy py4DSTEM files (v <= 0.13)
+ else:
+ assert (
+ filetype == "legacy"
+ ), "path points to an H5 file which is neither an EMD 1.0+ file, nor a recognized legacy py4DSTEM file."
+
+ # read v13
+ if legacy.is_py4DSTEM_version13(filepath):
+ # load the data
+ if verbose:
+ print("Legacy py4DSTEM version 13 file detected. Reading...")
+ kwargs["root"] = datapath
+ kwargs["tree"] = tree
+ data = legacy.read_legacy13(
+ filepath=filepath,
+ **kwargs,
+ )
+ if verbose:
+ print("Done.")
+ return data
+
+ # read <= v12
+ else:
+ # parse the root/data_id from the datapath arg
+ if datapath is not None:
+ datapath = datapath.split("/")
+ try:
+ datapath.remove("")
+ except ValueError:
+ pass
+ rootgroup = datapath[0]
+ if len(datapath) > 1:
+ datapath = "/".join(rootgroup[1:])
+ else:
+ datapath = None
+ else:
+ rootgroups = legacy.get_py4DSTEM_topgroups(filepath)
+ if len(rootgroups) > 1:
+ print(
+ "multiple root groups in a legacy file found - returning list of root names; please pass one as `datapath`"
+ )
+ return rootgroups
+ elif len(rootgroups) == 0:
+ raise Exception("No rootgroups found")
+ else:
+ rootgroup = rootgroups[0]
+ datapath = None
+
+ # load the data
+ if verbose:
+ print("Legacy py4DSTEM version <= 12 file detected. Reading...")
+ kwargs["topgroup"] = rootgroup
+ if datapath is not None:
+ kwargs["data_id"] = datapath
+ data = legacy.read_legacy12(
+ filepath=filepath,
+ **kwargs,
+ )
+ if verbose:
+ print("Done.")
+ return data
diff --git a/py4DSTEM/io/save.py b/py4DSTEM/io/save.py
new file mode 100644
index 000000000..148f246d5
--- /dev/null
+++ b/py4DSTEM/io/save.py
@@ -0,0 +1,26 @@
+from emdfile import save as _save
+import warnings
+
+
+def save(filepath, data, mode="w", emdpath=None, tree=True):
+ """
+ Saves data to an EMD 1.0 formatted HDF5 file at filepath.
+
+ For the full docstring, see py4DSTEM.emdfile.save.
+ """
+ # This function wraps emdfile's save and adds a small piece
+ # of metadata to the calibration to allow linking to calibrated
+ # data items on read
+
+ cal = None
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ if hasattr(data, "calibration") and data.calibration is not None:
+ cal = data.calibration
+ rp = "/".join(data._treepath.split("/")[:-1])
+ cal["_root_treepath"] = rp
+
+ _save(filepath, data=data, mode=mode, emdpath=emdpath, tree=tree)
+
+ if cal is not None:
+ del cal._params["_root_treepath"]
diff --git a/py4DSTEM/preprocess/__init__.py b/py4DSTEM/preprocess/__init__.py
new file mode 100644
index 000000000..4ba0bdee0
--- /dev/null
+++ b/py4DSTEM/preprocess/__init__.py
@@ -0,0 +1,5 @@
+from py4DSTEM.preprocess.utils import *
+from py4DSTEM.preprocess.preprocess import *
+from py4DSTEM.preprocess.darkreference import *
+from py4DSTEM.preprocess.electroncount import *
+from py4DSTEM.preprocess.radialbkgrd import *
diff --git a/py4DSTEM/preprocess/darkreference.py b/py4DSTEM/preprocess/darkreference.py
new file mode 100644
index 000000000..a23d4271b
--- /dev/null
+++ b/py4DSTEM/preprocess/darkreference.py
@@ -0,0 +1,200 @@
+# Functions for background fitting and subtraction.
+
+import numpy as np
+
+#### Subtrack darkreference from datacube frame at (Rx,Ry) ####
+
+
+def get_bksbtr_DP(datacube, darkref, Rx, Ry):
+ """
+ Returns a background subtracted diffraction pattern.
+
+ Args:
+ datacube (DataCube): data to background subtract
+ darkref (ndarray): dark reference. must have shape (datacube.Q_Nx, datacube.Q_Ny)
+ Rx,Ry (int): the scan position of the diffraction pattern of interest
+
+ Returns:
+ (ndarray) the background subtracted diffraction pattern
+ """
+ assert darkref.shape == (
+ datacube.Q_Nx,
+ datacube.Q_Ny,
+ ), "background must have shape (datacube.Q_Nx, datacube.Q_Ny)"
+ return datacube.data[Rx, Ry, :, :].astype(float) - darkref.astype(float)
+
+
+#### Get dark reference ####
+
+
+def get_darkreference(
+ datacube, N_frames, width_x=0, width_y=0, side_x="end", side_y="end"
+):
+ """
+ Gets a dark reference image.
+
+ Select N_frames random frames (DPs) from datacube. Find streaking noise in the
+ horizontal and vertical directions, by finding the average values along a thin strip
+ of width_x/width_y pixels along the detector edges. Which edges are used is
+ controlled by side_x/side_y, which must be 'start' or 'end'. Streaks along only one
+ direction can be used by setting width_x or width_y to 0, which disables correcting
+ streaks in this direction.
+
+ Note that the data is cast to float before computing the background, and should
+ similarly be cast to float before performing a subtraction. This avoids integer
+ clipping and wraparound errors.
+
+ Args:
+ datacube (DataCube): data to background subtract
+ N_frames (int): number of random diffraction patterns to use
+ width_x (int): width of the ROI strip for finding streaking in x
+ width_y (int): see above
+ side_x (str): use a strip from the start or end of the array. Must be 'start' or
+ 'end', defaults to 'end'
+ side_y (str): see above
+
+ Returns:
+ (ndarray): a 2D ndarray of shape (datacube.Q_Nx, datacube.Ny) giving the
+ background.
+ """
+ if width_x == 0 and width_y == 0:
+ print(
+ "Warning: either width_x or width_y should be a positive integer. Returning an empty dark reference."
+ )
+ return np.zeros((datacube.Q_Nx, datacube.Q_Ny))
+ elif width_x == 0:
+ return get_background_streaks_y(
+ datacube=datacube, N_frames=N_frames, width=width_y, side=side_y
+ )
+ elif width_y == 0:
+ return get_background_streaks_x(
+ datacube=datacube, N_frames=N_frames, width=width_x, side=side_x
+ )
+ else:
+ darkref_x = get_background_streaks_x(
+ datacube=datacube, N_frames=N_frames, width=width_x, side=side_x
+ )
+ darkref_y = get_background_streaks_y(
+ datacube=datacube, N_frames=N_frames, width=width_y, side=side_y
+ )
+ return (
+ darkref_x
+ + darkref_y
+ - (np.mean(darkref_x) * width_x + np.mean(darkref_y) * width_y)
+ / (width_x + width_y)
+ )
+ # Mean has been added twice; subtract one off
+
+
+def get_background_streaks(datacube, N_frames, width, side="end", direction="x"):
+ """
+ Gets background streaking in either the x- or y-direction, by finding the average of
+ a strip of pixels along the edge of the detector over a random selection of
+ diffraction patterns, and returns a dark reference array.
+
+ Note that the data is cast to float before computing the background, and should
+ similarly be cast to float before performing a subtraction. This avoids integer
+ clipping and wraparound errors.
+
+ Args:
+ datacube (DataCube): data to background subtract
+ N_frames (int): number of random frames to use
+ width (int): width of the ROI strip for background identification
+ side (str, optional): use a strip from the start or end of the array. Must be
+ 'start' or 'end', defaults to 'end'
+ directions (str): the direction of background streaks to find. Must be either
+ 'x' or 'y' defaults to 'x'
+
+ Returns:
+ (ndarray): a 2D ndarray of shape (datacube.Q_Nx,datacube.Q_Ny), giving the
+ the x- or y-direction background streaking.
+ """
+ assert (direction == "x") or (direction == "y"), "direction must be 'x' or 'y'."
+ if direction == "x":
+ return get_background_streaks_x(
+ datacube=datacube, N_frames=N_frames, width=width, side=side
+ )
+ else:
+ return get_background_streaks_y(
+ datacube=datacube, N_frames=N_frames, width=width, side=side
+ )
+
+
+def get_background_streaks_x(datacube, width, N_frames, side="start"):
+ """
+ Gets background streaking, by finding the average of a strip of pixels along the
+ y-edge of the detector over a random selection of diffraction patterns.
+
+ See docstring for get_background_streaks() for more info.
+ """
+ assert (
+ N_frames <= datacube.R_Nx * datacube.R_Ny
+ ), "N_frames must be less than or equal to the total number of diffraction patterns."
+ assert (side == "start") or (side == "end"), "side must be 'start' or 'end'."
+
+ # Get random subset of DPs
+ indices = np.arange(datacube.R_Nx * datacube.R_Ny)
+ np.random.shuffle(indices)
+ indices = indices[:N_frames]
+ indices_x, indices_y = np.unravel_index(indices, (datacube.R_Nx, datacube.R_Ny))
+
+ # Make a reference strip array
+ refstrip = np.zeros((width, datacube.Q_Ny))
+ if side == "start":
+ for i in range(N_frames):
+ refstrip += datacube.data[indices_x[i], indices_y[i], :width, :].astype(
+ float
+ )
+ else:
+ for i in range(N_frames):
+ refstrip += datacube.data[indices_x[i], indices_y[i], -width:, :].astype(
+ float
+ )
+
+ # Calculate mean and return 1D array of streaks
+ bkgrnd_streaks = np.sum(refstrip, axis=0) // width // N_frames
+
+ # Broadcast to 2D array
+ darkref = np.zeros((datacube.Q_Nx, datacube.Q_Ny))
+ darkref += bkgrnd_streaks[np.newaxis, :]
+ return darkref
+
+
+def get_background_streaks_y(datacube, N_frames, width, side="start"):
+ """
+ Gets background streaking, by finding the average of a strip of pixels along the
+ x-edge of the detector over a random selection of diffraction patterns.
+
+ See docstring for get_background_streaks_1D() for more info.
+ """
+ assert (
+ N_frames <= datacube.R_Nx * datacube.R_Ny
+ ), "N_frames must be less than or equal to the total number of diffraction patterns."
+ assert (side == "start") or (side == "end"), "side must be 'start' or 'end'."
+
+ # Get random subset of DPs
+ indices = np.arange(datacube.R_Nx * datacube.R_Ny)
+ np.random.shuffle(indices)
+ indices = indices[:N_frames]
+ indices_x, indices_y = np.unravel_index(indices, (datacube.R_Nx, datacube.R_Ny))
+
+ # Make a reference strip array
+ refstrip = np.zeros((datacube.Q_Nx, width))
+ if side == "start":
+ for i in range(N_frames):
+ refstrip += datacube.data[indices_x[i], indices_y[i], :, :width].astype(
+ float
+ )
+ else:
+ for i in range(N_frames):
+ refstrip += datacube.data[indices_x[i], indices_y[i], :, -width:].astype(
+ float
+ )
+
+ # Calculate mean and return 1D array of streaks
+ bkgrnd_streaks = np.sum(refstrip, axis=1) // width // N_frames
+
+ # Broadcast to 2D array
+ darkref = np.zeros((datacube.Q_Nx, datacube.Q_Ny))
+ darkref += bkgrnd_streaks[:, np.newaxis]
+ return darkref
diff --git a/py4DSTEM/preprocess/electroncount.py b/py4DSTEM/preprocess/electroncount.py
new file mode 100644
index 000000000..e3fc68e05
--- /dev/null
+++ b/py4DSTEM/preprocess/electroncount.py
@@ -0,0 +1,460 @@
+# Electron counting
+#
+# Includes functions for electron counting on either the CPU (electron_count)
+# or the GPU (electron_count_GPU). For GPU electron counting, pytorch is used
+# to interface between numpy and the GPU, and the datacube is expected in
+# numpy.memmap (memory mapped) form.
+
+import numpy as np
+from scipy import optimize
+
+from emdfile import PointListArray
+from py4DSTEM.preprocess.utils import get_maxima_2D, bin2D
+
+
+def electron_count(
+ datacube,
+ darkreference,
+ Nsamples=40,
+ thresh_bkgrnd_Nsigma=4,
+ thresh_xray_Nsigma=10,
+ binfactor=1,
+ sub_pixel=True,
+ output="pointlist",
+):
+ """
+ Performs electron counting.
+
+ The algorithm is as follows:
+ From a random sampling of frames, calculate an x-ray and background
+ threshold value. In each frame, subtract the dark reference, then apply the
+ two thresholds. Find all local maxima with respect to the nearest neighbor
+ pixels. These are considered electron strike events.
+
+ Thresholds are specified in units of standard deviations, either of a
+ gaussian fit to the histogram background noise (for thresh_bkgrnd) or of
+ the histogram itself (for thresh_xray). The background (lower) threshold is
+ more important; we will always be missing some real electron counts and
+ incorrectly counting some noise as electron strikes - this threshold
+ controls their relative balance. The x-ray threshold may be set fairly high.
+
+ Args:
+ datacube: a 4D numpy.ndarray pointing to the datacube. Note: the R/Q axes are
+ flipped with respect to py4DSTEM DataCube objects
+ darkreference: a 2D numpy.ndarray with the dark reference
+ Nsamples: the number of frames to use in dark reference and threshold
+ calculation.
+ thresh_bkgrnd_Nsigma: the background threshold is
+ ``mean(guassian fit) + (this #)*std(gaussian fit)``
+ where the gaussian fit is to the background noise.
+ thresh_xray_Nsigma: the X-ray threshold is
+ ``mean(hist) +/- (this #)*std(hist)``
+ where hist is the histogram of all pixel values in the Nsamples random frames
+ binfactor: the binnning factor
+ sub_pixel (bool): controls whether subpixel refinement is performed
+ output (str): controls output format; must be 'datacube' or 'pointlist'
+
+ Returns:
+ (variable) if output=='pointlist', returns a PointListArray of all electron
+ counts in each frame. If output=='datacube', returns a 4D array of bools, with
+ True indicating electron strikes
+ """
+ assert isinstance(output, str), "output must be a str"
+ assert output in [
+ "pointlist",
+ "datacube",
+ ], "output must be 'pointlist' or 'datacube'"
+
+ # Get dimensions
+ R_Nx, R_Ny, Q_Nx, Q_Ny = np.shape(datacube)
+
+ # Get threshholds
+ print("Calculating threshholds")
+ thresh_bkgrnd, thresh_xray = calculate_thresholds(
+ datacube,
+ darkreference,
+ Nsamples=Nsamples,
+ thresh_bkgrnd_Nsigma=thresh_bkgrnd_Nsigma,
+ thresh_xray_Nsigma=thresh_xray_Nsigma,
+ )
+
+ # Save to a new datacube
+ if output == "datacube":
+ counted = np.ones((R_Nx, R_Ny, Q_Nx // binfactor, Q_Ny // binfactor))
+ # Loop through frames
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ frame = datacube[Rx, Ry, :, :].astype(np.int16) # Get frame from file
+ workingarray = frame - darkreference # Subtract dark ref from frame
+ events = workingarray > thresh_bkgrnd # Threshold electron events
+ events *= thresh_xray > workingarray
+
+ ## Keep events which are greater than all NN pixels ##
+ events = get_maxima_2D(workingarray * events)
+
+ if binfactor > 1:
+ # Perform binning
+ counted[Rx, Ry, :, :] = bin2D(events, factor=binfactor)
+ else:
+ counted[Rx, Ry, :, :] = events
+ return counted
+
+ # Save to a PointListArray
+ else:
+ coordinates = [("qx", int), ("qy", int)]
+ pointlistarray = PointListArray(coordinates=coordinates, shape=(R_Nx, R_Ny))
+ # Loop through frames
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ frame = datacube[Rx, Ry, :, :].astype(np.int16) # Get frame from file
+ workingarray = frame - darkreference # Subtract dark ref from frame
+ events = workingarray > thresh_bkgrnd # Threshold electron events
+ events *= thresh_xray > workingarray
+
+ ## Keep events which are greater than all NN pixels ##
+ events = get_maxima_2D(workingarray * events)
+
+ # Perform binning
+ if binfactor > 1:
+ events = bin2D(events, factor=binfactor)
+
+ # Save to PointListArray
+ x, y = np.nonzero(events)
+ pointlist = pointlistarray.get_pointlist(Rx, Ry)
+ pointlist.add_tuple_of_nparrays((x, y))
+
+ return pointlistarray
+
+
+def electron_count_GPU(
+ datacube,
+ darkreference,
+ Nsamples=40,
+ thresh_bkgrnd_Nsigma=4,
+ thresh_xray_Nsigma=10,
+ binfactor=1,
+ sub_pixel=True,
+ output="pointlist",
+):
+ """
+ Performs electron counting on the GPU.
+
+ Uses pytorch to interface between numpy and cuda. Requires cuda and pytorch.
+ This function expects datacube to be a np.memmap object.
+ See electron_count() for additional documentation.
+ """
+ import torch
+ import dm
+
+ assert isinstance(output, str), "output must be a str"
+ assert output in [
+ "pointlist",
+ "datacube",
+ ], "output must be 'pointlist' or 'datacube'"
+
+ # Get dimensions
+ R_Nx, R_Ny, Q_Nx, Q_Ny = np.shape(datacube)
+
+ # Get threshholds
+ print("Calculating threshholds")
+ thresh_bkgrnd, thresh_xray = calculate_thresholds(
+ datacube,
+ darkreference,
+ Nsamples=Nsamples,
+ thresh_bkgrnd_Nsigma=thresh_bkgrnd_Nsigma,
+ thresh_xray_Nsigma=thresh_xray_Nsigma,
+ )
+
+ # Make a torch device object, to interface numpy with the GPU
+ # Put a few arrays on it - dark reference, counted image
+ device = torch.device("cuda")
+ darkref = torch.from_numpy(darkreference.astype(np.int16)).to(device)
+ counted = torch.ones(
+ R_Nx, R_Ny, Q_Nx // binfactor, Q_Ny // binfactor, dtype=torch.short
+ ).to(device)
+
+ # Loop through frames
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ frame = datacube[Rx, Ry, :, :].astype(np.int16) # Get frame from file
+ gframe = torch.from_numpy(frame).to(device) # Move frame to GPU
+ workingarray = gframe - darkref # Subtract dark ref from frame
+ events = workingarray > thresh_bkgrnd # Threshold electron events
+ events = thresh_xray > workingarray
+
+ ## Keep events which are greater than all NN pixels ##
+
+ # Check pixel is greater than all adjacent pixels
+ log = workingarray[1:-1, :] > workingarray[0:-2, :]
+ events[1:-1, :] = events[1:-1, :] & log
+ log = workingarray[0:-2, :] > workingarray[1:-1, :]
+ events[0:-2, :] = events[0:-2, :] & log
+ log = workingarray[:, 1:-1] > workingarray[:, 0:-2]
+ events[:, 1:-1] = events[:, 1:-1] & log
+ log = workingarray[:, 0:-2] > workingarray[:, 1:-1]
+ events[:, 0:-2] = events[:, 0:-2] & log
+ # Check pixel is greater than adjacent diagonal pixels
+ log = workingarray[1:-1, 1:-1] > workingarray[0:-2, 0:-2]
+ events[1:-1, 1:-1] = events[1:-1, 1:-1] & log
+ log = workingarray[0:-2, 1:-1] > workingarray[1:-1, 0:-2]
+ events[0:-2, 1:-1] = events[0:-2, 1:-1] & log
+ log = workingarray[1:-1, 0:-2] > workingarray[0:-2, 1:-1]
+ events[2:-1, 0:-2] = events[1:-1, 0:-2] & log
+ log = workingarray[0:-2, 0:-2] > workingarray[1:-1, 1:-1]
+ events[0:-2, 0:-2] = events[0:-2, 0:-2] & log
+
+ if binfactor > 1:
+ # Perform binning on GPU in torch_bin function
+ counted[Rx, Ry, :, :] = (
+ torch.transpose(
+ torch_bin(
+ events.type(torch.cuda.ShortTensor),
+ device,
+ factor=binfactor,
+ ),
+ 0,
+ 1,
+ )
+ .flip(0)
+ .flip(1)
+ )
+ else:
+ # I'm not sure I understand this - we're flipping coordinates to match what?
+ # TODO: check array flipping - may vary by camera
+ counted[Rx, Ry, :, :] = (
+ torch.transpose(events.type(torch.cuda.ShortTensor), 0, 1)
+ .flip(0)
+ .flip(1)
+ )
+
+ if output == "datacube":
+ return counted.cpu().numpy()
+ else:
+ return counted_datacube_to_pointlistarray(counted)
+
+
+####### Support functions ########
+
+
+def calculate_thresholds(
+ datacube,
+ darkreference,
+ Nsamples=20,
+ thresh_bkgrnd_Nsigma=4,
+ thresh_xray_Nsigma=10,
+ return_params=False,
+):
+ """
+ Calculate the upper and lower thresholds for thresholding what to register as
+ an electron count.
+
+ Both thresholds are determined from the histogram of detector pixel values summed
+ over Nsamples frames. The thresholds are set to::
+
+ thresh_xray_Nsigma = mean(histogram) + thresh_upper * std(histogram)
+ thresh_bkgrnd_N_sigma = mean(guassian fit) + thresh_lower * std(gaussian fit)
+
+ For more info, see the electron_count docstring.
+
+ Args:
+ datacube: a 4D numpy.ndarrau pointing to the datacube
+ darkreference: a 2D numpy.ndarray with the dark reference
+ Nsamples: the number of frames to use in dark reference and threshold
+ calculation.
+ thresh_bkgrnd_Nsigma: the background threshold is
+ ``mean(guassian fit) + (this #)*std(gaussian fit)``
+ where the gaussian fit is to the background noise.
+ thresh_xray_Nsigma: the X-ray threshold is
+ ``mean(hist) + (this #)*std(hist)``
+ where hist is the histogram of all pixel values in the Nsamples random frames
+ return_params: bool, if True return n,hist of the histogram and popt of the
+ gaussian fit
+
+ Returns:
+ (5-tuple): A 5-tuple containing:
+
+ * **thresh_bkgrnd**: the background threshold
+ * **thresh_xray**: the X-ray threshold
+ * **n**: returned iff return_params==True. The histogram values
+ * **hist**: returned iff return_params==True. The histogram bin edges
+ * **popt**: returned iff return_params==True. The fit gaussian parameters,
+ (A, mu, sigma).
+ """
+ R_Nx, R_Ny, Q_Nx, Q_Ny = datacube.shape
+
+ # Select random set of frames
+ nframes = R_Nx * R_Ny
+ samples = np.arange(nframes)
+ np.random.shuffle(samples)
+ samples = samples[:Nsamples]
+
+ # Get frames and subtract dark references
+ sample = np.zeros((Q_Nx, Q_Ny, Nsamples), dtype=np.int16)
+ for i in range(Nsamples):
+ sample[:, :, i] = datacube[samples[i] // R_Nx, samples[i] % R_Ny, :, :]
+ sample[:, :, i] -= darkreference
+ sample = np.ravel(sample) # Flatten array
+
+ # Get upper (X-ray) threshold
+ mean = np.mean(sample)
+ stddev = np.std(sample)
+ thresh_xray = mean + thresh_xray_Nsigma * stddev
+
+ # Make a histogram
+ binmax = min(int(np.ceil(np.amax(sample))), int(mean + thresh_xray * stddev))
+ binmin = max(int(np.ceil(np.amin(sample))), int(mean - thresh_xray * stddev))
+ step = max(1, (binmax - binmin) // 1000)
+ bins = np.arange(binmin, binmax, step=step, dtype=np.int16)
+ n, bins = np.histogram(sample, bins=bins)
+
+ # Define Guassian to fit to, with parameters p:
+ # p[0] is amplitude
+ # p[1] is the mean
+ # p[2] is std deviation
+ fitfunc = lambda p, x: p[0] * np.exp(-0.5 * np.square((x - p[1]) / p[2]))
+ errfunc = lambda p, x, y: fitfunc(p, x) - y # Error for scipy's optimize routine
+
+ # Get initial guess
+ p0 = [n.max(), (bins[n.argmax() + 1] - bins[n.argmax()]) / 2, np.std(sample)]
+ p1, success = optimize.leastsq(
+ errfunc, p0[:], args=(bins[:-1], n)
+ ) # Use the scipy optimize routine
+ p1[1] += 0.5 # Add a half to account for integer bin width
+
+ # Set lower threshhold for electron counts to count
+ thresh_bkgrnd = p1[1] + p1[2] * thresh_bkgrnd_Nsigma
+
+ if return_params:
+ return thresh_bkgrnd, thresh_xray, n, bins, p1
+ else:
+ return thresh_bkgrnd, thresh_xray
+
+
+def torch_bin(array, device, factor=2):
+ """
+ Bin data on the GPU using torch.
+
+ Args:
+ array: a 2D numpy array
+ device: a torch device class instance
+ factor (int): the binning factor
+
+ Returns:
+ (array): the binned array
+ """
+
+ import torch
+
+ x, y = array.shape
+ binx, biny = x // factor, y // factor
+ xx, yy = binx * factor, biny * factor
+
+ # Make a binned array on the device
+ binned_ar = torch.zeros(biny, binx, device=device, dtype=array.dtype)
+
+ # Collect pixel sums into new bins
+ for ix in range(factor):
+ for iy in range(factor):
+ binned_ar += array[0 + ix : xx + ix : factor, 0 + iy : yy + iy : factor]
+ return binned_ar
+
+
+def counted_datacube_to_pointlistarray(counted_datacube, subpixel=False):
+ """
+ Converts an electron counted datacube to PointListArray.
+
+ Args:
+ counted_datacube: a 4D array of bools, with true indicating an electron strike.
+ subpixel (bool): controls if subpixel electron strike positions are expected
+
+ Returns:
+ (PointListArray): a PointListArray of electron strike events
+ """
+ # Get shape, initialize PointListArray
+ R_Nx, R_Ny, Q_Nx, Q_Ny = counted_datacube.shape
+ if subpixel:
+ coordinates = [("qx", float), ("qy", float)]
+ else:
+ coordinates = [("qx", int), ("qy", int)]
+ pointlistarray = PointListArray(coordinates=coordinates, shape=(R_Nx, R_Ny))
+
+ # Loop through frames, adding electron counts to the PointListArray for each.
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ frame = counted_datacube[Rx, Ry, :, :]
+ x, y = np.nonzero(frame)
+ pointlist = pointlistarray.get_pointlist(Rx, Ry)
+ pointlist.add_tuple_of_nparrays((x, y))
+
+ return pointlistarray
+
+
+def counted_pointlistarray_to_datacube(counted_pointlistarray, shape, subpixel=False):
+ """
+ Converts an electron counted PointListArray to a datacube.
+
+ Args:
+ counted_pointlistarray (PointListArray): a PointListArray of electron strike
+ events
+ shape (4-tuple): a length 4 tuple of ints containing (R_Nx,R_Ny,Q_Nx,Q_Ny)
+ subpixel (bool): controls if subpixel electron strike positions are expected
+
+ Returns:
+ (4D array of bools): a 4D array of bools, with true indicating an electron strike.
+ """
+ assert len(shape) == 4
+ assert subpixel is False, "subpixel mode not presently supported."
+ R_Nx, R_Ny, Q_Nx, Q_Ny = shape
+ counted_datacube = np.zeros((R_Nx, R_Nx, Q_Nx, Q_Ny), dtype=bool)
+
+ # Loop through frames, adding electron counts to the datacube for each.
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ pointlist = counted_pointlistarray.get_pointlist(Rx, Ry)
+ counted_datacube[Rx, Ry, pointlist.data["qx"], pointlist.data["qy"]] = True
+
+ return counted_datacube
+
+
+if __name__ == "__main__":
+ from py4DSTEM.process.preprocess import get_darkreference
+ from py4DSTEM.io import DataCube, save
+ from ncempy.io import dm
+
+ dm4_filepath = "Capture25.dm4"
+
+ # Parameters for dark reference determination
+ drwidth = 100
+
+ # Parameters for electron counting
+ Nsamples = 40
+ thresh_bkgrnd_Nsigma = 4
+ thresh_xray_Nsigma = 30
+ binfactor = 1
+ subpixel = False
+ output = "pointlist"
+
+ # Get memory mapped 4D datacube from dm file
+ datacube = dm.dmReader(dm4_filepath, dSetNum=0, verbose=False)["data"]
+ datacube = np.moveaxis(datacube, (0, 1), (2, 3))
+
+ # Get dark reference
+ darkreference = 1 # TODO: get_darkreference(datacube = ...!
+
+ electron_counted_data = electron_count(
+ datacube,
+ darkreference,
+ Nsamples=Nsamples,
+ thresh_bkgrnd_Nsigma=thresh_bkgrnd_Nsigma,
+ thresh_xray_Nsigma=thresh_xray_Nsigma,
+ binfactor=binfactor,
+ sub_pixel=True,
+ output="pointlist",
+ )
+
+ # For outputting datacubes, wrap counted into a py4DSTEM DataCube
+ if output == "datacube":
+ electron_counted_data = DataCube(data=electron_counted_data)
+
+ output_path = dm4_filepath.replace(".dm4", ".h5")
+ save(electron_counted_data, output_path)
diff --git a/py4DSTEM/preprocess/preprocess.py b/py4DSTEM/preprocess/preprocess.py
new file mode 100644
index 000000000..fb4983622
--- /dev/null
+++ b/py4DSTEM/preprocess/preprocess.py
@@ -0,0 +1,655 @@
+# Preprocessing functions
+#
+# These functions generally accept DataCube objects as arguments, and return a new, modified
+# DataCube.
+# Most of these functions are also included as DataCube class methods. Thus
+# datacube = preprocess_function(datacube, *args)
+# will be identical to
+# datacube.preprocess_function(*args)
+
+import warnings
+import numpy as np
+from py4DSTEM.preprocess.utils import bin2D, get_shifted_ar
+from emdfile import tqdmnd
+from scipy.ndimage import median_filter
+
+### Editing datacube shape ###
+
+
+def set_scan_shape(datacube, R_Nx, R_Ny):
+ """
+ Reshape the data given the real space scan shape.
+ """
+ try:
+ # reshape
+ datacube.data = datacube.data.reshape(
+ datacube.R_N, datacube.Q_Nx, datacube.Q_Ny
+ ).reshape(R_Nx, R_Ny, datacube.Q_Nx, datacube.Q_Ny)
+
+ # set dim vectors
+ Rpixsize = datacube.calibration.get_R_pixel_size()
+ Rpixunits = datacube.calibration.get_R_pixel_units()
+ datacube.set_dim(0, [0, Rpixsize], units=Rpixunits)
+ datacube.set_dim(1, [0, Rpixsize], units=Rpixunits)
+
+ # return
+ return datacube
+
+ except ValueError:
+ print(
+ "Can't reshape {} scan positions into a {}x{} array.".format(
+ datacube.R_N, R_Nx, R_Ny
+ )
+ )
+ return datacube
+ except AttributeError:
+ print(f"Can't reshape {datacube.data.__class__.__name__} datacube.")
+ return datacube
+
+
+def swap_RQ(datacube):
+ """
+ Swaps real and reciprocal space coordinates, so that if
+
+ >>> datacube.data.shape
+ (Rx,Ry,Qx,Qy)
+
+ Then
+
+ >>> swap_RQ(datacube).data.shape
+ (Qx,Qy,Rx,Ry)
+ """
+ # swap
+ datacube.data = np.transpose(datacube.data, axes=(2, 3, 0, 1))
+
+ # set dim vectors
+ Rpixsize = datacube.calibration.get_R_pixel_size()
+ Rpixunits = datacube.calibration.get_R_pixel_units()
+ Qpixsize = datacube.calibration.get_Q_pixel_size()
+ Qpixunits = datacube.calibration.get_Q_pixel_units()
+ datacube.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx")
+ datacube.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry")
+ datacube.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx")
+ datacube.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy")
+
+ # return
+ return datacube
+
+
+def swap_Rxy(datacube):
+ """
+ Swaps real space x and y coordinates, so that if
+
+ >>> datacube.data.shape
+ (Ry,Rx,Qx,Qy)
+
+ Then
+
+ >>> swap_Rxy(datacube).data.shape
+ (Rx,Ry,Qx,Qy)
+ """
+ # swap
+ datacube.data = np.moveaxis(datacube.data, 1, 0)
+
+ # set dim vectors
+ Rpixsize = datacube.calibration.get_R_pixel_size()
+ Rpixunits = datacube.calibration.get_R_pixel_units()
+ datacube.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx")
+ datacube.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry")
+
+ # return
+ return datacube
+
+
+def swap_Qxy(datacube):
+ """
+ Swaps reciprocal space x and y coordinates, so that if
+
+ >>> datacube.data.shape
+ (Rx,Ry,Qy,Qx)
+
+ Then
+
+ >>> swap_Qxy(datacube).data.shape
+ (Rx,Ry,Qx,Qy)
+ """
+ datacube.data = np.moveaxis(datacube.data, 3, 2)
+ return datacube
+
+
+### Cropping and binning ###
+
+
+def crop_data_diffraction(datacube, crop_Qx_min, crop_Qx_max, crop_Qy_min, crop_Qy_max):
+ # crop
+ datacube.data = datacube.data[
+ :, :, crop_Qx_min:crop_Qx_max, crop_Qy_min:crop_Qy_max
+ ]
+
+ # set dim vectors
+ Qpixsize = datacube.calibration.get_Q_pixel_size()
+ Qpixunits = datacube.calibration.get_Q_pixel_units()
+ datacube.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx")
+ datacube.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy")
+
+ # return
+ return datacube
+
+
+def crop_data_real(datacube, crop_Rx_min, crop_Rx_max, crop_Ry_min, crop_Ry_max):
+ # crop
+ datacube.data = datacube.data[
+ crop_Rx_min:crop_Rx_max, crop_Ry_min:crop_Ry_max, :, :
+ ]
+
+ # set dim vectors
+ Rpixsize = datacube.calibration.get_R_pixel_size()
+ Rpixunits = datacube.calibration.get_R_pixel_units()
+ datacube.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx")
+ datacube.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry")
+
+ # return
+ return datacube
+
+
+def bin_data_diffraction(datacube, bin_factor, dtype=None):
+ """
+ Performs diffraction space binning of data by bin_factor.
+
+ Parameters
+ ----------
+ N : int
+ The binning factor
+ dtype : a datatype (optional)
+ Specify the datatype for the output. If not passed, the datatype
+ is left unchanged
+
+ """
+ # validate inputs
+ assert type(bin_factor) is int, f"Error: binning factor {bin_factor} is not an int."
+ if bin_factor == 1:
+ return datacube
+ if dtype is None:
+ dtype = datacube.data.dtype
+
+ # get shape
+ R_Nx, R_Ny, Q_Nx, Q_Ny = (
+ datacube.R_Nx,
+ datacube.R_Ny,
+ datacube.Q_Nx,
+ datacube.Q_Ny,
+ )
+ # crop edges if necessary
+ if (Q_Nx % bin_factor == 0) and (Q_Ny % bin_factor == 0):
+ pass
+ elif Q_Nx % bin_factor == 0:
+ datacube.data = datacube.data[:, :, :, : -(Q_Ny % bin_factor)]
+ elif Q_Ny % bin_factor == 0:
+ datacube.data = datacube.data[:, :, : -(Q_Nx % bin_factor), :]
+ else:
+ datacube.data = datacube.data[
+ :, :, : -(Q_Nx % bin_factor), : -(Q_Ny % bin_factor)
+ ]
+
+ # bin
+ datacube.data = (
+ datacube.data.reshape(
+ R_Nx,
+ R_Ny,
+ int(Q_Nx / bin_factor),
+ bin_factor,
+ int(Q_Ny / bin_factor),
+ bin_factor,
+ )
+ .sum(axis=(3, 5))
+ .astype(dtype)
+ )
+
+ # set dim vectors
+ Qpixsize = datacube.calibration.get_Q_pixel_size() * bin_factor
+ Qpixunits = datacube.calibration.get_Q_pixel_units()
+
+ datacube.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx")
+ datacube.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy")
+
+ # set calibration pixel size
+ datacube.calibration.set_Q_pixel_size(Qpixsize)
+
+ # return
+ return datacube
+
+
+def bin_data_mmap(datacube, bin_factor, dtype=np.float32):
+ """
+ Performs diffraction space binning of data by bin_factor.
+
+ """
+ # validate inputs
+ assert type(bin_factor) is int, f"Error: binning factor {bin_factor} is not an int."
+ if bin_factor == 1:
+ return datacube
+
+ # get shape
+ R_Nx, R_Ny, Q_Nx, Q_Ny = (
+ datacube.R_Nx,
+ datacube.R_Ny,
+ datacube.Q_Nx,
+ datacube.Q_Ny,
+ )
+ # allocate space
+ data = np.zeros(
+ (
+ datacube.R_Nx,
+ datacube.R_Ny,
+ datacube.Q_Nx // bin_factor,
+ datacube.Q_Ny // bin_factor,
+ ),
+ dtype=dtype,
+ )
+ # bin
+ for Rx, Ry in tqdmnd(datacube.R_Ny, datacube.R_Ny):
+ data[Rx, Ry, :, :] = bin2D(datacube.data[Rx, Ry, :, :], bin_factor, dtype=dtype)
+ datacube.data = data
+
+ # set dim vectors
+ Qpixsize = datacube.calibration.get_Q_pixel_size() * bin_factor
+ Qpixunits = datacube.calibration.get_Q_pixel_units()
+ datacube.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx")
+ datacube.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy")
+ # set calibration pixel size
+ datacube.calibration.set_Q_pixel_size(Qpixsize)
+
+ # return
+ return datacube
+
+
+def bin_data_real(datacube, bin_factor):
+ """
+ Performs diffraction space binning of data by bin_factor.
+ """
+ # validate inputs
+ assert type(bin_factor) is int, f"Bin factor {bin_factor} is not an int."
+ if bin_factor <= 1:
+ return datacube
+
+ # set shape
+ R_Nx, R_Ny, Q_Nx, Q_Ny = (
+ datacube.R_Nx,
+ datacube.R_Ny,
+ datacube.Q_Nx,
+ datacube.Q_Ny,
+ )
+ # crop edges if necessary
+ if (R_Nx % bin_factor == 0) and (R_Ny % bin_factor == 0):
+ pass
+ elif R_Nx % bin_factor == 0:
+ datacube.data = datacube.data[:, : -(R_Ny % bin_factor), :, :]
+ elif R_Ny % bin_factor == 0:
+ datacube.data = datacube.data[: -(R_Nx % bin_factor), :, :, :]
+ else:
+ datacube.data = datacube.data[
+ : -(R_Nx % bin_factor), : -(R_Ny % bin_factor), :, :
+ ]
+ # bin
+ datacube.data = datacube.data.reshape(
+ int(R_Nx / bin_factor),
+ bin_factor,
+ int(R_Ny / bin_factor),
+ bin_factor,
+ Q_Nx,
+ Q_Ny,
+ ).sum(axis=(1, 3))
+
+ # set dim vectors
+ Rpixsize = datacube.calibration.get_R_pixel_size() * bin_factor
+ Rpixunits = datacube.calibration.get_R_pixel_units()
+ datacube.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx")
+ datacube.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry")
+ # set calibration pixel size
+ datacube.calibration.set_R_pixel_size(Rpixsize)
+
+ # return
+ return datacube
+
+
+def thin_data_real(datacube, thinning_factor):
+ """
+ Reduces data size by a factor of `thinning_factor`^2 by skipping every `thinning_factor` beam positions in both x and y.
+ """
+ # get shapes
+ Rshape0 = datacube.Rshape
+ Rshapef = tuple([x // thinning_factor for x in Rshape0])
+
+ # allocate memory
+ data = np.empty(
+ (Rshapef[0], Rshapef[1], datacube.Qshape[0], datacube.Qshape[1]),
+ dtype=datacube.data.dtype,
+ )
+
+ # populate data
+ for rx, ry in tqdmnd(Rshapef[0], Rshapef[1]):
+ rx0 = rx * thinning_factor
+ ry0 = ry * thinning_factor
+ data[rx, ry, :, :] = datacube[rx0, ry0, :, :]
+
+ datacube.data = data
+
+ # set dim vectors
+ Rpixsize = datacube.calibration.get_R_pixel_size() * thinning_factor
+ Rpixunits = datacube.calibration.get_R_pixel_units()
+ datacube.set_dim(0, [0, Rpixsize], units=Rpixunits, name="Rx")
+ datacube.set_dim(1, [0, Rpixsize], units=Rpixunits, name="Ry")
+ # set calibration pixel size
+ datacube.calibration.set_R_pixel_size(Rpixsize)
+
+ # return
+ return datacube
+
+
+def filter_hot_pixels(datacube, thresh, ind_compare=1, return_mask=False):
+ """
+ This function performs pixel filtering to remove hot / bright pixels.
+ A mean diffraction pattern is calculated, then a moving local ordering filter
+ is applied to it, finding and sorting the intensities of the 21 pixels nearest
+ each pixel (where 21 = (the pixel itself) + (nearest neighbors) + (next
+ nearest neighbors) = (1) + (8) + (12) = 21; the next nearest neighbors
+ exclude the corners of the NNN square of pixels). This filter then returns
+ a single value at each pixel given by the N'th highest value of these 21
+ sorted values, where N is specified by `ind_compare`. ind_compare=0
+ specifies the highest intensity, =1 is the second hightest, etc. Next, a mask
+ is generated which is True for all pixels which are least a value `thresh`
+ higher than the local ordering filter output. Thus for the default
+ `ind_compare` value of 1, the mask will be True wherever the mean diffraction
+ pattern is higher than the second brightest pixel in it's local window by
+ at least a value of `thresh`. Finally, we loop through all diffraction
+ images, and any pixels defined by mask are replaced by their 3x3 local
+ median.
+
+ Parameters
+ ----------
+ datacube : DataCube
+ The 4D atacube
+ thresh : float
+ Threshold for replacing hot pixels, if pixel value minus local ordering
+ filter exceeds it.
+ ind_compare : int
+ Which median filter value to compare against. 0 = brightest pixel,
+ 1 = next brightest, etc.
+ return_mask : bool
+ If True, returns the filter mask
+
+ Returns
+ -------
+ datacube : Datacube
+ mask : bool
+ (optional) the bad pixel mask
+ """
+
+ # Mean image over all probe positions
+ diff_mean = np.mean(datacube.data, axis=(0, 1))
+ shape = diff_mean.shape
+
+ # Moving local ordered pixel values
+ diff_local_med = np.sort(
+ np.vstack(
+ [
+ np.roll(diff_mean, (-1, -1), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (0, -1), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (1, -1), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (-1, 0), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (0, 0), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (1, 0), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (-1, 1), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (0, 1), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (1, 1), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (-1, -2), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (0, -2), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (1, -2), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (-1, 2), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (0, 2), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (1, 2), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (-2, -1), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (-2, 0), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (-2, 1), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (2, -1), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (2, 0), axis=(0, 1)).ravel(),
+ np.roll(diff_mean, (2, 1), axis=(0, 1)).ravel(),
+ ]
+ ),
+ axis=0,
+ )
+ # arry of the ind_compare'th pixel intensity
+ diff_compare = np.reshape(diff_local_med[-ind_compare - 1, :], shape)
+
+ # Generate mask
+ mask = diff_mean - diff_compare > thresh
+
+ # If the mask is empty, return
+ if np.sum(mask) == 0:
+ print("No hot pixels detected")
+ if return_mask is True:
+ return datacube, mask
+ else:
+ return datacube
+
+ # Otherwise, apply filtering
+
+ # Get masked indices
+ x_ma, y_ma = np.nonzero(mask)
+
+ # Get local windows for each masked pixel
+ xslices, yslices = [], []
+ for xm, ym in zip(x_ma, y_ma):
+ xslice, yslice = slice(xm - 1, xm + 2), slice(ym - 1, ym + 2)
+ if xslice.start < 0:
+ xslice = slice(0, xslice.stop)
+ elif xslice.stop > shape[0]:
+ xslice = slice(xslice.start, shape[0])
+ if yslice.start < 0:
+ yslice = slice(0, yslice.stop)
+ elif yslice.stop > shape[1]:
+ yslice = slice(yslice.start, shape[1])
+ xslices.append(xslice)
+ yslices.append(yslice)
+
+ # Loop and replace pixels
+ for ax, ay in tqdmnd(
+ *(datacube.R_Nx, datacube.R_Ny), desc="Cleaning pixels", unit=" images"
+ ):
+ for xm, ym, xs, ys in zip(x_ma, y_ma, xslices, yslices):
+ datacube.data[ax, ay, xm, ym] = np.median(datacube.data[ax, ay, xs, ys])
+
+ # Calculate local 3x3 median images
+ # im_med = median_filter(datacube.data[ax, ay, :, :], size=3, mode="nearest")
+ # datacube.data[ax, ay, :, :][mask] = im_med[mask]
+
+ # Return
+ if return_mask is True:
+ return datacube, mask
+ else:
+ return datacube
+
+
+def datacube_diffraction_shift(
+ datacube,
+ xshifts,
+ yshifts,
+ periodic=True,
+ bilinear=False,
+):
+ """
+ This function shifts each 2D diffraction image by the values defined by
+ (xshifts,yshifts). The shift values can be scalars (same shift for all
+ images) or arrays with the same dimensions as the probe positions in
+ datacube.
+
+ Args:
+ datacube (DataCube): py4DSTEM DataCube
+ xshifts (float): Array or scalar value for the x dim shifts
+ yshifts (float): Array or scalar value for the y dim shifts
+ periodic (bool): Flag for periodic boundary conditions. If set to false, boundaries are assumed to be periodic.
+ bilinear (bool): Flag for bilinear image shifts. If set to False, Fourier shifting is used.
+
+ Returns:
+ datacube (DataCube): py4DSTEM DataCube
+ """
+
+ # if the shift values are constant, expand to arrays
+ xshifts = np.array(xshifts)
+ yshifts = np.array(yshifts)
+ if xshifts.ndim == 0:
+ xshifts = xshifts * np.ones((datacube.R_Nx, datacube.R_Ny))
+ if yshifts.ndim == 0:
+ yshifts = yshifts * np.ones((datacube.R_Nx, datacube.R_Ny))
+
+ # Loop over all images
+ for ax, ay in tqdmnd(
+ *(datacube.R_Nx, datacube.R_Ny), desc="Shifting images", unit=" images"
+ ):
+ datacube.data[ax, ay, :, :] = get_shifted_ar(
+ datacube.data[ax, ay, :, :],
+ xshifts[ax, ay],
+ yshifts[ax, ay],
+ periodic=periodic,
+ bilinear=bilinear,
+ )
+
+ return datacube
+
+
+def resample_data_diffraction(
+ datacube, resampling_factor=None, output_size=None, method="bilinear"
+):
+ """
+ Performs diffraction space resampling of data by resampling_factor or to match output_size.
+ """
+ if method == "fourier":
+ from py4DSTEM.process.utils import fourier_resample
+
+ if np.size(resampling_factor) != 1:
+ warnings.warn(
+ (
+ "Fourier resampling currently only accepts a scalar resampling_factor. "
+ f"'resampling_factor' set to {resampling_factor[0]}."
+ ),
+ UserWarning,
+ )
+ resampling_factor = resampling_factor[0]
+
+ old_size = datacube.data.shape
+
+ datacube.data = fourier_resample(
+ datacube.data, scale=resampling_factor, output_size=output_size
+ )
+
+ if not resampling_factor:
+ resampling_factor = output_size[0] / old_size[2]
+ if datacube.calibration.get_Q_pixel_size() is not None:
+ datacube.calibration.set_Q_pixel_size(
+ datacube.calibration.get_Q_pixel_size() / resampling_factor
+ )
+
+ elif method == "bilinear":
+ from scipy.ndimage import zoom
+
+ if resampling_factor is not None:
+ if output_size is not None:
+ raise ValueError(
+ "Only one of 'resampling_factor' or 'output_size' can be specified."
+ )
+
+ resampling_factor = np.array(resampling_factor)
+ if resampling_factor.shape == ():
+ resampling_factor = np.tile(resampling_factor, 2)
+
+ else:
+ if output_size is None:
+ raise ValueError(
+ "At-least one of 'resampling_factor' or 'output_size' must be specified."
+ )
+
+ if len(output_size) != 2:
+ raise ValueError(
+ f"'output_size' must have length 2, not {len(output_size)}"
+ )
+
+ resampling_factor = np.array(output_size) / np.array(datacube.shape[-2:])
+
+ resampling_factor = np.concatenate(((1, 1), resampling_factor))
+ datacube.data = zoom(datacube.data, resampling_factor, order=1)
+ datacube.calibration.set_Q_pixel_size(
+ datacube.calibration.get_Q_pixel_size() / resampling_factor[2]
+ )
+
+ else:
+ raise ValueError(
+ f"'method' needs to be one of 'bilinear' or 'fourier', not {method}."
+ )
+
+ return datacube
+
+
+def pad_data_diffraction(datacube, pad_factor=None, output_size=None):
+ """
+ Performs diffraction space padding of data by pad_factor or to match output_size.
+ """
+ Qx, Qy = datacube.shape[-2:]
+
+ if pad_factor is not None:
+ if output_size is not None:
+ raise ValueError(
+ "Only one of 'pad_factor' or 'output_size' can be specified."
+ )
+
+ pad_factor = np.array(pad_factor)
+ if pad_factor.shape == ():
+ pad_factor = np.tile(pad_factor, 2)
+
+ if np.any(pad_factor < 1):
+ raise ValueError("'pad_factor' needs to be larger than 1.")
+
+ pad_kx = np.round(Qx * (pad_factor[0] - 1) / 2).astype("int")
+ pad_kx = (pad_kx, pad_kx)
+ pad_ky = np.round(Qy * (pad_factor[1] - 1) / 2).astype("int")
+ pad_ky = (pad_ky, pad_ky)
+
+ else:
+ if output_size is None:
+ raise ValueError(
+ "At-least one of 'pad_factor' or 'output_size' must be specified."
+ )
+
+ if len(output_size) != 2:
+ raise ValueError(
+ f"'output_size' must have length 2, not {len(output_size)}"
+ )
+
+ Sx, Sy = output_size
+
+ if Sx < Qx or Sy < Qy:
+ raise ValueError(f"'output_size' must be at-least as large as {(Qx,Qy)}.")
+
+ pad_kx = Sx - Qx
+ pad_kx = (pad_kx // 2, pad_kx // 2 + pad_kx % 2)
+
+ pad_ky = Sy - Qy
+ pad_ky = (pad_ky // 2, pad_ky // 2 + pad_ky % 2)
+
+ pad_width = (
+ (0, 0),
+ (0, 0),
+ pad_kx,
+ pad_ky,
+ )
+
+ datacube.data = np.pad(datacube.data, pad_width=pad_width, mode="constant")
+
+ Qpixsize = datacube.calibration.get_Q_pixel_size()
+ Qpixunits = datacube.calibration.get_Q_pixel_units()
+
+ datacube.set_dim(2, [0, Qpixsize], units=Qpixunits, name="Qx")
+ datacube.set_dim(3, [0, Qpixsize], units=Qpixunits, name="Qy")
+
+ datacube.calibrate()
+
+ return datacube
diff --git a/py4DSTEM/preprocess/radialbkgrd.py b/py4DSTEM/preprocess/radialbkgrd.py
new file mode 100644
index 000000000..da225ed74
--- /dev/null
+++ b/py4DSTEM/preprocess/radialbkgrd.py
@@ -0,0 +1,174 @@
+"""
+Functions for generating radially averaged backgrounds
+"""
+
+import numpy as np
+from scipy.interpolate import interp1d
+from scipy.signal import savgol_filter
+
+
+## Create look up table for background subtraction
+def get_1D_polar_background(
+ data,
+ p_ellipse,
+ center=None,
+ maskUpdateIter=3,
+ min_relative_threshold=4,
+ smoothing=False,
+ smoothingWindowSize=3,
+ smoothingPolyOrder=4,
+ smoothing_log=True,
+ min_background_value=1e-3,
+ return_polararr=False,
+):
+ """
+ Gets the median polar background for a diffraction pattern
+
+ Parameters
+ ----------
+ data : ndarray
+ the data for which to find the polar eliptical background,
+ usually a diffraction pattern
+ p_ellipse : 5-tuple
+ the ellipse parameters (qx0,qy0,a,b,theta)
+ center : 2-tuple or None
+ if None, the center point from `p_ellipse` is used. Otherwise,
+ the center point in `p_ellipse` is ignored, and this argument
+ is used as (qx0,qy0) instead.
+ maskUpdate_iter : integer
+ min_relative_threshold : float
+ smoothing : bool
+ if true, applies a Savitzky-Golay smoothing filter
+ smoothingWindowSize : integer
+ size of the smoothing window, must be odd number
+ smoothingPolyOrder : number
+ order of the polynomial smoothing to be applied
+ smoothing_log : bool
+ if true log smoothing is performed
+ min_background_value : float
+ if log smoothing is true, a zero value will be replaced with a
+ small nonzero float
+ return_polar_arr : bool
+ if True the polar transform with the masked high intensity peaks
+ will be returned
+
+ Returns
+ -------
+ 2- or 3-tuple of ndarrays
+ * **background1D**: 1D polar elliptical background
+ * **r_bins**: the elliptically transformed radius associated with
+ background1D
+ * **polarData** (optional): the masked polar transform from which the
+ background is computed, returned iff `return_polar_arr==True`
+ """
+ from py4DSTEM.process.utils import cartesian_to_polarelliptical_transform
+
+ # assert data is proper form
+ assert isinstance(smoothing, bool), "Smoothing must be bool"
+ assert smoothingWindowSize % 2 == 1, "Smoothing window must be odd"
+ assert isinstance(return_polararr, bool), "return_polararr must be bool"
+
+ # Prepare ellipse params
+ if center is not None:
+ p_ellipse = tuple[
+ center[0], center[1], p_ellipse[2], p_ellipse[3], p_ellipse[4]
+ ]
+
+ # Compute Polar Transform
+ polarData, rr, tt = cartesian_to_polarelliptical_transform(data, p_ellipse)
+
+ # Crop polar data to maximum distance which contains information from original image
+ if (polarData.mask.sum(axis=(0)) == polarData.shape[0]).any():
+ ii = polarData.data.shape[1] - 1
+ while polarData.mask[:, ii].all() is True:
+ ii = ii - 1
+ maximalDistance = ii
+ polarData = polarData[:, 0:maximalDistance]
+ r_bins = rr[0, 0:maximalDistance]
+ else:
+ r_bins = rr[0, :]
+
+ # Iteratively mask off high intensity peaks
+ maskPolar = np.copy(polarData.mask)
+ background1D = np.ma.median(polarData, axis=0)
+ for ii in range(maskUpdateIter + 1):
+ if ii > 0:
+ maskUpdate = np.logical_or(
+ maskPolar, polarData / background1D > min_relative_threshold
+ )
+ # Prevent entire columns from being masked off
+ colMaskMin = np.all(maskUpdate, axis=0) # Detect columns that are empty
+ maskUpdate[:, colMaskMin] = polarData.mask[
+ :, colMaskMin
+ ] # reset empty columns to values of previous iterations
+ polarData.mask = maskUpdate # Update Mask
+
+ background1D = np.maximum(background1D, min_background_value)
+
+ if smoothing is True:
+ if smoothing_log is True:
+ background1D = np.log(background1D)
+
+ background1D = savgol_filter(
+ background1D, smoothingWindowSize, smoothingPolyOrder
+ )
+ if smoothing_log is True:
+ background1D = np.exp(background1D)
+ if return_polararr is True:
+ return (background1D, r_bins, polarData)
+ else:
+ return (background1D, r_bins)
+
+
+# Create 2D Background
+def get_2D_polar_background(data, background1D, r_bins, p_ellipse, center=None):
+ """
+ Gets 2D polar elliptical background from linear 1D background
+
+ Parameters
+ ----------
+ data : ndarray
+ the data for which to find the polar eliptical background,
+ usually a diffraction pattern
+ background1D : ndarray
+ a vector representing the radial elliptical background
+ r_bins : ndarray
+ a vector of the elliptically transformed radius associated with
+ background1D
+ p_ellipse : 5-tuple
+ the ellipse parameters (qx0,qy0,a,b,theta)
+ center : 2-tuple or None
+ if None, the center point from `p_ellipse` is used. Otherwise,
+ the center point in `p_ellipse` is ignored, and this argument
+ is used as (qx0,qy0) instead.
+
+ Returns
+ -------
+ ndarray
+ 2D polar elliptical median background image
+ """
+ assert (
+ r_bins.shape == background1D.shape
+ ), "1D background and r_bins must be same length"
+
+ # Prepare ellipse params
+ qx0, qy0, a, b, theta = p_ellipse
+ if center is not None:
+ qx0, qy0 = center
+
+ # Define centered 2D cartesian coordinate system
+ yc, xc = np.meshgrid(
+ np.arange(0, data.shape[1]) - qy0, np.arange(0, data.shape[0]) - qx0
+ )
+
+ # Calculate the semimajor axis distance for each point in the 2D array
+ r = np.sqrt(
+ ((xc * np.cos(theta) + yc * np.sin(theta)) ** 2)
+ + (((xc * np.sin(theta) - yc * np.cos(theta)) ** 2) / ((b / a) ** 2))
+ )
+
+ # Create a 2D eliptical background using linear interpolation
+ f = interp1d(r_bins, background1D, fill_value="extrapolate")
+ background2D = f(r)
+
+ return background2D
diff --git a/py4DSTEM/preprocess/utils.py b/py4DSTEM/preprocess/utils.py
new file mode 100644
index 000000000..829f66608
--- /dev/null
+++ b/py4DSTEM/preprocess/utils.py
@@ -0,0 +1,327 @@
+# Preprocessing utility functions
+
+import numpy as np
+from scipy.ndimage import gaussian_filter
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+
+def bin2D(array, factor, dtype=np.float64):
+ """
+ Bin a 2D ndarray by binfactor.
+
+ Args:
+ array (2D numpy array):
+ factor (int): the binning factor
+ dtype (numpy dtype): datatype for binned array. default is numpy default for
+ np.zeros()
+
+ Returns:
+ the binned array
+ """
+ x, y = array.shape
+ binx, biny = x // factor, y // factor
+ xx, yy = binx * factor, biny * factor
+
+ # Make a binned array on the device
+ binned_ar = np.zeros((binx, biny), dtype=dtype)
+ array = array.astype(dtype)
+
+ # Collect pixel sums into new bins
+ for ix in range(factor):
+ for iy in range(factor):
+ binned_ar += array[0 + ix : xx + ix : factor, 0 + iy : yy + iy : factor]
+ return binned_ar
+
+
+def make_Fourier_coords2D(Nx, Ny, pixelSize=1):
+ """
+ Generates Fourier coordinates for a (Nx,Ny)-shaped 2D array.
+ Specifying the pixelSize argument sets a unit size.
+ """
+ if hasattr(pixelSize, "__len__"):
+ assert len(pixelSize) == 2, "pixelSize must either be a scalar or have length 2"
+ pixelSize_x = pixelSize[0]
+ pixelSize_y = pixelSize[1]
+ else:
+ pixelSize_x = pixelSize
+ pixelSize_y = pixelSize
+
+ qx = np.fft.fftfreq(Nx, pixelSize_x)
+ qy = np.fft.fftfreq(Ny, pixelSize_y)
+ qy, qx = np.meshgrid(qy, qx)
+ return qx, qy
+
+
+def get_shifted_ar(ar, xshift, yshift, periodic=True, bilinear=False, device="cpu"):
+ """
+ Shifts array ar by the shift vector (xshift,yshift), using the either
+ the Fourier shift theorem (i.e. with sinc interpolation), or bilinear
+ resampling. Boundary conditions can be periodic or not.
+
+ Args:
+ ar (float): input array
+ xshift (float): shift along axis 0 (x) in pixels
+ yshift (float): shift along axis 1 (y) in pixels
+ periodic (bool): flag for periodic boundary conditions
+ bilinear (bool): flag for bilinear image shifts
+ device(str): calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ Returns:
+ (array) the shifted array
+ """
+ if device == "cpu":
+ xp = np
+
+ elif device == "gpu":
+ xp = cp
+
+ ar = xp.asarray(ar)
+
+ # Apply image shift
+ if bilinear is False:
+ nx, ny = xp.shape(ar)
+ qx, qy = make_Fourier_coords2D(nx, ny, 1)
+ qx = xp.asarray(qx)
+ qy = xp.asarray(qy)
+
+ w = xp.exp(-(2j * xp.pi) * ((yshift * qy) + (xshift * qx)))
+ shifted_ar = xp.real(xp.fft.ifft2((xp.fft.fft2(ar)) * w))
+
+ else:
+ xF = xp.floor(xshift).astype(int).item()
+ yF = xp.floor(yshift).astype(int).item()
+ wx = xshift - xF
+ wy = yshift - yF
+
+ shifted_ar = (
+ xp.roll(ar, (xF, yF), axis=(0, 1)) * ((1 - wx) * (1 - wy))
+ + xp.roll(ar, (xF + 1, yF), axis=(0, 1)) * ((wx) * (1 - wy))
+ + xp.roll(ar, (xF, yF + 1), axis=(0, 1)) * ((1 - wx) * (wy))
+ + xp.roll(ar, (xF + 1, yF + 1), axis=(0, 1)) * ((wx) * (wy))
+ )
+
+ if periodic is False:
+ # Rounded coordinates for boundaries
+ xR = (xp.round(xshift)).astype(int)
+ yR = (xp.round(yshift)).astype(int)
+
+ if xR > 0:
+ shifted_ar[0:xR, :] = 0
+ elif xR < 0:
+ shifted_ar[xR:, :] = 0
+ if yR > 0:
+ shifted_ar[:, 0:yR] = 0
+ elif yR < 0:
+ shifted_ar[:, yR:] = 0
+
+ return shifted_ar
+
+
+def get_maxima_2D(
+ ar,
+ subpixel="poly",
+ upsample_factor=16,
+ sigma=0,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0,
+ relativeToPeak=0,
+ minSpacing=0,
+ edgeBoundary=1,
+ maxNumPeaks=1,
+ _ar_FT=None,
+):
+ """
+ Finds the maximal points of a 2D array.
+
+ Args:
+ ar (array) the 2D array
+ subpixel (str): specifies the subpixel resolution algorithm to use.
+ must be in ('pixel','poly','multicorr'), which correspond
+ to pixel resolution, subpixel resolution by fitting a
+ parabola, and subpixel resultion by Fourier upsampling.
+ upsample_factor: the upsampling factor for the 'multicorr'
+ algorithm
+ sigma: if >0, applies a gaussian filter
+ maxNumPeaks: the maximum number of maxima to return
+ minAbsoluteIntensity, minRelativeIntensity, relativeToPeak,
+ minSpacing, edgeBoundary, maxNumPeaks: filtering applied
+ after maximum detection and before subpixel refinement
+ _ar_FT (complex array) if 'multicorr' is used and this is not
+ None, uses this argument as the Fourier transform of `ar`,
+ instead of recomputing it
+
+ Returns:
+ a structured array with fields 'x','y','intensity'
+ """
+ from py4DSTEM.process.utils.multicorr import upsampled_correlation
+
+ subpixel_modes = ("pixel", "poly", "multicorr")
+ er = f"Unrecognized subpixel option {subpixel}. Must be in {subpixel_modes}"
+ assert subpixel in subpixel_modes, er
+
+ # gaussian filtering
+ ar = ar if sigma <= 0 else gaussian_filter(ar, sigma)
+
+ # local pixelwise maxima
+ maxima_bool = (
+ (ar >= np.roll(ar, (-1, 0), axis=(0, 1)))
+ & (ar > np.roll(ar, (1, 0), axis=(0, 1)))
+ & (ar >= np.roll(ar, (0, -1), axis=(0, 1)))
+ & (ar > np.roll(ar, (0, 1), axis=(0, 1)))
+ & (ar >= np.roll(ar, (-1, -1), axis=(0, 1)))
+ & (ar > np.roll(ar, (-1, 1), axis=(0, 1)))
+ & (ar >= np.roll(ar, (1, -1), axis=(0, 1)))
+ & (ar > np.roll(ar, (1, 1), axis=(0, 1)))
+ )
+
+ # remove edges
+ assert isinstance(edgeBoundary, (int, np.integer))
+ if edgeBoundary < 1:
+ edgeBoundary = 1
+ maxima_bool[:edgeBoundary, :] = False
+ maxima_bool[-edgeBoundary:, :] = False
+ maxima_bool[:, :edgeBoundary] = False
+ maxima_bool[:, -edgeBoundary:] = False
+
+ # get indices
+ # sort by intensity
+ maxima_x, maxima_y = np.nonzero(maxima_bool)
+ dtype = np.dtype([("x", float), ("y", float), ("intensity", float)])
+ maxima = np.zeros(len(maxima_x), dtype=dtype)
+ maxima["x"] = maxima_x
+ maxima["y"] = maxima_y
+ maxima["intensity"] = ar[maxima_x, maxima_y]
+ maxima = np.sort(maxima, order="intensity")[::-1]
+
+ if len(maxima) == 0:
+ return maxima
+
+ # filter
+ maxima = filter_2D_maxima(
+ maxima,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minSpacing=minSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ )
+
+ if subpixel == "pixel":
+ return maxima
+
+ # Parabolic subpixel refinement
+ for i in range(len(maxima)):
+ Ix1_ = ar[int(maxima["x"][i]) - 1, int(maxima["y"][i])].astype(np.float64)
+ Ix0 = ar[int(maxima["x"][i]), int(maxima["y"][i])].astype(np.float64)
+ Ix1 = ar[int(maxima["x"][i]) + 1, int(maxima["y"][i])].astype(np.float64)
+ Iy1_ = ar[int(maxima["x"][i]), int(maxima["y"][i]) - 1].astype(np.float64)
+ Iy0 = ar[int(maxima["x"][i]), int(maxima["y"][i])].astype(np.float64)
+ Iy1 = ar[int(maxima["x"][i]), int(maxima["y"][i]) + 1].astype(np.float64)
+ deltax = (Ix1 - Ix1_) / (4 * Ix0 - 2 * Ix1 - 2 * Ix1_)
+ deltay = (Iy1 - Iy1_) / (4 * Iy0 - 2 * Iy1 - 2 * Iy1_)
+ maxima["x"][i] += deltax
+ maxima["y"][i] += deltay
+ maxima["intensity"][i] = linear_interpolation_2D(
+ ar, maxima["x"][i], maxima["y"][i]
+ )
+
+ if subpixel == "poly":
+ return maxima
+
+ # Fourier upsampling
+ if _ar_FT is None:
+ _ar_FT = np.fft.fft2(ar)
+ for ipeak in range(len(maxima["x"])):
+ xyShift = np.array((maxima["x"][ipeak], maxima["y"][ipeak]))
+ # we actually have to lose some precision and go down to half-pixel
+ # accuracy for multicorr
+ xyShift[0] = np.round(xyShift[0] * 2) / 2
+ xyShift[1] = np.round(xyShift[1] * 2) / 2
+
+ subShift = upsampled_correlation(_ar_FT, upsample_factor, xyShift)
+ maxima["x"][ipeak] = subShift[0]
+ maxima["y"][ipeak] = subShift[1]
+
+ maxima = np.sort(maxima, order="intensity")[::-1]
+ return maxima
+
+
+def filter_2D_maxima(
+ maxima,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0,
+ relativeToPeak=0,
+ minSpacing=0,
+ edgeBoundary=1,
+ maxNumPeaks=1,
+):
+ """
+ Args:
+ maxima : a numpy structured array with fields 'x', 'y', 'intensity'
+ minAbsoluteIntensity : delete counts with intensity below this value
+ minRelativeIntensity : delete counts with intensity below this value times
+ the intensity of the i'th peak, where i is given by `relativeToPeak`
+ relativeToPeak : see above
+ minSpacing : if two peaks are within this euclidean distance from one
+ another, delete the less intense of the two
+ edgeBoundary : delete peaks within this distance of the image edge
+ maxNumPeaks : an integer. defaults to 1
+
+ Returns:
+ a numpy structured array with fields 'x', 'y', 'intensity'
+ """
+
+ # Remove maxima which are too dim
+ if minAbsoluteIntensity > 0:
+ deletemask = maxima["intensity"] < minAbsoluteIntensity
+ maxima = maxima[~deletemask]
+
+ # Remove maxima which are too dim, compared to the n-th brightest
+ if (minRelativeIntensity > 0) & (len(maxima) > relativeToPeak):
+ assert isinstance(relativeToPeak, (int, np.integer))
+ deletemask = (
+ maxima["intensity"] / maxima["intensity"][relativeToPeak]
+ < minRelativeIntensity
+ )
+ maxima = maxima[~deletemask]
+
+ # Remove maxima which are too close
+ if minSpacing > 0:
+ deletemask = np.zeros(len(maxima), dtype=bool)
+ for i in range(len(maxima)):
+ if deletemask[i] is False:
+ tooClose = (
+ (maxima["x"] - maxima["x"][i]) ** 2
+ + (maxima["y"] - maxima["y"][i]) ** 2
+ ) < minSpacing**2
+ tooClose[: i + 1] = False
+ deletemask[tooClose] = True
+ maxima = maxima[~deletemask]
+
+ # Remove maxima in excess of maxNumPeaks
+ if maxNumPeaks is not None:
+ if len(maxima) > maxNumPeaks:
+ maxima = maxima[:maxNumPeaks]
+
+ return maxima
+
+
+def linear_interpolation_2D(ar, x, y):
+ """
+ Calculates the 2D linear interpolation of array ar at position x,y using the four
+ nearest array elements.
+ """
+ x0, x1 = int(np.floor(x)), int(np.ceil(x))
+ y0, y1 = int(np.floor(y)), int(np.ceil(y))
+ dx = x - x0
+ dy = y - y0
+ return (
+ (1 - dx) * (1 - dy) * ar[x0, y0]
+ + (1 - dx) * dy * ar[x0, y1]
+ + dx * (1 - dy) * ar[x1, y0]
+ + dx * dy * ar[x1, y1]
+ )
diff --git a/py4DSTEM/process/__init__.py b/py4DSTEM/process/__init__.py
new file mode 100644
index 000000000..0509d181e
--- /dev/null
+++ b/py4DSTEM/process/__init__.py
@@ -0,0 +1,9 @@
+from py4DSTEM.process.polar import PolarDatacube
+from py4DSTEM.process.strain.strain import StrainMap
+
+from py4DSTEM.process import phase
+from py4DSTEM.process import calibration
+from py4DSTEM.process import utils
+from py4DSTEM.process import classification
+from py4DSTEM.process import diffraction
+from py4DSTEM.process import wholepatternfit
diff --git a/py4DSTEM/process/calibration/__init__.py b/py4DSTEM/process/calibration/__init__.py
new file mode 100644
index 000000000..2f8de9c0d
--- /dev/null
+++ b/py4DSTEM/process/calibration/__init__.py
@@ -0,0 +1,5 @@
+from py4DSTEM.process.calibration.qpixelsize import *
+from py4DSTEM.process.calibration.origin import *
+from py4DSTEM.process.calibration.ellipse import *
+from py4DSTEM.process.calibration.rotation import *
+from py4DSTEM.process.calibration.probe import *
diff --git a/py4DSTEM/process/calibration/ellipse.py b/py4DSTEM/process/calibration/ellipse.py
new file mode 100644
index 000000000..e6a216cf1
--- /dev/null
+++ b/py4DSTEM/process/calibration/ellipse.py
@@ -0,0 +1,320 @@
+"""
+Functions related to elliptical calibration, such as fitting elliptical
+distortions.
+
+The user-facing representation of ellipses is in terms of the following 5
+parameters:
+
+ x0,y0 the center of the ellipse
+ a the semimajor axis length
+ b the semiminor axis length
+ theta the (positive, right handed) tilt of the a-axis
+ to the x-axis, in radians
+
+More details about the elliptical parameterization used can be found in
+the module docstring for process/utils/elliptical_coords.py.
+"""
+
+import numpy as np
+from scipy.optimize import leastsq
+from scipy.ndimage import gaussian_filter
+
+from py4DSTEM.process.utils import convert_ellipse_params, convert_ellipse_params_r
+from py4DSTEM.process.utils import get_CoM, radial_integral
+
+###### Fitting a 1d elliptical curve to a 2d array, e.g. a Bragg vector map ######
+
+
+def fit_ellipse_1D(ar, center=None, fitradii=None, mask=None):
+ """
+ For a 2d array ar, fits a 1d elliptical curve to the data inside an annulus centered
+ at `center` with inner and outer radii at `fitradii`. The data to fit make optionally
+ be additionally masked with the boolean array mask. See module docstring for more info.
+
+ Args:
+ ar (ndarray): array containing the data to fit
+ center (2-tuple of floats): the center (x0,y0) of the annular fitting region
+ fitradii (2-tuple of floats): inner and outer radii (ri,ro) of the fit region
+ mask (ar-shaped ndarray of bools): ignore data wherever mask==True
+
+ Returns:
+ (5-tuple of floats): A 5-tuple containing the ellipse parameters:
+ * **x0**: the center x-position
+ * **y0**: the center y-position
+ * **a**: the semimajor axis length
+ * **b**: the semiminor axis length
+ * **theta**: the tilt of the ellipse semimajor axis with respect to the
+ x-axis, in radians
+ """
+
+ # Default values
+ if center is None:
+ center = (ar.shape[0] / 2, ar.shape[1] / 2)
+
+ if fitradii is None:
+ fitradii = (0, np.minimum(ar.shape) / 2)
+
+ # Unpack inputs
+ x0, y0 = center
+ ri, ro = fitradii
+
+ # Get the datapoints to fit
+ yy, xx = np.meshgrid(np.arange(ar.shape[1]), np.arange(ar.shape[0]))
+ rr = np.sqrt((xx - x0) ** 2 + (yy - y0) ** 2)
+ _mask = (rr > ri) * (rr <= ro)
+ if mask is not None:
+ _mask *= mask is False
+ xs, ys = np.nonzero(_mask)
+ vals = ar[_mask]
+
+ # Get initial parameters guess
+ p0 = [x0, y0, (2 / (ri + ro)) ** 2, 0, (2 / (ri + ro)) ** 2]
+
+ # Fit
+ x, y, A, B, C = leastsq(ellipse_err, p0, args=(xs, ys, vals))[0]
+
+ # Convert ellipse params
+ a, b, theta = convert_ellipse_params(A, B, C)
+
+ return x, y, a, b, theta
+
+
+def ellipse_err(p, x, y, val):
+ """
+ For a point (x,y) in a 2d cartesian space, and a function taking the value
+ val at point (x,y), and some 1d ellipse in this space given by
+ ``A(x-x0)^2 + B(x-x0)(y-y0) + C(y-y0)^2 = 1``
+ this function computes the error associated with the function's value at (x,y)
+ given by its deviation from the ellipse times val.
+
+ Note that this function is for internal use, and uses ellipse parameters `p`
+ given in canonical form (x0,y0,A,B,C), which is different from the ellipse
+ parameterization used in all the user-facing functions, for reasons of
+ numerical stability.
+ """
+ x, y = x - p[0], y - p[1]
+ return (p[2] * x**2 + p[3] * x * y + p[4] * y**2 - 1) * val
+
+
+###### Fitting from amorphous diffraction rings ######
+
+
+def fit_ellipse_amorphous_ring(data, center, fitradii, p0=None, mask=None):
+ """
+ Fit the amorphous halo of a diffraction pattern, including any elliptical distortion.
+
+ The fit function is::
+
+ f(x,y; I0,I1,sigma0,sigma1,sigma2,c_bkgd,x0,y0,A,B,C) =
+ Norm(r; I0,sigma0,0) +
+ Norm(r; I1,sigma1,R)*Theta(r-R)
+ Norm(r; I1,sigma2,R)*Theta(R-r) + c_bkgd
+
+ where
+
+ * (x,y) are cartesian coordinates,
+ * r is the radial coordinate,
+ * (I0,I1,sigma0,sigma1,sigma2,c_bkgd,x0,y0,R,B,C) are parameters,
+ * Norm(x;I,s,u) is a gaussian in the variable x with maximum amplitude I,
+ standard deviation s, and mean u
+ * Theta(x) is a Heavyside step function
+ * R is the radial center of the double sided gaussian, derived from (A,B,C)
+ and set to the mean of the semiaxis lengths
+
+ The function thus contains a pair of gaussian-shaped peaks along the radial
+ direction of a polar-elliptical parametrization of a 2D plane. The first gaussian is
+ centered at the origin. The second gaussian is centered about some finite R, and is
+ 'two-faced': it's comprised of two half-gaussians of different standard deviations,
+ stitched together at their mean value of R. This Janus (two-faced ;p) gaussian thus
+ comprises an elliptical ring with different inner and outer widths.
+
+ The parameters of the fit function are
+
+ * I0: the intensity of the first gaussian function
+ * I1: the intensity of the Janus gaussian
+ * sigma0: std of first gaussian
+ * sigma1: inner std of Janus gaussian
+ * sigma2: outer std of Janus gaussian
+ * c_bkgd: a constant offset
+ * x0,y0: the origin
+ * A,B,C: The ellipse parameters, in the form Ax^2 + Bxy + Cy^2 = 1
+
+ Args:
+ data (2d array): the data
+ center (2-tuple of numbers): the center (x0,y0)
+ fitradii (2-tuple of numbers): the inner and outer radii of the fitting annulus
+ p0 (11-tuple): initial guess parameters. If p0 is None, the function will compute
+ a guess at all parameters. If p0 is a 11-tuple it must be populated by some
+ mix of numbers and None; any parameters which are set to None will be guessed
+ by the function. The parameters are the 11 parameters of the fit function
+ described above, p0 = (I0,I1,sigma0,sigma1,sigma2,c_bkgd,x0,y0,A,B,C).
+ Note that x0,y0 are redundant; their guess values are the x0,y0 values passed
+ to the main function, but if they are passed as elements of p0 these will
+ take precendence.
+ mask (2d array of bools): only fit to datapoints where mask is True
+
+ Returns:
+ (2-tuple comprised of a 5-tuple and an 11-tuple): Returns a 2-tuple.
+
+ The first element is the ellipse parameters need to elliptically parametrize
+ diffraction space, and is itself a 5-tuple:
+
+ * **x0**: x center
+ * **y0**: y center,
+ * **a**: the semimajor axis length
+ * **b**: the semiminor axis length
+ * **theta**: tilt of a-axis w.r.t x-axis, in radians
+
+ The second element is the full set of fit parameters to the double sided gaussian
+ function, described above, and is an 11-tuple
+ """
+ if mask is None:
+ mask = np.ones_like(data).astype(bool)
+ assert data.shape == mask.shape, "data and mask must have same shapes."
+ x0, y0 = center
+ ri, ro = fitradii
+
+ # Get data mask
+ Nx, Ny = data.shape
+ yy, xx = np.meshgrid(np.arange(Ny), np.arange(Nx))
+ rr = np.hypot(xx - x0, yy - y0)
+ _mask = ((rr > ri) * (rr < ro)).astype(bool)
+ _mask *= mask
+
+ # Make coordinates, get data values
+ x_inds, y_inds = np.nonzero(_mask)
+ vals = data[_mask]
+
+ # Get initial parameter guesses
+ I0 = np.max(data)
+ I1 = np.max(data * mask)
+ sigma0 = ri / 2.0
+ sigma1 = (ro - ri) / 4.0
+ sigma2 = (ro - ri) / 4.0
+ c_bkgd = np.min(data)
+ # To guess R, we take a radial integral
+ q, radial_profile = radial_integral(data, x0, y0, 1)
+ R = q[(q > ri) * (q < ro)][np.argmax(radial_profile[(q > ri) * (q < ro)])]
+ # Initial guess at A,B,C
+ A, B, C = convert_ellipse_params_r(R, R, 0)
+
+ # Populate initial parameters
+ p0_guess = tuple([I0, I1, sigma0, sigma1, sigma2, c_bkgd, x0, y0, A, B, C])
+ if p0 is None:
+ _p0 = p0_guess
+ else:
+ assert len(p0) == 11
+ _p0 = tuple([p0_guess[i] if p0[i] is None else p0[i] for i in range(len(p0))])
+
+ # Perform fit
+ p = leastsq(double_sided_gaussian_fiterr, _p0, args=(x_inds, y_inds, vals))[0]
+
+ # Return
+ _x0, _y0 = p[6], p[7]
+ _A, _B, _C = p[8], p[9], p[10]
+ _a, _b, _theta = convert_ellipse_params(_A, _B, _C)
+ return (_x0, _y0, _a, _b, _theta), p
+
+
+def double_sided_gaussian_fiterr(p, x, y, val):
+ """
+ Returns the fit error associated with a point (x,y) with value val, given parameters p.
+ """
+ return double_sided_gaussian(p, x, y) - val
+
+
+def double_sided_gaussian(p, x, y):
+ """
+ Return the value of the double-sided gaussian function at point (x,y) given
+ parameters p, described in detail in the fit_ellipse_amorphous_ring docstring.
+ """
+ # Unpack parameters
+ I0, I1, sigma0, sigma1, sigma2, c_bkgd, x0, y0, A, B, C = p
+ a, b, theta = convert_ellipse_params(A, B, C)
+ R = np.mean((a, b))
+ R2 = R**2
+ A, B, C = A * R2, B * R2, C * R2
+ r2 = A * (x - x0) ** 2 + B * (x - x0) * (y - y0) + C * (y - y0) ** 2
+ r = np.sqrt(r2) - R
+
+ return (
+ I0 * np.exp(-r2 / (2 * sigma0**2))
+ + I1 * np.exp(-(r**2) / (2 * sigma1**2)) * np.heaviside(-r, 0.5)
+ + I1 * np.exp(-(r**2) / (2 * sigma2**2)) * np.heaviside(r, 0.5)
+ + c_bkgd
+ )
+
+
+### Fit an ellipse to crystalline scattering with a known angle between peaks
+
+
+def constrain_degenerate_ellipse(
+ data, p_ellipse, r_inner, r_outer, phi_known, fitrad=6
+):
+ """
+ When fitting an ellipse to data containing 4 diffraction spots in a narrow annulus
+ about the central beam, the answer is degenerate: an infinite number of ellipses
+ correctly fit this data. Starting from one ellipse in the degenerate family of
+ ellipses, this function selects the ellipse which will yield a final angle of
+ phi_known between a pair of the diffraction peaks after performing elliptical
+ distortion correction.
+
+ Note that there are two possible angles which phi_known might refer to, because the
+ angle of interest is well defined up to a complementary angle. This function is
+ written such that phi_known should be the smaller of these two angles.
+
+ Args:
+ data (ndarray) the data to fit, typically a Bragg vector map
+ p_ellipse (5-tuple): the ellipse parameters (x0,y0,a,b,theta)
+ r_inner (float): the fitting annulus inner radius
+ r_outer (float): the fitting annulus outer radius
+ phi_known (float): the known angle between a pair of diffraction peaks, in
+ radians
+ fitrad (float): the region about the fixed data point used to refine its position
+
+ Returns:
+ (2-tuple): A 2-tuple containing:
+
+ * **a_constrained**: *(float)* the first semiaxis of the selected ellipse
+ * **b_constrained**: *(float)* the second semiaxis of the selected ellipse
+ """
+ # Unpack ellipse params
+ x, y, a, b, theta = p_ellipse
+
+ # Get 4 constraining points
+ xs, ys = np.zeros(4), np.zeros(4)
+ yy, xx = np.meshgrid(np.arange(data.shape[1]), np.arange(data.shape[0]))
+ rr = np.sqrt((xx - x) ** 2 + (yy - y) ** 2)
+ annular_mask = (rr > r_inner) * (rr <= r_outer)
+ data_temp = np.zeros_like(data)
+ data_temp[annular_mask] = data[annular_mask]
+ for i in range(4):
+ x_constr, y_constr = np.unravel_index(
+ np.argmax(gaussian_filter(data_temp, 2)), (data.shape[0], data.shape[1])
+ )
+ rr = np.sqrt((xx - x_constr) ** 2 + (yy - y_constr) ** 2)
+ mask = rr < fitrad
+ xs[i], ys[i] = get_CoM(data * mask)
+ data_temp[mask] = 0
+
+ # Transform constraining points coordinate system
+ xs -= x
+ ys -= y
+ T = np.squeeze(
+ np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]])
+ )
+ xs, ys = np.matmul(T, np.array([xs, ys]))
+
+ # Get symmetrized constraining point
+ angles = np.arctan2(ys, xs)
+ distances = np.hypot(xs, ys)
+ angle = np.mean(np.min(np.vstack([np.abs(angles), np.pi - np.abs(angles)]), axis=0))
+ distance = np.mean(distances)
+ x_fixed, y_fixed = distance * np.cos(angle), distance * np.sin(angle)
+
+ # Get semiaxes a,b for the specified theta
+ t = x_fixed / (a * np.cos(phi_known / 2.0))
+ a_constrained = a * t
+ b_constrained = np.sqrt(y_fixed**2 / (1 - (x_fixed / (a_constrained)) ** 2))
+
+ return a_constrained, b_constrained
diff --git a/py4DSTEM/process/calibration/origin.py b/py4DSTEM/process/calibration/origin.py
new file mode 100644
index 000000000..a0717e321
--- /dev/null
+++ b/py4DSTEM/process/calibration/origin.py
@@ -0,0 +1,366 @@
+# Find the origin of diffraction space
+
+import functools
+import numpy as np
+from scipy.ndimage import gaussian_filter
+from scipy.optimize import leastsq
+
+from emdfile import tqdmnd, PointListArray
+from py4DSTEM.datacube import DataCube
+from py4DSTEM.process.calibration.probe import get_probe_size
+from py4DSTEM.process.fit import plane, parabola, bezier_two, fit_2D
+from py4DSTEM.process.utils import get_CoM, add_to_2D_array_from_floats, get_maxima_2D
+
+
+#
+# # origin setting decorators
+#
+# def set_measured_origin(fun):
+# """
+# This is intended as a decorator function to wrap other functions which measure
+# the position of the origin. If some function `get_the_origin` returns the
+# position of the origin as a tuple of two (R_Nx,R_Ny)-shaped arrays, then
+# decorating the function definition like
+#
+# >>> @measure_origin
+# >>> def get_the_origin(...):
+#
+# will make the function also save those arrays as the measured origin in the
+# calibration associated with the data used for the measurement. Any existing
+# measured origin value will be overwritten.
+#
+# For the wrapper to work, the decorated function's first argument must have
+# a .calibration property, and its first two return values must be qx0,qy0.
+# """
+# @functools.wraps(fun)
+# def wrapper(*args,**kwargs):
+# ans = fun(*args,**kwargs)
+# data = args[0]
+# cali = data.calibration
+# cali.set_origin_meas((ans[0],ans[1]))
+# return ans
+# return wrapper
+#
+#
+# def set_fit_origin(fun):
+# """
+# See docstring for `set_measured_origin`
+# """
+# @functools.wraps(fun)
+# def wrapper(*args,**kwargs):
+# ans = fun(*args,**kwargs)
+# data = args[0]
+# cali = data.calibration
+# cali.set_origin((ans[0],ans[1]))
+# return ans
+# return wrapper
+#
+
+
+# fit the origin
+
+
+def fit_origin(
+ data,
+ mask=None,
+ fitfunction="plane",
+ returnfitp=False,
+ robust=False,
+ robust_steps=3,
+ robust_thresh=2,
+):
+ """
+ Fits the position of the origin of diffraction space to a plane or parabola,
+ given some 2D arrays (qx0_meas,qy0_meas) of measured center positions,
+ optionally masked by the Boolean array `mask`. The 2D data arrays may be
+ passed directly as a 2-tuple to the arg `data`, or, if `data` is either a
+ DataCube or Calibration instance, they will be retreived automatically. If a
+ DataCube or Calibration are passed, fitted origin and residuals are stored
+ there directly.
+
+ Args:
+ data (2-tuple of 2d arrays): the measured origin position (qx0,qy0)
+ mask (2b boolean array, optional): ignore points where mask=False
+ fitfunction (str, optional): must be 'plane' or 'parabola' or 'bezier_two'
+ or 'constant'
+ returnfitp (bool, optional): if True, returns the fit parameters
+ robust (bool, optional): If set to True, fit will be repeated with outliers
+ removed.
+ robust_steps (int, optional): Optional parameter. Number of robust iterations
+ performed after initial fit.
+ robust_thresh (int, optional): Threshold for including points, in units of
+ root-mean-square (standard deviations) error of the predicted values after
+ fitting.
+
+ Returns:
+ (variable): Return value depends on returnfitp. If ``returnfitp==False``
+ (default), returns a 4-tuple containing:
+
+ * **qx0_fit**: *(ndarray)* the fit origin x-position
+ * **qy0_fit**: *(ndarray)* the fit origin y-position
+ * **qx0_residuals**: *(ndarray)* the x-position fit residuals
+ * **qy0_residuals**: *(ndarray)* the y-position fit residuals
+
+ If ``returnfitp==True``, returns a 2-tuple. The first element is the 4-tuple
+ described above. The second element is a 4-tuple (popt_x,popt_y,pcov_x,pcov_y)
+ giving fit parameters and covariance matrices with respect to the chosen
+ fitting function.
+ """
+ assert isinstance(data, tuple) and len(data) == 2
+ qx0_meas, qy0_meas = data
+ assert isinstance(qx0_meas, np.ndarray) and len(qx0_meas.shape) == 2
+ assert isinstance(qx0_meas, np.ndarray) and len(qy0_meas.shape) == 2
+ assert qx0_meas.shape == qy0_meas.shape
+ assert mask is None or mask.shape == qx0_meas.shape and mask.dtype == bool
+ assert fitfunction in ("plane", "parabola", "bezier_two", "constant")
+ if fitfunction == "constant":
+ qx0_fit = np.mean(qx0_meas) * np.ones_like(qx0_meas)
+ qy0_fit = np.mean(qy0_meas) * np.ones_like(qy0_meas)
+ else:
+ if fitfunction == "plane":
+ f = plane
+ elif fitfunction == "parabola":
+ f = parabola
+ elif fitfunction == "bezier_two":
+ f = bezier_two
+ else:
+ raise Exception("Invalid fitfunction '{}'".format(fitfunction))
+
+ # Check if mask for data is stored in (qx0_meax,qy0_meas) as a masked array
+ if isinstance(qx0_meas, np.ma.MaskedArray):
+ mask = np.ma.getmask(qx0_meas)
+
+ # Fit data
+ if mask is None:
+ popt_x, pcov_x, qx0_fit, _ = fit_2D(
+ f,
+ qx0_meas,
+ robust=robust,
+ robust_steps=robust_steps,
+ robust_thresh=robust_thresh,
+ )
+ popt_y, pcov_y, qy0_fit, _ = fit_2D(
+ f,
+ qy0_meas,
+ robust=robust,
+ robust_steps=robust_steps,
+ robust_thresh=robust_thresh,
+ )
+
+ else:
+ popt_x, pcov_x, qx0_fit, _ = fit_2D(
+ f,
+ qx0_meas,
+ robust=robust,
+ robust_steps=robust_steps,
+ robust_thresh=robust_thresh,
+ data_mask=mask == True, # noqa E712
+ )
+ popt_y, pcov_y, qy0_fit, _ = fit_2D(
+ f,
+ qy0_meas,
+ robust=robust,
+ robust_steps=robust_steps,
+ robust_thresh=robust_thresh,
+ data_mask=mask == True, # noqa E712
+ )
+
+ # Compute residuals
+ qx0_residuals = qx0_meas - qx0_fit
+ qy0_residuals = qy0_meas - qy0_fit
+
+ # Return
+ ans = (qx0_fit, qy0_fit, qx0_residuals, qy0_residuals)
+ if returnfitp:
+ return ans, (popt_x, popt_y, pcov_x, pcov_y)
+ else:
+ return ans
+
+
+### Functions for finding the origin
+
+# for a diffraction pattern
+
+
+def get_origin_single_dp(dp, r, rscale=1.2):
+ """
+ Find the origin for a single diffraction pattern, assuming (a) there is no beam stop,
+ and (b) the center beam contains the highest intensity.
+
+ Args:
+ dp (ndarray): the diffraction pattern
+ r (number): the approximate disk radius
+ rscale (number): factor by which `r` is scaled to generate a mask
+
+ Returns:
+ (2-tuple): The origin
+ """
+ Q_Nx, Q_Ny = dp.shape
+ _qx0, _qy0 = np.unravel_index(np.argmax(gaussian_filter(dp, r)), (Q_Nx, Q_Ny))
+ qyy, qxx = np.meshgrid(np.arange(Q_Ny), np.arange(Q_Nx))
+ mask = np.hypot(qxx - _qx0, qyy - _qy0) < r * rscale
+ qx0, qy0 = get_CoM(dp * mask)
+ return qx0, qy0
+
+
+# for a datacube
+
+
+def get_origin(
+ datacube,
+ r=None,
+ rscale=1.2,
+ dp_max=None,
+ mask=None,
+ fast_center=False,
+):
+ """
+ Find the origin for all diffraction patterns in a datacube, assuming (a) there is no
+ beam stop, and (b) the center beam contains the highest intensity. Stores the origin
+ positions in the Calibration associated with datacube, and optionally also returns
+ them.
+
+ Args:
+ datacube (DataCube): the data
+ r (number or None): the approximate radius of the center disk. If None (default),
+ tries to compute r using the get_probe_size method. The data used for this
+ is controlled by dp_max.
+ rscale (number): expand 'r' by this amount to form a mask about the center disk
+ when taking its center of mass
+ dp_max (ndarray or None): the diffraction pattern or dp-shaped array used to
+ compute the center disk radius, if r is left unspecified. Behavior depends
+ on type:
+
+ * if ``dp_max==None`` (default), computes and uses the maximal
+ diffraction pattern. Note that for a large datacube, this may be a
+ slow operation.
+ * otherwise, this should be a (Q_Nx,Q_Ny) shaped array
+ mask (ndarray or None): if not None, should be an (R_Nx,R_Ny) shaped
+ boolean array. Origin is found only where mask==True, and masked
+ arrays are returned for qx0,qy0
+ fast_center: (bool)
+ Skip the center of mass refinement step.
+
+ Returns:
+ (2-tuple of (R_Nx,R_Ny)-shaped ndarrays): the origin, (x,y) at each scan position
+ """
+ if r is None:
+ if dp_max is None:
+ dp_max = np.max(datacube.data, axis=(0, 1))
+ else:
+ assert dp_max.shape == (datacube.Q_Nx, datacube.Q_Ny)
+ r, _, _ = get_probe_size(dp_max)
+
+ qx0 = np.zeros((datacube.R_Nx, datacube.R_Ny))
+ qy0 = np.zeros((datacube.R_Nx, datacube.R_Ny))
+ qyy, qxx = np.meshgrid(np.arange(datacube.Q_Ny), np.arange(datacube.Q_Nx))
+
+ if mask is None:
+ for rx, ry in tqdmnd(
+ datacube.R_Nx,
+ datacube.R_Ny,
+ desc="Finding origins",
+ unit="DP",
+ unit_scale=True,
+ ):
+ dp = datacube.data[rx, ry, :, :]
+ _qx0, _qy0 = np.unravel_index(
+ np.argmax(gaussian_filter(dp, r, mode="nearest")),
+ (datacube.Q_Nx, datacube.Q_Ny),
+ )
+ if fast_center:
+ qx0[rx, ry], qy0[rx, ry] = _qx0, _qy0
+ else:
+ _mask = np.hypot(qxx - _qx0, qyy - _qy0) < r * rscale
+ qx0[rx, ry], qy0[rx, ry] = get_CoM(dp * _mask)
+
+ else:
+ assert mask.shape == (datacube.R_Nx, datacube.R_Ny)
+ assert mask.dtype == bool
+ qx0 = np.ma.array(
+ data=qx0, mask=np.zeros((datacube.R_Nx, datacube.R_Ny), dtype=bool)
+ )
+ qy0 = np.ma.array(
+ data=qy0, mask=np.zeros((datacube.R_Nx, datacube.R_Ny), dtype=bool)
+ )
+ for rx, ry in tqdmnd(
+ datacube.R_Nx,
+ datacube.R_Ny,
+ desc="Finding origins",
+ unit="DP",
+ unit_scale=True,
+ ):
+ if mask[rx, ry]:
+ dp = datacube.data[rx, ry, :, :]
+ _qx0, _qy0 = np.unravel_index(
+ np.argmax(gaussian_filter(dp, r, mode="nearest")),
+ (datacube.Q_Nx, datacube.Q_Ny),
+ )
+ if fast_center:
+ qx0[rx, ry], qy0[rx, ry] = _qx0, _qy0
+ else:
+ _mask = np.hypot(qxx - _qx0, qyy - _qy0) < r * rscale
+ qx0.data[rx, ry], qy0.data[rx, ry] = get_CoM(dp * _mask)
+ else:
+ qx0.mask, qy0.mask = True, True
+
+ # return
+ mask = np.ones(datacube.Rshape, dtype=bool)
+ return qx0, qy0, mask
+
+
+def get_origin_single_dp_beamstop(DP: np.ndarray, mask: np.ndarray, **kwargs):
+ """
+ Find the origin for a single diffraction pattern, assuming there is a beam stop.
+
+ Args:
+ DP (np array): diffraction pattern
+ mask (np array): boolean mask which is False under the beamstop and True
+ in the diffraction pattern. One approach to generating this mask
+ is to apply a suitable threshold on the average diffraction pattern
+ and use binary opening/closing to remove and holes
+
+ Returns:
+ qx0, qy0 (tuple) measured center position of diffraction pattern
+ """
+
+ imCorr = np.real(
+ np.fft.ifft2(
+ np.fft.fft2(DP * mask)
+ * np.conj(np.fft.fft2(np.rot90(DP, 2) * np.rot90(mask, 2)))
+ )
+ )
+
+ xp, yp = np.unravel_index(np.argmax(imCorr), imCorr.shape)
+
+ dx = ((xp + DP.shape[0] / 2) % DP.shape[0]) - DP.shape[0] / 2
+ dy = ((yp + DP.shape[1] / 2) % DP.shape[1]) - DP.shape[1] / 2
+
+ return (DP.shape[0] + dx) / 2, (DP.shape[1] + dy) / 2
+
+
+def get_origin_beamstop(datacube: DataCube, mask: np.ndarray, **kwargs):
+ """
+ Find the origin for each diffraction pattern, assuming there is a beam stop.
+
+ Args:
+ datacube (DataCube)
+ mask (np array): boolean mask which is False under the beamstop and True
+ in the diffraction pattern. One approach to generating this mask
+ is to apply a suitable threshold on the average diffraction pattern
+ and use binary opening/closing to remove any holes
+
+ Returns:
+ qx0, qy0 (tuple of np arrays) measured center position of each diffraction pattern
+ """
+
+ qx0 = np.zeros(datacube.data.shape[:2])
+ qy0 = np.zeros_like(qx0)
+
+ for rx, ry in tqdmnd(datacube.R_Nx, datacube.R_Ny):
+ x, y = get_origin_single_dp_beamstop(datacube.data[rx, ry, :, :], mask)
+
+ qx0[rx, ry] = x
+ qy0[rx, ry] = y
+
+ return qx0, qy0
diff --git a/py4DSTEM/process/calibration/probe.py b/py4DSTEM/process/calibration/probe.py
new file mode 100644
index 000000000..dc0a38949
--- /dev/null
+++ b/py4DSTEM/process/calibration/probe.py
@@ -0,0 +1,62 @@
+import numpy as np
+
+from py4DSTEM.process.utils import get_CoM
+
+
+def get_probe_size(DP, thresh_lower=0.01, thresh_upper=0.99, N=100):
+ """
+ Gets the center and radius of the probe in the diffraction plane.
+
+ The algorithm is as follows:
+ First, create a series of N binary masks, by thresholding the diffraction pattern
+ DP with a linspace of N thresholds from thresh_lower to thresh_upper, measured
+ relative to the maximum intensity in DP.
+ Using the area of each binary mask, calculate the radius r of a circular probe.
+ Because the central disk is typically very intense relative to the rest of the DP, r
+ should change very little over a wide range of intermediate values of the threshold.
+ The range in which r is trustworthy is found by taking the derivative of r(thresh)
+ and finding identifying where it is small. The radius is taken to be the mean of
+ these r values. Using the threshold corresponding to this r, a mask is created and
+ the CoM of the DP times this mask it taken. This is taken to be the origin x0,y0.
+
+ Args:
+ DP (2D array): the diffraction pattern in which to find the central disk.
+ A position averaged, or shift-corrected and averaged, DP works best.
+ thresh_lower (float, 0 to 1): the lower limit of threshold values
+ thresh_upper (float, 0 to 1): the upper limit of threshold values
+ N (int): the number of thresholds / masks to use
+
+ Returns:
+ (3-tuple): A 3-tuple containing:
+
+ * **r**: *(float)* the central disk radius, in pixels
+ * **x0**: *(float)* the x position of the central disk center
+ * **y0**: *(float)* the y position of the central disk center
+ """
+ from py4DSTEM.braggvectors import Probe
+
+ # parse input
+ if isinstance(DP, Probe):
+ DP = DP.probe
+
+ thresh_vals = np.linspace(thresh_lower, thresh_upper, N)
+ r_vals = np.zeros(N)
+
+ # Get r for each mask
+ DPmax = np.max(DP)
+ for i in range(len(thresh_vals)):
+ thresh = thresh_vals[i]
+ mask = DP > DPmax * thresh
+ r_vals[i] = np.sqrt(np.sum(mask) / np.pi)
+
+ # Get derivative and determine trustworthy r-values
+ dr_dtheta = np.gradient(r_vals)
+ mask = (dr_dtheta <= 0) * (dr_dtheta >= 2 * np.median(dr_dtheta))
+ r = np.mean(r_vals[mask])
+
+ # Get origin
+ thresh = np.mean(thresh_vals[mask])
+ mask = DP > DPmax * thresh
+ x0, y0 = get_CoM(DP * mask)
+
+ return r, x0, y0
diff --git a/py4DSTEM/process/calibration/qpixelsize.py b/py4DSTEM/process/calibration/qpixelsize.py
new file mode 100644
index 000000000..0510fad06
--- /dev/null
+++ b/py4DSTEM/process/calibration/qpixelsize.py
@@ -0,0 +1,65 @@
+# Functions for calibrating the pixel size in the diffraction plane.
+
+import numpy as np
+from scipy.optimize import leastsq
+from typing import Union, Optional
+
+from emdfile import tqdmnd
+from py4DSTEM.process.utils import get_CoM
+
+
+def get_Q_pixel_size(q_meas, q_known, units="A"):
+ """
+ Computes the size of the Q-space pixels.
+
+ Args:
+ q_meas (number): a measured distance in q-space in pixels
+ q_known (number): the corresponding known *real space* distance
+ unit (str): the units of the real space value of `q_known`
+
+ Returns:
+ (number,str): the detector pixel size, the associated units
+ """
+ return 1.0 / (q_meas * q_known), units + "^-1"
+
+
+def get_dq_from_indexed_peaks(qs, hkl, a):
+ """
+ Get dq, the size of the detector pixels in the diffraction plane, in inverse length
+ units, using a set of measured peak distances from the optic axis, their Miller
+ indices, and the known unit cell size.
+
+ Args:
+ qs (array): the measured peak positions
+ hkl (list/tuple of length-3 tuples): the Miller indices of the peak positions qs.
+ The length of qs and hkl must be the same. To ignore any peaks, for this
+ peak set (h,k,l)=(0,0,0).
+ a (number): the unit cell size
+
+ Returns:
+ (4-tuple): A 4-tuple containing:
+
+ * **dq**: *(number)* the detector pixel size
+ * **qs_fit**: *(array)* the fit positions of the peaks
+ * **hkl_fit**: *(list/tuple of length-3 tuples)* the Miller indices of the
+ fit peaks
+ * **mask**: *(array of bools)* False wherever hkl[i]==(0,0,0)
+ """
+ assert len(qs) == len(hkl), "qs and hkl must have same length"
+
+ # Get spacings
+ d_inv = np.array([np.sqrt(a**2 + b**2 + c**2) for (a, b, c) in hkl])
+ mask = d_inv != 0
+
+ # Get scaling factor
+ c0 = np.average(qs[mask] / d_inv[mask])
+ fiterr = lambda c: qs[mask] - c * d_inv[mask]
+ popt, _ = leastsq(fiterr, c0)
+ c = popt[0]
+
+ # Get pixel size
+ dq = 1 / (c * a)
+ qs_fit = d_inv[mask] / a
+ hkl_fit = [hkl[i] for i in range(len(hkl)) if mask[i] is True]
+
+ return dq, qs_fit, hkl_fit
diff --git a/py4DSTEM/process/calibration/rotation.py b/py4DSTEM/process/calibration/rotation.py
new file mode 100644
index 000000000..aaf8a49ce
--- /dev/null
+++ b/py4DSTEM/process/calibration/rotation.py
@@ -0,0 +1,237 @@
+# Rotational calibrations
+
+import numpy as np
+from typing import Optional
+import matplotlib.pyplot as plt
+from py4DSTEM import show
+
+
+def compare_QR_rotation(
+ im_R,
+ im_Q,
+ QR_rotation,
+ R_rotation=0,
+ R_position=None,
+ Q_position=None,
+ R_pos_anchor="center",
+ Q_pos_anchor="center",
+ R_length=0.33,
+ Q_length=0.33,
+ R_width=0.001,
+ Q_width=0.001,
+ R_head_length_adjust=1,
+ Q_head_length_adjust=1,
+ R_head_width_adjust=1,
+ Q_head_width_adjust=1,
+ R_color="r",
+ Q_color="r",
+ figsize=(10, 5),
+ returnfig=False,
+):
+ """
+ Visualize a rotational offset between an image in real space, e.g. a STEM
+ virtual image, and an image in diffraction space, e.g. a defocused CBED
+ shadow image of the same region, by displaying an arrow overlaid over each
+ of these two images with the specified QR rotation applied. The QR rotation
+ is defined as the counter-clockwise rotation from real space to diffraction
+ space, in degrees.
+
+ Parameters
+ ----------
+ im_R : numpy array or other 2D image-like object (e.g. a VirtualImage)
+ A real space image, e.g. a STEM virtual image
+ im_Q : numpy array or other 2D image-like object
+ A diffraction space image, e.g. a defocused CBED image
+ QR_rotation : number
+ The counterclockwise rotation from real space to diffraction space,
+ in degrees
+ R_rotation : number
+ The orientation of the arrow drawn in real space, in degrees
+ R_position : None or 2-tuple
+ The position of the anchor point for the R-space arrow. If None, defaults
+ to the center of the image
+ Q_position : None or 2-tuple
+ The position of the anchor point for the Q-space arrow. If None, defaults
+ to the center of the image
+ R_pos_anchor : 'center' or 'tail' or 'head'
+ The anchor point for the R-space arrow, i.e. the point being specified by
+ the `R_position` parameter
+ Q_pos_anchor : 'center' or 'tail' or 'head'
+ The anchor point for the Q-space arrow, i.e. the point being specified by
+ the `Q_position` parameter
+ R_length : number or None
+ The length of the R-space arrow, as a fraction of the mean size of the
+ image
+ Q_length : number or None
+ The length of the Q-space arrow, as a fraction of the mean size of the
+ image
+ R_width : number
+ The width of the R-space arrow
+ Q_width : number
+ The width of the R-space arrow
+ R_head_length_adjust : number
+ Scaling factor for the R-space arrow head length
+ Q_head_length_adjust : number
+ Scaling factor for the Q-space arrow head length
+ R_head_width_adjust : number
+ Scaling factor for the R-space arrow head width
+ Q_head_width_adjust : number
+ Scaling factor for the Q-space arrow head width
+ R_color : color
+ Color of the R-space arrow
+ Q_color : color
+ Color of the Q-space arrow
+ figsize : 2-tuple
+ The figure size
+ returnfig : bool
+ Toggles returning the figure and axes
+ """
+ # parse inputs
+ if R_position is None:
+ R_position = (
+ im_R.shape[0] / 2,
+ im_R.shape[1] / 2,
+ )
+ if Q_position is None:
+ Q_position = (
+ im_Q.shape[0] / 2,
+ im_Q.shape[1] / 2,
+ )
+ R_length = np.mean(im_R.shape) * R_length
+ Q_length = np.mean(im_Q.shape) * Q_length
+ assert R_pos_anchor in ("center", "tail", "head")
+ assert Q_pos_anchor in ("center", "tail", "head")
+
+ # compute positions
+ rpos_x, rpos_y = R_position
+ qpos_x, qpos_y = Q_position
+ R_rot_rad = np.radians(R_rotation)
+ Q_rot_rad = np.radians(R_rotation + QR_rotation)
+ rvecx = np.cos(R_rot_rad)
+ rvecy = np.sin(R_rot_rad)
+ qvecx = np.cos(Q_rot_rad)
+ qvecy = np.sin(Q_rot_rad)
+ if R_pos_anchor == "center":
+ x0_r = rpos_x - rvecx * R_length / 2
+ y0_r = rpos_y - rvecy * R_length / 2
+ x1_r = rpos_x + rvecx * R_length / 2
+ y1_r = rpos_y + rvecy * R_length / 2
+ elif R_pos_anchor == "tail":
+ x0_r = rpos_x
+ y0_r = rpos_y
+ x1_r = rpos_x + rvecx * R_length
+ y1_r = rpos_y + rvecy * R_length
+ elif R_pos_anchor == "head":
+ x0_r = rpos_x - rvecx * R_length
+ y0_r = rpos_y - rvecy * R_length
+ x1_r = rpos_x
+ y1_r = rpos_y
+ else:
+ raise Exception(f"Invalid value for R_pos_anchor {R_pos_anchor}")
+ if Q_pos_anchor == "center":
+ x0_q = qpos_x - qvecx * Q_length / 2
+ y0_q = qpos_y - qvecy * Q_length / 2
+ x1_q = qpos_x + qvecx * Q_length / 2
+ y1_q = qpos_y + qvecy * Q_length / 2
+ elif Q_pos_anchor == "tail":
+ x0_q = qpos_x
+ y0_q = qpos_y
+ x1_q = qpos_x + qvecx * Q_length
+ y1_q = qpos_y + qvecy * Q_length
+ elif Q_pos_anchor == "head":
+ x0_q = qpos_x - qvecx * Q_length
+ y0_q = qpos_y - qvecy * Q_length
+ x1_q = qpos_x
+ y1_q = qpos_y
+ else:
+ raise Exception(f"Invalid value for Q_pos_anchor {Q_pos_anchor}")
+
+ # make the figure
+ axsize = (figsize[0] / 2, figsize[1])
+ fig, axs = show([im_R, im_Q], returnfig=True, axsize=axsize)
+ axs[0, 0].arrow(
+ x=y0_r,
+ y=x0_r,
+ dx=y1_r - y0_r,
+ dy=x1_r - x0_r,
+ color=R_color,
+ length_includes_head=True,
+ width=R_width,
+ head_width=R_length * R_head_width_adjust * 0.072,
+ head_length=R_length * R_head_length_adjust * 0.1,
+ )
+ axs[0, 1].arrow(
+ x=y0_q,
+ y=x0_q,
+ dx=y1_q - y0_q,
+ dy=x1_q - x0_q,
+ color=Q_color,
+ length_includes_head=True,
+ width=Q_width,
+ head_width=Q_length * Q_head_width_adjust * 0.072,
+ head_length=Q_length * Q_head_length_adjust * 0.1,
+ )
+ if returnfig:
+ return fig, axs
+ else:
+ plt.show()
+
+
+def get_Qvector_from_Rvector(vx, vy, QR_rotation):
+ """
+ For some vector (vx,vy) in real space, and some rotation QR between real and
+ reciprocal space, determine the corresponding orientation in diffraction space.
+ Returns both R and Q vectors, normalized.
+
+ Args:
+ vx,vy (numbers): the (x,y) components of a real space vector
+ QR_rotation (number): the offset angle between real and reciprocal space.
+ Specifically, the counterclockwise rotation of real space with respect to
+ diffraction space. In degrees.
+
+ Returns:
+ (4-tuple): 4-tuple consisting of:
+
+ * **vx_R**: the x component of the normalized real space vector
+ * **vy_R**: the y component of the normalized real space vector
+ * **vx_Q**: the x component of the normalized reciprocal space vector
+ * **vy_Q**: the y component of the normalized reciprocal space vector
+ """
+ phi = np.radians(QR_rotation)
+ vL = np.hypot(vx, vy)
+ vx_R, vy_R = vx / vL, vy / vL
+
+ vx_Q = np.cos(phi) * vx_R + np.sin(phi) * vy_R
+ vy_Q = -np.sin(phi) * vx_R + np.cos(phi) * vy_R
+
+ return vx_R, vy_R, vx_Q, vy_Q
+
+
+def get_Rvector_from_Qvector(vx, vy, QR_rotation):
+ """
+ For some vector (vx,vy) in diffraction space, and some rotation QR between real and
+ reciprocal space, determine the corresponding orientation in diffraction space.
+ Returns both R and Q vectors, normalized.
+
+ Args:
+ vx,vy (numbers): the (x,y) components of a reciprocal space vector
+ QR_rotation (number): the offset angle between real and reciprocal space.
+ Specifically, the counterclockwise rotation of real space with respect to
+ diffraction space. In degrees.
+
+ Returns:
+ (4-tuple): 4-tuple consisting of:
+
+ * **vx_R**: the x component of the normalized real space vector
+ * **vy_R**: the y component of the normalized real space vector
+ * **vx_Q**: the x component of the normalized reciprocal space vector
+ * **vy_Q**: the y component of the normalized reciprocal space vector
+ """
+ phi = np.radians(QR_rotation)
+ vL = np.hypot(vx, vy)
+ vx_Q, vy_Q = vx / vL, vy / vL
+
+ vx_R = np.cos(phi) * vx_Q - np.sin(phi) * vy_Q
+ vy_R = np.sin(phi) * vx_Q + np.cos(phi) * vy_Q
+
+ return vx_R, vy_R, vx_Q, vy_Q
diff --git a/py4DSTEM/process/classification/__init__.py b/py4DSTEM/process/classification/__init__.py
new file mode 100644
index 000000000..42a8a6d4a
--- /dev/null
+++ b/py4DSTEM/process/classification/__init__.py
@@ -0,0 +1,3 @@
+from py4DSTEM.process.classification.braggvectorclassification import *
+from py4DSTEM.process.classification.classutils import *
+from py4DSTEM.process.classification.featurization import *
diff --git a/py4DSTEM/process/classification/braggvectorclassification.py b/py4DSTEM/process/classification/braggvectorclassification.py
new file mode 100644
index 000000000..4956f7630
--- /dev/null
+++ b/py4DSTEM/process/classification/braggvectorclassification.py
@@ -0,0 +1,950 @@
+# Functions for classification using a method which initially classifies the detected bragg
+# vectors, then uses these labels to classify real space scan positions
+
+import numpy as np
+from numpy.linalg import lstsq
+from itertools import permutations
+from scipy.ndimage import gaussian_filter
+from scipy.ndimage import (
+ binary_opening,
+ binary_closing,
+ binary_dilation,
+ binary_erosion,
+)
+from skimage.measure import label
+from sklearn.decomposition import NMF
+
+from emdfile import PointListArray
+
+
+class BraggVectorClassification(object):
+ """
+ A class for classifying 4D-STEM data based on which Bragg peaks are found at each
+ diffraction pattern.
+
+ A BraggVectorClassification instance enables classification using several methods; a brief
+ overview is provided here, with more details in each individual method's documentation.
+
+ Initialization methods:
+
+ __init__:
+ Determine the initial classes. The approach here involves first segmenting diffraction
+ space, using maxima of a Bragg vector map.
+
+ get_initial_classes_by_cooccurrence:
+
+ Class refinement methods:
+ Each of these methods creates a new set of candidate classes, *but does not yet overwrite the
+ old classes*. This enables the new classes to be viewed and compared to the old classes before
+ deciding whether to accept or reject them. Thus running two of these methods in succession,
+ without accepting changes in between, simply discards the first set of candidate classes.
+
+ nmf:
+ Nonnegative matrix factorization (X = WH) to refine the classes. Briefly, after
+ constructing a matrix X which describes which Bragg peaks were observed in each
+ diffraction pattern, we factor X into two smaller matrices, W and H. Physically, W and H
+ describe a small set of classes, each of which corresponds to some subset of (or, more
+ strictly, weights for) the Bragg peaks and the scan positions. We additionally impose
+ the contraint that, on physical grounds, all the elements of X, W, and H must be
+ nonnegative.
+ split:
+ If any classes contain multiple non-contiguous segments in real space, divide these into
+ distinct classes.
+ merge:
+ If any classes contain sufficient overlap in both scan positions and BPs, merge them
+ into a single class.
+
+ Accepting/rejecting changes:
+
+ accept:
+ Updates classes (the W and H matrices) with the current candidate classes.
+ reject:
+ Discard the current candidate classes.
+
+ Class examination methods:
+
+ get_class:
+ get a single class, returning both its BP weights and scan position weights
+ get_class_BPs:
+ get the BP weights for a single class
+ get_class_image:
+ get the image, i.e. scan position weights, associated with a single class
+ get_candidate_class:
+ as above, for the current candidate class
+ get_candidate_class_BPs:
+ as above, for the current candidate class
+ get_candidate_class_image:
+ as above, for the current candidate class
+
+ Args:
+ braggpeaks (PointListArray): Bragg peaks; must have coords 'qx' and 'qy'
+ Qx (ndarray of floats): x-coords of the voronoi points
+ Qy (ndarray of floats): y-coords of the voronoi points
+ X_is_boolean (bool): if True, populate X with bools (BP is or is not present).
+ if False, populate X with floats (BP c.c. intensities)
+ max_dist (None or number): maximum distance from a given voronoi point a peak
+ can be and still be associated with this label
+ """
+
+ def __init__(self, braggpeaks, Qx, Qy, X_is_boolean=True, max_dist=None):
+ """
+ Initializes a BraggVectorClassification instance.
+
+ This method:
+ 1. Gets integer labels for all of the detected Bragg peaks, according to which
+ (Qx,Qy) is closest, then generating a corresponding set of integers for each scan
+ position. See get_braggpeak_labels_by_scan_position() docstring for more info.
+ 2. Generates the data matrix X. See the nmf() method docstring for more info.
+
+ This method should be followed by one of the methods which populates the initial classes -
+ currently, either get_initial_classes_by_cooccurrence() or get_initial_classes_from_images.
+ These methods generate the W and H matrices -- i.e. the decompositions of the X matrix in
+ terms of scan positions and Bragg peaks -- which are necessary for any subsequent
+ processing.
+
+ Args:
+ braggpeaks (PointListArray): Bragg peaks; must have coords 'qx' and 'qy'
+ Qx (ndarray of floats): x-coords of the voronoi points
+ Qy (ndarray of floats): y-coords of the voronoi points
+ X_is_boolean (bool): if True, populate X with bools (BP is or is not present).
+ if False, populate X with floats (BP c.c. intensities)
+ max_dist (None or number): maximum distance from a given voronoi point a peak
+ can be and still be associated with this label
+ """
+ assert isinstance(
+ braggpeaks, PointListArray
+ ), "braggpeaks must be a PointListArray"
+ assert np.all(
+ [name in braggpeaks.dtype.names for name in ("qx", "qy")]
+ ), "braggpeaks must contain coords 'qx' and 'qy'"
+ assert len(Qx) == len(Qy), "Qx and Qy must have same length"
+ self.braggpeaks = braggpeaks
+ self.R_Nx = braggpeaks.shape[0] #: shape of real space (x)
+ self.R_Ny = braggpeaks.shape[1] #: shape of real space (y)
+ self.Qx = Qx #: x-coordinates of the voronoi points
+ self.Qy = Qy #: y-coordinates of the voronoi points
+
+ #: the sets of Bragg peaks present at each scan position
+ self.braggpeak_labels = get_braggpeak_labels_by_scan_position(
+ braggpeaks, Qx, Qy, max_dist
+ )
+
+ # Construct X matrix
+ #: first dimension of the data matrix; the number of bragg peaks
+ self.N_feat = len(self.Qx)
+ #: second dimension of the data matrix; the number of scan positions
+ self.N_meas = self.R_Nx * self.R_Ny
+
+ self.X = np.zeros((self.N_feat, self.N_meas)) #: the data matrix
+ for Rx in range(self.R_Nx):
+ for Ry in range(self.R_Ny):
+ R = Rx * self.R_Ny + Ry
+ s = self.braggpeak_labels[Rx][Ry]
+ pointlist = self.braggpeaks.get_pointlist(Rx, Ry)
+ for i in s:
+ if X_is_boolean:
+ self.X[i, R] = True
+ else:
+ ind = np.argmin(
+ np.hypot(
+ pointlist.data["qx"] - Qx[i],
+ pointlist.data["qy"] - Qy[i],
+ )
+ )
+ self.X[i, R] = pointlist.data["intensity"][ind]
+
+ return
+
+ def get_initial_classes_by_cooccurrence(
+ self,
+ thresh=0.3,
+ BP_fraction_thresh=0.1,
+ max_iterations=200,
+ X_is_boolean=True,
+ n_corr_init=2,
+ ):
+ """
+ Populate the initial classes by finding sets of Bragg peaks that tend to co-occur
+ in the
+ same diffraction patterns.
+
+ Beginning from the sets of Bragg peaks labels for each scan position (determined
+ in __init__), this method gets initial classes by determining which labels are
+ most likely to co-occur with each other -- see get_initial_classes() docstring
+ for more info. Then the matrices W and H are generated -- see nmf() doscstring
+ for discussion.
+
+ Args:
+ thresh (float in [0,1]): threshold for adding new BPs to a class
+ BP_fraction_thresh (float in [0,1]): algorithm terminates if fewer than this
+ fraction of the BPs have not been assigned to a class
+ max_iterations (int): algorithm terminates after this many iterations
+ n_corr_init (int): seed new classes by finding maxima of the n-point joint
+ probability function. Must be 2 or 3.
+ """
+ assert isinstance(X_is_boolean, bool)
+ assert isinstance(max_iterations, (int, np.integer))
+ assert n_corr_init in (2, 3)
+
+ # Get sets of integers representing the initial classes
+ BP_sets = get_initial_classes(
+ self.braggpeak_labels,
+ N=len(self.Qx),
+ thresh=thresh,
+ BP_fraction_thresh=BP_fraction_thresh,
+ max_iterations=max_iterations,
+ n_corr_init=n_corr_init,
+ )
+
+ # Construct W, H matrices
+ self.N_c = len(BP_sets)
+
+ # W
+ self.W = np.zeros((self.N_feat, self.N_c))
+ for i in range(self.N_c):
+ BP_set = BP_sets[i]
+ for j in BP_set:
+ self.W[j, i] = 1
+
+ # H
+ self.H = lstsq(self.W, self.X, rcond=None)[0]
+ self.H = np.where(self.H < 0, 0, self.H)
+
+ self.W_next = None
+ self.H_next = None
+ self.N_c_next = None
+
+ return
+
+ def get_initial_classes_from_images(self, class_images):
+ """
+ Populate the initial classes using a set of user-defined class images.
+
+ Args:
+ class_images (ndarray): must have shape (R_Nx,R_Ny,N_c), where N_c is the
+ number of classes, and class_images[:,:,i] is the image of class i.
+ """
+ assert class_images.shape[0] == self.R_Nx
+ assert class_images.shape[1] == self.R_Ny
+
+ # Construct W, H matrices
+ self.N_c = class_images.shape[2]
+
+ # H
+ H = np.zeros((self.N_c, self.N_meas))
+ for i in range(self.N_c):
+ H[i, :] = class_images[:, :, i].ravel()
+ self.H = np.copy(H, order="C")
+
+ # W
+ W = lstsq(self.H.T, self.X.T, rcond=None)[0].T
+ W = np.where(W < 0, 0, W)
+ self.W = np.copy(W, order="C")
+
+ self.W_next = None
+ self.H_next = None
+ self.N_c_next = None
+
+ return
+
+ def nmf(self, max_iterations=1):
+ """
+ Nonnegative matrix factorization to refine the classes.
+
+ The data matrix ``X`` is factored into two smaller matrices, ``W`` and ``H``::
+
+ X = WH
+
+ Here,
+
+ * ``X``is the data matrix. It has shape (N_feat,N_meas), where N_feat is the
+ number of Bragg peak integer labels (i.e. len(Qx)) and N_meas is the number
+ of diffraction patterns (i.e. R_Nx*R_Ny). Element X[i,j] represents the
+ value of the i'th BP in the j'th DP. The values depend on the flag
+ datamatrix_is_boolean: if True, X[i,j] is 1 if this BP was present in this
+ DP, or 0 if not; if False, X[i,j] is the cross correlation intensity of
+ this BP in this DP.
+ * ``W`` is the class matrix. It has shape (N_feat,N_c), where N_c is the
+ number of classes. The i'th column vector, w_i = W[:,i], describes the
+ weight of each Bragg peak in the i'th class. w_i has length N_feat, and
+ w_i[j] describes how strongly the j'th BP is associated with the i'th
+ class.
+ * ``H`` is the coefficient matrix. It has shape (N_c,N_meas). The i'th
+ column vector H[:,i] describes the contribution of each class to scan
+ position i.
+
+ Alternatively, we can completely equivalently think of H as a class matrix,
+ and W as a coeffient matrix. In this picture, the i'th row vector of H,
+ h_i = H[i,:], describes the weight of each scan position in the i'th class.
+ h_i has length N_meas, and h_i[j] describes how strongly the j'th scan
+ position is associated with the i'th class. The row vector W[i,:] is then
+ a coefficient vector, which gives the contributions each of the (H) classes
+ to the measured values of the i'th BP. These pictures are related by a
+ transpose: X = WH is equivalent to X.T = (H.T)(W.T).
+
+ In nonnegative matrix factorization we impose the constrain, here on
+ physical grounds, that all elements of X, W, and H should be nonnegative.
+
+ The computation itself is performed using the sklearn nmf class. When this method
+ is called, the three relevant matrices should already be defined. This method
+ refines W and H, with up to max_iterations NMF steps.
+
+ Args:
+ max_iterations (int): the maximum number of NMF steps to take
+ """
+ sklearn_nmf = NMF(n_components=self.N_c, init="custom", max_iter=max_iterations)
+ self.W_next = sklearn_nmf.fit_transform(self.X, W=self.W, H=self.H)
+ self.H_next = sklearn_nmf.components_
+ self.N_c_next = self.W_next.shape[1]
+
+ return
+
+ def split(self, sigma=2, threshold_split=0.25, expand_mask=1, minimum_pixels=1):
+ """
+ If any classes contain multiple non-contiguous segments in real space, divide
+ these regions into distinct classes.
+
+ Algorithm is as follows:
+ First, an image of each class is obtained from its scan position weights.
+ Then, the image is convolved with a gaussian of std sigma.
+ This is then turned into a binary mask, by thresholding with threshold_split.
+ Stray pixels are eliminated by performing a one pixel binary closing, then binary
+ opening.
+ The mask is then expanded by expand_mask pixels.
+ Finally, the contiguous regions of the resulting mask are found. These become the
+ new class components by scan position.
+
+ The splitting itself involves creating two classes - i.e. adding a column to W
+ and a row to H. The new BP classes (W columns) have exactly the same values as
+ the old BP class. The two new scan position classes (H rows) divide up the
+ non-zero entries of the old scan position class into two or more non-intersecting
+ subsets, each of which becomes its own new class.
+
+ Args:
+ sigma (float): std of gaussian kernel used to smooth the class images before
+ thresholding and splitting.
+ threshold_split (float): used to threshold the class image to create a binary mask.
+ expand_mask (int): number of pixels by which to expand the mask before separating
+ into contiguous regions.
+ minimum_pixels (int): if, after splitting, a potential new class contains fewer than
+ this number of pixels, ignore it
+ """
+ assert isinstance(expand_mask, (int, np.integer))
+ assert isinstance(minimum_pixels, (int, np.integer))
+
+ W_next = np.zeros((self.N_feat, 1))
+ H_next = np.zeros((1, self.N_meas))
+ for i in range(self.N_c):
+ # Get the class in real space
+ class_image = self.get_class_image(i)
+
+ # Turn into a binary mask
+ class_image = gaussian_filter(class_image, sigma)
+ mask = class_image > (np.max(class_image) * threshold_split)
+ mask = binary_opening(mask, iterations=1)
+ mask = binary_closing(mask, iterations=1)
+ mask = binary_dilation(mask, iterations=expand_mask)
+
+ # Get connected regions
+ labels, nlabels = label(mask, background=0, return_num=True, connectivity=2)
+
+ # Add each region to the new W and H matrices
+ for j in range(nlabels):
+ mask = labels == (j + 1)
+ mask = binary_erosion(mask, iterations=expand_mask)
+
+ if np.sum(mask) >= minimum_pixels:
+ # Leave the Bragg peak weightings the same
+ W_next = np.hstack((W_next, self.W[:, i, np.newaxis]))
+
+ # Use the existing real space pixel weightings
+ h_i = np.zeros(self.N_meas)
+ h_i[mask.ravel()] = self.H[i, :][mask.ravel()]
+ H_next = np.vstack((H_next, h_i[np.newaxis, :]))
+
+ self.W_next = W_next[:, 1:]
+ self.H_next = H_next[1:, :]
+ self.N_c_next = self.W_next.shape[1]
+
+ return
+
+ def merge(self, threshBPs=0.1, threshScanPosition=0.1, return_params=True):
+ """
+ If any classes contain sufficient overlap in both scan positions and BPs, merge
+ them into a single class.
+
+ The algorithm is as follows:
+ First, the Pearson correlation coefficient matrix is calculated for the classes
+ according to both their diffraction space, Bragg peak representations (i.e. the
+ correlations of the columns of W) and according to their real space, scan
+ position representations (i.e. the correlations of the rows of H). Class pairs
+ whose BP correlation coefficient exceeds threshBPs and whose scan position
+ correlation coefficient exceed threshScanPosition are deemed 'sufficiently
+ overlapped', and are marked as merge candidates. To account for intransitivity
+ issues (e.g. class pairs 1/2 and 2/3 are merge candidates, but class pair 1/3 is
+ not), merging is then performed beginning with candidate pairs with the greatest
+ product of the two correlation coefficients, skipping later merge candidate pairs
+ if one of the two classes has already been merged.
+
+ The algorithm can be looped until no more merge candidates satisfying the
+ specified thresholds remain with the merge_iterative method.
+
+ The merging itself involves turning two classes into one by combining a pair of
+ W columns (i.e. the BP representations of the classes) and the corresponding pair
+ of H rows (i.e. the scan position representation of the class) into a single W
+ column / H row. In terms of scan positions, the new row of H is generated by
+ simply adding the two old H rows. In terms of Bragg peaks, the new column of W is
+ generated by adding the two old columns of W, while weighting each by its total
+ intensity in real space (i.e. the sum of its H row).
+
+ Args:
+ threshBPs (float): the threshold for the bragg peaks correlation coefficient,
+ above which the two classes are considered candidates for merging
+ threshScanPosition (float): the threshold for the scan position correlation
+ coefficient, above which two classes are considered candidates for
+ merging
+ return_params (bool): if True, returns W_corr, H_corr, and merge_candidates.
+ Otherwise, returns nothing. Incompatible with iterative=True.
+ """
+
+ def merge_by_class_index(self, i, j):
+ """
+ Merge classes i and j into a single class.
+
+ Columns i and j of W pair of W (i.e. the BP representations of the classes) and
+ the corresponding pair of H rows (i.e. the scan position representation of the
+ class) are mergedinto a single W column / H row. In terms of scan positions, the
+ new row of H is generated by simply adding the two old H rows. In terms of Bragg
+ peaks, the new column of W is generated by adding the two old columns of W, while
+ weighting each by its total intensity in real space (i.e. the sum of its H row).
+
+ Args:
+ i (int): index of the first class to merge
+ j (int): index of the second class to merge
+ """
+ assert np.all(
+ [isinstance(ind, (int, np.integer)) for ind in [i, j]]
+ ), "i and j must be ints"
+
+ # Get merged class
+ weight_i = np.sum(self.H[i, :])
+ weight_j = np.sum(self.H[j, :])
+ W_new = (self.W[:, i] * weight_i + self.W[:, j] * weight_j) / (
+ weight_i + weight_j
+ )
+ H_new = self.H[i, :] + self.H[j, :]
+
+ # Remove old classes and add in new class
+ self.W_next = np.delete(self.W, j, axis=1)
+ self.H_next = np.delete(self.H, j, axis=0)
+ self.W_next[:, i] = W_new
+ self.H_next[i, :] = H_new
+ self.N_c_next = self.N_c - 1
+
+ return
+
+ def split_by_class_index(
+ self, i, sigma=2, threshold_split=0.25, expand_mask=1, minimum_pixels=1
+ ):
+ """
+ If class i contains multiple non-contiguous segments in real space, divide these
+ regions into distinct classes.
+
+ Algorithm is as described in the docstring for self.split.
+
+ Args:
+ i (int): index of the class to split
+ sigma (float): std of gaussian kernel used to smooth the class images before
+ thresholding and splitting.
+ threshold_split (float): used to threshold the class image to create a binary
+ mask.
+ expand_mask (int): number of pixels by which to expand the mask before
+ separating into contiguous regions.
+ minimum_pixels (int): if, after splitting, a potential new class contains
+ fewer than this number of pixels, ignore it
+ """
+ assert isinstance(i, (int, np.integer))
+ assert isinstance(expand_mask, (int, np.integer))
+ assert isinstance(minimum_pixels, (int, np.integer))
+ W_next = np.zeros((self.N_feat, 1))
+ H_next = np.zeros((1, self.N_meas))
+
+ # Get the class in real space
+ class_image = self.get_class_image(i)
+
+ # Turn into a binary mask
+ class_image = gaussian_filter(class_image, sigma)
+ mask = class_image > (np.max(class_image) * threshold_split)
+ mask = binary_opening(mask, iterations=1)
+ mask = binary_closing(mask, iterations=1)
+ mask = binary_dilation(mask, iterations=expand_mask)
+
+ # Get connected regions
+ labels, nlabels = label(mask, background=0, return_num=True, connectivity=2)
+
+ # Add each region to the new W and H matrices
+ for j in range(nlabels):
+ mask = labels == (j + 1)
+ mask = binary_erosion(mask, iterations=expand_mask)
+
+ if np.sum(mask) >= minimum_pixels:
+ # Leave the Bragg peak weightings the same
+ W_next = np.hstack((W_next, self.W[:, i, np.newaxis]))
+
+ # Use the existing real space pixel weightings
+ h_i = np.zeros(self.N_meas)
+ h_i[mask.ravel()] = self.H[i, :][mask.ravel()]
+ H_next = np.vstack((H_next, h_i[np.newaxis, :]))
+
+ W_prev = np.delete(self.W, i, axis=1)
+ H_prev = np.delete(self.H, i, axis=0)
+ self.W_next = np.concatenate((W_next[:, 1:], W_prev), axis=1)
+ self.H_next = np.concatenate((H_next[1:, :], H_prev), axis=0)
+ self.N_c_next = self.W_next.shape[1]
+
+ return
+
+ def remove_class(self, i):
+ """
+ Remove class i.
+
+ Args:
+ i (int): index of the class to remove
+ """
+ assert isinstance(i, (int, np.integer))
+
+ self.W_next = np.delete(self.W, i, axis=1)
+ self.H_next = np.delete(self.H, i, axis=0)
+ self.N_c_next = self.W_next.shape[1]
+
+ return
+
+ def merge_iterative(self, threshBPs=0.1, threshScanPosition=0.1):
+ """
+ If any classes contain sufficient overlap in both scan positions and BPs, merge
+ them into a single class.
+
+ Identical to the merge method, with the addition of iterating until no new merge
+ pairs are found.
+
+ Args:
+ threshBPs (float): the threshold for the bragg peaks correlation coefficient,
+ above which the two classes are considered candidates for merging
+ threshScanPosition (float): the threshold for the scan position correlation
+ coefficient, above which two classes are considered candidates for
+ merging
+ """
+ proceed = True
+ W_ = np.copy(self.W)
+ H_ = np.copy(self.H)
+ Nc_ = W_.shape[1]
+
+ while proceed:
+ # Get correlation coefficients
+ W_corr = np.corrcoef(W_.T)
+ H_corr = np.corrcoef(H_)
+
+ # Get merge candidate pairs
+ mask_BPs = W_corr > threshBPs
+ mask_ScanPosition = H_corr > threshScanPosition
+ mask_upperright = np.zeros((Nc_, Nc_), dtype=bool)
+ for i in range(Nc_):
+ mask_upperright[i, i + 1 :] = 1
+ merge_mask = mask_BPs * mask_ScanPosition * mask_upperright
+ merge_i, merge_j = np.nonzero(merge_mask)
+
+ # Sort merge candidate pairs
+ merge_candidates = np.zeros(
+ len(merge_i),
+ dtype=[
+ ("i", int),
+ ("j", int),
+ ("cc_w", float),
+ ("cc_h", float),
+ ("score", float),
+ ],
+ )
+ merge_candidates["i"] = merge_i
+ merge_candidates["j"] = merge_j
+ merge_candidates["cc_w"] = W_corr[merge_i, merge_j]
+ merge_candidates["cc_h"] = H_corr[merge_i, merge_j]
+ merge_candidates["score"] = (
+ W_corr[merge_i, merge_j] * H_corr[merge_i, merge_j]
+ )
+ merge_candidates = np.sort(merge_candidates, order="score")[::-1]
+
+ # Perform merging
+ merged = np.zeros(Nc_, dtype=bool)
+ W_merge = np.zeros((self.N_feat, 1))
+ H_merge = np.zeros((1, self.N_meas))
+ for index in range(len(merge_candidates)):
+ i = merge_candidates["i"][index]
+ j = merge_candidates["j"][index]
+ if not (merged[i] or merged[j]):
+ weight_i = np.sum(H_[i, :])
+ weight_j = np.sum(H_[j, :])
+ W_new = (W_[:, i] * weight_i + W_[:, j] * weight_j) / (
+ weight_i + weight_j
+ )
+ H_new = H_[i, :] + H_[j, :]
+ W_merge = np.hstack((W_merge, W_new[:, np.newaxis]))
+ H_merge = np.vstack((H_merge, H_new[np.newaxis, :]))
+ merged[i] = True
+ merged[j] = True
+ W_merge = W_merge[:, 1:]
+ H_merge = H_merge[1:, :]
+
+ W_ = np.hstack((W_[:, merged is False], W_merge))
+ H_ = np.vstack((H_[merged is False, :], H_merge))
+ Nc_ = W_.shape[1]
+
+ if len(merge_candidates) == 0:
+ proceed = False
+
+ self.W_next = W_
+ self.H_next = H_
+ self.N_c_next = self.W_next.shape[1]
+
+ return
+
+ def accept(self):
+ """
+ Updates classes (the W and H matrices) with the current candidate classes.
+ """
+ if self.W_next is None or self.H_next is None:
+ return
+ else:
+ self.W = self.W_next
+ self.H = self.H_next
+ self.N_c = self.N_c_next
+ self.W_next = None
+ self.H_next = None
+ self.N_c_next = None
+
+ def reject(self):
+ """
+ Discard the current candidate classes.
+ """
+ self.W_next = None
+ self.H_next = None
+ self.N_c_next = None
+
+ def get_class(self, i):
+ """
+ Get a single class, returning both its BP weights and scan position weights.
+
+ Args:
+ i (int): the class index
+
+ Returns:
+ (2-tuple): A 2-tuple containing:
+
+ * **class_BPs**: *(length N_feat array of floats)* the weights of the
+ N_feat Bragg peaks for this class
+ * **class_image**: *(shape (R_Nx,R_Ny) array of floats)* the weights of
+ each scan position in this class
+ """
+ class_BPs = self.W[:, i]
+ class_image = self.H[i, :].reshape((self.R_Nx, self.R_Ny))
+ return class_BPs, class_image
+
+ def get_class_BPs(self, i):
+ """
+ Get a single class, returning its BP weights.
+
+ Args:
+ i (int): the class index
+
+ Returns:
+ (length N_feat array of floats): the weights of the N_feat Bragg peaks for
+ this class
+ """
+ return self.W[:, i]
+
+ def get_class_image(self, i):
+ """
+ Get a single class, returning its scan position weights.
+
+ Args:
+ i (int): the class index
+
+ Returns:
+ (shape (R_Nx,R_Ny) array of floats): the weights of each scan position in
+ this class
+ """
+ return self.H[i, :].reshape((self.R_Nx, self.R_Ny))
+
+ def get_candidate_class(self, i):
+ """
+ Get a single candidate class, returning both its BP weights and scan position weights.
+
+ Args:
+ i (int) the class index
+
+ Returns:
+ (2-tuple): A 2-tuple containing:
+
+ * **class_BPs**: *(length N_feat array of floats)* the weights of the
+ N_feat Bragg peaks for this class
+ * **class_image**: *(shape (R_Nx,R_Ny) array of floats)* the weights of
+ each scan position in this class
+ """
+ assert self.W_next is not None, "W_next is not assigned."
+ assert self.H_next is not None, "H_next is not assigned."
+
+ class_BPs = self.W_next[:, i]
+ class_image = self.H_next[i, :].reshape((self.R_Nx, self.R_Ny))
+ return class_BPs, class_image
+
+ def get_candidate_class_BPs(self, i):
+ """
+ Get a single candidate class, returning its BP weights.
+
+ Accepts:
+ i (int) the class index
+
+ Returns:
+ class_BPs (length N_feat array of floats) the weights of the N_feat Bragg peaks for
+ this class
+ """
+ assert self.W_next is not None, "W_next is not assigned."
+
+ return self.W_next[:, i]
+
+ def get_candidate_class_image(self, i):
+ """
+ Get a single candidate class, returning its scan position weights.
+
+ Args:
+ i (int): the class index
+
+ Returns:
+ (shape (R_Nx,R_Ny) array of floats): the weights of each scan position in
+ this class
+ """
+ assert self.H_next is not None, "H_next is not assigned."
+
+ return self.H_next[i, :].reshape((self.R_Nx, self.R_Ny))
+
+
+### Functions for initial class determination ###
+
+
+def get_braggpeak_labels_by_scan_position(braggpeaks, Qx, Qy, max_dist=None):
+ """
+ For each scan position, gets a set of integers, specifying the bragg peaks at this
+ scan position.
+
+ From a set of positions in diffraction space (Qx,Qy), assign each detected bragg peak
+ in the PointListArray braggpeaks a label corresponding to the index of the closest
+ position; thus for a bragg peak at (qx,qy), if the closest position in (Qx,Qy) is
+ (Qx[i],Qy[i]), assign this peak the label i. This is equivalent to assigning each
+ bragg peak (qx,qy) a label according to the Voronoi region it lives in, given a
+ voronoi tesselation seeded from the points (Qx,Qy).
+
+ For each scan position, get the set of all indices i for all bragg peaks found at
+ this scan position.
+
+ Args:
+ braggpeaks (PointListArray): Bragg peaks; must have coords 'qx' and 'qy'
+ Qx (ndarray of floats): x-coords of the voronoi points
+ Qy (ndarray of floats): y-coords of the voronoi points
+ max_dist (None or number): maximum distance from a given voronoi point a peak
+ can be and still be associated with this label
+
+ Returns:
+ (list of lists of sets) the labels found at each scan position. Scan position
+ (Rx,Ry) is accessed via braggpeak_labels[Rx][Ry]
+ """
+ assert isinstance(braggpeaks, PointListArray), "braggpeaks must be a PointListArray"
+ assert np.all(
+ [name in braggpeaks.dtype.names for name in ("qx", "qy")]
+ ), "braggpeaks must contain coords 'qx' and 'qy'"
+
+ braggpeak_labels = [
+ [set() for i in range(braggpeaks.shape[1])] for j in range(braggpeaks.shape[0])
+ ]
+ for Rx in range(braggpeaks.shape[0]):
+ for Ry in range(braggpeaks.shape[1]):
+ s = braggpeak_labels[Rx][Ry]
+ pointlist = braggpeaks.get_pointlist(Rx, Ry)
+ for i in range(len(pointlist.data)):
+ label = np.argmin(
+ np.hypot(Qx - pointlist.data["qx"][i], Qy - pointlist.data["qy"][i])
+ )
+ if max_dist is not None:
+ if (
+ np.hypot(
+ Qx[label] - pointlist.data["qx"][i],
+ Qy[label] - pointlist.data["qy"][i],
+ )
+ < max_dist
+ ):
+ s.add(label)
+ else:
+ s.add(label)
+
+ return braggpeak_labels
+
+
+def get_initial_classes(
+ braggpeak_labels,
+ N,
+ thresh=0.3,
+ BP_fraction_thresh=0.1,
+ max_iterations=200,
+ n_corr_init=2,
+):
+ """
+ From the sets of Bragg peaks present at each scan position, get an initial guess
+ classes at which Bragg peaks should be grouped together into classes.
+
+ The algorithm is as follows:
+ 1. Calculate an n-point correlation function, i.e. the joint probability of any given
+ n BPs coexisting in a diffraction pattern. n is controlled by n_corr_init, and must
+ be 2 or 3. peaks i, j, and k are all in the same DP.
+ 2. Find the BP triplet maximizing the 3-point function; include these three BPs in a
+ class.
+ 3. Get all DPs containing the class BPs. From these, find the next most likely BP to
+ also be present. If its probability of coexisting with the known class BPs is
+ greater than thresh, add it to the class and repeat this step. Otherwise, proceed to
+ the next step.
+ 4. Check: if the new class is the same as a class that has already been found, OR if
+ the fraction of BPs which have not yet been placed in a class is less than
+ BP_fraction_thresh, or more than max_iterations have been attempted, finish,
+ returning all classes. Otherwise, set all slices of the 3-point function containing
+ the BPs in the new class to zero, and begin a new iteration, starting at step 2 using
+ the new, altered 3-point function.
+
+ Args:
+ N (int): the total number of indexed Bragg peaks in the 4D-STEM dataset
+ braggpeak_labels (list of lists of sets): the Bragg peak labels found at each
+ scan position; see get_braggpeak_labels_by_scan_position().
+ thresh (float in [0,1]): threshold for adding new BPs to a class
+ BP_fraction_thresh (float in [0,1]): algorithm terminates if fewer than this
+ fraction of the BPs have not been assigned to a class
+ max_iterations (int): algorithm terminates after this many iterations
+ n_corr_init (int): seed new classes by finding maxima of the n-point joint
+ probability function. Must be 2 or 3.
+
+ Returns:
+ (list of sets): the sets of Bragg peaks constituting the classes
+ """
+ assert isinstance(braggpeak_labels[0][0], set)
+ assert thresh >= 0 and thresh <= 1
+ assert BP_fraction_thresh >= 0 and BP_fraction_thresh <= 1
+ assert isinstance(max_iterations, (int, np.integer))
+ assert n_corr_init in (2, 3)
+ R_Nx = len(braggpeak_labels)
+ R_Ny = len(braggpeak_labels[0])
+
+ if n_corr_init == 2:
+ # Get two-point function
+ n_point_function = np.zeros((N, N))
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ s = braggpeak_labels[Rx][Ry]
+ perms = permutations(s, 2)
+ for perm in perms:
+ n_point_function[perm[0], perm[1]] += 1
+ n_point_function /= R_Nx * R_Ny
+
+ # Main loop
+ BP_sets = []
+ iteration = 0
+ unused_BPs = np.ones(N, dtype=bool)
+ seed_new_class = True
+ while seed_new_class:
+ ind1, ind2 = np.unravel_index(np.argmax(n_point_function), (N, N))
+ BP_set = set([ind1, ind2])
+ grow_class = True
+ while grow_class:
+ frequencies = np.zeros(N)
+ N_elements = 0
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ s = braggpeak_labels[Rx][Ry]
+ if BP_set.issubset(s):
+ N_elements += 1
+ for i in s:
+ frequencies[i] += 1
+ frequencies /= N_elements
+ for i in BP_set:
+ frequencies[i] = 0
+ ind_new = np.argmax(frequencies)
+ if frequencies[ind_new] > thresh:
+ BP_set.add(ind_new)
+ else:
+ grow_class = False
+
+ # Modify 2-point function, add new BP set to list, and decide to continue or stop
+ for i in BP_set:
+ n_point_function[i, :] = 0
+ n_point_function[:, i] = 0
+ unused_BPs[i] = 0
+ for s in BP_sets:
+ if len(s) == len(s.union(BP_set)):
+ seed_new_class = False
+ if seed_new_class is True:
+ BP_sets.append(BP_set)
+ iteration += 1
+ N_unused_BPs = np.sum(unused_BPs)
+ if iteration > max_iterations or N_unused_BPs < N * BP_fraction_thresh:
+ seed_new_class = False
+
+ else:
+ # Get three-point function
+ n_point_function = np.zeros((N, N, N))
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ s = braggpeak_labels[Rx][Ry]
+ perms = permutations(s, 3)
+ for perm in perms:
+ n_point_function[perm[0], perm[1], perm[2]] += 1
+ n_point_function /= R_Nx * R_Ny
+
+ # Main loop
+ BP_sets = []
+ iteration = 0
+ unused_BPs = np.ones(N, dtype=bool)
+ seed_new_class = True
+ while seed_new_class:
+ ind1, ind2, ind3 = np.unravel_index(np.argmax(n_point_function), (N, N, N))
+ BP_set = set([ind1, ind2, ind3])
+ grow_class = True
+ while grow_class:
+ frequencies = np.zeros(N)
+ N_elements = 0
+ for Rx in range(R_Nx):
+ for Ry in range(R_Ny):
+ s = braggpeak_labels[Rx][Ry]
+ if BP_set.issubset(s):
+ N_elements += 1
+ for i in s:
+ frequencies[i] += 1
+ frequencies /= N_elements
+ for i in BP_set:
+ frequencies[i] = 0
+ ind_new = np.argmax(frequencies)
+ if frequencies[ind_new] > thresh:
+ BP_set.add(ind_new)
+ else:
+ grow_class = False
+
+ # Modify 3-point function, add new BP set to list, and decide to continue or stop
+ for i in BP_set:
+ n_point_function[i, :, :] = 0
+ n_point_function[:, i, :] = 0
+ n_point_function[:, :, i] = 0
+ unused_BPs[i] = 0
+ for s in BP_sets:
+ if len(s) == len(s.union(BP_set)):
+ seed_new_class = False
+ if seed_new_class is True:
+ BP_sets.append(BP_set)
+ iteration += 1
+ N_unused_BPs = np.sum(unused_BPs)
+ if iteration > max_iterations or N_unused_BPs < N * BP_fraction_thresh:
+ seed_new_class = False
+
+ return BP_sets
diff --git a/py4DSTEM/process/classification/classutils.py b/py4DSTEM/process/classification/classutils.py
new file mode 100644
index 000000000..51762a090
--- /dev/null
+++ b/py4DSTEM/process/classification/classutils.py
@@ -0,0 +1,186 @@
+# Utility functions for classification routines
+
+import numpy as np
+
+from emdfile import tqdmnd, PointListArray
+from py4DSTEM.datacube import DataCube
+from py4DSTEM.process.utils import get_shifted_ar
+
+
+def get_class_DP(
+ datacube,
+ class_image,
+ thresh=0.01,
+ xshifts=None,
+ yshifts=None,
+ darkref=None,
+ intshifts=True,
+):
+ """
+ Get the average diffraction pattern for the class described in real space by
+ class_image.
+
+ Args:
+ datacube (DataCube): a datacube
+ class_image (2D array): the weight of the class at each position in real space
+ thresh (float): only include diffraction patterns for scan positions with a value
+ greater than or equal to thresh in class_image
+ xshifts (2D array, or None): the x diffraction shifts at each real space pixel.
+ If None, no shifting is performed.
+ yshifts (2D array, or None): the y diffraction shifts at each real space pixel.
+ If None, no shifting is performed.
+ darkref (2D array, or None): background to remove from each diffraction pattern
+ intshifts (bool): if True, round shifts to the nearest integer to speed up
+ computation
+
+ Returns:
+ (2D array): the average diffraction pattern for the class
+ """
+ assert isinstance(datacube, DataCube)
+ assert class_image.shape == (datacube.R_Nx, datacube.R_Ny)
+ if xshifts is not None:
+ assert xshifts.shape == (datacube.R_Nx, datacube.R_Ny)
+ if yshifts is not None:
+ assert yshifts.shape == (datacube.R_Nx, datacube.R_Ny)
+ if darkref is not None:
+ assert darkref.shape == (datacube.Q_Nx, datacube.Q_Ny)
+ assert isinstance(intshifts, bool)
+
+ class_DP = np.zeros((datacube.Q_Nx, datacube.Q_Ny))
+ for Rx, Ry in tqdmnd(
+ datacube.R_Nx,
+ datacube.R_Ny,
+ desc="Computing class diffraction pattern",
+ unit="DP",
+ unit_scale=True,
+ ):
+ if class_image[Rx, Ry] >= thresh:
+ curr_DP = class_image[Rx, Ry] * datacube.data[Rx, Ry, :, :]
+ if xshifts is not None and yshifts is not None:
+ xshift = xshifts[Rx, Ry]
+ yshift = yshifts[Rx, Ry]
+ if intshifts is True:
+ xshift = int(np.round(xshift))
+ yshift = int(np.round(yshift))
+ curr_DP = np.roll(curr_DP, -xshift, axis=0)
+ curr_DP = np.roll(curr_DP, -yshift, axis=1)
+ else:
+ curr_DP = get_shifted_ar(curr_DP, -xshift, -yshift)
+ class_DP += curr_DP
+ if darkref is not None:
+ class_DP -= darkref * class_image[Rx, Ry]
+ class_DP /= np.sum(class_image[class_image >= thresh])
+ class_DP = np.where(class_DP > 0, class_DP, 0)
+ return class_DP
+
+
+def get_class_DP_without_Bragg_scattering(
+ datacube,
+ class_image,
+ braggpeaks,
+ radius,
+ x0,
+ y0,
+ thresh=0.01,
+ xshifts=None,
+ yshifts=None,
+ darkref=None,
+ intshifts=True,
+):
+ """
+ Get the average diffraction pattern, removing any Bragg scattering, for the class
+ described in real space by class_image.
+
+ Bragg scattering is eliminated by masking circles of size radius about each of the
+ detected peaks in braggpeaks in each diffraction pattern before adding to the average
+ image. Importantly, braggpeaks refers to the peak positions in the raw data - i.e.
+ BEFORE any shift correction is applied. Passing shifted Bragg peaks will yield
+ incorrect results. For speed, the Bragg peaks are removed with a binary mask, rather
+ than a continuous sigmoid, so selecting a radius that is slightly (~1 pix) larger
+ than the disk size is recommended.
+
+ Args:
+ datacube (DataCube): a datacube
+ class_image (2D array): the weight of the class at each position in real space
+ braggpeaks (PointListArray): the detected Bragg peak positions, with respect to
+ the raw data (i.e. not diffraction shift or ellipse corrected)
+ radius (number): the radius to mask about each detected Bragg peak - should be
+ slightly larger than the disk radius
+ x0 (number): x-position of the optic axis
+ y0 (number): y-position of the optic axis
+ thresh (float): only include diffraction patterns for scan positions with a value
+ greater than or equal to thresh in class_image
+ xshifts (2D array, or None): the x diffraction shifts at each real space pixel.
+ If None, no shifting is performed.
+ yshifts (2D array, or None): the y diffraction shifts at each real space pixel.
+ If None, no shifting is performed.
+ darkref (2D array, or None): background to remove from each diffraction pattern
+ intshifts (bool): if True, round shifts to the nearest integer to speed up
+ computation
+
+ Returns:
+ class_DP (2D array) the average diffraction pattern for the class
+ """
+ assert isinstance(datacube, DataCube)
+ assert class_image.shape == (datacube.R_Nx, datacube.R_Ny)
+ assert isinstance(braggpeaks, PointListArray)
+ if xshifts is not None:
+ assert xshifts.shape == (datacube.R_Nx, datacube.R_Ny)
+ if yshifts is not None:
+ assert yshifts.shape == (datacube.R_Nx, datacube.R_Ny)
+ if darkref is not None:
+ assert darkref.shape == (datacube.Q_Nx, datacube.Q_Ny)
+ assert isinstance(intshifts, bool)
+
+ class_DP = np.zeros((datacube.Q_Nx, datacube.Q_Ny))
+ mask_weights = np.zeros((datacube.Q_Nx, datacube.Q_Ny))
+ yy, xx = np.meshgrid(np.arange(datacube.Q_Ny), np.arange(datacube.Q_Nx))
+ for Rx, Ry in tqdmnd(
+ datacube.R_Nx,
+ datacube.R_Ny,
+ desc="Computing class diffraction pattern",
+ unit="DP",
+ unit_scale=True,
+ ):
+ weight = class_image[Rx, Ry]
+ if weight >= thresh:
+ braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry)
+ mask = np.ones((datacube.Q_Nx, datacube.Q_Ny))
+ if braggpeaks_curr.length > 1:
+ center_index = np.argmin(
+ np.hypot(
+ braggpeaks_curr.data["qx"] - x0, braggpeaks_curr.data["qy"] - y0
+ )
+ )
+ for i in range(braggpeaks_curr.length):
+ if i != center_index:
+ mask_ = (
+ (xx - braggpeaks_curr.data["qx"][i]) ** 2
+ + (yy - braggpeaks_curr.data["qy"][i]) ** 2
+ ) >= radius**2
+ mask = np.logical_and(mask, mask_)
+ curr_DP = datacube.data[Rx, Ry, :, :] * mask * weight
+ if xshifts is not None and yshifts is not None:
+ xshift = xshifts[Rx, Ry]
+ yshift = yshifts[Rx, Ry]
+ if intshifts:
+ xshift = int(np.round(xshift))
+ yshift = int(np.round(yshift))
+ curr_DP = np.roll(curr_DP, -xshift, axis=0)
+ curr_DP = np.roll(curr_DP, -yshift, axis=1)
+ mask = np.roll(mask, -xshift, axis=0)
+ mask = np.roll(mask, -yshift, axis=1)
+ else:
+ curr_DP = get_shifted_ar(curr_DP, -xshift, -yshift)
+ mask = get_shifted_ar(mask, -xshift, -yshift)
+ if darkref is not None:
+ curr_DP -= darkref * weight
+ class_DP += curr_DP
+ mask_weights += mask * weight
+ class_DP = np.divide(
+ class_DP,
+ mask_weights,
+ where=mask_weights != 0,
+ out=np.zeros((datacube.Q_Nx, datacube.Q_Ny)),
+ )
+ return class_DP
diff --git a/py4DSTEM/process/classification/featurization.py b/py4DSTEM/process/classification/featurization.py
new file mode 100644
index 000000000..b462ea1eb
--- /dev/null
+++ b/py4DSTEM/process/classification/featurization.py
@@ -0,0 +1,1003 @@
+import numpy as np
+from sklearn.preprocessing import MinMaxScaler, RobustScaler
+from sklearn.decomposition import NMF, PCA, FastICA
+from sklearn.mixture import GaussianMixture
+from sklearn.utils._testing import ignore_warnings
+from sklearn.exceptions import ConvergenceWarning
+from skimage.filters import threshold_otsu, threshold_yen
+from skimage.measure import label
+from skimage.morphology import closing, square, remove_small_objects
+
+from emdfile import tqdmnd
+
+
+class Featurization(object):
+ """
+ A class for feature selection, modification, and classification of 4D-STEM data based on a user defined
+ array of input features for each pattern. Features are stored under Featurization. Features and can be
+ used for a variety of unsupervised classification tasks.
+
+ Initialization methods:
+ __init__:
+ Creates instance of featurization
+ concatenate_features:
+ Creates instance of featurization from a list of featurization instances
+ from_braggvectors:
+ Creates instance of featurization from a BraggVectors instance
+
+ Feature Dictionary Modification Methods
+ add_feature:
+ Adds features to the features array
+ remove_feature:
+ Removes features to the features array
+
+ Feature Preprocessing Methods
+ MinMaxScaler:
+ Performs sklearn MinMaxScaler operation on features stored at a key
+ RobustScaler:
+ Performs sklearn RobustScaler operation on features stored at a key
+ mean_feature:
+ Takes the rowwise average of a matrix stored at a key, such that only one column is left,
+ reducing a set of n features down to 1 feature per pattern.
+ median_feature:
+ Takes the rowwise median of a matrix stored at a key, such that only one column is left,
+ reducing a set of n features down to 1 feature per pattern.
+ max_feature:
+ Takes the rowwise max of a matrix stored at a key, such that only one column is left,
+ reducing a set of n features down to 1 feature per pattern.
+
+ Classification Methods
+ PCA:
+ Principal Component Analysis to refine features.
+ ICA:
+ Independent Component Analysis to refine features.
+ NMF:
+ Performs either traditional or iterative Nonnegative Matrix Factorization (NMF) to refine features.
+ GMM:
+ Gaussian mixture model to predict class labels. Fits a gaussian based on covariance of features.
+
+ Class Examination Methods
+ get_class_DPs:
+ Gets weighted class diffraction patterns (DPs) for an NMF or GMM operation
+ get_class_ims:
+ Gets weighted class images (ims) for an NMF or GMM operation
+ """
+
+ def __init__(self, features, R_Nx, R_Ny, name):
+ """
+ Initializes classification instance.
+
+ This method:
+ 1. Generates key:value pair to access input features
+ 2. Initializes the empty dictionaries for feature modification and classification
+
+ Args:
+ features (list): A list of ndarrays which will each be associated with value stored at the key in the same index within the list
+ R_Nx (int): The real space x dimension of the dataset
+ R_Ny (int): The real space y dimension of the dataset
+ name (str): The name of the featurization object
+
+ Returns:
+ new_instance: New Featurization instance
+ """
+ self.R_Nx = R_Nx
+ self.R_Ny = R_Ny
+ self.name = name
+
+ if isinstance(features, np.ndarray):
+ if len(features.shape) == 3:
+ self.features = features.reshape(R_Nx * R_Ny, features.shape[-1])
+ elif len(features.shape) == 2:
+ self.features = features
+ else:
+ raise ValueError(
+ "feature array must be of dimensions (R_Nx*R_Ny, num_features) or (R_Nx, R_Ny, num_features)"
+ )
+ elif isinstance(features, list):
+ if all(isinstance(f, np.ndarray) for f in features):
+ for i in range(len(features)):
+ if features[i].shape == 3:
+ features[i] = features[i].reshape(
+ R_Nx * R_Ny, features.shape[-1]
+ )
+ if len(features[i].shape) != 2:
+ raise ValueError(
+ "feature array(s) in list must be of dimensions (R_Nx*R_Ny, num_features) or (R_Nx, R_Ny, num_features)"
+ )
+ self.features = np.concatenate(features, axis=1)
+ elif all(isinstance(f, Featurization) for f in features):
+ raise TypeError(
+ "List of Featurization instances must be initialized using the concatenate_features method."
+ )
+ else:
+ raise TypeError(
+ "Entries in list must be np.ndarrays for initialization of the Featurization instance."
+ )
+ else:
+ raise TypeError(
+ "Features must be either a single np.ndarray of shape 2 or 3 or a list of np.ndarrays or featurization instances."
+ )
+ return
+
+ def from_braggvectors(
+ braggvectors,
+ bins_x,
+ bins_y,
+ intensity_scale,
+ name,
+ mask=None,
+ ):
+ """
+ Initialize a featurization instance from a BraggVectors instance
+
+ Args:
+ braggvectors (BraggVectors): BraggVectors instance containing calibrations
+ bins_x (int): Number of pixels per bin in x direction
+ bins_y (int): Number of pixels per bin in y direction
+ intensity_scale (float): Value to scale intensity of detected disks by
+ name (str): Name of featurization instance
+ mask (bool): Mask to remove disks in unwanted positions in diffraction space
+
+ Returns:
+ new_instance: Featurization instance
+
+ Details:
+ Transforms the calibrated pointlistarray in BraggVectors instance into a numpy array
+ that can be clustered using the methods in featurization.
+ """
+
+ Q_Nx, Q_Ny = braggvectors.Qshape[0], braggvectors.Qshape[1]
+ nx_bins, ny_bins = int(np.ceil(Q_Nx / bins_x)), int(np.ceil(Q_Ny / bins_y))
+ n_bins = nx_bins * ny_bins
+
+ try:
+ pointlistarray = braggvectors._v_cal.copy()
+ except AttributeError:
+ er = "No calibrated bragg vectors found. Try running .calibrate()!"
+ raise Exception(er)
+ try:
+ q_pixel_size = braggvectors.calibration.get_Q_pixel_size()
+ except AttributeError:
+ er = "No q_pixel_size found. Please set value and recalibrate before continuing."
+ raise Exception(er)
+
+ peak_data = np.zeros((pointlistarray.shape[0], pointlistarray.shape[1], n_bins))
+
+ # Create Bragg Disk Features
+ for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]):
+ pointlist = pointlistarray.get_pointlist(Rx, Ry)
+ if pointlist.data.shape[0] == 0:
+ continue
+
+ if mask is not None:
+ deletemask = np.zeros(pointlist.length, dtype=bool)
+ for i in range(pointlist.length):
+ deletemask = np.where(
+ (
+ mask[
+ np.rint(
+ (pointlist.data["qx"] / q_pixel_size) + Q_Nx / 2
+ ).astype(int),
+ np.rint(
+ (pointlist.data["qy"] / q_pixel_size) + Q_Ny / 2
+ ).astype(int),
+ ]
+ is False
+ ),
+ True,
+ False,
+ )
+ pointlist.remove(deletemask)
+
+ for i in range(pointlist.data.shape[0]):
+ floor_x = np.rint(
+ (pointlist.data[i][0] / q_pixel_size + Q_Nx / 2) / bins_x
+ )
+ floor_y = np.rint(
+ (pointlist.data[i][1] / q_pixel_size + Q_Ny / 2) / bins_y
+ )
+ binval_ff = int((floor_x * ny_bins) + floor_y)
+ binval_cf = int(((floor_x + 1) * ny_bins) + floor_y)
+
+ # Distribute Peaks
+ if intensity_scale == 0:
+ try:
+ peak_data[Rx, Ry, binval_ff] += 1
+ peak_data[Rx, Ry, binval_ff + 1] += 1
+ peak_data[Rx, Ry, binval_cf] += 1
+ peak_data[Rx, Ry, binval_cf + 1] += 1
+ except IndexError:
+ continue
+ else:
+ try:
+ peak_data[Rx, Ry, binval_ff] += (
+ pointlist.data[i][2] * intensity_scale
+ )
+ peak_data[Rx, Ry, binval_ff + 1] += (
+ pointlist.data[i][2] * intensity_scale
+ )
+ peak_data[Rx, Ry, binval_cf] += (
+ pointlist.data[i][2] * intensity_scale
+ )
+ peak_data[Rx, Ry, binval_cf + 1] += (
+ pointlist.data[i][2] * intensity_scale
+ )
+ except IndexError:
+ continue
+
+ peak_data.reshape(pointlistarray.shape[0] * pointlistarray.shape[1], n_bins)
+ new_instance = Featurization(
+ peak_data, pointlistarray.shape[0], pointlistarray.shape[1], name
+ )
+ return new_instance
+
+ def concatenate_features(features, name):
+ """
+ Concatenates featurization instances (features) and outputs a new Featurization instance
+ containing the concatenated features from each featurization instance. R_Nx, R_Ny will be
+ inherited from the featurization instances and must be consistent across objects.
+
+ Args:
+ features (list): A list of keys to be concatenated into one array
+ name (str): The name of the featurization instance
+
+ Returns:
+ new_instance: Featurization instance
+ """
+ R_Nxs = [features[i].R_Nx for i in range(len(features))]
+ R_Nys = [features[i].R_Ny for i in range(len(features))]
+ if len(np.unique(R_Nxs)) != 1 or len(np.unique(R_Nys)) != 1:
+ raise ValueError(
+ "Can only concatenate Featurization instances with same R_Nx and R_Ny"
+ )
+ new_instance = Featurization(
+ np.concatenate(
+ [features[i].features for i in range(len(features))], axis=1
+ ),
+ R_Nx=R_Nxs[0],
+ R_Ny=R_Nys[0],
+ name=name,
+ )
+ return new_instance
+
+ def add_features(self, feature):
+ """
+ Add a feature to the end of the features array
+
+ Args:
+ key (int, float, str): A key in which a feature can be accessed from
+ feature (ndarray): The feature associated with the key
+ """
+ self.features = np.concatenate(self.features, feature, axis=1)
+ return
+
+ def delete_features(self, index):
+ """
+ Deletes feature columns from the feature array
+
+ Args:
+ index (int, list): A key which will be removed
+ """
+ self.features = np.delete(self.features, index, axis=1)
+ return
+
+ def mean_feature(self, index):
+ """
+ Takes columnwise mean and replaces features in 'index'.
+
+ Args:
+ index (list of int): Indices of features to take the mean of. New feature array is placed in self.features.
+ """
+ mean_features = np.mean(self.features[:, index], axis=1)
+ mean_features = mean_features.reshape(mean_features.shape[0], 1)
+ cleaned_features = np.delete(self.features, index, axis=1)
+ self.features = np.concatenate([cleaned_features, mean_features], axis=1)
+ return
+
+ def median_feature(self, index):
+ """
+ Takes columnwise median and replaces features in 'index'. New feature array is placed in self.features.
+
+ Args:
+ index (list of int): Indices of features to take the median of.
+ """
+ median_features = np.median(self.features[:, index], axis=1)
+ median_features = median_features.reshape(median_features.shape[0], 1)
+ cleaned_features = np.delete(self.features, index, axis=1)
+ self.features = np.concatenate([cleaned_features, median_features], axis=1)
+ return
+
+ def max_feature(self, index):
+ """
+ Takes columnwise max and replaces features in 'index'. New feature array is placed in self.features.
+
+ Args:
+ index (list of int): Indices of features to take the max of.
+ """
+ max_features = np.max(self.features[:, index], axis=1)
+ max_features = max_features.reshape(max_features.shape[0], 1)
+ cleaned_features = np.delete(self.features, index, axis=1)
+ self.features = np.concatenate([cleaned_features, max_features], axis=1)
+ return
+
+ def MinMaxScaler(self, return_scaled=True):
+ """
+ Uses sklearn MinMaxScaler to scale a subset of the input features.
+ Replaces a feature with the positive shifted array.
+
+ Args:
+ return_scaled (bool): returns the scaled array
+ """
+ mms = MinMaxScaler()
+ self.features = mms.fit_transform(self.features)
+ if return_scaled is True:
+ return self.features
+ else:
+ return
+
+ def RobustScaler(self, return_scaled=True):
+ """
+ Uses sklearn RobustScaler to scale a subset of the input features.
+ Replaces a feature with the positive shifted array.
+
+ Args:
+ return_scaled (bool): returns the scaled array
+ """
+ rs = RobustScaler()
+ self.features = rs.fit_transform(self.features)
+ if return_scaled is True:
+ return self.features
+ else:
+ return
+
+ def shift_positive(self, return_scaled=True):
+ """
+ Replaces a feature with the positive shifted array.
+
+ Args:
+ return_scaled (bool): returns the scaled array
+ """
+ self.features += np.abs(self.features.min())
+ if return_scaled is True:
+ return self.features
+ else:
+ return
+
+ def PCA(self, components, return_results=False):
+ """
+ Performs PCA on features
+
+ Args:
+ components (list): A list of ints for each key. This will be the output number of features
+ """
+ pca = PCA(n_components=components)
+ self.pca = pca.fit_transform(self.features)
+ if return_results is True:
+ return self.pca
+ return
+
+ def ICA(self, components, return_results=True):
+ """
+ Performs ICA on features
+
+ Args:
+ components (list): A list of ints for each key. This will be the output number of features
+ """
+ ica = FastICA(n_components=components)
+ self.ica = ica.fit_transform(self.features)
+ if return_results is True:
+ return self.ica
+ return
+
+ def NMF(
+ self,
+ max_components,
+ num_models,
+ merge_thresh=1,
+ max_iterations=1,
+ random_seed=None,
+ save_all_models=True,
+ return_results=False,
+ ):
+ """
+ Performs either traditional Nonnegative Matrix Factoriation (NMF) or iteratively on input features.
+ For Traditional NMF:
+ set either merge_threshold = 1, max_iterations = 1, or both. Default is to set
+
+ Args:
+ max_components (int): Number of initial components to start the first NMF iteration
+ merge_thresh (float): Correlation threshold to merge features
+ num_models (int): Number of independent models to run (number of learners that will be combined in consensus).
+ max_iterations (int): Number of iterations. Default 1, which runs traditional NMF
+ random_seed (int): Random seed.
+ save_all_models (bool): Whether or not to return all of the models - default is to return all outputs for consensus clustering.
+ if False, will only return the model with the lowest NMF reconstruction error.
+ return_results (bool): Whether or not to return the final class weights
+
+ Details:
+ This method may require trial and error for proper selection of parameters. To perform traditional NMF, the
+ defaults should be used:
+ merge_thresh = 1
+ max_iterations = 1
+ Note that the max_components in this case will be equivalent to the number of classes the NMF model identifies.
+
+ Iterative NMF calculates the correlation between all of the output columns from an NMF iteration, merges the
+ features correlated above the merge_thresh, and performs NMF until either max_iterations is reached or until
+ no more columns are correlated above merge_thresh.
+ """
+ self.W = _nmf_single(
+ self.features,
+ max_components=max_components,
+ merge_thresh=merge_thresh,
+ num_models=num_models,
+ max_iterations=max_iterations,
+ random_seed=random_seed,
+ save_all_models=save_all_models,
+ )
+ if return_results is True:
+ return self.W
+ return
+
+ def GMM(self, cv, components, num_models, random_seed=None, return_results=False):
+ """
+ Performs gaussian mixture model on input features
+
+ Args:
+ cv (str): Covariance type - must be 'spherical', 'tied', 'diag', or 'full'
+ components (int): Number of components
+ num_models (int): Number of models to run
+ random_seed (int): Random seed
+ """
+ self.gmm, self.gmm_labels, self.gmm_proba = _gmm_single(
+ self.features,
+ cv=cv,
+ components=components,
+ num_models=num_models,
+ random_seed=random_seed,
+ )
+ if return_results is True:
+ return self.gmm
+ return
+
+ def get_class_DPs(self, datacube, method, thresh):
+ """
+ Returns weighted class patterns based on classification instance
+ datacube must be vectorized in real space (shape = (R_Nx * R_Ny, 1, Q_Nx, Q_Ny)
+
+ Args:
+ classification_method (str): Either 'nmf' or 'gmm' - finds location of clusters
+ datacube (py4DSTEM datacube): Vectorized in real space, with shape (R_Nx * R_Ny, Q_Nx, Q_Ny)
+ """
+ class_patterns = []
+ datacube_shape = datacube.data.shape
+ if len(datacube.data.shape) != 3:
+ try:
+ datacube.data = datacube.data.reshape(
+ self.R_Nx * self.R_Ny,
+ datacube.data.shape[2],
+ datacube.data.shape[3],
+ )
+ except:
+ raise ValueError(
+ "Datacube must have same R_Nx and R_Ny dimensions as Featurization instance."
+ )
+ if method == "nmf":
+ if self.W == list:
+ return ValueError(
+ "Method not implmented for multiple NMF models, either return 1 model or perform spatial separation first."
+ )
+ for l in range(self.W.shape[1]):
+ class_pattern = np.zeros(
+ (datacube.data.shape[1], datacube.data.shape[2])
+ )
+ x_ = np.where(self.W[:, l] > thresh)[0]
+ for x in range(x_.shape[0]):
+ class_pattern += datacube.data[x_[x], 0] * self.W[x_[x], l]
+ class_patterns.append(class_pattern / np.sum(self.W[x_, l]))
+ elif method == "gmm":
+ if self.gmm_labels == list:
+ return ValueError(
+ "Method not implmented for multiple GMM models, either return 1 model or perform spatial separation first."
+ )
+ for l in range(np.max(self.gmm_labels)):
+ class_pattern = np.zeros(
+ (datacube.data.shape[1], datacube.data.shape[2])
+ )
+ x_ = np.where(self.gmm_proba[:, l] > thresh)[0]
+ for x in range(x_.shape[0]):
+ class_pattern += datacube.data[x_[x], 0] * self.gmm_proba[x_[x], l]
+ class_patterns.append(class_pattern / np.sum(self.gmm_proba[x_, l]))
+ elif method == "pca":
+ for l in range(self.pca.shape[1]):
+ class_pattern = np.zeros(
+ (datacube.data.shape[1], datacube.data.shape[2])
+ )
+ x_ = np.where(self.pca[:, l] > thresh)[0]
+ for x in range(x_.shape[0]):
+ class_pattern += datacube.data[x_[x]] * self.pca[x_[x], l]
+ class_patterns.append(class_pattern / np.sum(self.pca[x_, l]))
+ class_patterns = [class_patterns]
+ elif method == "spatially_separated_ims":
+ for l in range(len(self.spatially_separated_ims)):
+ small_class_patterns = []
+ for j in range(len(self.spatially_separated_ims[l])):
+ class_pattern = np.zeros(
+ (datacube.data.shape[1], datacube.data.shape[2])
+ )
+ x_ = np.where(
+ self.spatially_separated_ims[l][j].reshape(
+ self.R_Nx * self.R_Ny, 1
+ )
+ > thresh
+ )[0]
+ for x in range(x_.shape[0]):
+ class_pattern += (
+ datacube.data[x_[x]]
+ * self.spatially_separated_ims[l][j].reshape(
+ self.R_Nx * self.R_Ny, 1
+ )[x_[x]]
+ )
+ small_class_patterns.append(
+ class_pattern
+ / np.sum(
+ self.spatially_separated_ims[l][j].reshape(
+ self.R_Nx * self.R_Ny, 1
+ )[x_]
+ )
+ )
+ class_patterns.append(small_class_patterns)
+ elif method == "consensus_clusters":
+ for j in range(len(self.consensus_clusters)):
+ class_pattern = np.zeros(
+ (datacube.data.shape[1], datacube.data.shape[2])
+ )
+ x_ = np.where(
+ self.consensus_clusters[j].reshape(self.R_Nx * self.R_Ny, 1)
+ > thresh
+ )[0]
+ for x in range(x_.shape[0]):
+ class_pattern += (
+ datacube.data[x_[x]]
+ * self.consensus_clusters[j].reshape(self.R_Nx * self.R_Ny, 1)[
+ x_[x]
+ ]
+ )
+ class_patterns.append(
+ class_pattern
+ / np.sum(
+ self.consensus_clusters[j].reshape(self.R_Nx * self.R_Ny, 1)[x_]
+ )
+ )
+ class_patterns = [class_patterns]
+ else:
+ raise ValueError(
+ "method not accepted. Try NMF, GMM, PCA, ICA, spatially_separated_ims, or consensus_clustering."
+ )
+ datacube.data = datacube.data.reshape(datacube_shape)
+ self.class_DPs = class_patterns
+ return
+
+ def get_class_ims(self, classification_method):
+ """
+ Returns weighted class maps based on classification instance
+
+ Args:
+ classification_method (str): Location to retrieve class images from - NMF, GMM, PCA, or ICA
+ """
+ class_maps = []
+ if classification_method == "NMF":
+ if type(self.W) == list:
+ for l in range(len(self.W)):
+ small_class_maps = []
+ for k in range(self.W[l].shape[1]):
+ small_class_maps.append(
+ self.W[l][:, k].reshape(self.R_Nx, self.R_Ny)
+ )
+ class_maps.append(small_class_maps)
+ else:
+ for l in range(self.W.shape[1]):
+ class_maps.append(self.W[:, l].reshape(self.R_Nx, self.R_Ny))
+ class_maps = [class_maps]
+ elif classification_method == "GMM":
+ if type(self.gmm_labels) == list:
+ for l in range(len(self.gmm_labels)):
+ small_class_maps = []
+ for k in range(np.max(self.gmm_labels[l])):
+ R_vals = np.where(
+ self.gmm_labels[l].reshape(self.R_Nx, self.R_Ny) == k, 1, 0
+ )
+ small_class_maps.append(
+ R_vals
+ * self.gmm_proba[l][:, k].reshape(self.R_Nx, self.R_Ny)
+ )
+ class_maps.append(small_class_maps)
+ else:
+ for l in range((np.max(self.gmm_labels))):
+ R_vals = np.where(
+ self.gmm_labels[l].reshape(self.R_Nx, self.R_Ny) == l, 1, 0
+ )
+ class_maps.append(
+ R_vals * self.gmm_proba[:, l].reshape(self.R_Nx, self.R_Ny)
+ )
+ class_maps = [class_maps]
+ elif classification_method == "PCA":
+ for i in range(self.pca.shape[1]):
+ class_maps.append(self.pca[:, i].reshape(self.R_Nx, self.R_Ny))
+ class_maps = [class_maps]
+ elif classification_method == "ICA":
+ for i in range(self.ica.shape[1]):
+ class_maps.append(self.ica[:, i].reshape(self.R_Nx, self.R_Ny))
+ class_maps = [class_maps]
+ else:
+ raise ValueError(
+ "classification_method not accepted. Try NMF, GMM, PCA, or ICA."
+ )
+ self.class_ims = class_maps
+ return
+
+ def spatial_separation(self, size, threshold=0, method=None, clean=True):
+ """
+ Identify spatially distinct regions from class images and separate based on a threshold and size.
+
+ Args:
+ size (int): Number of pixels which is the minimum to keep a class - all spatially distinct regions with
+ less than 'size' pixels will be removed
+ threshold (float): Intensity weight of a component to keep
+ method (str): (Optional) Filter method, default None. Accepts options 'yen' and 'otsu'.
+ clean (bool): Whether or not to 'clean' cluster sets based on overlap, i.e. remove clusters that do not have
+ any unique components
+ """
+ # Prepare for separation
+ labelled = []
+ stacked = []
+
+ # Loop through all models
+ for j in range(len(self.class_ims)):
+ separated_temp = []
+
+ # Loop through class images in each model to filtered and separate class images
+ for l in range(len(self.class_ims[j])):
+ image = np.where(
+ self.class_ims[j][l] > threshold, self.class_ims[j][l], 0
+ )
+ if method == "yen":
+ t = threshold_yen(image)
+ bw = closing(image > t, square(2))
+ labelled_image = label(bw)
+ if np.sum(labelled_image) > size:
+ large_labelled_image = remove_small_objects(
+ labelled_image, size
+ )
+ else:
+ large_labelled_image = labelled_image
+ elif method == "otsu":
+ t = threshold_otsu(image)
+ bw = closing(image > t, square(2))
+ labelled_image = label(bw)
+ if np.sum(labelled_image) > size:
+ large_labelled_image = remove_small_objects(
+ labelled_image, size
+ )
+ else:
+ large_labelled_image = labelled_image
+ elif method is None:
+ labelled_image = label(image)
+ if np.sum(labelled_image) > size:
+ large_labelled_image = remove_small_objects(
+ labelled_image, size
+ )
+ else:
+ large_labelled_image = labelled_image
+
+ else:
+ raise ValueError(
+ method
+ + " method is not supported. Please use yen, otsu, or None instead."
+ )
+ unique_labels = np.unique(large_labelled_image)
+ separated_temp.extend(
+ [
+ (
+ np.where(
+ large_labelled_image == unique_labels[k + 1], image, 0
+ )
+ )
+ for k in range(len(unique_labels) - 1)
+ ]
+ )
+
+ if len(separated_temp) > 0:
+ if clean is True:
+ data_ndarray = np.dstack(separated_temp)
+ data_hard = (
+ data_ndarray.max(axis=2, keepdims=1) == data_ndarray
+ ) * data_ndarray
+ data_list = [
+ data_ndarray[:, :, x] for x in range(data_ndarray.shape[2])
+ ]
+ data_list_hard = [
+ np.where(data_hard[:, :, n] > threshold, 1, 0)
+ for n in range(data_hard.shape[2])
+ ]
+ labelled.append(
+ [
+ data_list[n]
+ for n in range(len(data_list_hard))
+ if (np.sum(data_list_hard[n]) > size)
+ ]
+ )
+ else:
+ labelled.append(separated_temp)
+ else:
+ continue
+
+ if len(labelled) > 0:
+ self.spatially_separated_ims = labelled
+ else:
+ raise ValueError(
+ "No distinct regions found in any models. Try modifying threshold, size, or method."
+ )
+
+ return
+
+ def consensus(
+ self,
+ threshold=0,
+ location="spatially_separated_ims",
+ split=0,
+ method="mean",
+ drop_bins=0,
+ ):
+ """
+ Consensus Clustering takes the outcome of a prepared set of 2D images from each cluster and averages the outcomes.
+
+ Args:
+ threshold (float): Threshold weights, default 0
+ location (str): Where to get the consensus from - after spatial separation = 'spatially_separated_ims'
+ split_value (float): Threshold in which to separate classes during label correspondence (Default 0). This should be
+ proportional to the expected class weights- the sum of the weights in the current class image
+ that match nonzero values in each bin are calculated and then checked for splitting.
+ method (str): Method in which to combine the consensus clusters - either mean or median.
+ drop_bins (int): Number of clusters needed in each class to keep cluster set in the consensus. Default 0, meaning
+
+ Details:
+ This method involves 2 steps: Label correspondence and consensus clustering.
+
+ Label correspondence sorts the classes found by the independent models into bins based on class overlap in real space.
+ Arguments related to label correspondence are the threshold and split_value. The threshold is related
+ to the weights of the independent classes. If the weight of the observation in the class is less than the threshold, it
+ will be set to 0. The split_value indicates the extent of similarity the independent classes must have before intializing
+ a new bin. The default is 0 - this means if the class of interest has 0 overlap with the identified bins, a new bin will
+ be created. The value is based on the sum of the weights in the current class image that match the nonzero values in the
+ current bins.
+
+ Consensus clustering combines these sorted bin into 1 class based on the selected method (either 'mean' which takes
+ the average of the bin, or 'median' which takes the median of the bin). Bins with less than the drop_bins value will
+ not be included in the final results.
+ """
+ # Set up for consensus clustering
+ class_dict = {}
+ consensus_clusters = []
+
+ if location != "spatially_separated_ims":
+ raise ValueError(
+ "Consensus clustering only supported for location = spatially_separated_ims."
+ )
+
+ # Find model with largest number of clusters for label correspondence
+ ncluster = [
+ len(self.spatially_separated_ims[j])
+ for j in range(len(self.spatially_separated_ims))
+ ]
+ max_cluster_ind = np.where(ncluster == np.max(ncluster))[0][0]
+
+ # Label Correspondence
+ for k in range(len(self.spatially_separated_ims[max_cluster_ind])):
+ class_dict["c" + str(k)] = [
+ np.where(
+ self.spatially_separated_ims[max_cluster_ind][k] > threshold,
+ self.spatially_separated_ims[max_cluster_ind][k],
+ 0,
+ )
+ ]
+ for j in range(len(self.spatially_separated_ims)):
+ if j == max_cluster_ind:
+ continue
+ for m in range(len(self.spatially_separated_ims[j])):
+ class_im = np.where(
+ self.spatially_separated_ims[j][m] > threshold,
+ self.spatially_separated_ims[j][m],
+ 0,
+ )
+ best_sum = -np.inf
+ for l in range(len(class_dict.keys())):
+ current_sum = np.sum(
+ np.where(class_dict["c" + str(l)][0] > threshold, class_im, 0)
+ )
+ if current_sum >= best_sum:
+ best_sum = current_sum
+ cvalue = l
+ if best_sum > split:
+ class_dict["c" + str(cvalue)].append(class_im)
+ else:
+ class_dict["c" + str(len(list(class_dict.keys())))] = [class_im]
+ key_list = list(class_dict.keys())
+
+ # Consensus clustering
+ if method == "mean":
+ for n in range(len(key_list)):
+ if drop_bins > 0:
+ if len(class_dict[key_list[n]]) <= drop_bins:
+ continue
+ consensus_clusters.append(
+ np.mean(np.dstack(class_dict[key_list[n]]), axis=2)
+ )
+ elif method == "median":
+ for n in range(len(key_list)):
+ if drop_bins > 0:
+ if len(class_dict[key_list[n]]) <= drop_bins:
+ continue
+ consensus_clusters.append(
+ np.median(np.dstack(class_dict[key_list[n]]), axis=2)
+ )
+ else:
+ raise ValueError(
+ "Only mean and median consensus methods currently supported."
+ )
+ self.consensus_dict = class_dict
+ self.consensus_clusters = consensus_clusters
+
+ return
+
+
+@ignore_warnings(category=ConvergenceWarning)
+def _nmf_single(
+ x,
+ max_components,
+ merge_thresh,
+ num_models,
+ max_iterations,
+ random_seed=None,
+ save_all_models=True,
+):
+ """
+ Performs NMF on single feature matrix, which is an nd.array
+
+ Args:
+ x (np.ndarray): Feature array
+ max_components (int): Number of initial components to start the first NMF iteration
+ merge_thresh (float): Correlation threshold to merge features
+ num_models (int): Number of independent models to run (number of learners that will be combined in consensus)
+ iterations (int): Number of iterations. Default 1, which runs traditional NMF
+ random_seed (int): Random seed
+ save_all_models (bool): Whether or not to return all of the models - default is to save
+ all outputs for consensus clustering
+ """
+ # Prepare error, random seed
+ err = np.inf
+ if random_seed is None:
+ rng = np.random.RandomState(seed=42)
+ else:
+ seed = random_seed
+ if save_all_models is True:
+ W = []
+
+ # Big loop through all models
+ for i in range(num_models):
+ if random_seed is None:
+ seed = rng.randint(5000)
+ n_comps = max_components
+ recon_error, counter = 0, 0
+ Hs, Ws = [], []
+
+ # Inner loop for iterative NMF
+ for z in range(max_iterations):
+ nmf = NMF(n_components=n_comps, random_state=seed)
+
+ if counter == 0:
+ nmf_temp = nmf.fit_transform(x)
+ else:
+ with np.errstate(invalid="raise", divide="raise"):
+ try:
+ nmf_temp_2 = nmf.fit_transform(nmf_temp)
+ except FloatingPointError:
+ print("Warning encountered in NMF: Returning last result")
+ break
+ Ws.append(nmf_temp)
+ Hs.append(np.transpose(nmf.components_))
+ recon_error += nmf.reconstruction_err_
+ counter += 1
+ if counter >= max_iterations:
+ break
+ elif counter > 1:
+ with np.errstate(invalid="raise", divide="raise"):
+ try:
+ tril = np.tril(np.corrcoef(nmf_temp_2, rowvar=False), k=-1)
+ nmf_temp = nmf_temp_2
+ except FloatingPointError:
+ print(
+ "Warning encountered in correlation: Returning last result. Try larger merge_thresh."
+ )
+ break
+ else:
+ tril = np.tril(np.corrcoef(nmf_temp, rowvar=False), k=-1)
+
+ # Merge correlated features
+ if np.nanmax(tril) >= merge_thresh:
+ inds = np.argwhere(tril >= merge_thresh)
+ for n in range(inds.shape[0]):
+ nmf_temp[:, inds[n, 0]] += nmf_temp[:, inds[n, 1]]
+ ys_sorted = np.sort(np.unique(inds[n, 1]))[::-1]
+ for n in range(ys_sorted.shape[0]):
+ nmf_temp = np.delete(nmf_temp, ys_sorted[n], axis=1)
+ else:
+ break
+ n_comps = nmf_temp.shape[1] - 1
+ if n_comps <= 2:
+ break
+
+ if save_all_models is True:
+ W.append(nmf_temp)
+
+ elif (recon_error / counter) < err:
+ err = recon_error / counter
+ W = nmf_temp
+ return W
+
+
+@ignore_warnings(category=ConvergenceWarning)
+def _gmm_single(x, cv, components, num_models, random_seed=None, return_all=True):
+ """
+ Runs GMM several times and saves value with best BIC score
+
+ Args:
+ x (np.ndarray): Data
+ cv (list of str): Covariance, must be 'spherical', 'tied', 'diag', or 'full'
+ components (list of ints): Number of output clusters
+ num_models (int): Number of models to run. Only one is returned
+ random_seed (int): Random seed
+ return_all (bool): Whether or not to return all models.
+
+ Returns:
+ gmm_list OR best_gmm: List of class identity or classes for best model
+ gmm_labels OR best_gmm_labels: Label list for all models or labels for best model
+ gmm_proba OR best_gmm_proba: Probability list of class belonging or probability for best model
+ """
+ if return_all is True:
+ gmm_list = []
+ gmm_labels = []
+ gmm_proba = []
+ lowest_bic = np.infty
+ bic_temp = 0
+ if random_seed is None:
+ rng = np.random.RandomState(seed=42)
+ else:
+ seed = random_seed
+ for n in range(num_models):
+ if random_seed is None:
+ seed = rng.randint(5000)
+ for j in range(len(components)):
+ for cv_type in cv:
+ gmm = GaussianMixture(
+ n_components=components[j],
+ covariance_type=cv_type,
+ random_state=seed,
+ )
+ labels = gmm.fit_predict(x)
+ bic_temp = gmm.bic(x)
+
+ if return_all is True:
+ gmm_list.append(gmm)
+ gmm_labels.append(labels)
+ gmm_proba.append(gmm.predict_proba(x))
+
+ elif return_all is False:
+ if bic_temp < lowest_bic:
+ lowest_bic = bic_temp
+ best_gmm = gmm
+ best_gmm_labels = labels
+ best_gmm_proba = gmm.predict_proba(x)
+
+ if return_all is True:
+ return gmm_list, gmm_labels, gmm_proba
+ return best_gmm, best_gmm_labels, best_gmm_proba
diff --git a/py4DSTEM/process/diffraction/WK_scattering_factors.py b/py4DSTEM/process/diffraction/WK_scattering_factors.py
new file mode 100644
index 000000000..70110a977
--- /dev/null
+++ b/py4DSTEM/process/diffraction/WK_scattering_factors.py
@@ -0,0 +1,585 @@
+import numpy as np
+from scipy.special import expi
+
+# from functools import lru_cache
+
+from py4DSTEM.process.utils import electron_wavelength_angstrom
+
+"""
+Weickenmeier-Kohl absorptive scattering factors, adapted by SE Zeltmann from EMsoftLib/others.f90
+by Mark De Graef, who adapted it from Weickenmeier's original f77 code.
+"""
+
+
+def compute_WK_factor(
+ g: np.ndarray,
+ Z: int,
+ accelerating_voltage: float,
+ thermal_sigma: float = None,
+ include_core: bool = True,
+ include_phonon: bool = True,
+ verbose=False,
+) -> np.complex128:
+ """
+ Compute the Weickenmeier-Kohl atomic scattering factors, using the parameterization
+ of the elastic part and computation of the inelastic part found in EMsoftLib/others.f90.
+ Return value should be in Å.
+
+ This implementation always returns the absorptive, relativistically corrected factors.
+
+ Currently this is mostly a direct translation of the Fortran code, along with
+ the accompanying comments from the original in quotation marks. Colin Ophus
+ vectorized it around v0.13.17. Currently it is only vectorized over `g` (i.e.
+ `Z` and all other args must be a single value.)
+
+ This method uses an 8-parameter fit to the elastic form factors, and then computes the
+ absorptive form factors using an analytic solution based on that fitting function.
+
+ Args: (note that these values cannot be arrays: the code is not vectorized)
+ g (float/ndarray): Scattering vector magnitude in the crystallographic/py4DSTEM
+ convention, 1/d_hkl in units of 1/Å
+ Z (int): Atomic number. Data are available for H thru Cf (1 thru 98)
+ accelerating_voltage (float): Accelerating voltage in eV.
+ thermal_sigma (float): RMS atomic displacement for TDS, in Å
+ (This is often written as 〈u〉in papers)
+ include_core (bool): If True, include the core loss contribution to the absorptive
+ form factors.
+ include_phonon (bool): If True, include the phonon/TDS contribution to the
+ absorptive form factors.
+ Returns:
+ Fscatt (np.complex128): The computed atomic form factor
+ """
+
+ # the WK Fortran code works in weird units:
+ # lowercase "g", our input, is the standard crystallographic quantity, in Å^-1
+ # uppercase "G" is the "G" in others.f90:FSCATT, g * 2π
+ # uppercase "S" is the "S" in others.f90:FSCATT, G / 4π = g / 2
+ G = g * 2.0 * np.pi
+ S = g / 2.0
+
+ if verbose:
+ print(f"S:{S}")
+
+ accelerating_voltage_kV = accelerating_voltage / 1.0e3
+
+ if thermal_sigma is not None:
+ UL = thermal_sigma
+ DWF = np.exp(-0.5 * UL**2 * G**2)
+ else:
+ UL = 0.0
+ DWF = 1.0
+
+ if verbose:
+ print(f"DWF:{DWF}")
+
+ A = WK_A_param[int(Z) - 1]
+ B = WK_B_param[int(Z) - 1]
+
+ if verbose:
+ print(f"A:{A}")
+ print(f"B:{B}")
+
+ # WEKO(A,B,S)
+ WK = np.zeros_like(S)
+ for i in range(4):
+ argu = B[i] * S**2
+ sub = argu < 1.0
+ WK[sub] += A[i] * B[i] * (1.0 - 0.5 * argu[sub])
+ sub = np.logical_and(argu >= 1.0, argu <= 20.0)
+ WK[sub] += A[i] * (1.0 - np.exp(-argu[sub])) / S[sub] ** 2
+ sub = argu > 20.0
+ WK[sub] += A[i] / S[sub] ** 2
+
+ Freal = 4.0 * np.pi * DWF * WK
+
+ if verbose:
+ print(f"Freal:{Freal}")
+
+ #################################################
+ # calculate "core" contribution, following FCORE:
+ k0 = (
+ 2.0 * np.pi / electron_wavelength_angstrom(accelerating_voltage)
+ ) # remember, physicist units here
+
+ if include_core:
+ # "CALCULATE CHARACTERISTIC ENERGY LOSS AND ANGLE"
+ DE = 6.0e-3 * Z
+ theta_e = (
+ DE
+ / (2.0 * accelerating_voltage_kV)
+ * (2.0 * accelerating_voltage_kV + 1022.0)
+ / (accelerating_voltage_kV + 1022.0)
+ )
+
+ # "SCREENING PARAMETER OF YUKAWA POTENTIAL"
+ R = 0.885 * 0.5289 / Z ** (1.0 / 3.0)
+
+ # "CALCULATE NORMALISING ANGLE"
+ TA = 1.0 / (k0 * R)
+
+ # "CALCULATE BRAGG ANGLE"
+ TB = G / (2.0 * k0)
+
+ # "NORMALIZE"
+ OMEGA = 2.0 * TB / TA
+ KAPPA = theta_e / TA
+
+ K2 = KAPPA * KAPPA
+ O2 = OMEGA * OMEGA
+
+ X1 = (
+ OMEGA
+ / ((1.0 + O2) * np.sqrt(O2 + 4.0 * K2))
+ * np.log((OMEGA + np.sqrt(O2 + 4.0 * K2)) / (2.0 * KAPPA))
+ )
+ X2 = (
+ 1.0
+ / np.sqrt((1.0 + O2) * (1.0 + O2) + 4.0 * K2 * O2)
+ * np.log(
+ (1.0 + 2.0 * K2 + O2 + np.sqrt((1.0 + O2) * (1.0 + O2) + 4.0 * K2 * O2))
+ / (2.0 * KAPPA * np.sqrt(1.0 + K2))
+ )
+ )
+
+ X3 = np.zeros_like(OMEGA)
+ sub = OMEGA > 1e-2
+ X3[sub] = (
+ 1.0
+ / (OMEGA[sub] * np.sqrt(O2[sub] + 4.0 * (1.0 + K2)))
+ * np.log(
+ (OMEGA[sub] + np.sqrt(O2[sub] + 4.0 * (1.0 + K2)))
+ / (2.0 * np.sqrt(1.0 + K2))
+ )
+ )
+ sub = np.logical_not(sub)
+ X3[sub] = 1.0 / (4.0 * (1.0 + K2))
+
+ HI = 2 * Z / (TA * TA) * (-X1 + X2 - X3)
+
+ A0 = 0.5289
+ Fcore = 4.0 / (A0 * A0) * 2.0 * np.pi / (k0 * k0) * HI
+
+ if verbose:
+ print(f"Fcore:{Fcore}")
+ else:
+ Fcore = 0.0
+
+ ##########################################################
+ # calculate phonon contribution, following FPHON(G,UL,A,B)
+ Fphon = 0.0
+ if include_phonon:
+ U2 = UL**2
+
+ A1 = A * (4.0 * np.pi) ** 2
+ B1 = B / (4.0 * np.pi) ** 2
+
+ for jj in range(4):
+ Fphon += (
+ A1[jj]
+ * A1[jj]
+ * (DWF * RI1(B1[jj], B1[jj], G) - RI2(B1[jj], B1[jj], G, UL))
+ )
+ for ii in range(jj + 1):
+ Fphon += (
+ 2.0
+ * A1[jj]
+ * A1[ii]
+ * (DWF * RI1(B1[ii], B1[jj], G) - RI2(B1[ii], B1[jj], G, UL))
+ )
+ if verbose:
+ print(f"Fphon:{Fphon}")
+
+ Fimag = (Fcore * DWF) + Fphon
+
+ # perform relativistic correction
+ gamma = (accelerating_voltage_kV + 511.0) / (511.0)
+
+ if verbose:
+ print(f"gamma:{gamma}")
+
+ Fscatt = np.complex128((Freal * gamma) + (1.0j * (Fimag * gamma**2 / k0)))
+
+ if verbose:
+ print(f"Fscatt:{Fscatt}")
+
+ return (
+ Fscatt * 0.4787801 * 0.664840340614319 / (4.0 * np.pi)
+ ) # convert to Å, and remove extra physicist factors, as performed in diffraction.f90:427,576,630
+
+
+##############################################
+# Helper integral functions for DW calculation
+
+
+def RI1(BI, BJ, G):
+ # "ERSTES INTEGRAL FUER DIE ABSORPTIONSPOTENTIALE"
+ eps = np.max([BI, BJ]) * G**2
+
+ ri1 = np.zeros_like(G)
+
+ sub = eps <= 0.1
+ ri1[sub] = np.pi * (BI * np.log((BI + BJ) / BI) + BJ * np.log((BI + BJ) / BJ))
+
+ sub = np.logical_and(eps <= 0.1, G > 0.0)
+ temp = 0.5 * BI**2 * np.log(BI / (BI + BJ)) + 0.5 * BJ**2 * np.log(
+ BJ / (BI + BJ)
+ )
+ temp += 0.75 * (BI**2 + BJ**2) - 0.25 * (BI + BJ) ** 2
+ temp -= 0.5 * (BI - BJ) ** 2
+ ri1[sub] += np.pi * G[sub] ** 2 * temp
+
+ sub = eps > 0.1
+ ri1[sub] = (
+ 2.0 * 0.5772157
+ + np.log(BI * G[sub] ** 2)
+ + np.log(BJ * G[sub] ** 2)
+ - 2.0 * expi(-BI * BJ * G[sub] ** 2 / (BI + BJ))
+ )
+
+ ri1[sub] += RIH1(
+ BI * G[sub] ** 2, BI * G[sub] ** 2 * BI / (BI + BJ), BI * G[sub] ** 2
+ )
+
+ ri1[sub] += RIH1(
+ BJ * G[sub] ** 2, BJ * G[sub] ** 2 * BJ / (BI + BJ), BJ * G[sub] ** 2
+ )
+ ri1[sub] *= np.pi / G[sub] ** 2
+
+ return ri1
+
+
+def RI2(BI, BJ, G, U):
+ # "ZWEITES INTEGRAL FUER DIE ABSORPTIONSPOTENTIALE"
+ U2 = U**2
+ U22 = 0.5 * U2
+ G2 = G**2
+ BIUH = BI + 0.5 * U2
+ BJUH = BJ + 0.5 * U2
+ BIU = BI + U2
+ BJU = BJ + U2
+
+ # "IST DIE ASYMPTOTISCHE ENTWICKLUNG ANWENDBAR?""
+ EPS = np.max([BI, BJ, U2])
+ EPS = EPS * G2
+
+ ri2 = np.zeros_like(G)
+
+ sub = EPS <= 0.1
+ ri2[sub] = (BI + U2) * np.log((BI + BJ + U2) / (BI + U2)) + BJ * np.log(
+ (BI + BJ + U2) / (BJ + U2)
+ )
+ if U2 > 0.0:
+ ri2[sub] += U2 * np.log(U2 / (BJ + U2))
+ ri2[sub] *= np.pi
+
+ if U2 > 0.0:
+ TEMP = 0.5 * U22 * U22 * np.log(BIU * BJU / (U2 * U2))
+ else:
+ TEMP = 0.0
+ TEMP = TEMP + 0.5 * BIUH * BIUH * np.log(BIU / (BIUH + BJUH))
+ TEMP = TEMP + 0.5 * BJUH * BJUH * np.log(BJU / (BIUH + BJUH))
+ TEMP = TEMP + 0.25 * BIU * BIU + 0.5 * BI * BI
+ TEMP = TEMP + 0.25 * BJU * BJU + 0.5 * BJ * BJ
+ TEMP = TEMP - 0.25 * (BIUH + BJUH) * (BIUH + BJUH)
+ TEMP = TEMP - 0.5 * ((BI * BIU - BJ * BJU) / (BIUH + BJUH)) ** 2
+ TEMP = TEMP - U22 * U22
+ ri2[sub] += np.pi * G2[sub] * TEMP
+
+ sub = EPS > 0.1
+ ri2[sub] = expi(-0.5 * U2 * G2[sub] * BIUH / BIU) + expi(
+ -0.5 * U2 * G2[sub] * BJUH / BJU
+ )
+ ri2[sub] -= expi(-BIUH * BJUH * G2[sub] / (BIUH + BJUH)) + expi(
+ -0.25 * U2 * G2[sub]
+ )
+ ri2[sub] *= 2.0
+ X1 = 0.5 * U2 * G2[sub]
+ X2 = 0.25 * U2 * G2[sub]
+ X3 = 0.25 * U2 * U2 * G2[sub] / BIU
+ ri2[sub] += RIH1(X1, X2, X3)
+
+ X1 = 0.5 * U2 * G2[sub]
+ X2 = 0.25 * U2 * G2[sub]
+ X3 = 0.25 * U2 * U2 * G2[sub] / BJU
+ ri2[sub] += RIH1(X1, X2, X3)
+
+ X1 = BIUH * G2[sub]
+ X2 = BIUH * BIUH * G2[sub] / (BIUH + BJUH)
+ X3 = BIUH * BIUH * G2[sub] / BIU
+ ri2[sub] += RIH1(X1, X2, X3)
+
+ X1 = BJUH * G2[sub]
+ X2 = BJUH * BJUH * G2[sub] / (BIUH + BJUH)
+ X3 = BJUH * BJUH * G2[sub] / BJU
+ ri2[sub] += RIH1(X1, X2, X3)
+
+ ri2[sub] *= np.pi / G2[sub]
+
+ return ri2
+
+
+def RIH1(X1, X2, X3):
+ # "WERTET DEN AUSDRUCK EXP(-X1) * ( EI(X2)-EI(X3) ) AUS"
+ rih1 = np.zeros(X1.shape)
+
+ sub = np.logical_and(X2 <= 20.0, X3 <= 20.0)
+ rih1[sub] = np.exp(-X1[sub]) * (expi(X2[sub]) - expi(X3[sub]))
+
+ sub = np.logical_and(X2 > 20.0, X3 <= 20.0)
+ rih1[sub] = np.exp(X2[sub] - X1[sub]) * RIH2(X2[sub]) / X2[sub] - np.exp(
+ -X1[sub]
+ ) * expi(X3[sub])
+
+ sub = np.logical_and(X2 <= 20.0, X3 > 20.0)
+ rih1[sub] = (
+ np.exp(-X1[sub]) * expi(X2[sub])
+ - np.exp(X3[sub] - X1[sub]) * RIH2(X3[sub]) / X3[sub]
+ )
+
+ sub = np.logical_and(X2 > 20.0, X3 > 20.0)
+ rih1[sub] = (
+ np.exp(X2[sub] - X1[sub]) * RIH2(X2[sub]) / X2[sub]
+ - np.exp(X3[sub] - X1[sub]) * RIH2(X3[sub]) / X3[sub]
+ )
+
+ return rih1
+
+
+def RIH2(X):
+ """
+ WERTET X*EXP(-X)*EI(X) AUS FUER GROSSE X
+ DURCH INTERPOLATION DER TABELLE ... AUS ABRAMOWITZ
+ """
+ idx = np.floor(200.0 / X).astype("int")
+
+ sig = RIH2_tabulated_data[idx] + 200.0 * (
+ RIH2_tabulated_data[idx + 1] - RIH2_tabulated_data[idx]
+ ) * ((1.0 / X) - 0.5e-3 * idx)
+
+ return sig
+
+
+# NOTE - This function is present in EMSoftLib but apparently not used.
+def RIH3(X):
+ # "WERTET DEN AUSDRUCK EXP(-X) * EI(X) AUS"
+ if X <= 20.0:
+ return np.exp(-X) * expi(X)
+ else:
+ return RIH2(X) / X
+
+
+##################
+# TABULATED DATA #
+##################
+
+# fmt:off
+
+RIH2_tabulated_data = np.array([1.000000,1.005051,1.010206,1.015472,1.020852,
+ 1.026355,1.031985,1.037751,1.043662,1.049726,
+ 1.055956,1.062364,1.068965,1.075780,1.082830,
+ 1.090140,1.097737,1.105647,1.113894,1.122497,
+ 1.131470])
+
+
+WK_A_param = np.array([
+ 0.00427, 0.00957, 0.00802, 0.00209,
+ 0.01217, 0.02616,-0.00884, 0.01841,
+ 0.00251, 0.03576, 0.00988, 0.02370,
+ 0.01596, 0.02959, 0.04024, 0.01001,
+ 0.03652, 0.01140, 0.05677, 0.01506,
+ 0.04102, 0.04911, 0.05296, 0.00061,
+ 0.04123, 0.05740, 0.06529, 0.00373,
+ 0.03547, 0.03133, 0.10865, 0.01615,
+ 0.03957, 0.07225, 0.09581, 0.00792,
+ 0.02597, 0.02197, 0.13762, 0.05394,
+ 0.03283, 0.08858, 0.11688, 0.02516,
+ 0.03833, 0.17124, 0.03649, 0.04134,
+ 0.04388, 0.17743, 0.05047, 0.03957,
+ 0.03812, 0.17833, 0.06280, 0.05605,
+ 0.04166, 0.17817, 0.09479, 0.04463,
+ 0.04003, 0.18346, 0.12218, 0.03753,
+ 0.04245, 0.17645, 0.15814, 0.03011,
+ 0.05011, 0.16667, 0.17074, 0.04358,
+ 0.04058, 0.17582, 0.20943, 0.02922,
+ 0.04001, 0.17416, 0.20986, 0.05497,
+ 0.09685, 0.14777, 0.20981, 0.04852,
+ 0.06667, 0.17356, 0.22710, 0.05957,
+ 0.05118, 0.16791, 0.26700, 0.06476,
+ 0.03204, 0.18460, 0.30764, 0.05052,
+ 0.03866, 0.17782, 0.31329, 0.06898,
+ 0.05455, 0.16660, 0.33208, 0.06947,
+ 0.05942, 0.17472, 0.34423, 0.06828,
+ 0.06049, 0.16600, 0.37302, 0.07109,
+ 0.08034, 0.15838, 0.40116, 0.05467,
+ 0.02948, 0.19200, 0.42222, 0.07480,
+ 0.16157, 0.32976, 0.18964, 0.06148,
+ 0.16184, 0.35705, 0.17618, 0.07133,
+ 0.06190, 0.18452, 0.41600, 0.12793,
+ 0.15913, 0.41583, 0.13385, 0.10549,
+ 0.16514, 0.41202, 0.12900, 0.13209,
+ 0.15798, 0.41181, 0.14254, 0.14987,
+ 0.16535, 0.44674, 0.24245, 0.03161,
+ 0.16039, 0.44470, 0.24661, 0.05840,
+ 0.16619, 0.44376, 0.25613, 0.06797,
+ 0.16794, 0.44505, 0.27188, 0.07313,
+ 0.16552, 0.45008, 0.30474, 0.06161,
+ 0.17327, 0.44679, 0.32441, 0.06143,
+ 0.16424, 0.45046, 0.33749, 0.07766,
+ 0.18750, 0.44919, 0.36323, 0.05388,
+ 0.16081, 0.45211, 0.40343, 0.06140,
+ 0.16599, 0.43951, 0.41478, 0.08142,
+ 0.16547, 0.44658, 0.45401, 0.05959,
+ 0.17154, 0.43689, 0.46392, 0.07725,
+ 0.15752, 0.44821, 0.48186, 0.08596,
+ 0.15732, 0.44563, 0.48507, 0.10948,
+ 0.16971, 0.42742, 0.48779, 0.13653,
+ 0.14927, 0.43729, 0.49444, 0.16440,
+ 0.18053, 0.44724, 0.48163, 0.15995,
+ 0.13141, 0.43855, 0.50035, 0.22299,
+ 0.31397, 0.55648, 0.39828, 0.04852,
+ 0.32756, 0.53927, 0.39830, 0.07607,
+ 0.30887, 0.53804, 0.42265, 0.09559,
+ 0.28398, 0.53568, 0.46662, 0.10282,
+ 0.35160, 0.56889, 0.42010, 0.07246,
+ 0.33810, 0.58035, 0.44442, 0.07413,
+ 0.35449, 0.59626, 0.43868, 0.07152,
+ 0.35559, 0.60598, 0.45165, 0.07168,
+ 0.38379, 0.64088, 0.41710, 0.06708,
+ 0.40352, 0.64303, 0.40488, 0.08137,
+ 0.36838, 0.64761, 0.47222, 0.06854,
+ 0.38514, 0.68422, 0.44359, 0.06775,
+ 0.37280, 0.67528, 0.47337, 0.08320,
+ 0.39335, 0.70093, 0.46774, 0.06658,
+ 0.40587, 0.71223, 0.46598, 0.06847,
+ 0.39728, 0.73368, 0.47795, 0.06759,
+ 0.40697, 0.73576, 0.47481, 0.08291,
+ 0.40122, 0.78861, 0.44658, 0.08799,
+ 0.41127, 0.76965, 0.46563, 0.10180,
+ 0.39978, 0.77171, 0.48541, 0.11540,
+ 0.39130, 0.80752, 0.48702, 0.11041,
+ 0.40436, 0.80701, 0.48445, 0.12438,
+ 0.38816, 0.80163, 0.51922, 0.13514,
+ 0.39551, 0.80409, 0.53365, 0.13485,
+ 0.40850, 0.83052, 0.53325, 0.11978,
+ 0.40092, 0.85415, 0.53346, 0.12747,
+ 0.41872, 0.88168, 0.54551, 0.09404,
+ 0.43358, 0.88007, 0.52966, 0.12059,
+ 0.40858, 0.87837, 0.56392, 0.13698,
+ 0.41637, 0.85094, 0.57749, 0.16700,
+ 0.38951, 0.83297, 0.60557, 0.20770,
+ 0.41677, 0.88094, 0.55170, 0.21029,
+ 0.50089, 1.00860, 0.51420, 0.05996,
+ 0.47470, 0.99363, 0.54721, 0.09206,
+ 0.47810, 0.98385, 0.54905, 0.12055,
+ 0.47903, 0.97455, 0.55883, 0.14309,
+ 0.48351, 0.98292, 0.58877, 0.12425,
+ 0.48664, 0.98057, 0.61483, 0.12136,
+ 0.46078, 0.97139, 0.66506, 0.13012,
+ 0.49148, 0.98583, 0.67674, 0.09725,
+ 0.50865, 0.98574, 0.68109, 0.09977,
+ 0.46259, 0.97882, 0.73056, 0.12723,
+ 0.46221, 0.95749, 0.76259, 0.14086,
+ 0.48500, 0.95602, 0.77234, 0.13374,
+ ]).reshape(98,4)
+
+
+WK_B_param = np.array([
+ 4.17218, 16.05892, 26.78365, 69.45643,
+ 1.83008, 7.20225, 16.13585, 18.75551,
+ 0.02620, 2.00907, 10.80597,130.49226,
+ 0.38968, 1.99268, 46.86913,108.84167,
+ 0.50627, 3.68297, 27.90586, 74.98296,
+ 0.41335, 10.98289, 34.80286,177.19113,
+ 0.29792, 7.84094, 22.58809, 72.59254,
+ 0.17964, 2.60856, 11.79972, 38.02912,
+ 0.16403, 3.96612, 12.43903, 40.05053,
+ 0.09101, 0.41253, 5.02463, 17.52954,
+ 0.06008, 2.07182, 7.64444,146.00952,
+ 0.07424, 2.87177, 18.06729, 97.00854,
+ 0.09086, 2.53252, 30.43883, 98.26737,
+ 0.05396, 1.86461, 22.54263, 72.43144,
+ 0.05564, 1.62500, 24.45354, 64.38264,
+ 0.05214, 1.40793, 23.35691, 53.59676,
+ 0.04643, 1.15677, 19.34091, 52.88785,
+ 0.07991, 1.01436, 15.67109, 39.60819,
+ 0.03352, 0.82984, 14.13679,200.97722,
+ 0.02289, 0.71288, 11.18914,135.02390,
+ 0.12527, 1.34248, 12.43524,131.71112,
+ 0.05198, 0.86467, 10.59984,103.56776,
+ 0.03786, 0.57160, 8.30305, 91.78068,
+ 0.00240, 0.44931, 7.92251, 86.64058,
+ 0.01836, 0.41203, 6.73736, 76.30466,
+ 0.03947, 0.43294, 6.26864, 71.29470,
+ 0.03962, 0.43253, 6.05175, 68.72437,
+ 0.03558, 0.39976, 5.36660, 62.46894,
+ 0.05475, 0.45736, 5.38252, 60.43276,
+ 0.00137, 0.26535, 4.48040, 54.26088,
+ 0.10455, 2.18391, 9.04125, 75.16958,
+ 0.09890, 2.06856, 9.89926, 68.13783,
+ 0.01642, 0.32542, 3.51888, 44.50604,
+ 0.07669, 1.89297, 11.31554, 46.32082,
+ 0.08199, 1.76568, 9.87254, 38.10640,
+ 0.06939, 1.53446, 8.98025, 33.04365,
+ 0.07044, 1.59236, 17.53592,215.26198,
+ 0.06199, 1.41265, 14.33812,152.80257,
+ 0.06364, 1.34205, 13.66551,125.72522,
+ 0.06565, 1.25292, 13.09355,109.50252,
+ 0.05921, 1.15624, 13.24924, 98.69958,
+ 0.06162, 1.11236, 12.76149, 90.92026,
+ 0.05081, 0.99771, 11.28925, 84.28943,
+ 0.05120, 1.08672, 12.23172, 85.27316,
+ 0.04662, 0.85252, 10.51121, 74.53949,
+ 0.04933, 0.79381, 9.30944, 41.17414,
+ 0.04481, 0.75608, 9.34354, 67.91975,
+ 0.04867, 0.71518, 8.40595, 64.24400,
+ 0.03672, 0.64379, 7.83687, 73.37281,
+ 0.03308, 0.60931, 7.04977, 64.83582,
+ 0.04023, 0.58192, 6.29247, 55.57061,
+ 0.02842, 0.50687, 5.60835, 48.28004,
+ 0.03830, 0.58340, 6.47550, 47.08820,
+ 0.02097, 0.41007, 4.52105, 37.18178,
+ 0.07813, 1.45053, 15.05933,199.48830,
+ 0.08444, 1.40227, 13.12939,160.56676,
+ 0.07206, 1.19585, 11.55866,127.31371,
+ 0.05717, 0.98756, 9.95556,117.31874,
+ 0.08249, 1.43427, 12.37363,150.55968,
+ 0.07081, 1.31033, 11.44403,144.17706,
+ 0.07442, 1.38680, 11.54391,143.72185,
+ 0.07155, 1.34703, 11.00432,140.09138,
+ 0.07794, 1.55042, 11.89283,142.79585,
+ 0.08508, 1.60712, 11.45367,116.64063,
+ 0.06520, 1.32571, 10.16884,134.69034,
+ 0.06850, 1.43566, 10.57719,131.88972,
+ 0.06264, 1.26756, 9.46411,107.50194,
+ 0.06750, 1.35829, 9.76480,127.40374,
+ 0.06958, 1.38750, 9.41888,122.10940,
+ 0.06574, 1.31578, 9.13448,120.98209,
+ 0.06517, 1.29452, 8.67569,100.34878,
+ 0.06213, 1.30860, 9.18871, 91.20213,
+ 0.06292, 1.23499, 8.42904, 77.59815,
+ 0.05693, 1.15762, 7.83077, 67.14066,
+ 0.05145, 1.11240, 8.33441, 65.71782,
+ 0.05573, 1.11159, 8.00221, 57.35021,
+ 0.04855, 0.99356, 7.38693, 51.75829,
+ 0.04981, 0.97669, 7.38024, 44.52068,
+ 0.05151, 1.00803, 8.03707, 45.01758,
+ 0.04693, 0.98398, 7.83562, 46.51474,
+ 0.05161, 1.02127, 9.18455, 64.88177,
+ 0.05154, 1.03252, 8.49678, 58.79463,
+ 0.04200, 0.90939, 7.71158, 57.79178,
+ 0.04661, 0.87289, 6.84038, 51.36000,
+ 0.04168, 0.73697, 5.86112, 43.78613,
+ 0.04488, 0.83871, 6.44020, 43.51940,
+ 0.05786, 1.20028, 13.85073,172.15909,
+ 0.05239, 1.03225, 11.49796,143.12303,
+ 0.05167, 0.98867, 10.52682,112.18267,
+ 0.04931, 0.95698, 9.61135, 95.44649,
+ 0.04748, 0.93369, 9.89867,102.06961,
+ 0.04660, 0.89912, 9.69785,100.23434,
+ 0.04323, 0.78798, 8.71624, 92.30811,
+ 0.04641, 0.85867, 9.51157,111.02754,
+ 0.04918, 0.87026, 9.41105,104.98576,
+ 0.03904, 0.72797, 8.00506, 86.41747,
+ 0.03969, 0.68167, 7.29607, 75.72682,
+ 0.04291, 0.69956, 7.38554, 77.18528,
+ ]).reshape(98,4)
diff --git a/py4DSTEM/process/diffraction/__init__.py b/py4DSTEM/process/diffraction/__init__.py
new file mode 100644
index 000000000..942547749
--- /dev/null
+++ b/py4DSTEM/process/diffraction/__init__.py
@@ -0,0 +1,4 @@
+from py4DSTEM.process.diffraction.crystal import *
+from py4DSTEM.process.diffraction.flowlines import *
+from py4DSTEM.process.diffraction.tdesign import *
+from py4DSTEM.process.diffraction.crystal_phase import *
diff --git a/py4DSTEM/process/diffraction/crystal.py b/py4DSTEM/process/diffraction/crystal.py
new file mode 100644
index 000000000..fb2911992
--- /dev/null
+++ b/py4DSTEM/process/diffraction/crystal.py
@@ -0,0 +1,1586 @@
+# Functions for calculating diffraction patterns, matching them to experiments, and creating orientation and phase maps.
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.patches import Circle
+from fractions import Fraction
+from typing import Union, Optional
+import sys
+
+from emdfile import PointList
+from py4DSTEM.process.utils import single_atom_scatter, electron_wavelength_angstrom
+
+from py4DSTEM.process.diffraction.utils import Orientation
+
+
+class Crystal:
+ """
+ A class storing a single crystal structure, and associated diffraction data.
+
+ """
+
+ # Various methods for the Crystal class are implemented in a separate file. This
+ # import statement inside the class declaration imports them as methods of the class!
+ # (see https://stackoverflow.com/a/47562412)
+
+ # Automated Crystal Orientation Mapping is implemented in crystal_ACOM.py
+ from py4DSTEM.process.diffraction.crystal_ACOM import (
+ orientation_plan,
+ match_orientations,
+ match_single_pattern,
+ cluster_grains,
+ cluster_orientation_map,
+ calculate_strain,
+ save_ang_file,
+ symmetry_reduce_directions,
+ orientation_map_to_orix_CrystalMap,
+ save_ang_file,
+ )
+
+ from py4DSTEM.process.diffraction.crystal_viz import (
+ plot_structure,
+ plot_structure_factors,
+ plot_scattering_intensity,
+ plot_orientation_zones,
+ plot_orientation_plan,
+ plot_orientation_maps,
+ plot_fiber_orientation_maps,
+ plot_clusters,
+ plot_cluster_size,
+ )
+
+ from py4DSTEM.process.diffraction.crystal_calibrate import (
+ calibrate_pixel_size,
+ calibrate_unit_cell,
+ )
+
+ # Dynamical diffraction calculations are implemented in crystal_bloch.py
+ from py4DSTEM.process.diffraction.crystal_bloch import (
+ generate_dynamical_diffraction_pattern,
+ generate_CBED,
+ calculate_dynamical_structure_factors,
+ )
+
+ def __init__(
+ self,
+ positions,
+ numbers,
+ cell,
+ ):
+ """
+ Args:
+ positions (np.array): fractional coordinates of each atom in the cell
+ numbers (np.array): Z number for each atom in the cell, if one number passed it is used for all atom positions
+ cell (np.array): specify the unit cell, using a variable number of parameters
+ 1 number: the lattice parameter for a cubic cell
+ 3 numbers: the three lattice parameters for an orthorhombic cell
+ 6 numbers: the a,b,c lattice parameters and ɑ,β,ɣ angles for any cell
+ 3x3 array: row vectors containing the (u,v,w) lattice vectors.
+
+ """
+ # Initialize Crystal
+ self.positions = np.asarray(positions) #: fractional atomic coordinates
+
+ #: atomic numbers - if only one value is provided, assume all atoms are same species
+ numbers = np.asarray(numbers, dtype="intp")
+ if np.size(numbers) == 1:
+ self.numbers = np.ones(self.positions.shape[0], dtype="intp") * numbers
+ elif np.size(numbers) == self.positions.shape[0]:
+ self.numbers = numbers
+ else:
+ raise Exception("Number of positions and atomic numbers do not match")
+
+ # unit cell, as one of:
+ # [a a a 90 90 90]
+ # [a b c 90 90 90]
+ # [a b c alpha beta gamma]
+ cell = np.asarray(cell, dtype="float_")
+ if np.size(cell) == 1:
+ self.cell = np.hstack([cell, cell, cell, 90, 90, 90])
+ elif np.size(cell) == 3:
+ self.cell = np.hstack([cell, 90, 90, 90])
+ elif np.size(cell) == 6:
+ self.cell = cell
+ elif np.shape(cell)[0] == 3 and np.shape(cell)[1] == 3:
+ self.lat_real = np.array(cell)
+ a = np.linalg.norm(self.lat_real[0, :])
+ b = np.linalg.norm(self.lat_real[1, :])
+ c = np.linalg.norm(self.lat_real[2, :])
+ alpha = np.rad2deg(
+ np.arccos(
+ np.clip(
+ np.sum(self.lat_real[1, :] * self.lat_real[2, :]) / b / c, -1, 1
+ )
+ )
+ )
+ beta = np.rad2deg(
+ np.arccos(
+ np.clip(
+ np.sum(self.lat_real[0, :] * self.lat_real[2, :]) / a / c, -1, 1
+ )
+ )
+ )
+ gamma = np.rad2deg(
+ np.arccos(
+ np.clip(
+ np.sum(self.lat_real[0, :] * self.lat_real[1, :]) / a / b, -1, 1
+ )
+ )
+ )
+ self.cell = (a, b, c, alpha, beta, gamma)
+ else:
+ raise Exception("Cell cannot contain " + np.size(cell) + " entries")
+
+ # pymatgen flag
+ if "pymatgen" in sys.modules:
+ self.pymatgen_available = True
+ else:
+ self.pymatgen_available = False
+ # Calculate lattice parameters
+ self.calculate_lattice()
+
+ def calculate_lattice(self):
+ if not hasattr(self, "lat_real"):
+ # calculate unit cell lattice vectors
+ a = self.cell[0]
+ b = self.cell[1]
+ c = self.cell[2]
+ alpha = np.deg2rad(self.cell[3])
+ beta = np.deg2rad(self.cell[4])
+ gamma = np.deg2rad(self.cell[5])
+ f = np.cos(beta) * np.cos(gamma) - np.cos(alpha)
+ vol = (
+ a
+ * b
+ * c
+ * np.sqrt(
+ 1
+ + 2 * np.cos(alpha) * np.cos(beta) * np.cos(gamma)
+ - np.cos(alpha) ** 2
+ - np.cos(beta) ** 2
+ - np.cos(gamma) ** 2
+ )
+ )
+ self.lat_real = np.array(
+ [
+ [a, 0, 0],
+ [b * np.cos(gamma), b * np.sin(gamma), 0],
+ [
+ c * np.cos(beta),
+ -c * f / np.sin(gamma),
+ vol / (a * b * np.sin(gamma)),
+ ],
+ ]
+ )
+
+ # Inverse lattice, metric tensors
+ self.metric_real = self.lat_real @ self.lat_real.T
+ self.metric_inv = np.linalg.inv(self.metric_real)
+ self.lat_inv = self.metric_inv @ self.lat_real
+
+ def get_strained_crystal(
+ self,
+ exx=0.0,
+ eyy=0.0,
+ ezz=0.0,
+ exy=0.0,
+ exz=0.0,
+ eyz=0.0,
+ deformation_matrix=None,
+ return_deformation_matrix=False,
+ ):
+ """
+ This method returns new Crystal class with strain applied. The directions of (x,y,z)
+ are with respect to the default Crystal orientation, which can be checked with
+ print(Crystal.lat_real) applied to the original Crystal.
+
+ Strains are given in fractional values, so exx = 0.01 is 1% strain along the x direction.
+ Deformation matrix should be of the form:
+ deformation_matrix = np.array([
+ [1.0+exx, 1.0*exy, 1.0*exz],
+ [1.0*exy, 1.0+eyy, 1.0*eyz],
+ [1.0*exz, 1.0*eyz, 1.0+ezz],
+ ])
+
+ Parameters
+ --------
+
+ exx (float):
+ fractional strain along the xx direction
+ eyy (float):
+ fractional strain along the yy direction
+ ezz (float):
+ fractional strain along the zz direction
+ exy (float):
+ fractional strain along the xy direction
+ exz (float):
+ fractional strain along the xz direction
+ eyz (float):
+ fractional strain along the yz direction
+ deformation_matrix (np.ndarray):
+ 3x3 array describing deformation matrix
+ return_deformation_matrix (bool):
+ boolean switch to return deformation matrix
+
+ Returns
+ --------
+ return_deformation_matrix == False:
+ strained_crystal (py4DSTEM.Crystal)
+ return_deformation_matrix == True:
+ (strained_crystal, deformation_matrix)
+ """
+
+ # deformation matrix
+ if deformation_matrix is None:
+ deformation_matrix = np.array(
+ [
+ [1.0 + exx, 1.0 * exy, 1.0 * exz],
+ [1.0 * exy, 1.0 + eyy, 1.0 * eyz],
+ [1.0 * exz, 1.0 * eyz, 1.0 + ezz],
+ ]
+ )
+
+ # new unit cell
+ lat_new = self.lat_real @ deformation_matrix
+
+ # make new crystal class
+ from py4DSTEM.process.diffraction import Crystal
+
+ crystal_strained = Crystal(
+ positions=self.positions.copy(),
+ numbers=self.numbers.copy(),
+ cell=lat_new,
+ )
+
+ if return_deformation_matrix:
+ return crystal_strained, deformation_matrix
+ else:
+ return crystal_strained
+
+ def from_CIF(CIF, conventional_standard_structure=True):
+ """
+ Create a Crystal object from a CIF file, using pymatgen to import the CIF
+
+ Note that pymatgen typically prefers to return primitive unit cells,
+ which can be overridden by setting conventional_standard_structure=True.
+
+ Args:
+ CIF: (str or Path) path to the CIF File
+ conventional_standard_structure: (bool) if True, conventional standard unit cell will be returned
+ instead of the primitive unit cell pymatgen typically returns
+ """
+ from pymatgen.io.cif import CifParser
+
+ parser = CifParser(CIF)
+
+ structure = parser.get_structures()[0]
+
+ return Crystal.from_pymatgen_structure(
+ structure, conventional_standard_structure=conventional_standard_structure
+ )
+
+ def from_pymatgen_structure(
+ structure=None,
+ formula=None,
+ space_grp=None,
+ MP_key=None,
+ conventional_standard_structure=True,
+ ):
+ """
+ Create a Crystal object from a pymatgen Structure object.
+ If a Materials Project API key is installed, you may pass
+ the Materials Project ID of a structure, which will be
+ fetched through the MP API. For setup information see:
+ https://pymatgen.org/usage.html#setting-the-pmg-mapi-key-in-the-config-file.
+ Alternatively, Materials Porject API key can be pass as an argument through
+ the function (MP_key). To get your API key, please visit Materials Project website
+ and login/sign up using your email id. Once logged in, go to the dashboard
+ to generate your own API key (https://materialsproject.org/dashboard).
+
+ Note that pymatgen typically prefers to return primitive unit cells,
+ which can be overridden by setting conventional_standard_structure=True.
+
+ Args:
+ structure: (pymatgen Structure or str), if specified as a string, it will be considered
+ as a Materials Project ID of a structure, otherwise it will accept only
+ pymatgen Structure object. if None, MP database will be queried using the
+ specified formula and/or space groups for the available structure
+ formula: (str), pretty formula to search in the MP database, (note that the forumlas in MP
+ database are not always formatted in the conventional order. Please
+ visit Materials Project website for information (https://materialsproject.org/)
+ if None, structure argument must not be None
+ space_grp: (int) space group number of the forumula provided to query MP database. If None, MP will search
+ for all the available space groups for the formula provided and will consider the
+ one with lowest unit cell volume, only specify when using formula to search MP
+ database
+ MP_key: (str) Materials Project API key
+ conventional_standard_structure: (bool) if True, conventional standard unit cell will be returned
+ instead of the primitive unit cell pymatgen returns
+
+ """
+ import pymatgen as mg
+
+ if structure is not None:
+ if isinstance(structure, str):
+ from mp_api.client import MPRester
+
+ with MPRester(MP_key) as mpr:
+ structure = mpr.get_structure_by_material_id(structure)
+
+ assert isinstance(
+ structure, mg.core.Structure
+ ), "structure must be pymatgen Structure object"
+
+ structure = (
+ mg.symmetry.analyzer.SpacegroupAnalyzer(
+ structure
+ ).get_conventional_standard_structure()
+ if conventional_standard_structure
+ else structure
+ )
+ else:
+ from mp_api.client import MPRester
+
+ with MPRester(MP_key) as mpr:
+ if formula is None:
+ raise Exception(
+ "Atleast a formula needs to be provided to query from MP database!!"
+ )
+ query = mpr.query(
+ criteria={"pretty_formula": formula},
+ properties=["structure", "icsd_ids", "spacegroup"],
+ )
+ if space_grp:
+ query = [
+ query[i]
+ for i in range(len(query))
+ if mg.symmetry.analyzer.SpacegroupAnalyzer(
+ query[i]["structure"]
+ ).get_space_group_number()
+ == space_grp
+ ]
+ selected = query[
+ np.argmin(
+ [
+ query[i]["structure"].lattice.volume
+ for i in range(len(query))
+ ]
+ )
+ ]
+ structure = (
+ mg.symmetry.analyzer.SpacegroupAnalyzer(
+ selected["structure"]
+ ).get_conventional_standard_structure()
+ if conventional_standard_structure
+ else selected["structure"]
+ )
+
+ positions = structure.frac_coords #: fractional atomic coordinates
+
+ cell = np.array(
+ [
+ structure.lattice.a,
+ structure.lattice.b,
+ structure.lattice.c,
+ structure.lattice.alpha,
+ structure.lattice.beta,
+ structure.lattice.gamma,
+ ]
+ )
+
+ numbers = np.array([s.species.elements[0].Z for s in structure])
+
+ return Crystal(positions, numbers, cell)
+
+ def from_unitcell_parameters(
+ latt_params,
+ elements,
+ positions,
+ space_group=None,
+ lattice_type="cubic",
+ from_cartesian=False,
+ conventional_standard_structure=True,
+ ):
+ """
+ Create a Crystal using pymatgen to generate unit cell manually from user inputs
+
+ Args:
+ latt_params: (list of floats) list of lattice parameters. For example, for cubic: latt_params = [a],
+ for hexagonal: latt_params = [a, c], for monoclinic: latt_params = [a,b,c,beta],
+ and in general: latt_params = [a,b,c,alpha,beta,gamma]
+ elements: (list of strings) list of elements, for example for SnS: elements = ["Sn", "S"]
+ positions: (list) list of (x,y,z) positions for each element present in the elements, default: fractional coord
+ space_group: (optional) (string or int) space group of the crystal system, if specified, unit cell will be created using
+ pymatgen Structure.from_spacegroup function
+ lattice_type: (string) type of crystal family: cubic, hexagonal, triclinic etc; default: 'cubic'
+ from_cartesian: (bool) if True, positions will be considered as cartesian, default: False
+ conventional_standard_structure: (bool) if True, conventional standard unit cell will be returned
+ instead of the primitive unit cell pymatgen returns
+ Returns:
+ Crystal object
+
+ """
+
+ import pymatgen as mg
+
+ if lattice_type == "cubic":
+ assert (
+ len(latt_params) == 1
+ ), "Only 1 lattice parameter is expected for cubic: a, but given {}".format(
+ len(latt_params)
+ )
+ lattice = mg.core.Lattice.cubic(latt_params[0])
+ elif lattice_type == "hexagonal":
+ assert (
+ len(latt_params) == 2
+ ), "2 lattice parametere are expected for hexagonal: a, c, but given {len(latt_params)}".format(
+ len(latt_params)
+ )
+ lattice = mg.core.Lattice.hexagonal(latt_params[0], latt_params[1])
+ elif lattice_type == "tetragonal":
+ assert (
+ len(latt_params) == 2
+ ), "2 lattice parametere are expected for tetragonal: a, c, but given {len(latt_params)}".format(
+ len(latt_params)
+ )
+ lattice = mg.core.Lattice.tetragonal(latt_params[0], latt_params[1])
+ elif lattice_type == "orthorhombic":
+ assert (
+ len(latt_params) == 3
+ ), "3 lattice parametere are expected for orthorhombic: a, b, c, but given {len(latt_params)}".format(
+ len(latt_params)
+ )
+ lattice = mg.core.Lattice.orthorhombic(
+ latt_params[0], latt_params[1], latt_params[2]
+ )
+ elif lattice_type == "monoclinic":
+ assert (
+ len(latt_params) == 4
+ ), "4 lattice parametere are expected for monoclinic: a, b, c, beta, but given {len(latt_params)}".format(
+ len(latt_params)
+ )
+ lattice = mg.core.Lattice.monoclinic(
+ latt_params[0], latt_params[1], latt_params[2], latt_params[3]
+ )
+ else:
+ assert (
+ len(latt_params) == 6
+ ), "all 6 lattice parametere are expected: a, b, c, alpha, beta, gamma, but given {len(latt_params)}".format(
+ len(latt_params)
+ )
+ lattice = mg.core.Lattice.from_parameters(
+ latt_params[0],
+ latt_params[1],
+ latt_params[2],
+ latt_params[3],
+ latt_params[4],
+ latt_params[5],
+ )
+
+ if space_group:
+ structure = mg.core.Structure.from_spacegroup(
+ space_group,
+ lattice,
+ elements,
+ positions,
+ coords_are_cartesian=from_cartesian,
+ )
+ else:
+ structure = mg.core.Structure(
+ lattice, elements, positions, coords_are_cartesian=from_cartesian
+ )
+
+ return Crystal.from_pymatgen_structure(structure)
+
+ def setup_diffraction(self, accelerating_voltage: float):
+ """
+ Set up attributes used for diffraction calculations without going
+ through the full ACOM pipeline.
+ """
+ self.accel_voltage = accelerating_voltage
+ self.wavelength = electron_wavelength_angstrom(self.accel_voltage)
+
+ def calculate_structure_factors(
+ self,
+ k_max: float = 2.0,
+ tol_structure_factor: float = 1e-4,
+ return_intensities: bool = False,
+ ):
+ """
+ Calculate structure factors for all hkl indices up to max scattering vector k_max
+
+ Parameters
+ --------
+
+ k_max: float
+ max scattering vector to include (1/Angstroms)
+ tol_structure_factor: float
+ tolerance for removing low-valued structure factors
+ return_intensities: bool
+ return the intensities and positions of all structure factor peaks.
+
+ Returns
+ --------
+ (q_SF, I_SF)
+ Tuple of the q vectors and intensities of each structure factor.
+ """
+
+ # Store k_max
+ self.k_max = np.asarray(k_max)
+
+ # Find shortest lattice vector direction
+ k_test = np.vstack(
+ [
+ self.lat_inv[0, :],
+ self.lat_inv[1, :],
+ self.lat_inv[2, :],
+ self.lat_inv[0, :] + self.lat_inv[1, :],
+ self.lat_inv[0, :] + self.lat_inv[2, :],
+ self.lat_inv[1, :] + self.lat_inv[2, :],
+ self.lat_inv[0, :] + self.lat_inv[1, :] + self.lat_inv[2, :],
+ self.lat_inv[0, :] - self.lat_inv[1, :] + self.lat_inv[2, :],
+ self.lat_inv[0, :] + self.lat_inv[1, :] - self.lat_inv[2, :],
+ self.lat_inv[0, :] - self.lat_inv[1, :] - self.lat_inv[2, :],
+ ]
+ )
+ k_leng_min = np.min(np.linalg.norm(k_test, axis=1))
+
+ # Tile lattice vectors
+ num_tile = np.ceil(self.k_max / k_leng_min)
+ ya, xa, za = np.meshgrid(
+ np.arange(-num_tile, num_tile + 1),
+ np.arange(-num_tile, num_tile + 1),
+ np.arange(-num_tile, num_tile + 1),
+ )
+ hkl = np.vstack([xa.ravel(), ya.ravel(), za.ravel()])
+ # g_vec_all = self.lat_inv @ hkl
+ g_vec_all = (hkl.T @ self.lat_inv).T
+
+ # Delete lattice vectors outside of k_max
+ keep = np.linalg.norm(g_vec_all, axis=0) <= self.k_max
+ self.hkl = hkl[:, keep]
+ self.g_vec_all = g_vec_all[:, keep]
+ self.g_vec_leng = np.linalg.norm(self.g_vec_all, axis=0)
+
+ # Calculate single atom scattering factors
+ # Note this can be sped up a lot, but we may want to generalize to allow non-1.0 occupancy in the future.
+ f_all = np.zeros(
+ (np.size(self.g_vec_leng, 0), self.positions.shape[0]), dtype="float_"
+ )
+ for a0 in range(self.positions.shape[0]):
+ atom_sf = single_atom_scatter([self.numbers[a0]], [1], self.g_vec_leng, "A")
+ atom_sf.get_scattering_factor([self.numbers[a0]], [1], self.g_vec_leng, "A")
+ f_all[:, a0] = atom_sf.fe
+
+ # Calculate structure factors
+ self.struct_factors = np.zeros(np.size(self.g_vec_leng, 0), dtype="complex64")
+ for a0 in range(self.positions.shape[0]):
+ self.struct_factors += f_all[:, a0] * np.exp(
+ (2j * np.pi)
+ * np.sum(
+ self.hkl * np.expand_dims(self.positions[a0, :], axis=1), axis=0
+ )
+ )
+
+ # Divide by unit cell volume
+ unit_cell_volume = np.abs(np.linalg.det(self.lat_real))
+ self.struct_factors /= unit_cell_volume
+
+ # Remove structure factors below tolerance level
+ keep = np.abs(self.struct_factors) > tol_structure_factor
+ self.hkl = self.hkl[:, keep]
+
+ self.g_vec_all = self.g_vec_all[:, keep]
+ self.g_vec_leng = self.g_vec_leng[keep]
+ self.struct_factors = self.struct_factors[keep]
+
+ # Structure factor intensities
+ self.struct_factors_int = np.abs(self.struct_factors) ** 2
+
+ if return_intensities:
+ q_SF = np.linspace(0, self.k_max, 250)
+ I_SF = np.zeros_like(q_SF)
+ for i in range(self.g_vec_leng.shape[0]):
+ idx = np.argmin(np.abs(q_SF - self.g_vec_leng[i]))
+ I_SF[idx] += self.struct_factors_int[i]
+ I_SF = I_SF / np.max(I_SF)
+
+ return (q_SF, I_SF)
+
+ def generate_diffraction_pattern(
+ self,
+ orientation: Optional[Orientation] = None,
+ ind_orientation: Optional[int] = 0,
+ orientation_matrix: Optional[np.ndarray] = None,
+ zone_axis_lattice: Optional[np.ndarray] = None,
+ proj_x_lattice: Optional[np.ndarray] = None,
+ foil_normal_lattice: Optional[Union[list, tuple, np.ndarray]] = None,
+ zone_axis_cartesian: Optional[np.ndarray] = None,
+ proj_x_cartesian: Optional[np.ndarray] = None,
+ foil_normal_cartesian: Optional[Union[list, tuple, np.ndarray]] = None,
+ sigma_excitation_error: float = 0.02,
+ tol_excitation_error_mult: float = 3,
+ tol_intensity: float = 1e-4,
+ k_max: Optional[float] = None,
+ keep_qz=False,
+ return_orientation_matrix=False,
+ ):
+ """
+ Generate a single diffraction pattern, return all peaks as a pointlist.
+
+ Args:
+ orientation (Orientation): an Orientation class object
+ ind_orientation If input is an Orientation class object with multiple orientations,
+ this input can be used to select a specific orientation.
+
+ orientation_matrix (array): (3,3) orientation matrix, where columns represent projection directions.
+ zone_axis_lattice (array): (3,) projection direction in lattice indices
+ proj_x_lattice (array): (3,) x-axis direction in lattice indices
+ zone_axis_cartesian (array): (3,) cartesian projection direction
+ proj_x_cartesian (array): (3,) cartesian projection direction
+
+ foil_normal: 3 element foil normal - set to None to use zone_axis
+ proj_x_axis (np float vector): 3 element vector defining image x axis (vertical)
+ accel_voltage (float): Accelerating voltage in Volts. If not specified,
+ we check to see if crystal already has voltage specified.
+ sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms
+ tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion
+ tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots
+ k_max (float): Maximum scattering vector
+ keep_qz (bool): Flag to return out-of-plane diffraction vectors
+ return_orientation_matrix (bool): Return the orientation matrix
+
+ Returns:
+ bragg_peaks (PointList): list of all Bragg peaks with fields [qx, qy, intensity, h, k, l]
+ orientation_matrix (array): 3x3 orientation matrix (optional)
+ """
+
+ if not (hasattr(self, "wavelength") and hasattr(self, "accel_voltage")):
+ print("Accelerating voltage not set. Assuming 300 keV!")
+ self.setup_diffraction(300e3)
+
+ # Tolerance for angular tests
+ tol = 1e-6
+
+ # Parse orientation inputs
+ if orientation is not None:
+ if ind_orientation is None:
+ orientation_matrix = orientation.matrix[0]
+ else:
+ orientation_matrix = orientation.matrix[ind_orientation]
+ elif orientation_matrix is None:
+ orientation_matrix = self.parse_orientation(
+ zone_axis_lattice, proj_x_lattice, zone_axis_cartesian, proj_x_cartesian
+ )
+
+ # Get foil normal direction
+ if foil_normal_lattice is not None:
+ foil_normal = self.lattice_to_cartesian(np.array(foil_normal_lattice))
+ elif foil_normal_cartesian is not None:
+ foil_normal = np.array(foil_normal_cartesian)
+ else:
+ foil_normal = None
+ # foil_normal = orientation_matrix[:,2]
+
+ # Rotate crystal into desired projection
+ g = orientation_matrix.T @ self.g_vec_all
+
+ # Calculate excitation errors
+ if foil_normal is None:
+ sg = self.excitation_errors(g)
+ else:
+ foil_normal = (
+ orientation_matrix.T
+ @ (-1 * foil_normal[:, None] / np.linalg.norm(foil_normal))
+ ).ravel()
+ sg = self.excitation_errors(g, foil_normal)
+
+ # Threshold for inclusion in diffraction pattern
+ sg_max = sigma_excitation_error * tol_excitation_error_mult
+ keep = np.abs(sg) <= sg_max
+
+ # Maximum scattering angle cutoff
+ if k_max is not None:
+ keep_kmax = np.linalg.norm(g, axis=0) <= k_max
+ keep = np.logical_and(keep, keep_kmax)
+
+ g_diff = g[:, keep]
+
+ # Diffracted peak intensities and labels
+ g_int = self.struct_factors_int[keep] * np.exp(
+ (sg[keep] ** 2) / (-2 * sigma_excitation_error**2)
+ )
+ hkl = self.hkl[:, keep]
+
+ # Intensity tolerance
+ keep_int = g_int > tol_intensity
+
+ # Output peaks
+ gx_proj = g_diff[0, keep_int]
+ gy_proj = g_diff[1, keep_int]
+
+ # Diffracted peak labels
+ h = hkl[0, keep_int]
+ k = hkl[1, keep_int]
+ l = hkl[2, keep_int]
+
+ # Output as PointList
+ if keep_qz:
+ gz_proj = g_diff[2, keep_int]
+ pl_dtype = np.dtype(
+ [
+ ("qx", "float64"),
+ ("qy", "float64"),
+ ("qz", "float64"),
+ ("intensity", "float64"),
+ ("h", "int"),
+ ("k", "int"),
+ ("l", "int"),
+ ]
+ )
+ bragg_peaks = PointList(np.array([], dtype=pl_dtype))
+ if np.any(keep_int):
+ bragg_peaks.add_data_by_field(
+ [gx_proj, gy_proj, gz_proj, g_int[keep_int], h, k, l]
+ )
+ else:
+ pl_dtype = np.dtype(
+ [
+ ("qx", "float64"),
+ ("qy", "float64"),
+ ("intensity", "float64"),
+ ("h", "int"),
+ ("k", "int"),
+ ("l", "int"),
+ ]
+ )
+ bragg_peaks = PointList(np.array([], dtype=pl_dtype))
+ if np.any(keep_int):
+ bragg_peaks.add_data_by_field(
+ [gx_proj, gy_proj, g_int[keep_int], h, k, l]
+ )
+
+ if return_orientation_matrix:
+ return bragg_peaks, orientation_matrix
+ else:
+ return bragg_peaks
+
+ def generate_ring_pattern(
+ self,
+ k_max=2.0,
+ use_bloch=False,
+ thickness=None,
+ bloch_params=None,
+ orientation_plan_params=None,
+ sigma_excitation_error=0.02,
+ tol_intensity=1e-3,
+ plot_rings=True,
+ plot_params={},
+ return_calc=True,
+ ):
+ """
+ Calculate polycrystalline diffraction pattern from structure
+
+ Args:
+ k_max (float): Maximum scattering vector
+ use_bloch (bool): if true, use dynamic instead of kinematic approach
+ thickness (float): thickness in Ångström to evaluate diffraction patterns,
+ only needed for dynamical calculations
+ bloch_params (dict): optional, parameters to calculate dynamical structure factor,
+ see calculate_dynamical_structure_factors doc strings
+ orientation_plan_params (dict): optional, parameters to calculate orientation plan,
+ see orientation_plan doc strings
+ sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors)
+ in units of inverse Angstroms
+ tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots
+ plot_rings(bool): if true, plot diffraction rings with plot_ring_pattern
+ return_calc (bool): return radii and intensities
+
+ Returns:
+ radii_unique (np array): radii of ring pattern in units of scattering vector k
+ intensity_unique (np array): intensity of rings weighted by frequency of diffraciton spots
+ """
+
+ if use_bloch:
+ assert (
+ thickness is not None
+ ), "provide thickness for dynamical diffraction calculation"
+ assert hasattr(
+ self, "Ug_dict"
+ ), "run calculate_dynamical_structure_factors first"
+
+ if not hasattr(self, "struct_factors"):
+ self.calculate_structure_factors(
+ k_max=k_max,
+ )
+
+ # check accelerating voltage
+ if hasattr(self, "accel_voltage"):
+ accelerating_voltage = self.accel_voltage
+ else:
+ self.accel_voltage = 300e3
+ print("Accelerating voltage not set. Assuming 300 keV!")
+
+ # check orientation plan
+ if not hasattr(self, "orientation_vecs"):
+ if orientation_plan_params is None:
+ orientation_plan_params = {
+ "zone_axis_range": "auto",
+ "angle_step_zone_axis": 4,
+ "angle_step_in_plane": 4,
+ }
+ self.orientation_plan(
+ **orientation_plan_params,
+ )
+
+ # calculate intensity and radius for rings
+ radii = []
+ intensity = []
+ for a0 in range(self.orientation_vecs.shape[0]):
+ if use_bloch:
+ beams = self.generate_diffraction_pattern(
+ zone_axis_lattice=self.orientation_vecs[a0],
+ sigma_excitation_error=sigma_excitation_error,
+ tol_intensity=tol_intensity,
+ k_max=k_max,
+ )
+ pattern = self.generate_dynamical_diffraction_pattern(
+ beams=beams,
+ zone_axis_lattice=self.orientation_vecs[a0],
+ thickness=thickness,
+ )
+ else:
+ pattern = self.generate_diffraction_pattern(
+ zone_axis_lattice=self.orientation_vecs[a0],
+ sigma_excitation_error=sigma_excitation_error,
+ tol_intensity=tol_intensity,
+ k_max=k_max,
+ )
+
+ intensity.append(pattern["intensity"])
+ radii.append((pattern["qx"] ** 2 + pattern["qy"] ** 2) ** 0.5)
+
+ intensity = np.concatenate(intensity)
+ radii = np.concatenate(radii)
+
+ radii_unique, idx, inv, cts = np.unique(
+ radii, return_counts=True, return_index=True, return_inverse=True
+ )
+ intensity_unique = np.bincount(inv, weights=intensity)
+
+ if plot_rings is True:
+ from py4DSTEM.process.diffraction.crystal_viz import plot_ring_pattern
+
+ plot_ring_pattern(radii_unique, intensity_unique, **plot_params)
+
+ if return_calc is True:
+ return radii_unique, intensity_unique
+
+ # Vector conversions and other utilities for Crystal classes
+ def cartesian_to_lattice(self, vec_cartesian):
+ vec_lattice = self.lat_inv @ vec_cartesian
+ return vec_lattice / np.linalg.norm(vec_lattice)
+
+ def lattice_to_cartesian(self, vec_lattice):
+ vec_cartesian = self.lat_real.T @ vec_lattice
+ return vec_cartesian / np.linalg.norm(vec_cartesian)
+
+ def hexagonal_to_lattice(self, vec_hexagonal):
+ return np.array(
+ [
+ 2.0 * vec_hexagonal[0] + vec_hexagonal[1],
+ 2.0 * vec_hexagonal[1] + vec_hexagonal[0],
+ vec_hexagonal[3],
+ ]
+ )
+
+ def lattice_to_hexagonal(self, vec_lattice):
+ return np.array(
+ [
+ (2.0 * vec_lattice[0] - vec_lattice[1]) / 3.0,
+ (2.0 * vec_lattice[1] - vec_lattice[0]) / 3.0,
+ (-vec_lattice[0] - vec_lattice[1]) / 3.0,
+ vec_lattice[2],
+ ]
+ )
+
+ def cartesian_to_miller(self, vec_cartesian):
+ vec_miller = self.lat_real.T @ self.metric_inv @ vec_cartesian
+ return vec_miller / np.linalg.norm(vec_miller)
+
+ def miller_to_cartesian(self, vec_miller):
+ vec_cartesian = self.lat_inv.T @ self.metric_real @ vec_miller
+ return vec_cartesian / np.linalg.norm(vec_cartesian)
+
+ def rational_ind(
+ self,
+ vec,
+ tol_den=1000,
+ ):
+ # This function rationalizes the indices of a vector, up to
+ # some tolerance. Returns integers to prevent rounding errors.
+ vec = np.array(vec, dtype="float64")
+ sub = np.abs(vec) > 0
+ if np.sum(sub) > 0:
+ for ind in np.argwhere(sub):
+ frac = Fraction(vec[ind[0]]).limit_denominator(tol_den)
+ vec *= np.round(frac.denominator)
+ vec = np.round(
+ vec / np.gcd.reduce(np.round(np.abs(vec[sub])).astype("int"))
+ ).astype("int")
+
+ return vec
+
+ def parse_orientation(
+ self,
+ zone_axis_lattice=None,
+ proj_x_lattice=None,
+ zone_axis_cartesian=None,
+ proj_x_cartesian=None,
+ ):
+ # This helper function parse the various types of orientation inputs,
+ # and returns the normalized, projected (x,y,z) cartesian vectors in
+ # the form of an orientation matrix.
+
+ if zone_axis_lattice is not None:
+ proj_z = np.array(zone_axis_lattice)
+ if proj_z.shape[0] == 4:
+ proj_z = self.hexagonal_to_lattice(proj_z)
+ proj_z = self.lattice_to_cartesian(proj_z)
+ elif zone_axis_cartesian is not None:
+ proj_z = np.array(zone_axis_cartesian)
+ else:
+ proj_z = np.array([0, 0, 1])
+
+ if proj_x_lattice is not None:
+ proj_x = np.array(proj_x_lattice)
+ if proj_x.shape[0] == 4:
+ proj_x = self.hexagonal_to_lattice(proj_x)
+ proj_x = self.lattice_to_cartesian(proj_x)
+ elif proj_x_cartesian is not None:
+ proj_x = np.array(proj_x_cartesian)
+ else:
+ if np.abs(proj_z[2]) > 1 - 1e-6:
+ proj_x = np.cross(np.array([0, 1, 0]), proj_z)
+ else:
+ proj_x = np.array([0, 0, -1])
+
+ # Generate orthogonal coordinate system, normalize
+ proj_y = np.cross(proj_z, proj_x)
+ proj_x = np.cross(proj_y, proj_z)
+ proj_x = proj_x / np.linalg.norm(proj_x)
+ proj_y = proj_y / np.linalg.norm(proj_y)
+ proj_z = proj_z / np.linalg.norm(proj_z)
+
+ return np.vstack((proj_x, proj_y, proj_z)).T
+
+ def excitation_errors(
+ self,
+ g,
+ foil_normal=None,
+ ):
+ """
+ Calculate the excitation errors, assuming k0 = [0, 0, -1/lambda].
+ If foil normal is not specified, we assume it is [0,0,-1].
+ """
+ if foil_normal is None:
+ return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / (
+ 2 - 2 * self.wavelength * g[2, :]
+ )
+ else:
+ return (2 * g[2, :] - self.wavelength * np.sum(g * g, axis=0)) / (
+ 2 * self.wavelength * np.sum(g * foil_normal[:, None], axis=0)
+ - 2 * foil_normal[2]
+ )
+
+ def calculate_bragg_peak_histogram(
+ self,
+ bragg_peaks,
+ bragg_k_power=1.0,
+ bragg_intensity_power=1.0,
+ k_min=0.0,
+ k_max=None,
+ k_step=0.005,
+ ):
+ """
+ Prepare experimental bragg peaks for lattice parameter or unit cell fitting.
+
+ Args:
+ bragg_peaks (BraggVectors): Input Bragg vectors.
+ bragg_k_power (float): Input Bragg peak intensities are multiplied by k**bragg_k_power
+ to change the weighting of longer scattering vectors
+ bragg_intensity_power (float): Input Bragg peak intensities are raised power **bragg_intensity_power.
+ k_min (float): min k value for fitting range (Å^-1)
+ k_max (float): max k value for fitting range (Å^-1)
+ k_step (float): step size of k in fitting range (Å^-1)
+
+ Returns:
+ bragg_peaks_cali (BraggVectors): Bragg vectors after calibration
+ fig, ax (handles): Optional figure and axis handles, if returnfig=True.
+ """
+
+ # k coordinates
+ if k_max is None:
+ k_max = self.k_max
+ k = np.arange(k_min, k_max + k_step, k_step)
+ k_num = k.shape[0]
+
+ # set rotate and ellipse based on their availability
+ rotate = bragg_peaks.calibration.get_QR_rotation_degrees()
+ ellipse = bragg_peaks.calibration.get_ellipse()
+ rotate = False if rotate is None else True
+ ellipse = False if ellipse is None else True
+
+ # concatenate all peaks
+ bigpl = np.concatenate(
+ [
+ bragg_peaks.get_vectors(
+ rx,
+ ry,
+ center=True,
+ ellipse=ellipse,
+ pixel=True,
+ rotate=rotate,
+ ).data
+ for rx in range(bragg_peaks.shape[0])
+ for ry in range(bragg_peaks.shape[1])
+ ]
+ )
+ qr = np.sqrt(bigpl["qx"] ** 2 + bigpl["qy"] ** 2)
+ int_meas = bigpl["intensity"]
+
+ # get discrete plot from structure factor amplitudes
+ int_exp = np.zeros_like(k)
+ k_px = (qr - k_min) / k_step
+ kf = np.floor(k_px).astype("int")
+ dk = k_px - kf
+
+ sub = np.logical_and(kf >= 0, kf < k_num)
+ int_exp = np.bincount(
+ np.floor(k_px[sub]).astype("int"),
+ weights=(1 - dk[sub]) * int_meas[sub],
+ minlength=k_num,
+ )
+ sub = np.logical_and(k_px >= -1, k_px < k_num - 1)
+ int_exp += np.bincount(
+ np.floor(k_px[sub] + 1).astype("int"),
+ weights=dk[sub] * int_meas[sub],
+ minlength=k_num,
+ )
+ int_exp = (int_exp**bragg_intensity_power) * (k**bragg_k_power)
+ int_exp /= np.max(int_exp)
+ return k, int_exp
+
+
+def generate_moire_diffraction_pattern(
+ bragg_peaks_0,
+ bragg_peaks_1,
+ thresh_0=0.0002,
+ thresh_1=0.0002,
+ exx_1=0.0,
+ eyy_1=0.0,
+ exy_1=0.0,
+ phi_1=0.0,
+ power=2.0,
+):
+ """
+ Calculate a Moire lattice from 2 parent diffraction patterns. The second lattice can be rotated
+ and strained with respect to the original lattice. Note that this strain is applied in real space,
+ and so the inverse of the calculated infinitestimal strain tensor is applied.
+
+ Parameters
+ --------
+ bragg_peaks_0: BraggVector
+ Bragg vectors for parent lattice 0.
+ bragg_peaks_1: BraggVector
+ Bragg vectors for parent lattice 1.
+ thresh_0: float
+ Intensity threshold for structure factors from lattice 0.
+ thresh_1: float
+ Intensity threshold for structure factors from lattice 1.
+ exx_1: float
+ Strain of lattice 1 in x direction (vertical) in real space.
+ eyy_1: float
+ Strain of lattice 1 in y direction (horizontal) in real space.
+ exy_1: float
+ Shear strain of lattice 1 in (x,y) direction (diagonal) in real space.
+ phi_1: float
+ Rotation of lattice 1 in real space.
+ power: float
+ Plotting power law (default is amplitude**2.0, i.e. intensity).
+
+ Returns
+ --------
+ parent_peaks_0, parent_peaks_1, moire_peaks: BraggVectors
+ Bragg vectors for the rotated & strained parent lattices
+ and the moire lattice
+
+ """
+
+ # get intenties of all peaks
+ int0 = bragg_peaks_0["intensity"] ** (power / 2.0)
+ int1 = bragg_peaks_1["intensity"] ** (power / 2.0)
+
+ # peaks above threshold
+ sub0 = int0 >= thresh_0
+ sub1 = int1 >= thresh_1
+
+ # Remove origin (assuming brightest peak)
+ ind0_or = np.argmax(bragg_peaks_0["intensity"])
+ ind1_or = np.argmax(bragg_peaks_1["intensity"])
+ sub0[ind0_or] = False
+ sub1[ind1_or] = False
+ int0_sub = int0[sub0]
+ int1_sub = int1[sub1]
+
+ # Get peaks
+ qx0 = bragg_peaks_0["qx"][sub0]
+ qy0 = bragg_peaks_0["qy"][sub0]
+ qx1_init = bragg_peaks_1["qx"][sub1]
+ qy1_init = bragg_peaks_1["qy"][sub1]
+
+ # peak labels
+ h0 = bragg_peaks_0["h"][sub0]
+ k0 = bragg_peaks_0["k"][sub0]
+ l0 = bragg_peaks_0["l"][sub0]
+ h1 = bragg_peaks_1["h"][sub1]
+ k1 = bragg_peaks_1["k"][sub1]
+ l1 = bragg_peaks_1["l"][sub1]
+
+ # apply strain tensor to lattice 1
+ m = np.array(
+ [
+ [np.cos(phi_1), -np.sin(phi_1)],
+ [np.sin(phi_1), np.cos(phi_1)],
+ ]
+ ) @ np.linalg.inv(
+ np.array(
+ [
+ [1 + exx_1, exy_1 * 0.5],
+ [exy_1 * 0.5, 1 + eyy_1],
+ ]
+ )
+ )
+ qx1 = m[0, 0] * qx1_init + m[0, 1] * qy1_init
+ qy1 = m[1, 0] * qx1_init + m[1, 1] * qy1_init
+
+ # Generate moire lattice
+ ind0, ind1 = np.meshgrid(
+ np.arange(np.sum(sub0)),
+ np.arange(np.sum(sub1)),
+ indexing="ij",
+ )
+ qx = qx0[ind0] + qx1[ind1]
+ qy = qy0[ind0] + qy1[ind1]
+ int_moire = (int0_sub[ind0] * int1_sub[ind1]) ** 0.5
+
+ # moire labels
+ m_h0 = h0[ind0]
+ m_k0 = k0[ind0]
+ m_l0 = l0[ind0]
+ m_h1 = h1[ind1]
+ m_k1 = k1[ind1]
+ m_l1 = l1[ind1]
+
+ # Convert thresholded and moire peaks to BraggVector class
+
+ pl_dtype_parent = np.dtype(
+ [
+ ("qx", "float"),
+ ("qy", "float"),
+ ("intensity", "float"),
+ ("h", "int"),
+ ("k", "int"),
+ ("l", "int"),
+ ]
+ )
+
+ bragg_parent_0 = PointList(np.array([], dtype=pl_dtype_parent))
+ bragg_parent_0.add_data_by_field(
+ [
+ qx0.ravel(),
+ qy0.ravel(),
+ int0_sub.ravel(),
+ h0.ravel(),
+ k0.ravel(),
+ l0.ravel(),
+ ]
+ )
+
+ bragg_parent_1 = PointList(np.array([], dtype=pl_dtype_parent))
+ bragg_parent_1.add_data_by_field(
+ [
+ qx1.ravel(),
+ qy1.ravel(),
+ int1_sub.ravel(),
+ h1.ravel(),
+ k1.ravel(),
+ l1.ravel(),
+ ]
+ )
+
+ pl_dtype = np.dtype(
+ [
+ ("qx", "float"),
+ ("qy", "float"),
+ ("intensity", "float"),
+ ("h0", "int"),
+ ("k0", "int"),
+ ("l0", "int"),
+ ("h1", "int"),
+ ("k1", "int"),
+ ("l1", "int"),
+ ]
+ )
+ bragg_moire = PointList(np.array([], dtype=pl_dtype))
+ bragg_moire.add_data_by_field(
+ [
+ qx.ravel(),
+ qy.ravel(),
+ int_moire.ravel(),
+ m_h0.ravel(),
+ m_k0.ravel(),
+ m_l0.ravel(),
+ m_h1.ravel(),
+ m_k1.ravel(),
+ m_l1.ravel(),
+ ]
+ )
+
+ return bragg_parent_0, bragg_parent_1, bragg_moire
+
+
+def plot_moire_diffraction_pattern(
+ bragg_parent_0,
+ bragg_parent_1,
+ bragg_moire,
+ int_range=(0, 5e-3),
+ k_max=1.0,
+ plot_subpixel=True,
+ labels=None,
+ marker_size_parent=16,
+ marker_size_moire=4,
+ text_size_parent=10,
+ text_size_moire=6,
+ add_labels_parent=False,
+ add_labels_moire=False,
+ dist_labels=0.03,
+ dist_check=0.06,
+ sep_labels=0.03,
+ figsize=(8, 6),
+ returnfig=False,
+):
+ """
+ Plot Moire lattice and parent lattices.
+
+ Parameters
+ --------
+ bragg_peaks_0: BraggVector
+ Bragg vectors for parent lattice 0.
+ bragg_peaks_1: BraggVector
+ Bragg vectors for parent lattice 1.
+ bragg_moire: BraggVector
+ Bragg vectors for moire lattice.
+ int_range: (float, float)
+ Plotting intensity range for the Moire peaks.
+ k_max: float
+ Max k value of the plotted Moire lattice.
+ plot_subpixel: bool
+ Apply subpixel corrections to the Bragg spot positions.
+ Matplotlib default scatter plot rounds to the nearest pixel.
+ labels: list
+ List of text labels for parent lattices
+ marker_size_parent: float
+ Size of plot markers for the two parent lattices.
+ marker_size_moire: float
+ Size of plot markers for the Moire lattice.
+ text_size_parent: float
+ Label text size for parent lattice.
+ text_size_moire: float
+ Label text size for Moire lattice.
+ add_labels_parent: bool
+ Plot the parent lattice index labels.
+ add_labels_moire: bool
+ Plot the parent lattice index labels for the Moire spots.
+ dist_labels: float
+ Distance to move the labels off the spots.
+ dist_check: float
+ Set to some distance to "push" the labels away from each other if they are within this distance.
+ sep_labels: float
+ Separation distance for labels which are "pushed" apart.
+ figsize: (float,float)
+ Size of output figure.
+ returnfig: bool
+ Return the (fix,ax) handles of the plot.
+
+ Returns
+ --------
+ fig, ax: matplotlib handles (optional)
+ Figure and axes handles for the moire plot.
+ """
+
+ # peak labels
+
+ if labels is None:
+ labels = ("crystal 0", "crystal 1")
+
+ def overline(x):
+ return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}")
+
+ # parent 1
+ qx0 = bragg_parent_0["qx"]
+ qy0 = bragg_parent_0["qy"]
+ h0 = bragg_parent_0["h"]
+ k0 = bragg_parent_0["k"]
+ l0 = bragg_parent_0["l"]
+
+ # parent 2
+ qx1 = bragg_parent_1["qx"]
+ qy1 = bragg_parent_1["qy"]
+ h1 = bragg_parent_1["h"]
+ k1 = bragg_parent_1["k"]
+ l1 = bragg_parent_1["l"]
+
+ # moire
+ qx = bragg_moire["qx"]
+ qy = bragg_moire["qy"]
+ m_h0 = bragg_moire["h0"]
+ m_k0 = bragg_moire["k0"]
+ m_l0 = bragg_moire["l0"]
+ m_h1 = bragg_moire["h1"]
+ m_k1 = bragg_moire["k1"]
+ m_l1 = bragg_moire["l1"]
+ int_moire = bragg_moire["intensity"]
+
+ fig = plt.figure(figsize=figsize)
+ ax = fig.add_axes([0.09, 0.09, 0.65, 0.9])
+ ax_labels = fig.add_axes([0.75, 0, 0.25, 1])
+
+ text_params_parent = {
+ "ha": "center",
+ "va": "center",
+ "family": "sans-serif",
+ "fontweight": "normal",
+ "size": text_size_parent,
+ }
+ text_params_moire = {
+ "ha": "center",
+ "va": "center",
+ "family": "sans-serif",
+ "fontweight": "normal",
+ "size": text_size_moire,
+ }
+
+ if plot_subpixel is False:
+ # moire
+ ax.scatter(
+ qy,
+ qx,
+ # color = (0,0,0,1),
+ c=int_moire,
+ s=marker_size_moire,
+ cmap="gray_r",
+ vmin=int_range[0],
+ vmax=int_range[1],
+ antialiased=True,
+ )
+
+ # parent lattices
+ ax.scatter(
+ qy0,
+ qx0,
+ color=(1, 0, 0, 1),
+ s=marker_size_parent,
+ antialiased=True,
+ )
+ ax.scatter(
+ qy1,
+ qx1,
+ color=(0, 0.7, 1, 1),
+ s=marker_size_parent,
+ antialiased=True,
+ )
+
+ # origin
+ ax.scatter(
+ 0,
+ 0,
+ color=(0, 0, 0, 1),
+ s=marker_size_parent,
+ antialiased=True,
+ )
+
+ else:
+ # moire peaks
+ int_all = np.clip(
+ (int_moire - int_range[0]) / (int_range[1] - int_range[0]), 0, 1
+ )
+ keep = np.logical_and.reduce(
+ (qx >= -k_max, qx <= k_max, qy >= -k_max, qy <= k_max)
+ )
+ for x, y, int_marker in zip(qx[keep], qy[keep], int_all[keep]):
+ ax.add_artist(
+ Circle(
+ xy=(y, x),
+ radius=np.sqrt(marker_size_moire) / 800.0,
+ color=(1 - int_marker, 1 - int_marker, 1 - int_marker),
+ )
+ )
+ if add_labels_moire:
+ for a0 in range(qx.size):
+ if keep.ravel()[a0]:
+ x0 = qx.ravel()[a0]
+ y0 = qy.ravel()[a0]
+ d2 = (qx.ravel() - x0) ** 2 + (qy.ravel() - y0) ** 2
+ sub = d2 < dist_check**2
+ xc = np.mean(qx.ravel()[sub])
+ yc = np.mean(qy.ravel()[sub])
+ xp = x0 - xc
+ yp = y0 - yc
+ if xp == 0 and yp == 0.0:
+ xp = x0 - dist_labels
+ yp = y0
+ else:
+ leng = np.linalg.norm((xp, yp))
+ xp = x0 + xp * dist_labels / leng
+ yp = y0 + yp * dist_labels / leng
+
+ ax.text(
+ yp,
+ xp - sep_labels,
+ "$"
+ + overline(m_h0.ravel()[a0])
+ + overline(m_k0.ravel()[a0])
+ + overline(m_l0.ravel()[a0])
+ + "$",
+ c="r",
+ **text_params_moire,
+ )
+ ax.text(
+ yp,
+ xp,
+ "$"
+ + overline(m_h1.ravel()[a0])
+ + overline(m_k1.ravel()[a0])
+ + overline(m_l1.ravel()[a0])
+ + "$",
+ c=(0, 0.7, 1.0),
+ **text_params_moire,
+ )
+
+ keep = np.logical_and.reduce(
+ (qx0 >= -k_max, qx0 <= k_max, qy0 >= -k_max, qy0 <= k_max)
+ )
+ for x, y in zip(qx0[keep], qy0[keep]):
+ ax.add_artist(
+ Circle(
+ xy=(y, x),
+ radius=np.sqrt(marker_size_parent) / 800.0,
+ color=(1, 0, 0),
+ )
+ )
+ if add_labels_parent:
+ for a0 in range(qx0.size):
+ if keep.ravel()[a0]:
+ xp = qx0.ravel()[a0] - dist_labels
+ yp = qy0.ravel()[a0]
+ ax.text(
+ yp,
+ xp,
+ "$"
+ + overline(h0.ravel()[a0])
+ + overline(k0.ravel()[a0])
+ + overline(l0.ravel()[a0])
+ + "$",
+ c="k",
+ **text_params_parent,
+ )
+
+ keep = np.logical_and.reduce(
+ (qx1 >= -k_max, qx1 <= k_max, qy1 >= -k_max, qy1 <= k_max)
+ )
+ for x, y in zip(qx1[keep], qy1[keep]):
+ ax.add_artist(
+ Circle(
+ xy=(y, x),
+ radius=np.sqrt(marker_size_parent) / 800.0,
+ color=(0, 0.7, 1),
+ )
+ )
+ if add_labels_parent:
+ for a0 in range(qx1.size):
+ if keep.ravel()[a0]:
+ xp = qx1.ravel()[a0] - dist_labels
+ yp = qy1.ravel()[a0]
+ ax.text(
+ yp,
+ xp,
+ "$"
+ + overline(h1.ravel()[a0])
+ + overline(k1.ravel()[a0])
+ + overline(l1.ravel()[a0])
+ + "$",
+ c="k",
+ **text_params_parent,
+ )
+
+ # origin
+ ax.add_artist(
+ Circle(
+ xy=(0, 0),
+ radius=np.sqrt(marker_size_parent) / 800.0,
+ color=(0, 0, 0),
+ )
+ )
+
+ ax.set_xlim((-k_max, k_max))
+ ax.set_ylim((-k_max, k_max))
+ ax.set_ylabel("$q_x$ (1/A)")
+ ax.set_xlabel("$q_y$ (1/A)")
+ ax.invert_yaxis()
+
+ # labels
+ ax_labels.scatter(
+ 0,
+ 0,
+ color=(1, 0, 0, 1),
+ s=marker_size_parent,
+ )
+ ax_labels.scatter(
+ 0,
+ -1,
+ color=(0, 0.7, 1, 1),
+ s=marker_size_parent,
+ )
+ ax_labels.scatter(
+ 0,
+ -2,
+ color=(0, 0, 0, 1),
+ s=marker_size_moire,
+ )
+ ax_labels.text(
+ 0.4,
+ -0.2,
+ labels[0],
+ fontsize=14,
+ )
+ ax_labels.text(
+ 0.4,
+ -1.2,
+ labels[1],
+ fontsize=14,
+ )
+ ax_labels.text(
+ 0.4,
+ -2.2,
+ "Moiré lattice",
+ fontsize=14,
+ )
+
+ ax_labels.set_xlim((-1, 4))
+ ax_labels.set_ylim((-21, 1))
+
+ ax_labels.axis("off")
+
+ if returnfig:
+ return fig, ax
diff --git a/py4DSTEM/process/diffraction/crystal_ACOM.py b/py4DSTEM/process/diffraction/crystal_ACOM.py
new file mode 100644
index 000000000..09ba51ffc
--- /dev/null
+++ b/py4DSTEM/process/diffraction/crystal_ACOM.py
@@ -0,0 +1,2555 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from typing import Union, Optional
+from tqdm import tqdm
+
+from emdfile import tqdmnd, PointList, PointListArray
+from py4DSTEM.data import RealSlice
+from py4DSTEM.process.diffraction.utils import Orientation, OrientationMap, axisEqual3D
+from py4DSTEM.process.utils import electron_wavelength_angstrom
+
+from warnings import warn
+
+from numpy.linalg import lstsq
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = None
+
+
+def orientation_plan(
+ self,
+ zone_axis_range: np.ndarray = np.array([[0, 1, 1], [1, 1, 1]]),
+ angle_step_zone_axis: float = 2.0,
+ angle_coarse_zone_axis: float = None,
+ angle_refine_range: float = None,
+ angle_step_in_plane: float = 2.0,
+ accel_voltage: float = 300e3,
+ corr_kernel_size: float = 0.08,
+ radial_power: float = 1.0,
+ intensity_power: float = 0.25, # New default intensity power scaling
+ calculate_correlation_array=True,
+ tol_peak_delete=None,
+ tol_distance: float = 0.01,
+ fiber_axis=None,
+ fiber_angles=None,
+ figsize: Union[list, tuple, np.ndarray] = (6, 6),
+ CUDA: bool = False,
+ progress_bar: bool = True,
+):
+ """
+ Calculate the rotation basis arrays for an SO(3) rotation correlogram.
+
+ Args:
+ zone_axis_range (float): Row vectors give the range for zone axis orientations.
+ If user specifies 2 vectors (2x3 array), we start at [0,0,1]
+ to make z-x-z rotation work.
+ If user specifies 3 vectors (3x3 array), plan will span these vectors.
+ Setting to 'full' as a string will use a hemispherical range.
+ Setting to 'half' as a string will use a quarter sphere range.
+ Setting to 'fiber' as a string will make a spherical cap around a given vector.
+ Setting to 'auto' will use pymatgen to determine the point group symmetry
+ of the structure and choose an appropriate zone_axis_range
+ angle_step_zone_axis (float): Approximate angular step size for zone axis search [degrees]
+ angle_coarse_zone_axis (float): Coarse step size for zone axis search [degrees]. Setting to
+ None uses the same value as angle_step_zone_axis.
+ angle_refine_range (float): Range of angles to use for zone axis refinement. Setting to
+ None uses same value as angle_coarse_zone_axis.
+
+ angle_step_in_plane (float): Approximate angular step size for in-plane rotation [degrees]
+ accel_voltage (float): Accelerating voltage for electrons [Volts]
+ corr_kernel_size (float): Correlation kernel size length in Angstroms
+ radial_power (float): Power for scaling the correlation intensity as a function of the peak radius
+ intensity_power (float): Power for scaling the correlation intensity as a function of the peak intensity
+ calculate_correlation_array (bool): Set to false to skip calculating the correlation array.
+ This is useful when we only want the angular range / rotation matrices.
+ tol_peak_delete (float): Distance to delete peaks for multiple matches.
+ Default is kernel_size * 0.5
+ tol_distance (float): Distance tolerance for radial shell assignment [1/Angstroms]
+ fiber_axis (float): (3,) vector specifying the fiber axis
+ fiber_angles (float): (2,) vector specifying angle range from fiber axis, and in-plane angular range [degrees]
+ cartesian_directions (bool): When set to true, all zone axes and projection directions
+ are specified in Cartesian directions.
+ figsize (float): (2,) vector giving the figure size
+ CUDA (bool): Use CUDA for the Fourier operations.
+ progress_bar (bool): If false no progress bar is displayed
+ """
+
+ # Store inputs
+ self.accel_voltage = np.asarray(accel_voltage)
+ self.orientation_kernel_size = np.asarray(corr_kernel_size)
+ if tol_peak_delete is None:
+ self.orientation_tol_peak_delete = self.orientation_kernel_size * 0.5
+ else:
+ self.orientation_tol_peak_delete = np.asarray(tol_peak_delete)
+ if fiber_axis is None:
+ self.orientation_fiber_axis = None
+ else:
+ self.orientation_fiber_axis = np.asarray(fiber_axis)
+ if fiber_angles is None:
+ self.orientation_fiber_angles = None
+ else:
+ self.orientation_fiber_angles = np.asarray(fiber_angles)
+ self.CUDA = CUDA
+
+ # Calculate wavelenth
+ self.wavelength = electron_wavelength_angstrom(self.accel_voltage)
+
+ # store the radial and intensity scaling to use later for generating test patterns
+ self.orientation_radial_power = radial_power
+ self.orientation_intensity_power = intensity_power
+
+ # Calculate the ratio between coarse and fine refinement
+ if angle_coarse_zone_axis is not None:
+ self.orientation_refine = True
+ self.orientation_refine_ratio = np.round(
+ angle_coarse_zone_axis / angle_step_zone_axis
+ ).astype("int")
+ self.orientation_angle_coarse = angle_coarse_zone_axis
+ if angle_refine_range is None:
+ self.orientation_refine_range = angle_coarse_zone_axis
+ else:
+ self.orientation_refine_range = angle_refine_range
+ else:
+ self.orientation_refine_ratio = 1.0
+ self.orientation_refine = False
+
+ if self.pymatgen_available:
+ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
+ from pymatgen.core.structure import Structure
+
+ structure = Structure(
+ self.lat_real, self.numbers, self.positions, coords_are_cartesian=False
+ )
+ self.pointgroup = SpacegroupAnalyzer(structure)
+
+ # Handle the "auto" case first, since it works by overriding zone_axis_range,
+ # fiber_axis, and fiber_angles then using the regular parser:
+ if isinstance(zone_axis_range, str) and zone_axis_range == "auto":
+ assert (
+ self.pointgroup.get_point_group_symbol() in orientation_ranges
+ ), "Unrecognized pointgroup returned by pymatgen!"
+
+ zone_axis_range, fiber_axis, fiber_angles = orientation_ranges[
+ self.pointgroup.get_point_group_symbol()
+ ]
+ if isinstance(zone_axis_range, list):
+ zone_axis_range = np.array(zone_axis_range)
+ elif zone_axis_range == "fiber":
+ self.orientation_fiber_axis = np.asarray(fiber_axis)
+ self.orientation_fiber_angles = np.asarray(fiber_angles)
+
+ print(
+ f"Automatically detected point group {self.pointgroup.get_point_group_symbol()},\n"
+ f" using arguments: zone_axis_range = \n{zone_axis_range}, \n fiber_axis={fiber_axis}, fiber_angles={fiber_angles}."
+ )
+
+ if isinstance(zone_axis_range, str):
+ if (
+ zone_axis_range == "fiber"
+ and fiber_axis is not None
+ and fiber_angles is not None
+ ):
+ # Determine vector ranges
+ self.orientation_fiber_axis = np.array(
+ self.orientation_fiber_axis, dtype="float"
+ )
+ # if self.cartesian_directions:
+ self.orientation_fiber_axis = self.orientation_fiber_axis / np.linalg.norm(
+ self.orientation_fiber_axis
+ )
+
+ # update fiber axis to be centered on the 1st unit cell vector
+ v3 = np.cross(self.orientation_fiber_axis, self.lat_real[0, :])
+ v2 = np.cross(
+ v3,
+ self.orientation_fiber_axis,
+ )
+ v2 = v2 / np.linalg.norm(v2)
+ v3 = v3 / np.linalg.norm(v3)
+
+ if self.orientation_fiber_angles[0] == 0:
+ self.orientation_zone_axis_range = np.vstack(
+ (self.orientation_fiber_axis, v2, v3)
+ ).astype("float")
+ else:
+ if self.orientation_fiber_angles[0] == 180:
+ theta = np.pi / 2.0
+ else:
+ theta = self.orientation_fiber_angles[0] * np.pi / 180.0
+ if (
+ self.orientation_fiber_angles[1] == 180
+ or self.orientation_fiber_angles[1] == 360
+ ):
+ phi = np.pi / 2.0
+ else:
+ phi = self.orientation_fiber_angles[1] * np.pi / 180.0
+
+ # Generate zone axis range
+ v2output = self.orientation_fiber_axis * np.cos(theta) + v2 * np.sin(
+ theta
+ )
+ v3output = (
+ self.orientation_fiber_axis * np.cos(theta)
+ + (v2 * np.sin(theta)) * np.cos(phi)
+ + (v3 * np.sin(theta)) * np.sin(phi)
+ )
+ v2output = (
+ self.orientation_fiber_axis * np.cos(theta)
+ + (v2 * np.sin(theta)) * np.cos(phi / 2)
+ - (v3 * np.sin(theta)) * np.sin(phi / 2)
+ )
+ v3output = (
+ self.orientation_fiber_axis * np.cos(theta)
+ + (v2 * np.sin(theta)) * np.cos(phi / 2)
+ + (v3 * np.sin(theta)) * np.sin(phi / 2)
+ )
+
+ self.orientation_zone_axis_range = np.vstack(
+ (self.orientation_fiber_axis, v2output, v3output)
+ ).astype("float")
+
+ self.orientation_full = False
+ self.orientation_half = False
+ self.orientation_fiber = True
+ else:
+ self.orientation_zone_axis_range = np.array(
+ [[0, 0, 1], [0, 1, 0], [1, 0, 0]]
+ )
+ if zone_axis_range == "full":
+ self.orientation_full = True
+ self.orientation_half = False
+ self.orientation_fiber = False
+ elif zone_axis_range == "half":
+ self.orientation_full = False
+ self.orientation_half = True
+ self.orientation_fiber = False
+ else:
+ if zone_axis_range == "fiber" and fiber_axis is None:
+ raise ValueError(
+ "For fiber zone axes, you must specify the fiber axis and angular ranges"
+ )
+ else:
+ raise ValueError(
+ "Zone axis range must be a 2x3 array, 3x3 array, or full, half or fiber"
+ )
+
+ else:
+ self.orientation_zone_axis_range = np.array(zone_axis_range, dtype="float")
+
+ # Define 3 vectors which span zone axis orientation range, normalize
+ if zone_axis_range.shape[0] == 3:
+ self.orientation_zone_axis_range = np.array(
+ self.orientation_zone_axis_range, dtype="float"
+ )
+ self.orientation_zone_axis_range[0, :] /= np.linalg.norm(
+ self.orientation_zone_axis_range[0, :]
+ )
+ self.orientation_zone_axis_range[1, :] /= np.linalg.norm(
+ self.orientation_zone_axis_range[1, :]
+ )
+ self.orientation_zone_axis_range[2, :] /= np.linalg.norm(
+ self.orientation_zone_axis_range[2, :]
+ )
+
+ elif zone_axis_range.shape[0] == 2:
+ self.orientation_zone_axis_range = np.vstack(
+ (
+ np.array([0, 0, 1]),
+ np.array(self.orientation_zone_axis_range, dtype="float"),
+ )
+ ).astype("float")
+ self.orientation_zone_axis_range[1, :] /= np.linalg.norm(
+ self.orientation_zone_axis_range[1, :]
+ )
+ self.orientation_zone_axis_range[2, :] /= np.linalg.norm(
+ self.orientation_zone_axis_range[2, :]
+ )
+ self.orientation_full = False
+ self.orientation_half = False
+ self.orientation_fiber = False
+
+ # Solve for number of angular steps in zone axis (rads)
+ angle_u_v = np.arccos(
+ np.sum(
+ self.orientation_zone_axis_range[0, :]
+ * self.orientation_zone_axis_range[1, :]
+ )
+ )
+ angle_u_w = np.arccos(
+ np.sum(
+ self.orientation_zone_axis_range[0, :]
+ * self.orientation_zone_axis_range[2, :]
+ )
+ )
+ step = np.maximum(
+ (180 / np.pi) * angle_u_v / angle_step_zone_axis,
+ (180 / np.pi) * angle_u_w / angle_step_zone_axis,
+ )
+ self.orientation_zone_axis_steps = (
+ np.round(step / self.orientation_refine_ratio) * self.orientation_refine_ratio
+ ).astype(np.integer)
+
+ if self.orientation_fiber and self.orientation_fiber_angles[0] == 0:
+ self.orientation_num_zones = int(1)
+ self.orientation_vecs = np.zeros((1, 3))
+ self.orientation_vecs[0, :] = self.orientation_zone_axis_range[0, :]
+ self.orientation_inds = np.zeros((1, 3), dtype="int")
+
+ else:
+ # Generate points spanning the zone axis range
+ # Calculate points along u and v using the SLERP formula
+ # https://en.wikipedia.org/wiki/Slerp
+ weights = np.linspace(0, 1, self.orientation_zone_axis_steps + 1)
+ pv = self.orientation_zone_axis_range[0, :] * np.sin(
+ (1 - weights[:, None]) * angle_u_v
+ ) / np.sin(angle_u_v) + self.orientation_zone_axis_range[1, :] * np.sin(
+ weights[:, None] * angle_u_v
+ ) / np.sin(
+ angle_u_v
+ )
+
+ # Calculate points along u and w using the SLERP formula
+ pw = self.orientation_zone_axis_range[0, :] * np.sin(
+ (1 - weights[:, None]) * angle_u_w
+ ) / np.sin(angle_u_w) + self.orientation_zone_axis_range[2, :] * np.sin(
+ weights[:, None] * angle_u_w
+ ) / np.sin(
+ angle_u_w
+ )
+
+ # Init array to hold all points
+ self.orientation_num_zones = (
+ (self.orientation_zone_axis_steps + 1)
+ * (self.orientation_zone_axis_steps + 2)
+ / 2
+ ).astype(np.integer)
+ self.orientation_vecs = np.zeros((self.orientation_num_zones, 3))
+ self.orientation_vecs[0, :] = self.orientation_zone_axis_range[0, :]
+ self.orientation_inds = np.zeros((self.orientation_num_zones, 3), dtype="int")
+
+ # Calculate zone axis points on the unit sphere with another application of SLERP,
+ # or circular arc SLERP for fiber texture
+ for a0 in np.arange(1, self.orientation_zone_axis_steps + 1):
+ inds = np.arange(a0 * (a0 + 1) / 2, a0 * (a0 + 1) / 2 + a0 + 1).astype(
+ np.integer
+ )
+
+ p0 = pv[a0, :]
+ p1 = pw[a0, :]
+
+ weights = np.linspace(0, 1, a0 + 1)
+
+ if self.orientation_fiber:
+ # For fiber texture, place points on circular arc perpendicular to the fiber axis
+ self.orientation_vecs[inds, :] = p0[None, :]
+
+ p_proj = (
+ np.dot(p0, self.orientation_fiber_axis)
+ * self.orientation_fiber_axis
+ )
+ p0_sub = p0 - p_proj
+ p1_sub = p1 - p_proj
+
+ angle_p_sub = np.arccos(
+ np.sum(p0_sub * p1_sub)
+ / np.linalg.norm(p0_sub)
+ / np.linalg.norm(p1_sub)
+ )
+
+ self.orientation_vecs[inds, :] = (
+ p_proj
+ + p0_sub[None, :]
+ * np.sin((1 - weights[:, None]) * angle_p_sub)
+ / np.sin(angle_p_sub)
+ + p1_sub[None, :]
+ * np.sin(weights[:, None] * angle_p_sub)
+ / np.sin(angle_p_sub)
+ )
+ else:
+ angle_p = np.arccos(np.sum(p0 * p1))
+
+ self.orientation_vecs[inds, :] = p0[None, :] * np.sin(
+ (1 - weights[:, None]) * angle_p
+ ) / np.sin(angle_p) + p1[None, :] * np.sin(
+ weights[:, None] * angle_p
+ ) / np.sin(
+ angle_p
+ )
+
+ self.orientation_inds[inds, 0] = a0
+ self.orientation_inds[inds, 1] = np.arange(a0 + 1)
+
+ if self.orientation_fiber and self.orientation_fiber_angles[0] == 180:
+ # Mirror about the equator of fiber_zone_axis
+ m = np.identity(3) - 2 * (
+ self.orientation_fiber_axis[:, None] @ self.orientation_fiber_axis[None, :]
+ )
+
+ vec_new = np.copy(self.orientation_vecs) @ m
+ orientation_sector = np.zeros(vec_new.shape[0], dtype="int")
+
+ keep = np.zeros(vec_new.shape[0], dtype="bool")
+ for a0 in range(keep.size):
+ if (
+ np.sqrt(
+ np.min(
+ np.sum((self.orientation_vecs - vec_new[a0, :]) ** 2, axis=1)
+ )
+ )
+ > tol_distance
+ ):
+ keep[a0] = True
+
+ self.orientation_vecs = np.vstack((self.orientation_vecs, vec_new[keep, :]))
+ self.orientation_num_zones = self.orientation_vecs.shape[0]
+
+ self.orientation_inds = np.vstack(
+ (self.orientation_inds, self.orientation_inds[keep, :])
+ ).astype("int")
+ self.orientation_inds[:, 2] = np.hstack(
+ (orientation_sector, np.ones(np.sum(keep), dtype="int"))
+ )
+
+ # Fiber texture angle 1 extend to 180 degree angular range if needed
+ if (
+ self.orientation_fiber
+ and self.orientation_fiber_angles[0] != 0
+ and (
+ self.orientation_fiber_angles[1] == 180
+ or self.orientation_fiber_angles[1] == 360
+ )
+ ):
+ # Mirror about the axes 0 and 1
+ n = np.cross(
+ self.orientation_zone_axis_range[0, :],
+ self.orientation_zone_axis_range[1, :],
+ )
+ n = n / np.linalg.norm(n)
+
+ # n = self.orientation_zone_axis_range[2,:]
+ m = np.identity(3) - 2 * (n[:, None] @ n[None, :])
+
+ vec_new = np.copy(self.orientation_vecs) @ m
+ orientation_sector = np.zeros(vec_new.shape[0], dtype="int")
+
+ keep = np.zeros(vec_new.shape[0], dtype="bool")
+ for a0 in range(keep.size):
+ if (
+ np.sqrt(
+ np.min(
+ np.sum((self.orientation_vecs - vec_new[a0, :]) ** 2, axis=1)
+ )
+ )
+ > tol_distance
+ ):
+ keep[a0] = True
+
+ self.orientation_vecs = np.vstack((self.orientation_vecs, vec_new[keep, :]))
+ self.orientation_num_zones = self.orientation_vecs.shape[0]
+
+ self.orientation_inds = np.vstack(
+ (self.orientation_inds, self.orientation_inds[keep, :])
+ ).astype("int")
+ self.orientation_inds[:, 2] = np.hstack(
+ (orientation_sector, np.ones(np.sum(keep), dtype="int"))
+ )
+ # Fiber texture extend to 360 angular range if needed
+ if (
+ self.orientation_fiber
+ and self.orientation_fiber_angles[0] != 0
+ and self.orientation_fiber_angles[1] == 360
+ ):
+ # Mirror about the axes 0 and 2
+ n = np.cross(
+ self.orientation_zone_axis_range[0, :],
+ self.orientation_zone_axis_range[2, :],
+ )
+ n = n / np.linalg.norm(n)
+
+ # n = self.orientation_zone_axis_range[2,:]
+ m = np.identity(3) - 2 * (n[:, None] @ n[None, :])
+
+ vec_new = np.copy(self.orientation_vecs) @ m
+ orientation_sector = np.zeros(vec_new.shape[0], dtype="int")
+
+ keep = np.zeros(vec_new.shape[0], dtype="bool")
+ for a0 in range(keep.size):
+ if (
+ np.sqrt(
+ np.min(
+ np.sum((self.orientation_vecs - vec_new[a0, :]) ** 2, axis=1)
+ )
+ )
+ > tol_distance
+ ):
+ keep[a0] = True
+
+ self.orientation_vecs = np.vstack((self.orientation_vecs, vec_new[keep, :]))
+ self.orientation_num_zones = self.orientation_vecs.shape[0]
+
+ self.orientation_inds = np.vstack(
+ (self.orientation_inds, self.orientation_inds[keep, :])
+ ).astype("int")
+ self.orientation_inds[:, 2] = np.hstack(
+ (orientation_sector, np.ones(np.sum(keep), dtype="int"))
+ )
+
+ # expand to quarter sphere if needed
+ if self.orientation_half or self.orientation_full:
+ vec_new = np.copy(self.orientation_vecs) * np.array([-1, 1, 1])
+ orientation_sector = np.zeros(vec_new.shape[0], dtype="int")
+
+ keep = np.zeros(vec_new.shape[0], dtype="bool")
+ for a0 in range(keep.size):
+ if (
+ np.sqrt(
+ np.min(
+ np.sum((self.orientation_vecs - vec_new[a0, :]) ** 2, axis=1)
+ )
+ )
+ > tol_distance
+ ):
+ keep[a0] = True
+
+ self.orientation_vecs = np.vstack((self.orientation_vecs, vec_new[keep, :]))
+ self.orientation_num_zones = self.orientation_vecs.shape[0]
+
+ self.orientation_inds = np.vstack(
+ (self.orientation_inds, self.orientation_inds[keep, :])
+ ).astype("int")
+ self.orientation_inds[:, 2] = np.hstack(
+ (orientation_sector, np.ones(np.sum(keep), dtype="int"))
+ )
+
+ # expand to hemisphere if needed
+ if self.orientation_full:
+ vec_new = np.copy(self.orientation_vecs) * np.array([1, -1, 1])
+
+ keep = np.zeros(vec_new.shape[0], dtype="bool")
+ for a0 in range(keep.size):
+ if (
+ np.sqrt(
+ np.min(
+ np.sum((self.orientation_vecs - vec_new[a0, :]) ** 2, axis=1)
+ )
+ )
+ > tol_distance
+ ):
+ keep[a0] = True
+
+ self.orientation_vecs = np.vstack((self.orientation_vecs, vec_new[keep, :]))
+ self.orientation_num_zones = self.orientation_vecs.shape[0]
+
+ orientation_sector = np.hstack(
+ (self.orientation_inds[:, 2], self.orientation_inds[keep, 2] + 2)
+ )
+ self.orientation_inds = np.vstack(
+ (self.orientation_inds, self.orientation_inds[keep, :])
+ ).astype("int")
+ self.orientation_inds[:, 2] = orientation_sector
+
+ # If needed, create coarse orientation sieve
+ if self.orientation_refine:
+ self.orientation_sieve = np.logical_and(
+ np.mod(self.orientation_inds[:, 0], self.orientation_refine_ratio) == 0,
+ np.mod(self.orientation_inds[:, 1], self.orientation_refine_ratio) == 0,
+ )
+ if self.CUDA:
+ self.orientation_sieve_CUDA = cp.asarray(self.orientation_sieve)
+
+ # Convert to spherical coordinates
+ elev = np.arctan2(
+ np.hypot(self.orientation_vecs[:, 0], self.orientation_vecs[:, 1]),
+ self.orientation_vecs[:, 2],
+ )
+ # azim = np.pi / 2 + np.arctan2(
+ # self.orientation_vecs[:, 1], self.orientation_vecs[:, 0]
+ # )
+ azim = np.arctan2(self.orientation_vecs[:, 0], self.orientation_vecs[:, 1])
+
+ # Solve for number of angular steps along in-plane rotation direction
+ self.orientation_in_plane_steps = np.round(360 / angle_step_in_plane).astype(
+ np.integer
+ )
+
+ # Calculate -z angles (Euler angle 3)
+ self.orientation_gamma = np.linspace(
+ 0, 2 * np.pi, self.orientation_in_plane_steps, endpoint=False
+ )
+
+ # Determine the radii of all spherical shells
+ radii_test = np.round(self.g_vec_leng / tol_distance) * tol_distance
+ radii = np.unique(radii_test)
+ # Remove zero beam
+ keep = np.abs(radii) > tol_distance
+ self.orientation_shell_radii = radii[keep]
+
+ # init
+ self.orientation_shell_index = -1 * np.ones(self.g_vec_all.shape[1], dtype="int")
+ self.orientation_shell_count = np.zeros(self.orientation_shell_radii.size)
+
+ # Assign each structure factor point to a radial shell
+ for a0 in range(self.orientation_shell_radii.size):
+ sub = np.abs(self.orientation_shell_radii[a0] - radii_test) <= tol_distance / 2
+
+ self.orientation_shell_index[sub] = a0
+ self.orientation_shell_count[a0] = np.sum(sub)
+ self.orientation_shell_radii[a0] = np.mean(self.g_vec_leng[sub])
+
+ # init storage arrays
+ self.orientation_rotation_angles = np.zeros((self.orientation_num_zones, 3))
+ self.orientation_rotation_matrices = np.zeros((self.orientation_num_zones, 3, 3))
+
+ # If possible, Get symmetry operations for this spacegroup, store in matrix form
+ if self.pymatgen_available:
+ # get operators
+ ops = self.pointgroup.get_point_group_operations()
+
+ # Inverse of lattice
+ zone_axis_range_inv = np.linalg.inv(self.orientation_zone_axis_range)
+
+ # init
+ num_sym = len(ops)
+ self.symmetry_operators = np.zeros((num_sym, 3, 3))
+ self.symmetry_reduction = np.zeros((num_sym, 3, 3))
+
+ # calculate symmetry and reduction matrices
+ for a0 in range(num_sym):
+ self.symmetry_operators[a0] = (
+ self.lat_inv.T @ ops[a0].rotation_matrix.T @ self.lat_real
+ )
+ self.symmetry_reduction[a0] = (
+ zone_axis_range_inv.T @ self.symmetry_operators[a0]
+ ).T
+
+ # Remove duplicates
+ keep = np.ones(num_sym, dtype="bool")
+ for a0 in range(num_sym):
+ if keep[a0]:
+ diff = np.sum(
+ np.abs(self.symmetry_operators - self.symmetry_operators[a0]),
+ axis=(1, 2),
+ )
+ sub = diff < 1e-3
+ sub[: a0 + 1] = False
+ keep[sub] = False
+ self.symmetry_operators = self.symmetry_operators[keep]
+ self.symmetry_reduction = self.symmetry_reduction[keep]
+
+ if (
+ self.orientation_fiber_angles is not None
+ and np.abs(self.orientation_fiber_angles[0] - 180.0) < 1e-3
+ ):
+ zone_axis_range_flip = self.orientation_zone_axis_range.copy()
+ zone_axis_range_flip[0, :] = -1 * zone_axis_range_flip[0, :]
+ zone_axis_range_inv = np.linalg.inv(zone_axis_range_flip)
+
+ num_sym = self.symmetry_operators.shape[0]
+ self.symmetry_operators = np.tile(self.symmetry_operators, [2, 1, 1])
+ self.symmetry_reduction = np.tile(self.symmetry_reduction, [2, 1, 1])
+
+ for a0 in range(num_sym):
+ self.symmetry_reduction[a0 + num_sym] = (
+ zone_axis_range_inv.T @ self.symmetry_operators[a0 + num_sym]
+ ).T
+
+ # Calculate rotation matrices for zone axes
+ for a0 in np.arange(self.orientation_num_zones):
+ m1z = np.array(
+ [
+ [np.cos(azim[a0]), np.sin(azim[a0]), 0],
+ [-np.sin(azim[a0]), np.cos(azim[a0]), 0],
+ [0, 0, 1],
+ ]
+ )
+ m2x = np.array(
+ [
+ [1, 0, 0],
+ [0, np.cos(elev[a0]), np.sin(elev[a0])],
+ [0, -np.sin(elev[a0]), np.cos(elev[a0])],
+ ]
+ )
+ m3z = np.array(
+ [
+ [np.cos(azim[a0]), -np.sin(azim[a0]), 0],
+ [np.sin(azim[a0]), np.cos(azim[a0]), 0],
+ [0, 0, 1],
+ ]
+ )
+ self.orientation_rotation_matrices[a0, :, :] = m1z @ m2x @ m3z
+ self.orientation_rotation_angles[a0, :] = [azim[a0], elev[a0], -azim[a0]]
+
+ # Calculate reference arrays for all orientations
+ k0 = np.array([0.0, 0.0, -1.0 / self.wavelength])
+ n = np.array([0.0, 0.0, -1.0])
+
+ if calculate_correlation_array:
+ # initialize empty correlation array
+ self.orientation_ref = np.zeros(
+ (
+ self.orientation_num_zones,
+ np.size(self.orientation_shell_radii),
+ self.orientation_in_plane_steps,
+ ),
+ dtype="float",
+ )
+
+ for a0 in tqdmnd(
+ np.arange(self.orientation_num_zones),
+ desc="Orientation plan",
+ unit=" zone axes",
+ disable=not progress_bar,
+ ):
+ # reciprocal lattice spots and excitation errors
+ g = self.orientation_rotation_matrices[a0, :, :].T @ self.g_vec_all
+ sg = self.excitation_errors(g)
+
+ # Keep only points that will contribute to this orientation plan slice
+ keep = np.abs(sg) < self.orientation_kernel_size
+
+ # in-plane rotation angle
+ phi = np.arctan2(g[1, :], g[0, :])
+
+ # Loop over all peaks
+ for a1 in np.arange(self.g_vec_all.shape[1]):
+ ind_radial = self.orientation_shell_index[a1]
+
+ if keep[a1] and ind_radial >= 0:
+ # 2D orientation plan
+ self.orientation_ref[a0, ind_radial, :] += (
+ np.power(self.orientation_shell_radii[ind_radial], radial_power)
+ * np.power(self.struct_factors_int[a1], intensity_power)
+ * np.maximum(
+ 1
+ - np.sqrt(
+ sg[a1] ** 2
+ + (
+ (
+ np.mod(
+ self.orientation_gamma - phi[a1] + np.pi,
+ 2 * np.pi,
+ )
+ - np.pi
+ )
+ * self.orientation_shell_radii[ind_radial]
+ )
+ ** 2
+ )
+ / self.orientation_kernel_size,
+ 0,
+ )
+ )
+
+ orientation_ref_norm = np.sqrt(np.sum(self.orientation_ref[a0, :, :] ** 2))
+ if orientation_ref_norm > 0:
+ self.orientation_ref[a0, :, :] /= orientation_ref_norm
+
+ # Maximum value
+ self.orientation_ref_max = np.max(np.real(self.orientation_ref))
+
+ # Fourier domain along angular axis
+ if self.CUDA:
+ self.orientation_ref = cp.asarray(self.orientation_ref)
+ self.orientation_ref = cp.conj(cp.fft.fft(self.orientation_ref))
+ else:
+ self.orientation_ref = np.conj(np.fft.fft(self.orientation_ref))
+
+
+def match_orientations(
+ self,
+ bragg_peaks_array: PointListArray,
+ num_matches_return: int = 1,
+ min_angle_between_matches_deg=None,
+ min_number_peaks: int = 3,
+ inversion_symmetry: bool = True,
+ multiple_corr_reset: bool = True,
+ return_orientation: bool = True,
+ progress_bar: bool = True,
+):
+ """
+ Parameters
+ --------
+ bragg_peaks_array: PointListArray
+ PointListArray containing the Bragg peaks and intensities, with calibrations applied
+ num_matches_return: int
+ return these many matches as 3th dim of orient (matrix)
+ min_angle_between_matches_deg: int
+ Minimum angle between zone axis of multiple matches, in degrees.
+ Note that I haven't thought how to handle in-plane rotations, since multiple matches are possible.
+ min_number_peaks: int
+ Minimum number of peaks required to perform ACOM matching
+ inversion_symmetry: bool
+ check for inversion symmetry in the matches
+ multiple_corr_reset: bool
+ keep original correlation score for multiple matches
+ return_orientation: bool
+ Return orientation map from function for inspection.
+ The map is always stored in the Crystal object.
+ progress_bar: bool
+ Show or hide the progress bar
+
+ """
+ orientation_map = OrientationMap(
+ num_x=bragg_peaks_array.shape[0],
+ num_y=bragg_peaks_array.shape[1],
+ num_matches=num_matches_return,
+ )
+
+ # check cal state
+ if bragg_peaks_array.calstate["ellipse"] is False:
+ ellipse = False
+ warn("Warning: bragg peaks not elliptically calibrated")
+ else:
+ ellipse = True
+ if bragg_peaks_array.calstate["rotate"] is False:
+ rotate = False
+ warn("bragg peaks not rotationally calibrated")
+ else:
+ rotate = True
+
+ for rx, ry in tqdmnd(
+ *bragg_peaks_array.shape,
+ desc="Matching Orientations",
+ unit=" PointList",
+ disable=not progress_bar,
+ ):
+ vectors = bragg_peaks_array.get_vectors(
+ scan_x=rx,
+ scan_y=ry,
+ center=True,
+ ellipse=ellipse,
+ pixel=True,
+ rotate=rotate,
+ )
+
+ orientation = self.match_single_pattern(
+ bragg_peaks=vectors,
+ num_matches_return=num_matches_return,
+ min_angle_between_matches_deg=min_angle_between_matches_deg,
+ min_number_peaks=min_number_peaks,
+ inversion_symmetry=inversion_symmetry,
+ multiple_corr_reset=multiple_corr_reset,
+ plot_corr=False,
+ verbose=False,
+ )
+
+ orientation_map.set_orientation(orientation, rx, ry)
+
+ # assign and return
+ self.orientation_map = orientation_map
+
+ if return_orientation:
+ return orientation_map
+ else:
+ return
+
+
+def match_single_pattern(
+ self,
+ bragg_peaks: PointList,
+ num_matches_return: int = 1,
+ min_angle_between_matches_deg=None,
+ min_number_peaks=3,
+ inversion_symmetry=True,
+ multiple_corr_reset=True,
+ plot_polar: bool = False,
+ plot_corr: bool = False,
+ returnfig: bool = False,
+ figsize: Union[list, tuple, np.ndarray] = (12, 4),
+ verbose: bool = False,
+ # plot_corr_3D: bool = False,
+):
+ """
+ Solve for the best fit orientation of a single diffraction pattern.
+
+ Parameters
+ --------
+ bragg_peaks: PointList
+ numpy array containing the Bragg positions and intensities ('qx', 'qy', 'intensity')
+ num_matches_return: int
+ return these many matches as 3th dim of orient (matrix)
+ min_angle_between_matches_deg: int
+ Minimum angle between zone axis of multiple matches, in degrees.
+ Note that I haven't thought how to handle in-plane rotations, since multiple matches are possible.
+ min_number_peaks: int
+ Minimum number of peaks required to perform ACOM matching
+ inversion_symmetry bool
+ check for inversion symmetry in the matches
+ multiple_corr_reset bool
+ keep original correlation score for multiple matches
+ subpixel_tilt: bool
+ set to false for faster matching, returning the nearest corr point
+ plot_polar: bool
+ set to true to plot the polar transform of the diffraction pattern
+ plot_corr: bool
+ set to true to plot the resulting correlogram
+ returnfig: bool
+ return figure handles
+ figsize: list
+ size of figure
+ verbose: bool
+ Print the fitted zone axes, correlation scores
+ CUDA: bool
+ Enable CUDA for the FFT steps
+
+ Returns
+ --------
+ orientation: Orientation
+ Orientation class containing all outputs
+ fig, ax: handles
+ Figure handles for the plotting output
+ """
+
+ # adding assert statement for checking self.orientation_ref is present
+ # adding assert statement for checking self.orientation_ref is present
+ if not hasattr(self, "orientation_ref"):
+ raise ValueError(
+ "orientation_plan must be run with 'calculate_correlation_array=True'"
+ )
+
+ orientation = Orientation(num_matches=num_matches_return)
+ if bragg_peaks.data.shape[0] < min_number_peaks:
+ return orientation
+
+ # get bragg peak data
+ qx = bragg_peaks.data["qx"]
+ qy = bragg_peaks.data["qy"]
+ intensity = bragg_peaks.data["intensity"]
+
+ # other init
+ dphi = self.orientation_gamma[1] - self.orientation_gamma[0]
+ corr_value = np.zeros(self.orientation_num_zones)
+ corr_in_plane_angle = np.zeros(self.orientation_num_zones)
+ if inversion_symmetry:
+ corr_inv = np.zeros(self.orientation_num_zones, dtype="bool")
+
+ # loop over the number of matches to return
+ for match_ind in range(num_matches_return):
+ # Convert Bragg peaks to polar coordinates
+ qr = np.sqrt(qx**2 + qy**2)
+ qphi = np.arctan2(qy, qx)
+
+ # Calculate polar Bragg peak image
+ im_polar = np.zeros(
+ (
+ np.size(self.orientation_shell_radii),
+ self.orientation_in_plane_steps,
+ ),
+ dtype="float",
+ )
+
+ for ind_radial, radius in enumerate(self.orientation_shell_radii):
+ dqr = np.abs(qr - radius)
+ sub = dqr < self.orientation_kernel_size
+
+ if np.any(sub):
+ im_polar[ind_radial, :] = np.sum(
+ np.power(radius, self.orientation_radial_power)
+ * np.power(
+ np.maximum(intensity[sub, None], 0.0),
+ self.orientation_intensity_power,
+ )
+ * np.maximum(
+ 1
+ - np.sqrt(
+ dqr[sub, None] ** 2
+ + (
+ (
+ np.mod(
+ self.orientation_gamma[None, :]
+ - qphi[sub, None]
+ + np.pi,
+ 2 * np.pi,
+ )
+ - np.pi
+ )
+ * radius
+ )
+ ** 2
+ )
+ / self.orientation_kernel_size,
+ 0,
+ ),
+ axis=0,
+ )
+
+ # Determine the RMS signal from im_polar for the first match.
+ # Note that we use scaling slightly below RMS so that following matches
+ # don't have higher correlating scores than previous matches.
+ if multiple_corr_reset is False and num_matches_return > 1:
+ if match_ind == 0:
+ im_polar_scale_0 = np.mean(im_polar**2) ** 0.4
+ else:
+ im_polar_scale = np.mean(im_polar**2) ** 0.4
+ if im_polar_scale > 0:
+ im_polar *= im_polar_scale_0 / im_polar_scale
+ # im_polar /= np.sqrt(np.mean(im_polar**2))
+ # im_polar *= im_polar_0_rms
+
+ # If later refinement is performed, we need to keep the original image's polar tranform if corr reset is enabled
+ if self.orientation_refine:
+ if multiple_corr_reset:
+ if match_ind == 0:
+ if self.CUDA:
+ im_polar_refine = cp.asarray(im_polar.copy())
+ else:
+ im_polar_refine = im_polar.copy()
+ else:
+ if self.CUDA:
+ im_polar_refine = cp.asarray(im_polar.copy())
+ else:
+ im_polar_refine = im_polar.copy()
+
+ # Plot polar space image if needed
+ if plot_polar is True: # and match_ind==0:
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
+ ax.imshow(im_polar)
+ plt.show()
+
+ # FFT along theta
+ if self.CUDA:
+ im_polar_fft = cp.fft.fft(cp.asarray(im_polar))
+ else:
+ im_polar_fft = np.fft.fft(im_polar)
+ if self.orientation_refine:
+ if self.CUDA:
+ im_polar_refine_fft = cp.fft.fft(cp.asarray(im_polar_refine))
+ else:
+ im_polar_refine_fft = np.fft.fft(im_polar_refine)
+
+ # Calculate full orientation correlogram
+ if self.orientation_refine:
+ corr_full = np.zeros(
+ (
+ self.orientation_num_zones,
+ self.orientation_in_plane_steps,
+ )
+ )
+ if self.CUDA:
+ corr_full[self.orientation_sieve, :] = cp.maximum(
+ cp.sum(
+ cp.real(
+ cp.fft.ifft(
+ self.orientation_ref[self.orientation_sieve_CUDA, :, :]
+ * im_polar_fft[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ ).get()
+ else:
+ corr_full[self.orientation_sieve, :] = np.maximum(
+ np.sum(
+ np.real(
+ np.fft.ifft(
+ self.orientation_ref[self.orientation_sieve, :, :]
+ * im_polar_fft[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ )
+
+ else:
+ if self.CUDA:
+ corr_full = np.maximum(
+ np.sum(
+ np.real(
+ cp.fft.ifft(self.orientation_ref * im_polar_fft[None, :, :])
+ ),
+ axis=1,
+ ),
+ 0,
+ ).get()
+ else:
+ corr_full = np.maximum(
+ np.sum(
+ np.real(
+ np.fft.ifft(self.orientation_ref * im_polar_fft[None, :, :])
+ ),
+ axis=1,
+ ),
+ 0,
+ )
+
+ # If minimum angle is specified and we're on a match later than the first,
+ # we zero correlation values within the given range.
+ if min_angle_between_matches_deg is not None:
+ if match_ind > 0:
+ inds_previous = orientation.inds[:match_ind, 0]
+ for a0 in range(inds_previous.size):
+ mask_zero = np.arccos(
+ np.clip(
+ np.sum(
+ self.orientation_vecs
+ * self.orientation_vecs[inds_previous[a0], :],
+ axis=1,
+ ),
+ -1,
+ 1,
+ )
+ ) < np.deg2rad(min_angle_between_matches_deg)
+ corr_full[mask_zero, :] = 0.0
+
+ # Get maximum (non inverted) correlation value
+ ind_phi = np.argmax(corr_full, axis=1)
+
+ # Calculate orientation correlogram for inverse pattern (in-plane mirror)
+ if inversion_symmetry:
+ if self.orientation_refine:
+ corr_full_inv = np.zeros(
+ (
+ self.orientation_num_zones,
+ self.orientation_in_plane_steps,
+ )
+ )
+ if self.CUDA:
+ corr_full_inv[self.orientation_sieve, :] = cp.maximum(
+ cp.sum(
+ cp.real(
+ cp.fft.ifft(
+ self.orientation_ref[
+ self.orientation_sieve_CUDA, :, :
+ ]
+ * cp.conj(im_polar_fft)[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ ).get()
+ else:
+ corr_full_inv[self.orientation_sieve, :] = np.maximum(
+ np.sum(
+ np.real(
+ np.fft.ifft(
+ self.orientation_ref[self.orientation_sieve, :, :]
+ * np.conj(im_polar_fft)[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ )
+ else:
+ if self.CUDA:
+ corr_full_inv = np.maximum(
+ np.sum(
+ np.real(
+ cp.fft.ifft(
+ self.orientation_ref
+ * cp.conj(im_polar_fft)[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ ).get()
+ else:
+ corr_full_inv = np.maximum(
+ np.sum(
+ np.real(
+ np.fft.ifft(
+ self.orientation_ref
+ * np.conj(im_polar_fft)[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ )
+
+ # If minimum angle is specified and we're on a match later than the first,
+ # we zero correlation values within the given range.
+ if min_angle_between_matches_deg is not None:
+ if match_ind > 0:
+ inds_previous = orientation.inds[:match_ind, 0]
+ for a0 in range(inds_previous.size):
+ mask_zero = np.arccos(
+ np.clip(
+ np.sum(
+ self.orientation_vecs
+ * self.orientation_vecs[inds_previous[a0], :],
+ axis=1,
+ ),
+ -1,
+ 1,
+ )
+ ) < np.deg2rad(min_angle_between_matches_deg)
+ corr_full_inv[mask_zero, :] = 0.0
+
+ ind_phi_inv = np.argmax(corr_full_inv, axis=1)
+ corr_inv = np.zeros(self.orientation_num_zones, dtype="bool")
+
+ # Find best match for each zone axis
+ corr_value[:] = 0
+ for a0 in range(self.orientation_num_zones):
+ if (self.orientation_refine is False) or self.orientation_sieve[a0]:
+ # Correlation score
+ if inversion_symmetry:
+ if corr_full_inv[a0, ind_phi_inv[a0]] > corr_full[a0, ind_phi[a0]]:
+ corr_value[a0] = corr_full_inv[a0, ind_phi_inv[a0]]
+ corr_inv[a0] = True
+ else:
+ corr_value[a0] = corr_full[a0, ind_phi[a0]]
+ else:
+ corr_value[a0] = corr_full[a0, ind_phi[a0]]
+
+ # In-plane sub-pixel angular fit
+ if inversion_symmetry and corr_inv[a0]:
+ inds = np.mod(
+ ind_phi_inv[a0] + np.arange(-1, 2), self.orientation_gamma.size
+ ).astype("int")
+ c = corr_full_inv[a0, inds]
+ if np.max(c) > 0:
+ dc = (c[2] - c[0]) / (4 * c[1] - 2 * c[0] - 2 * c[2])
+ corr_in_plane_angle[a0] = (
+ self.orientation_gamma[ind_phi_inv[a0]] + dc * dphi
+ ) + np.pi
+ else:
+ inds = np.mod(
+ ind_phi[a0] + np.arange(-1, 2), self.orientation_gamma.size
+ ).astype("int")
+ c = corr_full[a0, inds]
+ if np.max(c) > 0:
+ dc = (c[2] - c[0]) / (4 * c[1] - 2 * c[0] - 2 * c[2])
+ corr_in_plane_angle[a0] = (
+ self.orientation_gamma[ind_phi[a0]] + dc * dphi
+ )
+
+ # If needed, keep original polar image to recompute the correlations
+ if (
+ multiple_corr_reset
+ and num_matches_return > 1
+ and match_ind == 0
+ and not self.orientation_refine
+ ):
+ corr_value_keep = corr_value.copy()
+ corr_in_plane_angle_keep = corr_in_plane_angle.copy()
+
+ # Determine the best fit orientation
+ ind_best_fit = np.unravel_index(np.argmax(corr_value), corr_value.shape)[0]
+
+ ############################################################
+ # If needed, perform fine step refinement of the zone axis #
+ ############################################################
+ if self.orientation_refine:
+ mask_refine = np.arccos(
+ np.clip(
+ np.sum(
+ self.orientation_vecs * self.orientation_vecs[ind_best_fit, :],
+ axis=1,
+ ),
+ -1,
+ 1,
+ )
+ ) < np.deg2rad(self.orientation_refine_range)
+ if self.CUDA:
+ mask_refine_CUDA = cp.asarray(mask_refine)
+
+ if self.CUDA:
+ corr_full[mask_refine, :] = cp.maximum(
+ cp.sum(
+ cp.real(
+ cp.fft.ifft(
+ self.orientation_ref[mask_refine_CUDA, :, :]
+ * im_polar_refine_fft[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ ).get()
+ else:
+ corr_full[mask_refine, :] = np.maximum(
+ np.sum(
+ np.real(
+ np.fft.ifft(
+ self.orientation_ref[mask_refine, :, :]
+ * im_polar_refine_fft[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ )
+
+ # Get maximum (non inverted) correlation value
+ ind_phi = np.argmax(corr_full, axis=1)
+
+ # Inversion symmetry
+ if inversion_symmetry:
+ if self.CUDA:
+ corr_full_inv[mask_refine, :] = cp.maximum(
+ cp.sum(
+ cp.real(
+ cp.fft.ifft(
+ self.orientation_ref[mask_refine_CUDA, :, :]
+ * cp.conj(im_polar_refine_fft)[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ ).get()
+ else:
+ corr_full_inv[mask_refine, :] = np.maximum(
+ np.sum(
+ np.real(
+ np.fft.ifft(
+ self.orientation_ref[mask_refine, :, :]
+ * np.conj(im_polar_refine_fft)[None, :, :]
+ )
+ ),
+ axis=1,
+ ),
+ 0,
+ )
+ ind_phi_inv = np.argmax(corr_full_inv, axis=1)
+
+ # Determine best in-plane correlation
+ for a0 in np.argwhere(mask_refine):
+ # Correlation score
+ if inversion_symmetry:
+ if corr_full_inv[a0, ind_phi_inv[a0]] > corr_full[a0, ind_phi[a0]]:
+ corr_value[a0] = corr_full_inv[a0, ind_phi_inv[a0]]
+ corr_inv[a0] = True
+ else:
+ corr_value[a0] = corr_full[a0, ind_phi[a0]]
+ else:
+ corr_value[a0] = corr_full[a0, ind_phi[a0]]
+
+ # Subpixel angular fit
+ if inversion_symmetry and corr_inv[a0]:
+ inds = np.mod(
+ ind_phi_inv[a0] + np.arange(-1, 2), self.orientation_gamma.size
+ ).astype("int")
+ c = corr_full_inv[a0, inds]
+ if np.max(c) > 0:
+ dc = (c[2] - c[0]) / (4 * c[1] - 2 * c[0] - 2 * c[2])
+ corr_in_plane_angle[a0] = (
+ self.orientation_gamma[ind_phi_inv[a0]] + dc * dphi
+ ) + np.pi
+ else:
+ inds = np.mod(
+ ind_phi[a0] + np.arange(-1, 2), self.orientation_gamma.size
+ ).astype("int")
+ c = corr_full[a0, inds]
+ if np.max(c) > 0:
+ dc = (c[2] - c[0]) / (4 * c[1] - 2 * c[0] - 2 * c[2])
+ corr_in_plane_angle[a0] = (
+ self.orientation_gamma[ind_phi[a0]] + dc * dphi
+ )
+
+ # Determine the new best fit orientation
+ ind_best_fit = np.unravel_index(
+ np.argmax(corr_value * mask_refine[None, :]), corr_value.shape
+ )[0]
+
+ # Verify current match has a correlation > 0
+ if corr_value[ind_best_fit] > 0:
+ # Get orientation matrix
+ orientation_matrix = np.squeeze(
+ self.orientation_rotation_matrices[ind_best_fit, :, :]
+ )
+
+ # apply in-plane rotation, and inversion if needed
+ if (
+ multiple_corr_reset
+ and match_ind > 0
+ and self.orientation_refine is False
+ ):
+ phi = corr_in_plane_angle_keep[ind_best_fit]
+ else:
+ phi = corr_in_plane_angle[ind_best_fit]
+ m3z = np.array(
+ [
+ [np.cos(phi), np.sin(phi), 0],
+ [-np.sin(phi), np.cos(phi), 0],
+ [0, 0, 1],
+ ]
+ )
+ orientation_matrix = orientation_matrix @ m3z
+ if inversion_symmetry and corr_inv[ind_best_fit]:
+ # Rotate 180 degrees around x axis for projected x-mirroring operation
+ orientation_matrix[:, 1:] = -orientation_matrix[:, 1:]
+
+ # Output best fit values into Orientation class
+ orientation.matrix[match_ind] = orientation_matrix
+
+ if self.orientation_refine:
+ orientation.corr[match_ind] = corr_value[ind_best_fit]
+ else:
+ if multiple_corr_reset and match_ind > 0:
+ orientation.corr[match_ind] = corr_value_keep[ind_best_fit]
+ else:
+ orientation.corr[match_ind] = corr_value[ind_best_fit]
+
+ if inversion_symmetry and corr_inv[ind_best_fit]:
+ ind_phi = ind_phi_inv[ind_best_fit]
+ else:
+ ind_phi = ind_phi[ind_best_fit]
+ orientation.inds[match_ind, 0] = ind_best_fit
+ orientation.inds[match_ind, 1] = ind_phi
+
+ if inversion_symmetry:
+ orientation.mirror[match_ind] = corr_inv[ind_best_fit]
+
+ orientation.angles[match_ind, :] = self.orientation_rotation_angles[
+ ind_best_fit, :
+ ]
+ orientation.angles[match_ind, 2] += phi
+
+ # If point group is known, use pymatgen to caculate the symmetry-
+ # reduced orientation matrix, producing the crystal direction family.
+ if self.pymatgen_available:
+ orientation = self.symmetry_reduce_directions(
+ orientation,
+ match_ind=match_ind,
+ )
+
+ else:
+ # No more matches are detected, so output default orientation matrix and leave corr = 0
+ orientation.matrix[match_ind] = np.squeeze(
+ self.orientation_rotation_matrices[0, :, :]
+ )
+
+ if verbose:
+ if self.pymatgen_available:
+ if np.abs(self.cell[5] - 120.0) < 1e-6:
+ x_proj_lattice = self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(orientation.family[match_ind][:, 0])
+ )
+ x_proj_lattice = np.round(x_proj_lattice, decimals=3)
+ zone_axis_lattice = self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(orientation.family[match_ind][:, 2])
+ )
+ zone_axis_lattice = np.round(zone_axis_lattice, decimals=3)
+ else:
+ if np.max(np.abs(orientation.family)) > 0.1:
+ x_proj_lattice = self.cartesian_to_lattice(
+ orientation.family[match_ind][:, 0]
+ )
+ x_proj_lattice = np.round(x_proj_lattice, decimals=3)
+ zone_axis_lattice = self.cartesian_to_lattice(
+ orientation.family[match_ind][:, 2]
+ )
+ zone_axis_lattice = np.round(zone_axis_lattice, decimals=3)
+
+ if orientation.corr[match_ind] > 0:
+ print(
+ "Best fit lattice directions: z axis = ("
+ + str(zone_axis_lattice)
+ + "),"
+ " x axis = ("
+ + str(x_proj_lattice)
+ + "),"
+ + " with corr value = "
+ + str(np.round(orientation.corr[match_ind], decimals=3))
+ )
+ else:
+ print("No good match found for index " + str(match_ind))
+
+ else:
+ zone_axis_fit = orientation.matrix[match_ind][:, 2]
+ zone_axis_lattice = self.cartesian_to_lattice(zone_axis_fit)
+ zone_axis_lattice = np.round(zone_axis_lattice, decimals=3)
+ print(
+ "Best fit zone axis (lattice) = ("
+ + str(zone_axis_lattice)
+ + "),"
+ + " with corr value = "
+ + str(np.round(orientation.corr[match_ind], decimals=3))
+ )
+
+ # if needed, delete peaks for next iteration
+ if num_matches_return > 1 and corr_value[ind_best_fit] > 0:
+ bragg_peaks_fit = self.generate_diffraction_pattern(
+ orientation,
+ ind_orientation=match_ind,
+ sigma_excitation_error=self.orientation_kernel_size,
+ )
+
+ remove = np.zeros_like(qx, dtype="bool")
+ scale_int = np.ones_like(qx)
+ for a0 in np.arange(qx.size):
+ d_2 = (bragg_peaks_fit.data["qx"] - qx[a0]) ** 2 + (
+ bragg_peaks_fit.data["qy"] - qy[a0]
+ ) ** 2
+
+ dist_min = np.sqrt(np.min(d_2))
+
+ if dist_min < self.orientation_tol_peak_delete:
+ remove[a0] = True
+ elif dist_min < self.orientation_kernel_size:
+ scale_int[a0] = (dist_min - self.orientation_tol_peak_delete) / (
+ self.orientation_kernel_size - self.orientation_tol_peak_delete
+ )
+
+ intensity = intensity * scale_int
+ qx = qx[~remove]
+ qy = qy[~remove]
+ intensity = intensity[~remove]
+
+ # plotting correlation image
+ if plot_corr is True:
+ corr_plot = corr_value.copy()
+ sig_in_plane = np.squeeze(corr_full[ind_best_fit, :]).copy()
+
+ if self.orientation_full:
+ fig, ax = plt.subplots(1, 2, figsize=figsize * np.array([2, 2]))
+ cmin = np.min(corr_plot)
+ cmax = np.max(corr_plot)
+
+ im_corr_zone_axis = np.zeros(
+ (
+ 2 * self.orientation_zone_axis_steps + 1,
+ 2 * self.orientation_zone_axis_steps + 1,
+ )
+ )
+
+ sub = self.orientation_inds[:, 2] == 0
+ x_inds = (
+ self.orientation_inds[sub, 0] - self.orientation_inds[sub, 1]
+ ).astype("int") + self.orientation_zone_axis_steps
+ y_inds = (
+ self.orientation_inds[sub, 1].astype("int")
+ + self.orientation_zone_axis_steps
+ )
+ inds_1D = np.ravel_multi_index(
+ [x_inds, y_inds], im_corr_zone_axis.shape
+ )
+ im_corr_zone_axis.ravel()[inds_1D] = corr_plot[sub]
+
+ sub = self.orientation_inds[:, 2] == 1
+ x_inds = (
+ self.orientation_inds[sub, 0] - self.orientation_inds[sub, 1]
+ ).astype("int") + self.orientation_zone_axis_steps
+ y_inds = self.orientation_zone_axis_steps - self.orientation_inds[
+ sub, 1
+ ].astype("int")
+ inds_1D = np.ravel_multi_index(
+ [x_inds, y_inds], im_corr_zone_axis.shape
+ )
+ im_corr_zone_axis.ravel()[inds_1D] = corr_plot[sub]
+
+ sub = self.orientation_inds[:, 2] == 2
+ x_inds = (
+ self.orientation_inds[sub, 1] - self.orientation_inds[sub, 0]
+ ).astype("int") + self.orientation_zone_axis_steps
+ y_inds = (
+ self.orientation_inds[sub, 1].astype("int")
+ + self.orientation_zone_axis_steps
+ )
+ inds_1D = np.ravel_multi_index(
+ [x_inds, y_inds], im_corr_zone_axis.shape
+ )
+ im_corr_zone_axis.ravel()[inds_1D] = corr_plot[sub]
+
+ sub = self.orientation_inds[:, 2] == 3
+ x_inds = (
+ self.orientation_inds[sub, 1] - self.orientation_inds[sub, 0]
+ ).astype("int") + self.orientation_zone_axis_steps
+ y_inds = self.orientation_zone_axis_steps - self.orientation_inds[
+ sub, 1
+ ].astype("int")
+ inds_1D = np.ravel_multi_index(
+ [x_inds, y_inds], im_corr_zone_axis.shape
+ )
+ im_corr_zone_axis.ravel()[inds_1D] = corr_plot[sub]
+
+ im_plot = (im_corr_zone_axis - cmin) / (cmax - cmin)
+ ax[0].imshow(im_plot, cmap="viridis", vmin=0.0, vmax=1.0)
+
+ elif self.orientation_half:
+ fig, ax = plt.subplots(1, 2, figsize=figsize * np.array([2, 1]))
+ cmin = np.min(corr_plot)
+ cmax = np.max(corr_plot)
+
+ im_corr_zone_axis = np.zeros(
+ (
+ self.orientation_zone_axis_steps + 1,
+ self.orientation_zone_axis_steps * 2 + 1,
+ )
+ )
+
+ sub = self.orientation_inds[:, 2] == 0
+ x_inds = (
+ self.orientation_inds[sub, 0] - self.orientation_inds[sub, 1]
+ ).astype("int")
+ y_inds = (
+ self.orientation_inds[sub, 1].astype("int")
+ + self.orientation_zone_axis_steps
+ )
+ inds_1D = np.ravel_multi_index(
+ [x_inds, y_inds], im_corr_zone_axis.shape
+ )
+ im_corr_zone_axis.ravel()[inds_1D] = corr_plot[sub]
+
+ sub = self.orientation_inds[:, 2] == 1
+ x_inds = (
+ self.orientation_inds[sub, 0] - self.orientation_inds[sub, 1]
+ ).astype("int")
+ y_inds = self.orientation_zone_axis_steps - self.orientation_inds[
+ sub, 1
+ ].astype("int")
+ inds_1D = np.ravel_multi_index(
+ [x_inds, y_inds], im_corr_zone_axis.shape
+ )
+ im_corr_zone_axis.ravel()[inds_1D] = corr_plot[sub]
+
+ im_plot = (im_corr_zone_axis - cmin) / (cmax - cmin)
+ ax[0].imshow(im_plot, cmap="viridis", vmin=0.0, vmax=1.0)
+
+ else:
+ fig, ax = plt.subplots(1, 2, figsize=figsize)
+ cmin = np.min(corr_plot)
+ cmax = np.max(corr_plot)
+
+ im_corr_zone_axis = np.zeros(
+ (
+ self.orientation_zone_axis_steps + 1,
+ self.orientation_zone_axis_steps + 1,
+ )
+ )
+ im_mask = np.ones(
+ (
+ self.orientation_zone_axis_steps + 1,
+ self.orientation_zone_axis_steps + 1,
+ ),
+ dtype="bool",
+ )
+
+ # Image indices
+ x_inds = (
+ self.orientation_inds[:, 0] - self.orientation_inds[:, 1]
+ ).astype("int")
+ y_inds = self.orientation_inds[:, 1].astype("int")
+
+ # Check vertical range of the orientation triangle.
+ if (
+ self.orientation_fiber_angles is not None
+ and np.abs(self.orientation_fiber_angles[0] - 180.0) > 1e-3
+ ):
+ # Orientation covers only top of orientation sphere
+
+ inds_1D = np.ravel_multi_index(
+ [x_inds, y_inds], im_corr_zone_axis.shape
+ )
+ im_corr_zone_axis.ravel()[inds_1D] = corr_plot
+ im_mask.ravel()[inds_1D] = False
+
+ else:
+ # Orientation covers full vertical range of orientation sphere.
+ # top half
+ sub = self.orientation_inds[:, 2] == 0
+ inds_1D = np.ravel_multi_index(
+ [x_inds[sub], y_inds[sub]], im_corr_zone_axis.shape
+ )
+ im_corr_zone_axis.ravel()[inds_1D] = corr_plot[sub]
+ im_mask.ravel()[inds_1D] = False
+ # bottom half
+ sub = self.orientation_inds[:, 2] == 1
+ inds_1D = np.ravel_multi_index(
+ [
+ self.orientation_zone_axis_steps - y_inds[sub],
+ self.orientation_zone_axis_steps - x_inds[sub],
+ ],
+ im_corr_zone_axis.shape,
+ )
+ im_corr_zone_axis.ravel()[inds_1D] = corr_plot[sub]
+ im_mask.ravel()[inds_1D] = False
+
+ if cmax > cmin:
+ im_plot = np.ma.masked_array(
+ (im_corr_zone_axis - cmin) / (cmax - cmin), mask=im_mask
+ )
+ else:
+ im_plot = im_corr_zone_axis
+
+ ax[0].imshow(im_plot, cmap="viridis", vmin=0.0, vmax=1.0)
+ ax[0].spines["left"].set_color("none")
+ ax[0].spines["right"].set_color("none")
+ ax[0].spines["top"].set_color("none")
+ ax[0].spines["bottom"].set_color("none")
+
+ inds_plot = np.unravel_index(
+ np.argmax(im_plot, axis=None), im_plot.shape
+ )
+ ax[0].scatter(
+ inds_plot[1],
+ inds_plot[0],
+ s=120,
+ linewidth=2,
+ facecolors="none",
+ edgecolors="r",
+ )
+
+ if np.abs(self.cell[5] - 120.0) < 1e-6:
+ label_0 = self.rational_ind(
+ self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(
+ self.orientation_zone_axis_range[0, :]
+ )
+ )
+ )
+ label_1 = self.rational_ind(
+ self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(
+ self.orientation_zone_axis_range[1, :]
+ )
+ )
+ )
+ label_2 = self.rational_ind(
+ self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(
+ self.orientation_zone_axis_range[2, :]
+ )
+ )
+ )
+ else:
+ label_0 = self.rational_ind(
+ self.cartesian_to_lattice(
+ self.orientation_zone_axis_range[0, :]
+ )
+ )
+ label_1 = self.rational_ind(
+ self.cartesian_to_lattice(
+ self.orientation_zone_axis_range[1, :]
+ )
+ )
+ label_2 = self.rational_ind(
+ self.cartesian_to_lattice(
+ self.orientation_zone_axis_range[2, :]
+ )
+ )
+
+ ax[0].set_xticks([0, self.orientation_zone_axis_steps])
+ ax[0].set_xticklabels([str(label_0), str(label_2)], size=14)
+ ax[0].xaxis.tick_top()
+
+ ax[0].set_yticks([self.orientation_zone_axis_steps])
+ ax[0].set_yticklabels([str(label_1)], size=14)
+
+ # In-plane rotation
+ # sig_in_plane = np.squeeze(corr_full[ind_best_fit, :])
+ sig_in_plane_max = np.max(sig_in_plane)
+ if sig_in_plane_max > 0:
+ sig_in_plane /= sig_in_plane_max
+ ax[1].plot(
+ self.orientation_gamma * 180 / np.pi,
+ sig_in_plane,
+ )
+
+ # Add markers for the best fit
+ tol = 0.01
+ sub = sig_in_plane > 1 - tol
+ ax[1].scatter(
+ self.orientation_gamma[sub] * 180 / np.pi,
+ sig_in_plane[sub],
+ s=120,
+ linewidth=2,
+ facecolors="none",
+ edgecolors="r",
+ )
+
+ ax[1].set_xlabel("In-plane rotation angle [deg]", size=16)
+ ax[1].set_ylabel("Corr. of Best Fit Zone Axis", size=16)
+ ax[1].set_ylim([0, 1.03])
+
+ plt.show()
+
+ if returnfig:
+ return orientation, fig, ax
+ else:
+ return orientation
+
+
+def cluster_grains(
+ self,
+ threshold_add=1.0,
+ threshold_grow=0.1,
+ angle_tolerance_deg=5.0,
+ progress_bar=True,
+):
+ """
+ Cluster grains using rotation criterion, and correlation values.
+
+ Parameters
+ --------
+ threshold_add: float
+ Minimum signal required for a probe position to initialize a cluster.
+ threshold_grow: float
+ Minimum signal required for a probe position to be added to a cluster.
+ angle_tolerance_deg: float
+ Rotation rolerance for clustering grains.
+ progress_bar: bool
+ Turns on the progress bar for the polar transformation
+
+ """
+
+ # symmetry operators
+ sym = self.symmetry_operators
+
+ # Get data
+ # Correlation data = signal to cluster with
+ sig = self.orientation_map.corr.copy()
+ sig_init = sig.copy()
+ mark = sig >= threshold_grow
+ sig[np.logical_not(mark)] = 0
+ # orientation matrix used for angle tolerance
+ matrix = self.orientation_map.matrix.copy()
+
+ # init
+ self.cluster_sizes = np.array((), dtype="int")
+ self.cluster_sig = np.array(())
+ self.cluster_inds = []
+ self.cluster_orientation = []
+ inds_all = np.zeros_like(sig, dtype="int")
+ inds_all.ravel()[:] = np.arange(inds_all.size)
+
+ # Tolerance
+ tol = np.deg2rad(angle_tolerance_deg)
+
+ # Main loop
+ search = True
+ comp = 0.0
+ mark_total = np.sum(np.max(mark, axis=2))
+ pbar = tqdm(total=mark_total, disable=not progress_bar)
+ while search is True:
+ inds_grain = np.argmax(sig)
+
+ val = sig.ravel()[inds_grain]
+
+ if val < threshold_add:
+ search = False
+
+ else:
+ # Start cluster
+ x, y, z = np.unravel_index(inds_grain, sig.shape)
+ mark[x, y, z] = False
+ sig[x, y, z] = 0
+ matrix_cluster = matrix[x, y, z]
+ orientation_cluster = self.orientation_map.get_orientation_single(x, y, z)
+
+ # Neighbors to search
+ xr = np.clip(x + np.arange(-1, 2, dtype="int"), 0, sig.shape[0] - 1)
+ yr = np.clip(y + np.arange(-1, 2, dtype="int"), 0, sig.shape[1] - 1)
+ inds_cand = inds_all[xr[:, None], yr[None], :].ravel()
+ inds_cand = np.delete(
+ inds_cand, mark.ravel()[inds_cand] == False # noqa: E712
+ )
+
+ if inds_cand.size == 0:
+ grow = False
+ else:
+ grow = True
+
+ # grow the cluster
+ while grow is True:
+ inds_new = np.array((), dtype="int")
+
+ keep = np.zeros(inds_cand.size, dtype="bool")
+ for a0 in range(inds_cand.size):
+ xc, yc, zc = np.unravel_index(inds_cand[a0], sig.shape)
+
+ # Angle test between orientation matrices
+ dphi = np.min(
+ np.arccos(
+ np.clip(
+ (
+ np.trace(
+ self.symmetry_operators
+ @ matrix[xc, yc, zc]
+ @ np.transpose(matrix_cluster),
+ axis1=1,
+ axis2=2,
+ )
+ - 1
+ )
+ / 2,
+ -1,
+ 1,
+ )
+ )
+ )
+
+ if np.abs(dphi) < tol:
+ keep[a0] = True
+
+ sig[xc, yc, zc] = 0
+ mark[xc, yc, zc] = False
+
+ xr = np.clip(
+ xc + np.arange(-1, 2, dtype="int"), 0, sig.shape[0] - 1
+ )
+ yr = np.clip(
+ yc + np.arange(-1, 2, dtype="int"), 0, sig.shape[1] - 1
+ )
+ inds_add = inds_all[xr[:, None], yr[None], :].ravel()
+ inds_new = np.append(inds_new, inds_add)
+
+ inds_grain = np.append(inds_grain, inds_cand[keep])
+ inds_cand = np.unique(
+ np.delete(inds_new, mark.ravel()[inds_new] == False) # noqa: E712
+ )
+
+ if inds_cand.size == 0:
+ grow = False
+
+ # convert grain to x,y coordinates, add = list
+ xg, yg, zg = np.unravel_index(inds_grain, sig.shape)
+ xyg = np.unique(np.vstack((xg, yg)), axis=1)
+ sig_mean = np.mean(sig_init.ravel()[inds_grain])
+ self.cluster_sizes = np.append(self.cluster_sizes, xyg.shape[1])
+ self.cluster_sig = np.append(self.cluster_sig, sig_mean)
+ self.cluster_orientation.append(orientation_cluster)
+ self.cluster_inds.append(xyg)
+
+ # update progressbar
+ new_marks = mark_total - np.sum(np.max(mark, axis=2))
+ pbar.update(new_marks)
+ mark_total -= new_marks
+
+ pbar.close()
+
+
+def cluster_orientation_map(
+ self,
+ stripe_width=(2, 2),
+ area_min=2,
+):
+ """
+ Produce a new orientation map from the clustered grains.
+ Use a stripe pattern for the overlapping grains.
+
+ Parameters
+ --------
+ stripe_width: (int,int)
+ Width of stripes for plotting maps with overlapping grains
+ area_min: (int)
+ Minimum size of grains to include
+
+ Returns
+ --------
+
+ orientation_map
+ The clustered orientation map
+
+ """
+
+ # init
+ orientation_map = OrientationMap(
+ num_x=self.orientation_map.num_x,
+ num_y=self.orientation_map.num_y,
+ num_matches=1,
+ )
+ im_grain = np.zeros(
+ (self.orientation_map.num_x, self.orientation_map.num_y), dtype="bool"
+ )
+ im_count = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y))
+ im_mark = np.zeros((self.orientation_map.num_x, self.orientation_map.num_y))
+
+ # Loop over grains to determine number in each pixel
+ for a0 in range(self.cluster_sizes.shape[0]):
+ if self.cluster_sizes[a0] >= area_min:
+ im_grain[:] = False
+ im_grain[
+ self.cluster_inds[a0][0, :],
+ self.cluster_inds[a0][1, :],
+ ] = True
+ im_count += im_grain
+ im_stripe = im_count >= 2
+ im_single = np.logical_not(im_stripe)
+
+ # prefactor for stripes
+ if stripe_width[0] == 0:
+ dx = 0
+ else:
+ dx = 1 / stripe_width[0]
+ if stripe_width[1] == 0:
+ dy = 0
+ else:
+ dy = 1 / stripe_width[1]
+
+ # loop over grains
+ for a0 in range(self.cluster_sizes.shape[0]):
+ if self.cluster_sizes[a0] >= area_min:
+ im_grain[:] = False
+ im_grain[
+ self.cluster_inds[a0][0, :],
+ self.cluster_inds[a0][1, :],
+ ] = True
+
+ # non-overlapping grains
+ sub = np.logical_and(im_grain, im_single)
+ x, y = np.unravel_index(np.where(sub.ravel()), im_grain.shape)
+ x = np.atleast_1d(np.squeeze(x))
+ y = np.atleast_1d(np.squeeze(y))
+ for a1 in range(x.size):
+ orientation_map.set_orientation(
+ self.cluster_orientation[a0], x[a1], y[a1]
+ )
+
+ # overlapping grains
+ sub = np.logical_and(im_grain, im_stripe)
+ x, y = np.unravel_index(np.where(sub.ravel()), im_grain.shape)
+ x = np.atleast_1d(np.squeeze(x))
+ y = np.atleast_1d(np.squeeze(y))
+ for a1 in range(x.size):
+ d = np.mod(
+ x[a1] * dx + y[a1] * dy + im_mark[x[a1], y[a1]] + +0.5,
+ im_count[x[a1], y[a1]],
+ )
+
+ if d < 1.0:
+ orientation_map.set_orientation(
+ self.cluster_orientation[a0], x[a1], y[a1]
+ )
+ im_mark[x[a1], y[a1]] += 1
+
+ return orientation_map
+
+
+def calculate_strain(
+ self,
+ bragg_peaks_array: PointListArray,
+ orientation_map: OrientationMap,
+ corr_kernel_size=None,
+ sigma_excitation_error=0.02,
+ tol_excitation_error_mult: float = 3,
+ tol_intensity: float = 1e-4,
+ k_max: Optional[float] = None,
+ min_num_peaks=5,
+ rotation_range=None,
+ mask_from_corr=True,
+ corr_range=(0, 2),
+ corr_normalize=True,
+ progress_bar=True,
+):
+ """
+ This function takes in both a PointListArray containing Bragg peaks, and a
+ corresponding OrientationMap, and uses least squares to compute the
+ deformation tensor which transforms the simulated diffraction pattern
+ into the experimental pattern, for all probe positons.
+
+ TODO: add robust fitting?
+
+ Args:
+ bragg_peaks_array (PointListArray): All Bragg peaks
+ orientation_map (OrientationMap): Orientation map generated from ACOM
+ corr_kernel_size (float): Correlation kernel size - if user does
+ not specify, uses self.corr_kernel_size.
+ sigma_excitation_error (float): sigma value for envelope applied to s_g (excitation errors) in units of inverse Angstroms
+ tol_excitation_error_mult (float): tolerance in units of sigma for s_g inclusion
+ tol_intensity (np float): tolerance in intensity units for inclusion of diffraction spots
+ k_max (float): Maximum scattering vector
+ min_num_peaks (int): Minimum number of peaks required.
+ rotation_range (float): Maximum rotation range in radians (for symmetry reduction).
+ progress_bar (bool): Show progress bar
+ mask_from_corr (bool): Use ACOM correlation signal for mask
+ corr_range (np.ndarray): Range of correlation signals for mask
+ corr_normalize (bool): Normalize correlation signal before masking
+
+ Returns:
+ strain_map (RealSlice): strain tensor
+
+ """
+
+ # Initialize empty strain maps
+ strain_map = RealSlice(
+ data=np.zeros((5, bragg_peaks_array.shape[0], bragg_peaks_array.shape[1])),
+ slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"),
+ name="strain_map",
+ )
+ if mask_from_corr:
+ corr_range = np.array(corr_range)
+ corr_mask = orientation_map.corr[:, :, 0]
+ if corr_normalize:
+ corr_mask /= np.mean(corr_mask)
+ corr_mask = np.clip(
+ (corr_mask - corr_range[0]) / (corr_range[1] - corr_range[0]), 0, 1
+ )
+ strain_map.get_slice("mask").data[:] = corr_mask
+
+ else:
+ strain_map.get_slice("mask").data[:] = 1.0
+
+ # init values
+ if corr_kernel_size is None:
+ corr_kernel_size = self.orientation_kernel_size
+ radius_max_2 = corr_kernel_size**2
+
+ # check cal state
+ if bragg_peaks_array.calstate["ellipse"] is False:
+ ellipse = False
+ warn("bragg peaks not elliptically calibrated")
+ else:
+ ellipse = True
+ if bragg_peaks_array.calstate["rotate"] is False:
+ rotate = False
+ warn("bragg peaks not rotationally calibrated")
+ else:
+ rotate = True
+
+ # Loop over all probe positions
+ for rx, ry in tqdmnd(
+ *bragg_peaks_array.shape,
+ desc="Calculating strains",
+ unit=" PointList",
+ disable=not progress_bar,
+ ):
+ # Get bragg peaks from experiment and reference
+ p = bragg_peaks_array.get_vectors(
+ scan_x=rx,
+ scan_y=ry,
+ center=True,
+ ellipse=ellipse,
+ pixel=True,
+ rotate=rotate,
+ )
+
+ if p.data.shape[0] >= min_num_peaks:
+ p_ref = self.generate_diffraction_pattern(
+ orientation_map.get_orientation(rx, ry),
+ sigma_excitation_error=sigma_excitation_error,
+ tol_excitation_error_mult=tol_excitation_error_mult,
+ tol_intensity=tol_intensity,
+ k_max=k_max,
+ )
+
+ # init
+ keep = np.zeros(p.data.shape[0], dtype="bool")
+ inds_match = np.zeros(p.data.shape[0], dtype="int")
+
+ # Pair off experimental Bragg peaks with reference peaks
+ for a0 in range(p.data.shape[0]):
+ dist_2 = (p.data["qx"][a0] - p_ref.data["qx"]) ** 2 + (
+ p.data["qy"][a0] - p_ref.data["qy"]
+ ) ** 2
+ ind_min = np.argmin(dist_2)
+
+ if dist_2[ind_min] <= radius_max_2:
+ inds_match[a0] = ind_min
+ keep[a0] = True
+
+ # Get all paired peaks
+ qxy = np.vstack((p.data["qx"][keep], p.data["qy"][keep])).T
+ qxy_ref = np.vstack(
+ (p_ref.data["qx"][inds_match[keep]], p_ref.data["qy"][inds_match[keep]])
+ ).T
+
+ # Apply intensity weighting from experimental measurements
+ qxy *= p.data["intensity"][keep, None]
+ qxy_ref *= p.data["intensity"][keep, None]
+
+ # Fit transformation matrix
+ # Note - not sure about transpose here
+ # (though it might not matter if rotation isn't included)
+ m = lstsq(qxy_ref, qxy, rcond=None)[0].T
+
+ # Get the infinitesimal strain matrix
+ strain_map.get_slice("e_xx").data[rx, ry] = 1 - m[0, 0]
+ strain_map.get_slice("e_yy").data[rx, ry] = 1 - m[1, 1]
+ strain_map.get_slice("e_xy").data[rx, ry] = -(m[0, 1] + m[1, 0]) / 2.0
+ strain_map.get_slice("theta").data[rx, ry] = (m[0, 1] - m[1, 0]) / 2.0
+
+ # Add finite rotation from ACOM orientation map.
+ # I am not sure about the relative signs here.
+ # Also, I need to add in the mirror operator.
+ if orientation_map.mirror[rx, ry, 0]:
+ strain_map.get_slice("theta").data[rx, ry] += (
+ orientation_map.angles[rx, ry, 0, 0]
+ + orientation_map.angles[rx, ry, 0, 2]
+ )
+ else:
+ strain_map.get_slice("theta").data[rx, ry] -= (
+ orientation_map.angles[rx, ry, 0, 0]
+ + orientation_map.angles[rx, ry, 0, 2]
+ )
+
+ else:
+ strain_map.get_slice("mask").data[rx, ry] = 0.0
+
+ if rotation_range is not None:
+ strain_map.get_slice("theta").data[:] = np.mod(
+ strain_map.get_slice("theta").data[:], rotation_range
+ )
+
+ return strain_map
+
+
+def save_ang_file(
+ self,
+ file_name,
+ orientation_map,
+ ind_orientation=0,
+ pixel_size=1.0,
+ pixel_units="px",
+ transpose_xy=True,
+ flip_x=False,
+):
+ """
+ This function outputs an ascii text file in the .ang format, containing
+ the Euler angles of an orientation map.
+
+ Args:
+ file_name (str): Path to save .ang file.
+ orientation_map (OrientationMap): Class containing orientation matrices,
+ correlation values, etc.
+ ind_orientation (int): Which orientation match to plot if num_matches > 1
+ pixel_size (float): Pixel size, if known.
+ pixel_units (str): Units of the pixel size
+ transpose_xy (bool): Transpose x and y pixel coordinates.
+ flip_x (bool): Swap x direction pixels (after transpose).
+
+ Returns:
+ nothing
+
+ """
+
+ from orix.io.plugins.ang import file_writer
+
+ xmap = self.orientation_map_to_orix_CrystalMap(
+ orientation_map,
+ ind_orientation=ind_orientation,
+ pixel_size=pixel_size,
+ pixel_units=pixel_units,
+ return_color_key=False,
+ transpose_xy=transpose_xy,
+ flip_x=flip_x,
+ )
+
+ file_writer(file_name, xmap)
+
+
+def orientation_map_to_orix_CrystalMap(
+ self,
+ orientation_map,
+ ind_orientation=0,
+ pixel_size=1.0,
+ pixel_units="px",
+ transpose_xy=True,
+ flip_x=False,
+ return_color_key=False,
+):
+ try:
+ from orix.quaternion import Rotation, Orientation
+ from orix.crystal_map import (
+ CrystalMap,
+ Phase,
+ PhaseList,
+ create_coordinate_arrays,
+ )
+ from orix.plot import IPFColorKeyTSL
+ except ImportError:
+ raise Exception("orix failed to import; try pip installing separately")
+
+ from diffpy.structure import Atom, Lattice, Structure
+
+ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
+ from pymatgen.core.structure import Structure as pgStructure
+
+ from scipy.spatial.transform import Rotation as R
+
+ from py4DSTEM.process.diffraction.utils import element_symbols
+
+ import warnings
+
+ # Get orientation matrices
+ orientation_matrices = orientation_map.matrix[:, :, ind_orientation].copy()
+ if transpose_xy:
+ orientation_matrices = np.transpose(orientation_matrices, (1, 0, 2, 3))
+ if flip_x:
+ orientation_matrices = np.flip(orientation_matrices, axis=0)
+
+ # Convert the orientation matrices into Euler angles
+ # suppress Gimbal lock warnings
+ def fxn():
+ warnings.warn("deprecated", DeprecationWarning)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ angles = np.vstack(
+ [
+ R.from_matrix(matrix.T).as_euler("zxz")
+ for matrix in orientation_matrices.reshape(-1, 3, 3)
+ ]
+ )
+
+ # generate a list of Rotation objects from the Euler angles
+ rotations = Rotation.from_euler(angles, direction="crystal2lab")
+
+ # Generate x,y coordinates since orix uses flat data internally
+ # coords, _ = create_coordinate_arrays((orientation_map.num_x,orientation_map.num_y),(pixel_size,)*2)
+ coords, _ = create_coordinate_arrays(
+ (orientation_matrices.shape[0], orientation_matrices.shape[1]),
+ (pixel_size,) * 2,
+ )
+
+ # Generate an orix structure from the Crystal
+ atoms = [
+ Atom(element_symbols[Z - 1], pos)
+ for Z, pos in zip(self.numbers, self.positions)
+ ]
+
+ structure = Structure(
+ atoms=atoms,
+ lattice=Lattice(*self.cell),
+ )
+
+ # Use pymatgen to get the symmetry
+ pg_structure = pgStructure(
+ self.lat_real, self.numbers, self.positions, coords_are_cartesian=False
+ )
+ pointgroup = SpacegroupAnalyzer(pg_structure).get_point_group_symbol()
+
+ # If the structure has only one element, name the phase based on the element
+ if np.unique(self.numbers).size == 1:
+ name = element_symbols[self.numbers[0] - 1]
+ else:
+ name = pg_structure.formula
+
+ # Generate an orix Phase to store symmetry
+ phase = Phase(
+ name=name,
+ point_group=pointgroup,
+ structure=structure,
+ )
+
+ xmap = CrystalMap(
+ rotations=rotations,
+ x=coords["x"],
+ y=coords["y"],
+ phase_list=PhaseList(phase),
+ prop={
+ "iq": orientation_map.corr[:, :, ind_orientation].ravel(),
+ "ci": orientation_map.corr[:, :, ind_orientation].ravel(),
+ },
+ scan_unit=pixel_units,
+ )
+
+ ckey = IPFColorKeyTSL(phase.point_group)
+
+ return (xmap, ckey) if return_color_key else xmap
+
+
+def symmetry_reduce_directions(
+ self,
+ orientation,
+ match_ind=0,
+ plot_output=False,
+ figsize=(15, 6),
+ el_shift=0.0,
+ az_shift=-30.0,
+):
+ """
+ This function calculates the symmetry-reduced cartesian directions from
+ and orientation matrix stored in orientation.matrix, and outputs them
+ into orientation.family. It optionally plots the 3D output.
+
+ """
+
+ # optional plot
+ if plot_output:
+ bound = 1.05
+ cam_dir = np.mean(self.orientation_zone_axis_range, axis=0)
+ cam_dir = cam_dir / np.linalg.norm(cam_dir)
+ az = np.rad2deg(np.arctan2(cam_dir[0], cam_dir[1])) + az_shift
+ # if np.abs(self.orientation_fiber_angles[0] - 180.0) < 1e-3:
+ # el = 10
+ # else:
+ el = np.rad2deg(np.arcsin(cam_dir[2])) + el_shift
+ el = 0
+ fig = plt.figure(figsize=figsize)
+
+ num_points = 10
+ t = np.linspace(0, 1, num=num_points + 1, endpoint=True)
+ d = np.array([[0, 1], [0, 2], [1, 2]])
+ orientation_zone_axis_range_flip = self.orientation_zone_axis_range.copy()
+ orientation_zone_axis_range_flip[0, :] = (
+ -1 * orientation_zone_axis_range_flip[0, :]
+ )
+
+ # loop over orientation matrix directions
+ for a0 in range(3):
+ in_range = np.all(
+ np.sum(
+ self.symmetry_reduction
+ * orientation.matrix[match_ind, :, a0][None, :, None],
+ axis=1,
+ )
+ >= 0,
+ axis=1,
+ )
+
+ orientation.family[match_ind, :, a0] = (
+ self.symmetry_operators[np.argmax(in_range)]
+ @ orientation.matrix[match_ind, :, a0]
+ )
+
+ # in_range = np.all(np.sum(self.symmetry_reduction * \
+ # orientation.matrix[match_ind,:,a0][None,:,None],
+ # axis=1) >= 0,
+ # axis=1)
+ # if np.any(in_range):
+ # ind = np.argmax(in_range)
+ # orientation.family[match_ind,:,a0] = self.symmetry_operators[ind] \
+ # @ orientation.matrix[match_ind,:,a0]
+ # else:
+ # # Note this is a quick fix for fiber_angles[0] = 180 degrees
+ # in_range = np.all(np.sum(self.symmetry_reduction * \
+ # (np.array([1,1,-1])*orientation.matrix[match_ind,:,a0][None,:,None]),
+ # axis=1) >= 0,
+ # axis=1)
+ # ind = np.argmax(in_range)
+ # orientation.family[match_ind,:,a0] = self.symmetry_operators[ind] \
+ # @ (np.array([1,1,-1])*orientation.matrix[match_ind,:,a0])
+
+ if plot_output:
+ ax = fig.add_subplot(1, 3, a0 + 1, projection="3d", elev=el, azim=az)
+
+ # draw orienation triangle
+ for a1 in range(d.shape[0]):
+ v = self.orientation_zone_axis_range[d[a1, 0], :][None, :] * t[
+ :, None
+ ] + self.orientation_zone_axis_range[d[a1, 1], :][None, :] * (
+ 1 - t[:, None]
+ )
+ v = v / np.linalg.norm(v, axis=1)[:, None]
+ ax.plot(
+ v[:, 1],
+ v[:, 0],
+ v[:, 2],
+ c="k",
+ )
+ v = self.orientation_zone_axis_range[a1, :][None, :] * t[:, None]
+ ax.plot(
+ v[:, 1],
+ v[:, 0],
+ v[:, 2],
+ c="k",
+ )
+
+ # if needed, draw orientation diamond
+ if (
+ self.orientation_fiber_angles is not None
+ and np.abs(self.orientation_fiber_angles[0] - 180.0) < 1e-3
+ ):
+ for a1 in range(d.shape[0] - 1):
+ v = orientation_zone_axis_range_flip[d[a1, 0], :][None, :] * t[
+ :, None
+ ] + orientation_zone_axis_range_flip[d[a1, 1], :][None, :] * (
+ 1 - t[:, None]
+ )
+ v = v / np.linalg.norm(v, axis=1)[:, None]
+ ax.plot(
+ v[:, 1],
+ v[:, 0],
+ v[:, 2],
+ c="k",
+ )
+ v = orientation_zone_axis_range_flip[0, :][None, :] * t[:, None]
+ ax.plot(
+ v[:, 1],
+ v[:, 0],
+ v[:, 2],
+ c="k",
+ )
+
+ # add points
+ p = self.symmetry_operators @ orientation.matrix[match_ind, :, a0]
+ ax.scatter(
+ xs=p[:, 1],
+ ys=p[:, 0],
+ zs=p[:, 2],
+ s=10,
+ marker="o",
+ # c='k',
+ )
+ v = orientation.family[match_ind, :, a0][None, :] * t[:, None]
+ ax.plot(
+ v[:, 1],
+ v[:, 0],
+ v[:, 2],
+ c="k",
+ )
+ ax.scatter(
+ xs=orientation.family[match_ind, 1, a0],
+ ys=orientation.family[match_ind, 0, a0],
+ zs=orientation.family[match_ind, 2, a0],
+ s=160,
+ marker="o",
+ facecolors="None",
+ edgecolors="r",
+ )
+ ax.scatter(
+ xs=orientation.matrix[match_ind, 1, a0],
+ ys=orientation.matrix[match_ind, 0, a0],
+ zs=orientation.matrix[match_ind, 2, a0],
+ s=80,
+ marker="o",
+ facecolors="None",
+ edgecolors="c",
+ )
+
+ ax.invert_yaxis()
+ ax.axes.set_xlim3d(left=-bound, right=bound)
+ ax.axes.set_ylim3d(bottom=-bound, top=bound)
+ ax.axes.set_zlim3d(bottom=-bound, top=bound)
+ axisEqual3D(ax)
+
+ if plot_output:
+ plt.show()
+
+ return orientation
+
+
+# zone axis range arguments for orientation_plan corresponding
+# to the symmetric wedge for each pointgroup, in the order:
+# [zone_axis_range, fiber_axis, fiber_angles]
+orientation_ranges = {
+ "1": ["fiber", [0, 0, 1], [180.0, 360.0]],
+ "-1": ["full", None, None],
+ "2": ["fiber", [0, 0, 1], [180.0, 180.0]],
+ "m": ["full", None, None],
+ "2/m": ["half", None, None],
+ "222": ["fiber", [0, 0, 1], [90.0, 180.0]],
+ "mm2": ["fiber", [0, 0, 1], [180.0, 90.0]],
+ "mmm": [[[1, 0, 0], [0, 1, 0]], None, None],
+ "4": ["fiber", [0, 0, 1], [90.0, 180.0]],
+ "-4": ["half", None, None],
+ "4/m": [[[1, 0, 0], [0, 1, 0]], None, None],
+ "422": ["fiber", [0, 0, 1], [180.0, 45.0]],
+ "4mm": ["fiber", [0, 0, 1], [180.0, 45.0]],
+ "-42m": ["fiber", [0, 0, 1], [180.0, 45.0]],
+ "4/mmm": [[[1, 0, 0], [1, 1, 0]], None, None],
+ "3": ["fiber", [0, 0, 1], [180.0, 120.0]],
+ "-3": ["fiber", [0, 0, 1], [180.0, 60.0]],
+ "32": ["fiber", [0, 0, 1], [90.0, 60.0]],
+ "3m": ["fiber", [0, 0, 1], [180.0, 60.0]],
+ "-3m": ["fiber", [0, 0, 1], [90.0, 60.0]],
+ "6": ["fiber", [0, 0, 1], [180.0, 60.0]],
+ "-6": ["fiber", [0, 0, 1], [180.0, 60.0]],
+ "6/m": [[[1, 0, 0], [0.5, 0.5 * np.sqrt(3), 0]], None, None],
+ "622": ["fiber", [0, 0, 1], [180.0, 30.0]],
+ "6mm": ["fiber", [0, 0, 1], [180.0, 30.0]],
+ "-6m2": ["fiber", [0, 0, 1], [90.0, 60.0]],
+ "6/mmm": [[[0.5 * np.sqrt(3), 0.5, 0.0], [1, 0, 0]], None, None],
+ "23": [
+ [[1, 0, 0], [1, 1, 1]],
+ None,
+ None,
+ ], # this is probably wrong, it is half the needed range
+ "m-3": [[[1, 0, 0], [1, 1, 1]], None, None],
+ "432": [[[1, 0, 0], [1, 1, 1]], None, None],
+ "-43m": [[[1, -1, 1], [1, 1, 1]], None, None],
+ "m-3m": [[[0, 1, 1], [1, 1, 1]], None, None],
+}
+
+# "-3m": ["fiber", [0, 0, 1], [90.0, 60.0]],
+# "-3m": ["fiber", [0, 0, 1], [180.0, 30.0]],
diff --git a/py4DSTEM/process/diffraction/crystal_bloch.py b/py4DSTEM/process/diffraction/crystal_bloch.py
new file mode 100644
index 000000000..6a3c9b1ac
--- /dev/null
+++ b/py4DSTEM/process/diffraction/crystal_bloch.py
@@ -0,0 +1,696 @@
+import warnings
+import numpy as np
+import numpy.lib.recfunctions as rfn
+from scipy import linalg
+from typing import Union, Optional, Dict, Tuple, List
+from time import time
+from tqdm import tqdm
+from dataclasses import dataclass
+
+from emdfile import PointList
+from py4DSTEM.process.utils import electron_wavelength_angstrom, single_atom_scatter
+from py4DSTEM.process.diffraction.WK_scattering_factors import compute_WK_factor
+
+
+@dataclass
+class DynamicalMatrixCache:
+ has_valid_cache: bool = False
+ cached_U_gmh: np.array = None
+
+
+def calculate_dynamical_structure_factors(
+ self,
+ accelerating_voltage: float,
+ method: str = "WK-CP",
+ k_max: float = 2.0,
+ thermal_sigma: Optional[Union[float, dict]] = None,
+ tol_structure_factor: float = 0.0,
+ recompute_kinematic_structure_factors=True,
+ g_vec_precision=None,
+ verbose=True,
+):
+ """
+ Calculate and store the relativistic corrected structure factors used for Bloch computations
+ in a dictionary for faster lookup.
+
+ Args:
+ accelerating_voltage (float): accelerating voltage in eV
+ method (str): Choose which parameterization of the structure factors to use:
+ "Lobato": Uses the kinematic structure factors from crystal.py, using the parameterization from
+ Lobato & Van Dyck, Acta Cryst A 70:6 (2014)
+ "Lobato-absorptive": Lobato factors plus an imaginary part
+ equal to 0.1•f, as a simple but inaccurate way to include absorption, per
+ Hashimoto, Howie, & Whelan, Proc R Soc Lond A 269:80-103 (1962)
+ "WK": Uses the Weickenmeier-Kohl parameterization for
+ the elastic form factors, including Debye-Waller factor,
+ with no absorption, as described in
+ Weickenmeier & Kohl, Acta Cryst A 47:5 (1991)
+ "WK-C": WK form factors plus the "core" contribution to absorption
+ following H. Rose, Optik 45:2 (1976)
+ "WK-P": WK form factors plus the phonon/TDS absorptive contribution
+ "WK-CP": WK form factors plus core and phonon absorption (default)
+
+ k_max (float): max scattering length to compute structure factors to.
+ Setting this to 2x the k_max used in generating the beamsn
+ included in a simulation will retain all possible couplings
+ thermal_sigma (float or dict{int->float}): RMS atomic diplacement for attenuating form factors to account for thermal
+ broadening of the potential, only used when a "WK" method is
+ selected. Required when WK-P or WK-CP are selected.
+ Units are Å. (This is often written as 〈u〉in papers)
+ To specify different 〈u〉 for each element, pass a dictionary
+ with Z as the key, mapping to the appropriate float value
+ tol_structure_factor (float): tolerance for removing low-valued structure factors. Reflections
+ with structure factor below the tolerance will have zero coupling
+ in the dynamical calculations (i.e. they are the ignored weak beams)
+ recompute_kinematic_structure_factors (bool): When True, recomputes the kinematic structure
+ factors using the same tol_structure_factor, and with k_max
+ set to *half* the k_max for the dynamical factors. The factor
+ of half ensures that every beam in a simulation can couple to
+ every other beam (no high-angle couplings in the Bloch matrix
+ are set to zero.)
+ g_vec_precision (optional int): If specified, rounds |g| to this many decimal places so that
+ automatic caching of the atomic form factors is not slowed
+ down due to floating point errors. Setting this to 3 can give
+ substantial speedup at the cost of some reduced accuracy
+
+ See WK_scattering_factors.py for details on the Weickenmeier-Kohl form factors.
+ """
+
+ assert method in (
+ "Lobato",
+ "Lobato-absorptive",
+ "WK",
+ "WK-C",
+ "WK-P",
+ "WK-CP",
+ ), "Invalid method specified."
+
+ if "WK" in method:
+ assert (
+ thermal_sigma is not None
+ ), "thermal_sigma must be specifed when using W-K potentials"
+
+ # Calculate the reciprocal lattice points to include based on k_max
+
+ k_max = np.asarray(k_max)
+
+ if recompute_kinematic_structure_factors:
+ if hasattr(self, "struct_factors"):
+ print("Warning: overriding existing structure factors...")
+ self.calculate_structure_factors(
+ k_max=k_max / 2.0,
+ tol_structure_factor=tol_structure_factor,
+ return_intensities=False,
+ )
+
+ # Inverse lattice vectors
+ lat_inv = np.linalg.inv(self.lat_real)
+
+ # Find shortest lattice vector direction
+ k_test = np.vstack(
+ [
+ lat_inv[0, :],
+ lat_inv[1, :],
+ lat_inv[2, :],
+ lat_inv[0, :] + lat_inv[1, :],
+ lat_inv[0, :] + lat_inv[2, :],
+ lat_inv[1, :] + lat_inv[2, :],
+ lat_inv[0, :] + lat_inv[1, :] + lat_inv[2, :],
+ lat_inv[0, :] - lat_inv[1, :] + lat_inv[2, :],
+ lat_inv[0, :] + lat_inv[1, :] - lat_inv[2, :],
+ lat_inv[0, :] - lat_inv[1, :] - lat_inv[2, :],
+ ]
+ )
+ k_leng_min = np.min(np.linalg.norm(k_test, axis=1))
+
+ # Tile lattice vectors
+ num_tile = np.ceil(k_max / k_leng_min)
+ ya, xa, za = np.meshgrid(
+ np.arange(-num_tile, num_tile + 1),
+ np.arange(-num_tile, num_tile + 1),
+ np.arange(-num_tile, num_tile + 1),
+ )
+ hkl = np.vstack([xa.ravel(), ya.ravel(), za.ravel()])
+ g_vec_all = lat_inv @ hkl
+
+ # Delete lattice vectors outside of k_max
+ keep = np.linalg.norm(g_vec_all, axis=0) <= k_max
+ hkl = hkl[:, keep]
+ g_vec_all = g_vec_all[:, keep]
+ g_vec_leng = np.linalg.norm(g_vec_all, axis=0)
+
+ lobato_lookup = single_atom_scatter()
+
+ m0c2 = 5.109989461e5 # electron rest mass, in eV
+ relativistic_factor = (m0c2 + accelerating_voltage) / m0c2
+
+ def get_f_e(q, Z, thermal_sigma, method):
+ if method == "Lobato":
+ # Real lobato factors
+ lobato_lookup.get_scattering_factor([Z], [1.0], q, units="A")
+ return np.complex128(relativistic_factor / np.pi * lobato_lookup.fe)
+ elif method == "Lobato-absorptive":
+ # Fake absorptive Lobato factors
+ lobato_lookup.get_scattering_factor([Z], [1.0], q, units="A")
+ return np.complex128(
+ relativistic_factor / np.pi * lobato_lookup.fe * (1.0 + 0.1j)
+ )
+ elif method == "WK":
+ # Real WK factor
+ return compute_WK_factor(
+ q,
+ Z,
+ accelerating_voltage,
+ thermal_sigma,
+ include_core=False,
+ include_phonon=False,
+ )
+ elif method == "WK-C":
+ # WK, core only
+ return compute_WK_factor(
+ q,
+ Z,
+ accelerating_voltage,
+ thermal_sigma,
+ include_core=True,
+ include_phonon=False,
+ )
+ elif method == "WK-P":
+ # WK, phonon only
+ return compute_WK_factor(
+ q,
+ Z,
+ accelerating_voltage,
+ thermal_sigma,
+ include_core=False,
+ include_phonon=True,
+ )
+ elif method == "WK-CP":
+ # WK, core + phonon
+ return compute_WK_factor(
+ q,
+ Z,
+ accelerating_voltage,
+ thermal_sigma,
+ include_core=True,
+ include_phonon=True,
+ )
+
+ # find unique values of Z and |g| for computing atomic form factors
+ Z_unique, Z_inverse = np.unique(self.numbers, return_inverse=True)
+ g_unique, g_inverse = np.unique(
+ np.round(g_vec_leng, g_vec_precision) if g_vec_precision else g_vec_leng,
+ return_inverse=True,
+ )
+
+ f_e_uniq = np.zeros((Z_unique.size, g_unique.size), dtype=np.complex128)
+
+ for idx, Z in enumerate(Z_unique):
+ # get element-specific thermal displacements, if given
+ sigma = thermal_sigma[Z] if isinstance(thermal_sigma, dict) else thermal_sigma
+ f_e_uniq[idx, :] = get_f_e(g_unique, Z, sigma, method)
+
+ # flesh out the dense array of atomic scattering factors
+ f_e = f_e_uniq[np.ix_(Z_inverse, g_inverse)]
+
+ # Calculate structure factors
+ struct_factors = np.sum(
+ f_e * np.exp(2.0j * np.pi * np.squeeze(self.positions[:, None, :] @ hkl)),
+ axis=0,
+ )
+
+ # Divide by unit cell volume
+ unit_cell_volume = np.abs(np.linalg.det(self.lat_real))
+ struct_factors /= unit_cell_volume
+
+ # Remove structure factors below tolerance level
+ keep = np.abs(struct_factors) >= tol_structure_factor
+ hkl = hkl[:, keep]
+
+ g_vec_all = g_vec_all[:, keep]
+ g_vec_leng = g_vec_leng[keep]
+ struct_factors = struct_factors[keep]
+
+ # Store relativistic corrected structure factors in a dictionary for faster lookup in the Bloch code
+
+ self.accel_voltage = accelerating_voltage
+ self.wavelength = electron_wavelength_angstrom(self.accel_voltage)
+
+ self.Ug_dict = {
+ (hkl[0, i], hkl[1, i], hkl[2, i]): struct_factors[i]
+ for i in range(hkl.shape[1])
+ }
+
+
+def generate_dynamical_diffraction_pattern(
+ self,
+ beams: PointList,
+ thickness: Union[float, list, tuple, np.ndarray],
+ zone_axis_lattice: np.ndarray = None,
+ zone_axis_cartesian: np.ndarray = None,
+ foil_normal_lattice: np.ndarray = None,
+ foil_normal_cartesian: np.ndarray = None,
+ verbose: bool = False,
+ always_return_list: bool = False,
+ dynamical_matrix_cache: Optional[DynamicalMatrixCache] = None,
+ return_complex: bool = False,
+ return_eigenvectors: bool = False,
+ return_Smatrix: bool = False,
+) -> Union[PointList, List[PointList]]:
+ """
+ Generate a dynamical diffraction pattern (or thickness series of patterns)
+ using the Bloch wave method.
+
+ The beams to be included in the Bloch calculation must be pre-calculated
+ and passed as a PointList containing at least (qx, qy, h, k, l) fields.
+
+ If ``thickness`` is a single value, one new PointList will be returned.
+ If ``thickness`` is a sequence of values, a list of PointLists will be returned,
+ corresponding to each thickness value in the input.
+
+ Frequent reference will be made to "Introduction to conventional transmission electron microscopy"
+ by DeGraef, whose overall approach we follow here.
+
+ Args:
+ beams (PointList): PointList from the kinematical diffraction generator
+ which will define the beams included in the Bloch calculation
+ thickness (float or list/array) thickness in Ångström to evaluate diffraction patterns at.
+ The main Bloch calculation can be reused for multiple thicknesses
+ without much overhead.
+ zone_axis & foil_normal Incident beam orientation and foil normal direction.
+ Each can be specified in the Cartesian or crystallographic basis,
+ using e.g. zone_axis_lattice or zone_axis_cartesian. These are
+ internally parsed by Crystal.parse_orientation
+
+ Less commonly used args:
+ always_return_list (bool): When True, the return is always a list of PointLists,
+ even for a single thickness
+ dynamical_matrix_cache: (DyanmicalMatrixCache) Dataclass used for caching of the
+ dynamical matrix. If the cached matrix does not exist, it is
+ computed and stored. Subsequent calls will use the cached matrix
+ for the off-diagonal components of the A matrix and overwrite
+ the diagonal elements. This is used for CBED calculations.
+ return_complex (bool): When True, returns both the complex amplitude and intensity. Defaults to (False)
+ Returns:
+ bragg_peaks (PointList): Bragg peaks with fields [qx, qy, intensity, h, k, l]
+ or
+ [bragg_peaks,...] (PointList): If thickness is a list/array, or always_return_list is True,
+ a list of PointLists is returned.
+ if return_complex = True:
+ bragg_peaks (PointList): Bragg peaks with fields [qx, qy, intensity, amplitude, h, k, l]
+ or
+ [bragg_peaks,...] (PointList): If thickness is a list/array, or always_return_list is True,
+ a list of PointLists is returned.
+ if return_Smatrix = True:
+ [S_matrix, ...], psi_0: Returns a list of S-matrices for each thickness (this is always a list),
+ and the vector representing the incident plane wave. The beams of the
+ S-matrix have the same order as in the input `beams`.
+
+ """
+ t0 = time() # start timer for matrix setup
+
+ n_beams = beams.data.shape[0]
+
+ beam_g, beam_h = np.meshgrid(np.arange(n_beams), np.arange(n_beams))
+
+ # Parse input orientations:
+ zone_axis_rotation_matrix = self.parse_orientation(
+ zone_axis_lattice=zone_axis_lattice, zone_axis_cartesian=zone_axis_cartesian
+ )
+ if foil_normal_lattice is not None or foil_normal_cartesian is not None:
+ foil_normal = self.parse_orientation(
+ zone_axis_lattice=foil_normal_lattice,
+ zone_axis_cartesian=foil_normal_cartesian,
+ )
+ else:
+ foil_normal = zone_axis_rotation_matrix
+
+ foil_normal = foil_normal[:, 2]
+
+ # Note the difference in notation versus kinematic function:
+ # k0 is the scalar magnitude of the wavevector, rather than
+ # a vector along the zone axis.
+ k0 = 1.0 / electron_wavelength_angstrom(self.accel_voltage)
+
+ ################################################################
+ # Compute the reduced structure matrix \bar{A} in DeGraef 5.52 #
+ ################################################################
+
+ hkl = np.vstack((beams.data["h"], beams.data["k"], beams.data["l"])).T
+
+ # Check if we have a cached dynamical matrix, which saves us from calculating the
+ # off-diagonal elements when running this in a loop with the same zone axis
+ if dynamical_matrix_cache is not None and dynamical_matrix_cache.has_valid_cache:
+ U_gmh = dynamical_matrix_cache.cached_U_gmh
+ else:
+ # No cached matrix is available/desired, so calculate it:
+
+ # get hkl indices of \vec{g} - \vec{h}
+ g_minus_h = np.vstack(
+ (
+ beams.data["h"][beam_g.ravel()] - beams.data["h"][beam_h.ravel()],
+ beams.data["k"][beam_g.ravel()] - beams.data["k"][beam_h.ravel()],
+ beams.data["l"][beam_g.ravel()] - beams.data["l"][beam_h.ravel()],
+ )
+ ).T
+
+ # Get the structure factors for each nonzero element, and zero otherwise
+ U_gmh = np.array(
+ [
+ self.Ug_dict.get((gmh[0], gmh[1], gmh[2]), 0.0 + 0.0j)
+ for gmh in g_minus_h
+ ],
+ dtype=np.complex128,
+ ).reshape(beam_g.shape)
+
+ # If we are supposed to cache, but don't have one saved, save this one:
+ if (
+ dynamical_matrix_cache is not None
+ and not dynamical_matrix_cache.has_valid_cache
+ ):
+ dynamical_matrix_cache.cached_U_gmh = U_gmh
+ dynamical_matrix_cache.has_valid_cache = True
+
+ if verbose:
+ print(f"Bloch matrix has size {U_gmh.shape}")
+
+ # Compute the diagonal entries of \hat{A}: 2 k_0 s_g [5.51]
+ g = (hkl @ self.lat_inv) @ zone_axis_rotation_matrix
+ sg = self.excitation_errors(
+ g.T, foil_normal=-foil_normal @ zone_axis_rotation_matrix
+ )
+
+ # Fill in the diagonal, completing the structure mattrx
+ np.fill_diagonal(U_gmh, 2 * k0 * sg + 1.0j * np.imag(self.Ug_dict[(0, 0, 0)]))
+
+ if verbose:
+ print(f"Constructing the A matrix took {(time()-t0)*1000.:.3f} ms.")
+
+ #############################################################################################
+ # Compute eigen-decomposition of \hat{A} to yield C (the matrix containing the eigenvectors #
+ # as its columns) and gamma (the reduced eigenvalues), as in DeGraef 5.52 #
+ #############################################################################################
+
+ t0 = time() # start timer for eigendecomposition
+
+ v, C = linalg.eig(U_gmh) # decompose!
+ gamma_fac = 2.0 * k0 * zone_axis_rotation_matrix[:, 2] @ foil_normal
+ gamma = v / gamma_fac # divide by 2 k_n
+
+ # precompute the inverse of C
+ C_inv = np.linalg.inv(C)
+
+ if verbose:
+ print(f"Decomposing the A matrix took {(time()-t0)*1000.:.3f} ms.")
+
+ ##############################################################################################
+ # Compute diffraction intensities by calculating exit wave \Psi in DeGraef 5.60, and collect #
+ # values into PointLists #
+ ##############################################################################################
+
+ t0 = time()
+
+ psi_0 = np.zeros((n_beams,))
+ psi_0[int(np.where((hkl == [0, 0, 0]).all(axis=1))[0])] = 1.0
+
+ # calculate the diffraction intensities (and amplitudes) for each thichness matrix
+ # I = |psi|^2 ; psi = C @ E(z) @ C^-1 @ psi_0, where E(z) is the thickness matrix
+ if return_Smatrix:
+ Smatrices = [
+ C @ np.diag(np.exp(2.0j * np.pi * z * gamma)) @ C_inv
+ for z in np.atleast_1d(thickness)
+ ]
+ return (
+ (Smatrices, psi_0, (C, C_inv, gamma, gamma_fac))
+ if return_eigenvectors
+ else (Smatrices, psi_0)
+ )
+ elif return_complex:
+ # calculate the amplitudes
+ amplitudes = [
+ C @ (np.exp(2.0j * np.pi * z * gamma) * (C_inv @ psi_0))
+ for z in np.atleast_1d(thickness)
+ ]
+
+ # Do this first to avoid handling structured array
+ intensities = np.abs(amplitudes) ** 2
+
+ # convert amplitudes as a structured array
+ # do we want complex64 or complex 32.
+ amplitudes = np.array(amplitudes, dtype=([("amplitude", " Union[np.ndarray, List[np.ndarray], Dict[Tuple[int], np.ndarray]]:
+ """
+ Generate a dynamical CBED pattern using the Bloch wave method.
+
+ Args:
+ beams (PointList): PointList from the kinematical diffraction generator
+ which will define the beams included in the Bloch calculation
+ thickness (float or list/array) thickness in Ångström to evaluate diffraction patterns at.
+ The main Bloch calculation can be reused for multiple thicknesses
+ without much overhead.
+ alpha_mrad (float): Convergence angle for CBED pattern. Note that if disks in the calculation
+ overlap, they will be added incoherently, and the resulting CBED will
+ thus represent the average over the unit cell (i.e. a PACBED pattern,
+ as described in LeBeau et al., Ultramicroscopy 110(2): 2010.)
+ pixel_size_inv_A (float): CBED pixel size in 1/Å.
+ DP_size_inv_A (optional float): If specified, defines the extents of the diffraction pattern.
+ If left unspecified, the DP will be automatically scaled to
+ fit all of the beams present in the input plus some small buffer.
+ zone_axis (np float vector): 3 element projection direction for sim pattern
+ Can also be a 3x3 orientation matrix (zone axis 3rd column)
+ foil_normal: 3 element foil normal - set to None to use zone_axis
+ LACBED (bool) Return each diffraction disk as a separate image, in a dictionary
+ keyed by tuples of (h,k,l).
+ proj_x_axis (np float vector): 3 element vector defining image x axis (vertical)
+ two_beam_zone_axis_lattice When only two beams are present in the "beams" PointList,
+ the computation of the projected crystallographic directions
+ becomes ambiguous. In this case, you must specify the indices of
+ the zone axis used to generate the beams.
+ return_probe (bool): If True, the probe (np.ndarray) will be returned in additon to the CBED
+
+ Returns:
+ If thickness is a scalar: CBED pattern as np.ndarray
+ If thickness is a sequence: CBED patterns for each thickness value as a list of np.ndarrays
+ If LACBED is True and thickness is scalar: Dictionary with tuples of ints (h,k,l) as keys, mapping to np.ndarray.
+ If LACBED is True and thickness is a sequence: List of dictionaries, structured as above.
+ If return_probe is True: will return a tuple (, Probe)
+ """
+
+ alpha_rad = alpha_mrad / 1000.0
+
+ # figure out the projected x and y directions from the beams input
+ hkl = np.vstack((beams.data["h"], beams.data["k"], beams.data["l"])).T.astype(
+ np.float64
+ )
+ qxy = np.vstack((beams.data["qx"], beams.data["qy"])).T.astype(np.float64)
+
+ # If there are only two beams, augment the list with a third perpendicular spot
+ if qxy.shape[0] == 2:
+ assert (
+ two_beam_zone_axis_lattice is not None
+ ), "When only two beams are present, two_beam_zone_axis_lattice must be specified."
+ hkl_reflection = hkl[1] if np.all(qxy[0] == 0.0) else hkl[0]
+ qxy_reflection = qxy[1] if np.all(qxy[0] == 0.0) else qxy[0]
+ orthogonal_spot = np.cross(two_beam_zone_axis_lattice, hkl_reflection)
+ hkl_augmented = np.vstack((hkl, orthogonal_spot))
+ qxy_augmented = np.vstack((qxy, np.flipud(qxy_reflection)))
+ proj = np.linalg.lstsq(qxy_augmented, hkl_augmented, rcond=-1)[0]
+ hkl_proj_x = proj[0] / np.linalg.norm(proj[0])
+ # Otherwise calculate them based on the pattern
+ else:
+ proj = np.linalg.lstsq(qxy, hkl, rcond=-1)[0]
+ hkl_proj_x = proj[0] / np.linalg.norm(proj[0])
+
+ # get unit vector in zone axis direction and projected x and y Cartesian directions:
+ zone_axis_rotation_matrix = self.parse_orientation(
+ zone_axis_lattice=zone_axis_lattice,
+ zone_axis_cartesian=zone_axis_cartesian,
+ proj_x_lattice=hkl_proj_x,
+ )
+ ZA = np.array(zone_axis_rotation_matrix[:, 2]) / np.linalg.norm(
+ np.array(zone_axis_rotation_matrix[:, 2])
+ )
+ proj_x = zone_axis_rotation_matrix[:, 0] / np.linalg.norm(
+ zone_axis_rotation_matrix[:, 0]
+ )
+ proj_y = zone_axis_rotation_matrix[:, 1] / np.linalg.norm(
+ zone_axis_rotation_matrix[:, 1]
+ )
+
+ # the foil normal should be the zone axis if unspecified
+ if foil_normal_lattice is None:
+ foil_normal_lattice = zone_axis_lattice
+ if foil_normal_cartesian is None:
+ foil_normal_cartesian = zone_axis_cartesian
+
+ # TODO: refine pixel size to center reflections on pixels
+
+ # Generate list of plane waves inside aperture
+ alpha_pix = np.round(
+ alpha_rad / self.wavelength / pixel_size_inv_A
+ ) # radius of aperture in pixels
+
+ tx_pixels, ty_pixels = np.meshgrid(
+ np.arange(-alpha_pix, alpha_pix + 1), np.arange(-alpha_pix, alpha_pix + 1)
+ ) # plane waves in pixel units
+
+ # remove those outside circular aperture
+ keep_mask = np.hypot(tx_pixels, ty_pixels) < alpha_pix
+ tx_pixels = tx_pixels[keep_mask].astype(np.intp)
+ ty_pixels = ty_pixels[keep_mask].astype(np.intp)
+
+ tx_rad = tx_pixels / alpha_pix * alpha_rad
+ ty_rad = ty_pixels / alpha_pix * alpha_rad
+
+ # calculate plane waves as zone axes using small angle approximation for tilting
+ tZA = ZA - (tx_rad[:, None] * proj_x) - (ty_rad[:, None] * proj_y)
+
+ if LACBED:
+ # In LACBED mode, the default DP size is the same as one diffraction disk (2ɑ)
+ if DP_size_inv_A is None:
+ DP_size = [int(2 * alpha_pix), int(2 * alpha_pix)]
+ else:
+ DP_size = [
+ int(2 * DP_size_inv_A / pixel_size_inv_A),
+ int(2 * DP_size_inv_A / pixel_size_inv_A),
+ ]
+ else:
+ # determine DP size based on beams present, plus a little extra
+ qx_max = np.max(np.abs(beams.data["qx"])) / pixel_size_inv_A
+ qy_max = np.max(np.abs(beams.data["qy"])) / pixel_size_inv_A
+
+ if DP_size_inv_A is None:
+ DP_size = [
+ int(2 * (qx_max + 2 * alpha_pix)),
+ int(2 * (qy_max + 2 * alpha_pix)),
+ ]
+ else:
+ DP_size = [
+ int(2 * DP_size_inv_A / pixel_size_inv_A),
+ int(2 * DP_size_inv_A / pixel_size_inv_A),
+ ]
+
+ qx0 = DP_size[0] // 2
+ qy0 = DP_size[1] // 2
+
+ thickness = np.atleast_1d(thickness)
+
+ if LACBED:
+ # In LACBED mode, the DP classes is a list of dicts mapping tuples of ints to numpy arrays
+ DP = [
+ {
+ (d["h"], d["k"], d["l"]): np.zeros(DP_size, dtype=dtype)
+ for d in beams.data
+ }
+ for _ in range(len(thickness))
+ ]
+ else:
+ # In CBED mode, the DP classes is a list of arrays
+ DP = [np.zeros(DP_size, dtype=dtype) for _ in range(len(thickness))]
+
+ if return_probe:
+ probe = np.zeros(DP_size, dtype=dtype)
+
+ mask = np.zeros(DP_size, dtype=np.bool_)
+
+ Ugmh_cache = DynamicalMatrixCache()
+
+ for i in tqdm(range(len(tZA)), disable=not progress_bar):
+ bloch = self.generate_dynamical_diffraction_pattern(
+ beams,
+ thickness=thickness,
+ zone_axis_cartesian=tZA[i],
+ foil_normal_cartesian=foil_normal_cartesian,
+ foil_normal_lattice=foil_normal_lattice,
+ always_return_list=True,
+ dynamical_matrix_cache=Ugmh_cache,
+ )
+ if return_probe:
+ probe[tx_pixels[i] + qx0, ty_pixels[i] + qy0] = 1
+
+ if LACBED:
+ # loop over each thickness
+ for patt, sim in zip(DP, bloch):
+ # loop over each beam
+ for refl in sim.data:
+ patt[(refl["h"], refl["k"], refl["l"])][
+ qx0 + tx_pixels[i], qy0 + ty_pixels[i]
+ ] = refl["intensity"]
+ else:
+ xpix = np.round(
+ bloch[0].data["qx"] / pixel_size_inv_A + tx_pixels[i] + qx0
+ ).astype(np.intp)
+ ypix = np.round(
+ bloch[0].data["qy"] / pixel_size_inv_A + ty_pixels[i] + qy0
+ ).astype(np.intp)
+
+ keep_mask = np.logical_and.reduce(
+ (xpix >= 0, ypix >= 0, xpix < DP_size[0], ypix < DP_size[1])
+ )
+
+ xpix = xpix[keep_mask]
+ ypix = ypix[keep_mask]
+
+ mask[xpix, ypix] = True
+
+ for patt, sim in zip(DP, bloch):
+ patt[xpix, ypix] += sim.data["intensity"][keep_mask]
+
+ if not return_probe:
+ if return_mask:
+ return (DP[0], mask) if len(thickness) == 1 else (DP, mask)
+ else:
+ return DP[0] if len(thickness) == 1 else DP
+ else:
+ if return_mask:
+ return (DP[0], probe, mask) if len(thickness) == 1 else (DP, probe, mask)
+ else:
+ return (DP[0], probe) if len(thickness) == 1 else (DP, probe)
diff --git a/py4DSTEM/process/diffraction/crystal_calibrate.py b/py4DSTEM/process/diffraction/crystal_calibrate.py
new file mode 100644
index 000000000..c068bf79e
--- /dev/null
+++ b/py4DSTEM/process/diffraction/crystal_calibrate.py
@@ -0,0 +1,485 @@
+import numpy as np
+from typing import Union, Optional
+from scipy.optimize import curve_fit
+
+from py4DSTEM.process.diffraction.utils import Orientation, calc_1D_profile
+
+try:
+ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
+ from pymatgen.core.structure import Structure
+except ImportError:
+ pass
+
+
+def calibrate_pixel_size(
+ self,
+ bragg_peaks,
+ scale_pixel_size=1.0,
+ bragg_k_power=1.0,
+ bragg_intensity_power=1.0,
+ k_min=0.0,
+ k_max=None,
+ k_step=0.002,
+ k_broadening=0.002,
+ fit_all_intensities=True,
+ set_calibration_in_place=False,
+ verbose=True,
+ plot_result=False,
+ figsize: Union[list, tuple, np.ndarray] = (12, 6),
+ returnfig=False,
+):
+ """
+ Use the calculated structure factor scattering lengths to compute 1D
+ diffraction patterns, and solve the best-fit relative scaling between them.
+ Returns the fit pixel size in Å^-1.
+
+ Args:
+ bragg_peaks (BraggVectors): Input Bragg vectors.
+ scale_pixel_size (float): Initial guess for scaling of the existing
+ pixel size If the pixel size is currently uncalibrated, this is a
+ guess of the pixel size in Å^-1. If the pixel size is already
+ (approximately) calibrated, this is the scaling factor to
+ correct that existing calibration.
+ bragg_k_power (float): Input Bragg peak intensities are multiplied by
+ k**bragg_k_power to change the weighting of longer scattering vectors
+ bragg_intensity_power (float): Input Bragg peak intensities are raised
+ power **bragg_intensity_power.
+ k_min (float): min k value for fitting range (Å^-1)
+ k_max (float): max k value for fitting range (Å^-1)
+ k_step (float) step size of k in fitting range (Å^-1)
+ k_broadening (float): Initial guess for Gaussian broadening of simulated
+ pattern (Å^-1)
+ fit_all_intensities (bool): Set to true to allow all peak intensities to
+ change independently False forces a single intensity scaling.
+ set_calibration (bool): if True, set the fit pixel size to the calibration
+ metadata, and calibrate bragg_peaks
+ verbose (bool): Output the calibrated pixel size.
+ plot_result (bool): Plot the resulting fit.
+ figsize (list, tuple, np.ndarray): Figure size of the plot.
+ returnfig (bool): Return handles figure and axis
+
+ Returns
+ _______
+
+
+
+ fig, ax: handles, optional
+ Figure and axis handles, if returnfig=True.
+
+ """
+
+ assert hasattr(self, "struct_factors"), "Compute structure factors first..."
+
+ # Prepare experimental data
+ k, int_exp = self.calculate_bragg_peak_histogram(
+ bragg_peaks, bragg_k_power, bragg_intensity_power, k_min, k_max, k_step
+ )
+
+ # Perform fitting
+ def fit_profile(k, *coefs):
+ scale_pixel_size = coefs[0]
+ k_broadening = coefs[1]
+ int_scale = coefs[2:]
+
+ int_sf = calc_1D_profile(
+ k,
+ self.g_vec_leng * scale_pixel_size,
+ self.struct_factors_int,
+ k_broadening=k_broadening,
+ int_scale=int_scale,
+ normalize_intensity=False,
+ )
+ return int_sf
+
+ if fit_all_intensities:
+ coefs = (
+ scale_pixel_size,
+ k_broadening,
+ *tuple(np.ones(self.g_vec_leng.shape[0])),
+ )
+ bounds = (0.0, np.inf)
+ popt, pcov = curve_fit(fit_profile, k, int_exp, p0=coefs, bounds=bounds)
+ else:
+ coefs = (scale_pixel_size, k_broadening, 1.0)
+ bounds = (0.0, np.inf)
+ popt, pcov = curve_fit(fit_profile, k, int_exp, p0=coefs, bounds=bounds)
+
+ scale_pixel_size = popt[0]
+ k_broadening = popt[1]
+ int_scale = np.array(popt[2:])
+
+ # Get the answer
+ pix_size_prev = bragg_peaks.calibration.get_Q_pixel_size()
+ pixel_size_new = pix_size_prev / scale_pixel_size
+
+ # if requested, apply calibrations in place
+ if set_calibration_in_place:
+ bragg_peaks.calibration.set_Q_pixel_size(pixel_size_new)
+ bragg_peaks.calibration.set_Q_pixel_units("A^-1")
+
+ # Output calibrated Bragg peaks
+ bragg_peaks_cali = bragg_peaks.copy()
+ bragg_peaks_cali.calibration.set_Q_pixel_size(pixel_size_new)
+ bragg_peaks_cali.calibration.set_Q_pixel_units("A^-1")
+
+ # Output pixel size
+ if verbose:
+ print(f"Calibrated pixel size = {np.round(pixel_size_new, decimals=8)} A^-1")
+
+ # Plotting
+ if plot_result:
+ if int_scale.shape[0] < self.g_vec_leng.shape[0]:
+ int_scale = np.hstack(
+ (int_scale, np.ones(self.g_vec_leng.shape[0] - int_scale.shape[0]))
+ )
+ elif int_scale.shape[0] > self.g_vec_leng.shape[0]:
+ print(int_scale.shape[0])
+ int_scale = int_scale[: self.g_vec_leng.shape[0]]
+
+ if returnfig:
+ fig, ax = self.plot_scattering_intensity(
+ bragg_peaks=bragg_peaks,
+ figsize=figsize,
+ k_broadening=k_broadening,
+ int_power_scale=1.0,
+ int_scale=int_scale,
+ bragg_k_power=bragg_k_power,
+ bragg_intensity_power=bragg_intensity_power,
+ k_min=k_min,
+ k_max=k_max,
+ returnfig=True,
+ )
+ else:
+ self.plot_scattering_intensity(
+ bragg_peaks=bragg_peaks,
+ figsize=figsize,
+ k_broadening=k_broadening,
+ int_power_scale=1.0,
+ int_scale=int_scale,
+ bragg_k_power=bragg_k_power,
+ bragg_intensity_power=bragg_intensity_power,
+ k_min=k_min,
+ k_max=k_max,
+ )
+
+ # return
+ if returnfig and plot_result:
+ return bragg_peaks_cali, (fig, ax)
+ else:
+ return bragg_peaks_cali
+
+
+def calibrate_unit_cell(
+ self,
+ bragg_peaks,
+ coef_index=None,
+ coef_update=None,
+ bragg_k_power=1.0,
+ bragg_intensity_power=1.0,
+ k_min=0.0,
+ k_max=None,
+ k_step=0.005,
+ k_broadening=0.02,
+ fit_all_intensities=True,
+ verbose=True,
+ plot_result=False,
+ figsize: Union[list, tuple, np.ndarray] = (12, 6),
+ returnfig=False,
+):
+ """
+ Solve for the best fit scaling between the computed structure factors and bragg_peaks.
+
+ Args:
+ bragg_peaks (BraggVectors): Input Bragg vectors.
+ coef_index (list of ints): List of ints that act as pointers to unit cell parameters and angles to update.
+ coef_update (list of bool): List of booleans to indicate whether or not to update the cell at
+ that position
+ bragg_k_power (float): Input Bragg peak intensities are multiplied by k**bragg_k_power
+ to change the weighting of longer scattering vectors
+ bragg_intensity_power (float): Input Bragg peak intensities are raised power **bragg_intensity_power.
+ k_min (float): min k value for fitting range (Å^-1)
+ k_max (float): max k value for fitting range (Å^-1)
+ k_step (float): step size of k in fitting range (Å^-1)
+ k_broadening (float): Initial guess for Gaussian broadening of simulated pattern (Å^-1)
+ fit_all_intensities (bool): Set to true to allow all peak intensities to change independently
+ False forces a single intensity scaling.
+ verbose (bool): Output the calibrated pixel size.
+ plot_result (bool): Plot the resulting fit.
+ figsize (list, tuple, np.ndarray) Figure size of the plot.
+ returnfig (bool): Return handles figure and axis
+
+ Returns:
+ fig, ax (handles): Optional figure and axis handles, if returnfig=True.
+
+ Details:
+ User has the option to define what is allowed to update in the unit cell using the arguments
+ coef_index and coef_update. Each has 6 entries, corresponding to the a, b, c, alpha, beta, gamma
+ parameters of the unit cell, in this order. The coef_update argument is a list of bools specifying
+ whether or not the unit cell value will be allowed to change (True) or must maintain the original
+ value (False) upon fitting. The coef_index argument provides a pointer to the index in which the
+ code will update to.
+
+ For example, to update a, b, c, alpha, beta, gamma all independently of eachother, the following
+ arguments should be used:
+ coef_index = [0, 1, 2, 3, 4, 5]
+ coef_update = [True, True, True, True, True, True,]
+
+ The default is set to automatically define what can update in a unit cell based on the
+ point group constraints. When either 'coef_index' or 'coef_update' are None, these constraints
+ will be automatically pulled from the pointgroup.
+
+ For example, the default for cubic unit cells is:
+ coef_index = [0, 0, 0, 3, 3, 3]
+ coef_update = [True, True, True, False, False, False]
+ Which allows a, b, and c to update (True in first 3 indices of coef_update)
+ but b and c update based on the value of a (0 in the 1 and 2 list entries in coef_index) such
+ that a = b = c. While coef_update is False for alpha, beta, and gamma (entries 3, 4, 5), no
+ updates will be made to the angles.
+
+ The user has the option to predefine coef_index or coef_update to override defaults. In the
+ coef_update list, there must be 6 entries and each are boolean. In the coef_index list, there
+ must be 6 entries, with the first 3 entries being between 0 - 2 and the last 3 entries between
+ 3 - 5. These act as pointers to pull the updated parameter from.
+
+ """
+ # initialize structure
+ if coef_index is None or coef_update is None:
+ structure = Structure(
+ self.lat_real, self.numbers, self.positions, coords_are_cartesian=False
+ )
+ self.pointgroup = SpacegroupAnalyzer(structure)
+ assert (
+ self.pointgroup.get_point_group_symbol() in parameter_updates
+ ), "Unrecognized pointgroup returned by pymatgen!"
+ coef_index, coef_update = parameter_updates[
+ self.pointgroup.get_point_group_symbol()
+ ]
+
+ # Prepare experimental data
+ k, int_exp = self.calculate_bragg_peak_histogram(
+ bragg_peaks, bragg_k_power, bragg_intensity_power, k_min, k_max, k_step
+ )
+
+ # Define Fitting Class
+ class FitCrystal:
+ def __init__(
+ self,
+ crystal,
+ coef_index,
+ coef_update,
+ fit_all_intensities,
+ ):
+ self.coefs_init = crystal.cell
+ self.hkl = crystal.hkl
+ self.struct_factors_int = crystal.struct_factors_int
+ self.coef_index = coef_index
+ self.coef_update = coef_update
+
+ def get_coefs(
+ self,
+ coefs_fit,
+ ):
+ coefs = np.zeros_like(coefs_fit)
+ for a0 in range(6):
+ if self.coef_update[a0]:
+ coefs[a0] = coefs_fit[self.coef_index[a0]]
+ else:
+ coefs[a0] = self.coefs_init[a0]
+ coefs[6:] = coefs_fit[6:]
+
+ return coefs
+
+ def fitfun(self, k, *coefs_fit):
+ coefs = self.get_coefs(coefs_fit=coefs_fit)
+
+ # Update g vector positions
+ a, b, c = coefs[:3]
+ alpha = np.deg2rad(coefs[3])
+ beta = np.deg2rad(coefs[4])
+ gamma = np.deg2rad(coefs[5])
+ f = np.cos(beta) * np.cos(gamma) - np.cos(alpha)
+ vol = (
+ a
+ * b
+ * c
+ * np.sqrt(
+ 1
+ + 2 * np.cos(alpha) * np.cos(beta) * np.cos(gamma)
+ - np.cos(alpha) ** 2
+ - np.cos(beta) ** 2
+ - np.cos(gamma) ** 2
+ )
+ )
+ lat_real = np.array(
+ [
+ [a, 0, 0],
+ [b * np.cos(gamma), b * np.sin(gamma), 0],
+ [
+ c * np.cos(beta),
+ -c * f / np.sin(gamma),
+ vol / (a * b * np.sin(gamma)),
+ ],
+ ]
+ )
+ # Inverse lattice, metric tensors
+ metric_real = lat_real @ lat_real.T
+ metric_inv = np.linalg.inv(metric_real)
+ lat_inv = metric_inv @ lat_real
+ g_vec_all = (self.hkl.T @ lat_inv).T
+ g_vec_leng = np.linalg.norm(g_vec_all, axis=0)
+
+ # Calculate fitted intensity profile
+ k_broadening = coefs[6]
+ int_scale = coefs[7:]
+ int_sf = calc_1D_profile(
+ k,
+ g_vec_leng,
+ self.struct_factors_int,
+ k_broadening=k_broadening,
+ int_scale=int_scale,
+ normalize_intensity=False,
+ )
+
+ return int_sf
+
+ fit_crystal = FitCrystal(
+ self,
+ coef_index=coef_index,
+ coef_update=coef_update,
+ fit_all_intensities=fit_all_intensities,
+ )
+
+ if fit_all_intensities:
+ coefs = (
+ *tuple(self.cell),
+ k_broadening,
+ *tuple(np.ones(self.g_vec_leng.shape[0])),
+ )
+ bounds = (0.0, np.inf)
+ popt, pcov = curve_fit(
+ fit_crystal.fitfun,
+ k,
+ int_exp,
+ p0=coefs,
+ bounds=bounds,
+ )
+ else:
+ coefs = (
+ *tuple(self.cell),
+ k_broadening,
+ 1.0,
+ )
+ bounds = (0.0, np.inf)
+ popt, pcov = curve_fit(
+ fit_crystal.fitfun,
+ k,
+ int_exp,
+ p0=coefs,
+ bounds=bounds,
+ )
+
+ if verbose:
+ cell_init = self.cell
+ # Update crystal with new lattice parameters
+ self.cell = fit_crystal.get_coefs(popt[:6])
+ self.calculate_lattice()
+ self.calculate_structure_factors(self.k_max)
+
+ # Output
+ if verbose:
+ # Print unit cell parameters
+ print("Original unit cell = " + str(cell_init))
+ print("Calibrated unit cell = " + str(self.cell))
+
+ # Plotting
+ if plot_result:
+ k_broadening = popt[6]
+ int_scale = popt[7:]
+ if int_scale.shape[0] < self.g_vec_leng.shape[0]:
+ int_scale = np.hstack(
+ (int_scale, np.ones(self.g_vec_leng.shape[0] - int_scale.shape[0]))
+ )
+ elif int_scale.shape[0] > self.g_vec_leng.shape[0]:
+ print(int_scale.shape[0])
+ int_scale = int_scale[: self.g_vec_leng.shape[0]]
+
+ if returnfig:
+ fig, ax = self.plot_scattering_intensity(
+ bragg_peaks=bragg_peaks,
+ figsize=figsize,
+ k_broadening=k_broadening,
+ int_power_scale=1.0,
+ int_scale=int_scale,
+ bragg_k_power=bragg_k_power,
+ bragg_intensity_power=bragg_intensity_power,
+ k_min=k_min,
+ k_max=k_max,
+ returnfig=True,
+ )
+ else:
+ self.plot_scattering_intensity(
+ bragg_peaks=bragg_peaks,
+ figsize=figsize,
+ k_broadening=k_broadening,
+ int_power_scale=1.0,
+ int_scale=int_scale,
+ bragg_k_power=bragg_k_power,
+ bragg_intensity_power=bragg_intensity_power,
+ k_min=k_min,
+ k_max=k_max,
+ )
+
+ if returnfig and plot_result:
+ return fig, ax
+ else:
+ return
+
+
+# coef_index and coef_update sets for the fit_unit_cell function, in the order:
+# [coef_index, coef_update]
+parameter_updates = {
+ "1": [[0, 1, 2, 3, 4, 5], [True, True, True, True, True, True]], # Triclinic
+ "-1": [[0, 1, 2, 3, 4, 5], [True, True, True, True, True, True]], # Triclinic
+ "2": [[0, 1, 2, 3, 4, 3], [True, True, True, False, True, False]], # Monoclinic
+ "m": [[0, 1, 2, 3, 4, 3], [True, True, True, False, True, False]], # Monoclinic
+ "2/m": [[0, 1, 2, 3, 4, 3], [True, True, True, False, True, False]], # Monoclinic
+ "222": [
+ [0, 1, 2, 3, 3, 3],
+ [True, True, True, False, False, False],
+ ], # Orthorhombic
+ "mm2": [
+ [0, 1, 2, 3, 3, 3],
+ [True, True, True, False, False, False],
+ ], # Orthorhombic
+ "mmm": [
+ [0, 1, 2, 3, 3, 3],
+ [True, True, True, False, False, False],
+ ], # Orthorhombic
+ "4": [[0, 0, 2, 3, 3, 3], [True, True, True, False, False, False]], # Tetragonal
+ "-4": [[0, 0, 2, 3, 3, 3], [True, True, True, False, False, False]], # Tetragonal
+ "4/m": [[0, 0, 2, 3, 3, 3], [True, True, True, False, False, False]], # Tetragonal
+ "422": [[0, 0, 2, 3, 3, 3], [True, True, True, False, False, False]], # Tetragonal
+ "4mm": [[0, 0, 2, 3, 3, 3], [True, True, True, False, False, False]], # Tetragonal
+ "-42m": [[0, 0, 2, 3, 3, 3], [True, True, True, False, False, False]], # Tetragonal
+ "4/mmm": [
+ [0, 0, 2, 3, 3, 3],
+ [True, True, True, False, False, False],
+ ], # Tetragonal
+ "3": [[0, 0, 0, 3, 3, 3], [True, True, True, True, True, True]], # Trigonal
+ "-3": [[0, 0, 0, 3, 3, 3], [True, True, True, True, True, True]], # Trigonal
+ "32": [[0, 0, 0, 3, 3, 3], [True, True, True, True, True, True]], # Trigonal
+ "3m": [[0, 0, 0, 3, 3, 3], [True, True, True, True, True, True]], # Trigonal
+ "-3m": [[0, 0, 0, 3, 3, 3], [True, True, True, True, True, True]], # Trigonal
+ "6": [[0, 0, 2, 3, 3, 5], [True, True, True, False, False, True]], # Hexagonal
+ "-6": [[0, 0, 2, 3, 3, 5], [True, True, True, False, False, True]], # Hexagonal
+ "6/m": [[0, 0, 2, 3, 3, 5], [True, True, True, False, False, True]], # Hexagonal
+ "622": [[0, 0, 2, 3, 3, 5], [True, True, True, False, False, True]], # Hexagonal
+ "6mm": [[0, 0, 2, 3, 3, 5], [True, True, True, False, False, True]], # Hexagonal
+ "-6m2": [[0, 0, 2, 3, 3, 5], [True, True, True, False, False, True]], # Hexagonal
+ "6/mmm": [[0, 0, 2, 3, 3, 5], [True, True, True, False, False, True]], # Hexagonal
+ "23": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], # Cubic
+ "m-3": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], # Cubic
+ "432": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], # Cubic
+ "-43m": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], # Cubic
+ "m-3m": [[0, 0, 0, 3, 3, 3], [True, True, True, False, False, False]], # Cubic
+}
diff --git a/py4DSTEM/process/diffraction/crystal_phase.py b/py4DSTEM/process/diffraction/crystal_phase.py
new file mode 100644
index 000000000..bac1cf8c7
--- /dev/null
+++ b/py4DSTEM/process/diffraction/crystal_phase.py
@@ -0,0 +1,338 @@
+import numpy as np
+from numpy.linalg import lstsq
+from scipy.optimize import nnls
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+
+from emdfile import tqdmnd, PointListArray
+from py4DSTEM.visualize import show, show_image_grid
+from py4DSTEM.process.diffraction.crystal_viz import plot_diffraction_pattern
+
+
+class Crystal_Phase:
+ """
+ A class storing multiple crystal structures, and associated diffraction data.
+ Must be initialized after matching orientations to a pointlistarray???
+
+ """
+
+ def __init__(
+ self,
+ crystals,
+ orientation_maps,
+ name,
+ ):
+ """
+ Args:
+ crystals (list): List of crystal instances
+ orientation_maps (list): List of orientation maps
+ name (str): Name of Crystal_Phase instance
+ """
+ if isinstance(crystals, list):
+ self.crystals = crystals
+ self.num_crystals = len(crystals)
+ else:
+ raise TypeError("crystals must be a list of crystal instances.")
+ if isinstance(orientation_maps, list):
+ if len(self.crystals) != len(orientation_maps):
+ raise ValueError(
+ "Orientation maps must have the same number of entries as crystals."
+ )
+ self.orientation_maps = orientation_maps
+ else:
+ raise TypeError("orientation_maps must be a list of orientation maps.")
+ self.name = name
+ return
+
+ def plot_all_phase_maps(self, map_scale_values=None, index=0):
+ """
+ Visualize phase maps of dataset.
+
+ Args:
+ map_scale_values (float): Value to scale correlations by
+ """
+ phase_maps = []
+ if map_scale_values is None:
+ map_scale_values = [1] * len(self.orientation_maps)
+ corr_sum = np.sum(
+ [
+ (self.orientation_maps[m].corr[:, :, index] * map_scale_values[m])
+ for m in range(len(self.orientation_maps))
+ ]
+ )
+ for m in range(len(self.orientation_maps)):
+ phase_maps.append(self.orientation_maps[m].corr[:, :, index] / corr_sum)
+ show_image_grid(lambda i: phase_maps[i], 1, len(phase_maps), cmap="inferno")
+ return
+
+ def plot_phase_map(self, index=0, cmap=None):
+ corr_array = np.dstack(
+ [maps.corr[:, :, index] for maps in self.orientation_maps]
+ )
+ best_corr_score = np.max(corr_array, axis=2)
+ best_match_phase = [
+ np.where(corr_array[:, :, p] == best_corr_score, True, False)
+ for p in range(len(self.orientation_maps))
+ ]
+
+ if cmap is None:
+ cm = plt.get_cmap("rainbow")
+ cmap = [
+ cm(1.0 * i / len(self.orientation_maps))
+ for i in range(len(self.orientation_maps))
+ ]
+
+ fig, (ax) = plt.subplots(figsize=(6, 6))
+ ax.matshow(
+ np.zeros((self.orientation_maps[0].num_x, self.orientation_maps[0].num_y)),
+ cmap="gray",
+ )
+ ax.axis("off")
+
+ for m in range(len(self.orientation_maps)):
+ c0, c1 = (cmap[m][0] * 0.35, cmap[m][1] * 0.35, cmap[m][2] * 0.35, 1), cmap[
+ m
+ ]
+ cm = mpl.colors.LinearSegmentedColormap.from_list("cmap", [c0, c1], N=10)
+ ax.matshow(
+ np.ma.array(
+ self.orientation_maps[m].corr[:, :, index], mask=best_match_phase[m]
+ ),
+ cmap=cm,
+ )
+ plt.show()
+
+ return
+
+ # Potentially introduce a way to check best match out of all orientations in phase plan and plug into model
+ # to quantify phase
+
+ # def phase_plan(
+ # self,
+ # method,
+ # zone_axis_range: np.ndarray = np.array([[0, 1, 1], [1, 1, 1]]),
+ # angle_step_zone_axis: float = 2.0,
+ # angle_coarse_zone_axis: float = None,
+ # angle_refine_range: float = None,
+ # angle_step_in_plane: float = 2.0,
+ # accel_voltage: float = 300e3,
+ # intensity_power: float = 0.25,
+ # tol_peak_delete=None,
+ # tol_distance: float = 0.01,
+ # fiber_axis = None,
+ # fiber_angles = None,
+ # ):
+ # return
+
+ def quantify_phase(
+ self,
+ pointlistarray,
+ tolerance_distance=0.08,
+ method="nnls",
+ intensity_power=0,
+ mask_peaks=None,
+ ):
+ """
+ Quantification of the phase of a crystal based on the crystal instances and the pointlistarray.
+
+ Args:
+ pointlisarray (pointlistarray): Pointlistarray to quantify phase of
+ tolerance_distance (float): Distance allowed between a peak and match
+ method (str): Numerical method used to quantify phase
+ intensity_power (float): ...
+ mask_peaks (list, optional): A pointer of which positions to mask peaks from
+
+ Details:
+ """
+ if isinstance(pointlistarray, PointListArray):
+ phase_weights = np.zeros(
+ (
+ pointlistarray.shape[0],
+ pointlistarray.shape[1],
+ np.sum([map.num_matches for map in self.orientation_maps]),
+ )
+ )
+ phase_residuals = np.zeros(pointlistarray.shape)
+ for Rx, Ry in tqdmnd(pointlistarray.shape[0], pointlistarray.shape[1]):
+ (
+ _,
+ phase_weight,
+ phase_residual,
+ crystal_identity,
+ ) = self.quantify_phase_pointlist(
+ pointlistarray,
+ position=[Rx, Ry],
+ tolerance_distance=tolerance_distance,
+ method=method,
+ intensity_power=intensity_power,
+ mask_peaks=mask_peaks,
+ )
+ phase_weights[Rx, Ry, :] = phase_weight
+ phase_residuals[Rx, Ry] = phase_residual
+ self.phase_weights = phase_weights
+ self.phase_residuals = phase_residuals
+ self.crystal_identity = crystal_identity
+ return
+ else:
+ return TypeError("pointlistarray must be of type pointlistarray.")
+ return
+
+ def quantify_phase_pointlist(
+ self,
+ pointlistarray,
+ position,
+ method="nnls",
+ tolerance_distance=0.08,
+ intensity_power=0,
+ mask_peaks=None,
+ ):
+ """
+ Args:
+ pointlisarray (pointlistarray): Pointlistarray to quantify phase of
+ position (tuple/list): Position of pointlist in pointlistarray
+ tolerance_distance (float): Distance allowed between a peak and match
+ method (str): Numerical method used to quantify phase
+ intensity_power (float): ...
+ mask_peaks (list, optional): A pointer of which positions to mask peaks from
+
+ Returns:
+ pointlist_peak_intensity_matches (np.ndarray): Peak matches in the rows of array and the crystals in the columns
+ phase_weights (np.ndarray): Weights of each phase
+ phase_residuals (np.ndarray): Residuals
+ crystal_identity (list): List of lists, where the each entry represents the position in the
+ crystal and orientation match that is associated with the phase
+ weights. for example, if the output was [[0,0], [0,1], [1,0], [0,1]],
+ the first entry [0,0] in phase weights is associated with the first crystal
+ the first match within that crystal. [0,1] is the first crystal and the
+ second match within that crystal.
+ """
+ # Things to add:
+ # 1. Better cost for distance from peaks in pointlists
+ # 2. Iterate through multiple tolerance_distance values to find best value. Cost function residuals, or something else?
+
+ pointlist = pointlistarray.get_pointlist(position[0], position[1])
+ pl_mask = np.where((pointlist["qx"] == 0) & (pointlist["qy"] == 0), 1, 0)
+ pointlist.remove(pl_mask)
+ # False Negatives (exp peak with no match in crystal instances) will appear here, already coded in
+
+ if intensity_power == 0:
+ pl_intensities = np.ones(pointlist["intensity"].shape)
+ else:
+ pl_intensities = pointlist["intensity"] ** intensity_power
+ # Prepare matches for modeling
+ pointlist_peak_matches = []
+ crystal_identity = []
+
+ for c in range(len(self.crystals)):
+ for m in range(self.orientation_maps[c].num_matches):
+ crystal_identity.append([c, m])
+ phase_peak_match_intensities = np.zeros((pointlist["intensity"].shape))
+ bragg_peaks_fit = self.crystals[c].generate_diffraction_pattern(
+ self.orientation_maps[c].get_orientation(position[0], position[1]),
+ ind_orientation=m,
+ )
+ # Find the best match peak within tolerance_distance and add value in the right position
+ for d in range(pointlist["qx"].shape[0]):
+ distances = []
+ for p in range(bragg_peaks_fit["qx"].shape[0]):
+ distances.append(
+ np.sqrt(
+ (pointlist["qx"][d] - bragg_peaks_fit["qx"][p]) ** 2
+ + (pointlist["qy"][d] - bragg_peaks_fit["qy"][p]) ** 2
+ )
+ )
+ ind = np.where(distances == np.min(distances))[0][0]
+
+ # Potentially for-loop over multiple values for 'tolerance_distance' to find best tolerance_distance value
+ if distances[ind] <= tolerance_distance:
+ ## Somewhere in this if statement is probably where better distances from the peak should be coded in
+ if (
+ intensity_power == 0
+ ): # This could potentially be a different intensity_power arg
+ phase_peak_match_intensities[d] = 1 ** (
+ (tolerance_distance - distances[ind])
+ / tolerance_distance
+ )
+ else:
+ phase_peak_match_intensities[d] = bragg_peaks_fit[
+ "intensity"
+ ][ind] ** (
+ (tolerance_distance - distances[ind])
+ / tolerance_distance
+ )
+ else:
+ ## This is probably where the false positives (peaks in crystal but not in experiment) should be handled
+ continue
+
+ pointlist_peak_matches.append(phase_peak_match_intensities)
+ pointlist_peak_intensity_matches = np.dstack(pointlist_peak_matches)
+ pointlist_peak_intensity_matches = (
+ pointlist_peak_intensity_matches.reshape(
+ pl_intensities.shape[0],
+ pointlist_peak_intensity_matches.shape[-1],
+ )
+ )
+
+ if len(pointlist["qx"]) > 0:
+ if mask_peaks is not None:
+ for i in range(len(mask_peaks)):
+ if mask_peaks[i] is None:
+ continue
+ inds_mask = np.where(
+ pointlist_peak_intensity_matches[:, mask_peaks[i]] != 0
+ )[0]
+ for mask in range(len(inds_mask)):
+ pointlist_peak_intensity_matches[inds_mask[mask], i] = 0
+
+ if method == "nnls":
+ phase_weights, phase_residuals = nnls(
+ pointlist_peak_intensity_matches, pl_intensities
+ )
+
+ elif method == "lstsq":
+ phase_weights, phase_residuals, rank, singluar_vals = lstsq(
+ pointlist_peak_intensity_matches, pl_intensities, rcond=-1
+ )
+ phase_residuals = np.sum(phase_residuals)
+ else:
+ raise ValueError(method + " Not yet implemented. Try nnls or lstsq.")
+ else:
+ phase_weights = np.zeros((pointlist_peak_intensity_matches.shape[1],))
+ phase_residuals = np.NaN
+ return (
+ pointlist_peak_intensity_matches,
+ phase_weights,
+ phase_residuals,
+ crystal_identity,
+ )
+
+ # def plot_peak_matches(
+ # self,
+ # pointlistarray,
+ # position,
+ # tolerance_distance,
+ # ind_orientation,
+ # pointlist_peak_intensity_matches,
+ # ):
+ # """
+ # A method to view how the tolerance distance impacts the peak matches associated with
+ # the quantify_phase_pointlist method.
+
+ # Args:
+ # pointlistarray,
+ # position,
+ # tolerance_distance
+ # pointlist_peak_intensity_matches
+ # """
+ # pointlist = pointlistarray.get_pointlist(position[0],position[1])
+
+ # for m in range(pointlist_peak_intensity_matches.shape[1]):
+ # bragg_peaks_fit = self.crystals[m].generate_diffraction_pattern(
+ # self.orientation_maps[m].get_orientation(position[0], position[1]),
+ # ind_orientation = ind_orientation
+ # )
+ # peak_inds = np.where(bragg_peaks_fit.data['intensity'] == pointlist_peak_intensity_matches[:,m])
+
+ # fig, (ax1, ax2) = plt.subplots(2,1,figsize = figsize)
+ # ax1 = plot_diffraction_pattern(pointlist,)
+ # return
diff --git a/py4DSTEM/process/diffraction/crystal_viz.py b/py4DSTEM/process/diffraction/crystal_viz.py
new file mode 100644
index 000000000..94cf75b8c
--- /dev/null
+++ b/py4DSTEM/process/diffraction/crystal_viz.py
@@ -0,0 +1,2153 @@
+import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
+from matplotlib.axes import Axes
+import matplotlib.tri as mtri
+from mpl_toolkits.mplot3d import Axes3D, art3d
+from scipy.signal import medfilt
+from scipy.ndimage import gaussian_filter
+from scipy.ndimage import distance_transform_edt
+from skimage.morphology import dilation, erosion
+
+import warnings
+import numpy as np
+from typing import Union, Optional
+
+from emdfile import tqdmnd, PointList, PointListArray
+from py4DSTEM.process.diffraction.utils import calc_1D_profile
+
+
+def plot_structure(
+ self,
+ orientation_matrix: Optional[np.ndarray] = None,
+ zone_axis_lattice: Optional[np.ndarray] = None,
+ proj_x_lattice: Optional[np.ndarray] = None,
+ zone_axis_cartesian: Optional[np.ndarray] = None,
+ proj_x_cartesian: Optional[np.ndarray] = None,
+ size_marker: float = 400,
+ tol_distance: float = 0.001,
+ plot_limit: Optional[np.ndarray] = None,
+ camera_dist: Optional[float] = None,
+ show_axes: bool = False,
+ perspective_axes: bool = True,
+ figsize: Union[tuple, list, np.ndarray] = (8, 8),
+ returnfig: bool = False,
+):
+ """
+ Quick 3D plot of the untit cell /atomic structure.
+
+ Args:
+ orientation_matrix (array): (3,3) orientation matrix, where columns represent projection directions.
+ zone_axis_lattice (array): (3,) projection direction in lattice indices
+ proj_x_lattice (array): (3,) x-axis direction in lattice indices
+ zone_axis_cartesian (array): (3,) cartesian projection direction
+ proj_x_cartesian (array): (3,) cartesian projection direction
+ scale_markers (float): Size scaling for markers
+ tol_distance (float): Tolerance for repeating atoms on edges on cell boundaries.
+ plot_limit (float): (2,3) numpy array containing x y z plot min and max in columns.
+ Default is 1.1* unit cell dimensions.
+ camera_dist (float): Move camera closer to the plot (relative to matplotlib default of 10)
+ show_axes (bool): Whether to plot axes or not.
+ perspective_axes (bool): Select either perspective (true) or orthogonal (false) axes
+ figsize (2 element float): Size scaling of figure axes.
+ returnfig (bool): Return figure and axes handles.
+
+ Returns:
+ fig, ax (optional) figure and axes handles
+ """
+
+ # projection directions
+ if orientation_matrix is None:
+ orientation_matrix = self.parse_orientation(
+ zone_axis_lattice, proj_x_lattice, zone_axis_cartesian, proj_x_cartesian
+ )
+
+ # matplotlib camera orientation
+ if np.abs(abs(orientation_matrix[2, 2]) - 1) < 1e-6:
+ el = 90.0 * np.sign(orientation_matrix[2, 2])
+ else:
+ el = np.rad2deg(
+ np.arctan(
+ orientation_matrix[2, 2]
+ / np.sqrt(orientation_matrix[0, 2] ** 2 + orientation_matrix[1, 2] ** 2)
+ )
+ )
+ az = np.rad2deg(np.arctan2(orientation_matrix[0, 2], orientation_matrix[1, 2]))
+ # TODO roll is not yet implemented in matplot version 3.4.3
+ # matplotlib x projection direction (i.e. estimate the roll angle)
+ # init_y = np.cross(proj_z,np.array([0,1e-6,0]))
+ # init_x = np.cross(init_y,proj_z)
+ # init_x = init_x / np.linalg.norm(init_x)
+ # init_y = init_y / np.linalg.norm(init_y)
+ # beta = np.vstack((init_x,init_y)) @ proj_x[:,None]
+ # alpha = np.rad2deg(np.arctan2(beta[1], beta[0]))
+
+ # unit cell vectors
+ u = self.lat_real[0, :]
+ v = self.lat_real[1, :]
+ w = self.lat_real[2, :]
+
+ # atomic identities
+ ID = self.numbers
+
+ # Fractional atomic coordinates
+ pos = self.positions
+ # x tile
+ sub = pos[:, 0] < tol_distance
+ pos = np.vstack([pos, pos[sub, :] + np.array([1, 0, 0])])
+ ID = np.hstack([ID, ID[sub]])
+ # y tile
+ sub = pos[:, 1] < tol_distance
+ pos = np.vstack([pos, pos[sub, :] + np.array([0, 1, 0])])
+ ID = np.hstack([ID, ID[sub]])
+ # z tile
+ sub = pos[:, 2] < tol_distance
+ pos = np.vstack([pos, pos[sub, :] + np.array([0, 0, 1])])
+ ID = np.hstack([ID, ID[sub]])
+
+ # Cartesian atomic positions
+ xyz = pos @ self.lat_real
+
+ # 3D plotting
+ fig = plt.figure(figsize=figsize)
+ if perspective_axes:
+ ax = fig.add_subplot(projection="3d", elev=el, azim=az)
+ else:
+ ax = fig.add_subplot(projection="3d", elev=el, azim=az, proj_type="ortho")
+
+ # unit cell
+ p = np.vstack([[0, 0, 0], u, u + v, v, w, u + w, u + v + w, v + w])
+ p = p[:, [1, 0, 2]] # Reorder cell boundaries
+
+ f = np.array(
+ [
+ [0, 1, 2, 3],
+ [4, 5, 6, 7],
+ [0, 1, 5, 4],
+ [2, 3, 7, 6],
+ [0, 3, 7, 4],
+ [1, 2, 6, 5],
+ ]
+ )
+
+ # ax.plot3D(xline, yline, zline, 'gray')
+ pc = art3d.Poly3DCollection(
+ p[f],
+ facecolors=[0, 0.7, 1],
+ edgecolor=[0, 0, 0],
+ linewidth=2,
+ alpha=0.2,
+ )
+ ax.add_collection(pc)
+
+ # atoms
+ ID_all = np.unique(ID)
+ for ID_plot in ID_all:
+ sub = ID == ID_plot
+ ax.scatter(
+ xs=xyz[sub, 1], # + d[0],
+ ys=xyz[sub, 0], # + d[1],
+ zs=xyz[sub, 2], # + d[2],
+ s=size_marker,
+ linewidth=2,
+ facecolors=atomic_colors(ID_plot),
+ edgecolor=[0, 0, 0],
+ )
+
+ # plot limit
+ if plot_limit is None:
+ plot_limit = np.array(
+ [
+ [np.min(p[:, 0]), np.min(p[:, 1]), np.min(p[:, 2])],
+ [np.max(p[:, 0]), np.max(p[:, 1]), np.max(p[:, 2])],
+ ]
+ )
+ plot_limit = (plot_limit - np.mean(plot_limit, axis=0)) * 1.1 + np.mean(
+ plot_limit, axis=0
+ )
+
+ # appearance
+ ax.invert_yaxis()
+ if show_axes is False:
+ ax.set_axis_off()
+ ax.axes.set_xlim3d(left=plot_limit[0, 1], right=plot_limit[1, 1])
+ ax.axes.set_ylim3d(bottom=plot_limit[0, 0], top=plot_limit[1, 0])
+ ax.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2])
+ # ax.set_box_aspect((1, 1, 1))
+ axisEqual3D(ax)
+
+ if camera_dist is not None:
+ ax.dist = camera_dist
+
+ plt.show()
+
+ if returnfig:
+ return fig, ax
+
+
+def plot_structure_factors(
+ self,
+ orientation_matrix: Optional[np.ndarray] = None,
+ zone_axis_lattice: Optional[np.ndarray] = None,
+ proj_x_lattice: Optional[np.ndarray] = None,
+ zone_axis_cartesian: Optional[np.ndarray] = None,
+ proj_x_cartesian: Optional[np.ndarray] = None,
+ scale_markers: float = 1e3,
+ plot_limit: Optional[Union[list, tuple, np.ndarray]] = None,
+ camera_dist: Optional[float] = None,
+ show_axes: bool = True,
+ perspective_axes: bool = True,
+ figsize: Union[list, tuple, np.ndarray] = (8, 8),
+ returnfig: bool = False,
+):
+ """
+ 3D scatter plot of the structure factors using magnitude^2, i.e. intensity.
+
+ Args:
+ orientation_matrix (array): (3,3) orientation matrix, where columns represent projection directions.
+ zone_axis_lattice (array): (3,) projection direction in lattice indices
+ proj_x_lattice (array): (3,) x-axis direction in lattice indices
+ zone_axis_cartesian (array): (3,) cartesian projection direction
+ proj_x_cartesian (array): (3,) cartesian projection direction
+ scale_markers (float): size scaling for markers
+ plot_limit (float): x y z plot limits, default is [-1 1]*self.k_max
+ camera_dist (float): Move camera closer to the plot (relative to matplotlib default of 10)
+ show_axes (bool): Whether to plot axes or not.
+ perspective_axes (bool): Select either perspective (true) or orthogonal (false) axes
+ figsize (2 element float): size scaling of figure axes
+ returnfig (bool): set to True to return figure and axes handles
+
+ Returns:
+ fig, ax (optional) figure and axes handles
+ """
+
+ # projection directions
+ if orientation_matrix is None:
+ orientation_matrix = self.parse_orientation(
+ zone_axis_lattice, proj_x_lattice, zone_axis_cartesian, proj_x_cartesian
+ )
+
+ # matplotlib camera orientation
+ if np.abs(abs(orientation_matrix[2, 2]) - 1) < 1e-6:
+ el = 90.0 * np.sign(orientation_matrix[2, 2])
+ else:
+ el = np.rad2deg(
+ np.arctan(
+ orientation_matrix[2, 2]
+ / np.sqrt(orientation_matrix[0, 2] ** 2 + orientation_matrix[1, 2] ** 2)
+ )
+ )
+ az = np.rad2deg(np.arctan2(orientation_matrix[0, 2], orientation_matrix[1, 2]))
+
+ # TODO roll is not yet implemented in matplot version 3.4.3
+ # matplotlib x projection direction (i.e. estimate the roll angle)
+ # init_y = np.cross(proj_z,np.array([0,1e-6,0]))
+ # init_x = np.cross(init_y,proj_z)
+ # init_x = init_x / np.linalg.norm(init_x)
+ # init_y = init_y / np.linalg.norm(init_y)
+ # beta = np.vstack((init_x,init_y)) @ proj_x[:,None]
+ # alpha = np.rad2deg(np.arctan2(beta[1], beta[0]))
+
+ # 3D plotting
+ fig = plt.figure(figsize=figsize)
+ if perspective_axes:
+ ax = fig.add_subplot(projection="3d", elev=el, azim=az)
+ else:
+ ax = fig.add_subplot(projection="3d", elev=el, azim=az, proj_type="ortho")
+
+ # plot all structure factor points
+ ax.scatter(
+ xs=self.g_vec_all[1, :],
+ ys=self.g_vec_all[0, :],
+ zs=self.g_vec_all[2, :],
+ s=scale_markers * self.struct_factors_int,
+ )
+
+ # axes limits
+ if plot_limit is None:
+ plot_limit = self.k_max * 1.05
+
+ # appearance
+ ax.invert_yaxis()
+ if show_axes is False:
+ ax.set_axis_off()
+ ax.axes.set_xlim3d(left=-plot_limit, right=plot_limit)
+ ax.axes.set_ylim3d(bottom=-plot_limit, top=plot_limit)
+ ax.axes.set_zlim3d(bottom=-plot_limit, top=plot_limit)
+ axisEqual3D(ax)
+
+ if camera_dist is not None:
+ ax.dist = camera_dist
+
+ plt.show()
+
+ if returnfig:
+ return fig, ax
+
+
+def plot_scattering_intensity(
+ self,
+ k_min=0.0,
+ k_max=None,
+ k_step=0.001,
+ k_broadening=0.0,
+ k_power_scale=0.0,
+ int_power_scale=0.5,
+ int_scale=1.0,
+ remove_origin=True,
+ bragg_peaks=None,
+ bragg_k_power=0.0,
+ bragg_intensity_power=1.0,
+ bragg_k_broadening=0.005,
+ figsize: Union[list, tuple, np.ndarray] = (10, 4),
+ returnfig: bool = False,
+):
+ """
+ 1D plot of the structure factors
+
+ Parameters
+ --------
+
+ k_min: float
+ min k value for profile range.
+ k_max: float
+ max k value for profile range.
+ k_step: float
+ Step size of k in profile range.
+ k_broadening: float
+ Broadening of simulated pattern.
+ k_power_scale: float
+ Scale SF intensities by k**k_power_scale.
+ int_power_scale: float
+ Scale SF intensities**int_power_scale.
+ int_scale: float
+ Scale output profile by this value.
+ remove_origin: bool
+ Remove origin from plot.
+ bragg_peaks: BraggVectors
+ Passed in bragg_peaks for comparison with simulated pattern.
+ bragg_k_power: float
+ bragg_peaks scaled by k**bragg_k_power.
+ bragg_intensity_power: float
+ bragg_peaks scaled by intensities**bragg_intensity_power.
+ bragg_k_broadening: float
+ Broadening applied to bragg_peaks.
+ figsize: list, tuple, np.ndarray
+ Figure size for plot.
+ returnfig (bool):
+ Return figure and axes handles if this is True.
+
+ Returns
+ --------
+ fig, ax (optional)
+ figure and axes handles
+ """
+
+ # k coordinates
+ if k_max is None:
+ k_max = self.k_max
+ k = np.arange(k_min, k_max + k_step, k_step)
+ k_num = k.shape[0]
+
+ # get discrete plot from structure factor amplitudes
+ int_sf_plot = calc_1D_profile(
+ k,
+ self.g_vec_leng,
+ (self.struct_factors_int**int_power_scale)
+ * (self.g_vec_leng**k_power_scale),
+ remove_origin=True,
+ k_broadening=k_broadening,
+ int_scale=int_scale,
+ )
+
+ # If Bragg peaks are passed in, compute 1D integral
+ if bragg_peaks is not None:
+ # set rotate and ellipse based on their availability
+ rotate = bragg_peaks.calibration.get_QR_rotation_degrees()
+ ellipse = bragg_peaks.calibration.get_ellipse()
+ rotate = False if rotate is None else True
+ ellipse = False if ellipse is None else True
+
+ # concatenate all peaks
+ bigpl = np.concatenate(
+ [
+ bragg_peaks.get_vectors(
+ rx,
+ ry,
+ center=True,
+ ellipse=ellipse,
+ pixel=True,
+ rotate=rotate,
+ ).data
+ for rx in range(bragg_peaks.shape[0])
+ for ry in range(bragg_peaks.shape[1])
+ ]
+ )
+
+ # get radial positions and intensity
+ qr = np.sqrt(bigpl["qx"] ** 2 + bigpl["qy"] ** 2)
+ int_meas = bigpl["intensity"]
+
+ # get discrete plot from structure factor amplitudes
+ int_exp = np.zeros_like(k)
+ k_px = (qr - k_min) / k_step
+ kf = np.floor(k_px).astype("int")
+ dk = k_px - kf
+
+ sub = np.logical_and(kf >= 0, kf < k_num)
+ int_exp = np.bincount(
+ np.floor(k_px[sub]).astype("int"),
+ weights=(1 - dk[sub]) * int_meas[sub],
+ minlength=k_num,
+ )
+ sub = np.logical_and(k_px >= -1, k_px < k_num - 1)
+ int_exp += np.bincount(
+ np.floor(k_px[sub] + 1).astype("int"),
+ weights=dk[sub] * int_meas[sub],
+ minlength=k_num,
+ )
+
+ if bragg_k_broadening > 0.0:
+ int_exp = gaussian_filter(
+ int_exp, bragg_k_broadening / k_step, mode="constant"
+ )
+
+ int_exp_plot = (int_exp**bragg_intensity_power) * (k**bragg_k_power)
+ int_exp_plot /= np.max(int_exp_plot)
+
+ # Plotting
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
+ if bragg_peaks is not None:
+ ax.fill_between(k, int_exp_plot, facecolor=(1.0, 0.0, 0.0, 0.2))
+ ax.plot(k, int_exp_plot, c=(1.0, 0.0, 0.0, 0.8), linewidth=2)
+ ax.fill_between(k, int_sf_plot, facecolor=(0.0, 0.0, 0.0, 0.2))
+ ax.plot(k, int_sf_plot, c=(0.0, 0.0, 0.0, 0.8), linewidth=2)
+ # Appearance
+ ax.set_xlabel("Scattering Vector k [1/A]", fontsize=14)
+ ax.set_yticks([])
+ ax.set_ylabel("Magnitude", fontsize=14)
+
+ if returnfig:
+ return fig, ax
+
+
+def plot_orientation_zones(
+ self,
+ azim_elev: Optional[Union[list, tuple, np.ndarray]] = None,
+ proj_dir_lattice: Optional[Union[list, tuple, np.ndarray]] = None,
+ proj_dir_cartesian: Optional[Union[list, tuple, np.ndarray]] = None,
+ tol_den=10,
+ marker_size: float = 20,
+ plot_limit: Union[list, tuple, np.ndarray] = np.array([-1.1, 1.1]),
+ figsize: Union[list, tuple, np.ndarray] = (8, 8),
+ returnfig: bool = False,
+):
+ """
+ 3D scatter plot of the structure factors using magnitude^2, i.e. intensity.
+
+ Args:
+ azim_elev (array): az and el angles for plot
+ proj_dir_lattice (array): (3,) projection direction in lattice
+ proj_dir_cartesian: (array): (3,) projection direction in cartesian
+ tol_den (int): tolerance for rational index denominator
+ dir_proj (float): projection direction, either [elev azim] or normal vector
+ Default is mean vector of self.orientation_zone_axis_range rows
+ marker_size (float): size of markers
+ plot_limit (float): x y z plot limits, default is [0, 1.05]
+ figsize (2 element float): size scaling of figure axes
+ returnfig (bool): set to True to return figure and axes handles
+
+ Returns:
+ fig, ax (optional) figure and axes handles
+ """
+
+ if azim_elev is not None:
+ proj_dir = azim_elev
+ elif proj_dir_lattice is not None:
+ proj_dir = self.lattice_to_cartesian(proj_dir_lattice)
+ elif proj_dir_cartesian is not None:
+ proj_dir = proj_dir_cartesian
+ else:
+ proj_dir = np.mean(self.orientation_zone_axis_range, axis=0)
+
+ if np.size(proj_dir) == 2:
+ el = proj_dir[0]
+ az = proj_dir[1]
+ elif np.size(proj_dir) == 3:
+ if proj_dir[0] == 0 and proj_dir[1] == 0:
+ el = 90 * np.sign(proj_dir[2])
+ else:
+ el = (
+ np.arctan(proj_dir[2] / np.sqrt(proj_dir[0] ** 2 + proj_dir[1] ** 2))
+ * 180
+ / np.pi
+ )
+ az = np.arctan2(proj_dir[1], proj_dir[0]) * 180 / np.pi
+ else:
+ raise Exception(
+ "Projection direction cannot contain " + np.size(proj_dir) + " elements"
+ )
+
+ # 3D plotting
+ fig = plt.figure(figsize=figsize)
+ ax = fig.add_subplot(projection="3d", elev=el, azim=90 - az)
+
+ # Sphere
+ # Make data
+ u = np.linspace(0, 2 * np.pi, 100)
+ v = np.linspace(0, np.pi, 100)
+ r = 0.95
+ x = r * np.outer(np.cos(u), np.sin(v))
+ y = r * np.outer(np.sin(u), np.sin(v))
+ z = r * np.outer(np.ones(np.size(u)), np.cos(v))
+ # Plot the surface
+ ax.plot_surface(
+ x,
+ y,
+ z,
+ edgecolor=None,
+ color=np.array([1.0, 0.8, 0.0]),
+ alpha=0.4,
+ antialiased=True,
+ )
+
+ # Lines
+ r = 0.951
+ t = np.linspace(0, 2 * np.pi, 181)
+ t0 = np.zeros((181,))
+ # z = np.linspace(-2, 2, 100)
+ # r = z**2 + 1
+ # x = r * np.sin(theta)
+ # y = r * np.cos(theta)
+
+ warnings.filterwarnings("ignore", module=r"matplotlib\..*")
+ line_params = {"linewidth": 2, "alpha": 0.1, "c": "k"}
+ for phi in np.arange(0, 180, 5):
+ ax.plot3D(
+ np.sin(phi * np.pi / 180) * np.cos(t) * r,
+ np.sin(phi * np.pi / 180) * np.sin(t) * r,
+ np.cos(phi * np.pi / 180) * r,
+ **line_params,
+ )
+
+ # plot zone axes
+ ax.scatter(
+ xs=self.orientation_vecs[:, 1],
+ ys=self.orientation_vecs[:, 0],
+ zs=self.orientation_vecs[:, 2],
+ s=marker_size,
+ )
+
+ # zone axis range labels
+ if np.abs(self.cell[5] - 120.0) < 1e-6:
+ label_0 = self.rational_ind(
+ self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[0, :])
+ ),
+ tol_den=tol_den,
+ )
+ label_1 = self.rational_ind(
+ self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[1, :])
+ ),
+ tol_den=tol_den,
+ )
+ label_2 = self.rational_ind(
+ self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[2, :])
+ ),
+ tol_den=tol_den,
+ )
+ else:
+ label_0 = self.rational_ind(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[0, :]),
+ tol_den=tol_den,
+ )
+ label_1 = self.rational_ind(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[1, :]),
+ tol_den=tol_den,
+ )
+ label_2 = self.rational_ind(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[2, :]),
+ tol_den=tol_den,
+ )
+
+ # # label_0 = self.cartesian_to_crystal(self.orientation_zone_axis_range[0, :])
+ # # if self.cartesian_directions:
+ # label_0 = self.orientation_zone_axis_range[0, :]
+ # # else:
+ # # label_0 = self.cartesian_to_crystal(self.orientation_zone_axis_range[0, :])
+ # label_0 = np.round(label_0, decimals=3)
+ # label_0 = label_0 / np.min(np.abs(label_0[np.abs(label_0) > 0]))
+ # label_0 = np.round(label_0, decimals=3)
+
+ if self.orientation_full is False and self.orientation_half is False:
+ # # label_1 = self.cartesian_to_crystal(
+ # # self.orientation_zone_axis_range[1, :]
+ # # )
+ # # if self.cartesian_directions:
+ # label_1 = self.orientation_zone_axis_range[1, :]
+ # # else:
+ # # label_1 = self.cartesian_to_crystal(self.orientation_zone_axis_range[1, :])
+ # label_1 = np.round(label_1 * 1e3) * 1e-3
+ # label_1 = label_1 / np.min(np.abs(label_1[np.abs(label_1) > 0]))
+ # label_1 = np.round(label_1 * 1e3) * 1e-3
+
+ # # label_2 = self.cartesian_to_crystal(
+ # # self.orientation_zone_axis_range[2, :]
+ # # )
+ # # if self.cartesian_directions:
+ # label_2 = self.orientation_zone_axis_range[2, :]
+ # # else:
+ # # label_2 = self.cartesian_to_crystal(self.orientation_zone_axis_range[2, :])
+
+ # label_2 = np.round(label_2 * 1e3) * 1e-3
+ # label_2 = label_2 / np.min(np.abs(label_2[np.abs(label_2) > 0]))
+ # label_2 = np.round(label_2 * 1e3) * 1e-3
+
+ inds = np.array(
+ [
+ 0,
+ self.orientation_num_zones - self.orientation_zone_axis_steps - 1,
+ self.orientation_num_zones - 1,
+ ]
+ )
+ else:
+ inds = np.array([0])
+
+ ax.scatter(
+ xs=self.orientation_vecs[inds, 1] * 1.02,
+ ys=self.orientation_vecs[inds, 0] * 1.02,
+ zs=self.orientation_vecs[inds, 2] * 1.02,
+ s=marker_size * 8,
+ linewidth=2,
+ marker="o",
+ edgecolors="r",
+ alpha=1,
+ zorder=10,
+ )
+
+ text_scale_pos = 1.2
+ text_params = {
+ "va": "center",
+ "family": "sans-serif",
+ "fontweight": "normal",
+ "color": "k",
+ "size": 16,
+ }
+ # 'ha': 'center',
+
+ ax.text(
+ self.orientation_vecs[inds[0], 1] * text_scale_pos,
+ self.orientation_vecs[inds[0], 0] * text_scale_pos,
+ self.orientation_vecs[inds[0], 2] * text_scale_pos,
+ label_0,
+ None,
+ zorder=11,
+ ha="center",
+ **text_params,
+ )
+ if self.orientation_full is False and self.orientation_half is False:
+ ax.text(
+ self.orientation_vecs[inds[1], 1] * text_scale_pos,
+ self.orientation_vecs[inds[1], 0] * text_scale_pos,
+ self.orientation_vecs[inds[1], 2] * text_scale_pos,
+ label_1,
+ None,
+ zorder=12,
+ ha="center",
+ **text_params,
+ )
+ ax.text(
+ self.orientation_vecs[inds[2], 1] * text_scale_pos,
+ self.orientation_vecs[inds[2], 0] * text_scale_pos,
+ self.orientation_vecs[inds[2], 2] * text_scale_pos,
+ label_2,
+ None,
+ zorder=13,
+ ha="center",
+ **text_params,
+ )
+
+ # ax.scatter(
+ # xs=self.g_vec_all[0,:],
+ # ys=self.g_vec_all[1,:],
+ # zs=self.g_vec_all[2,:],
+ # s=scale_markers*self.struct_factors_int)
+
+ # axes limits
+ ax.axes.set_xlim3d(left=plot_limit[0], right=plot_limit[1])
+ ax.axes.set_ylim3d(bottom=plot_limit[0], top=plot_limit[1])
+ ax.axes.set_zlim3d(bottom=plot_limit[0], top=plot_limit[1])
+ ax.set_box_aspect((1, 1, 1))
+ ax.set_axis_off()
+ # ax.setxticklabels([])
+ # fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
+ # plt.gca().invert_yaxis()
+ ax.view_init(elev=el, azim=90 - az)
+
+ plt.show()
+
+ if returnfig:
+ return fig, ax
+
+
+def plot_orientation_plan(
+ self,
+ index_plot: int = 0,
+ zone_axis_lattice: Optional[np.ndarray] = None,
+ zone_axis_cartesian: Optional[np.ndarray] = None,
+ figsize: Union[list, tuple, np.ndarray] = (14, 6),
+ returnfig: bool = False,
+):
+ """
+ 3D scatter plot of the structure factors using magnitude^2,
+ i.e. intensity.
+
+ Args:
+ index_plot (int): which index slice to plot
+ zone_axis_plot (3 element float): which zone axis slice to plot
+ figsize (2 element float): size scaling of figure axes
+ returnfig (bool): set to True to return figure and axes handles
+
+ Returns:
+ fig, ax (optional) figure and axes handles
+ """
+
+ # Determine which index to plot if zone_axis_plot is specified
+ if zone_axis_lattice is not None or zone_axis_cartesian is not None:
+ orientation_matrix = self.parse_orientation(
+ zone_axis_lattice=zone_axis_lattice,
+ proj_x_lattice=None,
+ zone_axis_cartesian=zone_axis_cartesian,
+ proj_x_cartesian=None,
+ )
+ index_plot = np.argmin(
+ np.sum(np.abs(self.orientation_vecs - orientation_matrix[:, 2]), axis=1)
+ )
+
+ # if zone_axis_plot is not None:
+ # zone_axis_plot = np.array(zone_axis_plot, dtype="float")
+ # zone_axis_plot = zone_axis_plot / np.linalg.norm(zone_axis_plot)
+
+ # if not self.cartesian_directions:
+ # print(np.round(zone_axis_plot,decimals=6))
+ # zone_axis_plot = self.crystal_to_cartesian(zone_axis_plot)
+ # print(np.round(zone_axis_plot,decimals=6))
+
+ # index_plot = np.argmin(
+ # np.sum((self.orientation_vecs - zone_axis_plot) ** 2, axis=1)
+ # )
+ # print("Orientation plan index " + str(index_plot))
+
+ # initialize figure
+ fig, ax = plt.subplots(1, 2, figsize=figsize)
+
+ # Generate and plot diffraction pattern
+ k_x_y_range = np.array([1, 1]) * self.k_max * 1.2
+ bragg_peaks = self.generate_diffraction_pattern(
+ orientation_matrix=self.orientation_rotation_matrices[index_plot, :],
+ sigma_excitation_error=self.orientation_kernel_size / 3,
+ )
+
+ plot_diffraction_pattern(
+ bragg_peaks,
+ figsize=(figsize[1], figsize[1]),
+ plot_range_kx_ky=k_x_y_range,
+ # scale_markers=10,
+ # shift_labels=0.10,
+ input_fig_handle=[fig, ax],
+ )
+
+ # Plot orientation plan
+ # if self.orientation_corr_2D_method:
+ # im_plot = np.vstack((
+ # np.real(np.fft.ifft(self.orientation_ref[index_plot, :, :], axis=1)
+ # ).astype("float"),
+ # np.real(np.fft.ifft(self.orientation_ref_perp[index_plot, :, :], axis=1)
+ # ).astype("float"))) / self.orientation_ref_max
+
+ # # im_plot = np.vstack((
+ # # np.real(np.fft.ifft(self.orientation_ref[index_plot, :, :], axis=1)
+ # # ).astype("float"),
+ # # self.orientation_ref_perp[index_plot, :, :])) / self.orientation_ref_max
+ # else:
+ # im_plot = self.orientation_ref[index_plot, :, :] / self.orientation_ref_max
+ if self.CUDA:
+ im_plot = (
+ np.real(
+ np.fft.ifft(self.orientation_ref[index_plot, :, :].get(), axis=1)
+ ).astype("float")
+ / self.orientation_ref_max
+ )
+ else:
+ im_plot = (
+ np.real(np.fft.ifft(self.orientation_ref[index_plot, :, :], axis=1)).astype(
+ "float"
+ )
+ / self.orientation_ref_max
+ )
+
+ # coordinates
+ x = self.orientation_gamma * 180 / np.pi
+ # if self.orientation_corr_2D_method:
+ # y = np.arange(2*np.size(self.orientation_shell_radii))
+ # else:
+ y = np.arange(np.size(self.orientation_shell_radii))
+ dx = (x[1] - x[0]) / 2.0
+ dy = (y[1] - y[0]) / 2.0
+ extent = [x[0] - dx, x[-1] + dx, y[-1] + dy, y[0] - dy]
+
+ im = ax[1].imshow(
+ im_plot,
+ cmap="inferno",
+ vmin=0.0,
+ vmax=0.5,
+ extent=extent,
+ aspect="auto",
+ interpolation="none",
+ )
+ fig.colorbar(im)
+ ax[1].xaxis.tick_top()
+ ax[1].set_xticks(np.arange(0, 360 + 90, 90))
+ ax[1].set_ylabel("Radial Index", size=20)
+ # if self.orientation_corr_2D_method:
+ # t0 = np.arange(0,np.size(self.orientation_shell_radii),10)
+ # t1 = t0 + np.size(self.orientation_shell_radii)
+ # ax[1].set_yticks(np.hstack((t0,t1)))
+ # ax[1].set_yticklabels(np.hstack((t0,t0)))
+
+ # Add text label
+ zone_axis_fit = self.orientation_vecs[index_plot, :]
+ zone_axis_fit = zone_axis_fit / np.linalg.norm(zone_axis_fit)
+ sub = np.abs(zone_axis_fit) > 0
+ scale = np.min(np.abs(zone_axis_fit[sub]))
+ if scale > 0.14:
+ zone_axis_fit = zone_axis_fit / scale
+
+ temp = np.round(zone_axis_fit, decimals=2)
+ ax[0].text(
+ -k_x_y_range[0] * 0.95,
+ -k_x_y_range[1] * 0.95,
+ "[" + str(temp[0]) + ", " + str(temp[1]) + ", " + str(temp[2]) + "]",
+ size=18,
+ va="top",
+ )
+
+ # plt.tight_layout()
+ plt.show()
+
+ if returnfig:
+ return fig, ax
+
+
+def plot_diffraction_pattern(
+ bragg_peaks: PointList,
+ bragg_peaks_compare: PointList = None,
+ scale_markers: float = 500,
+ scale_markers_compare: Optional[float] = None,
+ power_markers: float = 1,
+ plot_range_kx_ky: Optional[Union[list, tuple, np.ndarray]] = None,
+ add_labels: bool = True,
+ shift_labels: float = 0.08,
+ shift_marker: float = 0.005,
+ min_marker_size: float = 1e-6,
+ max_marker_size: float = 1000,
+ figsize: Union[list, tuple, np.ndarray] = (12, 6),
+ returnfig: bool = False,
+ input_fig_handle=None,
+):
+ """
+ 2D scatter plot of the Bragg peaks
+
+ Args:
+ bragg_peaks (PointList): numpy array containing ('qx', 'qy', 'intensity', 'h', 'k', 'l')
+ bragg_peaks_compare(PointList): numpy array containing ('qx', 'qy', 'intensity')
+ scale_markers (float): size scaling for markers
+ scale_markers_compare (float): size scaling for markers of comparison
+ power_markers (float): power law scaling for marks (default is 1, i.e. amplitude)
+ plot_range_kx_ky (float): 2 element numpy vector giving the plot range
+ add_labels (bool): flag to add hkl labels to peaks
+ min_marker_size (float): minimum marker size for the comparison peaks
+ max_marker_size (float): maximum marker size for the comparison peaks
+ figsize (2 element float): size scaling of figure axes
+ returnfig (bool): set to True to return figure and axes handles
+ input_fig_handle (fig,ax) Tuple containing a figure / axes handle for the plot.
+ """
+
+ # 2D plotting
+ if input_fig_handle is None:
+ # fig = plt.figure(figsize=figsize)
+ # ax = fig.add_subplot()
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
+ else:
+ fig = input_fig_handle[0]
+ ax_parent = input_fig_handle[1]
+ ax = ax_parent[0]
+
+ if power_markers == 2:
+ marker_size = scale_markers * bragg_peaks.data["intensity"]
+ else:
+ marker_size = scale_markers * (
+ bragg_peaks.data["intensity"] ** (power_markers / 2)
+ )
+
+ # Apply marker size limits to primary plot
+ marker_size = np.clip(marker_size, min_marker_size, max_marker_size)
+
+ if bragg_peaks_compare is None:
+ ax.scatter(
+ bragg_peaks.data["qy"], bragg_peaks.data["qx"], s=marker_size, facecolor="k"
+ )
+ else:
+ if scale_markers_compare is None:
+ scale_markers_compare = scale_markers
+
+ if power_markers == 2:
+ marker_size_compare = np.clip(
+ scale_markers_compare * bragg_peaks_compare.data["intensity"],
+ min_marker_size,
+ max_marker_size,
+ )
+ else:
+ marker_size_compare = np.clip(
+ scale_markers_compare
+ * (bragg_peaks_compare.data["intensity"] ** (power_markers / 2)),
+ min_marker_size,
+ max_marker_size,
+ )
+
+ ax.scatter(
+ bragg_peaks_compare.data["qy"],
+ bragg_peaks_compare.data["qx"],
+ s=marker_size_compare,
+ marker="o",
+ facecolor=[0.0, 0.7, 1.0],
+ )
+ ax.scatter(
+ bragg_peaks.data["qy"],
+ bragg_peaks.data["qx"],
+ s=marker_size,
+ marker="+",
+ facecolor="k",
+ )
+
+ ax.set_xlabel("$q_y$ [Å$^{-1}$]")
+ ax.set_ylabel("$q_x$ [Å$^{-1}$]")
+
+ if plot_range_kx_ky is not None:
+ plot_range_kx_ky = np.array(plot_range_kx_ky)
+ if plot_range_kx_ky.ndim == 0:
+ plot_range_kx_ky = np.array((plot_range_kx_ky, plot_range_kx_ky))
+ ax.set_xlim((-plot_range_kx_ky[0], plot_range_kx_ky[0]))
+ ax.set_ylim((-plot_range_kx_ky[1], plot_range_kx_ky[1]))
+ else:
+ k_range = 1.05 * np.sqrt(
+ np.max(bragg_peaks.data["qx"] ** 2 + bragg_peaks.data["qy"] ** 2)
+ )
+ ax.set_xlim((-k_range, k_range))
+ ax.set_ylim((-k_range, k_range))
+
+ ax.invert_yaxis()
+ ax.set_box_aspect(1)
+ ax.xaxis.tick_top()
+
+ # Labels for all peaks
+ if add_labels is True:
+ text_params = {
+ "ha": "center",
+ "va": "center",
+ "family": "sans-serif",
+ "fontweight": "normal",
+ "color": "r",
+ "size": 10,
+ }
+
+ def overline(x):
+ return str(x) if x >= 0 else (r"\overline{" + str(np.abs(x)) + "}")
+
+ for a0 in range(bragg_peaks.data.shape[0]):
+ h = bragg_peaks.data["h"][a0]
+ k = bragg_peaks.data["k"][a0]
+ l = bragg_peaks.data["l"][a0]
+
+ ax.text(
+ bragg_peaks.data["qy"][a0],
+ bragg_peaks.data["qx"][a0]
+ - shift_labels
+ - shift_marker * np.sqrt(marker_size[a0]),
+ "$" + overline(h) + overline(k) + overline(l) + "$",
+ **text_params,
+ )
+
+ # Force plot to have 1:1 aspect ratio
+ ax.set_aspect("equal")
+
+ if input_fig_handle is None:
+ plt.show()
+
+ if returnfig:
+ return fig, ax
+
+
+def plot_orientation_maps(
+ self,
+ orientation_map=None,
+ orientation_ind: int = 0,
+ dir_in_plane_degrees: float = 0.0,
+ corr_range: np.ndarray = np.array([0, 5]),
+ corr_normalize: bool = True,
+ scale_legend: bool = None,
+ figsize: Union[list, tuple, np.ndarray] = (16, 5),
+ figbound: Union[list, tuple, np.ndarray] = (0.01, 0.005),
+ show_axes: bool = True,
+ camera_dist=None,
+ plot_limit=None,
+ plot_layout=0,
+ swap_axes_xy_limits=False,
+ returnfig: bool = False,
+ progress_bar=False,
+):
+ """
+ Plot the orientation maps.
+
+ Args:
+ orientation_map (OrientationMap): Class containing orientation matrices, correlation values, etc.
+ Optional - can reference internally stored OrientationMap.
+ orientation_ind (int): Which orientation match to plot if num_matches > 1
+ dir_in_plane_degrees (float): In-plane angle to plot in degrees. Default is 0 / x-axis / vertical down.
+ corr_range (np.ndarray): Correlation intensity range for the plot
+ corr_normalize (bool): If true, set mean correlation to 1.
+ scale_legend (float): 2 elements, x and y scaling of legend panel
+ figsize (array): 2 elements defining figure size
+ figbound (array): 2 elements defining figure boundary
+ show_axes (bool): Flag setting whether orienation map axes are visible.
+ camera_dist (float): distance of camera from legend
+ plot_limit (array): 2x3 array defining plot boundaries of legend
+ plot_layout (int): subplot layout: 0 - 1 row, 3 col
+ 1 - 3 row, 1 col
+ swap_axes_xy_limits (bool): swap x and y boundaries for legend (not sure why we need this in some cases)
+ returnfig (bool): set to True to return figure and axes handles
+ progress_bar (bool): Enable progressbar when calculating orientation images.
+
+ Returns:
+ images_orientation (int): RGB images
+ fig, axs (handles): Figure and axes handes for the
+
+ NOTE:
+ Currently, no symmetry reduction. Therefore the x and y orientations
+ are going to be correct only for [001][011][111] orientation triangle.
+
+ """
+
+ # Inputs
+ if orientation_map is None:
+ orientation_map = self.orientation_map
+
+ # Legend size
+ leg_size = np.array([300, 300], dtype="int")
+
+ # Color of the 3 corners
+ color_basis = np.array(
+ [
+ [1.0, 0.0, 0.0],
+ [0.0, 0.7, 0.0],
+ [0.0, 0.3, 1.0],
+ ]
+ )
+
+ # Generate reflection operators for symmetry reduction
+ A_ref = np.zeros((3, 3, 3))
+ # A_ref[0] = np.array([
+ # [-1, 0, 0],
+ # [ 0,-1. 0],
+ # [ 0, 0,-1]])
+ for a0 in range(3):
+ if a0 == 0:
+ v = np.cross(
+ self.orientation_zone_axis_range[1, :],
+ self.orientation_zone_axis_range[0, :],
+ )
+ elif a0 == 1:
+ v = np.cross(
+ self.orientation_zone_axis_range[2, :],
+ self.orientation_zone_axis_range[1, :],
+ )
+ elif a0 == 2:
+ v = np.cross(
+ self.orientation_zone_axis_range[0, :],
+ self.orientation_zone_axis_range[2, :],
+ )
+ v = v / np.linalg.norm(v)
+
+ A_ref[a0] = np.array(
+ [
+ [1 - 2 * v[0] ** 2, -2 * v[0] * v[1], -2 * v[0] * v[2]],
+ [-2 * v[0] * v[1], 1 - 2 * v[1] ** 2, -2 * v[1] * v[2]],
+ [-2 * v[0] * v[2], -2 * v[1] * v[2], 1 - 2 * v[2] ** 2],
+ ]
+ )
+
+ # init
+ dir_in_plane = np.deg2rad(dir_in_plane_degrees)
+ ct = np.cos(dir_in_plane)
+ st = np.sin(dir_in_plane)
+ basis_x = np.zeros((orientation_map.num_x, orientation_map.num_y, 3))
+ basis_y = np.zeros((orientation_map.num_x, orientation_map.num_y, 3))
+ basis_z = np.zeros((orientation_map.num_x, orientation_map.num_y, 3))
+ rgb_x = np.zeros((orientation_map.num_x, orientation_map.num_y, 3))
+ rgb_z = np.zeros((orientation_map.num_x, orientation_map.num_y, 3))
+
+ # Basis for fitting orientation projections
+ A = np.linalg.inv(self.orientation_zone_axis_range).T
+
+ # Correlation masking
+ corr = orientation_map.corr[:, :, orientation_ind]
+ if corr_normalize:
+ corr = corr / np.mean(corr)
+ mask = (corr - corr_range[0]) / (corr_range[1] - corr_range[0])
+ mask = np.clip(mask, 0, 1)
+
+ # Generate images
+ for rx, ry in tqdmnd(
+ orientation_map.num_x,
+ orientation_map.num_y,
+ desc="Generating orientation maps",
+ unit=" PointList",
+ disable=not progress_bar,
+ ):
+ if self.pymatgen_available:
+ basis_x[rx, ry, :] = (
+ A @ orientation_map.family[rx, ry, orientation_ind, :, 0]
+ )
+ basis_y[rx, ry, :] = (
+ A @ orientation_map.family[rx, ry, orientation_ind, :, 1]
+ )
+ basis_x[rx, ry, :] = basis_x[rx, ry, :] * ct + basis_y[rx, ry, :] * st
+
+ basis_z[rx, ry, :] = (
+ A @ orientation_map.family[rx, ry, orientation_ind, :, 2]
+ )
+ else:
+ basis_z[rx, ry, :] = (
+ A @ orientation_map.matrix[rx, ry, orientation_ind, :, 2]
+ )
+ basis_x = np.clip(basis_x, 0, 1)
+ basis_z = np.clip(basis_z, 0, 1)
+
+ # Convert to RGB images
+ basis_x_max = np.max(basis_x, axis=2)
+ sub = basis_x_max > 0
+ basis_x_scale = basis_x * mask[:, :, None]
+ for a0 in range(3):
+ basis_x_scale[:, :, a0][sub] /= basis_x_max[sub]
+ basis_x_scale[:, :, a0][np.logical_not(sub)] = 0
+ rgb_x = (
+ basis_x_scale[:, :, 0][:, :, None] * color_basis[0, :][None, None, :]
+ + basis_x_scale[:, :, 1][:, :, None] * color_basis[1, :][None, None, :]
+ + basis_x_scale[:, :, 2][:, :, None] * color_basis[2, :][None, None, :]
+ )
+
+ basis_z_max = np.max(basis_z, axis=2)
+ sub = basis_z_max > 0
+ basis_z_scale = basis_z * mask[:, :, None]
+ for a0 in range(3):
+ basis_z_scale[:, :, a0][sub] /= basis_z_max[sub]
+ basis_z_scale[:, :, a0][np.logical_not(sub)] = 0
+ rgb_z = (
+ basis_z_scale[:, :, 0][:, :, None] * color_basis[0, :][None, None, :]
+ + basis_z_scale[:, :, 1][:, :, None] * color_basis[1, :][None, None, :]
+ + basis_z_scale[:, :, 2][:, :, None] * color_basis[2, :][None, None, :]
+ )
+
+ # Legend init
+ # projection vector
+ cam_dir = np.mean(self.orientation_zone_axis_range, axis=0)
+ cam_dir = cam_dir / np.linalg.norm(cam_dir)
+ az = np.rad2deg(np.arctan2(cam_dir[0], cam_dir[1]))
+ # el = np.rad2deg(np.arccos(cam_dir[2]))
+ el = np.rad2deg(np.arcsin(cam_dir[2]))
+ # coloring
+ wx = self.orientation_inds[:, 0] / self.orientation_zone_axis_steps
+ wy = self.orientation_inds[:, 1] / self.orientation_zone_axis_steps
+ w0 = 1 - wx - 0.5 * wy
+ w1 = wx - wy
+ w2 = wy
+ # w0 = 1 - w1/2 - w2/2
+ w_scale = np.maximum(np.maximum(w0, w1), w2)
+ w_scale = 1 - np.exp(-w_scale)
+ w0 = w0 / w_scale
+ w1 = w1 / w_scale
+ w2 = w2 / w_scale
+ rgb_legend = np.clip(
+ w0[:, None] * color_basis[0, :]
+ + w1[:, None] * color_basis[1, :]
+ + w2[:, None] * color_basis[2, :],
+ 0,
+ 1,
+ )
+
+ if np.abs(self.cell[5] - 120.0) < 1e-6:
+ label_0 = self.rational_ind(
+ self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[0, :])
+ )
+ )
+ label_1 = self.rational_ind(
+ self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[1, :])
+ )
+ )
+ label_2 = self.rational_ind(
+ self.lattice_to_hexagonal(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[2, :])
+ )
+ )
+ else:
+ label_0 = self.rational_ind(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[0, :])
+ )
+ label_1 = self.rational_ind(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[1, :])
+ )
+ label_2 = self.rational_ind(
+ self.cartesian_to_lattice(self.orientation_zone_axis_range[2, :])
+ )
+
+ inds_legend = np.array(
+ [
+ 0,
+ self.orientation_num_zones - self.orientation_zone_axis_steps - 1,
+ self.orientation_num_zones - 1,
+ ]
+ )
+
+ # Determine if lattice direction labels should be left-right
+ # or right-left aligned.
+ v0 = self.orientation_vecs[inds_legend[0], :]
+ v1 = self.orientation_vecs[inds_legend[1], :]
+ v2 = self.orientation_vecs[inds_legend[2], :]
+ n = np.cross(v0, cam_dir)
+ if np.sum(v1 * n) < np.sum(v2 * n):
+ ha_1 = "left"
+ ha_2 = "right"
+ else:
+ ha_1 = "right"
+ ha_2 = "left"
+
+ # plotting frame
+ # fig, ax = plt.subplots(1, 3, figsize=figsize)
+ fig = plt.figure(figsize=figsize)
+ if plot_layout == 0:
+ ax_x = fig.add_axes([0.0 + figbound[0], 0.0, 0.4 - 2 * +figbound[0], 1.0])
+ ax_z = fig.add_axes([0.4 + figbound[0], 0.0, 0.4 - 2 * +figbound[0], 1.0])
+ ax_l = fig.add_axes(
+ [0.8 + figbound[0], 0.0, 0.2 - 2 * +figbound[0], 1.0],
+ projection="3d",
+ elev=el,
+ azim=az,
+ )
+ elif plot_layout == 1:
+ ax_x = fig.add_axes([0.0, 0.0 + figbound[0], 1.0, 0.4 - 2 * +figbound[0]])
+ ax_z = fig.add_axes([0.0, 0.4 + figbound[0], 1.0, 0.4 - 2 * +figbound[0]])
+ ax_l = fig.add_axes(
+ [0.0, 0.8 + figbound[0], 1.0, 0.2 - 2 * +figbound[0]],
+ projection="3d",
+ elev=el,
+ azim=az,
+ )
+
+ # orientation images
+ if self.pymatgen_available:
+ ax_x.imshow(rgb_x)
+ else:
+ ax_x.imshow(np.ones_like(rgb_z))
+ ax_x.text(
+ rgb_z.shape[1] / 2,
+ rgb_z.shape[0] / 2 - 10,
+ "in-plane orientation",
+ fontsize=14,
+ horizontalalignment="center",
+ )
+ ax_x.text(
+ rgb_z.shape[1] / 2,
+ rgb_z.shape[0] / 2 + 0,
+ "for this crystal system",
+ fontsize=14,
+ horizontalalignment="center",
+ )
+ ax_x.text(
+ rgb_z.shape[1] / 2,
+ rgb_z.shape[0] / 2 + 10,
+ "requires pymatgen",
+ fontsize=14,
+ horizontalalignment="center",
+ )
+ ax_z.imshow(rgb_z)
+
+ # Labels for orientation images
+ ax_x.set_title("In-Plane Orientation", size=20)
+ ax_z.set_title("Out-of-Plane Orientation", size=20)
+ if show_axes is False:
+ ax_x.axis("off")
+ ax_z.axis("off")
+
+ # Triangulate faces
+ p = self.orientation_vecs[:, (1, 0, 2)]
+ tri = mtri.Triangulation(
+ self.orientation_inds[:, 1] - self.orientation_inds[:, 0] * 1e-3,
+ self.orientation_inds[:, 0] - self.orientation_inds[:, 1] * 1e-3,
+ )
+ # convert rgb values of pixels to faces
+ rgb_faces = (
+ rgb_legend[tri.triangles[:, 0], :]
+ + rgb_legend[tri.triangles[:, 1], :]
+ + rgb_legend[tri.triangles[:, 2], :]
+ ) / 3
+ # Add triangulated surface plot to axes
+ pc = art3d.Poly3DCollection(
+ p[tri.triangles],
+ facecolors=rgb_faces,
+ alpha=1,
+ )
+ pc.set_antialiased(False)
+ ax_l.add_collection(pc)
+
+ if plot_limit is None:
+ plot_limit = np.array(
+ [
+ [np.min(p[:, 0]), np.min(p[:, 1]), np.min(p[:, 2])],
+ [np.max(p[:, 0]), np.max(p[:, 1]), np.max(p[:, 2])],
+ ]
+ )
+ # plot_limit = (plot_limit - np.mean(plot_limit, axis=0)) * 1.5 + np.mean(
+ # plot_limit, axis=0
+ # )
+ plot_limit[:, 0] = (
+ plot_limit[:, 0] - np.mean(plot_limit[:, 0])
+ ) * 1.5 + np.mean(plot_limit[:, 0])
+ plot_limit[:, 1] = (
+ plot_limit[:, 2] - np.mean(plot_limit[:, 1])
+ ) * 1.5 + np.mean(plot_limit[:, 1])
+ plot_limit[:, 2] = (
+ plot_limit[:, 1] - np.mean(plot_limit[:, 2])
+ ) * 1.1 + np.mean(plot_limit[:, 2])
+
+ # ax_l.view_init(elev=el, azim=az)
+ # Appearance
+ ax_l.invert_yaxis()
+ if swap_axes_xy_limits:
+ ax_l.axes.set_xlim3d(left=plot_limit[0, 0], right=plot_limit[1, 0])
+ ax_l.axes.set_ylim3d(bottom=plot_limit[0, 1], top=plot_limit[1, 1])
+ ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2])
+ else:
+ ax_l.axes.set_xlim3d(left=plot_limit[0, 1], right=plot_limit[1, 1])
+ ax_l.axes.set_ylim3d(bottom=plot_limit[0, 0], top=plot_limit[1, 0])
+ ax_l.axes.set_zlim3d(bottom=plot_limit[0, 2], top=plot_limit[1, 2])
+ axisEqual3D(ax_l)
+ if camera_dist is not None:
+ ax_l.dist = camera_dist
+ ax_l.axis("off")
+
+ # Add text labels
+ text_scale_pos = 0.1
+ text_params = {
+ "va": "center",
+ "family": "sans-serif",
+ "fontweight": "normal",
+ "color": "k",
+ "size": 14,
+ }
+ format_labels = "{0:.2g}"
+ vec = self.orientation_vecs[inds_legend[0], :] - cam_dir
+ vec = vec / np.linalg.norm(vec)
+ if np.abs(self.cell[5] - 120.0) > 1e-6:
+ ax_l.text(
+ self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos,
+ self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos,
+ self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos,
+ "["
+ + format_labels.format(label_0[0])
+ + " "
+ + format_labels.format(label_0[1])
+ + " "
+ + format_labels.format(label_0[2])
+ + "]",
+ None,
+ zorder=11,
+ ha="center",
+ **text_params,
+ )
+ else:
+ ax_l.text(
+ self.orientation_vecs[inds_legend[0], 1] + vec[1] * text_scale_pos,
+ self.orientation_vecs[inds_legend[0], 0] + vec[0] * text_scale_pos,
+ self.orientation_vecs[inds_legend[0], 2] + vec[2] * text_scale_pos,
+ "["
+ + format_labels.format(label_0[0])
+ + " "
+ + format_labels.format(label_0[1])
+ + " "
+ + format_labels.format(label_0[2])
+ + " "
+ + format_labels.format(label_0[3])
+ + "]",
+ None,
+ zorder=11,
+ ha="center",
+ **text_params,
+ )
+ vec = self.orientation_vecs[inds_legend[1], :] - cam_dir
+ vec = vec / np.linalg.norm(vec)
+ if np.abs(self.cell[5] - 120.0) > 1e-6:
+ ax_l.text(
+ self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos,
+ self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos,
+ self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos,
+ "["
+ + format_labels.format(label_1[0])
+ + " "
+ + format_labels.format(label_1[1])
+ + " "
+ + format_labels.format(label_1[2])
+ + "]",
+ None,
+ zorder=12,
+ ha=ha_1,
+ **text_params,
+ )
+ else:
+ ax_l.text(
+ self.orientation_vecs[inds_legend[1], 1] + vec[1] * text_scale_pos,
+ self.orientation_vecs[inds_legend[1], 0] + vec[0] * text_scale_pos,
+ self.orientation_vecs[inds_legend[1], 2] + vec[2] * text_scale_pos,
+ "["
+ + format_labels.format(label_1[0])
+ + " "
+ + format_labels.format(label_1[1])
+ + " "
+ + format_labels.format(label_1[2])
+ + " "
+ + format_labels.format(label_1[3])
+ + "]",
+ None,
+ zorder=12,
+ ha=ha_1,
+ **text_params,
+ )
+ vec = self.orientation_vecs[inds_legend[2], :] - cam_dir
+ vec = vec / np.linalg.norm(vec)
+ if np.abs(self.cell[5] - 120.0) > 1e-6:
+ ax_l.text(
+ self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos,
+ self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos,
+ self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos,
+ "["
+ + format_labels.format(label_2[0])
+ + " "
+ + format_labels.format(label_2[1])
+ + " "
+ + format_labels.format(label_2[2])
+ + "]",
+ None,
+ zorder=13,
+ ha=ha_2,
+ **text_params,
+ )
+ else:
+ ax_l.text(
+ self.orientation_vecs[inds_legend[2], 1] + vec[1] * text_scale_pos,
+ self.orientation_vecs[inds_legend[2], 0] + vec[0] * text_scale_pos,
+ self.orientation_vecs[inds_legend[2], 2] + vec[2] * text_scale_pos,
+ "["
+ + format_labels.format(label_2[0])
+ + " "
+ + format_labels.format(label_2[1])
+ + " "
+ + format_labels.format(label_2[2])
+ + " "
+ + format_labels.format(label_2[3])
+ + "]",
+ None,
+ zorder=13,
+ ha=ha_2,
+ **text_params,
+ )
+
+ plt.show()
+
+ images_orientation = np.zeros((orientation_map.num_x, orientation_map.num_y, 3, 2))
+ if self.pymatgen_available:
+ images_orientation[:, :, :, 0] = rgb_x
+ images_orientation[:, :, :, 1] = rgb_z
+
+ if returnfig:
+ ax = [ax_x, ax_z, ax_l]
+ return images_orientation, fig, ax
+ else:
+ return images_orientation
+
+
+def plot_fiber_orientation_maps(
+ self,
+ orientation_map,
+ orientation_ind: int = 0,
+ symmetry_order: int = None,
+ symmetry_mirror: bool = False,
+ dir_in_plane_degrees: float = 0.0,
+ corr_range: np.ndarray = np.array([0, 2]),
+ corr_normalize: bool = True,
+ show_axes: bool = True,
+ medfilt_size: int = None,
+ cmap_out_of_plane: str = "plasma",
+ leg_size: int = 200,
+ figsize: Union[list, tuple, np.ndarray] = (12, 8),
+ figbound: Union[list, tuple, np.ndarray] = (0.005, 0.04),
+ returnfig: bool = False,
+):
+ """
+ Generate and plot the orientation maps from fiber texture plots.
+
+ Args:
+ orientation_map (OrientationMap): Class containing orientation matrices, correlation values, etc.
+ orientation_ind (int): Which orientation match to plot if num_matches > 1
+ dir_in_plane_degrees (float): Reference in-plane angle (degrees). Default is 0 / x-axis / vertical down.
+ corr_range (np.ndarray): Correlation intensity range for the plot
+ corr_normalize (bool): If true, set mean correlation to 1.
+ show_axes (bool): Flag setting whether orienation map axes are visible.
+ figsize (array): 2 elements defining figure size
+ figbound (array): 2 elements defining figure boundary
+ returnfig (bool): set to True to return figure and axes handles
+
+ Returns:
+ images_orientation (int): RGB images
+ fig, axs (handles): Figure and axes handes for the
+
+ NOTE:
+ Currently, no symmetry reduction. Therefore the x and y orientations
+ are going to be correct only for [001][011][111] orientation triangle.
+
+ """
+
+ # angular colormap
+ basis = np.array(
+ [
+ [1.0, 0.2, 0.2],
+ [1.0, 0.7, 0.0],
+ [0.0, 0.8, 0.0],
+ [0.0, 0.8, 1.0],
+ [0.2, 0.4, 1.0],
+ [0.9, 0.2, 1.0],
+ ]
+ )
+
+ # Correlation masking
+ corr = orientation_map.corr[:, :, orientation_ind]
+ if corr_normalize:
+ corr = corr / np.mean(corr)
+ if medfilt_size is not None:
+ corr = medfilt(corr, medfilt_size)
+ mask = (corr - corr_range[0]) / (corr_range[1] - corr_range[0])
+ mask = np.clip(mask, 0, 1)
+
+ # Get symmetry
+ if symmetry_order is None:
+ symmetry_order = np.round(360.0 / self.orientation_fiber_angles[1])
+ elif symmetry_mirror:
+ symmetry_order = 2 * symmetry_order
+
+ # Generate out-of-plane orientation signal
+ ang_op = orientation_map.angles[:, :, orientation_ind, 1]
+ if self.orientation_fiber_angles[0] > 0:
+ sig_op = ang_op / np.deg2rad(self.orientation_fiber_angles[0])
+ else:
+ sig_op = ang_op
+ if medfilt_size is not None:
+ sig_op = medfilt(sig_op, medfilt_size)
+
+ # Generate in-plane orientation signal
+ ang_ip = (
+ orientation_map.angles[:, :, orientation_ind, 0]
+ + orientation_map.angles[:, :, orientation_ind, 2]
+ )
+ sig_ip = np.mod((symmetry_order / (2 * np.pi)) * ang_ip, 1.0)
+ if symmetry_mirror:
+ sub = np.sin((symmetry_order / 2) * ang_ip) < 0
+ sig_ip[sub] = np.mod(-sig_ip[sub], 1)
+ sig_ip = np.mod(sig_ip - (dir_in_plane_degrees / 360.0 * symmetry_order), 1.0)
+ if medfilt_size is not None:
+ sig_ip = medfilt(sig_ip, medfilt_size)
+
+ # out-of-plane RGB images
+ # im_op = plt.cm.blues(sig_op)
+ cmap = plt.get_cmap(cmap_out_of_plane)
+ im_op = cmap(sig_op)
+ im_op = np.delete(im_op, 3, axis=2)
+ im_op = im_op * mask[:, :, None]
+
+ # in-plane image
+ im_ip = np.zeros((sig_ip.shape[0], sig_ip.shape[1], 3))
+ for a0 in range(basis.shape[0]):
+ weight = np.maximum(
+ 1
+ - np.abs(np.mod(sig_ip - a0 / basis.shape[0] + 0.5, 1.0) - 0.5)
+ * basis.shape[0],
+ 0,
+ )
+ im_ip += basis[a0, :][None, None, :] * weight[:, :, None]
+ im_ip = np.clip(im_ip, 0, 1)
+ im_ip = im_ip * mask[:, :, None]
+
+ # draw in-plane legends
+ r = np.arange(leg_size) - leg_size / 2 + 0.5
+ ya, xa = np.meshgrid(r, r)
+ ra = np.sqrt(xa**2 + ya**2)
+ ta = np.arctan2(ya, xa)
+ sig_leg = np.mod((symmetry_order / (2 * np.pi)) * ta, 1.0)
+ if symmetry_mirror:
+ sub = np.sin((symmetry_order / 2) * ta) < 0
+ sig_leg[sub] = np.mod(-sig_leg[sub], 1)
+ # leg_ip =
+ im_ip_leg = np.zeros((leg_size, leg_size, 3))
+ for a0 in range(basis.shape[0]):
+ weight = np.maximum(
+ 1
+ - np.abs(np.mod(sig_leg - a0 / basis.shape[0] + 0.5, 1.0) - 0.5)
+ * basis.shape[0],
+ 0,
+ )
+ im_ip_leg += basis[a0, :][None, None, :] * weight[:, :, None]
+ im_ip_leg = np.clip(im_ip_leg, 0, 1)
+ mask = np.clip(leg_size / 2 - ra + 0.5, 0, 1) * np.clip(
+ ra - leg_size / 4 + 0.5, 0, 1
+ )
+ im_ip_leg = im_ip_leg * mask[:, :, None] + (1 - mask)[:, :, None]
+
+ # t = np.linspace(0,2*np.pi,1001)
+ # y = np.mod((symmetry_order/(2*np.pi))*t,1.0)
+ # if symmetry_mirror:
+ # sub = np.sin((symmetry_order/2)*t) < 0
+ # y[sub] = np.mod(-y[sub],1)
+
+ # plotting frame
+ # fig, ax = plt.subplots(1, 3, figsize=figsize)
+ fig = plt.figure(figsize=figsize)
+
+ ax_ip = fig.add_axes(
+ [
+ 0.0 + figbound[0],
+ 0.25 + figbound[1],
+ 0.5 - 2 * +figbound[0],
+ 0.75 - figbound[1],
+ ]
+ )
+ ax_op = fig.add_axes(
+ [
+ 0.5 + figbound[0],
+ 0.25 + figbound[1],
+ 0.5 - 2 * +figbound[0],
+ 0.75 - figbound[1],
+ ]
+ )
+
+ ax_ip_l = fig.add_axes(
+ [
+ 0.1 + figbound[0],
+ 0.0 + figbound[1],
+ 0.3 - 2 * +figbound[0],
+ 0.25 - figbound[1],
+ ]
+ )
+ ax_op_l = fig.add_axes(
+ [
+ 0.6 + figbound[0],
+ 0.0 + figbound[1],
+ 0.3 - 2 * +figbound[0],
+ 0.25 - figbound[1],
+ ]
+ )
+
+ # in-plane
+ ax_ip.imshow(im_ip)
+ ax_ip.set_title("In-Plane Rotation", size=16)
+
+ # out of plane
+ if self.orientation_fiber_angles[0] > 0:
+ ax_op.imshow(im_op)
+ ax_op.set_title("Out-of-Plane Tilt", size=16)
+ else:
+ ax_op.axis("off")
+
+ if show_axes is False:
+ ax_ip.axis("off")
+ ax_op.axis("off")
+
+ # in plane legend
+ ax_ip_l.imshow(im_ip_leg)
+ ax_ip_l.set_axis_off()
+
+ # out of plane legend
+ if self.orientation_fiber_angles[0] > 0:
+ t = np.tile(
+ np.linspace(0, 1, leg_size, endpoint=True),
+ (np.round(leg_size / 10).astype("int"), 1),
+ )
+ im_op_leg = cmap(t)
+ im_op_leg = np.delete(im_op_leg, 3, axis=2)
+ ax_op_l.imshow(im_op_leg)
+ ax_op_l.set_yticks([])
+
+ ticks = [
+ np.round(leg_size * 0.0),
+ np.round(leg_size * 0.25),
+ np.round(leg_size * 0.5),
+ np.round(leg_size * 0.75),
+ np.round(leg_size * 1.0),
+ ]
+ labels = [
+ str(np.round(self.orientation_fiber_angles[0] * 0.00)) + "$\\degree$",
+ str(np.round(self.orientation_fiber_angles[0] * 0.25)) + "$\\degree$",
+ str(np.round(self.orientation_fiber_angles[0] * 0.50)) + "$\\degree$",
+ str(np.round(self.orientation_fiber_angles[0] * 0.75)) + "$\\degree$",
+ str(np.round(self.orientation_fiber_angles[0] * 1.00)) + "$\\degree$",
+ ]
+ ax_op_l.set_xticks(ticks)
+ ax_op_l.set_xticklabels(labels)
+ else:
+ ax_op_l.axis("off")
+
+ images_orientation = np.zeros((orientation_map.num_x, orientation_map.num_y, 3, 2))
+ images_orientation[:, :, :, 0] = im_ip
+ images_orientation[:, :, :, 1] = im_op
+
+ if returnfig:
+ ax = [ax_ip, ax_op, ax_ip_l, ax_op_l]
+ return images_orientation, fig, ax
+ else:
+ return images_orientation
+
+
+def plot_clusters(
+ self,
+ area_min=2,
+ outline_grains=True,
+ outline_thickness=1,
+ fill_grains=0.25,
+ smooth_grains=1.0,
+ cmap="viridis",
+ figsize=(8, 8),
+ returnfig=False,
+):
+ """
+ Plot the clusters as an image.
+
+ Parameters
+ --------
+ area_min: int (optional)
+ Min cluster size to include, in units of probe positions.
+ outline_grains: bool (optional)
+ Set to True to draw grains with outlines
+ outline_thickness: int (optional)
+ Thickenss of the grain outline
+ fill_grains: float (optional)
+ Outlined grains are filled with this value in pixels.
+ smooth_grains: float (optional)
+ Grain boundaries are smoothed by this value in pixels.
+ figsize: tuple
+ Size of the figure panel
+ returnfig: bool
+ Setting this to true returns the figure and axis handles
+
+ Returns
+ --------
+ fig, ax (optional)
+ Figure and axes handles
+
+ """
+
+ # init
+ im_plot = np.zeros(
+ (
+ self.orientation_map.num_x,
+ self.orientation_map.num_y,
+ )
+ )
+ im_grain = np.zeros(
+ (
+ self.orientation_map.num_x,
+ self.orientation_map.num_y,
+ ),
+ dtype="bool",
+ )
+
+ # make plotting image
+
+ for a0 in range(self.cluster_sizes.shape[0]):
+ if self.cluster_sizes[a0] >= area_min:
+ if outline_grains:
+ im_grain[:] = False
+ im_grain[
+ self.cluster_inds[a0][0, :],
+ self.cluster_inds[a0][1, :],
+ ] = True
+
+ im_dist = distance_transform_edt(
+ erosion(
+ np.invert(im_grain), footprint=np.ones((3, 3), dtype="bool")
+ )
+ ) - distance_transform_edt(im_grain)
+ im_dist = gaussian_filter(im_dist, sigma=smooth_grains, mode="nearest")
+ im_add = np.exp(im_dist**2 / (-0.5 * outline_thickness**2))
+
+ if fill_grains > 0:
+ im_dist = distance_transform_edt(
+ erosion(
+ np.invert(im_grain), footprint=np.ones((3, 3), dtype="bool")
+ )
+ )
+ im_dist = gaussian_filter(
+ im_dist, sigma=smooth_grains, mode="nearest"
+ )
+ im_add += fill_grains * np.exp(
+ im_dist**2 / (-0.5 * outline_thickness**2)
+ )
+
+ # im_add = 1 - np.exp(
+ # distance_transform_edt(im_grain)**2 \
+ # / (-2*outline_thickness**2))
+ im_plot += im_add
+ # im_plot = np.minimum(im_plot, im_add)
+ else:
+ # xg,yg = np.unravel_index(self.cluster_inds[a0], im_plot.shape)
+ im_grain[:] = False
+ im_grain[
+ self.cluster_inds[a0][0, :],
+ self.cluster_inds[a0][1, :],
+ ] = True
+ im_plot += gaussian_filter(
+ im_grain.astype("float"), sigma=smooth_grains, mode="nearest"
+ )
+
+ # im_plot[
+ # self.cluster_inds[a0][0,:],
+ # self.cluster_inds[a0][1,:],
+ # ] += 1
+
+ if outline_grains:
+ im_plot = np.clip(im_plot, 0, 2)
+
+ # plotting
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.imshow(
+ im_plot,
+ # vmin = -3,
+ # vmax = 3,
+ cmap=cmap,
+ )
+
+
+def plot_cluster_size(
+ self,
+ area_min=None,
+ area_max=None,
+ area_step=1,
+ weight_intensity=False,
+ pixel_area=1.0,
+ pixel_area_units="px^2",
+ figsize=(8, 6),
+ returnfig=False,
+):
+ """
+ Plot the cluster sizes
+
+ Parameters
+ --------
+ area_min: int (optional)
+ Min area to include in pixels^2
+ area_max: int (optional)
+ Max area bin in pixels^2
+ area_step: int (optional)
+ Step size of the histogram bin in pixels^2
+ weight_intensity: bool
+ Weight histogram by the peak intensity.
+ pixel_area: float
+ Size of pixel area unit square
+ pixel_area_units: string
+ Units of the pixel area
+ figsize: tuple
+ Size of the figure panel
+ returnfig: bool
+ Setting this to true returns the figure and axis handles
+
+ Returns
+ --------
+ fig, ax (optional)
+ Figure and axes handles
+
+ """
+
+ if area_max is None:
+ area_max = np.max(self.cluster_sizes)
+ area = np.arange(0, area_max, area_step)
+ if area_min is None:
+ sub = self.cluster_sizes.astype("int") < area_max
+ else:
+ sub = np.logical_and(
+ self.cluster_sizes.astype("int") >= area_min,
+ self.cluster_sizes.astype("int") < area_max,
+ )
+ if weight_intensity:
+ hist = np.bincount(
+ self.cluster_sizes[sub] // area_step,
+ weights=self.cluster_sig[sub],
+ minlength=area.shape[0],
+ )
+ else:
+ hist = np.bincount(
+ self.cluster_sizes[sub] // area_step,
+ minlength=area.shape[0],
+ )
+
+ # plotting
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.bar(
+ area * pixel_area,
+ hist,
+ width=0.8 * pixel_area * area_step,
+ )
+ ax.set_xlim((0, area_max * pixel_area))
+ ax.set_xlabel("Grain Area [" + pixel_area_units + "]")
+ if weight_intensity:
+ ax.set_ylabel("Total Signal [arb. units]")
+ else:
+ ax.set_ylabel("Number of Grains")
+
+ if returnfig:
+ return fig, ax
+
+
+def axisEqual3D(ax):
+ extents = np.array([getattr(ax, "get_{}lim".format(dim))() for dim in "xyz"])
+ sz = extents[:, 1] - extents[:, 0]
+ centers = np.mean(extents, axis=1)
+ maxsize = max(abs(sz))
+ r = maxsize / 2
+ for ctr, dim in zip(centers, "xyz"):
+ getattr(ax, "set_{}lim".format(dim))(ctr - r, ctr + r)
+ ax.set_box_aspect((1, 1, 1))
+
+
+def atomic_colors(Z, scheme="jmol"):
+ """
+ Return atomic colors for Z.
+
+ Modes are "colin" and "jmol".
+ "colin" uses the handmade but incomplete scheme of Colin Ophus
+ "jmol" uses the JMOL scheme, from http://jmol.sourceforge.net/jscolors
+ which includes all elements up to 109
+ """
+ if scheme == "jmol":
+ return np.array(jmol_colors.get(Z, (0.0, 0.0, 0.0)))
+ else:
+ return {
+ 1: np.array([0.8, 0.8, 0.8]),
+ 2: np.array([1.0, 0.7, 0.0]),
+ 3: np.array([1.0, 0.0, 1.0]),
+ 4: np.array([0.0, 0.5, 0.0]),
+ 5: np.array([0.5, 0.0, 0.0]),
+ 6: np.array([0.5, 0.5, 0.5]),
+ 7: np.array([0.0, 0.7, 1.0]),
+ 8: np.array([1.0, 0.0, 0.0]),
+ 13: np.array([0.6, 0.7, 0.8]),
+ 14: np.array([0.3, 0.3, 0.3]),
+ 15: np.array([1.0, 0.6, 0.0]),
+ 16: np.array([1.0, 0.9, 0.0]),
+ 17: np.array([0.0, 1.0, 0.0]),
+ 79: np.array([1.0, 0.7, 0.0]),
+ }.get(Z, np.array([0.0, 0.0, 0.0]))
+
+
+jmol_colors = {
+ 1: (1.000, 1.000, 1.000),
+ 2: (0.851, 1.000, 1.000),
+ 3: (0.800, 0.502, 1.000),
+ 4: (0.761, 1.000, 0.000),
+ 5: (1.000, 0.710, 0.710),
+ 6: (0.565, 0.565, 0.565),
+ 7: (0.188, 0.314, 0.973),
+ 8: (1.000, 0.051, 0.051),
+ 9: (0.565, 0.878, 0.314),
+ 10: (0.702, 0.890, 0.961),
+ 11: (0.671, 0.361, 0.949),
+ 12: (0.541, 1.000, 0.000),
+ 13: (0.749, 0.651, 0.651),
+ 14: (0.941, 0.784, 0.627),
+ 15: (1.000, 0.502, 0.000),
+ 16: (1.000, 1.000, 0.188),
+ 17: (0.122, 0.941, 0.122),
+ 18: (0.502, 0.820, 0.890),
+ 19: (0.561, 0.251, 0.831),
+ 20: (0.239, 1.000, 0.000),
+ 21: (0.902, 0.902, 0.902),
+ 22: (0.749, 0.761, 0.780),
+ 23: (0.651, 0.651, 0.671),
+ 24: (0.541, 0.600, 0.780),
+ 25: (0.612, 0.478, 0.780),
+ 26: (0.878, 0.400, 0.200),
+ 27: (0.941, 0.565, 0.627),
+ 28: (0.314, 0.816, 0.314),
+ 29: (0.784, 0.502, 0.200),
+ 30: (0.490, 0.502, 0.690),
+ 31: (0.761, 0.561, 0.561),
+ 32: (0.400, 0.561, 0.561),
+ 33: (0.741, 0.502, 0.890),
+ 34: (1.000, 0.631, 0.000),
+ 35: (0.651, 0.161, 0.161),
+ 36: (0.361, 0.722, 0.820),
+ 37: (0.439, 0.180, 0.690),
+ 38: (0.000, 1.000, 0.000),
+ 39: (0.580, 1.000, 1.000),
+ 40: (0.580, 0.878, 0.878),
+ 41: (0.451, 0.761, 0.788),
+ 42: (0.329, 0.710, 0.710),
+ 43: (0.231, 0.620, 0.620),
+ 44: (0.141, 0.561, 0.561),
+ 45: (0.039, 0.490, 0.549),
+ 46: (0.000, 0.412, 0.522),
+ 47: (0.753, 0.753, 0.753),
+ 48: (1.000, 0.851, 0.561),
+ 49: (0.651, 0.459, 0.451),
+ 50: (0.400, 0.502, 0.502),
+ 51: (0.620, 0.388, 0.710),
+ 52: (0.831, 0.478, 0.000),
+ 53: (0.580, 0.000, 0.580),
+ 54: (0.259, 0.620, 0.690),
+ 55: (0.341, 0.090, 0.561),
+ 56: (0.000, 0.788, 0.000),
+ 57: (0.439, 0.831, 1.000),
+ 58: (1.000, 1.000, 0.780),
+ 59: (0.851, 1.000, 0.780),
+ 60: (0.780, 1.000, 0.780),
+ 61: (0.639, 1.000, 0.780),
+ 62: (0.561, 1.000, 0.780),
+ 63: (0.380, 1.000, 0.780),
+ 64: (0.271, 1.000, 0.780),
+ 65: (0.188, 1.000, 0.780),
+ 66: (0.122, 1.000, 0.780),
+ 67: (0.000, 1.000, 0.612),
+ 68: (0.000, 0.902, 0.459),
+ 69: (0.000, 0.831, 0.322),
+ 70: (0.000, 0.749, 0.220),
+ 71: (0.000, 0.671, 0.141),
+ 72: (0.302, 0.761, 1.000),
+ 73: (0.302, 0.651, 1.000),
+ 74: (0.129, 0.580, 0.839),
+ 75: (0.149, 0.490, 0.671),
+ 76: (0.149, 0.400, 0.588),
+ 77: (0.090, 0.329, 0.529),
+ 78: (0.816, 0.816, 0.878),
+ 79: (1.000, 0.820, 0.137),
+ 80: (0.722, 0.722, 0.816),
+ 81: (0.651, 0.329, 0.302),
+ 82: (0.341, 0.349, 0.380),
+ 83: (0.620, 0.310, 0.710),
+ 84: (0.671, 0.361, 0.000),
+ 85: (0.459, 0.310, 0.271),
+ 86: (0.259, 0.510, 0.588),
+ 87: (0.259, 0.000, 0.400),
+ 88: (0.000, 0.490, 0.000),
+ 89: (0.439, 0.671, 0.980),
+ 90: (0.000, 0.729, 1.000),
+ 91: (0.000, 0.631, 1.000),
+ 92: (0.000, 0.561, 1.000),
+ 93: (0.000, 0.502, 1.000),
+ 94: (0.000, 0.420, 1.000),
+ 95: (0.329, 0.361, 0.949),
+ 96: (0.471, 0.361, 0.890),
+ 97: (0.541, 0.310, 0.890),
+ 98: (0.631, 0.212, 0.831),
+ 99: (0.702, 0.122, 0.831),
+ 100: (0.702, 0.122, 0.729),
+ 101: (0.702, 0.051, 0.651),
+ 102: (0.741, 0.051, 0.529),
+ 103: (0.780, 0.000, 0.400),
+ 104: (0.800, 0.000, 0.349),
+ 105: (0.820, 0.000, 0.310),
+ 106: (0.851, 0.000, 0.271),
+ 107: (0.878, 0.000, 0.220),
+ 108: (0.902, 0.000, 0.180),
+ 109: (0.922, 0.000, 0.149),
+}
+
+# def isPointWithinPolygon(point, polygonVertexCoords):
+# path = matplotlib.path.Path( polygonVertexCoords )
+# return path.contains_point(point[0], point[1])
+
+
+def plot_ring_pattern(
+ radii,
+ intensity,
+ theta=[-np.pi, np.pi, 200],
+ intensity_scale=1,
+ intensity_constant=False,
+ color="k",
+ figsize=(10, 10),
+ returnfig=False,
+ input_fig_handle=None,
+ **kwargs,
+):
+ """
+ 2D plot of diffraction rings
+
+ Args:
+ radii (PointList): 1D numpy array containing radii for diffraction rings
+ intensity (PointList): 1D numpy array containing intensities for diffraciton rings
+ theta (3-tuple): first two values specify angle range, and the last specifies the number of points used for plotting
+ intensity_scale (float): size scaling for ring thickness
+ intensity_constant (bool): if true, all rings are plotted with same line width
+ color (matplotlib color): color of ring, any format recognized by matplotlib
+ figsize (2 element float): size scaling of figure axes
+ returnfig (bool): set to True to return figure and axes handles
+ input_fig_handle (fig,ax) tuple containing a figure / axes handle for the plot
+ """
+
+ theta = np.linspace(*theta)
+
+ if input_fig_handle is None:
+ fig, ax = plt.subplots(1, 1, figsize=figsize, facecolor=(1, 1, 1))
+ else:
+ fig = input_fig_handle[0]
+ ax_parent = input_fig_handle[1]
+ ax = ax_parent[0]
+
+ for a1 in range(radii.shape[0]):
+ if intensity_constant is True:
+ ax.plot(
+ radii[a1] * np.sin(theta),
+ radii[a1] * np.cos(theta),
+ lw=intensity_scale,
+ color=color,
+ **kwargs,
+ )
+ else:
+ ax.plot(
+ radii[a1] * np.sin(theta),
+ radii[a1] * np.cos(theta),
+ lw=intensity[a1] * intensity_scale,
+ color=color,
+ **kwargs,
+ )
+
+ ax.set_xlabel("$q_y$ [Å$^{-1}$]")
+ ax.set_ylabel("$q_x$ [Å$^{-1}$]")
+
+ max_value = np.max(radii) * 1.1
+ ax.set_xlim([-max_value, max_value])
+ ax.set_ylim([-max_value, max_value])
+
+ ax.set_aspect("equal")
+
+ if input_fig_handle is None:
+ plt.show()
+
+ if returnfig:
+ return fig, ax
diff --git a/py4DSTEM/process/diffraction/flowlines.py b/py4DSTEM/process/diffraction/flowlines.py
new file mode 100644
index 000000000..66904d4f8
--- /dev/null
+++ b/py4DSTEM/process/diffraction/flowlines.py
@@ -0,0 +1,1360 @@
+# Functions for creating flowline maps from diffraction spots
+
+
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.figure import Figure
+from matplotlib.axes import Axes
+from scipy.ndimage import gaussian_filter1d
+
+from matplotlib.colors import hsv_to_rgb
+from matplotlib.colors import rgb_to_hsv
+from matplotlib.colors import ListedColormap
+
+from emdfile import tqdmnd, PointList, PointListArray
+
+
+def make_orientation_histogram(
+ bragg_peaks: PointListArray = None,
+ radial_ranges: np.ndarray = None,
+ orientation_map=None,
+ orientation_ind: int = 0,
+ orientation_growth_angles: np.array = 0.0,
+ orientation_separate_bins: bool = False,
+ orientation_flip_sign: bool = False,
+ upsample_factor=4.0,
+ theta_step_deg=1.0,
+ sigma_x=1.0,
+ sigma_y=1.0,
+ sigma_theta=3.0,
+ normalize_intensity_image: bool = False,
+ normalize_intensity_stack: bool = True,
+ progress_bar: bool = True,
+):
+ """
+ Create an 3D or 4D orientation histogram from a braggpeaks PointListArray
+ from user-specified radial ranges, or from the Euler angles from a fiber
+ texture OrientationMap generated by the ACOM module of py4DSTEM.
+
+ Args:
+ bragg_peaks (PointListArray): 2D of pointlists containing centered peak locations.
+ radial_ranges (np array): Size (N x 2) array for N radial bins, or (2,) for a single bin.
+ orientation_map (OrientationMap): Class containing the Euler angles to generate a flowline map.
+ orientation_ind (int): Index of the orientation map (default 0)
+ orientation_growth_angles (array): Angles to place into histogram, relative to orientation.
+ orientation_separate_bins (bool): whether to place multiple angles into multiple radial bins.
+ upsample_factor (float): Upsample factor
+ theta_step_deg (float): Step size along annular direction in degrees
+ sigma_x (float): Smoothing in x direction before upsample
+ sigma_y (float): Smoothing in x direction before upsample
+ sigma_theta (float): Smoothing in annular direction (units of bins, periodic)
+ normalize_intensity_image (bool): Normalize to max peak intensity = 1, per image
+ normalize_intensity_stack (bool): Normalize to max peak intensity = 1, all images
+ progress_bar (bool): Enable progress bar
+
+ Returns:
+ orient_hist (array): 4D array containing Bragg peak intensity histogram
+ [radial_bin x_probe y_probe theta]
+ """
+
+ # coordinates
+ theta = np.arange(0, 180, theta_step_deg) * np.pi / 180.0
+ dtheta = theta[1] - theta[0]
+ dtheta_deg = dtheta * 180 / np.pi
+ num_theta_bins = np.size(theta)
+
+ if orientation_map is None:
+ # Input bins
+ radial_ranges = np.array(radial_ranges)
+ if radial_ranges.ndim == 1:
+ radial_ranges = radial_ranges[None, :]
+ radial_ranges_2 = radial_ranges**2
+ num_radii = radial_ranges.shape[0]
+ size_input = bragg_peaks.shape
+ else:
+ orientation_growth_angles = np.atleast_1d(orientation_growth_angles)
+ num_angles = orientation_growth_angles.shape[0]
+ size_input = [orientation_map.num_x, orientation_map.num_y]
+ if orientation_separate_bins is False:
+ num_radii = 1
+ else:
+ num_radii = num_angles
+
+ size_output = np.round(
+ np.array(size_input).astype("float") * upsample_factor
+ ).astype("int")
+
+ # output init
+ orient_hist = np.zeros([num_radii, size_output[0], size_output[1], num_theta_bins])
+
+ # Loop over all probe positions
+ for a0 in range(num_radii):
+ t = "Generating histogram " + str(a0)
+ # for rx, ry in tqdmnd(
+ # *bragg_peaks.shape, desc=t,unit=" probe positions", disable=not progress_bar
+ # ):
+ for rx, ry in tqdmnd(
+ *size_input, desc=t, unit=" probe positions", disable=not progress_bar
+ ):
+ x = (rx + 0.5) * upsample_factor - 0.5
+ y = (ry + 0.5) * upsample_factor - 0.5
+ x = np.clip(x, 0, size_output[0] - 2)
+ y = np.clip(y, 0, size_output[1] - 2)
+
+ xF = np.floor(x).astype("int")
+ yF = np.floor(y).astype("int")
+ dx = x - xF
+ dy = y - yF
+
+ add_data = False
+
+ if orientation_map is None:
+ p = bragg_peaks.get_pointlist(rx, ry)
+ r2 = p.data["qx"] ** 2 + p.data["qy"] ** 2
+ sub = np.logical_and(
+ r2 >= radial_ranges_2[a0, 0], r2 < radial_ranges_2[a0, 1]
+ )
+ if np.any(sub):
+ add_data = True
+ intensity = p.data["intensity"][sub]
+ t = np.arctan2(p.data["qy"][sub], p.data["qx"][sub]) / dtheta
+ else:
+ if orientation_map.corr[rx, ry, orientation_ind] > 0:
+ if orientation_separate_bins is False:
+ if orientation_flip_sign:
+ t = (
+ np.array(
+ [
+ (
+ -orientation_map.angles[
+ rx, ry, orientation_ind, 0
+ ]
+ - orientation_map.angles[
+ rx, ry, orientation_ind, 2
+ ]
+ )
+ / dtheta
+ ]
+ )
+ + orientation_growth_angles
+ )
+ else:
+ t = (
+ np.array(
+ [
+ (
+ orientation_map.angles[
+ rx, ry, orientation_ind, 0
+ ]
+ + orientation_map.angles[
+ rx, ry, orientation_ind, 2
+ ]
+ )
+ / dtheta
+ ]
+ )
+ + orientation_growth_angles
+ )
+ intensity = (
+ np.ones(num_angles)
+ * orientation_map.corr[rx, ry, orientation_ind]
+ )
+ add_data = True
+ else:
+ if orientation_flip_sign:
+ t = (
+ np.array(
+ [
+ (
+ -orientation_map.angles[
+ rx, ry, orientation_ind, 0
+ ]
+ - orientation_map.angles[
+ rx, ry, orientation_ind, 2
+ ]
+ )
+ / dtheta
+ ]
+ )
+ + orientation_growth_angles[a0]
+ )
+ else:
+ t = (
+ np.array(
+ [
+ (
+ orientation_map.angles[
+ rx, ry, orientation_ind, 0
+ ]
+ + orientation_map.angles[
+ rx, ry, orientation_ind, 2
+ ]
+ )
+ / dtheta
+ ]
+ )
+ + orientation_growth_angles[a0]
+ )
+ intensity = orientation_map.corr[rx, ry, orientation_ind]
+ add_data = True
+
+ if add_data:
+ tF = np.floor(t).astype("int")
+ dt = t - tF
+
+ orient_hist[a0, xF, yF, :] = orient_hist[a0, xF, yF, :] + np.bincount(
+ np.mod(tF, num_theta_bins),
+ weights=(1 - dx) * (1 - dy) * (1 - dt) * intensity,
+ minlength=num_theta_bins,
+ )
+ orient_hist[a0, xF, yF, :] = orient_hist[a0, xF, yF, :] + np.bincount(
+ np.mod(tF + 1, num_theta_bins),
+ weights=(1 - dx) * (1 - dy) * (dt) * intensity,
+ minlength=num_theta_bins,
+ )
+
+ orient_hist[a0, xF + 1, yF, :] = orient_hist[
+ a0, xF + 1, yF, :
+ ] + np.bincount(
+ np.mod(tF, num_theta_bins),
+ weights=(dx) * (1 - dy) * (1 - dt) * intensity,
+ minlength=num_theta_bins,
+ )
+ orient_hist[a0, xF + 1, yF, :] = orient_hist[
+ a0, xF + 1, yF, :
+ ] + np.bincount(
+ np.mod(tF + 1, num_theta_bins),
+ weights=(dx) * (1 - dy) * (dt) * intensity,
+ minlength=num_theta_bins,
+ )
+
+ orient_hist[a0, xF, yF + 1, :] = orient_hist[
+ a0, xF, yF + 1, :
+ ] + np.bincount(
+ np.mod(tF, num_theta_bins),
+ weights=(1 - dx) * (dy) * (1 - dt) * intensity,
+ minlength=num_theta_bins,
+ )
+ orient_hist[a0, xF, yF + 1, :] = orient_hist[
+ a0, xF, yF + 1, :
+ ] + np.bincount(
+ np.mod(tF + 1, num_theta_bins),
+ weights=(1 - dx) * (dy) * (dt) * intensity,
+ minlength=num_theta_bins,
+ )
+
+ orient_hist[a0, xF + 1, yF + 1, :] = orient_hist[
+ a0, xF + 1, yF + 1, :
+ ] + np.bincount(
+ np.mod(tF, num_theta_bins),
+ weights=(dx) * (dy) * (1 - dt) * intensity,
+ minlength=num_theta_bins,
+ )
+ orient_hist[a0, xF + 1, yF + 1, :] = orient_hist[
+ a0, xF + 1, yF + 1, :
+ ] + np.bincount(
+ np.mod(tF + 1, num_theta_bins),
+ weights=(dx) * (dy) * (dt) * intensity,
+ minlength=num_theta_bins,
+ )
+
+ # smoothing / interpolation
+ if (sigma_x is not None) or (sigma_y is not None) or (sigma_theta is not None):
+ if num_radii > 1:
+ print("Interpolating orientation matrices ...", end="")
+ else:
+ print("Interpolating orientation matrix ...", end="")
+ if sigma_x is not None and sigma_x > 0:
+ orient_hist = gaussian_filter1d(
+ orient_hist,
+ sigma_x * upsample_factor,
+ mode="nearest",
+ axis=1,
+ truncate=3.0,
+ )
+ if sigma_y is not None and sigma_y > 0:
+ orient_hist = gaussian_filter1d(
+ orient_hist,
+ sigma_y * upsample_factor,
+ mode="nearest",
+ axis=2,
+ truncate=3.0,
+ )
+ if sigma_theta is not None and sigma_theta > 0:
+ orient_hist = gaussian_filter1d(
+ orient_hist, sigma_theta / dtheta_deg, mode="wrap", axis=3, truncate=2.0
+ )
+ print(" done.")
+
+ # normalization
+ if normalize_intensity_stack is True:
+ orient_hist = orient_hist / np.max(orient_hist)
+ elif normalize_intensity_image is True:
+ for a0 in range(num_radii):
+ orient_hist[a0, :, :, :] = orient_hist[a0, :, :, :] / np.max(
+ orient_hist[a0, :, :, :]
+ )
+
+ return orient_hist
+
+
+def make_flowline_map(
+ orient_hist,
+ thresh_seed=0.2,
+ thresh_grow=0.05,
+ thresh_collision=0.001,
+ sep_seeds=None,
+ sep_xy=6.0,
+ sep_theta=5.0,
+ sort_seeds="intensity",
+ linewidth=2.0,
+ step_size=0.5,
+ min_steps=4,
+ max_steps=1000,
+ sigma_x=1.0,
+ sigma_y=1.0,
+ sigma_theta=2.0,
+ progress_bar: bool = True,
+):
+ """
+ Create an 3D or 4D orientation flowline map - essentially a pixelated "stream map" which represents diffraction data.
+
+ Args:
+ orient_hist (array): Histogram of all orientations with coordinates
+ [radial_bin x_probe y_probe theta]
+ We assume theta bin ranges from 0 to 180 degrees and is periodic.
+ thresh_seed (float): Threshold for seed generation in histogram.
+ thresh_grow (float): Threshold for flowline growth in histogram.
+ thresh_collision (float): Threshold for termination of flowline growth in histogram.
+ sep_seeds (float): Initial seed separation in bins - set to None to use default value,
+ which is equal to 0.5*sep_xy.
+ sep_xy (float): Search radius for flowline direction in x and y.
+ sep_theta = (float): Search radius for flowline direction in theta.
+ sort_seeds (str): How to sort the initial seeds for growth:
+ None - no sorting
+ 'intensity' - sort by histogram intensity
+ 'random' - random order
+ linewidth (float): Thickness of the flowlines in pixels.
+ step_size (float): Step size for flowline growth in pixels.
+ min_steps (int): Minimum number of steps for a flowline to be drawn.
+ max_steps (int): Maximum number of steps for a flowline to be drawn.
+ sigma_x (float): Weighted sigma in x direction for direction update.
+ sigma_y (float): Weighted sigma in y direction for direction update.
+ sigma_theta (float): Weighted sigma in theta for direction update.
+ progress_bar (bool): Enable progress bar
+
+ Returns:
+ orient_flowlines (array): 4D array containing flowlines
+ [radial_bin x_probe y_probe theta]
+ """
+
+ # Ensure sep_xy and sep_theta are arrays
+ sep_xy = np.atleast_1d(sep_xy)
+ sep_theta = np.atleast_1d(sep_theta)
+
+ # number of radial bins
+ num_radii = orient_hist.shape[0]
+ if num_radii > 1 and len(sep_xy) == 1:
+ sep_xy = np.ones(num_radii) * sep_xy
+ if num_radii > 1 and len(sep_theta) == 1:
+ sep_theta = np.ones(num_radii) * sep_theta
+
+ # Default seed separation
+ if sep_seeds is None:
+ sep_seeds = np.round(np.min(sep_xy) / 2 + 0.5).astype("int")
+ else:
+ sep_seeds = np.atleast_1d(sep_seeds).astype("int")
+ if num_radii > 1 and len(sep_seeds) == 1:
+ sep_seeds = (np.ones(num_radii) * sep_seeds).astype("int")
+
+ # coordinates
+ theta = np.linspace(0, np.pi, orient_hist.shape[3], endpoint=False)
+ dtheta = theta[1] - theta[0]
+ size_3D = np.array(
+ [
+ orient_hist.shape[1],
+ orient_hist.shape[2],
+ orient_hist.shape[3],
+ ]
+ )
+
+ # initialize weighting array
+ vx = np.arange(-np.ceil(2 * sigma_x), np.ceil(2 * sigma_x) + 1)
+ vy = np.arange(-np.ceil(2 * sigma_y), np.ceil(2 * sigma_y) + 1)
+ vt = np.arange(-np.ceil(2 * sigma_theta), np.ceil(2 * sigma_theta) + 1)
+ ay, ax, at = np.meshgrid(vy, vx, vt)
+ k = (
+ np.exp(ax**2 / (-2 * sigma_x**2))
+ * np.exp(ay**2 / (-2 * sigma_y**2))
+ * np.exp(at**2 / (-2 * sigma_theta**2))
+ )
+ k = k / np.sum(k)
+ vx = vx[:, None, None].astype("int")
+ vy = vy[None, :, None].astype("int")
+ vt = vt[None, None, :].astype("int")
+
+ # initalize flowline array
+ orient_flowlines = np.zeros_like(orient_hist)
+
+ # initialize output
+ xy_t_int = np.zeros((max_steps + 1, 4))
+ xy_t_int_rev = np.zeros((max_steps + 1, 4))
+
+ # Loop over radial bins
+ for a0 in range(num_radii):
+ # initialize collision check array
+ cr = np.arange(-np.ceil(sep_xy[a0]), np.ceil(sep_xy[a0]) + 1)
+ ct = np.arange(-np.ceil(sep_theta[a0]), np.ceil(sep_theta[a0]) + 1)
+ ay, ax, at = np.meshgrid(cr, cr, ct)
+ c_mask = (
+ (ax**2 + ay**2) / sep_xy[a0] ** 2 + at**2 / sep_theta[a0] ** 2
+ <= (1 + 1 / sep_xy[a0]) ** 2
+ )[None, :, :, :]
+ cx = cr[None, :, None, None].astype("int")
+ cy = cr[None, None, :, None].astype("int")
+ ct = ct[None, None, None, :].astype("int")
+
+ # Find all seed locations
+ orient = orient_hist[a0, :, :, :]
+ sub_seeds = np.logical_and(
+ np.logical_and(
+ orient >= np.roll(orient, 1, axis=2),
+ orient >= np.roll(orient, -1, axis=2),
+ ),
+ orient >= thresh_seed,
+ )
+
+ # Separate seeds
+ if sep_seeds > 0:
+ for a1 in range(sep_seeds - 1):
+ sub_seeds[a1::sep_seeds, :, :] = False
+ sub_seeds[:, a1::sep_seeds, :] = False
+
+ # Index seeds
+ x_inds, y_inds, t_inds = np.where(sub_seeds)
+ if sort_seeds is not None:
+ if sort_seeds == "intensity":
+ inds_sort = np.argsort(orient[sub_seeds])[::-1]
+ elif sort_seeds == "random":
+ inds_sort = np.random.permutation(np.count_nonzero(sub_seeds))
+ x_inds = x_inds[inds_sort]
+ y_inds = y_inds[inds_sort]
+ t_inds = t_inds[inds_sort]
+
+ # for a1 in tqdmnd(range(0,40), desc="Drawing flowlines",unit=" seeds", disable=not progress_bar):
+ t = "Drawing flowlines " + str(a0)
+ for a1 in tqdmnd(
+ range(0, x_inds.shape[0]), desc=t, unit=" seeds", disable=not progress_bar
+ ):
+ # initial coordinate and intensity
+ xy0 = np.array((x_inds[a1], y_inds[a1]))
+ t0 = theta[t_inds[a1]]
+
+ # init theta
+ inds_theta = np.mod(
+ np.round(t0 / dtheta).astype("int") + vt, orient.shape[2]
+ )
+ orient_crop = (
+ k
+ * orient[
+ np.clip(
+ np.round(xy0[0]).astype("int") + vx, 0, orient.shape[0] - 1
+ ),
+ np.clip(
+ np.round(xy0[1]).astype("int") + vy, 0, orient.shape[1] - 1
+ ),
+ inds_theta,
+ ]
+ )
+ theta_crop = theta[inds_theta]
+ t0 = np.sum(orient_crop * theta_crop) / np.sum(orient_crop)
+
+ # forward direction
+ t = t0
+ v0 = np.array((-np.sin(t), np.cos(t)))
+ v = v0 * step_size
+ xy = xy0
+ int_val = get_intensity(orient, xy0[0], xy0[1], t0 / dtheta)
+ xy_t_int[0, 0:2] = xy0
+ xy_t_int[0, 2] = t / dtheta
+ xy_t_int[0, 3] = int_val
+ # main loop
+ grow = True
+ count = 0
+ while grow is True:
+ count += 1
+
+ # update position and intensity
+ xy = xy + v
+ int_val = get_intensity(orient, xy[0], xy[1], t / dtheta)
+
+ # check for collision
+ flow_crop = orient_flowlines[
+ a0,
+ np.clip(np.round(xy[0]).astype("int") + cx, 0, orient.shape[0] - 1),
+ np.clip(np.round(xy[1]).astype("int") + cy, 0, orient.shape[1] - 1),
+ np.mod(np.round(t / dtheta).astype("int") + ct, orient.shape[2]),
+ ]
+ int_flow = np.max(flow_crop[c_mask])
+
+ if (
+ xy[0] < 0
+ or xy[1] < 0
+ or xy[0] > orient.shape[0]
+ or xy[1] > orient.shape[1]
+ or int_val < thresh_grow
+ or int_flow > thresh_collision
+ ):
+ grow = False
+ else:
+ # update direction
+ inds_theta = np.mod(
+ np.round(t / dtheta).astype("int") + vt, orient.shape[2]
+ )
+ orient_crop = (
+ k
+ * orient[
+ np.clip(
+ np.round(xy[0]).astype("int") + vx,
+ 0,
+ orient.shape[0] - 1,
+ ),
+ np.clip(
+ np.round(xy[1]).astype("int") + vy,
+ 0,
+ orient.shape[1] - 1,
+ ),
+ inds_theta,
+ ]
+ )
+ theta_crop = theta[inds_theta]
+ t = np.sum(orient_crop * theta_crop) / np.sum(orient_crop)
+ v = np.array((-np.sin(t), np.cos(t))) * step_size
+
+ xy_t_int[count, 0:2] = xy
+ xy_t_int[count, 2] = t / dtheta
+ xy_t_int[count, 3] = int_val
+
+ if count > max_steps - 1:
+ grow = False
+
+ # reverse direction
+ t = t0 + np.pi
+ v0 = np.array((-np.sin(t), np.cos(t)))
+ v = v0 * step_size
+ xy = xy0
+ int_val = get_intensity(orient, xy0[0], xy0[1], t0 / dtheta)
+ xy_t_int_rev[0, 0:2] = xy0
+ xy_t_int_rev[0, 2] = t / dtheta
+ xy_t_int_rev[0, 3] = int_val
+ # main loop
+ grow = True
+ count_rev = 0
+ while grow is True:
+ count_rev += 1
+
+ # update position and intensity
+ xy = xy + v
+ int_val = get_intensity(orient, xy[0], xy[1], t / dtheta)
+
+ # check for collision
+ flow_crop = orient_flowlines[
+ a0,
+ np.clip(np.round(xy[0]).astype("int") + cx, 0, orient.shape[0] - 1),
+ np.clip(np.round(xy[1]).astype("int") + cy, 0, orient.shape[1] - 1),
+ np.mod(np.round(t / dtheta).astype("int") + ct, orient.shape[2]),
+ ]
+ int_flow = np.max(flow_crop[c_mask])
+
+ if (
+ xy[0] < 0
+ or xy[1] < 0
+ or xy[0] > orient.shape[0]
+ or xy[1] > orient.shape[1]
+ or int_val < thresh_grow
+ or int_flow > thresh_collision
+ ):
+ grow = False
+ else:
+ # update direction
+ inds_theta = np.mod(
+ np.round(t / dtheta).astype("int") + vt, orient.shape[2]
+ )
+ orient_crop = (
+ k
+ * orient[
+ np.clip(
+ np.round(xy[0]).astype("int") + vx,
+ 0,
+ orient.shape[0] - 1,
+ ),
+ np.clip(
+ np.round(xy[1]).astype("int") + vy,
+ 0,
+ orient.shape[1] - 1,
+ ),
+ inds_theta,
+ ]
+ )
+ theta_crop = theta[inds_theta]
+ t = np.sum(orient_crop * theta_crop) / np.sum(orient_crop) + np.pi
+ v = np.array((-np.sin(t), np.cos(t))) * step_size
+
+ xy_t_int_rev[count_rev, 0:2] = xy
+ xy_t_int_rev[count_rev, 2] = t / dtheta
+ xy_t_int_rev[count_rev, 3] = int_val
+
+ if count_rev > max_steps - 1:
+ grow = False
+
+ # write into output array
+ if count + count_rev > min_steps:
+ if count > 0:
+ orient_flowlines[a0, :, :, :] = set_intensity(
+ orient_flowlines[a0, :, :, :], xy_t_int[1:count, :]
+ )
+ if count_rev > 1:
+ orient_flowlines[a0, :, :, :] = set_intensity(
+ orient_flowlines[a0, :, :, :], xy_t_int_rev[1:count_rev, :]
+ )
+
+ # normalize to step size
+ orient_flowlines = orient_flowlines * step_size
+
+ # linewidth
+ if linewidth > 1.0:
+ s = linewidth - 1.0
+
+ orient_flowlines = gaussian_filter1d(orient_flowlines, s, axis=1, truncate=3.0)
+ orient_flowlines = gaussian_filter1d(orient_flowlines, s, axis=2, truncate=3.0)
+ orient_flowlines = orient_flowlines * (s**2)
+
+ return orient_flowlines
+
+
+def make_flowline_rainbow_image(
+ orient_flowlines,
+ int_range=[0, 0.2],
+ sym_rotation_order=2,
+ theta_offset=0.0,
+ greyscale=False,
+ greyscale_max=True,
+ white_background=False,
+ power_scaling=1.0,
+ sum_radial_bins=False,
+ plot_images=True,
+ figsize=None,
+):
+ """
+ Generate RGB output images from the flowline arrays.
+
+ Args:
+ orient_flowline (array): Histogram of all orientations with coordinates [x y radial_bin theta]
+ We assume theta bin ranges from 0 to 180 degrees and is periodic.
+ int_range (float) 2 element array giving the intensity range
+ sym_rotation_order (int): rotational symmety for colouring
+ theta_offset (float): Offset the anglular coloring by this value in radians.
+ greyscale (bool): Set to False for color output, True for greyscale output.
+ greyscale_max (bool): If output is greyscale, use max instead of mean for overlapping flowlines.
+ white_background (bool): For either color or greyscale output, switch to white background (from black).
+ power_scaling (float): Power law scaling for flowline intensity output.
+ sum_radial_bins (bool): Sum all radial bins (alternative is to output separate images).
+ plot_images (bool): Plot the outputs for quick visualization.
+ figsize (2-tuple): Size of output figure.
+
+ Returns:
+ im_flowline (array): 3D or 4D array containing flowline images
+ """
+
+ # init array
+ size_input = orient_flowlines.shape
+ size_output = np.array([size_input[0], size_input[1], size_input[2], 3])
+ im_flowline = np.zeros(size_output)
+ theta_offset = np.atleast_1d(theta_offset)
+
+ if greyscale is True:
+ for a0 in range(size_input[0]):
+ if greyscale_max is True:
+ im = np.max(orient_flowlines[a0, :, :, :], axis=2)
+ else:
+ im = np.mean(orient_flowlines[a0, :, :, :], axis=2)
+
+ sig = np.clip((im - int_range[0]) / (int_range[1] - int_range[0]), 0, 1)
+
+ if power_scaling != 1:
+ sig = sig**power_scaling
+
+ if white_background is False:
+ im_flowline[a0, :, :, :] = sig[:, :, None]
+ else:
+ im_flowline[a0, :, :, :] = 1 - sig[:, :, None]
+
+ else:
+ # Color basis
+ c0 = np.array([1.0, 0.0, 0.0])
+ c1 = np.array([0.0, 0.7, 0.0])
+ c2 = np.array([0.0, 0.3, 1.0])
+
+ # angles
+ theta = np.linspace(0, np.pi, size_input[3], endpoint=False)
+ theta_color = theta * sym_rotation_order
+
+ if size_input[0] > 1 and len(theta_offset) == 1:
+ theta_offset = np.ones(size_input[0]) * theta_offset
+
+ for a0 in range(size_input[0]):
+ # color projections
+ b0 = np.maximum(
+ 1
+ - np.abs(
+ np.mod(theta_offset[a0] + theta_color + np.pi, 2 * np.pi) - np.pi
+ )
+ ** 2
+ / (np.pi * 2 / 3) ** 2,
+ 0,
+ )
+ b1 = np.maximum(
+ 1
+ - np.abs(
+ np.mod(
+ theta_offset[a0] + theta_color - np.pi * 2 / 3 + np.pi,
+ 2 * np.pi,
+ )
+ - np.pi
+ )
+ ** 2
+ / (np.pi * 2 / 3) ** 2,
+ 0,
+ )
+ b2 = np.maximum(
+ 1
+ - np.abs(
+ np.mod(
+ theta_offset[a0] + theta_color - np.pi * 4 / 3 + np.pi,
+ 2 * np.pi,
+ )
+ - np.pi
+ )
+ ** 2
+ / (np.pi * 2 / 3) ** 2,
+ 0,
+ )
+
+ sig = np.clip(
+ (orient_flowlines[a0, :, :, :] - int_range[0])
+ / (int_range[1] - int_range[0]),
+ 0,
+ 1,
+ )
+ if power_scaling != 1:
+ sig = sig**power_scaling
+
+ im_flowline[a0, :, :, :] = (
+ np.sum(sig * b0[None, None, :], axis=2)[:, :, None] * c0[None, None, :]
+ + np.sum(sig * b1[None, None, :], axis=2)[:, :, None]
+ * c1[None, None, :]
+ + np.sum(sig * b2[None, None, :], axis=2)[:, :, None]
+ * c2[None, None, :]
+ )
+
+ # clip limits
+ im_flowline[a0, :, :, :] = np.clip(im_flowline[a0, :, :, :], 0, 1)
+
+ # contrast flip
+ if white_background is True:
+ im = rgb_to_hsv(im_flowline[a0])
+ im_v = im[:, :, 2]
+ im[:, :, 1] = im_v
+ im[:, :, 2] = 1
+ im_flowline[a0] = hsv_to_rgb(im)
+
+ if sum_radial_bins is True:
+ if white_background is False:
+ im_flowline = np.clip(np.sum(im_flowline, axis=0), 0, 1)[None, :, :, :]
+ else:
+ # im_flowline = np.clip(np.sum(im_flowline,axis=0)+1-im_flowline.shape[0],0,1)[None,:,:,:]
+ im_flowline = np.min(im_flowline, axis=0)[None, :, :, :]
+
+ if plot_images is True:
+ if figsize is None:
+ fig, ax = plt.subplots(
+ im_flowline.shape[0], 1, figsize=(10, im_flowline.shape[0] * 10)
+ )
+ else:
+ fig, ax = plt.subplots(im_flowline.shape[0], 1, figsize=figsize)
+
+ if im_flowline.shape[0] > 1:
+ for a0 in range(im_flowline.shape[0]):
+ ax[a0].imshow(im_flowline[a0])
+ # ax[a0].axis('off')
+ plt.subplots_adjust(wspace=0, hspace=0.02)
+ else:
+ ax.imshow(im_flowline[0])
+ # ax.axis('off')
+ plt.show()
+
+ return im_flowline
+
+
+def make_flowline_rainbow_legend(
+ im_size=np.array([256, 256]),
+ sym_rotation_order=2,
+ theta_offset=0.0,
+ white_background=False,
+ return_image=False,
+ radial_range=np.array([0.45, 0.9]),
+ plot_legend=True,
+ figsize=(4, 4),
+):
+ """
+ This function generates a legend for a the rainbow colored flowline maps, and returns it as an RGB image.
+
+ Args:
+ im_size (np.array): Size of legend image in pixels.
+ sym_rotation_order (int): rotational symmety for colouring
+ theta_offset (float): Offset the anglular coloring by this value in radians.
+ white_background (bool): For either color or greyscale output, switch to white background (from black).
+ return_image (bool): Return the image array.
+ radial_range (np.array): Inner and outer radius for the legend ring.
+ plot_legend (bool): Plot the generated legend.
+ figsize (tuple or list): Size of the plotted legend.
+
+ Returns:
+ im_legend (array): Image array for the legend.
+ """
+
+ # Color basis
+ c0 = np.array([1.0, 0.0, 0.0])
+ c1 = np.array([0.0, 0.7, 0.0])
+ c2 = np.array([0.0, 0.3, 1.0])
+
+ # Coordinates
+ x = np.linspace(-1, 1, im_size[0])
+ y = np.linspace(-1, 1, im_size[1])
+ ya, xa = np.meshgrid(-y, x)
+ ra = np.sqrt(xa**2 + ya**2)
+ ta = np.arctan2(ya, xa)
+ ta_sym = ta * sym_rotation_order
+
+ # mask
+ dr = xa[1, 0] - xa[0, 0]
+ mask = np.clip((radial_range[1] - ra) / dr + 0.5, 0, 1) * np.clip(
+ (ra - radial_range[0]) / dr + 0.5, 0, 1
+ )
+
+ # rgb image
+ b0 = np.maximum(
+ 1
+ - np.abs(np.mod(theta_offset + ta_sym + np.pi, 2 * np.pi) - np.pi) ** 2
+ / (np.pi * 2 / 3) ** 2,
+ 0,
+ )
+ b1 = np.maximum(
+ 1
+ - np.abs(
+ np.mod(theta_offset + ta_sym - np.pi * 2 / 3 + np.pi, 2 * np.pi) - np.pi
+ )
+ ** 2
+ / (np.pi * 2 / 3) ** 2,
+ 0,
+ )
+ b2 = np.maximum(
+ 1
+ - np.abs(
+ np.mod(theta_offset + ta_sym - np.pi * 4 / 3 + np.pi, 2 * np.pi) - np.pi
+ )
+ ** 2
+ / (np.pi * 2 / 3) ** 2,
+ 0,
+ )
+ im_legend = (
+ b0[:, :, None] * c0[None, None, :]
+ + b1[:, :, None] * c1[None, None, :]
+ + b2[:, :, None] * c2[None, None, :]
+ )
+ im_legend = im_legend * mask[:, :, None]
+
+ if white_background is True:
+ im_legend = rgb_to_hsv(im_legend)
+ im_v = im_legend[:, :, 2]
+ im_legend[:, :, 1] = im_v
+ im_legend[:, :, 2] = 1
+ im_legend = hsv_to_rgb(im_legend)
+
+ # plotting
+ if plot_legend:
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
+ ax.imshow(im_legend)
+ ax.invert_yaxis()
+ # ax.set_axis_off()
+ ax.axis("off")
+
+ # # angles
+ # theta = np.linspace(0,np.pi,num_angle_bins,endpoint=False)
+ # theta_color = theta * sym_rotation_order
+
+ # # color projections
+ # b0 = np.maximum(1 - np.abs(np.mod(theta_color + np.pi, 2*np.pi) - np.pi)**2 / (np.pi*2/3)**2, 0)
+ # b1 = np.maximum(1 - np.abs(np.mod(theta_color - np.pi*2/3 + np.pi, 2*np.pi) - np.pi)**2 / (np.pi*2/3)**2, 0)
+ # b2 = np.maximum(1 - np.abs(np.mod(theta_color - np.pi*4/3 + np.pi, 2*np.pi) - np.pi)**2 / (np.pi*2/3)**2, 0)
+
+ # print(b0.shape)
+ if return_image:
+ return im_legend
+
+
+def make_flowline_combined_image(
+ orient_flowlines,
+ int_range=[0, 0.2],
+ cvals=np.array(
+ [
+ [0.0, 0.7, 0.0],
+ [1.0, 0.0, 0.0],
+ [0.0, 0.7, 1.0],
+ ]
+ ),
+ white_background=False,
+ power_scaling=1.0,
+ sum_radial_bins=True,
+ plot_images=True,
+ figsize=None,
+):
+ """
+ Generate RGB output images from the flowline arrays.
+
+ Args:
+ orient_flowline (array): Histogram of all orientations with coordinates [x y radial_bin theta]
+ We assume theta bin ranges from 0 to 180 degrees and is periodic.
+ int_range (float) 2 element array giving the intensity range
+ cvals (array): Nx3 size array containing RGB colors for different radial ibns.
+ white_background (bool): For either color or greyscale output, switch to white background (from black).
+ power_scaling (float): Power law scaling for flowline intensities.
+ sum_radial_bins (bool): Sum outputs over radial bins.
+ plot_images (bool): Plot the output images for quick visualization.
+ figsize (2-tuple): Size of output figure.
+
+ Returns:
+ im_flowline (array): flowline images
+ """
+
+ # init array
+ size_input = orient_flowlines.shape
+ size_output = np.array([size_input[0], size_input[1], size_input[2], 3])
+ im_flowline = np.zeros(size_output)
+ cvals = np.array(cvals)
+
+ # Generate all color images
+ for a0 in range(size_input[0]):
+ sig = np.clip(
+ (np.sum(orient_flowlines[a0, :, :, :], axis=2) - int_range[0])
+ / (int_range[1] - int_range[0]),
+ 0,
+ 1,
+ )
+ if power_scaling != 1:
+ sig = sig**power_scaling
+
+ if white_background:
+ im_flowline[a0, :, :, :] = 1 - sig[:, :, None] * (
+ 1 - cvals[a0, :][None, None, :]
+ )
+ else:
+ im_flowline[a0, :, :, :] = sig[:, :, None] * cvals[a0, :][None, None, :]
+
+ # # contrast flip
+ # if white_background is True:
+ # im = rgb_to_hsv(im_flowline[a0,:,:,:])
+ # # im_s = im[:,:,1]
+ # im_v = im[:,:,2]
+ # v_range = [np.min(im_v), np.max(im_v)]
+ # print(v_range)
+
+ # im[:,:,1] = im_v
+ # im[:,:,2] = 1
+ # im_flowline[a0,:,:,:] = hsv_to_rgb(im)
+
+ if sum_radial_bins is True:
+ if white_background is False:
+ im_flowline = np.clip(np.sum(im_flowline, axis=0), 0, 1)[None, :, :, :]
+ else:
+ # im_flowline = np.clip(np.sum(im_flowline,axis=0)+1-im_flowline.shape[0],0,1)[None,:,:,:]
+ im_flowline = np.min(im_flowline, axis=0)[None, :, :, :]
+
+ if plot_images is True:
+ if figsize is None:
+ fig, ax = plt.subplots(
+ im_flowline.shape[0], 1, figsize=(10, im_flowline.shape[0] * 10)
+ )
+ else:
+ fig, ax = plt.subplots(im_flowline.shape[0], 1, figsize=figsize)
+
+ if im_flowline.shape[0] > 1:
+ for a0 in range(im_flowline.shape[0]):
+ ax[a0].imshow(im_flowline[a0])
+ ax[a0].axis("off")
+ plt.subplots_adjust(wspace=0, hspace=0.02)
+ else:
+ ax.imshow(im_flowline[0])
+ ax.axis("off")
+ plt.show()
+
+ return im_flowline
+
+
+def orientation_correlation(
+ orient_hist,
+ radius_max=None,
+):
+ """
+ Take in the 4D orientation histogram, and compute the distance-angle (auto)correlations
+
+ Args:
+ orient_hist (array): 3D or 4D histogram of all orientations with coordinates [x y radial_bin theta]
+ radius_max (float): Maximum radial distance for correlogram calculation. If set to None, the maximum
+ radius will be set to min(orient_hist.shape[0],orient_hist.shape[1])/2.
+
+ Returns:
+ orient_corr (array): 3D or 4D array containing correlation images as function of (dr,dtheta)
+ """
+
+ # Array sizes
+ size_input = np.array(orient_hist.shape)
+ if radius_max is None:
+ radius_max = np.ceil(np.min(orient_hist.shape[1:3]) / 2).astype("int")
+ size_corr = np.array(
+ [
+ np.maximum(2 * size_input[1], 2 * radius_max),
+ np.maximum(2 * size_input[2], 2 * radius_max),
+ ]
+ )
+
+ # Initialize orientation histogram
+ orient_hist_pad = np.zeros(
+ (
+ size_input[0],
+ size_corr[0],
+ size_corr[1],
+ size_input[3],
+ ),
+ dtype="complex",
+ )
+ orient_norm_pad = np.zeros(
+ (
+ size_input[0],
+ size_corr[0],
+ size_corr[1],
+ ),
+ dtype="complex",
+ )
+
+ # Pad the histogram in real space
+ x_inds = np.arange(size_input[1])
+ y_inds = np.arange(size_input[2])
+ orient_hist_pad[:, x_inds[:, None], y_inds[None, :], :] = orient_hist
+ orient_norm_pad[:, x_inds[:, None], y_inds[None, :]] = np.sum(
+ orient_hist, axis=3
+ ) / np.sqrt(size_input[3])
+ orient_hist_pad = np.fft.fftn(orient_hist_pad, axes=(1, 2, 3))
+ orient_norm_pad = np.fft.fftn(orient_norm_pad, axes=(1, 2))
+
+ # Radial coordinates for integration
+ x = (
+ np.mod(np.arange(size_corr[0]) + size_corr[0] / 2, size_corr[0])
+ - size_corr[0] / 2
+ )
+ y = (
+ np.mod(np.arange(size_corr[1]) + size_corr[1] / 2, size_corr[1])
+ - size_corr[1] / 2
+ )
+ ya, xa = np.meshgrid(y, x)
+ ra = np.sqrt(xa**2 + ya**2)
+
+ # coordinate subset
+ sub0 = ra <= radius_max
+ sub1 = ra <= radius_max - 1
+ rF0 = np.floor(ra[sub0]).astype("int")
+ rF1 = np.floor(ra[sub1]).astype("int")
+ dr0 = ra[sub0] - rF0
+ dr1 = ra[sub1] - rF1
+ inds = np.concatenate((rF0, rF1 + 1))
+ weights = np.concatenate((1 - dr0, dr1))
+
+ # init output
+ num_corr = (0.5 * size_input[0] * (size_input[0] + 1)).astype("int")
+ orient_corr = np.zeros(
+ (
+ num_corr,
+ (size_input[3] / 2 + 1).astype("int"),
+ radius_max + 1,
+ )
+ )
+
+ # Main correlation calculation
+ ind_output = 0
+ for a0 in range(size_input[0]):
+ for a1 in range(size_input[0]):
+ if a0 <= a1:
+ # Correlation
+ c = np.real(
+ np.fft.ifftn(
+ orient_hist_pad[a0, :, :, :]
+ * np.conj(orient_hist_pad[a1, :, :, :]),
+ axes=(0, 1, 2),
+ )
+ )
+
+ # Loop over all angles from 0 to pi/2 (half of indices)
+ for a2 in range((size_input[3] / 2 + 1).astype("int")):
+ orient_corr[ind_output, a2, :] = np.bincount(
+ inds,
+ weights=weights
+ * np.concatenate((c[:, :, a2][sub0], c[:, :, a2][sub1])),
+ minlength=radius_max,
+ )
+
+ # normalize
+ c_norm = np.real(
+ np.fft.ifftn(
+ orient_norm_pad[a0, :, :] * np.conj(orient_norm_pad[a1, :, :]),
+ axes=(0, 1),
+ )
+ )
+ sig_norm = np.bincount(
+ inds,
+ weights=weights * np.concatenate((c_norm[sub0], c_norm[sub1])),
+ minlength=radius_max,
+ )
+ orient_corr[ind_output, :, :] /= sig_norm[None, :]
+
+ # increment output index
+ ind_output += 1
+
+ return orient_corr
+
+
+def plot_orientation_correlation(
+ orient_corr,
+ prob_range=[0.1, 10.0],
+ inds_plot=None,
+ pixel_size=None,
+ pixel_units=None,
+ size_fig=[8, 6],
+ return_fig=False,
+):
+ """
+ Plot the distance-angle (auto)correlations in orient_corr.
+
+ Args:
+ orient_corr (array): 3D or 4D array containing correlation images as function of (dr,dtheta)
+ 1st index represents each pair of rings.
+ prob_range (array): Plotting range in units of "multiples of random distribution".
+ inds_plot (float): Which indices to plot for orient_corr. Set to "None" to plot all pairs.
+ pixel_size (float): Pixel size for x axis.
+ pixel_units (str): units of pixels.
+ size_fig (array): Size of the figure panels.
+ return_fig (bool): Whether to return figure axes.
+
+ Returns:
+ fig, ax Figure and axes handles (optional).
+ """
+
+ # Make sure range is an numpy array
+ prob_range = np.array(prob_range)
+
+ if pixel_units is None:
+ pixel_units = "pixels"
+
+ # Get the pair indices
+ size_input = orient_corr.shape
+ num_corr = (np.sqrt(8 * size_input[0] + 1) / 2 - 1 / 2).astype("int")
+ ya, xa = np.meshgrid(np.arange(num_corr), np.arange(num_corr))
+ keep = ya >= xa
+ # row 0 is the first diff ring, row 1 is the second diff ring
+ pair_inds = np.vstack((xa[keep], ya[keep]))
+
+ if inds_plot is None:
+ inds_plot = np.arange(size_input[0])
+ elif np.ndim(inds_plot) == 0:
+ inds_plot = np.atleast_1d(inds_plot)
+ else:
+ inds_plot = np.array(inds_plot)
+
+ # Custom divergent colormap:
+ # dark blue
+ # light blue
+ # white
+ # red
+ # dark red
+ N = 256
+ cvals = np.zeros((N, 4))
+ cvals[:, 3] = 1
+ c = np.linspace(0.0, 1.0, int(N / 4))
+
+ cvals[0 : int(N / 4), 1] = c * 0.4 + 0.3
+ cvals[0 : int(N / 4), 2] = 1
+
+ cvals[int(N / 4) : int(N / 2), 0] = c
+ cvals[int(N / 4) : int(N / 2), 1] = c * 0.3 + 0.7
+ cvals[int(N / 4) : int(N / 2), 2] = 1
+
+ cvals[int(N / 2) : int(N * 3 / 4), 0] = 1
+ cvals[int(N / 2) : int(N * 3 / 4), 1] = 1 - c
+ cvals[int(N / 2) : int(N * 3 / 4), 2] = 1 - c
+
+ cvals[int(N * 3 / 4) : N, 0] = 1 - 0.5 * c
+ new_cmap = ListedColormap(cvals)
+
+ # plotting
+ num_plot = inds_plot.shape[0]
+ fig, ax = plt.subplots(num_plot, 1, figsize=(size_fig[0], num_plot * size_fig[1]))
+
+ # loop over indices
+ for count, ind in enumerate(inds_plot):
+ if num_plot > 1:
+ p = ax[count].imshow(
+ np.log10(orient_corr[ind, :, :]),
+ vmin=np.log10(prob_range[0]),
+ vmax=np.log10(prob_range[1]),
+ aspect="auto",
+ cmap=new_cmap,
+ )
+ ax_handle = ax[count]
+ else:
+ p = ax.imshow(
+ np.log10(orient_corr[ind, :, :]),
+ vmin=np.log10(prob_range[0]),
+ vmax=np.log10(prob_range[1]),
+ aspect="auto",
+ cmap=new_cmap,
+ )
+ ax_handle = ax
+
+ cbar = fig.colorbar(p, ax=ax_handle)
+ t = cbar.get_ticks()
+ t_lab = []
+ for a1 in range(t.shape[0]):
+ t_lab.append(f"{10**t[a1]:.2g}")
+
+ cbar.set_ticks(t)
+ cbar.ax.set_yticklabels(t_lab)
+ cbar.ax.set_ylabel("Probability [mult. of rand. dist.]", fontsize=12)
+
+ ind_0 = pair_inds[0, ind]
+ ind_1 = pair_inds[1, ind]
+
+ if ind_0 != ind_1:
+ ax_handle.set_title(
+ "Correlation of Rings " + str(ind_0) + " and " + str(ind_1), fontsize=16
+ )
+ else:
+ ax_handle.set_title("Autocorrelation of Ring " + str(ind_0), fontsize=16)
+
+ # x axis labels
+ if pixel_size is not None:
+ x_t = ax_handle.get_xticks()
+ sub = np.logical_or(x_t < 0, x_t > orient_corr.shape[2])
+ x_t_new = np.delete(x_t, sub)
+ ax_handle.set_xticks(x_t_new)
+ ax_handle.set_xticklabels(x_t_new * pixel_size)
+ ax_handle.set_xlabel("Radial Distance [" + pixel_units + "]", fontsize=12)
+
+ # y axis labels
+ ax_handle.invert_yaxis()
+ ax_handle.set_ylabel("Relative Grain Orientation [degrees]", fontsize=12)
+ ax_handle.set_yticks([0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
+ ax_handle.set_yticklabels(["0", "", "", "30", "", "", "60", "", "", "90"])
+
+ if return_fig is True:
+ return fig, ax
+ plt.show()
+
+
+def get_intensity(orient, x, y, t):
+ # utility function to get histogram intensites
+
+ x = np.clip(x, 0, orient.shape[0] - 2)
+ y = np.clip(y, 0, orient.shape[1] - 2)
+
+ xF = np.floor(x).astype("int")
+ yF = np.floor(y).astype("int")
+ tF = np.floor(t).astype("int")
+ dx = x - xF
+ dy = y - yF
+ dt = t - tF
+ t1 = np.mod(tF, orient.shape[2])
+ t2 = np.mod(tF + 1, orient.shape[2])
+
+ int_vals = (
+ orient[xF, yF, t1] * ((1 - dx) * (1 - dy) * (1 - dt))
+ + orient[xF, yF, t2] * ((1 - dx) * (1 - dy) * (dt))
+ + orient[xF, yF + 1, t1] * ((1 - dx) * (dy) * (1 - dt))
+ + orient[xF, yF + 1, t2] * ((1 - dx) * (dy) * (dt))
+ + orient[xF + 1, yF, t1] * ((dx) * (1 - dy) * (1 - dt))
+ + orient[xF + 1, yF, t2] * ((dx) * (1 - dy) * (dt))
+ + orient[xF + 1, yF + 1, t1] * ((dx) * (dy) * (1 - dt))
+ + orient[xF + 1, yF + 1, t2] * ((dx) * (dy) * (dt))
+ )
+
+ return int_vals
+
+
+def set_intensity(orient, xy_t_int):
+ # utility function to set flowline intensites
+
+ xF = np.floor(xy_t_int[:, 0]).astype("int")
+ yF = np.floor(xy_t_int[:, 1]).astype("int")
+ tF = np.floor(xy_t_int[:, 2]).astype("int")
+ dx = xy_t_int[:, 0] - xF
+ dy = xy_t_int[:, 1] - yF
+ dt = xy_t_int[:, 2] - tF
+
+ inds_1D = np.ravel_multi_index(
+ [xF, yF, tF], orient.shape[0:3], mode=["clip", "clip", "wrap"]
+ )
+ orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:, 3] * (1 - dx) * (
+ 1 - dy
+ ) * (1 - dt)
+ inds_1D = np.ravel_multi_index(
+ [xF, yF, tF + 1], orient.shape[0:3], mode=["clip", "clip", "wrap"]
+ )
+ orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:, 3] * (1 - dx) * (
+ 1 - dy
+ ) * (dt)
+ inds_1D = np.ravel_multi_index(
+ [xF, yF + 1, tF], orient.shape[0:3], mode=["clip", "clip", "wrap"]
+ )
+ orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:, 3] * (1 - dx) * (
+ dy
+ ) * (1 - dt)
+ inds_1D = np.ravel_multi_index(
+ [xF, yF + 1, tF + 1], orient.shape[0:3], mode=["clip", "clip", "wrap"]
+ )
+ orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:, 3] * (1 - dx) * (
+ dy
+ ) * (dt)
+ inds_1D = np.ravel_multi_index(
+ [xF + 1, yF, tF], orient.shape[0:3], mode=["clip", "clip", "wrap"]
+ )
+ orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:, 3] * (dx) * (
+ 1 - dy
+ ) * (1 - dt)
+ inds_1D = np.ravel_multi_index(
+ [xF + 1, yF, tF + 1], orient.shape[0:3], mode=["clip", "clip", "wrap"]
+ )
+ orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:, 3] * (dx) * (
+ 1 - dy
+ ) * (dt)
+ inds_1D = np.ravel_multi_index(
+ [xF + 1, yF + 1, tF], orient.shape[0:3], mode=["clip", "clip", "wrap"]
+ )
+ orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:, 3] * (dx) * (dy) * (
+ 1 - dt
+ )
+ inds_1D = np.ravel_multi_index(
+ [xF + 1, yF + 1, tF + 1], orient.shape[0:3], mode=["clip", "clip", "wrap"]
+ )
+ orient.ravel()[inds_1D] = orient.ravel()[inds_1D] + xy_t_int[:, 3] * (dx) * (dy) * (
+ dt
+ )
+
+ return orient
diff --git a/py4DSTEM/process/diffraction/tdesign.py b/py4DSTEM/process/diffraction/tdesign.py
new file mode 100644
index 000000000..1a0a81fb6
--- /dev/null
+++ b/py4DSTEM/process/diffraction/tdesign.py
@@ -0,0 +1,2000 @@
+__all__ = ["tdesign"]
+
+import numpy as np
+
+
+def tdesign(degree):
+ """
+ Returns the spherical coordinates of minimal T-designs.
+
+ This function returns the unit vectors and the spherical coordinates
+ of t-designs, which constitute uniform arrangements on the sphere for
+ which spherical polynomials up to degree t can be integrated exactly by
+ summation of their values at the points defined by the t-design.
+ Designs for order up to t=21 are stored and returned. Note that for the
+ spherical harmonic transform (SHT) of a function of order N, a spherical
+ t-design of t>=2N should be used (or equivalently N=floor(t/2) ), since
+ the integral evaluates the product of the spherical function with
+ spherical harmonics of up to order N.
+
+ Returns:
+ azim: Nx1, azimuth of each point in the t-design
+ elev: Nx1, elevation of each point in the t-design
+ vecs: Nx3, array of cartesian coordinates for each point
+
+ The designs have been copied from:
+ http://neilsloane.com/sphdesigns/
+ and should be referenced as:
+ "McLaren's Improved Snub Cube and Other New Spherical Designs in
+ Three Dimensions", R. H. Hardin and N. J. A. Sloane, Discrete and
+ Computational Geometry, 15 (1996), pp. 429-441.
+
+ Based on the MATLAB implementation by:
+ Archontis Politis, archontis.politis@aalto.fi, 10/11/2014
+ """
+
+ assert degree <= 21, "Degree must be 21 or less."
+ assert degree >= 1, "Degree should be at least 1."
+ assert type(degree) is int, "Degree should be an integer."
+
+ vecs = _tdesigns[degree - 1]
+
+ azim = np.arctan2(vecs[:, 1], vecs[:, 0])
+ elev = np.arctan2(np.hypot(vecs[:, 1], vecs[:, 0]), vecs[:, 2])
+
+ # elev = np.arctan2(vecs[:,2], np.hypot(vecs[:,1], vecs[:,0]))
+
+ return azim, elev, vecs
+
+
+_tdesigns = [
+ # degree 1
+ np.array(
+ [
+ [1, 0, 0],
+ [-1, 0, 0],
+ ]
+ ),
+ # degree 2
+ np.array(
+ [
+ [0.577350269189626, 0.577350269189626, 0.577350269189626],
+ [0.577350269189626, -0.577350269189626, -0.577350269189626],
+ [-0.577350269189626, 0.577350269189626, -0.577350269189626],
+ [-0.577350269189626, -0.577350269189626, 0.577350269189626],
+ ]
+ ),
+ # degree 3
+ np.array(
+ [
+ [1, 0, 0],
+ [-1, 0, 0],
+ [0, 1, 0],
+ [0, -1, 0],
+ [0, 0, 1],
+ [0, 0, -1],
+ ]
+ ),
+ # degree 4
+ np.array(
+ [
+ [0.850650808352, 0, -0.525731112119],
+ [0.525731112119, -0.850650808352, 0],
+ [0, -0.525731112119, 0.850650808352],
+ [0.850650808352, 0, 0.525731112119],
+ [-0.525731112119, -0.850650808352, 0],
+ [0, 0.525731112119, -0.850650808352],
+ [-0.850650808352, 0, -0.525731112119],
+ [-0.525731112119, 0.850650808352, 0],
+ [0, 0.525731112119, 0.850650808352],
+ [-0.850650808352, 0, 0.525731112119],
+ [0.525731112119, 0.850650808352, 0],
+ [0, -0.525731112119, -0.850650808352],
+ ]
+ ),
+ # degree 5
+ np.array(
+ [
+ [0.850650808352, 0, -0.525731112119],
+ [0.525731112119, -0.850650808352, 0],
+ [0, -0.525731112119, 0.850650808352],
+ [0.850650808352, 0, 0.525731112119],
+ [-0.525731112119, -0.850650808352, 0],
+ [0, 0.525731112119, -0.850650808352],
+ [-0.850650808352, 0, -0.525731112119],
+ [-0.525731112119, 0.850650808352, 0],
+ [0, 0.525731112119, 0.850650808352],
+ [-0.850650808352, 0, 0.525731112119],
+ [0.525731112119, 0.850650808352, 0],
+ [0, -0.525731112119, -0.850650808352],
+ ]
+ ),
+ # degree 6
+ np.array(
+ [
+ [0.866246818107821, 0.422518653761112, 0.266635401516705],
+ [0.866246818107821, -0.422518653761112, -0.266635401516705],
+ [0.866246818107821, 0.266635401516705, -0.422518653761112],
+ [0.866246818107821, -0.266635401516705, 0.422518653761112],
+ [-0.866246818107821, 0.422518653761112, -0.266635401516705],
+ [-0.866246818107821, -0.422518653761112, 0.266635401516705],
+ [-0.866246818107821, 0.266635401516705, 0.422518653761112],
+ [-0.866246818107821, -0.266635401516705, -0.422518653761112],
+ [0.266635401516705, 0.866246818107821, 0.422518653761112],
+ [-0.266635401516705, 0.866246818107821, -0.422518653761112],
+ [-0.422518653761112, 0.866246818107821, 0.266635401516705],
+ [0.422518653761112, 0.866246818107821, -0.266635401516705],
+ [-0.266635401516705, -0.866246818107821, 0.422518653761112],
+ [0.266635401516705, -0.866246818107821, -0.422518653761112],
+ [0.422518653761112, -0.866246818107821, 0.266635401516705],
+ [-0.422518653761112, -0.866246818107821, -0.266635401516705],
+ [0.422518653761112, 0.266635401516705, 0.866246818107821],
+ [-0.422518653761112, -0.266635401516705, 0.866246818107821],
+ [0.266635401516705, -0.422518653761112, 0.866246818107821],
+ [-0.266635401516705, 0.422518653761112, 0.866246818107821],
+ [0.422518653761112, -0.266635401516705, -0.866246818107821],
+ [-0.422518653761112, 0.266635401516705, -0.866246818107821],
+ [0.266635401516705, 0.422518653761112, -0.866246818107821],
+ [-0.266635401516705, -0.422518653761112, -0.866246818107821],
+ ]
+ ),
+ # degree 7
+ np.array(
+ [
+ [0.866246818107821, 0.422518653761112, 0.266635401516705],
+ [0.866246818107821, -0.422518653761112, -0.266635401516705],
+ [0.866246818107821, 0.266635401516705, -0.422518653761112],
+ [0.866246818107821, -0.266635401516705, 0.422518653761112],
+ [-0.866246818107821, 0.422518653761112, -0.266635401516705],
+ [-0.866246818107821, -0.422518653761112, 0.266635401516705],
+ [-0.866246818107821, 0.266635401516705, 0.422518653761112],
+ [-0.866246818107821, -0.266635401516705, -0.422518653761112],
+ [0.266635401516705, 0.866246818107821, 0.422518653761112],
+ [-0.266635401516705, 0.866246818107821, -0.422518653761112],
+ [-0.422518653761112, 0.866246818107821, 0.266635401516705],
+ [0.422518653761112, 0.866246818107821, -0.266635401516705],
+ [-0.266635401516705, -0.866246818107821, 0.422518653761112],
+ [0.266635401516705, -0.866246818107821, -0.422518653761112],
+ [0.422518653761112, -0.866246818107821, 0.266635401516705],
+ [-0.422518653761112, -0.866246818107821, -0.266635401516705],
+ [0.422518653761112, 0.266635401516705, 0.866246818107821],
+ [-0.422518653761112, -0.266635401516705, 0.866246818107821],
+ [0.266635401516705, -0.422518653761112, 0.866246818107821],
+ [-0.266635401516705, 0.422518653761112, 0.866246818107821],
+ [0.422518653761112, -0.266635401516705, -0.866246818107821],
+ [-0.422518653761112, 0.266635401516705, -0.866246818107821],
+ [0.266635401516705, 0.422518653761112, -0.866246818107821],
+ [-0.266635401516705, -0.422518653761112, -0.866246818107821],
+ ]
+ ),
+ # degree 8
+ np.array(
+ [
+ [0.507475446410817, -0.306200013239571, 0.805425492011663],
+ [-0.306200013239569, 0.805425492011663, 0.507475446410817],
+ [-0.507475446410817, 0.30620001323957, 0.805425492011663],
+ [0.805425492011663, 0.507475446410817, -0.306200013239569],
+ [0.306200013239569, 0.805425492011664, -0.507475446410817],
+ [0.805425492011663, -0.507475446410817, 0.306200013239569],
+ [0.306200013239569, -0.805425492011663, 0.507475446410816],
+ [-0.805425492011663, -0.507475446410817, -0.306200013239569],
+ [-0.30620001323957, -0.805425492011664, -0.507475446410816],
+ [-0.805425492011663, 0.507475446410818, 0.306200013239569],
+ [0.507475446410817, 0.30620001323957, -0.805425492011663],
+ [-0.507475446410817, -0.30620001323957, -0.805425492011663],
+ [0.626363670265271, -0.243527775409194, -0.74051520928072],
+ [-0.243527775409195, -0.74051520928072, 0.626363670265271],
+ [-0.626363670265271, 0.243527775409194, -0.74051520928072],
+ [-0.74051520928072, 0.62636367026527, -0.243527775409195],
+ [0.243527775409195, -0.740515209280719, -0.626363670265271],
+ [-0.74051520928072, -0.62636367026527, 0.243527775409195],
+ [0.243527775409195, 0.740515209280719, 0.626363670265271],
+ [0.74051520928072, -0.62636367026527, -0.243527775409195],
+ [-0.243527775409195, 0.74051520928072, -0.626363670265271],
+ [0.74051520928072, 0.62636367026527, 0.243527775409195],
+ [0.626363670265271, 0.243527775409194, 0.74051520928072],
+ [-0.626363670265271, -0.243527775409194, 0.74051520928072],
+ [-0.286248723426035, 0.957120327092458, -0.044523564585421],
+ [0.957120327092458, -0.04452356458542, -0.286248723426035],
+ [0.286248723426035, -0.957120327092458, -0.044523564585421],
+ [-0.04452356458542, -0.286248723426035, 0.957120327092458],
+ [-0.957120327092458, -0.044523564585419, 0.286248723426035],
+ [-0.044523564585421, 0.286248723426034, -0.957120327092458],
+ [-0.957120327092458, 0.04452356458542, -0.286248723426034],
+ [0.044523564585421, 0.286248723426034, 0.957120327092458],
+ [0.957120327092458, 0.04452356458542, 0.286248723426034],
+ [0.044523564585421, -0.286248723426034, -0.957120327092458],
+ [-0.286248723426034, -0.957120327092458, 0.044523564585421],
+ [0.286248723426035, 0.957120327092458, 0.044523564585421],
+ ]
+ ),
+ # degree 9
+ np.array(
+ [
+ [0.93336469319931, 0.353542188921472, -0.0619537742318597],
+ [0.93336469319931, -0.353542188921472, 0.0619537742318597],
+ [0.93336469319931, -0.0619537742318597, -0.353542188921472],
+ [0.93336469319931, 0.0619537742318597, 0.353542188921472],
+ [-0.93336469319931, 0.353542188921472, 0.0619537742318597],
+ [-0.93336469319931, -0.353542188921472, -0.0619537742318597],
+ [-0.93336469319931, -0.0619537742318597, 0.353542188921472],
+ [-0.93336469319931, 0.0619537742318597, -0.353542188921472],
+ [-0.0619537742318597, 0.93336469319931, 0.353542188921472],
+ [0.0619537742318597, 0.93336469319931, -0.353542188921472],
+ [-0.353542188921472, 0.93336469319931, -0.0619537742318597],
+ [0.353542188921472, 0.93336469319931, 0.0619537742318597],
+ [0.0619537742318597, -0.93336469319931, 0.353542188921472],
+ [-0.0619537742318597, -0.93336469319931, -0.353542188921472],
+ [0.353542188921472, -0.93336469319931, -0.0619537742318597],
+ [-0.353542188921472, -0.93336469319931, 0.0619537742318597],
+ [0.353542188921472, -0.0619537742318597, 0.93336469319931],
+ [-0.353542188921472, 0.0619537742318597, 0.93336469319931],
+ [-0.0619537742318597, -0.353542188921472, 0.93336469319931],
+ [0.0619537742318597, 0.353542188921472, 0.93336469319931],
+ [0.353542188921472, 0.0619537742318597, -0.93336469319931],
+ [-0.353542188921472, -0.0619537742318597, -0.93336469319931],
+ [-0.0619537742318597, 0.353542188921472, -0.93336469319931],
+ [0.0619537742318597, -0.353542188921472, -0.93336469319931],
+ [0.70684169771255, 0.639740098619792, 0.301840057965769],
+ [0.70684169771255, -0.639740098619792, -0.301840057965769],
+ [0.70684169771255, 0.301840057965769, -0.639740098619792],
+ [0.70684169771255, -0.301840057965769, 0.639740098619792],
+ [-0.70684169771255, 0.639740098619792, -0.301840057965769],
+ [-0.70684169771255, -0.639740098619792, 0.301840057965769],
+ [-0.70684169771255, 0.301840057965769, 0.639740098619792],
+ [-0.70684169771255, -0.301840057965769, -0.639740098619792],
+ [0.301840057965769, 0.70684169771255, 0.639740098619792],
+ [-0.301840057965769, 0.70684169771255, -0.639740098619792],
+ [-0.639740098619792, 0.70684169771255, 0.301840057965769],
+ [0.639740098619792, 0.70684169771255, -0.301840057965769],
+ [-0.301840057965769, -0.70684169771255, 0.639740098619792],
+ [0.301840057965769, -0.70684169771255, -0.639740098619792],
+ [0.639740098619792, -0.70684169771255, 0.301840057965769],
+ [-0.639740098619792, -0.70684169771255, -0.301840057965769],
+ [0.639740098619792, 0.301840057965769, 0.70684169771255],
+ [-0.639740098619792, -0.301840057965769, 0.70684169771255],
+ [0.301840057965769, -0.639740098619792, 0.70684169771255],
+ [-0.301840057965769, 0.639740098619792, 0.70684169771255],
+ [0.639740098619792, -0.301840057965769, -0.70684169771255],
+ [-0.639740098619792, 0.301840057965769, -0.70684169771255],
+ [0.301840057965769, 0.639740098619792, -0.70684169771255],
+ [-0.301840057965769, -0.639740098619792, -0.70684169771255],
+ ]
+ ),
+ # degree 10
+ np.array(
+ [
+ [-0.753828667197017, 0.54595190806126, -0.365621190026287],
+ [0.545951908061258, -0.36562119002629, -0.753828667197017],
+ [0.753828667197016, -0.545951908061261, -0.365621190026288],
+ [-0.365621190026289, -0.753828667197017, 0.545951908061259],
+ [-0.545951908061258, -0.365621190026288, 0.753828667197018],
+ [-0.365621190026289, 0.753828667197017, -0.545951908061259],
+ [-0.545951908061258, 0.365621190026289, -0.753828667197017],
+ [0.365621190026287, 0.753828667197017, 0.54595190806126],
+ [0.545951908061259, 0.365621190026289, 0.753828667197017],
+ [0.365621190026287, -0.753828667197018, -0.545951908061259],
+ [-0.753828667197017, -0.545951908061261, 0.365621190026288],
+ [0.753828667197016, 0.545951908061261, 0.365621190026287],
+ [0.70018101936373, -0.713151065847793, 0.034089549761256],
+ [-0.713151065847794, 0.034089549761254, 0.700181019363729],
+ [-0.70018101936373, 0.713151065847793, 0.034089549761256],
+ [0.034089549761255, 0.70018101936373, -0.713151065847793],
+ [0.713151065847793, 0.034089549761254, -0.70018101936373],
+ [0.034089549761257, -0.700181019363729, 0.713151065847794],
+ [0.713151065847794, -0.034089549761255, 0.700181019363728],
+ [-0.034089549761256, -0.700181019363729, -0.713151065847794],
+ [-0.713151065847794, -0.034089549761254, -0.700181019363729],
+ [-0.034089549761257, 0.700181019363729, 0.713151065847794],
+ [0.70018101936373, 0.713151065847793, -0.034089549761257],
+ [-0.700181019363729, -0.713151065847794, -0.034089549761257],
+ [0.276230218261792, 0.077050720725736, -0.957997939953259],
+ [0.077050720725735, -0.957997939953258, 0.276230218261793],
+ [-0.276230218261792, -0.077050720725734, -0.957997939953259],
+ [-0.957997939953259, 0.276230218261791, 0.077050720725738],
+ [-0.077050720725735, -0.957997939953259, -0.276230218261792],
+ [-0.957997939953258, -0.276230218261793, -0.077050720725736],
+ [-0.077050720725736, 0.957997939953258, 0.276230218261794],
+ [0.957997939953259, -0.27623021826179, 0.077050720725737],
+ [0.077050720725734, 0.957997939953259, -0.276230218261792],
+ [0.957997939953258, 0.276230218261793, -0.077050720725738],
+ [0.276230218261793, -0.077050720725736, 0.957997939953258],
+ [-0.276230218261791, 0.077050720725735, 0.957997939953259],
+ [0.451819102555243, -0.783355937521819, 0.42686411621907],
+ [-0.783355937521818, 0.426864116219071, 0.451819102555243],
+ [-0.451819102555243, 0.783355937521819, 0.42686411621907],
+ [0.426864116219071, 0.451819102555242, -0.783355937521819],
+ [0.783355937521818, 0.42686411621907, -0.451819102555244],
+ [0.426864116219072, -0.451819102555242, 0.783355937521818],
+ [0.783355937521819, -0.42686411621907, 0.451819102555242],
+ [-0.426864116219072, -0.451819102555241, -0.783355937521819],
+ [-0.783355937521818, -0.42686411621907, -0.451819102555243],
+ [-0.426864116219072, 0.451819102555241, 0.783355937521819],
+ [0.451819102555243, 0.783355937521818, -0.426864116219071],
+ [-0.451819102555242, -0.783355937521819, -0.426864116219071],
+ [-0.33858435995926, -0.933210037239527, 0.120331448866784],
+ [-0.933210037239526, 0.120331448866787, -0.33858435995926],
+ [0.338584359959261, 0.933210037239526, 0.120331448866786],
+ [0.120331448866785, -0.338584359959261, -0.933210037239526],
+ [0.933210037239526, 0.120331448866789, 0.33858435995926],
+ [0.120331448866785, 0.338584359959261, 0.933210037239526],
+ [0.933210037239526, -0.120331448866787, -0.338584359959262],
+ [-0.120331448866784, 0.338584359959262, -0.933210037239526],
+ [-0.933210037239526, -0.120331448866787, 0.338584359959261],
+ [-0.120331448866784, -0.338584359959262, 0.933210037239526],
+ [-0.338584359959262, 0.933210037239526, -0.120331448866784],
+ [0.338584359959261, -0.933210037239527, -0.120331448866783],
+ ]
+ ),
+ # degree 11
+ np.array(
+ [
+ [-0.674940520480437, 0.725629052064501, 0.133857284499464],
+ [0.09672433446143, -0.910327382989987, -0.402428203412229],
+ [0.906960315916358, 0.135127022135053, 0.398953221871704],
+ [-0.132758704758026, -0.307658524060733, 0.942189661842955],
+ [-0.226055801127587, -0.958831174708704, -0.171876563798827],
+ [0.275738264019853, -0.180692733507538, -0.944096682449892],
+ [0.830881650513589, 0.333278644528177, -0.445601871563928],
+ [-0.616471328612787, -0.2675443371664, 0.740528951931372],
+ [0.430277293287436, -0.892644471615357, -0.13434023290057],
+ [-0.690987198523076, 0.175109339053207, 0.701336874015319],
+ [0.810517041535507, -0.381449337547215, 0.444475565431127],
+ [-0.086734443854626, -0.706008517835924, -0.702872043114784],
+ [0.871320852056737, 0.46045780600396, 0.169642511361809],
+ [-0.600735266749549, 0.303266118552509, -0.739693720820614],
+ [-0.899100947083419, -0.418081246639828, 0.12971336924846],
+ [0.896927087079571, -0.188066327344843, -0.400191025613991],
+ [0.150494960966991, 0.903072153139254, 0.402258564791324],
+ [0.248601716402621, -0.224283612281953, 0.94228129975259],
+ [0.842584674708423, -0.510756382085546, -0.1708185707275],
+ [0.260034500418337, 0.209356489957684, -0.942630319215749],
+ [-0.058802461572434, 0.894595213188746, -0.442991732488095],
+ [0.061611769180132, -0.671290108790159, 0.738629528071408],
+ [0.982337536097614, 0.133784014710179, -0.130823555148513],
+ [-0.382277582532576, -0.605243847900137, 0.698243320392029],
+ [0.611839278216357, 0.651571608497249, 0.448449703569971],
+ [0.646865348569582, -0.298464129297652, -0.701772316597447],
+ [-0.169201016881282, 0.970430912746818, 0.172147783812972],
+ [-0.471725450862325, -0.47529570366279, -0.742676977621112],
+ [0.119369755955723, -0.984692604411347, 0.127009197228668],
+ [0.457289212231729, 0.796155990558714, -0.396260287026038],
+ [-0.813631436350979, 0.420942272793499, 0.40101307803722],
+ [0.287154555386871, 0.16417332397066, 0.943710432821951],
+ [0.746667577045155, 0.644035989066398, -0.166448713352744],
+ [-0.115779644740906, 0.314952464646105, -0.942019118105898],
+ [-0.867579212111466, 0.221916315040665, -0.445038717226738],
+ [0.655140022433912, -0.151162631680508, 0.740230646345257],
+ [0.176736512358047, 0.976002671721061, -0.12721238144483],
+ [0.455284607701078, -0.55278635410423, 0.697956426080188],
+ [-0.432023930219742, 0.781838026058859, 0.449538234998843],
+ [0.485961267092557, 0.525163287076294, -0.698602296584415],
+ [-0.975758639968897, 0.138431863354196, 0.1695042646494],
+ [0.308602378401872, -0.593188152818631, -0.743567338847214],
+ [0.972979693579006, -0.191167383224118, 0.129481842256537],
+ [-0.614624689780931, 0.68217777986423, -0.39606813475866],
+ [-0.653028964396532, -0.644975511259979, 0.396937981974668],
+ [-0.070378900922493, 0.320878001965403, 0.944502047726543],
+ [-0.381252925250545, 0.909662131759037, -0.164805986030565],
+ [-0.332341796304234, -0.009834857390798, -0.943107738283054],
+ [-0.477746621168896, -0.755138676192789, -0.448913962446598],
+ [0.343877558432071, 0.574039599276676, 0.743119615720828],
+ [-0.873212544495548, 0.47009394139203, -0.128497231106812],
+ [0.664216892966437, 0.259987346974329, 0.700872669256879],
+ [-0.878489109322641, -0.170673846340671, 0.446236846278739],
+ [-0.347082716608212, 0.626648635925969, -0.697742842975825],
+ [-0.433716795977713, -0.885744934523588, 0.165365207503367],
+ [0.661861683362982, 0.112512128799614, -0.741134355544863],
+ [0.482068945127674, 0.865869532174741, 0.133714192945202],
+ [-0.8374660393934, -0.372486946227971, -0.399880116725617],
+ [0.410355219266256, -0.82161905066793, 0.39566492086167],
+ [-0.329899568015879, 0.02926988290883, 0.943562159572669],
+ [-0.982429034616553, -0.080964254903198, -0.168160582094488],
+ [-0.090370421487683, -0.316160436207578, -0.944391743662116],
+ [0.571959920493404, -0.686312971502271, -0.449262010965652],
+ [-0.442021476996821, 0.502111749619808, 0.743303978710785],
+ [-0.716515724344093, -0.684793506171761, -0.13290248557738],
+ [-0.044218043628816, 0.709851625611568, 0.702961900983442],
+ [-0.110556556362806, -0.889624975730714, 0.443107944412334],
+ [-0.701028131184281, -0.134257385503649, -0.700381691455451],
+ [0.707841110014082, -0.686721956709281, 0.165450648676302],
+ [0.099860111408803, 0.666551337869757, -0.738740327945793],
+ ]
+ ),
+ # degree 12
+ np.array(
+ [
+ [-0.893804977761136, -0.426862191124497, 0.137482113446834],
+ [-0.426862191241092, 0.137482113445288, -0.893804977705691],
+ [0.893804977770157, 0.426862191128157, 0.137482113376823],
+ [0.1374821132964, -0.893804977739491, -0.426862191218271],
+ [0.426862191272731, 0.137482113377345, 0.893804977701032],
+ [0.137482113529033, 0.893804977707775, 0.426862191209756],
+ [0.426862191185983, -0.137482113474993, -0.893804977727441],
+ [-0.137482113324291, 0.893804977725279, -0.426862191239047],
+ [-0.426862191217414, -0.137482113347288, 0.893804977732073],
+ [-0.137482113501071, -0.893804977693655, 0.426862191248328],
+ [-0.893804977672548, 0.426862191328071, -0.137482113390703],
+ [0.893804977663553, -0.42686219133326, -0.137482113433065],
+ [0.983086600385574, 0.022300380107522, -0.181778516853323],
+ [0.022300380232394, -0.181778516808726, 0.983086600390988],
+ [-0.983086600396613, -0.022300380113323, -0.181778516792915],
+ [-0.181778516710471, 0.983086600409631, 0.022300380211455],
+ [-0.022300380272854, -0.181778516836686, -0.9830866003849],
+ [-0.18177851693601, -0.983086600368179, -0.022300380200376],
+ [-0.0223003801708, 0.181778516841875, 0.983086600386256],
+ [0.181778516710979, -0.983086600409044, 0.022300380233212],
+ [0.022300380212558, 0.181778516804081, -0.983086600392297],
+ [0.181778516934384, 0.983086600367503, -0.022300380243431],
+ [0.983086600391629, -0.022300380332372, 0.181778516792996],
+ [-0.98308660038057, 0.022300380337865, 0.181778516852128],
+ [-0.897951986971875, 0.376695603035365, 0.227558018419664],
+ [0.376695602927528, 0.227558018339206, -0.897951987037503],
+ [0.897951986986053, -0.376695603028569, 0.227558018374966],
+ [0.227558018305554, -0.897951987041904, 0.376695602937366],
+ [-0.376695602875261, 0.227558018455254, 0.89795198703002],
+ [0.227558018486567, 0.89795198699048, -0.3766956029506],
+ [-0.376695602982511, -0.22755801836891, -0.89795198700691],
+ [-0.227558018280939, 0.897951987054767, 0.376695602921573],
+ [0.376695602931437, -0.22755801842558, 0.897951987013974],
+ [-0.227558018511349, -0.897951987002348, -0.376695602907339],
+ [-0.897951987072194, -0.376695602830637, -0.227558018362707],
+ [0.897951987057819, 0.376695602823051, -0.227558018431989],
+ [-0.171330151245221, 0.459786194953055, -0.871345301361568],
+ [0.459786194843117, -0.871345301414649, -0.171330151270292],
+ [0.171330151191219, -0.459786194982334, -0.871345301356736],
+ [-0.871345301364754, -0.171330151162981, 0.459786194977662],
+ [-0.459786195042432, -0.871345301303738, 0.171330151299472],
+ [-0.871345301353407, 0.171330151362727, -0.459786194924734],
+ [-0.459786194855202, 0.87134530140841, -0.171330151269592],
+ [0.871345301392835, 0.171330151178183, 0.45978619491878],
+ [0.459786195054412, 0.871345301309038, 0.171330151240368],
+ [0.871345301325486, -0.171330151377355, -0.459786194972196],
+ [-0.17133015129661, -0.459786194913003, 0.871345301372597],
+ [0.171330151350736, 0.459786194942983, 0.871345301346135],
+ [-0.397191702297223, -0.548095590649226, -0.736091010091219],
+ [-0.548095590778902, -0.736091010056557, -0.397191702182515],
+ [0.397191702250221, 0.548095590625205, -0.736091010134467],
+ [-0.736091010174764, -0.397191702137083, -0.548095590653075],
+ [0.548095590610212, -0.736091010169131, 0.397191702206669],
+ [-0.736091010049194, 0.397191702305889, 0.548095590699385],
+ [0.548095590752529, 0.736091010044117, -0.397191702241962],
+ [0.736091010139925, 0.397191702119602, -0.548095590712531],
+ [-0.548095590584386, 0.736091010182625, 0.3971917022173],
+ [0.736091010083782, -0.39719170228798, 0.548095590665912],
+ [-0.39719170212526, 0.548095590740419, 0.736091010116106],
+ [0.397191702171386, -0.548095590716295, 0.736091010109179],
+ [0.379474725534956, 0.69627727809449, 0.609259291836815],
+ [0.696277278210441, 0.609259291787114, 0.379474725402001],
+ [-0.379474725495576, -0.696277278074161, 0.609259291884576],
+ [0.609259291925953, 0.379474725376213, 0.696277278103008],
+ [-0.696277278071056, 0.609259291933888, -0.379474725422102],
+ [0.60925929179591, -0.379474725515542, -0.696277278140864],
+ [-0.696277278185906, -0.609259291774849, 0.379474725466713],
+ [-0.609259291882878, -0.379474725353089, 0.696277278153303],
+ [0.696277278046548, -0.609259291946589, -0.379474725446676],
+ [-0.609259291838737, 0.379474725493095, -0.696277278115623],
+ [0.37947472533629, -0.696277278181595, -0.609259291861008],
+ [-0.379474725375237, 0.696277278161216, -0.609259291860039],
+ [-0.678701446470328, 0.729764213479081, 0.082513873284097],
+ [0.729764213389772, 0.082513873179234, -0.678701446579104],
+ [0.678701446474772, -0.72976421347722, 0.082513873263995],
+ [0.082513873217671, -0.678701446552547, 0.729764213410125],
+ [-0.729764213370974, 0.082513873368402, 0.678701446576318],
+ [0.082513873326892, 0.678701446534692, -0.729764213414381],
+ [-0.729764213431284, -0.082513873201736, -0.678701446531733],
+ [-0.082513873171694, 0.678701446577399, 0.72976421339221],
+ [0.729764213412797, -0.08251387334668, 0.67870144653399],
+ [-0.082513873373655, -0.678701446558336, -0.729764213387104],
+ [-0.678701446641541, -0.729764213324827, -0.082513873240061],
+ [0.678701446637016, 0.729764213321344, -0.082513873308075],
+ ]
+ ),
+ # degree 13
+ np.array(
+ [
+ [0.276790129286922, -0.235256466916603, 0.931687511509759],
+ [0.198886780634501, 0.360548603139528, 0.911289609983006],
+ [-0.258871339062373, 0.204230077441409, 0.944073993540935],
+ [-0.20028291392731, -0.228346161950354, 0.952756414153864],
+ [-0.883545166667525, -0.414277696639041, -0.218453492821483],
+ [0.397750057908559, -0.901619535998689, -0.16993264471327],
+ [0.876539487069282, 0.434392104192327, -0.207321073274483],
+ [-0.411742357517625, 0.88489597883979, -0.217778184534166],
+ [0.501114093867204, 0.377868932752059, 0.778524074507957],
+ [-0.394238847790386, 0.473687133880952, 0.787525383774109],
+ [-0.495364292002136, -0.406429808740612, 0.767742814213388],
+ [0.370186583802172, -0.559306270968252, 0.741713144300723],
+ [0.411742357517961, -0.884895978839253, 0.217778184535715],
+ [0.883545166668397, 0.414277696639157, 0.218453492817737],
+ [-0.39775005791059, 0.901619535997218, 0.169932644716324],
+ [-0.876539487069878, -0.434392104191278, 0.20732107327416],
+ [-0.69101430131565, -0.702815226887987, -0.168967429499392],
+ [0.684400344460127, -0.714441044251654, -0.145499700314004],
+ [0.660710489482765, 0.731357715035063, -0.169048932993191],
+ [-0.773611287956309, 0.615222357857778, -0.151746583284428],
+ [0.683629784686022, -0.21996733132084, -0.695891292258878],
+ [0.256574099503526, 0.681472791071418, -0.685393730999406],
+ [-0.644474509637892, 0.354062227985534, -0.677711254990588],
+ [-0.220535080416141, -0.731547754140859, -0.645137320046912],
+ [0.394238847792041, -0.473687133882522, -0.787525383772336],
+ [0.495364292000968, 0.406429808741285, -0.767742814213785],
+ [-0.370186583802439, 0.559306270970003, -0.74171314429927],
+ [-0.50111409386464, -0.377868932752239, -0.77852407450952],
+ [-0.488574873968534, -0.006884557978444, -0.872494811095214],
+ [0.055542048727444, -0.584131720249991, -0.809756268404849],
+ [0.526812107464791, 0.049707819617671, -0.848527039107984],
+ [0.004245864108125, 0.4886546223943, -0.872466980836902],
+ [-0.710317361514613, -0.479530914625401, 0.515266288291253],
+ [0.521404384476562, -0.728039165451723, 0.445080264016476],
+ [0.738099355388852, 0.407803273205931, 0.53749961110413],
+ [-0.496057991262554, 0.699670113703365, 0.514186932248264],
+ [-0.973220809307327, 0.194260751789571, -0.122898399685852],
+ [-0.376203572666605, -0.908865964003535, -0.180093118660339],
+ [0.914477900370762, -0.368657988534049, -0.166797653531193],
+ [0.28746218785413, 0.946817914340553, -0.144572914606861],
+ [-0.098900669929334, 0.99509705928004, 0.000707177308311],
+ [-0.986068201425202, -0.161328561779779, 0.04052896855503],
+ [0.098900669927371, -0.995097059280236, -0.000707177307522],
+ [0.98606820142538, 0.161328561777872, -0.040528968558297],
+ [0.815232440848265, 0.131832381174928, -0.563929331266187],
+ [-0.113644567080339, 0.787251605615581, -0.606069155978764],
+ [-0.76050444170531, -0.010569874890521, -0.649246695710724],
+ [0.179848227912241, -0.83248278540524, -0.52404868754798],
+ [-0.92768989180951, -0.047241289482188, -0.370350813692261],
+ [-0.062273759773745, -0.944434685686409, -0.322746190242513],
+ [0.939840260740896, -0.044569841802216, -0.33869427732427],
+ [-0.00824273878155, 0.93408705946015, -0.356950168239866],
+ [-0.287462187854123, -0.94681791433958, 0.144572914613243],
+ [0.973220809307003, -0.194260751794934, 0.122898399679932],
+ [0.376203572664097, 0.908865964003331, 0.180093118666611],
+ [-0.914477900372906, 0.368657988531957, 0.16679765352406],
+ [-0.198886780630987, -0.360548603140065, -0.91128960998356],
+ [0.258871339064112, -0.204230077444254, -0.944073993539843],
+ [0.200282913924527, 0.228346161951352, -0.952756414154209],
+ [-0.276790129284401, 0.235256466920766, -0.931687511509456],
+ [0.496057991258595, -0.69967011370777, -0.514186932246089],
+ [0.710317361512836, 0.479530914628761, -0.515266288290576],
+ [-0.521404384476695, 0.728039165452453, -0.445080264015126],
+ [-0.738099355384499, -0.407803273209131, -0.537499611107679],
+ [-0.815232440849446, -0.131832381170712, 0.563929331265466],
+ [0.113644567080183, -0.787251605613717, 0.606069155981215],
+ [0.760504441709935, 0.010569874889864, 0.649246695705317],
+ [-0.179848227916839, 0.832482785402468, 0.524048687550806],
+ [0.644474509638734, -0.354062227985804, 0.677711254989647],
+ [0.220535080413518, 0.73154775414237, 0.645137320046095],
+ [-0.683629784685343, 0.219967331325312, 0.695891292258132],
+ [-0.256574099500943, -0.681472791069379, 0.6853937310024],
+ [0.00824273878347, -0.934087059459458, 0.356950168241634],
+ [0.927689891812602, 0.047241289479133, 0.370350813684907],
+ [0.062273759768788, 0.944434685686016, 0.322746190244617],
+ [-0.939840260741931, 0.04456984180196, 0.338694277321433],
+ [-0.684400344460716, 0.714441044251305, 0.145499700312953],
+ [-0.660710489482671, -0.731357715034246, 0.169048932997096],
+ [0.773611287955743, -0.615222357858877, 0.151746583282855],
+ [0.691014301313431, 0.702815226889319, 0.168967429502926],
+ [0.823586023578098, -0.394634588904438, 0.407393670798948],
+ [0.494068620358303, 0.708839608416629, 0.503430837272612],
+ [-0.75887513050105, 0.450605887274021, 0.470173234734822],
+ [-0.431499072601357, -0.787935048711447, 0.439279989706176],
+ [0.488574873974618, 0.006884557979146, 0.872494811091801],
+ [-0.055542048725634, 0.584131720247444, 0.80975626840681],
+ [-0.52681210746758, -0.049707819615904, 0.848527039106356],
+ [-0.004245864106237, -0.488654622389235, 0.872466980839748],
+ [-0.49406862035774, -0.708839608420214, -0.503430837268118],
+ [0.758875130496518, -0.450605887275878, -0.470173234740358],
+ [0.431499072601226, 0.787935048714215, -0.43927998970134],
+ [-0.823586023577754, 0.394634588903444, -0.407393670800605],
+ [-0.05223814787874, -0.056184830047506, -0.997052877624217],
+ [0.052238147881538, 0.05618483004769, 0.99705287762406],
+ ]
+ ),
+ # degree 14
+ np.array(
+ [
+ [-0.625520988160254, -0.7673610045544, 0.14099851793647],
+ [-0.76724274137005, 0.141111638293461, -0.625640536852518],
+ [0.625492928633992, 0.767336602497947, 0.141255565185161],
+ [0.141259978285753, -0.625497336538309, -0.767332196977417],
+ [0.767217177597722, 0.141142445065633, 0.625664936367606],
+ [0.140994104732436, 0.625522897885037, 0.76736025871308],
+ [0.767367121351956, -0.141003470846495, -0.625512367805189],
+ [-0.141107421738627, 0.625630007825311, -0.767252102531351],
+ [-0.767341557579575, -0.14125061251247, 0.625487968290521],
+ [-0.141146661279494, -0.625655569171042, 0.767224040795719],
+ [-0.625631916042134, 0.767247698300479, -0.141122907715456],
+ [0.625659975568208, -0.76722329624468, -0.141131175405856],
+ [0.557188048071509, -0.044753456478336, -0.829179478291342],
+ [-0.044855878114542, -0.829283214592494, 0.55702540354432],
+ [-0.557023176560509, 0.04489683751235, -0.829282493940292],
+ [-0.82927698135379, 0.557030331904205, -0.044909882380585],
+ [0.04500608513203, -0.829178759033645, -0.557168769645708],
+ [-0.829184990273806, -0.557180524667655, 0.044744997884746],
+ [0.044745112627007, 0.829186887034329, 0.557177692721376],
+ [0.829285903726696, -0.557022572534015, -0.044841315410966],
+ [-0.044895319644503, 0.829275086591665, -0.557034326619468],
+ [0.829176067900875, 0.557172765298515, 0.045006199906779],
+ [0.55701504778817, 0.04485436035495, 0.829290252501915],
+ [-0.55717991929916, -0.044997741388974, 0.829171719729799],
+ [-0.256065565410913, 0.860770382492113, 0.439891776275906],
+ [0.860817193749452, 0.439942893453515, -0.255820267854343],
+ [0.255978113844099, -0.860846435024418, 0.439793838677361],
+ [0.439780410927513, -0.2559784640674, 0.860853190792787],
+ [-0.86089686693789, 0.439742722239698, 0.255896312466096],
+ [0.439905203705212, 0.256058129693811, -0.860765732339981],
+ [-0.860766305814323, -0.439898638597135, -0.256067480432698],
+ [-0.439951565414546, 0.255829619022141, 0.860809982586329],
+ [0.860845979002674, -0.439786977094207, 0.255991435820162],
+ [-0.439734049218326, -0.255909284650278, -0.860897441039197],
+ [-0.255822182689152, -0.860816739802146, -0.439942668220034],
+ [0.255909634255987, 0.860892792334745, -0.43974294673258],
+ [-0.214847470746312, -0.032398468989078, 0.976110087808274],
+ [-0.032361689068532, 0.976149872834738, -0.214672184610297],
+ [0.214653391258715, 0.032229687143749, 0.976158372851326],
+ [0.976156890684175, -0.214662164023714, -0.032216146673044],
+ [0.032184871976157, 0.976118589466139, 0.214840948877335],
+ [0.976111569264229, 0.214838964338531, 0.032410241444204],
+ [0.032404387317597, -0.976112740167589, -0.214834527404447],
+ [-0.97615046971745, 0.214667748038907, -0.032373112644706],
+ [-0.032227570225248, -0.976155722133346, 0.214665763138194],
+ [-0.976117990230955, -0.214844548351237, 0.032179017872419],
+ [-0.214659241112501, 0.032359572226461, -0.976152789418913],
+ [0.21485332060012, -0.032190790383644, -0.976115671240647],
+ [-0.531657953418075, -0.827333953094149, -0.181268724894593],
+ [-0.827232187812406, -0.181173291587112, -0.531848799813059],
+ [0.531693969367769, 0.827365274479422, -0.181019958909332],
+ [-0.181013937184585, -0.531695771783126, -0.827365433658478],
+ [0.82726500032424, -0.181115392520565, 0.53181748168959],
+ [-0.181274746488222, 0.531662962384716, 0.827329414872902],
+ [0.827337912964399, 0.181265235803751, -0.531652980863198],
+ [0.181178425057189, 0.531838819184424, -0.827237480233043],
+ [-0.827370725476244, 0.181023448305968, 0.53168429898609],
+ [0.181110258615433, -0.531806009787378, 0.82727349901848],
+ [-0.531843826870477, 0.827237640803175, 0.181162991414264],
+ [0.531807810920692, -0.827268962188108, 0.181125692390543],
+ [-0.660052978431453, -0.64107030142389, -0.391610692264717],
+ [-0.640943278162024, -0.391490626360013, -0.660247532105318],
+ [0.660130816198753, 0.641137993287446, -0.391368597447617],
+ [-0.391366127194665, -0.660129990434437, -0.641140351415881],
+ [0.64101419265168, -0.391488664009925, 0.660179847292265],
+ [-0.391613162232706, 0.660059082674287, 0.641062507518011],
+ [0.641074532205542, 0.391604771858404, -0.660052381872206],
+ [0.391493582576226, 0.660240832413506, -0.64094837390819],
+ [-0.641145446695277, 0.391374518513393, 0.660120066685087],
+ [0.391485706851, -0.660169924653076, 0.641026217806203],
+ [-0.660246935065265, 0.640950733115446, 0.391479427883123],
+ [0.660169097297863, -0.641018424977353, 0.391499861829449],
+ [-0.887809544786451, -0.296234001309576, 0.352192601646022],
+ [-0.296066792023988, 0.352356402512996, -0.887800326801429],
+ [0.88773949759831, 0.296173084548554, 0.352420329142482],
+ [0.352416311777685, -0.887743835667625, -0.296164861904896],
+ [0.296002975181616, 0.352256528867817, 0.887861237217634],
+ [0.352196618754376, 0.887807646454613, 0.296234914611202],
+ [0.29624330876316, -0.35220289417087, -0.887802356017778],
+ [-0.352357435214616, 0.887795037857324, -0.296081422255585],
+ [-0.296179491920463, -0.352410037210515, 0.887741445601714],
+ [-0.352255495317536, -0.887858848644387, 0.296011369549315],
+ [-0.887793137385775, 0.296073200109863, -0.352369132285204],
+ [0.887863184573985, -0.29601228334985, -0.352243798503465],
+ [-0.26223504413332, -0.963196832316083, -0.059030871962556],
+ [-0.963146753177879, -0.05898607663367, -0.262428989645347],
+ [0.262246759182379, 0.963207020345643, -0.058812186281487],
+ [-0.058802018270752, -0.262250202293796, -0.963206703695603],
+ [0.963157426232092, -0.058856981708733, 0.262418802676392],
+ [-0.059041039930677, 0.26223953025181, 0.9631949877243],
+ [0.963198910461706, 0.05903143538266, -0.262227284091577],
+ [0.058993521775316, 0.262416743924601, -0.963149633700058],
+ [-0.963209583515922, 0.058811622961071, 0.262237471059664],
+ [0.058849536426177, -0.26240607188442, 0.963161349671286],
+ [-0.262421229412009, 0.963149318669471, 0.058978710569357],
+ [0.262409514362814, -0.96315950669899, 0.058864347675232],
+ [-0.715507563967586, -0.551203770138786, 0.429212452859839],
+ [-0.551069607492995, 0.429343180584727, -0.715532473744488],
+ [0.715422202362423, 0.551129535145429, 0.429450006237378],
+ [0.429450819917231, -0.715428271411138, -0.551121022769127],
+ [0.5509918383738, 0.429319279234025, 0.715606701005125],
+ [0.429211638866897, 0.7155060331503, 0.551206391097486],
+ [0.551211869430264, -0.429219462039715, -0.715497119774448],
+ [-0.429341247317559, 0.715523561571074, -0.551082685436994],
+ [-0.551134100310698, -0.42944299777979, 0.715422892513669],
+ [-0.429321211466422, -0.715601323311163, 0.550997317108093],
+ [-0.715522029029489, 0.551074173985694, -0.429354726001132],
+ [0.71560739063481, -0.550999938995098, -0.429307733096245],
+ ]
+ ),
+ # degree 15
+ np.array(
+ [
+ [0.854403279867469, -0.505354134007206, 0.120881076242474],
+ [-0.50543491755569, 0.120816219805996, 0.85436466754382],
+ [-0.854386776665562, 0.505324765203946, 0.121120260611542],
+ [0.120833358636621, 0.854397789834015, -0.505374827397788],
+ [0.505397909754575, 0.121184507524897, -0.854334400543285],
+ [0.121167891781777, -0.854359592169892, 0.505359307095908],
+ [0.505550243990606, -0.121029099883223, 0.85426629793203],
+ [-0.120901230058257, -0.8542712308899, -0.505572503845152],
+ [-0.505512743080475, -0.120971801893292, -0.854296605243135],
+ [-0.121100221709937, 0.854233310670545, 0.505588950870808],
+ [0.854228086018077, 0.50562286374044, -0.120995440909188],
+ [-0.854244251915001, -0.505593339178042, -0.121004683582819],
+ [-0.264987898778375, 0.883813698575362, -0.385557725524417],
+ [0.883849543661418, -0.385514772323787, -0.264930829632268],
+ [0.26493557758679, -0.883717531264112, -0.385814028600849],
+ [-0.385585969828214, -0.26499188213212, 0.883800182323873],
+ [-0.883729220902574, -0.385857204128221, 0.264833687708496],
+ [-0.385785570894326, 0.264874341279828, -0.883748310675226],
+ [-0.88388533425828, 0.385579930632135, -0.264716514363662],
+ [0.385705565188833, 0.2647713629504, 0.883814088111154],
+ [0.883764779677347, 0.385791993289806, 0.26481003023928],
+ [0.385667159435638, -0.264653128949093, -0.883866258814251],
+ [-0.26465693400637, -0.883897033692944, 0.385594010730408],
+ [0.264708980625706, 0.88380076582133, 0.385778902883153],
+ [-0.973352164893031, 0.228026239253951, 0.024281624940474],
+ [0.228112345886926, 0.024344549147457, -0.973330416960638],
+ [0.973355442809202, -0.228032180463189, 0.024093705130315],
+ [0.02435678809852, -0.973342833359315, 0.228058053182924],
+ [-0.228119646411613, 0.024031592823588, 0.973336483168797],
+ [0.024019836271027, 0.973350319061891, -0.228061842156086],
+ [-0.228243300844141, -0.024101340294447, -0.97330576953791],
+ [-0.024325527828938, 0.973285859331229, 0.228304412400888],
+ [0.228251100231509, -0.024273911824481, 0.973299651930403],
+ [-0.024050630999231, -0.973293331950969, -0.228301680082122],
+ [-0.973284068418876, -0.228330589256741, -0.024150862752081],
+ [0.973280766897442, 0.228336792285628, -0.024225153791815],
+ [0.176494164360597, -0.643915961757687, 0.744460908403072],
+ [-0.643955474734765, 0.744429879896195, 0.176480878502065],
+ [-0.176394545795758, 0.643730246363376, 0.744645106161623],
+ [0.744482176984942, 0.176540768465681, -0.643878595094843],
+ [0.643722173544997, 0.744675565408302, -0.176295392935637],
+ [0.7446232602985, -0.176308754645389, 0.643779017410342],
+ [0.643979646991902, -0.744473128172097, 0.176210032886433],
+ [-0.744568815199564, -0.17638194509082, -0.643821938798528],
+ [-0.643745300418696, -0.744632317149074, -0.176393595252333],
+ [-0.744536607308663, 0.176149174255325, 0.643922905933989],
+ [0.17619570702358, 0.643971118562494, -0.744483895919737],
+ [-0.176295790346183, -0.643784799068587, -0.744621331143847],
+ [-0.538268091669357, -0.714443097207928, 0.447033021534844],
+ [-0.714408571342234, 0.447040121797618, -0.538308018420605],
+ [0.538328211219438, 0.714331668550684, 0.447138685768607],
+ [0.447120197809599, -0.538252620284606, -0.714400199795228],
+ [0.714269338389924, 0.44713116789972, 0.538417153263762],
+ [0.447050760030329, 0.538390645652783, 0.714339646547694],
+ [0.714335378036911, -0.44690804426239, -0.538514779424324],
+ [-0.44721697306067, 0.538430741435079, -0.714205373603506],
+ [-0.714197054202002, -0.447264479166707, 0.53840231560137],
+ [-0.446955147518652, -0.538568616287532, 0.714265316011294],
+ [-0.53855313587161, 0.7142734853459, -0.446960745451628],
+ [0.538492423492547, -0.714162749879013, -0.44721077416177],
+ [-0.854262559171519, -0.121196786481334, 0.505516388403308],
+ [-0.121135225917774, 0.505562840819898, -0.854243800693903],
+ [0.854330941426722, 0.121071937454801, 0.505430735592792],
+ [0.50560092757097, -0.854220500665318, -0.12114057240441],
+ [0.120980483865727, 0.505385380088754, 0.854370727562784],
+ [0.505347417325379, 0.854380006035553, 0.121073502837149],
+ [0.121020187472529, -0.505348640639146, -0.854386836057462],
+ [-0.505616684568734, 0.854250931602746, -0.120859894677939],
+ [-0.120865731027902, -0.505598482368904, 0.854260879175297],
+ [-0.50533076030876, -0.85441039802093, 0.120928468275617],
+ [-0.854368086396689, 0.120928425096259, -0.505402304061426],
+ [0.854300186149896, -0.120804448918389, -0.50554671106217],
+ [0.744463297304691, 0.643945879165515, 0.176374895243003],
+ [0.643874092586096, 0.176352654639277, 0.744530653565125],
+ [-0.744439438351756, -0.643989673955645, 0.176315689786883],
+ [0.176272538481235, 0.744499406704124, 0.643932159155441],
+ [-0.643927227994187, 0.176337348613599, -0.744488323972679],
+ [0.176417156092134, -0.744445359352701, -0.643955040336351],
+ [-0.643773276436742, -0.176537581494637, 0.74457400630557],
+ [-0.176186289075531, -0.744659935538757, 0.643770123526409],
+ [0.643827610309512, -0.176153056179147, -0.744618096074686],
+ [-0.176503857743247, 0.744606426388638, -0.643745025598251],
+ [0.744642563797528, -0.643711305756388, -0.176474380640507],
+ [-0.744666089545738, 0.643755051780831, -0.176215346628265],
+ [-0.228336071531986, 0.973285242484899, -0.024051511354805],
+ [0.97330645157172, -0.02400727784605, -0.228250305452786],
+ [0.228333198106367, -0.973279105669999, -0.024325564920999],
+ [-0.024087022235214, -0.228306334617004, 0.973291340212985],
+ [-0.973298936142401, -0.024368394387641, 0.228244084828147],
+ [-0.024288551979359, 0.228299553239789, -0.97328792257649],
+ [-0.973337174231582, 0.024064219708824, -0.228113258248362],
+ [0.024218299052141, 0.228063963157171, 0.973344904286279],
+ [0.973329707418057, 0.024311563193311, 0.228118891266323],
+ [0.024157446706033, -0.228056100974521, -0.97334825862943],
+ [-0.228026586903388, -0.973357047279715, 0.0240818225246],
+ [0.228030283046768, 0.973350913461235, 0.024293811512186],
+ [0.714188221641478, 0.538577770578359, -0.447067298186109],
+ [0.538527478643698, -0.447091804432198, 0.714210804423472],
+ [-0.714248768779612, -0.538467220657919, -0.447103733571689],
+ [-0.447161335031565, 0.714165750281327, 0.538529499264337],
+ [-0.538389584312319, -0.447080411787951, -0.714321888856505],
+ [-0.447010969457209, -0.714305113088417, -0.538469496444014],
+ [-0.538432015415055, 0.446913681352591, 0.714394237235966],
+ [0.447232748299163, -0.714300498031757, 0.538291433482237],
+ [0.538293039926151, 0.447257878415203, -0.714283552493402],
+ [0.446938450981038, 0.714439434305571, -0.538351479745163],
+ [0.714416613789271, -0.538354159540915, 0.44697170027516],
+ [-0.714357130519037, 0.538243224338143, 0.447200314770336],
+ [-0.88382262479637, 0.264636203965444, 0.385778754532718],
+ [0.264703577640278, 0.38583706665855, -0.88377699335113],
+ [0.883874390523727, -0.264731681104579, 0.385594604209984],
+ [0.385842662030848, -0.883779863837697, 0.264685837233447],
+ [-0.26482277877838, 0.385535295796556, 0.883872972510847],
+ [0.385530295763018, 0.883900556837248, -0.264737977388365],
+ [-0.264823043649479, -0.385616295049092, -0.883837557781314],
+ [-0.385806659371906, 0.883713819704025, 0.264958688191976],
+ [0.264941401779006, -0.385755452567231, 0.883741356075422],
+ [-0.385565266715756, -0.883834666939721, -0.264906977291953],
+ [-0.883791950673863, -0.264913750483099, -0.38565851828926],
+ [0.883739852971877, 0.26500888802554, -0.385712537437807],
+ ]
+ ),
+ # degree 16
+ np.array(
+ [
+ [0.938311825813856, -0.17507925577492, -0.298191501782276],
+ [-0.175109632245629, -0.298282531121024, 0.938277223598034],
+ [-0.938311652301346, 0.175147761450008, -0.298151815044902],
+ [-0.298182757815715, 0.938327057553728, -0.175012502421904],
+ [0.175097712410131, -0.298058347845738, -0.938350687316958],
+ [-0.298185477757762, -0.938323612741539, 0.175026336949732],
+ [0.175121225661409, 0.298070999742225, 0.938342280532811],
+ [0.298159022282375, -0.938297484887434, -0.175211378870018],
+ [-0.175136638135111, 0.298288500226525, -0.938270285480331],
+ [0.298175056505462, 0.938292628074833, 0.175210101816042],
+ [0.938309721676758, 0.175091137054814, 0.298191146635404],
+ [-0.938307020714082, -0.175144295988174, 0.298168426332282],
+ [0.318319389865683, -0.189552295411868, 0.928839433561922],
+ [-0.189466106261457, 0.928833946336168, 0.318386706242113],
+ [-0.318293314473071, 0.18936285961738, 0.928887007853633],
+ [0.928852943553566, 0.318350700348959, -0.189433473386317],
+ [0.189441607397533, 0.928892798895752, -0.318229548512164],
+ [0.928866264406345, -0.318313837307129, 0.189430102746667],
+ [0.18945182591494, -0.928887156552102, 0.318239934719146],
+ [-0.928865750332054, -0.318289122686796, -0.189474146625178],
+ [-0.189481041982253, -0.928834132900175, -0.318377273511944],
+ [-0.928863874908086, 0.318277395441538, 0.189503038080361],
+ [0.318275484124591, 0.18957282380822, -0.92885028970154],
+ [-0.318345902583112, -0.189353418017315, -0.928870911049379],
+ [0.415270907116288, 0.626546860524453, 0.659537038588256],
+ [0.626612654947257, 0.659451415891007, 0.415307609777736],
+ [-0.415241828112963, -0.626676394380167, 0.659432271664102],
+ [0.659494217922308, 0.41519684716212, 0.626641009377521],
+ [-0.626618996427069, 0.659521812332477, -0.415186238180433],
+ [0.659478785687794, -0.415192215022902, -0.626660319321504],
+ [-0.626602233185435, -0.65952877581014, 0.415200475969626],
+ [-0.659472693683341, -0.415326178073293, 0.626577953724091],
+ [0.626606052873236, -0.65944383659479, -0.415329605108723],
+ [-0.659498633823103, 0.415315781516604, -0.626557542136963],
+ [0.415250963158486, -0.626542854390271, -0.659553401331872],
+ [-0.415267233073285, 0.626674158557439, -0.659418398387537],
+ [0.081476869754028, 0.884767493032223, 0.458855100188022],
+ [0.88480215017059, 0.458780629597686, 0.081519868495058],
+ [-0.08148097265168, -0.88484396510395, 0.458706887363658],
+ [0.458778051156021, 0.08139667888042, 0.884814828336823],
+ [-0.884809515892886, 0.45878451702782, -0.081417980329578],
+ [0.458732327572868, -0.081386172952098, -0.884839500978449],
+ [-0.884806469025575, -0.458784799888689, 0.081449492089205],
+ [-0.458770768146743, -0.081567624478124, 0.884802862185155],
+ [0.884821176813587, -0.458741101923224, -0.081535798692882],
+ [-0.458810899116744, 0.081573887356361, -0.884781475706435],
+ [0.081470600041761, -0.884777903494754, -0.458836139396478],
+ [-0.081497545017818, 0.88485010959699, -0.458692090298344],
+ [-0.722581612146772, 0.69116944690793, -0.012673178305347],
+ [0.691146231887784, -0.012722477090735, -0.722602950951623],
+ [0.722589739174094, -0.691157232223568, -0.012874361552029],
+ [-0.012719991090033, -0.722649829139429, 0.691097262526357],
+ [-0.6911640879369, -0.012832809701898, 0.722583920760425],
+ [-0.012740894622282, 0.722658126679523, -0.691088200990487],
+ [-0.691184825451665, 0.012806932405418, -0.722564543516851],
+ [0.01278690708865, 0.722509435119358, 0.69124280189425],
+ [0.691166758903022, 0.012679269543203, 0.722584076430794],
+ [0.012798402734516, -0.722517774658893, -0.691233872281593],
+ [-0.722587198973115, -0.691163495604889, 0.01267920436853],
+ [0.722578352800658, 0.691170335944389, 0.012809792129789],
+ [0.560117573995459, 0.806868022890413, 0.187702682288658],
+ [0.806883478716379, 0.18757144265397, 0.560139273462648],
+ [-0.560134093540899, -0.806891631206385, 0.18755184014617],
+ [0.187652131237362, 0.560025149416763, 0.806943932168034],
+ [-0.806885441512999, 0.18768574188158, -0.560098157976558],
+ [0.187630222901067, -0.560004720839195, -0.806963203679022],
+ [-0.806874677594158, -0.187697516958668, 0.560109718523856],
+ [-0.187614808802038, -0.560215760321792, 0.806820293129301],
+ [0.806892320702248, -0.18757613331143, -0.560124965524367],
+ [-0.187636487681617, 0.56022497423671, -0.806808853900342],
+ [0.56009182108872, -0.806880865199227, -0.18772432267788],
+ [-0.560129384097476, 0.806896245083186, -0.187546054987136],
+ [-0.099485634221032, -0.358895129517995, -0.928060824834181],
+ [-0.359050794288811, -0.927994608087772, -0.099541621850345],
+ [0.099434389660615, 0.359143761945999, -0.927970129049474],
+ [-0.928019026720099, -0.09942019096838, -0.359021324816913],
+ [0.358990815531993, -0.928035748444477, 0.099374262124424],
+ [-0.928007207203491, 0.099420259668564, 0.359051856067911],
+ [0.359002982562248, 0.928031348467288, -0.099371398165657],
+ [0.928017938922059, 0.099510949379702, -0.358998991631458],
+ [-0.359042863742385, 0.92799619207621, 0.099555459356689],
+ [0.928013665632084, -0.099489549105096, 0.359015969030581],
+ [-0.099451875312545, 0.358926751348054, 0.928052213867059],
+ [0.099465503317397, -0.359120063291987, 0.927975966170987],
+ [0.787833199437607, 0.557450082325166, -0.261855409681697],
+ [0.557405388687852, -0.261977292048617, 0.787824302184578],
+ [-0.787861477876718, -0.557364111687839, -0.26195331393273],
+ [-0.261861028070608, 0.787802657602316, 0.557490604990374],
+ [-0.557427204478003, -0.261835304855293, -0.787856068605919],
+ [-0.261850091868655, -0.787804146924511, -0.557493637162722],
+ [-0.557398047481063, 0.261806624190095, 0.787886227950765],
+ [0.26192814680606, -0.787893374500188, 0.557330849971047],
+ [0.557399834363592, 0.261935778537884, -0.787842035292097],
+ [0.261909535328364, 0.787908371337394, -0.55731839524686],
+ [0.787858967733566, -0.557444321449493, 0.261790136264747],
+ [-0.787856023927293, 0.557369329488324, 0.261958615256708],
+ [-0.507282732168614, -0.717049946047353, -0.478020506377115],
+ [-0.71706431400176, -0.477906271006066, -0.507370048109131],
+ [0.507331753192767, 0.71711626280308, -0.477868975583995],
+ [-0.477891616916408, -0.50725750267016, -0.717153699332196],
+ [0.717108744361459, -0.47798851765986, 0.507229756368514],
+ [-0.477913676926975, 0.507235340412842, 0.717154674280526],
+ [0.717103637758922, 0.477942092943937, -0.507280719627002],
+ [0.477949791330649, 0.507362311781387, -0.7170407809538],
+ [-0.717073605236621, 0.477889924387354, 0.507372313830785],
+ [0.477966885504482, -0.507396895429057, 0.717004914118516],
+ [-0.507289494490155, 0.717039874321013, 0.478028437871252],
+ [0.507342973335893, -0.717147616692481, 0.47781000751239],
+ [-0.469705390085658, -0.33624876406351, 0.816280353304085],
+ [-0.336180458859188, 0.816354017519737, -0.4696262526314],
+ [0.469729267279509, 0.336087427571651, 0.816333054879763],
+ [0.816299320102214, -0.469688480988201, -0.336226338688182],
+ [0.336166188592078, 0.816261044646191, 0.469798042397566],
+ [0.816308187841148, 0.469684487990511, 0.336210386818421],
+ [0.336161196424763, -0.816254520485116, -0.469812949806501],
+ [-0.81631474754769, 0.469749196906201, -0.336104038866138],
+ [-0.336166711539314, -0.816355082377068, 0.469634242288587],
+ [-0.816302029136435, -0.469752338316787, 0.33613053695499],
+ [-0.469725914764869, 0.336254309274991, -0.816266258332602],
+ [0.469715709020586, -0.336082571137018, -0.816342855715183],
+ [0.220975783117544, 0.56198189964132, -0.797085972622227],
+ [0.56189854338611, -0.797188442616427, 0.220818056099052],
+ [-0.22090980871236, -0.561819935638318, -0.79721842448229],
+ [-0.7971433029262, 0.220906560624346, 0.561927794358875],
+ [-0.561911046458035, -0.797113560704263, -0.221056434445611],
+ [-0.797166608166814, -0.22092145416411, -0.561888876837612],
+ [-0.561903189214556, 0.797117195899141, 0.221063298519679],
+ [0.797149071206196, -0.221019917708182, 0.56187503437274],
+ [0.56187154518738, 0.797190222992272, -0.220880318440273],
+ [0.797151311779493, 0.220966501329483, -0.561892864715723],
+ [0.220989674227739, -0.56195922843892, 0.79709810529009],
+ [-0.220934514736207, 0.561821479644177, 0.797210489901321],
+ [-0.025586321091663, 0.991400659992677, -0.128335776535923],
+ [0.991391023154192, -0.128410509654448, -0.025584765380375],
+ [0.0255553186148, -0.991378867053065, -0.128510185009118],
+ [-0.128427355734578, -0.025687167640031, 0.991386193023514],
+ [-0.99138841235829, -0.128432289728374, 0.02557660643704],
+ [-0.128471046150121, 0.025696657527584, -0.991380286314492],
+ [-0.991388770492313, 0.128433611757029, -0.025556077805202],
+ [0.128434643068809, 0.02546732329508, 0.991390920829907],
+ [0.991386054149539, 0.128448345934336, 0.025587380962899],
+ [0.128392989158359, -0.02544975216483, -0.991396767419448],
+ [-0.025589705051665, -0.991398220731893, 0.128353943940207],
+ [0.025571746935955, 0.991376476866512, 0.128525354986419],
+ ]
+ ),
+ # degree 17
+ np.array(
+ [
+ [-0.053895316433783, -0.14060350667641, -0.988597971258691],
+ [-0.140602010826056, -0.988598302765153, -0.053893137981829],
+ [0.05389273741214, 0.140602992486377, -0.988598184986247],
+ [-0.988598098647216, -0.053895299659618, -0.140602617421263],
+ [0.140604478516356, -0.988597884351918, 0.053894375180116],
+ [-0.98859806586327, 0.053892813420617, 0.140603800919506],
+ [0.14060181241573, 0.988598276619905, -0.053894135205652],
+ [0.988598002509875, 0.053895652551635, -0.140603158106489],
+ [-0.140604846635692, 0.988597850035353, 0.053894044272357],
+ [0.988598167360928, -0.0538924507278, 0.140603226297146],
+ [-0.053892835910884, 0.140602751584219, 0.988598213878838],
+ [0.053896097450443, -0.140603793852246, 0.988597887836084],
+ [-0.712137820619482, 0.484725955627139, 0.50783902211694],
+ [0.484727142303201, 0.507838589157962, -0.71213732164283],
+ [0.712137878427749, -0.484726895412166, 0.507838044038163],
+ [0.507839760738435, -0.712137969376798, 0.484724963236905],
+ [-0.484727642870466, 0.507838390853387, 0.712137122338588],
+ [0.507839607814555, 0.712136191976364, -0.48472773472555],
+ [-0.484726768067281, -0.507839236872381, -0.712137114474401],
+ [-0.507840112632748, 0.712136257912531, 0.48472710896699],
+ [0.48472854912246, -0.507840095411427, 0.712135289926112],
+ [-0.507838861015403, -0.712137466171353, -0.484726645149226],
+ [-0.712136671868904, -0.484727632936327, -0.507839032024349],
+ [0.712137765857364, 0.484726401555971, -0.507838673275561],
+ [-0.703005448039525, 0.261790111709517, 0.66124827216248],
+ [0.26179085361446, 0.661247136036676, -0.7030062404041],
+ [0.703006433944545, -0.261790569085573, 0.661247042919986],
+ [0.661249487589413, -0.70300443378585, 0.261789765346499],
+ [-0.261791711051733, 0.661247232073399, 0.703005830772316],
+ [0.661247423359215, 0.703005908959219, -0.261791017930756],
+ [-0.261791042135151, -0.661248438610085, -0.703004944999334],
+ [-0.661249044674904, 0.703004424763402, 0.261790908321135],
+ [0.26179224626734, -0.661248916662998, 0.703004046934519],
+ [-0.661247254530942, -0.703006260525483, -0.261790500280794],
+ [-0.70300546948428, -0.261791326803081, -0.66124776830313],
+ [0.703005158463527, 0.261791025358218, -0.661248218308045],
+ [0.062800447246381, 0.786218819998244, -0.614748786827777],
+ [0.786220043108977, -0.614747449388309, 0.062798226760693],
+ [-0.062799502252198, -0.786219239565021, -0.614748346768559],
+ [-0.61474770709614, 0.062799571514381, 0.786219734194995],
+ [-0.786218534519124, -0.614749234835089, -0.062799635733612],
+ [-0.61474933069654, -0.062799956628617, -0.786218433932708],
+ [-0.786219538571286, 0.614747943528109, 0.062799706183351],
+ [0.614749395150454, -0.062799770458141, 0.786218398406292],
+ [0.786217798458002, 0.614750179051967, -0.062799607828601],
+ [0.614747800214019, 0.062800058129605, -0.786219622517107],
+ [0.062800526363459, -0.786218909415802, 0.614748664386918],
+ [-0.062801412397757, 0.786218712810953, 0.614748825315458],
+ [0.829543607739232, 0.321465368220585, 0.456637076783941],
+ [0.321463595502047, 0.456637380632479, 0.829544127443504],
+ [-0.82954344503853, -0.321464459743646, 0.456638011903666],
+ [0.456635000556537, 0.829545159774775, 0.321464312422039],
+ [-0.32146420025779, 0.456637459573867, -0.829543849634571],
+ [0.456637068558535, -0.829544000114897, -0.321464367374743],
+ [-0.321462954195433, -0.456636421102337, 0.829544904150941],
+ [-0.456636713034899, -0.829544190971737, 0.321464379883261],
+ [0.321462955106589, -0.456636688799517, -0.82954475643955],
+ [-0.456637112396323, 0.829544098578271, -0.321464051017078],
+ [0.829544701976578, -0.321462298506758, -0.45663724997129],
+ [-0.829544861446795, 0.321463589390512, -0.456636051515194],
+ [-0.249500423448462, 0.954025094362385, -0.166089307379737],
+ [0.954025470855406, -0.166087567010738, -0.249500142371853],
+ [0.249500943484422, -0.954025029612664, -0.166088898102612],
+ [-0.166086877408662, -0.249500137449683, 0.954025592196158],
+ [-0.954024855383494, -0.166089937003122, 0.249500918108135],
+ [-0.166090151998567, 0.249500107118379, -0.954025030047436],
+ [-0.954025593688894, 0.166087862658579, -0.249499475879328],
+ [0.166089692874499, 0.249499687822531, 0.954025219633797],
+ [0.954024931419817, 0.166090647913648, 0.249500154118264],
+ [0.166087956122076, -0.249500002352048, -0.954025439732882],
+ [-0.249499759982538, -0.954025225930409, 0.166089548307795],
+ [0.249498374708179, 0.954025720257113, 0.166088789826025],
+ [0.860787215766444, 0.418630333044569, -0.289471956203095],
+ [0.418631425510959, -0.289473932939102, 0.860786019707239],
+ [-0.860786771736426, -0.418630687137034, -0.289472764506019],
+ [-0.289474446503673, 0.860786651964262, 0.418629770347917],
+ [-0.418629889262302, -0.289472838226515, -0.860787134978979],
+ [-0.289472399693171, -0.860787556030986, -0.418629326729602],
+ [-0.418629257388446, 0.289472594933548, 0.86078752409688],
+ [0.289473185156189, -0.860787817544824, 0.418628245872098],
+ [0.418628942424652, 0.289472756316209, -0.860787623002977],
+ [0.289473542762772, 0.860786603163078, -0.418630495610795],
+ [0.860788012261758, -0.418628676927689, 0.289471982755196],
+ [-0.860787839361109, 0.418628479474405, 0.289472782452827],
+ [-0.16910412959425, -0.878917391692094, 0.445991044680649],
+ [-0.878918175066698, 0.445989574417742, -0.169103935637546],
+ [0.169102333488308, 0.878918206143462, 0.445990120651083],
+ [0.445989469034572, -0.169103982598388, -0.878918219506016],
+ [0.878916540171611, 0.445992443170219, 0.169104867014003],
+ [0.445990224108014, 0.169104198005711, 0.87891779491425],
+ [0.878918285318947, -0.445988968934561, -0.169104959479873],
+ [-0.445991404870479, 0.169104651147508, -0.87891710857278],
+ [-0.878917501870672, -0.445990557689425, 0.169104841318319],
+ [-0.4459886381463, -0.169104239878744, 0.878918591622365],
+ [-0.16910368279187, 0.878918081785917, -0.445989854117772],
+ [0.169104218306708, -0.878917936541782, -0.445989937303537],
+ [0.699159749436449, 0.682605593469953, 0.2126622875159],
+ [0.682603600110598, 0.212662432840561, 0.699161651389995],
+ [-0.699161274801242, -0.682604056820824, 0.21266220498729],
+ [0.212660843412531, 0.699162101347243, 0.68260363441662],
+ [-0.682604762820295, 0.212661985195386, -0.699160652373835],
+ [0.212662594223091, -0.699161699678606, -0.682603500372528],
+ [-0.682602562402764, -0.212662368159073, 0.69916268419457],
+ [-0.212661230546804, -0.699160950060591, 0.682604693019826],
+ [0.682603395227417, -0.21266176997579, -0.699162053042617],
+ [-0.212661876797938, 0.699162212981133, -0.682603198129121],
+ [0.699162077611852, -0.682602648532129, -0.212664085934605],
+ [-0.69916198280867, 0.682603270588311, -0.212662400948523],
+ [-0.893254372981228, -0.172342415041176, -0.415204428116666],
+ [-0.17234169138316, -0.415205105182347, -0.893254197886418],
+ [0.893254479512939, 0.172341865725213, -0.415204426937409],
+ [-0.415203760144359, -0.893254961141014, -0.172340975855865],
+ [0.172343621116966, -0.415205895670259, 0.893253458129858],
+ [-0.41520633444977, 0.89325370222595, 0.172341298859036],
+ [0.172340599563611, 0.415204277853847, -0.893254793098767],
+ [0.415204013461881, 0.893254652987798, -0.172341962739188],
+ [-0.17234119462921, 0.415204328749048, 0.893254654632054],
+ [0.415206142771325, -0.893254015914866, 0.172340134782712],
+ [-0.89325441464337, 0.172340274858685, 0.415205226823752],
+ [0.893254705389659, -0.172340550786628, 0.415204486793911],
+ [-0.030119107290242, 0.538031004327585, -0.842386774444073],
+ [0.538032715913301, -0.84238566918646, -0.030119444819523],
+ [0.030118087641353, -0.538031590262412, -0.842386436664628],
+ [-0.842386183209587, -0.030119347292783, 0.538031916577671],
+ [-0.538030304105545, -0.842387233636645, 0.030118772718924],
+ [-0.842387312723823, 0.030117901022641, -0.538030229076328],
+ [-0.538031723308682, 0.842386324999934, -0.030118834084264],
+ [0.842387103098144, 0.030119789658303, 0.53803045155907],
+ [0.538029173032331, 0.842387968746045, 0.030118417588877],
+ [0.842386330532407, -0.030117441179441, -0.538031792619125],
+ [-0.030117059116644, -0.538030739137179, 0.842387017049566],
+ [0.030118346812524, 0.538030710181824, 0.84238698950454],
+ [0.951905881051384, -0.301774121097739, 0.052986540323701],
+ [-0.301774405343499, 0.052986798530194, 0.951905776566724],
+ [-0.951905698855431, 0.3017745999, 0.05298708655653],
+ [0.052987612958423, 0.951905238066977, -0.301775960960479],
+ [0.301774562398047, 0.052986834212903, -0.951905724790833],
+ [0.052986766829206, -0.951905252173644, 0.301776065030379],
+ [0.301777293645336, -0.052987994727859, 0.951904794322844],
+ [-0.052986574701187, -0.951905591847301, -0.301775027315507],
+ [-0.301776941841401, -0.052986526316734, -0.951904987591586],
+ [-0.052988794896112, 0.951905240176443, 0.301775746772478],
+ [0.95190486181224, 0.301777462059556, -0.052985823114433],
+ [-0.951905018594348, -0.301776824324304, -0.052986638651951],
+ [0.553606146300219, 0.45440115669048, 0.697882385203248],
+ [0.454399619298559, 0.697882738116233, 0.55360696330584],
+ [-0.553605814018882, -0.454401128197097, 0.697882667342941],
+ [0.697880969772289, 0.55360789125326, 0.454401204632875],
+ [-0.454400796454347, 0.697882268834231, -0.553606588678677],
+ [0.697882520653254, -0.553605103541103, -0.454402219074583],
+ [-0.454400180900896, -0.697882141032518, 0.553607255032936],
+ [-0.697881635390884, -0.553607048847703, 0.454401208680482],
+ [0.454401271818775, -0.69788172578063, -0.553606883077631],
+ [-0.69788232191633, 0.553606343141444, -0.454401014072625],
+ [0.553606029480292, -0.454400306845994, -0.697883031217504],
+ [-0.553606960810672, 0.454400842716203, -0.697881943512493],
+ ]
+ ),
+ # degree 18
+ np.array(
+ [
+ [-0.866376343641697, 0.223696804580225, 0.446488265017841],
+ [0.223696806212017, 0.446488265347841, -0.866376343050305],
+ [0.866376343115579, -0.223696806225293, 0.44648826521453],
+ [0.44648826367979, -0.866376344067145, 0.223696805603153],
+ [-0.223696804286002, 0.446488265023544, 0.866376343714725],
+ [0.446488262849567, 0.866376344947941, -0.22369680384892],
+ [-0.2236968055886, -0.446488263582537, -0.866376344121022],
+ [-0.446488264810465, 0.866376343829741, 0.223696804265844],
+ [0.223696803801399, -0.446488262808774, 0.866376344981234],
+ [-0.446488265014924, -0.866376343219064, -0.223696806222901],
+ [-0.8663763449212, -0.223696804074408, -0.446488262788483],
+ [0.866376344172558, 0.223696805482214, -0.446488263535836],
+ [-0.806844783933568, -0.461758079243128, -0.368484695601989],
+ [-0.461758081774945, -0.368484698390835, -0.806844781210945],
+ [0.80684478133506, 0.461758081613586, -0.368484698321273],
+ [-0.368484697968357, -0.806844781706494, -0.461758081246195],
+ [0.461758078945765, -0.368484695716793, 0.806844784051319],
+ [-0.368484697702554, 0.806844783105505, 0.461758079013772],
+ [0.461758081217295, 0.368484698443883, -0.806844781505862],
+ [0.368484695481328, 0.806844784182969, -0.461758078903629],
+ [-0.461758078967836, 0.36848469793606, 0.806844783025151],
+ [0.36848469846061, -0.806844781213308, 0.461758081715136],
+ [-0.806844782774103, 0.461758079314706, 0.368484698051091],
+ [0.806844781709987, -0.461758081098712, 0.368484698145526],
+ [-0.134842418858112, -0.040021669507572, 0.990058477084218],
+ [-0.040021669975618, 0.990058477016276, -0.134842419218046],
+ [0.134842418981357, 0.040021669788942, 0.990058477056058],
+ [0.990058476924117, -0.134842420436143, -0.040021668151402],
+ [0.040021669677461, 0.990058477116921, 0.13484241856757],
+ [0.990058477286021, 0.134842417855397, 0.040021667893743],
+ [0.040021668037836, -0.990058476927568, -0.134842420444507],
+ [-0.990058477115239, 0.134842418635191, -0.040021669491235],
+ [-0.040021667837798, -0.990058477270892, 0.134842417983082],
+ [-0.990058477042031, -0.134842419087429, 0.040021669778575],
+ [-0.134842418122745, 0.040021667670891, -0.990058477258617],
+ [0.134842420378113, -0.040021667867212, -0.990058476943508],
+ [0.049794077313207, -0.279738156561879, -0.958784185115654],
+ [-0.279738157129975, -0.958784185068512, 0.049794075029415],
+ [-0.049794075085005, 0.279738157111834, -0.958784185070918],
+ [-0.958784184460955, 0.049794077429761, -0.279738158785068],
+ [0.279738156684233, -0.958784185083851, -0.049794077238191],
+ [-0.958784184306963, -0.049794076856858, 0.279738159414846],
+ [0.279738159012938, 0.958784184390379, 0.049794077508567],
+ [0.958784185034086, -0.049794077337113, -0.279738156837192],
+ [-0.279738159575992, 0.95878418424315, -0.049794077180261],
+ [0.958784185016722, 0.049794074909289, 0.279738157328865],
+ [0.049794077031905, 0.279738159517178, 0.958784184268015],
+ [-0.049794077785621, -0.279738158888691, 0.958784184412241],
+ [0.205470768670777, -0.192901743072287, 0.959463746444603],
+ [-0.192901744385898, 0.959463746331714, 0.205470767964668],
+ [-0.205470768086678, 0.19290174463045, 0.959463746256418],
+ [0.959463745738735, 0.205470770340502, -0.192901744804646],
+ [0.19290174290288, 0.959463746447685, -0.205470768815433],
+ [0.95946374536634, -0.205470771694041, 0.192901745215149],
+ [0.192901744892626, -0.959463745685675, 0.205470770505673],
+ [-0.959463746372533, -0.205470769064203, -0.192901743011692],
+ [-0.192901745122065, -0.959463745348892, -0.205470771862908],
+ [-0.959463746220563, 0.205470768260652, 0.192901744623478],
+ [0.205470771726444, 0.192901745460598, -0.959463745310053],
+ [-0.205470770652743, -0.192901744949698, -0.959463745642705],
+ [-0.278905392074019, 0.772004854137857, -0.571156972696319],
+ [0.772004854268172, -0.571156972466399, -0.278905392184152],
+ [0.278905392160238, -0.772004854249339, -0.571156972503532],
+ [-0.571156971675365, -0.278905392078835, 0.772004854891456],
+ [-0.772004854266013, -0.571156972533567, 0.278905392052577],
+ [-0.571156970567139, 0.278905391582234, -0.77200485589077],
+ [-0.772004855078365, 0.571156971421921, -0.27890539208049],
+ [0.57115697255647, 0.278905391952095, 0.77200485428537],
+ [0.772004855995114, 0.571156970376128, 0.278905391684575],
+ [0.571156972406906, -0.278905392034109, -0.772004854366394],
+ [-0.278905391724262, -0.772004855939801, 0.571156970431511],
+ [0.278905392185009, 0.772004854970249, 0.571156971517017],
+ [0.912363859945553, -0.393198149041577, -0.113962286110494],
+ [-0.393198146993911, -0.113962287143254, 0.912363860699027],
+ [-0.912363860783175, 0.393198146824102, -0.113962287055462],
+ [-0.113962288369946, 0.912363860162702, -0.393198147882843],
+ [0.393198149035756, -0.113962285681562, -0.912363860001638],
+ [-0.113962285259825, -0.912363861029922, 0.393198146771995],
+ [0.393198147677495, 0.113962288066404, 0.912363860289116],
+ [0.113962286031302, -0.912363859929225, -0.393198149102416],
+ [-0.39319814691232, 0.113962285090844, -0.912363860990554],
+ [0.113962287370876, 0.912363860762456, 0.393198146780759],
+ [0.9123638610767, 0.393198146641003, 0.113962285337288],
+ [-0.912363860199082, -0.39319814784926, 0.113962288194565],
+ [0.848662336981788, -0.012909984472825, -0.528777429633226],
+ [-0.01290998542466, -0.528777432035493, 0.848662335470524],
+ [-0.848662335552813, 0.012909985214183, -0.528777431908562],
+ [-0.52877743352546, 0.848662334565693, -0.012909983878239],
+ [0.012909984639682, -0.528777429281808, -0.848662337198209],
+ [-0.528777430386149, -0.848662336521067, 0.012909983920382],
+ [0.012909983944448, 0.52877743344648, 0.848662334613896],
+ [0.528777429496827, -0.848662337067122, -0.012909984449961],
+ [-0.012909983871647, 0.528777430419671, -0.848662336500922],
+ [0.528777432240356, 0.848662335344326, 0.012909985329594],
+ [0.848662336343559, 0.012909983743557, 0.528777430675359],
+ [-0.848662334668199, -0.012909983655303, 0.528777433366386],
+ [-0.69585113208617, 0.211164782101034, 0.686440555892948],
+ [0.211164781099711, 0.68644055554441, -0.695851132733858],
+ [0.695851132741401, -0.211164781335084, 0.686440555464357],
+ [0.686440553889191, -0.695851134384757, 0.211164781040182],
+ [-0.211164781930503, 0.686440555960218, 0.695851132071559],
+ [0.686440553598939, 0.695851134525998, -0.21116478151828],
+ [-0.21116478087036, -0.686440553737906, -0.69585113458553],
+ [-0.686440555776475, 0.695851132224679, 0.211164782023223],
+ [0.211164781498505, -0.686440553499572, 0.695851134630023],
+ [-0.686440555292332, -0.695851132981083, -0.211164781104467],
+ [-0.695851134744882, -0.211164781531153, -0.686440553373094],
+ [0.695851134495813, 0.211164781101486, -0.686440553757753],
+ [-0.261718169263029, -0.581630098396244, 0.770201290908541],
+ [-0.581630098290833, 0.770201291506502, -0.261718167737572],
+ [0.261718167857864, 0.581630098126426, 0.770201291589781],
+ [0.770201292726794, -0.261718168321791, -0.581630096412025],
+ [0.581630098450626, 0.770201290888376, 0.261718169201518],
+ [0.770201293263127, 0.261718168077775, 0.581630095811608],
+ [0.581630096213803, -0.770201292881278, -0.261718168307686],
+ [-0.770201291051568, 0.26171816913029, -0.581630098266577],
+ [-0.581630095716607, -0.770201293304276, 0.261718168167806],
+ [-0.770201291705965, -0.261718167641045, 0.581630098070137],
+ [-0.261718168076348, 0.581630095637746, -0.770201293394907],
+ [0.261718168494542, -0.581630096129926, -0.770201292881124],
+ [0.506136437086844, 0.700992881596967, 0.502428987025446],
+ [0.700992883568509, 0.502428985302136, 0.506136436066968],
+ [-0.506136436123196, -0.700992883503112, 0.502428985336736],
+ [0.502428986281426, 0.506136435764488, 0.700992883085013],
+ [-0.700992881635171, 0.502428986938925, -0.50613643711982],
+ [0.502428986199081, -0.506136436342322, -0.700992882726821],
+ [-0.700992883178434, -0.502428986124795, 0.506136435790584],
+ [-0.502428987099143, -0.506136437034413, 0.700992881582003],
+ [0.700992882671006, -0.502428986197914, -0.506136436420782],
+ [-0.502428985277898, 0.506136435955935, -0.700992883666051],
+ [0.506136436300189, -0.700992882789867, -0.502428986153563],
+ [-0.506136435852246, 0.700992882991532, -0.502428986323445],
+ [-0.440748149182578, 0.602242024157979, 0.665616716534547],
+ [0.602242022260099, 0.66561671834234, -0.440748149045733],
+ [0.440748149100016, -0.602242022337998, 0.665616718235914],
+ [0.665616715634027, -0.440748149390786, 0.602242025000887],
+ [-0.602242023804998, 0.665616716814167, 0.440748149242614],
+ [0.665616716586012, 0.440748149783209, -0.602242023661529],
+ [-0.602242024940208, -0.665616715760932, -0.440748149282046],
+ [-0.665616716462371, 0.440748149424008, 0.602242024061062],
+ [0.602242023852026, -0.665616716460744, 0.440748149712092],
+ [-0.665616718266293, -0.44074814917988, -0.602242022245974],
+ [-0.440748149655782, -0.602242023883194, -0.66561671646983],
+ [0.44074814928254, 0.602242025306933, -0.665616715428797],
+ [-0.89025783677553, -0.293518547758229, 0.348264046639405],
+ [-0.293518546899673, 0.348264043649922, -0.890257838228066],
+ [0.890257838178446, 0.293518546762444, 0.348264043892422],
+ [0.34826404446276, -0.890257837322353, -0.293518548682307],
+ [0.293518547908785, 0.348264046686625, 0.89025783670742],
+ [0.348264047178787, 0.890257836270502, 0.293518548650024],
+ [0.293518548932545, -0.348264044336184, -0.890257837289365],
+ [-0.348264046901224, 0.890257836626627, -0.29351854789921],
+ [-0.2935185489462, -0.348264047080228, 0.890257836211408],
+ [-0.348264043786589, -0.890257838192766, 0.293518546844585],
+ [-0.890257836357219, 0.293518548692058, -0.348264046921688],
+ [0.890257837186443, -0.293518548811097, -0.348264044701638],
+ [0.661971946522154, 0.031389655564508, 0.748871037990662],
+ [0.03138965429721, 0.748871040172752, 0.661971944113708],
+ [-0.661971944196008, -0.031389654112142, 0.748871040107759],
+ [0.748871039164329, 0.661971945218693, 0.031389655052549],
+ [-0.031389655768972, 0.748871037783183, -0.661971946747175],
+ [0.748871037422933, -0.661971947171443, -0.031389655416215],
+ [-0.031389655026768, -0.748871039044161, 0.661971945355858],
+ [-0.748871037767735, -0.661971946761125, 0.031389655843332],
+ [0.03138965553856, -0.748871037222178, -0.661971947392751],
+ [-0.748871040238931, 0.661971944045087, -0.031389654165497],
+ [0.66197194707148, -0.031389655223358, -0.748871037519379],
+ [-0.661971945551351, 0.031389654961479, -0.74887103887409],
+ [-0.125732546862956, -0.877697090664539, -0.462427446956124],
+ [-0.877697091831705, -0.462427445382079, -0.125732544504481],
+ [0.125732544403638, 0.877697091976424, -0.46242744513482],
+ [-0.462427446167756, -0.125732547895101, -0.877697090932044],
+ [0.877697090790478, -0.462427446687307, 0.125732546972493],
+ [-0.462427443232932, 0.125732547131528, 0.877697092587683],
+ [0.87769709111192, 0.462427445862905, -0.12573254776065],
+ [0.462427446678181, 0.125732547366796, -0.877697090738801],
+ [-0.87769709250851, 0.462427443357225, 0.125732547227075],
+ [0.462427444949274, -0.125732544734265, 0.877697092026818],
+ [-0.125732546895942, 0.877697092616795, 0.462427443241732],
+ [0.125732547889573, -0.877697091021935, 0.462427445998644],
+ ]
+ ),
+ # degree 19
+ np.array(
+ [
+ [0.553035945587524, -0.472050222255944, 0.686527370580538],
+ [-0.472050227459673, 0.686527365766638, 0.553035947121696],
+ [-0.55303594558747, 0.472050222505474, 0.686527370409006],
+ [0.686527372366403, 0.553035941501725, -0.472050224445432],
+ [0.472050228567412, 0.686527364805305, -0.553035947369552],
+ [0.68652737203169, -0.553035941518164, 0.472050224912964],
+ [0.472050228340927, -0.686527365268236, 0.553035946988198],
+ [-0.686527371732145, -0.553035942965273, -0.47205022365323],
+ [-0.472050227580466, -0.686527365608527, -0.553035947214868],
+ [-0.686527371021655, 0.553035942983048, 0.472050224665708],
+ [0.553035946644886, 0.472050221691609, -0.686527370116806],
+ [-0.553035947212832, -0.472050222465287, -0.68652736912732],
+ [0.534151654424436, 0.792082393152326, 0.29544456761586],
+ [0.792082397489039, 0.295444568376044, 0.534151647573148],
+ [-0.53415165460592, -0.792082392760173, 0.295444568339099],
+ [0.295444567949351, 0.534151645341887, 0.792082399152876],
+ [-0.792082397600766, 0.29544456929757, -0.534151646897765],
+ [0.295444567829592, -0.534151645364488, -0.792082399182305],
+ [-0.792082397865911, -0.295444567933543, 0.534151647259045],
+ [-0.295444567962042, -0.53415164560035, 0.792082398973845],
+ [0.792082397128777, -0.29544456908432, -0.534151647715618],
+ [-0.295444567489444, 0.53415164476261, -0.792082399715064],
+ [0.534151654464793, -0.792082393125927, -0.29544456761367],
+ [-0.534151654460867, 0.792082392713663, -0.295444568726043],
+ [-0.987783901989363, -0.008366313346394, -0.155605166275604],
+ [-0.008366316491905, -0.155605166254194, -0.987783901966094],
+ [0.987783902042018, 0.008366312354305, -0.155605165994688],
+ [-0.155605167507252, -0.98778390181532, -0.008366310987655],
+ [0.008366315777747, -0.155605166766477, 0.987783901891443],
+ [-0.155605168424393, 0.987783901667492, 0.008366311383278],
+ [0.008366317026602, 0.155605166706053, -0.987783901890384],
+ [0.155605166835858, 0.987783901919836, -0.008366311135093],
+ [-0.008366315838957, 0.155605165685948, 0.98778390206114],
+ [0.155605167761508, -0.987783901773982, 0.008366311139443],
+ [-0.98778390211314, 0.008366313179595, 0.155605165498836],
+ [0.987783902165643, -0.008366312162939, 0.155605165220208],
+ [0.950764981387945, 0.202727494112491, -0.234408859255043],
+ [0.202727496789996, -0.234408860732757, 0.950764980452705],
+ [-0.950764980986237, -0.202727494847485, -0.234408860248721],
+ [-0.23440885021567, 0.950764983233011, 0.202727495911382],
+ [-0.20272749729896, -0.234408861592541, -0.950764980132203],
+ [-0.23440885011577, -0.950764983322899, -0.20272749560533],
+ [-0.202727496759051, 0.234408860491485, 0.950764980518789],
+ [0.234408850569327, -0.950764983253747, 0.20272749540521],
+ [0.202727497565203, 0.234408861679376, -0.950764980054025],
+ [0.234408850341461, 0.950764983305224, -0.202727495427267],
+ [0.950764981380539, -0.202727493695606, 0.234408859645621],
+ [-0.950764980970666, 0.202727494426432, 0.234408860676023],
+ [0.512072989115983, -0.124051607170076, -0.849936734455185],
+ [-0.12405160965336, -0.849936734716267, 0.512072988081055],
+ [-0.51207298893537, 0.124051606674421, -0.849936734636344],
+ [-0.849936734725645, 0.512072989351902, -0.124051604343177],
+ [0.124051609706947, -0.849936734419284, -0.512072988561004],
+ [-0.849936734185619, -0.512072990133951, 0.124051604814925],
+ [0.124051609905272, 0.849936734159209, 0.512072988944631],
+ [0.849936734486865, -0.512072989718667, -0.124051604465195],
+ [-0.124051609776913, 0.849936734911909, -0.512072987726399],
+ [0.849936733865973, 0.512072990727649, 0.124051604554246],
+ [0.512072989657044, 0.124051606837459, 0.849936734177751],
+ [-0.512072989396574, -0.124051606970032, 0.84993673431533],
+ [0.391883697914976, 0.850423194793585, -0.351009340424947],
+ [0.850423195330397, -0.351009335923244, 0.39188370078221],
+ [-0.391883697466306, -0.850423195007668, -0.351009340407185],
+ [-0.351009335872243, 0.391883705326061, 0.850423193257595],
+ [-0.850423194593128, -0.351009337444654, -0.391883701019427],
+ [-0.35100933799945, -0.391883705264188, -0.850423192408108],
+ [-0.850423194468673, 0.351009337760337, 0.391883701006749],
+ [0.351009335527498, -0.391883705866852, 0.850423193150685],
+ [0.850423195361416, 0.351009336170377, -0.391883700493539],
+ [0.351009336873483, 0.39188370616407, -0.850423192458173],
+ [0.391883698323181, -0.850423194786778, 0.3510093399857],
+ [-0.391883698167036, 0.850423194811902, 0.351009340099156],
+ [-0.637143378120116, -0.628499374133282, 0.446135464216598],
+ [-0.628499375204954, 0.446135468086576, -0.637143374353178],
+ [0.63714337757707, 0.628499374826823, 0.446135464015109],
+ [0.446135466991108, -0.63714337381196, -0.628499376531226],
+ [0.628499375911897, 0.446135468292267, 0.637143373511799],
+ [0.446135467311664, 0.637143373480954, 0.628499376639239],
+ [0.62849937527089, -0.446135468666392, -0.637143373882143],
+ [-0.446135467006437, 0.637143373286424, -0.628499377053108],
+ [-0.628499376195147, -0.446135467887251, 0.637143373515989],
+ [-0.446135467633094, -0.637143373382135, 0.628499376511253],
+ [-0.637143377816856, 0.628499373935058, -0.446135464928946],
+ [0.637143377542956, -0.62849937478419, -0.446135464123887],
+ [-0.420378708184596, 0.903565957647232, -0.082766550526719],
+ [0.903565960547129, -0.082766548074817, -0.420378702434272],
+ [0.42037870768752, -0.903565957904322, -0.082766550244743],
+ [-0.08276654593922, -0.420378701283585, 0.9035659612781],
+ [-0.903565960760554, -0.082766547146253, 0.420378702158358],
+ [-0.082766545039106, 0.420378701254078, -0.903565961374279],
+ [-0.903565960509685, 0.082766547722836, -0.420378702584056],
+ [0.082766546052241, 0.420378700439882, 0.903565961660275],
+ [0.90356596090935, 0.082766547862683, 0.420378701697478],
+ [0.082766545679396, -0.420378701270528, -0.903565961307975],
+ [-0.420378707505505, -0.903565957945495, 0.082766550719722],
+ [0.420378706956033, 0.903565958233438, 0.082766550367062],
+ [0.491848298473796, 0.355367007972287, 0.794858189196825],
+ [0.355367012626596, 0.794858187708896, 0.491848297515583],
+ [-0.491848298344631, -0.355367008156744, 0.794858189194284],
+ [0.794858192062911, 0.491848294626225, 0.355367006886901],
+ [-0.355367012548889, 0.794858187634159, -0.491848297692508],
+ [0.794858192091183, -0.491848294618182, -0.355367006834796],
+ [-0.355367012605403, -0.79485818761909, 0.491848297676028],
+ [-0.79485819260841, -0.491848293926967, 0.355367006634583],
+ [0.35536701250799, -0.794858187986535, -0.491848297152596],
+ [-0.794858192358054, 0.491848294578517, -0.355367006292778],
+ [0.491848297979809, -0.355367007868558, -0.794858189548874],
+ [-0.491848297571808, 0.355367007417215, -0.794858190003127],
+ [0.060667255805915, 0.97798263888706, 0.199673338501868],
+ [0.977982638810576, 0.199673341482313, 0.060667247229371],
+ [-0.06066725576913, -0.977982638936182, 0.199673338272451],
+ [0.19967333790937, 0.060667250976018, 0.977982639307643],
+ [-0.977982639072168, 0.199673340242081, -0.060667247094362],
+ [0.199673337373138, -0.060667251086811, -0.977982639410252],
+ [-0.977982638978921, -0.199673340702871, 0.060667247080943],
+ [-0.199673337990036, -0.060667251306886, 0.977982639270649],
+ [0.977982638897865, -0.199673341052594, -0.060667247236562],
+ [-0.199673337084201, 0.060667250789575, -0.977982639487682],
+ [0.06066725570001, -0.977982638898456, -0.199673338478232],
+ [-0.060667256074209, 0.977982638961939, -0.199673338053604],
+ [-0.708312961873346, 0.702414591990534, 0.070046334671986],
+ [0.702414584158394, 0.070046328146925, -0.70831297028554],
+ [0.70831296180624, -0.702414591950002, 0.070046335757007],
+ [0.070046325730793, -0.7083129711293, 0.702414583548491],
+ [-0.702414584241602, 0.070046328819927, 0.70831297013647],
+ [0.070046325138075, 0.708312971393231, -0.70241458334145],
+ [-0.702414584340882, -0.070046327329757, -0.708312970185382],
+ [-0.070046326094986, 0.708312970542407, 0.702414584103993],
+ [0.702414584126282, -0.070046328999645, 0.708312970233058],
+ [-0.07004632593766, -0.708312970292399, -0.70241458437179],
+ [-0.70831296129047, -0.702414592488956, -0.070046335567964],
+ [0.708312961059513, 0.702414592640383, -0.070046336384914],
+ [-0.608778246497891, -0.729529462544733, -0.311730348009535],
+ [-0.729529461162802, -0.311730341531525, -0.608778251471052],
+ [0.608778246679673, 0.729529462023489, -0.31173034887438],
+ [-0.311730343069402, -0.608778253416134, -0.729529458882528],
+ [0.729529460955067, -0.311730341992774, 0.608778251483804],
+ [-0.311730342453046, 0.608778253837742, 0.729529458794075],
+ [0.729529461285603, 0.311730341286902, -0.608778251449154],
+ [0.311730342676067, 0.608778254584565, -0.729529458075568],
+ [-0.729529460737167, 0.311730342625706, 0.608778251420826],
+ [0.311730342500045, -0.608778254449614, 0.729529458263397],
+ [-0.608778247292532, 0.72952946202083, 0.31173034768375],
+ [0.608778247330452, -0.729529461617846, 0.311730348552781],
+ [0.230102774190651, -0.807756554170623, 0.542754145543051],
+ [-0.807756552084345, 0.542754149424728, 0.230102772358463],
+ [-0.230102773683601, 0.807756554197333, 0.542754145718266],
+ [0.542754144206019, 0.230102772513564, -0.807756555546758],
+ [0.807756552132751, 0.542754149180793, -0.230102772763921],
+ [0.54275414387689, -0.230102773432955, 0.807756555506005],
+ [0.807756552229309, -0.542754148882616, 0.230102773128283],
+ [-0.542754145084005, -0.230102772500065, -0.80775655496066],
+ [-0.807756552237738, -0.542754149346909, -0.230102772003543],
+ [-0.542754144288786, 0.230102773227955, 0.807756555287639],
+ [0.230102774097675, 0.807756554025896, -0.542754145797859],
+ [-0.230102773562357, -0.807756553761761, -0.542754146417909],
+ [-0.496383809474105, -0.862518230775131, -0.098312843883766],
+ [-0.862518224287333, -0.098312838975785, -0.49638382171939],
+ [0.496383809596231, 0.862518230686221, -0.098312844047173],
+ [-0.098312839350041, -0.496383823019562, -0.862518223496418],
+ [0.862518224300261, -0.098312838333147, 0.496383821824206],
+ [-0.098312838299782, 0.496383823078636, 0.862518223582133],
+ [0.862518224470515, 0.098312838524506, -0.496383821490472],
+ [0.09831283917121, 0.496383824314041, -0.862518222771822],
+ [-0.862518224078588, 0.098312839378387, 0.496383822002367],
+ [0.098312838470056, -0.49638382381015, 0.862518223141735],
+ [-0.496383810069422, 0.862518230414379, 0.098312844042943],
+ [0.496383810403814, -0.862518230215463, 0.098312844099726],
+ [0.278692551327958, 0.919313188465131, 0.277837584477674],
+ [0.919313191744972, 0.277837581526559, 0.278692543450923],
+ [-0.278692551566547, -0.91931318841363, 0.277837584408758],
+ [0.277837583051005, 0.278692544908351, 0.919313190842426],
+ [-0.919313192180326, 0.277837580345951, -0.278692543191822],
+ [0.277837582008532, -0.278692545046071, -0.919313191115735],
+ [-0.919313192196504, -0.277837580255645, 0.278692543228489],
+ [-0.277837582825575, -0.278692545265575, 0.919313190802263],
+ [0.919313191814052, -0.277837581086655, -0.278692543661607],
+ [-0.277837581528602, 0.278692544535811, -0.919313191415468],
+ [0.278692551299389, -0.91931318860489, -0.277837584043894],
+ [-0.278692551719555, 0.919313188501633, -0.277837583964092],
+ [0.711723818982073, -0.147355178359107, -0.686830151423428],
+ [-0.14735518004562, -0.686830151696651, 0.711723818369232],
+ [-0.711723818994987, 0.147355179083635, -0.686830151254603],
+ [-0.686830156031755, 0.711723816221312, -0.147355170213896],
+ [0.147355179878181, -0.68683015150873, -0.711723818585246],
+ [-0.686830155899656, -0.711723816480405, 0.147355169578202],
+ [0.147355179832049, 0.686830151151262, 0.711723818939762],
+ [0.686830156307707, -0.711723816117428, -0.147355169429431],
+ [-0.147355180410728, 0.686830151596083, -0.711723818390689],
+ [0.686830155813336, 0.711723816667769, 0.147355169075579],
+ [0.711723819167954, 0.147355177853636, 0.686830151339256],
+ [-0.711723818958232, -0.147355177932743, 0.686830151539607],
+ [0.910866815770901, -0.407547474081887, 0.065013077890936],
+ [-0.407547470055014, 0.06501307469253, 0.910866817800923],
+ [-0.91086681602966, 0.40754747351676, 0.065013077808199],
+ [0.065013071417123, 0.910866817243773, -0.407547471822745],
+ [0.407547469547224, 0.065013074424327, -0.910866818047266],
+ [0.065013071503944, -0.910866817193855, 0.407547471920462],
+ [0.407547469994702, -0.065013074730498, 0.910866817825199],
+ [-0.065013071237167, -0.910866817002829, -0.407547472389962],
+ [-0.407547469492954, -0.065013074760909, -0.910866818047525],
+ [-0.065013070894069, 0.910866817046896, 0.407547472346204],
+ [0.910866815571027, 0.407547474607393, -0.065013077397032],
+ [-0.910866815826998, -0.407547474069762, -0.065013077180997],
+ ]
+ ),
+ # degree 20
+ np.array(
+ [
+ [-0.251581299355938, 0.965702462813156, -0.064230858090044],
+ [0.965702462812973, -0.064230858090163, -0.251581299356609],
+ [0.25158129935621, -0.965702462813076, -0.064230858090184],
+ [-0.064230858090037, -0.251581299356469, 0.965702462813018],
+ [-0.965702462812988, -0.064230858090212, 0.25158129935654],
+ [-0.064230858090283, 0.251581299356213, -0.965702462813068],
+ [-0.965702462813129, 0.06423085809035, -0.251581299355962],
+ [0.064230858090209, 0.251581299356322, 0.965702462813045],
+ [0.96570246281309, 0.064230858089911, 0.251581299356226],
+ [0.0642308580902, -0.2515812993563, -0.965702462813051],
+ [-0.2515812993566, -0.965702462812992, 0.064230858089919],
+ [0.251581299356516, 0.965702462812981, 0.064230858090402],
+ [-0.774265533845772, 0.381515182343397, -0.504934697500583],
+ [0.381515182343197, -0.504934697500657, -0.774265533845823],
+ [0.774265533845583, -0.381515182343386, -0.504934697500883],
+ [-0.504934697500797, -0.774265533845681, 0.3815151823433],
+ [-0.381515182343153, -0.504934697500805, 0.774265533845748],
+ [-0.504934697500622, 0.774265533845887, -0.381515182343114],
+ [-0.381515182343272, 0.504934697500883, -0.774265533845639],
+ [0.504934697500808, 0.774265533845615, 0.381515182343419],
+ [0.38151518234349, 0.504934697500621, 0.774265533845703],
+ [0.50493469750058, -0.774265533845806, -0.381515182343333],
+ [-0.774265533845719, -0.381515182343321, 0.504934697500723],
+ [0.774265533845894, 0.38151518234298, 0.504934697500711],
+ [0.621892089865857, 0.451716799694261, -0.639689113113747],
+ [0.451716799694191, -0.639689113113918, 0.621892089865731],
+ [-0.621892089865648, -0.451716799694225, -0.639689113113976],
+ [-0.639689113113901, 0.621892089865499, 0.451716799694535],
+ [-0.451716799694008, -0.6396891131138, -0.621892089865986],
+ [-0.639689113113879, -0.621892089865655, -0.451716799694351],
+ [-0.451716799694347, 0.639689113113675, 0.621892089865869],
+ [0.639689113113788, -0.621892089865995, 0.451716799694013],
+ [0.451716799694587, 0.639689113113955, -0.621892089865406],
+ [0.639689113114061, 0.6218920898659, -0.451716799693757],
+ [0.621892089865889, -0.451716799694281, 0.639689113113701],
+ [-0.621892089865898, 0.451716799693713, 0.639689113114094],
+ [0.281811042675091, 0.858047847696197, -0.429344182783814],
+ [0.858047847696408, -0.429344182783659, 0.281811042674688],
+ [-0.281811042675114, -0.858047847696306, -0.429344182783581],
+ [-0.429344182783315, 0.281811042674947, 0.858047847696495],
+ [-0.858047847696386, -0.429344182783329, -0.281811042675257],
+ [-0.429344182783979, -0.281811042674793, -0.858047847696213],
+ [-0.858047847696136, 0.429344182783948, 0.281811042675075],
+ [0.429344182783574, -0.281811042675002, 0.858047847696347],
+ [0.85804784769643, 0.429344182783432, -0.281811042674964],
+ [0.429344182783407, 0.2818110426754, -0.8580478476963],
+ [0.28181104267478, -0.858047847696515, 0.429344182783383],
+ [-0.281811042675193, 0.858047847696227, 0.429344182783688],
+ [-0.649612004107369, -0.615311084069471, 0.44653836782617],
+ [-0.615311084069575, 0.446538367826544, -0.649612004107014],
+ [0.649612004107338, 0.615311084069274, 0.446538367826487],
+ [0.44653836782629, -0.649612004107234, -0.615311084069526],
+ [0.615311084069631, 0.446538367826189, 0.649612004107205],
+ [0.4465383678263, 0.649612004107223, 0.615311084069531],
+ [0.615311084069337, -0.44653836782627, -0.649612004107428],
+ [-0.446538367826248, 0.649612004107346, -0.615311084069439],
+ [-0.615311084069373, -0.446538367826536, 0.649612004107211],
+ [-0.446538367826286, -0.649612004107303, 0.615311084069457],
+ [-0.649612004107121, 0.615311084069723, -0.446538367826183],
+ [0.649612004107125, -0.615311084069551, -0.446538367826415],
+ [0.993363116319503, -0.113468728148246, -0.018829946054775],
+ [-0.113468728148035, -0.018829946054639, 0.993363116319529],
+ [-0.993363116319523, 0.113468728148204, -0.018829946053964],
+ [-0.018829946053903, 0.993363116319554, -0.113468728147943],
+ [0.113468728148066, -0.018829946054323, -0.993363116319532],
+ [-0.018829946054743, -0.993363116319533, 0.113468728147986],
+ [0.113468728148219, 0.018829946054485, 0.993363116319511],
+ [0.018829946054344, -0.99336311631951, -0.113468728148254],
+ [-0.113468728148178, 0.018829946054246, -0.993363116319521],
+ [0.018829946054485, 0.993363116319503, 0.113468728148287],
+ [0.99336311631954, 0.113468728147985, 0.018829946054382],
+ [-0.993363116319531, -0.113468728148037, 0.018829946054542],
+ [0.246398885891569, -0.720801569649804, 0.647867799957501],
+ [-0.720801569649392, 0.647867799957886, 0.246398885891762],
+ [-0.246398885891682, 0.720801569649632, 0.647867799957649],
+ [0.647867799957437, 0.246398885891663, -0.720801569649829],
+ [0.720801569649864, 0.647867799957577, -0.246398885891192],
+ [0.647867799957658, -0.246398885891679, 0.720801569649625],
+ [0.720801569649656, -0.647867799957734, 0.246398885891389],
+ [-0.647867799957904, -0.246398885891433, -0.720801569649489],
+ [-0.720801569649865, -0.647867799957373, -0.246398885891727],
+ [-0.647867799957474, 0.246398885891166, 0.720801569649966],
+ [0.246398885891794, 0.720801569649507, -0.647867799957745],
+ [-0.246398885891456, -0.720801569649666, -0.647867799957697],
+ [-0.793544204802179, -0.387628773401269, -0.469075184865183],
+ [-0.387628773401353, -0.469075184864794, -0.793544204802368],
+ [0.793544204802171, 0.387628773401536, -0.469075184864975],
+ [-0.469075184865097, -0.793544204802034, -0.387628773401668],
+ [0.38762877340168, -0.469075184864988, 0.793544204802093],
+ [-0.46907518486511, 0.793544204802104, 0.387628773401512],
+ [0.387628773401425, 0.469075184865298, -0.793544204802035],
+ [0.469075184865068, 0.793544204802337, -0.387628773401084],
+ [-0.387628773401491, 0.469075184864931, 0.793544204802219],
+ [0.469075184864784, -0.793544204802296, 0.387628773401512],
+ [-0.793544204802265, 0.387628773401224, 0.469075184865075],
+ [0.793544204802185, -0.387628773401823, 0.469075184864715],
+ [0.164945057653003, -0.958376909717154, 0.233038251960587],
+ [-0.958376909716935, 0.233038251961126, 0.164945057653512],
+ [-0.164945057653238, 0.958376909717001, 0.233038251961048],
+ [0.233038251960668, 0.164945057653504, -0.958376909717048],
+ [0.958376909717102, 0.233038251960514, -0.164945057653409],
+ [0.233038251960742, -0.164945057653288, 0.958376909717067],
+ [0.958376909717099, -0.233038251960827, 0.164945057652982],
+ [-0.233038251961122, -0.164945057653226, -0.958376909716986],
+ [-0.958376909717093, -0.233038251960632, -0.164945057653293],
+ [-0.233038251960434, 0.164945057653261, 0.958376909717147],
+ [0.164945057653494, 0.958376909716965, -0.233038251961015],
+ [-0.164945057653458, -0.958376909717031, -0.233038251960769],
+ [0.560484250466976, 0.813252649483695, -0.156452974040834],
+ [0.81325264948369, -0.156452974041446, 0.560484250466813],
+ [-0.56048425046724, -0.813252649483431, -0.156452974041263],
+ [-0.15645297404103, 0.560484250467047, 0.813252649483609],
+ [-0.81325264948382, -0.156452974040726, -0.560484250466826],
+ [-0.156452974041097, -0.560484250466778, -0.813252649483781],
+ [-0.81325264948363, 0.156452974040967, 0.560484250467035],
+ [0.156452974041285, -0.560484250467053, 0.813252649483555],
+ [0.813252649483481, 0.156452974041151, -0.560484250467199],
+ [0.156452974040881, 0.560484250466996, -0.813252649483672],
+ [0.560484250466836, -0.813252649483737, 0.156452974041122],
+ [-0.56048425046674, 0.813252649483823, 0.156452974041018],
+ [0.366630058651312, 0.922018832550933, -0.124353015704282],
+ [0.92201883255088, -0.124353015704762, 0.366630058651284],
+ [-0.366630058651761, -0.922018832550708, -0.124353015704629],
+ [-0.124353015704377, 0.366630058651577, 0.922018832550815],
+ [-0.922018832550933, -0.124353015704203, -0.366630058651341],
+ [-0.124353015704534, -0.366630058651111, -0.922018832550979],
+ [-0.922018832550883, 0.124353015704478, 0.366630058651372],
+ [0.12435301570463, -0.366630058651537, 0.922018832550797],
+ [0.922018832550745, 0.124353015704463, -0.366630058651723],
+ [0.124353015704299, 0.366630058651563, -0.922018832550831],
+ [0.366630058651286, -0.922018832550923, 0.124353015704438],
+ [-0.366630058651229, 0.922018832550938, 0.124353015704492],
+ [-0.804671953651735, -0.070836250755727, 0.589478814365005],
+ [-0.070836250756058, 0.589478814365003, -0.804671953651707],
+ [0.804671953651921, 0.070836250755383, 0.589478814364792],
+ [0.589478814364726, -0.804671953651941, -0.070836250755714],
+ [0.070836250755939, 0.589478814364776, 0.804671953651884],
+ [0.589478814365018, 0.804671953651715, 0.070836250755846],
+ [0.070836250755601, -0.589478814364811, -0.804671953651888],
+ [-0.589478814364784, 0.804671953651884, -0.070836250755875],
+ [-0.070836250755551, -0.589478814364944, 0.804671953651795],
+ [-0.589478814364978, -0.804671953651759, 0.07083625075567],
+ [-0.804671953651836, 0.070836250756193, -0.589478814364811],
+ [0.804671953651731, -0.070836250755764, -0.589478814365006],
+ [-0.830597137771463, -0.481356221636722, 0.280008183125909],
+ [-0.481356221636763, 0.280008183126324, -0.830597137771299],
+ [0.830597137771467, 0.481356221636628, 0.280008183126056],
+ [0.280008183125864, -0.830597137771343, -0.481356221636956],
+ [0.481356221637075, 0.280008183125899, 0.830597137771262],
+ [0.280008183126004, 0.830597137771351, 0.481356221636859],
+ [0.481356221636653, -0.280008183125859, -0.83059713777152],
+ [-0.280008183126012, 0.83059713777152, -0.481356221636564],
+ [-0.481356221636741, -0.280008183126112, 0.830597137771384],
+ [-0.280008183126053, -0.830597137771314, 0.481356221636894],
+ [-0.830597137771366, 0.48135622163684, -0.280008183125994],
+ [0.830597137771194, -0.481356221637029, -0.280008183126178],
+ [0.622576105404642, 0.027441908430236, -0.782077959439399],
+ [0.027441908430276, -0.782077959439431, 0.622576105404601],
+ [-0.622576105404963, -0.027441908430045, -0.78207795943915],
+ [-0.782077959439118, 0.622576105404988, 0.027441908430397],
+ [-0.027441908430201, -0.782077959439296, -0.622576105404774],
+ [-0.782077959439408, -0.622576105404628, -0.027441908430289],
+ [-0.027441908430238, 0.782077959439221, 0.622576105404866],
+ [0.782077959439263, -0.62257610540482, 0.027441908430083],
+ [0.027441908430419, 0.782077959439269, -0.622576105404798],
+ [0.782077959439451, 0.622576105404591, -0.027441908429928],
+ [0.622576105404788, -0.02744190843038, 0.782077959439278],
+ [-0.622576105404572, 0.027441908429868, 0.782077959439468],
+ [-0.93186959347387, 0.318712863282032, -0.173323891998229],
+ [0.318712863281944, -0.173323891998258, -0.931869593473894],
+ [0.931869593473744, -0.318712863282051, -0.173323891998871],
+ [-0.173323891998841, -0.931869593473836, 0.318712863281799],
+ [-0.318712863281924, -0.173323891998617, 0.931869593473834],
+ [-0.173323891998245, 0.931869593473975, -0.318712863281714],
+ [-0.318712863281997, 0.173323891998515, -0.931869593473828],
+ [0.173323891998501, 0.931869593473801, 0.318712863282084],
+ [0.318712863282089, 0.173323891998539, 0.931869593473793],
+ [0.173323891998443, -0.931869593473824, -0.31871286328205],
+ [-0.931869593473865, -0.318712863281928, 0.173323891998448],
+ [0.931869593473897, 0.318712863281802, 0.173323891998503],
+ [0.883848176852703, 0.201423804475213, 0.422185801827685],
+ [0.201423804475703, 0.422185801827661, 0.883848176852602],
+ [-0.883848176852534, -0.201423804475554, 0.422185801827875],
+ [0.42218580182791, 0.883848176852484, 0.201423804475701],
+ [-0.201423804475472, 0.422185801827744, -0.883848176852615],
+ [0.422185801827623, -0.883848176852647, -0.201423804475586],
+ [-0.201423804475397, -0.422185801827833, 0.88384817685259],
+ [-0.42218580182793, -0.883848176852523, 0.201423804475489],
+ [0.201423804475479, -0.422185801827682, -0.883848176852643],
+ [-0.422185801827514, 0.883848176852769, -0.20142380447528],
+ [0.883848176852476, -0.201423804475614, -0.422185801827967],
+ [-0.88384817685271, 0.201423804475563, -0.422185801827502],
+ [0.204275039956405, 0.718770569884226, 0.664560438123663],
+ [0.718770569884334, 0.664560438123474, 0.204275039956637],
+ [-0.20427503995626, -0.718770569884265, 0.664560438123664],
+ [0.66456043812381, 0.204275039956156, 0.71877056988416],
+ [-0.718770569884325, 0.664560438123579, -0.204275039956328],
+ [0.664560438123492, -0.204275039956373, -0.718770569884393],
+ [-0.718770569884361, -0.664560438123554, 0.20427503995628],
+ [-0.664560438123554, -0.204275039956662, 0.718770569884254],
+ [0.71877056988409, -0.664560438123802, -0.204275039956432],
+ [-0.664560438123505, 0.204275039956682, -0.718770569884293],
+ [0.204275039956376, -0.718770569884165, -0.664560438123738],
+ [-0.204275039956367, 0.718770569884538, -0.664560438123337],
+ [-0.898847927472069, 0.43770082336828, 0.022144807560617],
+ [0.437700823367923, 0.022144807560963, -0.898847927472234],
+ [0.898847927472182, -0.437700823368065, 0.02214480756027],
+ [0.022144807560315, -0.898847927472293, 0.437700823367834],
+ [-0.437700823367766, 0.022144807560623, 0.898847927472319],
+ [0.022144807560559, 0.89884792747216, -0.437700823368094],
+ [-0.437700823368255, -0.022144807560327, -0.898847927472088],
+ [-0.022144807560661, 0.898847927472103, 0.437700823368207],
+ [0.43770082336803, -0.022144807560607, 0.898847927472191],
+ [-0.022144807560733, -0.898847927472195, -0.437700823368015],
+ [-0.898847927472245, -0.437700823367908, -0.022144807560796],
+ [0.898847927472313, 0.437700823367778, -0.022144807560634],
+ ]
+ ),
+ # degree 21
+ np.array(
+ [
+ [0.892653535762723, 0.412534053657361, -0.181618610454253],
+ [0.412534053425032, -0.181618610641782, 0.892653535831938],
+ [-0.892653535806407, -0.412534053627853, -0.181618610306575],
+ [-0.181618610613849, 0.892653535740475, 0.41253405363524],
+ [-0.412534053477435, -0.181618610422654, -0.892653535852304],
+ [-0.181618610451384, -0.892653535762812, -0.412534053658432],
+ [-0.41253405331709, 0.181618610611827, 0.892653535887918],
+ [0.181618610400136, -0.8926535358123, 0.412534053573911],
+ [0.412534053327996, 0.1816186104204, -0.892653535921825],
+ [0.181618610580789, 0.892653535810904, -0.412534053497399],
+ [0.892653535867644, -0.412534053472558, 0.181618610358339],
+ [-0.892653535855064, 0.41253405353516, 0.181618610277971],
+ [-0.292093742593433, -0.29576702799317, 0.909507070170347],
+ [-0.295767028026887, 0.90950707008926, -0.292093742811776],
+ [0.292093742447864, 0.295767028039713, 0.909507070201962],
+ [0.909507070147612, -0.292093742926721, -0.295767027733934],
+ [0.295767028145396, 0.909507070084441, 0.292093742706783],
+ [0.909507070188854, 0.292093742689207, 0.295767027841675],
+ [0.295767027907311, -0.909507070148419, -0.292093742748651],
+ [-0.909507070101221, 0.292093743159272, -0.295767027646927],
+ [-0.295767027835333, -0.909507070047293, 0.292093743136414],
+ [-0.909507070218591, -0.292093742721776, 0.295767027718069],
+ [-0.292093742540896, 0.295767027793147, -0.909507070252266],
+ [0.292093742861938, -0.295767027747614, -0.909507070163969],
+ [-0.575225718038192, 0.024120572825078, 0.817639022597403],
+ [0.024120572786144, 0.817639022511238, -0.575225718162301],
+ [0.575225718116478, -0.024120572979213, 0.817639022537781],
+ [0.817639022556003, -0.57522571810348, 0.024120572671469],
+ [-0.024120573041503, 0.817639022440757, 0.575225718251777],
+ [0.817639022458379, 0.575225718229118, -0.024120572984526],
+ [-0.024120572818239, -0.81763902258126, -0.575225718061424],
+ [-0.817639022543578, 0.575225718123882, 0.024120572606111],
+ [0.02412057271295, -0.817639022527296, 0.575225718142546],
+ [-0.817639022600495, -0.575225718035174, -0.024120572792228],
+ [-0.575225717925469, -0.024120572711052, -0.81763902268007],
+ [0.57522571790823, 0.024120572594155, -0.817639022695646],
+ [-0.1288331617248, 0.05224764072024, 0.990288947973853],
+ [0.052247640694409, 0.990288947958895, -0.128833161850251],
+ [0.128833161840325, -0.052247640320038, 0.990288947979938],
+ [0.990288947949717, -0.128833161924796, 0.052247640684558],
+ [-0.05224764038851, 0.990288947967581, 0.128833161907538],
+ [0.99028894797773, 0.128833161878001, -0.052247640268992],
+ [-0.052247640390409, -0.99028894796219, -0.128833161948209],
+ [-0.990288947960626, 0.128833161896649, 0.052247640547187],
+ [0.052247640527808, -0.990288947953251, 0.1288331619612],
+ [-0.990288947970868, -0.128833161936205, -0.052247640255526],
+ [-0.128833161790478, -0.052247640337643, -0.990288947985494],
+ [0.128833161857416, 0.052247640551545, -0.9902889479655],
+ [0.71800638603475, 0.657446876255993, -0.228539787596286],
+ [0.657446876286737, -0.228539787831922, 0.718006385931596],
+ [-0.718006386109442, -0.657446876171434, -0.228539787604877],
+ [-0.228539787737219, 0.718006385947422, 0.657446876302374],
+ [-0.657446876241021, -0.2285397877138, -0.718006386011054],
+ [-0.228539787678997, -0.718006386031359, -0.657446876230945],
+ [-0.657446876361185, 0.228539787860549, 0.718006385854315],
+ [0.228539787703065, -0.718006385857385, 0.657446876412577],
+ [0.657446876304454, 0.228539787874017, -0.718006385901975],
+ [0.228539787784967, 0.718006385813853, -0.657446876431648],
+ [0.71800638588076, -0.657446876363485, 0.228539787770851],
+ [-0.718006385891018, 0.657446876371558, 0.228539787715401],
+ [0.863176473117803, 0.468181816653138, 0.189029528940001],
+ [0.468181816438486, 0.189029529197492, 0.86317647317784],
+ [-0.863176473194446, -0.46818181657642, 0.189029528780033],
+ [0.189029529125527, 0.863176473064389, 0.468181816676708],
+ [-0.468181816392671, 0.189029528897443, -0.863176473268398],
+ [0.189029528792174, -0.863176473143688, -0.4681818166651],
+ [-0.468181816411213, -0.189029529128138, 0.863176473207821],
+ [-0.189029528897852, -0.86317647308972, 0.468181816721931],
+ [0.468181816508867, -0.189029528930555, -0.863176473198123],
+ [-0.189029529001823, 0.863176473106659, -0.468181816648722],
+ [0.863176473135229, -0.468181816648642, -0.189029528871561],
+ [-0.863176473123334, 0.468181816698762, -0.189029528801744],
+ [0.772632856847133, -0.51705945069559, 0.368358511462152],
+ [-0.517059450567132, 0.368358511585515, 0.772632856874286],
+ [-0.772632856806081, 0.517059450647391, 0.368358511615915],
+ [0.368358511648001, 0.772632856806054, -0.517059450624573],
+ [0.517059450494007, 0.368358511816588, -0.772632856813056],
+ [0.368358511720496, -0.772632856802476, 0.517059450578273],
+ [0.517059450583445, -0.368358511487117, 0.77263285691028],
+ [-0.36835851156733, -0.772632856859467, -0.517059450602229],
+ [-0.517059450502369, -0.368358511665956, -0.772632856879275],
+ [-0.368358511469803, 0.772632856855651, 0.517059450677412],
+ [0.772632856934749, 0.517059450691919, -0.368358511283531],
+ [-0.772632856927485, -0.517059450633778, -0.368358511380378],
+ [-0.847819231914648, -0.066325775900167, -0.526121128113002],
+ [-0.066325775913631, -0.526121128257686, -0.847819231823809],
+ [0.847819231883018, 0.066325775819852, -0.526121128174097],
+ [-0.526121128348762, -0.847819231766957, -0.06632577591791],
+ [0.06632577584612, -0.526121128407098, 0.847819231736372],
+ [-0.52612112845924, 0.84781923170908, 0.066325775781366],
+ [0.066325775945785, 0.52612112834438, -0.847819231767496],
+ [0.526121128449532, 0.847819231700692, -0.066325775965613],
+ [-0.066325775877211, 0.526121128306388, 0.847819231796436],
+ [0.526121128504669, -0.847819231665213, 0.06632577598176],
+ [-0.847819231821725, 0.066325775941005, 0.526121128257594],
+ [0.847819231850264, -0.066325775996655, 0.52612112820459],
+ [0.00980574322923, 0.942983815842593, 0.332694109443892],
+ [0.942983815808923, 0.332694109539748, 0.00980574321495],
+ [-0.00980574337969, -0.942983815787291, 0.332694109596207],
+ [0.332694109226554, 0.009805743204272, 0.942983815919532],
+ [-0.94298381577404, 0.332694109635647, -0.009805743315804],
+ [0.332694109397996, -0.00980574329891, -0.942983815858062],
+ [-0.942983815776114, -0.332694109630098, 0.009805743304667],
+ [-0.332694109319027, -0.009805743188507, 0.94298381588707],
+ [0.942983815775082, -0.332694109635199, -0.009805743230763],
+ [-0.332694109455765, 0.009805743389762, -0.942983815836735],
+ [0.00980574330114, -0.942983815752524, -0.332694109697065],
+ [-0.009805743287713, 0.942983815791379, -0.332694109587331],
+ [0.785599248371152, -0.405156945312269, -0.467634120465896],
+ [-0.405156944932125, -0.467634120649859, 0.785599248457698],
+ [-0.78559924820179, 0.405156945434051, -0.467634120644904],
+ [-0.467634120611242, 0.785599248334623, -0.405156945215339],
+ [0.405156945136423, -0.467634120868201, -0.785599248222366],
+ [-0.467634120811804, -0.785599248145609, 0.405156945350347],
+ [0.405156944841985, 0.467634120861332, 0.785599248378305],
+ [0.467634120786726, -0.785599248249857, -0.405156945177156],
+ [-0.405156944999643, 0.467634120871098, -0.785599248291182],
+ [0.467634120893713, 0.78559924823424, 0.405156945083953],
+ [0.785599248313341, 0.405156945117104, 0.467634120732106],
+ [-0.7855992482811, -0.40515694519737, 0.467634120716727],
+ [-0.737331999131492, 0.620851501013764, -0.26624225199189],
+ [0.620851500949186, -0.266242252154895, -0.73733199912701],
+ [0.737331999060061, -0.620851501088737, -0.266242252014883],
+ [-0.266242251948631, -0.737331999103255, 0.62085150106585],
+ [-0.620851501079221, -0.2662422522338, 0.737331998989025],
+ [-0.266242252011624, 0.737331998996222, -0.620851501165951],
+ [-0.620851501072124, 0.26624225222256, -0.73733199899906],
+ [0.266242252113864, 0.737331998832974, 0.620851501315983],
+ [0.620851501187387, 0.266242252328374, 0.737331998863797],
+ [0.26624225193225, -0.73733199893899, -0.620851501267959],
+ [-0.737331998947943, -0.620851501183297, 0.266242252104879],
+ [0.737331998835007, 0.620851501305786, 0.26624225213201],
+ [0.726871469165659, -0.027488282350428, -0.686223186468061],
+ [-0.027488282182755, -0.686223186448325, 0.726871469190633],
+ [-0.726871469172931, 0.027488282371885, -0.686223186459499],
+ [-0.686223186449712, 0.726871469185406, -0.027488282286341],
+ [0.027488282351607, -0.68622318649112, -0.726871469143845],
+ [-0.686223186545622, -0.726871469089794, 0.027488282420281],
+ [0.027488282266836, 0.686223186470335, 0.726871469166674],
+ [0.686223186661183, -0.726871468983422, -0.027488282348185],
+ [-0.027488282251029, 0.686223186523092, -0.726871469117465],
+ [0.686223186609112, 0.726871469033498, 0.027488282323948],
+ [0.726871469070107, 0.02748828233555, 0.686223186569869],
+ [-0.726871469080183, -0.027488282309716, 0.686223186560232],
+ [0.665363385720515, 0.580860267739271, 0.468927408352716],
+ [0.580860267577087, 0.468927408488638, 0.665363385766308],
+ [-0.66536338567738, -0.580860267719575, 0.468927408438318],
+ [0.468927408340783, 0.665363385821863, 0.580860267632813],
+ [-0.580860267528453, 0.468927408678832, -0.665363385674723],
+ [0.468927408372614, -0.665363385698803, -0.580860267748078],
+ [-0.580860267640877, -0.468927408552762, 0.665363385665427],
+ [-0.468927408468336, -0.665363385847947, 0.580860267499961],
+ [0.580860267386752, -0.468927408654519, -0.665363385815563],
+ [-0.468927408375699, 0.665363385651356, -0.580860267799938],
+ [0.665363385651819, -0.580860267791212, -0.46892740838585],
+ [-0.665363385751734, 0.580860267548017, -0.468927408545326],
+ [-0.580125367305304, -0.779099597924434, 0.237609710918707],
+ [-0.779099598053518, 0.237609710909934, -0.580125367135539],
+ [0.580125367186808, 0.779099597977732, 0.237609711033258],
+ [0.237609710695932, -0.58012536727611, -0.779099598014114],
+ [0.779099598064732, 0.23760971114732, 0.58012536702325],
+ [0.237609710819285, 0.580125367047426, 0.779099598146774],
+ [0.779099598170224, -0.237609710849642, -0.580125367003499],
+ [-0.237609710811802, 0.580125367157256, -0.779099598067276],
+ [-0.779099598074961, -0.237609711045128, 0.580125367051369],
+ [-0.237609710609253, -0.580125367022359, 0.779099598229495],
+ [-0.580125367090094, 0.779099598151966, -0.237609710698086],
+ [0.580125367218411, -0.779099597966716, -0.237609710992215],
+ [0.9586680253602, 0.101113605900539, -0.265954236389956],
+ [0.101113605889893, -0.265954236477199, 0.95866802533712],
+ [-0.95866802532641, -0.101113606095432, -0.26595423643766],
+ [-0.265954236634179, 0.958668025294555, 0.101113605880558],
+ [-0.101113606003171, -0.265954236656317, -0.958668025275482],
+ [-0.265954236715455, -0.958668025246162, -0.101113606125602],
+ [-0.101113605825438, 0.265954236414664, 0.958668025361267],
+ [0.265954236286739, -0.958668025393583, 0.101113605855522],
+ [0.101113605802444, 0.265954236260664, -0.958668025406415],
+ [0.265954236515854, 0.958668025322577, -0.101113605926106],
+ [0.9586680254495, -0.101113605909101, 0.265954236064808],
+ [-0.9586680254786, 0.101113605789497, 0.265954236005386],
+ [-0.784431814417085, 0.284319025007229, 0.551207239202516],
+ [0.284319024822848, 0.551207239320709, -0.784431814400862],
+ [0.784431814443422, -0.284319024888131, 0.551207239226467],
+ [0.551207239434677, -0.784431814291888, 0.284319024902556],
+ [-0.284319024640161, 0.551207239347504, 0.784431814448249],
+ [0.551207239408357, 0.784431814400998, -0.284319024652546],
+ [-0.28431902471494, -0.551207239160137, -0.784431814552804],
+ [-0.551207239417649, 0.784431814426743, 0.284319024563503],
+ [0.284319024477106, -0.551207239394067, 0.784431814474629],
+ [-0.551207239227164, -0.784431814510832, -0.284319024700797],
+ [-0.7844318146549, -0.284319024757729, -0.551207238992772],
+ [0.784431814542139, 0.284319024689884, -0.55120723918824],
+ [0.166663878535118, 0.97946877886665, 0.113419851953285],
+ [0.979468778892362, 0.113419852011248, 0.166663878344564],
+ [-0.166663878322335, -0.979468778877222, 0.113419852174659],
+ [0.113419851852603, 0.166663878465092, 0.979468778890224],
+ [-0.979468778908051, 0.113419852233229, -0.166663878101297],
+ [0.113419852023532, -0.166663878213165, -0.979468778913298],
+ [-0.979468778891418, -0.113419852088755, 0.166663878297368],
+ [-0.113419851942299, -0.166663878383785, 0.979468778893673],
+ [0.979468778887792, -0.113419852252651, -0.166663878207142],
+ [-0.113419851887333, 0.166663878420061, -0.979468778893865],
+ [0.166663878513312, -0.97946877885884, -0.113419852052775],
+ [-0.166663878525992, 0.979468778852403, -0.113419852089727],
+ [0.90354263539087, 0.099002690679599, 0.416904273507865],
+ [0.09900269051118, 0.416904273753692, 0.903542635295897],
+ [-0.903542635383533, -0.099002690647923, 0.416904273531288],
+ [0.41690427395825, 0.903542635193768, 0.09900269058185],
+ [-0.099002690414933, 0.416904273699732, -0.903542635331341],
+ [0.416904273843964, -0.903542635237517, -0.099002690663845],
+ [-0.099002690464192, -0.416904273937254, 0.903542635216348],
+ [-0.416904274206036, -0.903542635110147, 0.099002690301575],
+ [0.099002690128044, -0.41690427406438, -0.903542635194523],
+ [-0.416904274113744, 0.903542635131386, -0.099002690496392],
+ [0.903542635279275, -0.099002690467102, -0.416904273800183],
+ [-0.903542635234399, 0.099002690245829, -0.416904273949988],
+ [0.278762404536092, 0.349312185537063, -0.894579520698175],
+ [0.349312185586056, -0.894579520608515, 0.278762404762431],
+ [-0.278762404540525, -0.349312185503473, -0.89457952070991],
+ [-0.894579520734144, 0.278762404727917, 0.349312185291866],
+ [-0.349312185466701, -0.894579520677723, -0.278762404689896],
+ [-0.894579520788864, -0.278762404658677, -0.349312185206984],
+ [-0.349312185551041, 0.894579520682798, 0.278762404567923],
+ [0.894579520785219, -0.278762404680469, 0.349312185198929],
+ [0.349312185549623, 0.89457952067923, -0.278762404581149],
+ [0.894579520781805, 0.278762404555908, -0.349312185307075],
+ [0.27876240443795, -0.3493121855065, 0.894579520740692],
+ [-0.278762404443259, 0.349312185428787, 0.894579520769382],
+ [0.555896230179415, -0.676833211736671, 0.48257246581476],
+ [-0.676833211681567, 0.482572466040116, 0.555896230050876],
+ [-0.555896230314892, 0.676833211522987, 0.482572465958401],
+ [0.482572465910283, 0.555896230164672, -0.676833211680673],
+ [0.676833211457692, 0.482572466092895, -0.555896230277639],
+ [0.482572465902981, -0.555896230367909, 0.676833211518957],
+ [0.676833211635592, -0.482572466071981, 0.555896230079191],
+ [-0.482572466150586, -0.555896230230084, -0.676833211455616],
+ [-0.676833211438286, -0.482572466327737, -0.5558962300974],
+ [-0.482572465972373, 0.55589623026777, 0.676833211551727],
+ [0.555896230192691, 0.676833211589453, -0.482572466005949],
+ [-0.555896230194338, -0.676833211455537, -0.482572466191875],
+ ]
+ ),
+]
diff --git a/py4DSTEM/process/diffraction/utils.py b/py4DSTEM/process/diffraction/utils.py
new file mode 100644
index 000000000..cfb11f044
--- /dev/null
+++ b/py4DSTEM/process/diffraction/utils.py
@@ -0,0 +1,251 @@
+# Utility functions for the crystal module of py4DSTEM
+
+import numpy as np
+from dataclasses import dataclass
+import copy
+from scipy.ndimage import gaussian_filter
+
+from emdfile import tqdmnd
+
+
+@dataclass
+class Orientation:
+ """
+ A class for storing output orientations, generated by fitting a Crystal
+ class orientation plan or Bloch wave pattern matching to a PointList.
+ """
+
+ num_matches: int
+
+ def __post_init__(self):
+ self.matrix = np.zeros((self.num_matches, 3, 3))
+ self.family = np.zeros((self.num_matches, 3, 3))
+ self.corr = np.zeros((self.num_matches))
+ self.inds = np.zeros((self.num_matches, 2), dtype="int")
+ self.mirror = np.zeros((self.num_matches), dtype="bool")
+ self.angles = np.zeros((self.num_matches, 3))
+
+
+@dataclass
+class OrientationMap:
+ """
+ A class for storing output orientations, generated by fitting a Crystal class orientation plan or
+ Bloch wave pattern matching to a PointListArray.
+
+ """
+
+ num_x: int
+ num_y: int
+ num_matches: int
+
+ def __post_init__(self):
+ # initialize empty arrays
+ self.matrix = np.zeros((self.num_x, self.num_y, self.num_matches, 3, 3))
+ self.family = np.zeros((self.num_x, self.num_y, self.num_matches, 3, 3))
+ self.corr = np.zeros((self.num_x, self.num_y, self.num_matches))
+ self.inds = np.zeros((self.num_x, self.num_y, self.num_matches, 2), dtype="int")
+ self.mirror = np.zeros((self.num_x, self.num_y, self.num_matches), dtype="bool")
+ self.angles = np.zeros((self.num_x, self.num_y, self.num_matches, 3))
+
+ def set_orientation(self, orientation, ind_x, ind_y):
+ # Add an orientation to the orientation map
+ self.matrix[ind_x, ind_y] = orientation.matrix
+ self.family[ind_x, ind_y] = orientation.family
+ self.corr[ind_x, ind_y] = orientation.corr
+ self.inds[ind_x, ind_y] = orientation.inds
+ self.mirror[ind_x, ind_y] = orientation.mirror
+ self.angles[ind_x, ind_y] = orientation.angles
+
+ def get_orientation(self, ind_x, ind_y):
+ # Return an orientation from the orientation map
+ orientation = Orientation(num_matches=self.num_matches)
+ orientation.matrix = self.matrix[ind_x, ind_y]
+ orientation.family = self.family[ind_x, ind_y]
+ orientation.corr = self.corr[ind_x, ind_y]
+ orientation.inds = self.inds[ind_x, ind_y]
+ orientation.mirror = self.mirror[ind_x, ind_y]
+ orientation.angles = self.angles[ind_x, ind_y]
+ return orientation
+
+ def get_orientation_single(self, ind_x, ind_y, ind_match):
+ orientation = Orientation(num_matches=1)
+ orientation.matrix = self.matrix[ind_x, ind_y, ind_match]
+ orientation.family = self.family[ind_x, ind_y, ind_match]
+ orientation.corr = self.corr[ind_x, ind_y, ind_match]
+ orientation.inds = self.inds[ind_x, ind_y, ind_match]
+ orientation.mirror = self.mirror[ind_x, ind_y, ind_match]
+ orientation.angles = self.angles[ind_x, ind_y, ind_match]
+ return orientation
+
+ # def __copy__(self):
+ # return OrientationMap(self.name)
+ # def __deepcopy__(self, memo):
+ # return OrientationMap(copy.deepcopy(self.name, memo))
+
+
+def sort_orientation_maps(
+ orientation_map,
+ sort="intensity",
+ cluster_thresh=0.1,
+):
+ """
+ Sort the orientation maps along the ind_match direction, either by intensity
+ or by clustering similar angles (greedily, in order of intensity).
+
+ Args:
+ orientation_map Initial OrientationMap
+ sort (string): "intensity" or "cluster" for sorting method.
+ cluster_thresh (float): similarity threshold for clustering method
+
+ Returns:
+ orientation_sort Sorted OrientationMap
+ """
+
+ # make a deep copy
+ orientation_sort = copy.deepcopy(orientation_map)
+
+ if sort == "intensity":
+ for rx, ry in tqdmnd(
+ orientation_sort.num_x,
+ orientation_sort.num_y,
+ desc="Sorting orientations",
+ unit=" probe positions",
+ # disable=not progress_bar,
+ ):
+ inds = np.argsort(orientation_map.corr[rx, ry])[::-1]
+
+ orientation_sort.matrix[rx, ry, :, :, :] = orientation_sort.matrix[
+ rx, ry, inds, :, :
+ ]
+ orientation_sort.family[rx, ry, :, :, :] = orientation_sort.family[
+ rx, ry, inds, :, :
+ ]
+ orientation_sort.corr[rx, ry, :] = orientation_sort.corr[rx, ry, inds]
+ orientation_sort.inds[rx, ry, :, :] = orientation_sort.inds[rx, ry, inds, :]
+ orientation_sort.mirror[rx, ry, :] = orientation_sort.mirror[rx, ry, inds]
+ orientation_sort.angles[rx, ry, :, :] = orientation_sort.angles[
+ rx, ry, inds, :
+ ]
+
+ # elif sort == "cluster":
+ # mask = np.zeros_like(orientation_map.corr, dtype='bool')
+ # TODO - implement clustering method for sorting
+
+ else:
+ err_msg = "Invalid sorting method: " + sort
+ raise Exception(err_msg)
+
+ return orientation_sort
+
+
+def calc_1D_profile(
+ k,
+ g_coords,
+ g_int,
+ remove_origin=True,
+ k_broadening=0.0,
+ int_scale=None,
+ normalize_intensity=True,
+):
+ """
+ Utility function to calculate a 1D histogram from the diffraction vector lengths
+ stored in a Crystal class.
+
+ Args:
+ k (np.array): k coordinates.
+ g_coords (np.array): Scattering vector lengths g.
+ bragg_intensity_power (np.array): Scattering vector intensities.
+ remove_origin (bool): Remove the origin peak from the profile.
+ k_broadening (float): Broadening applied to full profile.
+ int_scale (np.array): Either a scalar value mulitiplied into all peak intensities,
+ or a vector with 1 value per peak to scale peaks individually.
+ normalize_intensity (bool): Normalize maximum output value to 1.
+
+ Returns:
+ int_profile (np.array): Computed intensity profile
+ """
+
+ # init
+ int_scale = np.atleast_1d(int_scale)
+ k_num = k.shape[0]
+ k_min = k[0]
+ k_step = k[1] - k[0]
+ k_max = k[-1]
+
+ # get discrete plot from structure factor amplitudes
+ int_profile = np.zeros_like(k)
+ k_px = (g_coords - k_min) / k_step
+ kf = np.floor(k_px).astype("int")
+ dk = k_px - kf
+
+ sub = np.logical_and(kf >= 0, kf < k_num)
+ if int_scale.shape[0] > 1:
+ int_profile = np.bincount(
+ np.floor(k_px[sub]).astype("int"),
+ weights=(1 - dk[sub]) * g_int[sub] * int_scale[sub],
+ minlength=k_num,
+ )
+ else:
+ int_profile = np.bincount(
+ np.floor(k_px[sub]).astype("int"),
+ weights=(1 - dk[sub]) * g_int[sub],
+ minlength=k_num,
+ )
+ sub = np.logical_and(k_px >= -1, k_px < k_num - 1)
+ if int_scale.shape[0] > 1:
+ int_profile += np.bincount(
+ np.floor(k_px[sub] + 1).astype("int"),
+ weights=dk[sub] * g_int[sub] * int_scale[sub],
+ minlength=k_num,
+ )
+ else:
+ int_profile += np.bincount(
+ np.floor(k_px[sub] + 1).astype("int"),
+ weights=dk[sub] * g_int[sub],
+ minlength=k_num,
+ )
+
+ if remove_origin is True:
+ int_profile[0:2] = 0
+
+ # Apply broadening if needed
+ if k_broadening > 0.0:
+ int_profile = gaussian_filter(
+ int_profile, k_broadening / k_step, mode="constant"
+ )
+
+ if normalize_intensity:
+ int_profile /= np.max(int_profile)
+ if int_scale is not None:
+ if int_scale.shape[0] == 1:
+ int_profile *= int_scale
+
+ return int_profile
+
+
+def axisEqual3D(ax):
+ extents = np.array([getattr(ax, "get_{}lim".format(dim))() for dim in "xyz"])
+ sz = extents[:, 1] - extents[:, 0]
+ centers = np.mean(extents, axis=1)
+ maxsize = max(abs(sz))
+ r = maxsize / 2
+ for ctr, dim in zip(centers, "xyz"):
+ getattr(ax, "set_{}lim".format(dim))(ctr - r, ctr + r)
+
+
+# fmt: off
+
+# a list of symbols for each element
+element_symbols = ('H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', 'Na', 'Mg', 'Al', 'Si',
+ 'P', 'S', 'Cl', 'Ar', 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni',
+ 'Cu', 'Zn', 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', 'Rb', 'Sr', 'Y', 'Zr', 'Nb',
+ 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe',
+ 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', 'Ho',
+ 'Er', 'Tm', 'Yb', 'Lu', 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl',
+ 'Pb', 'Bi', 'Po', 'At', 'Rn', 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am',
+ 'Cm', 'Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt',
+ 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', 'Lv', 'Ts', 'Og', 'Uue', 'Ubn', 'Ubn')
+
+# a dictionary for converting element names into Z
+elements = {element_symbols[i]: i + 1 for i in range(len(element_symbols))}
+# fmt: on
diff --git a/py4DSTEM/process/fit/__init__.py b/py4DSTEM/process/fit/__init__.py
new file mode 100644
index 000000000..6027635df
--- /dev/null
+++ b/py4DSTEM/process/fit/__init__.py
@@ -0,0 +1 @@
+from py4DSTEM.process.fit.fit import *
diff --git a/py4DSTEM/process/fit/fit.py b/py4DSTEM/process/fit/fit.py
new file mode 100644
index 000000000..9973ff79f
--- /dev/null
+++ b/py4DSTEM/process/fit/fit.py
@@ -0,0 +1,284 @@
+# Fitting
+
+import numpy as np
+from scipy.optimize import curve_fit
+from inspect import signature
+
+
+def gaussian(x, A, mu, sigma):
+ return A * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
+
+
+def fit_1D_gaussian(xdata, ydata, xmin, xmax):
+ """
+ Fits a 1D gaussian to the subset of the 1D curve f(xdata)=ydata within the window
+ (xmin,xmax). Returns A,mu,sigma. Retrieve the full curve with
+
+ >>> fit_gaussian = py4DSTEM.process.fit.gaussian(xdata,A,mu,sigma)
+ """
+
+ mask = (xmin <= xdata) * (xmax > xdata)
+ inds = np.nonzero(mask)[0]
+ _xdata = xdata[inds]
+ _ydata = ydata[inds]
+ scale = np.max(_ydata)
+ _ydata = _ydata / scale
+
+ p0 = [
+ np.max(_ydata),
+ _xdata[np.argmax(_ydata)],
+ (xmax - xmin) / 8.0,
+ ] # TODO: better guess for std
+
+ popt, pcov = curve_fit(gaussian, _xdata, _ydata, p0=p0)
+ A, mu, sigma = scale * popt[0], popt[1], popt[2]
+ return A, mu, sigma
+
+
+def fit_2D(
+ function,
+ data,
+ data_mask=None,
+ popt=None,
+ robust=False,
+ robust_steps=3,
+ robust_thresh=2,
+):
+ """
+ Performs a 2D fit.
+
+ TODO: make returning the mask optional
+
+ Parameters
+ ----------
+ function : callable
+ Some `function( xy, **p)` where `xy` is a length 2 vector (1D np array)
+ specifying the pixel position (x,y), and `p` is the function parameters
+ data : ndarray
+ Some 2D array of any shape (n,m)
+ data_mask : None or boolean array of shape (n,m), optional
+ If specified, fits only the pixels in `data` where this array is True
+ popt : dict
+ Initial guess at the parameters `p` of `function`. Note that positions
+ in pixels (i.e. the xy positions) are linearly scaled to the space [0,1]
+ robust : bool
+ Toggles robust fitting, which iteratively rejects outlier data points
+ which have a root-mean-square error beyond `robust_thresh`
+ robust_steps : int
+ The number of robust fitting iterations to perform
+ robust_thresh : int
+ The robust fitting cutoff
+
+ Returns:
+ (popt,pcov,fit_at, mask) : 4-tuple
+ The optimal fit parameters, the fitting covariance matrix, the
+ the fit array with the returned `popt` params, and the mask
+ """
+ # get shape
+ shape = data.shape
+ shape1D = [1, np.prod(shape)]
+
+ # x and y coordinates normalized from 0 to 1
+ x, y = np.linspace(0, 1, shape[0]), np.linspace(0, 1, shape[1])
+ ry, rx = np.meshgrid(y, x)
+ rx_1D = rx.reshape((1, np.prod(shape)))
+ ry_1D = ry.reshape((1, np.prod(shape)))
+ xy = np.vstack((rx_1D, ry_1D))
+
+ # if robust fitting is turned off, set number of robust iterations to 0
+ if robust is False:
+ robust_steps = 0
+
+ # least squares fitting
+ for k in range(robust_steps + 1):
+ # in 1st iteration, set up params and mask
+ if k == 0:
+ if popt is None:
+ popt = np.zeros((1, len(signature(function).parameters) - 1))
+ if data_mask is not None:
+ mask = data_mask
+ else:
+ mask = np.ones(shape, dtype=bool)
+
+ # otherwise, get fitting error and add high error pixels to mask
+ else:
+ fit_mean_square_error = (function(xy, *popt).reshape(shape) - data) ** 2
+ _mask = (
+ fit_mean_square_error
+ > np.mean(fit_mean_square_error) * robust_thresh**2
+ )
+ mask[_mask] = False
+
+ # perform fitting
+ popt, pcov = curve_fit(
+ function,
+ np.vstack((rx_1D[mask.reshape(shape1D)], ry_1D[mask.reshape(shape1D)])),
+ data[mask],
+ p0=popt,
+ )
+
+ fit_ar = function(xy, *popt).reshape(shape)
+ return popt, pcov, fit_ar, mask
+
+
+# Functions for fitting
+
+
+def plane(xy, mx, my, b):
+ return mx * xy[0] + my * xy[1] + b
+
+
+def parabola(xy, c0, cx1, cx2, cy1, cy2, cxy):
+ return (
+ c0
+ + cx1 * xy[0]
+ + cy1 * xy[1]
+ + cx2 * xy[0] ** 2
+ + cy2 * xy[1] ** 2
+ + cxy * xy[0] * xy[1]
+ )
+
+
+def bezier_two(xy, c00, c01, c02, c10, c11, c12, c20, c21, c22):
+ return (
+ c00 * ((1 - xy[0]) ** 2) * ((1 - xy[1]) ** 2)
+ + c10 * 2 * (1 - xy[0]) * xy[0] * ((1 - xy[1]) ** 2)
+ + c20 * (xy[0] ** 2) * ((1 - xy[1]) ** 2)
+ + c01 * 2 * ((1 - xy[0]) ** 2) * (1 - xy[1]) * xy[1]
+ + c11 * 4 * (1 - xy[0]) * xy[0] * (1 - xy[1]) * xy[1]
+ + c21 * 2 * (xy[0] ** 2) * (1 - xy[1]) * xy[1]
+ + c02 * ((1 - xy[0]) ** 2) * (xy[1] ** 2)
+ + c12 * 2 * (1 - xy[0]) * xy[0] * (xy[1] ** 2)
+ + c22 * (xy[0] ** 2) * (xy[1] ** 2)
+ )
+
+
+def polar_gaussian_2D(
+ tq,
+ I0,
+ mu_t,
+ mu_q,
+ sigma_t,
+ sigma_q,
+ C,
+):
+ # unpack position
+ t, q = tq
+ # set theta value to its closest periodic reflection to mu_t
+ # t = np.square(t-mu_t)
+ # t2 = np.min(np.vstack([t,1-t]))
+ t2 = np.square(t - mu_t)
+ return (
+ I0 * np.exp(-(t2 / (2 * sigma_t**2) + (q - mu_q) ** 2 / (2 * sigma_q**2)))
+ + C
+ )
+
+
+def polar_twofold_gaussian_2D(
+ tq,
+ I0,
+ mu_t,
+ mu_q,
+ sigma_t,
+ sigma_q,
+):
+ # unpack position
+ t, q = tq
+
+ # theta periodicity
+ dt = np.mod(t - mu_t + np.pi / 2, np.pi) - np.pi / 2
+
+ # output intensity
+ return I0 * np.exp(
+ (dt**2 / (-2.0 * sigma_t**2)) + ((q - mu_q) ** 2 / (-2.0 * sigma_q**2))
+ )
+
+
+def polar_twofold_gaussian_2D_background(
+ tq,
+ I0,
+ mu_t,
+ mu_q,
+ sigma_t,
+ sigma_q,
+ C,
+):
+ # unpack position
+ t, q = tq
+
+ # theta periodicity
+ dt = np.mod(t - mu_t + np.pi / 2, np.pi) - np.pi / 2
+
+ # output intensity
+ return C + I0 * np.exp(
+ (dt**2 / (-2.0 * sigma_t**2)) + ((q - mu_q) ** 2 / (-2.0 * sigma_q**2))
+ )
+
+
+def fit_2D_polar_gaussian(
+ data,
+ mask=None,
+ p0=None,
+ robust=False,
+ robust_steps=3,
+ robust_thresh=2,
+ constant_background=False,
+):
+ """
+
+ NOTE - this cannot work without using pixel coordinates - something is wrong in the workflow.
+
+
+ Fits a 2D gaussian to the pixels in `data` which are set to True in `mask`.
+
+ The gaussian is anisotropic and oriented along (t,q), centered at
+ (mu_t,mu_q), has standard deviations (sigma_t,sigma_q), maximum of I0,
+ and an optional constant offset of C, and is periodic in t.
+
+ f(x,y) = I0 * exp( - (x-mu_x)^2/(2sig_x^2) + (y-mu_y)^2/(2sig_y^2) )
+ or
+ f(x,y) = I0 * exp( - (x-mu_x)^2/(2sig_x^2) + (y-mu_y)^2/(2sig_y^2) ) + C
+
+ Parameters
+ ----------
+ data : 2d array
+ the data to fit
+ p0 : 6-tuple
+ initial guess at fit parameters, (I0,mu_x,mu_y,sigma_x_sigma_y,C)
+ mask : 2d boolean array
+ ignore pixels where mask is False
+ robust : bool
+ toggle robust fitting
+ robust_steps : int
+ number of robust fit iterations
+ robust_thresh : number
+ the robust fitting threshold
+ constant_background : bool
+ whether or not to include constant background
+
+ Returns
+ -------
+ (popt,pcov,fit_ar) : 3-tuple
+ the optimal fit parameters, the covariance matrix, and the fit array
+ """
+
+ if constant_background:
+ return fit_2D(
+ polar_twofold_gaussian_2D_background,
+ data=data,
+ data_mask=mask,
+ popt=p0,
+ robust=robust,
+ robust_steps=robust_steps,
+ robust_thresh=robust_thresh,
+ )
+ else:
+ return fit_2D(
+ polar_twofold_gaussian_2D,
+ data=data,
+ data_mask=mask,
+ popt=p0,
+ robust=robust,
+ robust_steps=robust_steps,
+ robust_thresh=robust_thresh,
+ )
diff --git a/py4DSTEM/process/phase/.gitignore b/py4DSTEM/process/phase/.gitignore
new file mode 100644
index 000000000..c97f963b3
--- /dev/null
+++ b/py4DSTEM/process/phase/.gitignore
@@ -0,0 +1 @@
+*.sh
diff --git a/py4DSTEM/process/phase/__init__.py b/py4DSTEM/process/phase/__init__.py
new file mode 100644
index 000000000..1005a619d
--- /dev/null
+++ b/py4DSTEM/process/phase/__init__.py
@@ -0,0 +1,16 @@
+# fmt: off
+
+_emd_hook = True
+
+from py4DSTEM.process.phase.iterative_dpc import DPCReconstruction
+from py4DSTEM.process.phase.iterative_mixedstate_multislice_ptychography import MixedstateMultislicePtychographicReconstruction
+from py4DSTEM.process.phase.iterative_mixedstate_ptychography import MixedstatePtychographicReconstruction
+from py4DSTEM.process.phase.iterative_multislice_ptychography import MultislicePtychographicReconstruction
+from py4DSTEM.process.phase.iterative_overlap_magnetic_tomography import OverlapMagneticTomographicReconstruction
+from py4DSTEM.process.phase.iterative_overlap_tomography import OverlapTomographicReconstruction
+from py4DSTEM.process.phase.iterative_parallax import ParallaxReconstruction
+from py4DSTEM.process.phase.iterative_simultaneous_ptychography import SimultaneousPtychographicReconstruction
+from py4DSTEM.process.phase.iterative_singleslice_ptychography import SingleslicePtychographicReconstruction
+from py4DSTEM.process.phase.parameter_optimize import OptimizationParameter, PtychographyOptimizer
+
+# fmt: on
diff --git a/py4DSTEM/process/phase/iterative_base_class.py b/py4DSTEM/process/phase/iterative_base_class.py
new file mode 100644
index 000000000..767789df2
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_base_class.py
@@ -0,0 +1,2627 @@
+"""
+Module for reconstructing phase objects from 4DSTEM datasets using iterative methods.
+"""
+
+import warnings
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1 import ImageGrid
+from py4DSTEM.visualize import return_scaled_histogram_ordering, show, show_complex
+from scipy.ndimage import rotate
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd
+from py4DSTEM.data import Calibration
+from py4DSTEM.datacube import DataCube
+from py4DSTEM.process.calibration import fit_origin
+from py4DSTEM.process.phase.iterative_ptychographic_constraints import (
+ PtychographicConstraints,
+)
+from py4DSTEM.process.phase.utils import (
+ AffineTransform,
+ generate_batches,
+ polar_aliases,
+)
+from py4DSTEM.process.utils import (
+ electron_wavelength_angstrom,
+ fourier_resample,
+ get_shifted_ar,
+)
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+
+class PhaseReconstruction(Custom):
+ """
+ Base phase reconstruction class.
+ Defines various common functions and properties for subclasses to inherit.
+ """
+
+ def attach_datacube(self, datacube: DataCube):
+ """
+ Attaches a datacube to a class initialized without one.
+
+ Parameters
+ ----------
+ datacube: Datacube
+ Input 4D diffraction pattern intensities
+
+ Returns
+ --------
+ self: PhaseReconstruction
+ Self to enable chaining
+ """
+ self._datacube = datacube
+ return self
+
+ def reinitialize_parameters(self, device: str = None, verbose: bool = None):
+ """
+ Reinitializes common parameters. This is useful when loading a previously-saved
+ reconstruction (which set device='cpu' and verbose=True for compatibility) ,
+ using different initialization parameters.
+
+ Parameters
+ ----------
+ device: str, optional
+ If not None, imports and assigns appropriate device modules
+ verbose: bool, optional
+ If not None, sets the verbosity to verbose
+
+ Returns
+ --------
+ self: PhaseReconstruction
+ Self to enable chaining
+ """
+
+ if device is not None:
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from scipy.special import erf
+
+ self._erf = erf
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from cupyx.scipy.special import erf
+
+ self._erf = erf
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+ self._device = device
+
+ if verbose is not None:
+ self._verbose = verbose
+
+ return self
+
+ def set_save_defaults(
+ self,
+ save_datacube: bool = False,
+ save_exit_waves: bool = False,
+ save_iterations: bool = True,
+ save_iterations_frequency: int = 1,
+ ):
+ """
+ Sets the class defaults for saving reconstructions to file.
+
+ Parameters
+ ----------
+ save_datacube: bool, optional
+ If True, self._datacube saved to file
+ save_exit_waves: bool, optional
+ If True, self._exit_waves saved to file
+ save_iterations: bool, optional
+ If True, self.probe_iterations and self.object_iterations saved to file
+ save_iterations: int, optional
+ If save_iterations is True, controls the frequency of saved iterations
+
+ Returns
+ --------
+ self: PhaseReconstruction
+ Self to enable chaining
+ """
+ self._save_datacube = save_datacube
+ self._save_exit_waves = save_exit_waves
+ self._save_iterations = save_iterations
+ self._save_iterations_frequency = save_iterations_frequency
+ return self
+
+ def _preprocess_datacube_and_vacuum_probe(
+ self,
+ datacube,
+ diffraction_intensities_shape=None,
+ reshaping_method="fourier",
+ probe_roi_shape=None,
+ vacuum_probe_intensity=None,
+ dp_mask=None,
+ com_shifts=None,
+ ):
+ """
+ Datacube preprocessing step, to set the reciprocal- and real-space sampling.
+ Let the measured diffraction intensities have size (Rx,Ry,Qx,Qy), with reciprocal-space
+ samping (dkx,dky). This sets a real-space sampling which is inversely proportional to
+ the maximum scattering wavevector (Qx*dkx,Qy*dky).
+
+ Often, it is beneficial to resample the measured diffraction intensities using a different
+ reciprocal-space sampling (dkx',dky'), e.g. downsampling to save memory. This is achieved
+ by specifying a diffraction_intensities_shape (Sx,Sy) which is different than (Qx,Qy).
+ Note this does not affect the maximum scattering wavevector (Qx*dkx,Qy*dky) = (Sx*dkx',Sy*dky'),
+ and thus the real-space sampling stays fixed.
+
+ The real space sampling, (dx, dy), combined with the resampled diffraction_intensities_shape,
+ sets the real-space probe region of interest (ROI) extent (dx*Sx, dy*Sy).
+ Occasionally, one may also want to specify a larger probe ROI extent, e.g when the probe
+ does not comfortably fit without self-ovelap artifacts, or when the scan step sizes are much
+ smaller than the real-space sampling (dx,dy). This can be achieved by specifying a
+ probe_roi_shape, which is larger than diffraction_intensities_shape, which will result in
+ zero-padding of the diffraction intensities.
+
+ Parameters
+ ----------
+ datacube: Datacube
+ Input 4D diffraction pattern intensities
+ diffraction_intensities_shape: (int,int), optional
+ Resampled diffraction intensities shape.
+ If None, no resamping is performed
+ reshaping method: str, optional
+ Reshaping method to use, one of 'bin', 'bilinear' or 'fourier' (default)
+ probe_roi_shape, (int,int), optional
+ Padded diffraction intensities shape.
+ If None, no padding is performed
+ vacuum_probe_intensity, np.ndarray, optional
+ If not None, the vacuum probe intensity is also resampled and padded
+ dp_mask, np.ndarray, optional
+ If not None, dp_mask is also resampled and padded
+ com_shifts, np.ndarray, optional
+ If not None, com_shifts are multiplied by resampling factor
+
+ Returns
+ --------
+ datacube: Datacube
+ Resampled and Padded datacube
+ """
+ if com_shifts is not None:
+ if np.isscalar(com_shifts[0]):
+ com_shifts = (
+ np.ones(self._datacube.Rshape) * com_shifts[0],
+ np.ones(self._datacube.Rshape) * com_shifts[1],
+ )
+
+ if diffraction_intensities_shape is not None:
+ Qx, Qy = datacube.shape[-2:]
+ Sx, Sy = diffraction_intensities_shape
+
+ resampling_factor_x = Sx / Qx
+ resampling_factor_y = Sy / Qy
+
+ if resampling_factor_x != resampling_factor_y:
+ raise ValueError(
+ "Datacube calibration can only handle uniform Q-sampling."
+ )
+
+ if com_shifts is not None:
+ com_shifts = (
+ com_shifts[0] * resampling_factor_x,
+ com_shifts[1] * resampling_factor_x,
+ )
+
+ if reshaping_method == "bin":
+ bin_factor = int(1 / resampling_factor_x)
+ if bin_factor < 1:
+ raise ValueError(
+ f"Calculated binning factor {bin_factor} is less than 1."
+ )
+
+ datacube = datacube.bin_Q(N=bin_factor)
+ if vacuum_probe_intensity is not None:
+ vacuum_probe_intensity = vacuum_probe_intensity[
+ ::bin_factor, ::bin_factor
+ ]
+ if dp_mask is not None:
+ dp_mask = dp_mask[::bin_factor, ::bin_factor]
+ else:
+ datacube = datacube.resample_Q(
+ N=resampling_factor_x, method=reshaping_method
+ )
+ if vacuum_probe_intensity is not None:
+ vacuum_probe_intensity = fourier_resample(
+ vacuum_probe_intensity,
+ output_size=diffraction_intensities_shape,
+ force_nonnegative=True,
+ )
+ if dp_mask is not None:
+ dp_mask = fourier_resample(
+ dp_mask,
+ output_size=diffraction_intensities_shape,
+ force_nonnegative=True,
+ )
+
+ if probe_roi_shape is not None:
+ Qx, Qy = datacube.shape[-2:]
+ Sx, Sy = probe_roi_shape
+ datacube = datacube.pad_Q(output_size=probe_roi_shape)
+
+ if vacuum_probe_intensity is not None or dp_mask is not None:
+ pad_kx = Sx - Qx
+ pad_kx = (pad_kx // 2, pad_kx // 2 + pad_kx % 2)
+
+ pad_ky = Sy - Qy
+ pad_ky = (pad_ky // 2, pad_ky // 2 + pad_ky % 2)
+
+ if vacuum_probe_intensity is not None:
+ vacuum_probe_intensity = np.pad(
+ vacuum_probe_intensity, pad_width=(pad_kx, pad_ky), mode="constant"
+ )
+
+ if dp_mask is not None:
+ dp_mask = np.pad(dp_mask, pad_width=(pad_kx, pad_ky), mode="constant")
+
+ return datacube, vacuum_probe_intensity, dp_mask, com_shifts
+
+ def _extract_intensities_and_calibrations_from_datacube(
+ self,
+ datacube: DataCube,
+ require_calibrations: bool = False,
+ force_scan_sampling: float = None,
+ force_angular_sampling: float = None,
+ force_reciprocal_sampling: float = None,
+ ):
+ """
+ Method to extract intensities and calibrations from datacube.
+
+ Parameters
+ ----------
+ datacube: DataCube
+ Input 4D diffraction pattern intensities
+ require_calibrations: bool
+ If False, warning is issued instead of raising an error
+ force_scan_sampling: float, optional
+ Override DataCube real space scan pixel size calibrations, in Angstrom
+ force_angular_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in mrad
+ force_reciprocal_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in A^-1
+
+ Assigns
+ --------
+ self._grid_scan_shape: Tuple[int,int]
+ Real-space scan size
+ self._scan_sampling: Tuple[float,float]
+ Real-space scan step sizes in 'A' or 'pixels'
+ self._scan_units: Tuple[str,str]
+ Real-space scan units
+ self._angular_sampling: Tuple[float,float]
+ Reciprocal-space sampling in 'mrad' or 'pixels'
+ self._angular_units: Tuple[str,str]
+ Reciprocal-space angular units
+ self._reciprocal_sampling: Tuple[float,float]
+ Reciprocal-space sampling in 'A^-1' or 'pixels'
+ self._reciprocal_units: Tuple[str,str]
+ Reciprocal-space units
+
+ Returns
+ -------
+ intensities: (Rx,Ry,Qx,Qy) xp.ndarray
+ Raw intensities array stored on device, with dtype xp.float32
+
+ Raises
+ ------
+ ValueError
+ If require_calibrations is True and calibrations are not set
+
+ Warns
+ ------
+ UserWarning
+ If require_calibrations is False and calibrations are not set
+ """
+
+ # Copies intensities to device casting to float32
+ xp = self._xp
+
+ intensities = xp.asarray(datacube.data, dtype=xp.float32)
+ self._grid_scan_shape = intensities.shape[:2]
+
+ # Extracts calibrations
+ calibration = datacube.calibration
+ real_space_units = calibration.get_R_pixel_units()
+ reciprocal_space_units = calibration.get_Q_pixel_units()
+
+ # Real-space
+ if force_scan_sampling is not None:
+ self._scan_sampling = (force_scan_sampling, force_scan_sampling)
+ self._scan_units = "A"
+ else:
+ if real_space_units == "pixels":
+ if require_calibrations:
+ raise ValueError("Real-space calibrations must be given in 'A'")
+
+ if self._verbose:
+ warnings.warn(
+ (
+ "Iterative reconstruction will not be quantitative unless you specify "
+ "real-space calibrations in 'A'"
+ ),
+ UserWarning,
+ )
+
+ self._scan_sampling = (1.0, 1.0)
+ self._scan_units = ("pixels",) * 2
+
+ elif real_space_units == "A":
+ self._scan_sampling = (calibration.get_R_pixel_size(),) * 2
+ self._scan_units = ("A",) * 2
+ elif real_space_units == "nm":
+ self._scan_sampling = (calibration.get_R_pixel_size() * 10,) * 2
+ self._scan_units = ("A",) * 2
+ else:
+ raise ValueError(
+ f"Real-space calibrations must be given in 'A', not {real_space_units}"
+ )
+
+ # Reciprocal-space
+ if force_angular_sampling is not None or force_reciprocal_sampling is not None:
+ # there is no xor keyword in Python!
+ angular = force_angular_sampling is not None
+ reciprocal = force_reciprocal_sampling is not None
+ assert (angular and not reciprocal) or (
+ not angular and reciprocal
+ ), "Only one of angular or reciprocal calibration can be forced!"
+
+ # angular calibration specified
+ if angular:
+ self._angular_sampling = (force_angular_sampling,) * 2
+ self._angular_units = ("mrad",) * 2
+
+ if self._energy is not None:
+ self._reciprocal_sampling = (
+ force_angular_sampling
+ / electron_wavelength_angstrom(self._energy)
+ / 1e3,
+ ) * 2
+ self._reciprocal_units = ("A^-1",) * 2
+
+ # reciprocal calibration specified
+ if reciprocal:
+ self._reciprocal_sampling = (force_reciprocal_sampling,) * 2
+ self._reciprocal_units = ("A^-1",) * 2
+
+ if self._energy is not None:
+ self._angular_sampling = (
+ force_reciprocal_sampling
+ * electron_wavelength_angstrom(self._energy)
+ * 1e3,
+ ) * 2
+ self._angular_units = ("mrad",) * 2
+
+ else:
+ if reciprocal_space_units == "pixels":
+ if require_calibrations:
+ raise ValueError(
+ "Reciprocal-space calibrations must be given in in 'A^-1' or 'mrad'"
+ )
+
+ if self._verbose:
+ warnings.warn(
+ (
+ "Iterative reconstruction will not be quantitative unless you specify "
+ "appropriate reciprocal-space calibrations"
+ ),
+ UserWarning,
+ )
+
+ self._angular_sampling = (1.0, 1.0)
+ self._angular_units = ("pixels",) * 2
+ self._reciprocal_sampling = (1.0, 1.0)
+ self._reciprocal_units = ("pixels",) * 2
+
+ elif reciprocal_space_units == "A^-1":
+ reciprocal_size = calibration.get_Q_pixel_size()
+ self._reciprocal_sampling = (reciprocal_size,) * 2
+ self._reciprocal_units = ("A^-1",) * 2
+
+ if self._energy is not None:
+ self._angular_sampling = (
+ reciprocal_size
+ * electron_wavelength_angstrom(self._energy)
+ * 1e3,
+ ) * 2
+ self._angular_units = ("mrad",) * 2
+
+ elif reciprocal_space_units == "mrad":
+ angular_size = calibration.get_Q_pixel_size()
+ self._angular_sampling = (angular_size,) * 2
+ self._angular_units = ("mrad",) * 2
+
+ if self._energy is not None:
+ self._reciprocal_sampling = (
+ angular_size / electron_wavelength_angstrom(self._energy) / 1e3,
+ ) * 2
+ self._reciprocal_units = ("A^-1",) * 2
+ else:
+ raise ValueError(
+ (
+ "Reciprocal-space calibrations must be given in 'A^-1' or 'mrad', "
+ f"not {reciprocal_space_units}"
+ )
+ )
+
+ return intensities
+
+ def _calculate_intensities_center_of_mass(
+ self,
+ intensities: np.ndarray,
+ dp_mask: np.ndarray = None,
+ fit_function: str = "plane",
+ com_shifts: np.ndarray = None,
+ com_measured: np.ndarray = None,
+ ):
+ """
+ Common preprocessing function to compute and fit diffraction intensities CoM
+
+ Parameters
+ ----------
+ intensities: (Rx,Ry,Qx,Qy) xp.ndarray
+ Raw intensities array stored on device, with dtype xp.float32
+ dp_mask: ndarray
+ If not None, apply mask to datacube amplitude
+ fit_function: str, optional
+ 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two'
+ com_shifts, tuple of ndarrays (CoMx measured, CoMy measured)
+ If not None, com_shifts are fitted on the measured CoM values.
+ com_measured: tuple of ndarrays (CoMx measured, CoMy measured)
+ If not None, com_measured are passed as com_measured_x, com_measured_y
+ Returns
+ -------
+
+ com_measured_x: (Rx,Ry) xp.ndarray
+ Measured horizontal center of mass gradient
+ com_measured_y: (Rx,Ry) xp.ndarray
+ Measured vertical center of mass gradient
+ com_fitted_x: (Rx,Ry) xp.ndarray
+ Best fit horizontal center of mass gradient
+ com_fitted_y: (Rx,Ry) xp.ndarray
+ Best fit vertical center of mass gradient
+ com_normalized_x: (Rx,Ry) xp.ndarray
+ Normalized horizontal center of mass gradient
+ com_normalized_y: (Rx,Ry) xp.ndarray
+ Normalized vertical center of mass gradient
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # for ptycho
+ if com_measured:
+ com_measured_x, com_measured_y = com_measured
+
+ else:
+ # Coordinates
+ kx = xp.arange(intensities.shape[-2], dtype=xp.float32)
+ ky = xp.arange(intensities.shape[-1], dtype=xp.float32)
+ kya, kxa = xp.meshgrid(ky, kx)
+
+ # calculate CoM
+ if dp_mask is not None:
+ if dp_mask.shape != intensities.shape[-2:]:
+ raise ValueError(
+ (
+ f"Mask shape should be (Qx,Qy):{intensities.shape[-2:]}, "
+ f"not {dp_mask.shape}"
+ )
+ )
+ intensities_mask = intensities * xp.asarray(dp_mask, dtype=xp.float32)
+ else:
+ intensities_mask = intensities
+
+ intensities_sum = xp.sum(intensities_mask, axis=(-2, -1))
+ com_measured_x = (
+ xp.sum(intensities_mask * kxa[None, None], axis=(-2, -1))
+ / intensities_sum
+ )
+ com_measured_y = (
+ xp.sum(intensities_mask * kya[None, None], axis=(-2, -1))
+ / intensities_sum
+ )
+
+ if com_shifts is None:
+ com_measured_x_np = asnumpy(com_measured_x)
+ com_measured_y_np = asnumpy(com_measured_y)
+ finite_mask = np.isfinite(com_measured_x_np)
+
+ com_shifts = fit_origin(
+ (com_measured_x_np, com_measured_y_np),
+ fitfunction=fit_function,
+ mask=finite_mask,
+ )
+
+ # Fit function to center of mass
+ com_fitted_x = xp.asarray(com_shifts[0], dtype=xp.float32)
+ com_fitted_y = xp.asarray(com_shifts[1], dtype=xp.float32)
+
+ # fix CoM units
+ com_normalized_x = (
+ xp.nan_to_num(com_measured_x - com_fitted_x) * self._reciprocal_sampling[0]
+ )
+ com_normalized_y = (
+ xp.nan_to_num(com_measured_y - com_fitted_y) * self._reciprocal_sampling[1]
+ )
+
+ return (
+ com_measured_x,
+ com_measured_y,
+ com_fitted_x,
+ com_fitted_y,
+ com_normalized_x,
+ com_normalized_y,
+ )
+
+ def _solve_for_center_of_mass_relative_rotation(
+ self,
+ _com_measured_x: np.ndarray,
+ _com_measured_y: np.ndarray,
+ _com_normalized_x: np.ndarray,
+ _com_normalized_y: np.ndarray,
+ rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0),
+ plot_rotation: bool = True,
+ plot_center_of_mass: str = "default",
+ maximize_divergence: bool = False,
+ force_com_rotation: float = None,
+ force_com_transpose: bool = None,
+ **kwargs,
+ ):
+ """
+ Common method to solve for the relative rotation between scan directions
+ and the reciprocal coordinate system. We do this by minimizing the curl of the
+ CoM gradient vector field or, alternatively, maximizing the divergence.
+
+ Parameters
+ ----------
+ _com_measured_x: (Rx,Ry) xp.ndarray
+ Measured horizontal center of mass gradient
+ _com_measured_y: (Rx,Ry) xp.ndarray
+ Measured vertical center of mass gradient
+ _com_normalized_x: (Rx,Ry) xp.ndarray
+ Normalized horizontal center of mass gradient
+ _com_normalized_y: (Rx,Ry) xp.ndarray
+ Normalized vertical center of mass gradient
+ rotation_angles_deg: ndarray, optional
+ Array of angles in degrees to perform curl minimization over
+ plot_rotation: bool, optional
+ If True, the CoM curl minimization search result will be displayed
+ plot_center_of_mass: str, optional
+ If 'default', the corrected CoM arrays will be displayed
+ If 'all', the computed and fitted CoM arrays will be displayed
+ maximize_divergence: bool, optional
+ If True, the divergence of the CoM gradient vector field is maximized
+ force_com_rotation: float (degrees), optional
+ Force relative rotation angle between real and reciprocal space
+ force_com_transpose: bool, optional
+ Force whether diffraction intensities need to be transposed.
+
+ Returns
+ --------
+ _rotation_best_rad: float
+ Rotation angle which minimizes CoM curl, in radians
+ _rotation_best_transpose: bool
+ Whether diffraction intensities need to be transposed to minimize CoM curl
+ _com_x: xp.ndarray
+ Corrected horizontal center of mass gradient, on calculation device
+ _com_y: xp.ndarray
+ Corrected vertical center of mass gradient, on calculation device
+ com_x: np.ndarray
+ Corrected horizontal center of mass gradient, as a numpy array
+ com_y: np.ndarray
+ Corrected vertical center of mass gradient, as a numpy array
+
+ Displays
+ --------
+ rotation_curl/div vs rotation_angles_deg, optional
+ Vector calculus quantity being minimized/maximized
+ com_measured_x/y, com_normalized_x/y and com_x/y, optional
+ Measured and normalized CoM gradients
+ rotation_best_deg, optional
+ Summary statistics
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ if force_com_rotation is not None:
+ # Rotation known
+
+ _rotation_best_rad = np.deg2rad(force_com_rotation)
+
+ if self._verbose:
+ warnings.warn(
+ (
+ "Best fit rotation forced to "
+ f"{force_com_rotation:.0f} degrees."
+ ),
+ UserWarning,
+ )
+
+ if force_com_transpose is not None:
+ # Transpose known
+
+ _rotation_best_transpose = force_com_transpose
+
+ if self._verbose:
+ warnings.warn(
+ f"Transpose of intensities forced to {force_com_transpose}.",
+ UserWarning,
+ )
+
+ else:
+ # Rotation known, transpose unknown
+ com_measured_x = (
+ xp.cos(_rotation_best_rad) * _com_normalized_x
+ - xp.sin(_rotation_best_rad) * _com_normalized_y
+ )
+ com_measured_y = (
+ xp.sin(_rotation_best_rad) * _com_normalized_x
+ + xp.cos(_rotation_best_rad) * _com_normalized_y
+ )
+ if maximize_divergence:
+ com_grad_x_x = com_measured_x[2:, 1:-1] - com_measured_x[:-2, 1:-1]
+ com_grad_y_y = com_measured_y[1:-1, 2:] - com_measured_y[1:-1, :-2]
+ rotation_div = xp.mean(xp.abs(com_grad_x_x + com_grad_y_y))
+ else:
+ com_grad_x_y = com_measured_x[1:-1, 2:] - com_measured_x[1:-1, :-2]
+ com_grad_y_x = com_measured_y[2:, 1:-1] - com_measured_y[:-2, 1:-1]
+ rotation_curl = xp.mean(xp.abs(com_grad_y_x - com_grad_x_y))
+
+ com_measured_x = (
+ xp.cos(_rotation_best_rad) * _com_normalized_y
+ - xp.sin(_rotation_best_rad) * _com_normalized_x
+ )
+ com_measured_y = (
+ xp.sin(_rotation_best_rad) * _com_normalized_y
+ + xp.cos(_rotation_best_rad) * _com_normalized_x
+ )
+ if maximize_divergence:
+ com_grad_x_x = com_measured_x[2:, 1:-1] - com_measured_x[:-2, 1:-1]
+ com_grad_y_y = com_measured_y[1:-1, 2:] - com_measured_y[1:-1, :-2]
+ rotation_div_transpose = xp.mean(
+ xp.abs(com_grad_x_x + com_grad_y_y)
+ )
+ else:
+ com_grad_x_y = com_measured_x[1:-1, 2:] - com_measured_x[1:-1, :-2]
+ com_grad_y_x = com_measured_y[2:, 1:-1] - com_measured_y[:-2, 1:-1]
+ rotation_curl_transpose = xp.mean(
+ xp.abs(com_grad_y_x - com_grad_x_y)
+ )
+
+ if maximize_divergence:
+ _rotation_best_transpose = rotation_div_transpose > rotation_div
+ else:
+ _rotation_best_transpose = rotation_curl_transpose < rotation_curl
+
+ if self._verbose:
+ if _rotation_best_transpose:
+ print("Diffraction intensities should be transposed.")
+ else:
+ print("No need to transpose diffraction intensities.")
+
+ else:
+ # Rotation unknown
+ if force_com_transpose is not None:
+ # Transpose known, rotation unknown
+
+ _rotation_best_transpose = force_com_transpose
+
+ if self._verbose:
+ warnings.warn(
+ f"Transpose of intensities forced to {force_com_transpose}.",
+ UserWarning,
+ )
+
+ rotation_angles_deg = xp.asarray(rotation_angles_deg, dtype=xp.float32)
+ rotation_angles_rad = xp.deg2rad(rotation_angles_deg)[:, None, None]
+
+ if _rotation_best_transpose:
+ com_measured_x = (
+ xp.cos(rotation_angles_rad) * _com_normalized_y[None]
+ - xp.sin(rotation_angles_rad) * _com_normalized_x[None]
+ )
+ com_measured_y = (
+ xp.sin(rotation_angles_rad) * _com_normalized_y[None]
+ + xp.cos(rotation_angles_rad) * _com_normalized_x[None]
+ )
+
+ rotation_angles_rad = asnumpy(xp.squeeze(rotation_angles_rad))
+ rotation_angles_deg = asnumpy(rotation_angles_deg)
+
+ if maximize_divergence:
+ com_grad_x_x = (
+ com_measured_x[:, 2:, 1:-1] - com_measured_x[:, :-2, 1:-1]
+ )
+ com_grad_y_y = (
+ com_measured_y[:, 1:-1, 2:] - com_measured_y[:, 1:-1, :-2]
+ )
+ rotation_div_transpose = xp.mean(
+ xp.abs(com_grad_x_x + com_grad_y_y), axis=(-2, -1)
+ )
+
+ ind_trans_max = xp.argmax(rotation_div_transpose).item()
+ rotation_best_deg = rotation_angles_deg[ind_trans_max]
+ _rotation_best_rad = rotation_angles_rad[ind_trans_max]
+
+ else:
+ com_grad_x_y = (
+ com_measured_x[:, 1:-1, 2:] - com_measured_x[:, 1:-1, :-2]
+ )
+ com_grad_y_x = (
+ com_measured_y[:, 2:, 1:-1] - com_measured_y[:, :-2, 1:-1]
+ )
+ rotation_curl_transpose = xp.mean(
+ xp.abs(com_grad_y_x - com_grad_x_y), axis=(-2, -1)
+ )
+
+ ind_trans_min = xp.argmin(rotation_curl_transpose).item()
+ rotation_best_deg = rotation_angles_deg[ind_trans_min]
+ _rotation_best_rad = rotation_angles_rad[ind_trans_min]
+
+ else:
+ com_measured_x = (
+ xp.cos(rotation_angles_rad) * _com_normalized_x[None]
+ - xp.sin(rotation_angles_rad) * _com_normalized_y[None]
+ )
+ com_measured_y = (
+ xp.sin(rotation_angles_rad) * _com_normalized_x[None]
+ + xp.cos(rotation_angles_rad) * _com_normalized_y[None]
+ )
+
+ rotation_angles_rad = asnumpy(xp.squeeze(rotation_angles_rad))
+ rotation_angles_deg = asnumpy(rotation_angles_deg)
+
+ if maximize_divergence:
+ com_grad_x_x = (
+ com_measured_x[:, 2:, 1:-1] - com_measured_x[:, :-2, 1:-1]
+ )
+ com_grad_y_y = (
+ com_measured_y[:, 1:-1, 2:] - com_measured_y[:, 1:-1, :-2]
+ )
+ rotation_div = xp.mean(
+ xp.abs(com_grad_x_x + com_grad_y_y), axis=(-2, -1)
+ )
+
+ ind_max = xp.argmax(rotation_div).item()
+ rotation_best_deg = rotation_angles_deg[ind_max]
+ _rotation_best_rad = rotation_angles_rad[ind_max]
+
+ else:
+ com_grad_x_y = (
+ com_measured_x[:, 1:-1, 2:] - com_measured_x[:, 1:-1, :-2]
+ )
+ com_grad_y_x = (
+ com_measured_y[:, 2:, 1:-1] - com_measured_y[:, :-2, 1:-1]
+ )
+ rotation_curl = xp.mean(
+ xp.abs(com_grad_y_x - com_grad_x_y), axis=(-2, -1)
+ )
+
+ ind_min = xp.argmin(rotation_curl).item()
+ rotation_best_deg = rotation_angles_deg[ind_min]
+ _rotation_best_rad = rotation_angles_rad[ind_min]
+
+ if self._verbose:
+ print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees."))
+
+ if plot_rotation:
+ figsize = kwargs.get("figsize", (8, 2))
+ fig, ax = plt.subplots(figsize=figsize)
+
+ if _rotation_best_transpose:
+ ax.plot(
+ rotation_angles_deg,
+ asnumpy(rotation_div_transpose)
+ if maximize_divergence
+ else asnumpy(rotation_curl_transpose),
+ label="CoM after transpose",
+ )
+ else:
+ ax.plot(
+ rotation_angles_deg,
+ asnumpy(rotation_div)
+ if maximize_divergence
+ else asnumpy(rotation_curl),
+ label="CoM",
+ )
+
+ y_r = ax.get_ylim()
+ ax.plot(
+ np.ones(2) * rotation_best_deg,
+ y_r,
+ color=(0, 0, 0, 1),
+ )
+
+ ax.legend(loc="best")
+ ax.set_xlabel("Rotation [degrees]")
+ if maximize_divergence:
+ aspect_ratio = (
+ np.ptp(rotation_div_transpose)
+ if _rotation_best_transpose
+ else np.ptp(rotation_div)
+ )
+ ax.set_ylabel("Mean Absolute Divergence")
+ ax.set_aspect(np.ptp(rotation_angles_deg) / aspect_ratio / 4)
+ else:
+ aspect_ratio = (
+ np.ptp(rotation_curl_transpose)
+ if _rotation_best_transpose
+ else np.ptp(rotation_curl)
+ )
+ ax.set_ylabel("Mean Absolute Curl")
+ ax.set_aspect(np.ptp(rotation_angles_deg) / aspect_ratio / 4)
+ fig.tight_layout()
+
+ else:
+ # Transpose unknown, rotation unknown
+ rotation_angles_deg = xp.asarray(rotation_angles_deg, dtype=xp.float32)
+ rotation_angles_rad = xp.deg2rad(rotation_angles_deg)[:, None, None]
+
+ # Untransposed
+ com_measured_x = (
+ xp.cos(rotation_angles_rad) * _com_normalized_x[None]
+ - xp.sin(rotation_angles_rad) * _com_normalized_y[None]
+ )
+ com_measured_y = (
+ xp.sin(rotation_angles_rad) * _com_normalized_x[None]
+ + xp.cos(rotation_angles_rad) * _com_normalized_y[None]
+ )
+
+ if maximize_divergence:
+ com_grad_x_x = (
+ com_measured_x[:, 2:, 1:-1] - com_measured_x[:, :-2, 1:-1]
+ )
+ com_grad_y_y = (
+ com_measured_y[:, 1:-1, 2:] - com_measured_y[:, 1:-1, :-2]
+ )
+ rotation_div = xp.mean(
+ xp.abs(com_grad_x_x + com_grad_y_y), axis=(-2, -1)
+ )
+ else:
+ com_grad_x_y = (
+ com_measured_x[:, 1:-1, 2:] - com_measured_x[:, 1:-1, :-2]
+ )
+ com_grad_y_x = (
+ com_measured_y[:, 2:, 1:-1] - com_measured_y[:, :-2, 1:-1]
+ )
+ rotation_curl = xp.mean(
+ xp.abs(com_grad_y_x - com_grad_x_y), axis=(-2, -1)
+ )
+
+ # Transposed
+ com_measured_x = (
+ xp.cos(rotation_angles_rad) * _com_normalized_y[None]
+ - xp.sin(rotation_angles_rad) * _com_normalized_x[None]
+ )
+ com_measured_y = (
+ xp.sin(rotation_angles_rad) * _com_normalized_y[None]
+ + xp.cos(rotation_angles_rad) * _com_normalized_x[None]
+ )
+
+ if maximize_divergence:
+ com_grad_x_x = (
+ com_measured_x[:, 2:, 1:-1] - com_measured_x[:, :-2, 1:-1]
+ )
+ com_grad_y_y = (
+ com_measured_y[:, 1:-1, 2:] - com_measured_y[:, 1:-1, :-2]
+ )
+ rotation_div_transpose = xp.mean(
+ xp.abs(com_grad_x_x + com_grad_y_y), axis=(-2, -1)
+ )
+ else:
+ com_grad_x_y = (
+ com_measured_x[:, 1:-1, 2:] - com_measured_x[:, 1:-1, :-2]
+ )
+ com_grad_y_x = (
+ com_measured_y[:, 2:, 1:-1] - com_measured_y[:, :-2, 1:-1]
+ )
+ rotation_curl_transpose = xp.mean(
+ xp.abs(com_grad_y_x - com_grad_x_y), axis=(-2, -1)
+ )
+
+ rotation_angles_rad = asnumpy(xp.squeeze(rotation_angles_rad))
+ rotation_angles_deg = asnumpy(rotation_angles_deg)
+
+ # Find lowest curl/ maximum div value
+ if maximize_divergence:
+ # Maximize Divergence
+ ind_max = xp.argmax(rotation_div).item()
+ ind_trans_max = xp.argmax(rotation_div_transpose).item()
+
+ if rotation_div[ind_max] >= rotation_div_transpose[ind_trans_max]:
+ rotation_best_deg = rotation_angles_deg[ind_max]
+ _rotation_best_rad = rotation_angles_rad[ind_max]
+ _rotation_best_transpose = False
+ else:
+ rotation_best_deg = rotation_angles_deg[ind_trans_max]
+ _rotation_best_rad = rotation_angles_rad[ind_trans_max]
+ _rotation_best_transpose = True
+
+ self._rotation_div = rotation_div
+ self._rotation_div_transpose = rotation_div_transpose
+ else:
+ # Minimize Curl
+ ind_min = xp.argmin(rotation_curl).item()
+ ind_trans_min = xp.argmin(rotation_curl_transpose).item()
+ self._rotation_curl = rotation_curl
+ self._rotation_curl_transpose = rotation_curl_transpose
+ if rotation_curl[ind_min] <= rotation_curl_transpose[ind_trans_min]:
+ rotation_best_deg = rotation_angles_deg[ind_min]
+ _rotation_best_rad = rotation_angles_rad[ind_min]
+ _rotation_best_transpose = False
+ else:
+ rotation_best_deg = rotation_angles_deg[ind_trans_min]
+ _rotation_best_rad = rotation_angles_rad[ind_trans_min]
+ _rotation_best_transpose = True
+
+ self._rotation_angles_deg = rotation_angles_deg
+ # Print summary
+ if self._verbose:
+ print(("Best fit rotation = " f"{rotation_best_deg:.0f} degrees."))
+ if _rotation_best_transpose:
+ print("Diffraction intensities should be transposed.")
+ else:
+ print("No need to transpose diffraction intensities.")
+
+ # Plot Curl/Div rotation
+ if plot_rotation:
+ figsize = kwargs.get("figsize", (8, 2))
+ fig, ax = plt.subplots(figsize=figsize)
+
+ ax.plot(
+ rotation_angles_deg,
+ asnumpy(rotation_div)
+ if maximize_divergence
+ else asnumpy(rotation_curl),
+ label="CoM",
+ )
+ ax.plot(
+ rotation_angles_deg,
+ asnumpy(rotation_div_transpose)
+ if maximize_divergence
+ else asnumpy(rotation_curl_transpose),
+ label="CoM after transpose",
+ )
+ y_r = ax.get_ylim()
+ ax.plot(
+ np.ones(2) * rotation_best_deg,
+ y_r,
+ color=(0, 0, 0, 1),
+ )
+
+ ax.legend(loc="best")
+ ax.set_xlabel("Rotation [degrees]")
+ if maximize_divergence:
+ ax.set_ylabel("Mean Absolute Divergence")
+ ax.set_aspect(
+ np.ptp(rotation_angles_deg)
+ / np.maximum(
+ np.ptp(rotation_div),
+ np.ptp(rotation_div_transpose),
+ )
+ / 4
+ )
+ else:
+ ax.set_ylabel("Mean Absolute Curl")
+ ax.set_aspect(
+ np.ptp(rotation_angles_deg)
+ / np.maximum(
+ np.ptp(rotation_curl),
+ np.ptp(rotation_curl_transpose),
+ )
+ / 4
+ )
+ fig.tight_layout()
+
+ # Calculate corrected CoM
+ if _rotation_best_transpose:
+ _com_x = (
+ xp.cos(_rotation_best_rad) * _com_normalized_y
+ - xp.sin(_rotation_best_rad) * _com_normalized_x
+ )
+ _com_y = (
+ xp.sin(_rotation_best_rad) * _com_normalized_y
+ + xp.cos(_rotation_best_rad) * _com_normalized_x
+ )
+ else:
+ _com_x = (
+ xp.cos(_rotation_best_rad) * _com_normalized_x
+ - xp.sin(_rotation_best_rad) * _com_normalized_y
+ )
+ _com_y = (
+ xp.sin(_rotation_best_rad) * _com_normalized_x
+ + xp.cos(_rotation_best_rad) * _com_normalized_y
+ )
+
+ # 'Public'-facing attributes as numpy arrays
+ com_x = asnumpy(_com_x)
+ com_y = asnumpy(_com_y)
+
+ # Optionally, plot CoM
+ if plot_center_of_mass == "all":
+ figsize = kwargs.pop("figsize", (8, 12))
+ cmap = kwargs.pop("cmap", "RdBu_r")
+ extent = [
+ 0,
+ self._scan_sampling[1] * _com_measured_x.shape[1],
+ self._scan_sampling[0] * _com_measured_x.shape[0],
+ 0,
+ ]
+
+ fig = plt.figure(figsize=figsize)
+ grid = ImageGrid(fig, 111, nrows_ncols=(3, 2), axes_pad=(0.25, 0.5))
+
+ for ax, arr, title in zip(
+ grid,
+ [
+ _com_measured_x,
+ _com_measured_y,
+ _com_normalized_x,
+ _com_normalized_y,
+ com_x,
+ com_y,
+ ],
+ [
+ "CoM_x",
+ "CoM_y",
+ "Normalized CoM_x",
+ "Normalized CoM_y",
+ "Corrected CoM_x",
+ "Corrected CoM_y",
+ ],
+ ):
+ ax.imshow(asnumpy(arr), extent=extent, cmap=cmap, **kwargs)
+ ax.set_ylabel(f"x [{self._scan_units[0]}]")
+ ax.set_xlabel(f"y [{self._scan_units[1]}]")
+ ax.set_title(title)
+
+ elif plot_center_of_mass == "default":
+ figsize = kwargs.pop("figsize", (8, 4))
+ cmap = kwargs.pop("cmap", "RdBu_r")
+
+ extent = [
+ 0,
+ self._scan_sampling[1] * com_x.shape[1],
+ self._scan_sampling[0] * com_x.shape[0],
+ 0,
+ ]
+
+ fig = plt.figure(figsize=figsize)
+ grid = ImageGrid(fig, 111, nrows_ncols=(1, 2), axes_pad=(0.25, 0.5))
+
+ for ax, arr, title in zip(
+ grid,
+ [
+ com_x,
+ com_y,
+ ],
+ [
+ "Corrected CoM_x",
+ "Corrected CoM_y",
+ ],
+ ):
+ ax.imshow(arr, extent=extent, cmap=cmap, **kwargs)
+ ax.set_ylabel(f"x [{self._scan_units[0]}]")
+ ax.set_xlabel(f"y [{self._scan_units[1]}]")
+ ax.set_title(title)
+
+ return (
+ _rotation_best_rad,
+ _rotation_best_transpose,
+ _com_x,
+ _com_y,
+ com_x,
+ com_y,
+ )
+
+ def _normalize_diffraction_intensities(
+ self,
+ diffraction_intensities,
+ com_fitted_x,
+ com_fitted_y,
+ crop_patterns,
+ positions_mask,
+ ):
+ """
+ Fix diffraction intensities CoM, shift to origin, and take square root
+
+ Parameters
+ ----------
+ diffraction_intensities: (Rx,Ry,Sx,Sy) np.ndarray
+ Zero-padded diffraction intensities
+ com_fitted_x: (Rx,Ry) xp.ndarray
+ Best fit horizontal center of mass gradient
+ com_fitted_y: (Rx,Ry) xp.ndarray
+ Best fit vertical center of mass gradient
+ crop_patterns: bool
+ if True, crop patterns to avoid wrap around of patterns
+ when centering
+ positions_mask: np.ndarray, optional
+ Boolean real space mask to select positions in datacube to skip for reconstruction
+
+ Returns
+ -------
+ amplitudes: (Rx * Ry, Sx, Sy) np.ndarray
+ Flat array of normalized diffraction amplitudes
+ mean_intensity: float
+ Mean intensity value
+ """
+
+ xp = self._xp
+ mean_intensity = 0
+
+ diffraction_intensities = self._asnumpy(diffraction_intensities)
+ if positions_mask is not None:
+ number_of_patterns = np.count_nonzero(self._positions_mask.ravel())
+ else:
+ number_of_patterns = np.prod(diffraction_intensities.shape[:2])
+
+ if crop_patterns:
+ crop_x = int(
+ np.minimum(
+ diffraction_intensities.shape[2] - com_fitted_x.max(),
+ com_fitted_x.min(),
+ )
+ )
+ crop_y = int(
+ np.minimum(
+ diffraction_intensities.shape[3] - com_fitted_y.max(),
+ com_fitted_y.min(),
+ )
+ )
+
+ crop_w = np.minimum(crop_y, crop_x)
+ region_of_interest_shape = (crop_w * 2, crop_w * 2)
+ amplitudes = np.zeros(
+ (
+ number_of_patterns,
+ crop_w * 2,
+ crop_w * 2,
+ ),
+ dtype=np.float32,
+ )
+
+ crop_mask = np.zeros(diffraction_intensities.shape[-2:], dtype=np.bool_)
+ crop_mask[:crop_w, :crop_w] = True
+ crop_mask[-crop_w:, :crop_w] = True
+ crop_mask[:crop_w:, -crop_w:] = True
+ crop_mask[-crop_w:, -crop_w:] = True
+ self._crop_mask = crop_mask
+
+ else:
+ region_of_interest_shape = diffraction_intensities.shape[-2:]
+ amplitudes = np.zeros(
+ (number_of_patterns,) + region_of_interest_shape, dtype=np.float32
+ )
+
+ com_fitted_x = self._asnumpy(com_fitted_x)
+ com_fitted_y = self._asnumpy(com_fitted_y)
+
+ counter = 0
+ for rx in range(diffraction_intensities.shape[0]):
+ for ry in range(diffraction_intensities.shape[1]):
+ if positions_mask is not None:
+ if not self._positions_mask[rx, ry]:
+ continue
+ intensities = get_shifted_ar(
+ diffraction_intensities[rx, ry],
+ -com_fitted_x[rx, ry],
+ -com_fitted_y[rx, ry],
+ bilinear=True,
+ device="cpu",
+ )
+
+ if crop_patterns:
+ intensities = intensities[crop_mask].reshape(
+ region_of_interest_shape
+ )
+
+ mean_intensity += np.sum(intensities)
+ amplitudes[counter] = np.sqrt(np.maximum(intensities, 0))
+ counter += 1
+
+ amplitudes = xp.asarray(amplitudes)
+ mean_intensity /= amplitudes.shape[0]
+
+ return amplitudes, mean_intensity
+
+ def show_complex_CoM(
+ self,
+ com=None,
+ cbar=True,
+ scalebar=True,
+ pixelsize=None,
+ pixelunits=None,
+ **kwargs,
+ ):
+ """
+ Plot complex-valued CoM image
+
+ Parameters
+ ----------
+
+ com = (CoM_x, CoM_y) tuple
+ If None is specified, uses (self.com_x, self.com_y) instead
+ cbar: bool, optional
+ if True, adds colorbar
+ scalebar: bool, optional
+ if True, adds scalebar to probe
+ pixelunits: str, optional
+ units for scalebar, default is A
+ pixelsize: float, optional
+ default is scan sampling
+ """
+
+ if com is None:
+ com = (self.com_x, self.com_y)
+
+ if pixelsize is None:
+ pixelsize = self._scan_sampling[0]
+ if pixelunits is None:
+ pixelunits = self._scan_units[0]
+
+ figsize = kwargs.pop("figsize", (6, 6))
+ fig, ax = plt.subplots(figsize=figsize)
+
+ complex_com = com[0] + 1j * com[1]
+
+ show_complex(
+ complex_com,
+ cbar=cbar,
+ figax=(fig, ax),
+ scalebar=scalebar,
+ pixelsize=pixelsize,
+ pixelunits=pixelunits,
+ ticks=False,
+ **kwargs,
+ )
+
+
+class PtychographicReconstruction(PhaseReconstruction, PtychographicConstraints):
+ """
+ Base ptychographic reconstruction class.
+ Inherits from PhaseReconstruction and PtychographicConstraints.
+ Defines various common functions and properties for subclasses to inherit.
+ """
+
+ def to_h5(self, group):
+ """
+ Wraps datasets and metadata to write in emdfile classes,
+ notably: the object and probe arrays.
+ """
+
+ asnumpy = self._asnumpy
+
+ # instantiation metadata
+ tf = AffineTransform(angle=-self._rotation_best_rad)
+ pos = self.positions
+
+ if pos.ndim == 2:
+ origin = np.mean(pos, axis=0)
+ else:
+ origin = np.mean(pos, axis=(0, 1))
+ scan_positions = tf(pos, origin)
+
+ vacuum_probe_intensity = (
+ asnumpy(self._vacuum_probe_intensity)
+ if self._vacuum_probe_intensity is not None
+ else None
+ )
+ metadata = {
+ "energy": self._energy,
+ "semiangle_cutoff": self._semiangle_cutoff,
+ "rolloff": self._rolloff,
+ "object_padding_px": self._object_padding_px,
+ "object_type": self._object_type,
+ "verbose": self._verbose,
+ "device": self._device,
+ "name": self.name,
+ "vacuum_probe_intensity": vacuum_probe_intensity,
+ "positions": scan_positions,
+ }
+
+ cls = self.__class__
+ class_specific_metadata = {}
+ for key in cls._class_specific_metadata:
+ class_specific_metadata[key[1:]] = getattr(self, key, None)
+
+ metadata |= class_specific_metadata
+
+ self.metadata = Metadata(
+ name="instantiation_metadata",
+ data=metadata,
+ )
+
+ # preprocessing metadata
+ self.metadata = Metadata(
+ name="preprocess_metadata",
+ data={
+ "rotation_angle_rad": self._rotation_best_rad,
+ "data_transpose": self._rotation_best_transpose,
+ "positions_px": asnumpy(self._positions_px),
+ "region_of_interest_shape": self._region_of_interest_shape,
+ "num_diffraction_patterns": self._num_diffraction_patterns,
+ "sampling": self.sampling,
+ "angular_sampling": self.angular_sampling,
+ "positions_mask": self._positions_mask,
+ },
+ )
+
+ # reconstruction metadata
+ is_stack = self._save_iterations and hasattr(self, "object_iterations")
+ if is_stack:
+ num_iterations = len(self.object_iterations)
+ iterations = list(range(0, num_iterations, self._save_iterations_frequency))
+ if num_iterations - 1 not in iterations:
+ iterations.append(num_iterations - 1)
+
+ error = [self.error_iterations[i] for i in iterations]
+ else:
+ error = getattr(self, "error", 0.0)
+
+ self.metadata = Metadata(
+ name="reconstruction_metadata",
+ data={
+ "reconstruction_error": error,
+ },
+ )
+
+ # aberrations metadata
+ self.metadata = Metadata(
+ name="aberrations_metadata",
+ data=self._polar_parameters,
+ )
+
+ # object
+ self._object_emd = Array(
+ name="reconstruction_object",
+ data=asnumpy(self._xp.asarray(self._object)),
+ )
+
+ # probe
+ self._probe_emd = Array(name="reconstruction_probe", data=asnumpy(self._probe))
+
+ if is_stack:
+ iterations_labels = [f"iteration_{i:03}" for i in iterations]
+
+ # object
+ object_iterations = [
+ np.asarray(self.object_iterations[i]) for i in iterations
+ ]
+ self._object_iterations_emd = Array(
+ name="reconstruction_object_iterations",
+ data=np.stack(object_iterations, axis=0),
+ slicelabels=iterations_labels,
+ )
+
+ # probe
+ probe_iterations = [self.probe_iterations[i] for i in iterations]
+ self._probe_iterations_emd = Array(
+ name="reconstruction_probe_iterations",
+ data=np.stack(probe_iterations, axis=0),
+ slicelabels=iterations_labels,
+ )
+
+ # exit_waves
+ if self._save_exit_waves:
+ self._exit_waves_emd = Array(
+ name="reconstruction_exit_waves",
+ data=asnumpy(self._xp.asarray(self._exit_waves)),
+ )
+
+ # datacube
+ if self._save_datacube:
+ self.metadata = self._datacube.calibration
+ Custom.to_h5(self, group)
+ else:
+ dc = self._datacube
+ self._datacube = None
+ Custom.to_h5(self, group)
+ self._datacube = dc
+
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of arguments/values to pass
+ to the class' __init__ function
+ """
+ # Get data
+ dict_data = cls._get_emd_attr_data(cls, group)
+
+ # Get metadata dictionaries
+ instance_md = _read_metadata(group, "instantiation_metadata")
+ polar_params = _read_metadata(group, "aberrations_metadata")._params
+
+ # Fix calibrations bug
+ if "_datacube" in dict_data:
+ calibrations_dict = _read_metadata(group, "calibration")._params
+ cal = Calibration()
+ cal._params.update(calibrations_dict)
+ dc = dict_data["_datacube"]
+ dc.calibration = cal
+ else:
+ dc = None
+
+ obj = dict_data["_object_emd"].data
+ probe = dict_data["_probe_emd"].data
+
+ # Populate args and return
+ kwargs = {
+ "datacube": dc,
+ "initial_object_guess": np.asarray(obj),
+ "initial_probe_guess": np.asarray(probe),
+ "vacuum_probe_intensity": instance_md["vacuum_probe_intensity"],
+ "initial_scan_positions": instance_md["positions"],
+ "energy": instance_md["energy"],
+ "object_padding_px": instance_md["object_padding_px"],
+ "object_type": instance_md["object_type"],
+ "semiangle_cutoff": instance_md["semiangle_cutoff"],
+ "rolloff": instance_md["rolloff"],
+ "name": instance_md["name"],
+ "polar_parameters": polar_params,
+ "verbose": True, # for compatibility
+ "device": "cpu", # for compatibility
+ }
+
+ class_specific_kwargs = {}
+ for key in cls._class_specific_metadata:
+ class_specific_kwargs[key[1:]] = instance_md[key[1:]]
+
+ kwargs |= class_specific_kwargs
+
+ return kwargs
+
+ def _populate_instance(self, group):
+ """
+ Sets post-initialization properties, notably some preprocessing meta
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # Preprocess metadata
+ preprocess_md = _read_metadata(group, "preprocess_metadata")
+ self._rotation_best_rad = preprocess_md["rotation_angle_rad"]
+ self._rotation_best_transpose = preprocess_md["data_transpose"]
+ self._positions_px = xp.asarray(preprocess_md["positions_px"])
+ self._angular_sampling = preprocess_md["angular_sampling"]
+ self._region_of_interest_shape = preprocess_md["region_of_interest_shape"]
+ self._num_diffraction_patterns = preprocess_md["num_diffraction_patterns"]
+ self._positions_mask = preprocess_md["positions_mask"]
+
+ # Reconstruction metadata
+ reconstruction_md = _read_metadata(group, "reconstruction_metadata")
+ error = reconstruction_md["reconstruction_error"]
+
+ # Data
+ dict_data = Custom._get_emd_attr_data(Custom, group)
+ if "_exit_waves_emd" in dict_data:
+ self._exit_waves = dict_data["_exit_waves_emd"].data
+ self._exit_waves = xp.asarray(self._exit_waves, dtype=xp.complex64)
+ else:
+ self._exit_waves = None
+
+ # Check if stack
+ if hasattr(error, "__len__"):
+ self.object_iterations = list(dict_data["_object_iterations_emd"].data)
+ self.probe_iterations = list(dict_data["_probe_iterations_emd"].data)
+ self.error_iterations = error
+ self.error = error[-1]
+ else:
+ self.error = error
+
+ # Slim preprocessing to enable visualize
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self.object = asnumpy(self._object)
+ self.probe = self.probe_centered
+ self._preprocessed = True
+
+ def _set_polar_parameters(self, parameters: dict):
+ """
+ Set the probe aberrations dictionary.
+
+ Parameters
+ ----------
+ parameters: dict
+ Mapping from aberration symbols to their corresponding values.
+ """
+
+ for symbol, value in parameters.items():
+ if symbol in self._polar_parameters.keys():
+ self._polar_parameters[symbol] = value
+
+ elif symbol == "defocus":
+ self._polar_parameters[polar_aliases[symbol]] = -value
+
+ elif symbol in polar_aliases.keys():
+ self._polar_parameters[polar_aliases[symbol]] = value
+
+ else:
+ raise ValueError("{} not a recognized parameter".format(symbol))
+
+ def _calculate_scan_positions_in_pixels(
+ self, positions: np.ndarray, positions_mask
+ ):
+ """
+ Method to compute the initial guess of scan positions in pixels.
+
+ Parameters
+ ----------
+ positions: (J,2) np.ndarray or None
+ Input probe positions in Å.
+ If None, a raster scan using experimental parameters is constructed.
+ positions_mask: np.ndarray, optional
+ Boolean real space mask to select positions in datacube to skip for reconstruction
+
+ Returns
+ -------
+ positions_in_px: (J,2) np.ndarray
+ Initial guess of scan positions in pixels
+ """
+
+ grid_scan_shape = self._grid_scan_shape
+ rotation_angle = self._rotation_best_rad
+ step_sizes = self._scan_sampling
+
+ if positions is None:
+ if grid_scan_shape is not None:
+ nx, ny = grid_scan_shape
+
+ if step_sizes is not None:
+ sx, sy = step_sizes
+ x = np.arange(nx) * sx
+ y = np.arange(ny) * sy
+ else:
+ raise ValueError()
+ else:
+ raise ValueError()
+
+ if self._rotation_best_transpose:
+ x = (x - np.ptp(x) / 2) / self.sampling[1]
+ y = (y - np.ptp(y) / 2) / self.sampling[0]
+ else:
+ x = (x - np.ptp(x) / 2) / self.sampling[0]
+ y = (y - np.ptp(y) / 2) / self.sampling[1]
+ x, y = np.meshgrid(x, y, indexing="ij")
+ if positions_mask is not None:
+ x = x[positions_mask]
+ y = y[positions_mask]
+ else:
+ positions -= np.mean(positions, axis=0)
+ x = positions[:, 0] / self.sampling[1]
+ y = positions[:, 1] / self.sampling[0]
+
+ if rotation_angle is not None:
+ x, y = x * np.cos(rotation_angle) + y * np.sin(rotation_angle), -x * np.sin(
+ rotation_angle
+ ) + y * np.cos(rotation_angle)
+
+ if self._rotation_best_transpose:
+ positions = np.array([y.ravel(), x.ravel()]).T
+ else:
+ positions = np.array([x.ravel(), y.ravel()]).T
+ positions -= np.min(positions, axis=0)
+
+ if self._object_padding_px is None:
+ float_padding = self._region_of_interest_shape / 2
+ self._object_padding_px = (float_padding, float_padding)
+ elif np.isscalar(self._object_padding_px[0]):
+ self._object_padding_px = (
+ (self._object_padding_px[0],) * 2,
+ (self._object_padding_px[1],) * 2,
+ )
+
+ positions[:, 0] += self._object_padding_px[0][0]
+ positions[:, 1] += self._object_padding_px[1][0]
+
+ return positions
+
+ def _sum_overlapping_patches_bincounts_base(self, patches: np.ndarray):
+ """
+ Base bincouts overlapping patches sum function, operating on real-valued arrays.
+ Note this assumes the probe is corner-centered.
+
+ Parameters
+ ----------
+ patches: (Rx*Ry,Sx,Sy) np.ndarray
+ Patches to sum
+
+ Returns
+ -------
+ out_array: (Px,Py) np.ndarray
+ Summed array
+ """
+ xp = self._xp
+ x0 = xp.round(self._positions_px[:, 0]).astype("int")
+ y0 = xp.round(self._positions_px[:, 1]).astype("int")
+
+ roi_shape = self._region_of_interest_shape
+ x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int")
+ y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int")
+
+ flat_weights = patches.ravel()
+ indices = (
+ (y0[:, None, None] + y_ind[None, None, :]) % self._object_shape[1]
+ ) + (
+ (x0[:, None, None] + x_ind[None, :, None]) % self._object_shape[0]
+ ) * self._object_shape[
+ 1
+ ]
+ counts = xp.bincount(
+ indices.ravel(), weights=flat_weights, minlength=np.prod(self._object_shape)
+ )
+ return xp.reshape(counts, self._object_shape)
+
+ def _sum_overlapping_patches_bincounts(self, patches: np.ndarray):
+ """
+ Sum overlapping patches defined into object shaped array using bincounts.
+ Calls _sum_overlapping_patches_bincounts_base on Real and Imaginary parts.
+
+ Parameters
+ ----------
+ patches: (Rx*Ry,Sx,Sy) np.ndarray
+ Patches to sum
+
+ Returns
+ -------
+ out_array: (Px,Py) np.ndarray
+ Summed array
+ """
+
+ xp = self._xp
+ if xp.iscomplexobj(patches):
+ real = self._sum_overlapping_patches_bincounts_base(xp.real(patches))
+ imag = self._sum_overlapping_patches_bincounts_base(xp.imag(patches))
+ return real + 1.0j * imag
+ else:
+ return self._sum_overlapping_patches_bincounts_base(patches)
+
+ def _extract_vectorized_patch_indices(self):
+ """
+ Sets the vectorized row/col indices used for the overlap projection
+ Note this assumes the probe is corner-centered.
+
+ Returns
+ -------
+ self._vectorized_patch_indices_row: np.ndarray
+ Row indices for probe patches inside object array
+ self._vectorized_patch_indices_col: np.ndarray
+ Column indices for probe patches inside object array
+ """
+ xp = self._xp
+ x0 = xp.round(self._positions_px[:, 0]).astype("int")
+ y0 = xp.round(self._positions_px[:, 1]).astype("int")
+
+ roi_shape = self._region_of_interest_shape
+ x_ind = xp.fft.fftfreq(roi_shape[0], d=1 / roi_shape[0]).astype("int")
+ y_ind = xp.fft.fftfreq(roi_shape[1], d=1 / roi_shape[1]).astype("int")
+
+ obj_shape = self._object_shape
+ vectorized_patch_indices_row = (
+ x0[:, None, None] + x_ind[None, :, None]
+ ) % obj_shape[0]
+ vectorized_patch_indices_col = (
+ y0[:, None, None] + y_ind[None, None, :]
+ ) % obj_shape[1]
+
+ return vectorized_patch_indices_row, vectorized_patch_indices_col
+
+ def _crop_rotate_object_fov(
+ self,
+ array,
+ padding=0,
+ ):
+ """
+ Crops and rotated object to FOV bounded by current pixel positions.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ Object array to crop and rotate. Only operates on numpy arrays for comptatibility.
+ padding: int, optional
+ Optional padding outside pixel positions
+
+ Returns
+ cropped_rotated_array: np.ndarray
+ Cropped and rotated object array
+ """
+
+ asnumpy = self._asnumpy
+ angle = (
+ self._rotation_best_rad
+ if self._rotation_best_transpose
+ else -self._rotation_best_rad
+ )
+
+ tf = AffineTransform(angle=angle)
+ rotated_points = tf(
+ asnumpy(self._positions_px), origin=asnumpy(self._positions_px_com), xp=np
+ )
+
+ min_x, min_y = np.floor(np.amin(rotated_points, axis=0) - padding).astype("int")
+ min_x = min_x if min_x > 0 else 0
+ min_y = min_y if min_y > 0 else 0
+ max_x, max_y = np.ceil(np.amax(rotated_points, axis=0) + padding).astype("int")
+
+ rotated_array = rotate(
+ asnumpy(array), np.rad2deg(-angle), reshape=False, axes=(-2, -1)
+ )[..., min_x:max_x, min_y:max_y]
+
+ if self._rotation_best_transpose:
+ rotated_array = rotated_array.swapaxes(-2, -1)
+
+ return rotated_array
+
+ def tune_angle_and_defocus(
+ self,
+ angle_guess=None,
+ defocus_guess=None,
+ transpose=None,
+ angle_step_size=1,
+ defocus_step_size=20,
+ num_angle_values=5,
+ num_defocus_values=5,
+ max_iter=5,
+ plot_reconstructions=True,
+ plot_convergence=True,
+ return_values=False,
+ **kwargs,
+ ):
+ """
+ Run reconstructions over a parameters space of angles and
+ defocus values. Should be run after preprocess step.
+
+ Parameters
+ ----------
+ angle_guess: float (degrees), optional
+ initial starting guess for rotation angle between real and reciprocal space
+ if None, uses current initialized values
+ defocus_guess: float (A), optional
+ initial starting guess for defocus
+ if None, uses current initialized values
+ angle_step_size: float (degrees), optional
+ size of change of rotation angle between real and reciprocal space for
+ each step in parameter space
+ defocus_step_size: float (A), optional
+ size of change of defocus for each step in parameter space
+ num_angle_values: int, optional
+ number of values of angle to test, must be >= 1.
+ num_defocus_values: int,optional
+ number of values of defocus to test, must be >= 1
+ max_iter: int, optional
+ number of iterations to run in ptychographic reconstruction
+ plot_reconstructions: bool, optional
+ if True, plot phase of reconstructed objects
+ plot_convergence: bool, optional
+ if True, plots error for each iteration for each reconstruction.
+ return_values: bool, optional
+ if True, returns objects, convergence
+
+ Returns
+ -------
+ objects: list
+ reconstructed objects
+ convergence: np.ndarray
+ array of convergence values from reconstructions
+ """
+ # calculate angles and defocus values to test
+ if angle_guess is None:
+ angle_guess = self._rotation_best_rad * 180 / np.pi
+ if defocus_guess is None:
+ defocus_guess = -self._polar_parameters["C10"]
+ if transpose is None:
+ transpose = self._rotation_best_transpose
+
+ if num_angle_values == 1:
+ angle_step_size = 0
+
+ if num_defocus_values == 1:
+ defocus_step_size = 0
+
+ angles = np.linspace(
+ angle_guess - angle_step_size * (num_angle_values - 1) / 2,
+ angle_guess + angle_step_size * (num_angle_values - 1) / 2,
+ num_angle_values,
+ )
+
+ defocus_values = np.linspace(
+ defocus_guess - defocus_step_size * (num_defocus_values - 1) / 2,
+ defocus_guess + defocus_step_size * (num_defocus_values - 1) / 2,
+ num_defocus_values,
+ )
+
+ if return_values:
+ convergence = []
+ objects = []
+
+ # current initialized values
+ current_verbose = self._verbose
+ current_defocus = -self._polar_parameters["C10"]
+ current_rotation_deg = self._rotation_best_rad * 180 / np.pi
+ current_transpose = self._rotation_best_transpose
+
+ # Gridspec to plot on
+ if plot_reconstructions:
+ if plot_convergence:
+ spec = GridSpec(
+ ncols=num_defocus_values,
+ nrows=num_angle_values * 2,
+ height_ratios=[1, 1 / 4] * num_angle_values,
+ hspace=0.15,
+ wspace=0.35,
+ )
+ figsize = kwargs.get(
+ "figsize", (4 * num_defocus_values, 5 * num_angle_values)
+ )
+ else:
+ spec = GridSpec(
+ ncols=num_defocus_values,
+ nrows=num_angle_values,
+ hspace=0.15,
+ wspace=0.35,
+ )
+ figsize = kwargs.get(
+ "figsize", (4 * num_defocus_values, 4 * num_angle_values)
+ )
+
+ fig = plt.figure(figsize=figsize)
+
+ progress_bar = kwargs.pop("progress_bar", False)
+ # run loop and plot along the way
+ self._verbose = False
+ for flat_index, (angle, defocus) in enumerate(
+ tqdmnd(angles, defocus_values, desc="Tuning angle and defocus")
+ ):
+ self._polar_parameters["C10"] = -defocus
+ self._probe = None
+ self._object = None
+ self.preprocess(
+ force_com_rotation=angle,
+ force_com_transpose=transpose,
+ plot_center_of_mass=False,
+ plot_rotation=False,
+ plot_probe_overlaps=False,
+ )
+
+ self.reconstruct(
+ reset=True,
+ store_iterations=True,
+ max_iter=max_iter,
+ progress_bar=progress_bar,
+ **kwargs,
+ )
+
+ if plot_reconstructions:
+ row_index, col_index = np.unravel_index(
+ flat_index, (num_angle_values, num_defocus_values)
+ )
+
+ if plot_convergence:
+ object_ax = fig.add_subplot(spec[row_index * 2, col_index])
+ convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index])
+ self._visualize_last_iteration_figax(
+ fig,
+ object_ax=object_ax,
+ convergence_ax=convergence_ax,
+ cbar=True,
+ )
+ convergence_ax.yaxis.tick_right()
+ else:
+ object_ax = fig.add_subplot(spec[row_index, col_index])
+ self._visualize_last_iteration_figax(
+ fig,
+ object_ax=object_ax,
+ convergence_ax=None,
+ cbar=True,
+ )
+
+ object_ax.set_title(
+ f" angle = {angle:.1f} °, defocus = {defocus:.1f} A \n error = {self.error:.3e}"
+ )
+ object_ax.set_xticks([])
+ object_ax.set_yticks([])
+
+ if return_values:
+ objects.append(self.object)
+ convergence.append(self.error_iterations.copy())
+
+ # initialize back to pre-tuning values
+ self._polar_parameters["C10"] = -current_defocus
+ self._probe = None
+ self._object = None
+ self.preprocess(
+ force_com_rotation=current_rotation_deg,
+ force_com_transpose=current_transpose,
+ plot_center_of_mass=False,
+ plot_rotation=False,
+ plot_probe_overlaps=False,
+ )
+ self._verbose = current_verbose
+
+ if plot_reconstructions:
+ spec.tight_layout(fig)
+
+ if return_values:
+ return objects, convergence
+
+ def _position_correction(
+ self,
+ relevant_object,
+ relevant_probes,
+ relevant_overlap,
+ relevant_amplitudes,
+ current_positions,
+ positions_step_size,
+ constrain_position_distance,
+ ):
+ """
+ Position correction using estimated intensity gradient.
+
+ Parameters
+ --------
+ relevant_object: np.ndarray
+ Current object estimate
+ relevant_probes:np.ndarray
+ fractionally-shifted probes
+ relevant_overlap: np.ndarray
+ object * probe overlap
+ relevant_amplitudes: np.ndarray
+ Measured amplitudes
+ current_positions: np.ndarray
+ Current positions estimate
+ positions_step_size: float
+ Positions step size
+ constrain_position_distance: float
+ Distance to constrain position correction within original
+ field of view in A
+
+ Returns
+ --------
+ updated_positions: np.ndarray
+ Updated positions estimate
+ """
+
+ xp = self._xp
+
+ if self._object_type == "potential":
+ complex_object = xp.exp(1j * relevant_object)
+ else:
+ complex_object = relevant_object
+
+ obj_rolled_x_patches = complex_object[
+ (self._vectorized_patch_indices_row + 1) % self._object_shape[0],
+ self._vectorized_patch_indices_col,
+ ]
+ obj_rolled_y_patches = complex_object[
+ self._vectorized_patch_indices_row,
+ (self._vectorized_patch_indices_col + 1) % self._object_shape[1],
+ ]
+
+ overlap_fft = xp.fft.fft2(relevant_overlap)
+
+ exit_waves_dx_fft = overlap_fft - xp.fft.fft2(
+ obj_rolled_x_patches * relevant_probes
+ )
+ exit_waves_dy_fft = overlap_fft - xp.fft.fft2(
+ obj_rolled_y_patches * relevant_probes
+ )
+
+ overlap_fft_conj = xp.conj(overlap_fft)
+ estimated_intensity = xp.abs(overlap_fft) ** 2
+ measured_intensity = relevant_amplitudes**2
+
+ flat_shape = (relevant_overlap.shape[0], -1)
+ difference_intensity = (measured_intensity - estimated_intensity).reshape(
+ flat_shape
+ )
+
+ partial_intensity_dx = 2 * xp.real(
+ exit_waves_dx_fft * overlap_fft_conj
+ ).reshape(flat_shape)
+ partial_intensity_dy = 2 * xp.real(
+ exit_waves_dy_fft * overlap_fft_conj
+ ).reshape(flat_shape)
+
+ coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy))
+
+ # positions_update = xp.einsum(
+ # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity
+ # )
+
+ coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2)
+ positions_update = (
+ xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix)
+ @ coefficients_matrix_T
+ @ difference_intensity[..., None]
+ )
+
+ if constrain_position_distance is not None:
+ constrain_position_distance /= xp.sqrt(
+ self.sampling[0] ** 2 + self.sampling[1] ** 2
+ )
+ x1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 0
+ ]
+ y1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 1
+ ]
+ x0 = self._positions_px_initial[:, 0]
+ y0 = self._positions_px_initial[:, 1]
+ if self._rotation_best_transpose:
+ x0, y0 = xp.array([y0, x0])
+ x1, y1 = xp.array([y1, x1])
+
+ if self._rotation_best_rad is not None:
+ rotation_angle = self._rotation_best_rad
+ x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin(
+ -rotation_angle
+ ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle)
+ x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin(
+ -rotation_angle
+ ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle)
+
+ outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + (
+ x1 < (xp.min(x0) - constrain_position_distance)
+ ) + (y1 > (xp.max(y0) + constrain_position_distance)) + (
+ y1 < (xp.min(y0) - constrain_position_distance)
+ ) > 0
+
+ positions_update[..., 0][outlier_ind] = 0
+
+ current_positions -= positions_step_size * positions_update[..., 0]
+ return current_positions
+
+ def plot_position_correction(
+ self,
+ scale_arrows=1,
+ plot_arrow_freq=1,
+ verbose=True,
+ **kwargs,
+ ):
+ """
+ Function to plot changes to probe positions during ptychography reconstruciton
+
+ Parameters
+ ----------
+ scale_arrows: float, optional
+ scaling factor to be applied on vectors prior to plt.quiver call
+ verbose: bool, optional
+ if True, prints AffineTransformation if positions have been updated
+ """
+ if verbose:
+ if hasattr(self, "_tf"):
+ print(self._tf)
+
+ asnumpy = self._asnumpy
+
+ extent = [
+ 0,
+ self.sampling[1] * self._object_shape[1],
+ self.sampling[0] * self._object_shape[0],
+ 0,
+ ]
+
+ initial_pos = asnumpy(self._positions_initial)
+ pos = self.positions
+
+ figsize = kwargs.pop("figsize", (6, 6))
+ color = kwargs.pop("color", (1, 0, 0, 1))
+
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.quiver(
+ initial_pos[::plot_arrow_freq, 1],
+ initial_pos[::plot_arrow_freq, 0],
+ (pos[::plot_arrow_freq, 1] - initial_pos[::plot_arrow_freq, 1])
+ * scale_arrows,
+ (pos[::plot_arrow_freq, 0] - initial_pos[::plot_arrow_freq, 0])
+ * scale_arrows,
+ scale_units="xy",
+ scale=1,
+ color=color,
+ **kwargs,
+ )
+
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ ax.set_xlim((extent[0], extent[1]))
+ ax.set_ylim((extent[2], extent[3]))
+ ax.set_aspect("equal")
+ ax.set_title("Probe positions correction")
+
+ def _return_fourier_probe(
+ self,
+ probe=None,
+ ):
+ """
+ Returns complex fourier probe shifted to center of array from
+ corner-centered complex real space probe
+
+ Parameters
+ ----------
+ probe: complex array, optional
+ if None is specified, uses self._probe
+
+ Returns
+ -------
+ fourier_probe: np.ndarray
+ Fourier-transformed and center-shifted probe.
+ """
+ xp = self._xp
+
+ if probe is None:
+ probe = self._probe
+ else:
+ probe = xp.asarray(probe, dtype=xp.complex64)
+
+ return xp.fft.fftshift(xp.fft.fft2(probe), axes=(-2, -1))
+
+ def _return_fourier_probe_from_centered_probe(
+ self,
+ probe=None,
+ ):
+ """
+ Returns complex fourier probe shifted to center of array from
+ centered complex real space probe
+
+ Parameters
+ ----------
+ probe: complex array, optional
+ if None is specified, uses self._probe
+
+ Returns
+ -------
+ fourier_probe: np.ndarray
+ Fourier-transformed and center-shifted probe.
+ """
+ xp = self._xp
+ return self._return_fourier_probe(xp.fft.ifftshift(probe, axes=(-2, -1)))
+
+ def _return_centered_probe(
+ self,
+ probe=None,
+ ):
+ """
+ Returns complex probe centered in middle of the array.
+
+ Parameters
+ ----------
+ probe: complex array, optional
+ if None is specified, uses self._probe
+
+ Returns
+ -------
+ centered_probe: np.ndarray
+ Center-shifted probe.
+ """
+ xp = self._xp
+
+ if probe is None:
+ probe = self._probe
+ else:
+ probe = xp.asarray(probe, dtype=xp.complex64)
+
+ return xp.fft.fftshift(probe, axes=(-2, -1))
+
+ def _return_object_fft(
+ self,
+ obj=None,
+ ):
+ """
+ Returns absolute value of obj fft shifted to center of array
+
+ Parameters
+ ----------
+ obj: array, optional
+ if None is specified, uses self._object
+
+ Returns
+ -------
+ object_fft_amplitude: np.ndarray
+ Amplitude of Fourier-transformed and center-shifted obj.
+ """
+ asnumpy = self._asnumpy
+
+ if obj is None:
+ obj = self._object
+
+ obj = self._crop_rotate_object_fov(asnumpy(obj))
+ return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj))))
+
+ def _return_self_consistency_errors(
+ self,
+ max_batch_size=None,
+ ):
+ """Compute the self-consistency errors for each probe position"""
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # Batch-size
+ if max_batch_size is None:
+ max_batch_size = self._num_diffraction_patterns
+
+ # Re-initialize fractional positions and vector patches
+ errors = np.array([])
+ positions_px = self._positions_px.copy()
+
+ for start, end in generate_batches(
+ self._num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ amplitudes = self._amplitudes[start:end]
+
+ # Overlaps
+ _, _, overlap = self._overlap_projection(self._object, self._probe)
+ fourier_overlap = xp.fft.fft2(overlap)
+
+ # Normalized mean-squared errors
+ batch_errors = xp.sum(
+ xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1)
+ )
+ errors = np.hstack((errors, batch_errors))
+
+ self._positions_px = positions_px.copy()
+ errors /= self._mean_diffraction_intensity
+
+ return asnumpy(errors)
+
+ def _return_projected_cropped_potential(
+ self,
+ ):
+ """Utility function to accommodate multiple classes"""
+ if self._object_type == "complex":
+ projected_cropped_potential = np.angle(self.object_cropped)
+ else:
+ projected_cropped_potential = self.object_cropped
+
+ return projected_cropped_potential
+
+ def show_uncertainty_visualization(
+ self,
+ errors=None,
+ max_batch_size=None,
+ projected_cropped_potential=None,
+ kde_sigma=None,
+ plot_histogram=True,
+ plot_contours=False,
+ **kwargs,
+ ):
+ """Plot uncertainty visualization using self-consistency errors"""
+
+ if errors is None:
+ errors = self._return_self_consistency_errors(max_batch_size=max_batch_size)
+
+ if projected_cropped_potential is None:
+ projected_cropped_potential = self._return_projected_cropped_potential()
+
+ if kde_sigma is None:
+ kde_sigma = 0.5 * self._scan_sampling[0] / self.sampling[0]
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+ gaussian_filter = self._gaussian_filter
+
+ ## Kernel Density Estimation
+
+ # rotated basis
+ angle = (
+ self._rotation_best_rad
+ if self._rotation_best_transpose
+ else -self._rotation_best_rad
+ )
+
+ tf = AffineTransform(angle=angle)
+ rotated_points = tf(self._positions_px, origin=self._positions_px_com, xp=xp)
+
+ padding = xp.min(rotated_points, axis=0).astype("int")
+
+ # bilinear sampling
+ pixel_output = np.array(projected_cropped_potential.shape) + asnumpy(
+ 2 * padding
+ )
+ pixel_size = pixel_output.prod()
+
+ xa = rotated_points[:, 0]
+ ya = rotated_points[:, 1]
+
+ # bilinear sampling
+ xF = xp.floor(xa).astype("int")
+ yF = xp.floor(ya).astype("int")
+ dx = xa - xF
+ dy = ya - yF
+
+ # resampling
+ inds_1D = xp.ravel_multi_index(
+ xp.hstack(
+ [
+ [xF, yF],
+ [xF + 1, yF],
+ [xF, yF + 1],
+ [xF + 1, yF + 1],
+ ]
+ ),
+ pixel_output,
+ mode=["wrap", "wrap"],
+ )
+
+ weights = xp.hstack(
+ (
+ (1 - dx) * (1 - dy),
+ (dx) * (1 - dy),
+ (1 - dx) * (dy),
+ (dx) * (dy),
+ )
+ )
+
+ pix_count = xp.reshape(
+ xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output
+ )
+
+ pix_output = xp.reshape(
+ xp.bincount(
+ inds_1D,
+ weights=weights * xp.tile(xp.asarray(errors), 4),
+ minlength=pixel_size,
+ ),
+ pixel_output,
+ )
+
+ # kernel density estimate
+ pix_count = gaussian_filter(pix_count, kde_sigma, mode="wrap")
+ pix_count[pix_count == 0.0] = np.inf
+ pix_output = gaussian_filter(pix_output, kde_sigma, mode="wrap")
+ pix_output /= pix_count
+ pix_output = pix_output[padding[0] : -padding[0], padding[1] : -padding[1]]
+ pix_output, _, _ = return_scaled_histogram_ordering(
+ pix_output.get(), normalize=True
+ )
+
+ ## Visualization
+ if plot_histogram:
+ spec = GridSpec(
+ ncols=1,
+ nrows=2,
+ height_ratios=[1, 4],
+ hspace=0.15,
+ )
+ auto_figsize = (4, 5.25)
+ else:
+ spec = GridSpec(
+ ncols=1,
+ nrows=1,
+ )
+ auto_figsize = (4, 4)
+
+ figsize = kwargs.pop("figsize", auto_figsize)
+
+ fig = plt.figure(figsize=figsize)
+
+ if plot_histogram:
+ ax_hist = fig.add_subplot(spec[0])
+
+ counts, bins = np.histogram(errors, bins=50)
+ ax_hist.hist(bins[:-1], bins, weights=counts, color="#5ac8c8", alpha=0.5)
+ ax_hist.set_ylabel("Counts")
+ ax_hist.set_xlabel("Normalized Squared Error")
+
+ ax = fig.add_subplot(spec[-1])
+
+ cmap = kwargs.pop("cmap", "magma")
+ vmin = kwargs.pop("vmin", None)
+ vmax = kwargs.pop("vmax", None)
+
+ projected_cropped_potential, vmin, vmax = return_scaled_histogram_ordering(
+ projected_cropped_potential,
+ vmin=vmin,
+ vmax=vmax,
+ )
+
+ extent = [
+ 0,
+ self.sampling[1] * projected_cropped_potential.shape[1],
+ self.sampling[0] * projected_cropped_potential.shape[0],
+ 0,
+ ]
+
+ ax.imshow(
+ projected_cropped_potential,
+ vmin=vmin,
+ vmax=vmax,
+ extent=extent,
+ alpha=1 - pix_output,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ if plot_contours:
+ aligned_points = asnumpy(rotated_points - padding)
+ aligned_points[:, 0] *= self.sampling[0]
+ aligned_points[:, 1] *= self.sampling[1]
+
+ ax.tricontour(
+ aligned_points[:, 1],
+ aligned_points[:, 0],
+ errors,
+ colors="grey",
+ levels=5,
+ # linestyles='dashed',
+ linewidths=0.5,
+ )
+
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ ax.set_xlim((extent[0], extent[1]))
+ ax.set_ylim((extent[2], extent[3]))
+ ax.xaxis.set_ticks_position("bottom")
+
+ spec.tight_layout(fig)
+
+ def show_fourier_probe(
+ self,
+ probe=None,
+ cbar=True,
+ scalebar=True,
+ pixelsize=None,
+ pixelunits=None,
+ **kwargs,
+ ):
+ """
+ Plot probe in fourier space
+
+ Parameters
+ ----------
+ probe: complex array, optional
+ if None is specified, uses the `probe_fourier` property
+ cbar: bool, optional
+ if True, adds colorbar
+ scalebar: bool, optional
+ if True, adds scalebar to probe
+ pixelunits: str, optional
+ units for scalebar, default is A^-1
+ pixelsize: float, optional
+ default is probe reciprocal sampling
+ """
+ asnumpy = self._asnumpy
+
+ if probe is None:
+ probe = self.probe_fourier
+ else:
+ probe = asnumpy(self._return_fourier_probe(probe))
+
+ if pixelsize is None:
+ pixelsize = self._reciprocal_sampling[1]
+ if pixelunits is None:
+ pixelunits = r"$\AA^{-1}$"
+
+ figsize = kwargs.pop("figsize", (6, 6))
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+
+ fig, ax = plt.subplots(figsize=figsize)
+ show_complex(
+ probe,
+ cbar=cbar,
+ figax=(fig, ax),
+ scalebar=scalebar,
+ pixelsize=pixelsize,
+ pixelunits=pixelunits,
+ ticks=False,
+ chroma_boost=chroma_boost,
+ **kwargs,
+ )
+
+ def show_object_fft(self, obj=None, **kwargs):
+ """
+ Plot FFT of reconstructed object
+
+ Parameters
+ ----------
+ obj: complex array, optional
+ if None is specified, uses the `object_fft` property
+ """
+ if obj is None:
+ object_fft = self.object_fft
+ else:
+ object_fft = self._return_object_fft(obj)
+
+ figsize = kwargs.pop("figsize", (6, 6))
+ cmap = kwargs.pop("cmap", "magma")
+
+ pixelsize = 1 / (object_fft.shape[1] * self.sampling[1])
+ show(
+ object_fft,
+ figsize=figsize,
+ cmap=cmap,
+ scalebar=True,
+ pixelsize=pixelsize,
+ ticks=False,
+ pixelunits=r"$\AA^{-1}$",
+ **kwargs,
+ )
+
+ @property
+ def probe_fourier(self):
+ """Current probe estimate in Fourier space"""
+ if not hasattr(self, "_probe"):
+ return None
+
+ asnumpy = self._asnumpy
+ return asnumpy(self._return_fourier_probe(self._probe))
+
+ @property
+ def probe_centered(self):
+ """Current probe estimate shifted to the center"""
+ if not hasattr(self, "_probe"):
+ return None
+
+ asnumpy = self._asnumpy
+ return asnumpy(self._return_centered_probe(self._probe))
+
+ @property
+ def object_fft(self):
+ """Fourier transform of current object estimate"""
+
+ if not hasattr(self, "_object"):
+ return None
+
+ return self._return_object_fft(self._object)
+
+ @property
+ def angular_sampling(self):
+ """Angular sampling [mrad]"""
+ return getattr(self, "_angular_sampling", None)
+
+ @property
+ def sampling(self):
+ """Sampling [Å]"""
+
+ if self.angular_sampling is None:
+ return None
+
+ return tuple(
+ electron_wavelength_angstrom(self._energy) * 1e3 / dk / n
+ for dk, n in zip(self.angular_sampling, self._region_of_interest_shape)
+ )
+
+ @property
+ def positions(self):
+ """Probe positions [A]"""
+
+ if self.angular_sampling is None:
+ return None
+
+ asnumpy = self._asnumpy
+
+ positions = self._positions_px.copy()
+ positions[:, 0] *= self.sampling[0]
+ positions[:, 1] *= self.sampling[1]
+
+ return asnumpy(positions)
+
+ @property
+ def object_cropped(self):
+ """Cropped and rotated object"""
+
+ return self._crop_rotate_object_fov(self._object)
diff --git a/py4DSTEM/process/phase/iterative_dpc.py b/py4DSTEM/process/phase/iterative_dpc.py
new file mode 100644
index 000000000..b390ce46d
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_dpc.py
@@ -0,0 +1,1039 @@
+"""
+Module for reconstructing phase objects from 4DSTEM datasets using iterative methods,
+namely DPC.
+"""
+
+import warnings
+from typing import Sequence, Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd
+from py4DSTEM.data import Calibration
+from py4DSTEM.datacube import DataCube
+from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+
+class DPCReconstruction(PhaseReconstruction):
+ """
+ Iterative Differential Phase Constrast Reconstruction Class.
+
+ Diffraction intensities dimensions : (Rx,Ry,Qx,Qy)
+ Reconstructed phase object dimensions : (Rx,Ry)
+
+ Parameters
+ ----------
+ datacube: DataCube
+ Input 4D diffraction pattern intensities
+ initial_object_guess: np.ndarray, optional
+ Cropped initial guess of dpc phase
+ energy: float, optional
+ The electron energy of the wave functions in eV
+ verbose: bool, optional
+ If True, class methods will inherit this and print additional information
+ device: str, optional
+ Calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ name: str, optional
+ Class name
+ """
+
+ def __init__(
+ self,
+ datacube: DataCube = None,
+ initial_object_guess: np.ndarray = None,
+ energy: float = None,
+ verbose: bool = True,
+ device: str = "cpu",
+ name: str = "dpc_reconstruction",
+ ):
+ Custom.__init__(self, name=name)
+
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ self.set_save_defaults()
+
+ # Data
+ self._datacube = datacube
+ self._object_phase = initial_object_guess
+
+ # Metadata
+ self._energy = energy
+ self._verbose = verbose
+ self._device = device
+ self._preprocessed = False
+
+ def to_h5(self, group):
+ """
+ Wraps datasets and metadata to write in emdfile classes,
+ notably: the object phase array.
+ """
+
+ # instantiation metadata
+ self.metadata = Metadata(
+ name="instantiation_metadata",
+ data={
+ "energy": self._energy,
+ "verbose": self._verbose,
+ "device": self._device,
+ "name": self.name,
+ },
+ )
+
+ # preprocessing metadata
+ self.metadata = Metadata(
+ name="preprocess_metadata",
+ data={
+ "rotation_angle_rad": self._rotation_best_rad,
+ "data_transpose": self._rotation_best_transpose,
+ "sampling": self.sampling,
+ },
+ )
+
+ # reconstruction metadata
+ is_stack = self._save_iterations and hasattr(self, "object_phase_iterations")
+ if is_stack:
+ num_iterations = len(self.object_phase_iterations)
+ iterations = list(range(0, num_iterations, self._save_iterations_frequency))
+ if num_iterations - 1 not in iterations:
+ iterations.append(num_iterations - 1)
+
+ error = [self.error_iterations[i] for i in iterations]
+ else:
+ error = self.error
+
+ self.metadata = Metadata(
+ name="reconstruction_metadata",
+ data={
+ "reconstruction_error": error,
+ "final_step_size": self._step_size,
+ },
+ )
+
+ if is_stack:
+ iterations_labels = [f"iteration_{i:03}" for i in iterations]
+
+ # object
+ object_iterations = [
+ np.asarray(self.object_phase_iterations[i]) for i in iterations
+ ]
+ self._object_emd = Array(
+ name="reconstruction_object",
+ data=np.stack(object_iterations, axis=0),
+ slicelabels=iterations_labels,
+ )
+
+ else:
+ # object
+ self._object_emd = Array(
+ name="reconstruction_object",
+ data=self._asnumpy(self._object_phase),
+ )
+
+ # datacube
+ if self._save_datacube:
+ self.metadata = self._datacube.calibration
+ Custom.to_h5(self, group)
+ else:
+ dc = self._datacube
+ self._datacube = None
+ Custom.to_h5(self, group)
+ self._datacube = dc
+
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of arguments/values to pass
+ to the class' __init__ function
+ """
+ # Get data
+ dict_data = cls._get_emd_attr_data(cls, group)
+
+ # Get metadata dictionaries
+ instance_md = _read_metadata(group, "instantiation_metadata")
+
+ # Fix calibrations bug
+ if "_datacube" in dict_data:
+ calibrations_dict = _read_metadata(group, "calibration")._params
+ cal = Calibration()
+ cal._params.update(calibrations_dict)
+ dc = dict_data["_datacube"]
+ dc.calibration = cal
+ else:
+ dc = None
+
+ # Check if stack
+ if dict_data["_object_emd"].is_stack:
+ obj = dict_data["_object_emd"][-1].data
+ else:
+ obj = dict_data["_object_emd"].data
+
+ # Populate args and return
+ kwargs = {
+ "datacube": dc,
+ "initial_object_guess": np.asarray(obj),
+ "energy": instance_md["energy"],
+ "name": instance_md["name"],
+ "verbose": True, # for compatibility
+ "device": "cpu", # for compatibility
+ }
+
+ return kwargs
+
+ def _populate_instance(self, group):
+ """
+ Sets post-initialization properties, notably some preprocessing meta
+ optional; during read, this method is run after object instantiation.
+ """
+ # Preprocess metadata
+ preprocess_md = _read_metadata(group, "preprocess_metadata")
+ self._rotation_best_rad = preprocess_md["rotation_angle_rad"]
+ self._rotation_best_transpose = preprocess_md["data_transpose"]
+ self._preprocessed = False
+
+ # Reconstruction metadata
+ reconstruction_md = _read_metadata(group, "reconstruction_metadata")
+ error = reconstruction_md["reconstruction_error"]
+
+ # Data
+ dict_data = Custom._get_emd_attr_data(Custom, group)
+
+ # Check if stack
+ if hasattr(error, "__len__"):
+ self.object_phase_iterations = list(dict_data["_object_emd"].data)
+ self.error_iterations = error
+ self.error = error[-1]
+ else:
+ self.error = error
+
+ self._step_size = reconstruction_md["final_step_size"]
+
+ def preprocess(
+ self,
+ dp_mask: np.ndarray = None,
+ padding_factor: float = 2,
+ rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0),
+ maximize_divergence: bool = False,
+ fit_function: str = "plane",
+ force_com_rotation: float = None,
+ force_com_transpose: bool = None,
+ force_com_shifts: Union[Sequence[np.ndarray], Sequence[float]] = None,
+ force_com_measured: Sequence[np.ndarray] = None,
+ plot_center_of_mass: str = "default",
+ plot_rotation: bool = True,
+ **kwargs,
+ ):
+ """
+ DPC preprocessing step.
+ Calls the base class methods:
+
+ _extract_intensities_and_calibrations_from_datacube(),
+ _calculate_intensities_center_of_mass(), and
+ _solve_for_center_of_mass_relative_rotation()
+
+ Parameters
+ ----------
+ dp_mask: ndarray, optional
+ Mask for datacube intensities (Qx,Qy)
+ padding_factor: float, optional
+ Factor to pad object by to reduce periodic artifacts
+ rotation_angles_deg: np.darray, optional
+ Array of angles in degrees to perform curl minimization over
+ maximize_divergence: bool, optional
+ If True, the divergence of the CoM gradient vector field is maximized
+ fit_function: str, optional
+ 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two'
+ force_ com_rotation: float (degrees), optional
+ Force relative rotation angle between real and reciprocal space
+ force_com_transpose: bool (optional)
+ Force whether diffraction intensities need to be transposed.
+ force_com_shifts: tuple of ndarrays (CoMx, CoMy)
+ Force CoM fitted shifts
+ force_com_measured: tuple of ndarrays (CoMx measured, CoMy measured)
+ Force CoM measured shifts
+ plot_center_of_mass: str, optional
+ If 'default', the corrected CoM arrays will be displayed
+ If 'all', the computed and fitted CoM arrays will be displayed
+ plot_rotation: bool, optional
+ If True, the CoM curl minimization search result will be displayed
+
+ Returns
+ --------
+ self: DPCReconstruction
+ Self to accommodate chaining
+ """
+ xp = self._xp
+
+ # set additional metadata
+ self._dp_mask = dp_mask
+
+ if self._datacube is None:
+ if force_com_measured is None:
+ raise ValueError(
+ (
+ "The preprocess() method requires either a DataCube "
+ "or `force_com_measured`. "
+ "Please run dpc.attach_datacube(DataCube) to attach DataCube."
+ )
+ )
+ else:
+ self._datacube = DataCube(
+ data=np.empty(force_com_measured[0].shape + (1, 1))
+ )
+
+ self._intensities = self._extract_intensities_and_calibrations_from_datacube(
+ self._datacube,
+ require_calibrations=False,
+ )
+
+ (
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_fitted_x,
+ self._com_fitted_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ ) = self._calculate_intensities_center_of_mass(
+ self._intensities,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts,
+ com_measured=force_com_measured,
+ )
+
+ (
+ self._rotation_best_rad,
+ self._rotation_best_transpose,
+ self._com_x,
+ self._com_y,
+ self.com_x,
+ self.com_y,
+ ) = self._solve_for_center_of_mass_relative_rotation(
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ rotation_angles_deg=rotation_angles_deg,
+ plot_rotation=plot_rotation,
+ plot_center_of_mass=plot_center_of_mass,
+ maximize_divergence=maximize_divergence,
+ force_com_rotation=force_com_rotation,
+ force_com_transpose=force_com_transpose,
+ **kwargs,
+ )
+
+ # Object Initialization
+ padded_object_shape = np.round(
+ np.array(self._grid_scan_shape) * padding_factor
+ ).astype("int")
+ self._padded_object_phase = xp.zeros(padded_object_shape, dtype=xp.float32)
+ if self._object_phase is not None:
+ self._padded_object_phase[
+ : self._grid_scan_shape[0], : self._grid_scan_shape[1]
+ ] = xp.asarray(self._object_phase, dtype=xp.float32)
+
+ self._padded_object_phase_initial = self._padded_object_phase.copy()
+
+ # Fourier coordinates and operators
+ kx = xp.fft.fftfreq(padded_object_shape[0], d=self._scan_sampling[0])
+ ky = xp.fft.fftfreq(padded_object_shape[1], d=self._scan_sampling[1])
+ kya, kxa = xp.meshgrid(ky, kx)
+ k_den = kxa**2 + kya**2
+ k_den[0, 0] = np.inf
+ k_den = 1 / k_den
+ self._kx_op = -1j * 0.25 * kxa * k_den
+ self._ky_op = -1j * 0.25 * kya * k_den
+
+ self._preprocessed = True
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _forward(
+ self,
+ padded_phase_object: np.ndarray,
+ mask: np.ndarray,
+ mask_inv: np.ndarray,
+ error: float,
+ step_size: float,
+ ):
+ """
+ DPC forward projection:
+ Computes a centered finite-difference approximation to the phase gradient
+ and projects to the measured CoM gradient
+
+ Parameters
+ ----------
+ padded_phase_object: np.ndarray
+ Current padded phase object estimate
+ mask: np.ndarray
+ Mask of object inside padded array
+ mask_inv: np.ndarray
+ Inverse mask of object inside padded array
+ error: float
+ Current error estimate
+ step_size: float
+ Current reconstruction step-size
+
+ Returns
+ --------
+ obj_dx: np.ndarray
+ Forward-projected horizontal CoM gradient
+ obj_dy: np.ndarray
+ Forward-projected vertical CoM gradient
+ error: float
+ Updated estimate error
+ step_size: float
+ Updated reconstruction step-size. Halved if error increased.
+ """
+
+ xp = self._xp
+ dx, dy = self._scan_sampling
+
+ # centered finite-differences
+ obj_dx = (
+ xp.roll(padded_phase_object, 1, axis=0)
+ - xp.roll(padded_phase_object, -1, axis=0)
+ ) / (2 * dx)
+ obj_dy = (
+ xp.roll(padded_phase_object, 1, axis=1)
+ - xp.roll(padded_phase_object, -1, axis=1)
+ ) / (2 * dy)
+
+ # difference from measurement
+ obj_dx[mask] += self._com_x.ravel()
+ obj_dy[mask] += self._com_y.ravel()
+ obj_dx[mask_inv] = 0
+ obj_dy[mask_inv] = 0
+
+ new_error = xp.mean(obj_dx[mask] ** 2 + obj_dy[mask] ** 2) / (
+ xp.mean(self._com_x.ravel() ** 2 + self._com_y.ravel() ** 2)
+ )
+
+ return obj_dx, obj_dy, new_error, step_size
+
+ def _adjoint(
+ self,
+ obj_dx: np.ndarray,
+ obj_dy: np.ndarray,
+ kx_op: np.ndarray,
+ ky_op: np.ndarray,
+ ):
+ """
+ DPC adjoint projection:
+ Fourier-integrates the current estimate of the CoM gradient
+
+ Parameters
+ ----------
+ obj_dx: np.ndarray
+ Forward-projected horizontal phase gradient
+ obj_dy: np.ndarray
+ Forward-projected vertical phase gradient
+ kx_op: np.ndarray
+ Scaled k_x operator
+ ky_op: np.ndarray
+ Scaled k_y operator
+
+ Returns
+ --------
+ phase_update: np.ndarray
+ Adjoint-projected phase object
+ """
+
+ xp = self._xp
+
+ phase_update = xp.real(
+ xp.fft.ifft2(xp.fft.fft2(obj_dx) * kx_op + xp.fft.fft2(obj_dy) * ky_op)
+ )
+
+ return phase_update
+
+ def _update(
+ self,
+ padded_phase_object: np.ndarray,
+ phase_update: np.ndarray,
+ step_size: float,
+ ):
+ """
+ DPC update step:
+
+ Parameters
+ ----------
+ padded_phase_object: np.ndarray
+ Current padded phase object estimate
+ phase_update: np.ndarray
+ Adjoint-projected phase object
+ step_size: float
+ Update step size
+
+ Returns
+ --------
+ updated_padded_object_phase: np.ndarray
+ Updated padded phase object estimate
+ """
+
+ padded_phase_object += step_size * phase_update
+ return padded_phase_object
+
+ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma):
+ """
+ Smoothness constrain used for blurring object.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ gaussian_filter = self._gaussian_filter
+
+ gaussian_filter_sigma /= self.sampling[0]
+ current_object = gaussian_filter(current_object, gaussian_filter_sigma)
+
+ return current_object
+
+ def _object_butterworth_constraint(
+ self, current_object, q_lowpass, q_highpass, butterworth_order
+ ):
+ """
+ Butterworth filter used for low/high-pass filtering.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+ qx = xp.fft.fftfreq(current_object.shape[0], self.sampling[0])
+ qy = xp.fft.fftfreq(current_object.shape[1], self.sampling[1])
+
+ qya, qxa = xp.meshgrid(qy, qx)
+ qra = xp.sqrt(qxa**2 + qya**2)
+
+ env = xp.ones_like(qra)
+ if q_highpass:
+ env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order))
+ if q_lowpass:
+ env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order))
+
+ current_object_mean = xp.mean(current_object)
+ current_object -= current_object_mean
+ current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env)
+ current_object += current_object_mean
+
+ return xp.real(current_object)
+
+ def _object_anti_gridding_contraint(self, current_object):
+ """
+ Zero outer pixels of object fft to remove gridding artifacts
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+
+ # find indices to zero
+ width_x = current_object.shape[0]
+ width_y = current_object.shape[1]
+ ind_min_x = int(xp.floor(width_x / 2) - 2)
+ ind_max_x = int(xp.ceil(width_x / 2) + 2)
+ ind_min_y = int(xp.floor(width_y / 2) - 2)
+ ind_max_y = int(xp.ceil(width_y / 2) + 2)
+
+ # zero pixels
+ object_fft = xp.fft.fft2(current_object)
+ object_fft[ind_min_x:ind_max_x] = 0
+ object_fft[:, ind_min_y:ind_max_y] = 0
+
+ return xp.real(xp.fft.ifft2(object_fft))
+
+ def _constraints(
+ self,
+ current_object,
+ gaussian_filter,
+ gaussian_filter_sigma,
+ butterworth_filter,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ anti_gridding,
+ ):
+ """
+ DPC constraints operator.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ gaussian_filter: bool
+ If True, applies real-space gaussian filter
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A
+ butterworth_filter: bool
+ If True, applies high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ anti_gridding: bool
+ If true, zero outer pixels of object fft to remove
+ gridding artifacts
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ if gaussian_filter:
+ current_object = self._object_gaussian_constraint(
+ current_object, gaussian_filter_sigma
+ )
+
+ if butterworth_filter:
+ current_object = self._object_butterworth_constraint(
+ current_object,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ )
+
+ if anti_gridding:
+ current_object = self._object_anti_gridding_contraint(
+ current_object,
+ )
+
+ return current_object
+
+ def reconstruct(
+ self,
+ reset: bool = None,
+ max_iter: int = 64,
+ step_size: float = None,
+ stopping_criterion: float = 1e-6,
+ backtrack: bool = True,
+ progress_bar: bool = True,
+ gaussian_filter_sigma: float = None,
+ gaussian_filter_iter: int = np.inf,
+ butterworth_filter_iter: int = np.inf,
+ q_lowpass: float = None,
+ q_highpass: float = None,
+ butterworth_order: float = 2,
+ anti_gridding: float = True,
+ store_iterations: bool = False,
+ ):
+ """
+ Performs Iterative DPC Reconstruction:
+
+ Parameters
+ ----------
+ reset: bool, optional
+ If True, previous reconstructions are ignored
+ max_iter: int, optional
+ Maximum number of iterations
+ step_size: float, optional
+ Reconstruction update step size
+ stopping_criterion: float, optional
+ step_size below which reconstruction exits
+ backtrack: bool, optional
+ If True, steps that increase the error metric are rejected
+ and iteration continues with a reduced step size from the
+ previous iteration
+ progress_bar: bool, optional
+ If True, reconstruction progress bar will be printed
+ gaussian_filter_sigma: float, optional
+ Standard deviation of gaussian kernel in A
+ gaussian_filter_iter: int, optional
+ Number of iterations to run using object smoothness constraint
+ butterworth_filter_iter: int, optional
+ Number of iterations to run using high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ anti_gridding: bool
+ If true, zero outer pixels of object fft to remove
+ gridding artifacts
+ store_iterations: bool, optional
+ If True, all reconstruction iterations will be stored
+
+ Returns
+ --------
+ self: DPCReconstruction
+ Self to accommodate chaining
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # Restart
+ if store_iterations and (not hasattr(self, "object_phase_iterations") or reset):
+ self.object_phase_iterations = []
+
+ if reset:
+ self.error = np.inf
+ self.error_iterations = []
+ self._step_size = step_size if step_size is not None else 0.5
+ self._padded_object_phase = self._padded_object_phase_initial.copy()
+ elif reset is None:
+ if hasattr(self, "error"):
+ warnings.warn(
+ (
+ "Continuing reconstruction from previous result. "
+ "Use reset=True for a fresh start."
+ ),
+ UserWarning,
+ )
+ else:
+ self.error_iterations = []
+
+ self.error = getattr(self, "error", np.inf)
+
+ if step_size is None:
+ self._step_size = getattr(self, "_step_size", 0.5)
+ else:
+ self._step_size = step_size
+
+ mask = xp.zeros(self._padded_object_phase.shape, dtype="bool")
+ mask[: self._grid_scan_shape[0], : self._grid_scan_shape[1]] = True
+ mask_inv = xp.logical_not(mask)
+
+ # main loop
+ for a0 in tqdmnd(
+ max_iter,
+ desc="Reconstructing phase",
+ unit=" iter",
+ disable=not progress_bar,
+ ):
+ if self._step_size < stopping_criterion:
+ break
+
+ previous_iteration = self._padded_object_phase.copy()
+
+ # forward operator
+ com_dx, com_dy, new_error, self._step_size = self._forward(
+ self._padded_object_phase, mask, mask_inv, self.error, self._step_size
+ )
+
+ # if the error went up after the previous step, go back to the step
+ # before the error rose and continue with the halved step size
+ if (new_error > self.error) and backtrack:
+ self._padded_object_phase = previous_iteration
+ self._step_size /= 2
+ if self._verbose:
+ print(f"Iteration {a0}, step reduced to {self._step_size}")
+ continue
+ self.error = new_error
+
+ # adjoint operator
+ phase_update = self._adjoint(com_dx, com_dy, self._kx_op, self._ky_op)
+
+ # update
+ self._padded_object_phase = self._update(
+ self._padded_object_phase, phase_update, self._step_size
+ )
+
+ # constraints
+ self._padded_object_phase = self._constraints(
+ self._padded_object_phase,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma is not None,
+ gaussian_filter_sigma=gaussian_filter_sigma,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass is not None or q_highpass is not None),
+ q_lowpass=q_lowpass,
+ q_highpass=q_highpass,
+ butterworth_order=butterworth_order,
+ anti_gridding=anti_gridding,
+ )
+
+ self.error_iterations.append(self.error.item())
+ if store_iterations:
+ self.object_phase_iterations.append(
+ asnumpy(
+ self._padded_object_phase[
+ : self._grid_scan_shape[0], : self._grid_scan_shape[1]
+ ].copy()
+ )
+ )
+
+ if self._step_size < stopping_criterion:
+ if self._verbose:
+ warnings.warn(
+ f"Step-size has decreased below stopping criterion {stopping_criterion}.",
+ UserWarning,
+ )
+
+ # crop result
+ self._object_phase = self._padded_object_phase[
+ : self._grid_scan_shape[0], : self._grid_scan_shape[1]
+ ]
+ self.object_phase = asnumpy(self._object_phase)
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _visualize_last_iteration(
+ self, fig, cbar: bool, plot_convergence: bool, **kwargs
+ ):
+ """
+ Displays last iteration of reconstructed phase object.
+
+ Parameters
+ --------
+ fig, optional
+ Matplotlib figure to draw Gridspec on
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_convergence: bool, optional
+ If true, the NMSE error plot is displayed
+ """
+
+ figsize = kwargs.pop("figsize", (5, 6))
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_convergence:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15)
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ extent = [
+ 0,
+ self._scan_sampling[1] * self._grid_scan_shape[1],
+ self._scan_sampling[0] * self._grid_scan_shape[0],
+ 0,
+ ]
+
+ ax1 = fig.add_subplot(spec[0])
+ im = ax1.imshow(self.object_phase, extent=extent, cmap=cmap, **kwargs)
+ ax1.set_ylabel(f"x [{self._scan_units[0]}]")
+ ax1.set_xlabel(f"y [{self._scan_units[1]}]")
+ ax1.set_title(f"DPC phase reconstruction - NMSE error: {self.error:.3e}")
+
+ if cbar:
+ divider = make_axes_locatable(ax1)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if plot_convergence:
+ errors = self.error_iterations
+ ax2 = fig.add_subplot(spec[1])
+ ax2.semilogy(np.arange(len(errors)), errors, **kwargs)
+ ax2.set_xlabel("Iteration number")
+ ax2.set_ylabel("Log NMSE error")
+ ax2.yaxis.tick_right()
+
+ spec.tight_layout(fig)
+
+ def _visualize_all_iterations(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ iterations_grid: Tuple[int, int],
+ **kwargs,
+ ):
+ """
+ Displays last iteration of reconstructed phase object.
+
+ Parameters
+ --------
+ fig, optional
+ Matplotlib figure to draw Gridspec on
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_convergence: bool, optional
+ If true, the NMSE error plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ """
+
+ if not hasattr(self, "object_phase_iterations"):
+ raise ValueError(
+ (
+ "Object iterations were not saved during reconstruction. "
+ "Please re-run using store_iterations=True."
+ )
+ )
+
+ if iterations_grid == "auto":
+ num_iter = len(self.error_iterations)
+
+ if num_iter == 1:
+ return self._visualize_last_iteration(
+ plot_convergence=plot_convergence,
+ cbar=cbar,
+ **kwargs,
+ )
+ else:
+ iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2)
+
+ auto_figsize = (
+ (3 * iterations_grid[1], 3 * iterations_grid[0] + 1)
+ if plot_convergence
+ else (3 * iterations_grid[1], 3 * iterations_grid[0])
+ )
+ figsize = kwargs.pop("figsize", auto_figsize)
+ cmap = kwargs.pop("cmap", "magma")
+
+ total_grids = np.prod(iterations_grid)
+ errors = self.error_iterations
+ phases = self.object_phase_iterations
+ max_iter = len(phases) - 1
+ grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1))
+
+ extent = [
+ 0,
+ self._scan_sampling[1] * self._grid_scan_shape[1],
+ self._scan_sampling[0] * self._grid_scan_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15)
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ grid = ImageGrid(
+ fig,
+ spec[0],
+ nrows_ncols=iterations_grid,
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ im = ax.imshow(
+ phases[grid_range[n]],
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_ylabel(f"x [{self._scan_units[0]}]")
+ ax.set_xlabel(f"y [{self._scan_units[1]}]")
+ if cbar:
+ grid.cbar_axes[n].colorbar(im)
+ ax.set_title(
+ f"Iteration: {grid_range[n]}\nNMSE error: {errors[grid_range[n]]:.3e}"
+ )
+
+ if plot_convergence:
+ ax2 = fig.add_subplot(spec[1])
+ ax2.semilogy(np.arange(len(errors)), errors, **kwargs)
+ ax2.set_xlabel("Iteration number")
+ ax2.set_ylabel("Log NMSE error")
+ ax2.yaxis.tick_right()
+
+ spec.tight_layout(fig)
+
+ def visualize(
+ self,
+ fig=None,
+ iterations_grid: Tuple[int, int] = None,
+ plot_convergence: bool = True,
+ cbar: bool = True,
+ **kwargs,
+ ):
+ """
+ Displays reconstructed phase object.
+
+ Parameters
+ ----------
+ fig, optional
+ Matplotlib figure to draw Gridspec on
+ plot_convergence: bool, optional
+ If true, the NMSE error plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+
+ Returns
+ --------
+ self: DPCReconstruction
+ Self to accommodate chaining
+ """
+
+ if iterations_grid is None:
+ self._visualize_last_iteration(
+ fig=fig, plot_convergence=plot_convergence, cbar=cbar, **kwargs
+ )
+ else:
+ self._visualize_all_iterations(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ iterations_grid=iterations_grid,
+ cbar=cbar,
+ **kwargs,
+ )
+
+ return self
+
+ @property
+ def sampling(self):
+ """Sampling [Å]"""
+
+ return self._scan_sampling
diff --git a/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py
new file mode 100644
index 000000000..f4c10cb13
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_mixedstate_multislice_ptychography.py
@@ -0,0 +1,3654 @@
+"""
+Module for reconstructing phase objects from 4DSTEM datasets using iterative methods,
+namely multislice ptychography.
+"""
+
+import warnings
+from typing import Mapping, Sequence, Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pylops
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
+from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex
+
+try:
+ import cupy as cp
+except ImportError:
+ cp = None
+
+from emdfile import Custom, tqdmnd
+from py4DSTEM import DataCube
+from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction
+from py4DSTEM.process.phase.utils import (
+ ComplexProbe,
+ fft_shift,
+ generate_batches,
+ polar_aliases,
+ polar_symbols,
+ spatial_frequencies,
+)
+from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar
+from scipy.ndimage import rotate
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+
+class MixedstateMultislicePtychographicReconstruction(PtychographicReconstruction):
+ """
+ Mixed-State Multislice Ptychographic Reconstruction Class.
+
+ Diffraction intensities dimensions : (Rx,Ry,Qx,Qy)
+ Reconstructed probe dimensions : (N,Sx,Sy)
+ Reconstructed object dimensions : (T,Px,Py)
+
+ such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes
+ and (Px,Py) is the padded-object size we position our ROI around in
+ each of the T slices.
+
+ Parameters
+ ----------
+ energy: float
+ The electron energy of the wave functions in eV
+ num_probes: int, optional
+ Number of mixed-state probes
+ num_slices: int
+ Number of slices to use in the forward model
+ slice_thicknesses: float or Sequence[float]
+ Slice thicknesses in angstroms. If float, all slices are assigned the same thickness
+ datacube: DataCube, optional
+ Input 4D diffraction pattern intensities
+ semiangle_cutoff: float, optional
+ Semiangle cutoff for the initial probe guess in mrad
+ semiangle_cutoff_pixels: float, optional
+ Semiangle cutoff for the initial probe guess in pixels
+ rolloff: float, optional
+ Semiangle rolloff for the initial probe guess
+ vacuum_probe_intensity: np.ndarray, optional
+ Vacuum probe to use as intensity aperture for initial probe guess
+ polar_parameters: dict, optional
+ Mapping from aberration symbols to their corresponding values. All aberration
+ magnitudes should be given in Å and angles should be given in radians.
+ object_padding_px: Tuple[int,int], optional
+ Pixel dimensions to pad object with
+ If None, the padding is set to half the probe ROI dimensions
+ initial_object_guess: np.ndarray, optional
+ Initial guess for complex-valued object of dimensions (Px,Py)
+ If None, initialized to 1.0j
+ initial_probe_guess: np.ndarray, optional
+ Initial guess for complex-valued probe of dimensions (Sx,Sy). If None,
+ initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations
+ initial_scan_positions: np.ndarray, optional
+ Probe positions in Å for each diffraction intensity
+ If None, initialized to a grid scan
+ theta_x: float
+ x tilt of propagator (in degrees)
+ theta_y: float
+ y tilt of propagator (in degrees)
+ middle_focus: bool
+ if True, adds half the sample thickness to the defocus
+ object_type: str, optional
+ The object can be reconstructed as a real potential ('potential') or a complex
+ object ('complex')
+ positions_mask: np.ndarray, optional
+ Boolean real space mask to select positions in datacube to skip for reconstruction
+ verbose: bool, optional
+ If True, class methods will inherit this and print additional information
+ device: str, optional
+ Calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ name: str, optional
+ Class name
+ kwargs:
+ Provide the aberration coefficients as keyword arguments.
+ """
+
+ # Class-specific Metadata
+ _class_specific_metadata = ("_num_probes", "_num_slices", "_slice_thicknesses")
+
+ def __init__(
+ self,
+ energy: float,
+ num_slices: int,
+ slice_thicknesses: Union[float, Sequence[float]],
+ num_probes: int = None,
+ datacube: DataCube = None,
+ semiangle_cutoff: float = None,
+ semiangle_cutoff_pixels: float = None,
+ rolloff: float = 2.0,
+ vacuum_probe_intensity: np.ndarray = None,
+ polar_parameters: Mapping[str, float] = None,
+ object_padding_px: Tuple[int, int] = None,
+ initial_object_guess: np.ndarray = None,
+ initial_probe_guess: np.ndarray = None,
+ initial_scan_positions: np.ndarray = None,
+ theta_x: float = 0,
+ theta_y: float = 0,
+ middle_focus: bool = False,
+ object_type: str = "complex",
+ positions_mask: np.ndarray = None,
+ verbose: bool = True,
+ device: str = "cpu",
+ name: str = "multi-slice_ptychographic_reconstruction",
+ **kwargs,
+ ):
+ Custom.__init__(self, name=name)
+
+ if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe):
+ if num_probes is None:
+ raise ValueError(
+ (
+ "If initial_probe_guess is None, or a ComplexProbe object, "
+ "num_probes must be specified."
+ )
+ )
+ else:
+ if len(initial_probe_guess.shape) != 3:
+ raise ValueError(
+ "Specified initial_probe_guess must have dimensions (N,Sx,Sy)."
+ )
+ num_probes = initial_probe_guess.shape[0]
+
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from scipy.special import erf
+
+ self._erf = erf
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from cupyx.scipy.special import erf
+
+ self._erf = erf
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ for key in kwargs.keys():
+ if (key not in polar_symbols) and (key not in polar_aliases.keys()):
+ raise ValueError("{} not a recognized parameter".format(key))
+
+ if np.isscalar(slice_thicknesses):
+ mean_slice_thickness = slice_thicknesses
+ else:
+ mean_slice_thickness = np.mean(slice_thicknesses)
+
+ if middle_focus:
+ if "defocus" in kwargs:
+ kwargs["defocus"] += mean_slice_thickness * num_slices / 2
+ elif "C10" in kwargs:
+ kwargs["C10"] -= mean_slice_thickness * num_slices / 2
+ elif polar_parameters is not None and "defocus" in polar_parameters:
+ polar_parameters["defocus"] = (
+ polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2
+ )
+ elif polar_parameters is not None and "C10" in polar_parameters:
+ polar_parameters["C10"] = (
+ polar_parameters["C10"] - mean_slice_thickness * num_slices / 2
+ )
+
+ self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))
+
+ if polar_parameters is None:
+ polar_parameters = {}
+
+ polar_parameters.update(kwargs)
+ self._set_polar_parameters(polar_parameters)
+
+ slice_thicknesses = np.array(slice_thicknesses)
+ if slice_thicknesses.shape == ():
+ slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1)
+ elif slice_thicknesses.shape[0] != (num_slices - 1):
+ raise ValueError(
+ (
+ f"slice_thicknesses must have length {num_slices - 1}, "
+ f"not {slice_thicknesses.shape[0]}."
+ )
+ )
+
+ if object_type != "potential" and object_type != "complex":
+ raise ValueError(
+ f"object_type must be either 'potential' or 'complex', not {object_type}"
+ )
+
+ if positions_mask is not None and positions_mask.dtype != "bool":
+ warnings.warn(
+ ("`positions_mask` converted to `bool` array"),
+ UserWarning,
+ )
+ positions_mask = np.asarray(positions_mask, dtype="bool")
+
+ self.set_save_defaults()
+
+ # Data
+ self._datacube = datacube
+ self._object = initial_object_guess
+ self._probe = initial_probe_guess
+
+ # Common Metadata
+ self._vacuum_probe_intensity = vacuum_probe_intensity
+ self._scan_positions = initial_scan_positions
+ self._energy = energy
+ self._semiangle_cutoff = semiangle_cutoff
+ self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
+ self._rolloff = rolloff
+ self._object_type = object_type
+ self._positions_mask = positions_mask
+ self._object_padding_px = object_padding_px
+ self._verbose = verbose
+ self._device = device
+ self._preprocessed = False
+
+ # Class-specific Metadata
+ self._num_probes = num_probes
+ self._num_slices = num_slices
+ self._slice_thicknesses = slice_thicknesses
+ self._theta_x = theta_x
+ self._theta_y = theta_y
+
+ def _precompute_propagator_arrays(
+ self,
+ gpts: Tuple[int, int],
+ sampling: Tuple[float, float],
+ energy: float,
+ slice_thicknesses: Sequence[float],
+ theta_x: float,
+ theta_y: float,
+ ):
+ """
+ Precomputes propagator arrays complex wave-function will be convolved by,
+ for all slice thicknesses.
+
+ Parameters
+ ----------
+ gpts: Tuple[int,int]
+ Wavefunction pixel dimensions
+ sampling: Tuple[float,float]
+ Wavefunction sampling in A
+ energy: float
+ The electron energy of the wave functions in eV
+ slice_thicknesses: Sequence[float]
+ Array of slice thicknesses in A
+ theta_x: float
+ x tilt of propagator (in degrees)
+ theta_y: float
+ y tilt of propagator (in degrees)
+
+ Returns
+ -------
+ propagator_arrays: np.ndarray
+ (T,Sx,Sy) shape array storing propagator arrays
+ """
+ xp = self._xp
+
+ # Frequencies
+ kx, ky = spatial_frequencies(gpts, sampling)
+ kx = xp.asarray(kx, dtype=xp.float32)
+ ky = xp.asarray(ky, dtype=xp.float32)
+
+ # Propagators
+ wavelength = electron_wavelength_angstrom(energy)
+ num_slices = slice_thicknesses.shape[0]
+ propagators = xp.empty(
+ (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64
+ )
+
+ theta_x = np.deg2rad(theta_x)
+ theta_y = np.deg2rad(theta_y)
+
+ for i, dz in enumerate(slice_thicknesses):
+ propagators[i] = xp.exp(
+ 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz)
+ )
+ propagators[i] *= xp.exp(
+ 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)
+ )
+ propagators[i] *= xp.exp(
+ 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x))
+ )
+ propagators[i] *= xp.exp(
+ 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y))
+ )
+
+ return propagators
+
+ def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray):
+ """
+ Propagates array by Fourier convolving array with propagator_array.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ Wavefunction array to be convolved
+ propagator_array: np.ndarray
+ Propagator array to convolve array with
+
+ Returns
+ -------
+ propagated_array: np.ndarray
+ Fourier-convolved array
+ """
+ xp = self._xp
+
+ return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array)
+
+ def preprocess(
+ self,
+ diffraction_intensities_shape: Tuple[int, int] = None,
+ reshaping_method: str = "fourier",
+ probe_roi_shape: Tuple[int, int] = None,
+ dp_mask: np.ndarray = None,
+ fit_function: str = "plane",
+ plot_center_of_mass: str = "default",
+ plot_rotation: bool = True,
+ maximize_divergence: bool = False,
+ rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0),
+ plot_probe_overlaps: bool = True,
+ force_com_rotation: float = None,
+ force_com_transpose: float = None,
+ force_com_shifts: float = None,
+ force_scan_sampling: float = None,
+ force_angular_sampling: float = None,
+ force_reciprocal_sampling: float = None,
+ object_fov_mask: np.ndarray = None,
+ crop_patterns: bool = False,
+ **kwargs,
+ ):
+ """
+ Ptychographic preprocessing step.
+ Calls the base class methods:
+
+ _extract_intensities_and_calibrations_from_datacube,
+ _compute_center_of_mass(),
+ _solve_CoM_rotation(),
+ _normalize_diffraction_intensities()
+ _calculate_scan_positions_in_px()
+
+ Additionally, it initializes an (T,Px,Py) array of 1.0j
+ and a complex probe using the specified polar parameters.
+
+ Parameters
+ ----------
+ diffraction_intensities_shape: Tuple[int,int], optional
+ Pixel dimensions (Qx',Qy') of the resampled diffraction intensities
+ If None, no resampling of diffraction intenstities is performed
+ reshaping_method: str, optional
+ Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default)
+ probe_roi_shape, (int,int), optional
+ Padded diffraction intensities shape.
+ If None, no padding is performed
+ dp_mask: ndarray, optional
+ Mask for datacube intensities (Qx,Qy)
+ fit_function: str, optional
+ 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two'
+ plot_center_of_mass: str, optional
+ If 'default', the corrected CoM arrays will be displayed
+ If 'all', the computed and fitted CoM arrays will be displayed
+ plot_rotation: bool, optional
+ If True, the CoM curl minimization search result will be displayed
+ maximize_divergence: bool, optional
+ If True, the divergence of the CoM gradient vector field is maximized
+ rotation_angles_deg: np.darray, optional
+ Array of angles in degrees to perform curl minimization over
+ plot_probe_overlaps: bool, optional
+ If True, initial probe overlaps scanned over the object will be displayed
+ force_com_rotation: float (degrees), optional
+ Force relative rotation angle between real and reciprocal space
+ force_com_transpose: bool, optional
+ Force whether diffraction intensities need to be transposed.
+ force_com_shifts: tuple of ndarrays (CoMx, CoMy)
+ Amplitudes come from diffraction patterns shifted with
+ the CoM in the upper left corner for each probe unless
+ shift is overwritten.
+ force_scan_sampling: float, optional
+ Override DataCube real space scan pixel size calibrations, in Angstrom
+ force_angular_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in mrad
+ force_reciprocal_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in A^-1
+ object_fov_mask: np.ndarray (boolean)
+ Boolean mask of FOV. Used to calculate additional shrinkage of object
+ If None, probe_overlap intensity is thresholded
+ crop_patterns: bool
+ if True, crop patterns to avoid wrap around of patterns when centering
+
+ Returns
+ --------
+ self: MixedstateMultislicePtychographicReconstruction
+ Self to accommodate chaining
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # set additional metadata
+ self._diffraction_intensities_shape = diffraction_intensities_shape
+ self._reshaping_method = reshaping_method
+ self._probe_roi_shape = probe_roi_shape
+ self._dp_mask = dp_mask
+
+ if self._datacube is None:
+ raise ValueError(
+ (
+ "The preprocess() method requires a DataCube. "
+ "Please run ptycho.attach_datacube(DataCube) first."
+ )
+ )
+
+ (
+ self._datacube,
+ self._vacuum_probe_intensity,
+ self._dp_mask,
+ force_com_shifts,
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube,
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ dp_mask=self._dp_mask,
+ com_shifts=force_com_shifts,
+ )
+
+ self._intensities = self._extract_intensities_and_calibrations_from_datacube(
+ self._datacube,
+ require_calibrations=True,
+ force_scan_sampling=force_scan_sampling,
+ force_angular_sampling=force_angular_sampling,
+ force_reciprocal_sampling=force_reciprocal_sampling,
+ )
+
+ (
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_fitted_x,
+ self._com_fitted_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ ) = self._calculate_intensities_center_of_mass(
+ self._intensities,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts,
+ )
+
+ (
+ self._rotation_best_rad,
+ self._rotation_best_transpose,
+ self._com_x,
+ self._com_y,
+ self.com_x,
+ self.com_y,
+ ) = self._solve_for_center_of_mass_relative_rotation(
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ rotation_angles_deg=rotation_angles_deg,
+ plot_rotation=plot_rotation,
+ plot_center_of_mass=plot_center_of_mass,
+ maximize_divergence=maximize_divergence,
+ force_com_rotation=force_com_rotation,
+ force_com_transpose=force_com_transpose,
+ **kwargs,
+ )
+
+ (
+ self._amplitudes,
+ self._mean_diffraction_intensity,
+ ) = self._normalize_diffraction_intensities(
+ self._intensities,
+ self._com_fitted_x,
+ self._com_fitted_y,
+ crop_patterns,
+ self._positions_mask,
+ )
+
+ # explicitly delete namespace
+ self._num_diffraction_patterns = self._amplitudes.shape[0]
+ self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:])
+ del self._intensities
+
+ self._positions_px = self._calculate_scan_positions_in_pixels(
+ self._scan_positions, self._positions_mask
+ )
+
+ # handle semiangle specified in pixels
+ if self._semiangle_cutoff_pixels:
+ self._semiangle_cutoff = (
+ self._semiangle_cutoff_pixels * self._angular_sampling[0]
+ )
+
+ # Object Initialization
+ if self._object is None:
+ pad_x = self._object_padding_px[0][1]
+ pad_y = self._object_padding_px[1][1]
+ p, q = np.round(np.max(self._positions_px, axis=0))
+ p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype(
+ "int"
+ )
+ q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype(
+ "int"
+ )
+ if self._object_type == "potential":
+ self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32)
+ elif self._object_type == "complex":
+ self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64)
+ else:
+ if self._object_type == "potential":
+ self._object = xp.asarray(self._object, dtype=xp.float32)
+ elif self._object_type == "complex":
+ self._object = xp.asarray(self._object, dtype=xp.complex64)
+
+ self._object_initial = self._object.copy()
+ self._object_type_initial = self._object_type
+ self._object_shape = self._object.shape[-2:]
+
+ self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32)
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+
+ self._positions_px_initial = self._positions_px.copy()
+ self._positions_initial = self._positions_px_initial.copy()
+ self._positions_initial[:, 0] *= self.sampling[0]
+ self._positions_initial[:, 1] *= self.sampling[1]
+
+ # Vectorized Patches
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ # Probe Initialization
+ if self._probe is None or isinstance(self._probe, ComplexProbe):
+ if self._probe is None:
+ if self._vacuum_probe_intensity is not None:
+ self._semiangle_cutoff = np.inf
+ self._vacuum_probe_intensity = xp.asarray(
+ self._vacuum_probe_intensity, dtype=xp.float32
+ )
+ probe_x0, probe_y0 = get_CoM(
+ self._vacuum_probe_intensity,
+ device=self._device,
+ )
+ self._vacuum_probe_intensity = get_shifted_ar(
+ self._vacuum_probe_intensity,
+ -probe_x0,
+ -probe_y0,
+ bilinear=True,
+ device=self._device,
+ )
+ if crop_patterns:
+ self._vacuum_probe_intensity = self._vacuum_probe_intensity[
+ self._crop_mask
+ ].reshape(self._region_of_interest_shape)
+ _probe = (
+ ComplexProbe(
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ energy=self._energy,
+ semiangle_cutoff=self._semiangle_cutoff,
+ rolloff=self._rolloff,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )
+ .build()
+ ._array
+ )
+
+ else:
+ if self._probe._gpts != self._region_of_interest_shape:
+ raise ValueError()
+ if hasattr(self._probe, "_array"):
+ _probe = self._probe._array
+ else:
+ self._probe._xp = xp
+ _probe = self._probe.build()._array
+
+ self._probe = xp.zeros(
+ (self._num_probes,) + tuple(self._region_of_interest_shape),
+ dtype=xp.complex64,
+ )
+ sx, sy = self._region_of_interest_shape
+ self._probe[0] = _probe
+
+ # Randomly shift phase of other probes
+ for i_probe in range(1, self._num_probes):
+ shift_x = xp.exp(
+ -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx)
+ )
+ shift_y = xp.exp(
+ -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy)
+ )
+ self._probe[i_probe] = (
+ self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None]
+ )
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2)
+ self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity)
+
+ else:
+ self._probe = xp.asarray(self._probe, dtype=xp.complex64)
+
+ self._probe_initial = self._probe.copy()
+ self._probe_initial_aperture = None # Doesn't really make sense for mixed-state
+
+ self._known_aberrations_array = ComplexProbe(
+ energy=self._energy,
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )._evaluate_ctf()
+
+ # Precomputed propagator arrays
+ self._propagator_arrays = self._precompute_propagator_arrays(
+ self._region_of_interest_shape,
+ self.sampling,
+ self._energy,
+ self._slice_thicknesses,
+ self._theta_x,
+ self._theta_y,
+ )
+
+ # overlaps
+ shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp)
+ probe_intensities = xp.abs(shifted_probes) ** 2
+ probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities)
+ probe_overlap = self._gaussian_filter(probe_overlap, 1.0)
+
+ if object_fov_mask is None:
+ self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max())
+ else:
+ self._object_fov_mask = np.asarray(object_fov_mask)
+ self._object_fov_mask_inverse = np.invert(self._object_fov_mask)
+
+ if plot_probe_overlaps:
+ figsize = kwargs.pop("figsize", (13, 4))
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ # initial probe
+ complex_probe_rgb = Complex2RGB(
+ self.probe_centered[0],
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ # propagated
+ propagated_probe = self._probe[0].copy()
+
+ for s in range(self._num_slices - 1):
+ propagated_probe = self._propagate_array(
+ propagated_probe, self._propagator_arrays[s]
+ )
+ complex_propagated_rgb = Complex2RGB(
+ asnumpy(self._return_centered_probe(propagated_probe)),
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ extent = [
+ 0,
+ self.sampling[1] * self._object_shape[1],
+ self.sampling[0] * self._object_shape[0],
+ 0,
+ ]
+
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)
+
+ ax1.imshow(
+ complex_probe_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax1)
+ cax1 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(
+ cax1,
+ chroma_boost=chroma_boost,
+ )
+ ax1.set_ylabel("x [A]")
+ ax1.set_xlabel("y [A]")
+ ax1.set_title("Initial probe[0] intensity")
+
+ ax2.imshow(
+ complex_propagated_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax2)
+ cax2 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(cax2, chroma_boost=chroma_boost)
+ ax2.set_ylabel("x [A]")
+ ax2.set_xlabel("y [A]")
+ ax2.set_title("Propagated probe[0] intensity")
+
+ ax3.imshow(
+ asnumpy(probe_overlap),
+ extent=extent,
+ cmap="Greys_r",
+ )
+ ax3.scatter(
+ self.positions[:, 1],
+ self.positions[:, 0],
+ s=2.5,
+ color=(1, 0, 0, 1),
+ )
+ ax3.set_ylabel("x [A]")
+ ax3.set_xlabel("y [A]")
+ ax3.set_xlim((extent[0], extent[1]))
+ ax3.set_ylim((extent[2], extent[3]))
+ ax3.set_title("Object field of view")
+
+ fig.tight_layout()
+
+ self._preprocessed = True
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _overlap_projection(self, current_object, current_probe):
+ """
+ Ptychographic overlap projection method.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ """
+
+ xp = self._xp
+
+ if self._object_type == "potential":
+ complex_object = xp.exp(1j * current_object)
+ else:
+ complex_object = current_object
+
+ object_patches = complex_object[
+ :,
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ]
+
+ num_probe_positions = object_patches.shape[1]
+
+ propagated_shape = (
+ self._num_slices,
+ num_probe_positions,
+ self._num_probes,
+ self._region_of_interest_shape[0],
+ self._region_of_interest_shape[1],
+ )
+ propagated_probes = xp.empty(propagated_shape, dtype=object_patches.dtype)
+ propagated_probes[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes = (
+ xp.expand_dims(object_patches[s], axis=1) * propagated_probes[s]
+ )
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes[s + 1] = self._propagate_array(
+ transmitted_probes, self._propagator_arrays[s]
+ )
+
+ return propagated_probes, object_patches, transmitted_probes
+
+ def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes):
+ """
+ Ptychographic fourier projection method for GD method.
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Exit wave difference
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ fourier_exit_waves = xp.fft.fft2(transmitted_probes)
+ intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1))
+ error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2)
+
+ intensity_norm[intensity_norm == 0.0] = np.inf
+ amplitude_modification = amplitudes / intensity_norm
+
+ fourier_modified_overlap = amplitude_modification[:, None] * fourier_exit_waves
+ modified_exit_wave = xp.fft.ifft2(fourier_modified_overlap)
+
+ exit_waves = modified_exit_wave - transmitted_probes
+
+ return exit_waves, error
+
+ def _projection_sets_fourier_projection(
+ self,
+ amplitudes,
+ transmitted_probes,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic fourier projection method for DM_AP and RAAR methods.
+ Generalized projection using three parameters: a,b,c
+
+ DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha
+ DM: DM_AP(1.0), AP: DM_AP(0.0)
+
+ RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2
+ DM : RAAR(1.0)
+
+ RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2
+ DM: RRR(1.0)
+
+ SUPERFLIP : a = 0, b = 1, c = 2
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit wave difference
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ projection_x = 1 - projection_a - projection_b
+ projection_y = 1 - projection_c
+
+ if exit_waves is None:
+ exit_waves = transmitted_probes.copy()
+
+ fourier_exit_waves = xp.fft.fft2(transmitted_probes)
+ intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_exit_waves) ** 2, axis=1))
+ error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2)
+
+ factor_to_be_projected = (
+ projection_c * transmitted_probes + projection_y * exit_waves
+ )
+ fourier_projected_factor = xp.fft.fft2(factor_to_be_projected)
+
+ intensity_norm_projected = xp.sqrt(
+ xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1)
+ )
+ intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf
+
+ amplitude_modification = amplitudes / intensity_norm_projected
+ fourier_projected_factor *= amplitude_modification[:, None]
+
+ projected_factor = xp.fft.ifft2(fourier_projected_factor)
+
+ exit_waves = (
+ projection_x * exit_waves
+ + projection_a * transmitted_probes
+ + projection_b * projected_factor
+ )
+
+ return exit_waves, error
+
+ def _forward(
+ self,
+ current_object,
+ current_probe,
+ amplitudes,
+ exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic forward operator.
+ Calls _overlap_projection() and the appropriate _fourier_projection().
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ (
+ propagated_probes,
+ object_patches,
+ transmitted_probes,
+ ) = self._overlap_projection(current_object, current_probe)
+
+ if use_projection_scheme:
+ exit_waves, error = self._projection_sets_fourier_projection(
+ amplitudes,
+ transmitted_probes,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ else:
+ exit_waves, error = self._gradient_descent_fourier_projection(
+ amplitudes, transmitted_probes
+ )
+
+ return propagated_probes, object_patches, transmitted_probes, exit_waves, error
+
+ def _gradient_descent_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ for s in reversed(range(self._num_slices)):
+ probe = propagated_probes[s]
+ obj = object_patches[s]
+
+ # object-update
+ probe_normalization = xp.zeros_like(current_object[s])
+ object_update = xp.zeros_like(current_object[s])
+
+ for i_probe in range(self._num_probes):
+ probe_normalization += self._sum_overlapping_patches_bincounts(
+ xp.abs(probe[:, i_probe]) ** 2
+ )
+
+ if self._object_type == "potential":
+ object_update += (
+ step_size
+ * self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * xp.conj(obj)
+ * xp.conj(probe[:, i_probe])
+ * exit_waves[:, i_probe]
+ )
+ )
+ )
+ else:
+ object_update += (
+ step_size
+ * self._sum_overlapping_patches_bincounts(
+ xp.conj(probe[:, i_probe]) * exit_waves[:, i_probe]
+ )
+ )
+
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ current_object[s] += object_update * probe_normalization
+
+ # back-transmit
+ exit_waves *= xp.expand_dims(xp.conj(obj), axis=1) # / xp.abs(obj) ** 2
+
+ if s > 0:
+ # back-propagate
+ exit_waves = self._propagate_array(
+ exit_waves, xp.conj(self._propagator_arrays[s - 1])
+ )
+ elif not fix_probe:
+ # probe-update
+ object_normalization = xp.sum(
+ (xp.abs(obj) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe += (
+ step_size
+ * xp.sum(
+ exit_waves,
+ axis=0,
+ )
+ * object_normalization[None]
+ )
+
+ return current_object, current_probe
+
+ def _projection_sets_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for DM_AP and RAAR methods.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ # careful not to modify exit_waves in-place for projection set methods
+ exit_waves_copy = exit_waves.copy()
+ for s in reversed(range(self._num_slices)):
+ probe = propagated_probes[s]
+ obj = object_patches[s]
+
+ # object-update
+ probe_normalization = xp.zeros_like(current_object[s])
+ object_update = xp.zeros_like(current_object[s])
+
+ for i_probe in range(self._num_probes):
+ probe_normalization += self._sum_overlapping_patches_bincounts(
+ xp.abs(probe[:, i_probe]) ** 2
+ )
+
+ if self._object_type == "potential":
+ object_update += self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * xp.conj(obj)
+ * xp.conj(probe[:, i_probe])
+ * exit_waves_copy[:, i_probe]
+ )
+ )
+ else:
+ object_update += self._sum_overlapping_patches_bincounts(
+ xp.conj(probe[:, i_probe]) * exit_waves_copy[:, i_probe]
+ )
+
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ current_object[s] = object_update * probe_normalization
+
+ # back-transmit
+ exit_waves_copy *= xp.expand_dims(
+ xp.conj(obj), axis=1
+ ) # / xp.abs(obj) ** 2
+
+ if s > 0:
+ # back-propagate
+ exit_waves_copy = self._propagate_array(
+ exit_waves_copy, xp.conj(self._propagator_arrays[s - 1])
+ )
+
+ elif not fix_probe:
+ # probe-update
+ object_normalization = xp.sum(
+ (xp.abs(obj) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe = (
+ xp.sum(
+ exit_waves_copy,
+ axis=0,
+ )
+ * object_normalization[None]
+ )
+
+ return current_object, current_probe
+
+ def _adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ use_projection_scheme: bool,
+ step_size: float,
+ normalization_min: float,
+ fix_probe: bool,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+
+ if use_projection_scheme:
+ current_object, current_probe = self._projection_sets_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ )
+ else:
+ current_object, current_probe = self._gradient_descent_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ )
+
+ return current_object, current_probe
+
+ def _position_correction(
+ self,
+ current_object,
+ current_probe,
+ transmitted_probes,
+ amplitudes,
+ current_positions,
+ positions_step_size,
+ constrain_position_distance,
+ ):
+ """
+ Position correction using estimated intensity gradient.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe:np.ndarray
+ fractionally-shifted probes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ amplitudes: np.ndarray
+ Measured amplitudes
+ current_positions: np.ndarray
+ Current positions estimate
+ positions_step_size: float
+ Positions step size
+ constrain_position_distance: float
+ Distance to constrain position correction within original
+ field of view in A
+
+ Returns
+ --------
+ updated_positions: np.ndarray
+ Updated positions estimate
+ """
+
+ xp = self._xp
+
+ # Intensity gradient
+ exit_waves_fft = xp.fft.fft2(transmitted_probes)
+ exit_waves_fft_conj = xp.conj(exit_waves_fft)
+ estimated_intensity = xp.abs(exit_waves_fft) ** 2
+ measured_intensity = amplitudes**2
+
+ flat_shape = (transmitted_probes.shape[0], -1)
+ difference_intensity = (measured_intensity - estimated_intensity).reshape(
+ flat_shape
+ )
+
+ # Computing perturbed exit waves one at a time to save on memory
+
+ if self._object_type == "potential":
+ complex_object = xp.exp(1j * current_object)
+ else:
+ complex_object = current_object
+
+ # dx
+ obj_rolled_patches = complex_object[
+ :,
+ (self._vectorized_patch_indices_row + 1) % self._object_shape[0],
+ self._vectorized_patch_indices_col,
+ ]
+
+ propagated_probes_perturbed = xp.empty_like(obj_rolled_patches)
+ propagated_probes_perturbed[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes_perturbed = (
+ obj_rolled_patches[s] * propagated_probes_perturbed[s]
+ )
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes_perturbed[s + 1] = self._propagate_array(
+ transmitted_probes_perturbed, self._propagator_arrays[s]
+ )
+
+ exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed)
+
+ # dy
+ obj_rolled_patches = complex_object[
+ :,
+ self._vectorized_patch_indices_row,
+ (self._vectorized_patch_indices_col + 1) % self._object_shape[1],
+ ]
+
+ propagated_probes_perturbed = xp.empty_like(obj_rolled_patches)
+ propagated_probes_perturbed[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes_perturbed = (
+ obj_rolled_patches[s] * propagated_probes_perturbed[s]
+ )
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes_perturbed[s + 1] = self._propagate_array(
+ transmitted_probes_perturbed, self._propagator_arrays[s]
+ )
+
+ exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed)
+
+ partial_intensity_dx = 2 * xp.real(
+ exit_waves_dx_fft * exit_waves_fft_conj
+ ).reshape(flat_shape)
+ partial_intensity_dy = 2 * xp.real(
+ exit_waves_dy_fft * exit_waves_fft_conj
+ ).reshape(flat_shape)
+
+ coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy))
+
+ # positions_update = xp.einsum(
+ # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity
+ # )
+
+ coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2)
+ positions_update = (
+ xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix)
+ @ coefficients_matrix_T
+ @ difference_intensity[..., None]
+ )
+
+ if constrain_position_distance is not None:
+ constrain_position_distance /= xp.sqrt(
+ self.sampling[0] ** 2 + self.sampling[1] ** 2
+ )
+ x1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 0
+ ]
+ y1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 1
+ ]
+ x0 = self._positions_px_initial[:, 0]
+ y0 = self._positions_px_initial[:, 1]
+ if self._rotation_best_transpose:
+ x0, y0 = xp.array([y0, x0])
+ x1, y1 = xp.array([y1, x1])
+
+ if self._rotation_best_rad is not None:
+ rotation_angle = self._rotation_best_rad
+ x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin(
+ -rotation_angle
+ ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle)
+ x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin(
+ -rotation_angle
+ ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle)
+
+ outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + (
+ x1 < (xp.min(x0) - constrain_position_distance)
+ ) + (y1 > (xp.max(y0) + constrain_position_distance)) + (
+ y1 < (xp.min(y0) - constrain_position_distance)
+ ) > 0
+
+ positions_update[..., 0][outlier_ind] = 0
+
+ current_positions -= positions_step_size * positions_update[..., 0]
+
+ return current_positions
+
+ def _probe_center_of_mass_constraint(self, current_probe):
+ """
+ Ptychographic center of mass constraint.
+ Used for centering corner-centered probe intensity.
+
+ Parameters
+ --------
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ """
+ xp = self._xp
+ probe_intensity = xp.abs(current_probe[0]) ** 2
+
+ probe_x0, probe_y0 = get_CoM(
+ probe_intensity, device=self._device, corner_centered=True
+ )
+ shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp)
+
+ return shifted_probe
+
+ def _probe_orthogonalization_constraint(self, current_probe):
+ """
+ Ptychographic probe-orthogonalization constraint.
+ Used to ensure mixed states are orthogonal to each other.
+ Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690
+
+ Parameters
+ --------
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ constrained_probe: np.ndarray
+ Orthogonalized probe estimate
+ """
+ xp = self._xp
+ n_probes = self._num_probes
+
+ # compute upper half of P* @ P
+ pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype)
+
+ for i in range(n_probes):
+ for j in range(i, n_probes):
+ pairwise_dot_product[i, j] = xp.sum(
+ current_probe[i].conj() * current_probe[j]
+ )
+
+ # compute eigenvectors (effectively cheaper way of computing V* from SVD)
+ _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U")
+ current_probe = xp.tensordot(evecs.T, current_probe, axes=1)
+
+ # sort by real-space intensity
+ intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1))
+ intensities_order = xp.argsort(intensities, axis=None)[::-1]
+ return current_probe[intensities_order]
+
+ def _object_butterworth_constraint(
+ self, current_object, q_lowpass, q_highpass, butterworth_order
+ ):
+ """
+ 2D Butterworth filter
+ Used for low/high-pass filtering object.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+ qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0])
+ qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1])
+ qya, qxa = xp.meshgrid(qy, qx)
+ qra = xp.sqrt(qxa**2 + qya**2)
+
+ env = xp.ones_like(qra)
+ if q_highpass:
+ env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order))
+ if q_lowpass:
+ env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order))
+
+ current_object_mean = xp.mean(current_object)
+ current_object -= current_object_mean
+ current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None])
+ current_object += current_object_mean
+
+ if self._object_type == "potential":
+ current_object = xp.real(current_object)
+
+ return current_object
+
+ def _object_kz_regularization_constraint(
+ self, current_object, kz_regularization_gamma
+ ):
+ """
+ Arctan regularization filter
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ kz_regularization_gamma: float
+ Slice regularization strength
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+
+ current_object = xp.pad(
+ current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant"
+ )
+
+ qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0])
+ qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1])
+ qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0])
+
+ kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0]
+
+ qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij")
+ qz2 = qza**2 * kz_regularization_gamma**2
+ qr2 = qxa**2 + qya**2
+
+ w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2)
+
+ current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w)
+ current_object = current_object[1:]
+
+ if self._object_type == "potential":
+ current_object = xp.real(current_object)
+
+ return current_object
+
+ def _object_identical_slices_constraint(self, current_object):
+ """
+ Strong regularization forcing all slices to be identical
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ object_mean = current_object.mean(0, keepdims=True)
+ current_object[:] = object_mean
+
+ return current_object
+
+ def _object_denoise_tv_pylops(self, current_object, weights, iterations):
+ """
+ Performs second order TV denoising along x and y
+
+ Parameters
+ ----------
+ current_object: np.ndarray
+ Current object estimate
+ weights : [float, float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ iterations: float
+ Number of iterations to run in denoising algorithm.
+ `niter_out` in pylops
+
+ Returns
+ -------
+ constrained_object: np.ndarray
+ Constrained object estimate
+
+ """
+ xp = self._xp
+
+ if xp.iscomplexobj(current_object):
+ current_object_tv = current_object
+ warnings.warn(
+ ("TV denoising is currently only supported for potential objects."),
+ UserWarning,
+ )
+
+ else:
+ # zero pad at top and bottom slice
+ pad_width = ((1, 1), (0, 0), (0, 0))
+ current_object = xp.pad(
+ current_object, pad_width=pad_width, mode="constant"
+ )
+
+ # run tv denoising
+ nz, nx, ny = current_object.shape
+ niter_out = iterations
+ niter_in = 1
+ Iop = pylops.Identity(nx * ny * nz)
+
+ if weights[0] == 0:
+ xy_laplacian = pylops.Laplacian(
+ (nz, nx, ny), axes=(1, 2), edge=False, kind="backward"
+ )
+ l1_regs = [xy_laplacian]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=[weights[1]],
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ elif weights[1] == 0:
+ z_gradient = pylops.FirstDerivative(
+ (nz, nx, ny), axis=0, edge=False, kind="backward"
+ )
+ l1_regs = [z_gradient]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=[weights[0]],
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ else:
+ z_gradient = pylops.FirstDerivative(
+ (nz, nx, ny), axis=0, edge=False, kind="backward"
+ )
+ xy_laplacian = pylops.Laplacian(
+ (nz, nx, ny), axes=(1, 2), edge=False, kind="backward"
+ )
+ l1_regs = [z_gradient, xy_laplacian]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=weights,
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ # remove padding
+ current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1]
+
+ return current_object_tv
+
+ def _constraints(
+ self,
+ current_object,
+ current_probe,
+ current_positions,
+ fix_com,
+ fit_probe_aberrations,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ constrain_probe_amplitude,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ fix_probe_aperture,
+ initial_probe_aperture,
+ fix_positions,
+ global_affine_transformation,
+ gaussian_filter,
+ gaussian_filter_sigma,
+ butterworth_filter,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ kz_regularization_filter,
+ kz_regularization_gamma,
+ identical_slices,
+ object_positivity,
+ shrinkage_rad,
+ object_mask,
+ pure_phase_object,
+ tv_denoise_chambolle,
+ tv_denoise_weight_chambolle,
+ tv_denoise_pad_chambolle,
+ tv_denoise,
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ orthogonalize_probe,
+ ):
+ """
+ Ptychographic constraints operator.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ current_positions: np.ndarray
+ Current positions estimate
+ fix_com: bool
+ If True, probe CoM is fixed to the center
+ fit_probe_aberrations: bool
+ If True, fits the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ constrain_probe_amplitude: bool
+ If True, probe amplitude is constrained by top hat function
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude: bool
+ If True, probe aperture is constrained by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_probe_aperture: bool
+ If True, probe Fourier amplitude is replaced by initial_probe_aperture
+ initial_probe_aperture: np.ndarray
+ Initial probe aperture to use in replacing probe Fourier amplitude
+ fix_positions: bool
+ If True, positions are not updated
+ gaussian_filter: bool
+ If True, applies real-space gaussian filter in A
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel
+ butterworth_filter: bool
+ If True, applies fourier-space butterworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ kz_regularization_filter: bool
+ If True, applies fourier-space arctan regularization filter
+ kz_regularization_gamma: float
+ Slice regularization strength
+ identical_slices: bool
+ If True, forces all object slices to be identical
+ object_positivity: bool
+ If True, forces object to be positive
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ object_mask: np.ndarray (boolean)
+ If not None, used to calculate additional shrinkage using masked-mean of object
+ pure_phase_object: bool
+ If True, object amplitude is set to unity
+ tv_denoise_chambolle: bool
+ If True, performs TV denoising along z
+ tv_denoise_weight_chambolle: float
+ weight of tv denoising constraint
+ tv_denoise_pad_chambolle: bool
+ if True, pads object at top and bottom with zeros before applying denoising
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weights: [float,float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ orthogonalize_probe: bool
+ If True, probe will be orthogonalized
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ constrained_positions: np.ndarray
+ Constrained positions estimate
+ """
+
+ if gaussian_filter:
+ current_object = self._object_gaussian_constraint(
+ current_object, gaussian_filter_sigma, pure_phase_object
+ )
+
+ if butterworth_filter:
+ current_object = self._object_butterworth_constraint(
+ current_object,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ )
+
+ if identical_slices:
+ current_object = self._object_identical_slices_constraint(current_object)
+ elif kz_regularization_filter:
+ current_object = self._object_kz_regularization_constraint(
+ current_object, kz_regularization_gamma
+ )
+ elif tv_denoise:
+ current_object = self._object_denoise_tv_pylops(
+ current_object,
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ )
+ elif tv_denoise_chambolle:
+ current_object = self._object_denoise_tv_chambolle(
+ current_object,
+ tv_denoise_weight_chambolle,
+ axis=0,
+ pad_object=tv_denoise_pad_chambolle,
+ )
+
+ if shrinkage_rad > 0.0 or object_mask is not None:
+ current_object = self._object_shrinkage_constraint(
+ current_object,
+ shrinkage_rad,
+ object_mask,
+ )
+
+ if self._object_type == "complex":
+ current_object = self._object_threshold_constraint(
+ current_object, pure_phase_object
+ )
+ elif object_positivity:
+ current_object = self._object_positivity_constraint(current_object)
+
+ if fix_com:
+ current_probe = self._probe_center_of_mass_constraint(current_probe)
+
+ # These constraints don't _really_ make sense for mixed-state
+ if fix_probe_aperture:
+ raise NotImplementedError()
+ elif constrain_probe_fourier_amplitude:
+ raise NotImplementedError()
+ if fit_probe_aberrations:
+ raise NotImplementedError()
+ if constrain_probe_amplitude:
+ raise NotImplementedError()
+
+ if orthogonalize_probe:
+ current_probe = self._probe_orthogonalization_constraint(current_probe)
+
+ if not fix_positions:
+ current_positions = self._positions_center_of_mass_constraint(
+ current_positions
+ )
+
+ if global_affine_transformation:
+ current_positions = self._positions_affine_transformation_constraint(
+ self._positions_px_initial, current_positions
+ )
+
+ return current_object, current_probe, current_positions
+
+ def reconstruct(
+ self,
+ max_iter: int = 64,
+ reconstruction_method: str = "gradient-descent",
+ reconstruction_parameter: float = 1.0,
+ reconstruction_parameter_a: float = None,
+ reconstruction_parameter_b: float = None,
+ reconstruction_parameter_c: float = None,
+ max_batch_size: int = None,
+ seed_random: int = None,
+ step_size: float = 0.5,
+ normalization_min: float = 1,
+ positions_step_size: float = 0.9,
+ fix_com: bool = True,
+ orthogonalize_probe: bool = True,
+ fix_probe_iter: int = 0,
+ fix_probe_aperture_iter: int = 0,
+ constrain_probe_amplitude_iter: int = 0,
+ constrain_probe_amplitude_relative_radius: float = 0.5,
+ constrain_probe_amplitude_relative_width: float = 0.05,
+ constrain_probe_fourier_amplitude_iter: int = 0,
+ constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0,
+ constrain_probe_fourier_amplitude_constant_intensity: bool = False,
+ fix_positions_iter: int = np.inf,
+ constrain_position_distance: float = None,
+ global_affine_transformation: bool = True,
+ gaussian_filter_sigma: float = None,
+ gaussian_filter_iter: int = np.inf,
+ fit_probe_aberrations_iter: int = 0,
+ fit_probe_aberrations_max_angular_order: int = 4,
+ fit_probe_aberrations_max_radial_order: int = 4,
+ butterworth_filter_iter: int = np.inf,
+ q_lowpass: float = None,
+ q_highpass: float = None,
+ butterworth_order: float = 2,
+ kz_regularization_filter_iter: int = np.inf,
+ kz_regularization_gamma: Union[float, np.ndarray] = None,
+ identical_slices_iter: int = 0,
+ object_positivity: bool = True,
+ shrinkage_rad: float = 0.0,
+ fix_potential_baseline: bool = True,
+ pure_phase_object_iter: int = 0,
+ tv_denoise_iter_chambolle=np.inf,
+ tv_denoise_weight_chambolle=None,
+ tv_denoise_pad_chambolle=True,
+ tv_denoise_iter=np.inf,
+ tv_denoise_weights=None,
+ tv_denoise_inner_iter=40,
+ switch_object_iter: int = np.inf,
+ store_iterations: bool = False,
+ progress_bar: bool = True,
+ reset: bool = None,
+ ):
+ """
+ Ptychographic reconstruction main method.
+
+ Parameters
+ --------
+ max_iter: int, optional
+ Maximum number of iterations to run
+ reconstruction_method: str, optional
+ Specifies which reconstruction algorithm to use, one of:
+ "generalized-projections",
+ "DM_AP" (or "difference-map_alternating-projections"),
+ "RAAR" (or "relaxed-averaged-alternating-reflections"),
+ "RRR" (or "relax-reflect-reflect"),
+ "SUPERFLIP" (or "charge-flipping"), or
+ "GD" (or "gradient_descent")
+ reconstruction_parameter: float, optional
+ Reconstruction parameter for various reconstruction methods above.
+ reconstruction_parameter_a: float, optional
+ Reconstruction parameter a for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_b: float, optional
+ Reconstruction parameter b for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_c: float, optional
+ Reconstruction parameter c for reconstruction_method='generalized-projections'.
+ max_batch_size: int, optional
+ Max number of probes to update at once
+ seed_random: int, optional
+ Seeds the random number generator, only applicable when max_batch_size is not None
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ positions_step_size: float, optional
+ Positions update step size
+ fix_com: bool, optional
+ If True, fixes center of mass of probe
+ fix_probe_iter: int, optional
+ Number of iterations to run with a fixed probe before updating probe estimate
+ fix_probe_aperture_iter: int, optional
+ Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate
+ constrain_probe_amplitude_iter: int, optional
+ Number of iterations to run while constraining the real-space probe with a top-hat support.
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude_iter: int, optional
+ Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_positions_iter: int, optional
+ Number of iterations to run with fixed positions before updating positions estimate
+ global_affine_transformation: bool, optional
+ If True, positions are assumed to be a global affine transform from initial scan
+ gaussian_filter_sigma: float, optional
+ Standard deviation of gaussian kernel in A
+ gaussian_filter_iter: int, optional
+ Number of iterations to run using object smoothness constraint
+ fit_probe_aberrations_iter: int, optional
+ Number of iterations to run while fitting the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ butterworth_filter_iter: int, optional
+ Number of iterations to run using high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ kz_regularization_filter_iter: int, optional
+ Number of iterations to run using kz regularization filter
+ kz_regularization_gamma, float, optional
+ kz regularization strength
+ identical_slices_iter: int, optional
+ Number of iterations to run using identical slices
+ object_positivity: bool, optional
+ If True, forces object to be positive
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ fix_potential_baseline: bool
+ If true, the potential mean outside the FOV is forced to zero at each iteration
+ pure_phase_object_iter: int, optional
+ Number of iterations where object amplitude is set to unity
+ tv_denoise_iter_chambolle: bool
+ Number of iterations with TV denoisining
+ tv_denoise_weight_chambolle: float
+ weight of tv denoising constraint
+ tv_denoise_pad_chambolle: bool
+ if True, pads object at top and bottom with zeros before applying denoising
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weights: [float,float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ switch_object_iter: int, optional
+ Iteration to switch object type between 'complex' and 'potential' or between
+ 'potential' and 'complex'
+ store_iterations: bool, optional
+ If True, reconstructed objects and probes are stored at each iteration
+ progress_bar: bool, optional
+ If True, reconstruction progress is displayed
+ reset: bool, optional
+ If True, previous reconstructions are ignored
+
+ Returns
+ --------
+ self: MultislicePtychographicReconstruction
+ Self to accommodate chaining
+ """
+ asnumpy = self._asnumpy
+ xp = self._xp
+
+ # Reconstruction method
+
+ if reconstruction_method == "generalized-projections":
+ if (
+ reconstruction_parameter_a is None
+ or reconstruction_parameter_b is None
+ or reconstruction_parameter_c is None
+ ):
+ raise ValueError(
+ (
+ "reconstruction_parameter_a/b/c must all be specified "
+ "when using reconstruction_method='generalized-projections'."
+ )
+ )
+
+ use_projection_scheme = True
+ projection_a = reconstruction_parameter_a
+ projection_b = reconstruction_parameter_b
+ projection_c = reconstruction_parameter_c
+ step_size = None
+ elif (
+ reconstruction_method == "DM_AP"
+ or reconstruction_method == "difference-map_alternating-projections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = 1
+ projection_c = 1 + reconstruction_parameter
+ step_size = None
+ elif (
+ reconstruction_method == "RAAR"
+ or reconstruction_method == "relaxed-averaged-alternating-reflections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = 1 - 2 * reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "RRR"
+ or reconstruction_method == "relax-reflect-reflect"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0:
+ raise ValueError("reconstruction_parameter must be between 0-2.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "SUPERFLIP"
+ or reconstruction_method == "charge-flipping"
+ ):
+ use_projection_scheme = True
+ projection_a = 0
+ projection_b = 1
+ projection_c = 2
+ reconstruction_parameter = None
+ step_size = None
+ elif (
+ reconstruction_method == "GD" or reconstruction_method == "gradient-descent"
+ ):
+ use_projection_scheme = False
+ projection_a = None
+ projection_b = None
+ projection_c = None
+ reconstruction_parameter = None
+ else:
+ raise ValueError(
+ (
+ "reconstruction_method must be one of 'generalized-projections', "
+ "'DM_AP' (or 'difference-map_alternating-projections'), "
+ "'RAAR' (or 'relaxed-averaged-alternating-reflections'), "
+ "'RRR' (or 'relax-reflect-reflect'), "
+ "'SUPERFLIP' (or 'charge-flipping'), "
+ f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}."
+ )
+ )
+
+ if self._verbose:
+ if switch_object_iter > max_iter:
+ first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, "
+ else:
+ switch_object_type = (
+ "complex" if self._object_type == "potential" else "potential"
+ )
+ first_line = (
+ f"Performing {switch_object_iter} iterations using a {self._object_type} object type and "
+ f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, "
+ )
+ if max_batch_size is not None:
+ if use_projection_scheme:
+ raise ValueError(
+ (
+ "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. "
+ "Use reconstruction_method='GD' or set max_batch_size=None."
+ )
+ )
+ else:
+ print(
+ (
+ first_line + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}, "
+ f"in batches of max {max_batch_size} measurements."
+ )
+ )
+
+ else:
+ if reconstruction_parameter is not None:
+ if np.array(reconstruction_parameter).shape == (3,):
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}."
+ )
+ )
+ else:
+ if step_size is not None:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}."
+ )
+ )
+
+ # Batching
+ shuffled_indices = np.arange(self._num_diffraction_patterns)
+ unshuffled_indices = np.zeros_like(shuffled_indices)
+
+ if max_batch_size is not None:
+ xp.random.seed(seed_random)
+ else:
+ max_batch_size = self._num_diffraction_patterns
+
+ # initialization
+ if store_iterations and (not hasattr(self, "object_iterations") or reset):
+ self.object_iterations = []
+ self.probe_iterations = []
+
+ if reset:
+ self.error_iterations = []
+ self._object = self._object_initial.copy()
+ self._probe = self._probe_initial.copy()
+ self._positions_px = self._positions_px_initial.copy()
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ self._exit_waves = None
+ self._object_type = self._object_type_initial
+ if hasattr(self, "_tf"):
+ del self._tf
+ elif reset is None:
+ if hasattr(self, "error"):
+ warnings.warn(
+ (
+ "Continuing reconstruction from previous result. "
+ "Use reset=True for a fresh start."
+ ),
+ UserWarning,
+ )
+ else:
+ self.error_iterations = []
+ self._exit_waves = None
+
+ # main loop
+ for a0 in tqdmnd(
+ max_iter,
+ desc="Reconstructing object and probe",
+ unit=" iter",
+ disable=not progress_bar,
+ ):
+ error = 0.0
+
+ if a0 == switch_object_iter:
+ if self._object_type == "potential":
+ self._object_type = "complex"
+ self._object = xp.exp(1j * self._object)
+ elif self._object_type == "complex":
+ self._object_type = "potential"
+ self._object = xp.angle(self._object)
+
+ # randomize
+ if not use_projection_scheme:
+ np.random.shuffle(shuffled_indices)
+ unshuffled_indices[shuffled_indices] = np.arange(
+ self._num_diffraction_patterns
+ )
+ positions_px = self._positions_px.copy()[shuffled_indices]
+
+ for start, end in generate_batches(
+ self._num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ amplitudes = self._amplitudes[shuffled_indices[start:end]]
+
+ # forward operator
+ (
+ propagated_probes,
+ object_patches,
+ self._transmitted_probes,
+ self._exit_waves,
+ batch_error,
+ ) = self._forward(
+ self._object,
+ self._probe,
+ amplitudes,
+ self._exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ # adjoint operator
+ self._object, self._probe = self._adjoint(
+ self._object,
+ self._probe,
+ object_patches,
+ propagated_probes,
+ self._exit_waves,
+ use_projection_scheme=use_projection_scheme,
+ step_size=step_size,
+ normalization_min=normalization_min,
+ fix_probe=a0 < fix_probe_iter,
+ )
+
+ # position correction
+ if a0 >= fix_positions_iter:
+ positions_px[start:end] = self._position_correction(
+ self._object,
+ self._probe[0],
+ self._transmitted_probes[:, 0],
+ amplitudes,
+ self._positions_px,
+ positions_step_size,
+ constrain_position_distance,
+ )
+
+ error += batch_error
+
+ # Normalize Error
+ error /= self._mean_diffraction_intensity * self._num_diffraction_patterns
+
+ # constraints
+ self._positions_px = positions_px.copy()[unshuffled_indices]
+ self._object, self._probe, self._positions_px = self._constraints(
+ self._object,
+ self._probe,
+ self._positions_px,
+ fix_com=fix_com and a0 >= fix_probe_iter,
+ constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude=a0
+ < constrain_probe_fourier_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
+ fit_probe_aberrations=a0 < fit_probe_aberrations_iter
+ and a0 >= fix_probe_iter,
+ fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
+ fix_probe_aperture=a0 < fix_probe_aperture_iter,
+ initial_probe_aperture=self._probe_initial_aperture,
+ fix_positions=a0 < fix_positions_iter,
+ global_affine_transformation=global_affine_transformation,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma is not None,
+ gaussian_filter_sigma=gaussian_filter_sigma,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass is not None or q_highpass is not None),
+ q_lowpass=q_lowpass,
+ q_highpass=q_highpass,
+ butterworth_order=butterworth_order,
+ kz_regularization_filter=a0 < kz_regularization_filter_iter
+ and kz_regularization_gamma is not None,
+ kz_regularization_gamma=kz_regularization_gamma[a0]
+ if kz_regularization_gamma is not None
+ and isinstance(kz_regularization_gamma, np.ndarray)
+ else kz_regularization_gamma,
+ identical_slices=a0 < identical_slices_iter,
+ object_positivity=object_positivity,
+ shrinkage_rad=shrinkage_rad,
+ object_mask=self._object_fov_mask_inverse
+ if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0
+ else None,
+ pure_phase_object=a0 < pure_phase_object_iter
+ and self._object_type == "complex",
+ tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle
+ and tv_denoise_weight_chambolle is not None,
+ tv_denoise_weight_chambolle=tv_denoise_weight_chambolle,
+ tv_denoise_pad_chambolle=tv_denoise_pad_chambolle,
+ tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None,
+ tv_denoise_weights=tv_denoise_weights,
+ tv_denoise_inner_iter=tv_denoise_inner_iter,
+ orthogonalize_probe=orthogonalize_probe,
+ )
+
+ self.error_iterations.append(error.item())
+ if store_iterations:
+ self.object_iterations.append(asnumpy(self._object.copy()))
+ self.probe_iterations.append(self.probe_centered)
+
+ # store result
+ self.object = asnumpy(self._object)
+ self.probe = self.probe_centered
+ self.error = error.item()
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _visualize_last_iteration_figax(
+ self,
+ fig,
+ object_ax,
+ convergence_ax,
+ cbar: bool,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object on a given fig/ax.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure object_ax lives in
+ object_ax: Axes
+ Matplotlib axes to plot reconstructed object in
+ convergence_ax: Axes, optional
+ Matplotlib axes to plot convergence plot in
+ cbar: bool, optional
+ If true, displays a colorbar
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ cmap = kwargs.pop("cmap", "magma")
+
+ if self._object_type == "complex":
+ obj = np.angle(self.object)
+ else:
+ obj = self.object
+
+ rotated_object = self._crop_rotate_object_fov(
+ np.sum(obj, axis=0), padding=padding
+ )
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ im = object_ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(object_ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if convergence_ax is not None and hasattr(self, "error_iterations"):
+ errors = np.array(self.error_iterations)
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = self.error_iterations
+
+ convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs)
+
+ def _visualize_last_iteration(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+
+ """
+ figsize = kwargs.pop("figsize", (8, 5))
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ if self._object_type == "complex":
+ obj = np.angle(self.object)
+ else:
+ obj = self.object
+
+ rotated_object = self._crop_rotate_object_fov(
+ np.sum(obj, axis=0), padding=padding
+ )
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=2,
+ height_ratios=[4, 1],
+ hspace=0.15,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=1,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ if plot_probe or plot_fourier_probe:
+ # Object
+ ax = fig.add_subplot(spec[0, 0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed object potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed object phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ # Probe
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+
+ ax = fig.add_subplot(spec[0, 1])
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ self.probe_fourier[0], chroma_boost=chroma_boost
+ )
+ ax.set_title("Reconstructed Fourier probe[0]")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ self.probe[0], power=2, chroma_boost=chroma_boost
+ )
+ ax.set_title("Reconstructed probe[0] intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(ax_cb, chroma_boost=chroma_boost)
+
+ else:
+ ax = fig.add_subplot(spec[0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed object potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed object phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if plot_convergence and hasattr(self, "error_iterations"):
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = np.array(self.error_iterations)
+ if plot_probe:
+ ax = fig.add_subplot(spec[1, :])
+ else:
+ ax = fig.add_subplot(spec[1])
+ ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax.set_ylabel("NMSE")
+ ax.set_xlabel("Iteration number")
+ ax.yaxis.tick_right()
+
+ fig.suptitle(f"Normalized mean squared error: {self.error:.3e}")
+ spec.tight_layout(fig)
+
+ def _visualize_all_iterations(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ iterations_grid: Tuple[int, int],
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays all reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ asnumpy = self._asnumpy
+
+ if not hasattr(self, "object_iterations"):
+ raise ValueError(
+ (
+ "Object and probe iterations were not saved during reconstruction. "
+ "Please re-run using store_iterations=True."
+ )
+ )
+
+ if iterations_grid == "auto":
+ num_iter = len(self.error_iterations)
+
+ if num_iter == 1:
+ return self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ elif plot_probe or plot_fourier_probe:
+ iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter)
+ else:
+ iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2)
+ else:
+ if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2:
+ raise ValueError()
+
+ auto_figsize = (
+ (3 * iterations_grid[1], 3 * iterations_grid[0] + 1)
+ if plot_convergence
+ else (3 * iterations_grid[1], 3 * iterations_grid[0])
+ )
+ figsize = kwargs.pop("figsize", auto_figsize)
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ errors = np.array(self.error_iterations)
+
+ objects = []
+ object_type = []
+
+ for obj in self.object_iterations:
+ if np.iscomplexobj(obj):
+ obj = np.angle(obj)
+ object_type.append("phase")
+ else:
+ object_type.append("potential")
+ objects.append(
+ self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding)
+ )
+
+ if plot_probe or plot_fourier_probe:
+ total_grids = (np.prod(iterations_grid) / 2).astype("int")
+ probes = self.probe_iterations
+ else:
+ total_grids = np.prod(iterations_grid)
+ max_iter = len(objects) - 1
+ grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1))
+
+ extent = [
+ 0,
+ self.sampling[1] * objects[0].shape[1],
+ self.sampling[0] * objects[0].shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0)
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=2)
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ grid = ImageGrid(
+ fig,
+ spec[0],
+ nrows_ncols=(1, iterations_grid[1])
+ if (plot_probe or plot_fourier_probe)
+ else iterations_grid,
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ im = ax.imshow(
+ objects[grid_range[n]],
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if cbar:
+ grid.cbar_axes[n].colorbar(im)
+
+ if plot_probe or plot_fourier_probe:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ grid = ImageGrid(
+ fig,
+ spec[1],
+ nrows_ncols=(1, iterations_grid[1]),
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ asnumpy(
+ self._return_fourier_probe_from_centered_probe(
+ probes[grid_range[n]][0]
+ )
+ ),
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ probes[grid_range[n]][0],
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} probe[0]")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ add_colorbar_arg(
+ grid.cbar_axes[n],
+ chroma_boost=chroma_boost,
+ )
+
+ if plot_convergence:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ if plot_probe:
+ ax2 = fig.add_subplot(spec[2])
+ else:
+ ax2 = fig.add_subplot(spec[1])
+ ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax2.set_ylabel("NMSE")
+ ax2.set_xlabel("Iteration number")
+ ax2.yaxis.tick_right()
+
+ spec.tight_layout(fig)
+
+ def visualize(
+ self,
+ fig=None,
+ iterations_grid: Tuple[int, int] = None,
+ plot_convergence: bool = True,
+ plot_probe: bool = True,
+ plot_fourier_probe: bool = False,
+ cbar: bool = True,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays reconstructed object and probe.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+
+ if iterations_grid is None:
+ self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ else:
+ self._visualize_all_iterations(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ iterations_grid=iterations_grid,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ return self
+
+ def show_fourier_probe(
+ self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs
+ ):
+ """
+ Plot probe in fourier space
+
+ Parameters
+ ----------
+ probe: complex array, optional
+ if None is specified, uses the `probe_fourier` property
+ scalebar: bool, optional
+ if True, adds scalebar to probe
+ pixelunits: str, optional
+ units for scalebar, default is A^-1
+ pixelsize: float, optional
+ default is probe reciprocal sampling
+ """
+ asnumpy = self._asnumpy
+
+ if probe is None:
+ probe = list(self.probe_fourier)
+ else:
+ if isinstance(probe, np.ndarray) and probe.ndim == 2:
+ probe = [probe]
+ probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe]
+
+ if pixelsize is None:
+ pixelsize = self._reciprocal_sampling[1]
+ if pixelunits is None:
+ pixelunits = r"$\AA^{-1}$"
+
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+
+ show_complex(
+ probe if len(probe) > 1 else probe[0],
+ scalebar=scalebar,
+ pixelsize=pixelsize,
+ pixelunits=pixelunits,
+ ticks=False,
+ chroma_boost=chroma_boost,
+ **kwargs,
+ )
+
+ def show_transmitted_probe(
+ self,
+ plot_fourier_probe: bool = False,
+ **kwargs,
+ ):
+ """
+ Plots the min, max, and mean transmitted probe after propagation and transmission.
+
+ Parameters
+ ----------
+ plot_fourier_probe: boolean, optional
+ If True, the transmitted probes are also plotted in Fourier space
+ kwargs:
+ Passed to show_complex
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ transmitted_probe_intensities = xp.sum(
+ xp.abs(self._transmitted_probes[:, 0]) ** 2, axis=(-2, -1)
+ )
+ min_intensity_transmitted = self._transmitted_probes[
+ xp.argmin(transmitted_probe_intensities), 0
+ ]
+ max_intensity_transmitted = self._transmitted_probes[
+ xp.argmax(transmitted_probe_intensities), 0
+ ]
+ mean_transmitted = self._transmitted_probes[:, 0].mean(0)
+ probes = [
+ asnumpy(self._return_centered_probe(probe))
+ for probe in [
+ mean_transmitted,
+ min_intensity_transmitted,
+ max_intensity_transmitted,
+ ]
+ ]
+ title = [
+ "Mean Transmitted Probe",
+ "Min Intensity Transmitted Probe",
+ "Max Intensity Transmitted Probe",
+ ]
+
+ if plot_fourier_probe:
+ bottom_row = [
+ asnumpy(self._return_fourier_probe(probe))
+ for probe in [
+ mean_transmitted,
+ min_intensity_transmitted,
+ max_intensity_transmitted,
+ ]
+ ]
+ probes = [probes, bottom_row]
+
+ title += [
+ "Mean Transmitted Fourier Probe",
+ "Min Intensity Transmitted Fourier Probe",
+ "Max Intensity Transmitted Fourier Probe",
+ ]
+
+ title = kwargs.get("title", title)
+ show_complex(
+ probes,
+ title=title,
+ **kwargs,
+ )
+
+ def show_slices(
+ self,
+ ms_object=None,
+ cbar: bool = True,
+ common_color_scale: bool = True,
+ padding: int = 0,
+ num_cols: int = 3,
+ show_fft: bool = False,
+ **kwargs,
+ ):
+ """
+ Displays reconstructed slices of object
+
+ Parameters
+ --------
+ ms_object: nd.array, optional
+ Object to plot slices of. If None, uses current object
+ cbar: bool, optional
+ If True, displays a colorbar
+ padding: int, optional
+ Padding to leave uncropped
+ num_cols: int, optional
+ Number of GridSpec columns
+ show_fft: bool, optional
+ if True, plots fft of object slices
+ """
+
+ if ms_object is None:
+ ms_object = self._object
+
+ rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding)
+ if show_fft:
+ rotated_object = np.abs(
+ np.fft.fftshift(
+ np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1)
+ )
+ )
+ rotated_shape = rotated_object.shape
+
+ if np.iscomplexobj(rotated_object):
+ rotated_object = np.angle(rotated_object)
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[2],
+ self.sampling[0] * rotated_shape[1],
+ 0,
+ ]
+
+ num_rows = np.ceil(self._num_slices / num_cols).astype("int")
+ wspace = 0.35 if cbar else 0.15
+
+ axsize = kwargs.pop("axsize", (3, 3))
+ cmap = kwargs.pop("cmap", "magma")
+
+ if common_color_scale:
+ vals = np.sort(rotated_object.ravel())
+ ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int")
+ ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int")
+ ind_vmin = np.max([0, ind_vmin])
+ ind_vmax = np.min([len(vals) - 1, ind_vmax])
+ vmin = vals[ind_vmin]
+ vmax = vals[ind_vmax]
+ if vmax == vmin:
+ vmin = vals[0]
+ vmax = vals[-1]
+ else:
+ vmax = None
+ vmin = None
+ vmin = kwargs.pop("vmin", vmin)
+ vmax = kwargs.pop("vmax", vmax)
+
+ spec = GridSpec(
+ ncols=num_cols,
+ nrows=num_rows,
+ hspace=0.15,
+ wspace=wspace,
+ )
+
+ figsize = (axsize[0] * num_cols, axsize[1] * num_rows)
+ fig = plt.figure(figsize=figsize)
+
+ for flat_index, obj_slice in enumerate(rotated_object):
+ row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols))
+ ax = fig.add_subplot(spec[row_index, col_index])
+ im = ax.imshow(
+ obj_slice,
+ cmap=cmap,
+ vmin=vmin,
+ vmax=vmax,
+ extent=extent,
+ **kwargs,
+ )
+
+ ax.set_title(f"Slice index: {flat_index}")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if row_index < num_rows - 1:
+ ax.set_xticks([])
+ else:
+ ax.set_xlabel("y [A]")
+
+ if col_index > 0:
+ ax.set_yticks([])
+ else:
+ ax.set_ylabel("x [A]")
+
+ spec.tight_layout(fig)
+
+ def show_depth(
+ self,
+ x1: float,
+ x2: float,
+ y1: float,
+ y2: float,
+ specify_calibrated: bool = False,
+ gaussian_filter_sigma: float = None,
+ ms_object=None,
+ cbar: bool = False,
+ aspect: float = None,
+ plot_line_profile: bool = False,
+ **kwargs,
+ ):
+ """
+ Displays line profile depth section
+
+ Parameters
+ --------
+ x1, x2, y1, y2: floats (pixels)
+ Line profile for depth section runs from (x1,y1) to (x2,y2)
+ Specified in pixels unless specify_calibrated is True
+ specify_calibrated: bool (optional)
+ If True, specify x1, x2, y1, y2 in A values instead of pixels
+ gaussian_filter_sigma: float (optional)
+ Standard deviation of gaussian kernel in A
+ ms_object: np.array
+ Object to plot slices of. If None, uses current object
+ cbar: bool, optional
+ If True, displays a colorbar
+ aspect: float, optional
+ aspect ratio for depth profile plot
+ plot_line_profile: bool
+ If True, also plots line profile showing where depth profile is taken
+ """
+ if ms_object is not None:
+ ms_obj = ms_object
+ else:
+ ms_obj = self.object_cropped
+
+ if specify_calibrated:
+ x1 /= self.sampling[0]
+ x2 /= self.sampling[0]
+ y1 /= self.sampling[1]
+ y2 /= self.sampling[1]
+
+ if x2 == x1:
+ angle = 0
+ elif y2 == y1:
+ angle = np.pi / 2
+ else:
+ angle = np.arctan((x2 - x1) / (y2 - y1))
+
+ x0 = ms_obj.shape[1] / 2
+ y0 = ms_obj.shape[2] / 2
+
+ if (
+ x1 > ms_obj.shape[1]
+ or x2 > ms_obj.shape[1]
+ or y1 > ms_obj.shape[2]
+ or y2 > ms_obj.shape[2]
+ ):
+ raise ValueError("depth section must be in field of view of object")
+
+ from py4DSTEM.process.phase.utils import rotate_point
+
+ x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle)
+ x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle)
+
+ rotated_object = np.roll(
+ rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)),
+ int(x1_0),
+ axis=1,
+ )
+
+ if np.iscomplexobj(rotated_object):
+ rotated_object = np.angle(rotated_object)
+ if gaussian_filter_sigma is not None:
+ from scipy.ndimage import gaussian_filter
+
+ gaussian_filter_sigma /= self.sampling[0]
+ rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma)
+
+ plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)]
+
+ extent = [
+ 0,
+ self.sampling[1] * plot_im.shape[1],
+ self._slice_thicknesses[0] * plot_im.shape[0],
+ 0,
+ ]
+
+ figsize = kwargs.pop("figsize", (6, 6))
+ if not plot_line_profile:
+ fig, ax = plt.subplots(figsize=figsize)
+ im = ax.imshow(plot_im, cmap="magma", extent=extent)
+ if aspect is not None:
+ ax.set_aspect(aspect)
+ ax.set_xlabel("r [A]")
+ ax.set_ylabel("z [A]")
+ ax.set_title("Multislice depth profile")
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+ else:
+ extent2 = [
+ 0,
+ self.sampling[1] * ms_obj.shape[2],
+ self.sampling[0] * ms_obj.shape[1],
+ 0,
+ ]
+ fig, ax = plt.subplots(2, 1, figsize=figsize)
+ ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2)
+ ax[0].plot(
+ [y1 * self.sampling[0], y2 * self.sampling[1]],
+ [x1 * self.sampling[0], x2 * self.sampling[1]],
+ color="red",
+ )
+ ax[0].set_xlabel("y [A]")
+ ax[0].set_ylabel("x [A]")
+ ax[0].set_title("Multislice depth profile location")
+
+ im = ax[1].imshow(plot_im, cmap="magma", extent=extent)
+ if aspect is not None:
+ ax[1].set_aspect(aspect)
+ ax[1].set_xlabel("r [A]")
+ ax[1].set_ylabel("z [A]")
+ ax[1].set_title("Multislice depth profile")
+ if cbar:
+ divider = make_axes_locatable(ax[1])
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+ plt.tight_layout()
+
+ def tune_num_slices_and_thicknesses(
+ self,
+ num_slices_guess=None,
+ thicknesses_guess=None,
+ num_slices_step_size=1,
+ thicknesses_step_size=20,
+ num_slices_values=3,
+ num_thicknesses_values=3,
+ update_defocus=False,
+ max_iter=5,
+ plot_reconstructions=True,
+ plot_convergence=True,
+ return_values=False,
+ **kwargs,
+ ):
+ """
+ Run reconstructions over a parameters space of number of slices
+ and slice thicknesses. Should be run after the preprocess step.
+
+ Parameters
+ ----------
+ num_slices_guess: float, optional
+ initial starting guess for number of slices, rounds to nearest integer
+ if None, uses current initialized values
+ thicknesses_guess: float (A), optional
+ initial starting guess for thicknesses of slices assuming same
+ thickness for each slice
+ if None, uses current initialized values
+ num_slices_step_size: float, optional
+ size of change of number of slices for each step in parameter space
+ thicknesses_step_size: float (A), optional
+ size of change of slice thicknesses for each step in parameter space
+ num_slices_values: int, optional
+ number of number of slice values to test, must be >= 1
+ num_thicknesses_values: int,optional
+ number of thicknesses values to test, must be >= 1
+ update_defocus: bool, optional
+ if True, updates defocus based on estimated total thickness
+ max_iter: int, optional
+ number of iterations to run in ptychographic reconstruction
+ plot_reconstructions: bool, optional
+ if True, plot phase of reconstructed objects
+ plot_convergence: bool, optional
+ if True, plots error for each iteration for each reconstruction
+ return_values: bool, optional
+ if True, returns objects, convergence
+
+ Returns
+ -------
+ objects: list
+ reconstructed objects
+ convergence: np.ndarray
+ array of convergence values from reconstructions
+ """
+
+ # calculate number of slices and thicknesses values to test
+ if num_slices_guess is None:
+ num_slices_guess = self._num_slices
+ if thicknesses_guess is None:
+ thicknesses_guess = np.mean(self._slice_thicknesses)
+
+ if num_slices_values == 1:
+ num_slices_step_size = 0
+
+ if num_thicknesses_values == 1:
+ thicknesses_step_size = 0
+
+ num_slices = np.linspace(
+ num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2,
+ num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2,
+ num_slices_values,
+ )
+
+ thicknesses = np.linspace(
+ thicknesses_guess
+ - thicknesses_step_size * (num_thicknesses_values - 1) / 2,
+ thicknesses_guess
+ + thicknesses_step_size * (num_thicknesses_values - 1) / 2,
+ num_thicknesses_values,
+ )
+
+ if return_values:
+ convergence = []
+ objects = []
+
+ # current initialized values
+ current_verbose = self._verbose
+ current_num_slices = self._num_slices
+ current_thicknesses = self._slice_thicknesses
+ current_rotation_deg = self._rotation_best_rad * 180 / np.pi
+ current_transpose = self._rotation_best_transpose
+ current_defocus = -self._polar_parameters["C10"]
+
+ # Gridspec to plot on
+ if plot_reconstructions:
+ if plot_convergence:
+ spec = GridSpec(
+ ncols=num_thicknesses_values,
+ nrows=num_slices_values * 2,
+ height_ratios=[1, 1 / 4] * num_slices_values,
+ hspace=0.15,
+ wspace=0.35,
+ )
+ figsize = kwargs.get(
+ "figsize", (4 * num_thicknesses_values, 5 * num_slices_values)
+ )
+ else:
+ spec = GridSpec(
+ ncols=num_thicknesses_values,
+ nrows=num_slices_values,
+ hspace=0.15,
+ wspace=0.35,
+ )
+ figsize = kwargs.get(
+ "figsize", (4 * num_thicknesses_values, 4 * num_slices_values)
+ )
+
+ fig = plt.figure(figsize=figsize)
+
+ progress_bar = kwargs.pop("progress_bar", False)
+ # run loop and plot along the way
+ self._verbose = False
+ for flat_index, (slices, thickness) in enumerate(
+ tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus")
+ ):
+ slices = int(slices)
+ self._num_slices = slices
+ self._slice_thicknesses = np.tile(thickness, slices - 1)
+ self._probe = None
+ self._object = None
+ if update_defocus:
+ defocus = current_defocus + slices / 2 * thickness
+ self._polar_parameters["C10"] = -defocus
+
+ self.preprocess(
+ plot_center_of_mass=False,
+ plot_rotation=False,
+ plot_probe_overlaps=False,
+ force_com_rotation=current_rotation_deg,
+ force_com_transpose=current_transpose,
+ )
+ self.reconstruct(
+ reset=True,
+ store_iterations=True if plot_convergence else False,
+ max_iter=max_iter,
+ progress_bar=progress_bar,
+ **kwargs,
+ )
+
+ if plot_reconstructions:
+ row_index, col_index = np.unravel_index(
+ flat_index, (num_slices_values, num_thicknesses_values)
+ )
+
+ if plot_convergence:
+ object_ax = fig.add_subplot(spec[row_index * 2, col_index])
+ convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index])
+ self._visualize_last_iteration_figax(
+ fig,
+ object_ax=object_ax,
+ convergence_ax=convergence_ax,
+ cbar=True,
+ )
+ convergence_ax.yaxis.tick_right()
+ else:
+ object_ax = fig.add_subplot(spec[row_index, col_index])
+ self._visualize_last_iteration_figax(
+ fig,
+ object_ax=object_ax,
+ convergence_ax=None,
+ cbar=True,
+ )
+
+ object_ax.set_title(
+ f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}"
+ )
+ object_ax.set_xticks([])
+ object_ax.set_yticks([])
+
+ if return_values:
+ objects.append(self.object)
+ convergence.append(self.error_iterations.copy())
+
+ # initialize back to pre-tuning values
+ self._probe = None
+ self._object = None
+ self._num_slices = current_num_slices
+ self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1)
+ self._polar_parameters["C10"] = -current_defocus
+ self.preprocess(
+ force_com_rotation=current_rotation_deg,
+ force_com_transpose=current_transpose,
+ plot_center_of_mass=False,
+ plot_rotation=False,
+ plot_probe_overlaps=False,
+ )
+ self._verbose = current_verbose
+
+ if plot_reconstructions:
+ spec.tight_layout(fig)
+
+ if return_values:
+ return objects, convergence
+
+ def _return_object_fft(
+ self,
+ obj=None,
+ ):
+ """
+ Returns obj fft shifted to center of array
+
+ Parameters
+ ----------
+ obj: array, optional
+ if None is specified, uses self._object
+ """
+ asnumpy = self._asnumpy
+
+ if obj is None:
+ obj = self._object
+
+ obj = asnumpy(obj)
+ if np.iscomplexobj(obj):
+ obj = np.angle(obj)
+
+ obj = self._crop_rotate_object_fov(np.sum(obj, axis=0))
+ return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj))))
+
+ def _return_self_consistency_errors(
+ self,
+ max_batch_size=None,
+ ):
+ """Compute the self-consistency errors for each probe position"""
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # Batch-size
+ if max_batch_size is None:
+ max_batch_size = self._num_diffraction_patterns
+
+ # Re-initialize fractional positions and vector patches
+ errors = np.array([])
+ positions_px = self._positions_px.copy()
+
+ for start, end in generate_batches(
+ self._num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ amplitudes = self._amplitudes[start:end]
+
+ # Overlaps
+ _, _, overlap = self._overlap_projection(self._object, self._probe)
+ fourier_overlap = xp.fft.fft2(overlap)
+ intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))
+
+ # Normalized mean-squared errors
+ batch_errors = xp.sum(
+ xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1)
+ )
+ errors = np.hstack((errors, batch_errors))
+
+ self._positions_px = positions_px.copy()
+ errors /= self._mean_diffraction_intensity
+
+ return asnumpy(errors)
+
+ def _return_projected_cropped_potential(
+ self,
+ ):
+ """Utility function to accommodate multiple classes"""
+ if self._object_type == "complex":
+ projected_cropped_potential = np.angle(self.object_cropped).sum(0)
+ else:
+ projected_cropped_potential = self.object_cropped.sum(0)
+
+ return projected_cropped_potential
diff --git a/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py
new file mode 100644
index 000000000..d68291143
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_mixedstate_ptychography.py
@@ -0,0 +1,2390 @@
+"""
+Module for reconstructing phase objects from 4DSTEM datasets using iterative methods,
+namely mixed-state ptychography.
+"""
+
+import warnings
+from typing import Mapping, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
+from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+from emdfile import Custom, tqdmnd
+from py4DSTEM import DataCube
+from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction
+from py4DSTEM.process.phase.utils import (
+ ComplexProbe,
+ fft_shift,
+ generate_batches,
+ polar_aliases,
+ polar_symbols,
+)
+from py4DSTEM.process.utils import get_CoM, get_shifted_ar
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+
+class MixedstatePtychographicReconstruction(PtychographicReconstruction):
+ """
+ Mixed-State Ptychographic Reconstruction Class.
+
+ Diffraction intensities dimensions : (Rx,Ry,Qx,Qy)
+ Reconstructed probe dimensions : (N,Sx,Sy)
+ Reconstructed object dimensions : (Px,Py)
+
+ such that (Sx,Sy) is the region-of-interest (ROI) size of our N probes
+ and (Px,Py) is the padded-object size we position our ROI around in.
+
+ Parameters
+ ----------
+ energy: float
+ The electron energy of the wave functions in eV
+ datacube: DataCube
+ Input 4D diffraction pattern intensities
+ num_probes: int, optional
+ Number of mixed-state probes
+ semiangle_cutoff: float, optional
+ Semiangle cutoff for the initial probe guess in mrad
+ semiangle_cutoff_pixels: float, optional
+ Semiangle cutoff for the initial probe guess in pixels
+ rolloff: float, optional
+ Semiangle rolloff for the initial probe guess
+ vacuum_probe_intensity: np.ndarray, optional
+ Vacuum probe to use as intensity aperture for initial probe guess
+ polar_parameters: dict, optional
+ Mapping from aberration symbols to their corresponding values. All aberration
+ magnitudes should be given in Å and angles should be given in radians.
+ object_padding_px: Tuple[int,int], optional
+ Pixel dimensions to pad object with
+ If None, the padding is set to half the probe ROI dimensions
+ initial_object_guess: np.ndarray, optional
+ Initial guess for complex-valued object of dimensions (Px,Py)
+ If None, initialized to 1.0j
+ initial_probe_guess: np.ndarray, optional
+ Initial guess for complex-valued probe of dimensions (Sx,Sy). If None,
+ initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations
+ initial_scan_positions: np.ndarray, optional
+ Probe positions in Å for each diffraction intensity
+ If None, initialized to a grid scan
+ positions_mask: np.ndarray, optional
+ Boolean real space mask to select positions in datacube to skip for reconstruction
+ verbose: bool, optional
+ If True, class methods will inherit this and print additional information
+ device: str, optional
+ Calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ name: str, optional
+ Class name
+ kwargs:
+ Provide the aberration coefficients as keyword arguments.
+ """
+
+ # Class-specific Metadata
+ _class_specific_metadata = ("_num_probes",)
+
+ def __init__(
+ self,
+ energy: float,
+ datacube: DataCube = None,
+ num_probes: int = None,
+ semiangle_cutoff: float = None,
+ semiangle_cutoff_pixels: float = None,
+ rolloff: float = 2.0,
+ vacuum_probe_intensity: np.ndarray = None,
+ polar_parameters: Mapping[str, float] = None,
+ object_padding_px: Tuple[int, int] = None,
+ initial_object_guess: np.ndarray = None,
+ initial_probe_guess: np.ndarray = None,
+ initial_scan_positions: np.ndarray = None,
+ object_type: str = "complex",
+ positions_mask: np.ndarray = None,
+ verbose: bool = True,
+ device: str = "cpu",
+ name: str = "mixed-state_ptychographic_reconstruction",
+ **kwargs,
+ ):
+ Custom.__init__(self, name=name)
+
+ if initial_probe_guess is None or isinstance(initial_probe_guess, ComplexProbe):
+ if num_probes is None:
+ raise ValueError(
+ (
+ "If initial_probe_guess is None, or a ComplexProbe object, "
+ "num_probes must be specified."
+ )
+ )
+ else:
+ if len(initial_probe_guess.shape) != 3:
+ raise ValueError(
+ "Specified initial_probe_guess must have dimensions (N,Sx,Sy)."
+ )
+ num_probes = initial_probe_guess.shape[0]
+
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from scipy.special import erf
+
+ self._erf = erf
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from cupyx.scipy.special import erf
+
+ self._erf = erf
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ for key in kwargs.keys():
+ if (key not in polar_symbols) and (key not in polar_aliases.keys()):
+ raise ValueError("{} not a recognized parameter".format(key))
+
+ self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))
+
+ if polar_parameters is None:
+ polar_parameters = {}
+
+ polar_parameters.update(kwargs)
+ self._set_polar_parameters(polar_parameters)
+
+ if object_type != "potential" and object_type != "complex":
+ raise ValueError(
+ f"object_type must be either 'potential' or 'complex', not {object_type}"
+ )
+ if positions_mask is not None and positions_mask.dtype != "bool":
+ warnings.warn(
+ ("`positions_mask` converted to `bool` array"),
+ UserWarning,
+ )
+ positions_mask = np.asarray(positions_mask, dtype="bool")
+
+ self.set_save_defaults()
+
+ # Data
+ self._datacube = datacube
+ self._object = initial_object_guess
+ self._probe = initial_probe_guess
+
+ # Common Metadata
+ self._vacuum_probe_intensity = vacuum_probe_intensity
+ self._scan_positions = initial_scan_positions
+ self._energy = energy
+ self._semiangle_cutoff = semiangle_cutoff
+ self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
+ self._rolloff = rolloff
+ self._object_type = object_type
+ self._object_padding_px = object_padding_px
+ self._positions_mask = positions_mask
+ self._verbose = verbose
+ self._device = device
+ self._preprocessed = False
+
+ # Class-specific Metadata
+ self._num_probes = num_probes
+
+ def preprocess(
+ self,
+ diffraction_intensities_shape: Tuple[int, int] = None,
+ reshaping_method: str = "fourier",
+ probe_roi_shape: Tuple[int, int] = None,
+ dp_mask: np.ndarray = None,
+ fit_function: str = "plane",
+ plot_center_of_mass: str = "default",
+ plot_rotation: bool = True,
+ maximize_divergence: bool = False,
+ rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0),
+ plot_probe_overlaps: bool = True,
+ force_com_rotation: float = None,
+ force_com_transpose: float = None,
+ force_com_shifts: float = None,
+ force_scan_sampling: float = None,
+ force_angular_sampling: float = None,
+ force_reciprocal_sampling: float = None,
+ object_fov_mask: np.ndarray = None,
+ crop_patterns: bool = False,
+ **kwargs,
+ ):
+ """
+ Ptychographic preprocessing step.
+ Calls the base class methods:
+
+ _extract_intensities_and_calibrations_from_datacube,
+ _compute_center_of_mass(),
+ _solve_CoM_rotation(),
+ _normalize_diffraction_intensities()
+ _calculate_scan_positions_in_px()
+
+ Additionally, it initializes an (Px,Py) array of 1.0j
+ and a complex probe using the specified polar parameters.
+
+ Parameters
+ ----------
+ diffraction_intensities_shape: Tuple[int,int], optional
+ Pixel dimensions (Qx',Qy') of the resampled diffraction intensities
+ If None, no resampling of diffraction intenstities is performed
+ reshaping_method: str, optional
+ Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default)
+ probe_roi_shape, (int,int), optional
+ Padded diffraction intensities shape.
+ If None, no padding is performed
+ dp_mask: ndarray, optional
+ Mask for datacube intensities (Qx,Qy)
+ fit_function: str, optional
+ 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two'
+ plot_center_of_mass: str, optional
+ If 'default', the corrected CoM arrays will be displayed
+ If 'all', the computed and fitted CoM arrays will be displayed
+ plot_rotation: bool, optional
+ If True, the CoM curl minimization search result will be displayed
+ maximize_divergence: bool, optional
+ If True, the divergence of the CoM gradient vector field is maximized
+ rotation_angles_deg: np.darray, optional
+ Array of angles in degrees to perform curl minimization over
+ plot_probe_overlaps: bool, optional
+ If True, initial probe overlaps scanned over the object will be displayed
+ force_com_rotation: float (degrees), optional
+ Force relative rotation angle between real and reciprocal space
+ force_com_transpose: bool, optional
+ Force whether diffraction intensities need to be transposed.
+ force_com_shifts: tuple of ndarrays (CoMx, CoMy)
+ Amplitudes come from diffraction patterns shifted with
+ the CoM in the upper left corner for each probe unless
+ shift is overwritten.
+ force_scan_sampling: float, optional
+ Override DataCube real space scan pixel size calibrations, in Angstrom
+ force_angular_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in mrad
+ force_reciprocal_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in A^-1
+ object_fov_mask: np.ndarray (boolean)
+ Boolean mask of FOV. Used to calculate additional shrinkage of object
+ If None, probe_overlap intensity is thresholded
+ crop_patterns: bool
+ if True, crop patterns to avoid wrap around of patterns when centering
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # set additional metadata
+ self._diffraction_intensities_shape = diffraction_intensities_shape
+ self._reshaping_method = reshaping_method
+ self._probe_roi_shape = probe_roi_shape
+ self._dp_mask = dp_mask
+
+ if self._datacube is None:
+ raise ValueError(
+ (
+ "The preprocess() method requires a DataCube. "
+ "Please run ptycho.attach_datacube(DataCube) first."
+ )
+ )
+
+ (
+ self._datacube,
+ self._vacuum_probe_intensity,
+ self._dp_mask,
+ force_com_shifts,
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube,
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ dp_mask=self._dp_mask,
+ com_shifts=force_com_shifts,
+ )
+
+ self._intensities = self._extract_intensities_and_calibrations_from_datacube(
+ self._datacube,
+ require_calibrations=True,
+ force_scan_sampling=force_scan_sampling,
+ force_angular_sampling=force_angular_sampling,
+ force_reciprocal_sampling=force_reciprocal_sampling,
+ )
+
+ (
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_fitted_x,
+ self._com_fitted_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ ) = self._calculate_intensities_center_of_mass(
+ self._intensities,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts,
+ )
+
+ (
+ self._rotation_best_rad,
+ self._rotation_best_transpose,
+ self._com_x,
+ self._com_y,
+ self.com_x,
+ self.com_y,
+ ) = self._solve_for_center_of_mass_relative_rotation(
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ rotation_angles_deg=rotation_angles_deg,
+ plot_rotation=plot_rotation,
+ plot_center_of_mass=plot_center_of_mass,
+ maximize_divergence=maximize_divergence,
+ force_com_rotation=force_com_rotation,
+ force_com_transpose=force_com_transpose,
+ **kwargs,
+ )
+
+ (
+ self._amplitudes,
+ self._mean_diffraction_intensity,
+ ) = self._normalize_diffraction_intensities(
+ self._intensities,
+ self._com_fitted_x,
+ self._com_fitted_y,
+ crop_patterns,
+ self._positions_mask,
+ )
+
+ # explicitly delete namespace
+ self._num_diffraction_patterns = self._amplitudes.shape[0]
+ self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:])
+ del self._intensities
+
+ self._positions_px = self._calculate_scan_positions_in_pixels(
+ self._scan_positions, self._positions_mask
+ )
+
+ # handle semiangle specified in pixels
+ if self._semiangle_cutoff_pixels:
+ self._semiangle_cutoff = (
+ self._semiangle_cutoff_pixels * self._angular_sampling[0]
+ )
+
+ # Object Initialization
+ if self._object is None:
+ pad_x = self._object_padding_px[0][1]
+ pad_y = self._object_padding_px[1][1]
+ p, q = np.round(np.max(self._positions_px, axis=0))
+ p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype(
+ "int"
+ )
+ q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype(
+ "int"
+ )
+ if self._object_type == "potential":
+ self._object = xp.zeros((p, q), dtype=xp.float32)
+ elif self._object_type == "complex":
+ self._object = xp.ones((p, q), dtype=xp.complex64)
+ else:
+ if self._object_type == "potential":
+ self._object = xp.asarray(self._object, dtype=xp.float32)
+ elif self._object_type == "complex":
+ self._object = xp.asarray(self._object, dtype=xp.complex64)
+
+ self._object_initial = self._object.copy()
+ self._object_type_initial = self._object_type
+ self._object_shape = self._object.shape
+
+ self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32)
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+
+ self._positions_px_initial = self._positions_px.copy()
+ self._positions_initial = self._positions_px_initial.copy()
+ self._positions_initial[:, 0] *= self.sampling[0]
+ self._positions_initial[:, 1] *= self.sampling[1]
+
+ # Vectorized Patches
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ # Probe Initialization
+ if self._probe is None or isinstance(self._probe, ComplexProbe):
+ if self._probe is None:
+ if self._vacuum_probe_intensity is not None:
+ self._semiangle_cutoff = np.inf
+ self._vacuum_probe_intensity = xp.asarray(
+ self._vacuum_probe_intensity, dtype=xp.float32
+ )
+ probe_x0, probe_y0 = get_CoM(
+ self._vacuum_probe_intensity,
+ device=self._device,
+ )
+ self._vacuum_probe_intensity = get_shifted_ar(
+ self._vacuum_probe_intensity,
+ -probe_x0,
+ -probe_y0,
+ bilinear=True,
+ device=self._device,
+ )
+ if crop_patterns:
+ self._vacuum_probe_intensity = self._vacuum_probe_intensity[
+ self._crop_mask
+ ].reshape(self._region_of_interest_shape)
+
+ _probe = (
+ ComplexProbe(
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ energy=self._energy,
+ semiangle_cutoff=self._semiangle_cutoff,
+ rolloff=self._rolloff,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )
+ .build()
+ ._array
+ )
+
+ else:
+ if self._probe._gpts != self._region_of_interest_shape:
+ raise ValueError()
+ if hasattr(self._probe, "_array"):
+ _probe = self._probe._array
+ else:
+ self._probe._xp = xp
+ _probe = self._probe.build()._array
+
+ self._probe = xp.zeros(
+ (self._num_probes,) + tuple(self._region_of_interest_shape),
+ dtype=xp.complex64,
+ )
+ sx, sy = self._region_of_interest_shape
+ self._probe[0] = _probe
+
+ # Randomly shift phase of other probes
+ for i_probe in range(1, self._num_probes):
+ shift_x = xp.exp(
+ -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sx)
+ )
+ shift_y = xp.exp(
+ -2j * np.pi * (xp.random.rand() - 0.5) * xp.fft.fftfreq(sy)
+ )
+ self._probe[i_probe] = (
+ self._probe[i_probe - 1] * shift_x[:, None] * shift_y[None]
+ )
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe[0])) ** 2)
+ self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity)
+
+ else:
+ self._probe = xp.asarray(self._probe, dtype=xp.complex64)
+
+ self._probe_initial = self._probe.copy()
+ self._probe_initial_aperture = None # Doesn't really make sense for mixed-state
+
+ self._known_aberrations_array = ComplexProbe(
+ energy=self._energy,
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )._evaluate_ctf()
+
+ # overlaps
+ shifted_probes = fft_shift(self._probe[0], self._positions_px_fractional, xp)
+ probe_intensities = xp.abs(shifted_probes) ** 2
+ probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities)
+ probe_overlap = self._gaussian_filter(probe_overlap, 1.0)
+
+ if object_fov_mask is None:
+ self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max())
+ else:
+ self._object_fov_mask = np.asarray(object_fov_mask)
+ self._object_fov_mask_inverse = np.invert(self._object_fov_mask)
+
+ if plot_probe_overlaps:
+ figsize = kwargs.pop("figsize", (4.5 * self._num_probes + 4, 4))
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ # initial probe
+ complex_probe_rgb = Complex2RGB(
+ self.probe_centered,
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ extent = [
+ 0,
+ self.sampling[1] * self._object_shape[1],
+ self.sampling[0] * self._object_shape[0],
+ 0,
+ ]
+
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ fig, axs = plt.subplots(1, self._num_probes + 1, figsize=figsize)
+
+ for i in range(self._num_probes):
+ axs[i].imshow(
+ complex_probe_rgb[i],
+ extent=probe_extent,
+ )
+ axs[i].set_ylabel("x [A]")
+ axs[i].set_xlabel("y [A]")
+ axs[i].set_title(f"Initial probe[{i}] intensity")
+
+ divider = make_axes_locatable(axs[i])
+ cax = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(cax, chroma_boost=chroma_boost)
+
+ axs[-1].imshow(
+ asnumpy(probe_overlap),
+ extent=extent,
+ cmap="Greys_r",
+ )
+ axs[-1].scatter(
+ self.positions[:, 1],
+ self.positions[:, 0],
+ s=2.5,
+ color=(1, 0, 0, 1),
+ )
+ axs[-1].set_ylabel("x [A]")
+ axs[-1].set_xlabel("y [A]")
+ axs[-1].set_xlim((extent[0], extent[1]))
+ axs[-1].set_ylim((extent[2], extent[3]))
+ axs[-1].set_title("Object field of view")
+
+ fig.tight_layout()
+
+ self._preprocessed = True
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _overlap_projection(self, current_object, current_probe):
+ """
+ Ptychographic overlap projection method.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ object_patches: np.ndarray
+ Patched object view
+ overlap: np.ndarray
+ shifted_probes * object_patches
+ """
+
+ xp = self._xp
+
+ shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp)
+
+ if self._object_type == "potential":
+ complex_object = xp.exp(1j * current_object)
+ else:
+ complex_object = current_object
+
+ object_patches = complex_object[
+ self._vectorized_patch_indices_row, self._vectorized_patch_indices_col
+ ]
+
+ overlap = shifted_probes * xp.expand_dims(object_patches, axis=1)
+
+ return shifted_probes, object_patches, overlap
+
+ def _gradient_descent_fourier_projection(self, amplitudes, overlap):
+ """
+ Ptychographic fourier projection method for GD method.
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ overlap: np.ndarray
+ object * probe overlap
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Difference between modified and estimated exit waves
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ fourier_overlap = xp.fft.fft2(overlap)
+ intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))
+ error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2)
+
+ intensity_norm[intensity_norm == 0.0] = np.inf
+ amplitude_modification = amplitudes / intensity_norm
+
+ fourier_modified_overlap = amplitude_modification[:, None] * fourier_overlap
+ modified_overlap = xp.fft.ifft2(fourier_modified_overlap)
+
+ exit_waves = modified_overlap - overlap
+
+ return exit_waves, error
+
+ def _projection_sets_fourier_projection(
+ self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c
+ ):
+ """
+ Ptychographic fourier projection method for DM_AP and RAAR methods.
+ Generalized projection using three parameters: a,b,c
+
+ DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha
+ DM: DM_AP(1.0), AP: DM_AP(0.0)
+
+ RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2
+ DM : RAAR(1.0)
+
+ RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2
+ DM: RRR(1.0)
+
+ SUPERFLIP : a = 0, b = 1, c = 2
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ overlap: np.ndarray
+ object * probe overlap
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ projection_x = 1 - projection_a - projection_b
+ projection_y = 1 - projection_c
+
+ if exit_waves is None:
+ exit_waves = overlap.copy()
+
+ fourier_overlap = xp.fft.fft2(overlap)
+ intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))
+ error = xp.sum(xp.abs(amplitudes - intensity_norm) ** 2)
+
+ factor_to_be_projected = projection_c * overlap + projection_y * exit_waves
+ fourier_projected_factor = xp.fft.fft2(factor_to_be_projected)
+
+ intensity_norm_projected = xp.sqrt(
+ xp.sum(xp.abs(fourier_projected_factor) ** 2, axis=1)
+ )
+ intensity_norm_projected[intensity_norm_projected == 0.0] = np.inf
+
+ amplitude_modification = amplitudes / intensity_norm_projected
+ fourier_projected_factor *= amplitude_modification[:, None]
+
+ projected_factor = xp.fft.ifft2(fourier_projected_factor)
+
+ exit_waves = (
+ projection_x * exit_waves
+ + projection_a * overlap
+ + projection_b * projected_factor
+ )
+
+ return exit_waves, error
+
+ def _forward(
+ self,
+ current_object,
+ current_probe,
+ amplitudes,
+ exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic forward operator.
+ Calls _overlap_projection() and the appropriate _fourier_projection().
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ object_patches: np.ndarray
+ Patched object view
+ overlap: np.ndarray
+ object * probe overlap
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ shifted_probes, object_patches, overlap = self._overlap_projection(
+ current_object, current_probe
+ )
+ if use_projection_scheme:
+ exit_waves, error = self._projection_sets_fourier_projection(
+ amplitudes,
+ overlap,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ else:
+ exit_waves, error = self._gradient_descent_fourier_projection(
+ amplitudes, overlap
+ )
+
+ return shifted_probes, object_patches, overlap, exit_waves, error
+
+ def _gradient_descent_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ probe_normalization = xp.zeros_like(current_object)
+ object_update = xp.zeros_like(current_object)
+
+ for i_probe in range(self._num_probes):
+ probe_normalization += self._sum_overlapping_patches_bincounts(
+ xp.abs(shifted_probes[:, i_probe]) ** 2
+ )
+ if self._object_type == "potential":
+ object_update += step_size * self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * xp.conj(object_patches)
+ * xp.conj(shifted_probes[:, i_probe])
+ * exit_waves[:, i_probe]
+ )
+ )
+ else:
+ object_update += step_size * self._sum_overlapping_patches_bincounts(
+ xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe]
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ current_object += object_update * probe_normalization
+
+ if not fix_probe:
+ object_normalization = xp.sum(
+ (xp.abs(object_patches) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe += step_size * (
+ xp.sum(
+ xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves,
+ axis=0,
+ )
+ * object_normalization[None]
+ )
+
+ return current_object, current_probe
+
+ def _projection_sets_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for DM_AP and RAAR methods.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ probe_normalization = xp.zeros_like(current_object)
+ current_object = xp.zeros_like(current_object)
+
+ for i_probe in range(self._num_probes):
+ probe_normalization += self._sum_overlapping_patches_bincounts(
+ xp.abs(shifted_probes[:, i_probe]) ** 2
+ )
+ if self._object_type == "potential":
+ current_object += self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * xp.conj(object_patches)
+ * xp.conj(shifted_probes[:, i_probe])
+ * exit_waves[:, i_probe]
+ )
+ )
+ else:
+ current_object += self._sum_overlapping_patches_bincounts(
+ xp.conj(shifted_probes[:, i_probe]) * exit_waves[:, i_probe]
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ current_object *= probe_normalization
+
+ if not fix_probe:
+ object_normalization = xp.sum(
+ (xp.abs(object_patches) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe = (
+ xp.sum(
+ xp.expand_dims(xp.conj(object_patches), axis=1) * exit_waves,
+ axis=0,
+ )
+ * object_normalization[None]
+ )
+
+ return current_object, current_probe
+
+ def _adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ use_projection_scheme: bool,
+ step_size: float,
+ normalization_min: float,
+ fix_probe: bool,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+
+ if use_projection_scheme:
+ current_object, current_probe = self._projection_sets_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ )
+ else:
+ current_object, current_probe = self._gradient_descent_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ )
+
+ return current_object, current_probe
+
+ def _probe_center_of_mass_constraint(self, current_probe):
+ """
+ Ptychographic center of mass constraint.
+ Used for centering corner-centered probe intensity.
+
+ Parameters
+ --------
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ """
+ xp = self._xp
+ probe_intensity = xp.abs(current_probe[0]) ** 2
+
+ probe_x0, probe_y0 = get_CoM(
+ probe_intensity, device=self._device, corner_centered=True
+ )
+ shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp)
+
+ return shifted_probe
+
+ def _probe_orthogonalization_constraint(self, current_probe):
+ """
+ Ptychographic probe-orthogonalization constraint.
+ Used to ensure mixed states are orthogonal to each other.
+ Adapted from https://github.com/AdvancedPhotonSource/tike/blob/main/src/tike/ptycho/probe.py#L690
+
+ Parameters
+ --------
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ constrained_probe: np.ndarray
+ Orthogonalized probe estimate
+ """
+ xp = self._xp
+ n_probes = self._num_probes
+
+ # compute upper half of P* @ P
+ pairwise_dot_product = xp.empty((n_probes, n_probes), dtype=current_probe.dtype)
+
+ for i in range(n_probes):
+ for j in range(i, n_probes):
+ pairwise_dot_product[i, j] = xp.sum(
+ current_probe[i].conj() * current_probe[j]
+ )
+
+ # compute eigenvectors (effectively cheaper way of computing V* from SVD)
+ _, evecs = xp.linalg.eigh(pairwise_dot_product, UPLO="U")
+ current_probe = xp.tensordot(evecs.T, current_probe, axes=1)
+
+ # sort by real-space intensity
+ intensities = xp.sum(xp.abs(current_probe) ** 2, axis=(-2, -1))
+ intensities_order = xp.argsort(intensities, axis=None)[::-1]
+ return current_probe[intensities_order]
+
+ def _constraints(
+ self,
+ current_object,
+ current_probe,
+ current_positions,
+ pure_phase_object,
+ fix_com,
+ fit_probe_aberrations,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ constrain_probe_amplitude,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ fix_probe_aperture,
+ initial_probe_aperture,
+ fix_positions,
+ global_affine_transformation,
+ gaussian_filter,
+ gaussian_filter_sigma,
+ butterworth_filter,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ tv_denoise,
+ tv_denoise_weight,
+ tv_denoise_inner_iter,
+ orthogonalize_probe,
+ object_positivity,
+ shrinkage_rad,
+ object_mask,
+ ):
+ """
+ Ptychographic constraints operator.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ current_positions: np.ndarray
+ Current positions estimate
+ pure_phase_object: bool
+ If True, object amplitude is set to unity
+ fix_com: bool
+ If True, probe CoM is fixed to the center
+ fit_probe_aberrations: bool
+ If True, fits the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ constrain_probe_amplitude: bool
+ If True, probe amplitude is constrained by top hat function
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude: bool
+ If True, probe aperture is constrained by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_probe_aperture: bool,
+ If True, probe fourier amplitude is replaced by initial probe aperture.
+ initial_probe_aperture: np.ndarray
+ initial probe aperture to use in replacing probe fourier amplitude
+ fix_positions: bool
+ If True, positions are not updated
+ gaussian_filter: bool
+ If True, applies real-space gaussian filter
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A
+ butterworth_filter: bool
+ If True, applies high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ orthogonalize_probe: bool
+ If True, probe will be orthogonalized
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weight: float
+ Denoising weight. The greater `weight`, the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ object_positivity: bool
+ If True, clips negative potential values
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ object_mask: np.ndarray (boolean)
+ If not None, used to calculate additional shrinkage using masked-mean of object
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ constrained_positions: np.ndarray
+ Constrained positions estimate
+ """
+
+ if gaussian_filter:
+ current_object = self._object_gaussian_constraint(
+ current_object, gaussian_filter_sigma, pure_phase_object
+ )
+
+ if butterworth_filter:
+ current_object = self._object_butterworth_constraint(
+ current_object,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ )
+
+ if tv_denoise:
+ current_object = self._object_denoise_tv_pylops(
+ current_object, tv_denoise_weight, tv_denoise_inner_iter
+ )
+
+ if shrinkage_rad > 0.0 or object_mask is not None:
+ current_object = self._object_shrinkage_constraint(
+ current_object,
+ shrinkage_rad,
+ object_mask,
+ )
+
+ if self._object_type == "complex":
+ current_object = self._object_threshold_constraint(
+ current_object, pure_phase_object
+ )
+ elif object_positivity:
+ current_object = self._object_positivity_constraint(current_object)
+
+ if fix_com:
+ current_probe = self._probe_center_of_mass_constraint(current_probe)
+
+ # These constraints don't _really_ make sense for mixed-state
+ if fix_probe_aperture:
+ raise NotImplementedError()
+ elif constrain_probe_fourier_amplitude:
+ raise NotImplementedError()
+ if fit_probe_aberrations:
+ raise NotImplementedError()
+ if constrain_probe_amplitude:
+ raise NotImplementedError()
+
+ if orthogonalize_probe:
+ current_probe = self._probe_orthogonalization_constraint(current_probe)
+
+ if not fix_positions:
+ current_positions = self._positions_center_of_mass_constraint(
+ current_positions
+ )
+
+ if global_affine_transformation:
+ current_positions = self._positions_affine_transformation_constraint(
+ self._positions_px_initial, current_positions
+ )
+
+ return current_object, current_probe, current_positions
+
+ def reconstruct(
+ self,
+ max_iter: int = 64,
+ reconstruction_method: str = "gradient-descent",
+ reconstruction_parameter: float = 1.0,
+ reconstruction_parameter_a: float = None,
+ reconstruction_parameter_b: float = None,
+ reconstruction_parameter_c: float = None,
+ max_batch_size: int = None,
+ seed_random: int = None,
+ step_size: float = 0.5,
+ normalization_min: float = 1,
+ positions_step_size: float = 0.9,
+ pure_phase_object_iter: int = 0,
+ fix_com: bool = True,
+ orthogonalize_probe: bool = True,
+ fix_probe_iter: int = 0,
+ fix_probe_aperture_iter: int = 0,
+ constrain_probe_amplitude_iter: int = 0,
+ constrain_probe_amplitude_relative_radius: float = 0.5,
+ constrain_probe_amplitude_relative_width: float = 0.05,
+ constrain_probe_fourier_amplitude_iter: int = 0,
+ constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0,
+ constrain_probe_fourier_amplitude_constant_intensity: bool = False,
+ fix_positions_iter: int = np.inf,
+ global_affine_transformation: bool = True,
+ constrain_position_distance: float = None,
+ gaussian_filter_sigma: float = None,
+ gaussian_filter_iter: int = np.inf,
+ fit_probe_aberrations_iter: int = 0,
+ fit_probe_aberrations_max_angular_order: int = 4,
+ fit_probe_aberrations_max_radial_order: int = 4,
+ butterworth_filter_iter: int = np.inf,
+ q_lowpass: float = None,
+ q_highpass: float = None,
+ butterworth_order: float = 2,
+ tv_denoise_iter: int = np.inf,
+ tv_denoise_weight: float = None,
+ tv_denoise_inner_iter: float = 40,
+ object_positivity: bool = True,
+ shrinkage_rad: float = 0.0,
+ fix_potential_baseline: bool = True,
+ switch_object_iter: int = np.inf,
+ store_iterations: bool = False,
+ progress_bar: bool = True,
+ reset: bool = None,
+ ):
+ """
+ Ptychographic reconstruction main method.
+
+ Parameters
+ --------
+ max_iter: int, optional
+ Maximum number of iterations to run
+ reconstruction_method: str, optional
+ Specifies which reconstruction algorithm to use, one of:
+ "generalized-projections",
+ "DM_AP" (or "difference-map_alternating-projections"),
+ "RAAR" (or "relaxed-averaged-alternating-reflections"),
+ "RRR" (or "relax-reflect-reflect"),
+ "SUPERFLIP" (or "charge-flipping"), or
+ "GD" (or "gradient_descent")
+ reconstruction_parameter: float, optional
+ Reconstruction parameter for various reconstruction methods above.
+ reconstruction_parameter_a: float, optional
+ Reconstruction parameter a for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_b: float, optional
+ Reconstruction parameter b for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_c: float, optional
+ Reconstruction parameter c for reconstruction_method='generalized-projections'.
+ max_batch_size: int, optional
+ Max number of probes to update at once
+ seed_random: int, optional
+ Seeds the random number generator, only applicable when max_batch_size is not None
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ positions_step_size: float, optional
+ Positions update step size
+ pure_phase_object_iter: int, optional
+ Number of iterations where object amplitude is set to unity
+ fix_com: bool, optional
+ If True, fixes center of mass of probe
+ fix_probe_iter: int, optional
+ Number of iterations to run with a fixed probe before updating probe estimate
+ fix_probe_aperture_iter: int, optional
+ Number of iterations to run with a fixed probe fourier amplitude before updating probe estimate
+ constrain_probe_amplitude_iter: int, optional
+ Number of iterations to run while constraining the real-space probe with a top-hat support.
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude_iter: int, optional
+ Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_positions_iter: int, optional
+ Number of iterations to run with fixed positions before updating positions estimate
+ constrain_position_distance: float
+ Distance to constrain position correction within original field of view in A
+ global_affine_transformation: bool, optional
+ If True, positions are assumed to be a global affine transform from initial scan
+ gaussian_filter_sigma: float, optional
+ Standard deviation of gaussian kernel in A
+ gaussian_filter_iter: int, optional
+ Number of iterations to run using object smoothness constraint
+ fit_probe_aberrations_iter: int, optional
+ Number of iterations to run while fitting the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ butterworth_filter_iter: int, optional
+ Number of iterations to run using high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ tv_denoise_iter: int, optional
+ Number of iterations to run using tv denoise filter on object
+ tv_denoise_weight: float
+ Denoising weight. The greater `weight`, the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ object_positivity: bool, optional
+ If True, forces object to be positive
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ fix_potential_baseline: bool
+ If true, the potential mean outside the FOV is forced to zero at each iteration
+ switch_object_iter: int, optional
+ Iteration to switch object type between 'complex' and 'potential' or between
+ 'potential' and 'complex'
+ store_iterations: bool, optional
+ If True, reconstructed objects and probes are stored at each iteration
+ progress_bar: bool, optional
+ If True, reconstruction progress is displayed
+ reset: bool, optional
+ If True, previous reconstructions are ignored
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+ asnumpy = self._asnumpy
+ xp = self._xp
+
+ # Reconstruction method
+
+ if reconstruction_method == "generalized-projections":
+ if (
+ reconstruction_parameter_a is None
+ or reconstruction_parameter_b is None
+ or reconstruction_parameter_c is None
+ ):
+ raise ValueError(
+ (
+ "reconstruction_parameter_a/b/c must all be specified "
+ "when using reconstruction_method='generalized-projections'."
+ )
+ )
+
+ use_projection_scheme = True
+ projection_a = reconstruction_parameter_a
+ projection_b = reconstruction_parameter_b
+ projection_c = reconstruction_parameter_c
+ step_size = None
+ elif (
+ reconstruction_method == "DM_AP"
+ or reconstruction_method == "difference-map_alternating-projections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = 1
+ projection_c = 1 + reconstruction_parameter
+ step_size = None
+ elif (
+ reconstruction_method == "RAAR"
+ or reconstruction_method == "relaxed-averaged-alternating-reflections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = 1 - 2 * reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "RRR"
+ or reconstruction_method == "relax-reflect-reflect"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0:
+ raise ValueError("reconstruction_parameter must be between 0-2.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "SUPERFLIP"
+ or reconstruction_method == "charge-flipping"
+ ):
+ use_projection_scheme = True
+ projection_a = 0
+ projection_b = 1
+ projection_c = 2
+ reconstruction_parameter = None
+ step_size = None
+ elif (
+ reconstruction_method == "GD" or reconstruction_method == "gradient-descent"
+ ):
+ use_projection_scheme = False
+ projection_a = None
+ projection_b = None
+ projection_c = None
+ reconstruction_parameter = None
+ else:
+ raise ValueError(
+ (
+ "reconstruction_method must be one of 'generalized-projections', "
+ "'DM_AP' (or 'difference-map_alternating-projections'), "
+ "'RAAR' (or 'relaxed-averaged-alternating-reflections'), "
+ "'RRR' (or 'relax-reflect-reflect'), "
+ "'SUPERFLIP' (or 'charge-flipping'), "
+ f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}."
+ )
+ )
+
+ if self._verbose:
+ if switch_object_iter > max_iter:
+ first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, "
+ else:
+ switch_object_type = (
+ "complex" if self._object_type == "potential" else "potential"
+ )
+ first_line = (
+ f"Performing {switch_object_iter} iterations using a {self._object_type} object type and "
+ f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, "
+ )
+ if max_batch_size is not None:
+ if use_projection_scheme:
+ raise ValueError(
+ (
+ "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. "
+ "Use reconstruction_method='GD' or set max_batch_size=None."
+ )
+ )
+ else:
+ print(
+ (
+ first_line + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}, "
+ f"in batches of max {max_batch_size} measurements."
+ )
+ )
+
+ else:
+ if reconstruction_parameter is not None:
+ if np.array(reconstruction_parameter).shape == (3,):
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}."
+ )
+ )
+ else:
+ if step_size is not None:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}."
+ )
+ )
+
+ # Batching
+ shuffled_indices = np.arange(self._num_diffraction_patterns)
+ unshuffled_indices = np.zeros_like(shuffled_indices)
+
+ if max_batch_size is not None:
+ xp.random.seed(seed_random)
+ else:
+ max_batch_size = self._num_diffraction_patterns
+
+ # initialization
+ if store_iterations and (not hasattr(self, "object_iterations") or reset):
+ self.object_iterations = []
+ self.probe_iterations = []
+
+ if reset:
+ self._object = self._object_initial.copy()
+ self.error_iterations = []
+ self._probe = self._probe_initial.copy()
+ self._positions_px = self._positions_px_initial.copy()
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ self._exit_waves = None
+ self._object_type = self._object_type_initial
+ if hasattr(self, "_tf"):
+ del self._tf
+ elif reset is None:
+ if hasattr(self, "error"):
+ warnings.warn(
+ (
+ "Continuing reconstruction from previous result. "
+ "Use reset=True for a fresh start."
+ ),
+ UserWarning,
+ )
+ else:
+ self.error_iterations = []
+ self._exit_waves = None
+
+ # main loop
+ for a0 in tqdmnd(
+ max_iter,
+ desc="Reconstructing object and probe",
+ unit=" iter",
+ disable=not progress_bar,
+ ):
+ error = 0.0
+
+ if a0 == switch_object_iter:
+ if self._object_type == "potential":
+ self._object_type = "complex"
+ self._object = xp.exp(1j * self._object)
+ elif self._object_type == "complex":
+ self._object_type = "potential"
+ self._object = xp.angle(self._object)
+
+ # randomize
+ if not use_projection_scheme:
+ np.random.shuffle(shuffled_indices)
+ unshuffled_indices[shuffled_indices] = np.arange(
+ self._num_diffraction_patterns
+ )
+ positions_px = self._positions_px.copy()[shuffled_indices]
+
+ for start, end in generate_batches(
+ self._num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ amplitudes = self._amplitudes[shuffled_indices[start:end]]
+
+ # forward operator
+ (
+ shifted_probes,
+ object_patches,
+ overlap,
+ self._exit_waves,
+ batch_error,
+ ) = self._forward(
+ self._object,
+ self._probe,
+ amplitudes,
+ self._exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ # adjoint operator
+ self._object, self._probe = self._adjoint(
+ self._object,
+ self._probe,
+ object_patches,
+ shifted_probes,
+ self._exit_waves,
+ use_projection_scheme=use_projection_scheme,
+ step_size=step_size,
+ normalization_min=normalization_min,
+ fix_probe=a0 < fix_probe_iter,
+ )
+
+ # position correction
+ if a0 >= fix_positions_iter:
+ positions_px[start:end] = self._position_correction(
+ self._object,
+ shifted_probes[:, 0],
+ overlap[:, 0],
+ amplitudes,
+ self._positions_px,
+ positions_step_size,
+ constrain_position_distance,
+ )
+
+ error += batch_error
+
+ # Normalize Error
+ error /= self._mean_diffraction_intensity * self._num_diffraction_patterns
+
+ # constraints
+ self._positions_px = positions_px.copy()[unshuffled_indices]
+ self._object, self._probe, self._positions_px = self._constraints(
+ self._object,
+ self._probe,
+ self._positions_px,
+ fix_com=fix_com and a0 >= fix_probe_iter,
+ constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude=a0
+ < constrain_probe_fourier_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
+ fit_probe_aberrations=a0 < fit_probe_aberrations_iter
+ and a0 >= fix_probe_iter,
+ fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
+ fix_probe_aperture=a0 < fix_probe_aperture_iter,
+ initial_probe_aperture=self._probe_initial_aperture,
+ fix_positions=a0 < fix_positions_iter,
+ global_affine_transformation=global_affine_transformation,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma is not None,
+ gaussian_filter_sigma=gaussian_filter_sigma,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass is not None or q_highpass is not None),
+ q_lowpass=q_lowpass,
+ q_highpass=q_highpass,
+ butterworth_order=butterworth_order,
+ orthogonalize_probe=orthogonalize_probe,
+ tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None,
+ tv_denoise_weight=tv_denoise_weight,
+ tv_denoise_inner_iter=tv_denoise_inner_iter,
+ object_positivity=object_positivity,
+ shrinkage_rad=shrinkage_rad,
+ object_mask=self._object_fov_mask_inverse
+ if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0
+ else None,
+ pure_phase_object=a0 < pure_phase_object_iter
+ and self._object_type == "complex",
+ )
+
+ self.error_iterations.append(error.item())
+ if store_iterations:
+ self.object_iterations.append(asnumpy(self._object.copy()))
+ self.probe_iterations.append(self.probe_centered)
+
+ # store result
+ self.object = asnumpy(self._object)
+ self.probe = self.probe_centered
+ self.error = error.item()
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _visualize_last_iteration_figax(
+ self,
+ fig,
+ object_ax,
+ convergence_ax: None,
+ cbar: bool,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object on a given fig/ax.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure object_ax lives in
+ object_ax: Axes
+ Matplotlib axes to plot reconstructed object in
+ convergence_ax: Axes, optional
+ Matplotlib axes to plot convergence plot in
+ cbar: bool, optional
+ If true, displays a colorbar
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ cmap = kwargs.pop("cmap", "magma")
+
+ if self._object_type == "complex":
+ obj = np.angle(self.object)
+ else:
+ obj = self.object
+
+ rotated_object = self._crop_rotate_object_fov(obj, padding=padding)
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ im = object_ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(object_ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if convergence_ax is not None and hasattr(self, "error_iterations"):
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = self.error_iterations
+ convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs)
+
+ def _visualize_last_iteration(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool, optional
+ If true, the reconstructed complex probe is displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ figsize = kwargs.pop("figsize", (8, 5))
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ if self._object_type == "complex":
+ obj = np.angle(self.object)
+ else:
+ obj = self.object
+
+ rotated_object = self._crop_rotate_object_fov(obj, padding=padding)
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=2,
+ height_ratios=[4, 1],
+ hspace=0.15,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=1,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ if plot_probe or plot_fourier_probe:
+ # Object
+ ax = fig.add_subplot(spec[0, 0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed object potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed object phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ # Probe
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ ax = fig.add_subplot(spec[0, 1])
+
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ self.probe_fourier[0],
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title("Reconstructed Fourier probe[0]")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ self.probe[0],
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title("Reconstructed probe[0] intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(ax_cb, chroma_boost=chroma_boost)
+
+ else:
+ ax = fig.add_subplot(spec[0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed object potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed object phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if plot_convergence and hasattr(self, "error_iterations"):
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = np.array(self.error_iterations)
+ if plot_probe:
+ ax = fig.add_subplot(spec[1, :])
+ else:
+ ax = fig.add_subplot(spec[1])
+ ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax.set_ylabel("NMSE")
+ ax.set_xlabel("Iteration number")
+ ax.yaxis.tick_right()
+
+ fig.suptitle(f"Normalized mean squared error: {self.error:.3e}")
+ spec.tight_layout(fig)
+
+ def _visualize_all_iterations(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ iterations_grid: Tuple[int, int],
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays all reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool, optional
+ If true, the reconstructed complex probe is displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ asnumpy = self._asnumpy
+
+ if not hasattr(self, "object_iterations"):
+ raise ValueError(
+ (
+ "Object and probe iterations were not saved during reconstruction. "
+ "Please re-run using store_iterations=True."
+ )
+ )
+
+ if iterations_grid == "auto":
+ num_iter = len(self.error_iterations)
+
+ if num_iter == 1:
+ return self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ elif plot_probe or plot_fourier_probe:
+ iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter)
+ else:
+ iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2)
+ else:
+ if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2:
+ raise ValueError()
+
+ auto_figsize = (
+ (3 * iterations_grid[1], 3 * iterations_grid[0] + 1)
+ if plot_convergence
+ else (3 * iterations_grid[1], 3 * iterations_grid[0])
+ )
+ figsize = kwargs.pop("figsize", auto_figsize)
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ errors = np.array(self.error_iterations)
+
+ objects = []
+ object_type = []
+
+ for obj in self.object_iterations:
+ if np.iscomplexobj(obj):
+ obj = np.angle(obj)
+ object_type.append("phase")
+ else:
+ object_type.append("potential")
+ objects.append(self._crop_rotate_object_fov(obj, padding=padding))
+
+ if plot_probe or plot_fourier_probe:
+ total_grids = (np.prod(iterations_grid) / 2).astype("int")
+ probes = self.probe_iterations
+ else:
+ total_grids = np.prod(iterations_grid)
+ max_iter = len(objects) - 1
+ grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1))
+
+ extent = [
+ 0,
+ self.sampling[1] * objects[0].shape[1],
+ self.sampling[0] * objects[0].shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0)
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=2)
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ grid = ImageGrid(
+ fig,
+ spec[0],
+ nrows_ncols=(1, iterations_grid[1])
+ if (plot_probe or plot_fourier_probe)
+ else iterations_grid,
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ im = ax.imshow(
+ objects[grid_range[n]],
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if cbar:
+ grid.cbar_axes[n].colorbar(im)
+
+ if plot_probe or plot_fourier_probe:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ grid = ImageGrid(
+ fig,
+ spec[1],
+ nrows_ncols=(1, iterations_grid[1]),
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ asnumpy(
+ self._return_fourier_probe_from_centered_probe(
+ probes[grid_range[n]][0]
+ )
+ ),
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} Fourier probe[0]")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ probes[grid_range[n]][0],
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} probe[0] intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ add_colorbar_arg(
+ grid.cbar_axes[n],
+ chroma_boost=chroma_boost,
+ )
+
+ if plot_convergence:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ if plot_probe:
+ ax2 = fig.add_subplot(spec[2])
+ else:
+ ax2 = fig.add_subplot(spec[1])
+ ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax2.set_ylabel("NMSE")
+ ax2.set_xlabel("Iteration number")
+ ax2.yaxis.tick_right()
+
+ spec.tight_layout(fig)
+
+ def visualize(
+ self,
+ fig=None,
+ iterations_grid: Tuple[int, int] = None,
+ plot_convergence: bool = True,
+ plot_probe: bool = True,
+ plot_fourier_probe: bool = False,
+ cbar: bool = True,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays reconstructed object and probe.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool, optional
+ If true, the reconstructed complex probe is displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+
+ if iterations_grid is None:
+ self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ else:
+ self._visualize_all_iterations(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ iterations_grid=iterations_grid,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+
+ return self
+
+ def show_fourier_probe(
+ self, probe=None, scalebar=True, pixelsize=None, pixelunits=None, **kwargs
+ ):
+ """
+ Plot probe in fourier space
+
+ Parameters
+ ----------
+ probe: complex array, optional
+ if None is specified, uses the `probe_fourier` property
+ scalebar: bool, optional
+ if True, adds scalebar to probe
+ pixelunits: str, optional
+ units for scalebar, default is A^-1
+ pixelsize: float, optional
+ default is probe reciprocal sampling
+ """
+ asnumpy = self._asnumpy
+
+ if probe is None:
+ probe = list(self.probe_fourier)
+ else:
+ if isinstance(probe, np.ndarray) and probe.ndim == 2:
+ probe = [probe]
+ probe = [asnumpy(self._return_fourier_probe(pr)) for pr in probe]
+
+ if pixelsize is None:
+ pixelsize = self._reciprocal_sampling[1]
+ if pixelunits is None:
+ pixelunits = r"$\AA^{-1}$"
+
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+
+ show_complex(
+ probe if len(probe) > 1 else probe[0],
+ scalebar=scalebar,
+ pixelsize=pixelsize,
+ pixelunits=pixelunits,
+ ticks=False,
+ chroma_boost=chroma_boost,
+ **kwargs,
+ )
+
+ def _return_self_consistency_errors(
+ self,
+ max_batch_size=None,
+ ):
+ """Compute the self-consistency errors for each probe position"""
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # Batch-size
+ if max_batch_size is None:
+ max_batch_size = self._num_diffraction_patterns
+
+ # Re-initialize fractional positions and vector patches
+ errors = np.array([])
+ positions_px = self._positions_px.copy()
+
+ for start, end in generate_batches(
+ self._num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ amplitudes = self._amplitudes[start:end]
+
+ # Overlaps
+ _, _, overlap = self._overlap_projection(self._object, self._probe)
+ fourier_overlap = xp.fft.fft2(overlap)
+ intensity_norm = xp.sqrt(xp.sum(xp.abs(fourier_overlap) ** 2, axis=1))
+
+ # Normalized mean-squared errors
+ batch_errors = xp.sum(
+ xp.abs(amplitudes - intensity_norm) ** 2, axis=(-2, -1)
+ )
+ errors = np.hstack((errors, batch_errors))
+
+ self._positions_px = positions_px.copy()
+ errors /= self._mean_diffraction_intensity
+
+ return asnumpy(errors)
diff --git a/py4DSTEM/process/phase/iterative_multislice_ptychography.py b/py4DSTEM/process/phase/iterative_multislice_ptychography.py
new file mode 100644
index 000000000..93e32b079
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_multislice_ptychography.py
@@ -0,0 +1,3439 @@
+"""
+Module for reconstructing phase objects from 4DSTEM datasets using iterative methods,
+namely multislice ptychography.
+"""
+
+import warnings
+from typing import Mapping, Sequence, Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pylops
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
+from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg, show_complex
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+from emdfile import Custom, tqdmnd
+from py4DSTEM import DataCube
+from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction
+from py4DSTEM.process.phase.utils import (
+ ComplexProbe,
+ fft_shift,
+ generate_batches,
+ polar_aliases,
+ polar_symbols,
+ spatial_frequencies,
+)
+from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar
+from scipy.ndimage import rotate
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+
+class MultislicePtychographicReconstruction(PtychographicReconstruction):
+ """
+ Multislice Ptychographic Reconstruction Class.
+
+ Diffraction intensities dimensions : (Rx,Ry,Qx,Qy)
+ Reconstructed probe dimensions : (Sx,Sy)
+ Reconstructed object dimensions : (T,Px,Py)
+
+ such that (Sx,Sy) is the region-of-interest (ROI) size of our probe
+ and (Px,Py) is the padded-object size we position our ROI around in
+ each of the T slices.
+
+ Parameters
+ ----------
+ energy: float
+ The electron energy of the wave functions in eV
+ num_slices: int
+ Number of slices to use in the forward model
+ slice_thicknesses: float or Sequence[float]
+ Slice thicknesses in angstroms. If float, all slices are assigned the same thickness
+ datacube: DataCube, optional
+ Input 4D diffraction pattern intensities
+ semiangle_cutoff: float, optional
+ Semiangle cutoff for the initial probe guess in mrad
+ semiangle_cutoff_pixels: float, optional
+ Semiangle cutoff for the initial probe guess in pixels
+ rolloff: float, optional
+ Semiangle rolloff for the initial probe guess
+ vacuum_probe_intensity: np.ndarray, optional
+ Vacuum probe to use as intensity aperture for initial probe guess
+ polar_parameters: dict, optional
+ Mapping from aberration symbols to their corresponding values. All aberration
+ magnitudes should be given in Å and angles should be given in radians.
+ object_padding_px: Tuple[int,int], optional
+ Pixel dimensions to pad object with
+ If None, the padding is set to half the probe ROI dimensions
+ initial_object_guess: np.ndarray, optional
+ Initial guess for complex-valued object of dimensions (Px,Py)
+ If None, initialized to 1.0j
+ initial_probe_guess: np.ndarray, optional
+ Initial guess for complex-valued probe of dimensions (Sx,Sy). If None,
+ initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations
+ initial_scan_positions: np.ndarray, optional
+ Probe positions in Å for each diffraction intensity
+ If None, initialized to a grid scan
+ theta_x: float
+ x tilt of propagator (in degrees)
+ theta_y: float
+ y tilt of propagator (in degrees)
+ middle_focus: bool
+ if True, adds half the sample thickness to the defocus
+ object_type: str, optional
+ The object can be reconstructed as a real potential ('potential') or a complex
+ object ('complex')
+ positions_mask: np.ndarray, optional
+ Boolean real space mask to select positions in datacube to skip for reconstruction
+ verbose: bool, optional
+ If True, class methods will inherit this and print additional information
+ device: str, optional
+ Calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ name: str, optional
+ Class name
+ kwargs:
+ Provide the aberration coefficients as keyword arguments.
+ """
+
+ # Class-specific Metadata
+ _class_specific_metadata = ("_num_slices", "_slice_thicknesses")
+
+ def __init__(
+ self,
+ energy: float,
+ num_slices: int,
+ slice_thicknesses: Union[float, Sequence[float]],
+ datacube: DataCube = None,
+ semiangle_cutoff: float = None,
+ semiangle_cutoff_pixels: float = None,
+ rolloff: float = 2.0,
+ vacuum_probe_intensity: np.ndarray = None,
+ polar_parameters: Mapping[str, float] = None,
+ object_padding_px: Tuple[int, int] = None,
+ initial_object_guess: np.ndarray = None,
+ initial_probe_guess: np.ndarray = None,
+ initial_scan_positions: np.ndarray = None,
+ theta_x: float = 0,
+ theta_y: float = 0,
+ middle_focus: bool = False,
+ object_type: str = "complex",
+ positions_mask: np.ndarray = None,
+ verbose: bool = True,
+ device: str = "cpu",
+ name: str = "multi-slice_ptychographic_reconstruction",
+ **kwargs,
+ ):
+ Custom.__init__(self, name=name)
+
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from scipy.special import erf
+
+ self._erf = erf
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from cupyx.scipy.special import erf
+
+ self._erf = erf
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ for key in kwargs.keys():
+ if (key not in polar_symbols) and (key not in polar_aliases.keys()):
+ raise ValueError("{} not a recognized parameter".format(key))
+
+ if np.isscalar(slice_thicknesses):
+ mean_slice_thickness = slice_thicknesses
+ else:
+ mean_slice_thickness = np.mean(slice_thicknesses)
+
+ if middle_focus:
+ if "defocus" in kwargs:
+ kwargs["defocus"] += mean_slice_thickness * num_slices / 2
+ elif "C10" in kwargs:
+ kwargs["C10"] -= mean_slice_thickness * num_slices / 2
+ elif polar_parameters is not None and "defocus" in polar_parameters:
+ polar_parameters["defocus"] = (
+ polar_parameters["defocus"] + mean_slice_thickness * num_slices / 2
+ )
+ elif polar_parameters is not None and "C10" in polar_parameters:
+ polar_parameters["C10"] = (
+ polar_parameters["C10"] - mean_slice_thickness * num_slices / 2
+ )
+
+ self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))
+
+ if polar_parameters is None:
+ polar_parameters = {}
+
+ polar_parameters.update(kwargs)
+ self._set_polar_parameters(polar_parameters)
+
+ slice_thicknesses = np.array(slice_thicknesses)
+ if slice_thicknesses.shape == ():
+ slice_thicknesses = np.tile(slice_thicknesses, num_slices - 1)
+ elif slice_thicknesses.shape[0] != (num_slices - 1):
+ raise ValueError(
+ (
+ f"slice_thicknesses must have length {num_slices - 1}, "
+ f"not {slice_thicknesses.shape[0]}."
+ )
+ )
+
+ if object_type != "potential" and object_type != "complex":
+ raise ValueError(
+ f"object_type must be either 'potential' or 'complex', not {object_type}"
+ )
+ if positions_mask is not None and positions_mask.dtype != "bool":
+ warnings.warn(
+ ("`positions_mask` converted to `bool` array"),
+ UserWarning,
+ )
+ positions_mask = np.asarray(positions_mask, dtype="bool")
+
+ self.set_save_defaults()
+
+ # Data
+ self._datacube = datacube
+ self._object = initial_object_guess
+ self._probe = initial_probe_guess
+
+ # Common Metadata
+ self._vacuum_probe_intensity = vacuum_probe_intensity
+ self._scan_positions = initial_scan_positions
+ self._energy = energy
+ self._semiangle_cutoff = semiangle_cutoff
+ self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
+ self._rolloff = rolloff
+ self._object_type = object_type
+ self._positions_mask = positions_mask
+ self._object_padding_px = object_padding_px
+ self._verbose = verbose
+ self._device = device
+ self._preprocessed = False
+
+ # Class-specific Metadata
+ self._num_slices = num_slices
+ self._slice_thicknesses = slice_thicknesses
+ self._theta_x = theta_x
+ self._theta_y = theta_y
+
+ def _precompute_propagator_arrays(
+ self,
+ gpts: Tuple[int, int],
+ sampling: Tuple[float, float],
+ energy: float,
+ slice_thicknesses: Sequence[float],
+ theta_x: float,
+ theta_y: float,
+ ):
+ """
+ Precomputes propagator arrays complex wave-function will be convolved by,
+ for all slice thicknesses.
+
+ Parameters
+ ----------
+ gpts: Tuple[int,int]
+ Wavefunction pixel dimensions
+ sampling: Tuple[float,float]
+ Wavefunction sampling in A
+ energy: float
+ The electron energy of the wave functions in eV
+ slice_thicknesses: Sequence[float]
+ Array of slice thicknesses in A
+ theta_x: float
+ x tilt of propagator (in degrees)
+ theta_y: float
+ y tilt of propagator (in degrees)
+
+ Returns
+ -------
+ propagator_arrays: np.ndarray
+ (T,Sx,Sy) shape array storing propagator arrays
+ """
+ xp = self._xp
+
+ # Frequencies
+ kx, ky = spatial_frequencies(gpts, sampling)
+ kx = xp.asarray(kx, dtype=xp.float32)
+ ky = xp.asarray(ky, dtype=xp.float32)
+
+ # Propagators
+ wavelength = electron_wavelength_angstrom(energy)
+ num_slices = slice_thicknesses.shape[0]
+ propagators = xp.empty(
+ (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64
+ )
+
+ theta_x = np.deg2rad(theta_x)
+ theta_y = np.deg2rad(theta_y)
+
+ for i, dz in enumerate(slice_thicknesses):
+ propagators[i] = xp.exp(
+ 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz)
+ )
+ propagators[i] *= xp.exp(
+ 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)
+ )
+ propagators[i] *= xp.exp(
+ 1.0j * (2 * kx[:, None] * np.pi * dz * np.tan(theta_x))
+ )
+ propagators[i] *= xp.exp(
+ 1.0j * (2 * ky[None] * np.pi * dz * np.tan(theta_y))
+ )
+
+ return propagators
+
+ def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray):
+ """
+ Propagates array by Fourier convolving array with propagator_array.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ Wavefunction array to be convolved
+ propagator_array: np.ndarray
+ Propagator array to convolve array with
+
+ Returns
+ -------
+ propagated_array: np.ndarray
+ Fourier-convolved array
+ """
+ xp = self._xp
+
+ return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array)
+
+ def preprocess(
+ self,
+ diffraction_intensities_shape: Tuple[int, int] = None,
+ reshaping_method: str = "fourier",
+ probe_roi_shape: Tuple[int, int] = None,
+ dp_mask: np.ndarray = None,
+ fit_function: str = "plane",
+ plot_center_of_mass: str = "default",
+ plot_rotation: bool = True,
+ maximize_divergence: bool = False,
+ rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0),
+ plot_probe_overlaps: bool = True,
+ force_com_rotation: float = None,
+ force_com_transpose: float = None,
+ force_com_shifts: float = None,
+ force_scan_sampling: float = None,
+ force_angular_sampling: float = None,
+ force_reciprocal_sampling: float = None,
+ object_fov_mask: np.ndarray = None,
+ crop_patterns: bool = False,
+ **kwargs,
+ ):
+ """
+ Ptychographic preprocessing step.
+ Calls the base class methods:
+
+ _extract_intensities_and_calibrations_from_datacube,
+ _compute_center_of_mass(),
+ _solve_CoM_rotation(),
+ _normalize_diffraction_intensities()
+ _calculate_scan_positions_in_px()
+
+ Additionally, it initializes an (T,Px,Py) array of 1.0j
+ and a complex probe using the specified polar parameters.
+
+ Parameters
+ ----------
+ diffraction_intensities_shape: Tuple[int,int], optional
+ Pixel dimensions (Qx',Qy') of the resampled diffraction intensities
+ If None, no resampling of diffraction intenstities is performed
+ reshaping_method: str, optional
+ Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default)
+ probe_roi_shape, (int,int), optional
+ Padded diffraction intensities shape.
+ If None, no padding is performed
+ dp_mask: ndarray, optional
+ Mask for datacube intensities (Qx,Qy)
+ fit_function: str, optional
+ 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two'
+ plot_center_of_mass: str, optional
+ If 'default', the corrected CoM arrays will be displayed
+ If 'all', the computed and fitted CoM arrays will be displayed
+ plot_rotation: bool, optional
+ If True, the CoM curl minimization search result will be displayed
+ maximize_divergence: bool, optional
+ If True, the divergence of the CoM gradient vector field is maximized
+ rotation_angles_deg: np.darray, optional
+ Array of angles in degrees to perform curl minimization over
+ plot_probe_overlaps: bool, optional
+ If True, initial probe overlaps scanned over the object will be displayed
+ force_com_rotation: float (degrees), optional
+ Force relative rotation angle between real and reciprocal space
+ force_com_transpose: bool, optional
+ Force whether diffraction intensities need to be transposed.
+ force_com_shifts: tuple of ndarrays (CoMx, CoMy)
+ Amplitudes come from diffraction patterns shifted with
+ the CoM in the upper left corner for each probe unless
+ shift is overwritten.
+ force_scan_sampling: float, optional
+ Override DataCube real space scan pixel size calibrations, in Angstrom
+ force_angular_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in mrad
+ force_reciprocal_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in A^-1
+ object_fov_mask: np.ndarray (boolean)
+ Boolean mask of FOV. Used to calculate additional shrinkage of object
+ If None, probe_overlap intensity is thresholded
+ crop_patterns: bool
+ if True, crop patterns to avoid wrap around of patterns when centering
+
+ Returns
+ --------
+ self: MultislicePtychographicReconstruction
+ Self to accommodate chaining
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # set additional metadata
+ self._diffraction_intensities_shape = diffraction_intensities_shape
+ self._reshaping_method = reshaping_method
+ self._probe_roi_shape = probe_roi_shape
+ self._dp_mask = dp_mask
+
+ if self._datacube is None:
+ raise ValueError(
+ (
+ "The preprocess() method requires a DataCube. "
+ "Please run ptycho.attach_datacube(DataCube) first."
+ )
+ )
+
+ (
+ self._datacube,
+ self._vacuum_probe_intensity,
+ self._dp_mask,
+ force_com_shifts,
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube,
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ dp_mask=self._dp_mask,
+ com_shifts=force_com_shifts,
+ )
+
+ self._intensities = self._extract_intensities_and_calibrations_from_datacube(
+ self._datacube,
+ require_calibrations=True,
+ force_scan_sampling=force_scan_sampling,
+ force_angular_sampling=force_angular_sampling,
+ force_reciprocal_sampling=force_reciprocal_sampling,
+ )
+
+ (
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_fitted_x,
+ self._com_fitted_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ ) = self._calculate_intensities_center_of_mass(
+ self._intensities,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts,
+ )
+
+ (
+ self._rotation_best_rad,
+ self._rotation_best_transpose,
+ self._com_x,
+ self._com_y,
+ self.com_x,
+ self.com_y,
+ ) = self._solve_for_center_of_mass_relative_rotation(
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ rotation_angles_deg=rotation_angles_deg,
+ plot_rotation=plot_rotation,
+ plot_center_of_mass=plot_center_of_mass,
+ maximize_divergence=maximize_divergence,
+ force_com_rotation=force_com_rotation,
+ force_com_transpose=force_com_transpose,
+ **kwargs,
+ )
+
+ (
+ self._amplitudes,
+ self._mean_diffraction_intensity,
+ ) = self._normalize_diffraction_intensities(
+ self._intensities,
+ self._com_fitted_x,
+ self._com_fitted_y,
+ crop_patterns,
+ self._positions_mask,
+ )
+
+ # explicitly delete namespace
+ self._num_diffraction_patterns = self._amplitudes.shape[0]
+ self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:])
+ del self._intensities
+
+ self._positions_px = self._calculate_scan_positions_in_pixels(
+ self._scan_positions, self._positions_mask
+ )
+
+ # handle semiangle specified in pixels
+ if self._semiangle_cutoff_pixels:
+ self._semiangle_cutoff = (
+ self._semiangle_cutoff_pixels * self._angular_sampling[0]
+ )
+
+ # Object Initialization
+ if self._object is None:
+ pad_x = self._object_padding_px[0][1]
+ pad_y = self._object_padding_px[1][1]
+ p, q = np.round(np.max(self._positions_px, axis=0))
+ p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype(
+ "int"
+ )
+ q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype(
+ "int"
+ )
+ if self._object_type == "potential":
+ self._object = xp.zeros((self._num_slices, p, q), dtype=xp.float32)
+ elif self._object_type == "complex":
+ self._object = xp.ones((self._num_slices, p, q), dtype=xp.complex64)
+ else:
+ if self._object_type == "potential":
+ self._object = xp.asarray(self._object, dtype=xp.float32)
+ elif self._object_type == "complex":
+ self._object = xp.asarray(self._object, dtype=xp.complex64)
+
+ self._object_initial = self._object.copy()
+ self._object_type_initial = self._object_type
+ self._object_shape = self._object.shape[-2:]
+
+ self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32)
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+
+ self._positions_px_initial = self._positions_px.copy()
+ self._positions_initial = self._positions_px_initial.copy()
+ self._positions_initial[:, 0] *= self.sampling[0]
+ self._positions_initial[:, 1] *= self.sampling[1]
+
+ # Vectorized Patches
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ # Probe Initialization
+ if self._probe is None:
+ if self._vacuum_probe_intensity is not None:
+ self._semiangle_cutoff = np.inf
+ self._vacuum_probe_intensity = xp.asarray(
+ self._vacuum_probe_intensity, dtype=xp.float32
+ )
+ probe_x0, probe_y0 = get_CoM(
+ self._vacuum_probe_intensity,
+ device=self._device,
+ )
+ self._vacuum_probe_intensity = get_shifted_ar(
+ self._vacuum_probe_intensity,
+ -probe_x0,
+ -probe_y0,
+ bilinear=True,
+ device=self._device,
+ )
+ if crop_patterns:
+ self._vacuum_probe_intensity = self._vacuum_probe_intensity[
+ self._crop_mask
+ ].reshape(self._region_of_interest_shape)
+
+ self._probe = (
+ ComplexProbe(
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ energy=self._energy,
+ semiangle_cutoff=self._semiangle_cutoff,
+ rolloff=self._rolloff,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )
+ .build()
+ ._array
+ )
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity)
+
+ else:
+ if isinstance(self._probe, ComplexProbe):
+ if self._probe._gpts != self._region_of_interest_shape:
+ raise ValueError()
+ if hasattr(self._probe, "_array"):
+ self._probe = self._probe._array
+ else:
+ self._probe._xp = xp
+ self._probe = self._probe.build()._array
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(
+ self._mean_diffraction_intensity / probe_intensity
+ )
+
+ else:
+ self._probe = xp.asarray(self._probe, dtype=xp.complex64)
+
+ self._probe_initial = self._probe.copy()
+ self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe))
+
+ self._known_aberrations_array = ComplexProbe(
+ energy=self._energy,
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )._evaluate_ctf()
+
+ # Precomputed propagator arrays
+ self._propagator_arrays = self._precompute_propagator_arrays(
+ self._region_of_interest_shape,
+ self.sampling,
+ self._energy,
+ self._slice_thicknesses,
+ self._theta_x,
+ self._theta_y,
+ )
+
+ # overlaps
+ shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp)
+ probe_intensities = xp.abs(shifted_probes) ** 2
+ probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities)
+ probe_overlap = self._gaussian_filter(probe_overlap, 1.0)
+
+ if object_fov_mask is None:
+ self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max())
+ else:
+ self._object_fov_mask = np.asarray(object_fov_mask)
+ self._object_fov_mask_inverse = np.invert(self._object_fov_mask)
+
+ if plot_probe_overlaps:
+ figsize = kwargs.pop("figsize", (13, 4))
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ # initial probe
+ complex_probe_rgb = Complex2RGB(
+ self.probe_centered,
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ # propagated
+ propagated_probe = self._probe.copy()
+
+ for s in range(self._num_slices - 1):
+ propagated_probe = self._propagate_array(
+ propagated_probe, self._propagator_arrays[s]
+ )
+ complex_propagated_rgb = Complex2RGB(
+ asnumpy(self._return_centered_probe(propagated_probe)),
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ extent = [
+ 0,
+ self.sampling[1] * self._object_shape[1],
+ self.sampling[0] * self._object_shape[0],
+ 0,
+ ]
+
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)
+
+ ax1.imshow(
+ complex_probe_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax1)
+ cax1 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(cax1, chroma_boost=chroma_boost)
+ ax1.set_ylabel("x [A]")
+ ax1.set_xlabel("y [A]")
+ ax1.set_title("Initial probe intensity")
+
+ ax2.imshow(
+ complex_propagated_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax2)
+ cax2 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(
+ cax2,
+ chroma_boost=chroma_boost,
+ )
+ ax2.set_ylabel("x [A]")
+ ax2.set_xlabel("y [A]")
+ ax2.set_title("Propagated probe intensity")
+
+ ax3.imshow(
+ asnumpy(probe_overlap),
+ extent=extent,
+ cmap="Greys_r",
+ )
+ ax3.scatter(
+ self.positions[:, 1],
+ self.positions[:, 0],
+ s=2.5,
+ color=(1, 0, 0, 1),
+ )
+ ax3.set_ylabel("x [A]")
+ ax3.set_xlabel("y [A]")
+ ax3.set_xlim((extent[0], extent[1]))
+ ax3.set_ylim((extent[2], extent[3]))
+ ax3.set_title("Object field of view")
+
+ fig.tight_layout()
+
+ self._preprocessed = True
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _overlap_projection(self, current_object, current_probe):
+ """
+ Ptychographic overlap projection method.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ """
+
+ xp = self._xp
+
+ if self._object_type == "potential":
+ complex_object = xp.exp(1j * current_object)
+ else:
+ complex_object = current_object
+
+ object_patches = complex_object[
+ :,
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ]
+
+ propagated_probes = xp.empty_like(object_patches)
+ propagated_probes[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes = object_patches[s] * propagated_probes[s]
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes[s + 1] = self._propagate_array(
+ transmitted_probes, self._propagator_arrays[s]
+ )
+
+ return propagated_probes, object_patches, transmitted_probes
+
+ def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes):
+ """
+ Ptychographic fourier projection method for GD method.
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Exit wave difference
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ fourier_exit_waves = xp.fft.fft2(transmitted_probes)
+
+ error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2)
+
+ modified_exit_wave = xp.fft.ifft2(
+ amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves))
+ )
+
+ exit_waves = modified_exit_wave - transmitted_probes
+
+ return exit_waves, error
+
+ def _projection_sets_fourier_projection(
+ self,
+ amplitudes,
+ transmitted_probes,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic fourier projection method for DM_AP and RAAR methods.
+ Generalized projection using three parameters: a,b,c
+
+ DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha
+ DM: DM_AP(1.0), AP: DM_AP(0.0)
+
+ RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2
+ DM : RAAR(1.0)
+
+ RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2
+ DM: RRR(1.0)
+
+ SUPERFLIP : a = 0, b = 1, c = 2
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit wave difference
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ projection_x = 1 - projection_a - projection_b
+ projection_y = 1 - projection_c
+
+ if exit_waves is None:
+ exit_waves = transmitted_probes.copy()
+
+ fourier_exit_waves = xp.fft.fft2(transmitted_probes)
+ error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2)
+
+ factor_to_be_projected = (
+ projection_c * transmitted_probes + projection_y * exit_waves
+ )
+ fourier_projected_factor = xp.fft.fft2(factor_to_be_projected)
+
+ fourier_projected_factor = amplitudes * xp.exp(
+ 1j * xp.angle(fourier_projected_factor)
+ )
+ projected_factor = xp.fft.ifft2(fourier_projected_factor)
+
+ exit_waves = (
+ projection_x * exit_waves
+ + projection_a * transmitted_probes
+ + projection_b * projected_factor
+ )
+
+ return exit_waves, error
+
+ def _forward(
+ self,
+ current_object,
+ current_probe,
+ amplitudes,
+ exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic forward operator.
+ Calls _overlap_projection() and the appropriate _fourier_projection().
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ (
+ propagated_probes,
+ object_patches,
+ transmitted_probes,
+ ) = self._overlap_projection(current_object, current_probe)
+
+ if use_projection_scheme:
+ exit_waves, error = self._projection_sets_fourier_projection(
+ amplitudes,
+ transmitted_probes,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ else:
+ exit_waves, error = self._gradient_descent_fourier_projection(
+ amplitudes, transmitted_probes
+ )
+
+ return propagated_probes, object_patches, transmitted_probes, exit_waves, error
+
+ def _gradient_descent_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ for s in reversed(range(self._num_slices)):
+ probe = propagated_probes[s]
+ obj = object_patches[s]
+
+ # object-update
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(probe) ** 2
+ )
+
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ if self._object_type == "potential":
+ current_object[s] += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves)
+ )
+ * probe_normalization
+ )
+ elif self._object_type == "complex":
+ current_object[s] += step_size * (
+ self._sum_overlapping_patches_bincounts(xp.conj(probe) * exit_waves)
+ * probe_normalization
+ )
+
+ # back-transmit
+ exit_waves *= xp.conj(obj) # / xp.abs(obj) ** 2
+
+ if s > 0:
+ # back-propagate
+ exit_waves = self._propagate_array(
+ exit_waves, xp.conj(self._propagator_arrays[s - 1])
+ )
+ elif not fix_probe:
+ # probe-update
+ object_normalization = xp.sum(
+ (xp.abs(obj) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe += (
+ step_size
+ * xp.sum(
+ exit_waves,
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return current_object, current_probe
+
+ def _projection_sets_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for DM_AP and RAAR methods.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ # careful not to modify exit_waves in-place for projection set methods
+ exit_waves_copy = exit_waves.copy()
+ for s in reversed(range(self._num_slices)):
+ probe = propagated_probes[s]
+ obj = object_patches[s]
+
+ # object-update
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(probe) ** 2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ if self._object_type == "potential":
+ current_object[s] = (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy)
+ )
+ * probe_normalization
+ )
+ elif self._object_type == "complex":
+ current_object[s] = (
+ self._sum_overlapping_patches_bincounts(
+ xp.conj(probe) * exit_waves_copy
+ )
+ * probe_normalization
+ )
+
+ # back-transmit
+ exit_waves_copy *= xp.conj(obj) # / xp.abs(obj) ** 2
+
+ if s > 0:
+ # back-propagate
+ exit_waves_copy = self._propagate_array(
+ exit_waves_copy, xp.conj(self._propagator_arrays[s - 1])
+ )
+
+ elif not fix_probe:
+ # probe-update
+ object_normalization = xp.sum(
+ (xp.abs(obj) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe = (
+ xp.sum(
+ exit_waves_copy,
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return current_object, current_probe
+
+ def _adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ use_projection_scheme: bool,
+ step_size: float,
+ normalization_min: float,
+ fix_probe: bool,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+
+ if use_projection_scheme:
+ current_object, current_probe = self._projection_sets_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ )
+ else:
+ current_object, current_probe = self._gradient_descent_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ )
+
+ return current_object, current_probe
+
+ def _position_correction(
+ self,
+ current_object,
+ current_probe,
+ transmitted_probes,
+ amplitudes,
+ current_positions,
+ positions_step_size,
+ constrain_position_distance,
+ ):
+ """
+ Position correction using estimated intensity gradient.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe:np.ndarray
+ fractionally-shifted probes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ amplitudes: np.ndarray
+ Measured amplitudes
+ current_positions: np.ndarray
+ Current positions estimate
+ positions_step_size: float
+ Positions step size
+ constrain_position_distance: float
+ Distance to constrain position correction within original
+ field of view in A
+
+ Returns
+ --------
+ updated_positions: np.ndarray
+ Updated positions estimate
+ """
+
+ xp = self._xp
+
+ # Intensity gradient
+ exit_waves_fft = xp.fft.fft2(transmitted_probes)
+ exit_waves_fft_conj = xp.conj(exit_waves_fft)
+ estimated_intensity = xp.abs(exit_waves_fft) ** 2
+ measured_intensity = amplitudes**2
+
+ flat_shape = (transmitted_probes.shape[0], -1)
+ difference_intensity = (measured_intensity - estimated_intensity).reshape(
+ flat_shape
+ )
+
+ # Computing perturbed exit waves one at a time to save on memory
+
+ if self._object_type == "potential":
+ complex_object = xp.exp(1j * current_object)
+ else:
+ complex_object = current_object
+
+ # dx
+ obj_rolled_patches = complex_object[
+ :,
+ (self._vectorized_patch_indices_row + 1) % self._object_shape[0],
+ self._vectorized_patch_indices_col,
+ ]
+
+ propagated_probes_perturbed = xp.empty_like(obj_rolled_patches)
+ propagated_probes_perturbed[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes_perturbed = (
+ obj_rolled_patches[s] * propagated_probes_perturbed[s]
+ )
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes_perturbed[s + 1] = self._propagate_array(
+ transmitted_probes_perturbed, self._propagator_arrays[s]
+ )
+
+ exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed)
+
+ # dy
+ obj_rolled_patches = complex_object[
+ :,
+ self._vectorized_patch_indices_row,
+ (self._vectorized_patch_indices_col + 1) % self._object_shape[1],
+ ]
+
+ propagated_probes_perturbed = xp.empty_like(obj_rolled_patches)
+ propagated_probes_perturbed[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes_perturbed = (
+ obj_rolled_patches[s] * propagated_probes_perturbed[s]
+ )
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes_perturbed[s + 1] = self._propagate_array(
+ transmitted_probes_perturbed, self._propagator_arrays[s]
+ )
+
+ exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed)
+
+ partial_intensity_dx = 2 * xp.real(
+ exit_waves_dx_fft * exit_waves_fft_conj
+ ).reshape(flat_shape)
+ partial_intensity_dy = 2 * xp.real(
+ exit_waves_dy_fft * exit_waves_fft_conj
+ ).reshape(flat_shape)
+
+ coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy))
+
+ # positions_update = xp.einsum(
+ # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity
+ # )
+
+ coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2)
+ positions_update = (
+ xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix)
+ @ coefficients_matrix_T
+ @ difference_intensity[..., None]
+ )
+
+ if constrain_position_distance is not None:
+ constrain_position_distance /= xp.sqrt(
+ self.sampling[0] ** 2 + self.sampling[1] ** 2
+ )
+ x1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 0
+ ]
+ y1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 1
+ ]
+ x0 = self._positions_px_initial[:, 0]
+ y0 = self._positions_px_initial[:, 1]
+ if self._rotation_best_transpose:
+ x0, y0 = xp.array([y0, x0])
+ x1, y1 = xp.array([y1, x1])
+
+ if self._rotation_best_rad is not None:
+ rotation_angle = self._rotation_best_rad
+ x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin(
+ -rotation_angle
+ ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle)
+ x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin(
+ -rotation_angle
+ ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle)
+
+ outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + (
+ x1 < (xp.min(x0) - constrain_position_distance)
+ ) + (y1 > (xp.max(y0) + constrain_position_distance)) + (
+ y1 < (xp.min(y0) - constrain_position_distance)
+ ) > 0
+
+ positions_update[..., 0][outlier_ind] = 0
+
+ current_positions -= positions_step_size * positions_update[..., 0]
+
+ return current_positions
+
+ def _object_butterworth_constraint(
+ self, current_object, q_lowpass, q_highpass, butterworth_order
+ ):
+ """
+ 2D Butterworth filter
+ Used for low/high-pass filtering object.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+ qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0])
+ qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1])
+ qya, qxa = xp.meshgrid(qy, qx)
+ qra = xp.sqrt(qxa**2 + qya**2)
+
+ env = xp.ones_like(qra)
+ if q_highpass:
+ env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order))
+ if q_lowpass:
+ env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order))
+
+ current_object_mean = xp.mean(current_object)
+ current_object -= current_object_mean
+ current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env[None])
+ current_object += current_object_mean
+
+ if self._object_type == "potential":
+ current_object = xp.real(current_object)
+
+ return current_object
+
+ def _object_kz_regularization_constraint(
+ self, current_object, kz_regularization_gamma
+ ):
+ """
+ Arctan regularization filter
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ kz_regularization_gamma: float
+ Slice regularization strength
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+
+ current_object = xp.pad(
+ current_object, pad_width=((1, 0), (0, 0), (0, 0)), mode="constant"
+ )
+
+ qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0])
+ qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1])
+ qz = xp.fft.fftfreq(current_object.shape[0], self._slice_thicknesses[0])
+
+ kz_regularization_gamma *= self._slice_thicknesses[0] / self.sampling[0]
+
+ qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij")
+ qz2 = qza**2 * kz_regularization_gamma**2
+ qr2 = qxa**2 + qya**2
+
+ w = 1 - 2 / np.pi * xp.arctan2(qz2, qr2)
+
+ current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * w)
+ current_object = current_object[1:]
+
+ if self._object_type == "potential":
+ current_object = xp.real(current_object)
+
+ return current_object
+
+ def _object_identical_slices_constraint(self, current_object):
+ """
+ Strong regularization forcing all slices to be identical
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ object_mean = current_object.mean(0, keepdims=True)
+ current_object[:] = object_mean
+
+ return current_object
+
+ def _object_denoise_tv_pylops(self, current_object, weights, iterations):
+ """
+ Performs second order TV denoising along x and y
+
+ Parameters
+ ----------
+ current_object: np.ndarray
+ Current object estimate
+ weights : [float, float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ iterations: float
+ Number of iterations to run in denoising algorithm.
+ `niter_out` in pylops
+
+ Returns
+ -------
+ constrained_object: np.ndarray
+ Constrained object estimate
+
+ """
+ xp = self._xp
+
+ if xp.iscomplexobj(current_object):
+ current_object_tv = current_object
+ warnings.warn(
+ ("TV denoising is currently only supported for potential objects."),
+ UserWarning,
+ )
+
+ else:
+ # zero pad at top and bottom slice
+ pad_width = ((1, 1), (0, 0), (0, 0))
+ current_object = xp.pad(
+ current_object, pad_width=pad_width, mode="constant"
+ )
+
+ # run tv denoising
+ nz, nx, ny = current_object.shape
+ niter_out = iterations
+ niter_in = 1
+ Iop = pylops.Identity(nx * ny * nz)
+
+ if weights[0] == 0:
+ xy_laplacian = pylops.Laplacian(
+ (nz, nx, ny), axes=(1, 2), edge=False, kind="backward"
+ )
+ l1_regs = [xy_laplacian]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=[weights[1]],
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ elif weights[1] == 0:
+ z_gradient = pylops.FirstDerivative(
+ (nz, nx, ny), axis=0, edge=False, kind="backward"
+ )
+ l1_regs = [z_gradient]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=[weights[0]],
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ else:
+ z_gradient = pylops.FirstDerivative(
+ (nz, nx, ny), axis=0, edge=False, kind="backward"
+ )
+ xy_laplacian = pylops.Laplacian(
+ (nz, nx, ny), axes=(1, 2), edge=False, kind="backward"
+ )
+ l1_regs = [z_gradient, xy_laplacian]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=weights,
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ # remove padding
+ current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1]
+
+ return current_object_tv
+
+ def _constraints(
+ self,
+ current_object,
+ current_probe,
+ current_positions,
+ fix_com,
+ fit_probe_aberrations,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ constrain_probe_amplitude,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ fix_probe_aperture,
+ initial_probe_aperture,
+ fix_positions,
+ global_affine_transformation,
+ gaussian_filter,
+ gaussian_filter_sigma,
+ butterworth_filter,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ kz_regularization_filter,
+ kz_regularization_gamma,
+ identical_slices,
+ object_positivity,
+ shrinkage_rad,
+ object_mask,
+ pure_phase_object,
+ tv_denoise_chambolle,
+ tv_denoise_weight_chambolle,
+ tv_denoise_pad_chambolle,
+ tv_denoise,
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ ):
+ """
+ Ptychographic constraints operator.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ current_positions: np.ndarray
+ Current positions estimate
+ fix_com: bool
+ If True, probe CoM is fixed to the center
+ fit_probe_aberrations: bool
+ If True, fits the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ constrain_probe_amplitude: bool
+ If True, probe amplitude is constrained by top hat function
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude: bool
+ If True, probe aperture is constrained by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_probe_aperture: bool
+ If True, probe Fourier amplitude is replaced by initial_probe_aperture
+ initial_probe_aperture: np.ndarray
+ Initial probe aperture to use in replacing probe Fourier amplitude
+ fix_positions: bool
+ If True, positions are not updated
+ gaussian_filter: bool
+ If True, applies real-space gaussian filter in A
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel
+ butterworth_filter: bool
+ If True, applies fourier-space butterworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ kz_regularization_filter: bool
+ If True, applies fourier-space arctan regularization filter
+ kz_regularization_gamma: float
+ Slice regularization strength
+ identical_slices: bool
+ If True, forces all object slices to be identical
+ object_positivity: bool
+ If True, forces object to be positive
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ object_mask: np.ndarray (boolean)
+ If not None, used to calculate additional shrinkage using masked-mean of object
+ pure_phase_object: bool
+ If True, object amplitude is set to unity
+ tv_denoise_chambolle: bool
+ If True, performs TV denoising along z
+ tv_denoise_weight_chambolle: float
+ weight of tv denoising constraint
+ tv_denoise_pad_chambolle: bool
+ if True, pads object at top and bottom with zeros before applying denoising
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weights: [float,float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ constrained_positions: np.ndarray
+ Constrained positions estimate
+ """
+
+ if gaussian_filter:
+ current_object = self._object_gaussian_constraint(
+ current_object, gaussian_filter_sigma, pure_phase_object
+ )
+
+ if butterworth_filter:
+ current_object = self._object_butterworth_constraint(
+ current_object,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ )
+
+ if identical_slices:
+ current_object = self._object_identical_slices_constraint(current_object)
+ elif kz_regularization_filter:
+ current_object = self._object_kz_regularization_constraint(
+ current_object, kz_regularization_gamma
+ )
+ elif tv_denoise:
+ current_object = self._object_denoise_tv_pylops(
+ current_object,
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ )
+ elif tv_denoise_chambolle:
+ current_object = self._object_denoise_tv_chambolle(
+ current_object,
+ tv_denoise_weight_chambolle,
+ axis=0,
+ pad_object=tv_denoise_pad_chambolle,
+ )
+
+ if shrinkage_rad > 0.0 or object_mask is not None:
+ current_object = self._object_shrinkage_constraint(
+ current_object,
+ shrinkage_rad,
+ object_mask,
+ )
+
+ if self._object_type == "complex":
+ current_object = self._object_threshold_constraint(
+ current_object, pure_phase_object
+ )
+ elif object_positivity:
+ current_object = self._object_positivity_constraint(current_object)
+
+ if fix_com:
+ current_probe = self._probe_center_of_mass_constraint(current_probe)
+
+ if fix_probe_aperture:
+ current_probe = self._probe_aperture_constraint(
+ current_probe,
+ initial_probe_aperture,
+ )
+ elif constrain_probe_fourier_amplitude:
+ current_probe = self._probe_fourier_amplitude_constraint(
+ current_probe,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ )
+
+ if fit_probe_aberrations:
+ current_probe = self._probe_aberration_fitting_constraint(
+ current_probe,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ )
+
+ if constrain_probe_amplitude:
+ current_probe = self._probe_amplitude_constraint(
+ current_probe,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ )
+
+ if not fix_positions:
+ current_positions = self._positions_center_of_mass_constraint(
+ current_positions
+ )
+
+ if global_affine_transformation:
+ current_positions = self._positions_affine_transformation_constraint(
+ self._positions_px_initial, current_positions
+ )
+
+ return current_object, current_probe, current_positions
+
+ def reconstruct(
+ self,
+ max_iter: int = 64,
+ reconstruction_method: str = "gradient-descent",
+ reconstruction_parameter: float = 1.0,
+ reconstruction_parameter_a: float = None,
+ reconstruction_parameter_b: float = None,
+ reconstruction_parameter_c: float = None,
+ max_batch_size: int = None,
+ seed_random: int = None,
+ step_size: float = 0.5,
+ normalization_min: float = 1,
+ positions_step_size: float = 0.9,
+ fix_com: bool = True,
+ fix_probe_iter: int = 0,
+ fix_probe_aperture_iter: int = 0,
+ constrain_probe_amplitude_iter: int = 0,
+ constrain_probe_amplitude_relative_radius: float = 0.5,
+ constrain_probe_amplitude_relative_width: float = 0.05,
+ constrain_probe_fourier_amplitude_iter: int = 0,
+ constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0,
+ constrain_probe_fourier_amplitude_constant_intensity: bool = False,
+ fix_positions_iter: int = np.inf,
+ constrain_position_distance: float = None,
+ global_affine_transformation: bool = True,
+ gaussian_filter_sigma: float = None,
+ gaussian_filter_iter: int = np.inf,
+ fit_probe_aberrations_iter: int = 0,
+ fit_probe_aberrations_max_angular_order: int = 4,
+ fit_probe_aberrations_max_radial_order: int = 4,
+ butterworth_filter_iter: int = np.inf,
+ q_lowpass: float = None,
+ q_highpass: float = None,
+ butterworth_order: float = 2,
+ kz_regularization_filter_iter: int = np.inf,
+ kz_regularization_gamma: Union[float, np.ndarray] = None,
+ identical_slices_iter: int = 0,
+ object_positivity: bool = True,
+ shrinkage_rad: float = 0.0,
+ fix_potential_baseline: bool = True,
+ pure_phase_object_iter: int = 0,
+ tv_denoise_iter_chambolle=np.inf,
+ tv_denoise_weight_chambolle=None,
+ tv_denoise_pad_chambolle=True,
+ tv_denoise_iter=np.inf,
+ tv_denoise_weights=None,
+ tv_denoise_inner_iter=40,
+ switch_object_iter: int = np.inf,
+ store_iterations: bool = False,
+ progress_bar: bool = True,
+ reset: bool = None,
+ ):
+ """
+ Ptychographic reconstruction main method.
+
+ Parameters
+ --------
+ max_iter: int, optional
+ Maximum number of iterations to run
+ reconstruction_method: str, optional
+ Specifies which reconstruction algorithm to use, one of:
+ "generalized-projections",
+ "DM_AP" (or "difference-map_alternating-projections"),
+ "RAAR" (or "relaxed-averaged-alternating-reflections"),
+ "RRR" (or "relax-reflect-reflect"),
+ "SUPERFLIP" (or "charge-flipping"), or
+ "GD" (or "gradient_descent")
+ reconstruction_parameter: float, optional
+ Reconstruction parameter for various reconstruction methods above.
+ reconstruction_parameter_a: float, optional
+ Reconstruction parameter a for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_b: float, optional
+ Reconstruction parameter b for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_c: float, optional
+ Reconstruction parameter c for reconstruction_method='generalized-projections'.
+ max_batch_size: int, optional
+ Max number of probes to update at once
+ seed_random: int, optional
+ Seeds the random number generator, only applicable when max_batch_size is not None
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ positions_step_size: float, optional
+ Positions update step size
+ fix_com: bool, optional
+ If True, fixes center of mass of probe
+ fix_probe_iter: int, optional
+ Number of iterations to run with a fixed probe before updating probe estimate
+ fix_probe_aperture_iter: int, optional
+ Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate
+ constrain_probe_amplitude_iter: int, optional
+ Number of iterations to run while constraining the real-space probe with a top-hat support.
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude_iter: int, optional
+ Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_positions_iter: int, optional
+ Number of iterations to run with fixed positions before updating positions estimate
+ constrain_position_distance: float
+ Distance to constrain position correction within original field of view in A
+ global_affine_transformation: bool, optional
+ If True, positions are assumed to be a global affine transform from initial scan
+ gaussian_filter_sigma: float, optional
+ Standard deviation of gaussian kernel in A
+ gaussian_filter_iter: int, optional
+ Number of iterations to run using object smoothness constraint
+ fit_probe_aberrations_iter: int, optional
+ Number of iterations to run while fitting the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ butterworth_filter_iter: int, optional
+ Number of iterations to run using high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ kz_regularization_filter_iter: int, optional
+ Number of iterations to run using kz regularization filter
+ kz_regularization_gamma, float, optional
+ kz regularization strength
+ identical_slices_iter: int, optional
+ Number of iterations to run using identical slices
+ object_positivity: bool, optional
+ If True, forces object to be positive
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ fix_potential_baseline: bool
+ If true, the potential mean outside the FOV is forced to zero at each iteration
+ pure_phase_object_iter: int, optional
+ Number of iterations where object amplitude is set to unity
+ tv_denoise_iter_chambolle: bool
+ Number of iterations with TV denoisining
+ tv_denoise_weight_chambolle: float
+ weight of tv denoising constraint
+ tv_denoise_pad_chambolle: bool
+ if True, pads object at top and bottom with zeros before applying denoising
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weights: [float,float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ switch_object_iter: int, optional
+ Iteration to switch object type between 'complex' and 'potential' or between
+ 'potential' and 'complex'
+ store_iterations: bool, optional
+ If True, reconstructed objects and probes are stored at each iteration
+ progress_bar: bool, optional
+ If True, reconstruction progress is displayed
+ reset: bool, optional
+ If True, previous reconstructions are ignored
+
+ Returns
+ --------
+ self: MultislicePtychographicReconstruction
+ Self to accommodate chaining
+ """
+ asnumpy = self._asnumpy
+ xp = self._xp
+
+ # Reconstruction method
+
+ if reconstruction_method == "generalized-projections":
+ if (
+ reconstruction_parameter_a is None
+ or reconstruction_parameter_b is None
+ or reconstruction_parameter_c is None
+ ):
+ raise ValueError(
+ (
+ "reconstruction_parameter_a/b/c must all be specified "
+ "when using reconstruction_method='generalized-projections'."
+ )
+ )
+
+ use_projection_scheme = True
+ projection_a = reconstruction_parameter_a
+ projection_b = reconstruction_parameter_b
+ projection_c = reconstruction_parameter_c
+ step_size = None
+ elif (
+ reconstruction_method == "DM_AP"
+ or reconstruction_method == "difference-map_alternating-projections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = 1
+ projection_c = 1 + reconstruction_parameter
+ step_size = None
+ elif (
+ reconstruction_method == "RAAR"
+ or reconstruction_method == "relaxed-averaged-alternating-reflections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = 1 - 2 * reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "RRR"
+ or reconstruction_method == "relax-reflect-reflect"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0:
+ raise ValueError("reconstruction_parameter must be between 0-2.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "SUPERFLIP"
+ or reconstruction_method == "charge-flipping"
+ ):
+ use_projection_scheme = True
+ projection_a = 0
+ projection_b = 1
+ projection_c = 2
+ reconstruction_parameter = None
+ step_size = None
+ elif (
+ reconstruction_method == "GD" or reconstruction_method == "gradient-descent"
+ ):
+ use_projection_scheme = False
+ projection_a = None
+ projection_b = None
+ projection_c = None
+ reconstruction_parameter = None
+ else:
+ raise ValueError(
+ (
+ "reconstruction_method must be one of 'generalized-projections', "
+ "'DM_AP' (or 'difference-map_alternating-projections'), "
+ "'RAAR' (or 'relaxed-averaged-alternating-reflections'), "
+ "'RRR' (or 'relax-reflect-reflect'), "
+ "'SUPERFLIP' (or 'charge-flipping'), "
+ f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}."
+ )
+ )
+
+ if self._verbose:
+ if switch_object_iter > max_iter:
+ first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, "
+ else:
+ switch_object_type = (
+ "complex" if self._object_type == "potential" else "potential"
+ )
+ first_line = (
+ f"Performing {switch_object_iter} iterations using a {self._object_type} object type and "
+ f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, "
+ )
+ if max_batch_size is not None:
+ if use_projection_scheme:
+ raise ValueError(
+ (
+ "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. "
+ "Use reconstruction_method='GD' or set max_batch_size=None."
+ )
+ )
+ else:
+ print(
+ (
+ first_line + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}, "
+ f"in batches of max {max_batch_size} measurements."
+ )
+ )
+
+ else:
+ if reconstruction_parameter is not None:
+ if np.array(reconstruction_parameter).shape == (3,):
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}."
+ )
+ )
+ else:
+ if step_size is not None:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}."
+ )
+ )
+
+ # Batching
+ shuffled_indices = np.arange(self._num_diffraction_patterns)
+ unshuffled_indices = np.zeros_like(shuffled_indices)
+
+ if max_batch_size is not None:
+ xp.random.seed(seed_random)
+ else:
+ max_batch_size = self._num_diffraction_patterns
+
+ # initialization
+ if store_iterations and (not hasattr(self, "object_iterations") or reset):
+ self.object_iterations = []
+ self.probe_iterations = []
+
+ if reset:
+ self.error_iterations = []
+ self._object = self._object_initial.copy()
+ self._probe = self._probe_initial.copy()
+ self._positions_px = self._positions_px_initial.copy()
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ self._exit_waves = None
+ self._object_type = self._object_type_initial
+ if hasattr(self, "_tf"):
+ del self._tf
+ elif reset is None:
+ if hasattr(self, "error"):
+ warnings.warn(
+ (
+ "Continuing reconstruction from previous result. "
+ "Use reset=True for a fresh start."
+ ),
+ UserWarning,
+ )
+ else:
+ self.error_iterations = []
+ self._exit_waves = None
+
+ # main loop
+ for a0 in tqdmnd(
+ max_iter,
+ desc="Reconstructing object and probe",
+ unit=" iter",
+ disable=not progress_bar,
+ ):
+ error = 0.0
+
+ if a0 == switch_object_iter:
+ if self._object_type == "potential":
+ self._object_type = "complex"
+ self._object = xp.exp(1j * self._object)
+ elif self._object_type == "complex":
+ self._object_type = "potential"
+ self._object = xp.angle(self._object)
+
+ # randomize
+ if not use_projection_scheme:
+ np.random.shuffle(shuffled_indices)
+ unshuffled_indices[shuffled_indices] = np.arange(
+ self._num_diffraction_patterns
+ )
+ positions_px = self._positions_px.copy()[shuffled_indices]
+
+ for start, end in generate_batches(
+ self._num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ amplitudes = self._amplitudes[shuffled_indices[start:end]]
+
+ # forward operator
+ (
+ propagated_probes,
+ object_patches,
+ self._transmitted_probes,
+ self._exit_waves,
+ batch_error,
+ ) = self._forward(
+ self._object,
+ self._probe,
+ amplitudes,
+ self._exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ # adjoint operator
+ self._object, self._probe = self._adjoint(
+ self._object,
+ self._probe,
+ object_patches,
+ propagated_probes,
+ self._exit_waves,
+ use_projection_scheme=use_projection_scheme,
+ step_size=step_size,
+ normalization_min=normalization_min,
+ fix_probe=a0 < fix_probe_iter,
+ )
+
+ # position correction
+ if a0 >= fix_positions_iter:
+ positions_px[start:end] = self._position_correction(
+ self._object,
+ self._probe,
+ self._transmitted_probes,
+ amplitudes,
+ self._positions_px,
+ positions_step_size,
+ constrain_position_distance,
+ )
+
+ error += batch_error
+
+ # Normalize Error
+ error /= self._mean_diffraction_intensity * self._num_diffraction_patterns
+
+ # constraints
+ self._positions_px = positions_px.copy()[unshuffled_indices]
+ self._object, self._probe, self._positions_px = self._constraints(
+ self._object,
+ self._probe,
+ self._positions_px,
+ fix_com=fix_com and a0 >= fix_probe_iter,
+ constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude=a0
+ < constrain_probe_fourier_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
+ fit_probe_aberrations=a0 < fit_probe_aberrations_iter
+ and a0 >= fix_probe_iter,
+ fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
+ fix_probe_aperture=a0 < fix_probe_aperture_iter,
+ initial_probe_aperture=self._probe_initial_aperture,
+ fix_positions=a0 < fix_positions_iter,
+ global_affine_transformation=global_affine_transformation,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma is not None,
+ gaussian_filter_sigma=gaussian_filter_sigma,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass is not None or q_highpass is not None),
+ q_lowpass=q_lowpass,
+ q_highpass=q_highpass,
+ butterworth_order=butterworth_order,
+ kz_regularization_filter=a0 < kz_regularization_filter_iter
+ and kz_regularization_gamma is not None,
+ kz_regularization_gamma=kz_regularization_gamma[a0]
+ if kz_regularization_gamma is not None
+ and isinstance(kz_regularization_gamma, np.ndarray)
+ else kz_regularization_gamma,
+ identical_slices=a0 < identical_slices_iter,
+ object_positivity=object_positivity,
+ shrinkage_rad=shrinkage_rad,
+ object_mask=self._object_fov_mask_inverse
+ if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0
+ else None,
+ pure_phase_object=a0 < pure_phase_object_iter
+ and self._object_type == "complex",
+ tv_denoise_chambolle=a0 < tv_denoise_iter_chambolle
+ and tv_denoise_weight_chambolle is not None,
+ tv_denoise_weight_chambolle=tv_denoise_weight_chambolle,
+ tv_denoise_pad_chambolle=tv_denoise_pad_chambolle,
+ tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None,
+ tv_denoise_weights=tv_denoise_weights,
+ tv_denoise_inner_iter=tv_denoise_inner_iter,
+ )
+
+ self.error_iterations.append(error.item())
+ if store_iterations:
+ self.object_iterations.append(asnumpy(self._object.copy()))
+ self.probe_iterations.append(self.probe_centered)
+
+ # store result
+ self.object = asnumpy(self._object)
+ self.probe = self.probe_centered
+ self.error = error.item()
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _visualize_last_iteration_figax(
+ self,
+ fig,
+ object_ax,
+ convergence_ax,
+ cbar: bool,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object on a given fig/ax.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure object_ax lives in
+ object_ax: Axes
+ Matplotlib axes to plot reconstructed object in
+ convergence_ax: Axes, optional
+ Matplotlib axes to plot convergence plot in
+ cbar: bool, optional
+ If true, displays a colorbar
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ cmap = kwargs.pop("cmap", "magma")
+
+ if self._object_type == "complex":
+ obj = np.angle(self.object)
+ else:
+ obj = self.object
+
+ rotated_object = self._crop_rotate_object_fov(
+ np.sum(obj, axis=0), padding=padding
+ )
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ im = object_ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(object_ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if convergence_ax is not None and hasattr(self, "error_iterations"):
+ errors = np.array(self.error_iterations)
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = self.error_iterations
+
+ convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs)
+
+ def _visualize_last_iteration(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+
+ """
+ figsize = kwargs.pop("figsize", (8, 5))
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ if self._object_type == "complex":
+ obj = np.angle(self.object)
+ else:
+ obj = self.object
+
+ rotated_object = self._crop_rotate_object_fov(
+ np.sum(obj, axis=0), padding=padding
+ )
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=2,
+ height_ratios=[4, 1],
+ hspace=0.15,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=1,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ if plot_probe or plot_fourier_probe:
+ # Object
+ ax = fig.add_subplot(spec[0, 0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed object potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed object phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ # Probe
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+
+ ax = fig.add_subplot(spec[0, 1])
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ self.probe_fourier,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title("Reconstructed Fourier probe")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ self.probe, power=2, chroma_boost=chroma_boost
+ )
+ ax.set_title("Reconstructed probe intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(ax_cb, chroma_boost=chroma_boost)
+
+ else:
+ ax = fig.add_subplot(spec[0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed object potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed object phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if plot_convergence and hasattr(self, "error_iterations"):
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = np.array(self.error_iterations)
+ if plot_probe:
+ ax = fig.add_subplot(spec[1, :])
+ else:
+ ax = fig.add_subplot(spec[1])
+ ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax.set_ylabel("NMSE")
+ ax.set_xlabel("Iteration number")
+ ax.yaxis.tick_right()
+
+ fig.suptitle(f"Normalized mean squared error: {self.error:.3e}")
+ spec.tight_layout(fig)
+
+ def _visualize_all_iterations(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ iterations_grid: Tuple[int, int],
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays all reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ asnumpy = self._asnumpy
+
+ if not hasattr(self, "object_iterations"):
+ raise ValueError(
+ (
+ "Object and probe iterations were not saved during reconstruction. "
+ "Please re-run using store_iterations=True."
+ )
+ )
+
+ if iterations_grid == "auto":
+ num_iter = len(self.error_iterations)
+
+ if num_iter == 1:
+ return self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ elif plot_probe or plot_fourier_probe:
+ iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter)
+ else:
+ iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2)
+ else:
+ if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2:
+ raise ValueError()
+
+ auto_figsize = (
+ (3 * iterations_grid[1], 3 * iterations_grid[0] + 1)
+ if plot_convergence
+ else (3 * iterations_grid[1], 3 * iterations_grid[0])
+ )
+ figsize = kwargs.pop("figsize", auto_figsize)
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ errors = np.array(self.error_iterations)
+
+ objects = []
+ object_type = []
+
+ for obj in self.object_iterations:
+ if np.iscomplexobj(obj):
+ obj = np.angle(obj)
+ object_type.append("phase")
+ else:
+ object_type.append("potential")
+ objects.append(
+ self._crop_rotate_object_fov(np.sum(obj, axis=0), padding=padding)
+ )
+
+ if plot_probe or plot_fourier_probe:
+ total_grids = (np.prod(iterations_grid) / 2).astype("int")
+ probes = self.probe_iterations
+ else:
+ total_grids = np.prod(iterations_grid)
+ max_iter = len(objects) - 1
+ grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1))
+
+ extent = [
+ 0,
+ self.sampling[1] * objects[0].shape[1],
+ self.sampling[0] * objects[0].shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0)
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=2)
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ grid = ImageGrid(
+ fig,
+ spec[0],
+ nrows_ncols=(1, iterations_grid[1])
+ if (plot_probe or plot_fourier_probe)
+ else iterations_grid,
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ im = ax.imshow(
+ objects[grid_range[n]],
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if cbar:
+ grid.cbar_axes[n].colorbar(im)
+
+ if plot_probe or plot_fourier_probe:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ grid = ImageGrid(
+ fig,
+ spec[1],
+ nrows_ncols=(1, iterations_grid[1]),
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ asnumpy(
+ self._return_fourier_probe_from_centered_probe(
+ probes[grid_range[n]]
+ )
+ ),
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} Fourier probe")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ probes[grid_range[n]], power=2, chroma_boost=chroma_boost
+ )
+ ax.set_title(f"Iter: {grid_range[n]} probe intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ add_colorbar_arg(
+ grid.cbar_axes[n],
+ chroma_boost=chroma_boost,
+ )
+
+ if plot_convergence:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ if plot_probe:
+ ax2 = fig.add_subplot(spec[2])
+ else:
+ ax2 = fig.add_subplot(spec[1])
+ ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax2.set_ylabel("NMSE")
+ ax2.set_xlabel("Iteration number")
+ ax2.yaxis.tick_right()
+
+ spec.tight_layout(fig)
+
+ def visualize(
+ self,
+ fig=None,
+ iterations_grid: Tuple[int, int] = None,
+ plot_convergence: bool = True,
+ plot_probe: bool = True,
+ plot_fourier_probe: bool = False,
+ cbar: bool = True,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays reconstructed object and probe.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+
+ if iterations_grid is None:
+ self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ else:
+ self._visualize_all_iterations(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ iterations_grid=iterations_grid,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ return self
+
+ def show_transmitted_probe(
+ self,
+ plot_fourier_probe: bool = False,
+ **kwargs,
+ ):
+ """
+ Plots the min, max, and mean transmitted probe after propagation and transmission.
+
+ Parameters
+ ----------
+ plot_fourier_probe: boolean, optional
+ If True, the transmitted probes are also plotted in Fourier space
+ kwargs:
+ Passed to show_complex
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ transmitted_probe_intensities = xp.sum(
+ xp.abs(self._transmitted_probes) ** 2, axis=(-2, -1)
+ )
+ min_intensity_transmitted = self._transmitted_probes[
+ xp.argmin(transmitted_probe_intensities)
+ ]
+ max_intensity_transmitted = self._transmitted_probes[
+ xp.argmax(transmitted_probe_intensities)
+ ]
+ mean_transmitted = self._transmitted_probes.mean(0)
+ probes = [
+ asnumpy(self._return_centered_probe(probe))
+ for probe in [
+ mean_transmitted,
+ min_intensity_transmitted,
+ max_intensity_transmitted,
+ ]
+ ]
+ title = [
+ "Mean Transmitted Probe",
+ "Min Intensity Transmitted Probe",
+ "Max Intensity Transmitted Probe",
+ ]
+
+ if plot_fourier_probe:
+ bottom_row = [
+ asnumpy(self._return_fourier_probe(probe))
+ for probe in [
+ mean_transmitted,
+ min_intensity_transmitted,
+ max_intensity_transmitted,
+ ]
+ ]
+ probes = [probes, bottom_row]
+
+ title += [
+ "Mean Transmitted Fourier Probe",
+ "Min Intensity Transmitted Fourier Probe",
+ "Max Intensity Transmitted Fourier Probe",
+ ]
+
+ title = kwargs.get("title", title)
+ show_complex(
+ probes,
+ title=title,
+ **kwargs,
+ )
+
+ def show_slices(
+ self,
+ ms_object=None,
+ cbar: bool = True,
+ common_color_scale: bool = True,
+ padding: int = 0,
+ num_cols: int = 3,
+ show_fft: bool = False,
+ **kwargs,
+ ):
+ """
+ Displays reconstructed slices of object
+
+ Parameters
+ --------
+ ms_object: nd.array, optional
+ Object to plot slices of. If None, uses current object
+ cbar: bool, optional
+ If True, displays a colorbar
+ padding: int, optional
+ Padding to leave uncropped
+ num_cols: int, optional
+ Number of GridSpec columns
+ show_fft: bool, optional
+ if True, plots fft of object slices
+ """
+
+ if ms_object is None:
+ ms_object = self._object
+
+ rotated_object = self._crop_rotate_object_fov(ms_object, padding=padding)
+ if show_fft:
+ rotated_object = np.abs(
+ np.fft.fftshift(
+ np.fft.fft2(rotated_object, axes=(-2, -1)), axes=(-2, -1)
+ )
+ )
+ rotated_shape = rotated_object.shape
+
+ if np.iscomplexobj(rotated_object):
+ rotated_object = np.angle(rotated_object)
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[2],
+ self.sampling[0] * rotated_shape[1],
+ 0,
+ ]
+
+ num_rows = np.ceil(self._num_slices / num_cols).astype("int")
+ wspace = 0.35 if cbar else 0.15
+
+ axsize = kwargs.pop("axsize", (3, 3))
+ cmap = kwargs.pop("cmap", "magma")
+
+ if common_color_scale:
+ vals = np.sort(rotated_object.ravel())
+ ind_vmin = np.round((vals.shape[0] - 1) * 0.02).astype("int")
+ ind_vmax = np.round((vals.shape[0] - 1) * 0.98).astype("int")
+ ind_vmin = np.max([0, ind_vmin])
+ ind_vmax = np.min([len(vals) - 1, ind_vmax])
+ vmin = vals[ind_vmin]
+ vmax = vals[ind_vmax]
+ if vmax == vmin:
+ vmin = vals[0]
+ vmax = vals[-1]
+ else:
+ vmax = None
+ vmin = None
+ vmin = kwargs.pop("vmin", vmin)
+ vmax = kwargs.pop("vmax", vmax)
+
+ spec = GridSpec(
+ ncols=num_cols,
+ nrows=num_rows,
+ hspace=0.15,
+ wspace=wspace,
+ )
+
+ figsize = (axsize[0] * num_cols, axsize[1] * num_rows)
+ fig = plt.figure(figsize=figsize)
+
+ for flat_index, obj_slice in enumerate(rotated_object):
+ row_index, col_index = np.unravel_index(flat_index, (num_rows, num_cols))
+ ax = fig.add_subplot(spec[row_index, col_index])
+ im = ax.imshow(
+ obj_slice,
+ cmap=cmap,
+ vmin=vmin,
+ vmax=vmax,
+ extent=extent,
+ **kwargs,
+ )
+
+ ax.set_title(f"Slice index: {flat_index}")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if row_index < num_rows - 1:
+ ax.set_xticks([])
+ else:
+ ax.set_xlabel("y [A]")
+
+ if col_index > 0:
+ ax.set_yticks([])
+ else:
+ ax.set_ylabel("x [A]")
+
+ spec.tight_layout(fig)
+
+ def show_depth(
+ self,
+ x1: float,
+ x2: float,
+ y1: float,
+ y2: float,
+ specify_calibrated: bool = False,
+ gaussian_filter_sigma: float = None,
+ ms_object=None,
+ cbar: bool = False,
+ aspect: float = None,
+ plot_line_profile: bool = False,
+ **kwargs,
+ ):
+ """
+ Displays line profile depth section
+
+ Parameters
+ --------
+ x1, x2, y1, y2: floats (pixels)
+ Line profile for depth section runs from (x1,y1) to (x2,y2)
+ Specified in pixels unless specify_calibrated is True
+ specify_calibrated: bool (optional)
+ If True, specify x1, x2, y1, y2 in A values instead of pixels
+ gaussian_filter_sigma: float (optional)
+ Standard deviation of gaussian kernel in A
+ ms_object: np.array
+ Object to plot slices of. If None, uses current object
+ cbar: bool, optional
+ If True, displays a colorbar
+ aspect: float, optional
+ aspect ratio for depth profile plot
+ plot_line_profile: bool
+ If True, also plots line profile showing where depth profile is taken
+ """
+ if ms_object is not None:
+ ms_obj = ms_object
+ else:
+ ms_obj = self.object_cropped
+
+ if specify_calibrated:
+ x1 /= self.sampling[0]
+ x2 /= self.sampling[0]
+ y1 /= self.sampling[1]
+ y2 /= self.sampling[1]
+
+ if x2 == x1:
+ angle = 0
+ elif y2 == y1:
+ angle = np.pi / 2
+ else:
+ angle = np.arctan((x2 - x1) / (y2 - y1))
+
+ x0 = ms_obj.shape[1] / 2
+ y0 = ms_obj.shape[2] / 2
+
+ if (
+ x1 > ms_obj.shape[1]
+ or x2 > ms_obj.shape[1]
+ or y1 > ms_obj.shape[2]
+ or y2 > ms_obj.shape[2]
+ ):
+ raise ValueError("depth section must be in field of view of object")
+
+ from py4DSTEM.process.phase.utils import rotate_point
+
+ x1_0, y1_0 = rotate_point((x0, y0), (x1, y1), angle)
+ x2_0, y2_0 = rotate_point((x0, y0), (x2, y2), angle)
+
+ rotated_object = np.roll(
+ rotate(ms_obj, np.rad2deg(angle), reshape=False, axes=(-1, -2)),
+ -int(x1_0),
+ axis=1,
+ )
+
+ if np.iscomplexobj(rotated_object):
+ rotated_object = np.angle(rotated_object)
+ if gaussian_filter_sigma is not None:
+ from scipy.ndimage import gaussian_filter
+
+ gaussian_filter_sigma /= self.sampling[0]
+ rotated_object = gaussian_filter(rotated_object, gaussian_filter_sigma)
+
+ plot_im = rotated_object[:, 0, int(y1_0) : int(y2_0)]
+
+ extent = [
+ 0,
+ self.sampling[1] * plot_im.shape[1],
+ self._slice_thicknesses[0] * plot_im.shape[0],
+ 0,
+ ]
+ figsize = kwargs.pop("figsize", (6, 6))
+ if not plot_line_profile:
+ fig, ax = plt.subplots(figsize=figsize)
+ im = ax.imshow(plot_im, cmap="magma", extent=extent)
+ if aspect is not None:
+ ax.set_aspect(aspect)
+ ax.set_xlabel("r [A]")
+ ax.set_ylabel("z [A]")
+ ax.set_title("Multislice depth profile")
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+ else:
+ extent2 = [
+ 0,
+ self.sampling[1] * ms_obj.shape[2],
+ self.sampling[0] * ms_obj.shape[1],
+ 0,
+ ]
+
+ fig, ax = plt.subplots(2, 1, figsize=figsize)
+ ax[0].imshow(ms_obj.sum(0), cmap="gray", extent=extent2)
+ ax[0].plot(
+ [y1 * self.sampling[0], y2 * self.sampling[1]],
+ [x1 * self.sampling[0], x2 * self.sampling[1]],
+ color="red",
+ )
+ ax[0].set_xlabel("y [A]")
+ ax[0].set_ylabel("x [A]")
+ ax[0].set_title("Multislice depth profile location")
+
+ im = ax[1].imshow(plot_im, cmap="magma", extent=extent)
+ if aspect is not None:
+ ax[1].set_aspect(aspect)
+ ax[1].set_xlabel("r [A]")
+ ax[1].set_ylabel("z [A]")
+ ax[1].set_title("Multislice depth profile")
+ if cbar:
+ divider = make_axes_locatable(ax[1])
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+ plt.tight_layout()
+
+ def tune_num_slices_and_thicknesses(
+ self,
+ num_slices_guess=None,
+ thicknesses_guess=None,
+ num_slices_step_size=1,
+ thicknesses_step_size=20,
+ num_slices_values=3,
+ num_thicknesses_values=3,
+ update_defocus=False,
+ max_iter=5,
+ plot_reconstructions=True,
+ plot_convergence=True,
+ return_values=False,
+ **kwargs,
+ ):
+ """
+ Run reconstructions over a parameters space of number of slices
+ and slice thicknesses. Should be run after the preprocess step.
+
+ Parameters
+ ----------
+ num_slices_guess: float, optional
+ initial starting guess for number of slices, rounds to nearest integer
+ if None, uses current initialized values
+ thicknesses_guess: float (A), optional
+ initial starting guess for thicknesses of slices assuming same
+ thickness for each slice
+ if None, uses current initialized values
+ num_slices_step_size: float, optional
+ size of change of number of slices for each step in parameter space
+ thicknesses_step_size: float (A), optional
+ size of change of slice thicknesses for each step in parameter space
+ num_slices_values: int, optional
+ number of number of slice values to test, must be >= 1
+ num_thicknesses_values: int,optional
+ number of thicknesses values to test, must be >= 1
+ update_defocus: bool, optional
+ if True, updates defocus based on estimated total thickness
+ max_iter: int, optional
+ number of iterations to run in ptychographic reconstruction
+ plot_reconstructions: bool, optional
+ if True, plot phase of reconstructed objects
+ plot_convergence: bool, optional
+ if True, plots error for each iteration for each reconstruction
+ return_values: bool, optional
+ if True, returns objects, convergence
+
+ Returns
+ -------
+ objects: list
+ reconstructed objects
+ convergence: np.ndarray
+ array of convergence values from reconstructions
+ """
+
+ # calculate number of slices and thicknesses values to test
+ if num_slices_guess is None:
+ num_slices_guess = self._num_slices
+ if thicknesses_guess is None:
+ thicknesses_guess = np.mean(self._slice_thicknesses)
+
+ if num_slices_values == 1:
+ num_slices_step_size = 0
+
+ if num_thicknesses_values == 1:
+ thicknesses_step_size = 0
+
+ num_slices = np.linspace(
+ num_slices_guess - num_slices_step_size * (num_slices_values - 1) / 2,
+ num_slices_guess + num_slices_step_size * (num_slices_values - 1) / 2,
+ num_slices_values,
+ )
+
+ thicknesses = np.linspace(
+ thicknesses_guess
+ - thicknesses_step_size * (num_thicknesses_values - 1) / 2,
+ thicknesses_guess
+ + thicknesses_step_size * (num_thicknesses_values - 1) / 2,
+ num_thicknesses_values,
+ )
+
+ if return_values:
+ convergence = []
+ objects = []
+
+ # current initialized values
+ current_verbose = self._verbose
+ current_num_slices = self._num_slices
+ current_thicknesses = self._slice_thicknesses
+ current_rotation_deg = self._rotation_best_rad * 180 / np.pi
+ current_transpose = self._rotation_best_transpose
+ current_defocus = -self._polar_parameters["C10"]
+
+ # Gridspec to plot on
+ if plot_reconstructions:
+ if plot_convergence:
+ spec = GridSpec(
+ ncols=num_thicknesses_values,
+ nrows=num_slices_values * 2,
+ height_ratios=[1, 1 / 4] * num_slices_values,
+ hspace=0.15,
+ wspace=0.35,
+ )
+ figsize = kwargs.get(
+ "figsize", (4 * num_thicknesses_values, 5 * num_slices_values)
+ )
+ else:
+ spec = GridSpec(
+ ncols=num_thicknesses_values,
+ nrows=num_slices_values,
+ hspace=0.15,
+ wspace=0.35,
+ )
+ figsize = kwargs.get(
+ "figsize", (4 * num_thicknesses_values, 4 * num_slices_values)
+ )
+
+ fig = plt.figure(figsize=figsize)
+
+ progress_bar = kwargs.pop("progress_bar", False)
+ # run loop and plot along the way
+ self._verbose = False
+ for flat_index, (slices, thickness) in enumerate(
+ tqdmnd(num_slices, thicknesses, desc="Tuning angle and defocus")
+ ):
+ slices = int(slices)
+ self._num_slices = slices
+ self._slice_thicknesses = np.tile(thickness, slices - 1)
+ self._probe = None
+ self._object = None
+ if update_defocus:
+ defocus = current_defocus + slices / 2 * thickness
+ self._polar_parameters["C10"] = -defocus
+
+ self.preprocess(
+ plot_center_of_mass=False,
+ plot_rotation=False,
+ plot_probe_overlaps=False,
+ force_com_rotation=current_rotation_deg,
+ force_com_transpose=current_transpose,
+ )
+ self.reconstruct(
+ reset=True,
+ store_iterations=True if plot_convergence else False,
+ max_iter=max_iter,
+ progress_bar=progress_bar,
+ **kwargs,
+ )
+
+ if plot_reconstructions:
+ row_index, col_index = np.unravel_index(
+ flat_index, (num_slices_values, num_thicknesses_values)
+ )
+
+ if plot_convergence:
+ object_ax = fig.add_subplot(spec[row_index * 2, col_index])
+ convergence_ax = fig.add_subplot(spec[row_index * 2 + 1, col_index])
+ self._visualize_last_iteration_figax(
+ fig,
+ object_ax=object_ax,
+ convergence_ax=convergence_ax,
+ cbar=True,
+ )
+ convergence_ax.yaxis.tick_right()
+ else:
+ object_ax = fig.add_subplot(spec[row_index, col_index])
+ self._visualize_last_iteration_figax(
+ fig,
+ object_ax=object_ax,
+ convergence_ax=None,
+ cbar=True,
+ )
+
+ object_ax.set_title(
+ f" num slices = {slices:.0f}, slices thickness = {thickness:.1f} A \n error = {self.error:.3e}"
+ )
+ object_ax.set_xticks([])
+ object_ax.set_yticks([])
+
+ if return_values:
+ objects.append(self.object)
+ convergence.append(self.error_iterations.copy())
+
+ # initialize back to pre-tuning values
+ self._probe = None
+ self._object = None
+ self._num_slices = current_num_slices
+ self._slice_thicknesses = np.tile(current_thicknesses, current_num_slices - 1)
+ self._polar_parameters["C10"] = -current_defocus
+ self.preprocess(
+ force_com_rotation=current_rotation_deg,
+ force_com_transpose=current_transpose,
+ plot_center_of_mass=False,
+ plot_rotation=False,
+ plot_probe_overlaps=False,
+ )
+ self._verbose = current_verbose
+
+ if plot_reconstructions:
+ spec.tight_layout(fig)
+
+ if return_values:
+ return objects, convergence
+
+ def _return_object_fft(
+ self,
+ obj=None,
+ ):
+ """
+ Returns obj fft shifted to center of array
+
+ Parameters
+ ----------
+ obj: array, optional
+ if None is specified, uses self._object
+ """
+ asnumpy = self._asnumpy
+
+ if obj is None:
+ obj = self._object
+
+ obj = asnumpy(obj)
+ if np.iscomplexobj(obj):
+ obj = np.angle(obj)
+
+ obj = self._crop_rotate_object_fov(np.sum(obj, axis=0))
+ return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj))))
+
+ def _return_projected_cropped_potential(
+ self,
+ ):
+ """Utility function to accommodate multiple classes"""
+ if self._object_type == "complex":
+ projected_cropped_potential = np.angle(self.object_cropped).sum(0)
+ else:
+ projected_cropped_potential = self.object_cropped.sum(0)
+
+ return projected_cropped_potential
diff --git a/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py
new file mode 100644
index 000000000..c49a1faac
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_overlap_magnetic_tomography.py
@@ -0,0 +1,3364 @@
+"""
+Module for reconstructing phase objects from 4DSTEM datasets using iterative methods,
+namely overlap magnetic tomography.
+"""
+
+import warnings
+from typing import Mapping, Sequence, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pylops
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+from py4DSTEM.visualize import show
+from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg
+from scipy.ndimage import rotate as rotate_np
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+from emdfile import Custom, tqdmnd
+from py4DSTEM import DataCube
+from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction
+from py4DSTEM.process.phase.utils import (
+ ComplexProbe,
+ fft_shift,
+ generate_batches,
+ polar_aliases,
+ polar_symbols,
+ project_vector_field_divergence,
+ spatial_frequencies,
+)
+from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+
+class OverlapMagneticTomographicReconstruction(PtychographicReconstruction):
+ """
+ Overlap Magnetic Tomographic Reconstruction Class.
+
+ List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy)
+ Reconstructed probe dimensions : (Sx,Sy)
+ Reconstructed object dimensions : (Px,Py,Py)
+
+ such that (Sx,Sy) is the region-of-interest (ROI) size of our probe
+ and (Px,Py,Py) is the padded-object electrostatic potential volume,
+ where x-axis is the tilt.
+
+ Parameters
+ ----------
+ datacube: List of DataCubes
+ Input list of 4D diffraction pattern intensities for different tilts
+ energy: float
+ The electron energy of the wave functions in eV
+ num_slices: int
+ Number of slices to use in the forward model
+ tilt_angles_deg: Sequence[float]
+ List of (\alpha, \beta) tilt angle tuple in degrees,
+ with the following Euler-angle convention:
+ - \alpha tilt around z-axis
+ - \beta tilt around x-axis
+ - -\alpha tilt around z-axis
+ semiangle_cutoff: float, optional
+ Semiangle cutoff for the initial probe guess in mrad
+ semiangle_cutoff_pixels: float, optional
+ Semiangle cutoff for the initial probe guess in pixels
+ rolloff: float, optional
+ Semiangle rolloff for the initial probe guess
+ vacuum_probe_intensity: np.ndarray, optional
+ Vacuum probe to use as intensity aperture for initial probe guess
+ polar_parameters: dict, optional
+ Mapping from aberration symbols to their corresponding values. All aberration
+ magnitudes should be given in Å and angles should be given in radians.
+ object_padding_px: Tuple[int,int], optional
+ Pixel dimensions to pad object with
+ If None, the padding is set to half the probe ROI dimensions
+ initial_object_guess: np.ndarray, optional
+ Initial guess for complex-valued object of dimensions (Px,Py,Py)
+ If None, initialized to 1.0
+ initial_probe_guess: np.ndarray, optional
+ Initial guess for complex-valued probe of dimensions (Sx,Sy). If None,
+ initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations
+ initial_scan_positions: list of np.ndarray, optional
+ Probe positions in Å for each diffraction intensity per tilt
+ If None, initialized to a grid scan centered along tilt axis
+ verbose: bool, optional
+ If True, class methods will inherit this and print additional information
+ device: str, optional
+ Calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ object_type: str, optional
+ The object can be reconstructed as a real potential ('potential') or a complex
+ object ('complex')
+ positions_mask: np.ndarray, optional
+ Boolean real space mask to select positions in datacube to skip for reconstruction
+ name: str, optional
+ Class name
+ kwargs:
+ Provide the aberration coefficients as keyword arguments.
+ """
+
+ # Class-specific Metadata
+ _class_specific_metadata = ("_num_slices", "_tilt_angles_deg")
+
+ def __init__(
+ self,
+ energy: float,
+ num_slices: int,
+ tilt_angles_deg: Sequence[Tuple[float, float]],
+ datacube: Sequence[DataCube] = None,
+ semiangle_cutoff: float = None,
+ semiangle_cutoff_pixels: float = None,
+ rolloff: float = 2.0,
+ vacuum_probe_intensity: np.ndarray = None,
+ polar_parameters: Mapping[str, float] = None,
+ object_padding_px: Tuple[int, int] = None,
+ object_type: str = "potential",
+ positions_mask: np.ndarray = None,
+ initial_object_guess: np.ndarray = None,
+ initial_probe_guess: np.ndarray = None,
+ initial_scan_positions: Sequence[np.ndarray] = None,
+ verbose: bool = True,
+ device: str = "cpu",
+ name: str = "overlap-magnetic-tomographic_reconstruction",
+ **kwargs,
+ ):
+ Custom.__init__(self, name=name)
+
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import gaussian_filter, rotate, zoom
+
+ self._gaussian_filter = gaussian_filter
+ self._zoom = zoom
+ self._rotate = rotate
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import gaussian_filter, rotate, zoom
+
+ self._gaussian_filter = gaussian_filter
+ self._zoom = zoom
+ self._rotate = rotate
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ for key in kwargs.keys():
+ if (key not in polar_symbols) and (key not in polar_aliases.keys()):
+ raise ValueError("{} not a recognized parameter".format(key))
+
+ self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))
+
+ if polar_parameters is None:
+ polar_parameters = {}
+
+ polar_parameters.update(kwargs)
+ self._set_polar_parameters(polar_parameters)
+
+ num_tilts = len(tilt_angles_deg)
+ if initial_scan_positions is None:
+ initial_scan_positions = [None] * num_tilts
+
+ if object_type != "potential":
+ raise NotImplementedError()
+
+ if positions_mask is not None and positions_mask.dtype != "bool":
+ warnings.warn(
+ ("`positions_mask` converted to `bool` array"),
+ UserWarning,
+ )
+ positions_mask = np.asarray(positions_mask, dtype="bool")
+
+ self.set_save_defaults()
+
+ # Data
+ self._datacube = datacube
+ self._object = initial_object_guess
+ self._probe = initial_probe_guess
+
+ # Common Metadata
+ self._vacuum_probe_intensity = vacuum_probe_intensity
+ self._scan_positions = initial_scan_positions
+ self._energy = energy
+ self._semiangle_cutoff = semiangle_cutoff
+ self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
+ self._rolloff = rolloff
+ self._object_type = object_type
+ self._object_padding_px = object_padding_px
+ self._positions_mask = positions_mask
+ self._verbose = verbose
+ self._device = device
+ self._preprocessed = False
+
+ # Class-specific Metadata
+ self._num_slices = num_slices
+ self._tilt_angles_deg = tuple(tilt_angles_deg)
+ self._num_tilts = num_tilts
+
+ def _precompute_propagator_arrays(
+ self,
+ gpts: Tuple[int, int],
+ sampling: Tuple[float, float],
+ energy: float,
+ slice_thicknesses: Sequence[float],
+ ):
+ """
+ Precomputes propagator arrays complex wave-function will be convolved by,
+ for all slice thicknesses.
+
+ Parameters
+ ----------
+ gpts: Tuple[int,int]
+ Wavefunction pixel dimensions
+ sampling: Tuple[float,float]
+ Wavefunction sampling in A
+ energy: float
+ The electron energy of the wave functions in eV
+ slice_thicknesses: Sequence[float]
+ Array of slice thicknesses in A
+
+ Returns
+ -------
+ propagator_arrays: np.ndarray
+ (T,Sx,Sy) shape array storing propagator arrays
+ """
+ xp = self._xp
+
+ # Frequencies
+ kx, ky = spatial_frequencies(gpts, sampling)
+ kx = xp.asarray(kx, dtype=xp.float32)
+ ky = xp.asarray(ky, dtype=xp.float32)
+
+ # Propagators
+ wavelength = electron_wavelength_angstrom(energy)
+ num_slices = slice_thicknesses.shape[0]
+ propagators = xp.empty(
+ (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64
+ )
+ for i, dz in enumerate(slice_thicknesses):
+ propagators[i] = xp.exp(
+ 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz)
+ )
+ propagators[i] *= xp.exp(
+ 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)
+ )
+
+ return propagators
+
+ def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray):
+ """
+ Propagates array by Fourier convolving array with propagator_array.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ Wavefunction array to be convolved
+ propagator_array: np.ndarray
+ Propagator array to convolve array with
+
+ Returns
+ -------
+ propagated_array: np.ndarray
+ Fourier-convolved array
+ """
+ xp = self._xp
+
+ return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array)
+
+ def _project_sliced_object(self, array: np.ndarray, output_z):
+ """
+ Expands supersliced object or projects voxel-sliced object.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ 3D array to expand/project
+ output_z: int
+ Output_dimension to expand/project array to.
+ If output_z > array.shape[0] array is expanded, else it's projected
+
+ Returns
+ -------
+ expanded_or_projected_array: np.ndarray
+ expanded or projected array
+ """
+ xp = self._xp
+ input_z = array.shape[0]
+
+ voxels_per_slice = np.ceil(input_z / output_z).astype("int")
+ pad_size = voxels_per_slice * output_z - input_z
+
+ padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0)))
+
+ return xp.sum(
+ padded_array.reshape(
+ (
+ -1,
+ voxels_per_slice,
+ )
+ + array.shape[1:]
+ ),
+ axis=1,
+ )
+
+ def _expand_sliced_object(self, array: np.ndarray, output_z):
+ """
+ Expands supersliced object or projects voxel-sliced object.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ 3D array to expand/project
+ output_z: int
+ Output_dimension to expand/project array to.
+ If output_z > array.shape[0] array is expanded, else it's projected
+
+ Returns
+ -------
+ expanded_or_projected_array: np.ndarray
+ expanded or projected array
+ """
+ xp = self._xp
+ input_z = array.shape[0]
+
+ voxels_per_slice = np.ceil(output_z / input_z).astype("int")
+ remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z)
+
+ voxels_in_slice = xp.repeat(voxels_per_slice, input_z)
+ voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice
+
+ normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None]
+ return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z]
+
+ def _euler_angle_rotate_volume(
+ self,
+ volume_array,
+ alpha_deg,
+ beta_deg,
+ ):
+ """
+ Rotate 3D volume using alpha, beta, gamma Euler angles according to convention:
+
+ - \\-alpha tilt around first axis (z)
+ - \\beta tilt around second axis (x)
+ - \\alpha tilt around first axis (z)
+
+ Note: since we store array as zxy, the x- and y-axis rotations flip sign below.
+
+ """
+
+ rotate = self._rotate
+ volume = volume_array.copy()
+
+ alpha_deg, beta_deg = np.mod(np.array([alpha_deg, beta_deg]) + 180, 360) - 180
+
+ if alpha_deg == -180:
+ # print(f"rotation of {-beta_deg} around x")
+ volume = rotate(
+ volume,
+ beta_deg,
+ axes=(0, 2),
+ reshape=False,
+ order=3,
+ )
+ elif alpha_deg == -90:
+ # print(f"rotation of {beta_deg} around y")
+ volume = rotate(
+ volume,
+ -beta_deg,
+ axes=(0, 1),
+ reshape=False,
+ order=3,
+ )
+ elif alpha_deg == 0:
+ # print(f"rotation of {beta_deg} around x")
+ volume = rotate(
+ volume,
+ -beta_deg,
+ axes=(0, 2),
+ reshape=False,
+ order=3,
+ )
+ elif alpha_deg == 90:
+ # print(f"rotation of {-beta_deg} around y")
+ volume = rotate(
+ volume,
+ beta_deg,
+ axes=(0, 1),
+ reshape=False,
+ order=3,
+ )
+ else:
+ # print((
+ # f"rotation of {-alpha_deg} around z, "
+ # f"rotation of {beta_deg} around x, "
+ # f"rotation of {alpha_deg} around z."
+ # ))
+
+ volume = rotate(
+ volume,
+ -alpha_deg,
+ axes=(1, 2),
+ reshape=False,
+ order=3,
+ )
+
+ volume = rotate(
+ volume,
+ -beta_deg,
+ axes=(0, 2),
+ reshape=False,
+ order=3,
+ )
+
+ volume = rotate(
+ volume,
+ alpha_deg,
+ axes=(1, 2),
+ reshape=False,
+ order=3,
+ )
+
+ return volume
+
+ def preprocess(
+ self,
+ diffraction_intensities_shape: Tuple[int, int] = None,
+ reshaping_method: str = "fourier",
+ probe_roi_shape: Tuple[int, int] = None,
+ dp_mask: np.ndarray = None,
+ fit_function: str = "plane",
+ plot_probe_overlaps: bool = True,
+ rotation_real_space_degrees: float = None,
+ diffraction_patterns_rotate_degrees: float = None,
+ diffraction_patterns_transpose: bool = None,
+ force_com_shifts: Sequence[float] = None,
+ progress_bar: bool = True,
+ force_scan_sampling: float = None,
+ force_angular_sampling: float = None,
+ force_reciprocal_sampling: float = None,
+ object_fov_mask: np.ndarray = None,
+ crop_patterns: bool = False,
+ **kwargs,
+ ):
+ """
+ Ptychographic preprocessing step.
+
+ Additionally, it initializes an (Px,Py, Py) array of 1.0
+ and a complex probe using the specified polar parameters.
+
+ Parameters
+ ----------
+ diffraction_intensities_shape: Tuple[int,int], optional
+ Pixel dimensions (Qx',Qy') of the resampled diffraction intensities
+ If None, no resampling of diffraction intenstities is performed
+ reshaping_method: str, optional
+ Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default)
+ probe_roi_shape, (int,int), optional
+ Padded diffraction intensities shape.
+ If None, no padding is performed
+ dp_mask: ndarray, optional
+ Mask for datacube intensities (Qx,Qy)
+ fit_function: str, optional
+ 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two'
+ plot_probe_overlaps: bool, optional
+ If True, initial probe overlaps scanned over the object will be displayed
+ rotation_real_space_degrees: float (degrees), optional
+ In plane rotation around z axis between x axis and tilt axis in
+ real space (forced to be in xy plane)
+ diffraction_patterns_rotate_degrees: float, optional
+ Relative rotation angle between real and reciprocal space
+ diffraction_patterns_transpose: bool, optional
+ Whether diffraction intensities need to be transposed.
+ force_com_shifts: list of tuple of ndarrays (CoMx, CoMy)
+ Amplitudes come from diffraction patterns shifted with
+ the CoM in the upper left corner for each probe unless
+ shift is overwritten. One tuple per tilt.
+ force_scan_sampling: float, optional
+ Override DataCube real space scan pixel size calibrations, in Angstrom
+ force_angular_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in mrad
+ force_reciprocal_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in A^-1
+ object_fov_mask: np.ndarray (boolean)
+ Boolean mask of FOV. Used to calculate additional shrinkage of object
+ If None, probe_overlap intensity is thresholded
+ crop_patterns: bool
+ if True, crop patterns to avoid wrap around of patterns when centering
+
+ Returns
+ --------
+ self: OverlapTomographicReconstruction
+ Self to accommodate chaining
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # set additional metadata
+ self._diffraction_intensities_shape = diffraction_intensities_shape
+ self._reshaping_method = reshaping_method
+ self._probe_roi_shape = probe_roi_shape
+ self._dp_mask = dp_mask
+
+ if self._datacube is None:
+ raise ValueError(
+ (
+ "The preprocess() method requires a DataCube. "
+ "Please run ptycho.attach_datacube(DataCube) first."
+ )
+ )
+
+ # Prepopulate various arrays
+ num_probes_per_tilt = [0]
+
+ for dc in self._datacube:
+ rx, ry = dc.Rshape
+ num_probes_per_tilt.append(rx * ry)
+
+ self._num_diffraction_patterns = sum(num_probes_per_tilt)
+ self._cum_probes_per_tilt = np.cumsum(np.array(num_probes_per_tilt))
+
+ self._mean_diffraction_intensity = []
+ self._positions_px_all = np.empty((self._num_diffraction_patterns, 2))
+
+ self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees)
+ self._rotation_best_transpose = diffraction_patterns_transpose
+
+ if force_com_shifts is None:
+ force_com_shifts = [None] * self._num_tilts
+
+ for tilt_index in tqdmnd(
+ self._num_tilts,
+ desc="Preprocessing data",
+ unit="tilt",
+ disable=not progress_bar,
+ ):
+ if tilt_index == 0:
+ (
+ self._datacube[tilt_index],
+ self._vacuum_probe_intensity,
+ self._dp_mask,
+ force_com_shifts[tilt_index],
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube[tilt_index],
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ dp_mask=self._dp_mask,
+ com_shifts=force_com_shifts[tilt_index],
+ )
+
+ self._amplitudes = xp.empty(
+ (self._num_diffraction_patterns,) + self._datacube[0].Qshape
+ )
+ self._region_of_interest_shape = np.array(
+ self._amplitudes[0].shape[-2:]
+ )
+
+ else:
+ (
+ self._datacube[tilt_index],
+ _,
+ _,
+ force_com_shifts[tilt_index],
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube[tilt_index],
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=None,
+ dp_mask=None,
+ com_shifts=force_com_shifts[tilt_index],
+ )
+
+ intensities = self._extract_intensities_and_calibrations_from_datacube(
+ self._datacube[tilt_index],
+ require_calibrations=True,
+ force_scan_sampling=force_scan_sampling,
+ force_angular_sampling=force_angular_sampling,
+ force_reciprocal_sampling=force_reciprocal_sampling,
+ )
+
+ (
+ com_measured_x,
+ com_measured_y,
+ com_fitted_x,
+ com_fitted_y,
+ com_normalized_x,
+ com_normalized_y,
+ ) = self._calculate_intensities_center_of_mass(
+ intensities,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts[tilt_index],
+ )
+
+ (
+ self._amplitudes[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ],
+ mean_diffraction_intensity_temp,
+ ) = self._normalize_diffraction_intensities(
+ intensities,
+ com_fitted_x,
+ com_fitted_y,
+ crop_patterns,
+ self._positions_mask[tilt_index],
+ )
+
+ self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp)
+
+ del (
+ intensities,
+ com_measured_x,
+ com_measured_y,
+ com_fitted_x,
+ com_fitted_y,
+ com_normalized_x,
+ com_normalized_y,
+ )
+
+ self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ] = self._calculate_scan_positions_in_pixels(
+ self._scan_positions[tilt_index], self._positions_mask[tilt_index]
+ )
+
+ # handle semiangle specified in pixels
+ if self._semiangle_cutoff_pixels:
+ self._semiangle_cutoff = (
+ self._semiangle_cutoff_pixels * self._angular_sampling[0]
+ )
+
+ # Object Initialization
+ if self._object is None:
+ pad_x = self._object_padding_px[0][1]
+ pad_y = self._object_padding_px[1][1]
+ p, q = np.round(np.max(self._positions_px_all, axis=0))
+ p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype(
+ "int"
+ )
+ q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype(
+ "int"
+ )
+ self._object = xp.zeros((4, q, p, q), dtype=xp.float32)
+ else:
+ self._object = xp.asarray(self._object, dtype=xp.float32)
+
+ self._object_initial = self._object.copy()
+ self._object_type_initial = self._object_type
+ self._object_shape = self._object.shape[-2:]
+ self._num_voxels = self._object.shape[1]
+
+ # Center Probes
+ self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32)
+
+ for tilt_index in range(self._num_tilts):
+ self._positions_px = self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ]
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px -= (
+ self._positions_px_com - xp.array(self._object_shape) / 2
+ )
+
+ self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ] = self._positions_px.copy()
+
+ self._positions_px_initial_all = self._positions_px_all.copy()
+ self._positions_initial_all = self._positions_px_initial_all.copy()
+ self._positions_initial_all[:, 0] *= self.sampling[0]
+ self._positions_initial_all[:, 1] *= self.sampling[1]
+
+ # Probe Initialization
+ if self._probe is None:
+ if self._vacuum_probe_intensity is not None:
+ self._semiangle_cutoff = np.inf
+ self._vacuum_probe_intensity = xp.asarray(
+ self._vacuum_probe_intensity, dtype=xp.float32
+ )
+ probe_x0, probe_y0 = get_CoM(
+ self._vacuum_probe_intensity, device=self._device
+ )
+ self._vacuum_probe_intensity = get_shifted_ar(
+ self._vacuum_probe_intensity,
+ -probe_x0,
+ -probe_y0,
+ bilinear=True,
+ device=self._device,
+ )
+ if crop_patterns:
+ self._vacuum_probe_intensity = self._vacuum_probe_intensity[
+ self._crop_mask
+ ].reshape(self._region_of_interest_shape)
+
+ self._probe = (
+ ComplexProbe(
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ energy=self._energy,
+ semiangle_cutoff=self._semiangle_cutoff,
+ rolloff=self._rolloff,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )
+ .build()
+ ._array
+ )
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(
+ sum(self._mean_diffraction_intensity)
+ / self._num_tilts
+ / probe_intensity
+ )
+
+ else:
+ if isinstance(self._probe, ComplexProbe):
+ if self._probe._gpts != self._region_of_interest_shape:
+ raise ValueError()
+ if hasattr(self._probe, "_array"):
+ self._probe = self._probe._array
+ else:
+ self._probe._xp = xp
+ self._probe = self._probe.build()._array
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(
+ sum(self._mean_diffraction_intensity)
+ / self._num_tilts
+ / probe_intensity
+ )
+ else:
+ self._probe = xp.asarray(self._probe, dtype=xp.complex64)
+
+ self._probe_initial = self._probe.copy()
+ self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe))
+
+ self._known_aberrations_array = ComplexProbe(
+ energy=self._energy,
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )._evaluate_ctf()
+
+ # Precomputed propagator arrays
+ self._slice_thicknesses = np.tile(
+ self._object_shape[1] * self.sampling[1] / self._num_slices,
+ self._num_slices - 1,
+ )
+ self._propagator_arrays = self._precompute_propagator_arrays(
+ self._region_of_interest_shape,
+ self.sampling,
+ self._energy,
+ self._slice_thicknesses,
+ )
+
+ # overlaps
+ if object_fov_mask is None:
+ probe_overlap_3D = xp.zeros_like(self._object[0])
+
+ for tilt_index in np.arange(self._num_tilts):
+ alpha_deg, beta_deg = self._tilt_angles_deg[tilt_index]
+
+ probe_overlap_3D = self._euler_angle_rotate_volume(
+ probe_overlap_3D,
+ alpha_deg,
+ beta_deg,
+ )
+
+ self._positions_px = self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ shifted_probes = fft_shift(
+ self._probe, self._positions_px_fractional, xp
+ )
+ probe_intensities = xp.abs(shifted_probes) ** 2
+ probe_overlap = self._sum_overlapping_patches_bincounts(
+ probe_intensities
+ )
+
+ probe_overlap_3D += probe_overlap[None]
+
+ probe_overlap_3D = self._euler_angle_rotate_volume(
+ probe_overlap_3D,
+ alpha_deg,
+ -beta_deg,
+ )
+
+ probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0)
+ self._object_fov_mask = asnumpy(
+ probe_overlap_3D > 0.25 * probe_overlap_3D.max()
+ )
+ else:
+ self._object_fov_mask = np.asarray(object_fov_mask)
+ self._positions_px = self._positions_px_all[: self._cum_probes_per_tilt[1]]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp)
+ probe_intensities = xp.abs(shifted_probes) ** 2
+ probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities)
+ probe_overlap = self._gaussian_filter(probe_overlap, 1.0)
+
+ self._object_fov_mask_inverse = np.invert(self._object_fov_mask)
+
+ if plot_probe_overlaps:
+ figsize = kwargs.pop("figsize", (13, 4))
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ # initial probe
+ complex_probe_rgb = Complex2RGB(
+ self.probe_centered,
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ # propagated
+ propagated_probe = self._probe.copy()
+
+ for s in range(self._num_slices - 1):
+ propagated_probe = self._propagate_array(
+ propagated_probe, self._propagator_arrays[s]
+ )
+ complex_propagated_rgb = Complex2RGB(
+ asnumpy(self._return_centered_probe(propagated_probe)),
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ extent = [
+ 0,
+ self.sampling[1] * self._object_shape[1],
+ self.sampling[0] * self._object_shape[0],
+ 0,
+ ]
+
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)
+
+ ax1.imshow(
+ complex_probe_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax1)
+ cax1 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(
+ cax1,
+ chroma_boost=chroma_boost,
+ )
+ ax1.set_ylabel("x [A]")
+ ax1.set_xlabel("y [A]")
+ ax1.set_title("Initial probe intensity")
+
+ ax2.imshow(
+ complex_propagated_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax2)
+ cax2 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(
+ cax2,
+ chroma_boost=chroma_boost,
+ )
+ ax2.set_ylabel("x [A]")
+ ax2.set_xlabel("y [A]")
+ ax2.set_title("Propagated probe intensity")
+
+ ax3.imshow(
+ asnumpy(probe_overlap),
+ extent=extent,
+ cmap="Greys_r",
+ )
+ ax3.scatter(
+ self.positions[0, :, 1],
+ self.positions[0, :, 0],
+ s=2.5,
+ color=(1, 0, 0, 1),
+ )
+ ax3.set_ylabel("x [A]")
+ ax3.set_xlabel("y [A]")
+ ax3.set_xlim((extent[0], extent[1]))
+ ax3.set_ylim((extent[2], extent[3]))
+ ax3.set_title("Object field of view")
+
+ fig.tight_layout()
+
+ self._preprocessed = True
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _overlap_projection(
+ self, current_object_V, current_object_A_projected, current_probe
+ ):
+ """
+ Ptychographic overlap projection method.
+
+ Parameters
+ --------
+ current_object_V: np.ndarray
+ Current electrostatic object estimate
+ current_object_A_projected: np.ndarray
+ Current projected magnetic object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ """
+
+ xp = self._xp
+
+ complex_object = xp.exp(1j * (current_object_V + current_object_A_projected))
+ object_patches = complex_object[
+ :, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col
+ ]
+
+ propagated_probes = xp.empty_like(object_patches)
+ propagated_probes[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes = object_patches[s] * propagated_probes[s]
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes[s + 1] = self._propagate_array(
+ transmitted_probes, self._propagator_arrays[s]
+ )
+
+ return propagated_probes, object_patches, transmitted_probes
+
+ def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes):
+ """
+ Ptychographic fourier projection method for GD method.
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit wave difference
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ fourier_exit_waves = xp.fft.fft2(transmitted_probes)
+
+ error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2)
+
+ modified_exit_wave = xp.fft.ifft2(
+ amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves))
+ )
+
+ exit_waves = modified_exit_wave - transmitted_probes
+
+ return exit_waves, error
+
+ def _projection_sets_fourier_projection(
+ self,
+ amplitudes,
+ transmitted_probes,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic fourier projection method for DM_AP and RAAR methods.
+ Generalized projection using three parameters: a,b,c
+
+ DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha
+ DM: DM_AP(1.0), AP: DM_AP(0.0)
+
+ RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2
+ DM : RAAR(1.0)
+
+ RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2
+ DM: RRR(1.0)
+
+ SUPERFLIP : a = 0, b = 1, c = 2
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit wave difference
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ projection_x = 1 - projection_a - projection_b
+ projection_y = 1 - projection_c
+
+ if exit_waves is None:
+ exit_waves = transmitted_probes.copy()
+
+ fourier_exit_waves = xp.fft.fft2(transmitted_probes)
+ error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2)
+
+ factor_to_be_projected = (
+ projection_c * transmitted_probes + projection_y * exit_waves
+ )
+ fourier_projected_factor = xp.fft.fft2(factor_to_be_projected)
+
+ fourier_projected_factor = amplitudes * xp.exp(
+ 1j * xp.angle(fourier_projected_factor)
+ )
+ projected_factor = xp.fft.ifft2(fourier_projected_factor)
+
+ exit_waves = (
+ projection_x * exit_waves
+ + projection_a * transmitted_probes
+ + projection_b * projected_factor
+ )
+
+ return exit_waves, error
+
+ def _forward(
+ self,
+ current_object_V,
+ current_object_A_projected,
+ current_probe,
+ amplitudes,
+ exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic forward operator.
+ Calls _overlap_projection() and the appropriate _fourier_projection().
+
+ Parameters
+ --------
+ current_object_V: np.ndarray
+ Current electrostatic object estimate
+ current_object_A_projected: np.ndarray
+ Current projected magnetic object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ propagated_probes:np.ndarray
+ Prop[object^n*probe^n]
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ (
+ propagated_probes,
+ object_patches,
+ transmitted_probes,
+ ) = self._overlap_projection(
+ current_object_V,
+ current_object_A_projected,
+ current_probe,
+ )
+
+ if use_projection_scheme:
+ (
+ exit_waves[self._active_tilt_index],
+ error,
+ ) = self._projection_sets_fourier_projection(
+ amplitudes,
+ transmitted_probes,
+ exit_waves[self._active_tilt_index],
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ else:
+ exit_waves, error = self._gradient_descent_fourier_projection(
+ amplitudes, transmitted_probes
+ )
+
+ return propagated_probes, object_patches, transmitted_probes, exit_waves, error
+
+ def _gradient_descent_adjoint(
+ self,
+ current_object_V,
+ current_object_A_projected,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object_V: np.ndarray
+ Current electrostatic object estimate
+ current_object_A_projected: np.ndarray
+ Current projected magnetic object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object_V: np.ndarray
+ Updated electrostatic object estimate
+ updated_object_A_projected: np.ndarray
+ Updated projected magnetic object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ for s in reversed(range(self._num_slices)):
+ probe = propagated_probes[s]
+ obj = object_patches[s]
+
+ # object-update
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(probe) ** 2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ object_update = step_size * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves)
+ )
+ * probe_normalization
+ )
+
+ current_object_V[s] += object_update
+ current_object_A_projected[s] += object_update
+
+ # back-transmit
+ exit_waves *= xp.conj(obj)
+
+ if s > 0:
+ # back-propagate
+ exit_waves = self._propagate_array(
+ exit_waves, xp.conj(self._propagator_arrays[s - 1])
+ )
+ elif not fix_probe:
+ # probe-update
+ object_normalization = xp.sum(
+ (xp.abs(obj) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe += (
+ step_size
+ * xp.sum(
+ exit_waves,
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return current_object_V, current_object_A_projected, current_probe
+
+ def _projection_sets_adjoint(
+ self,
+ current_object_V,
+ current_object_A_projected,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for DM_AP and RAAR methods.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object_V: np.ndarray
+ Current electrostatic object estimate
+ current_object_A_projected: np.ndarray
+ Current projected magnetic object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object_V: np.ndarray
+ Updated electrostatic object estimate
+ updated_object_A_projected: np.ndarray
+ Updated projected magnetic object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ # careful not to modify exit_waves in-place for projection set methods
+ exit_waves_copy = exit_waves.copy()
+ for s in reversed(range(self._num_slices)):
+ probe = propagated_probes[s]
+ obj = object_patches[s]
+
+ # object-update
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(probe) ** 2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ object_update = (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves)
+ )
+ * probe_normalization
+ )
+
+ current_object_V[s] = object_update
+ current_object_A_projected[s] = object_update
+
+ # back-transmit
+ exit_waves_copy *= xp.conj(obj)
+
+ if s > 0:
+ # back-propagate
+ exit_waves_copy = self._propagate_array(
+ exit_waves_copy, xp.conj(self._propagator_arrays[s - 1])
+ )
+
+ elif not fix_probe:
+ # probe-update
+ object_normalization = xp.sum(
+ (xp.abs(obj) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe = (
+ xp.sum(
+ exit_waves_copy,
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return current_object_V, current_object_A_projected, current_probe
+
+ def _adjoint(
+ self,
+ current_object_V,
+ current_object_A_projected,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ use_projection_scheme: bool,
+ step_size: float,
+ normalization_min: float,
+ fix_probe: bool,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object_V: np.ndarray
+ Current electrostatic object estimate
+ current_object_A_projected: np.ndarray
+ Current projected magnetic object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object_V: np.ndarray
+ Updated electrostatic object estimate
+ updated_object_A_projected: np.ndarray
+ Updated projected magnetic object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+
+ if use_projection_scheme:
+ (
+ current_object_V,
+ current_object_A_projected,
+ current_probe,
+ ) = self._projection_sets_adjoint(
+ current_object_V,
+ current_object_A_projected,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves[self._active_tilt_index],
+ normalization_min,
+ fix_probe,
+ )
+ else:
+ (
+ current_object_V,
+ current_object_A_projected,
+ current_probe,
+ ) = self._gradient_descent_adjoint(
+ current_object_V,
+ current_object_A_projected,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ )
+
+ return current_object_V, current_object_A_projected, current_probe
+
+ def _position_correction(
+ self,
+ current_object,
+ current_probe,
+ transmitted_probes,
+ amplitudes,
+ current_positions,
+ positions_step_size,
+ constrain_position_distance,
+ ):
+ """
+ Position correction using estimated intensity gradient.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe:np.ndarray
+ fractionally-shifted probes
+ transmitted_probes: np.ndarray
+ Transmitted probes at each layer
+ amplitudes: np.ndarray
+ Measured amplitudes
+ current_positions: np.ndarray
+ Current positions estimate
+ positions_step_size: float
+ Positions step size
+ constrain_position_distance: float
+ Distance to constrain position correction within original
+ field of view in A
+
+ Returns
+ --------
+ updated_positions: np.ndarray
+ Updated positions estimate
+ """
+
+ xp = self._xp
+
+ # Intensity gradient
+ exit_waves_fft = xp.fft.fft2(transmitted_probes[-1])
+ exit_waves_fft_conj = xp.conj(exit_waves_fft)
+ estimated_intensity = xp.abs(exit_waves_fft) ** 2
+ measured_intensity = amplitudes**2
+
+ flat_shape = (transmitted_probes[-1].shape[0], -1)
+ difference_intensity = (measured_intensity - estimated_intensity).reshape(
+ flat_shape
+ )
+
+ # Computing perturbed exit waves one at a time to save on memory
+
+ complex_object = xp.exp(1j * current_object)
+
+ # dx
+ propagated_probes = fft_shift(current_probe, self._positions_px_fractional, xp)
+ obj_rolled_patches = complex_object[
+ :,
+ (self._vectorized_patch_indices_row + 1) % self._object_shape[0],
+ self._vectorized_patch_indices_col,
+ ]
+
+ transmitted_probes_perturbed = xp.empty_like(obj_rolled_patches)
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes_perturbed[s] = obj_rolled_patches[s] * propagated_probes
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes = self._propagate_array(
+ transmitted_probes_perturbed[s], self._propagator_arrays[s]
+ )
+
+ exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(
+ transmitted_probes_perturbed[-1]
+ )
+
+ # dy
+ propagated_probes = fft_shift(current_probe, self._positions_px_fractional, xp)
+ obj_rolled_patches = complex_object[
+ :,
+ self._vectorized_patch_indices_row,
+ (self._vectorized_patch_indices_col + 1) % self._object_shape[1],
+ ]
+
+ transmitted_probes_perturbed = xp.empty_like(obj_rolled_patches)
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes_perturbed[s] = obj_rolled_patches[s] * propagated_probes
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes = self._propagate_array(
+ transmitted_probes_perturbed[s], self._propagator_arrays[s]
+ )
+
+ exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(
+ transmitted_probes_perturbed[-1]
+ )
+
+ partial_intensity_dx = 2 * xp.real(
+ exit_waves_dx_fft * exit_waves_fft_conj
+ ).reshape(flat_shape)
+ partial_intensity_dy = 2 * xp.real(
+ exit_waves_dy_fft * exit_waves_fft_conj
+ ).reshape(flat_shape)
+
+ coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy))
+
+ # positions_update = xp.einsum(
+ # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity
+ # )
+
+ coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2)
+ positions_update = (
+ xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix)
+ @ coefficients_matrix_T
+ @ difference_intensity[..., None]
+ )
+
+ if constrain_position_distance is not None:
+ constrain_position_distance /= xp.sqrt(
+ self.sampling[0] ** 2 + self.sampling[1] ** 2
+ )
+ x1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 0
+ ]
+ y1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 1
+ ]
+ x0 = self._positions_px_initial[:, 0]
+ y0 = self._positions_px_initial[:, 1]
+ if self._rotation_best_transpose:
+ x0, y0 = xp.array([y0, x0])
+ x1, y1 = xp.array([y1, x1])
+
+ if self._rotation_best_rad is not None:
+ rotation_angle = self._rotation_best_rad
+ x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin(
+ -rotation_angle
+ ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle)
+ x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin(
+ -rotation_angle
+ ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle)
+
+ outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + (
+ x1 < (xp.min(x0) - constrain_position_distance)
+ ) + (y1 > (xp.max(y0) + constrain_position_distance)) + (
+ y1 < (xp.min(y0) - constrain_position_distance)
+ ) > 0
+
+ positions_update[..., 0][outlier_ind] = 0
+
+ current_positions -= positions_step_size * positions_update[..., 0]
+
+ return current_positions
+
+ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma):
+ """
+ Ptychographic smoothness constraint.
+ Used for blurring object.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ gaussian_filter = self._gaussian_filter
+
+ gaussian_filter_sigma /= self.sampling[0]
+ current_object = gaussian_filter(current_object, gaussian_filter_sigma)
+
+ return current_object
+
+ def _object_butterworth_constraint(
+ self, current_object, q_lowpass, q_highpass, butterworth_order
+ ):
+ """
+ Butterworth filter
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+ qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1])
+ qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0])
+ qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1])
+ qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij")
+ qra = xp.sqrt(qza**2 + qxa**2 + qya**2)
+
+ env = xp.ones_like(qra)
+ if q_highpass:
+ env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order))
+ if q_lowpass:
+ env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order))
+
+ current_object_mean = xp.mean(current_object)
+ current_object -= current_object_mean
+ current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env)
+ current_object += current_object_mean
+
+ return xp.real(current_object)
+
+ def _divergence_free_constraint(self, vector_field):
+ """
+ Leray projection operator
+
+ Parameters
+ --------
+ vector_field: np.ndarray
+ Current object vector as Az, Ax, Ay
+
+ Returns
+ --------
+ projected_vector_field: np.ndarray
+ Divergence-less object vector as Az, Ax, Ay
+ """
+ xp = self._xp
+
+ spacings = (self.sampling[1],) + self.sampling
+ vector_field = project_vector_field_divergence(
+ vector_field, spacings=spacings, xp=xp
+ )
+
+ return vector_field
+
+ def _object_denoise_tv_pylops(self, current_object, weights, iterations):
+ """
+ Performs second order TV denoising along x and y
+
+ Parameters
+ ----------
+ current_object: np.ndarray
+ Current object estimate
+ weights : [float, float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ iterations: float
+ Number of iterations to run in denoising algorithm.
+ `niter_out` in pylops
+
+ Returns
+ -------
+ constrained_object: np.ndarray
+ Constrained object estimate
+
+ """
+ xp = self._xp
+
+ if xp.iscomplexobj(current_object):
+ current_object_tv = current_object
+ warnings.warn(
+ ("TV denoising is currently only supported for potential objects."),
+ UserWarning,
+ )
+
+ else:
+ # zero pad at top and bottom slice
+ pad_width = ((1, 1), (0, 0), (0, 0))
+ current_object = xp.pad(
+ current_object, pad_width=pad_width, mode="constant"
+ )
+
+ # run tv denoising
+ nz, nx, ny = current_object.shape
+ niter_out = iterations
+ niter_in = 1
+ Iop = pylops.Identity(nx * ny * nz)
+
+ if weights[0] == 0:
+ xy_laplacian = pylops.Laplacian(
+ (nz, nx, ny), axes=(1, 2), edge=False, kind="backward"
+ )
+ l1_regs = [xy_laplacian]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=[weights[1]],
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ elif weights[1] == 0:
+ z_gradient = pylops.FirstDerivative(
+ (nz, nx, ny), axis=0, edge=False, kind="backward"
+ )
+ l1_regs = [z_gradient]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=[weights[0]],
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ else:
+ z_gradient = pylops.FirstDerivative(
+ (nz, nx, ny), axis=0, edge=False, kind="backward"
+ )
+ xy_laplacian = pylops.Laplacian(
+ (nz, nx, ny), axes=(1, 2), edge=False, kind="backward"
+ )
+ l1_regs = [z_gradient, xy_laplacian]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=weights,
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ # remove padding
+ current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1]
+
+ return current_object_tv
+
+ def _constraints(
+ self,
+ current_object,
+ current_probe,
+ current_positions,
+ fix_com,
+ fit_probe_aberrations,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ constrain_probe_amplitude,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ fix_probe_aperture,
+ initial_probe_aperture,
+ fix_positions,
+ global_affine_transformation,
+ gaussian_filter,
+ gaussian_filter_sigma_e,
+ gaussian_filter_sigma_m,
+ butterworth_filter,
+ q_lowpass_e,
+ q_lowpass_m,
+ q_highpass_e,
+ q_highpass_m,
+ butterworth_order,
+ object_positivity,
+ shrinkage_rad,
+ object_mask,
+ tv_denoise,
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ ):
+ """
+ Ptychographic constraints operator.
+ Calls _threshold_object_constraint() and _probe_center_of_mass_constraint()
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ current_positions: np.ndarray
+ Current positions estimate
+ fix_com: bool
+ If True, probe CoM is fixed to the center
+ fit_probe_aberrations: bool
+ If True, fits the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ constrain_probe_amplitude: bool
+ If True, probe amplitude is constrained by top hat function
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude: bool
+ If True, probe aperture is constrained by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_probe_aperture: bool,
+ If True, probe Fourier amplitude is replaced by initial probe aperture.
+ initial_probe_aperture: np.ndarray,
+ Initial probe aperture to use in replacing probe Fourier amplitude.
+ fix_positions: bool
+ If True, positions are not updated
+ gaussian_filter: bool
+ If True, applies real-space gaussian filter
+ gaussian_filter_sigma_e: float
+ Standard deviation of gaussian kernel for electrostatic object in A
+ gaussian_filter_sigma_m: float
+ Standard deviation of gaussian kernel for magnetic object in A
+ butterworth_filter: bool
+ If True, applies high-pass butteworth filter
+ q_lowpass_e: float
+ Cut-off frequency in A^-1 for low-pass filtering electrostatic object
+ q_lowpass_m: float
+ Cut-off frequency in A^-1 for low-pass filtering magnetic object
+ q_highpass_e: float
+ Cut-off frequency in A^-1 for high-pass filtering electrostatic object
+ q_highpass_m: float
+ Cut-off frequency in A^-1 for high-pass filtering magnetic object
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ object_positivity: bool
+ If True, forces object to be positive
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ object_mask: np.ndarray (boolean)
+ If not None, used to calculate additional shrinkage using masked-mean of object
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weights: [float,float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ constrained_positions: np.ndarray
+ Constrained positions estimate
+ """
+
+ if gaussian_filter:
+ current_object[0] = self._object_gaussian_constraint(
+ current_object[0], gaussian_filter_sigma_e
+ )
+ current_object[1] = self._object_gaussian_constraint(
+ current_object[1], gaussian_filter_sigma_m
+ )
+ current_object[2] = self._object_gaussian_constraint(
+ current_object[2], gaussian_filter_sigma_m
+ )
+ current_object[3] = self._object_gaussian_constraint(
+ current_object[3], gaussian_filter_sigma_m
+ )
+
+ if butterworth_filter:
+ current_object[0] = self._object_butterworth_constraint(
+ current_object[0],
+ q_lowpass_e,
+ q_highpass_e,
+ butterworth_order,
+ )
+ current_object[1] = self._object_butterworth_constraint(
+ current_object[1],
+ q_lowpass_m,
+ q_highpass_m,
+ butterworth_order,
+ )
+ current_object[2] = self._object_butterworth_constraint(
+ current_object[2],
+ q_lowpass_m,
+ q_highpass_m,
+ butterworth_order,
+ )
+ current_object[3] = self._object_butterworth_constraint(
+ current_object[3],
+ q_lowpass_m,
+ q_highpass_m,
+ butterworth_order,
+ )
+
+ elif tv_denoise:
+ current_object[0] = self._object_denoise_tv_pylops(
+ current_object[0],
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ )
+
+ current_object[1] = self._object_denoise_tv_pylops(
+ current_object[1],
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ )
+
+ current_object[2] = self._object_denoise_tv_pylops(
+ current_object[2],
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ )
+
+ current_object[3] = self._object_denoise_tv_pylops(
+ current_object[3],
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ )
+
+ if shrinkage_rad > 0.0 or object_mask is not None:
+ current_object[0] = self._object_shrinkage_constraint(
+ current_object[0],
+ shrinkage_rad,
+ object_mask,
+ )
+
+ if object_positivity:
+ current_object[0] = self._object_positivity_constraint(current_object[0])
+
+ if fix_com:
+ current_probe = self._probe_center_of_mass_constraint(current_probe)
+
+ if fix_probe_aperture:
+ current_probe = self._probe_aperture_constraint(
+ current_probe,
+ initial_probe_aperture,
+ )
+ elif constrain_probe_fourier_amplitude:
+ current_probe = self._probe_fourier_amplitude_constraint(
+ current_probe,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ )
+
+ if fit_probe_aberrations:
+ current_probe = self._probe_aberration_fitting_constraint(
+ current_probe,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ )
+
+ if constrain_probe_amplitude:
+ current_probe = self._probe_amplitude_constraint(
+ current_probe,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ )
+
+ if not fix_positions:
+ current_positions = self._positions_center_of_mass_constraint(
+ current_positions
+ )
+
+ if global_affine_transformation:
+ current_positions = self._positions_affine_transformation_constraint(
+ self._positions_px_initial, current_positions
+ )
+
+ return current_object, current_probe, current_positions
+
+ def reconstruct(
+ self,
+ max_iter: int = 64,
+ reconstruction_method: str = "gradient-descent",
+ reconstruction_parameter: float = 1.0,
+ reconstruction_parameter_a: float = None,
+ reconstruction_parameter_b: float = None,
+ reconstruction_parameter_c: float = None,
+ max_batch_size: int = None,
+ seed_random: int = None,
+ step_size: float = 0.5,
+ normalization_min: float = 1,
+ positions_step_size: float = 0.9,
+ fix_com: bool = True,
+ fix_probe_iter: int = 0,
+ fix_probe_aperture_iter: int = 0,
+ constrain_probe_amplitude_iter: int = 0,
+ constrain_probe_amplitude_relative_radius: float = 0.5,
+ constrain_probe_amplitude_relative_width: float = 0.05,
+ constrain_probe_fourier_amplitude_iter: int = 0,
+ constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0,
+ constrain_probe_fourier_amplitude_constant_intensity: bool = False,
+ fix_positions_iter: int = np.inf,
+ constrain_position_distance: float = None,
+ global_affine_transformation: bool = True,
+ gaussian_filter_sigma_e: float = None,
+ gaussian_filter_sigma_m: float = None,
+ gaussian_filter_iter: int = np.inf,
+ fit_probe_aberrations_iter: int = 0,
+ fit_probe_aberrations_max_angular_order: int = 4,
+ fit_probe_aberrations_max_radial_order: int = 4,
+ butterworth_filter_iter: int = np.inf,
+ q_lowpass_e: float = None,
+ q_lowpass_m: float = None,
+ q_highpass_e: float = None,
+ q_highpass_m: float = None,
+ butterworth_order: float = 2,
+ object_positivity: bool = True,
+ shrinkage_rad: float = 0.0,
+ fix_potential_baseline: bool = True,
+ tv_denoise_iter=np.inf,
+ tv_denoise_weights=None,
+ tv_denoise_inner_iter=40,
+ collective_tilt_updates: bool = False,
+ store_iterations: bool = False,
+ progress_bar: bool = True,
+ reset: bool = None,
+ ):
+ """
+ Ptychographic reconstruction main method.
+
+ Parameters
+ --------
+ max_iter: int, optional
+ Maximum number of iterations to run
+ reconstruction_method: str, optional
+ Specifies which reconstruction algorithm to use, one of:
+ "generalized-projections",
+ "DM_AP" (or "difference-map_alternating-projections"),
+ "RAAR" (or "relaxed-averaged-alternating-reflections"),
+ "RRR" (or "relax-reflect-reflect"),
+ "SUPERFLIP" (or "charge-flipping"), or
+ "GD" (or "gradient_descent")
+ reconstruction_parameter: float, optional
+ Reconstruction parameter for various reconstruction methods above.
+ reconstruction_parameter_a: float, optional
+ Reconstruction parameter a for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_b: float, optional
+ Reconstruction parameter b for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_c: float, optional
+ Reconstruction parameter c for reconstruction_method='generalized-projections'.
+ max_batch_size: int, optional
+ Max number of probes to update at once
+ seed_random: int, optional
+ Seeds the random number generator, only applicable when max_batch_size is not None
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ positions_step_size: float, optional
+ Positions update step size
+ fix_com: bool, optional
+ If True, fixes center of mass of probe
+ fix_probe_iter: int, optional
+ Number of iterations to run with a fixed probe before updating probe estimate
+ fix_probe_aperture_iter: int, optional
+ Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate
+ constrain_probe_amplitude_iter: int, optional
+ Number of iterations to run while constraining the real-space probe with a top-hat support.
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude_iter: int, optional
+ Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_positions_iter: int, optional
+ Number of iterations to run with fixed positions before updating positions estimate
+ constrain_position_distance: float, optional
+ Distance to constrain position correction within original
+ field of view in A
+ global_affine_transformation: bool, optional
+ If True, positions are assumed to be a global affine transform from initial scan
+ gaussian_filter_sigma_e: float
+ Standard deviation of gaussian kernel for electrostatic object in A
+ gaussian_filter_sigma_m: float
+ Standard deviation of gaussian kernel for magnetic object in A
+ gaussian_filter_iter: int, optional
+ Number of iterations to run using object smoothness constraint
+ fit_probe_aberrations_iter: int, optional
+ Number of iterations to run while fitting the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ butterworth_filter_iter: int, optional
+ Number of iterations to run using high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ object_positivity: bool, optional
+ If True, forces object to be positive
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weights: [float,float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ collective_tilt_updates: bool
+ if True perform collective tilt updates
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ store_iterations: bool, optional
+ If True, reconstructed objects and probes are stored at each iteration
+ progress_bar: bool, optional
+ If True, reconstruction progress is displayed
+ reset: bool, optional
+ If True, previous reconstructions are ignored
+
+ Returns
+ --------
+ self: OverlapMagneticTomographicReconstruction
+ Self to accommodate chaining
+ """
+ asnumpy = self._asnumpy
+ xp = self._xp
+
+ # Reconstruction method
+
+ if reconstruction_method == "generalized-projections":
+ if (
+ reconstruction_parameter_a is None
+ or reconstruction_parameter_b is None
+ or reconstruction_parameter_c is None
+ ):
+ raise ValueError(
+ (
+ "reconstruction_parameter_a/b/c must all be specified "
+ "when using reconstruction_method='generalized-projections'."
+ )
+ )
+
+ use_projection_scheme = True
+ projection_a = reconstruction_parameter_a
+ projection_b = reconstruction_parameter_b
+ projection_c = reconstruction_parameter_c
+ step_size = None
+ elif (
+ reconstruction_method == "DM_AP"
+ or reconstruction_method == "difference-map_alternating-projections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = 1
+ projection_c = 1 + reconstruction_parameter
+ step_size = None
+ elif (
+ reconstruction_method == "RAAR"
+ or reconstruction_method == "relaxed-averaged-alternating-reflections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = 1 - 2 * reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "RRR"
+ or reconstruction_method == "relax-reflect-reflect"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0:
+ raise ValueError("reconstruction_parameter must be between 0-2.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "SUPERFLIP"
+ or reconstruction_method == "charge-flipping"
+ ):
+ use_projection_scheme = True
+ projection_a = 0
+ projection_b = 1
+ projection_c = 2
+ reconstruction_parameter = None
+ step_size = None
+ elif (
+ reconstruction_method == "GD" or reconstruction_method == "gradient-descent"
+ ):
+ use_projection_scheme = False
+ projection_a = None
+ projection_b = None
+ projection_c = None
+ reconstruction_parameter = None
+ else:
+ raise ValueError(
+ (
+ "reconstruction_method must be one of 'generalized-projections', "
+ "'DM_AP' (or 'difference-map_alternating-projections'), "
+ "'RAAR' (or 'relaxed-averaged-alternating-reflections'), "
+ "'RRR' (or 'relax-reflect-reflect'), "
+ "'SUPERFLIP' (or 'charge-flipping'), "
+ f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}."
+ )
+ )
+
+ if self._verbose:
+ if max_batch_size is not None:
+ if use_projection_scheme:
+ raise ValueError(
+ (
+ "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. "
+ "Use reconstruction_method='GD' or set max_batch_size=None."
+ )
+ )
+ else:
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}, "
+ f"in batches of max {max_batch_size} measurements."
+ )
+ )
+ else:
+ if reconstruction_parameter is not None:
+ if np.array(reconstruction_parameter).shape == (3,):
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}."
+ )
+ )
+ else:
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}."
+ )
+ )
+ else:
+ if step_size is not None:
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min}."
+ )
+ )
+ else:
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}."
+ )
+ )
+
+ # Position Correction + Collective Updates not yet implemented
+ if fix_positions_iter < max_iter:
+ raise NotImplementedError(
+ "Position correction is currently incompatible with collective updates."
+ )
+
+ # Batching
+
+ if max_batch_size is not None:
+ xp.random.seed(seed_random)
+
+ # initialization
+ if store_iterations and (not hasattr(self, "object_iterations") or reset):
+ self.object_iterations = []
+ self.probe_iterations = []
+
+ if reset:
+ self._object = self._object_initial.copy()
+ self.error_iterations = []
+ self._probe = self._probe_initial.copy()
+ self._positions_px_all = self._positions_px_initial_all.copy()
+ if hasattr(self, "_tf"):
+ del self._tf
+
+ if use_projection_scheme:
+ self._exit_waves = [None] * self._num_tilts
+ else:
+ self._exit_waves = None
+ elif reset is None:
+ if hasattr(self, "error"):
+ warnings.warn(
+ (
+ "Continuing reconstruction from previous result. "
+ "Use reset=True for a fresh start."
+ ),
+ UserWarning,
+ )
+ else:
+ self.error_iterations = []
+ if use_projection_scheme:
+ self._exit_waves = [None] * self._num_tilts
+ else:
+ self._exit_waves = None
+
+ if gaussian_filter_sigma_m is None:
+ gaussian_filter_sigma_m = gaussian_filter_sigma_e
+
+ if q_lowpass_m is None:
+ q_lowpass_m = q_lowpass_e
+
+ # main loop
+ for a0 in tqdmnd(
+ max_iter,
+ desc="Reconstructing object and probe",
+ unit=" iter",
+ disable=not progress_bar,
+ ):
+ error = 0.0
+
+ if collective_tilt_updates:
+ collective_object = xp.zeros_like(self._object)
+
+ tilt_indices = np.arange(self._num_tilts)
+ np.random.shuffle(tilt_indices)
+
+ for tilt_index in tilt_indices:
+ tilt_error = 0.0
+ self._active_tilt_index = tilt_index
+
+ alpha_deg, beta_deg = self._tilt_angles_deg[self._active_tilt_index]
+ alpha, beta = np.deg2rad([alpha_deg, beta_deg])
+
+ # V
+ self._object[0] = self._euler_angle_rotate_volume(
+ self._object[0],
+ alpha_deg,
+ beta_deg,
+ )
+
+ # Az
+ self._object[1] = self._euler_angle_rotate_volume(
+ self._object[1],
+ alpha_deg,
+ beta_deg,
+ )
+
+ # Ax
+ self._object[2] = self._euler_angle_rotate_volume(
+ self._object[2],
+ alpha_deg,
+ beta_deg,
+ )
+
+ # Ay
+ self._object[3] = self._euler_angle_rotate_volume(
+ self._object[3],
+ alpha_deg,
+ beta_deg,
+ )
+
+ object_A = self._object[1] * np.cos(beta) + np.sin(beta) * (
+ self._object[3] * np.cos(alpha) - self._object[2] * np.sin(alpha)
+ )
+
+ object_sliced_V = self._project_sliced_object(
+ self._object[0], self._num_slices
+ )
+
+ object_sliced_A = self._project_sliced_object(
+ object_A, self._num_slices
+ )
+
+ if not use_projection_scheme:
+ object_sliced_old_V = object_sliced_V.copy()
+ object_sliced_old_A = object_sliced_A.copy()
+
+ start_tilt = self._cum_probes_per_tilt[self._active_tilt_index]
+ end_tilt = self._cum_probes_per_tilt[self._active_tilt_index + 1]
+
+ num_diffraction_patterns = end_tilt - start_tilt
+ shuffled_indices = np.arange(num_diffraction_patterns)
+ unshuffled_indices = np.zeros_like(shuffled_indices)
+
+ if max_batch_size is None:
+ current_max_batch_size = num_diffraction_patterns
+ else:
+ current_max_batch_size = max_batch_size
+
+ # randomize
+ if not use_projection_scheme:
+ np.random.shuffle(shuffled_indices)
+
+ unshuffled_indices[shuffled_indices] = np.arange(
+ num_diffraction_patterns
+ )
+
+ positions_px = self._positions_px_all[start_tilt:end_tilt].copy()[
+ shuffled_indices
+ ]
+ initial_positions_px = self._positions_px_initial_all[
+ start_tilt:end_tilt
+ ].copy()[shuffled_indices]
+
+ for start, end in generate_batches(
+ num_diffraction_patterns, max_batch=current_max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_initial = initial_positions_px[start:end]
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ amplitudes = self._amplitudes[start_tilt:end_tilt][
+ shuffled_indices[start:end]
+ ]
+
+ # forward operator
+ (
+ propagated_probes,
+ object_patches,
+ transmitted_probes,
+ self._exit_waves,
+ batch_error,
+ ) = self._forward(
+ object_sliced_V,
+ object_sliced_A,
+ self._probe,
+ amplitudes,
+ self._exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ # adjoint operator
+ object_sliced_V, object_sliced_A, self._probe = self._adjoint(
+ object_sliced_V,
+ object_sliced_A,
+ self._probe,
+ object_patches,
+ propagated_probes,
+ self._exit_waves,
+ use_projection_scheme=use_projection_scheme,
+ step_size=step_size,
+ normalization_min=normalization_min,
+ fix_probe=a0 < fix_probe_iter,
+ )
+
+ # position correction
+ if a0 >= fix_positions_iter:
+ positions_px[start:end] = self._position_correction(
+ object_sliced_V,
+ self._probe,
+ transmitted_probes,
+ amplitudes,
+ self._positions_px,
+ positions_step_size,
+ constrain_position_distance,
+ )
+
+ tilt_error += batch_error
+
+ if not use_projection_scheme:
+ object_sliced_V -= object_sliced_old_V
+ object_sliced_A -= object_sliced_old_A
+
+ object_update_V = self._expand_sliced_object(
+ object_sliced_V, self._num_voxels
+ )
+ object_update_A = self._expand_sliced_object(
+ object_sliced_A, self._num_voxels
+ )
+
+ if collective_tilt_updates:
+ collective_object[0] += self._euler_angle_rotate_volume(
+ object_update_V,
+ alpha_deg,
+ -beta_deg,
+ )
+ collective_object[1] += self._euler_angle_rotate_volume(
+ object_update_A * np.cos(beta),
+ alpha_deg,
+ -beta_deg,
+ )
+ collective_object[2] -= self._euler_angle_rotate_volume(
+ object_update_A * np.sin(alpha) * np.sin(beta),
+ alpha_deg,
+ -beta_deg,
+ )
+ collective_object[3] += self._euler_angle_rotate_volume(
+ object_update_A * np.cos(alpha) * np.sin(beta),
+ alpha_deg,
+ -beta_deg,
+ )
+ else:
+ self._object[0] += object_update_V
+ self._object[1] += object_update_A * np.cos(beta)
+ self._object[2] -= object_update_A * np.sin(alpha) * np.sin(beta)
+ self._object[3] += object_update_A * np.cos(alpha) * np.sin(beta)
+
+ self._object[0] = self._euler_angle_rotate_volume(
+ self._object[0],
+ alpha_deg,
+ -beta_deg,
+ )
+
+ self._object[1] = self._euler_angle_rotate_volume(
+ self._object[1],
+ alpha_deg,
+ -beta_deg,
+ )
+
+ self._object[2] = self._euler_angle_rotate_volume(
+ self._object[2],
+ alpha_deg,
+ -beta_deg,
+ )
+
+ self._object[3] = self._euler_angle_rotate_volume(
+ self._object[3],
+ alpha_deg,
+ -beta_deg,
+ )
+
+ # Normalize Error
+ tilt_error /= (
+ self._mean_diffraction_intensity[self._active_tilt_index]
+ * num_diffraction_patterns
+ )
+ error += tilt_error
+
+ # constraints
+ self._positions_px_all[start_tilt:end_tilt] = positions_px.copy()[
+ unshuffled_indices
+ ]
+
+ if not collective_tilt_updates:
+ (
+ self._object,
+ self._probe,
+ self._positions_px_all[start_tilt:end_tilt],
+ ) = self._constraints(
+ self._object,
+ self._probe,
+ self._positions_px_all[start_tilt:end_tilt],
+ fix_com=fix_com and a0 >= fix_probe_iter,
+ constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude=a0
+ < constrain_probe_fourier_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
+ fit_probe_aberrations=a0 < fit_probe_aberrations_iter
+ and a0 >= fix_probe_iter,
+ fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
+ fix_probe_aperture=a0 < fix_probe_aperture_iter,
+ initial_probe_aperture=self._probe_initial_aperture,
+ fix_positions=a0 < fix_positions_iter,
+ global_affine_transformation=global_affine_transformation,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma_m is not None,
+ gaussian_filter_sigma_e=gaussian_filter_sigma_e,
+ gaussian_filter_sigma_m=gaussian_filter_sigma_m,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass_m is not None or q_highpass_m is not None),
+ q_lowpass_e=q_lowpass_e,
+ q_lowpass_m=q_lowpass_m,
+ q_highpass_e=q_highpass_e,
+ q_highpass_m=q_highpass_m,
+ butterworth_order=butterworth_order,
+ object_positivity=object_positivity,
+ shrinkage_rad=shrinkage_rad,
+ object_mask=self._object_fov_mask_inverse
+ if fix_potential_baseline
+ and self._object_fov_mask_inverse.sum() > 0
+ else None,
+ tv_denoise=a0 < tv_denoise_iter
+ and tv_denoise_weights is not None,
+ tv_denoise_weights=tv_denoise_weights,
+ tv_denoise_inner_iter=tv_denoise_inner_iter,
+ )
+
+ # Normalize Error Over Tilts
+ error /= self._num_tilts
+
+ self._object[1:] = self._divergence_free_constraint(self._object[1:])
+
+ if collective_tilt_updates:
+ self._object += collective_object / self._num_tilts
+
+ (
+ self._object,
+ self._probe,
+ _,
+ ) = self._constraints(
+ self._object,
+ self._probe,
+ None,
+ fix_com=fix_com and a0 >= fix_probe_iter,
+ constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude=a0
+ < constrain_probe_fourier_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
+ fit_probe_aberrations=a0 < fit_probe_aberrations_iter
+ and a0 >= fix_probe_iter,
+ fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
+ fix_probe_aperture=a0 < fix_probe_aperture_iter,
+ initial_probe_aperture=self._probe_initial_aperture,
+ fix_positions=True,
+ global_affine_transformation=global_affine_transformation,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma_m is not None,
+ gaussian_filter_sigma_e=gaussian_filter_sigma_e,
+ gaussian_filter_sigma_m=gaussian_filter_sigma_m,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass_m is not None or q_highpass_m is not None),
+ q_lowpass_e=q_lowpass_e,
+ q_lowpass_m=q_lowpass_m,
+ q_highpass_e=q_highpass_e,
+ q_highpass_m=q_highpass_m,
+ butterworth_order=butterworth_order,
+ object_positivity=object_positivity,
+ shrinkage_rad=shrinkage_rad,
+ object_mask=self._object_fov_mask_inverse
+ if fix_potential_baseline
+ and self._object_fov_mask_inverse.sum() > 0
+ else None,
+ tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None,
+ tv_denoise_weights=tv_denoise_weights,
+ tv_denoise_inner_iter=tv_denoise_inner_iter,
+ )
+
+ self.error_iterations.append(error.item())
+ if store_iterations:
+ self.object_iterations.append(asnumpy(self._object.copy()))
+ self.probe_iterations.append(self.probe_centered)
+
+ # store result
+ self.object = asnumpy(self._object)
+ self.probe = self.probe_centered
+ self.error = error.item()
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _crop_rotate_object_manually(
+ self,
+ array,
+ angle,
+ x_lims,
+ y_lims,
+ ):
+ """
+ Crops and rotates rotates object manually.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ Object array to crop and rotate. Only operates on numpy arrays for comptatibility.
+ angle: float
+ In-plane angle in degrees to rotate by
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+
+ Returns
+ -------
+ cropped_rotated_array: np.ndarray
+ Cropped and rotated object array
+ """
+
+ asnumpy = self._asnumpy
+ min_x, max_x = x_lims
+ min_y, max_y = y_lims
+
+ if angle is not None:
+ rotated_array = rotate_np(
+ asnumpy(array), angle, reshape=False, axes=(-2, -1)
+ )
+ else:
+ rotated_array = asnumpy(array)
+
+ return rotated_array[..., min_x:max_x, min_y:max_y]
+
+ def _visualize_last_iteration_figax(
+ self,
+ fig,
+ object_ax,
+ convergence_ax,
+ cbar: bool,
+ projection_angle_deg: float,
+ projection_axes: Tuple[int, int],
+ x_lims: Tuple[int, int],
+ y_lims: Tuple[int, int],
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object on a given fig/ax.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure object_ax lives in
+ object_ax: Axes
+ Matplotlib axes to plot reconstructed object in
+ convergence_ax: Axes, optional
+ Matplotlib axes to plot convergence plot in
+ cbar: bool, optional
+ If true, displays a colorbar
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ """
+
+ cmap = kwargs.pop("cmap", "magma")
+
+ asnumpy = self._asnumpy
+
+ if projection_angle_deg is not None:
+ rotated_3d_obj = self._rotate(
+ self._object[0],
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ )
+ rotated_3d_obj = asnumpy(rotated_3d_obj)
+ else:
+ rotated_3d_obj = self.object[0]
+
+ rotated_object = self._crop_rotate_object_manually(
+ rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ im = object_ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(object_ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if convergence_ax is not None and hasattr(self, "error_iterations"):
+ errors = np.array(self.error_iterations)
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = self.error_iterations
+ convergence_ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+
+ def _visualize_last_iteration(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ projection_angle_deg: float,
+ projection_axes: Tuple[int, int],
+ x_lims: Tuple[int, int],
+ y_lims: Tuple[int, int],
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ """
+ figsize = kwargs.pop("figsize", (14, 10) if cbar else (12, 10))
+ cmap_e = kwargs.pop("cmap_e", "magma")
+ cmap_m = kwargs.pop("cmap_m", "PuOr")
+
+ asnumpy = self._asnumpy
+
+ if projection_angle_deg is not None:
+ rotated_3d_obj_V = self._rotate(
+ self._object[0],
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ )
+
+ rotated_3d_obj_Az = self._rotate(
+ self._object[1],
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ )
+
+ rotated_3d_obj_Ax = self._rotate(
+ self._object[2],
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ )
+
+ rotated_3d_obj_Ay = self._rotate(
+ self._object[3],
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ )
+
+ rotated_3d_obj_V = asnumpy(rotated_3d_obj_V)
+ rotated_3d_obj_Az = asnumpy(rotated_3d_obj_Az)
+ rotated_3d_obj_Ax = asnumpy(rotated_3d_obj_Ax)
+ rotated_3d_obj_Ay = asnumpy(rotated_3d_obj_Ay)
+ else:
+ (
+ rotated_3d_obj_V,
+ rotated_3d_obj_Az,
+ rotated_3d_obj_Ax,
+ rotated_3d_obj_Ay,
+ ) = self.object
+
+ rotated_object_Vx = self._crop_rotate_object_manually(
+ rotated_3d_obj_V.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_object_Vy = self._crop_rotate_object_manually(
+ rotated_3d_obj_V.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_object_Vz = self._crop_rotate_object_manually(
+ rotated_3d_obj_V.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+
+ rotated_object_Azx = self._crop_rotate_object_manually(
+ rotated_3d_obj_Az.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_object_Azy = self._crop_rotate_object_manually(
+ rotated_3d_obj_Az.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_object_Azz = self._crop_rotate_object_manually(
+ rotated_3d_obj_Az.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+
+ rotated_object_Axx = self._crop_rotate_object_manually(
+ rotated_3d_obj_Ax.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_object_Axy = self._crop_rotate_object_manually(
+ rotated_3d_obj_Ax.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_object_Axz = self._crop_rotate_object_manually(
+ rotated_3d_obj_Ax.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+
+ rotated_object_Ayx = self._crop_rotate_object_manually(
+ rotated_3d_obj_Ay.sum(1).T, angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_object_Ayy = self._crop_rotate_object_manually(
+ rotated_3d_obj_Ay.sum(2).T, angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_object_Ayz = self._crop_rotate_object_manually(
+ rotated_3d_obj_Ay.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+
+ rotated_shape = rotated_object_Vx.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ arrays = [
+ [
+ rotated_object_Vx,
+ rotated_object_Axx,
+ rotated_object_Ayx,
+ rotated_object_Azx,
+ ],
+ [
+ rotated_object_Vy,
+ rotated_object_Axy,
+ rotated_object_Ayy,
+ rotated_object_Azy,
+ ],
+ [
+ rotated_object_Vz,
+ rotated_object_Axz,
+ rotated_object_Ayz,
+ rotated_object_Azz,
+ ],
+ ]
+
+ titles = [
+ [
+ "V projected along x",
+ "Ax projected along x",
+ "Ay projected along x",
+ "Az projected along x",
+ ],
+ [
+ "V projected along y",
+ "Ax projected along y",
+ "Ay projected along y",
+ "Az projected along y",
+ ],
+ [
+ "V projected along z",
+ "Ax projected along z",
+ "Ay projected along z",
+ "Az projected along z",
+ ],
+ ]
+
+ max_e = np.array(
+ [rotated_object_Vx.max(), rotated_object_Vy.max(), rotated_object_Vz.max()]
+ ).max()
+ max_m = np.array(
+ [
+ [
+ np.abs(rotated_object_Axx).max(),
+ np.abs(rotated_object_Ayx).max(),
+ np.abs(rotated_object_Azx).max(),
+ ],
+ [
+ np.abs(rotated_object_Axy).max(),
+ np.abs(rotated_object_Ayy).max(),
+ np.abs(rotated_object_Azy).max(),
+ ],
+ [
+ np.abs(rotated_object_Axz).max(),
+ np.abs(rotated_object_Ayz).max(),
+ np.abs(rotated_object_Azz).max(),
+ ],
+ ]
+ ).max()
+
+ vmin_e = kwargs.pop("vmin_e", 0.0)
+ vmax_e = kwargs.pop("vmax_e", max_e)
+ vmin_m = kwargs.pop("vmin_m", -max_m)
+ vmax_m = kwargs.pop("vmax_m", max_m)
+
+ if plot_convergence:
+ spec = GridSpec(
+ ncols=4, nrows=4, height_ratios=[4, 4, 4, 1], hspace=0.15, wspace=0.35
+ )
+ else:
+ spec = GridSpec(ncols=4, nrows=3, hspace=0.15, wspace=0.35)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ for sp in spec:
+ row, col = np.unravel_index(sp.num1, (4, 4))
+
+ if row < 3:
+ ax = fig.add_subplot(sp)
+ if sp.is_first_col():
+ cmap = cmap_e
+ vmin = vmin_e
+ vmax = vmax_e
+ else:
+ cmap = cmap_m
+ vmin = vmin_m
+ vmax = vmax_m
+
+ im = ax.imshow(
+ arrays[row][col],
+ cmap=cmap,
+ vmin=vmin,
+ vmax=vmax,
+ extent=extent,
+ **kwargs,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ ax.set_title(titles[row][col])
+
+ if row < 2:
+ ax.set_xticks([])
+ else:
+ ax.set_xlabel("y [A]")
+
+ if col > 0:
+ ax.set_yticks([])
+ else:
+ ax.set_ylabel("x [A]")
+
+ if plot_convergence and hasattr(self, "error_iterations"):
+ errors = np.array(self.error_iterations)
+
+ ax = fig.add_subplot(spec[-1, :])
+ ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax.set_ylabel("NMSE")
+ ax.set_xlabel("Iteration number")
+ ax.yaxis.tick_right()
+
+ spec.tight_layout(fig)
+
+ def _visualize_all_iterations(
+ self,
+ fig,
+ plot_convergence: bool,
+ iterations_grid: Tuple[int, int],
+ projection_angle_deg: float,
+ projection_axes: Tuple[int, int],
+ x_lims: Tuple[int, int],
+ y_lims: Tuple[int, int],
+ **kwargs,
+ ):
+ """
+ Displays all reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ """
+ raise NotImplementedError()
+
+ def visualize(
+ self,
+ fig=None,
+ cbar: bool = True,
+ iterations_grid: Tuple[int, int] = None,
+ plot_convergence: bool = True,
+ projection_angle_deg: float = None,
+ projection_axes: Tuple[int, int] = (0, 2),
+ x_lims=(None, None),
+ y_lims=(None, None),
+ **kwargs,
+ ):
+ """
+ Displays reconstructed object and probe.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+
+ Returns
+ --------
+ self: OverlapMagneticTomographicReconstruction
+ Self to accommodate chaining
+ """
+
+ if iterations_grid is None:
+ self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ projection_angle_deg=projection_angle_deg,
+ projection_axes=projection_axes,
+ cbar=cbar,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ **kwargs,
+ )
+ else:
+ self._visualize_all_iterations(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ iterations_grid=iterations_grid,
+ projection_angle_deg=projection_angle_deg,
+ projection_axes=projection_axes,
+ cbar=cbar,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ **kwargs,
+ )
+
+ return self
+
+ def _return_object_fft(
+ self,
+ obj=None,
+ projection_angle_deg: float = None,
+ projection_axes: Tuple[int, int] = (0, 2),
+ x_lims: Tuple[int, int] = (None, None),
+ y_lims: Tuple[int, int] = (None, None),
+ ):
+ """
+ Returns obj fft shifted to center of array
+
+ Parameters
+ ----------
+ obj: array, optional
+ if None is specified, uses self._object
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ if obj is None:
+ obj = self._object[0]
+ else:
+ obj = xp.asarray(obj[0], dtype=xp.float32)
+
+ if projection_angle_deg is not None:
+ rotated_3d_obj = self._rotate(
+ obj,
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ )
+ rotated_3d_obj = asnumpy(rotated_3d_obj)
+ else:
+ rotated_3d_obj = asnumpy(obj)
+
+ rotated_object = self._crop_rotate_object_manually(
+ rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+
+ return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object))))
+
+ def show_object_fft(
+ self,
+ obj=None,
+ projection_angle_deg: float = None,
+ projection_axes: Tuple[int, int] = (0, 2),
+ x_lims: Tuple[int, int] = (None, None),
+ y_lims: Tuple[int, int] = (None, None),
+ **kwargs,
+ ):
+ """
+ Plot FFT of reconstructed object
+
+ Parameters
+ ----------
+ obj: array, optional
+ if None is specified, uses self._object
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ """
+ if obj is None:
+ object_fft = self._return_object_fft(
+ projection_angle_deg=projection_angle_deg,
+ projection_axes=projection_axes,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ )
+ else:
+ object_fft = self._return_object_fft(
+ obj,
+ projection_angle_deg=projection_angle_deg,
+ projection_axes=projection_axes,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ )
+
+ figsize = kwargs.pop("figsize", (6, 6))
+ cmap = kwargs.pop("cmap", "magma")
+
+ pixelsize = 1 / (object_fft.shape[1] * self.sampling[1])
+ show(
+ object_fft,
+ figsize=figsize,
+ cmap=cmap,
+ scalebar=True,
+ pixelsize=pixelsize,
+ ticks=False,
+ pixelunits=r"$\AA^{-1}$",
+ **kwargs,
+ )
+
+ @property
+ def positions(self):
+ """Probe positions [A]"""
+
+ if self.angular_sampling is None:
+ return None
+
+ asnumpy = self._asnumpy
+ positions_all = []
+ for tilt_index in range(self._num_tilts):
+ positions = self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ].copy()
+ positions[:, 0] *= self.sampling[0]
+ positions[:, 1] *= self.sampling[1]
+ positions_all.append(asnumpy(positions))
+
+ return np.asarray(positions_all)
+
+ def _return_self_consistency_errors(
+ self,
+ max_batch_size=None,
+ ):
+ """Compute the self-consistency errors for each probe position"""
+ raise NotImplementedError()
+
+ def _return_projected_cropped_potential(
+ self,
+ ):
+ """Utility function to accommodate multiple classes"""
+ raise NotImplementedError()
+
+ def show_uncertainty_visualization(
+ self,
+ errors=None,
+ max_batch_size=None,
+ projected_cropped_potential=None,
+ kde_sigma=None,
+ plot_histogram=True,
+ plot_contours=False,
+ **kwargs,
+ ):
+ """Plot uncertainty visualization using self-consistency errors"""
+ raise NotImplementedError()
diff --git a/py4DSTEM/process/phase/iterative_overlap_tomography.py b/py4DSTEM/process/phase/iterative_overlap_tomography.py
new file mode 100644
index 000000000..ddd13ac58
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_overlap_tomography.py
@@ -0,0 +1,3244 @@
+"""
+Module for reconstructing phase objects from 4DSTEM datasets using iterative methods,
+namely overlap tomography.
+"""
+
+import warnings
+from typing import Mapping, Sequence, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pylops
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
+from py4DSTEM.visualize import show
+from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg
+from scipy.ndimage import rotate as rotate_np
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+from emdfile import Custom, tqdmnd
+from py4DSTEM import DataCube
+from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction
+from py4DSTEM.process.phase.utils import (
+ ComplexProbe,
+ fft_shift,
+ generate_batches,
+ polar_aliases,
+ polar_symbols,
+ spatial_frequencies,
+)
+from py4DSTEM.process.utils import electron_wavelength_angstrom, get_CoM, get_shifted_ar
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+
+class OverlapTomographicReconstruction(PtychographicReconstruction):
+ """
+ Overlap Tomographic Reconstruction Class.
+
+ List of diffraction intensities dimensions : (Rx,Ry,Qx,Qy)
+ Reconstructed probe dimensions : (Sx,Sy)
+ Reconstructed object dimensions : (Px,Py,Py)
+
+ such that (Sx,Sy) is the region-of-interest (ROI) size of our probe
+ and (Px,Py,Py) is the padded-object electrostatic potential volume,
+ where x-axis is the tilt.
+
+ Parameters
+ ----------
+ datacube: List of DataCubes
+ Input list of 4D diffraction pattern intensities
+ energy: float
+ The electron energy of the wave functions in eV
+ num_slices: int
+ Number of slices to use in the forward model
+ tilt_orientation_matrices: Sequence[np.ndarray]
+ List of orientation matrices for each tilt
+ semiangle_cutoff: float, optional
+ Semiangle cutoff for the initial probe guess in mrad
+ semiangle_cutoff_pixels: float, optional
+ Semiangle cutoff for the initial probe guess in pixels
+ rolloff: float, optional
+ Semiangle rolloff for the initial probe guess
+ vacuum_probe_intensity: np.ndarray, optional
+ Vacuum probe to use as intensity aperture for initial probe guess
+ polar_parameters: dict, optional
+ Mapping from aberration symbols to their corresponding values. All aberration
+ magnitudes should be given in Å and angles should be given in radians.
+ object_padding_px: Tuple[int,int], optional
+ Pixel dimensions to pad object with
+ If None, the padding is set to half the probe ROI dimensions
+ initial_object_guess: np.ndarray, optional
+ Initial guess for complex-valued object of dimensions (Px,Py,Py)
+ If None, initialized to 1.0
+ initial_probe_guess: np.ndarray, optional
+ Initial guess for complex-valued probe of dimensions (Sx,Sy). If None,
+ initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations
+ initial_scan_positions: list of np.ndarray, optional
+ Probe positions in Å for each diffraction intensity per tilt
+ If None, initialized to a grid scan centered along tilt axis
+ verbose: bool, optional
+ If True, class methods will inherit this and print additional information
+ device: str, optional
+ Calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ object_type: str, optional
+ The object can be reconstructed as a real potential ('potential') or a complex
+ object ('complex')
+ positions_mask: np.ndarray, optional
+ Boolean real space mask to select positions to ignore in reconstruction
+ name: str, optional
+ Class name
+ kwargs:
+ Provide the aberration coefficients as keyword arguments.
+ """
+
+ # Class-specific Metadata
+ _class_specific_metadata = ("_num_slices", "_tilt_orientation_matrices")
+ _swap_zxy_to_xyz = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]])
+
+ def __init__(
+ self,
+ energy: float,
+ num_slices: int,
+ tilt_orientation_matrices: Sequence[np.ndarray],
+ datacube: Sequence[DataCube] = None,
+ semiangle_cutoff: float = None,
+ semiangle_cutoff_pixels: float = None,
+ rolloff: float = 2.0,
+ vacuum_probe_intensity: np.ndarray = None,
+ polar_parameters: Mapping[str, float] = None,
+ object_padding_px: Tuple[int, int] = None,
+ object_type: str = "potential",
+ positions_mask: np.ndarray = None,
+ initial_object_guess: np.ndarray = None,
+ initial_probe_guess: np.ndarray = None,
+ initial_scan_positions: Sequence[np.ndarray] = None,
+ verbose: bool = True,
+ device: str = "cpu",
+ name: str = "overlap-tomographic_reconstruction",
+ **kwargs,
+ ):
+ Custom.__init__(self, name=name)
+
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import affine_transform, gaussian_filter, rotate, zoom
+
+ self._gaussian_filter = gaussian_filter
+ self._zoom = zoom
+ self._rotate = rotate
+ self._affine_transform = affine_transform
+ from scipy.special import erf
+
+ self._erf = erf
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import (
+ affine_transform,
+ gaussian_filter,
+ rotate,
+ zoom,
+ )
+
+ self._gaussian_filter = gaussian_filter
+ self._zoom = zoom
+ self._rotate = rotate
+ self._affine_transform = affine_transform
+ from cupyx.scipy.special import erf
+
+ self._erf = erf
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ for key in kwargs.keys():
+ if (key not in polar_symbols) and (key not in polar_aliases.keys()):
+ raise ValueError("{} not a recognized parameter".format(key))
+
+ self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))
+
+ if polar_parameters is None:
+ polar_parameters = {}
+
+ polar_parameters.update(kwargs)
+ self._set_polar_parameters(polar_parameters)
+
+ num_tilts = len(tilt_orientation_matrices)
+ if initial_scan_positions is None:
+ initial_scan_positions = [None] * num_tilts
+
+ if object_type != "potential":
+ raise NotImplementedError()
+
+ if positions_mask is not None and positions_mask.dtype != "bool":
+ warnings.warn(
+ ("`positions_mask` converted to `bool` array"),
+ UserWarning,
+ )
+ positions_mask = np.asarray(positions_mask, dtype="bool")
+
+ self.set_save_defaults()
+
+ # Data
+ self._datacube = datacube
+ self._object = initial_object_guess
+ self._probe = initial_probe_guess
+
+ # Common Metadata
+ self._vacuum_probe_intensity = vacuum_probe_intensity
+ self._scan_positions = initial_scan_positions
+ self._energy = energy
+ self._semiangle_cutoff = semiangle_cutoff
+ self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
+ self._rolloff = rolloff
+ self._object_type = object_type
+ self._object_padding_px = object_padding_px
+ self._positions_mask = positions_mask
+ self._verbose = verbose
+ self._device = device
+ self._preprocessed = False
+
+ # Class-specific Metadata
+ self._num_slices = num_slices
+ self._tilt_orientation_matrices = tuple(tilt_orientation_matrices)
+ self._num_tilts = num_tilts
+
+ def _precompute_propagator_arrays(
+ self,
+ gpts: Tuple[int, int],
+ sampling: Tuple[float, float],
+ energy: float,
+ slice_thicknesses: Sequence[float],
+ ):
+ """
+ Precomputes propagator arrays complex wave-function will be convolved by,
+ for all slice thicknesses.
+
+ Parameters
+ ----------
+ gpts: Tuple[int,int]
+ Wavefunction pixel dimensions
+ sampling: Tuple[float,float]
+ Wavefunction sampling in A
+ energy: float
+ The electron energy of the wave functions in eV
+ slice_thicknesses: Sequence[float]
+ Array of slice thicknesses in A
+
+ Returns
+ -------
+ propagator_arrays: np.ndarray
+ (T,Sx,Sy) shape array storing propagator arrays
+ """
+ xp = self._xp
+
+ # Frequencies
+ kx, ky = spatial_frequencies(gpts, sampling)
+ kx = xp.asarray(kx, dtype=xp.float32)
+ ky = xp.asarray(ky, dtype=xp.float32)
+
+ # Propagators
+ wavelength = electron_wavelength_angstrom(energy)
+ num_slices = slice_thicknesses.shape[0]
+ propagators = xp.empty(
+ (num_slices, kx.shape[0], ky.shape[0]), dtype=xp.complex64
+ )
+ for i, dz in enumerate(slice_thicknesses):
+ propagators[i] = xp.exp(
+ 1.0j * (-(kx**2)[:, None] * np.pi * wavelength * dz)
+ )
+ propagators[i] *= xp.exp(
+ 1.0j * (-(ky**2)[None] * np.pi * wavelength * dz)
+ )
+
+ return propagators
+
+ def _propagate_array(self, array: np.ndarray, propagator_array: np.ndarray):
+ """
+ Propagates array by Fourier convolving array with propagator_array.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ Wavefunction array to be convolved
+ propagator_array: np.ndarray
+ Propagator array to convolve array with
+
+ Returns
+ -------
+ propagated_array: np.ndarray
+ Fourier-convolved array
+ """
+ xp = self._xp
+
+ return xp.fft.ifft2(xp.fft.fft2(array) * propagator_array)
+
+ def _project_sliced_object(self, array: np.ndarray, output_z):
+ """
+ Expands supersliced object or projects voxel-sliced object.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ 3D array to expand/project
+ output_z: int
+ Output_dimension to expand/project array to.
+ If output_z > array.shape[0] array is expanded, else it's projected
+
+ Returns
+ -------
+ expanded_or_projected_array: np.ndarray
+ expanded or projected array
+ """
+ xp = self._xp
+ input_z = array.shape[0]
+
+ voxels_per_slice = np.ceil(input_z / output_z).astype("int")
+ pad_size = voxels_per_slice * output_z - input_z
+
+ padded_array = xp.pad(array, ((0, pad_size), (0, 0), (0, 0)))
+
+ return xp.sum(
+ padded_array.reshape(
+ (
+ -1,
+ voxels_per_slice,
+ )
+ + array.shape[1:]
+ ),
+ axis=1,
+ )
+
+ def _expand_sliced_object(self, array: np.ndarray, output_z):
+ """
+ Expands supersliced object or projects voxel-sliced object.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ 3D array to expand/project
+ output_z: int
+ Output_dimension to expand/project array to.
+ If output_z > array.shape[0] array is expanded, else it's projected
+
+ Returns
+ -------
+ expanded_or_projected_array: np.ndarray
+ expanded or projected array
+ """
+ xp = self._xp
+ input_z = array.shape[0]
+
+ voxels_per_slice = np.ceil(output_z / input_z).astype("int")
+ remainder_size = voxels_per_slice - (voxels_per_slice * input_z - output_z)
+
+ voxels_in_slice = xp.repeat(voxels_per_slice, input_z)
+ voxels_in_slice[-1] = remainder_size if remainder_size > 0 else voxels_per_slice
+
+ normalized_array = array / xp.asarray(voxels_in_slice)[:, None, None]
+ return xp.repeat(normalized_array, voxels_per_slice, axis=0)[:output_z]
+
+ def _rotate_zxy_volume(
+ self,
+ volume_array,
+ rot_matrix,
+ ):
+ """ """
+
+ xp = self._xp
+ affine_transform = self._affine_transform
+ swap_zxy_to_xyz = self._swap_zxy_to_xyz
+
+ volume = volume_array.copy()
+ volume_shape = xp.asarray(volume.shape)
+ tf = xp.asarray(swap_zxy_to_xyz.T @ rot_matrix.T @ swap_zxy_to_xyz)
+
+ in_center = (volume_shape - 1) / 2
+ out_center = tf @ in_center
+ offset = in_center - out_center
+
+ volume = affine_transform(volume, tf, offset=offset, order=3)
+
+ return volume
+
+ def preprocess(
+ self,
+ diffraction_intensities_shape: Tuple[int, int] = None,
+ reshaping_method: str = "fourier",
+ probe_roi_shape: Tuple[int, int] = None,
+ dp_mask: np.ndarray = None,
+ fit_function: str = "plane",
+ plot_probe_overlaps: bool = True,
+ rotation_real_space_degrees: float = None,
+ diffraction_patterns_rotate_degrees: float = None,
+ diffraction_patterns_transpose: bool = None,
+ force_com_shifts: Sequence[float] = None,
+ force_scan_sampling: float = None,
+ force_angular_sampling: float = None,
+ force_reciprocal_sampling: float = None,
+ progress_bar: bool = True,
+ object_fov_mask: np.ndarray = None,
+ crop_patterns: bool = False,
+ **kwargs,
+ ):
+ """
+ Ptychographic preprocessing step.
+
+ Additionally, it initializes an (Px,Py, Py) array of 1.0
+ and a complex probe using the specified polar parameters.
+
+ Parameters
+ ----------
+ diffraction_intensities_shape: Tuple[int,int], optional
+ Pixel dimensions (Qx',Qy') of the resampled diffraction intensities
+ If None, no resampling of diffraction intenstities is performed
+ reshaping_method: str, optional
+ Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default)
+ probe_roi_shape, (int,int), optional
+ Padded diffraction intensities shape.
+ If None, no padding is performed
+ dp_mask: ndarray, optional
+ Mask for datacube intensities (Qx,Qy)
+ fit_function: str, optional
+ 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two'
+ plot_probe_overlaps: bool, optional
+ If True, initial probe overlaps scanned over the object will be displayed
+ rotation_real_space_degrees: float (degrees), optional
+ In plane rotation around z axis between x axis and tilt axis in
+ real space (forced to be in xy plane)
+ diffraction_patterns_rotate_degrees: float, optional
+ Relative rotation angle between real and reciprocal space
+ diffraction_patterns_transpose: bool, optional
+ Whether diffraction intensities need to be transposed.
+ force_com_shifts: list of tuple of ndarrays (CoMx, CoMy)
+ Amplitudes come from diffraction patterns shifted with
+ the CoM in the upper left corner for each probe unless
+ shift is overwritten. One tuple per tilt.
+ force_scan_sampling: float, optional
+ Override DataCube real space scan pixel size calibrations, in Angstrom
+ force_angular_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in mrad
+ force_reciprocal_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in A^-1
+ object_fov_mask: np.ndarray (boolean)
+ Boolean mask of FOV. Used to calculate additional shrinkage of object
+ If None, probe_overlap intensity is thresholded
+ crop_patterns: bool
+ if True, crop patterns to avoid wrap around of patterns when centering
+
+ Returns
+ --------
+ self: OverlapTomographicReconstruction
+ Self to accommodate chaining
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # set additional metadata
+ self._diffraction_intensities_shape = diffraction_intensities_shape
+ self._reshaping_method = reshaping_method
+ self._probe_roi_shape = probe_roi_shape
+ self._dp_mask = dp_mask
+
+ if self._datacube is None:
+ raise ValueError(
+ (
+ "The preprocess() method requires a DataCube. "
+ "Please run ptycho.attach_datacube(DataCube) first."
+ )
+ )
+
+ # Prepopulate various arrays
+ num_probes_per_tilt = [0]
+ for dc in self._datacube:
+ rx, ry = dc.Rshape
+ num_probes_per_tilt.append(rx * ry)
+
+ self._num_diffraction_patterns = sum(num_probes_per_tilt)
+ self._cum_probes_per_tilt = np.cumsum(np.array(num_probes_per_tilt))
+
+ self._mean_diffraction_intensity = []
+ self._positions_px_all = np.empty((self._num_diffraction_patterns, 2))
+
+ self._rotation_best_rad = np.deg2rad(diffraction_patterns_rotate_degrees)
+ self._rotation_best_transpose = diffraction_patterns_transpose
+
+ if force_com_shifts is None:
+ force_com_shifts = [None] * self._num_tilts
+
+ for tilt_index in tqdmnd(
+ self._num_tilts,
+ desc="Preprocessing data",
+ unit="tilt",
+ disable=not progress_bar,
+ ):
+ if tilt_index == 0:
+ (
+ self._datacube[tilt_index],
+ self._vacuum_probe_intensity,
+ self._dp_mask,
+ force_com_shifts[tilt_index],
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube[tilt_index],
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ dp_mask=self._dp_mask,
+ com_shifts=force_com_shifts[tilt_index],
+ )
+
+ self._amplitudes = xp.empty(
+ (self._num_diffraction_patterns,) + self._datacube[0].Qshape
+ )
+ self._region_of_interest_shape = np.array(
+ self._amplitudes[0].shape[-2:]
+ )
+
+ else:
+ (
+ self._datacube[tilt_index],
+ _,
+ _,
+ force_com_shifts[tilt_index],
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube[tilt_index],
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=None,
+ dp_mask=None,
+ com_shifts=force_com_shifts[tilt_index],
+ )
+
+ intensities = self._extract_intensities_and_calibrations_from_datacube(
+ self._datacube[tilt_index],
+ require_calibrations=True,
+ force_scan_sampling=force_scan_sampling,
+ force_angular_sampling=force_angular_sampling,
+ force_reciprocal_sampling=force_reciprocal_sampling,
+ )
+
+ (
+ com_measured_x,
+ com_measured_y,
+ com_fitted_x,
+ com_fitted_y,
+ com_normalized_x,
+ com_normalized_y,
+ ) = self._calculate_intensities_center_of_mass(
+ intensities,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts[tilt_index],
+ )
+
+ (
+ self._amplitudes[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ],
+ mean_diffraction_intensity_temp,
+ ) = self._normalize_diffraction_intensities(
+ intensities,
+ com_fitted_x,
+ com_fitted_y,
+ crop_patterns,
+ self._positions_mask[tilt_index],
+ )
+
+ self._mean_diffraction_intensity.append(mean_diffraction_intensity_temp)
+
+ del (
+ intensities,
+ com_measured_x,
+ com_measured_y,
+ com_fitted_x,
+ com_fitted_y,
+ com_normalized_x,
+ com_normalized_y,
+ )
+
+ self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ] = self._calculate_scan_positions_in_pixels(
+ self._scan_positions[tilt_index], self._positions_mask[tilt_index]
+ )
+
+ # handle semiangle specified in pixels
+ if self._semiangle_cutoff_pixels:
+ self._semiangle_cutoff = (
+ self._semiangle_cutoff_pixels * self._angular_sampling[0]
+ )
+
+ # Object Initialization
+ if self._object is None:
+ pad_x = self._object_padding_px[0][1]
+ pad_y = self._object_padding_px[1][1]
+ p, q = np.round(np.max(self._positions_px_all, axis=0))
+ p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype(
+ "int"
+ )
+ q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype(
+ "int"
+ )
+ self._object = xp.zeros((q, p, q), dtype=xp.float32)
+ else:
+ self._object = xp.asarray(self._object, dtype=xp.float32)
+
+ self._object_initial = self._object.copy()
+ self._object_type_initial = self._object_type
+ self._object_shape = self._object.shape[-2:]
+ self._num_voxels = self._object.shape[0]
+
+ # Center Probes
+ self._positions_px_all = xp.asarray(self._positions_px_all, dtype=xp.float32)
+
+ for tilt_index in range(self._num_tilts):
+ self._positions_px = self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ]
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px -= (
+ self._positions_px_com - xp.array(self._object_shape) / 2
+ )
+
+ self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ] = self._positions_px.copy()
+
+ self._positions_px_initial_all = self._positions_px_all.copy()
+ self._positions_initial_all = self._positions_px_initial_all.copy()
+ self._positions_initial_all[:, 0] *= self.sampling[0]
+ self._positions_initial_all[:, 1] *= self.sampling[1]
+
+ # Probe Initialization
+ if self._probe is None:
+ if self._vacuum_probe_intensity is not None:
+ self._semiangle_cutoff = np.inf
+ self._vacuum_probe_intensity = xp.asarray(
+ self._vacuum_probe_intensity, dtype=xp.float32
+ )
+ probe_x0, probe_y0 = get_CoM(
+ self._vacuum_probe_intensity, device=self._device
+ )
+ self._vacuum_probe_intensity = get_shifted_ar(
+ self._vacuum_probe_intensity,
+ -probe_x0,
+ -probe_y0,
+ bilinear=True,
+ device=self._device,
+ )
+ if crop_patterns:
+ self._vacuum_probe_intensity = self._vacuum_probe_intensity[
+ self._crop_mask
+ ].reshape(self._region_of_interest_shape)
+
+ self._probe = (
+ ComplexProbe(
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ energy=self._energy,
+ semiangle_cutoff=self._semiangle_cutoff,
+ rolloff=self._rolloff,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )
+ .build()
+ ._array
+ )
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(
+ sum(self._mean_diffraction_intensity)
+ / self._num_tilts
+ / probe_intensity
+ )
+
+ else:
+ if isinstance(self._probe, ComplexProbe):
+ if self._probe._gpts != self._region_of_interest_shape:
+ raise ValueError()
+ if hasattr(self._probe, "_array"):
+ self._probe = self._probe._array
+ else:
+ self._probe._xp = xp
+ self._probe = self._probe.build()._array
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(
+ sum(self._mean_diffraction_intensity)
+ / self._num_tilts
+ / probe_intensity
+ )
+ else:
+ self._probe = xp.asarray(self._probe, dtype=xp.complex64)
+
+ self._probe_initial = self._probe.copy()
+ self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe))
+
+ self._known_aberrations_array = ComplexProbe(
+ energy=self._energy,
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )._evaluate_ctf()
+
+ # Precomputed propagator arrays
+ self._slice_thicknesses = np.tile(
+ self._object_shape[1] * self.sampling[1] / self._num_slices,
+ self._num_slices - 1,
+ )
+ self._propagator_arrays = self._precompute_propagator_arrays(
+ self._region_of_interest_shape,
+ self.sampling,
+ self._energy,
+ self._slice_thicknesses,
+ )
+
+ # overlaps
+ if object_fov_mask is None:
+ probe_overlap_3D = xp.zeros_like(self._object)
+ old_rot_matrix = np.eye(3) # identity
+
+ for tilt_index in np.arange(self._num_tilts):
+ rot_matrix = self._tilt_orientation_matrices[tilt_index]
+
+ probe_overlap_3D = self._rotate_zxy_volume(
+ probe_overlap_3D,
+ rot_matrix @ old_rot_matrix.T,
+ )
+
+ self._positions_px = self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ shifted_probes = fft_shift(
+ self._probe, self._positions_px_fractional, xp
+ )
+ probe_intensities = xp.abs(shifted_probes) ** 2
+ probe_overlap = self._sum_overlapping_patches_bincounts(
+ probe_intensities
+ )
+
+ probe_overlap_3D += probe_overlap[None]
+ old_rot_matrix = rot_matrix
+
+ probe_overlap_3D = self._rotate_zxy_volume(
+ probe_overlap_3D,
+ old_rot_matrix.T,
+ )
+
+ probe_overlap_3D = self._gaussian_filter(probe_overlap_3D, 1.0)
+ self._object_fov_mask = asnumpy(
+ probe_overlap_3D > 0.25 * probe_overlap_3D.max()
+ )
+ else:
+ self._object_fov_mask = np.asarray(object_fov_mask)
+ self._positions_px = self._positions_px_all[: self._cum_probes_per_tilt[1]]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp)
+ probe_intensities = xp.abs(shifted_probes) ** 2
+ probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities)
+ probe_overlap = self._gaussian_filter(probe_overlap, 1.0)
+
+ self._object_fov_mask_inverse = np.invert(self._object_fov_mask)
+
+ if plot_probe_overlaps:
+ figsize = kwargs.pop("figsize", (13, 4))
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ # initial probe
+ complex_probe_rgb = Complex2RGB(
+ self.probe_centered,
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ # propagated
+ propagated_probe = self._probe.copy()
+
+ for s in range(self._num_slices - 1):
+ propagated_probe = self._propagate_array(
+ propagated_probe, self._propagator_arrays[s]
+ )
+ complex_propagated_rgb = Complex2RGB(
+ asnumpy(self._return_centered_probe(propagated_probe)),
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ extent = [
+ 0,
+ self.sampling[1] * self._object_shape[1],
+ self.sampling[0] * self._object_shape[0],
+ 0,
+ ]
+
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)
+
+ ax1.imshow(
+ complex_probe_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax1)
+ cax1 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(
+ cax1,
+ chroma_boost=chroma_boost,
+ )
+ ax1.set_ylabel("x [A]")
+ ax1.set_xlabel("y [A]")
+ ax1.set_title("Initial probe intensity")
+
+ ax2.imshow(
+ complex_propagated_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax2)
+ cax2 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(
+ cax2,
+ chroma_boost=chroma_boost,
+ )
+ ax2.set_ylabel("x [A]")
+ ax2.set_xlabel("y [A]")
+ ax2.set_title("Propagated probe intensity")
+
+ ax3.imshow(
+ asnumpy(probe_overlap),
+ extent=extent,
+ cmap="Greys_r",
+ )
+ ax3.scatter(
+ self.positions[0, :, 1],
+ self.positions[0, :, 0],
+ s=2.5,
+ color=(1, 0, 0, 1),
+ )
+ ax3.set_ylabel("x [A]")
+ ax3.set_xlabel("y [A]")
+ ax3.set_xlim((extent[0], extent[1]))
+ ax3.set_ylim((extent[2], extent[3]))
+ ax3.set_title("Object field of view")
+
+ fig.tight_layout()
+
+ self._preprocessed = True
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _overlap_projection(self, current_object, current_probe):
+ """
+ Ptychographic overlap projection method.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ """
+
+ xp = self._xp
+
+ complex_object = xp.exp(1j * current_object)
+ object_patches = complex_object[
+ :, self._vectorized_patch_indices_row, self._vectorized_patch_indices_col
+ ]
+
+ propagated_probes = xp.empty_like(object_patches)
+ propagated_probes[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes = object_patches[s] * propagated_probes[s]
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes[s + 1] = self._propagate_array(
+ transmitted_probes, self._propagator_arrays[s]
+ )
+
+ return propagated_probes, object_patches, transmitted_probes
+
+ def _gradient_descent_fourier_projection(self, amplitudes, transmitted_probes):
+ """
+ Ptychographic fourier projection method for GD method.
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit wave difference
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ fourier_exit_waves = xp.fft.fft2(transmitted_probes)
+
+ error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2)
+
+ modified_exit_wave = xp.fft.ifft2(
+ amplitudes * xp.exp(1j * xp.angle(fourier_exit_waves))
+ )
+
+ exit_waves = modified_exit_wave - transmitted_probes
+
+ return exit_waves, error
+
+ def _projection_sets_fourier_projection(
+ self,
+ amplitudes,
+ transmitted_probes,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic fourier projection method for DM_AP and RAAR methods.
+ Generalized projection using three parameters: a,b,c
+
+ DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha
+ DM: DM_AP(1.0), AP: DM_AP(0.0)
+
+ RAAR(\\beta) : a = 1-2\\beta, b = \beta, c = 2
+ DM : RAAR(1.0)
+
+ RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2
+ DM: RRR(1.0)
+
+ SUPERFLIP : a = 0, b = 1, c = 2
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit wave difference
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ projection_x = 1 - projection_a - projection_b
+ projection_y = 1 - projection_c
+
+ if exit_waves is None:
+ exit_waves = transmitted_probes.copy()
+
+ fourier_exit_waves = xp.fft.fft2(transmitted_probes)
+ error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_exit_waves)) ** 2)
+
+ factor_to_be_projected = (
+ projection_c * transmitted_probes + projection_y * exit_waves
+ )
+ fourier_projected_factor = xp.fft.fft2(factor_to_be_projected)
+
+ fourier_projected_factor = amplitudes * xp.exp(
+ 1j * xp.angle(fourier_projected_factor)
+ )
+ projected_factor = xp.fft.ifft2(fourier_projected_factor)
+
+ exit_waves = (
+ projection_x * exit_waves
+ + projection_a * transmitted_probes
+ + projection_b * projected_factor
+ )
+
+ return exit_waves, error
+
+ def _forward(
+ self,
+ current_object,
+ current_probe,
+ amplitudes,
+ exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic forward operator.
+ Calls _overlap_projection() and the appropriate _fourier_projection().
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ propagated_probes:np.ndarray
+ Prop[object^n*probe^n]
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ (
+ propagated_probes,
+ object_patches,
+ transmitted_probes,
+ ) = self._overlap_projection(current_object, current_probe)
+
+ if use_projection_scheme:
+ (
+ exit_waves[self._active_tilt_index],
+ error,
+ ) = self._projection_sets_fourier_projection(
+ amplitudes,
+ transmitted_probes,
+ exit_waves[self._active_tilt_index],
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ else:
+ exit_waves, error = self._gradient_descent_fourier_projection(
+ amplitudes, transmitted_probes
+ )
+
+ return propagated_probes, object_patches, transmitted_probes, exit_waves, error
+
+ def _gradient_descent_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ for s in reversed(range(self._num_slices)):
+ probe = propagated_probes[s]
+ obj = object_patches[s]
+
+ # object-update
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(probe) ** 2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ current_object[s] += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves)
+ )
+ * probe_normalization
+ )
+
+ # back-transmit
+ exit_waves *= xp.conj(obj)
+
+ if s > 0:
+ # back-propagate
+ exit_waves = self._propagate_array(
+ exit_waves, xp.conj(self._propagator_arrays[s - 1])
+ )
+ elif not fix_probe:
+ # probe-update
+ object_normalization = xp.sum(
+ (xp.abs(obj) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe += (
+ step_size
+ * xp.sum(
+ exit_waves,
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return current_object, current_probe
+
+ def _projection_sets_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for DM_AP and RAAR methods.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ propagated_probes: np.ndarray
+ Shifted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ # careful not to modify exit_waves in-place for projection set methods
+ exit_waves_copy = exit_waves.copy()
+ for s in reversed(range(self._num_slices)):
+ probe = propagated_probes[s]
+ obj = object_patches[s]
+
+ # object-update
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(probe) ** 2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ current_object[s] = (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(-1j * xp.conj(obj) * xp.conj(probe) * exit_waves_copy)
+ )
+ * probe_normalization
+ )
+
+ # back-transmit
+ exit_waves_copy *= xp.conj(obj)
+
+ if s > 0:
+ # back-propagate
+ exit_waves_copy = self._propagate_array(
+ exit_waves_copy, xp.conj(self._propagator_arrays[s - 1])
+ )
+
+ elif not fix_probe:
+ # probe-update
+ object_normalization = xp.sum(
+ (xp.abs(obj) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe = (
+ xp.sum(
+ exit_waves_copy,
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return current_object, current_probe
+
+ def _adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ use_projection_scheme: bool,
+ step_size: float,
+ normalization_min: float,
+ fix_probe: bool,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ transmitted_probes: np.ndarray
+ Transmitted probes at each layer
+ exit_waves:np.ndarray
+ Updated exit_waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+
+ if use_projection_scheme:
+ current_object, current_probe = self._projection_sets_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves[self._active_tilt_index],
+ normalization_min,
+ fix_probe,
+ )
+ else:
+ current_object, current_probe = self._gradient_descent_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ propagated_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ )
+
+ return current_object, current_probe
+
+ def _position_correction(
+ self,
+ current_object,
+ current_probe,
+ transmitted_probes,
+ amplitudes,
+ current_positions,
+ positions_step_size,
+ constrain_position_distance,
+ ):
+ """
+ Position correction using estimated intensity gradient.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe:np.ndarray
+ fractionally-shifted probes
+ transmitted_probes: np.ndarray
+ Transmitted probes after N-1 propagations and N transmissions
+ amplitudes: np.ndarray
+ Measured amplitudes
+ current_positions: np.ndarray
+ Current positions estimate
+ positions_step_size: float
+ Positions step size
+ constrain_position_distance: float
+ Distance to constrain position correction within original
+ field of view in A
+
+ Returns
+ --------
+ updated_positions: np.ndarray
+ Updated positions estimate
+ """
+
+ xp = self._xp
+
+ # Intensity gradient
+ exit_waves_fft = xp.fft.fft2(transmitted_probes)
+ exit_waves_fft_conj = xp.conj(exit_waves_fft)
+ estimated_intensity = xp.abs(exit_waves_fft) ** 2
+ measured_intensity = amplitudes**2
+
+ flat_shape = (transmitted_probes.shape[0], -1)
+ difference_intensity = (measured_intensity - estimated_intensity).reshape(
+ flat_shape
+ )
+
+ # Computing perturbed exit waves one at a time to save on memory
+
+ complex_object = xp.exp(1j * current_object)
+
+ # dx
+ obj_rolled_patches = complex_object[
+ :,
+ (self._vectorized_patch_indices_row + 1) % self._object_shape[0],
+ self._vectorized_patch_indices_col,
+ ]
+
+ propagated_probes_perturbed = xp.empty_like(obj_rolled_patches)
+ propagated_probes_perturbed[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes_perturbed = (
+ obj_rolled_patches[s] * propagated_probes_perturbed[s]
+ )
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes_perturbed[s + 1] = self._propagate_array(
+ transmitted_probes_perturbed, self._propagator_arrays[s]
+ )
+
+ exit_waves_dx_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed)
+
+ # dy
+ obj_rolled_patches = complex_object[
+ :,
+ self._vectorized_patch_indices_row,
+ (self._vectorized_patch_indices_col + 1) % self._object_shape[1],
+ ]
+
+ propagated_probes_perturbed = xp.empty_like(obj_rolled_patches)
+ propagated_probes_perturbed[0] = fft_shift(
+ current_probe, self._positions_px_fractional, xp
+ )
+
+ for s in range(self._num_slices):
+ # transmit
+ transmitted_probes_perturbed = (
+ obj_rolled_patches[s] * propagated_probes_perturbed[s]
+ )
+
+ # propagate
+ if s + 1 < self._num_slices:
+ propagated_probes_perturbed[s + 1] = self._propagate_array(
+ transmitted_probes_perturbed, self._propagator_arrays[s]
+ )
+
+ exit_waves_dy_fft = exit_waves_fft - xp.fft.fft2(transmitted_probes_perturbed)
+
+ partial_intensity_dx = 2 * xp.real(
+ exit_waves_dx_fft * exit_waves_fft_conj
+ ).reshape(flat_shape)
+ partial_intensity_dy = 2 * xp.real(
+ exit_waves_dy_fft * exit_waves_fft_conj
+ ).reshape(flat_shape)
+
+ coefficients_matrix = xp.dstack((partial_intensity_dx, partial_intensity_dy))
+
+ # positions_update = xp.einsum(
+ # "idk,ik->id", xp.linalg.pinv(coefficients_matrix), difference_intensity
+ # )
+
+ coefficients_matrix_T = coefficients_matrix.conj().swapaxes(-1, -2)
+ positions_update = (
+ xp.linalg.inv(coefficients_matrix_T @ coefficients_matrix)
+ @ coefficients_matrix_T
+ @ difference_intensity[..., None]
+ )
+
+ if constrain_position_distance is not None:
+ constrain_position_distance /= xp.sqrt(
+ self.sampling[0] ** 2 + self.sampling[1] ** 2
+ )
+ x1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 0
+ ]
+ y1 = (current_positions - positions_step_size * positions_update[..., 0])[
+ :, 1
+ ]
+ x0 = self._positions_px_initial[:, 0]
+ y0 = self._positions_px_initial[:, 1]
+ if self._rotation_best_transpose:
+ x0, y0 = xp.array([y0, x0])
+ x1, y1 = xp.array([y1, x1])
+
+ if self._rotation_best_rad is not None:
+ rotation_angle = self._rotation_best_rad
+ x0, y0 = x0 * xp.cos(-rotation_angle) + y0 * xp.sin(
+ -rotation_angle
+ ), -x0 * xp.sin(-rotation_angle) + y0 * xp.cos(-rotation_angle)
+ x1, y1 = x1 * xp.cos(-rotation_angle) + y1 * xp.sin(
+ -rotation_angle
+ ), -x1 * xp.sin(-rotation_angle) + y1 * xp.cos(-rotation_angle)
+
+ outlier_ind = (x1 > (xp.max(x0) + constrain_position_distance)) + (
+ x1 < (xp.min(x0) - constrain_position_distance)
+ ) + (y1 > (xp.max(y0) + constrain_position_distance)) + (
+ y1 < (xp.min(y0) - constrain_position_distance)
+ ) > 0
+
+ positions_update[..., 0][outlier_ind] = 0
+ current_positions -= positions_step_size * positions_update[..., 0]
+
+ return current_positions
+
+ def _object_gaussian_constraint(self, current_object, gaussian_filter_sigma):
+ """
+ Ptychographic smoothness constraint.
+ Used for blurring object.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ gaussian_filter = self._gaussian_filter
+
+ gaussian_filter_sigma /= self.sampling[0]
+ current_object = gaussian_filter(current_object, gaussian_filter_sigma)
+
+ return current_object
+
+ def _object_butterworth_constraint(
+ self, current_object, q_lowpass, q_highpass, butterworth_order
+ ):
+ """
+ Butterworth filter
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+ qz = xp.fft.fftfreq(current_object.shape[0], self.sampling[1])
+ qx = xp.fft.fftfreq(current_object.shape[1], self.sampling[0])
+ qy = xp.fft.fftfreq(current_object.shape[2], self.sampling[1])
+ qza, qxa, qya = xp.meshgrid(qz, qx, qy, indexing="ij")
+ qra = xp.sqrt(qza**2 + qxa**2 + qya**2)
+
+ env = xp.ones_like(qra)
+ if q_highpass:
+ env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order))
+ if q_lowpass:
+ env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order))
+
+ current_object_mean = xp.mean(current_object)
+ current_object -= current_object_mean
+ current_object = xp.fft.ifftn(xp.fft.fftn(current_object) * env)
+ current_object += current_object_mean
+ return xp.real(current_object)
+
+ def _object_denoise_tv_pylops(self, current_object, weights, iterations):
+ """
+ Performs second order TV denoising along x and y
+
+ Parameters
+ ----------
+ current_object: np.ndarray
+ Current object estimate
+ weights : [float, float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ iterations: float
+ Number of iterations to run in denoising algorithm.
+ `niter_out` in pylops
+
+ Returns
+ -------
+ constrained_object: np.ndarray
+ Constrained object estimate
+
+ """
+ xp = self._xp
+
+ if xp.iscomplexobj(current_object):
+ current_object_tv = current_object
+ warnings.warn(
+ ("TV denoising is currently only supported for potential objects."),
+ UserWarning,
+ )
+
+ else:
+ # zero pad at top and bottom slice
+ pad_width = ((1, 1), (0, 0), (0, 0))
+ current_object = xp.pad(
+ current_object, pad_width=pad_width, mode="constant"
+ )
+
+ # run tv denoising
+ nz, nx, ny = current_object.shape
+ niter_out = iterations
+ niter_in = 1
+ Iop = pylops.Identity(nx * ny * nz)
+
+ if weights[0] == 0:
+ xy_laplacian = pylops.Laplacian(
+ (nz, nx, ny), axes=(1, 2), edge=False, kind="backward"
+ )
+ l1_regs = [xy_laplacian]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=[weights[1]],
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ elif weights[1] == 0:
+ z_gradient = pylops.FirstDerivative(
+ (nz, nx, ny), axis=0, edge=False, kind="backward"
+ )
+ l1_regs = [z_gradient]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=[weights[0]],
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ else:
+ z_gradient = pylops.FirstDerivative(
+ (nz, nx, ny), axis=0, edge=False, kind="backward"
+ )
+ xy_laplacian = pylops.Laplacian(
+ (nz, nx, ny), axes=(1, 2), edge=False, kind="backward"
+ )
+ l1_regs = [z_gradient, xy_laplacian]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=weights,
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ # remove padding
+ current_object_tv = current_object_tv.reshape(current_object.shape)[1:-1]
+
+ return current_object_tv
+
+ def _constraints(
+ self,
+ current_object,
+ current_probe,
+ current_positions,
+ fix_com,
+ fit_probe_aberrations,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ constrain_probe_amplitude,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ fix_probe_aperture,
+ initial_probe_aperture,
+ fix_positions,
+ global_affine_transformation,
+ gaussian_filter,
+ gaussian_filter_sigma,
+ butterworth_filter,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ object_positivity,
+ shrinkage_rad,
+ object_mask,
+ tv_denoise,
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ ):
+ """
+ Ptychographic constraints operator.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ current_positions: np.ndarray
+ Current positions estimate
+ fix_com: bool
+ If True, probe CoM is fixed to the center
+ fit_probe_aberrations: bool
+ If True, fits the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ constrain_probe_amplitude: bool
+ If True, probe amplitude is constrained by top hat function
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude: bool
+ If True, probe aperture is constrained by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_probe_aperture: bool,
+ If True, probe Fourier amplitude is replaced by initial probe aperture.
+ initial_probe_aperture: np.ndarray,
+ Initial probe aperture to use in replacing probe Fourier amplitude.
+ fix_positions: bool
+ If True, positions are not updated
+ gaussian_filter: bool
+ If True, applies real-space gaussian filter
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A
+ butterworth_filter: bool
+ If True, applies fourier-space butterworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ object_positivity: bool
+ If True, forces object to be positive
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ object_mask: np.ndarray (boolean)
+ If not None, used to calculate additional shrinkage using masked-mean of object
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weights: [float,float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ constrained_positions: np.ndarray
+ Constrained positions estimate
+ """
+
+ if gaussian_filter:
+ current_object = self._object_gaussian_constraint(
+ current_object, gaussian_filter_sigma
+ )
+
+ if butterworth_filter:
+ current_object = self._object_butterworth_constraint(
+ current_object,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ )
+ if tv_denoise:
+ current_object = self._object_denoise_tv_pylops(
+ current_object,
+ tv_denoise_weights,
+ tv_denoise_inner_iter,
+ )
+
+ if shrinkage_rad > 0.0 or object_mask is not None:
+ current_object = self._object_shrinkage_constraint(
+ current_object,
+ shrinkage_rad,
+ object_mask,
+ )
+
+ if object_positivity:
+ current_object = self._object_positivity_constraint(current_object)
+
+ if fix_com:
+ current_probe = self._probe_center_of_mass_constraint(current_probe)
+
+ if fix_probe_aperture:
+ current_probe = self._probe_aperture_constraint(
+ current_probe,
+ initial_probe_aperture,
+ )
+ elif constrain_probe_fourier_amplitude:
+ current_probe = self._probe_fourier_amplitude_constraint(
+ current_probe,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ )
+
+ if fit_probe_aberrations:
+ current_probe = self._probe_aberration_fitting_constraint(
+ current_probe,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ )
+
+ if constrain_probe_amplitude:
+ current_probe = self._probe_amplitude_constraint(
+ current_probe,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ )
+
+ if not fix_positions:
+ current_positions = self._positions_center_of_mass_constraint(
+ current_positions
+ )
+
+ if global_affine_transformation:
+ current_positions = self._positions_affine_transformation_constraint(
+ self._positions_px_initial, current_positions
+ )
+
+ return current_object, current_probe, current_positions
+
+ def reconstruct(
+ self,
+ max_iter: int = 64,
+ reconstruction_method: str = "gradient-descent",
+ reconstruction_parameter: float = 1.0,
+ reconstruction_parameter_a: float = None,
+ reconstruction_parameter_b: float = None,
+ reconstruction_parameter_c: float = None,
+ max_batch_size: int = None,
+ seed_random: int = None,
+ step_size: float = 0.5,
+ normalization_min: float = 1,
+ positions_step_size: float = 0.9,
+ fix_com: bool = True,
+ fix_probe_iter: int = 0,
+ fix_probe_aperture_iter: int = 0,
+ constrain_probe_amplitude_iter: int = 0,
+ constrain_probe_amplitude_relative_radius: float = 0.5,
+ constrain_probe_amplitude_relative_width: float = 0.05,
+ constrain_probe_fourier_amplitude_iter: int = 0,
+ constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0,
+ constrain_probe_fourier_amplitude_constant_intensity: bool = False,
+ fix_positions_iter: int = np.inf,
+ constrain_position_distance: float = None,
+ global_affine_transformation: bool = True,
+ gaussian_filter_sigma: float = None,
+ gaussian_filter_iter: int = np.inf,
+ fit_probe_aberrations_iter: int = 0,
+ fit_probe_aberrations_max_angular_order: int = 4,
+ fit_probe_aberrations_max_radial_order: int = 4,
+ butterworth_filter_iter: int = np.inf,
+ q_lowpass: float = None,
+ q_highpass: float = None,
+ butterworth_order: float = 2,
+ object_positivity: bool = True,
+ shrinkage_rad: float = 0.0,
+ fix_potential_baseline: bool = True,
+ tv_denoise_iter=np.inf,
+ tv_denoise_weights=None,
+ tv_denoise_inner_iter=40,
+ collective_tilt_updates: bool = False,
+ store_iterations: bool = False,
+ progress_bar: bool = True,
+ reset: bool = None,
+ ):
+ """
+ Ptychographic reconstruction main method.
+
+ Parameters
+ --------
+ max_iter: int, optional
+ Maximum number of iterations to run
+ reconstruction_method: str, optional
+ Specifies which reconstruction algorithm to use, one of:
+ "generalized-projections",
+ "DM_AP" (or "difference-map_alternating-projections"),
+ "RAAR" (or "relaxed-averaged-alternating-reflections"),
+ "RRR" (or "relax-reflect-reflect"),
+ "SUPERFLIP" (or "charge-flipping"), or
+ "GD" (or "gradient_descent")
+ reconstruction_parameter: float, optional
+ Reconstruction parameter for various reconstruction methods above.
+ reconstruction_parameter_a: float, optional
+ Reconstruction parameter a for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_b: float, optional
+ Reconstruction parameter b for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_c: float, optional
+ Reconstruction parameter c for reconstruction_method='generalized-projections'.
+ max_batch_size: int, optional
+ Max number of probes to update at once
+ seed_random: int, optional
+ Seeds the random number generator, only applicable when max_batch_size is not None
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ positions_step_size: float, optional
+ Positions update step size
+ fix_com: bool, optional
+ If True, fixes center of mass of probe
+ fix_probe_iter: int, optional
+ Number of iterations to run with a fixed probe before updating probe estimate
+ fix_probe_aperture_iter: int, optional
+ Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate
+ constrain_probe_amplitude_iter: int, optional
+ Number of iterations to run while constraining the real-space probe with a top-hat support.
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude_iter: int, optional
+ Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_positions_iter: int, optional
+ Number of iterations to run with fixed positions before updating positions estimate
+ constrain_position_distance: float, optional
+ Distance to constrain position correction within original
+ field of view in A
+ global_affine_transformation: bool, optional
+ If True, positions are assumed to be a global affine transform from initial scan
+ gaussian_filter_sigma: float, optional
+ Standard deviation of gaussian kernel in A
+ gaussian_filter_iter: int, optional
+ Number of iterations to run using object smoothness constraint
+ fit_probe_aberrations_iter: int, optional
+ Number of iterations to run while fitting the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ butterworth_filter_iter: int, optional
+ Number of iterations to run using high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ object_positivity: bool, optional
+ If True, forces object to be positive
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weights: [float,float]
+ Denoising weights[z weight, r weight]. The greater `weight`,
+ the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ collective_tilt_updates: bool
+ if True perform collective tilt updates
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ store_iterations: bool, optional
+ If True, reconstructed objects and probes are stored at each iteration
+ progress_bar: bool, optional
+ If True, reconstruction progress is displayed
+ reset: bool, optional
+ If True, previous reconstructions are ignored
+
+ Returns
+ --------
+ self: OverlapTomographicReconstruction
+ Self to accommodate chaining
+ """
+ asnumpy = self._asnumpy
+ xp = self._xp
+
+ # Reconstruction method
+
+ if reconstruction_method == "generalized-projections":
+ if (
+ reconstruction_parameter_a is None
+ or reconstruction_parameter_b is None
+ or reconstruction_parameter_c is None
+ ):
+ raise ValueError(
+ (
+ "reconstruction_parameter_a/b/c must all be specified "
+ "when using reconstruction_method='generalized-projections'."
+ )
+ )
+
+ use_projection_scheme = True
+ projection_a = reconstruction_parameter_a
+ projection_b = reconstruction_parameter_b
+ projection_c = reconstruction_parameter_c
+ step_size = None
+ elif (
+ reconstruction_method == "DM_AP"
+ or reconstruction_method == "difference-map_alternating-projections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = 1
+ projection_c = 1 + reconstruction_parameter
+ step_size = None
+ elif (
+ reconstruction_method == "RAAR"
+ or reconstruction_method == "relaxed-averaged-alternating-reflections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = 1 - 2 * reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "RRR"
+ or reconstruction_method == "relax-reflect-reflect"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0:
+ raise ValueError("reconstruction_parameter must be between 0-2.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "SUPERFLIP"
+ or reconstruction_method == "charge-flipping"
+ ):
+ use_projection_scheme = True
+ projection_a = 0
+ projection_b = 1
+ projection_c = 2
+ reconstruction_parameter = None
+ step_size = None
+ elif (
+ reconstruction_method == "GD" or reconstruction_method == "gradient-descent"
+ ):
+ use_projection_scheme = False
+ projection_a = None
+ projection_b = None
+ projection_c = None
+ reconstruction_parameter = None
+ else:
+ raise ValueError(
+ (
+ "reconstruction_method must be one of 'generalized-projections', "
+ "'DM_AP' (or 'difference-map_alternating-projections'), "
+ "'RAAR' (or 'relaxed-averaged-alternating-reflections'), "
+ "'RRR' (or 'relax-reflect-reflect'), "
+ "'SUPERFLIP' (or 'charge-flipping'), "
+ f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}."
+ )
+ )
+
+ if self._verbose:
+ if max_batch_size is not None:
+ if use_projection_scheme:
+ raise ValueError(
+ (
+ "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. "
+ "Use reconstruction_method='GD' or set max_batch_size=None."
+ )
+ )
+ else:
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}, "
+ f"in batches of max {max_batch_size} measurements."
+ )
+ )
+ else:
+ if reconstruction_parameter is not None:
+ if np.array(reconstruction_parameter).shape == (3,):
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}."
+ )
+ )
+ else:
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}."
+ )
+ )
+ else:
+ if step_size is not None:
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min}."
+ )
+ )
+ else:
+ print(
+ (
+ f"Performing {max_iter} iterations using the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}."
+ )
+ )
+
+ # Position Correction + Collective Updates not yet implemented
+ if fix_positions_iter < max_iter:
+ raise NotImplementedError(
+ "Position correction is currently incompatible with collective updates."
+ )
+
+ # Batching
+
+ if max_batch_size is not None:
+ xp.random.seed(seed_random)
+ else:
+ max_batch_size = self._num_diffraction_patterns
+
+ # initialization
+ if store_iterations and (not hasattr(self, "object_iterations") or reset):
+ self.object_iterations = []
+ self.probe_iterations = []
+
+ if reset:
+ self._object = self._object_initial.copy()
+ self.error_iterations = []
+ self._probe = self._probe_initial.copy()
+ self._positions_px_all = self._positions_px_initial_all.copy()
+ if hasattr(self, "_tf"):
+ del self._tf
+
+ if use_projection_scheme:
+ self._exit_waves = [None] * self._num_tilts
+ else:
+ self._exit_waves = None
+ elif reset is None:
+ if hasattr(self, "error"):
+ warnings.warn(
+ (
+ "Continuing reconstruction from previous result. "
+ "Use reset=True for a fresh start."
+ ),
+ UserWarning,
+ )
+ else:
+ self.error_iterations = []
+ if use_projection_scheme:
+ self._exit_waves = [None] * self._num_tilts
+ else:
+ self._exit_waves = None
+
+ # main loop
+ for a0 in tqdmnd(
+ max_iter,
+ desc="Reconstructing object and probe",
+ unit=" iter",
+ disable=not progress_bar,
+ ):
+ error = 0.0
+
+ if collective_tilt_updates:
+ collective_object = xp.zeros_like(self._object)
+
+ tilt_indices = np.arange(self._num_tilts)
+ np.random.shuffle(tilt_indices)
+
+ old_rot_matrix = np.eye(3) # identity
+
+ for tilt_index in tilt_indices:
+ self._active_tilt_index = tilt_index
+
+ tilt_error = 0.0
+
+ rot_matrix = self._tilt_orientation_matrices[self._active_tilt_index]
+ self._object = self._rotate_zxy_volume(
+ self._object,
+ rot_matrix @ old_rot_matrix.T,
+ )
+
+ object_sliced = self._project_sliced_object(
+ self._object, self._num_slices
+ )
+ if not use_projection_scheme:
+ object_sliced_old = object_sliced.copy()
+
+ start_tilt = self._cum_probes_per_tilt[self._active_tilt_index]
+ end_tilt = self._cum_probes_per_tilt[self._active_tilt_index + 1]
+
+ num_diffraction_patterns = end_tilt - start_tilt
+ shuffled_indices = np.arange(num_diffraction_patterns)
+ unshuffled_indices = np.zeros_like(shuffled_indices)
+
+ # randomize
+ if not use_projection_scheme:
+ np.random.shuffle(shuffled_indices)
+
+ unshuffled_indices[shuffled_indices] = np.arange(
+ num_diffraction_patterns
+ )
+
+ positions_px = self._positions_px_all[start_tilt:end_tilt].copy()[
+ shuffled_indices
+ ]
+ initial_positions_px = self._positions_px_initial_all[
+ start_tilt:end_tilt
+ ].copy()[shuffled_indices]
+
+ for start, end in generate_batches(
+ num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_initial = initial_positions_px[start:end]
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ amplitudes = self._amplitudes[start_tilt:end_tilt][
+ shuffled_indices[start:end]
+ ]
+
+ # forward operator
+ (
+ propagated_probes,
+ object_patches,
+ transmitted_probes,
+ self._exit_waves,
+ batch_error,
+ ) = self._forward(
+ object_sliced,
+ self._probe,
+ amplitudes,
+ self._exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ # adjoint operator
+ object_sliced, self._probe = self._adjoint(
+ object_sliced,
+ self._probe,
+ object_patches,
+ propagated_probes,
+ self._exit_waves,
+ use_projection_scheme=use_projection_scheme,
+ step_size=step_size,
+ normalization_min=normalization_min,
+ fix_probe=a0 < fix_probe_iter,
+ )
+
+ # position correction
+ if a0 >= fix_positions_iter:
+ positions_px[start:end] = self._position_correction(
+ object_sliced,
+ self._probe,
+ transmitted_probes,
+ amplitudes,
+ self._positions_px,
+ positions_step_size,
+ constrain_position_distance,
+ )
+
+ tilt_error += batch_error
+
+ if not use_projection_scheme:
+ object_sliced -= object_sliced_old
+
+ object_update = self._expand_sliced_object(
+ object_sliced, self._num_voxels
+ )
+
+ if collective_tilt_updates:
+ collective_object += self._rotate_zxy_volume(
+ object_update, rot_matrix.T
+ )
+ else:
+ self._object += object_update
+
+ old_rot_matrix = rot_matrix
+
+ # Normalize Error
+ tilt_error /= (
+ self._mean_diffraction_intensity[self._active_tilt_index]
+ * num_diffraction_patterns
+ )
+ error += tilt_error
+
+ # constraints
+ self._positions_px_all[start_tilt:end_tilt] = positions_px.copy()[
+ unshuffled_indices
+ ]
+
+ if not collective_tilt_updates:
+ (
+ self._object,
+ self._probe,
+ self._positions_px_all[start_tilt:end_tilt],
+ ) = self._constraints(
+ self._object,
+ self._probe,
+ self._positions_px_all[start_tilt:end_tilt],
+ fix_com=fix_com and a0 >= fix_probe_iter,
+ constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude=a0
+ < constrain_probe_fourier_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
+ fit_probe_aberrations=a0 < fit_probe_aberrations_iter
+ and a0 >= fix_probe_iter,
+ fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
+ fix_probe_aperture=a0 < fix_probe_aperture_iter,
+ initial_probe_aperture=self._probe_initial_aperture,
+ fix_positions=a0 < fix_positions_iter,
+ global_affine_transformation=global_affine_transformation,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma is not None,
+ gaussian_filter_sigma=gaussian_filter_sigma,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass is not None or q_highpass is not None),
+ q_lowpass=q_lowpass,
+ q_highpass=q_highpass,
+ butterworth_order=butterworth_order,
+ object_positivity=object_positivity,
+ shrinkage_rad=shrinkage_rad,
+ object_mask=self._object_fov_mask_inverse
+ if fix_potential_baseline
+ and self._object_fov_mask_inverse.sum() > 0
+ else None,
+ tv_denoise=a0 < tv_denoise_iter
+ and tv_denoise_weights is not None,
+ tv_denoise_weights=tv_denoise_weights,
+ tv_denoise_inner_iter=tv_denoise_inner_iter,
+ )
+
+ self._object = self._rotate_zxy_volume(self._object, old_rot_matrix.T)
+
+ # Normalize Error Over Tilts
+ error /= self._num_tilts
+
+ if collective_tilt_updates:
+ self._object += collective_object / self._num_tilts
+
+ (
+ self._object,
+ self._probe,
+ _,
+ ) = self._constraints(
+ self._object,
+ self._probe,
+ None,
+ fix_com=fix_com and a0 >= fix_probe_iter,
+ constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude=a0
+ < constrain_probe_fourier_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
+ fit_probe_aberrations=a0 < fit_probe_aberrations_iter
+ and a0 >= fix_probe_iter,
+ fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
+ fix_probe_aperture=a0 < fix_probe_aperture_iter,
+ initial_probe_aperture=self._probe_initial_aperture,
+ fix_positions=True,
+ global_affine_transformation=global_affine_transformation,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma is not None,
+ gaussian_filter_sigma=gaussian_filter_sigma,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass is not None or q_highpass is not None),
+ q_lowpass=q_lowpass,
+ q_highpass=q_highpass,
+ butterworth_order=butterworth_order,
+ object_positivity=object_positivity,
+ shrinkage_rad=shrinkage_rad,
+ object_mask=self._object_fov_mask_inverse
+ if fix_potential_baseline
+ and self._object_fov_mask_inverse.sum() > 0
+ else None,
+ tv_denoise=a0 < tv_denoise_iter and tv_denoise_weights is not None,
+ tv_denoise_weights=tv_denoise_weights,
+ tv_denoise_inner_iter=tv_denoise_inner_iter,
+ )
+
+ self.error_iterations.append(error.item())
+ if store_iterations:
+ self.object_iterations.append(asnumpy(self._object.copy()))
+ self.probe_iterations.append(self.probe_centered)
+
+ # store result
+ self.object = asnumpy(self._object)
+ self.probe = self.probe_centered
+ self.error = error.item()
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _crop_rotate_object_manually(
+ self,
+ array,
+ angle,
+ x_lims,
+ y_lims,
+ ):
+ """
+ Crops and rotates rotates object manually.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ Object array to crop and rotate. Only operates on numpy arrays for comptatibility.
+ angle: float
+ In-plane angle in degrees to rotate by
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+
+ Returns
+ -------
+ cropped_rotated_array: np.ndarray
+ Cropped and rotated object array
+ """
+
+ asnumpy = self._asnumpy
+ min_x, max_x = x_lims
+ min_y, max_y = y_lims
+
+ if angle is not None:
+ rotated_array = rotate_np(
+ asnumpy(array), angle, reshape=False, axes=(-2, -1)
+ )
+ else:
+ rotated_array = asnumpy(array)
+
+ return rotated_array[..., min_x:max_x, min_y:max_y]
+
+ def _visualize_last_iteration_figax(
+ self,
+ fig,
+ object_ax,
+ convergence_ax,
+ cbar: bool,
+ projection_angle_deg: float,
+ projection_axes: Tuple[int, int],
+ x_lims: Tuple[int, int],
+ y_lims: Tuple[int, int],
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object on a given fig/ax.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure object_ax lives in
+ object_ax: Axes
+ Matplotlib axes to plot reconstructed object in
+ convergence_ax: Axes, optional
+ Matplotlib axes to plot convergence plot in
+ cbar: bool, optional
+ If true, displays a colorbar
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ """
+
+ cmap = kwargs.pop("cmap", "magma")
+
+ asnumpy = self._asnumpy
+
+ if projection_angle_deg is not None:
+ rotated_3d_obj = self._rotate(
+ self._object,
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ )
+ rotated_3d_obj = asnumpy(rotated_3d_obj)
+ else:
+ rotated_3d_obj = self.object
+
+ rotated_object = self._crop_rotate_object_manually(
+ rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ im = object_ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(object_ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if convergence_ax is not None and hasattr(self, "error_iterations"):
+ errors = np.array(self.error_iterations)
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = self.error_iterations
+ convergence_ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+
+ def _visualize_last_iteration(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ projection_angle_deg: float,
+ projection_axes: Tuple[int, int],
+ x_lims: Tuple[int, int],
+ y_lims: Tuple[int, int],
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ """
+ figsize = kwargs.pop("figsize", (8, 5))
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ asnumpy = self._asnumpy
+
+ if projection_angle_deg is not None:
+ rotated_3d_obj = self._rotate(
+ self._object,
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ )
+ rotated_3d_obj = asnumpy(rotated_3d_obj)
+ else:
+ rotated_3d_obj = self.object
+
+ rotated_object = self._crop_rotate_object_manually(
+ rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=2,
+ height_ratios=[4, 1],
+ hspace=0.15,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=1,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ if plot_probe or plot_fourier_probe:
+ # Object
+ ax = fig.add_subplot(spec[0, 0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ ax.set_title("Reconstructed object projection")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ # Probe
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+
+ ax = fig.add_subplot(spec[0, 1])
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ self.probe_fourier,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title("Reconstructed Fourier probe")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ self.probe,
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title("Reconstructed probe intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ **kwargs,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(
+ ax_cb,
+ chroma_boost=chroma_boost,
+ )
+ else:
+ ax = fig.add_subplot(spec[0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ ax.set_title("Reconstructed object projection")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if plot_convergence and hasattr(self, "error_iterations"):
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = np.array(self.error_iterations)
+ if plot_probe:
+ ax = fig.add_subplot(spec[1, :])
+ else:
+ ax = fig.add_subplot(spec[1])
+ ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax.set_ylabel("NMSE")
+ ax.set_xlabel("Iteration number")
+ ax.yaxis.tick_right()
+
+ fig.suptitle(f"Normalized mean squared error: {self.error:.3e}")
+ spec.tight_layout(fig)
+
+ def _visualize_all_iterations(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ iterations_grid: Tuple[int, int],
+ projection_angle_deg: float,
+ projection_axes: Tuple[int, int],
+ x_lims: Tuple[int, int],
+ y_lims: Tuple[int, int],
+ **kwargs,
+ ):
+ """
+ Displays all reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ """
+ asnumpy = self._asnumpy
+
+ if not hasattr(self, "object_iterations"):
+ raise ValueError(
+ (
+ "Object and probe iterations were not saved during reconstruction. "
+ "Please re-run using store_iterations=True."
+ )
+ )
+
+ if iterations_grid == "auto":
+ num_iter = len(self.error_iterations)
+
+ if num_iter == 1:
+ return self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ projection_angle_deg=projection_angle_deg,
+ projection_axes=projection_axes,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ **kwargs,
+ )
+ elif plot_probe or plot_fourier_probe:
+ iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter)
+ else:
+ iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2)
+ else:
+ if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2:
+ raise ValueError()
+
+ auto_figsize = (
+ (3 * iterations_grid[1], 3 * iterations_grid[0] + 1)
+ if plot_convergence
+ else (3 * iterations_grid[1], 3 * iterations_grid[0])
+ )
+ figsize = kwargs.pop("figsize", auto_figsize)
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ errors = np.array(self.error_iterations)
+
+ if projection_angle_deg is not None:
+ objects = [
+ self._crop_rotate_object_manually(
+ rotate_np(
+ obj,
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ ).sum(0),
+ angle=None,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ )
+ for obj in self.object_iterations
+ ]
+ else:
+ objects = [
+ self._crop_rotate_object_manually(
+ obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+ for obj in self.object_iterations
+ ]
+
+ if plot_probe or plot_fourier_probe:
+ total_grids = (np.prod(iterations_grid) / 2).astype("int")
+ probes = self.probe_iterations
+ else:
+ total_grids = np.prod(iterations_grid)
+ max_iter = len(objects) - 1
+ grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1))
+
+ extent = [
+ 0,
+ self.sampling[1] * objects[0].shape[1],
+ self.sampling[0] * objects[0].shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0)
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=2)
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ grid = ImageGrid(
+ fig,
+ spec[0],
+ nrows_ncols=(1, iterations_grid[1]) if plot_probe else iterations_grid,
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ im = ax.imshow(
+ objects[grid_range[n]],
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} Object")
+
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if cbar:
+ grid.cbar_axes[n].colorbar(im)
+
+ if plot_probe or plot_fourier_probe:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ grid = ImageGrid(
+ fig,
+ spec[1],
+ nrows_ncols=(1, iterations_grid[1]),
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ asnumpy(
+ self._return_fourier_probe_from_centered_probe(
+ probes[grid_range[n]]
+ )
+ ),
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} Fourier probe")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ probes[grid_range[n]],
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} probe intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ add_colorbar_arg(
+ grid.cbar_axes[n],
+ chroma_boost=chroma_boost,
+ )
+
+ if plot_convergence:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ if plot_probe:
+ ax2 = fig.add_subplot(spec[2])
+ else:
+ ax2 = fig.add_subplot(spec[1])
+ ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax2.set_ylabel("NMSE")
+ ax2.set_xlabel("Iteration number")
+ ax2.yaxis.tick_right()
+
+ spec.tight_layout(fig)
+
+ def visualize(
+ self,
+ fig=None,
+ iterations_grid: Tuple[int, int] = None,
+ plot_convergence: bool = True,
+ plot_probe: bool = True,
+ plot_fourier_probe: bool = False,
+ cbar: bool = True,
+ projection_angle_deg: float = None,
+ projection_axes: Tuple[int, int] = (0, 2),
+ x_lims=(None, None),
+ y_lims=(None, None),
+ **kwargs,
+ ):
+ """
+ Displays reconstructed object and probe.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+
+ Returns
+ --------
+ self: OverlapTomographicReconstruction
+ Self to accommodate chaining
+ """
+
+ if iterations_grid is None:
+ self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ projection_angle_deg=projection_angle_deg,
+ projection_axes=projection_axes,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ **kwargs,
+ )
+ else:
+ self._visualize_all_iterations(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ iterations_grid=iterations_grid,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ projection_angle_deg=projection_angle_deg,
+ projection_axes=projection_axes,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ **kwargs,
+ )
+
+ return self
+
+ def _return_object_fft(
+ self,
+ obj=None,
+ projection_angle_deg: float = None,
+ projection_axes: Tuple[int, int] = (0, 2),
+ x_lims: Tuple[int, int] = (None, None),
+ y_lims: Tuple[int, int] = (None, None),
+ ):
+ """
+ Returns obj fft shifted to center of array
+
+ Parameters
+ ----------
+ obj: array, optional
+ if None is specified, uses self._object
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ if obj is None:
+ obj = self._object
+ else:
+ obj = xp.asarray(obj, dtype=xp.float32)
+
+ if projection_angle_deg is not None:
+ rotated_3d_obj = self._rotate(
+ obj,
+ projection_angle_deg,
+ axes=projection_axes,
+ reshape=False,
+ order=2,
+ )
+ rotated_3d_obj = asnumpy(rotated_3d_obj)
+ else:
+ rotated_3d_obj = asnumpy(obj)
+
+ rotated_object = self._crop_rotate_object_manually(
+ rotated_3d_obj.sum(0), angle=None, x_lims=x_lims, y_lims=y_lims
+ )
+
+ return np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(rotated_object))))
+
+ def show_object_fft(
+ self,
+ obj=None,
+ projection_angle_deg: float = None,
+ projection_axes: Tuple[int, int] = (0, 2),
+ x_lims: Tuple[int, int] = (None, None),
+ y_lims: Tuple[int, int] = (None, None),
+ **kwargs,
+ ):
+ """
+ Plot FFT of reconstructed object
+
+ Parameters
+ ----------
+ obj: array, optional
+ if None is specified, uses self._object
+ projection_angle_deg: float
+ Angle in degrees to rotate 3D array around prior to projection
+ projection_axes: tuple(int,int)
+ Axes defining projection plane
+ x_lims: tuple(float,float)
+ min/max x indices
+ y_lims: tuple(float,float)
+ min/max y indices
+ """
+ if obj is None:
+ object_fft = self._return_object_fft(
+ projection_angle_deg=projection_angle_deg,
+ projection_axes=projection_axes,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ )
+ else:
+ object_fft = self._return_object_fft(
+ obj,
+ projection_angle_deg=projection_angle_deg,
+ projection_axes=projection_axes,
+ x_lims=x_lims,
+ y_lims=y_lims,
+ )
+
+ figsize = kwargs.pop("figsize", (6, 6))
+ cmap = kwargs.pop("cmap", "magma")
+
+ pixelsize = 1 / (object_fft.shape[1] * self.sampling[1])
+ show(
+ object_fft,
+ figsize=figsize,
+ cmap=cmap,
+ scalebar=True,
+ pixelsize=pixelsize,
+ ticks=False,
+ pixelunits=r"$\AA^{-1}$",
+ **kwargs,
+ )
+
+ @property
+ def positions(self):
+ """Probe positions [A]"""
+
+ if self.angular_sampling is None:
+ return None
+
+ asnumpy = self._asnumpy
+ positions_all = []
+ for tilt_index in range(self._num_tilts):
+ positions = self._positions_px_all[
+ self._cum_probes_per_tilt[tilt_index] : self._cum_probes_per_tilt[
+ tilt_index + 1
+ ]
+ ].copy()
+ positions[:, 0] *= self.sampling[0]
+ positions[:, 1] *= self.sampling[1]
+ positions_all.append(asnumpy(positions))
+
+ return np.asarray(positions_all)
+
+ def _return_self_consistency_errors(
+ self,
+ max_batch_size=None,
+ ):
+ """Compute the self-consistency errors for each probe position"""
+ raise NotImplementedError()
+
+ def _return_projected_cropped_potential(
+ self,
+ ):
+ """Utility function to accommodate multiple classes"""
+ raise NotImplementedError()
+
+ def show_uncertainty_visualization(
+ self,
+ errors=None,
+ max_batch_size=None,
+ projected_cropped_potential=None,
+ kde_sigma=None,
+ plot_histogram=True,
+ plot_contours=False,
+ **kwargs,
+ ):
+ """Plot uncertainty visualization using self-consistency errors"""
+ raise NotImplementedError()
diff --git a/py4DSTEM/process/phase/iterative_parallax.py b/py4DSTEM/process/phase/iterative_parallax.py
new file mode 100644
index 000000000..716e1d782
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_parallax.py
@@ -0,0 +1,2454 @@
+"""
+Module for reconstructing virtual parallax (also known as tilted-shifted bright field)
+images by aligning each virtual BF image.
+"""
+
+import warnings
+from typing import Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+from emdfile import Array, Custom, Metadata, _read_metadata, tqdmnd
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
+from py4DSTEM import Calibration, DataCube
+from py4DSTEM.preprocess.utils import get_shifted_ar
+from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction
+from py4DSTEM.process.phase.utils import AffineTransform
+from py4DSTEM.process.utils.cross_correlate import align_images_fourier
+from py4DSTEM.process.utils.utils import electron_wavelength_angstrom
+from py4DSTEM.visualize import show
+from scipy.linalg import polar
+from scipy.optimize import minimize
+from scipy.special import comb
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+_aberration_names = {
+ (1, 0): "C1 ",
+ (1, 2): "stig ",
+ (2, 1): "coma ",
+ (2, 3): "trefoil ",
+ (3, 0): "C3 ",
+ (3, 2): "stig2 ",
+ (3, 4): "quadfoil ",
+ (4, 1): "coma2 ",
+ (4, 3): "trefoil2 ",
+ (4, 5): "pentafoil ",
+ (5, 0): "C5 ",
+ (5, 2): "stig3 ",
+ (5, 4): "quadfoil2 ",
+ (5, 6): "hexafoil ",
+}
+
+
+class ParallaxReconstruction(PhaseReconstruction):
+ """
+ Iterative parallax reconstruction class.
+
+ Parameters
+ ----------
+ datacube: DataCube
+ Input 4D diffraction pattern intensities
+ energy: float
+ The electron energy of the wave functions in eV
+ verbose: bool, optional
+ If True, class methods will inherit this and print additional information
+ device: str, optional
+ Calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ object_padding_px: Tuple[int,int], optional
+ Pixel dimensions to pad object with
+ If None, the padding is set to half the probe ROI dimensions
+ """
+
+ def __init__(
+ self,
+ energy: float,
+ datacube: DataCube = None,
+ verbose: bool = False,
+ object_padding_px: Tuple[int, int] = (32, 32),
+ device: str = "cpu",
+ name: str = "parallax_reconstruction",
+ ):
+ Custom.__init__(self, name=name)
+
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ self.set_save_defaults()
+
+ # Data
+ self._datacube = datacube
+
+ # Metadata
+ self._energy = energy
+ self._verbose = verbose
+ self._device = device
+ self._object_padding_px = object_padding_px
+ self._preprocessed = False
+
+ def to_h5(self, group):
+ """
+ Wraps datasets and metadata to write in emdfile classes,
+ notably the (subpixel-)aligned BF.
+ """
+ # instantiation metadata
+ self.metadata = Metadata(
+ name="instantiation_metadata",
+ data={
+ "energy": self._energy,
+ "verbose": self._verbose,
+ "device": self._device,
+ "object_padding_px": self._object_padding_px,
+ "name": self.name,
+ },
+ )
+
+ # preprocessing metadata
+ self.metadata = Metadata(
+ name="preprocess_metadata",
+ data={
+ "scan_sampling": self._scan_sampling,
+ "wavelength": self._wavelength,
+ },
+ )
+
+ # reconstruction metadata
+ recon_metadata = {"reconstruction_error": float(self._recon_error)}
+
+ if hasattr(self, "aberration_C1"):
+ recon_metadata |= {
+ "aberration_rotation_QR": self.rotation_Q_to_R_rads,
+ "aberration_transpose": self.transpose,
+ "aberration_C1": self.aberration_C1,
+ "aberration_A1x": self.aberration_A1x,
+ "aberration_A1y": self.aberration_A1y,
+ }
+
+ if hasattr(self, "_kde_upsample_factor"):
+ recon_metadata |= {
+ "kde_upsample_factor": self._kde_upsample_factor,
+ }
+ self._subpixel_aligned_BF_emd = Array(
+ name="subpixel_aligned_BF",
+ data=self._asnumpy(self._recon_BF_subpixel_aligned),
+ )
+
+ if hasattr(self, "aberration_dict"):
+ self.metadata = Metadata(
+ name="aberrations_metadata",
+ data={
+ v["aberration name"]: v["value [Ang]"]
+ for k, v in self.aberration_dict.items()
+ },
+ )
+
+ self.metadata = Metadata(
+ name="reconstruction_metadata",
+ data=recon_metadata,
+ )
+
+ self._aligned_BF_emd = Array(
+ name="aligned_BF",
+ data=self._asnumpy(self._recon_BF),
+ )
+
+ # datacube
+ if self._save_datacube:
+ self.metadata = self._datacube.calibration
+ Custom.to_h5(self, group)
+ else:
+ dc = self._datacube
+ self._datacube = None
+ Custom.to_h5(self, group)
+ self._datacube = dc
+
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of arguments/values to pass
+ to the class' __init__ function
+ """
+ # Get data
+ dict_data = cls._get_emd_attr_data(cls, group)
+
+ # Get metadata dictionaries
+ instance_md = _read_metadata(group, "instantiation_metadata")
+
+ # Fix calibrations bug
+ if "_datacube" in dict_data:
+ calibrations_dict = _read_metadata(group, "calibration")._params
+ cal = Calibration()
+ cal._params.update(calibrations_dict)
+ dc = dict_data["_datacube"]
+ dc.calibration = cal
+ else:
+ dc = None
+
+ # Populate args and return
+ kwargs = {
+ "datacube": dc,
+ "energy": instance_md["energy"],
+ "object_padding_px": instance_md["object_padding_px"],
+ "name": instance_md["name"],
+ "verbose": True, # for compatibility
+ "device": "cpu", # for compatibility
+ }
+
+ return kwargs
+
+ def _populate_instance(self, group):
+ """
+ Sets post-initialization properties, notably some preprocessing meta
+ optional; during read, this method is run after object instantiation.
+ """
+
+ xp = self._xp
+
+ # Preprocess metadata
+ preprocess_md = _read_metadata(group, "preprocess_metadata")
+ self._scan_sampling = preprocess_md["scan_sampling"]
+ self._wavelength = preprocess_md["wavelength"]
+
+ # Reconstruction metadata
+ reconstruction_md = _read_metadata(group, "reconstruction_metadata")
+ self._recon_error = reconstruction_md["reconstruction_error"]
+
+ # Data
+ dict_data = Custom._get_emd_attr_data(Custom, group)
+
+ if "aberration_C1" in reconstruction_md.keys:
+ self.rotation_Q_to_R_rads = reconstruction_md["aberration_rotation_QR"]
+ self.transpose = reconstruction_md["aberration_transpose"]
+ self.aberration_C1 = reconstruction_md["aberration_C1"]
+ self.aberration_A1x = reconstruction_md["aberration_A1x"]
+ self.aberration_A1y = reconstruction_md["aberration_A1y"]
+
+ if "kde_upsample_factor" in reconstruction_md.keys:
+ self._kde_upsample_factor = reconstruction_md["kde_upsample_factor"]
+ self._recon_BF_subpixel_aligned = xp.asarray(
+ dict_data["_subpixel_aligned_BF_emd"].data, dtype=xp.float32
+ )
+
+ self._recon_BF = xp.asarray(dict_data["_aligned_BF_emd"].data, dtype=xp.float32)
+
+ def preprocess(
+ self,
+ edge_blend: int = 16,
+ threshold_intensity: float = 0.8,
+ normalize_images: bool = True,
+ normalize_order=0,
+ descan_correct: bool = True,
+ defocus_guess: float = None,
+ rotation_guess: float = None,
+ plot_average_bf: bool = True,
+ **kwargs,
+ ):
+ """
+ Iterative parallax reconstruction preprocessing method.
+
+ Parameters
+ ----------
+ edge_blend: int, optional
+ Pixels to blend image at the border
+ threshold: float, optional
+ Fraction of max of dp_mean for bright-field pixels
+ normalize_images: bool, optional
+ If True, bright images normalized to have a mean of 1
+ normalize_order: integer, optional
+ Polynomial order for normalization. 0 means constant, 1 means linear, etc.
+ Higher orders not yet implemented.
+ defocus_guess: float, optional
+ Initial guess of defocus value (defocus dF) in A
+ If None, first iteration is assumed to be in-focus
+ descan_correct: float, optional
+ If True, aligns bright field stack based on measured descan
+ rotation_guess: float, optional
+ Initial guess of defocus value in degrees
+ If None, first iteration assumed to be 0
+ plot_average_bf: bool, optional
+ If True, plots the average bright field image, using defocus_guess
+
+ Returns
+ --------
+ self: ParallaxReconstruction
+ Self to accommodate chaining
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ if self._datacube is None:
+ raise ValueError(
+ (
+ "The preprocess() method requires a DataCube. "
+ "Please run parallax.attach_datacube(DataCube) first."
+ )
+ )
+
+ # get mean diffraction pattern
+ try:
+ self._dp_mean = xp.asarray(
+ self._datacube.tree("dp_mean").data, dtype=xp.float32
+ )
+ except AssertionError:
+ self._dp_mean = xp.asarray(
+ self._datacube.get_dp_mean().data, dtype=xp.float32
+ )
+
+ # extract calibrations
+ self._intensities = self._extract_intensities_and_calibrations_from_datacube(
+ self._datacube,
+ require_calibrations=True,
+ )
+
+ self._region_of_interest_shape = np.array(self._intensities.shape[-2:])
+ self._scan_shape = np.array(self._intensities.shape[:2])
+
+ # make sure mean diffraction pattern is shaped correctly
+ if (self._dp_mean.shape[0] != self._intensities.shape[2]) or (
+ self._dp_mean.shape[1] != self._intensities.shape[3]
+ ):
+ raise ValueError(
+ "dp_mean must match the datacube shape. Try setting dp_mean = None."
+ )
+
+ # descan correct
+ if descan_correct:
+ (
+ _,
+ _,
+ com_fitted_x,
+ com_fitted_y,
+ _,
+ _,
+ ) = self._calculate_intensities_center_of_mass(
+ self._intensities,
+ dp_mask=None,
+ fit_function="plane",
+ com_shifts=None,
+ com_measured=None,
+ )
+
+ com_fitted_x = asnumpy(com_fitted_x)
+ com_fitted_y = asnumpy(com_fitted_y)
+ intensities = asnumpy(self._intensities)
+ intensities_shifted = np.zeros_like(intensities)
+
+ center_x, center_y = self._region_of_interest_shape / 2
+
+ for rx in range(intensities_shifted.shape[0]):
+ for ry in range(intensities_shifted.shape[1]):
+ intensity_shifted = get_shifted_ar(
+ intensities[rx, ry],
+ -com_fitted_x[rx, ry] + center_x,
+ -com_fitted_y[rx, ry] + center_y,
+ bilinear=True,
+ device="cpu",
+ )
+
+ intensities_shifted[rx, ry] = intensity_shifted
+
+ self._intensities = xp.asarray(intensities_shifted, xp.float32)
+ self._dp_mean = self._intensities.mean((0, 1))
+
+ # select virtual detector pixels
+ self._dp_mask = self._dp_mean >= (xp.max(self._dp_mean) * threshold_intensity)
+ self._num_bf_images = int(xp.count_nonzero(self._dp_mask))
+ self._wavelength = electron_wavelength_angstrom(self._energy)
+
+ # diffraction space coordinates
+ self._xy_inds = np.argwhere(self._dp_mask)
+ self._kxy = xp.asarray(
+ (self._xy_inds - xp.mean(self._xy_inds, axis=0)[None])
+ * xp.array(self._reciprocal_sampling)[None],
+ dtype=xp.float32,
+ )
+ self._probe_angles = self._kxy * self._wavelength
+ self._kr = xp.sqrt(xp.sum(self._kxy**2, axis=1))
+
+ # Window function
+ x = xp.linspace(-1, 1, self._grid_scan_shape[0] + 1, dtype=xp.float32)[1:]
+ x -= (x[1] - x[0]) / 2
+ wx = (
+ xp.sin(
+ xp.clip(
+ (1 - xp.abs(x)) * self._grid_scan_shape[0] / edge_blend / 2, 0, 1
+ )
+ * (xp.pi / 2)
+ )
+ ** 2
+ )
+ y = xp.linspace(-1, 1, self._grid_scan_shape[1] + 1, dtype=xp.float32)[1:]
+ y -= (y[1] - y[0]) / 2
+ wy = (
+ xp.sin(
+ xp.clip(
+ (1 - xp.abs(y)) * self._grid_scan_shape[1] / edge_blend / 2, 0, 1
+ )
+ * (xp.pi / 2)
+ )
+ ** 2
+ )
+ self._window_edge = wx[:, None] * wy[None, :]
+ self._window_inv = 1 - self._window_edge
+ self._window_pad = xp.zeros(
+ (
+ self._grid_scan_shape[0] + self._object_padding_px[0],
+ self._grid_scan_shape[1] + self._object_padding_px[1],
+ ),
+ dtype=xp.float32,
+ )
+ self._window_pad[
+ self._object_padding_px[0] // 2 : self._grid_scan_shape[0]
+ + self._object_padding_px[0] // 2,
+ self._object_padding_px[1] // 2 : self._grid_scan_shape[1]
+ + self._object_padding_px[1] // 2,
+ ] = self._window_edge
+
+ # Collect BF images
+ all_bfs = xp.moveaxis(
+ self._intensities[:, :, self._xy_inds[:, 0], self._xy_inds[:, 1]],
+ (0, 1, 2),
+ (1, 2, 0),
+ )
+
+ # initalize
+ stack_shape = (
+ self._num_bf_images,
+ self._grid_scan_shape[0] + self._object_padding_px[0],
+ self._grid_scan_shape[1] + self._object_padding_px[1],
+ )
+ if normalize_images:
+ self._stack_BF = xp.ones(stack_shape, dtype=xp.float32)
+ self._stack_BF_no_window = xp.ones(stack_shape, xp.float32)
+
+ if normalize_order == 0:
+ all_bfs /= xp.mean(all_bfs, axis=(1, 2))[:, None, None]
+ self._stack_BF[
+ :,
+ self._object_padding_px[0] // 2 : self._grid_scan_shape[0]
+ + self._object_padding_px[0] // 2,
+ self._object_padding_px[1] // 2 : self._grid_scan_shape[1]
+ + self._object_padding_px[1] // 2,
+ ] = (
+ self._window_inv[None] + self._window_edge[None] * all_bfs
+ )
+
+ self._stack_BF_no_window[
+ :,
+ self._object_padding_px[0] // 2 : self._grid_scan_shape[0]
+ + self._object_padding_px[0] // 2,
+ self._object_padding_px[1] // 2 : self._grid_scan_shape[1]
+ + self._object_padding_px[1] // 2,
+ ] = all_bfs
+
+ elif normalize_order == 1:
+ x = xp.linspace(-0.5, 0.5, all_bfs.shape[1], xp.float32)
+ y = xp.linspace(-0.5, 0.5, all_bfs.shape[2], xp.float32)
+ ya, xa = xp.meshgrid(y, x)
+ basis = np.vstack(
+ (
+ xp.ones_like(xa),
+ xa.ravel(),
+ ya.ravel(),
+ )
+ ).T
+ for a0 in range(all_bfs.shape[0]):
+ coefs = np.linalg.lstsq(basis, all_bfs[a0].ravel(), rcond=None)
+
+ self._stack_BF[
+ a0,
+ self._object_padding_px[0] // 2 : self._grid_scan_shape[0]
+ + self._object_padding_px[0] // 2,
+ self._object_padding_px[1] // 2 : self._grid_scan_shape[1]
+ + self._object_padding_px[1] // 2,
+ ] = self._window_inv[None] + self._window_edge[None] * all_bfs[
+ a0
+ ] / xp.reshape(
+ basis @ coefs[0], all_bfs.shape[1:3]
+ )
+
+ self._stack_BF_no_window[
+ a0,
+ self._object_padding_px[0] // 2 : self._grid_scan_shape[0]
+ + self._object_padding_px[0] // 2,
+ self._object_padding_px[1] // 2 : self._grid_scan_shape[1]
+ + self._object_padding_px[1] // 2,
+ ] = all_bfs[a0] / xp.reshape(basis @ coefs[0], all_bfs.shape[1:3])
+
+ else:
+ all_means = xp.mean(all_bfs, axis=(1, 2))
+ self._stack_BF = xp.full(stack_shape, all_means[:, None, None])
+ self._stack_BF_no_window = xp.full(stack_shape, all_means[:, None, None])
+ self._stack_BF[
+ :,
+ self._object_padding_px[0] // 2 : self._grid_scan_shape[0]
+ + self._object_padding_px[0] // 2,
+ self._object_padding_px[1] // 2 : self._grid_scan_shape[1]
+ + self._object_padding_px[1] // 2,
+ ] = (
+ self._window_inv[None] * all_means[:, None, None]
+ + self._window_edge[None] * all_bfs
+ )
+
+ self._stack_BF_no_window[
+ :,
+ self._object_padding_px[0] // 2 : self._grid_scan_shape[0]
+ + self._object_padding_px[0] // 2,
+ self._object_padding_px[1] // 2 : self._grid_scan_shape[1]
+ + self._object_padding_px[1] // 2,
+ ] = all_bfs
+
+ # Fourier space operators for image shifts
+ qx = xp.fft.fftfreq(self._stack_BF.shape[1], d=1)
+ qx = xp.asarray(qx, dtype=xp.float32)
+
+ qy = xp.fft.fftfreq(self._stack_BF.shape[2], d=1)
+ qy = xp.asarray(qy, dtype=xp.float32)
+
+ qxa, qya = xp.meshgrid(qx, qy, indexing="ij")
+ self._qx_shift = -2j * xp.pi * qxa
+ self._qy_shift = -2j * xp.pi * qya
+
+ # Initialization utilities
+ self._stack_mask = xp.tile(self._window_pad[None], (self._num_bf_images, 1, 1))
+ if defocus_guess is not None:
+ Gs = xp.fft.fft2(self._stack_BF)
+
+ self._xy_shifts = (
+ -self._probe_angles * defocus_guess / xp.array(self._scan_sampling)
+ )
+
+ if rotation_guess:
+ angle = xp.deg2rad(rotation_guess)
+ rotation_matrix = xp.array(
+ [[np.cos(angle), np.sin(angle)], [-np.sin(angle), np.cos(angle)]]
+ )
+ self._xy_shifts = xp.dot(self._xy_shifts, rotation_matrix)
+
+ dx = self._xy_shifts[:, 0]
+ dy = self._xy_shifts[:, 1]
+
+ shift_op = xp.exp(
+ self._qx_shift[None] * dx[:, None, None]
+ + self._qy_shift[None] * dy[:, None, None]
+ )
+ self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op))
+ self._stack_mask = xp.real(
+ xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op)
+ )
+
+ del Gs
+ else:
+ self._xy_shifts = xp.zeros((self._num_bf_images, 2), dtype=xp.float32)
+
+ self._stack_mean = xp.mean(self._stack_BF)
+ self._mask_sum = xp.sum(self._window_edge) * self._num_bf_images
+ self._recon_mask = xp.sum(self._stack_mask, axis=0)
+
+ mask_inv = 1 - xp.clip(self._recon_mask, 0, 1)
+
+ self._recon_BF = (
+ self._stack_mean * mask_inv
+ + xp.sum(self._stack_BF * self._stack_mask, axis=0)
+ ) / (self._recon_mask + mask_inv)
+
+ self._recon_error = (
+ xp.atleast_1d(
+ xp.sum(xp.abs(self._stack_BF - self._recon_BF[None]) * self._stack_mask)
+ )
+ / self._mask_sum
+ )
+
+ self._recon_BF_initial = self._recon_BF.copy()
+ self._stack_BF_initial = self._stack_BF.copy()
+ self._stack_mask_initial = self._stack_mask.copy()
+ self._recon_mask_initial = self._recon_mask.copy()
+ self._xy_shifts_initial = self._xy_shifts.copy()
+
+ self.recon_BF = asnumpy(self._recon_BF)
+
+ if plot_average_bf:
+ figsize = kwargs.pop("figsize", (6, 12))
+
+ fig, ax = plt.subplots(1, 2, figsize=figsize)
+
+ self._visualize_figax(fig, ax[0], **kwargs)
+
+ ax[0].set_ylabel("x [A]")
+ ax[0].set_xlabel("y [A]")
+ ax[0].set_title("Average Bright Field Image")
+
+ reciprocal_extent = [
+ -0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]),
+ 0.5 * (self._reciprocal_sampling[1] * self._dp_mask.shape[1]),
+ 0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]),
+ -0.5 * (self._reciprocal_sampling[0] * self._dp_mask.shape[0]),
+ ]
+ ax[1].imshow(self._dp_mask, extent=reciprocal_extent, cmap="gray")
+ ax[1].set_title("DP mask")
+ ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]")
+ ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]")
+ plt.tight_layout()
+ self._preprocessed = True
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def tune_angle_and_defocus(
+ self,
+ angle_guess=None,
+ defocus_guess=None,
+ angle_step_size=5,
+ defocus_step_size=100,
+ num_angle_values=5,
+ num_defocus_values=5,
+ return_values=False,
+ plot_reconstructions=True,
+ plot_convergence=True,
+ **kwargs,
+ ):
+ """
+ Run parallax reconstruction over a parameters space of pre-determined angles
+ and defocus
+
+ Parameters
+ ----------
+ angle_guess: float (degrees), optional
+ initial starting guess for rotation angle between real and reciprocal space
+ if None, uses 0
+ defocus_guess: float (A), optional
+ initial starting guess for defocus (defocus dF)
+ if None, uses 0
+ angle_step_size: float (degrees), optional
+ size of change of rotation angle between real and reciprocal space for
+ each step in parameter space
+ defocus_step_size: float (A), optional
+ size of change of defocus for each step in parameter space
+ num_angle_values: int, optional
+ number of values of angle to test, must be >= 1.
+ num_defocus_values: int,optional
+ number of values of defocus to test, must be >= 1
+ plot_reconstructions: bool, optional
+ if True, plot phase of reconstructed objects
+ plot_convergence: bool, optional
+ if True, makes 2D plot of error metrix
+ return_values: bool, optional
+ if True, returns objects, convergence
+
+ Returns
+ -------
+ objects: list
+ reconstructed objects
+ convergence: np.ndarray
+ array of convergence values from reconstructions
+ """
+ asnumpy = self._asnumpy
+
+ if angle_guess is None:
+ angle_guess = 0
+ if defocus_guess is None:
+ defocus_guess = 0
+
+ if num_angle_values == 1:
+ angle_step_size = 0
+
+ if num_defocus_values == 1:
+ defocus_step_size = 0
+
+ angles = np.linspace(
+ angle_guess - angle_step_size * (num_angle_values - 1) / 2,
+ angle_guess + angle_step_size * (num_angle_values - 1) / 2,
+ num_angle_values,
+ )
+
+ defocus_values = np.linspace(
+ defocus_guess - defocus_step_size * (num_defocus_values - 1) / 2,
+ defocus_guess + defocus_step_size * (num_defocus_values - 1) / 2,
+ num_defocus_values,
+ )
+ if return_values or plot_convergence:
+ recon_BF = []
+ convergence = []
+
+ if plot_reconstructions:
+ spec = GridSpec(
+ ncols=num_defocus_values,
+ nrows=num_angle_values,
+ hspace=0.15,
+ wspace=0.35,
+ )
+ figsize = kwargs.get(
+ "figsize", (4 * num_defocus_values, 4 * num_angle_values)
+ )
+
+ fig = plt.figure(figsize=figsize)
+
+ # run loop and plot along the way
+ self._verbose = False
+ for flat_index, (angle, defocus) in enumerate(
+ tqdmnd(angles, defocus_values, desc="Tuning angle and defocus")
+ ):
+ self.preprocess(
+ defocus_guess=defocus,
+ rotation_guess=angle,
+ plot_average_bf=False,
+ **kwargs,
+ )
+
+ if plot_reconstructions:
+ row_index, col_index = np.unravel_index(
+ flat_index, (num_angle_values, num_defocus_values)
+ )
+ object_ax = fig.add_subplot(spec[row_index, col_index])
+ self._visualize_figax(
+ fig,
+ ax=object_ax,
+ )
+
+ object_ax.set_title(
+ f" angle = {angle:.1f} °, defocus = {defocus:.1f} A \n error = {self._recon_error[0]:.3e}"
+ )
+ object_ax.set_xticks([])
+ object_ax.set_yticks([])
+
+ if return_values:
+ recon_BF.append(self.recon_BF)
+ if return_values or plot_convergence:
+ convergence.append(asnumpy(self._recon_error[0]))
+
+ if plot_convergence:
+ fig, ax = plt.subplots()
+ ax.set_title("convergence")
+ im = ax.imshow(
+ np.array(convergence).reshape(angles.shape[0], defocus_values.shape[0]),
+ cmap="magma",
+ )
+
+ if angles.shape[0] > 1:
+ ax.set_ylabel("angles")
+ ax.set_yticks(np.arange(angles.shape[0]))
+ ax.set_yticklabels([f"{angle:.1f} °" for angle in angles])
+ else:
+ ax.set_yticks([])
+ ax.set_ylabel(f"angle {angles[0]:.1f}")
+
+ if defocus_values.shape[0] > 1:
+ ax.set_xlabel("defocus values")
+ ax.set_xticks(np.arange(defocus_values.shape[0]))
+ ax.set_xticklabels([f"{df:.1f}" for df in defocus_values])
+ else:
+ ax.set_xticks([])
+ ax.set_xlabel(f"defocus value: {defocus_values[0]:.1f}")
+
+ divider = make_axes_locatable(ax)
+ cax = divider.append_axes("right", size="5%", pad=0.05)
+ fig.colorbar(im, cax=cax)
+
+ fig.tight_layout()
+
+ if return_values:
+ convergence = np.array(convergence).reshape(
+ angles.shape[0], defocus_values.shape[0]
+ )
+ return recon_BF, convergence
+
+ def reconstruct(
+ self,
+ max_alignment_bin: int = None,
+ min_alignment_bin: int = 1,
+ max_iter_at_min_bin: int = 2,
+ cross_correlation_upsample_factor: int = 8,
+ regularizer_matrix_size: Tuple[int, int] = (1, 1),
+ regularize_shifts: bool = True,
+ running_average: bool = True,
+ progress_bar: bool = True,
+ plot_aligned_bf: bool = True,
+ plot_convergence: bool = True,
+ reset: bool = None,
+ **kwargs,
+ ):
+ """
+ Iterative Parallax Reconstruction main reconstruction method.
+
+ Parameters
+ ----------
+ max_alignment_bin: int, optional
+ Maximum bin size for bright field alignment
+ If None, the bright field disk radius is used
+ min_alignment_bin: int, optional
+ Minimum bin size for bright field alignment
+ max_iter_at_min_bin: int, optional
+ Number of iterations to run at the smallest bin size
+ cross_correlation_upsample_factor: int, optional
+ DFT upsample factor for subpixel alignment
+ regularizer_matrix_size: Tuple[int,int], optional
+ Bernstein basis degree used for regularizing shifts
+ regularize_shifts: bool, optional
+ If True, the cross-correlated shifts are constrained to a spline interpolation
+ running_average: bool, optional
+ If True, the bright field reference image is updated in a spiral from the origin
+ progress_bar: bool, optional
+ If True, progress bar is displayed
+ plot_aligned_bf: bool, optional
+ If True, the aligned bright field image is plotted at each bin level
+ plot_convergence: bool, optional
+ If True, the convergence error is also plotted
+ reset: bool, optional
+ If True, the reconstruction is reset
+
+ Returns
+ --------
+ self: BFReconstruction
+ Self to accommodate chaining
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ if reset:
+ self._recon_BF = self._recon_BF_initial.copy()
+ self._stack_BF = self._stack_BF_initial.copy()
+ self._stack_mask = self._stack_mask_initial.copy()
+ self._recon_mask = self._recon_mask_initial.copy()
+ self._xy_shifts = self._xy_shifts_initial.copy()
+ elif reset is None:
+ if hasattr(self, "_basis"):
+ warnings.warn(
+ (
+ "Continuing reconstruction from previous result. "
+ "Use reset=True for a fresh start."
+ ),
+ UserWarning,
+ )
+
+ if not regularize_shifts:
+ self._basis = self._kxy
+ else:
+ kr_max = xp.max(self._kr)
+ u = self._kxy[:, 0] * 0.5 / kr_max + 0.5
+ v = self._kxy[:, 1] * 0.5 / kr_max + 0.5
+
+ self._basis = xp.zeros(
+ (
+ self._num_bf_images,
+ (regularizer_matrix_size[0] + 1) * (regularizer_matrix_size[1] + 1),
+ ),
+ dtype=xp.float32,
+ )
+ for ii in np.arange(regularizer_matrix_size[0] + 1):
+ Bi = (
+ comb(regularizer_matrix_size[0], ii)
+ * (u**ii)
+ * ((1 - u) ** (regularizer_matrix_size[0] - ii))
+ )
+
+ for jj in np.arange(regularizer_matrix_size[1] + 1):
+ Bj = (
+ comb(regularizer_matrix_size[1], jj)
+ * (v**jj)
+ * ((1 - v) ** (regularizer_matrix_size[1] - jj))
+ )
+
+ ind = ii * (regularizer_matrix_size[1] + 1) + jj
+ self._basis[:, ind] = Bi * Bj
+
+ # Iterative binning for more robust alignment
+ diameter_pixels = int(
+ xp.maximum(
+ xp.max(self._xy_inds[:, 0]) - xp.min(self._xy_inds[:, 0]),
+ xp.max(self._xy_inds[:, 1]) - xp.min(self._xy_inds[:, 1]),
+ )
+ + 1
+ )
+
+ if max_alignment_bin is not None:
+ max_alignment_bin = np.minimum(diameter_pixels, max_alignment_bin)
+ else:
+ max_alignment_bin = diameter_pixels
+
+ bin_min = np.ceil(np.log(min_alignment_bin) / np.log(2))
+ bin_max = np.ceil(np.log(max_alignment_bin) / np.log(2))
+ bin_vals = 2 ** np.arange(bin_min, bin_max)[::-1]
+
+ if max_iter_at_min_bin > 1:
+ bin_vals = np.hstack(
+ (bin_vals, np.repeat(bin_vals[-1], max_iter_at_min_bin - 1))
+ )
+
+ if plot_aligned_bf:
+ num_plots = bin_vals.shape[0]
+ nrows = int(np.sqrt(num_plots))
+ ncols = int(np.ceil(num_plots / nrows))
+
+ if plot_convergence:
+ errors = []
+ spec = GridSpec(
+ ncols=ncols,
+ nrows=nrows + 1,
+ hspace=0.15,
+ wspace=0.15,
+ height_ratios=[1] * nrows + [1 / 4],
+ )
+
+ figsize = kwargs.get("figsize", (4 * ncols, 4 * nrows + 1))
+ else:
+ spec = GridSpec(
+ ncols=ncols,
+ nrows=nrows,
+ hspace=0.15,
+ wspace=0.15,
+ )
+
+ figsize = kwargs.get("figsize", (4 * ncols, 4 * nrows))
+
+ kwargs.pop("figsize", None)
+ fig = plt.figure(figsize=figsize)
+
+ xy_center = (self._xy_inds - xp.median(self._xy_inds, axis=0)).astype("float")
+
+ # Loop over all binning values
+ for a0 in range(bin_vals.shape[0]):
+ G_ref = xp.fft.fft2(self._recon_BF)
+
+ # Segment the virtual images with current binning values
+ xy_inds = xp.round(xy_center / bin_vals[a0] + 0.5).astype("int")
+ xy_vals = np.unique(
+ asnumpy(xy_inds), axis=0
+ ) # axis is not yet supported in cupy
+ # Sort by radial order, from center to outer edge
+ inds_order = xp.argsort(xp.sum(xy_vals**2, axis=1))
+
+ shifts_update = xp.zeros((self._num_bf_images, 2), dtype=xp.float32)
+
+ for a1 in tqdmnd(
+ xy_vals.shape[0],
+ desc="Alignment at bin " + str(bin_vals[a0].astype("int")),
+ unit=" image subsets",
+ disable=not progress_bar,
+ ):
+ ind_align = inds_order[a1]
+
+ # Generate mean image for alignment
+ sub = xp.logical_and(
+ xy_inds[:, 0] == xy_vals[ind_align, 0],
+ xy_inds[:, 1] == xy_vals[ind_align, 1],
+ )
+
+ G = xp.fft.fft2(xp.mean(self._stack_BF[sub], axis=0))
+
+ # Get best fit alignment
+ xy_shift = align_images_fourier(
+ G_ref,
+ G,
+ upsample_factor=cross_correlation_upsample_factor,
+ device=self._device,
+ )
+
+ dx = (
+ xp.mod(
+ xy_shift[0] + self._stack_BF.shape[1] / 2,
+ self._stack_BF.shape[1],
+ )
+ - self._stack_BF.shape[1] / 2
+ )
+ dy = (
+ xp.mod(
+ xy_shift[1] + self._stack_BF.shape[2] / 2,
+ self._stack_BF.shape[2],
+ )
+ - self._stack_BF.shape[2] / 2
+ )
+
+ # output shifts
+ shifts_update[sub, 0] = dx
+ shifts_update[sub, 1] = dy
+
+ # update running estimate of reference image
+ shift_op = xp.exp(self._qx_shift * dx + self._qy_shift * dy)
+
+ if running_average:
+ G_ref = G_ref * a1 / (a1 + 1) + (G * shift_op) / (a1 + 1)
+
+ # regularize the shifts
+ xy_shifts_new = self._xy_shifts + shifts_update
+ coefs = xp.linalg.lstsq(self._basis, xy_shifts_new, rcond=None)[0]
+ xy_shifts_fit = self._basis @ coefs
+ shifts_update = xy_shifts_fit - self._xy_shifts
+
+ # apply shifts
+ Gs = xp.fft.fft2(self._stack_BF)
+
+ dx = shifts_update[:, 0]
+ dy = shifts_update[:, 1]
+ self._xy_shifts[:, 0] += dx
+ self._xy_shifts[:, 1] += dy
+
+ shift_op = xp.exp(
+ self._qx_shift[None] * dx[:, None, None]
+ + self._qy_shift[None] * dy[:, None, None]
+ )
+
+ self._stack_BF = xp.real(xp.fft.ifft2(Gs * shift_op))
+ self._stack_mask = xp.real(
+ xp.fft.ifft2(xp.fft.fft2(self._stack_mask) * shift_op)
+ )
+
+ self._stack_BF = xp.asarray(
+ self._stack_BF, dtype=xp.float32
+ ) # numpy fft upcasts?
+ self._stack_mask = xp.asarray(
+ self._stack_mask, dtype=xp.float32
+ ) # numpy fft upcasts?
+
+ del Gs
+
+ # Center the shifts
+ xy_shifts_median = xp.round(xp.median(self._xy_shifts, axis=0)).astype(int)
+ self._xy_shifts -= xy_shifts_median[None, :]
+ self._stack_BF = xp.roll(self._stack_BF, -xy_shifts_median, axis=(1, 2))
+ self._stack_mask = xp.roll(self._stack_mask, -xy_shifts_median, axis=(1, 2))
+
+ # Generate new estimate
+ self._recon_mask = xp.sum(self._stack_mask, axis=0)
+
+ mask_inv = 1 - np.clip(self._recon_mask, 0, 1)
+ self._recon_BF = (
+ self._stack_mean * mask_inv
+ + xp.sum(self._stack_BF * self._stack_mask, axis=0)
+ ) / (self._recon_mask + mask_inv)
+
+ self._recon_error = (
+ xp.atleast_1d(
+ xp.sum(
+ xp.abs(self._stack_BF - self._recon_BF[None]) * self._stack_mask
+ )
+ )
+ / self._mask_sum
+ )
+
+ if plot_aligned_bf:
+ row_index, col_index = np.unravel_index(a0, (nrows, ncols))
+
+ ax = fig.add_subplot(spec[row_index, col_index])
+ self._visualize_figax(fig, ax, **kwargs)
+
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_title(f"Aligned BF at bin {int(bin_vals[a0])}")
+
+ if plot_convergence:
+ errors.append(float(self._recon_error))
+
+ if plot_aligned_bf:
+ if plot_convergence:
+ ax = fig.add_subplot(spec[-1, :])
+ ax.plot(np.arange(num_plots), errors)
+ ax.set_xticks(np.arange(num_plots))
+ ax.set_ylabel("Error")
+ spec.tight_layout(fig)
+
+ self.recon_BF = asnumpy(self._recon_BF)
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def subpixel_alignment(
+ self,
+ kde_upsample_factor=None,
+ kde_sigma=0.125,
+ plot_upsampled_BF_comparison: bool = True,
+ plot_upsampled_FFT_comparison: bool = False,
+ **kwargs,
+ ):
+ """
+ Upsample and subpixel-align BFs using the measured image shifts.
+ Uses kernel density estimation (KDE) to align upsampled BFs.
+
+ Parameters
+ ----------
+ kde_upsample_factor: int, optional
+ Real-space upsampling factor
+ kde_sigma: float, optional
+ KDE gaussian kernel bandwidth
+ plot_upsampled_BF_comparison: bool, optional
+ If True, the pre/post alignment BF images are plotted for comparison
+ plot_upsampled_FFT_comparison: bool, optional
+ If True, the pre/post alignment BF FFTs are plotted for comparison
+
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+ gaussian_filter = self._gaussian_filter
+
+ xy_shifts = self._xy_shifts
+ BF_size = np.array(self._stack_BF_no_window.shape[-2:])
+
+ self._DF_upsample_limit = np.max(
+ 2 * self._region_of_interest_shape / self._scan_shape
+ )
+ self._BF_upsample_limit = (
+ 4 * self._kr.max() / self._reciprocal_sampling[0]
+ ) / self._scan_shape.max()
+ if self._device == "gpu":
+ self._BF_upsample_limit = self._BF_upsample_limit.item()
+
+ if kde_upsample_factor is None:
+ if self._BF_upsample_limit * 3 / 2 > self._DF_upsample_limit:
+ kde_upsample_factor = self._DF_upsample_limit
+
+ warnings.warn(
+ (
+ f"Upsampling factor set to {kde_upsample_factor:.2f} (the "
+ f"dark-field upsampling limit)."
+ ),
+ UserWarning,
+ )
+
+ elif self._BF_upsample_limit * 3 / 2 > 1:
+ kde_upsample_factor = self._BF_upsample_limit * 3 / 2
+
+ warnings.warn(
+ (
+ f"Upsampling factor set to {kde_upsample_factor:.2f} (1.5 times the "
+ f"bright-field upsampling limit of {self._BF_upsample_limit:.2f})."
+ ),
+ UserWarning,
+ )
+ else:
+ kde_upsample_factor = self._DF_upsample_limit * 2 / 3
+
+ warnings.warn(
+ (
+ f"Upsampling factor set to {kde_upsample_factor:.2f} (2/3 times the "
+ f"dark-field upsampling limit of {self._DF_upsample_limit:.2f})."
+ ),
+ UserWarning,
+ )
+
+ if kde_upsample_factor < 1:
+ raise ValueError("kde_upsample_factor must be larger than 1")
+
+ if kde_upsample_factor > self._DF_upsample_limit:
+ warnings.warn(
+ (
+ "Requested upsampling factor exceeds "
+ f"dark-field upsampling limit of {self._DF_upsample_limit:.2f}."
+ ),
+ UserWarning,
+ )
+
+ self._kde_upsample_factor = kde_upsample_factor
+ pixel_output = np.round(BF_size * self._kde_upsample_factor).astype("int")
+ pixel_size = pixel_output.prod()
+
+ # shifted coordinates
+ x = xp.arange(BF_size[0])
+ y = xp.arange(BF_size[1])
+
+ xa, ya = xp.meshgrid(x, y, indexing="ij")
+ xa = ((xa + xy_shifts[:, 0, None, None]) * self._kde_upsample_factor).ravel()
+ ya = ((ya + xy_shifts[:, 1, None, None]) * self._kde_upsample_factor).ravel()
+
+ # bilinear sampling
+ xF = xp.floor(xa).astype("int")
+ yF = xp.floor(ya).astype("int")
+ dx = xa - xF
+ dy = ya - yF
+
+ # resampling
+ inds_1D = xp.ravel_multi_index(
+ xp.hstack(
+ [
+ [xF, yF],
+ [xF + 1, yF],
+ [xF, yF + 1],
+ [xF + 1, yF + 1],
+ ]
+ ),
+ pixel_output,
+ mode=["wrap", "wrap"],
+ )
+
+ weights = xp.hstack(
+ (
+ (1 - dx) * (1 - dy),
+ (dx) * (1 - dy),
+ (1 - dx) * (dy),
+ (dx) * (dy),
+ )
+ )
+
+ pix_count = xp.reshape(
+ xp.bincount(inds_1D, weights=weights, minlength=pixel_size), pixel_output
+ )
+ pix_output = xp.reshape(
+ xp.bincount(
+ inds_1D,
+ weights=weights * xp.tile(self._stack_BF_no_window.ravel(), 4),
+ minlength=pixel_size,
+ ),
+ pixel_output,
+ )
+
+ # kernel density estimate
+ sigma = kde_sigma * self._kde_upsample_factor
+ pix_count = gaussian_filter(pix_count, sigma)
+ pix_count[pix_count == 0.0] = np.inf
+ pix_output = gaussian_filter(pix_output, sigma)
+ pix_output /= pix_count
+
+ self._recon_BF_subpixel_aligned = pix_output
+ self.recon_BF_subpixel_aligned = asnumpy(self._recon_BF_subpixel_aligned)
+
+ # plotting
+ if plot_upsampled_BF_comparison:
+ if plot_upsampled_FFT_comparison:
+ figsize = kwargs.pop("figsize", (8, 8))
+ fig, axs = plt.subplots(2, 2, figsize=figsize)
+ else:
+ figsize = kwargs.pop("figsize", (8, 4))
+ fig, axs = plt.subplots(1, 2, figsize=figsize)
+
+ axs = axs.flat
+ cmap = kwargs.pop("cmap", "magma")
+
+ cropped_object = self._crop_padded_object(self._recon_BF)
+ cropped_object_aligned = self._crop_padded_object(
+ self._recon_BF_subpixel_aligned, upsampled=True
+ )
+
+ extent = [
+ 0,
+ self._scan_sampling[1] * cropped_object.shape[1],
+ self._scan_sampling[0] * cropped_object.shape[0],
+ 0,
+ ]
+
+ axs[0].imshow(
+ cropped_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ axs[0].set_title("Aligned Bright Field")
+
+ axs[1].imshow(
+ cropped_object_aligned,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ axs[1].set_title("Upsampled Bright Field")
+
+ for ax in axs[:2]:
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ if plot_upsampled_FFT_comparison:
+ recon_fft = xp.fft.fftshift(xp.abs(xp.fft.fft2(self._recon_BF)))
+ pad_x = np.round(
+ BF_size[0] * (self._kde_upsample_factor - 1) / 2
+ ).astype("int")
+ pad_y = np.round(
+ BF_size[1] * (self._kde_upsample_factor - 1) / 2
+ ).astype("int")
+ pad_recon_fft = asnumpy(
+ xp.pad(recon_fft, ((pad_x, pad_x), (pad_y, pad_y)))
+ )
+
+ upsampled_fft = asnumpy(
+ xp.fft.fftshift(
+ xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned))
+ )
+ )
+
+ reciprocal_extent = [
+ -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor),
+ 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor),
+ 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor),
+ -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor),
+ ]
+
+ show(
+ pad_recon_fft,
+ figax=(fig, axs[2]),
+ extent=reciprocal_extent,
+ cmap="gray",
+ title="Aligned Bright Field FFT",
+ **kwargs,
+ )
+
+ show(
+ upsampled_fft,
+ figax=(fig, axs[3]),
+ extent=reciprocal_extent,
+ cmap="gray",
+ title="Upsampled Bright Field FFT",
+ **kwargs,
+ )
+
+ for ax in axs[2:]:
+ ax.set_ylabel(r"$k_x$ [$A^{-1}$]")
+ ax.set_xlabel(r"$k_y$ [$A^{-1}$]")
+ ax.xaxis.set_ticks_position("bottom")
+
+ fig.tight_layout()
+
+ def aberration_fit(
+ self,
+ fit_BF_shifts: bool = False,
+ fit_CTF_FFT: bool = False,
+ fit_aberrations_max_radial_order: int = 3,
+ fit_aberrations_max_angular_order: int = 4,
+ fit_aberrations_min_radial_order: int = 2,
+ fit_aberrations_min_angular_order: int = 0,
+ fit_max_thon_rings: int = 6,
+ fit_power_alpha: float = 2.0,
+ plot_CTF_comparison: bool = None,
+ plot_BF_shifts_comparison: bool = None,
+ upsampled: bool = True,
+ force_transpose: bool = False,
+ ):
+ """
+ Fit aberrations to the measured image shifts.
+
+ Parameters
+ ----------
+ fit_BF_shifts: bool
+ Set to True to fit aberrations to the measured BF shifts directly.
+ fit_CTF_FFT: bool
+ Set to True to fit aberrations in the FFT of the (upsampled) BF
+ image. Note that this method relies on visible zero crossings in the FFT.
+ fit_aberrations_max_radial_order: int
+ Max radial order for fitting of aberrations.
+ fit_aberrations_max_angular_order: int
+ Max angular order for fitting of aberrations.
+ fit_aberrations_min_radial_order: int
+ Min radial order for fitting of aberrations.
+ fit_aberrations_min_angular_order: int
+ Min angular order for fitting of aberrations.
+ fit_max_thon_rings: int
+ Max number of Thon rings to search for during CTF FFT fitting.
+ fit_power_alpha: int
+ Power to raise FFT alpha weighting during CTF FFT fitting.
+ plot_CTF_comparison: bool, optional
+ If True, the fitted CTF is plotted against the reconstructed frequencies.
+ plot_BF_shifts_comparison: bool, optional
+ If True, the measured vs fitted BF shifts are plotted.
+ upsampled: bool
+ If True, and upsampled BF is available, uses that for CTF FFT fitting.
+ force_transpose: bool
+ If True, and fit_BF_shifts is True, flips the measured x and y shifts
+
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ ### First pass
+
+ # Convert real space shifts to Angstroms
+
+ if force_transpose is True:
+ self._xy_shifts_Ang = xp.flip(self._xy_shifts, axis=1) * xp.array(
+ self._scan_sampling
+ )
+ else:
+ self._xy_shifts_Ang = self._xy_shifts * xp.array(self._scan_sampling)
+ self.transpose = force_transpose
+
+ # Solve affine transformation
+ m = asnumpy(
+ xp.linalg.lstsq(self._probe_angles, self._xy_shifts_Ang, rcond=None)[0]
+ )
+ m_rotation, m_aberration = polar(m, side="right")
+
+ # Convert into rotation and aberration coefficients
+ self.rotation_Q_to_R_rads = -1 * np.arctan2(m_rotation[1, 0], m_rotation[0, 0])
+ if np.abs(np.mod(self.rotation_Q_to_R_rads + np.pi, 2.0 * np.pi) - np.pi) > (
+ np.pi * 0.5
+ ):
+ self.rotation_Q_to_R_rads = (
+ np.mod(self.rotation_Q_to_R_rads, 2.0 * np.pi) - np.pi
+ )
+ m_aberration = -1.0 * m_aberration
+
+ self.aberration_C1 = (m_aberration[0, 0] + m_aberration[1, 1]) / 2.0
+
+ if self.transpose:
+ self.aberration_A1x = -(m_aberration[0, 0] - m_aberration[1, 1]) / 2.0
+ self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0
+ else:
+ self.aberration_A1x = (m_aberration[0, 0] - m_aberration[1, 1]) / 2.0
+ self.aberration_A1y = (m_aberration[1, 0] + m_aberration[0, 1]) / 2.0
+
+ ### Second pass
+
+ # Aberration coefs
+ mn = []
+
+ for m in range(
+ fit_aberrations_min_radial_order - 1, fit_aberrations_max_radial_order
+ ):
+ n_max = np.minimum(fit_aberrations_max_angular_order, m + 1)
+ for n in range(fit_aberrations_min_angular_order, n_max + 1):
+ if (m + n) % 2:
+ mn.append([m, n, 0])
+ if n > 0:
+ mn.append([m, n, 1])
+
+ self._aberrations_mn = np.array(mn)
+ self._aberrations_mn = self._aberrations_mn[
+ np.argsort(self._aberrations_mn[:, 1]), :
+ ]
+
+ sub = self._aberrations_mn[:, 1] > 0
+ self._aberrations_mn[sub, :] = self._aberrations_mn[sub, :][
+ np.argsort(self._aberrations_mn[sub, 0]), :
+ ]
+ self._aberrations_mn[~sub, :] = self._aberrations_mn[~sub, :][
+ np.argsort(self._aberrations_mn[~sub, 0]), :
+ ]
+ self._aberrations_num = self._aberrations_mn.shape[0]
+
+ if plot_CTF_comparison is None:
+ if fit_CTF_FFT:
+ plot_CTF_comparison = True
+
+ if plot_BF_shifts_comparison is None:
+ if fit_BF_shifts:
+ plot_BF_shifts_comparison = True
+
+ # Thon Rings Fitting
+ if fit_CTF_FFT or plot_CTF_comparison:
+ if upsampled and hasattr(self, "_kde_upsample_factor"):
+ im_FFT = xp.abs(xp.fft.fft2(self._recon_BF_subpixel_aligned))
+ sx = self._scan_sampling[0] / self._kde_upsample_factor
+ sy = self._scan_sampling[1] / self._kde_upsample_factor
+
+ reciprocal_extent = [
+ -0.5 / (self._scan_sampling[1] / self._kde_upsample_factor),
+ 0.5 / (self._scan_sampling[1] / self._kde_upsample_factor),
+ 0.5 / (self._scan_sampling[0] / self._kde_upsample_factor),
+ -0.5 / (self._scan_sampling[0] / self._kde_upsample_factor),
+ ]
+
+ else:
+ im_FFT = xp.abs(xp.fft.fft2(self._recon_BF))
+ sx = self._scan_sampling[0]
+ sy = self._scan_sampling[1]
+ upsampled = False
+
+ reciprocal_extent = [
+ -0.5 / self._scan_sampling[1],
+ 0.5 / self._scan_sampling[1],
+ 0.5 / self._scan_sampling[0],
+ -0.5 / self._scan_sampling[0],
+ ]
+
+ # FFT coordinates
+ qx = xp.fft.fftfreq(im_FFT.shape[0], sx)
+ qy = xp.fft.fftfreq(im_FFT.shape[1], sy)
+ qr2 = qx[:, None] ** 2 + qy[None, :] ** 2
+
+ alpha_FFT = xp.sqrt(qr2) * self._wavelength
+ theta_FFT = xp.arctan2(qy[None, :], qx[:, None])
+
+ # Aberration basis
+ self._aberrations_basis_FFT = xp.zeros(
+ (alpha_FFT.size, self._aberrations_num)
+ )
+ for a0 in range(self._aberrations_num):
+ m, n, a = self._aberrations_mn[a0]
+ if n == 0:
+ # Radially symmetric basis
+ self._aberrations_basis_FFT[:, a0] = (
+ alpha_FFT ** (m + 1) / (m + 1)
+ ).ravel()
+
+ elif a == 0:
+ # cos coef
+ self._aberrations_basis_FFT[:, a0] = (
+ alpha_FFT ** (m + 1) * xp.cos(n * theta_FFT) / (m + 1)
+ ).ravel()
+ else:
+ # sin coef
+ self._aberrations_basis_FFT[:, a0] = (
+ alpha_FFT ** (m + 1) * xp.sin(n * theta_FFT) / (m + 1)
+ ).ravel()
+
+ # global scaling
+ self._aberrations_basis_FFT *= 2 * np.pi / self._wavelength
+ self._aberrations_surface_shape_FFT = alpha_FFT.shape
+ plot_mask = qr2 > np.pi**2 / 4 / np.abs(self.aberration_C1)
+ angular_mask = np.cos(8.0 * theta_FFT) ** 2 < 0.25
+
+ # CTF function
+ def calculate_CTF_FFT(alpha_shape, *coefs):
+ chi = xp.zeros_like(self._aberrations_basis_FFT[:, 0])
+ for a0 in range(len(coefs)):
+ chi += coefs[a0] * self._aberrations_basis_FFT[:, a0]
+ return xp.reshape(chi, alpha_shape)
+
+ # Direct Shifts Fitting
+ if fit_BF_shifts:
+ # FFT coordinates
+ sx = 1 / (self._reciprocal_sampling[0] * self._region_of_interest_shape[0])
+ sy = 1 / (self._reciprocal_sampling[1] * self._region_of_interest_shape[1])
+ qx = xp.fft.fftfreq(self._region_of_interest_shape[0], sx)
+ qy = xp.fft.fftfreq(self._region_of_interest_shape[1], sy)
+ qx, qy = np.meshgrid(qx, qy, indexing="ij")
+
+ # passive rotation basis by -theta
+ rotation_angle = -self.rotation_Q_to_R_rads
+ qx, qy = qx * np.cos(rotation_angle) + qy * np.sin(
+ rotation_angle
+ ), -qx * np.sin(rotation_angle) + qy * np.cos(rotation_angle)
+
+ qr2 = qx**2 + qy**2
+ u = qx * self._wavelength
+ v = qy * self._wavelength
+ alpha = xp.sqrt(qr2) * self._wavelength
+ theta = xp.arctan2(qy, qx)
+
+ # Aberration basis
+ self._aberrations_basis = xp.zeros((alpha.size, self._aberrations_num))
+ self._aberrations_basis_du = xp.zeros((alpha.size, self._aberrations_num))
+ self._aberrations_basis_dv = xp.zeros((alpha.size, self._aberrations_num))
+ for a0 in range(self._aberrations_num):
+ m, n, a = self._aberrations_mn[a0]
+
+ if n == 0:
+ # Radially symmetric basis
+ self._aberrations_basis[:, a0] = (
+ alpha ** (m + 1) / (m + 1)
+ ).ravel()
+ self._aberrations_basis_du[:, a0] = (u * alpha ** (m - 1)).ravel()
+ self._aberrations_basis_dv[:, a0] = (v * alpha ** (m - 1)).ravel()
+
+ elif a == 0:
+ # cos coef
+ self._aberrations_basis[:, a0] = (
+ alpha ** (m + 1) * xp.cos(n * theta) / (m + 1)
+ ).ravel()
+ self._aberrations_basis_du[:, a0] = (
+ alpha ** (m - 1)
+ * ((m + 1) * u * xp.cos(n * theta) + n * v * xp.sin(n * theta))
+ / (m + 1)
+ ).ravel()
+ self._aberrations_basis_dv[:, a0] = (
+ alpha ** (m - 1)
+ * ((m + 1) * v * xp.cos(n * theta) - n * u * xp.sin(n * theta))
+ / (m + 1)
+ ).ravel()
+
+ else:
+ # sin coef
+ self._aberrations_basis[:, a0] = (
+ alpha ** (m + 1) * xp.sin(n * theta) / (m + 1)
+ ).ravel()
+ self._aberrations_basis_du[:, a0] = (
+ alpha ** (m - 1)
+ * ((m + 1) * u * xp.sin(n * theta) - n * v * xp.cos(n * theta))
+ / (m + 1)
+ ).ravel()
+ self._aberrations_basis_dv[:, a0] = (
+ alpha ** (m - 1)
+ * ((m + 1) * v * xp.sin(n * theta) + n * u * xp.cos(n * theta))
+ / (m + 1)
+ ).ravel()
+
+ # global scaling
+ self._aberrations_basis *= 2 * np.pi / self._wavelength
+ self._aberrations_surface_shape = alpha.shape
+
+ # CTF function
+ def calculate_CTF(alpha_shape, *coefs):
+ chi = xp.zeros_like(self._aberrations_basis[:, 0])
+ for a0 in range(len(coefs)):
+ chi += coefs[a0] * self._aberrations_basis[:, a0]
+ return xp.reshape(chi, alpha_shape)
+
+ # initial coefficients and plotting intensity range mask
+ self._aberrations_coefs = np.zeros(self._aberrations_num)
+
+ aberrations_mn_list = self._aberrations_mn.tolist()
+ if [1, 0, 0] in aberrations_mn_list:
+ ind_C1 = aberrations_mn_list.index([1, 0, 0])
+ self._aberrations_coefs[ind_C1] = self.aberration_C1
+
+ if [1, 2, 0] in aberrations_mn_list:
+ ind_A1x = aberrations_mn_list.index([1, 2, 0])
+ ind_A1y = aberrations_mn_list.index([1, 2, 1])
+ self._aberrations_coefs[ind_A1x] = self.aberration_A1x
+ self._aberrations_coefs[ind_A1y] = self.aberration_A1y
+
+ # Refinement using CTF fitting / Thon rings
+ if fit_CTF_FFT:
+ # scoring function to minimize - mean value of zero crossing regions of FFT
+ def score_CTF(coefs):
+ im_CTF = xp.abs(
+ calculate_CTF_FFT(self._aberrations_surface_shape_FFT, *coefs)
+ )
+ mask = xp.logical_and(
+ im_CTF > 0.5 * np.pi,
+ im_CTF < (max_num_rings + 0.5) * np.pi,
+ )
+ if np.any(mask):
+ weights = xp.cos(im_CTF[mask]) ** 4
+ return asnumpy(
+ xp.sum(
+ weights * im_FFT[mask] * alpha_FFT[mask] ** fit_power_alpha
+ )
+ / xp.sum(weights)
+ )
+ else:
+ return np.inf
+
+ for max_num_rings in range(1, fit_max_thon_rings + 1):
+ # minimization
+ res = minimize(
+ score_CTF,
+ self._aberrations_coefs,
+ # method = 'Nelder-Mead',
+ # method = 'CG',
+ method="BFGS",
+ tol=1e-8,
+ )
+ self._aberrations_coefs = res.x
+
+ # Refinement using CTF fitting / Thon rings
+ elif fit_BF_shifts:
+ # Gradient basis
+ corner_indices = self._xy_inds - xp.asarray(
+ self._region_of_interest_shape // 2
+ )
+ raveled_indices = np.ravel_multi_index(
+ corner_indices.T, self._region_of_interest_shape, mode="wrap"
+ )
+ gradients = xp.vstack(
+ (
+ self._aberrations_basis_du[raveled_indices, :],
+ self._aberrations_basis_dv[raveled_indices, :],
+ )
+ )
+
+ # (Relative) untransposed fit
+ raveled_shifts = self._xy_shifts_Ang.T.ravel()
+ aberrations_coefs, res = xp.linalg.lstsq(
+ gradients, raveled_shifts, rcond=None
+ )[:2]
+
+ self._aberrations_coefs = asnumpy(aberrations_coefs)
+
+ if self.transpose:
+ aberrations_to_flip = (self._aberrations_mn[:, 1] > 0) & (
+ self._aberrations_mn[:, 2] == 0
+ )
+ self._aberrations_coefs[aberrations_to_flip] *= -1
+
+ # Plot the measured/fitted shifts comparison
+ if plot_BF_shifts_comparison:
+ measured_shifts_sx = xp.zeros(
+ self._region_of_interest_shape, dtype=xp.float32
+ )
+ measured_shifts_sx[
+ self._xy_inds[:, 0], self._xy_inds[:, 1]
+ ] = self._xy_shifts_Ang[:, 0]
+
+ measured_shifts_sy = xp.zeros(
+ self._region_of_interest_shape, dtype=xp.float32
+ )
+ measured_shifts_sy[
+ self._xy_inds[:, 0], self._xy_inds[:, 1]
+ ] = self._xy_shifts_Ang[:, 1]
+
+ fitted_shifts = (
+ xp.tensordot(gradients, xp.array(self._aberrations_coefs), axes=1)
+ .reshape((2, -1))
+ .T
+ )
+
+ fitted_shifts_sx = xp.zeros(
+ self._region_of_interest_shape, dtype=xp.float32
+ )
+ fitted_shifts_sx[
+ self._xy_inds[:, 0], self._xy_inds[:, 1]
+ ] = fitted_shifts[:, 0]
+
+ fitted_shifts_sy = xp.zeros(
+ self._region_of_interest_shape, dtype=xp.float32
+ )
+ fitted_shifts_sy[
+ self._xy_inds[:, 0], self._xy_inds[:, 1]
+ ] = fitted_shifts[:, 1]
+
+ max_shift = xp.max(
+ xp.array(
+ [
+ xp.abs(measured_shifts_sx).max(),
+ xp.abs(measured_shifts_sy).max(),
+ xp.abs(fitted_shifts_sx).max(),
+ xp.abs(fitted_shifts_sy).max(),
+ ]
+ )
+ )
+
+ show(
+ [
+ [asnumpy(measured_shifts_sx), asnumpy(fitted_shifts_sx)],
+ [asnumpy(measured_shifts_sy), asnumpy(fitted_shifts_sy)],
+ ],
+ cmap="PiYG",
+ vmin=-max_shift,
+ vmax=max_shift,
+ intensity_range="absolute",
+ axsize=(4, 4),
+ ticks=False,
+ title=[
+ "Measured Vertical Shifts",
+ "Fitted Vertical Shifts",
+ "Measured Horizontal Shifts",
+ "Fitted Horizontal Shifts",
+ ],
+ )
+
+ # Plot the CTF comparison between experiment and fit
+ if plot_CTF_comparison:
+ # Generate FFT plotting image
+ im_scale = asnumpy(im_FFT * alpha_FFT**fit_power_alpha)
+ int_vals = np.sort(im_scale.ravel())
+ int_range = (
+ int_vals[np.round(0.02 * im_scale.size).astype("int")],
+ int_vals[np.round(0.98 * im_scale.size).astype("int")],
+ )
+ int_range = (
+ int_range[0],
+ (int_range[1] - int_range[0]) * 1.0 + int_range[0],
+ )
+ im_scale = np.clip(
+ (np.fft.fftshift(im_scale) - int_range[0])
+ / (int_range[1] - int_range[0]),
+ 0,
+ 1,
+ )
+ im_plot = np.tile(im_scale[:, :, None], (1, 1, 3))
+
+ # Add CTF zero crossings
+ im_CTF = calculate_CTF_FFT(
+ self._aberrations_surface_shape_FFT, *self._aberrations_coefs
+ )
+ im_CTF_cos = xp.cos(xp.abs(im_CTF)) ** 4
+ im_CTF[xp.abs(im_CTF) > (fit_max_thon_rings + 0.5) * np.pi] = np.pi / 2
+ im_CTF = xp.abs(xp.sin(im_CTF)) < 0.15
+ im_CTF[xp.logical_not(plot_mask)] = 0
+
+ im_CTF = np.fft.fftshift(asnumpy(im_CTF * angular_mask))
+ im_plot[:, :, 0] += im_CTF
+ im_plot[:, :, 1] -= im_CTF
+ im_plot[:, :, 2] -= im_CTF
+ im_plot = np.clip(im_plot, 0, 1)
+
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
+ ax1.imshow(
+ im_plot, vmin=int_range[0], vmax=int_range[1], extent=reciprocal_extent
+ )
+ ax2.imshow(
+ np.fft.fftshift(asnumpy(im_CTF_cos)),
+ cmap="gray",
+ extent=reciprocal_extent,
+ )
+
+ for ax in (ax1, ax2):
+ ax.set_ylabel(r"$k_x$ [$A^{-1}$]")
+ ax.set_xlabel(r"$k_y$ [$A^{-1}$]")
+
+ ax1.set_title("Aligned Bright Field FFT")
+ ax2.set_title("Fitted CTF Zero-Crossings")
+
+ fig.tight_layout()
+
+ self.aberration_dict = {
+ tuple(self._aberrations_mn[a0]): {
+ "aberration name": _aberration_names.get(
+ tuple(self._aberrations_mn[a0, :2]), "-"
+ ).strip(),
+ "value [Ang]": self._aberrations_coefs[a0],
+ }
+ for a0 in range(self._aberrations_num)
+ }
+
+ # Print results
+ if self._verbose:
+ if fit_CTF_FFT or fit_BF_shifts:
+ print("Initial Aberration coefficients")
+ print("-------------------------------")
+ print(
+ (
+ "Rotation of Q w.r.t. R = "
+ f"{np.rad2deg(self.rotation_Q_to_R_rads):.3f} deg"
+ )
+ )
+ print(
+ (
+ "Astigmatism (A1x,A1y) = ("
+ f"{self.aberration_A1x:.0f},"
+ f"{self.aberration_A1y:.0f}) Ang"
+ )
+ )
+ print(f"Aberration C1 = {self.aberration_C1:.0f} Ang")
+ print(f"Defocus dF = {-1*self.aberration_C1:.0f} Ang")
+ print(f"Transpose = {self.transpose}")
+
+ if fit_CTF_FFT or fit_BF_shifts:
+ print()
+ print("Refined Aberration coefficients")
+ print("-------------------------------")
+ print("aberration radial angular dir. coefs")
+ print("name order order Ang ")
+ print("---------- ------- ------- ---- -----")
+
+ for a0 in range(self._aberrations_mn.shape[0]):
+ m, n, a = self._aberrations_mn[a0]
+ name = _aberration_names.get((m, n), " -- ")
+ if n == 0:
+ print(
+ name
+ + " "
+ + str(m + 1)
+ + " 0 - "
+ + str(np.round(self._aberrations_coefs[a0]).astype("int"))
+ )
+ elif a == 0:
+ print(
+ name
+ + " "
+ + str(m + 1)
+ + " "
+ + str(n)
+ + " x "
+ + str(np.round(self._aberrations_coefs[a0]).astype("int"))
+ )
+ else:
+ print(
+ name
+ + " "
+ + str(m + 1)
+ + " "
+ + str(n)
+ + " y "
+ + str(np.round(self._aberrations_coefs[a0]).astype("int"))
+ )
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ def _calculate_CTF(self, alpha_shape, sampling, *coefs):
+ xp = self._xp
+
+ # FFT coordinates
+ sx, sy = sampling
+ qx = xp.fft.fftfreq(alpha_shape[0], sx)
+ qy = xp.fft.fftfreq(alpha_shape[1], sy)
+ qr2 = qx[:, None] ** 2 + qy[None, :] ** 2
+
+ alpha = xp.sqrt(qr2) * self._wavelength
+ theta = xp.arctan2(qy[None, :], qx[:, None])
+
+ # Aberration basis
+ aberrations_basis = xp.zeros((alpha.size, self._aberrations_num))
+ for a0 in range(self._aberrations_num):
+ m, n, a = self._aberrations_mn[a0]
+ if n == 0:
+ # Radially symmetric basis
+ aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel()
+
+ elif a == 0:
+ # cos coef
+ aberrations_basis[:, a0] = (
+ alpha ** (m + 1) * xp.cos(n * theta) / (m + 1)
+ ).ravel()
+ else:
+ # sin coef
+ aberrations_basis[:, a0] = (
+ alpha ** (m + 1) * xp.sin(n * theta) / (m + 1)
+ ).ravel()
+
+ # global scaling
+ aberrations_basis *= 2 * np.pi / self._wavelength
+
+ chi = xp.zeros_like(aberrations_basis[:, 0])
+
+ for a0 in range(len(coefs)):
+ chi += coefs[a0] * aberrations_basis[:, a0]
+
+ return xp.reshape(chi, alpha_shape)
+
+ def aberration_correct(
+ self,
+ use_CTF_fit=None,
+ plot_corrected_phase: bool = True,
+ k_info_limit: float = None,
+ k_info_power: float = 1.0,
+ Wiener_filter=False,
+ Wiener_signal_noise_ratio: float = 1.0,
+ Wiener_filter_low_only: bool = False,
+ upsampled: bool = True,
+ **kwargs,
+ ):
+ """
+ CTF correction of the phase image using the measured defocus aberration.
+
+ Parameters
+ ----------
+ use_FFT_fit: bool
+ Use the CTF fitted to the zero crossings of the FFT.
+ Default is True
+ plot_corrected_phase: bool, optional
+ If True, the CTF-corrected phase is plotted
+ k_info_limit: float, optional
+ maximum allowed frequency in butterworth filter
+ k_info_power: float, optional
+ power of butterworth filter
+ Wiener_filter: bool, optional
+ Use Wiener filtering instead of CTF sign correction.
+ Wiener_signal_noise_ratio: float, optional
+ Signal to noise radio at k = 0 for Wiener filter
+ Wiener_filter_low_only: bool, optional
+ Apply Wiener filtering only to the CTF portions before the 1st CTF maxima.
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ if not hasattr(self, "aberration_C1"):
+ raise ValueError(
+ (
+ "CTF correction is meant to be ran after alignment and aberration fitting. "
+ "Please run the `reconstruct()` and `aberration_fit()` functions first."
+ )
+ )
+
+ if upsampled and hasattr(self, "_kde_upsample_factor"):
+ im = self._recon_BF_subpixel_aligned
+ sx = self._scan_sampling[0] / self._kde_upsample_factor
+ sy = self._scan_sampling[1] / self._kde_upsample_factor
+ else:
+ upsampled = False
+ im = self._recon_BF
+ sx = self._scan_sampling[0]
+ sy = self._scan_sampling[1]
+
+ # Fourier coordinates
+ kx = xp.fft.fftfreq(im.shape[0], sx)
+ ky = xp.fft.fftfreq(im.shape[1], sy)
+ kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2
+
+ if use_CTF_fit is None:
+ if hasattr(self, "_aberrations_surface_shape"):
+ use_CTF_fit = True
+
+ if use_CTF_fit:
+ sin_chi = np.sin(
+ self._calculate_CTF(im.shape, (sx, sy), *self._aberrations_coefs)
+ )
+
+ CTF_corr = xp.sign(sin_chi)
+ CTF_corr[0, 0] = 0
+
+ # apply correction to mean reconstructed BF image
+ im_fft_corr = xp.fft.fft2(im) * CTF_corr
+
+ # if needed, add low pass filter output image
+ if k_info_limit is not None:
+ im_fft_corr /= 1 + (kra2**k_info_power) / (
+ (k_info_limit) ** (2 * k_info_power)
+ )
+ else:
+ # CTF
+ sin_chi = xp.sin((xp.pi * self._wavelength * self.aberration_C1) * kra2)
+
+ if Wiener_filter:
+ SNR_inv = (
+ xp.sqrt(
+ 1
+ + (kra2**k_info_power)
+ / ((k_info_limit) ** (2 * k_info_power))
+ )
+ / Wiener_signal_noise_ratio
+ )
+ CTF_corr = xp.sign(sin_chi) / (sin_chi**2 + SNR_inv)
+ if Wiener_filter_low_only:
+ # limit Wiener filter to only the part of the CTF before 1st maxima
+ k_thresh = 1 / xp.sqrt(
+ 2.0 * self._wavelength * xp.abs(self.aberration_C1)
+ )
+ k_mask = kra2 >= k_thresh**2
+ CTF_corr[k_mask] = xp.sign(sin_chi[k_mask])
+
+ # apply correction to mean reconstructed BF image
+ im_fft_corr = xp.fft.fft2(im) * CTF_corr
+
+ else:
+ # CTF without tilt correction (beyond the parallax operator)
+ CTF_corr = xp.sign(sin_chi)
+ CTF_corr[0, 0] = 0
+
+ # apply correction to mean reconstructed BF image
+ im_fft_corr = xp.fft.fft2(im) * CTF_corr
+
+ # if needed, add low pass filter output image
+ if k_info_limit is not None:
+ im_fft_corr /= 1 + (kra2**k_info_power) / (
+ (k_info_limit) ** (2 * k_info_power)
+ )
+
+ # Output phase image
+ self._recon_phase_corrected = xp.real(xp.fft.ifft2(im_fft_corr))
+ self.recon_phase_corrected = asnumpy(self._recon_phase_corrected)
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ # plotting
+ if plot_corrected_phase:
+ figsize = kwargs.pop("figsize", (6, 6))
+ cmap = kwargs.pop("cmap", "magma")
+
+ fig, ax = plt.subplots(figsize=figsize)
+
+ cropped_object = self._crop_padded_object(
+ self._recon_phase_corrected, upsampled=upsampled
+ )
+
+ extent = [
+ 0,
+ sy * cropped_object.shape[1],
+ sx * cropped_object.shape[0],
+ 0,
+ ]
+
+ ax.imshow(
+ cropped_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ ax.set_title("Parallax-Corrected Phase Image")
+
+ def depth_section(
+ self,
+ depth_angstroms=np.arange(-250, 260, 100),
+ plot_depth_sections=True,
+ k_info_limit: float = None,
+ k_info_power: float = 1.0,
+ progress_bar=True,
+ **kwargs,
+ ):
+ """
+ CTF correction of the BF image using the measured defocus aberration.
+
+ Parameters
+ ----------
+ depth_angstroms: np.array
+ Specify the depths
+ k_info_limit: float, optional
+ maximum allowed frequency in butterworth filter
+ k_info_power: float, optional
+ power of butterworth filter
+
+
+ Returns
+ -------
+ stack_depth: np.array
+ stack of phase images at different depths with shape [depth Nx Ny]
+
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+ depth_angstroms = xp.atleast_1d(depth_angstroms)
+
+ if not hasattr(self, "aberration_C1"):
+ raise ValueError(
+ (
+ "Depth sectioning is meant to be ran after alignment and aberration fitting. "
+ "Please run the `reconstruct()` and `aberration_fit()` functions first."
+ )
+ )
+
+ # Fourier coordinates
+ kx = xp.fft.fftfreq(self._recon_BF.shape[0], self._scan_sampling[0])
+ ky = xp.fft.fftfreq(self._recon_BF.shape[1], self._scan_sampling[1])
+ kra2 = (kx[:, None]) ** 2 + (ky[None, :]) ** 2
+
+ # information limit
+ if k_info_limit is not None:
+ k_filt = 1 / (
+ 1 + (kra2**k_info_power) / ((k_info_limit) ** (2 * k_info_power))
+ )
+
+ # init
+ stack_depth = xp.zeros(
+ (depth_angstroms.shape[0], self._recon_BF.shape[0], self._recon_BF.shape[1])
+ )
+
+ # plotting
+ if plot_depth_sections:
+ num_plots = depth_angstroms.shape[0]
+ nrows = int(np.sqrt(num_plots))
+ ncols = int(np.ceil(num_plots / nrows))
+
+ spec = GridSpec(
+ ncols=ncols,
+ nrows=nrows,
+ hspace=0.15,
+ wspace=0.15,
+ )
+
+ figsize = kwargs.pop("figsize", (4 * ncols, 4 * nrows))
+ cmap = kwargs.pop("cmap", "magma")
+
+ fig = plt.figure(figsize=figsize)
+
+ # main loop
+ for a0 in tqdmnd(
+ depth_angstroms.shape[0],
+ desc="Depth sectioning ",
+ unit="plane",
+ disable=not progress_bar,
+ ):
+ dz = depth_angstroms[a0]
+
+ # Parallax
+ im_depth = xp.zeros_like(self._recon_BF, dtype="complex")
+ for a1 in range(self._stack_BF.shape[0]):
+ dx = self._probe_angles[a1, 0] * dz
+ dy = self._probe_angles[a1, 1] * dz
+ im_depth += xp.fft.fft2(self._stack_BF[a1]) * xp.exp(
+ self._qx_shift * dx + self._qy_shift * dy
+ )
+
+ # CTF correction
+ sin_chi = xp.sin(
+ (xp.pi * self._wavelength * (self.aberration_C1 + dz)) * kra2
+ )
+ CTF_corr = xp.sign(sin_chi)
+ CTF_corr[0, 0] = 0
+ if k_info_limit is not None:
+ CTF_corr *= k_filt
+
+ # apply correction to mean reconstructed BF image
+ stack_depth[a0] = (
+ xp.real(xp.fft.ifft2(im_depth * CTF_corr)) / self._stack_BF.shape[0]
+ )
+
+ if plot_depth_sections:
+ row_index, col_index = np.unravel_index(a0, (nrows, ncols))
+ ax = fig.add_subplot(spec[row_index, col_index])
+
+ cropped_object = self._crop_padded_object(asnumpy(stack_depth[a0]))
+
+ extent = [
+ 0,
+ self._scan_sampling[1] * cropped_object.shape[1],
+ self._scan_sampling[0] * cropped_object.shape[0],
+ 0,
+ ]
+
+ ax.imshow(
+ cropped_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_title(f"Depth section: {dz}A")
+
+ if self._device == "gpu":
+ xp = self._xp
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return stack_depth
+
+ def _crop_padded_object(
+ self,
+ padded_object: np.ndarray,
+ remaining_padding: int = 0,
+ upsampled: bool = False,
+ ):
+ """
+ Utility function to crop padded object
+
+ Parameters
+ ----------
+ padded_object: np.ndarray
+ Padded object to be cropped
+ remaining_padding: int, optional
+ Padding to leave uncropped
+
+ Returns
+ -------
+ cropped_object: np.ndarray
+ Cropped object
+
+ """
+
+ asnumpy = self._asnumpy
+
+ if upsampled:
+ pad_x = np.round(
+ self._object_padding_px[0] / 2 * self._kde_upsample_factor
+ ).astype("int")
+ pad_y = np.round(
+ self._object_padding_px[1] / 2 * self._kde_upsample_factor
+ ).astype("int")
+ else:
+ pad_x = self._object_padding_px[0] // 2
+ pad_y = self._object_padding_px[1] // 2
+
+ pad_x -= remaining_padding
+ pad_y -= remaining_padding
+
+ return asnumpy(padded_object[pad_x:-pad_x, pad_y:-pad_y])
+
+ def _visualize_figax(
+ self,
+ fig,
+ ax,
+ remaining_padding: int = 0,
+ upsampled: bool = False,
+ **kwargs,
+ ):
+ """
+ Utility function to visualize bright field average on given fig/ax
+
+ Parameters
+ ----------
+ fig: Figure
+ Matplotlib figure ax lives in
+ ax: Axes
+ Matplotlib axes to plot bright field average in
+ remaining_padding: int, optional
+ Padding to leave uncropped
+
+ """
+
+ cmap = kwargs.pop("cmap", "magma")
+
+ if upsampled:
+ cropped_object = self._crop_padded_object(
+ self._recon_BF_subpixel_aligned, remaining_padding, upsampled
+ )
+
+ extent = [
+ 0,
+ self._scan_sampling[1]
+ * cropped_object.shape[1]
+ / self._kde_upsample_factor,
+ self._scan_sampling[0]
+ * cropped_object.shape[0]
+ / self._kde_upsample_factor,
+ 0,
+ ]
+
+ else:
+ cropped_object = self._crop_padded_object(self._recon_BF, remaining_padding)
+
+ extent = [
+ 0,
+ self._scan_sampling[1] * cropped_object.shape[1],
+ self._scan_sampling[0] * cropped_object.shape[0],
+ 0,
+ ]
+
+ ax.imshow(
+ cropped_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ def show_shifts(
+ self,
+ scale_arrows=1,
+ plot_arrow_freq=1,
+ plot_rotated_shifts=True,
+ **kwargs,
+ ):
+ """
+ Utility function to visualize bright field disk pixel shifts
+
+ Parameters
+ ----------
+ scale_arrows: float, optional
+ Scale to multiply shifts by
+ plot_arrow_freq: int, optional
+ Frequency of shifts to plot in quiver plot
+ """
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ color = kwargs.pop("color", (1, 0, 0, 1))
+ if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"):
+ figsize = kwargs.pop("figsize", (8, 4))
+ fig, ax = plt.subplots(1, 2, figsize=figsize)
+ scaling_factor = (
+ xp.array(self._reciprocal_sampling)
+ / xp.array(self._scan_sampling)
+ * scale_arrows
+ )
+ rotated_shifts = self._xy_shifts_Ang * scaling_factor
+
+ else:
+ figsize = kwargs.pop("figsize", (4, 4))
+ fig, ax = plt.subplots(figsize=figsize)
+
+ shifts = self._xy_shifts * scale_arrows * self._reciprocal_sampling[0]
+
+ dp_mask_ind = xp.nonzero(self._dp_mask)
+ yy, xx = xp.meshgrid(
+ xp.arange(self._dp_mean.shape[1]), xp.arange(self._dp_mean.shape[0])
+ )
+ freq_mask = xp.logical_and(xx % plot_arrow_freq == 0, yy % plot_arrow_freq == 0)
+ masked_ind = xp.logical_and(freq_mask, self._dp_mask)
+ plot_ind = masked_ind[dp_mask_ind]
+
+ kr_max = xp.max(self._kr)
+ if plot_rotated_shifts and hasattr(self, "rotation_Q_to_R_rads"):
+ ax[0].quiver(
+ asnumpy(self._kxy[plot_ind, 1]),
+ asnumpy(self._kxy[plot_ind, 0]),
+ asnumpy(shifts[plot_ind, 1]),
+ asnumpy(shifts[plot_ind, 0]),
+ color=color,
+ angles="xy",
+ scale_units="xy",
+ scale=1,
+ **kwargs,
+ )
+
+ ax[0].set_xlim([-1.2 * kr_max, 1.2 * kr_max])
+ ax[0].set_ylim([-1.2 * kr_max, 1.2 * kr_max])
+ ax[0].set_title("Measured Bright Field Shifts")
+ ax[0].set_ylabel(r"$k_x$ [$A^{-1}$]")
+ ax[0].set_xlabel(r"$k_y$ [$A^{-1}$]")
+ ax[0].set_aspect("equal")
+
+ # passive coordinate rotation
+ tf_T = AffineTransform(angle=-self.rotation_Q_to_R_rads)
+ rotated_kxy = tf_T(self._kxy[plot_ind], xp=xp)
+ ax[1].quiver(
+ asnumpy(rotated_kxy[:, 1]),
+ asnumpy(rotated_kxy[:, 0]),
+ asnumpy(rotated_shifts[plot_ind, 1]),
+ asnumpy(rotated_shifts[plot_ind, 0]),
+ angles="xy",
+ scale_units="xy",
+ scale=1,
+ **kwargs,
+ )
+
+ ax[1].set_xlim([-1.2 * kr_max, 1.2 * kr_max])
+ ax[1].set_ylim([-1.2 * kr_max, 1.2 * kr_max])
+ ax[1].set_title("Rotated Bright Field Shifts")
+ ax[1].set_ylabel(r"$k_x$ [$A^{-1}$]")
+ ax[1].set_xlabel(r"$k_y$ [$A^{-1}$]")
+ ax[1].set_aspect("equal")
+ else:
+ ax.quiver(
+ asnumpy(self._kxy[plot_ind, 1]),
+ asnumpy(self._kxy[plot_ind, 0]),
+ asnumpy(shifts[plot_ind, 1]),
+ asnumpy(shifts[plot_ind, 0]),
+ color=color,
+ angles="xy",
+ scale_units="xy",
+ scale=1,
+ **kwargs,
+ )
+
+ ax.set_xlim([-1.2 * kr_max, 1.2 * kr_max])
+ ax.set_ylim([-1.2 * kr_max, 1.2 * kr_max])
+ ax.set_title("Measured BF Shifts")
+ ax.set_ylabel(r"$k_x$ [$A^{-1}$]")
+ ax.set_xlabel(r"$k_y$ [$A^{-1}$]")
+ ax.set_aspect("equal")
+
+ fig.tight_layout()
+
+ def visualize(
+ self,
+ **kwargs,
+ ):
+ """
+ Visualization function for bright field average
+
+ Returns
+ --------
+ self: BFReconstruction
+ Self to accommodate chaining
+ """
+
+ figsize = kwargs.pop("figsize", (6, 6))
+
+ fig, ax = plt.subplots(figsize=figsize)
+
+ self._visualize_figax(fig, ax, **kwargs)
+
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ ax.set_title("Reconstructed Bright Field Image")
+
+ return self
+
+ @property
+ def object_cropped(self):
+ """cropped object"""
+ if hasattr(self, "_recon_phase_corrected"):
+ if hasattr(self, "_kde_upsample_factor"):
+ return self._crop_padded_object(
+ self._recon_phase_corrected, upsampled=True
+ )
+ else:
+ return self._crop_padded_object(self._recon_phase_corrected)
+ else:
+ if hasattr(self, "_kde_upsample_factor"):
+ return self._crop_padded_object(
+ self._recon_BF_subpixel_aligned, upsampled=True
+ )
+ else:
+ return self._crop_padded_object(self._recon_BF)
diff --git a/py4DSTEM/process/phase/iterative_ptychographic_constraints.py b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py
new file mode 100644
index 000000000..d29aa1747
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_ptychographic_constraints.py
@@ -0,0 +1,638 @@
+import warnings
+
+import numpy as np
+import pylops
+from py4DSTEM.process.phase.utils import (
+ array_slice,
+ estimate_global_transformation_ransac,
+ fft_shift,
+ fit_aberration_surface,
+ regularize_probe_amplitude,
+)
+from py4DSTEM.process.utils import get_CoM
+
+
+class PtychographicConstraints:
+ """
+ Container class for PtychographicReconstruction methods.
+ """
+
+ def _object_threshold_constraint(self, current_object, pure_phase_object):
+ """
+ Ptychographic threshold constraint.
+ Used for avoiding the scaling ambiguity between probe and object.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ pure_phase_object: bool
+ If True, object amplitude is set to unity
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+ phase = xp.angle(current_object)
+
+ if pure_phase_object:
+ amplitude = 1.0
+ else:
+ amplitude = xp.minimum(xp.abs(current_object), 1.0)
+
+ return amplitude * xp.exp(1.0j * phase)
+
+ def _object_shrinkage_constraint(self, current_object, shrinkage_rad, object_mask):
+ """
+ Ptychographic shrinkage constraint.
+ Used to ensure electrostatic potential is positive.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ object_mask: np.ndarray (boolean)
+ If not None, used to calculate additional shrinkage using masked-mean of object
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+
+ if self._object_type == "complex":
+ phase = xp.angle(current_object)
+ amp = xp.abs(current_object)
+
+ if object_mask is not None:
+ shrinkage_rad += phase[..., object_mask].mean()
+
+ phase -= shrinkage_rad
+
+ current_object = amp * xp.exp(1.0j * phase)
+ else:
+ if object_mask is not None:
+ shrinkage_rad += current_object[..., object_mask].mean()
+
+ current_object -= shrinkage_rad
+
+ return current_object
+
+ def _object_positivity_constraint(self, current_object):
+ """
+ Ptychographic positivity constraint.
+ Used to ensure potential is positive.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+
+ return xp.maximum(current_object, 0.0)
+
+ def _object_gaussian_constraint(
+ self, current_object, gaussian_filter_sigma, pure_phase_object
+ ):
+ """
+ Ptychographic smoothness constraint.
+ Used for blurring object.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A
+ pure_phase_object: bool
+ If True, gaussian blur performed on phase only
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+ gaussian_filter = self._gaussian_filter
+ gaussian_filter_sigma /= self.sampling[0]
+
+ if pure_phase_object:
+ phase = xp.angle(current_object)
+ phase = gaussian_filter(phase, gaussian_filter_sigma)
+ current_object = xp.exp(1.0j * phase)
+ else:
+ current_object = gaussian_filter(current_object, gaussian_filter_sigma)
+
+ return current_object
+
+ def _object_butterworth_constraint(
+ self,
+ current_object,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ ):
+ """
+ Ptychographic butterworth filter.
+ Used for low/high-pass filtering object.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ """
+ xp = self._xp
+ qx = xp.fft.fftfreq(current_object.shape[0], self.sampling[0])
+ qy = xp.fft.fftfreq(current_object.shape[1], self.sampling[1])
+
+ qya, qxa = xp.meshgrid(qy, qx)
+ qra = xp.sqrt(qxa**2 + qya**2)
+
+ env = xp.ones_like(qra)
+ if q_highpass:
+ env *= 1 - 1 / (1 + (qra / q_highpass) ** (2 * butterworth_order))
+ if q_lowpass:
+ env *= 1 / (1 + (qra / q_lowpass) ** (2 * butterworth_order))
+
+ current_object_mean = xp.mean(current_object)
+ current_object -= current_object_mean
+ current_object = xp.fft.ifft2(xp.fft.fft2(current_object) * env)
+ current_object += current_object_mean
+
+ if self._object_type == "potential":
+ current_object = xp.real(current_object)
+
+ return current_object
+
+ def _object_denoise_tv_pylops(self, current_object, weight, iterations):
+ """
+ Performs second order TV denoising along x and y
+
+ Parameters
+ ----------
+ current_object: np.ndarray
+ Current object estimate
+ weight : float
+ Denoising weight. The greater `weight`, the more denoising (at
+ the expense of fidelity to `input`).
+ iterations: float
+ Number of iterations to run in denoising algorithm.
+ `niter_out` in pylops
+
+ Returns
+ -------
+ constrained_object: np.ndarray
+ Constrained object estimate
+
+ """
+ xp = self._xp
+
+ if xp.iscomplexobj(current_object):
+ current_object_tv = current_object
+ warnings.warn(
+ ("TV denoising is currently only supported for potential objects."),
+ UserWarning,
+ )
+
+ else:
+ nx, ny = current_object.shape
+ niter_out = iterations
+ niter_in = 1
+ Iop = pylops.Identity(nx * ny)
+ xy_laplacian = pylops.Laplacian(
+ (nx, ny), axes=(0, 1), edge=False, kind="backward"
+ )
+
+ l1_regs = [xy_laplacian]
+
+ current_object_tv = pylops.optimization.sparsity.splitbregman(
+ Op=Iop,
+ y=current_object.ravel(),
+ RegsL1=l1_regs,
+ niter_outer=niter_out,
+ niter_inner=niter_in,
+ epsRL1s=[weight],
+ tol=1e-4,
+ tau=1.0,
+ show=False,
+ )[0]
+
+ current_object_tv = current_object_tv.reshape(current_object.shape)
+
+ return current_object_tv
+
+ def _object_denoise_tv_chambolle(
+ self,
+ current_object,
+ weight,
+ axis,
+ pad_object,
+ eps=2.0e-4,
+ max_num_iter=200,
+ scaling=None,
+ ):
+ """
+ Perform total-variation denoising on n-dimensional images.
+
+ Parameters
+ ----------
+ current_object: np.ndarray
+ Current object estimate
+ weight : float, optional
+ Denoising weight. The greater `weight`, the more denoising (at
+ the expense of fidelity to `input`).
+ axis: int or tuple
+ Axis for denoising, if None uses all axes
+ pad_object: bool
+ if True, pads object with zeros along axes of blurring
+ eps : float, optional
+ Relative difference of the value of the cost function that determines
+ the stop criterion. The algorithm stops when:
+
+ (E_(n-1) - E_n) < eps * E_0
+
+ max_num_iter : int, optional
+ Maximal number of iterations used for the optimization.
+ scaling : tuple, optional
+ Scale weight of tv denoise on different axes
+
+ Returns
+ -------
+ constrained_object: np.ndarray
+ Constrained object estimate
+
+ Notes
+ -----
+ Rudin, Osher and Fatemi algorithm.
+ Adapted skimage.restoration.denoise_tv_chambolle.
+ """
+ xp = self._xp
+ if xp.iscomplexobj(current_object):
+ updated_object = current_object
+ warnings.warn(
+ ("TV denoising is currently only supported for potential objects."),
+ UserWarning,
+ )
+ else:
+ current_object_sum = xp.sum(current_object)
+ if axis is None:
+ ndim = xp.arange(current_object.ndim).tolist()
+ elif isinstance(axis, tuple):
+ ndim = list(axis)
+ else:
+ ndim = [axis]
+
+ if pad_object:
+ pad_width = ((0, 0),) * current_object.ndim
+ pad_width = list(pad_width)
+ for ax in range(len(ndim)):
+ pad_width[ndim[ax]] = (1, 1)
+ current_object = xp.pad(
+ current_object, pad_width=pad_width, mode="constant"
+ )
+
+ p = xp.zeros(
+ (current_object.ndim,) + current_object.shape,
+ dtype=current_object.dtype,
+ )
+ g = xp.zeros_like(p)
+ d = xp.zeros_like(current_object)
+
+ i = 0
+ while i < max_num_iter:
+ if i > 0:
+ # d will be the (negative) divergence of p
+ d = -p.sum(0)
+ slices_d = [
+ slice(None),
+ ] * current_object.ndim
+ slices_p = [
+ slice(None),
+ ] * (current_object.ndim + 1)
+ for ax in range(len(ndim)):
+ slices_d[ndim[ax]] = slice(1, None)
+ slices_p[ndim[ax] + 1] = slice(0, -1)
+ slices_p[0] = ndim[ax]
+ d[tuple(slices_d)] += p[tuple(slices_p)]
+ slices_d[ndim[ax]] = slice(None)
+ slices_p[ndim[ax] + 1] = slice(None)
+ updated_object = current_object + d
+ else:
+ updated_object = current_object
+ E = (d**2).sum()
+
+ # g stores the gradients of updated_object along each axis
+ # e.g. g[0] is the first order finite difference along axis 0
+ slices_g = [
+ slice(None),
+ ] * (current_object.ndim + 1)
+ for ax in range(len(ndim)):
+ slices_g[ndim[ax] + 1] = slice(0, -1)
+ slices_g[0] = ndim[ax]
+ g[tuple(slices_g)] = xp.diff(updated_object, axis=ndim[ax])
+ slices_g[ndim[ax] + 1] = slice(None)
+ if scaling is not None:
+ scaling /= xp.max(scaling)
+ g *= xp.array(scaling)[:, xp.newaxis, xp.newaxis]
+ norm = xp.sqrt((g**2).sum(axis=0))[xp.newaxis, ...]
+ E += weight * norm.sum()
+ tau = 1.0 / (2.0 * len(ndim))
+ norm *= tau / weight
+ norm += 1.0
+ p -= tau * g
+ p /= norm
+ E /= float(current_object.size)
+ if i == 0:
+ E_init = E
+ E_previous = E
+ else:
+ if xp.abs(E_previous - E) < eps * E_init:
+ break
+ else:
+ E_previous = E
+ i += 1
+
+ if pad_object:
+ for ax in range(len(ndim)):
+ slices = array_slice(ndim[ax], current_object.ndim, 1, -1)
+ updated_object = updated_object[slices]
+ updated_object = (
+ updated_object / xp.sum(updated_object) * current_object_sum
+ )
+
+ return updated_object
+
+ def _probe_center_of_mass_constraint(self, current_probe):
+ """
+ Ptychographic center of mass constraint.
+ Used for centering corner-centered probe intensity.
+
+ Parameters
+ --------
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ """
+ xp = self._xp
+
+ probe_intensity = xp.abs(current_probe) ** 2
+
+ probe_x0, probe_y0 = get_CoM(
+ probe_intensity, device=self._device, corner_centered=True
+ )
+ shifted_probe = fft_shift(current_probe, -xp.array([probe_x0, probe_y0]), xp)
+
+ return shifted_probe
+
+ def _probe_amplitude_constraint(
+ self, current_probe, relative_radius, relative_width
+ ):
+ """
+ Ptychographic top-hat filtering of probe.
+
+ Parameters
+ ----------
+ current_probe: np.ndarray
+ Current positions estimate
+ relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+
+ Returns
+ --------
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ """
+ xp = self._xp
+ erf = self._erf
+
+ probe_intensity = xp.abs(current_probe) ** 2
+ current_probe_sum = xp.sum(probe_intensity)
+
+ X = xp.fft.fftfreq(current_probe.shape[0])[:, None]
+ Y = xp.fft.fftfreq(current_probe.shape[1])[None]
+ r = xp.hypot(X, Y) - relative_radius
+
+ sigma = np.sqrt(np.pi) / relative_width
+ tophat_mask = 0.5 * (1 - erf(sigma * r / (1 - r**2)))
+
+ updated_probe = current_probe * tophat_mask
+ updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2)
+ normalization = xp.sqrt(current_probe_sum / updated_probe_sum)
+
+ return updated_probe * normalization
+
+ def _probe_fourier_amplitude_constraint(
+ self,
+ current_probe,
+ width_max_pixels,
+ enforce_constant_intensity,
+ ):
+ """
+ Ptychographic top-hat filtering of Fourier probe.
+
+ Parameters
+ ----------
+ current_probe: np.ndarray
+ Current positions estimate
+ threshold: np.ndarray
+ Threshold value for current probe fourier mask. Value should
+ be between 0 and 1, where 1 uses the maximum amplitude to threshold.
+ relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+
+ Returns
+ --------
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ current_probe_sum = xp.sum(xp.abs(current_probe) ** 2)
+ current_probe_fft = xp.fft.fft2(current_probe)
+
+ updated_probe_fft, _, _, _ = regularize_probe_amplitude(
+ asnumpy(current_probe_fft),
+ width_max_pixels=width_max_pixels,
+ nearest_angular_neighbor_averaging=5,
+ enforce_constant_intensity=enforce_constant_intensity,
+ corner_centered=True,
+ )
+
+ updated_probe_fft = xp.asarray(updated_probe_fft)
+ updated_probe = xp.fft.ifft2(updated_probe_fft)
+ updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2)
+ normalization = xp.sqrt(current_probe_sum / updated_probe_sum)
+
+ return updated_probe * normalization
+
+ def _probe_aperture_constraint(
+ self,
+ current_probe,
+ initial_probe_aperture,
+ ):
+ """
+ Ptychographic constraint to fix Fourier amplitude to initial aperture.
+
+ Parameters
+ ----------
+ current_probe: np.ndarray
+ Current positions estimate
+
+ Returns
+ --------
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ """
+ xp = self._xp
+
+ current_probe_sum = xp.sum(xp.abs(current_probe) ** 2)
+ current_probe_fft_phase = xp.angle(xp.fft.fft2(current_probe))
+
+ updated_probe = xp.fft.ifft2(
+ xp.exp(1j * current_probe_fft_phase) * initial_probe_aperture
+ )
+ updated_probe_sum = xp.sum(xp.abs(updated_probe) ** 2)
+ normalization = xp.sqrt(current_probe_sum / updated_probe_sum)
+
+ return updated_probe * normalization
+
+ def _probe_aberration_fitting_constraint(
+ self,
+ current_probe,
+ max_angular_order,
+ max_radial_order,
+ ):
+ """
+ Ptychographic probe smoothing constraint.
+ Removes/adds known (initialization) aberrations before/after smoothing.
+
+ Parameters
+ ----------
+ current_probe: np.ndarray
+ Current positions estimate
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A^-1
+ fix_amplitude: bool
+ If True, only the phase is smoothed
+
+ Returns
+ --------
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ """
+
+ xp = self._xp
+
+ fourier_probe = xp.fft.fft2(current_probe)
+ fourier_probe_abs = xp.abs(fourier_probe)
+ sampling = self.sampling
+ energy = self._energy
+
+ fitted_angle, _ = fit_aberration_surface(
+ fourier_probe,
+ sampling,
+ energy,
+ max_angular_order,
+ max_radial_order,
+ xp=xp,
+ )
+
+ fourier_probe = fourier_probe_abs * xp.exp(-1.0j * fitted_angle)
+ current_probe = xp.fft.ifft2(fourier_probe)
+
+ return current_probe
+
+ def _positions_center_of_mass_constraint(self, current_positions):
+ """
+ Ptychographic position center of mass constraint.
+ Additionally updates vectorized indices used in _overlap_projection.
+
+ Parameters
+ ----------
+ current_positions: np.ndarray
+ Current positions estimate
+
+ Returns
+ --------
+ constrained_positions: np.ndarray
+ CoM constrained positions estimate
+ """
+ xp = self._xp
+
+ current_positions -= xp.mean(current_positions, axis=0) - self._positions_px_com
+ self._positions_px_fractional = current_positions - xp.round(current_positions)
+
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ return current_positions
+
+ def _positions_affine_transformation_constraint(
+ self, initial_positions, current_positions
+ ):
+ """
+ Constrains the updated positions to be an affine transformation of the initial scan positions,
+ composing of two scale factors, a shear, and a rotation angle.
+
+ Uses RANSAC to estimate the global transformation robustly.
+ Stores the AffineTransformation in self._tf.
+
+ Parameters
+ ----------
+ initial_positions: np.ndarray
+ Initial scan positions
+ current_positions: np.ndarray
+ Current positions estimate
+
+ Returns
+ -------
+ constrained_positions: np.ndarray
+ Affine-transform constrained positions estimate
+ """
+
+ xp = self._xp
+
+ tf, _ = estimate_global_transformation_ransac(
+ positions0=initial_positions,
+ positions1=current_positions,
+ origin=self._positions_px_com,
+ translation_allowed=True,
+ min_sample=self._num_diffraction_patterns // 10,
+ xp=xp,
+ )
+
+ self._tf = tf
+ current_positions = tf(initial_positions, origin=self._positions_px_com, xp=xp)
+
+ return current_positions
diff --git a/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py
new file mode 100644
index 000000000..233d34e45
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_simultaneous_ptychography.py
@@ -0,0 +1,3479 @@
+"""
+Module for reconstructing phase objects from 4DSTEM datasets using iterative methods,
+namely joint ptychography.
+"""
+
+import warnings
+from typing import Mapping, Sequence, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+from emdfile import Custom, tqdmnd
+from py4DSTEM import DataCube
+from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction
+from py4DSTEM.process.phase.utils import (
+ ComplexProbe,
+ fft_shift,
+ generate_batches,
+ polar_aliases,
+ polar_symbols,
+)
+from py4DSTEM.process.utils import get_CoM, get_shifted_ar
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+
+class SimultaneousPtychographicReconstruction(PtychographicReconstruction):
+ """
+ Iterative Simultaneous Ptychographic Reconstruction Class.
+
+ Diffraction intensities dimensions : (Rx,Ry,Qx,Qy) (for each measurement)
+ Reconstructed probe dimensions : (Sx,Sy)
+ Reconstructed electrostatic dimensions : (Px,Py)
+ Reconstructed magnetic dimensions : (Px,Py)
+
+ such that (Sx,Sy) is the region-of-interest (ROI) size of our probe
+ and (Px,Py) is the padded-object size we position our ROI around in.
+
+ Parameters
+ ----------
+ datacube: Sequence[DataCube]
+ Tuple of input 4D diffraction pattern intensities
+ energy: float
+ The electron energy of the wave functions in eV
+ simultaneous_measurements_mode: str, optional
+ One of '-+', '-0+', '0+', where -/0/+ refer to the sign of the magnetic potential
+ semiangle_cutoff: float, optional
+ Semiangle cutoff for the initial probe guess in mrad
+ semiangle_cutoff_pixels: float, optional
+ Semiangle cutoff for the initial probe guess in pixels
+ rolloff: float, optional
+ Semiangle rolloff for the initial probe guess
+ vacuum_probe_intensity: np.ndarray, optional
+ Vacuum probe to use as intensity aperture for initial probe guess
+ polar_parameters: dict, optional
+ Mapping from aberration symbols to their corresponding values. All aberration
+ magnitudes should be given in Å and angles should be given in radians.
+ object_padding_px: Tuple[int,int], optional
+ Pixel dimensions to pad objects with
+ If None, the padding is set to half the probe ROI dimensions
+ positions_mask: np.ndarray, optional
+ Boolean real space mask to select positions in datacube to skip for reconstruction
+ initial_object_guess: np.ndarray, optional
+ Initial guess for complex-valued object of dimensions (Px,Py)
+ If None, initialized to 1.0j
+ initial_probe_guess: np.ndarray, optional
+ Initial guess for complex-valued probe of dimensions (Sx,Sy). If None,
+ initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations
+ initial_scan_positions: np.ndarray, optional
+ Probe positions in Å for each diffraction intensity
+ If None, initialized to a grid scan
+ verbose: bool, optional
+ If True, class methods will inherit this and print additional information
+ device: str, optional
+ Calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ object_type: str, optional
+ The object can be reconstructed as a real potential ('potential') or a complex
+ object ('complex')
+ name: str, optional
+ Class name
+ kwargs:
+ Provide the aberration coefficients as keyword arguments.
+ """
+
+ # Class-specific Metadata
+ _class_specific_metadata = ("_simultaneous_measurements_mode",)
+
+ def __init__(
+ self,
+ energy: float,
+ datacube: Sequence[DataCube] = None,
+ simultaneous_measurements_mode: str = "-+",
+ semiangle_cutoff: float = None,
+ semiangle_cutoff_pixels: float = None,
+ rolloff: float = 2.0,
+ vacuum_probe_intensity: np.ndarray = None,
+ polar_parameters: Mapping[str, float] = None,
+ object_padding_px: Tuple[int, int] = None,
+ positions_mask: np.ndarray = None,
+ initial_object_guess: np.ndarray = None,
+ initial_probe_guess: np.ndarray = None,
+ initial_scan_positions: np.ndarray = None,
+ object_type: str = "complex",
+ verbose: bool = True,
+ device: str = "cpu",
+ name: str = "simultaneous_ptychographic_reconstruction",
+ **kwargs,
+ ):
+ Custom.__init__(self, name=name)
+
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from scipy.special import erf
+
+ self._erf = erf
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from cupyx.scipy.special import erf
+
+ self._erf = erf
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ for key in kwargs.keys():
+ if (key not in polar_symbols) and (key not in polar_aliases.keys()):
+ raise ValueError("{} not a recognized parameter".format(key))
+
+ self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))
+
+ if polar_parameters is None:
+ polar_parameters = {}
+
+ polar_parameters.update(kwargs)
+ self._set_polar_parameters(polar_parameters)
+
+ if object_type != "potential" and object_type != "complex":
+ raise ValueError(
+ f"object_type must be either 'potential' or 'complex', not {object_type}"
+ )
+ if positions_mask is not None and positions_mask.dtype != "bool":
+ warnings.warn(
+ ("`positions_mask` converted to `bool` array"),
+ UserWarning,
+ )
+ positions_mask = np.asarray(positions_mask, dtype="bool")
+
+ self.set_save_defaults()
+
+ # Data
+ self._datacube = datacube
+ self._object = initial_object_guess
+ self._probe = initial_probe_guess
+
+ # Common Metadata
+ self._vacuum_probe_intensity = vacuum_probe_intensity
+ self._scan_positions = initial_scan_positions
+ self._energy = energy
+ self._semiangle_cutoff = semiangle_cutoff
+ self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
+ self._rolloff = rolloff
+ self._object_type = object_type
+ self._object_padding_px = object_padding_px
+ self._positions_mask = positions_mask
+ self._verbose = verbose
+ self._device = device
+ self._preprocessed = False
+
+ # Class-specific Metadata
+ self._simultaneous_measurements_mode = simultaneous_measurements_mode
+
+ def preprocess(
+ self,
+ diffraction_intensities_shape: Tuple[int, int] = None,
+ reshaping_method: str = "fourier",
+ probe_roi_shape: Tuple[int, int] = None,
+ dp_mask: np.ndarray = None,
+ fit_function: str = "plane",
+ plot_rotation: bool = True,
+ maximize_divergence: bool = False,
+ rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0),
+ plot_probe_overlaps: bool = True,
+ force_com_rotation: float = None,
+ force_com_transpose: float = None,
+ force_com_shifts: float = None,
+ force_scan_sampling: float = None,
+ force_angular_sampling: float = None,
+ force_reciprocal_sampling: float = None,
+ object_fov_mask: np.ndarray = None,
+ crop_patterns: bool = False,
+ **kwargs,
+ ):
+ """
+ Ptychographic preprocessing step.
+ Calls the base class methods:
+
+ _extract_intensities_and_calibrations_from_datacube,
+ _compute_center_of_mass(),
+ _solve_CoM_rotation(),
+ _normalize_diffraction_intensities()
+ _calculate_scan_positions_in_px()
+
+ Additionally, it initializes an (Px,Py) array of 1.0j
+ and a complex probe using the specified polar parameters.
+
+ Parameters
+ ----------
+ diffraction_intensities_shape: Tuple[int,int], optional
+ Pixel dimensions (Qx',Qy') of the resampled diffraction intensities
+ If None, no resampling of diffraction intenstities is performed
+ reshaping_method: str, optional
+ Method to use for reshaping, either 'bin', 'bilinear', or 'fourier' (default)
+ probe_roi_shape, (int,int), optional
+ Padded diffraction intensities shape.
+ If None, no padding is performed
+ dp_mask: ndarray, optional
+ Mask for datacube intensities (Qx,Qy)
+ fit_function: str, optional
+ 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two'
+ plot_rotation: bool, optional
+ If True, the CoM curl minimization search result will be displayed
+ maximize_divergence: bool, optional
+ If True, the divergence of the CoM gradient vector field is maximized
+ rotation_angles_deg: np.darray, optional
+ Array of angles in degrees to perform curl minimization over
+ plot_probe_overlaps: bool, optional
+ If True, initial probe overlaps scanned over the object will be displayed
+ force_com_rotation: float (degrees), optional
+ Force relative rotation angle between real and reciprocal space
+ force_com_transpose: bool, optional
+ Force whether diffraction intensities need to be transposed.
+ force_com_shifts: sequence of tuples of ndarrays (CoMx, CoMy)
+ Amplitudes come from diffraction patterns shifted with
+ the CoM in the upper left corner for each probe unless
+ shift is overwritten.
+ force_scan_sampling: float, optional
+ Override DataCube real space scan pixel size calibrations, in Angstrom
+ force_angular_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in mrad
+ force_reciprocal_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in A^-1
+ object_fov_mask: np.ndarray (boolean)
+ Boolean mask of FOV. Used to calculate additional shrinkage of object
+ If None, probe_overlap intensity is thresholded
+ crop_patterns: bool
+ if True, crop patterns to avoid wrap around of patterns when centering
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # set additional metadata
+ self._diffraction_intensities_shape = diffraction_intensities_shape
+ self._reshaping_method = reshaping_method
+ self._probe_roi_shape = probe_roi_shape
+ self._dp_mask = dp_mask
+
+ if self._datacube is None:
+ raise ValueError(
+ (
+ "The preprocess() method requires a DataCube. "
+ "Please run ptycho.attach_datacube(DataCube) first."
+ )
+ )
+
+ if self._simultaneous_measurements_mode == "-+":
+ self._sim_recon_mode = 0
+ self._num_sim_measurements = 2
+ if self._verbose:
+ print(
+ (
+ "Magnetic vector potential sign in first meaurement assumed to be negative.\n"
+ "Magnetic vector potential sign in second meaurement assumed to be positive."
+ )
+ )
+ if len(self._datacube) != 2:
+ raise ValueError(
+ f"datacube must be a set of two measurements, not length {len(self._datacube)}."
+ )
+ if self._datacube[0].shape != self._datacube[1].shape:
+ raise ValueError("datacube intensities must be the same size.")
+ elif self._simultaneous_measurements_mode == "-0+":
+ self._sim_recon_mode = 1
+ self._num_sim_measurements = 3
+ if self._verbose:
+ print(
+ (
+ "Magnetic vector potential sign in first meaurement assumed to be negative.\n"
+ "Magnetic vector potential assumed to be zero in second meaurement.\n"
+ "Magnetic vector potential sign in third meaurement assumed to be positive."
+ )
+ )
+ if len(self._datacube) != 3:
+ raise ValueError(
+ f"datacube must be a set of three measurements, not length {len(self._datacube)}."
+ )
+ if (
+ self._datacube[0].shape != self._datacube[1].shape
+ or self._datacube[0].shape != self._datacube[2].shape
+ ):
+ raise ValueError("datacube intensities must be the same size.")
+ elif self._simultaneous_measurements_mode == "0+":
+ self._sim_recon_mode = 2
+ self._num_sim_measurements = 2
+ if self._verbose:
+ print(
+ (
+ "Magnetic vector potential assumed to be zero in first meaurement.\n"
+ "Magnetic vector potential sign in second meaurement assumed to be positive."
+ )
+ )
+ if len(self._datacube) != 2:
+ raise ValueError(
+ f"datacube must be a set of two measurements, not length {len(self._datacube)}."
+ )
+ if self._datacube[0].shape != self._datacube[1].shape:
+ raise ValueError("datacube intensities must be the same size.")
+ else:
+ raise ValueError(
+ f"simultaneous_measurements_mode must be either '-+', '-0+', or '0+', not {self._simultaneous_measurements_mode}"
+ )
+
+ if force_com_shifts is None:
+ force_com_shifts = [None, None, None]
+ elif len(force_com_shifts) == self._num_sim_measurements:
+ force_com_shifts = list(force_com_shifts)
+ else:
+ raise ValueError(
+ (
+ "force_com_shifts must be a sequence of tuples "
+ "with the same length as the datasets."
+ )
+ )
+
+ # Ensure plot_center_of_mass is not in kwargs
+ kwargs.pop("plot_center_of_mass", None)
+
+ # 1st measurement sets rotation angle and transposition
+ (
+ measurement_0,
+ self._vacuum_probe_intensity,
+ self._dp_mask,
+ force_com_shifts[0],
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube[0],
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ dp_mask=self._dp_mask,
+ com_shifts=force_com_shifts[0],
+ )
+
+ intensities_0 = self._extract_intensities_and_calibrations_from_datacube(
+ measurement_0,
+ require_calibrations=True,
+ force_scan_sampling=force_scan_sampling,
+ force_angular_sampling=force_angular_sampling,
+ force_reciprocal_sampling=force_reciprocal_sampling,
+ )
+
+ (
+ com_measured_x_0,
+ com_measured_y_0,
+ com_fitted_x_0,
+ com_fitted_y_0,
+ com_normalized_x_0,
+ com_normalized_y_0,
+ ) = self._calculate_intensities_center_of_mass(
+ intensities_0,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts[0],
+ )
+
+ (
+ self._rotation_best_rad,
+ self._rotation_best_transpose,
+ _com_x_0,
+ _com_y_0,
+ com_x_0,
+ com_y_0,
+ ) = self._solve_for_center_of_mass_relative_rotation(
+ com_measured_x_0,
+ com_measured_y_0,
+ com_normalized_x_0,
+ com_normalized_y_0,
+ rotation_angles_deg=rotation_angles_deg,
+ plot_rotation=plot_rotation,
+ plot_center_of_mass=False,
+ maximize_divergence=maximize_divergence,
+ force_com_rotation=force_com_rotation,
+ force_com_transpose=force_com_transpose,
+ **kwargs,
+ )
+
+ (
+ amplitudes_0,
+ mean_diffraction_intensity_0,
+ ) = self._normalize_diffraction_intensities(
+ intensities_0,
+ com_fitted_x_0,
+ com_fitted_y_0,
+ crop_patterns,
+ self._positions_mask,
+ )
+
+ # explicitly delete namescapes
+ del (
+ intensities_0,
+ com_measured_x_0,
+ com_measured_y_0,
+ com_fitted_x_0,
+ com_fitted_y_0,
+ com_normalized_x_0,
+ com_normalized_y_0,
+ _com_x_0,
+ _com_y_0,
+ com_x_0,
+ com_y_0,
+ )
+
+ # 2nd measurement
+ (
+ measurement_1,
+ _,
+ _,
+ force_com_shifts[1],
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube[1],
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=None,
+ dp_mask=None,
+ com_shifts=force_com_shifts[1],
+ )
+
+ intensities_1 = self._extract_intensities_and_calibrations_from_datacube(
+ measurement_1,
+ require_calibrations=True,
+ force_scan_sampling=force_scan_sampling,
+ force_angular_sampling=force_angular_sampling,
+ force_reciprocal_sampling=force_reciprocal_sampling,
+ )
+
+ (
+ com_measured_x_1,
+ com_measured_y_1,
+ com_fitted_x_1,
+ com_fitted_y_1,
+ com_normalized_x_1,
+ com_normalized_y_1,
+ ) = self._calculate_intensities_center_of_mass(
+ intensities_1,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts[1],
+ )
+
+ (
+ _,
+ _,
+ _com_x_1,
+ _com_y_1,
+ com_x_1,
+ com_y_1,
+ ) = self._solve_for_center_of_mass_relative_rotation(
+ com_measured_x_1,
+ com_measured_y_1,
+ com_normalized_x_1,
+ com_normalized_y_1,
+ rotation_angles_deg=rotation_angles_deg,
+ plot_rotation=plot_rotation,
+ plot_center_of_mass=False,
+ maximize_divergence=maximize_divergence,
+ force_com_rotation=np.rad2deg(self._rotation_best_rad),
+ force_com_transpose=self._rotation_best_transpose,
+ **kwargs,
+ )
+
+ (
+ amplitudes_1,
+ mean_diffraction_intensity_1,
+ ) = self._normalize_diffraction_intensities(
+ intensities_1,
+ com_fitted_x_1,
+ com_fitted_y_1,
+ crop_patterns,
+ self._positions_mask,
+ )
+
+ # explicitly delete namescapes
+ del (
+ intensities_1,
+ com_measured_x_1,
+ com_measured_y_1,
+ com_fitted_x_1,
+ com_fitted_y_1,
+ com_normalized_x_1,
+ com_normalized_y_1,
+ _com_x_1,
+ _com_y_1,
+ com_x_1,
+ com_y_1,
+ )
+
+ # Optionally, 3rd measurement
+ if self._num_sim_measurements == 3:
+ (
+ measurement_2,
+ _,
+ _,
+ force_com_shifts[2],
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube[2],
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=None,
+ dp_mask=None,
+ com_shifts=force_com_shifts[2],
+ )
+
+ intensities_2 = self._extract_intensities_and_calibrations_from_datacube(
+ measurement_2,
+ require_calibrations=True,
+ force_scan_sampling=force_scan_sampling,
+ force_angular_sampling=force_angular_sampling,
+ force_reciprocal_sampling=force_reciprocal_sampling,
+ )
+
+ (
+ com_measured_x_2,
+ com_measured_y_2,
+ com_fitted_x_2,
+ com_fitted_y_2,
+ com_normalized_x_2,
+ com_normalized_y_2,
+ ) = self._calculate_intensities_center_of_mass(
+ intensities_2,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts[2],
+ )
+
+ (
+ _,
+ _,
+ _com_x_2,
+ _com_y_2,
+ com_x_2,
+ com_y_2,
+ ) = self._solve_for_center_of_mass_relative_rotation(
+ com_measured_x_2,
+ com_measured_y_2,
+ com_normalized_x_2,
+ com_normalized_y_2,
+ rotation_angles_deg=rotation_angles_deg,
+ plot_rotation=plot_rotation,
+ plot_center_of_mass=False,
+ maximize_divergence=maximize_divergence,
+ force_com_rotation=np.rad2deg(self._rotation_best_rad),
+ force_com_transpose=self._rotation_best_transpose,
+ **kwargs,
+ )
+
+ (
+ amplitudes_2,
+ mean_diffraction_intensity_2,
+ ) = self._normalize_diffraction_intensities(
+ intensities_2,
+ com_fitted_x_2,
+ com_fitted_y_2,
+ crop_patterns,
+ self._positions_mask,
+ )
+
+ # explicitly delete namescapes
+ del (
+ intensities_2,
+ com_measured_x_2,
+ com_measured_y_2,
+ com_fitted_x_2,
+ com_fitted_y_2,
+ com_normalized_x_2,
+ com_normalized_y_2,
+ _com_x_2,
+ _com_y_2,
+ com_x_2,
+ com_y_2,
+ )
+
+ self._amplitudes = (amplitudes_0, amplitudes_1, amplitudes_2)
+ self._mean_diffraction_intensity = (
+ mean_diffraction_intensity_0
+ + mean_diffraction_intensity_1
+ + mean_diffraction_intensity_2
+ ) / 3
+
+ del amplitudes_0, amplitudes_1, amplitudes_2
+
+ else:
+ self._amplitudes = (amplitudes_0, amplitudes_1)
+ self._mean_diffraction_intensity = (
+ mean_diffraction_intensity_0 + mean_diffraction_intensity_1
+ ) / 2
+
+ del amplitudes_0, amplitudes_1
+
+ # explicitly delete namespace
+ self._num_diffraction_patterns = self._amplitudes[0].shape[0]
+ self._region_of_interest_shape = np.array(self._amplitudes[0].shape[-2:])
+
+ self._positions_px = self._calculate_scan_positions_in_pixels(
+ self._scan_positions, self._positions_mask
+ )
+
+ # handle semiangle specified in pixels
+ if self._semiangle_cutoff_pixels:
+ self._semiangle_cutoff = (
+ self._semiangle_cutoff_pixels * self._angular_sampling[0]
+ )
+
+ # Object Initialization
+ if self._object is None:
+ pad_x = self._object_padding_px[0][1]
+ pad_y = self._object_padding_px[1][1]
+ p, q = np.round(np.max(self._positions_px, axis=0))
+ p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype(
+ "int"
+ )
+ q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype(
+ "int"
+ )
+ if self._object_type == "potential":
+ object_e = xp.zeros((p, q), dtype=xp.float32)
+ elif self._object_type == "complex":
+ object_e = xp.ones((p, q), dtype=xp.complex64)
+ object_m = xp.zeros((p, q), dtype=xp.float32)
+ else:
+ if self._object_type == "potential":
+ object_e = xp.asarray(self._object[0], dtype=xp.float32)
+ elif self._object_type == "complex":
+ object_e = xp.asarray(self._object[0], dtype=xp.complex64)
+ object_m = xp.asarray(self._object[1], dtype=xp.float32)
+
+ self._object = (object_e, object_m)
+ self._object_initial = (object_e.copy(), object_m.copy())
+ self._object_type_initial = self._object_type
+ self._object_shape = self._object[0].shape
+
+ self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32)
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+
+ self._positions_px_initial = self._positions_px.copy()
+ self._positions_initial = self._positions_px_initial.copy()
+ self._positions_initial[:, 0] *= self.sampling[0]
+ self._positions_initial[:, 1] *= self.sampling[1]
+
+ # Vectorized Patches
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ # Probe Initialization
+ if self._probe is None:
+ if self._vacuum_probe_intensity is not None:
+ self._semiangle_cutoff = np.inf
+ self._vacuum_probe_intensity = xp.asarray(
+ self._vacuum_probe_intensity, dtype=xp.float32
+ )
+ probe_x0, probe_y0 = get_CoM(
+ self._vacuum_probe_intensity, device=self._device
+ )
+ self._vacuum_probe_intensity = get_shifted_ar(
+ self._vacuum_probe_intensity,
+ -probe_x0,
+ -probe_y0,
+ bilinear=True,
+ device=self._device,
+ )
+ if crop_patterns:
+ self._vacuum_probe_intensity = self._vacuum_probe_intensity[
+ self._crop_mask
+ ].reshape(self._region_of_interest_shape)
+
+ self._probe = (
+ ComplexProbe(
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ energy=self._energy,
+ semiangle_cutoff=self._semiangle_cutoff,
+ rolloff=self._rolloff,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )
+ .build()
+ ._array
+ )
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity)
+
+ else:
+ if isinstance(self._probe, ComplexProbe):
+ if self._probe._gpts != self._region_of_interest_shape:
+ raise ValueError()
+ if hasattr(self._probe, "_array"):
+ self._probe = self._probe._array
+ else:
+ self._probe._xp = xp
+ self._probe = self._probe.build()._array
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(
+ self._mean_diffraction_intensity / probe_intensity
+ )
+ else:
+ self._probe = xp.asarray(self._probe, dtype=xp.complex64)
+
+ self._probe_initial = self._probe.copy()
+ self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe))
+
+ self._known_aberrations_array = ComplexProbe(
+ energy=self._energy,
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )._evaluate_ctf()
+
+ # overlaps
+ shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp)
+ probe_intensities = xp.abs(shifted_probes) ** 2
+ probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities)
+ probe_overlap = self._gaussian_filter(probe_overlap, 1.0)
+
+ if object_fov_mask is None:
+ self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max())
+ else:
+ self._object_fov_mask = np.asarray(object_fov_mask)
+ self._object_fov_mask_inverse = np.invert(self._object_fov_mask)
+
+ if plot_probe_overlaps:
+ figsize = kwargs.pop("figsize", (9, 4))
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ # initial probe
+ complex_probe_rgb = Complex2RGB(
+ self.probe_centered,
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ extent = [
+ 0,
+ self.sampling[1] * self._object_shape[1],
+ self.sampling[0] * self._object_shape[0],
+ 0,
+ ]
+
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
+
+ ax1.imshow(
+ complex_probe_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax1)
+ cax1 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(
+ cax1,
+ chroma_boost=chroma_boost,
+ )
+ ax1.set_ylabel("x [A]")
+ ax1.set_xlabel("y [A]")
+ ax1.set_title("Initial probe intensity")
+
+ ax2.imshow(
+ asnumpy(probe_overlap),
+ extent=extent,
+ cmap="Greys_r",
+ )
+ ax2.scatter(
+ self.positions[:, 1],
+ self.positions[:, 0],
+ s=2.5,
+ color=(1, 0, 0, 1),
+ )
+ ax2.set_ylabel("x [A]")
+ ax2.set_xlabel("y [A]")
+ ax2.set_xlim((extent[0], extent[1]))
+ ax2.set_ylim((extent[2], extent[3]))
+ ax2.set_title("Object field of view")
+
+ fig.tight_layout()
+
+ self._preprocessed = True
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _warmup_overlap_projection(self, current_object, current_probe):
+ """
+ Ptychographic overlap projection method.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ object_patches: np.ndarray
+ Patched object view
+ overlap: np.ndarray
+ shifted_probes * object_patches
+ """
+
+ xp = self._xp
+
+ shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp)
+
+ electrostatic_obj, _ = current_object
+
+ if self._object_type == "potential":
+ complex_object = xp.exp(1j * electrostatic_obj)
+ else:
+ complex_object = electrostatic_obj
+
+ electrostatic_obj_patches = complex_object[
+ self._vectorized_patch_indices_row, self._vectorized_patch_indices_col
+ ]
+
+ object_patches = (electrostatic_obj_patches, None)
+ overlap = (shifted_probes * electrostatic_obj_patches, None)
+
+ return shifted_probes, object_patches, overlap
+
+ def _overlap_projection(self, current_object, current_probe):
+ """
+ Ptychographic overlap projection method.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ object_patches: np.ndarray
+ Patched object view
+ overlap: np.ndarray
+ shifted_probes * object_patches
+ """
+
+ xp = self._xp
+
+ shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp)
+
+ electrostatic_obj, magnetic_obj = current_object
+
+ if self._object_type == "potential":
+ complex_object_e = xp.exp(1j * electrostatic_obj)
+ else:
+ complex_object_e = electrostatic_obj
+
+ complex_object_m = xp.exp(1j * magnetic_obj)
+
+ electrostatic_obj_patches = complex_object_e[
+ self._vectorized_patch_indices_row, self._vectorized_patch_indices_col
+ ]
+ magnetic_obj_patches = complex_object_m[
+ self._vectorized_patch_indices_row, self._vectorized_patch_indices_col
+ ]
+
+ object_patches = (electrostatic_obj_patches, magnetic_obj_patches)
+
+ if self._sim_recon_mode == 0:
+ overlap_reverse = (
+ shifted_probes
+ * electrostatic_obj_patches
+ * xp.conj(magnetic_obj_patches)
+ )
+ overlap_forward = (
+ shifted_probes * electrostatic_obj_patches * magnetic_obj_patches
+ )
+ overlap = (overlap_reverse, overlap_forward)
+ elif self._sim_recon_mode == 1:
+ overlap_reverse = (
+ shifted_probes
+ * electrostatic_obj_patches
+ * xp.conj(magnetic_obj_patches)
+ )
+ overlap_neutral = shifted_probes * electrostatic_obj_patches
+ overlap_forward = (
+ shifted_probes * electrostatic_obj_patches * magnetic_obj_patches
+ )
+ overlap = (overlap_reverse, overlap_neutral, overlap_forward)
+ else:
+ overlap_neutral = shifted_probes * electrostatic_obj_patches
+ overlap_forward = (
+ shifted_probes * electrostatic_obj_patches * magnetic_obj_patches
+ )
+ overlap = (overlap_neutral, overlap_forward)
+
+ return shifted_probes, object_patches, overlap
+
+ def _warmup_gradient_descent_fourier_projection(self, amplitudes, overlap):
+ """
+ Ptychographic fourier projection method for GD method.
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ overlap: np.ndarray
+ object * probe overlap
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Difference between modified and estimated exit waves
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+
+ fourier_overlap = xp.fft.fft2(overlap[0])
+ error = xp.sum(xp.abs(amplitudes[0] - xp.abs(fourier_overlap)) ** 2)
+
+ fourier_modified_overlap = amplitudes[0] * xp.exp(
+ 1j * xp.angle(fourier_overlap)
+ )
+ modified_overlap = xp.fft.ifft2(fourier_modified_overlap)
+
+ exit_waves = (modified_overlap - overlap[0],) + (None,) * (
+ self._num_sim_measurements - 1
+ )
+
+ return exit_waves, error
+
+ def _gradient_descent_fourier_projection(self, amplitudes, overlap):
+ """
+ Ptychographic fourier projection method for GD method.
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ overlap: np.ndarray
+ object * probe overlap
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Difference between modified and estimated exit waves
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+
+ error = 0.0
+ exit_waves = []
+ for amp, overl in zip(amplitudes, overlap):
+ fourier_overl = xp.fft.fft2(overl)
+ error += xp.sum(xp.abs(amp - xp.abs(fourier_overl)) ** 2)
+
+ fourier_modified_overl = amp * xp.exp(1j * xp.angle(fourier_overl))
+ modified_overl = xp.fft.ifft2(fourier_modified_overl)
+
+ exit_waves.append(modified_overl - overl)
+
+ error /= len(exit_waves)
+ exit_waves = tuple(exit_waves)
+
+ return exit_waves, error
+
+ def _warmup_projection_sets_fourier_projection(
+ self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c
+ ):
+ """
+ Ptychographic fourier projection method for DM_AP and RAAR methods.
+ Generalized projection using three parameters: a,b,c
+
+ DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha
+ DM: DM_AP(1.0), AP: DM_AP(0.0)
+
+ RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2
+ DM : RAAR(1.0)
+
+ RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2
+ DM: RRR(1.0)
+
+ SUPERFLIP : a = 0, b = 1, c = 2
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ overlap: np.ndarray
+ object * probe overlap
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ projection_x = 1 - projection_a - projection_b
+ projection_y = 1 - projection_c
+
+ exit_wave = exit_waves[0]
+
+ if exit_wave is None:
+ exit_wave = overlap[0].copy()
+
+ fourier_overlap = xp.fft.fft2(overlap[0])
+ error = xp.sum(xp.abs(amplitudes[0] - xp.abs(fourier_overlap)) ** 2)
+
+ factor_to_be_projected = projection_c * overlap[0] + projection_y * exit_wave
+ fourier_projected_factor = xp.fft.fft2(factor_to_be_projected)
+
+ fourier_projected_factor = amplitudes[0] * xp.exp(
+ 1j * xp.angle(fourier_projected_factor)
+ )
+ projected_factor = xp.fft.ifft2(fourier_projected_factor)
+
+ exit_wave = (
+ projection_x * exit_wave
+ + projection_a * overlap[0]
+ + projection_b * projected_factor
+ )
+
+ exit_waves = (exit_wave,) + (None,) * (self._num_sim_measurements - 1)
+
+ return exit_waves, error
+
+ def _projection_sets_fourier_projection(
+ self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c
+ ):
+ """
+ Ptychographic fourier projection method for DM_AP and RAAR methods.
+ Generalized projection using three parameters: a,b,c
+
+ DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha
+ DM: DM_AP(1.0), AP: DM_AP(0.0)
+
+ RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2
+ DM : RAAR(1.0)
+
+ RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2
+ DM: RRR(1.0)
+
+ SUPERFLIP : a = 0, b = 1, c = 2
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ overlap: np.ndarray
+ object * probe overlap
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ projection_x = 1 - projection_a - projection_b
+ projection_y = 1 - projection_c
+
+ error = 0.0
+ _exit_waves = []
+ for amp, overl, exit_wave in zip(amplitudes, overlap, exit_waves):
+ if exit_wave is None:
+ exit_wave = overl.copy()
+
+ fourier_overl = xp.fft.fft2(overl)
+ error += xp.sum(xp.abs(amp - xp.abs(fourier_overl)) ** 2)
+
+ factor_to_be_projected = projection_c * overl + projection_y * exit_wave
+ fourier_projected_factor = xp.fft.fft2(factor_to_be_projected)
+
+ fourier_projected_factor = amp * xp.exp(
+ 1j * xp.angle(fourier_projected_factor)
+ )
+ projected_factor = xp.fft.ifft2(fourier_projected_factor)
+
+ _exit_waves.append(
+ projection_x * exit_wave
+ + projection_a * overl
+ + projection_b * projected_factor
+ )
+
+ error /= len(_exit_waves)
+ exit_waves = tuple(_exit_waves)
+
+ return exit_waves, error
+
+ def _forward(
+ self,
+ current_object,
+ current_probe,
+ amplitudes,
+ exit_waves,
+ warmup_iteration,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic forward operator.
+ Calls _overlap_projection() and the appropriate _fourier_projection().
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ object_patches: np.ndarray
+ Patched object view
+ overlap: np.ndarray
+ object * probe overlap
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+ if warmup_iteration:
+ shifted_probes, object_patches, overlap = self._warmup_overlap_projection(
+ current_object, current_probe
+ )
+ if use_projection_scheme:
+ exit_waves, error = self._warmup_projection_sets_fourier_projection(
+ amplitudes,
+ overlap,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ else:
+ exit_waves, error = self._warmup_gradient_descent_fourier_projection(
+ amplitudes, overlap
+ )
+
+ else:
+ shifted_probes, object_patches, overlap = self._overlap_projection(
+ current_object, current_probe
+ )
+ if use_projection_scheme:
+ exit_waves, error = self._projection_sets_fourier_projection(
+ amplitudes,
+ overlap,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ else:
+ exit_waves, error = self._gradient_descent_fourier_projection(
+ amplitudes, overlap
+ )
+
+ return shifted_probes, object_patches, overlap, exit_waves, error
+
+ def _warmup_gradient_descent_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ electrostatic_obj, _ = current_object
+ electrostatic_obj_patches, _ = object_patches
+
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(shifted_probes) ** 2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ if self._object_type == "potential":
+ electrostatic_obj += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * xp.conj(electrostatic_obj_patches)
+ * xp.conj(shifted_probes)
+ * exit_waves[0]
+ )
+ )
+ * probe_normalization
+ )
+ elif self._object_type == "complex":
+ electrostatic_obj += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ xp.conj(shifted_probes) * exit_waves[0]
+ )
+ * probe_normalization
+ )
+
+ if not fix_probe:
+ object_normalization = xp.sum(
+ (xp.abs(electrostatic_obj_patches) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe += step_size * (
+ xp.sum(
+ xp.conj(electrostatic_obj_patches) * exit_waves[0],
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return (electrostatic_obj, None), current_probe
+
+ def _gradient_descent_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ electrostatic_obj, magnetic_obj = current_object
+ probe_conj = xp.conj(shifted_probes)
+
+ electrostatic_obj_patches, magnetic_obj_patches = object_patches
+ electrostatic_conj = xp.conj(electrostatic_obj_patches)
+ magnetic_conj = xp.conj(magnetic_obj_patches)
+
+ probe_electrostatic_abs = xp.abs(shifted_probes * electrostatic_obj_patches)
+ probe_magnetic_abs = xp.abs(shifted_probes * magnetic_obj_patches)
+
+ probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts(
+ probe_electrostatic_abs**2
+ )
+ probe_electrostatic_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2
+ + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2
+ )
+
+ probe_magnetic_normalization = self._sum_overlapping_patches_bincounts(
+ probe_magnetic_abs**2
+ )
+ probe_magnetic_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_magnetic_normalization) ** 2
+ + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2
+ )
+
+ if self._sim_recon_mode > 0:
+ probe_abs = xp.abs(shifted_probes)
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ probe_abs**2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ if self._sim_recon_mode == 0:
+ exit_waves_reverse, exit_waves_forward = exit_waves
+
+ if self._object_type == "potential":
+ electrostatic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * magnetic_obj_patches
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_reverse
+ )
+ )
+ * probe_magnetic_normalization
+ )
+ / 2
+ )
+ electrostatic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * magnetic_conj
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_forward
+ )
+ )
+ * probe_magnetic_normalization
+ )
+ / 2
+ )
+
+ elif self._object_type == "complex":
+ electrostatic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * magnetic_obj_patches * exit_waves_reverse
+ )
+ * probe_magnetic_normalization
+ )
+ / 2
+ )
+ electrostatic_obj += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * magnetic_conj * exit_waves_forward
+ )
+ * probe_magnetic_normalization
+ / 2
+ )
+
+ magnetic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ 1j
+ * magnetic_obj_patches
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_reverse
+ )
+ )
+ * probe_electrostatic_normalization
+ )
+ / 2
+ )
+ magnetic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * magnetic_conj
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_forward
+ )
+ )
+ * probe_electrostatic_normalization
+ )
+ / 2
+ )
+
+ elif self._sim_recon_mode == 1:
+ exit_waves_reverse, exit_waves_neutral, exit_waves_forward = exit_waves
+
+ if self._object_type == "potential":
+ electrostatic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * magnetic_obj_patches
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_reverse
+ )
+ )
+ * probe_magnetic_normalization
+ )
+ / 3
+ )
+ electrostatic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_neutral
+ )
+ )
+ * probe_normalization
+ )
+ / 3
+ )
+ electrostatic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * magnetic_conj
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_forward
+ )
+ )
+ * probe_magnetic_normalization
+ )
+ / 3
+ )
+
+ elif self._object_type == "complex":
+ electrostatic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * magnetic_obj_patches * exit_waves_reverse
+ )
+ * probe_magnetic_normalization
+ )
+ / 3
+ )
+ electrostatic_obj += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * exit_waves_neutral
+ )
+ * probe_normalization
+ / 3
+ )
+ electrostatic_obj += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * magnetic_conj * exit_waves_forward
+ )
+ * probe_magnetic_normalization
+ / 3
+ )
+
+ magnetic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ 1j
+ * magnetic_obj_patches
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_reverse
+ )
+ )
+ * probe_electrostatic_normalization
+ )
+ / 2
+ )
+ magnetic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * magnetic_conj
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_forward
+ )
+ )
+ * probe_electrostatic_normalization
+ )
+ / 2
+ )
+
+ else:
+ exit_waves_neutral, exit_waves_forward = exit_waves
+
+ if self._object_type == "potential":
+ electrostatic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_neutral
+ )
+ )
+ * probe_normalization
+ )
+ / 2
+ )
+ electrostatic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * magnetic_conj
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_forward
+ )
+ )
+ * probe_magnetic_normalization
+ )
+ / 2
+ )
+
+ elif self._object_type == "complex":
+ electrostatic_obj += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * exit_waves_neutral
+ )
+ * probe_normalization
+ / 2
+ )
+ electrostatic_obj += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * magnetic_conj * exit_waves_forward
+ )
+ * probe_magnetic_normalization
+ / 2
+ )
+
+ magnetic_obj += (
+ step_size
+ * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * magnetic_conj
+ * electrostatic_conj
+ * xp.conj(shifted_probes)
+ * exit_waves_forward
+ )
+ )
+ * probe_electrostatic_normalization
+ )
+ / 3
+ )
+
+ if not fix_probe:
+ electrostatic_magnetic_abs = xp.abs(
+ electrostatic_obj_patches * magnetic_obj_patches
+ )
+ electrostatic_magnetic_normalization = xp.sum(
+ electrostatic_magnetic_abs**2,
+ axis=0,
+ )
+ electrostatic_magnetic_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2
+ + (normalization_min * xp.max(electrostatic_magnetic_normalization))
+ ** 2
+ )
+
+ if self._sim_recon_mode > 0:
+ electrostatic_abs = xp.abs(electrostatic_obj_patches)
+ electrostatic_normalization = xp.sum(
+ electrostatic_abs**2,
+ axis=0,
+ )
+ electrostatic_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * electrostatic_normalization) ** 2
+ + (normalization_min * xp.max(electrostatic_normalization)) ** 2
+ )
+
+ if self._sim_recon_mode == 0:
+ current_probe += step_size * (
+ xp.sum(
+ electrostatic_conj * magnetic_obj_patches * exit_waves_reverse,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 2
+ )
+
+ current_probe += step_size * (
+ xp.sum(
+ electrostatic_conj * magnetic_conj * exit_waves_forward,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 2
+ )
+
+ elif self._sim_recon_mode == 1:
+ current_probe += step_size * (
+ xp.sum(
+ electrostatic_conj * magnetic_obj_patches * exit_waves_reverse,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 3
+ )
+
+ current_probe += step_size * (
+ xp.sum(
+ electrostatic_conj * exit_waves_neutral,
+ axis=0,
+ )
+ * electrostatic_normalization
+ / 3
+ )
+
+ current_probe += step_size * (
+ xp.sum(
+ electrostatic_conj * magnetic_conj * exit_waves_forward,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 3
+ )
+ else:
+ current_probe += step_size * (
+ xp.sum(
+ electrostatic_conj * exit_waves_neutral,
+ axis=0,
+ )
+ * electrostatic_normalization
+ / 2
+ )
+
+ current_probe += step_size * (
+ xp.sum(
+ electrostatic_conj * magnetic_conj * exit_waves_forward,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 2
+ )
+
+ current_object = (electrostatic_obj, magnetic_obj)
+
+ return current_object, current_probe
+
+ def _warmup_projection_sets_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for DM_AP and RAAR methods.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ electrostatic_obj, _ = current_object
+ electrostatic_obj_patches, _ = object_patches
+
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(shifted_probes) ** 2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ electrostatic_obj = (
+ self._sum_overlapping_patches_bincounts(
+ xp.conj(shifted_probes) * exit_waves[0]
+ )
+ * probe_normalization
+ )
+
+ if not fix_probe:
+ object_normalization = xp.sum(
+ (xp.abs(electrostatic_obj_patches) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe = (
+ xp.sum(
+ xp.conj(electrostatic_obj_patches) * exit_waves[0],
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return (electrostatic_obj, None), current_probe
+
+ def _projection_sets_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for DM_AP and RAAR methods.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+
+ xp = self._xp
+
+ electrostatic_obj, magnetic_obj = current_object
+ probe_conj = xp.conj(shifted_probes)
+
+ electrostatic_obj_patches, magnetic_obj_patches = object_patches
+ electrostatic_conj = xp.conj(electrostatic_obj_patches)
+ magnetic_conj = xp.conj(magnetic_obj_patches)
+
+ probe_electrostatic_abs = xp.abs(shifted_probes * electrostatic_obj_patches)
+ probe_magnetic_abs = xp.abs(shifted_probes * magnetic_obj_patches)
+
+ probe_electrostatic_normalization = self._sum_overlapping_patches_bincounts(
+ probe_electrostatic_abs**2
+ )
+ probe_electrostatic_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_electrostatic_normalization) ** 2
+ + (normalization_min * xp.max(probe_electrostatic_normalization)) ** 2
+ )
+
+ probe_magnetic_normalization = self._sum_overlapping_patches_bincounts(
+ probe_magnetic_abs**2
+ )
+ probe_magnetic_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_magnetic_normalization) ** 2
+ + (normalization_min * xp.max(probe_magnetic_normalization)) ** 2
+ )
+
+ if self._sim_recon_mode > 0:
+ probe_abs = xp.abs(shifted_probes)
+
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ probe_abs**2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ if self._sim_recon_mode == 0:
+ exit_waves_reverse, exit_waves_forward = exit_waves
+
+ electrostatic_obj = (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * magnetic_obj_patches * exit_waves_reverse
+ )
+ * probe_magnetic_normalization
+ / 2
+ )
+
+ electrostatic_obj += (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * magnetic_conj * exit_waves_forward
+ )
+ * probe_magnetic_normalization
+ / 2
+ )
+
+ magnetic_obj = xp.conj(
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * electrostatic_conj * exit_waves_reverse
+ )
+ * probe_electrostatic_normalization
+ / 2
+ )
+
+ magnetic_obj += (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * electrostatic_conj * exit_waves_forward
+ )
+ * probe_electrostatic_normalization
+ / 2
+ )
+
+ elif self._sim_recon_mode == 1:
+ exit_waves_reverse, exit_waves_neutral, exit_waves_forward = exit_waves
+
+ electrostatic_obj = (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * magnetic_obj_patches * exit_waves_reverse
+ )
+ * probe_magnetic_normalization
+ / 3
+ )
+
+ electrostatic_obj += (
+ self._sum_overlapping_patches_bincounts(probe_conj * exit_waves_neutral)
+ * probe_normalization
+ / 3
+ )
+
+ electrostatic_obj += (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * magnetic_conj * exit_waves_forward
+ )
+ * probe_magnetic_normalization
+ / 3
+ )
+
+ magnetic_obj = xp.conj(
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * electrostatic_conj * exit_waves_reverse
+ )
+ * probe_electrostatic_normalization
+ / 2
+ )
+
+ magnetic_obj += (
+ self._sum_overlapping_patches_bincounts(
+ probe_conj * electrostatic_conj * exit_waves_forward
+ )
+ * probe_electrostatic_normalization
+ / 2
+ )
+
+ else:
+ raise NotImplementedError()
+
+ if not fix_probe:
+ electrostatic_magnetic_abs = xp.abs(
+ electrostatic_obj_patches * magnetic_obj_patches
+ )
+
+ electrostatic_magnetic_normalization = xp.sum(
+ (electrostatic_magnetic_abs**2),
+ axis=0,
+ )
+ electrostatic_magnetic_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * electrostatic_magnetic_normalization) ** 2
+ + (normalization_min * xp.max(electrostatic_magnetic_normalization))
+ ** 2
+ )
+
+ if self._sim_recon_mode > 0:
+ electrostatic_abs = xp.abs(electrostatic_obj_patches)
+ electrostatic_normalization = xp.sum(
+ (electrostatic_abs**2),
+ axis=0,
+ )
+ electrostatic_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * electrostatic_normalization) ** 2
+ + (normalization_min * xp.max(electrostatic_normalization)) ** 2
+ )
+
+ if self._sim_recon_mode == 0:
+ current_probe = (
+ xp.sum(
+ electrostatic_conj * magnetic_obj_patches * exit_waves_reverse,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 2
+ )
+
+ current_probe += (
+ xp.sum(
+ electrostatic_conj * magnetic_conj * exit_waves_forward,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 2
+ )
+
+ elif self._sim_recon_mode == 1:
+ current_probe = (
+ xp.sum(
+ electrostatic_conj * magnetic_obj_patches * exit_waves_reverse,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 3
+ )
+
+ current_probe += (
+ xp.sum(
+ electrostatic_conj * exit_waves_neutral,
+ axis=0,
+ )
+ * electrostatic_normalization
+ / 3
+ )
+
+ current_probe += (
+ xp.sum(
+ electrostatic_conj * magnetic_conj * exit_waves_forward,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 3
+ )
+ else:
+ current_probe = (
+ xp.sum(
+ electrostatic_conj * exit_waves_neutral,
+ axis=0,
+ )
+ * electrostatic_normalization
+ / 2
+ )
+
+ current_probe += (
+ xp.sum(
+ electrostatic_conj * magnetic_conj * exit_waves_forward,
+ axis=0,
+ )
+ * electrostatic_magnetic_normalization
+ / 2
+ )
+
+ current_object = (electrostatic_obj, magnetic_obj)
+
+ return current_object, current_probe
+
+ def _adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ warmup_iteration: bool,
+ use_projection_scheme: bool,
+ step_size: float,
+ normalization_min: float,
+ fix_probe: bool,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+
+ if warmup_iteration:
+ if use_projection_scheme:
+ current_object, current_probe = self._warmup_projection_sets_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ )
+ else:
+ current_object, current_probe = self._warmup_gradient_descent_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ )
+
+ else:
+ if use_projection_scheme:
+ current_object, current_probe = self._projection_sets_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ )
+ else:
+ current_object, current_probe = self._gradient_descent_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ )
+
+ return current_object, current_probe
+
+ def _constraints(
+ self,
+ current_object,
+ current_probe,
+ current_positions,
+ pure_phase_object,
+ fix_com,
+ fit_probe_aberrations,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ constrain_probe_amplitude,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ fix_probe_aperture,
+ initial_probe_aperture,
+ fix_positions,
+ global_affine_transformation,
+ gaussian_filter,
+ gaussian_filter_sigma_e,
+ gaussian_filter_sigma_m,
+ butterworth_filter,
+ q_lowpass_e,
+ q_lowpass_m,
+ q_highpass_e,
+ q_highpass_m,
+ butterworth_order,
+ tv_denoise,
+ tv_denoise_weight,
+ tv_denoise_inner_iter,
+ warmup_iteration,
+ object_positivity,
+ shrinkage_rad,
+ object_mask,
+ ):
+ """
+ Ptychographic constraints operator.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ current_positions: np.ndarray
+ Current positions estimate
+ pure_phase_object: bool
+ If True, object amplitude is set to unity
+ fix_com: bool
+ If True, probe CoM is fixed to the center
+ fit_probe_aberrations: bool
+ If True, fits the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ constrain_probe_amplitude: bool
+ If True, probe amplitude is constrained by top hat function
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude: bool
+ If True, probe aperture is constrained by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_probe_aperture: bool,
+ If True, probe Fourier amplitude is replaced by initial probe aperture.
+ initial_probe_aperture: np.ndarray,
+ Initial probe aperture to use in replacing probe Fourier amplitude.
+ fix_positions: bool
+ If True, positions are not updated
+ gaussian_filter: bool
+ If True, applies real-space gaussian filter
+ gaussian_filter_sigma_e: float
+ Standard deviation of gaussian kernel for electrostatic object in A
+ gaussian_filter_sigma_m: float
+ Standard deviation of gaussian kernel for magnetic object in A
+ probe_gaussian_filter: bool
+ If True, applies reciprocal-space gaussian filtering on residual aberrations
+ probe_gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A^-1
+ probe_gaussian_filter_fix_amplitude: bool
+ If True, only the probe phase is smoothed
+ butterworth_filter: bool
+ If True, applies high-pass butteworth filter
+ q_lowpass_e: float
+ Cut-off frequency in A^-1 for low-pass filtering electrostatic object
+ q_lowpass_m: float
+ Cut-off frequency in A^-1 for low-pass filtering magnetic object
+ q_highpass_e: float
+ Cut-off frequency in A^-1 for high-pass filtering electrostatic object
+ q_highpass_m: float
+ Cut-off frequency in A^-1 for high-pass filtering magnetic object
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weight: float
+ Denoising weight. The greater `weight`, the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ warmup_iteration: bool
+ If True, constraints electrostatic object only
+ object_positivity: bool
+ If True, clips negative potential values
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ object_mask: np.ndarray (boolean)
+ If not None, used to calculate additional shrinkage using masked-mean of object
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ constrained_positions: np.ndarray
+ Constrained positions estimate
+ """
+
+ electrostatic_obj, magnetic_obj = current_object
+
+ if gaussian_filter:
+ electrostatic_obj = self._object_gaussian_constraint(
+ electrostatic_obj, gaussian_filter_sigma_e, pure_phase_object
+ )
+ if not warmup_iteration:
+ magnetic_obj = self._object_gaussian_constraint(
+ magnetic_obj,
+ gaussian_filter_sigma_m,
+ pure_phase_object,
+ )
+
+ if butterworth_filter:
+ electrostatic_obj = self._object_butterworth_constraint(
+ electrostatic_obj,
+ q_lowpass_e,
+ q_highpass_e,
+ butterworth_order,
+ )
+ if not warmup_iteration:
+ magnetic_obj = self._object_butterworth_constraint(
+ magnetic_obj,
+ q_lowpass_m,
+ q_highpass_m,
+ butterworth_order,
+ )
+
+ if self._object_type == "complex":
+ magnetic_obj = magnetic_obj.real
+ if tv_denoise:
+ electrostatic_obj = self._object_denoise_tv_pylops(
+ electrostatic_obj, tv_denoise_weight, tv_denoise_inner_iter
+ )
+
+ if not warmup_iteration:
+ magnetic_obj = self._object_denoise_tv_pylops(
+ magnetic_obj, tv_denoise_weight, tv_denoise_inner_iter
+ )
+
+ if shrinkage_rad > 0.0 or object_mask is not None:
+ electrostatic_obj = self._object_shrinkage_constraint(
+ electrostatic_obj,
+ shrinkage_rad,
+ object_mask,
+ )
+
+ if self._object_type == "complex":
+ electrostatic_obj = self._object_threshold_constraint(
+ electrostatic_obj, pure_phase_object
+ )
+ elif object_positivity:
+ electrostatic_obj = self._object_positivity_constraint(electrostatic_obj)
+
+ current_object = (electrostatic_obj, magnetic_obj)
+
+ if fix_com:
+ current_probe = self._probe_center_of_mass_constraint(current_probe)
+
+ if fix_probe_aperture:
+ current_probe = self._probe_aperture_constraint(
+ current_probe,
+ initial_probe_aperture,
+ )
+ elif constrain_probe_fourier_amplitude:
+ current_probe = self._probe_fourier_amplitude_constraint(
+ current_probe,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ )
+
+ if fit_probe_aberrations:
+ current_probe = self._probe_aberration_fitting_constraint(
+ current_probe,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ )
+
+ if constrain_probe_amplitude:
+ current_probe = self._probe_amplitude_constraint(
+ current_probe,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ )
+
+ if not fix_positions:
+ current_positions = self._positions_center_of_mass_constraint(
+ current_positions
+ )
+
+ if global_affine_transformation:
+ current_positions = self._positions_affine_transformation_constraint(
+ self._positions_px_initial, current_positions
+ )
+
+ return current_object, current_probe, current_positions
+
+ def reconstruct(
+ self,
+ max_iter: int = 64,
+ reconstruction_method: str = "gradient-descent",
+ reconstruction_parameter: float = 1.0,
+ reconstruction_parameter_a: float = None,
+ reconstruction_parameter_b: float = None,
+ reconstruction_parameter_c: float = None,
+ max_batch_size: int = None,
+ seed_random: int = None,
+ step_size: float = 0.5,
+ normalization_min: float = 1,
+ positions_step_size: float = 0.9,
+ pure_phase_object_iter: int = 0,
+ fix_com: bool = True,
+ fix_probe_iter: int = 0,
+ warmup_iter: int = 0,
+ fix_probe_aperture_iter: int = 0,
+ constrain_probe_amplitude_iter: int = 0,
+ constrain_probe_amplitude_relative_radius: float = 0.5,
+ constrain_probe_amplitude_relative_width: float = 0.05,
+ constrain_probe_fourier_amplitude_iter: int = 0,
+ constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0,
+ constrain_probe_fourier_amplitude_constant_intensity: bool = False,
+ fix_positions_iter: int = np.inf,
+ constrain_position_distance: float = None,
+ global_affine_transformation: bool = True,
+ gaussian_filter_sigma_e: float = None,
+ gaussian_filter_sigma_m: float = None,
+ gaussian_filter_iter: int = np.inf,
+ fit_probe_aberrations_iter: int = 0,
+ fit_probe_aberrations_max_angular_order: int = 4,
+ fit_probe_aberrations_max_radial_order: int = 4,
+ butterworth_filter_iter: int = np.inf,
+ q_lowpass_e: float = None,
+ q_lowpass_m: float = None,
+ q_highpass_e: float = None,
+ q_highpass_m: float = None,
+ butterworth_order: float = 2,
+ tv_denoise_iter: int = np.inf,
+ tv_denoise_weight: float = None,
+ tv_denoise_inner_iter: float = 40,
+ object_positivity: bool = True,
+ shrinkage_rad: float = 0.0,
+ fix_potential_baseline: bool = True,
+ switch_object_iter: int = np.inf,
+ store_iterations: bool = False,
+ progress_bar: bool = True,
+ reset: bool = None,
+ ):
+ """
+ Ptychographic reconstruction main method.
+
+ Parameters
+ --------
+ max_iter: int, optional
+ Maximum number of iterations to run
+ reconstruction_method: str, optional
+ Specifies which reconstruction algorithm to use, one of:
+ "generalized-projections",
+ "DM_AP" (or "difference-map_alternating-projections"),
+ "RAAR" (or "relaxed-averaged-alternating-reflections"),
+ "RRR" (or "relax-reflect-reflect"),
+ "SUPERFLIP" (or "charge-flipping"), or
+ "GD" (or "gradient_descent")
+ reconstruction_parameter: float, optional
+ Reconstruction parameter for various reconstruction methods above.
+ reconstruction_parameter_a: float, optional
+ Reconstruction parameter a for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_b: float, optional
+ Reconstruction parameter b for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_c: float, optional
+ Reconstruction parameter c for reconstruction_method='generalized-projections'.
+ max_batch_size: int, optional
+ Max number of probes to update at once
+ seed_random: int, optional
+ Seeds the random number generator, only applicable when max_batch_size is not None
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ positions_step_size: float, optional
+ Positions update step size
+ pure_phase_object_iter: float, optional
+ Number of iterations where object amplitude is set to unity
+ fix_com: bool, optional
+ If True, fixes center of mass of probe
+ fix_probe_iter: int, optional
+ Number of iterations to run with a fixed probe before updating probe estimate
+ fix_probe_aperture_iter: int, optional
+ Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate
+ constrain_probe_amplitude_iter: int, optional
+ Number of iterations to run while constraining the real-space probe with a top-hat support.
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude_iter: int, optional
+ Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_positions_iter: int, optional
+ Number of iterations to run with fixed positions before updating positions estimate
+ constrain_position_distance: float
+ Distance to constrain position correction within original
+ field of view in A
+ global_affine_transformation: bool, optional
+ If True, positions are assumed to be a global affine transform from initial scan
+ gaussian_filter_sigma_e: float
+ Standard deviation of gaussian kernel for electrostatic object in A
+ gaussian_filter_sigma_m: float
+ Standard deviation of gaussian kernel for magnetic object in A
+ gaussian_filter_iter: int, optional
+ Number of iterations to run using object smoothness constraint
+ fit_probe_aberrations_iter: int, optional
+ Number of iterations to run while fitting the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ butterworth_filter_iter: int, optional
+ Number of iterations to run using high-pass butteworth filter
+ q_lowpass_e: float
+ Cut-off frequency in A^-1 for low-pass filtering electrostatic object
+ q_lowpass_m: float
+ Cut-off frequency in A^-1 for low-pass filtering magnetic object
+ q_highpass_e: float
+ Cut-off frequency in A^-1 for high-pass filtering electrostatic object
+ q_highpass_m: float
+ Cut-off frequency in A^-1 for high-pass filtering magnetic object
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ tv_denoise_iter: int, optional
+ Number of iterations to run using tv denoise filter on object
+ tv_denoise_weight: float
+ Denoising weight. The greater `weight`, the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ object_positivity: bool, optional
+ If True, forces object to be positive
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ fix_potential_baseline: bool
+ If true, the potential mean outside the FOV is forced to zero at each iteration
+ switch_object_iter: int, optional
+ Iteration to switch object type between 'complex' and 'potential' or between
+ 'potential' and 'complex'
+ store_iterations: bool, optional
+ If True, reconstructed objects and probes are stored at each iteration
+ progress_bar: bool, optional
+ If True, reconstruction progress is displayed
+ reset: bool, optional
+ If True, previous reconstructions are ignored
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+ asnumpy = self._asnumpy
+ xp = self._xp
+
+ # Reconstruction method
+
+ if reconstruction_method == "generalized-projections":
+ if (
+ reconstruction_parameter_a is None
+ or reconstruction_parameter_b is None
+ or reconstruction_parameter_c is None
+ ):
+ raise ValueError(
+ (
+ "reconstruction_parameter_a/b/c must all be specified "
+ "when using reconstruction_method='generalized-projections'."
+ )
+ )
+
+ use_projection_scheme = True
+ projection_a = reconstruction_parameter_a
+ projection_b = reconstruction_parameter_b
+ projection_c = reconstruction_parameter_c
+ step_size = None
+ elif (
+ reconstruction_method == "DM_AP"
+ or reconstruction_method == "difference-map_alternating-projections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = 1
+ projection_c = 1 + reconstruction_parameter
+ step_size = None
+ elif (
+ reconstruction_method == "RAAR"
+ or reconstruction_method == "relaxed-averaged-alternating-reflections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = 1 - 2 * reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "RRR"
+ or reconstruction_method == "relax-reflect-reflect"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0:
+ raise ValueError("reconstruction_parameter must be between 0-2.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "SUPERFLIP"
+ or reconstruction_method == "charge-flipping"
+ ):
+ use_projection_scheme = True
+ projection_a = 0
+ projection_b = 1
+ projection_c = 2
+ reconstruction_parameter = None
+ step_size = None
+ elif (
+ reconstruction_method == "GD" or reconstruction_method == "gradient-descent"
+ ):
+ use_projection_scheme = False
+ projection_a = None
+ projection_b = None
+ projection_c = None
+ reconstruction_parameter = None
+ else:
+ raise ValueError(
+ (
+ "reconstruction_method must be one of 'generalized-projections', "
+ "'DM_AP' (or 'difference-map_alternating-projections'), "
+ "'RAAR' (or 'relaxed-averaged-alternating-reflections'), "
+ "'RRR' (or 'relax-reflect-reflect'), "
+ "'SUPERFLIP' (or 'charge-flipping'), "
+ f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}."
+ )
+ )
+
+ if use_projection_scheme and self._sim_recon_mode == 2:
+ raise NotImplementedError(
+ "simultaneous_measurements_mode == '0+' and projection set algorithms are currently incompatible."
+ )
+
+ if self._verbose:
+ if switch_object_iter > max_iter:
+ first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, "
+ else:
+ switch_object_type = (
+ "complex" if self._object_type == "potential" else "potential"
+ )
+ first_line = (
+ f"Performing {switch_object_iter} iterations using a {self._object_type} object type and "
+ f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, "
+ )
+ if max_batch_size is not None:
+ if use_projection_scheme:
+ raise ValueError(
+ (
+ "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. "
+ "Use reconstruction_method='GD' or set max_batch_size=None."
+ )
+ )
+ else:
+ print(
+ (
+ first_line + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}, "
+ f"in batches of max {max_batch_size} measurements."
+ )
+ )
+
+ else:
+ if reconstruction_parameter is not None:
+ if np.array(reconstruction_parameter).shape == (3,):
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}."
+ )
+ )
+ else:
+ if step_size is not None:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}."
+ )
+ )
+
+ # Batching
+ shuffled_indices = np.arange(self._num_diffraction_patterns)
+ unshuffled_indices = np.zeros_like(shuffled_indices)
+
+ if max_batch_size is not None:
+ xp.random.seed(seed_random)
+ else:
+ max_batch_size = self._num_diffraction_patterns
+
+ # initialization
+ if store_iterations and (not hasattr(self, "object_iterations") or reset):
+ self.object_iterations = []
+ self.probe_iterations = []
+
+ if reset:
+ self._object = (
+ self._object_initial[0].copy(),
+ self._object_initial[1].copy(),
+ )
+ self._probe = self._probe_initial.copy()
+ self.error_iterations = []
+ self._positions_px = self._positions_px_initial.copy()
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ self._exit_waves = (None,) * self._num_sim_measurements
+ self._object_type = self._object_type_initial
+ if hasattr(self, "_tf"):
+ del self._tf
+ elif reset is None:
+ if hasattr(self, "error"):
+ warnings.warn(
+ (
+ "Continuing reconstruction from previous result. "
+ "Use reset=True for a fresh start."
+ ),
+ UserWarning,
+ )
+ else:
+ self.error_iterations = []
+ self._exit_waves = (None,) * self._num_sim_measurements
+
+ if gaussian_filter_sigma_m is None:
+ gaussian_filter_sigma_m = gaussian_filter_sigma_e
+
+ if q_lowpass_m is None:
+ q_lowpass_m = q_lowpass_e
+
+ # main loop
+ for a0 in tqdmnd(
+ max_iter,
+ desc="Reconstructing object and probe",
+ unit=" iter",
+ disable=not progress_bar,
+ ):
+ error = 0.0
+
+ if a0 == switch_object_iter:
+ if self._object_type == "potential":
+ self._object_type = "complex"
+ self._object = (xp.exp(1j * self._object[0]), self._object[1])
+ elif self._object_type == "complex":
+ self._object_type = "potential"
+ self._object = (xp.angle(self._object[0]), self._object[1])
+
+ if a0 == warmup_iter:
+ self._object = (self._object[0], self._object_initial[1].copy())
+
+ # randomize
+ if not use_projection_scheme:
+ np.random.shuffle(shuffled_indices)
+ unshuffled_indices[shuffled_indices] = np.arange(
+ self._num_diffraction_patterns
+ )
+ positions_px = self._positions_px.copy()[shuffled_indices]
+
+ for start, end in generate_batches(
+ self._num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ amps = []
+ for amplitudes in self._amplitudes:
+ amps.append(amplitudes[shuffled_indices[start:end]])
+ amplitudes = tuple(amps)
+
+ # forward operator
+ (
+ shifted_probes,
+ object_patches,
+ overlap,
+ self._exit_waves,
+ batch_error,
+ ) = self._forward(
+ self._object,
+ self._probe,
+ amplitudes,
+ self._exit_waves,
+ warmup_iteration=a0 < warmup_iter,
+ use_projection_scheme=use_projection_scheme,
+ projection_a=projection_a,
+ projection_b=projection_b,
+ projection_c=projection_c,
+ )
+
+ # adjoint operator
+ self._object, self._probe = self._adjoint(
+ self._object,
+ self._probe,
+ object_patches,
+ shifted_probes,
+ self._exit_waves,
+ warmup_iteration=a0 < warmup_iter,
+ use_projection_scheme=use_projection_scheme,
+ step_size=step_size,
+ normalization_min=normalization_min,
+ fix_probe=a0 < fix_probe_iter,
+ )
+
+ # position correction
+ if a0 >= fix_positions_iter:
+ positions_px[start:end] = self._position_correction(
+ self._object[0],
+ shifted_probes,
+ overlap[0],
+ amplitudes[0],
+ self._positions_px,
+ positions_step_size,
+ constrain_position_distance,
+ )
+
+ error += batch_error
+
+ # Normalize Error
+ error /= self._mean_diffraction_intensity * self._num_diffraction_patterns
+
+ # constraints
+ self._positions_px = positions_px.copy()[unshuffled_indices]
+ self._object, self._probe, self._positions_px = self._constraints(
+ self._object,
+ self._probe,
+ self._positions_px,
+ fix_com=fix_com and a0 >= fix_probe_iter,
+ constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude=a0
+ < constrain_probe_fourier_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
+ fit_probe_aberrations=a0 < fit_probe_aberrations_iter
+ and a0 >= fix_probe_iter,
+ fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
+ fix_probe_aperture=a0 < fix_probe_aperture_iter,
+ initial_probe_aperture=self._probe_initial_aperture,
+ fix_positions=a0 < fix_positions_iter,
+ global_affine_transformation=global_affine_transformation,
+ warmup_iteration=a0 < warmup_iter,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma_m is not None,
+ gaussian_filter_sigma_e=gaussian_filter_sigma_e,
+ gaussian_filter_sigma_m=gaussian_filter_sigma_m,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass_m is not None or q_highpass_m is not None),
+ q_lowpass_e=q_lowpass_e,
+ q_lowpass_m=q_lowpass_m,
+ q_highpass_e=q_highpass_e,
+ q_highpass_m=q_highpass_m,
+ butterworth_order=butterworth_order,
+ tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None,
+ tv_denoise_weight=tv_denoise_weight,
+ tv_denoise_inner_iter=tv_denoise_inner_iter,
+ object_positivity=object_positivity,
+ shrinkage_rad=shrinkage_rad,
+ object_mask=self._object_fov_mask_inverse
+ if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0
+ else None,
+ pure_phase_object=a0 < pure_phase_object_iter
+ and self._object_type == "complex",
+ )
+
+ self.error_iterations.append(error.item())
+ if store_iterations:
+ if a0 < warmup_iter:
+ self.object_iterations.append(
+ (asnumpy(self._object[0].copy()), None)
+ )
+ else:
+ self.object_iterations.append(
+ (
+ asnumpy(self._object[0].copy()),
+ asnumpy(self._object[1].copy()),
+ )
+ )
+ self.probe_iterations.append(self.probe_centered)
+
+ # store result
+ if a0 < warmup_iter:
+ self.object = (asnumpy(self._object[0]), None)
+ else:
+ self.object = (asnumpy(self._object[0]), asnumpy(self._object[1]))
+ self.probe = self.probe_centered
+ self.error = error.item()
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _visualize_last_iteration_figax(
+ self,
+ fig,
+ object_ax,
+ convergence_ax: None,
+ cbar: bool,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object on a given fig/ax.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure object_ax lives in
+ object_ax: Axes
+ Matplotlib axes to plot reconstructed object in
+ convergence_ax: Axes, optional
+ Matplotlib axes to plot convergence plot in
+ cbar: bool, optional
+ If true, displays a colorbar
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ cmap = kwargs.pop("cmap", "magma")
+
+ if self._object_type == "complex":
+ obj = np.angle(self.object[0])
+ else:
+ obj = self.object[0]
+
+ rotated_object = self._crop_rotate_object_fov(obj, padding=padding)
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ im = object_ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(object_ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if convergence_ax is not None and hasattr(self, "error_iterations"):
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = self.error_iterations
+ convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs)
+
+ def _visualize_last_iteration(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool, optional
+ If true, the reconstructed complex probe is displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ figsize = kwargs.pop("figsize", (12, 5))
+ cmap_e = kwargs.pop("cmap_e", "magma")
+ cmap_m = kwargs.pop("cmap_m", "PuOr")
+
+ if self._object_type == "complex":
+ obj_e = np.angle(self.object[0])
+ obj_m = self.object[1]
+ else:
+ obj_e, obj_m = self.object
+
+ rotated_electrostatic = self._crop_rotate_object_fov(obj_e, padding=padding)
+ rotated_magnetic = self._crop_rotate_object_fov(obj_m, padding=padding)
+ rotated_shape = rotated_electrostatic.shape
+
+ min_e = rotated_electrostatic.min()
+ max_e = rotated_electrostatic.max()
+ max_m = np.abs(rotated_magnetic).max()
+ min_m = -max_m
+
+ vmin_e = kwargs.pop("vmin_e", min_e)
+ vmax_e = kwargs.pop("vmax_e", max_e)
+ vmin_m = kwargs.pop("vmin_m", min_m)
+ vmax_m = kwargs.pop("vmax_m", max_m)
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=3,
+ nrows=2,
+ height_ratios=[4, 1],
+ hspace=0.15,
+ width_ratios=[
+ 1,
+ 1,
+ (probe_extent[1] / probe_extent[2]) / (extent[1] / extent[2]),
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=2, nrows=2, height_ratios=[4, 1], hspace=0.15)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=3,
+ nrows=1,
+ width_ratios=[
+ 1,
+ 1,
+ (probe_extent[1] / probe_extent[2]) / (extent[1] / extent[2]),
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=2, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ if plot_probe or plot_fourier_probe:
+ # Electrostatic Object
+ ax = fig.add_subplot(spec[0, 0])
+ im = ax.imshow(
+ rotated_electrostatic,
+ extent=extent,
+ cmap=cmap_e,
+ vmin=vmin_e,
+ vmax=vmax_e,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed electrostatic potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed electrostatic phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ # Magnetic Object
+ ax = fig.add_subplot(spec[0, 1])
+ im = ax.imshow(
+ rotated_magnetic,
+ extent=extent,
+ cmap=cmap_m,
+ vmin=vmin_m,
+ vmax=vmax_m,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ ax.set_title("Reconstructed magnetic potential")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ # Probe
+ ax = fig.add_subplot(spec[0, 2])
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ self.probe_fourier,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title("Reconstructed Fourier probe")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ self.probe, power=2, chroma_boost=chroma_boost
+ )
+ ax.set_title("Reconstructed probe intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(ax_cb, chroma_boost=chroma_boost)
+
+ else:
+ # Electrostatic Object
+ ax = fig.add_subplot(spec[0, 0])
+ im = ax.imshow(
+ rotated_electrostatic,
+ extent=extent,
+ cmap=cmap_e,
+ vmin=vmin_e,
+ vmax=vmax_e,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed electrostatic potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed electrostatic phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ # Magnetic Object
+ ax = fig.add_subplot(spec[0, 1])
+ im = ax.imshow(
+ rotated_magnetic,
+ extent=extent,
+ cmap=cmap_m,
+ vmin=vmin_m,
+ vmax=vmax_m,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ ax.set_title("Reconstructed magnetic potential")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if plot_convergence and hasattr(self, "error_iterations"):
+ errors = np.array(self.error_iterations)
+ ax = fig.add_subplot(spec[1, :])
+ ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax.set_ylabel("NMSE")
+ ax.set_xlabel("Iteration number")
+ ax.yaxis.tick_right()
+
+ fig.suptitle(f"Normalized mean squared error: {self.error:.3e}")
+ spec.tight_layout(fig)
+
+ def _visualize_all_iterations(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ iterations_grid: Tuple[int, int],
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays all reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool, optional
+ If true, the reconstructed complex probe is displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ raise NotImplementedError()
+
+ def visualize(
+ self,
+ fig=None,
+ iterations_grid: Tuple[int, int] = None,
+ plot_convergence: bool = True,
+ plot_probe: bool = True,
+ plot_fourier_probe: bool = False,
+ cbar: bool = True,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays reconstructed object and probe.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool, optional
+ If true, the reconstructed complex probe is displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+
+ if iterations_grid is None:
+ self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ else:
+ self._visualize_all_iterations(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ iterations_grid=iterations_grid,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+
+ return self
+
+ @property
+ def self_consistency_errors(self):
+ """Compute the self-consistency errors for each probe position"""
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # Re-initialize fractional positions and vector patches, max_batch_size = None
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ # Overlaps
+ _, _, overlap = self._warmup_overlap_projection(self._object, self._probe)
+ fourier_overlap = xp.fft.fft2(overlap[0])
+
+ # Normalized mean-squared errors
+ error = xp.sum(
+ xp.abs(self._amplitudes[0] - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1)
+ )
+ error /= self._mean_diffraction_intensity
+
+ return asnumpy(error)
+
+ def _return_self_consistency_errors(
+ self,
+ max_batch_size=None,
+ ):
+ """Compute the self-consistency errors for each probe position"""
+
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # Batch-size
+ if max_batch_size is None:
+ max_batch_size = self._num_diffraction_patterns
+
+ # Re-initialize fractional positions and vector patches
+ errors = np.array([])
+ positions_px = self._positions_px.copy()
+
+ for start, end in generate_batches(
+ self._num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ amplitudes = self._amplitudes[0][start:end]
+
+ # Overlaps
+ _, _, overlap = self._warmup_overlap_projection(self._object, self._probe)
+ fourier_overlap = xp.fft.fft2(overlap[0])
+
+ # Normalized mean-squared errors
+ batch_errors = xp.sum(
+ xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2, axis=(-2, -1)
+ )
+ errors = np.hstack((errors, batch_errors))
+
+ self._positions_px = positions_px.copy()
+ errors /= self._mean_diffraction_intensity
+
+ return asnumpy(errors)
+
+ def _return_projected_cropped_potential(
+ self,
+ ):
+ """Utility function to accommodate multiple classes"""
+ if self._object_type == "complex":
+ projected_cropped_potential = np.angle(self.object_cropped[0])
+ else:
+ projected_cropped_potential = self.object_cropped[0]
+
+ return projected_cropped_potential
+
+ @property
+ def object_cropped(self):
+ """Cropped and rotated object"""
+
+ obj_e, obj_m = self._object
+ obj_e = self._crop_rotate_object_fov(obj_e)
+ obj_m = self._crop_rotate_object_fov(obj_m)
+ return (obj_e, obj_m)
diff --git a/py4DSTEM/process/phase/iterative_singleslice_ptychography.py b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py
new file mode 100644
index 000000000..350d0a3cb
--- /dev/null
+++ b/py4DSTEM/process/phase/iterative_singleslice_ptychography.py
@@ -0,0 +1,2213 @@
+"""
+Module for reconstructing phase objects from 4DSTEM datasets using iterative methods,
+namely (single-slice) ptychography.
+"""
+
+import warnings
+from typing import Mapping, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.gridspec import GridSpec
+from mpl_toolkits.axes_grid1 import ImageGrid, make_axes_locatable
+from py4DSTEM.visualize.vis_special import Complex2RGB, add_colorbar_arg
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+from emdfile import Custom, tqdmnd
+from py4DSTEM.datacube import DataCube
+from py4DSTEM.process.phase.iterative_base_class import PtychographicReconstruction
+from py4DSTEM.process.phase.utils import (
+ ComplexProbe,
+ fft_shift,
+ generate_batches,
+ polar_aliases,
+ polar_symbols,
+)
+from py4DSTEM.process.utils import get_CoM, get_shifted_ar
+
+warnings.simplefilter(action="always", category=UserWarning)
+
+
+class SingleslicePtychographicReconstruction(PtychographicReconstruction):
+ """
+ Iterative Ptychographic Reconstruction Class.
+
+ Diffraction intensities dimensions : (Rx,Ry,Qx,Qy)
+ Reconstructed probe dimensions : (Sx,Sy)
+ Reconstructed object dimensions : (Px,Py)
+
+ such that (Sx,Sy) is the region-of-interest (ROI) size of our probe
+ and (Px,Py) is the padded-object size we position our ROI around in.
+
+ Parameters
+ ----------
+ energy: float
+ The electron energy of the wave functions in eV
+ datacube: DataCube
+ Input 4D diffraction pattern intensities
+ semiangle_cutoff: float, optional
+ Semiangle cutoff for the initial probe guess in mrad
+ semiangle_cutoff_pixels: float, optional
+ Semiangle cutoff for the initial probe guess in pixels
+ rolloff: float, optional
+ Semiangle rolloff for the initial probe guess
+ vacuum_probe_intensity: np.ndarray, optional
+ Vacuum probe to use as intensity aperture for initial probe guess
+ polar_parameters: dict, optional
+ Mapping from aberration symbols to their corresponding values. All aberration
+ magnitudes should be given in Å and angles should be given in radians.
+ object_padding_px: Tuple[int,int], optional
+ Pixel dimensions to pad object with
+ If None, the padding is set to half the probe ROI dimensions
+ initial_object_guess: np.ndarray, optional
+ Initial guess for complex-valued object of dimensions (Px,Py)
+ If None, initialized to 1.0j
+ initial_probe_guess: np.ndarray, optional
+ Initial guess for complex-valued probe of dimensions (Sx,Sy). If None,
+ initialized to ComplexProbe with semiangle_cutoff, energy, and aberrations
+ initial_scan_positions: np.ndarray, optional
+ Probe positions in Å for each diffraction intensity
+ If None, initialized to a grid scan
+ verbose: bool, optional
+ If True, class methods will inherit this and print additional information
+ device: str, optional
+ Calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ object_type: str, optional
+ The object can be reconstructed as a real potential ('potential') or a complex
+ object ('complex')
+ positions_mask: np.ndarray, optional
+ Boolean real space mask to select positions in datacube to skip for reconstruction
+ name: str, optional
+ Class name
+ kwargs:
+ Provide the aberration coefficients as keyword arguments.
+ """
+
+ # Class-specific Metadata
+ _class_specific_metadata = ()
+
+ def __init__(
+ self,
+ energy: float,
+ datacube: DataCube = None,
+ semiangle_cutoff: float = None,
+ semiangle_cutoff_pixels: float = None,
+ rolloff: float = 2.0,
+ vacuum_probe_intensity: np.ndarray = None,
+ polar_parameters: Mapping[str, float] = None,
+ initial_object_guess: np.ndarray = None,
+ initial_probe_guess: np.ndarray = None,
+ initial_scan_positions: np.ndarray = None,
+ object_padding_px: Tuple[int, int] = None,
+ object_type: str = "complex",
+ positions_mask: np.ndarray = None,
+ verbose: bool = True,
+ device: str = "cpu",
+ name: str = "ptychographic_reconstruction",
+ **kwargs,
+ ):
+ Custom.__init__(self, name=name)
+
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ from scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from scipy.special import erf
+
+ self._erf = erf
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ from cupyx.scipy.ndimage import gaussian_filter
+
+ self._gaussian_filter = gaussian_filter
+ from cupyx.scipy.special import erf
+
+ self._erf = erf
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ for key in kwargs.keys():
+ if (key not in polar_symbols) and (key not in polar_aliases.keys()):
+ raise ValueError("{} not a recognized parameter".format(key))
+
+ self._polar_parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))
+
+ if polar_parameters is None:
+ polar_parameters = {}
+
+ polar_parameters.update(kwargs)
+ self._set_polar_parameters(polar_parameters)
+
+ if object_type != "potential" and object_type != "complex":
+ raise ValueError(
+ f"object_type must be either 'potential' or 'complex', not {object_type}"
+ )
+
+ if positions_mask is not None and positions_mask.dtype != "bool":
+ warnings.warn(
+ ("`positions_mask` converted to `bool` array"),
+ UserWarning,
+ )
+ positions_mask = np.asarray(positions_mask, dtype="bool")
+
+ self.set_save_defaults()
+
+ # Data
+ self._datacube = datacube
+ self._object = initial_object_guess
+ self._probe = initial_probe_guess
+
+ # Common Metadata
+ self._vacuum_probe_intensity = vacuum_probe_intensity
+ self._scan_positions = initial_scan_positions
+ self._energy = energy
+ self._semiangle_cutoff = semiangle_cutoff
+ self._semiangle_cutoff_pixels = semiangle_cutoff_pixels
+ self._rolloff = rolloff
+ self._object_type = object_type
+ self._object_padding_px = object_padding_px
+ self._positions_mask = positions_mask
+ self._verbose = verbose
+ self._device = device
+ self._preprocessed = False
+
+ # Class-specific Metadata
+
+ def preprocess(
+ self,
+ diffraction_intensities_shape: Tuple[int, int] = None,
+ reshaping_method: str = "fourier",
+ probe_roi_shape: Tuple[int, int] = None,
+ dp_mask: np.ndarray = None,
+ fit_function: str = "plane",
+ plot_center_of_mass: str = "default",
+ plot_rotation: bool = True,
+ maximize_divergence: bool = False,
+ rotation_angles_deg: np.ndarray = np.arange(-89.0, 90.0, 1.0),
+ plot_probe_overlaps: bool = True,
+ force_com_rotation: float = None,
+ force_com_transpose: float = None,
+ force_com_shifts: float = None,
+ force_scan_sampling: float = None,
+ force_angular_sampling: float = None,
+ force_reciprocal_sampling: float = None,
+ object_fov_mask: np.ndarray = None,
+ crop_patterns: bool = False,
+ **kwargs,
+ ):
+ """
+ Ptychographic preprocessing step.
+ Calls the base class methods:
+
+ _extract_intensities_and_calibrations_from_datacube,
+ _compute_center_of_mass(),
+ _solve_CoM_rotation(),
+ _normalize_diffraction_intensities()
+ _calculate_scan_positions_in_px()
+
+ Additionally, it initializes an (Px,Py) array of 1.0j
+ and a complex probe using the specified polar parameters.
+
+ Parameters
+ ----------
+ diffraction_intensities_shape: Tuple[int,int], optional
+ Pixel dimensions (Qx',Qy') of the resampled diffraction intensities
+ If None, no resampling of diffraction intenstities is performed
+ reshaping_method: str, optional
+ Method to use for reshaping, either 'bin, 'bilinear', or 'fourier' (default)
+ probe_roi_shape, (int,int), optional
+ Padded diffraction intensities shape.
+ If None, no padding is performed
+ dp_mask: ndarray, optional
+ Mask for datacube intensities (Qx,Qy)
+ fit_function: str, optional
+ 2D fitting function for CoM fitting. One of 'plane','parabola','bezier_two'
+ plot_center_of_mass: str, optional
+ If 'default', the corrected CoM arrays will be displayed
+ If 'all', the computed and fitted CoM arrays will be displayed
+ plot_rotation: bool, optional
+ If True, the CoM curl minimization search result will be displayed
+ maximize_divergence: bool, optional
+ If True, the divergence of the CoM gradient vector field is maximized
+ rotation_angles_deg: np.darray, optional
+ Array of angles in degrees to perform curl minimization over
+ plot_probe_overlaps: bool, optional
+ If True, initial probe overlaps scanned over the object will be displayed
+ force_com_rotation: float (degrees), optional
+ Force relative rotation angle between real and reciprocal space
+ force_com_transpose: bool, optional
+ Force whether diffraction intensities need to be transposed.
+ force_com_shifts: tuple of ndarrays (CoMx, CoMy)
+ Amplitudes come from diffraction patterns shifted with
+ the CoM in the upper left corner for each probe unless
+ shift is overwritten.
+ force_scan_sampling: float, optional
+ Override DataCube real space scan pixel size calibrations, in Angstrom
+ force_angular_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in mrad
+ force_reciprocal_sampling: float, optional
+ Override DataCube reciprocal pixel size calibration, in A^-1
+ object_fov_mask: np.ndarray (boolean)
+ Boolean mask of FOV. Used to calculate additional shrinkage of object
+ If None, probe_overlap intensity is thresholded
+ crop_patterns: bool
+ if True, crop patterns to avoid wrap around of patterns when centering
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ # set additional metadata
+ self._diffraction_intensities_shape = diffraction_intensities_shape
+ self._reshaping_method = reshaping_method
+ self._probe_roi_shape = probe_roi_shape
+ self._dp_mask = dp_mask
+
+ if self._datacube is None:
+ raise ValueError(
+ (
+ "The preprocess() method requires a DataCube. "
+ "Please run ptycho.attach_datacube(DataCube) first."
+ )
+ )
+
+ (
+ self._datacube,
+ self._vacuum_probe_intensity,
+ self._dp_mask,
+ force_com_shifts,
+ ) = self._preprocess_datacube_and_vacuum_probe(
+ self._datacube,
+ diffraction_intensities_shape=self._diffraction_intensities_shape,
+ reshaping_method=self._reshaping_method,
+ probe_roi_shape=self._probe_roi_shape,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ dp_mask=self._dp_mask,
+ com_shifts=force_com_shifts,
+ )
+
+ self._intensities = self._extract_intensities_and_calibrations_from_datacube(
+ self._datacube,
+ require_calibrations=True,
+ force_scan_sampling=force_scan_sampling,
+ force_angular_sampling=force_angular_sampling,
+ force_reciprocal_sampling=force_reciprocal_sampling,
+ )
+
+ (
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_fitted_x,
+ self._com_fitted_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ ) = self._calculate_intensities_center_of_mass(
+ self._intensities,
+ dp_mask=self._dp_mask,
+ fit_function=fit_function,
+ com_shifts=force_com_shifts,
+ )
+
+ (
+ self._rotation_best_rad,
+ self._rotation_best_transpose,
+ self._com_x,
+ self._com_y,
+ self.com_x,
+ self.com_y,
+ ) = self._solve_for_center_of_mass_relative_rotation(
+ self._com_measured_x,
+ self._com_measured_y,
+ self._com_normalized_x,
+ self._com_normalized_y,
+ rotation_angles_deg=rotation_angles_deg,
+ plot_rotation=plot_rotation,
+ plot_center_of_mass=plot_center_of_mass,
+ maximize_divergence=maximize_divergence,
+ force_com_rotation=force_com_rotation,
+ force_com_transpose=force_com_transpose,
+ **kwargs,
+ )
+
+ (
+ self._amplitudes,
+ self._mean_diffraction_intensity,
+ ) = self._normalize_diffraction_intensities(
+ self._intensities,
+ self._com_fitted_x,
+ self._com_fitted_y,
+ crop_patterns,
+ self._positions_mask,
+ )
+
+ # explicitly delete namespace
+ self._num_diffraction_patterns = self._amplitudes.shape[0]
+ self._region_of_interest_shape = np.array(self._amplitudes.shape[-2:])
+ del self._intensities
+
+ self._positions_px = self._calculate_scan_positions_in_pixels(
+ self._scan_positions, self._positions_mask
+ )
+
+ # handle semiangle specified in pixels
+ if self._semiangle_cutoff_pixels:
+ self._semiangle_cutoff = (
+ self._semiangle_cutoff_pixels * self._angular_sampling[0]
+ )
+
+ # Object Initialization
+ if self._object is None:
+ pad_x = self._object_padding_px[0][1]
+ pad_y = self._object_padding_px[1][1]
+ p, q = np.round(np.max(self._positions_px, axis=0))
+ p = np.max([np.round(p + pad_x), self._region_of_interest_shape[0]]).astype(
+ "int"
+ )
+ q = np.max([np.round(q + pad_y), self._region_of_interest_shape[1]]).astype(
+ "int"
+ )
+ if self._object_type == "potential":
+ self._object = xp.zeros((p, q), dtype=xp.float32)
+ elif self._object_type == "complex":
+ self._object = xp.ones((p, q), dtype=xp.complex64)
+ else:
+ if self._object_type == "potential":
+ self._object = xp.asarray(self._object, dtype=xp.float32)
+ elif self._object_type == "complex":
+ self._object = xp.asarray(self._object, dtype=xp.complex64)
+
+ self._object_initial = self._object.copy()
+ self._object_type_initial = self._object_type
+ self._object_shape = self._object.shape
+
+ self._positions_px = xp.asarray(self._positions_px, dtype=xp.float32)
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px -= self._positions_px_com - xp.array(self._object_shape) / 2
+ self._positions_px_com = xp.mean(self._positions_px, axis=0)
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+
+ self._positions_px_initial = self._positions_px.copy()
+ self._positions_initial = self._positions_px_initial.copy()
+ self._positions_initial[:, 0] *= self.sampling[0]
+ self._positions_initial[:, 1] *= self.sampling[1]
+
+ # Vectorized Patches
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+
+ # Probe Initialization
+ if self._probe is None:
+ if self._vacuum_probe_intensity is not None:
+ self._semiangle_cutoff = np.inf
+ self._vacuum_probe_intensity = xp.asarray(
+ self._vacuum_probe_intensity, dtype=xp.float32
+ )
+ probe_x0, probe_y0 = get_CoM(
+ self._vacuum_probe_intensity,
+ device=self._device,
+ )
+ self._vacuum_probe_intensity = get_shifted_ar(
+ self._vacuum_probe_intensity,
+ -probe_x0,
+ -probe_y0,
+ bilinear=True,
+ device=self._device,
+ )
+ if crop_patterns:
+ self._vacuum_probe_intensity = self._vacuum_probe_intensity[
+ self._crop_mask
+ ].reshape(self._region_of_interest_shape)
+
+ self._probe = (
+ ComplexProbe(
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ energy=self._energy,
+ semiangle_cutoff=self._semiangle_cutoff,
+ rolloff=self._rolloff,
+ vacuum_probe_intensity=self._vacuum_probe_intensity,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )
+ .build()
+ ._array
+ )
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(self._mean_diffraction_intensity / probe_intensity)
+
+ else:
+ if isinstance(self._probe, ComplexProbe):
+ if self._probe._gpts != self._region_of_interest_shape:
+ raise ValueError()
+ if hasattr(self._probe, "_array"):
+ self._probe = self._probe._array
+ else:
+ self._probe._xp = xp
+ self._probe = self._probe.build()._array
+
+ # Normalize probe to match mean diffraction intensity
+ probe_intensity = xp.sum(xp.abs(xp.fft.fft2(self._probe)) ** 2)
+ self._probe *= xp.sqrt(
+ self._mean_diffraction_intensity / probe_intensity
+ )
+ else:
+ self._probe = xp.asarray(self._probe, dtype=xp.complex64)
+
+ self._probe_initial = self._probe.copy()
+ self._probe_initial_aperture = xp.abs(xp.fft.fft2(self._probe))
+
+ self._known_aberrations_array = ComplexProbe(
+ energy=self._energy,
+ gpts=self._region_of_interest_shape,
+ sampling=self.sampling,
+ parameters=self._polar_parameters,
+ device=self._device,
+ )._evaluate_ctf()
+
+ # overlaps
+ shifted_probes = fft_shift(self._probe, self._positions_px_fractional, xp)
+ probe_intensities = xp.abs(shifted_probes) ** 2
+ probe_overlap = self._sum_overlapping_patches_bincounts(probe_intensities)
+ probe_overlap = self._gaussian_filter(probe_overlap, 1.0)
+
+ if object_fov_mask is None:
+ self._object_fov_mask = asnumpy(probe_overlap > 0.25 * probe_overlap.max())
+ else:
+ self._object_fov_mask = np.asarray(object_fov_mask)
+ self._object_fov_mask_inverse = np.invert(self._object_fov_mask)
+
+ if plot_probe_overlaps:
+ figsize = kwargs.pop("figsize", (9, 4))
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ # initial probe
+ complex_probe_rgb = Complex2RGB(
+ self.probe_centered,
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+
+ extent = [
+ 0,
+ self.sampling[1] * self._object_shape[1],
+ self.sampling[0] * self._object_shape[0],
+ 0,
+ ]
+
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
+
+ ax1.imshow(
+ complex_probe_rgb,
+ extent=probe_extent,
+ )
+
+ divider = make_axes_locatable(ax1)
+ cax1 = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(cax1, chroma_boost=chroma_boost)
+ ax1.set_ylabel("x [A]")
+ ax1.set_xlabel("y [A]")
+ ax1.set_title("Initial probe intensity")
+
+ ax2.imshow(
+ asnumpy(probe_overlap),
+ extent=extent,
+ cmap="gray",
+ )
+ ax2.scatter(
+ self.positions[:, 1],
+ self.positions[:, 0],
+ s=2.5,
+ color=(1, 0, 0, 1),
+ )
+ ax2.set_ylabel("x [A]")
+ ax2.set_xlabel("y [A]")
+ ax2.set_xlim((extent[0], extent[1]))
+ ax2.set_ylim((extent[2], extent[3]))
+ ax2.set_title("Object field of view")
+
+ fig.tight_layout()
+
+ self._preprocessed = True
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _overlap_projection(self, current_object, current_probe):
+ """
+ Ptychographic overlap projection method.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+
+ Returns
+ --------
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ object_patches: np.ndarray
+ Patched object view
+ overlap: np.ndarray
+ shifted_probes * object_patches
+ """
+
+ xp = self._xp
+
+ shifted_probes = fft_shift(current_probe, self._positions_px_fractional, xp)
+
+ if self._object_type == "potential":
+ complex_object = xp.exp(1j * current_object)
+ else:
+ complex_object = current_object
+
+ object_patches = complex_object[
+ self._vectorized_patch_indices_row, self._vectorized_patch_indices_col
+ ]
+
+ overlap = shifted_probes * object_patches
+
+ return shifted_probes, object_patches, overlap
+
+ def _gradient_descent_fourier_projection(self, amplitudes, overlap):
+ """
+ Ptychographic fourier projection method for GD method.
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ overlap: np.ndarray
+ object * probe overlap
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Difference between modified and estimated exit waves
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ fourier_overlap = xp.fft.fft2(overlap)
+ error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2)
+
+ fourier_modified_overlap = amplitudes * xp.exp(1j * xp.angle(fourier_overlap))
+ modified_overlap = xp.fft.ifft2(fourier_modified_overlap)
+
+ exit_waves = modified_overlap - overlap
+
+ return exit_waves, error
+
+ def _projection_sets_fourier_projection(
+ self, amplitudes, overlap, exit_waves, projection_a, projection_b, projection_c
+ ):
+ """
+ Ptychographic fourier projection method for DM_AP and RAAR methods.
+ Generalized projection using three parameters: a,b,c
+
+ DM_AP(\\alpha) : a = -\\alpha, b = 1, c = 1 + \\alpha
+ DM: DM_AP(1.0), AP: DM_AP(0.0)
+
+ RAAR(\\beta) : a = 1-2\\beta, b = \\beta, c = 2
+ DM : RAAR(1.0)
+
+ RRR(\\gamma) : a = -\\gamma, b = \\gamma, c = 2
+ DM: RRR(1.0)
+
+ SUPERFLIP : a = 0, b = 1, c = 2
+
+ Parameters
+ --------
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ overlap: np.ndarray
+ object * probe overlap
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ xp = self._xp
+ projection_x = 1 - projection_a - projection_b
+ projection_y = 1 - projection_c
+
+ if exit_waves is None:
+ exit_waves = overlap.copy()
+
+ fourier_overlap = xp.fft.fft2(overlap)
+ error = xp.sum(xp.abs(amplitudes - xp.abs(fourier_overlap)) ** 2)
+
+ factor_to_be_projected = projection_c * overlap + projection_y * exit_waves
+ fourier_projected_factor = xp.fft.fft2(factor_to_be_projected)
+
+ fourier_projected_factor = amplitudes * xp.exp(
+ 1j * xp.angle(fourier_projected_factor)
+ )
+ projected_factor = xp.fft.ifft2(fourier_projected_factor)
+
+ exit_waves = (
+ projection_x * exit_waves
+ + projection_a * overlap
+ + projection_b * projected_factor
+ )
+
+ return exit_waves, error
+
+ def _forward(
+ self,
+ current_object,
+ current_probe,
+ amplitudes,
+ exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ ):
+ """
+ Ptychographic forward operator.
+ Calls _overlap_projection() and the appropriate _fourier_projection().
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ amplitudes: np.ndarray
+ Normalized measured amplitudes
+ exit_waves: np.ndarray
+ previously estimated exit waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ projection_a: float
+ projection_b: float
+ projection_c: float
+
+ Returns
+ --------
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ object_patches: np.ndarray
+ Patched object view
+ overlap: np.ndarray
+ object * probe overlap
+ exit_waves:np.ndarray
+ Updated exit_waves
+ error: float
+ Reconstruction error
+ """
+
+ shifted_probes, object_patches, overlap = self._overlap_projection(
+ current_object, current_probe
+ )
+ if use_projection_scheme:
+ exit_waves, error = self._projection_sets_fourier_projection(
+ amplitudes,
+ overlap,
+ exit_waves,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ else:
+ exit_waves, error = self._gradient_descent_fourier_projection(
+ amplitudes, overlap
+ )
+
+ return shifted_probes, object_patches, overlap, exit_waves, error
+
+ def _gradient_descent_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for GD method.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(shifted_probes) ** 2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ if self._object_type == "potential":
+ current_object += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * xp.conj(object_patches)
+ * xp.conj(shifted_probes)
+ * exit_waves
+ )
+ )
+ * probe_normalization
+ )
+ elif self._object_type == "complex":
+ current_object += step_size * (
+ self._sum_overlapping_patches_bincounts(
+ xp.conj(shifted_probes) * exit_waves
+ )
+ * probe_normalization
+ )
+
+ if not fix_probe:
+ object_normalization = xp.sum(
+ (xp.abs(object_patches) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe += step_size * (
+ xp.sum(
+ xp.conj(object_patches) * exit_waves,
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return current_object, current_probe
+
+ def _projection_sets_adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ ):
+ """
+ Ptychographic adjoint operator for DM_AP and RAAR methods.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+ xp = self._xp
+
+ probe_normalization = self._sum_overlapping_patches_bincounts(
+ xp.abs(shifted_probes) ** 2
+ )
+ probe_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * probe_normalization) ** 2
+ + (normalization_min * xp.max(probe_normalization)) ** 2
+ )
+
+ if self._object_type == "potential":
+ current_object = (
+ self._sum_overlapping_patches_bincounts(
+ xp.real(
+ -1j
+ * xp.conj(object_patches)
+ * xp.conj(shifted_probes)
+ * exit_waves
+ )
+ )
+ * probe_normalization
+ )
+ elif self._object_type == "complex":
+ current_object = (
+ self._sum_overlapping_patches_bincounts(
+ xp.conj(shifted_probes) * exit_waves
+ )
+ * probe_normalization
+ )
+
+ if not fix_probe:
+ object_normalization = xp.sum(
+ (xp.abs(object_patches) ** 2),
+ axis=0,
+ )
+ object_normalization = 1 / xp.sqrt(
+ 1e-16
+ + ((1 - normalization_min) * object_normalization) ** 2
+ + (normalization_min * xp.max(object_normalization)) ** 2
+ )
+
+ current_probe = (
+ xp.sum(
+ xp.conj(object_patches) * exit_waves,
+ axis=0,
+ )
+ * object_normalization
+ )
+
+ return current_object, current_probe
+
+ def _adjoint(
+ self,
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ use_projection_scheme: bool,
+ step_size: float,
+ normalization_min: float,
+ fix_probe: bool,
+ ):
+ """
+ Ptychographic adjoint operator.
+ Computes object and probe update steps.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ object_patches: np.ndarray
+ Patched object view
+ shifted_probes:np.ndarray
+ fractionally-shifted probes
+ exit_waves:np.ndarray
+ Updated exit_waves
+ use_projection_scheme: bool,
+ If True, use generalized projection update
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ fix_probe: bool, optional
+ If True, probe will not be updated
+
+ Returns
+ --------
+ updated_object: np.ndarray
+ Updated object estimate
+ updated_probe: np.ndarray
+ Updated probe estimate
+ """
+
+ if use_projection_scheme:
+ current_object, current_probe = self._projection_sets_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ normalization_min,
+ fix_probe,
+ )
+ else:
+ current_object, current_probe = self._gradient_descent_adjoint(
+ current_object,
+ current_probe,
+ object_patches,
+ shifted_probes,
+ exit_waves,
+ step_size,
+ normalization_min,
+ fix_probe,
+ )
+
+ return current_object, current_probe
+
+ def _constraints(
+ self,
+ current_object,
+ current_probe,
+ current_positions,
+ pure_phase_object,
+ fix_com,
+ fit_probe_aberrations,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ constrain_probe_amplitude,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ fix_probe_aperture,
+ initial_probe_aperture,
+ fix_positions,
+ global_affine_transformation,
+ gaussian_filter,
+ gaussian_filter_sigma,
+ butterworth_filter,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ tv_denoise,
+ tv_denoise_weight,
+ tv_denoise_inner_iter,
+ object_positivity,
+ shrinkage_rad,
+ object_mask,
+ ):
+ """
+ Ptychographic constraints operator.
+
+ Parameters
+ --------
+ current_object: np.ndarray
+ Current object estimate
+ current_probe: np.ndarray
+ Current probe estimate
+ current_positions: np.ndarray
+ Current positions estimate
+ pure_phase_object: bool
+ If True, object amplitude is set to unity
+ fix_com: bool
+ If True, probe CoM is fixed to the center
+ fit_probe_aberrations: bool
+ If True, fits the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ constrain_probe_amplitude: bool
+ If True, probe amplitude is constrained by top hat function
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude: bool
+ If True, probe aperture is constrained by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_probe_aperture: bool,
+ If True, probe Fourier amplitude is replaced by initial probe aperture.
+ initial_probe_aperture: np.ndarray,
+ Initial probe aperture to use in replacing probe Fourier amplitude.
+ fix_positions: bool
+ If True, positions are not updated
+ gaussian_filter: bool
+ If True, applies real-space gaussian filter
+ gaussian_filter_sigma: float
+ Standard deviation of gaussian kernel in A
+ butterworth_filter: bool
+ If True, applies high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ tv_denoise: bool
+ If True, applies TV denoising on object
+ tv_denoise_weight: float
+ Denoising weight. The greater `weight`, the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ object_positivity: bool
+ If True, clips negative potential values
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ object_mask: np.ndarray (boolean)
+ If not None, used to calculate additional shrinkage using masked-mean of object
+
+ Returns
+ --------
+ constrained_object: np.ndarray
+ Constrained object estimate
+ constrained_probe: np.ndarray
+ Constrained probe estimate
+ constrained_positions: np.ndarray
+ Constrained positions estimate
+ """
+
+ if gaussian_filter:
+ current_object = self._object_gaussian_constraint(
+ current_object, gaussian_filter_sigma, pure_phase_object
+ )
+
+ if butterworth_filter:
+ current_object = self._object_butterworth_constraint(
+ current_object,
+ q_lowpass,
+ q_highpass,
+ butterworth_order,
+ )
+
+ if tv_denoise:
+ current_object = self._object_denoise_tv_pylops(
+ current_object, tv_denoise_weight, tv_denoise_inner_iter
+ )
+
+ if shrinkage_rad > 0.0 or object_mask is not None:
+ current_object = self._object_shrinkage_constraint(
+ current_object,
+ shrinkage_rad,
+ object_mask,
+ )
+
+ if self._object_type == "complex":
+ current_object = self._object_threshold_constraint(
+ current_object, pure_phase_object
+ )
+ elif object_positivity:
+ current_object = self._object_positivity_constraint(current_object)
+
+ if fix_com:
+ current_probe = self._probe_center_of_mass_constraint(current_probe)
+
+ if fix_probe_aperture:
+ current_probe = self._probe_aperture_constraint(
+ current_probe,
+ initial_probe_aperture,
+ )
+ elif constrain_probe_fourier_amplitude:
+ current_probe = self._probe_fourier_amplitude_constraint(
+ current_probe,
+ constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity,
+ )
+
+ if fit_probe_aberrations:
+ current_probe = self._probe_aberration_fitting_constraint(
+ current_probe,
+ fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order,
+ )
+
+ if constrain_probe_amplitude:
+ current_probe = self._probe_amplitude_constraint(
+ current_probe,
+ constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width,
+ )
+
+ if not fix_positions:
+ current_positions = self._positions_center_of_mass_constraint(
+ current_positions
+ )
+
+ if global_affine_transformation:
+ current_positions = self._positions_affine_transformation_constraint(
+ self._positions_px_initial, current_positions
+ )
+
+ return current_object, current_probe, current_positions
+
+ def reconstruct(
+ self,
+ max_iter: int = 64,
+ reconstruction_method: str = "gradient-descent",
+ reconstruction_parameter: float = 1.0,
+ reconstruction_parameter_a: float = None,
+ reconstruction_parameter_b: float = None,
+ reconstruction_parameter_c: float = None,
+ max_batch_size: int = None,
+ seed_random: int = None,
+ step_size: float = 0.5,
+ normalization_min: float = 1,
+ positions_step_size: float = 0.9,
+ pure_phase_object_iter: int = 0,
+ fix_com: bool = True,
+ fix_probe_iter: int = 0,
+ fix_probe_aperture_iter: int = 0,
+ constrain_probe_amplitude_iter: int = 0,
+ constrain_probe_amplitude_relative_radius: float = 0.5,
+ constrain_probe_amplitude_relative_width: float = 0.05,
+ constrain_probe_fourier_amplitude_iter: int = 0,
+ constrain_probe_fourier_amplitude_max_width_pixels: float = 3.0,
+ constrain_probe_fourier_amplitude_constant_intensity: bool = False,
+ fix_positions_iter: int = np.inf,
+ constrain_position_distance: float = None,
+ global_affine_transformation: bool = True,
+ gaussian_filter_sigma: float = None,
+ gaussian_filter_iter: int = np.inf,
+ fit_probe_aberrations_iter: int = 0,
+ fit_probe_aberrations_max_angular_order: int = 4,
+ fit_probe_aberrations_max_radial_order: int = 4,
+ butterworth_filter_iter: int = np.inf,
+ q_lowpass: float = None,
+ q_highpass: float = None,
+ butterworth_order: float = 2,
+ tv_denoise_iter: int = np.inf,
+ tv_denoise_weight: float = None,
+ tv_denoise_inner_iter: float = 40,
+ object_positivity: bool = True,
+ shrinkage_rad: float = 0.0,
+ fix_potential_baseline: bool = True,
+ switch_object_iter: int = np.inf,
+ store_iterations: bool = False,
+ progress_bar: bool = True,
+ reset: bool = None,
+ ):
+ """
+ Ptychographic reconstruction main method.
+
+ Parameters
+ --------
+ max_iter: int, optional
+ Maximum number of iterations to run
+ reconstruction_method: str, optional
+ Specifies which reconstruction algorithm to use, one of:
+ "generalized-projections",
+ "DM_AP" (or "difference-map_alternating-projections"),
+ "RAAR" (or "relaxed-averaged-alternating-reflections"),
+ "RRR" (or "relax-reflect-reflect"),
+ "SUPERFLIP" (or "charge-flipping"), or
+ "GD" (or "gradient_descent")
+ reconstruction_parameter: float, optional
+ Reconstruction parameter for various reconstruction methods above.
+ reconstruction_parameter_a: float, optional
+ Reconstruction parameter a for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_b: float, optional
+ Reconstruction parameter b for reconstruction_method='generalized-projections'.
+ reconstruction_parameter_c: float, optional
+ Reconstruction parameter c for reconstruction_method='generalized-projections'.
+ max_batch_size: int, optional
+ Max number of probes to update at once
+ seed_random: int, optional
+ Seeds the random number generator, only applicable when max_batch_size is not None
+ step_size: float, optional
+ Update step size
+ normalization_min: float, optional
+ Probe normalization minimum as a fraction of the maximum overlap intensity
+ positions_step_size: float, optional
+ Positions update step size
+ pure_phase_object_iter: int, optional
+ Number of iterations where object amplitude is set to unity
+ fix_com: bool, optional
+ If True, fixes center of mass of probe
+ fix_probe_iter: int, optional
+ Number of iterations to run with a fixed probe before updating probe estimate
+ fix_probe_aperture_iter: int, optional
+ Number of iterations to run with a fixed probe Fourier amplitude before updating probe estimate
+ constrain_probe_amplitude_iter: int, optional
+ Number of iterations to run while constraining the real-space probe with a top-hat support.
+ constrain_probe_amplitude_relative_radius: float
+ Relative location of top-hat inflection point, between 0 and 0.5
+ constrain_probe_amplitude_relative_width: float
+ Relative width of top-hat sigmoid, between 0 and 0.5
+ constrain_probe_fourier_amplitude_iter: int, optional
+ Number of iterations to run while constraining the Fourier-space probe by fitting a sigmoid for each angular frequency.
+ constrain_probe_fourier_amplitude_max_width_pixels: float
+ Maximum pixel width of fitted sigmoid functions.
+ constrain_probe_fourier_amplitude_constant_intensity: bool
+ If True, the probe aperture is additionally constrained to a constant intensity.
+ fix_positions_iter: int, optional
+ Number of iterations to run with fixed positions before updating positions estimate
+ constrain_position_distance: float, optional
+ Distance to constrain position correction within original
+ field of view in A
+ global_affine_transformation: bool, optional
+ If True, positions are assumed to be a global affine transform from initial scan
+ gaussian_filter_sigma: float, optional
+ Standard deviation of gaussian kernel in A
+ gaussian_filter_iter: int, optional
+ Number of iterations to run using object smoothness constraint
+ fit_probe_aberrations_iter: int, optional
+ Number of iterations to run while fitting the probe aberrations to a low-order expansion
+ fit_probe_aberrations_max_angular_order: bool
+ Max angular order of probe aberrations basis functions
+ fit_probe_aberrations_max_radial_order: bool
+ Max radial order of probe aberrations basis functions
+ butterworth_filter_iter: int, optional
+ Number of iterations to run using high-pass butteworth filter
+ q_lowpass: float
+ Cut-off frequency in A^-1 for low-pass butterworth filter
+ q_highpass: float
+ Cut-off frequency in A^-1 for high-pass butterworth filter
+ butterworth_order: float
+ Butterworth filter order. Smaller gives a smoother filter
+ tv_denoise_iter: int, optional
+ Number of iterations to run using tv denoise filter on object
+ tv_denoise_weight: float
+ Denoising weight. The greater `weight`, the more denoising.
+ tv_denoise_inner_iter: float
+ Number of iterations to run in inner loop of TV denoising
+ object_positivity: bool, optional
+ If True, forces object to be positive
+ shrinkage_rad: float
+ Phase shift in radians to be subtracted from the potential at each iteration
+ fix_potential_baseline: bool
+ If true, the potential mean outside the FOV is forced to zero at each iteration
+ switch_object_iter: int, optional
+ Iteration to switch object type between 'complex' and 'potential' or between
+ 'potential' and 'complex'
+ store_iterations: bool, optional
+ If True, reconstructed objects and probes are stored at each iteration
+ progress_bar: bool, optional
+ If True, reconstruction progress is displayed
+ reset: bool, optional
+ If True, previous reconstructions are ignored
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+ asnumpy = self._asnumpy
+ xp = self._xp
+
+ # Reconstruction method
+
+ if reconstruction_method == "generalized-projections":
+ if (
+ reconstruction_parameter_a is None
+ or reconstruction_parameter_b is None
+ or reconstruction_parameter_c is None
+ ):
+ raise ValueError(
+ (
+ "reconstruction_parameter_a/b/c must all be specified "
+ "when using reconstruction_method='generalized-projections'."
+ )
+ )
+
+ use_projection_scheme = True
+ projection_a = reconstruction_parameter_a
+ projection_b = reconstruction_parameter_b
+ projection_c = reconstruction_parameter_c
+ step_size = None
+ elif (
+ reconstruction_method == "DM_AP"
+ or reconstruction_method == "difference-map_alternating-projections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = 1
+ projection_c = 1 + reconstruction_parameter
+ step_size = None
+ elif (
+ reconstruction_method == "RAAR"
+ or reconstruction_method == "relaxed-averaged-alternating-reflections"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 1.0:
+ raise ValueError("reconstruction_parameter must be between 0-1.")
+
+ use_projection_scheme = True
+ projection_a = 1 - 2 * reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "RRR"
+ or reconstruction_method == "relax-reflect-reflect"
+ ):
+ if reconstruction_parameter < 0.0 or reconstruction_parameter > 2.0:
+ raise ValueError("reconstruction_parameter must be between 0-2.")
+
+ use_projection_scheme = True
+ projection_a = -reconstruction_parameter
+ projection_b = reconstruction_parameter
+ projection_c = 2
+ step_size = None
+ elif (
+ reconstruction_method == "SUPERFLIP"
+ or reconstruction_method == "charge-flipping"
+ ):
+ use_projection_scheme = True
+ projection_a = 0
+ projection_b = 1
+ projection_c = 2
+ reconstruction_parameter = None
+ step_size = None
+ elif (
+ reconstruction_method == "GD" or reconstruction_method == "gradient-descent"
+ ):
+ use_projection_scheme = False
+ projection_a = None
+ projection_b = None
+ projection_c = None
+ reconstruction_parameter = None
+ else:
+ raise ValueError(
+ (
+ "reconstruction_method must be one of 'generalized-projections', "
+ "'DM_AP' (or 'difference-map_alternating-projections'), "
+ "'RAAR' (or 'relaxed-averaged-alternating-reflections'), "
+ "'RRR' (or 'relax-reflect-reflect'), "
+ "'SUPERFLIP' (or 'charge-flipping'), "
+ f"or 'GD' (or 'gradient-descent'), not {reconstruction_method}."
+ )
+ )
+
+ if self._verbose:
+ if switch_object_iter > max_iter:
+ first_line = f"Performing {max_iter} iterations using a {self._object_type} object type, "
+ else:
+ switch_object_type = (
+ "complex" if self._object_type == "potential" else "potential"
+ )
+ first_line = (
+ f"Performing {switch_object_iter} iterations using a {self._object_type} object type and "
+ f"{max_iter - switch_object_iter} iterations using a {switch_object_type} object type, "
+ )
+ if max_batch_size is not None:
+ if use_projection_scheme:
+ raise ValueError(
+ (
+ "Stochastic object/probe updating is inconsistent with 'DM_AP', 'RAAR', 'RRR', and 'SUPERFLIP'. "
+ "Use reconstruction_method='GD' or set max_batch_size=None."
+ )
+ )
+ else:
+ print(
+ (
+ first_line + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}, "
+ f"in batches of max {max_batch_size} measurements."
+ )
+ )
+
+ else:
+ if reconstruction_parameter is not None:
+ if np.array(reconstruction_parameter).shape == (3,):
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and (a,b,c): {reconstruction_parameter}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and α: {reconstruction_parameter}."
+ )
+ )
+ else:
+ if step_size is not None:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min}."
+ )
+ )
+ else:
+ print(
+ (
+ first_line
+ + f"with the {reconstruction_method} algorithm, "
+ f"with normalization_min: {normalization_min} and step _size: {step_size}."
+ )
+ )
+
+ # Batching
+ shuffled_indices = np.arange(self._num_diffraction_patterns)
+ unshuffled_indices = np.zeros_like(shuffled_indices)
+
+ if max_batch_size is not None:
+ xp.random.seed(seed_random)
+ else:
+ max_batch_size = self._num_diffraction_patterns
+
+ # initialization
+ if store_iterations and (not hasattr(self, "object_iterations") or reset):
+ self.object_iterations = []
+ self.probe_iterations = []
+
+ if reset:
+ self.error_iterations = []
+ self._object = self._object_initial.copy()
+ self._probe = self._probe_initial.copy()
+ self._positions_px = self._positions_px_initial.copy()
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ self._exit_waves = None
+ self._object_type = self._object_type_initial
+ if hasattr(self, "_tf"):
+ del self._tf
+ elif reset is None:
+ if hasattr(self, "error"):
+ warnings.warn(
+ (
+ "Continuing reconstruction from previous result. "
+ "Use reset=True for a fresh start."
+ ),
+ UserWarning,
+ )
+ else:
+ self.error_iterations = []
+ self._exit_waves = None
+
+ # main loop
+ for a0 in tqdmnd(
+ max_iter,
+ desc="Reconstructing object and probe",
+ unit=" iter",
+ disable=not progress_bar,
+ ):
+ error = 0.0
+
+ if a0 == switch_object_iter:
+ if self._object_type == "potential":
+ self._object_type = "complex"
+ self._object = xp.exp(1j * self._object)
+ elif self._object_type == "complex":
+ self._object_type = "potential"
+ self._object = xp.angle(self._object)
+
+ # randomize
+ if not use_projection_scheme:
+ np.random.shuffle(shuffled_indices)
+ unshuffled_indices[shuffled_indices] = np.arange(
+ self._num_diffraction_patterns
+ )
+ positions_px = self._positions_px.copy()[shuffled_indices]
+
+ for start, end in generate_batches(
+ self._num_diffraction_patterns, max_batch=max_batch_size
+ ):
+ # batch indices
+ self._positions_px = positions_px[start:end]
+ self._positions_px_fractional = self._positions_px - xp.round(
+ self._positions_px
+ )
+ (
+ self._vectorized_patch_indices_row,
+ self._vectorized_patch_indices_col,
+ ) = self._extract_vectorized_patch_indices()
+ amplitudes = self._amplitudes[shuffled_indices[start:end]]
+
+ # forward operator
+ (
+ shifted_probes,
+ object_patches,
+ overlap,
+ self._exit_waves,
+ batch_error,
+ ) = self._forward(
+ self._object,
+ self._probe,
+ amplitudes,
+ self._exit_waves,
+ use_projection_scheme,
+ projection_a,
+ projection_b,
+ projection_c,
+ )
+
+ # adjoint operator
+ self._object, self._probe = self._adjoint(
+ self._object,
+ self._probe,
+ object_patches,
+ shifted_probes,
+ self._exit_waves,
+ use_projection_scheme=use_projection_scheme,
+ step_size=step_size,
+ normalization_min=normalization_min,
+ fix_probe=a0 < fix_probe_iter,
+ )
+
+ # position correction
+ if a0 >= fix_positions_iter:
+ positions_px[start:end] = self._position_correction(
+ self._object,
+ shifted_probes,
+ overlap,
+ amplitudes,
+ self._positions_px,
+ positions_step_size,
+ constrain_position_distance,
+ )
+
+ error += batch_error
+
+ # Normalize Error
+ error /= self._mean_diffraction_intensity * self._num_diffraction_patterns
+
+ # constraints
+ self._positions_px = positions_px.copy()[unshuffled_indices]
+ self._object, self._probe, self._positions_px = self._constraints(
+ self._object,
+ self._probe,
+ self._positions_px,
+ fix_com=fix_com and a0 >= fix_probe_iter,
+ constrain_probe_amplitude=a0 < constrain_probe_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_amplitude_relative_radius=constrain_probe_amplitude_relative_radius,
+ constrain_probe_amplitude_relative_width=constrain_probe_amplitude_relative_width,
+ constrain_probe_fourier_amplitude=a0
+ < constrain_probe_fourier_amplitude_iter
+ and a0 >= fix_probe_iter,
+ constrain_probe_fourier_amplitude_max_width_pixels=constrain_probe_fourier_amplitude_max_width_pixels,
+ constrain_probe_fourier_amplitude_constant_intensity=constrain_probe_fourier_amplitude_constant_intensity,
+ fit_probe_aberrations=a0 < fit_probe_aberrations_iter
+ and a0 >= fix_probe_iter,
+ fit_probe_aberrations_max_angular_order=fit_probe_aberrations_max_angular_order,
+ fit_probe_aberrations_max_radial_order=fit_probe_aberrations_max_radial_order,
+ fix_probe_aperture=a0 < fix_probe_aperture_iter,
+ initial_probe_aperture=self._probe_initial_aperture,
+ fix_positions=a0 < fix_positions_iter,
+ global_affine_transformation=global_affine_transformation,
+ gaussian_filter=a0 < gaussian_filter_iter
+ and gaussian_filter_sigma is not None,
+ gaussian_filter_sigma=gaussian_filter_sigma,
+ butterworth_filter=a0 < butterworth_filter_iter
+ and (q_lowpass is not None or q_highpass is not None),
+ q_lowpass=q_lowpass,
+ q_highpass=q_highpass,
+ butterworth_order=butterworth_order,
+ tv_denoise=a0 < tv_denoise_iter and tv_denoise_weight is not None,
+ tv_denoise_weight=tv_denoise_weight,
+ tv_denoise_inner_iter=tv_denoise_inner_iter,
+ object_positivity=object_positivity,
+ shrinkage_rad=shrinkage_rad,
+ object_mask=self._object_fov_mask_inverse
+ if fix_potential_baseline and self._object_fov_mask_inverse.sum() > 0
+ else None,
+ pure_phase_object=a0 < pure_phase_object_iter
+ and self._object_type == "complex",
+ )
+
+ self.error_iterations.append(error.item())
+ if store_iterations:
+ self.object_iterations.append(asnumpy(self._object.copy()))
+ self.probe_iterations.append(self.probe_centered)
+
+ # store result
+ self.object = asnumpy(self._object)
+ self.probe = self.probe_centered
+ self.error = error.item()
+
+ if self._device == "gpu":
+ xp._default_memory_pool.free_all_blocks()
+ xp.clear_memo()
+
+ return self
+
+ def _visualize_last_iteration_figax(
+ self,
+ fig,
+ object_ax,
+ convergence_ax: None,
+ cbar: bool,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object on a given fig/ax.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure object_ax lives in
+ object_ax: Axes
+ Matplotlib axes to plot reconstructed object in
+ convergence_ax: Axes, optional
+ Matplotlib axes to plot convergence plot in
+ cbar: bool, optional
+ If true, displays a colorbar
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ cmap = kwargs.pop("cmap", "magma")
+
+ if self._object_type == "complex":
+ obj = np.angle(self.object)
+ else:
+ obj = self.object
+
+ rotated_object = self._crop_rotate_object_fov(obj, padding=padding)
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ im = object_ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(object_ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if convergence_ax is not None and hasattr(self, "error_iterations"):
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = self.error_iterations
+ convergence_ax.semilogy(np.arange(len(errors)), errors, **kwargs)
+
+ def _visualize_last_iteration(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays last reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool, optional
+ If true, the reconstructed complex probe is displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ figsize = kwargs.pop("figsize", (8, 5))
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ if self._object_type == "complex":
+ obj = np.angle(self.object)
+ else:
+ obj = self.object
+
+ rotated_object = self._crop_rotate_object_fov(obj, padding=padding)
+ rotated_shape = rotated_object.shape
+
+ extent = [
+ 0,
+ self.sampling[1] * rotated_shape[1],
+ self.sampling[0] * rotated_shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=2,
+ height_ratios=[4, 1],
+ hspace=0.15,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0.15)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(
+ ncols=2,
+ nrows=1,
+ width_ratios=[
+ (extent[1] / extent[2]) / (probe_extent[1] / probe_extent[2]),
+ 1,
+ ],
+ wspace=0.35,
+ )
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ if plot_probe or plot_fourier_probe:
+ # Object
+ ax = fig.add_subplot(spec[0, 0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed object potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed object phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ # Probe
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+
+ ax = fig.add_subplot(spec[0, 1])
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ self.probe_fourier,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title("Reconstructed Fourier probe")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+ else:
+ probe_array = Complex2RGB(
+ self.probe,
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title("Reconstructed probe intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(ax_cb, chroma_boost=chroma_boost)
+
+ else:
+ ax = fig.add_subplot(spec[0])
+ im = ax.imshow(
+ rotated_object,
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if self._object_type == "potential":
+ ax.set_title("Reconstructed object potential")
+ elif self._object_type == "complex":
+ ax.set_title("Reconstructed object phase")
+
+ if cbar:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ fig.add_axes(ax_cb)
+ fig.colorbar(im, cax=ax_cb)
+
+ if plot_convergence and hasattr(self, "error_iterations"):
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ errors = np.array(self.error_iterations)
+ if plot_probe:
+ ax = fig.add_subplot(spec[1, :])
+ else:
+ ax = fig.add_subplot(spec[1])
+ ax.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax.set_ylabel("NMSE")
+ ax.set_xlabel("Iteration number")
+ ax.yaxis.tick_right()
+
+ fig.suptitle(f"Normalized mean squared error: {self.error:.3e}")
+ spec.tight_layout(fig)
+
+ def _visualize_all_iterations(
+ self,
+ fig,
+ cbar: bool,
+ plot_convergence: bool,
+ plot_probe: bool,
+ plot_fourier_probe: bool,
+ iterations_grid: Tuple[int, int],
+ padding: int,
+ **kwargs,
+ ):
+ """
+ Displays all reconstructed object and probe iterations.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed complex probe is displayed
+ plot_fourier_probe: bool
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+ """
+ asnumpy = self._asnumpy
+
+ if not hasattr(self, "object_iterations"):
+ raise ValueError(
+ (
+ "Object and probe iterations were not saved during reconstruction. "
+ "Please re-run using store_iterations=True."
+ )
+ )
+
+ if iterations_grid == "auto":
+ num_iter = len(self.error_iterations)
+
+ if num_iter == 1:
+ return self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ elif plot_probe or plot_fourier_probe:
+ iterations_grid = (2, 4) if num_iter > 4 else (2, num_iter)
+ else:
+ iterations_grid = (2, 4) if num_iter > 8 else (2, num_iter // 2)
+ else:
+ if (plot_probe or plot_fourier_probe) and iterations_grid[0] != 2:
+ raise ValueError()
+
+ auto_figsize = (
+ (3 * iterations_grid[1], 3 * iterations_grid[0] + 1)
+ if plot_convergence
+ else (3 * iterations_grid[1], 3 * iterations_grid[0])
+ )
+ figsize = kwargs.pop("figsize", auto_figsize)
+ cmap = kwargs.pop("cmap", "magma")
+
+ if plot_fourier_probe:
+ chroma_boost = kwargs.pop("chroma_boost", 2)
+ else:
+ chroma_boost = kwargs.pop("chroma_boost", 1)
+
+ errors = np.array(self.error_iterations)
+
+ objects = []
+ object_type = []
+
+ for obj in self.object_iterations:
+ if np.iscomplexobj(obj):
+ obj = np.angle(obj)
+ object_type.append("phase")
+ else:
+ object_type.append("potential")
+ objects.append(self._crop_rotate_object_fov(obj, padding=padding))
+
+ if plot_probe or plot_fourier_probe:
+ total_grids = (np.prod(iterations_grid) / 2).astype("int")
+ probes = self.probe_iterations
+ else:
+ total_grids = np.prod(iterations_grid)
+ max_iter = len(objects) - 1
+ grid_range = range(0, max_iter + 1, max_iter // (total_grids - 1))
+
+ extent = [
+ 0,
+ self.sampling[1] * objects[0].shape[1],
+ self.sampling[0] * objects[0].shape[0],
+ 0,
+ ]
+
+ if plot_fourier_probe:
+ probe_extent = [
+ -self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[1] * self._region_of_interest_shape[1] / 2,
+ self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ -self.angular_sampling[0] * self._region_of_interest_shape[0] / 2,
+ ]
+ elif plot_probe:
+ probe_extent = [
+ 0,
+ self.sampling[1] * self._region_of_interest_shape[1],
+ self.sampling[0] * self._region_of_interest_shape[0],
+ 0,
+ ]
+
+ if plot_convergence:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=3, height_ratios=[4, 4, 1], hspace=0)
+ else:
+ spec = GridSpec(ncols=1, nrows=2, height_ratios=[4, 1], hspace=0)
+ else:
+ if plot_probe or plot_fourier_probe:
+ spec = GridSpec(ncols=1, nrows=2)
+ else:
+ spec = GridSpec(ncols=1, nrows=1)
+
+ if fig is None:
+ fig = plt.figure(figsize=figsize)
+
+ grid = ImageGrid(
+ fig,
+ spec[0],
+ nrows_ncols=(1, iterations_grid[1])
+ if (plot_probe or plot_fourier_probe)
+ else iterations_grid,
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ im = ax.imshow(
+ objects[grid_range[n]],
+ extent=extent,
+ cmap=cmap,
+ **kwargs,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} {object_type[grid_range[n]]}")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+ if cbar:
+ grid.cbar_axes[n].colorbar(im)
+
+ if plot_probe or plot_fourier_probe:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+
+ grid = ImageGrid(
+ fig,
+ spec[1],
+ nrows_ncols=(1, iterations_grid[1]),
+ axes_pad=(0.75, 0.5) if cbar else 0.5,
+ cbar_mode="each" if cbar else None,
+ cbar_pad="2.5%" if cbar else None,
+ )
+
+ for n, ax in enumerate(grid):
+ if plot_fourier_probe:
+ probe_array = Complex2RGB(
+ asnumpy(
+ self._return_fourier_probe_from_centered_probe(
+ probes[grid_range[n]]
+ )
+ ),
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} Fourier probe")
+ ax.set_ylabel("kx [mrad]")
+ ax.set_xlabel("ky [mrad]")
+
+ else:
+ probe_array = Complex2RGB(
+ probes[grid_range[n]],
+ power=2,
+ chroma_boost=chroma_boost,
+ )
+ ax.set_title(f"Iter: {grid_range[n]} probe intensity")
+ ax.set_ylabel("x [A]")
+ ax.set_xlabel("y [A]")
+
+ im = ax.imshow(
+ probe_array,
+ extent=probe_extent,
+ )
+
+ if cbar:
+ add_colorbar_arg(
+ grid.cbar_axes[n],
+ chroma_boost=chroma_boost,
+ )
+
+ if plot_convergence:
+ kwargs.pop("vmin", None)
+ kwargs.pop("vmax", None)
+ if plot_probe:
+ ax2 = fig.add_subplot(spec[2])
+ else:
+ ax2 = fig.add_subplot(spec[1])
+ ax2.semilogy(np.arange(errors.shape[0]), errors, **kwargs)
+ ax2.set_ylabel("NMSE")
+ ax2.set_xlabel("Iteration number")
+ ax2.yaxis.tick_right()
+
+ spec.tight_layout(fig)
+
+ def visualize(
+ self,
+ fig=None,
+ iterations_grid: Tuple[int, int] = None,
+ plot_convergence: bool = True,
+ plot_probe: bool = True,
+ plot_fourier_probe: bool = False,
+ cbar: bool = True,
+ padding: int = 0,
+ **kwargs,
+ ):
+ """
+ Displays reconstructed object and probe.
+
+ Parameters
+ --------
+ fig: Figure
+ Matplotlib figure to place Gridspec in
+ plot_convergence: bool, optional
+ If true, the normalized mean squared error (NMSE) plot is displayed
+ iterations_grid: Tuple[int,int]
+ Grid dimensions to plot reconstruction iterations
+ cbar: bool, optional
+ If true, displays a colorbar
+ plot_probe: bool
+ If true, the reconstructed probe intensity is also displayed
+ plot_fourier_probe: bool, optional
+ If true, the reconstructed complex Fourier probe is displayed
+ padding : int, optional
+ Pixels to pad by post rotating-cropping object
+
+ Returns
+ --------
+ self: PtychographicReconstruction
+ Self to accommodate chaining
+ """
+
+ if iterations_grid is None:
+ self._visualize_last_iteration(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+ else:
+ self._visualize_all_iterations(
+ fig=fig,
+ plot_convergence=plot_convergence,
+ iterations_grid=iterations_grid,
+ plot_probe=plot_probe,
+ plot_fourier_probe=plot_fourier_probe,
+ cbar=cbar,
+ padding=padding,
+ **kwargs,
+ )
+
+ return self
diff --git a/py4DSTEM/process/phase/parameter_optimize.py b/py4DSTEM/process/phase/parameter_optimize.py
new file mode 100644
index 000000000..91a71cb30
--- /dev/null
+++ b/py4DSTEM/process/phase/parameter_optimize.py
@@ -0,0 +1,598 @@
+from functools import partial
+from typing import Callable, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.gridspec import GridSpec
+from py4DSTEM.process.phase.iterative_base_class import PhaseReconstruction
+from py4DSTEM.process.phase.utils import AffineTransform
+from skopt import gp_minimize
+from skopt.plots import plot_convergence as skopt_plot_convergence
+from skopt.plots import plot_evaluations as skopt_plot_evaluations
+from skopt.plots import plot_gaussian_process as skopt_plot_gaussian_process
+from skopt.plots import plot_objective as skopt_plot_objective
+from skopt.space import Categorical, Integer, Real
+from skopt.utils import use_named_args
+from tqdm import tqdm
+
+
+class PtychographyOptimizer:
+ """
+ Optimize ptychographic hyperparameters with Bayesian Optimization of a
+ Gaussian process. Any of the scalar-valued real or integer, boolean, or categorical
+ arguments to the ptychographic init-preprocess-reconstruct pipeline can be optimized over.
+ """
+
+ def __init__(
+ self,
+ reconstruction_type: type[PhaseReconstruction],
+ init_args: dict,
+ preprocess_args: dict = {},
+ reconstruction_args: dict = {},
+ affine_args: dict = {},
+ ):
+ """
+ Parameter optimization for ptychographic reconstruction based on Bayesian Optimization
+ with Gaussian Process.
+
+ Usage
+ -----
+ Dictionaries of the arguments to __init__, AffineTransform (for distorting the initial
+ scan positions), preprocess, and reconstruct are required. For parameters not optimized
+ over, the value in the dictionary is used. To optimize a parameter, instead pass an
+ OptimizationParameter object inside the dictionary to specify the initial guess, bounds,
+ and type of parameter, for example:
+ >>> 'param':OptimizationParameter(initial guess, lower bound, upper bound)
+ Calling optimize will then run the optimization simultaneously over all
+ optimization parameters. To obtain the optimized parameters, call get_optimized_arguments
+ to return a set of dictionaries where the OptimizationParameter objects have been replaced
+ with the optimal values. These can then be modified for running a full reconstruction.
+
+ Parameters
+ ----------
+ reconstruction_type: class
+ Type of ptychographic reconstruction to perform
+ init_args: dict
+ Keyword arguments passed to the __init__ method of the reconstruction class
+ preprocess_args: dict
+ Keyword arguments passed to the preprocess method the reconstruction object
+ reconstruction_args: dict
+ Keyword arguments passed to the reconstruct method the the reconstruction object
+ affine_args: dict
+ Keyword arguments passed to AffineTransform. The transform is applied to the initial
+ scan positions.
+ """
+
+ # loop over each argument dictionary and split into static and optimization variables
+ (
+ self._init_static_args,
+ self._init_optimize_args,
+ ) = self._split_static_and_optimization_vars(init_args)
+ (
+ self._affine_static_args,
+ self._affine_optimize_args,
+ ) = self._split_static_and_optimization_vars(affine_args)
+ (
+ self._preprocess_static_args,
+ self._preprocess_optimize_args,
+ ) = self._split_static_and_optimization_vars(preprocess_args)
+ (
+ self._reconstruction_static_args,
+ self._reconstruction_optimize_args,
+ ) = self._split_static_and_optimization_vars(reconstruction_args)
+
+ # Save list of skopt parameter objects and inital guess
+ self._parameter_list = []
+ self._x0 = []
+ for k, v in (
+ self._init_optimize_args
+ | self._affine_optimize_args
+ | self._preprocess_optimize_args
+ | self._reconstruction_optimize_args
+ ).items():
+ self._parameter_list.append(v._get(k))
+ self._x0.append(v._initial_value)
+
+ self._init_args = init_args
+ self._affine_args = affine_args
+ self._preprocess_args = preprocess_args
+ self._reconstruction_args = reconstruction_args
+
+ self._reconstruction_type = reconstruction_type
+
+ self._set_optimizer_defaults()
+
+ def optimize(
+ self,
+ n_calls: int = 50,
+ n_initial_points: int = 20,
+ error_metric: Union[Callable, str] = "log",
+ **skopt_kwargs: dict,
+ ):
+ """
+ Run optimizer
+
+ Parameters
+ ----------
+ n_calls: int
+ Number of times to run ptychographic reconstruction
+ n_initial_points: int
+ Number of uniformly spaced trial points to test before
+ beginning Bayesian optimization (must be less than n_calls)
+ error_metric: Callable or str
+ Function used to compute the reconstruction error.
+ When passed as a string, may be one of:
+ 'log': log(NMSE) of final object
+ 'linear': NMSE of final object
+ 'log-converged': log(NMSE) of final object if
+ NMSE is decreasing, 0 if NMSE increasing
+ 'linear-converged': NMSE of final object if
+ NMSE is decreasing, 1 if NMSE increasing
+ 'TV': sum( abs( grad( object ) ) ) / sum( abs( object ) )
+ 'std': negative standard deviation of cropped object
+ 'std-phase': negative standard deviation of
+ phase of the cropped object
+ 'entropy-phase': entropy of the phase of the
+ cropped object
+ When passed as a Callable, a function that takes the
+ PhaseReconstruction object as its only argument
+ and returns the error metric as a single float
+ skopt_kwargs: dict
+ Additional arguments to be passed to skopt.gp_minimize
+
+ """
+
+ error_metric = self._get_error_metric(error_metric)
+
+ self._optimization_function = self._get_optimization_function(
+ self._reconstruction_type,
+ self._parameter_list,
+ self._init_static_args,
+ self._affine_static_args,
+ self._preprocess_static_args,
+ self._reconstruction_static_args,
+ self._init_optimize_args,
+ self._affine_optimize_args,
+ self._preprocess_optimize_args,
+ self._reconstruction_optimize_args,
+ error_metric,
+ )
+
+ # Make a progress bar
+ pbar = tqdm(total=n_calls, desc="Optimizing parameters")
+
+ # We need to wrap the callback because if it returns a value
+ # the optimizer breaks its loop
+ def callback(*args, **kwargs):
+ pbar.update(1)
+
+ self._skopt_result = gp_minimize(
+ self._optimization_function,
+ self._parameter_list,
+ n_calls=n_calls,
+ n_initial_points=n_initial_points,
+ x0=self._x0,
+ callback=callback,
+ **skopt_kwargs,
+ )
+
+ print("Optimized parameters:")
+ for p, x in zip(self._parameter_list, self._skopt_result.x):
+ print(f"{p.name}: {x}")
+
+ # Finish the tqdm progressbar so subsequent things behave nicely
+ pbar.close()
+
+ return self
+
+ def visualize(
+ self,
+ plot_gp_model=True,
+ plot_convergence=False,
+ plot_objective=True,
+ plot_evaluations=False,
+ **kwargs,
+ ):
+ """
+ Visualize optimization results
+
+ Parameters
+ ----------
+ plot_gp_model: bool
+ Display fitted Gaussian process model (only available for 1-dimensional problem)
+ plot_convergence: bool
+ Display convergence history
+ plot_objective: bool
+ Display GP objective function and partial dependence plots
+ plot_evaluations: bool
+ Display histograms of sampled points
+ kwargs:
+ Passed directly to the skopt plot_gassian_process/plot_objective
+ """
+ ndims = len(self._parameter_list)
+ if ndims == 1:
+ if plot_convergence:
+ figsize = kwargs.pop("figsize", (9, 9))
+ spec = GridSpec(nrows=2, ncols=1, height_ratios=[2, 1], hspace=0.15)
+ else:
+ figsize = kwargs.pop("figsize", (9, 6))
+ spec = GridSpec(nrows=1, ncols=1)
+
+ fig = plt.figure(figsize=figsize)
+ ax = fig.add_subplot(spec[0])
+ skopt_plot_gaussian_process(self._skopt_result, ax=ax, **kwargs)
+
+ if plot_convergence:
+ ax = fig.add_subplot(spec[1])
+ skopt_plot_convergence(self._skopt_result, ax=ax)
+
+ else:
+ if plot_convergence:
+ figsize = kwargs.pop("figsize", (4 * ndims, 4 * (ndims + 0.5)))
+ spec = GridSpec(
+ nrows=ndims + 1,
+ ncols=ndims,
+ height_ratios=[2] * ndims + [1],
+ hspace=0.15,
+ )
+ else:
+ figsize = kwargs.pop("figsize", (4 * ndims, 4 * ndims))
+ spec = GridSpec(nrows=ndims, ncols=ndims, hspace=0.15)
+
+ if plot_evaluations:
+ axs = skopt_plot_evaluations(self._skopt_result)
+ elif plot_objective:
+ cmap = kwargs.pop("cmap", "magma")
+ axs = skopt_plot_objective(self._skopt_result, cmap=cmap, **kwargs)
+ elif plot_convergence:
+ skopt_plot_convergence(self._skopt_result)
+ return self
+
+ fig = axs[0, 0].figure
+ fig.set_size_inches(figsize)
+ for i in range(ndims):
+ for j in range(ndims):
+ ax = axs[i, j]
+ ax.remove()
+ ax.figure = fig
+ fig.add_axes(ax)
+ ax.set_subplotspec(spec[i, j])
+
+ if plot_convergence:
+ ax = fig.add_subplot(spec[ndims, :])
+ skopt_plot_convergence(self._skopt_result, ax=ax)
+
+ spec.tight_layout(fig)
+
+ return self
+
+ def get_optimized_arguments(self):
+ """
+ Get argument dictionaries containing optimized hyperparameters
+
+ Returns
+ -------
+ init_opt, prep_opt, reco_opt: dicts
+ Dictionaries of arguments to __init__, preprocess, and reconstruct
+ where the OptimizationParameter items have been replaced with the optimal
+ values obtained from the optimizer
+ """
+ optimized_dict = {
+ p.name: v for p, v in zip(self._parameter_list, self._skopt_result.x)
+ }
+
+ filtered_dict = {
+ k: v for k, v in optimized_dict.items() if k in self._init_args
+ }
+ init_opt = self._init_args | filtered_dict
+
+ filtered_dict = {
+ k: v for k, v in optimized_dict.items() if k in self._affine_args
+ }
+ affine_opt = self._affine_args | filtered_dict
+
+ affine_transform = partial(AffineTransform, **self._affine_static_args)(
+ **affine_opt
+ )
+ scan_positions = self._get_scan_positions(
+ affine_transform, init_opt["datacube"]
+ )
+ init_opt["initial_scan_positions"] = scan_positions
+
+ filtered_dict = {
+ k: v for k, v in optimized_dict.items() if k in self._preprocess_args
+ }
+ prep_opt = self._preprocess_args | filtered_dict
+
+ filtered_dict = {
+ k: v for k, v in optimized_dict.items() if k in self._reconstruction_args
+ }
+ reco_opt = self._reconstruction_args | filtered_dict
+
+ return init_opt, prep_opt, reco_opt
+
+ def _split_static_and_optimization_vars(self, argdict):
+ static_args = {}
+ optimization_args = {}
+ for k, v in argdict.items():
+ if isinstance(v, OptimizationParameter):
+ optimization_args[k] = v
+ else:
+ static_args[k] = v
+ return static_args, optimization_args
+
+ def _get_scan_positions(self, affine_transform, dataset):
+ R_pixel_size = dataset.calibration.get_R_pixel_size()
+ x, y = (
+ np.arange(dataset.R_Nx) * R_pixel_size,
+ np.arange(dataset.R_Ny) * R_pixel_size,
+ )
+ x, y = np.meshgrid(x, y, indexing="ij")
+ scan_positions = np.stack((x.ravel(), y.ravel()), axis=1)
+ scan_positions = scan_positions @ affine_transform.asarray()
+ return scan_positions
+
+ def _get_error_metric(self, error_metric: Union[Callable, str]) -> Callable:
+ """
+ Get error metric as a function, converting builtin method names
+ to functions
+ """
+
+ if callable(error_metric):
+ return error_metric
+
+ assert error_metric in (
+ "log",
+ "linear",
+ "log-converged",
+ "linear-converged",
+ "TV",
+ "std",
+ "std-phase",
+ "entropy-phase",
+ ), f"Error metric {error_metric} not recognized."
+
+ if error_metric == "log":
+
+ def f(ptycho):
+ return np.log(ptycho.error)
+
+ elif error_metric == "linear":
+
+ def f(ptycho):
+ return ptycho.error
+
+ elif error_metric == "log-converged":
+
+ def f(ptycho):
+ converged = ptycho.error_iterations[-1] <= np.min(
+ ptycho.error_iterations
+ )
+ return np.log(ptycho.error) if converged else 0.0
+
+ elif error_metric == "log-linear":
+
+ def f(ptycho):
+ converged = ptycho.error_iterations[-1] <= np.min(
+ ptycho.error_iterations
+ )
+ return ptycho.error if converged else 1.0
+
+ elif error_metric == "TV":
+
+ def f(ptycho):
+ gx, gy = np.gradient(ptycho.object_cropped, axis=(-2, -1))
+ obj_mag = np.sum(np.abs(ptycho.object_cropped))
+ tv = np.sum(np.abs(gx)) + np.sum(np.abs(gy))
+ return tv / obj_mag
+
+ elif error_metric == "std":
+
+ def f(ptycho):
+ return -np.std(ptycho.object_cropped)
+
+ elif error_metric == "std-phase":
+
+ def f(ptycho):
+ return -np.std(np.angle(ptycho.object_cropped))
+
+ elif error_metric == "entropy-phase":
+
+ def f(ptycho):
+ obj = np.angle(ptycho.object_cropped)
+ gx, gy = np.gradient(obj)
+ ghist, _, _ = np.histogram2d(
+ gx.ravel(), gy.ravel(), bins=obj.shape, density=True
+ )
+ nz = ghist > 0
+ S = np.sum(ghist[nz] * np.log2(ghist[nz]))
+ return S
+
+ else:
+ raise ValueError(f"Error metric {error_metric} not recognized.")
+
+ return f
+
+ def _get_optimization_function(
+ self,
+ cls: type[PhaseReconstruction],
+ parameter_list: list,
+ init_static_args: dict,
+ affine_static_args: dict,
+ preprocess_static_args: dict,
+ reconstruct_static_args: dict,
+ init_optimization_params: dict,
+ affine_optimization_params: dict,
+ preprocess_optimization_params: dict,
+ reconstruct_optimization_params: dict,
+ error_metric: Callable,
+ ):
+ """
+ Wrap the ptychography pipeline into a single function that encapsulates all of the
+ non-optimization arguments and accepts a concatenated set of keyword arguments. The
+ wrapper function returns the final error value from the ptychography run.
+
+ parameter_list is a list of skopt Dimension objects
+
+ Both static and optimization args are passed in dictionaries. The values of the
+ static dictionary are the fixed parameters, and only the keys of the optimization
+ dictionary are used.
+ """
+
+ # Get lists of optimization parameters for each step
+ init_params = list(init_optimization_params.keys())
+ afft_params = list(affine_optimization_params.keys())
+ prep_params = list(preprocess_optimization_params.keys())
+ reco_params = list(reconstruct_optimization_params.keys())
+
+ # Construct partial methods to encapsulate the static parameters.
+ # If only ``reconstruct`` has optimization variables, perform
+ # preprocessing now, store the ptycho object, and use dummy
+ # functions instead of the partials
+ if (len(init_params), len(afft_params), len(prep_params)) == (0, 0, 0):
+ affine_preprocessed = AffineTransform(**affine_static_args)
+ init_args = init_static_args.copy()
+ init_args["initial_scan_positions"] = self._get_scan_positions(
+ affine_preprocessed, init_static_args["datacube"]
+ )
+
+ ptycho_preprocessed = cls(**init_args).preprocess(**preprocess_static_args)
+
+ def obj(**kwargs):
+ return ptycho_preprocessed
+
+ def prep(ptycho, **kwargs):
+ return ptycho
+
+ else:
+ obj = partial(cls, **init_static_args)
+ prep = partial(cls.preprocess, **preprocess_static_args)
+
+ affine = partial(AffineTransform, **affine_static_args)
+ recon = partial(cls.reconstruct, **reconstruct_static_args)
+
+ # Target function for Gaussian process optimization that takes a single
+ # dict of named parameters and returns the ptycho error metric
+ @use_named_args(parameter_list)
+ def f(**kwargs):
+ init_args = {k: kwargs[k] for k in init_params}
+ afft_args = {k: kwargs[k] for k in afft_params}
+ prep_args = {k: kwargs[k] for k in prep_params}
+ reco_args = {k: kwargs[k] for k in reco_params}
+
+ # Create affine transform object
+ tr = affine(**afft_args)
+ # Apply affine transform to pixel grid, using the
+ # calibrations lifted from the dataset
+ dataset = init_static_args["datacube"]
+ init_args["initial_scan_positions"] = self._get_scan_positions(tr, dataset)
+
+ ptycho = obj(**init_args)
+ prep(ptycho, **prep_args)
+ recon(ptycho, **reco_args)
+
+ return error_metric(ptycho)
+
+ return f
+
+ def _set_optimizer_defaults(
+ self,
+ verbose=False,
+ plot_center_of_mass=False,
+ plot_rotation=False,
+ plot_probe_overlaps=False,
+ progress_bar=False,
+ store_iterations=False,
+ reset=True,
+ ):
+ """
+ Set all of the verbose and plotting to False, allowing for user-overwrite.
+ """
+ self._init_static_args["verbose"] = verbose
+
+ self._preprocess_static_args["plot_center_of_mass"] = plot_center_of_mass
+ self._preprocess_static_args["plot_rotation"] = plot_rotation
+ self._preprocess_static_args["plot_probe_overlaps"] = plot_probe_overlaps
+
+ self._reconstruction_static_args["progress_bar"] = progress_bar
+ self._reconstruction_static_args["store_iterations"] = store_iterations
+ self._reconstruction_static_args["reset"] = reset
+
+
+class OptimizationParameter:
+ """
+ Wrapper for scikit-optimize Space objects used for convenient calling in the PtyhochraphyOptimizer
+ """
+
+ def __init__(
+ self,
+ initial_value: Union[float, int, bool],
+ lower_bound: Union[float, int, bool] = None,
+ upper_bound: Union[float, int, bool] = None,
+ scaling: str = "uniform",
+ space: str = "real",
+ categories: list = [],
+ ):
+ """
+ Wrapper for scikit-optimize Space objects used as inputs to PtychographyOptimizer
+
+ Parameters
+ ----------
+ initial_value:
+ Initial value, used for first evaluation in optimizer
+ lower_bound, upper_bound:
+ Bounds on real or integer variables (not needed for bool or categorical)
+ scaling: str
+ Prior knowledge on sensitivity of the variable. Can be 'uniform' or 'log-uniform'
+ space: str
+ Type of variable. Can be 'real', 'integer', 'bool', or 'categorical'
+ categories: list
+ List of options for Categorical parameter
+ """
+ # Check input
+ space = space.lower()
+ if space not in ("real", "integer", "bool", "categorical"):
+ raise ValueError(f"Unknown Parameter type: {space}")
+
+ scaling = scaling.lower()
+ if scaling not in ("uniform", "log-uniform"):
+ raise ValueError(f"Unknown scaling: {scaling}")
+
+ # Get the right scikit-optimize class
+ space_map = {
+ "real": Real,
+ "integer": Integer,
+ "bool": Categorical,
+ "categorical": Categorical,
+ }
+ param = space_map[space]
+
+ # If a boolean property, the categories are True/False
+ if space == "bool":
+ categories = [True, False]
+
+ if categories == [] and space in ("categorical", "bool"):
+ raise ValueError("Empty list of categories!")
+
+ # store necessary information
+ self._initial_value = initial_value
+ self._categories = categories
+ self._lower_bound = lower_bound
+ self._upper_bound = upper_bound
+ self._scaling = scaling
+ self._param_type = param
+
+ def _get(self, name):
+ self._name = name
+ if self._param_type is Categorical:
+ self._skopt_param = self._param_type(
+ name=self._name, categories=self._categories
+ )
+ else:
+ self._skopt_param = self._param_type(
+ name=self._name,
+ low=self._lower_bound,
+ high=self._upper_bound,
+ prior=self._scaling,
+ )
+ return self._skopt_param
diff --git a/py4DSTEM/process/phase/utils.py b/py4DSTEM/process/phase/utils.py
new file mode 100644
index 000000000..a1eb54c80
--- /dev/null
+++ b/py4DSTEM/process/phase/utils.py
@@ -0,0 +1,1676 @@
+import functools
+from typing import Mapping, Tuple, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.optimize import curve_fit
+
+try:
+ import cupy as cp
+ from cupyx.scipy.fft import rfft
+except ImportError:
+ cp = None
+ from scipy.fft import dstn, idstn
+
+from py4DSTEM.process.utils import get_CoM
+from py4DSTEM.process.utils.cross_correlate import align_and_shift_images
+from py4DSTEM.process.utils.utils import electron_wavelength_angstrom
+from scipy.ndimage import gaussian_filter, uniform_filter1d
+from skimage.restoration import unwrap_phase
+
+# fmt: off
+
+#: Symbols for the polar representation of all optical aberrations up to the fifth order.
+polar_symbols = (
+ "C10", "C12", "phi12",
+ "C21", "phi21", "C23", "phi23",
+ "C30", "C32", "phi32", "C34", "phi34",
+ "C41", "phi41", "C43", "phi43", "C45", "phi45",
+ "C50", "C52", "phi52", "C54", "phi54", "C56", "phi56",
+)
+
+#: Aliases for the most commonly used optical aberrations.
+polar_aliases = {
+ "defocus": "C10", "astigmatism": "C12", "astigmatism_angle": "phi12",
+ "coma": "C21", "coma_angle": "phi21",
+ "Cs": "C30",
+ "C5": "C50",
+}
+
+# fmt: on
+
+### Probe functions
+
+
+class ComplexProbe:
+ """
+ Complex Probe Class.
+
+ Simplified version of CTF and Probe from abTEM:
+ https://github.com/abTEM/abTEM/blob/master/abtem/transfer.py
+ https://github.com/abTEM/abTEM/blob/master/abtem/waves.py
+
+ Parameters
+ ----------
+ energy: float
+ The electron energy of the wave functions this contrast transfer function will be applied to [eV].
+ semiangle_cutoff: float
+ The semiangle cutoff describes the sharp Fourier space cutoff due to the objective aperture [mrad].
+ gpts : Tuple[int,int]
+ Number of grid points describing the wave functions.
+ sampling : Tuple[float,float]
+ Lateral sampling of wave functions in Å
+ device: str, optional
+ Device to perform calculations on. Must be either 'cpu' or 'gpu'
+ rolloff: float, optional
+ Tapers the cutoff edge over the given angular range [mrad].
+ vacuum_probe_intensity: np.ndarray, optional
+ Squared of corner-centered aperture amplitude to use, instead of semiangle_cutoff + rolloff
+ focal_spread: float, optional
+ The 1/e width of the focal spread due to chromatic aberration and lens current instability [Å].
+ angular_spread: float, optional
+ The 1/e width of the angular deviations due to source size [mrad].
+ gaussian_spread: float, optional
+ The 1/e width image deflections due to vibrations and thermal magnetic noise [Å].
+ phase_shift : float, optional
+ A constant phase shift [radians].
+ parameters: dict, optional
+ Mapping from aberration symbols to their corresponding values. All aberration magnitudes should be given in Å
+ and angles should be given in radians.
+ kwargs:
+ Provide the aberration coefficients as keyword arguments.
+ """
+
+ def __init__(
+ self,
+ energy: float,
+ gpts: Tuple[int, int],
+ sampling: Tuple[float, float],
+ semiangle_cutoff: float = np.inf,
+ rolloff: float = 2.0,
+ vacuum_probe_intensity: np.ndarray = None,
+ device: str = "cpu",
+ focal_spread: float = 0.0,
+ angular_spread: float = 0.0,
+ gaussian_spread: float = 0.0,
+ phase_shift: float = 0.0,
+ parameters: Mapping[str, float] = None,
+ **kwargs,
+ ):
+ if device == "cpu":
+ self._xp = np
+ self._asnumpy = np.asarray
+ elif device == "gpu":
+ self._xp = cp
+ self._asnumpy = cp.asnumpy
+ else:
+ raise ValueError(f"device must be either 'cpu' or 'gpu', not {device}")
+
+ for key in kwargs.keys():
+ if (key not in polar_symbols) and (key not in polar_aliases.keys()):
+ raise ValueError("{} not a recognized parameter".format(key))
+
+ self._vacuum_probe_intensity = vacuum_probe_intensity
+ self._semiangle_cutoff = semiangle_cutoff
+ self._rolloff = rolloff
+ self._focal_spread = focal_spread
+ self._angular_spread = angular_spread
+ self._gaussian_spread = gaussian_spread
+ self._phase_shift = phase_shift
+ self._energy = energy
+ self._wavelength = electron_wavelength_angstrom(energy)
+ self._gpts = gpts
+ self._sampling = sampling
+
+ self._parameters = dict(zip(polar_symbols, [0.0] * len(polar_symbols)))
+
+ if parameters is None:
+ parameters = {}
+
+ parameters.update(kwargs)
+ self.set_parameters(parameters)
+
+ def set_parameters(self, parameters: dict):
+ """
+ Set the phase of the phase aberration.
+ Parameters
+ ----------
+ parameters: dict
+ Mapping from aberration symbols to their corresponding values.
+ """
+
+ for symbol, value in parameters.items():
+ if symbol in self._parameters.keys():
+ self._parameters[symbol] = value
+
+ elif symbol == "defocus":
+ self._parameters[polar_aliases[symbol]] = -value
+
+ elif symbol in polar_aliases.keys():
+ self._parameters[polar_aliases[symbol]] = value
+
+ else:
+ raise ValueError("{} not a recognized parameter".format(symbol))
+
+ return parameters
+
+ def evaluate_aperture(
+ self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray] = None
+ ) -> Union[float, np.ndarray]:
+ xp = self._xp
+ semiangle_cutoff = self._semiangle_cutoff / 1000
+
+ if self._vacuum_probe_intensity is not None:
+ vacuum_probe_intensity = xp.asarray(
+ self._vacuum_probe_intensity, dtype=xp.float32
+ )
+ vacuum_probe_amplitude = xp.sqrt(xp.maximum(vacuum_probe_intensity, 0))
+ return vacuum_probe_amplitude
+
+ if self._semiangle_cutoff == xp.inf:
+ return xp.ones_like(alpha)
+
+ if self._rolloff > 0.0:
+ rolloff = self._rolloff / 1000.0 # * semiangle_cutoff
+ array = 0.5 * (
+ 1 + xp.cos(np.pi * (alpha - semiangle_cutoff + rolloff) / rolloff)
+ )
+ array[alpha > semiangle_cutoff] = 0.0
+ array = xp.where(
+ alpha > semiangle_cutoff - rolloff,
+ array,
+ xp.ones_like(alpha, dtype=xp.float32),
+ )
+ else:
+ array = xp.array(alpha < semiangle_cutoff).astype(xp.float32)
+ return array
+
+ def evaluate_temporal_envelope(
+ self, alpha: Union[float, np.ndarray]
+ ) -> Union[float, np.ndarray]:
+ xp = self._xp
+ return xp.exp(
+ -((0.5 * xp.pi / self._wavelength * self._focal_spread * alpha**2) ** 2)
+ ).astype(xp.float32)
+
+ def evaluate_gaussian_envelope(
+ self, alpha: Union[float, np.ndarray]
+ ) -> Union[float, np.ndarray]:
+ xp = self._xp
+ return xp.exp(
+ -0.5 * self._gaussian_spread**2 * alpha**2 / self._wavelength**2
+ )
+
+ def evaluate_spatial_envelope(
+ self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]
+ ) -> Union[float, np.ndarray]:
+ xp = self._xp
+ p = self._parameters
+ dchi_dk = (
+ 2
+ * xp.pi
+ / self._wavelength
+ * (
+ (p["C12"] * xp.cos(2.0 * (phi - p["phi12"])) + p["C10"]) * alpha
+ + (
+ p["C23"] * xp.cos(3.0 * (phi - p["phi23"]))
+ + p["C21"] * xp.cos(1.0 * (phi - p["phi21"]))
+ )
+ * alpha**2
+ + (
+ p["C34"] * xp.cos(4.0 * (phi - p["phi34"]))
+ + p["C32"] * xp.cos(2.0 * (phi - p["phi32"]))
+ + p["C30"]
+ )
+ * alpha**3
+ + (
+ p["C45"] * xp.cos(5.0 * (phi - p["phi45"]))
+ + p["C43"] * xp.cos(3.0 * (phi - p["phi43"]))
+ + p["C41"] * xp.cos(1.0 * (phi - p["phi41"]))
+ )
+ * alpha**4
+ + (
+ p["C56"] * xp.cos(6.0 * (phi - p["phi56"]))
+ + p["C54"] * xp.cos(4.0 * (phi - p["phi54"]))
+ + p["C52"] * xp.cos(2.0 * (phi - p["phi52"]))
+ + p["C50"]
+ )
+ * alpha**5
+ )
+ )
+
+ dchi_dphi = (
+ -2
+ * xp.pi
+ / self._wavelength
+ * (
+ 1 / 2.0 * (2.0 * p["C12"] * xp.sin(2.0 * (phi - p["phi12"]))) * alpha
+ + 1
+ / 3.0
+ * (
+ 3.0 * p["C23"] * xp.sin(3.0 * (phi - p["phi23"]))
+ + 1.0 * p["C21"] * xp.sin(1.0 * (phi - p["phi21"]))
+ )
+ * alpha**2
+ + 1
+ / 4.0
+ * (
+ 4.0 * p["C34"] * xp.sin(4.0 * (phi - p["phi34"]))
+ + 2.0 * p["C32"] * xp.sin(2.0 * (phi - p["phi32"]))
+ )
+ * alpha**3
+ + 1
+ / 5.0
+ * (
+ 5.0 * p["C45"] * xp.sin(5.0 * (phi - p["phi45"]))
+ + 3.0 * p["C43"] * xp.sin(3.0 * (phi - p["phi43"]))
+ + 1.0 * p["C41"] * xp.sin(1.0 * (phi - p["phi41"]))
+ )
+ * alpha**4
+ + 1
+ / 6.0
+ * (
+ 6.0 * p["C56"] * xp.sin(6.0 * (phi - p["phi56"]))
+ + 4.0 * p["C54"] * xp.sin(4.0 * (phi - p["phi54"]))
+ + 2.0 * p["C52"] * xp.sin(2.0 * (phi - p["phi52"]))
+ )
+ * alpha**5
+ )
+ )
+
+ return xp.exp(
+ -xp.sign(self._angular_spread)
+ * (self._angular_spread / 2 / 1000) ** 2
+ * (dchi_dk**2 + dchi_dphi**2)
+ )
+
+ def evaluate_chi(
+ self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]
+ ) -> Union[float, np.ndarray]:
+ xp = self._xp
+ p = self._parameters
+
+ alpha2 = alpha**2
+ alpha = xp.array(alpha)
+
+ array = xp.zeros(alpha.shape, dtype=np.float32)
+ if any([p[symbol] != 0.0 for symbol in ("C10", "C12", "phi12")]):
+ array += (
+ 1 / 2 * alpha2 * (p["C10"] + p["C12"] * xp.cos(2 * (phi - p["phi12"])))
+ )
+
+ if any([p[symbol] != 0.0 for symbol in ("C21", "phi21", "C23", "phi23")]):
+ array += (
+ 1
+ / 3
+ * alpha2
+ * alpha
+ * (
+ p["C21"] * xp.cos(phi - p["phi21"])
+ + p["C23"] * xp.cos(3 * (phi - p["phi23"]))
+ )
+ )
+
+ if any(
+ [p[symbol] != 0.0 for symbol in ("C30", "C32", "phi32", "C34", "phi34")]
+ ):
+ array += (
+ 1
+ / 4
+ * alpha2**2
+ * (
+ p["C30"]
+ + p["C32"] * xp.cos(2 * (phi - p["phi32"]))
+ + p["C34"] * xp.cos(4 * (phi - p["phi34"]))
+ )
+ )
+
+ if any(
+ [
+ p[symbol] != 0.0
+ for symbol in ("C41", "phi41", "C43", "phi43", "C45", "phi41")
+ ]
+ ):
+ array += (
+ 1
+ / 5
+ * alpha2**2
+ * alpha
+ * (
+ p["C41"] * xp.cos((phi - p["phi41"]))
+ + p["C43"] * xp.cos(3 * (phi - p["phi43"]))
+ + p["C45"] * xp.cos(5 * (phi - p["phi45"]))
+ )
+ )
+
+ if any(
+ [
+ p[symbol] != 0.0
+ for symbol in ("C50", "C52", "phi52", "C54", "phi54", "C56", "phi56")
+ ]
+ ):
+ array += (
+ 1
+ / 6
+ * alpha2**3
+ * (
+ p["C50"]
+ + p["C52"] * xp.cos(2 * (phi - p["phi52"]))
+ + p["C54"] * xp.cos(4 * (phi - p["phi54"]))
+ + p["C56"] * xp.cos(6 * (phi - p["phi56"]))
+ )
+ )
+
+ array = 2 * xp.pi / self._wavelength * array + self._phase_shift
+ return array
+
+ def evaluate_aberrations(
+ self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]
+ ) -> Union[float, np.ndarray]:
+ xp = self._xp
+ return xp.exp(-1.0j * self.evaluate_chi(alpha, phi))
+
+ def evaluate(
+ self, alpha: Union[float, np.ndarray], phi: Union[float, np.ndarray]
+ ) -> Union[float, np.ndarray]:
+ array = self.evaluate_aberrations(alpha, phi)
+
+ if self._semiangle_cutoff < np.inf or self._vacuum_probe_intensity is not None:
+ array *= self.evaluate_aperture(alpha, phi)
+
+ if self._focal_spread > 0.0:
+ array *= self.evaluate_temporal_envelope(alpha)
+
+ if self._angular_spread > 0.0:
+ array *= self.evaluate_spatial_envelope(alpha, phi)
+
+ if self._gaussian_spread > 0.0:
+ array *= self.evaluate_gaussian_envelope(alpha)
+
+ return array
+
+ def _evaluate_ctf(self):
+ alpha, phi = self.get_scattering_angles()
+
+ array = self.evaluate(alpha, phi)
+ return array
+
+ def get_scattering_angles(self):
+ kx, ky = self.get_spatial_frequencies()
+ alpha, phi = self.polar_coordinates(
+ kx * self._wavelength, ky * self._wavelength
+ )
+ return alpha, phi
+
+ def get_spatial_frequencies(self):
+ xp = self._xp
+ kx, ky = spatial_frequencies(self._gpts, self._sampling)
+ kx = xp.asarray(kx, dtype=xp.float32)
+ ky = xp.asarray(ky, dtype=xp.float32)
+ return kx, ky
+
+ def polar_coordinates(self, x, y):
+ """Calculate a polar grid for a given Cartesian grid."""
+ xp = self._xp
+ alpha = xp.sqrt(x.reshape((-1, 1)) ** 2 + y.reshape((1, -1)) ** 2)
+ phi = xp.arctan2(x.reshape((-1, 1)), y.reshape((1, -1)))
+ return alpha, phi
+
+ def build(self):
+ """Builds corner-centered complex probe in the center of the region of interest."""
+ xp = self._xp
+ array = xp.fft.ifft2(self._evaluate_ctf())
+ array = array / xp.sqrt((xp.abs(array) ** 2).sum())
+ self._array = array
+ return self
+
+ def visualize(self, **kwargs):
+ """Plots the probe intensity."""
+ xp = self._xp
+ asnumpy = self._asnumpy
+
+ cmap = kwargs.get("cmap", "Greys_r")
+ kwargs.pop("cmap", None)
+
+ plt.imshow(
+ asnumpy(xp.abs(xp.fft.ifftshift(self._array)) ** 2),
+ cmap=cmap,
+ **kwargs,
+ )
+ return self
+
+
+def spatial_frequencies(gpts: Tuple[int, int], sampling: Tuple[float, float]):
+ """
+ Calculate spatial frequencies of a grid.
+
+ Parameters
+ ----------
+ gpts: tuple of int
+ Number of grid points.
+ sampling: tuple of float
+ Sampling of the potential [1 / Å].
+
+ Returns
+ -------
+ tuple of arrays
+ """
+
+ return tuple(
+ np.fft.fftfreq(n, d).astype(np.float32) for n, d in zip(gpts, sampling)
+ )
+
+
+### FFT-shift functions
+
+
+def fourier_translation_operator(
+ positions: np.ndarray, shape: tuple, xp=np
+) -> np.ndarray:
+ """
+ Create an array representing one or more phase ramp(s) for shifting another array.
+
+ Parameters
+ ----------
+ positions : array of xy-positions
+ Positions to calculate fourier translation operators for
+ shape : two int
+ Array dimensions to be fourier-shifted
+ xp: Callable
+ Array computing module
+
+ Returns
+ -------
+ Fourier translation operators
+ """
+
+ positions_shape = positions.shape
+
+ if len(positions_shape) == 1:
+ positions = positions[None]
+
+ kx, ky = spatial_frequencies(shape, (1.0, 1.0))
+ kx = kx.reshape((1, -1, 1))
+ ky = ky.reshape((1, 1, -1))
+ kx = xp.asarray(kx, dtype=xp.float32)
+ ky = xp.asarray(ky, dtype=xp.float32)
+ positions = xp.asarray(positions, dtype=xp.float32)
+ x = positions[:, 0].reshape((-1,) + (1, 1))
+ y = positions[:, 1].reshape((-1,) + (1, 1))
+
+ result = xp.exp(-2.0j * np.pi * kx * x) * xp.exp(-2.0j * np.pi * ky * y)
+
+ if len(positions_shape) == 1:
+ return result[0]
+ else:
+ return result
+
+
+def fft_shift(array, positions, xp=np):
+ """
+ Fourier-shift array using positions.
+
+ Parameters
+ ----------
+ array: np.ndarray
+ Array to be shifted
+ positions: array of xy-positions
+ Positions to fourier-shift array with
+ xp: Callable
+ Array computing module
+
+ Returns
+ -------
+ Fourier-shifted array
+ """
+ translation_operator = fourier_translation_operator(positions, array.shape[-2:], xp)
+ fourier_array = xp.fft.fft2(array)
+
+ if len(translation_operator.shape) == 3 and len(fourier_array.shape) == 3:
+ shifted_fourier_array = fourier_array[None] * translation_operator[:, None]
+ else:
+ shifted_fourier_array = fourier_array * translation_operator
+
+ return xp.fft.ifft2(shifted_fourier_array)
+
+
+### Batching functions
+
+
+def subdivide_into_batches(
+ num_items: int, num_batches: int = None, max_batch: int = None
+):
+ """
+ Split an n integer into m (almost) equal integers, such that the sum of smaller integers equals n.
+
+ Parameters
+ ----------
+ n: int
+ The integer to split.
+ m: int
+ The number integers n will be split into.
+
+ Returns
+ -------
+ list of int
+ """
+ if (num_batches is not None) & (max_batch is not None):
+ raise RuntimeError()
+
+ if num_batches is None:
+ if max_batch is not None:
+ num_batches = (num_items + (-num_items % max_batch)) // max_batch
+ else:
+ raise RuntimeError()
+
+ if num_items < num_batches:
+ raise RuntimeError("num_batches may not be larger than num_items")
+
+ elif num_items % num_batches == 0:
+ return [num_items // num_batches] * num_batches
+ else:
+ v = []
+ zp = num_batches - (num_items % num_batches)
+ pp = num_items // num_batches
+ for i in range(num_batches):
+ if i >= zp:
+ v = [pp + 1] + v
+ else:
+ v = [pp] + v
+ return v
+
+
+def generate_batches(
+ num_items: int, num_batches: int = None, max_batch: int = None, start=0
+):
+ for batch in subdivide_into_batches(num_items, num_batches, max_batch):
+ end = start + batch
+ yield start, end
+
+ start = end
+
+
+#### Affine transformation functions
+
+
+class AffineTransform:
+ """
+ Affine Transform Class.
+
+ Simplified version of AffineTransform from tike:
+ https://github.com/AdvancedPhotonSource/tike/blob/f9004a32fda5e49fa63b987e9ffe3c8447d59950/src/tike/ptycho/position.py
+
+ AffineTransform() -> Identity
+
+ Parameters
+ ----------
+ scale0: float
+ x-scaling
+ scale1: float
+ y-scaling
+ shear1: float
+ \\gamma shear
+ angle: float
+ \\theta rotation angle
+ t0: float
+ x-translation
+ t1: float
+ y-translation
+ dilation: float
+ Isotropic expansion (multiplies scale0 and scale1)
+ """
+
+ def __init__(
+ self,
+ scale0: float = 1.0,
+ scale1: float = 1.0,
+ shear1: float = 0.0,
+ angle: float = 0.0,
+ t0: float = 0.0,
+ t1: float = 0.0,
+ dilation: float = 1.0,
+ ):
+ self.scale0 = scale0 * dilation
+ self.scale1 = scale1 * dilation
+ self.shear1 = shear1
+ self.angle = angle
+ self.t0 = t0
+ self.t1 = t1
+
+ @classmethod
+ def fromarray(self, T: np.ndarray):
+ """
+ Return an Affine Transfrom from a 2x2 matrix.
+ Use decomposition method from Graphics Gems 2 Section 7.1
+ """
+ R = T[:2, :2].copy()
+ scale0 = np.linalg.norm(R[0])
+ if scale0 <= 0:
+ return AffineTransform()
+ R[0] /= scale0
+ shear1 = R[0] @ R[1]
+ R[1] -= shear1 * R[0]
+ scale1 = np.linalg.norm(R[1])
+ if scale1 <= 0:
+ return AffineTransform()
+ R[1] /= scale1
+ shear1 /= scale1
+ angle = np.arccos(R[0, 0])
+
+ if T.shape[0] > 2:
+ t0, t1 = T[2]
+ else:
+ t0 = t1 = 0.0
+
+ return AffineTransform(
+ scale0=float(scale0),
+ scale1=float(scale1),
+ shear1=float(shear1),
+ angle=float(angle),
+ t0=t0,
+ t1=t1,
+ )
+
+ def asarray(self):
+ """
+ Return an 2x2 matrix of scale, shear, rotation.
+ This matrix is scale @ shear @ rotate from left to right.
+ """
+ cosx = np.cos(self.angle)
+ sinx = np.sin(self.angle)
+ return (
+ np.array(
+ [
+ [self.scale0, 0.0],
+ [0.0, self.scale1],
+ ],
+ dtype="float32",
+ )
+ @ np.array(
+ [
+ [1.0, 0.0],
+ [self.shear1, 1.0],
+ ],
+ dtype="float32",
+ )
+ @ np.array(
+ [
+ [+cosx, -sinx],
+ [+sinx, +cosx],
+ ],
+ dtype="float32",
+ )
+ )
+
+ def asarray3(self):
+ """
+ Return an 3x2 matrix of scale, shear, rotation, translation.
+ This matrix is scale @ shear @ rotate from left to right.
+ Expects a homogenous (z) coordinate of 1.
+ """
+ T = np.empty((3, 2), dtype="float32")
+ T[2] = (self.t0, self.t1)
+ T[:2, :2] = self.asarray()
+ return T
+
+ def astuple(self):
+ """Return the constructor parameters in a tuple."""
+ return (
+ self.scale0,
+ self.scale1,
+ self.shear1,
+ self.angle,
+ self.t0,
+ self.t1,
+ )
+
+ def __call__(self, x: np.ndarray, origin=(0, 0), xp=np):
+ origin = xp.asarray(origin, dtype=xp.float32)
+ tf_matrix = self.asarray()
+ tf_matrix = xp.asarray(tf_matrix, dtype=xp.float32)
+ tf_translation = xp.array((self.t0, self.t1)) + origin
+ return ((x - origin) @ tf_matrix) + tf_translation
+
+ def __str__(self):
+ return (
+ "AffineTransform( \n"
+ f" scale0 = {self.scale0:.4f}, scale1 = {self.scale1:.4f}, \n"
+ f" shear1 = {self.shear1:.4f}, angle = {self.angle:.4f}, \n"
+ f" t0 = {self.t0:.4f}, t1 = {self.t1:.4f}, \n"
+ ")"
+ )
+
+
+def estimate_global_transformation(
+ positions0: np.ndarray,
+ positions1: np.ndarray,
+ origin: Tuple[int, int] = (0, 0),
+ translation_allowed: bool = True,
+ xp=np,
+):
+ """Use least squares to estimate the global affine transformation."""
+ origin = xp.asarray(origin, dtype=xp.float32)
+
+ try:
+ if translation_allowed:
+ a = xp.pad(positions0 - origin, ((0, 0), (0, 1)), constant_values=1)
+ else:
+ a = positions0 - origin
+
+ b = positions1 - origin
+ aT = a.conj().swapaxes(-1, -2)
+ x = xp.linalg.inv(aT @ a) @ aT @ b
+
+ tf = AffineTransform.fromarray(x)
+
+ except xp.linalg.LinAlgError:
+ tf = AffineTransform()
+
+ error = xp.linalg.norm(tf(positions0, origin=origin, xp=xp) - positions1)
+
+ return tf, error
+
+
+def estimate_global_transformation_ransac(
+ positions0: np.ndarray,
+ positions1: np.ndarray,
+ origin: Tuple[int, int] = (0, 0),
+ translation_allowed: bool = True,
+ min_sample: int = 64,
+ max_error: float = 16,
+ min_consensus: float = 0.75,
+ max_iter: int = 20,
+ xp=np,
+):
+ """Use RANSAC to estimate the global affine transformation."""
+ best_fitness = np.inf # small fitness is good
+ transform = AffineTransform()
+
+ # Choose a subset
+ for subset in np.random.choice(
+ a=len(positions0),
+ size=(max_iter, min_sample),
+ replace=True,
+ ):
+ # Fit to subset
+ subset = np.unique(subset)
+ candidate_model, _ = estimate_global_transformation(
+ positions0=positions0[subset],
+ positions1=positions1[subset],
+ origin=origin,
+ translation_allowed=translation_allowed,
+ xp=xp,
+ )
+
+ # Determine inliars and outliars
+ position_error = xp.linalg.norm(
+ candidate_model(positions0, origin=origin, xp=xp) - positions1,
+ axis=-1,
+ )
+ inliars = position_error <= max_error
+
+ # Check if consensus reached
+ if xp.sum(inliars) / len(inliars) >= min_consensus:
+ # Refit with consensus inliars
+ candidate_model, fitness = estimate_global_transformation(
+ positions0=positions0[inliars],
+ positions1=positions1[inliars],
+ origin=origin,
+ translation_allowed=translation_allowed,
+ xp=xp,
+ )
+ if fitness < best_fitness:
+ best_fitness = fitness
+ transform = candidate_model
+
+ return transform, best_fitness
+
+
+def fourier_ring_correlation(
+ image_1,
+ image_2,
+ pixel_size=None,
+ bin_size=None,
+ sigma=None,
+ align_images=False,
+ upsample_factor=8,
+ device="cpu",
+ plot_frc=True,
+ frc_color="red",
+ half_bit_color="blue",
+):
+ """
+ Computes fourier ring correlation (FRC) of 2 arrays.
+ Arrays must bet the same size.
+
+ Parameters
+ ----------
+ image1: ndarray
+ first image for FRC
+ image2: ndarray
+ second image for FRC
+ pixel_size: tuple
+ size of pixels in A (x,y)
+ bin_size: float, optional
+ size of bins for ring profile
+ sigma: float, optional
+ standard deviation for Gaussian kernel
+ align_images: bool
+ if True, aligns images using DFT upsampling of cross correlation.
+ upsample factor: int
+ if align_images, upsampling for correlation. Must be greater than 2.
+ device: str, optional
+ calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+ plot_frc: bool, optional
+ if True, plots frc
+ frc_color: str, optional
+ color of FRC line in plot
+ half_bit_color: str, optional
+ color of half-bit line
+
+ Returns
+ --------
+ q_frc: ndarray
+ spatial frequencies of FRC
+ frc: ndarray
+ fourier ring correlation
+ half_bit: ndarray
+ half-bit criteria
+ """
+ if device == "cpu":
+ xp = np
+ elif device == "gpu":
+ xp = cp
+
+ if align_images:
+ image_2 = align_and_shift_images(
+ image_1,
+ image_2,
+ upsample_factor=upsample_factor,
+ device=device,
+ )
+
+ fft_image_1 = xp.fft.fft2(image_1)
+ fft_image_2 = xp.fft.fft2(image_2)
+
+ cc_mixed = xp.real(fft_image_1 * xp.conj(fft_image_2))
+ cc_image_1 = xp.abs(fft_image_1) ** 2
+ cc_image_2 = xp.abs(fft_image_2) ** 2
+
+ # take 1D profile
+ q_frc, cc_mixed_1D, n = return_1D_profile(
+ cc_mixed,
+ pixel_size=pixel_size,
+ sigma=sigma,
+ bin_size=bin_size,
+ device=device,
+ )
+ _, cc_image_1_1D, _ = return_1D_profile(
+ cc_image_1, pixel_size=pixel_size, sigma=sigma, bin_size=bin_size, device=device
+ )
+ _, cc_image_2_1D, _ = return_1D_profile(
+ cc_image_2,
+ pixel_size=pixel_size,
+ sigma=sigma,
+ bin_size=bin_size,
+ device=device,
+ )
+
+ frc = cc_mixed_1D / ((cc_image_1_1D * cc_image_2_1D) ** 0.5)
+ half_bit = 2 / xp.sqrt(n / 2)
+
+ ind_max = xp.argmax(n)
+ q_frc = q_frc[1:ind_max]
+ frc = frc[1:ind_max]
+ half_bit = half_bit[1:ind_max]
+
+ if plot_frc:
+ fig, ax = plt.subplots()
+ if device == "gpu":
+ ax.plot(q_frc.get(), frc.get(), label="FRC", color=frc_color)
+ ax.plot(q_frc.get(), half_bit.get(), label="half bit", color=half_bit_color)
+ ax.set_xlim([0, q_frc.get().max()])
+ else:
+ ax.plot(q_frc, frc, label="FRC", color=frc_color)
+ ax.plot(q_frc, half_bit, label="half bit", color=half_bit_color)
+ ax.set_xlim([0, q_frc.max()])
+ ax.legend()
+ ax.set_ylim([0, 1])
+
+ if pixel_size is None:
+ ax.set_xlabel(r"Spatial frequency (pixels)")
+ else:
+ ax.set_xlabel(r"Spatial frequency ($\AA$)")
+ ax.set_ylabel("FRC")
+
+ return q_frc, frc, half_bit
+
+
+def return_1D_profile(
+ intensity, pixel_size=None, bin_size=None, sigma=None, device="cpu"
+):
+ """
+ Return 1D radial profile from corner centered array
+
+ Parameters
+ ----------
+ intensity: ndarray
+ Array for computing 1D profile
+ pixel_size: tuple
+ Size of pixels in A (x,y)
+ bin_size: float, optional
+ Size of bins for ring profile
+ sigma: float, optional
+ standard deviation for Gaussian kernel
+ device: str, optional
+ calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+
+ Returns
+ --------
+ q_bins: ndarray
+ spatial frequencies of bins
+ I_bins: ndarray
+ Intensity of bins
+ n: ndarray
+ Number of pixels in each bin
+ """
+ if device == "cpu":
+ xp = np
+ elif device == "gpu":
+ xp = cp
+
+ if pixel_size is None:
+ pixel_size = (1, 1)
+
+ x = xp.fft.fftfreq(intensity.shape[0], pixel_size[0])
+ y = xp.fft.fftfreq(intensity.shape[1], pixel_size[1])
+ q = xp.sqrt(x[:, None] ** 2 + y[None, :] ** 2)
+ q = q.ravel()
+
+ intensity = intensity.ravel()
+
+ if bin_size is None:
+ bin_size = q[1] - q[0]
+
+ q_bins = xp.arange(0, q.max() + bin_size, bin_size)
+
+ inds = q / bin_size
+ inds_f = xp.floor(inds).astype("int")
+ d_ind = inds - inds_f
+
+ nf = xp.bincount(inds_f, weights=(1 - d_ind), minlength=q_bins.shape[0])
+ nc = xp.bincount(inds_f + 1, weights=(d_ind), minlength=q_bins.shape[0])
+ n = nf + nc
+
+ I_bins0 = xp.bincount(
+ inds_f, weights=intensity * (1 - d_ind), minlength=q_bins.shape[0]
+ )
+ I_bins1 = xp.bincount(
+ inds_f + 1, weights=intensity * (d_ind), minlength=q_bins.shape[0]
+ )
+
+ I_bins = (I_bins0 + I_bins1) / n
+ if sigma is not None:
+ I_bins = gaussian_filter(I_bins, sigma)
+
+ return q_bins, I_bins, n
+
+
+def fourier_rotate_real_volume(array, angle, axes=(0, 1), xp=np):
+ """
+ Rotates a 3D array using three Fourier-based shear operators.
+
+ Parameters
+ ----------
+ array: ndarray
+ 3D array to rotate
+ angle: float
+ Angle in deg to rotate array by
+ axes: tuple, Optional
+ Axes defining plane in which to rotate about
+ xp: Callable, optional
+ Array computing module
+
+ Returns
+ --------
+ output_arr: ndarray
+ Fourier-rotated array
+ """
+ input_arr = xp.asarray(array, dtype=array.dtype)
+ array_shape = np.array(input_arr.shape)
+ ndim = input_arr.ndim
+
+ if ndim != 3:
+ raise ValueError("input array should be 3D")
+
+ axes = list(axes)
+
+ if len(axes) != 2:
+ raise ValueError("axes should contain exactly two values")
+
+ if not all([float(ax).is_integer() for ax in axes]):
+ raise ValueError("axes should contain only integer values")
+
+ if axes[0] < 0:
+ axes[0] += ndim
+ if axes[1] < 0:
+ axes[1] += ndim
+ if axes[0] < 0 or axes[1] < 0 or axes[0] >= ndim or axes[1] >= ndim:
+ raise ValueError("invalid rotation plane specified")
+
+ axes.sort()
+ rotation_ax = np.setdiff1d([0, 1, 2], axes)[0]
+ plane_dims = array_shape[axes]
+
+ qx = xp.fft.fftfreq(plane_dims[0], 1)
+ qy = xp.fft.fftfreq(plane_dims[1], 1)
+ qxa, qya = xp.meshgrid(qx, qy, indexing="ij")
+
+ x = xp.arange(plane_dims[0]) - plane_dims[0] / 2
+ y = xp.arange(plane_dims[1]) - plane_dims[1] / 2
+ xa, ya = xp.meshgrid(x, y, indexing="ij")
+
+ theta_90 = round(angle / 90)
+ theta_rest = (angle + 45) % 90 - 45
+
+ theta = np.deg2rad(theta_rest)
+ a = np.tan(-theta / 2)
+ b = np.sin(theta)
+
+ xOp = xp.exp(-2j * np.pi * qxa * ya * a)
+ yOp = xp.exp(-2j * np.pi * qya * xa * b)
+
+ output_arr = input_arr.copy()
+
+ # 90 degree rotation
+ if abs(theta_90) > 0:
+ if plane_dims[0] == plane_dims[1]:
+ output_arr = xp.rot90(output_arr, theta_90, axes=axes)
+ else:
+ if plane_dims[0] > plane_dims[1]:
+ xx = np.arange(plane_dims[1]) + (plane_dims[0] - plane_dims[1]) // 2
+ if rotation_ax == 0:
+ output_arr[:, xx, :] = xp.rot90(
+ output_arr[:, xx, :], theta_90, axes=axes
+ )
+ output_arr[:, : xx[0], :] = 0
+ output_arr[:, xx[-1] :, :] = 0
+ else:
+ output_arr[xx, :, :] = xp.rot90(
+ output_arr[xx, :, :], theta_90, axes=axes
+ )
+ output_arr[: xx[0], :, :] = 0
+ output_arr[xx[-1] :, :, :] = 0
+ else:
+ yy = np.arange(plane_dims[0]) + (plane_dims[1] - plane_dims[0]) // 2
+ if rotation_ax == 2:
+ output_arr[:, yy, :] = xp.rot90(
+ output_arr[:, yy, :], theta_90, axes=axes
+ )
+ output_arr[:, : yy[0], :] = 0
+ output_arr[:, yy[-1] :, :] = 0
+ else:
+ output_arr[:, :, yy] = xp.rot90(
+ output_arr[:, :, yy], theta_90, axes=axes
+ )
+ output_arr[:, :, : yy[0]] = 0
+ output_arr[:, :, yy[-1] :] = 0
+
+ # small rotation
+ if rotation_ax == 0:
+ output_arr = xp.fft.ifft(xp.fft.fft(output_arr, axis=1) * xOp[None, :], axis=1)
+ output_arr = xp.fft.ifft(xp.fft.fft(output_arr, axis=2) * yOp[None, :], axis=2)
+ output_arr = xp.fft.ifft(xp.fft.fft(output_arr, axis=1) * xOp[None, :], axis=1)
+ output_arr = xp.real(output_arr)
+
+ elif rotation_ax == 1:
+ output_arr = xp.fft.ifft(
+ xp.fft.fft(output_arr, axis=0) * xOp[:, None, :], axis=0
+ )
+ output_arr = xp.fft.ifft(
+ xp.fft.fft(output_arr, axis=2) * yOp[:, None, :], axis=2
+ )
+ output_arr = xp.fft.ifft(
+ xp.fft.fft(output_arr, axis=0) * xOp[:, None, :], axis=0
+ )
+ output_arr = np.real(output_arr)
+
+ else:
+ output_arr = xp.fft.ifft(
+ xp.fft.fft(output_arr, axis=0) * xOp[:, :, None], axis=0
+ )
+ output_arr = xp.fft.ifft(
+ xp.fft.fft(output_arr, axis=1) * yOp[:, :, None], axis=1
+ )
+ output_arr = xp.fft.ifft(
+ xp.fft.fft(output_arr, axis=0) * xOp[:, :, None], axis=0
+ )
+ output_arr = xp.real(output_arr)
+
+ return output_arr
+
+
+### Divergence Projection Functions
+
+
+def compute_divergence(vector_field, spacings, xp=np):
+ """Computes divergence of vector_field"""
+ num_dims = len(spacings)
+ div = xp.zeros_like(vector_field[0])
+
+ for i in range(num_dims):
+ div += xp.gradient(vector_field[i], spacings[i], axis=i)
+
+ return div
+
+
+def compute_gradient(scalar_field, spacings, xp=np):
+ """Computes gradient of scalar_field"""
+ num_dims = len(spacings)
+ grad = xp.zeros((num_dims,) + scalar_field.shape)
+
+ for i in range(num_dims):
+ grad[i] = xp.gradient(scalar_field, spacings[i], axis=i)
+
+ return grad
+
+
+def array_slice(axis, ndim, start, end, step=1):
+ """Returns array slice along dynamic axis"""
+ return (slice(None),) * (axis % ndim) + (slice(start, end, step),)
+
+
+def make_array_rfft_compatible(array_nd, axis=0, xp=np):
+ """Expand array to be rfft compatible"""
+ array_shape = np.array(array_nd.shape)
+ d = array_nd.ndim
+ n = array_shape[axis]
+ array_shape[axis] = (n + 1) * 2
+
+ dtype = array_nd.dtype
+ padded_array = xp.zeros(array_shape, dtype=dtype)
+
+ padded_array[array_slice(axis, d, 1, n + 1)] = -array_nd
+ padded_array[array_slice(axis, d, None, -n - 1, -1)] = array_nd
+
+ return padded_array
+
+
+def dst_I(array_nd, xp=np):
+ """1D rfft-based DST-I"""
+ d = array_nd.ndim
+ for axis in range(d):
+ crop_slice = array_slice(axis, d, 1, -1)
+ array_nd = rfft(
+ make_array_rfft_compatible(array_nd, axis=axis, xp=xp), axis=axis
+ )[crop_slice].imag
+
+ return array_nd
+
+
+def idst_I(array_nd, xp=np):
+ """1D rfft-based iDST-I"""
+ scaling = np.prod((np.array(array_nd.shape) + 1) * 2)
+ return dst_I(array_nd, xp=xp) / scaling
+
+
+def preconditioned_laplacian(num_exterior, spacing=1, xp=np):
+ """DST-I eigenvalues"""
+ n = num_exterior - 1
+ evals_1d = 2 - 2 * xp.cos(np.pi * xp.arange(1, num_exterior) / num_exterior)
+
+ op = (
+ xp.repeat(evals_1d, n**2)
+ + xp.tile(evals_1d, n**2)
+ + xp.tile(xp.repeat(evals_1d, n), n)
+ )
+
+ return -op / spacing**2
+
+
+def preconditioned_poisson_solver(rhs_interior, spacing=1, xp=np):
+ """DST-I based poisson solver"""
+ nx, ny, nz = rhs_interior.shape
+ if nx != ny or nx != nz:
+ raise ValueError()
+
+ op = preconditioned_laplacian(nx + 1, spacing=spacing, xp=xp)
+ if xp is np:
+ dst_rhs = dstn(rhs_interior, type=1).ravel()
+ dst_u = (dst_rhs / op).reshape((nx, ny, nz))
+ sol = idstn(dst_u, type=1)
+ else:
+ dst_rhs = dst_I(rhs_interior, xp=xp).ravel()
+ dst_u = (dst_rhs / op).reshape((nx, ny, nz))
+ sol = idst_I(dst_u, xp=xp)
+
+ return sol
+
+
+def project_vector_field_divergence(vector_field, spacings=(1, 1, 1), xp=np):
+ """
+ Returns solenoidal part of vector field using projection:
+
+ f - \\grad{p}
+ s.t. \\laplacian{p} = \\div{f}
+ """
+
+ div_v = compute_divergence(vector_field, spacings, xp=xp)
+ p = preconditioned_poisson_solver(div_v, spacings[0], xp=xp)
+ grad_p = compute_gradient(p, spacings, xp=xp)
+ return vector_field - grad_p
+
+
+# Nesterov acceleration functions
+# https://blogs.princeton.edu/imabandit/2013/04/01/acceleratedgradientdescent/
+
+
+@functools.cache
+def nesterov_lambda(one_indexed_iter_num):
+ if one_indexed_iter_num == 0:
+ return 0
+ return (1 + np.sqrt(1 + 4 * nesterov_lambda(one_indexed_iter_num - 1) ** 2)) / 2
+
+
+def nesterov_gamma(zero_indexed_iter_num):
+ one_indexed_iter_num = zero_indexed_iter_num + 1
+ return (1 - nesterov_lambda(one_indexed_iter_num)) / nesterov_lambda(
+ one_indexed_iter_num + 1
+ )
+
+
+def cartesian_to_polar_transform_2Ddata(
+ im_cart,
+ xy_center,
+ num_theta_bins=90,
+ radius_max=None,
+ corner_centered=False,
+ xp=np,
+):
+ """
+ Quick cartesian to polar conversion.
+ """
+
+ # coordinates
+ if radius_max is None:
+ if corner_centered:
+ radius_max = np.min(np.array(im_cart.shape) // 2)
+ else:
+ radius_max = np.sqrt(np.sum(np.array(im_cart.shape) ** 2)) // 2
+
+ r = xp.arange(radius_max)
+ t = xp.linspace(
+ 0,
+ 2.0 * np.pi,
+ num_theta_bins,
+ endpoint=False,
+ )
+ ra, ta = xp.meshgrid(r, t)
+
+ # resampling coordinates
+ x = ra * xp.cos(ta) + xy_center[0]
+ y = ra * xp.sin(ta) + xy_center[1]
+
+ xf = xp.floor(x).astype("int")
+ yf = xp.floor(y).astype("int")
+ dx = x - xf
+ dy = y - yf
+
+ mode = "wrap" if corner_centered else "clip"
+
+ # resample image
+ im_polar = (
+ im_cart.ravel()[
+ xp.ravel_multi_index(
+ (xf, yf),
+ im_cart.shape,
+ mode=mode,
+ )
+ ]
+ * (1 - dx)
+ * (1 - dy)
+ + im_cart.ravel()[
+ xp.ravel_multi_index(
+ (xf + 1, yf),
+ im_cart.shape,
+ mode=mode,
+ )
+ ]
+ * (dx)
+ * (1 - dy)
+ + im_cart.ravel()[
+ xp.ravel_multi_index(
+ (xf, yf + 1),
+ im_cart.shape,
+ mode=mode,
+ )
+ ]
+ * (1 - dx)
+ * (dy)
+ + im_cart.ravel()[
+ xp.ravel_multi_index(
+ (xf + 1, yf + 1),
+ im_cart.shape,
+ mode=mode,
+ )
+ ]
+ * (dx)
+ * (dy)
+ )
+
+ return im_polar
+
+
+def polar_to_cartesian_transform_2Ddata(
+ im_polar,
+ xy_size,
+ xy_center,
+ corner_centered=False,
+ xp=np,
+):
+ """
+ Quick polar to cartesian conversion.
+ """
+
+ # coordinates
+ sx, sy = xy_size
+ cx, cy = xy_center
+
+ if corner_centered:
+ x = xp.fft.fftfreq(sx, d=1 / sx)
+ y = xp.fft.fftfreq(sy, d=1 / sy)
+ else:
+ x = xp.arange(sx)
+ y = xp.arange(sy)
+
+ xa, ya = xp.meshgrid(x, y, indexing="ij")
+ ra = xp.hypot(xa - cx, ya - cy)
+ ta = xp.arctan2(ya - cy, xa - cx)
+
+ t = xp.linspace(0, 2 * np.pi, im_polar.shape[0], endpoint=False)
+ t_step = t[1] - t[0]
+
+ # resampling coordinates
+ t_ind = ta / t_step
+ r_ind = ra.copy()
+ tf = xp.floor(t_ind).astype("int")
+ rf = xp.floor(r_ind).astype("int")
+
+ # resample image
+ im_cart = im_polar.ravel()[
+ xp.ravel_multi_index(
+ (tf, rf),
+ im_polar.shape,
+ mode=("wrap", "clip"),
+ )
+ ]
+
+ return im_cart
+
+
+def regularize_probe_amplitude(
+ probe_init,
+ width_max_pixels=2.0,
+ nearest_angular_neighbor_averaging=5,
+ enforce_constant_intensity=True,
+ corner_centered=False,
+):
+ """
+ Fits sigmoid for each angular direction.
+
+ Parameters
+ --------
+ probe_init: np.array
+ 2D complex image of the probe in Fourier space.
+ width_max_pixels: float
+ Maximum edge width of the probe in pixels.
+ nearest_angular_neighbor_averaging: int
+ Number of nearest angular neighbor pixels to average to make aperture less jagged.
+ enforce_constant_intensity: bool
+ Set to true to make intensity inside the aperture constant.
+ corner_centered: bool
+ If True, the probe is assumed to be corner-centered
+
+ Returns
+ --------
+ probe_corr: np.ndarray
+ 2D complex image of the corrected probe in Fourier space.
+ coefs_all: np.ndarray
+ coefficients for the sigmoid fits
+ """
+
+ # Get probe intensity
+ probe_amp = np.abs(probe_init)
+ probe_angle = np.angle(probe_init)
+ probe_int = probe_amp**2
+
+ # Center of mass for probe intensity
+ xy_center = get_CoM(probe_int, device="cpu", corner_centered=corner_centered)
+
+ # Convert intensity to polar coordinates
+ polar_int = cartesian_to_polar_transform_2Ddata(
+ probe_int,
+ xy_center=xy_center,
+ corner_centered=corner_centered,
+ xp=np,
+ )
+
+ # Fit corrected probe intensity
+ radius = np.arange(polar_int.shape[1])
+
+ # estimate initial parameters
+ sub = polar_int > (np.max(polar_int) * 0.5)
+ sig_0 = np.mean(polar_int[sub])
+ rad_0 = np.max(np.argwhere(np.sum(sub, axis=0)))
+ width = width_max_pixels * 0.5
+
+ # init
+ def step_model(radius, sig_0, rad_0, width):
+ return sig_0 * np.clip((rad_0 - radius) / width, 0.0, 1.0)
+
+ coefs_all = np.zeros((polar_int.shape[0], 3))
+ coefs_all[:, 0] = sig_0
+ coefs_all[:, 1] = rad_0
+ coefs_all[:, 2] = width
+
+ # bounds
+ lb = (0.0, 0.0, 1e-4)
+ ub = (np.inf, np.inf, width_max_pixels)
+
+ # refine parameters, generate polar image
+ polar_fit = np.zeros_like(polar_int)
+ for a0 in range(polar_int.shape[0]):
+ coefs_all[a0, :] = curve_fit(
+ step_model,
+ radius,
+ polar_int[a0, :],
+ p0=coefs_all[a0, :],
+ xtol=1e-12,
+ bounds=(lb, ub),
+ )[0]
+ polar_fit[a0, :] = step_model(radius, *coefs_all[a0, :])
+
+ # Compute best-fit constant intensity inside probe, update bounds
+ sig_0 = np.median(coefs_all[:, 0])
+ coefs_all[:, 0] = sig_0
+ lb = (sig_0 - 1e-8, 0.0, 1e-4)
+ ub = (sig_0 + 1e-8, np.inf, width_max_pixels)
+
+ # refine parameters, generate polar image
+ polar_int_corr = np.zeros_like(polar_int)
+ for a0 in range(polar_int.shape[0]):
+ coefs_all[a0, :] = curve_fit(
+ step_model,
+ radius,
+ polar_int[a0, :],
+ p0=coefs_all[a0, :],
+ xtol=1e-12,
+ bounds=(lb, ub),
+ )[0]
+ # polar_int_corr[a0, :] = step_model(radius, *coefs_all[a0, :])
+
+ # make aperture less jagged, using moving mean
+ coefs_all = np.apply_along_axis(
+ uniform_filter1d,
+ 0,
+ coefs_all,
+ size=nearest_angular_neighbor_averaging,
+ mode="wrap",
+ )
+ for a0 in range(polar_int.shape[0]):
+ polar_int_corr[a0, :] = step_model(radius, *coefs_all[a0, :])
+
+ # Convert back to cartesian coordinates
+ int_corr = polar_to_cartesian_transform_2Ddata(
+ polar_int_corr,
+ xy_size=probe_init.shape,
+ xy_center=xy_center,
+ corner_centered=corner_centered,
+ )
+
+ amp_corr = np.sqrt(np.maximum(int_corr, 0))
+
+ # Assemble output probe
+ if not enforce_constant_intensity:
+ max_coeff = np.sqrt(coefs_all[:, 0]).max()
+ amp_corr = amp_corr / max_coeff * probe_amp
+
+ probe_corr = amp_corr * np.exp(1j * probe_angle)
+
+ return probe_corr, polar_int, polar_int_corr, coefs_all
+
+
+def aberrations_basis_function(
+ probe_size,
+ probe_sampling,
+ energy,
+ max_angular_order,
+ max_radial_order,
+ xp=np,
+):
+ """ """
+
+ # Add constant phase shift in basis
+ mn = [[-1, 0, 0]]
+
+ for m in range(1, max_radial_order):
+ n_max = np.minimum(max_angular_order, m + 1)
+ for n in range(0, n_max + 1):
+ if (m + n) % 2:
+ mn.append([m, n, 0])
+ if n > 0:
+ mn.append([m, n, 1])
+
+ aberrations_mn = np.array(mn)
+ aberrations_mn = aberrations_mn[np.argsort(aberrations_mn[:, 1]), :]
+
+ sub = aberrations_mn[:, 1] > 0
+ aberrations_mn[sub, :] = aberrations_mn[sub, :][
+ np.argsort(aberrations_mn[sub, 0]), :
+ ]
+ aberrations_mn[~sub, :] = aberrations_mn[~sub, :][
+ np.argsort(aberrations_mn[~sub, 0]), :
+ ]
+ aberrations_num = aberrations_mn.shape[0]
+
+ sx, sy = probe_size
+ dx, dy = probe_sampling
+ wavelength = electron_wavelength_angstrom(energy)
+
+ qx = xp.fft.fftfreq(sx, dx)
+ qy = xp.fft.fftfreq(sy, dy)
+ qr2 = qx[:, None] ** 2 + qy[None, :] ** 2
+ alpha = xp.sqrt(qr2) * wavelength
+ theta = xp.arctan2(qy[None, :], qx[:, None])
+
+ # Aberration basis
+ aberrations_basis = xp.ones((alpha.size, aberrations_num))
+
+ # Skip constant to avoid dividing by zero in normalization
+ for a0 in range(1, aberrations_num):
+ m, n, a = aberrations_mn[a0]
+ if n == 0:
+ # Radially symmetric basis
+ aberrations_basis[:, a0] = (alpha ** (m + 1) / (m + 1)).ravel()
+
+ elif a == 0:
+ # cos coef
+ aberrations_basis[:, a0] = (
+ alpha ** (m + 1) * xp.cos(n * theta) / (m + 1)
+ ).ravel()
+ else:
+ # sin coef
+ aberrations_basis[:, a0] = (
+ alpha ** (m + 1) * xp.sin(n * theta) / (m + 1)
+ ).ravel()
+
+ # global scaling
+ aberrations_basis *= 2 * np.pi / wavelength
+
+ return aberrations_basis, aberrations_mn
+
+
+def fit_aberration_surface(
+ complex_probe,
+ probe_sampling,
+ energy,
+ max_angular_order,
+ max_radial_order,
+ xp=np,
+):
+ """ """
+ probe_amp = xp.abs(complex_probe)
+ probe_angle = -xp.angle(complex_probe)
+
+ if xp is np:
+ probe_angle = probe_angle.astype(np.float64)
+ unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True).astype(xp.float32)
+ else:
+ probe_angle = xp.asnumpy(probe_angle).astype(np.float64)
+ unwrapped_angle = unwrap_phase(probe_angle, wrap_around=True)
+ unwrapped_angle = xp.asarray(unwrapped_angle).astype(xp.float32)
+
+ raveled_basis, _ = aberrations_basis_function(
+ complex_probe.shape,
+ probe_sampling,
+ energy,
+ max_angular_order,
+ max_radial_order,
+ xp=xp,
+ )
+
+ raveled_weights = probe_amp.ravel()
+
+ Aw = raveled_basis * raveled_weights[:, None]
+ bw = unwrapped_angle.ravel() * raveled_weights
+ coeff = xp.linalg.lstsq(Aw, bw, rcond=None)[0]
+
+ fitted_angle = xp.tensordot(raveled_basis, coeff, axes=1).reshape(probe_angle.shape)
+
+ return fitted_angle, coeff
+
+
+def rotate_point(origin, point, angle):
+ """
+ Rotate a point (x1, y1) counterclockwise by a given angle around
+ a given origin (x0, y0).
+
+ Parameters
+ --------
+ origin: 2-tuple of floats
+ (x0, y0)
+ point: 2-tuple of floats
+ (x1, y1)
+ angle: float (radians)
+
+ Returns
+ --------
+ rotated points (2-tuple)
+
+ """
+ ox, oy = origin
+ px, py = point
+
+ qx = ox + np.cos(angle) * (px - ox) - np.sin(angle) * (py - oy)
+ qy = oy + np.sin(angle) * (px - ox) + np.cos(angle) * (py - oy)
+ return qx, qy
diff --git a/py4DSTEM/process/polar/__init__.py b/py4DSTEM/process/polar/__init__.py
new file mode 100644
index 000000000..06d32c88e
--- /dev/null
+++ b/py4DSTEM/process/polar/__init__.py
@@ -0,0 +1,10 @@
+from py4DSTEM.process.polar.polar_datacube import PolarDatacube
+from py4DSTEM.process.polar.polar_fits import fit_amorphous_ring, plot_amorphous_ring
+from py4DSTEM.process.polar.polar_peaks import (
+ find_peaks_single_pattern,
+ find_peaks,
+ refine_peaks,
+ plot_radial_peaks,
+ plot_radial_background,
+ make_orientation_histogram,
+)
diff --git a/py4DSTEM/process/polar/polar_analysis.py b/py4DSTEM/process/polar/polar_analysis.py
new file mode 100644
index 000000000..4f053dcfa
--- /dev/null
+++ b/py4DSTEM/process/polar/polar_analysis.py
@@ -0,0 +1,644 @@
+# Analysis scripts for amorphous 4D-STEM data using polar transformations.
+
+import numpy as np
+import matplotlib.pyplot as plt
+from scipy.optimize import curve_fit
+from scipy.ndimage import gaussian_filter
+
+from emdfile import tqdmnd
+
+
+def calculate_radial_statistics(
+ self,
+ plot_results_mean=False,
+ plot_results_var=False,
+ figsize=(8, 4),
+ returnval=False,
+ returnfig=False,
+ progress_bar=True,
+):
+ """
+ Calculate the radial statistics used in fluctuation electron microscopy (FEM)
+ and as an initial step in radial distribution function (RDF) calculation.
+ The computed quantities are the radial mean, variance, and normalized variance.
+
+ There are several ways the means and variances can be computed. Here we first
+ compute the mean and standard deviation pattern by pattern, i.e. for
+ diffraction signal d(x,y; q,theta) we take
+
+ d_mean_all(x,y; q) = \int_{0}^{2\pi} d(x,y; q,\theta) d\theta
+ d_var_all(x,y; q) = \int_{0}^{2\pi}
+ \( d(x,y; q,\theta) - d_mean_all(x,y; q,\theta) \)^2 d\theta
+
+ Then we find the mean and variance profiles by taking the means of these
+ quantities over all scan positions:
+
+ d_mean(q) = \sum_{x,y} d_mean_all(x,y; q)
+ d_var(q) = \sum_{x,y} d_var_all(x,y; q)
+
+ and the normalized variance is d_var/d_mean.
+
+ This follows the methods described in [@cophus TODO ADD CITATION].
+
+
+ Parameters
+ --------
+ plot_results_mean: bool
+ Toggles plotting the computed radial means
+ plot_results_var: bool
+ Toggles plotting the computed radial variances
+ figsize: 2-tuple
+ Size of output figures
+ returnval: bool
+ Toggles returning the answer. Answers are always stored internally.
+ returnfig: bool
+ Toggles returning figures that have been plotted. Only figures for
+ which `plot_results_*` is True are returned.
+
+ Returns
+ --------
+ radial_avg: np.array
+ Optional - returned iff returnval is True. The average radial intensity.
+ radial_var: np.array
+ Optional - returned iff returnval is True. The radial variance.
+ fig_means: 2-tuple (fig,ax)
+ Optional - returned iff returnfig is True. Plot of the radial means.
+ fig_var: 2-tuple (fig,ax)
+ Optional - returned iff returnfig is True. Plot of the radial variances.
+ """
+
+ # init radial data arrays
+ self.radial_all = np.zeros(
+ (
+ self._datacube.shape[0],
+ self._datacube.shape[1],
+ self.polar_shape[1],
+ )
+ )
+ self.radial_all_std = np.zeros(
+ (
+ self._datacube.shape[0],
+ self._datacube.shape[1],
+ self.polar_shape[1],
+ )
+ )
+
+ # Compute the radial mean and standard deviation for each probe position
+ for rx, ry in tqdmnd(
+ self._datacube.shape[0],
+ self._datacube.shape[1],
+ desc="Radial statistics",
+ unit=" probe positions",
+ disable=not progress_bar,
+ ):
+ self.radial_all[rx, ry] = np.mean(self.data[rx, ry], axis=0)
+ self.radial_all_std[rx, ry] = np.sqrt(
+ np.mean((self.data[rx, ry] - self.radial_all[rx, ry][None]) ** 2, axis=0)
+ )
+
+ self.radial_mean = np.mean(self.radial_all, axis=(0, 1))
+ self.radial_var = np.mean(
+ (self.radial_all - self.radial_mean[None, None]) ** 2, axis=(0, 1)
+ )
+
+ self.radial_var_norm = np.copy(self.radial_var)
+ sub = self.radial_mean > 0.0
+ self.radial_var_norm[sub] /= self.radial_mean[sub] ** 2
+
+ # prepare answer
+ statistics = self.radial_mean, self.radial_var, self.radial_var_norm
+ if returnval:
+ ans = statistics if not returnfig else [statistics]
+ else:
+ ans = None if not returnfig else []
+
+ # plot results
+ if plot_results_mean:
+ fig, ax = plot_radial_mean(
+ self,
+ figsize=figsize,
+ returnfig=True,
+ )
+ if returnfig:
+ ans.append((fig, ax))
+ if plot_results_var:
+ fig, ax = plot_radial_var_norm(
+ self,
+ figsize=figsize,
+ returnfig=True,
+ )
+ if returnfig:
+ ans.append((fig, ax))
+
+ # return
+ return ans
+
+
+def plot_radial_mean(
+ self,
+ log_x=False,
+ log_y=False,
+ figsize=(8, 4),
+ returnfig=False,
+):
+ """
+ Plot the radial means.
+
+ Parameters
+ ----------
+ log_x : bool
+ Toggle log scaling of the x-axis
+ log_y : bool
+ Toggle log scaling of the y-axis
+ figsize : 2-tuple
+ Size of the output figure
+ returnfig : bool
+ Toggle returning the figure
+ """
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.plot(
+ self.qq,
+ self.radial_mean,
+ )
+
+ if log_x:
+ ax.set_xscale("log")
+ if log_y:
+ ax.set_yscale("log")
+
+ ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")")
+ ax.set_ylabel("Radial Mean")
+ if log_x and self.qq[0] == 0.0:
+ ax.set_xlim((self.qq[1], self.qq[-1]))
+ else:
+ ax.set_xlim((self.qq[0], self.qq[-1]))
+
+ if returnfig:
+ return fig, ax
+
+
+def plot_radial_var_norm(
+ self,
+ figsize=(8, 4),
+ returnfig=False,
+):
+ """
+ Plot the radial variances.
+
+ Parameters
+ ----------
+ figsize : 2-tuple
+ Size of the output figure
+ returnfig : bool
+ Toggle returning the figure
+
+ """
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.plot(
+ self.qq,
+ self.radial_var_norm,
+ )
+
+ ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")")
+ ax.set_ylabel("Normalized Variance")
+ ax.set_xlim((self.qq[0], self.qq[-1]))
+
+ if returnfig:
+ return fig, ax
+
+
+def calculate_pair_dist_function(
+ self,
+ k_min=0.05,
+ k_max=None,
+ k_width=0.25,
+ k_lowpass=None,
+ k_highpass=None,
+ r_min=0.0,
+ r_max=20.0,
+ r_step=0.02,
+ damp_origin_fluctuations=True,
+ density=None,
+ plot_background_fits=False,
+ plot_sf_estimate=False,
+ plot_reduced_pdf=True,
+ plot_pdf=False,
+ figsize=(8, 4),
+ maxfev=None,
+ returnval=False,
+ returnfig=False,
+):
+ """
+ Calculate the pair distribution function (PDF).
+
+ First a background is calculated using primarily the signal at the highest
+ scattering vectors available, given by a sum of two exponentials ~exp(-k^2)
+ and ~exp(-k^4) modelling the single atom scattering factor plus a constant
+ offset. Next, the structure factor is computed as
+
+ S(k) = (I(k) - bg(k)) * k / f(k)
+
+ where k is the magnitude of the scattering vector, I(k) is the mean radial
+ signal, f(k) is the single atom scattering factor, and bg(k) is the total
+ background signal (i.e. f(k) plus a constant offset). S(k) is masked outside
+ of the selected fitting region of k-values [k_min,k_max] and low/high pass
+ filters are optionally applied. The structure factor is then inverted into
+ the reduced pair distribution function g(r) using
+
+ g(r) = \frac{2}{\pi) \int sin( 2\pi r k ) S(k) dk
+
+ The value of the integral is (optionally) damped to zero at the origin to
+ match the physical requirement that this condition holds. Finally, the
+ full PDF G(r) is computed if a known density is provided, using
+
+ G(r) = 1 + [ \frac{2}{\pi} * g(r) / ( 4\pi * D * r dr ) ]
+
+ This follows the methods described in [@cophus TODO ADD CITATION].
+
+
+ Parameters
+ ----------
+ k_min : number
+ Minimum scattering vector to include in the calculation
+ k_max : number or None
+ Maximum scattering vector to include in the calculation. Note that
+ this cutoff is used when calculating the structure factor - however it
+ is *not* used when estimating the background / single atom scattering
+ factor, which is best estimated from high scattering lengths.
+ k_width : number
+ The fitting window for the structure factor calculation [k_min,k_max]
+ includes a damped region at its edges, i.e. the signal is smoothly dampled
+ to zero in the regions [k_min, k_min+k_width] and [k_max-k_width,k_max]
+ k_lowpass : number or None
+ Lowpass filter, in units the scattering vector stepsize (i.e. self.qstep)
+ k_highpass : number or None
+ Highpass filter, in units the scattering vector stepsize (i.e. self.qstep)
+ r_min,r_max,r_step : numbers
+ Define the real space coordinates r that the PDF g(r) will be computed in.
+ The coordinates will be np.arange(r_min,r_max,r_step), given in units
+ inverse to the scattering vector units.
+ damp_origin_fluctuations : bool
+ The value of the PDF approaching the origin should be zero, however numerical
+ instability may result in non-physical finite values there. This flag toggles
+ damping the value of the PDF to zero near the origin.
+ density : number or None
+ The density of the sample, if known. If this is not provided, only the
+ reduced PDF is calculated. If this value is provided, the PDF is also
+ calculated.
+ plot_background_fits : bool
+ plot_sf_estimate : bool
+ plot_reduced_pdf=True : bool
+ plot_pdf : bool
+ figsize : 2-tuple
+ maxfev : integer or None
+ Max number of iterations to use when fitting the background
+ returnval: bool
+ Toggles returning the answer. Answers are always stored internally.
+ returnfig: bool
+ Toggles returning figures that have been plotted. Only figures for
+ which `plot_*` is True are returned.
+ """
+
+ # set up coordinates and scaling
+ k = self.qq
+ dk = k[1] - k[0]
+ k2 = k**2
+ Ik = self.radial_mean
+ int_mean = np.mean(Ik)
+ sub_fit = k >= k_min
+
+ # initial guesses for background coefs
+ const_bg = np.min(self.radial_mean) / int_mean
+ int0 = np.median(self.radial_mean) / int_mean - const_bg
+ sigma0 = np.mean(k)
+ coefs = [const_bg, int0, sigma0, int0, sigma0]
+ lb = [0, 0, 0, 0, 0]
+ ub = [np.inf, np.inf, np.inf, np.inf, np.inf]
+ # Weight the fit towards high k values
+ noise_est = k[-1] - k + dk
+
+ # Estimate the mean atomic form factor + background
+ if maxfev is None:
+ coefs = curve_fit(
+ scattering_model,
+ k2[sub_fit],
+ Ik[sub_fit] / int_mean,
+ sigma=noise_est[sub_fit],
+ p0=coefs,
+ xtol=1e-8,
+ bounds=(lb, ub),
+ )[0]
+ else:
+ coefs = curve_fit(
+ scattering_model,
+ k2[sub_fit],
+ Ik[sub_fit] / int_mean,
+ sigma=noise_est[sub_fit],
+ p0=coefs,
+ xtol=1e-8,
+ bounds=(lb, ub),
+ maxfev=maxfev,
+ )[0]
+
+ coefs[0] *= int_mean
+ coefs[1] *= int_mean
+ coefs[3] *= int_mean
+
+ # Calculate the mean atomic form factor without a constant offset
+ # coefs_fk = (0.0, coefs[1], coefs[2], coefs[3], coefs[4])
+ # fk = scattering_model(k2, coefs_fk)
+ bg = scattering_model(k2, coefs)
+ fk = bg - coefs[0]
+
+ # mask for structure factor estimate
+ if k_max is None:
+ k_max = np.max(k)
+ mask = np.clip(
+ np.minimum(
+ (k - 0.0) / k_width,
+ (k_max - k) / k_width,
+ ),
+ 0,
+ 1,
+ )
+ mask = np.sin(mask * (np.pi / 2))
+
+ # Estimate the reduced structure factor S(k)
+ Sk = (Ik - bg) * k / fk
+
+ # Masking edges of S(k)
+ mask_sum = np.sum(mask)
+ Sk = (Sk - np.sum(Sk * mask) / mask_sum) * mask
+
+ # Filtering of S(k)
+ if k_lowpass is not None and k_lowpass > 0.0:
+ Sk = gaussian_filter(Sk, sigma=k_lowpass / dk, mode="nearest")
+ if k_highpass is not None and k_highpass > 0.0:
+ Sk_lowpass = gaussian_filter(Sk, sigma=k_highpass / dk, mode="nearest")
+ Sk -= Sk_lowpass
+
+ # Calculate the PDF
+ r = np.arange(r_min, r_max, r_step)
+ ra, ka = np.meshgrid(r, k)
+ pdf_reduced = (
+ (2 / np.pi)
+ * dk
+ * np.sum(
+ np.sin(2 * np.pi * ra * ka) * Sk[:, None],
+ axis=0,
+ )
+ )
+
+ # Damp the unphysical fluctuations at the PDF origin
+ if damp_origin_fluctuations:
+ ind_max = np.argmax(pdf_reduced)
+ r_ind_max = r[ind_max]
+ r_mask = np.minimum(r / r_ind_max, 1.0)
+ r_mask = np.sin(r_mask * np.pi / 2) ** 2
+ pdf_reduced *= r_mask
+
+ # Store results
+ self.pdf_r = r
+ self.pdf_reduced = pdf_reduced
+
+ self.Sk = Sk
+ self.fk = fk
+ self.bg = bg
+ self.offset = coefs[0]
+ self.Sk_mask = mask
+
+ # if density is provided, we can estimate the full PDF
+ if density is not None:
+ pdf = pdf_reduced.copy()
+ pdf[1:] /= 4 * np.pi * density * r[1:] * (r[1] - r[0])
+ pdf *= 2 / np.pi
+ pdf += 1
+
+ # damp and clip values below zero
+ if damp_origin_fluctuations:
+ pdf *= r_mask
+ pdf = np.maximum(pdf, 0.0)
+
+ # store results
+ self.pdf = pdf
+
+ # prepare answer
+ if density is None:
+ return_values = self.pdf_r, self.pdf_reduced
+ else:
+ return_values = self.pdf_r, self.pdf_reduced, self.pdf
+ if returnval:
+ ans = return_values if not returnfig else [return_values]
+ else:
+ ans = None if not returnfig else []
+
+ # Plots
+ if plot_background_fits:
+ fig, ax = self.plot_background_fits(figsize=figsize, returnfig=True)
+ if returnfig:
+ ans.append((fig, ax))
+
+ if plot_sf_estimate:
+ fig, ax = self.plot_sf_estimate(figsize=figsize, returnfig=True)
+ if returnfig:
+ ans.append((fig, ax))
+
+ if plot_reduced_pdf:
+ fig, ax = self.plot_reduced_pdf(figsize=figsize, returnfig=True)
+ if returnfig:
+ ans.append((fig, ax))
+
+ if plot_pdf:
+ fig, ax = self.plot_pdf(figsize=figsize, returnfig=True)
+ if returnfig:
+ ans.append((fig, ax))
+
+ # return
+ return ans
+
+
+def plot_background_fits(
+ self,
+ figsize=(8, 4),
+ returnfig=False,
+):
+ """
+ TODO
+ """
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.plot(
+ self.qq,
+ self.radial_mean,
+ color="k",
+ )
+ ax.plot(
+ self.qq,
+ self.bg,
+ color="r",
+ )
+ ax.set_xlabel("Scattering Vector (" + self.calibration.get_Q_pixel_units() + ")")
+ ax.set_ylabel("Radial Mean")
+ ax.set_xlim((self.qq[0], self.qq[-1]))
+ ax.set_xlabel("Scattering Vector [A^-1]")
+ ax.set_ylabel("I(k) and Background Fit Estimates")
+ ax.set_ylim(
+ (
+ np.min(self.radial_mean[self.radial_mean > 0]) * 0.8,
+ np.max(self.radial_mean * self.Sk_mask) * 1.25,
+ )
+ )
+ ax.set_yscale("log")
+ if returnfig:
+ return fig, ax
+ plt.show()
+
+
+def plot_sf_estimate(
+ self,
+ figsize=(8, 4),
+ returnfig=False,
+):
+ """
+ TODO
+ """
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.plot(
+ self.qq,
+ self.Sk,
+ color="r",
+ )
+ yr = (np.min(self.Sk), np.max(self.Sk))
+ ax.set_ylim(
+ (
+ yr[0] - 0.05 * (yr[1] - yr[0]),
+ yr[1] + 0.05 * (yr[1] - yr[0]),
+ )
+ )
+ ax.set_xlabel("Scattering Vector [A^-1]")
+ ax.set_ylabel("Reduced Structure Factor")
+ if returnfig:
+ return fig, ax
+ plt.show()
+
+
+def plot_reduced_pdf(
+ self,
+ figsize=(8, 4),
+ returnfig=False,
+):
+ """
+ TODO
+ """
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.plot(
+ self.pdf_r,
+ self.pdf_reduced,
+ color="r",
+ )
+ ax.set_xlabel("Radius [A]")
+ ax.set_ylabel("Reduced Pair Distribution Function")
+ if returnfig:
+ return fig, ax
+ plt.show()
+
+
+def plot_pdf(
+ self,
+ figsize=(8, 4),
+ returnfig=False,
+):
+ """
+ TODO
+ """
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.plot(
+ self.pdf_r,
+ self.pdf,
+ color="r",
+ )
+ ax.set_xlabel("Radius [A]")
+ ax.set_ylabel("Pair Distribution Function")
+ if returnfig:
+ return fig, ax
+ plt.show()
+
+ # functions for inverting from reduced PDF back to S(k)
+
+ # # invert
+ # ind_max = np.argmax(pdf_reduced* np.sqrt(r))
+ # r_ind_max = r[ind_max-1]
+ # r_mask = np.minimum(r / (r_ind_max), 1.0)
+ # r_mask = np.sin(r_mask*np.pi/2)**2
+
+ # Sk_back_proj = (0.5*r_step)*np.sum(
+ # np.sin(
+ # 2*np.pi*ra*ka
+ # ) * pdf_corr[None,:],# * r_mask[None,:],
+ # # ) * pdf_corr[None,:],# * r_mask[None,:],
+ # axis=1,
+ # )
+
+
+def calculate_FEM_local(
+ self,
+ figsize=(8, 6),
+ returnfig=False,
+):
+ """
+ Calculate fluctuation electron microscopy (FEM) statistics, including radial mean,
+ variance, and normalized variance. This function computes the radial average and variance
+ for each individual probe position, which can then be mapped over the field-of-view.
+
+ Parameters
+ --------
+ self: PolarDatacube
+ Polar datacube used for measuring FEM properties.
+
+ Returns
+ --------
+ radial_avg: np.array
+ Average radial intensity
+ radial_var: np.array
+ Variance in the radial dimension
+
+
+ """
+
+ pass
+
+
+def scattering_model(k2, *coefs):
+ """
+ The scattering model used to fit the PDF background. The fit
+ function is a constant plus two exponentials - one in k^2 and one
+ in k^4:
+
+ f(k; c,i0,s0,i1,s1) =
+ c + i0*exp(k^2/-2*s0^2) + i1*exp(k^4/-2*s1^4)
+
+ Parameters
+ ----------
+ k2 : 1d array
+ the scattering vector squared
+ coefs : 5-tuple
+ Initial guesses at the parameters (c,i0,s0,i1,s1)
+ """
+ coefs = np.squeeze(np.array(coefs))
+
+ const_bg = coefs[0]
+ int0 = coefs[1]
+ sigma0 = coefs[2]
+ int1 = coefs[3]
+ sigma1 = coefs[4]
+
+ int_model = (
+ const_bg
+ + int0 * np.exp(k2 / (-2 * sigma0**2))
+ + int1 * np.exp(k2**2 / (-2 * sigma1**4))
+ )
+
+ # (int1*sigma1)/(k2 + sigma1**2)
+ # int1*np.exp(k2/(-2*sigma1**2))
+ # int1*np.exp(k2/(-2*sigma1**2))
+
+ return int_model
diff --git a/py4DSTEM/process/polar/polar_datacube.py b/py4DSTEM/process/polar/polar_datacube.py
new file mode 100644
index 000000000..56071c534
--- /dev/null
+++ b/py4DSTEM/process/polar/polar_datacube.py
@@ -0,0 +1,563 @@
+import numpy as np
+from py4DSTEM.datacube import DataCube
+from scipy.ndimage import binary_opening, binary_closing, gaussian_filter1d
+
+
+class PolarDatacube:
+
+ """
+ An interface to a 4D-STEM datacube under polar-elliptical transformation.
+ """
+
+ def __init__(
+ self,
+ datacube,
+ qmin=0.0,
+ qmax=None,
+ qstep=1.0,
+ n_annular=180,
+ qscale=None,
+ mask=None,
+ mask_thresh=0.1,
+ ellipse=True,
+ two_fold_symmetry=False,
+ ):
+ """
+ Parameters
+ ----------
+ datacube : DataCube
+ The datacube in cartesian coordinates
+ qmin : number
+ Minumum radius of the polar transformation, in pixels
+ qmax : number or None
+ Maximum radius of the polar transformation, in pixels
+ qstep : number
+ Width of radial bins, in pixels
+ n_annular : integer
+ Number of bins in the annular direction. Bins will each
+ have a width of 360/n_annular, or 180/n_annular if
+ two_fold_rotation is set to True, in degrees
+ qscale : number or None
+ Radial scaling power to apply to polar transform
+ mask : boolean array
+ Cartesian space shaped mask to apply to all transforms
+ mask_thresh : number
+ Pixels below this value in the transformed mask are considered
+ masked pixels
+ ellipse : bool
+ Setting to False forces a circular transform. Setting to True
+ performs an elliptic transform iff elliptic calibrations are
+ available.
+ two_fold_rotation : bool
+ Setting to True computes the transform mod(theta,pi), i.e. assumes
+ all patterns possess two-fold rotation (Friedel symmetry). The
+ output angular range in this case becomes [0, pi) as opposed to the
+ default of [0,2*pi).
+ """
+
+ # attach datacube
+ assert isinstance(datacube, DataCube)
+ self._datacube = datacube
+ self._datacube.polar = self
+
+ # check for calibrations
+ assert hasattr(self._datacube, "calibration"), "No .calibration found"
+ self.calibration = self._datacube.calibration
+
+ # setup data getter
+ self._set_polar_data_getter()
+
+ # setup sampling
+
+ # polar
+ self._qscale = qscale
+ if qmax is None:
+ qmax = np.min(self._datacube.Qshape) / np.sqrt(2)
+ self._n_annular = n_annular
+ self.two_fold_symmetry = two_fold_symmetry # implicitly calls set_annular_bins
+ self.set_radial_bins(qmin, qmax, qstep)
+
+ # cartesian
+ self._xa, self._ya = np.meshgrid(
+ np.arange(self._datacube.Q_Nx),
+ np.arange(self._datacube.Q_Ny),
+ indexing="ij",
+ )
+
+ # ellipse
+ self.ellipse = ellipse
+
+ # mask
+ self._mask_thresh = mask_thresh
+ self.mask = mask
+
+ pass
+
+ from py4DSTEM.process.polar.polar_analysis import (
+ calculate_radial_statistics,
+ calculate_pair_dist_function,
+ calculate_FEM_local,
+ plot_radial_mean,
+ plot_radial_var_norm,
+ plot_background_fits,
+ plot_sf_estimate,
+ plot_reduced_pdf,
+ plot_pdf,
+ )
+ from py4DSTEM.process.polar.polar_peaks import (
+ find_peaks_single_pattern,
+ find_peaks,
+ refine_peaks_local,
+ refine_peaks,
+ plot_radial_peaks,
+ plot_radial_background,
+ model_radial_background,
+ make_orientation_histogram,
+ )
+
+ # sampling methods + properties
+ def set_radial_bins(
+ self,
+ qmin,
+ qmax,
+ qstep,
+ ):
+ self._qmin = qmin
+ self._qmax = qmax
+ self._qstep = qstep
+
+ self.radial_bins = np.arange(self._qmin, self._qmax, self._qstep)
+ self._radial_step = self._datacube.calibration.get_Q_pixel_size() * self._qstep
+ self.set_polar_shape()
+ self.qscale = self._qscale
+
+ @property
+ def qmin(self):
+ return self._qmin
+
+ @qmin.setter
+ def qmin(self, x):
+ self.set_radial_bins(x, self._qmax, self._qstep)
+
+ @property
+ def qmax(self):
+ return self._qmax
+
+ @qmin.setter
+ def qmax(self, x):
+ self.set_radial_bins(self._qmin, x, self._qstep)
+
+ @property
+ def qstep(self):
+ return self._qstep
+
+ @qstep.setter
+ def qstep(self, x):
+ self.set_radial_bins(self._qmin, self._qmax, x)
+
+ def set_annular_bins(self, n_annular):
+ self._n_annular = n_annular
+ self._annular_bins = np.linspace(
+ 0, self._annular_range, self._n_annular, endpoint=False
+ )
+ self._annular_step = self.annular_bins[1] - self.annular_bins[0]
+ self.set_polar_shape()
+
+ @property
+ def annular_bins(self):
+ return self._annular_bins
+
+ @property
+ def annular_step(self):
+ return self._annular_step
+
+ @property
+ def two_fold_symmetry(self):
+ return self._two_fold_symmetry
+
+ @two_fold_symmetry.setter
+ def two_fold_symmetry(self, x):
+ assert isinstance(
+ x, bool
+ ), f"two_fold_symmetry must be boolean, not type {type(x)}"
+ self._two_fold_symmetry = x
+ if x:
+ self._annular_range = np.pi
+ else:
+ self._annular_range = 2 * np.pi
+ self.set_annular_bins(self._n_annular)
+
+ @property
+ def n_annular(self):
+ return self._n_annular
+
+ @n_annular.setter
+ def n_annular(self, x):
+ self.set_annular_bins(x)
+
+ def set_polar_shape(self):
+ if hasattr(self, "radial_bins") and hasattr(self, "annular_bins"):
+ # set shape
+ self.polar_shape = np.array(
+ (self.annular_bins.shape[0], self.radial_bins.shape[0])
+ )
+ self.polar_size = np.prod(self.polar_shape)
+ # set KDE params
+ self._annular_bin_step = 1 / (
+ self._annular_step * (self.radial_bins + self.qstep * 0.5)
+ )
+ self._sigma_KDE = self._annular_bin_step * 0.5
+ # set array indices
+ self._annular_indices = np.arange(self.polar_shape[0]).astype(int)
+ self._radial_indices = np.arange(self.polar_shape[1]).astype(int)
+
+ # coordinate grid properties
+ @property
+ def tt(self):
+ return self._annular_bins
+
+ @property
+ def tt_deg(self):
+ return self._annular_bins * 180 / np.pi
+
+ @property
+ def qq(self):
+ return self.radial_bins * self.calibration.get_Q_pixel_size()
+
+ # scaling property
+ @property
+ def qscale(self):
+ return self._qscale
+
+ @qscale.setter
+ def qscale(self, x):
+ self._qscale = x
+ if x is not None:
+ self._qscale_ar = (self.qq / self.qq[-1]) ** x
+
+ # expose raw data
+ @property
+ def data_raw(self):
+ return self._datacube
+
+ # expose transformed data
+ @property
+ def data(self):
+ return self._polar_data_getter
+
+ def _set_polar_data_getter(self):
+ self._polar_data_getter = PolarDataGetter(polarcube=self)
+
+ # mask properties
+ @property
+ def mask(self):
+ return self._mask
+
+ @mask.setter
+ def mask(self, x):
+ if x is None:
+ self._mask = x
+ else:
+ assert (
+ x.shape == self._datacube.Qshape
+ ), "Mask shape must match diffraction space"
+ self._mask = x
+ self._mask_polar = self.transform(x)
+
+ @property
+ def mask_polar(self):
+ return self._mask_polar
+
+ @property
+ def mask_thresh(self):
+ return self._mask_thresh
+
+ @mask_thresh.setter
+ def mask_thresh(self, x):
+ self._mask_thresh = x
+ self.mask = self.mask
+
+ # expose transformation
+ @property
+ def transform(self):
+ """
+ Return a transformed copy of the diffraction pattern `cartesian_data`.
+
+ Parameters
+ ----------
+ cartesian_data : array
+ The data
+ origin : tuple or list or None
+ Variable behavior depending on the arg type. Length 2 tuples uses
+ these values directly. Length 2 list of ints uses the calibrated
+ origin value at this scan position. None uses the calibrated mean
+ origin.
+ mask : boolean array or None
+ A mask applied to the data before transformation. The value of
+ masked pixels (0's) in the output is determined by `returnval`. Note
+ that this mask is applied in combination with any mask at
+ PolarData.mask.
+ returnval : 'masked' or 'nan' or None
+ Controls the returned data. 'masked' returns a numpy masked array.
+ 'nan' returns a normal numpy array with masked pixels set to np.nan.
+ None returns a 2-tuple of numpy arrays - the transformed data with
+ masked pixels set to 0, and the transformed mask.
+ """
+ return self._polar_data_getter._transform
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( "
+ string += (
+ "Retrieves diffraction images in polar coordinates, using .data[x,y] )"
+ )
+ return string
+
+
+class PolarDataGetter:
+ def __init__(
+ self,
+ polarcube,
+ ):
+ self._polarcube = polarcube
+
+ def __getitem__(self, pos):
+ # unpack scan position
+ x, y = pos
+ # get the data
+ cartesian_data = self._polarcube._datacube[x, y]
+ # transform
+ ans = self._transform(cartesian_data, origin=[x, y], returnval="masked")
+ # return
+ return ans
+
+ def _transform(
+ self,
+ cartesian_data,
+ origin=None,
+ ellipse=None,
+ mask=None,
+ mask_thresh=None,
+ returnval="masked",
+ ):
+ """
+ Return a transformed copy of the diffraction pattern `cartesian_data`.
+
+ Parameters
+ ----------
+ cartesian_data : array
+ The data
+ origin : tuple or list or None
+ Variable behavior depending on the arg type. Length 2 tuples uses
+ these values directly. Length 2 list of ints uses the calibrated
+ origin value at this scan position. None uses the calibrated mean
+ origin.
+ ellipse : tuple or None
+ Variable behavior depending on the arg type. Length 3 tuples uses
+ these values directly (a,b,theta). None uses the calibrated value.
+ mask : boolean array or None
+ A mask applied to the data before transformation. The value of
+ masked pixels in the output is determined by `returnval`. Note that
+ this mask is applied in combination with any mask at PolarData.mask.
+ mask_thresh : number
+ Pixels in the transformed mask with values below this number are
+ considered masked, and will be populated by the values specified
+ by `returnval`.
+ returnval : 'masked' or 'nan' or 'all' or 'zeros' or 'all_zeros'
+ Controls the returned data, including how un-sampled points
+ are handled.
+ - 'masked' returns a numpy masked array.
+ - 'nan' returns a normal numpy array with unsampled pixels set to
+ np.nan.
+ - 'all' returns a 4-tuple of numpy arrays - the transformed data
+ with unsampled pixels set to 'nan', the normalization array, the
+ normalization array scaled to account for the q-dependent
+ sampling density, and the polar boolean mask
+ - 'zeros' returns a normal numpy with unsampled pixels set to 0
+ - 'all_zeros' returns the same 4-tuple as 'all', but with unsampled
+ pixels in the transformed data array set to zeros.
+
+ Returns
+ --------
+ variable
+ see `returnval`, above. Default is a masked array representing
+ the polar transformed data.
+ """
+
+ # get calibrations
+ if origin is None:
+ origin = self._polarcube.calibration.get_origin_mean()
+ elif isinstance(origin, list):
+ origin = self._polarcube.calibration.get_origin(origin[0], origin[1])
+ elif isinstance(origin, tuple):
+ pass
+ else:
+ raise Exception(f"Invalid type for `origin`, {type(origin)}")
+
+ if ellipse is None:
+ ellipse = self._polarcube.calibration.get_ellipse()
+ elif isinstance(ellipse, tuple):
+ pass
+ else:
+ raise Exception(f"Invalid type for `ellipse`, {type(ellipse)}")
+
+ # combine passed mask with default mask
+ mask0 = self._polarcube.mask
+ if mask is None and mask0 is None:
+ mask = np.ones_like(cartesian_data, dtype=bool)
+ elif mask is None:
+ mask = mask0
+ elif mask0 is None:
+ mask = mask
+ else:
+ mask = mask * mask0
+
+ if mask_thresh is None:
+ mask_thresh = self._polarcube.mask_thresh
+
+ # transform data
+ ans = self._transform_array(
+ cartesian_data * mask.astype("float"),
+ origin,
+ ellipse,
+ )
+
+ # transform normalization array
+ ans_norm = self._transform_array(
+ mask.astype("float"),
+ origin,
+ ellipse,
+ )
+
+ # scale the normalization array by the bin density
+ norm_array = ans_norm * self._polarcube._annular_bin_step[np.newaxis]
+ mask_bool = norm_array < mask_thresh
+
+ # apply normalization
+ ans = np.divide(
+ ans,
+ ans_norm,
+ out=np.full_like(ans, np.nan),
+ where=np.logical_not(mask_bool),
+ )
+
+ # radial power law scaling of output
+ if self._polarcube.qscale is not None:
+ ans *= self._polarcube._qscale_ar[np.newaxis, :]
+
+ # return
+ if returnval == "masked":
+ ans = np.ma.array(data=ans, mask=mask_bool)
+ return ans
+ elif returnval == "nan":
+ ans[mask_bool] = np.nan
+ return ans
+ elif returnval == "all":
+ return ans, ans_norm, norm_array, mask_bool
+ elif returnval == "zeros":
+ ans[mask_bool] = 0
+ return ans
+ elif returnval == "all_zeros":
+ ans[mask_bool] = 0
+ return ans, ans_norm, norm_array, mask_bool
+ else:
+ raise Exception(f"Unexpected value {returnval} encountered for `returnval`")
+
+ def _transform_array(
+ self,
+ data,
+ origin,
+ ellipse,
+ ):
+ # set origin
+ x = self._polarcube._xa - origin[0]
+ y = self._polarcube._ya - origin[1]
+
+ # circular
+ if (ellipse is None) or (self._polarcube.ellipse) is False:
+ # get polar coords
+ rr = np.sqrt(x**2 + y**2)
+ tt = np.mod(np.arctan2(y, x), self._polarcube._annular_range)
+
+ # elliptical
+ else:
+ # unpack ellipse
+ a, b, theta = ellipse
+
+ # Get polar coords
+ xc = x * np.cos(theta) + y * np.sin(theta)
+ yc = (y * np.cos(theta) - x * np.sin(theta)) * (a / b)
+ rr = (b / a) * np.hypot(xc, yc)
+ tt = np.mod(np.arctan2(yc, xc) + theta, self._polarcube._annular_range)
+
+ # transform to bin sampling
+ r_ind = (rr - self._polarcube.radial_bins[0]) / self._polarcube.qstep
+ t_ind = tt / self._polarcube.annular_step
+
+ # get integers and increments
+ r_ind_floor = np.floor(r_ind).astype("int")
+ t_ind_floor = np.floor(t_ind).astype("int")
+ dr = r_ind - r_ind_floor
+ dt = t_ind - t_ind_floor
+
+ # resample
+ sub = np.logical_and(
+ r_ind_floor >= 0,
+ r_ind_floor < self._polarcube.polar_shape[1],
+ )
+ im = np.bincount(
+ r_ind_floor[sub]
+ + np.mod(t_ind_floor[sub], self._polarcube.polar_shape[0])
+ * self._polarcube.polar_shape[1],
+ weights=data[sub] * (1 - dr[sub]) * (1 - dt[sub]),
+ minlength=self._polarcube.polar_size,
+ )
+ im += np.bincount(
+ r_ind_floor[sub]
+ + np.mod(t_ind_floor[sub] + 1, self._polarcube.polar_shape[0])
+ * self._polarcube.polar_shape[1],
+ weights=data[sub] * (1 - dr[sub]) * (dt[sub]),
+ minlength=self._polarcube.polar_size,
+ )
+ sub = np.logical_and(
+ r_ind_floor >= -1, r_ind_floor < self._polarcube.polar_shape[1] - 1
+ )
+ im += np.bincount(
+ r_ind_floor[sub]
+ + 1
+ + np.mod(t_ind_floor[sub], self._polarcube.polar_shape[0])
+ * self._polarcube.polar_shape[1],
+ weights=data[sub] * (dr[sub]) * (1 - dt[sub]),
+ minlength=self._polarcube.polar_size,
+ )
+ im += np.bincount(
+ r_ind_floor[sub]
+ + 1
+ + np.mod(t_ind_floor[sub] + 1, self._polarcube.polar_shape[0])
+ * self._polarcube.polar_shape[1],
+ weights=data[sub] * (dr[sub]) * (dt[sub]),
+ minlength=self._polarcube.polar_size,
+ )
+
+ # reshape to 2D
+ ans = np.reshape(im, self._polarcube.polar_shape)
+
+ # apply KDE
+ for a0 in range(self._polarcube.polar_shape[1]):
+ # Use 5% (= exp(-(1/2*.1669)^2)) cutoff value
+ # for adjacent pixel in kernel
+ if self._polarcube._sigma_KDE[a0] > 0.1669:
+ ans[:, a0] = gaussian_filter1d(
+ ans[:, a0],
+ sigma=self._polarcube._sigma_KDE[a0],
+ mode="wrap",
+ )
+
+ # return
+ return ans
+
+ def __repr__(self):
+ space = " " * len(self.__class__.__name__) + " "
+ string = f"{self.__class__.__name__}( "
+ string += "Retrieves the diffraction pattern at scan position (x,y) in polar coordinates when sliced with [x,y]."
+ return string
diff --git a/py4DSTEM/process/polar/polar_fits.py b/py4DSTEM/process/polar/polar_fits.py
new file mode 100644
index 000000000..3e39c5584
--- /dev/null
+++ b/py4DSTEM/process/polar/polar_fits.py
@@ -0,0 +1,398 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+# from scipy.optimize import leastsq
+from scipy.optimize import curve_fit
+
+
+def fit_amorphous_ring(
+ im,
+ center=None,
+ radial_range=None,
+ coefs=None,
+ mask_dp=None,
+ show_fit_mask=False,
+ maxfev=None,
+ verbose=False,
+ plot_result=True,
+ plot_log_scale=False,
+ plot_int_scale=(-3, 3),
+ figsize=(8, 8),
+ return_all_coefs=True,
+):
+ """
+ Fit an amorphous halo with a two-sided Gaussian model, plus a background
+ Gaussian function.
+
+ Parameters
+ --------
+ im: np.array
+ 2D image array to perform fitting on
+ center: np.array
+ (x,y) center coordinates for fitting mask. If not specified
+ by the user, we will assume the center coordinate is (im.shape-1)/2.
+ radial_range: np.array
+ (radius_inner, radius_outer) radial range to perform fitting over.
+ If not specified by the user, we will assume (im.shape[0]/4,im.shape[0]/2).
+ coefs: np.array (optional)
+ Array containing initial fitting coefficients for the amorphous fit.
+ mask_dp: np.array
+ Dark field mask for fitting, in addition to the radial range specified above.
+ show_fit_mask: bool
+ Set to true to preview the fitting mask and initial guess for the ellipse params
+ maxfev: int
+ Max number of fitting evaluations for curve_fit.
+ verbose: bool
+ Print fit results
+ plot_result: bool
+ Plot the result of the fitting
+ plot_log_scale: bool
+ Plot logarithmic image intensities
+ plot_int_scale: tuple of 2 values
+ Min and max plotting range in standard deviations of image intensity
+ figsize: tuple, list, np.array (optional)
+ Figure size for plots
+ return_all_coefs: bool
+ Set to True to return the 11 parameter fit, rather than the 5 parameter ellipse
+
+ Returns
+ --------
+ params_ellipse: np.array
+ 5 parameter elliptic fit coefficients
+ params_ellipse_fit: np.array (optional)
+ 11 parameter elliptic fit coefficients
+ """
+
+ # Default values
+ if center is None:
+ center = np.array(((im.shape[0] - 1) / 2, (im.shape[1] - 1) / 2))
+ if radial_range is None:
+ radial_range = (im.shape[0] / 4, im.shape[0] / 2)
+
+ # coordinates
+ xa, ya = np.meshgrid(
+ np.arange(im.shape[0]),
+ np.arange(im.shape[1]),
+ indexing="ij",
+ )
+
+ # Make fitting mask
+ ra2 = (xa - center[0]) ** 2 + (ya - center[1]) ** 2
+ mask = np.logical_and(
+ ra2 >= radial_range[0] ** 2,
+ ra2 <= radial_range[1] ** 2,
+ )
+ if mask_dp is not None:
+ # Logical AND the radial mask with the user-provided mask
+ mask = np.logical_and(mask, mask_dp)
+ vals = im[mask]
+ basis = np.vstack((xa[mask], ya[mask]))
+
+ # initial fitting parameters
+ if coefs is None:
+ # ellipse parameters
+ x0 = center[0]
+ y0 = center[1]
+ R_mean = np.mean(radial_range)
+ # A = 1/R_mean**2
+ # B = 0
+ # C = 1/R_mean**2
+ a = R_mean
+ b = R_mean
+ t = 0
+
+ # Gaussian model parameters
+ int_min = np.min(vals)
+ int_max = np.max(vals)
+ int0 = (int_max - int_min) / 2
+ int12 = (int_max - int_min) / 2
+ k_bg = int_min
+ sigma0 = np.mean(radial_range)
+ sigma1 = (radial_range[1] - radial_range[0]) / 4
+ sigma2 = (radial_range[1] - radial_range[0]) / 4
+
+ coefs = (x0, y0, a, b, t, int0, int12, k_bg, sigma0, sigma1, sigma2)
+ lb = (0, 0, radial_range[0], radial_range[0], -np.inf, 0, 0, 0, 1, 1, 1)
+ ub = (
+ im.shape[0],
+ im.shape[1],
+ radial_range[1],
+ radial_range[1],
+ np.inf,
+ np.inf,
+ np.inf,
+ np.inf,
+ np.inf,
+ np.inf,
+ np.inf,
+ )
+
+ if show_fit_mask:
+ # show image preview of fitting mask
+
+ # Generate hybrid image for plotting
+ if plot_log_scale:
+ int_med = np.median(np.log(vals))
+ int_std = np.sqrt(np.median((np.log(vals) - int_med) ** 2))
+ int_range = (
+ int_med + plot_int_scale[0] * int_std,
+ int_med + plot_int_scale[1] * int_std,
+ )
+ im_plot = np.tile(
+ np.clip(
+ (np.log(im[:, :, None]) - int_range[0])
+ / (int_range[1] - int_range[0]),
+ 0,
+ 1,
+ ),
+ (1, 1, 3),
+ )
+
+ else:
+ int_med = np.median(vals)
+ int_std = np.sqrt(np.median((vals - int_med) ** 2))
+ int_range = (
+ int_med + plot_int_scale[0] * int_std,
+ int_med + plot_int_scale[1] * int_std,
+ )
+ im_plot = np.tile(
+ np.clip(
+ (im[:, :, None] - int_range[0]) / (int_range[1] - int_range[0]),
+ 0,
+ 1,
+ ),
+ (1, 1, 3),
+ )
+ im_plot[:, :, 0] *= 1 - mask
+
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.imshow(im_plot)
+
+ else:
+ # Perform elliptic fitting
+ int_mean = np.mean(vals)
+
+ if maxfev is None:
+ coefs = curve_fit(
+ amorphous_model,
+ basis,
+ vals / int_mean,
+ p0=coefs,
+ xtol=1e-8,
+ bounds=(lb, ub),
+ )[0]
+ else:
+ coefs = curve_fit(
+ amorphous_model,
+ basis,
+ vals / int_mean,
+ p0=coefs,
+ xtol=1e-8,
+ bounds=(lb, ub),
+ maxfev=maxfev,
+ )[0]
+ coefs[4] = np.mod(coefs[4], 2 * np.pi)
+ coefs[5:8] *= int_mean
+ # bounds=bounds
+
+ if verbose:
+ print("x0 = " + str(np.round(coefs[0], 3)) + " px")
+ print("y0 = " + str(np.round(coefs[1], 3)) + " px")
+ print("a = " + str(np.round(coefs[2], 3)) + " px")
+ print("b = " + str(np.round(coefs[3], 3)) + " px")
+ print("t = " + str(np.round(np.rad2deg(coefs[4]), 3)) + " deg")
+
+ if plot_result and not show_fit_mask:
+ plot_amorphous_ring(
+ im=im,
+ coefs=coefs,
+ radial_range=radial_range,
+ plot_log_scale=plot_log_scale,
+ plot_int_scale=plot_int_scale,
+ figsize=figsize,
+ )
+
+ # Return fit parameters
+ if return_all_coefs:
+ return coefs
+ else:
+ return coefs[:5]
+
+
+def plot_amorphous_ring(
+ im,
+ coefs,
+ radial_range=(0, np.inf),
+ plot_log_scale=True,
+ plot_int_scale=(-3, 3),
+ figsize=(8, 8),
+):
+ """
+ Fit an amorphous halo with a two-sided Gaussian model, plus a background
+ Gaussian function.
+
+ Parameters
+ --------
+ im: np.array
+ 2D image array to perform fitting on
+ coefs: np.array
+ all fitting coefficients
+ plot_log_scale: bool
+ Plot logarithmic image intensities
+ plot_int_scale: tuple of 2 values
+ Min and max plotting range in standard deviations of image intensity
+ figsize: tuple, list, np.array (optional)
+ Figure size for plots
+ return_all_coefs: bool
+ Set to True to return the 11 parameter fit, rather than the 5 parameter ellipse
+
+ Returns
+ --------
+
+ """
+
+ # get needed coefs
+ center = coefs[0:2]
+
+ # coordinates
+ xa, ya = np.meshgrid(
+ np.arange(im.shape[0]),
+ np.arange(im.shape[1]),
+ indexing="ij",
+ )
+
+ # Make fitting mask
+ ra2 = (xa - center[0]) ** 2 + (ya - center[1]) ** 2
+ mask = np.logical_and(
+ ra2 >= radial_range[0] ** 2,
+ ra2 <= radial_range[1] ** 2,
+ )
+ vals = im[mask]
+ basis = np.vstack((xa[mask], ya[mask]))
+
+ # Generate resulting best fit image
+ im_fit = np.reshape(
+ amorphous_model(np.vstack((xa.ravel(), ya.ravel())), coefs), im.shape
+ )
+
+ # plotting arrays
+ phi = np.linspace(0, 2 * np.pi, 360)
+ cp = np.cos(phi)
+ sp = np.sin(phi)
+
+ # plotting intensity range
+ if plot_log_scale:
+ int_med = np.median(np.log(vals))
+ int_std = np.sqrt(np.median((np.log(vals) - int_med) ** 2))
+ int_range = (
+ int_med + plot_int_scale[0] * int_std,
+ int_med + plot_int_scale[1] * int_std,
+ )
+ im_plot = np.tile(
+ np.clip(
+ (np.log(im[:, :, None]) - int_range[0]) / (int_range[1] - int_range[0]),
+ 0,
+ 1,
+ ),
+ (1, 1, 3),
+ )
+ else:
+ int_med = np.median(vals)
+ int_std = np.sqrt(np.median((vals - int_med) ** 2))
+ int_range = (
+ int_med + plot_int_scale[0] * int_std,
+ int_med + plot_int_scale[1] * int_std,
+ )
+ im_plot = np.clip(
+ (im[:, :, None] - int_range[0]) / (int_range[1] - int_range[0]), 0, 1
+ )
+ # vals_mean = np.mean(vals)
+ # vals_std = np.std(vals)
+ # vmin = vals_mean -
+
+ # plotting
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.imshow(
+ im_plot,
+ vmin=0,
+ vmax=1,
+ cmap="gray",
+ )
+
+ x0 = coefs[0]
+ y0 = coefs[1]
+ a = coefs[2]
+ b = coefs[3]
+ t = coefs[4]
+ s1 = coefs[9]
+ s2 = coefs[10]
+
+ ax.plot(
+ y0 + np.array((-1, 1)) * a * np.sin(t),
+ x0 + np.array((-1, 1)) * a * np.cos(t),
+ c="r",
+ )
+ ax.plot(
+ y0 + np.array((-1, 1)) * b * np.cos(t),
+ x0 + np.array((1, -1)) * b * np.sin(t),
+ c="r",
+ linestyle="dashed",
+ )
+
+ ax.plot(
+ y0 + a * np.sin(t) * cp + b * np.cos(t) * sp,
+ x0 + a * np.cos(t) * cp - b * np.sin(t) * sp,
+ c="r",
+ )
+ scale = 1 - s1 / a
+ ax.plot(
+ y0 + scale * a * np.sin(t) * cp + scale * b * np.cos(t) * sp,
+ x0 + scale * a * np.cos(t) * cp - scale * b * np.sin(t) * sp,
+ c="r",
+ linestyle="dashed",
+ )
+ scale = 1 + s2 / a
+ ax.plot(
+ y0 + scale * a * np.sin(t) * cp + scale * b * np.cos(t) * sp,
+ x0 + scale * a * np.cos(t) * cp - scale * b * np.sin(t) * sp,
+ c="r",
+ linestyle="dashed",
+ )
+ ax.set_xlim((0, im.shape[1] - 1))
+ ax.set_ylim((im.shape[0] - 1, 0))
+
+
+def amorphous_model(basis, *coefs):
+ coefs = np.squeeze(np.array(coefs))
+
+ x0 = coefs[0]
+ y0 = coefs[1]
+ a = coefs[2]
+ b = coefs[3]
+ t = coefs[4]
+ # A = coefs[2]
+ # B = coefs[3]
+ # C = coefs[4]
+ int0 = coefs[5]
+ int12 = coefs[6]
+ k_bg = coefs[7]
+ sigma0 = coefs[8]
+ sigma1 = coefs[9]
+ sigma2 = coefs[10]
+
+ x0 = basis[0, :] - x0
+ y0 = basis[1, :] - y0
+ x = np.cos(t) * x0 - (b / a) * np.sin(t) * y0
+ y = np.sin(t) * x0 + (b / a) * np.cos(t) * y0
+
+ r2 = x**2 + y**2
+ dr = np.sqrt(r2) - b
+ dr2 = dr**2
+ sub = dr < 0
+
+ int_model = k_bg + int0 * np.exp(r2 / (-2 * sigma0**2))
+ int_model[sub] += int12 * np.exp(dr2[sub] / (-2 * sigma1**2))
+ sub = np.logical_not(sub)
+ int_model[sub] += int12 * np.exp(dr2[sub] / (-2 * sigma2**2))
+
+ return int_model
diff --git a/py4DSTEM/process/polar/polar_peaks.py b/py4DSTEM/process/polar/polar_peaks.py
new file mode 100644
index 000000000..4064fccaf
--- /dev/null
+++ b/py4DSTEM/process/polar/polar_peaks.py
@@ -0,0 +1,1407 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+from scipy.ndimage import gaussian_filter, gaussian_filter1d
+from scipy.signal import peak_prominences
+from skimage.feature import peak_local_max
+from scipy.optimize import curve_fit, leastsq
+import warnings
+
+# from emdfile import tqdmnd, PointList, PointListArray
+from py4DSTEM import tqdmnd, PointList, PointListArray
+from py4DSTEM.process.fit import (
+ polar_twofold_gaussian_2D,
+ polar_twofold_gaussian_2D_background,
+)
+
+
+def find_peaks_single_pattern(
+ self,
+ x,
+ y,
+ mask=None,
+ bragg_peaks=None,
+ bragg_mask_radius=None,
+ sigma_annular_deg=10.0,
+ sigma_radial_px=3.0,
+ sigma_annular_deg_max=None,
+ radial_background_subtract=True,
+ radial_background_thresh=0.25,
+ num_peaks_max=100,
+ threshold_abs=1.0,
+ threshold_prom_annular=None,
+ threshold_prom_radial=None,
+ remove_masked_peaks=False,
+ scale_sigma_annular=0.5,
+ scale_sigma_radial=0.25,
+ return_background=False,
+ plot_result=True,
+ plot_power_scale=1.0,
+ plot_scale_size=10.0,
+ figsize=(12, 6),
+ returnfig=False,
+):
+ """
+ Peak detection function for polar transformations.
+
+ Parameters
+ --------
+ x: int
+ x index of diffraction pattern
+ y: int
+ y index of diffraction pattern
+ mask: np.array
+ Boolean mask in Cartesian space, to filter detected peaks.
+ bragg_peaks: py4DSTEM.BraggVectors
+ Set of Bragg peaks used to generated a mask in Cartesian space, to filter detected peaks
+ sigma_annular_deg: float
+ smoothing along the annular direction in degrees, periodic
+ sigma_radial_px: float
+ smoothing along the radial direction in pixels, not periodic
+ sigma_annular_deg_max: float
+ Specify this value for the max annular sigma. Peaks larger than this will be split
+ into multiple peaks, depending on the ratio.
+ radial_background_subtract: bool
+ If true, subtract radial background estimate
+ radial_background_thresh: float
+ Relative order of sorted values to use as background estimate.
+ Setting to 0.5 is equivalent to median, 0.0 is min value.
+ num_peaks_max = 100
+ Max number of peaks to return.
+ threshold_abs: float
+ Absolute image intensity threshold for peaks.
+ threshold_prom_annular: float
+ Threshold for prominance, along annular direction.
+ threshold_prom_radial: float
+ Threshold for prominance, along radial direction.
+ remove_masked_peaks: bool
+ Delete peaks that are in the region masked by "mask"
+ scale_sigma_annular: float
+ Scaling of the estimated annular standard deviation.
+ scale_sigma_radial: float
+ Scaling of the estimated radial standard deviation.
+ return_background: bool
+ Return the background signal.
+ plot_result:
+ Plot the detector peaks
+ plot_power_scale: float
+ Image intensity power law scaling.
+ plot_scale_size: float
+ Marker scaling in the plot.
+ figsize: 2-tuple
+ Size of the result plotting figure.
+ returnfig: bool
+ Return the figure and axes handles.
+
+ Returns
+ --------
+
+ peaks_polar : pointlist
+ The detected peaks
+ fig, ax : (optional)
+ Figure and axes handles
+
+ """
+
+ # if needed, generate mask from Bragg peaks
+ if bragg_peaks is not None:
+ mask_bragg = self._datacube.get_braggmask(
+ bragg_peaks,
+ x,
+ y,
+ radius=bragg_mask_radius,
+ )
+ if mask is None:
+ mask = mask_bragg
+ else:
+ mask = np.logical_or(mask, mask_bragg)
+
+ # Convert sigma values into units of bins
+ sigma_annular = np.deg2rad(sigma_annular_deg) / self.annular_step
+ sigma_radial = sigma_radial_px / self.qstep
+
+ # Get transformed image and normalization array
+ im_polar, im_polar_norm, norm_array, mask_bool = self.transform(
+ self._datacube.data[x, y],
+ mask=mask,
+ returnval="all_zeros",
+ )
+ # Change sign convention of mask
+ mask_bool = np.logical_not(mask_bool)
+
+ # Background subtraction
+ if radial_background_subtract:
+ sig_bg = np.zeros(im_polar.shape[1])
+ for a0 in range(im_polar.shape[1]):
+ if np.any(mask_bool[:, a0]):
+ vals = np.sort(im_polar[mask_bool[:, a0], a0])
+ ind = np.round(radial_background_thresh * (vals.shape[0] - 1)).astype(
+ "int"
+ )
+ sig_bg[a0] = vals[ind]
+ sig_bg_mask = np.sum(mask_bool, axis=0) >= (im_polar.shape[0] // 2)
+ im_polar = np.maximum(im_polar - sig_bg[None, :], 0)
+
+ # apply smoothing and normalization
+ im_polar_sm = gaussian_filter(
+ im_polar * norm_array,
+ sigma=(sigma_annular, sigma_radial),
+ mode=("wrap", "nearest"),
+ )
+ im_mask = gaussian_filter(
+ norm_array,
+ sigma=(sigma_annular, sigma_radial),
+ mode=("wrap", "nearest"),
+ )
+ sub = im_mask > 0.001 * np.max(im_mask)
+ im_polar_sm[sub] /= im_mask[sub]
+
+ # Find local maxima
+ peaks = peak_local_max(
+ im_polar_sm,
+ num_peaks=num_peaks_max,
+ threshold_abs=threshold_abs,
+ )
+
+ # check if peaks should be removed from the polar transformation mask
+ if remove_masked_peaks:
+ peaks = np.delete(
+ peaks,
+ mask_bool[peaks[:, 0], peaks[:, 1]] == False, # noqa: E712
+ axis=0,
+ )
+
+ # peak intensity
+ peaks_int = im_polar_sm[peaks[:, 0], peaks[:, 1]]
+
+ # Estimate prominance of peaks, and their size in units of pixels
+ peaks_prom = np.zeros((peaks.shape[0], 4))
+ annular_ind_center = np.atleast_1d(
+ np.array(im_polar_sm.shape[0] // 2).astype("int")
+ )
+ for a0 in range(peaks.shape[0]):
+ # annular
+ trace_annular = np.roll(
+ np.squeeze(im_polar_sm[:, peaks[a0, 1]]), annular_ind_center - peaks[a0, 0]
+ )
+ p_annular = peak_prominences(
+ trace_annular,
+ annular_ind_center,
+ )
+ sigma_annular = scale_sigma_annular * np.minimum(
+ annular_ind_center - p_annular[1], p_annular[2] - annular_ind_center
+ )
+
+ # radial
+ trace_radial = im_polar_sm[peaks[a0, 0], :]
+ p_radial = peak_prominences(
+ trace_radial,
+ np.atleast_1d(peaks[a0, 1]),
+ )
+ sigma_radial = scale_sigma_radial * np.minimum(
+ peaks[a0, 1] - p_radial[1], p_radial[2] - peaks[a0, 1]
+ )
+
+ # output
+ peaks_prom[a0, 0] = p_annular[0]
+ peaks_prom[a0, 1] = sigma_annular[0]
+ peaks_prom[a0, 2] = p_radial[0]
+ peaks_prom[a0, 3] = sigma_radial[0]
+
+ # if needed, remove peaks using prominance criteria
+ if threshold_prom_annular is not None:
+ remove = peaks_prom[:, 0] < threshold_prom_annular
+ peaks = np.delete(
+ peaks,
+ remove,
+ axis=0,
+ )
+ peaks_int = np.delete(
+ peaks_int,
+ remove,
+ )
+ peaks_prom = np.delete(
+ peaks_prom,
+ remove,
+ axis=0,
+ )
+ if threshold_prom_radial is not None:
+ remove = peaks_prom[:, 2] < threshold_prom_radial
+ peaks = np.delete(
+ peaks,
+ remove,
+ axis=0,
+ )
+ peaks_int = np.delete(
+ peaks_int,
+ remove,
+ )
+ peaks_prom = np.delete(
+ peaks_prom,
+ remove,
+ axis=0,
+ )
+
+ # combine peaks into one array
+ peaks_all = np.column_stack((peaks, peaks_int, peaks_prom))
+
+ # Split peaks into multiple peaks if they have sigma values larger than user-specified threshold
+ if sigma_annular_deg_max is not None:
+ peaks_new = np.zeros((0, peaks_all.shape[1]))
+ for a0 in range(peaks_all.shape[0]):
+ if peaks_all[a0, 4] >= (1.5 * sigma_annular_deg_max):
+ num = np.round(peaks_all[a0, 4] / sigma_annular_deg_max)
+ sigma_annular_new = peaks_all[a0, 4] / num
+
+ v = np.arange(num)
+ v -= np.mean(v)
+ t_new = np.mod(
+ peaks_all[a0, 0] + 2 * v * sigma_annular_new, self._n_annular
+ )
+
+ for a1 in range(num.astype("int")):
+ peaks_new = np.vstack(
+ (
+ peaks_new,
+ np.array(
+ (
+ t_new[a1],
+ peaks_all[a0, 1],
+ peaks_all[a0, 2],
+ peaks_all[a0, 3],
+ sigma_annular_new,
+ peaks_all[a0, 5],
+ peaks_all[a0, 6],
+ )
+ ),
+ )
+ )
+ else:
+ peaks_new = np.vstack((peaks_new, peaks_all[a0, :]))
+ peaks_all = peaks_new
+
+ # Output data as a pointlist
+ peaks_polar = PointList(
+ peaks_all.ravel().view(
+ [
+ ("qt", float),
+ ("qr", float),
+ ("intensity", float),
+ ("prom_annular", float),
+ ("sigma_annular", float),
+ ("prom_radial", float),
+ ("sigma_radial", float),
+ ]
+ ),
+ name="peaks_polar",
+ )
+
+ if plot_result:
+ # init
+ im_plot = im_polar.copy()
+ im_plot = np.maximum(im_plot, 0) ** plot_power_scale
+
+ t = np.linspace(0, 2 * np.pi, 180 + 1)
+ ct = np.cos(t)
+ st = np.sin(t)
+
+ fig, ax = plt.subplots(figsize=figsize)
+
+ ax.imshow(
+ im_plot,
+ cmap="gray",
+ )
+
+ # peaks
+ ax.scatter(
+ peaks_polar["qr"],
+ peaks_polar["qt"],
+ s=peaks_polar["intensity"] * plot_scale_size,
+ marker="o",
+ color=(1, 0, 0),
+ )
+ for a0 in range(peaks_polar.data.shape[0]):
+ ax.plot(
+ peaks_polar["qr"][a0] + st * peaks_polar["sigma_radial"][a0],
+ peaks_polar["qt"][a0] + ct * peaks_polar["sigma_annular"][a0],
+ linewidth=1,
+ color="r",
+ )
+ if peaks_polar["qt"][a0] - peaks_polar["sigma_annular"][a0] < 0:
+ ax.plot(
+ peaks_polar["qr"][a0] + st * peaks_polar["sigma_radial"][a0],
+ peaks_polar["qt"][a0]
+ + ct * peaks_polar["sigma_annular"][a0]
+ + im_plot.shape[0],
+ linewidth=1,
+ color="r",
+ )
+ if (
+ peaks_polar["qt"][a0] + peaks_polar["sigma_annular"][a0]
+ > im_plot.shape[0]
+ ):
+ ax.plot(
+ peaks_polar["qr"][a0] + st * peaks_polar["sigma_radial"][a0],
+ peaks_polar["qt"][a0]
+ + ct * peaks_polar["sigma_annular"][a0]
+ - im_plot.shape[0],
+ linewidth=1,
+ color="r",
+ )
+
+ # plot appearance
+ ax.set_xlim((0, im_plot.shape[1] - 1))
+ ax.set_ylim((im_plot.shape[0] - 1, 0))
+
+ if returnfig and plot_result:
+ if return_background:
+ return peaks_polar, sig_bg, sig_bg_mask, fig, ax
+ else:
+ return peaks_polar, fig, ax
+ else:
+ if return_background:
+ return peaks_polar, sig_bg, sig_bg_mask
+ else:
+ return peaks_polar
+
+
+def find_peaks(
+ self,
+ mask=None,
+ bragg_peaks=None,
+ bragg_mask_radius=None,
+ sigma_annular_deg=10.0,
+ sigma_radial_px=3.0,
+ sigma_annular_deg_max=None,
+ radial_background_subtract=True,
+ radial_background_thresh=0.25,
+ num_peaks_max=100,
+ threshold_abs=1.0,
+ threshold_prom_annular=None,
+ threshold_prom_radial=None,
+ remove_masked_peaks=False,
+ scale_sigma_annular=0.5,
+ scale_sigma_radial=0.25,
+ progress_bar=True,
+):
+ """
+ Peak detection function for polar transformations. Loop through all probe positions,
+ find peaks. Store the peak positions and background signals.
+
+ Parameters
+ --------
+ sigma_annular_deg: float
+ smoothing along the annular direction in degrees, periodic
+ sigma_radial_px: float
+ smoothing along the radial direction in pixels, not periodic
+
+ Returns
+ --------
+
+ """
+
+ # init
+ self.bragg_peaks = bragg_peaks
+ self.bragg_mask_radius = bragg_mask_radius
+ self.peaks = PointListArray(
+ dtype=[
+ ("qt", " min_num_pixels_fit:
+ try:
+ # perform fitting
+ p0, pcov = curve_fit(
+ polar_twofold_gaussian_2D,
+ tq[:, mask_peak.ravel()],
+ im_polar[mask_peak],
+ p0=p0,
+ # bounds = bounds,
+ )
+
+ # Output parameters
+ self.peaks[rx, ry]["intensity"][a0] = p0[0]
+ self.peaks[rx, ry]["qt"][a0] = p0[1] / t_step
+ self.peaks[rx, ry]["qr"][a0] = p0[2] / q_step
+ self.peaks[rx, ry]["sigma_annular"][a0] = p0[3] / t_step
+ self.peaks[rx, ry]["sigma_radial"][a0] = p0[4] / q_step
+
+ except:
+ pass
+
+ else:
+ # initial parameters
+ p0 = [
+ p["intensity"][a0],
+ p["qt"][a0] * t_step,
+ p["qr"][a0] * q_step,
+ p["sigma_annular"][a0] * t_step,
+ p["sigma_radial"][a0] * q_step,
+ 0,
+ ]
+
+ # Mask around peak for fitting
+ dt = np.mod(tt - p0[1] + np.pi / 2, np.pi) - np.pi / 2
+ mask_peak = np.logical_and(
+ mask_bool,
+ dt**2 / (fit_range_sigma_annular * p0[3]) ** 2
+ + (qq - p0[2]) ** 2 / (fit_range_sigma_radial * p0[4]) ** 2
+ <= 1,
+ )
+
+ if np.sum(mask_peak) > min_num_pixels_fit:
+ try:
+ # perform fitting
+ p0, pcov = curve_fit(
+ polar_twofold_gaussian_2D_background,
+ tq[:, mask_peak.ravel()],
+ im_polar[mask_peak],
+ p0=p0,
+ # bounds = bounds,
+ )
+
+ # Output parameters
+ self.peaks[rx, ry]["intensity"][a0] = p0[0]
+ self.peaks[rx, ry]["qt"][a0] = p0[1] / t_step
+ self.peaks[rx, ry]["qr"][a0] = p0[2] / q_step
+ self.peaks[rx, ry]["sigma_annular"][a0] = p0[3] / t_step
+ self.peaks[rx, ry]["sigma_radial"][a0] = p0[4] / q_step
+
+ except:
+ pass
+
+
+def plot_radial_peaks(
+ self,
+ q_pixel_units=False,
+ qmin=None,
+ qmax=None,
+ qstep=None,
+ label_y_axis=False,
+ figsize=(8, 4),
+ returnfig=False,
+):
+ """
+ Calculate and plot the total peak signal as a function of the radial coordinate.
+
+ """
+
+ # Get all peak data
+ vects = np.concatenate(
+ [
+ self.peaks[i, j].data
+ for i in range(self._datacube.Rshape[0])
+ for j in range(self._datacube.Rshape[1])
+ ]
+ )
+ if q_pixel_units:
+ qr = vects["qr"]
+ else:
+ qr = (vects["qr"] + self.qmin) * self._radial_step
+ intensity = vects["intensity"]
+
+ # bins
+ if qmin is None:
+ qmin = self.qq[0]
+ if qmax is None:
+ qmax = self.qq[-1]
+ if qstep is None:
+ qstep = self.qq[1] - self.qq[0]
+ q_bins = np.arange(qmin, qmax, qstep)
+ q_num = q_bins.shape[0]
+ if q_pixel_units:
+ q_bins /= self._radial_step
+
+ # histogram
+ q_ind = (qr - q_bins[0]) / (q_bins[1] - q_bins[0])
+ qf = np.floor(q_ind).astype("int")
+ dq = q_ind - qf
+
+ sub = np.logical_and(qf >= 0, qf < q_num)
+ int_peaks = np.bincount(
+ np.floor(q_ind[sub]).astype("int"),
+ weights=(1 - dq[sub]) * intensity[sub],
+ minlength=q_num,
+ )
+ sub = np.logical_and(q_ind >= -1, q_ind < q_num - 1)
+ int_peaks += np.bincount(
+ np.floor(q_ind[sub] + 1).astype("int"),
+ weights=dq[sub] * intensity[sub],
+ minlength=q_num,
+ )
+
+ # plotting
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.plot(
+ q_bins,
+ int_peaks,
+ color="r",
+ linewidth=2,
+ )
+ ax.set_xlim((q_bins[0], q_bins[-1]))
+ if q_pixel_units:
+ ax.set_xlabel(
+ "Scattering Angle (pixels)",
+ fontsize=14,
+ )
+ else:
+ ax.set_xlabel(
+ "Scattering Angle (" + self.calibration.get_Q_pixel_units() + ")",
+ fontsize=14,
+ )
+ ax.set_ylabel(
+ "Total Peak Signal",
+ fontsize=14,
+ )
+ if not label_y_axis:
+ ax.tick_params(left=False, labelleft=False)
+
+ if returnfig:
+ return fig, ax
+
+
+def model_radial_background(
+ self,
+ ring_position=None,
+ ring_sigma=None,
+ ring_int=None,
+ refine_model=True,
+ plot_result=True,
+ figsize=(8, 4),
+):
+ """
+ User provided radial background model, of the form:
+
+ int = int_const
+ + int_0 * exp( - q**2 / (2*s0**2) )
+ + int_1 * exp( - (q - q_1)**2 / (2*s1**2) )
+ + ...
+ + int_n * exp( - (q - q_n)**2 / (2*sn**2) )
+
+ where n is the number of amorphous halos / rings included in the fit.
+
+ """
+
+ # Get mean radial background and mask
+ self.background_radial_mean = np.sum(
+ self.background_radial * self.background_radial_mask, axis=(0, 1)
+ )
+ background_radial_mean_norm = np.sum(self.background_radial_mask, axis=(0, 1))
+ self.background_mask = background_radial_mean_norm > (
+ np.max(background_radial_mean_norm) * 0.05
+ )
+ self.background_radial_mean[self.background_mask] /= background_radial_mean_norm[
+ self.background_mask
+ ]
+ self.background_radial_mean[np.logical_not(self.background_mask)] = 0
+
+ # init
+ if ring_position is not None:
+ ring_position = np.atleast_1d(np.array(ring_position))
+ num_rings = ring_position.shape[0]
+ else:
+ num_rings = 0
+ self.background_coefs = np.zeros(3 + 3 * num_rings)
+
+ if ring_sigma is None:
+ ring_sigma = (
+ np.atleast_1d(np.ones(num_rings))
+ * self.polar_shape[1]
+ * 0.05
+ * self._radial_step
+ )
+ else:
+ ring_sigma = np.atleast_1d(np.array(ring_sigma))
+
+ # Background model initial parameters
+ int_const = np.min(self.background_radial_mean)
+ int_0 = np.max(self.background_radial_mean) - int_const
+ sigma_0 = self.polar_shape[1] * 0.25 * self._radial_step
+ self.background_coefs[0] = int_const
+ self.background_coefs[1] = int_0
+ self.background_coefs[2] = sigma_0
+
+ # Additional Gaussians
+ if ring_int is None:
+ # Estimate peak intensities
+ sig_0 = int_const + int_0 * np.exp(self.qq**2 / (-2 * sigma_0**2))
+ sig_peaks = np.maximum(self.background_radial_mean - sig_0, 0.0)
+
+ ring_int = np.atleast_1d(np.zeros(num_rings))
+ for a0 in range(num_rings):
+ ind = np.argmin(np.abs(self.qq - ring_position[a0]))
+ ring_int[a0] = sig_peaks[ind]
+
+ else:
+ ring_int = np.atleast_1d(np.array(ring_int))
+ for a0 in range(num_rings):
+ self.background_coefs[3 * a0 + 3] = ring_int[a0]
+ self.background_coefs[3 * a0 + 4] = ring_sigma[a0]
+ self.background_coefs[3 * a0 + 5] = ring_position[a0]
+ lb = np.zeros_like(self.background_coefs)
+ ub = np.ones_like(self.background_coefs) * np.inf
+
+ # Create background model
+ def background_model(q, *coefs):
+ coefs = np.squeeze(np.array(coefs))
+ num_rings = np.round((coefs.shape[0] - 3) / 3).astype("int")
+
+ sig = np.ones(q.shape[0]) * coefs[0]
+ sig += coefs[1] * np.exp(q**2 / (-2 * coefs[2] ** 2))
+
+ for a0 in range(num_rings):
+ sig += coefs[3 * a0 + 3] * np.exp(
+ (q - coefs[3 * a0 + 5]) ** 2 / (-2 * coefs[3 * a0 + 4] ** 2)
+ )
+
+ return sig
+
+ self.background_model = background_model
+
+ # Refine background model coefficients
+ if refine_model:
+ self.background_coefs = curve_fit(
+ self.background_model,
+ self.qq[self.background_mask],
+ self.background_radial_mean[self.background_mask],
+ p0=self.background_coefs,
+ xtol=1e-12,
+ bounds=(lb, ub),
+ )[0]
+
+ # plotting
+ if plot_result:
+ self.plot_radial_background(
+ q_pixel_units=False,
+ plot_background_model=True,
+ figsize=figsize,
+ )
+
+
+def refine_peaks(
+ self,
+ mask=None,
+ # reset_fits_to_init_positions = False,
+ scale_sigma_estimate=0.5,
+ min_num_pixels_fit=10,
+ maxfev=None,
+ progress_bar=True,
+):
+ """
+ Use global fitting model for all images. Requires an background model
+ specified with self.model_radial_background().
+
+ TODO: add fitting reset
+ add min number pixels condition
+ track any failed fitting points, output as a boolean array
+
+ Parameters
+ --------
+ mask: np.array
+ Mask image to apply to all images
+ radial_background_subtract: bool
+ Subtract radial background before fitting
+ reset_fits_to_init_positions: bool
+ Use the initial peak parameters for fitting
+ scale_sigma_estimate: float
+ Factor to reduce sigma of peaks by, to prevent fit from running away.
+ min_num_pixels_fit: int
+ Minimum number of pixels to perform fitting
+ maxfev: int
+ Maximum number of iterations in fit. Set to a low number for a fast fit.
+ progress_bar: bool
+ Enable progress bar
+
+ Returns
+ --------
+
+ """
+
+ # coordinate scaling
+ t_step = self._annular_step
+ q_step = self._radial_step
+
+ # Background model params
+ num_rings = np.round((self.background_coefs.shape[0] - 3) / 3).astype("int")
+
+ # basis
+ qq, tt = np.meshgrid(
+ self.qq,
+ self.tt,
+ )
+ basis = np.zeros((qq.size, 3))
+ basis[:, 0] = tt.ravel()
+ basis[:, 1] = qq.ravel()
+ basis[:, 2] = num_rings
+
+ # init
+ self.peaks_refine = PointListArray(
+ dtype=[
+ ("qt", "float"),
+ ("qr", "float"),
+ ("intensity", "float"),
+ ("sigma_annular", "float"),
+ ("sigma_radial", "float"),
+ ],
+ shape=self._datacube.Rshape,
+ name="peaks_polardata_refined",
+ )
+ self.background_refine = np.zeros(
+ (
+ self._datacube.Rshape[0],
+ self._datacube.Rshape[1],
+ np.round(3 * num_rings + 3).astype("int"),
+ )
+ )
+
+ # Main loop over probe positions
+ for rx, ry in tqdmnd(
+ self._datacube.shape[0],
+ self._datacube.shape[1],
+ desc="Refining peaks ",
+ unit=" probe positions",
+ disable=not progress_bar,
+ ):
+ # Get transformed image and normalization array
+ im_polar, im_polar_norm, norm_array, mask_bool = self.transform(
+ self._datacube.data[rx, ry],
+ mask=mask,
+ returnval="all_zeros",
+ )
+ # Change sign convention of mask
+ mask_bool = np.logical_not(mask_bool)
+
+ # Get initial peaks, in dimensioned units
+ p = self.peaks[rx, ry]
+ qt = p.data["qt"] * t_step
+ qr = (p.data["qr"] + self.qmin) * q_step
+ int_peaks = p.data["intensity"]
+ s_annular = p.data["sigma_annular"] * t_step
+ s_radial = p.data["sigma_radial"] * q_step
+ num_peaks = p["qt"].shape[0]
+
+ # unified coefficients
+ # Note we sharpen sigma estimate for refinement
+ coefs_all = np.hstack(
+ (
+ self.background_coefs,
+ qt,
+ qr,
+ int_peaks,
+ s_annular * scale_sigma_estimate,
+ s_radial * scale_sigma_estimate,
+ )
+ )
+
+ # bounds
+ lb = np.zeros_like(coefs_all)
+ ub = np.ones_like(coefs_all) * np.inf
+
+ # Construct fitting model
+ def fit_image(basis, *coefs):
+ coefs = np.squeeze(np.array(coefs))
+
+ num_rings = np.round(basis[0, 2]).astype("int")
+ num_peaks = np.round((coefs.shape[0] - (3 * num_rings + 3)) / 5).astype(
+ "int"
+ )
+
+ coefs_bg = coefs[: (3 * num_rings + 3)]
+ coefs_peaks = coefs[(3 * num_rings + 3) :]
+
+ # Background
+ sig = self.background_model(basis[:, 1], coefs_bg)
+
+ # add peaks
+ for a0 in range(num_peaks):
+ dt = (
+ np.mod(
+ basis[:, 0] - coefs_peaks[num_peaks * 0 + a0] + np.pi / 2, np.pi
+ )
+ - np.pi / 2
+ )
+ dq = basis[:, 1] - coefs_peaks[num_peaks * 1 + a0]
+
+ sig += coefs_peaks[num_peaks * 2 + a0] * np.exp(
+ dt**2 / (-2 * coefs_peaks[num_peaks * 3 + a0] ** 2)
+ + dq**2 / (-2 * coefs_peaks[num_peaks * 4 + a0] ** 2)
+ )
+
+ return sig
+
+ # refine fitting model
+ try:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ if maxfev is None:
+ coefs_all = curve_fit(
+ fit_image,
+ basis[mask_bool.ravel(), :],
+ im_polar[mask_bool],
+ p0=coefs_all,
+ xtol=1e-12,
+ bounds=(lb, ub),
+ )[0]
+ else:
+ coefs_all = curve_fit(
+ fit_image,
+ basis[mask_bool.ravel(), :],
+ im_polar[mask_bool],
+ p0=coefs_all,
+ xtol=1e-12,
+ maxfev=maxfev,
+ bounds=(lb, ub),
+ )[0]
+
+ # Output refined peak parameters
+ coefs_peaks = np.reshape(coefs_all[(3 * num_rings + 3) :], (5, num_peaks)).T
+ self.peaks_refine[rx, ry] = PointList(
+ coefs_peaks.ravel().view(
+ [
+ ("qt", float),
+ ("qr", float),
+ ("intensity", float),
+ ("sigma_annular", float),
+ ("sigma_radial", float),
+ ]
+ ),
+ name="peaks_polar",
+ )
+ except:
+ # if fitting has failed, we will still output the last iteration
+ # TODO - add a flag for unconverged fits
+ coefs_peaks = np.reshape(coefs_all[(3 * num_rings + 3) :], (5, num_peaks)).T
+ self.peaks_refine[rx, ry] = PointList(
+ coefs_peaks.ravel().view(
+ [
+ ("qt", float),
+ ("qr", float),
+ ("intensity", float),
+ ("sigma_annular", float),
+ ("sigma_radial", float),
+ ]
+ ),
+ name="peaks_polar",
+ )
+
+ # mean background signal,
+ # # but none of the peaks.
+ # pass
+
+ # Output refined parameters for background
+ coefs_bg = coefs_all[: (3 * num_rings + 3)]
+ self.background_refine[rx, ry] = coefs_bg
+
+ # # Testing
+ # im_fit = np.reshape(
+ # fit_image(basis,coefs_all),
+ # self.polar_shape)
+
+ # fig,ax = plt.subplots(figsize=(8,6))
+ # ax.imshow(
+ # np.vstack((
+ # im_polar,
+ # im_fit,
+ # )),
+ # cmap = 'turbo',
+ # )
+
+
+def plot_radial_background(
+ self,
+ q_pixel_units=False,
+ label_y_axis=False,
+ plot_background_model=False,
+ figsize=(8, 4),
+ returnfig=False,
+):
+ """
+ Calculate and plot the mean background signal, background standard deviation.
+
+ """
+
+ # mean
+ self.background_radial_mean = np.sum(
+ self.background_radial * self.background_radial_mask, axis=(0, 1)
+ )
+ background_radial_mean_norm = np.sum(self.background_radial_mask, axis=(0, 1))
+ self.background_mask = background_radial_mean_norm > (
+ np.max(background_radial_mean_norm) * 0.05
+ )
+ self.background_radial_mean[self.background_mask] /= background_radial_mean_norm[
+ self.background_mask
+ ]
+ self.background_radial_mean[np.logical_not(self.background_mask)] = 0
+
+ # variance and standard deviation
+ self.background_radial_var = np.sum(
+ (self.background_radial - self.background_radial_mean[None, None, :]) ** 2
+ * self.background_radial_mask,
+ axis=(0, 1),
+ )
+ self.background_radial_var[self.background_mask] /= self.background_radial_var[
+ self.background_mask
+ ]
+ self.background_radial_var[np.logical_not(self.background_mask)] = 0
+ self.background_radial_std = np.sqrt(self.background_radial_var)
+
+ if q_pixel_units:
+ q_axis = np.arange(self.qq.shape[0])
+ else:
+ q_axis = self.qq[self.background_mask]
+
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.fill_between(
+ q_axis,
+ self.background_radial_mean[self.background_mask]
+ - self.background_radial_std[self.background_mask],
+ self.background_radial_mean[self.background_mask]
+ + self.background_radial_std[self.background_mask],
+ color="r",
+ alpha=0.2,
+ )
+ ax.plot(
+ q_axis,
+ self.background_radial_mean[self.background_mask],
+ color="r",
+ linewidth=2,
+ )
+
+ # overlay fitting model
+ if plot_background_model:
+ sig = self.background_model(
+ self.qq,
+ self.background_coefs,
+ )
+ ax.plot(q_axis, sig, color="k", linewidth=2, linestyle="--")
+
+ # plot appearance
+ ax.set_xlim((q_axis[0], q_axis[-1]))
+ if q_pixel_units:
+ ax.set_xlabel(
+ "Scattering Angle (pixels)",
+ fontsize=14,
+ )
+ else:
+ ax.set_xlabel(
+ "Scattering Angle (" + self.calibration.get_Q_pixel_units() + ")",
+ fontsize=14,
+ )
+ ax.set_ylabel(
+ "Background Signal",
+ fontsize=14,
+ )
+ if not label_y_axis:
+ ax.tick_params(left=False, labelleft=False)
+
+ if returnfig:
+ return fig, ax
+
+
+def make_orientation_histogram(
+ self,
+ radial_ranges: np.ndarray = None,
+ orientation_flip_sign: bool = False,
+ orientation_offset_degrees: float = 0.0,
+ orientation_separate_bins: bool = False,
+ upsample_factor: float = 4.0,
+ use_refined_peaks=True,
+ use_peak_sigma=False,
+ peak_sigma_samples=6,
+ theta_step_deg: float = None,
+ sigma_x: float = 1.0,
+ sigma_y: float = 1.0,
+ sigma_theta: float = 3.0,
+ normalize_intensity_image: bool = False,
+ normalize_intensity_stack: bool = True,
+ progress_bar: bool = True,
+):
+ """
+ Make an orientation histogram, in order to use flowline visualization of orientation maps.
+ Use peaks attached to polardatacube.
+
+ NOTE - currently assumes two fold rotation symmetry
+ TODO - add support for non two fold symmetry polardatacube
+
+ Args:
+ radial_ranges (np array): Size (N x 2) array for N radial bins, or (2,) for a single bin.
+ orientation_flip_sign (bool): Flip the direction of theta
+ orientation_offset_degrees (float): Offset for orientation angles
+ orientation_separate_bins (bool): whether to place multiple angles into multiple radial bins.
+ upsample_factor (float): Upsample factor
+ use_refined_peaks (float): Use refined peak positions
+ use_peak_sigma (float): Spread signal along annular direction using measured std.
+ theta_step_deg (float): Step size along annular direction in degrees
+ sigma_x (float): Smoothing in x direction before upsample
+ sigma_y (float): Smoothing in x direction before upsample
+ sigma_theta (float): Smoothing in annular direction (units of bins, periodic)
+ normalize_intensity_image (bool): Normalize to max peak intensity = 1, per image
+ normalize_intensity_stack (bool): Normalize to max peak intensity = 1, all images
+ progress_bar (bool): Enable progress bar
+
+ Returns:
+ orient_hist (array): 4D array containing Bragg peak intensity histogram
+ [radial_bin x_probe y_probe theta]
+ """
+
+ # coordinates
+ if theta_step_deg is None:
+ # Get angles from polardatacube
+ theta = self.tt
+ else:
+ theta = np.arange(0, 180, theta_step_deg) * np.pi / 180.0
+ dtheta = theta[1] - theta[0]
+ dtheta_deg = dtheta * 180 / np.pi
+ num_theta_bins = np.size(theta)
+
+ # Input bins
+ radial_ranges = np.array(radial_ranges)
+ if radial_ranges.ndim == 1:
+ radial_ranges = radial_ranges[None, :]
+ radial_ranges_2 = radial_ranges**2
+ num_radii = radial_ranges.shape[0]
+ size_input = self._datacube.shape[0:2]
+
+ # Output size
+ size_output = np.round(
+ np.array(size_input).astype("float") * upsample_factor
+ ).astype("int")
+
+ # output init
+ orient_hist = np.zeros([num_radii, size_output[0], size_output[1], num_theta_bins])
+
+ if use_peak_sigma:
+ v_sigma = np.linspace(-2, 2, 2 * peak_sigma_samples + 1)
+ w_sigma = np.exp(-(v_sigma**2) / 2)
+
+ if use_refined_peaks is False:
+ warnings.warn("Orientation histogram is using non-refined peak positions")
+
+ # Loop over all probe positions
+ for a0 in range(num_radii):
+ t = "Generating histogram " + str(a0)
+
+ for rx, ry in tqdmnd(
+ *size_input, desc=t, unit=" probe positions", disable=not progress_bar
+ ):
+ x = (rx + 0.5) * upsample_factor - 0.5
+ y = (ry + 0.5) * upsample_factor - 0.5
+ x = np.clip(x, 0, size_output[0] - 2)
+ y = np.clip(y, 0, size_output[1] - 2)
+
+ xF = np.floor(x).astype("int")
+ yF = np.floor(y).astype("int")
+ dx = x - xF
+ dy = y - yF
+
+ add_data = False
+ if use_refined_peaks:
+ q = self.peaks_refine[rx, ry]["qr"]
+ else:
+ q = (self.peaks[rx, ry]["qr"] + self.qmin) * self._radial_step
+ r2 = q**2
+ sub = np.logical_and(
+ r2 >= radial_ranges_2[a0, 0], r2 < radial_ranges_2[a0, 1]
+ )
+
+ if np.any(sub):
+ add_data = True
+ intensity = self.peaks[rx, ry]["intensity"][sub]
+
+ # Angles of all peaks
+ if use_refined_peaks:
+ theta = self.peaks_refine[rx, ry]["qt"][sub]
+ else:
+ theta = self.peaks[rx, ry]["qt"][sub] * self._annular_step
+ if orientation_flip_sign:
+ theta *= -1
+ theta += orientation_offset_degrees
+
+ t = theta / dtheta
+
+ # If needed, expand signal using peak sigma to write into multiple bins
+ if use_peak_sigma:
+ if use_refined_peaks:
+ theta_std = (
+ self.peaks_refine[rx, ry]["sigma_annular"][sub] / dtheta
+ )
+ else:
+ theta_std = self.peaks[rx, ry]["sigma_annular"][sub] / dtheta
+ t = (t[:, None] + theta_std[:, None] * v_sigma[None, :]).ravel()
+ intensity = (intensity[:, None] * w_sigma[None, :]).ravel()
+
+ if add_data:
+ tF = np.floor(t).astype("int")
+ dt = t - tF
+
+ orient_hist[a0, xF, yF, :] = orient_hist[a0, xF, yF, :] + np.bincount(
+ np.mod(tF, num_theta_bins),
+ weights=(1 - dx) * (1 - dy) * (1 - dt) * intensity,
+ minlength=num_theta_bins,
+ )
+ orient_hist[a0, xF, yF, :] = orient_hist[a0, xF, yF, :] + np.bincount(
+ np.mod(tF + 1, num_theta_bins),
+ weights=(1 - dx) * (1 - dy) * (dt) * intensity,
+ minlength=num_theta_bins,
+ )
+
+ orient_hist[a0, xF + 1, yF, :] = orient_hist[
+ a0, xF + 1, yF, :
+ ] + np.bincount(
+ np.mod(tF, num_theta_bins),
+ weights=(dx) * (1 - dy) * (1 - dt) * intensity,
+ minlength=num_theta_bins,
+ )
+ orient_hist[a0, xF + 1, yF, :] = orient_hist[
+ a0, xF + 1, yF, :
+ ] + np.bincount(
+ np.mod(tF + 1, num_theta_bins),
+ weights=(dx) * (1 - dy) * (dt) * intensity,
+ minlength=num_theta_bins,
+ )
+
+ orient_hist[a0, xF, yF + 1, :] = orient_hist[
+ a0, xF, yF + 1, :
+ ] + np.bincount(
+ np.mod(tF, num_theta_bins),
+ weights=(1 - dx) * (dy) * (1 - dt) * intensity,
+ minlength=num_theta_bins,
+ )
+ orient_hist[a0, xF, yF + 1, :] = orient_hist[
+ a0, xF, yF + 1, :
+ ] + np.bincount(
+ np.mod(tF + 1, num_theta_bins),
+ weights=(1 - dx) * (dy) * (dt) * intensity,
+ minlength=num_theta_bins,
+ )
+
+ orient_hist[a0, xF + 1, yF + 1, :] = orient_hist[
+ a0, xF + 1, yF + 1, :
+ ] + np.bincount(
+ np.mod(tF, num_theta_bins),
+ weights=(dx) * (dy) * (1 - dt) * intensity,
+ minlength=num_theta_bins,
+ )
+ orient_hist[a0, xF + 1, yF + 1, :] = orient_hist[
+ a0, xF + 1, yF + 1, :
+ ] + np.bincount(
+ np.mod(tF + 1, num_theta_bins),
+ weights=(dx) * (dy) * (dt) * intensity,
+ minlength=num_theta_bins,
+ )
+
+ # smoothing / interpolation
+ if (sigma_x is not None) or (sigma_y is not None) or (sigma_theta is not None):
+ if num_radii > 1:
+ print("Interpolating orientation matrices ...", end="")
+ else:
+ print("Interpolating orientation matrix ...", end="")
+ if sigma_x is not None and sigma_x > 0:
+ orient_hist = gaussian_filter1d(
+ orient_hist,
+ sigma_x * upsample_factor,
+ mode="nearest",
+ axis=1,
+ truncate=3.0,
+ )
+ if sigma_y is not None and sigma_y > 0:
+ orient_hist = gaussian_filter1d(
+ orient_hist,
+ sigma_y * upsample_factor,
+ mode="nearest",
+ axis=2,
+ truncate=3.0,
+ )
+ if sigma_theta is not None and sigma_theta > 0:
+ orient_hist = gaussian_filter1d(
+ orient_hist, sigma_theta / dtheta_deg, mode="wrap", axis=3, truncate=2.0
+ )
+ print(" done.")
+
+ # normalization
+ if normalize_intensity_stack is True:
+ orient_hist = orient_hist / np.max(orient_hist)
+ elif normalize_intensity_image is True:
+ for a0 in range(num_radii):
+ orient_hist[a0, :, :, :] = orient_hist[a0, :, :, :] / np.max(
+ orient_hist[a0, :, :, :]
+ )
+
+ return orient_hist
diff --git a/py4DSTEM/process/rdf/__init__.py b/py4DSTEM/process/rdf/__init__.py
new file mode 100644
index 000000000..feff32583
--- /dev/null
+++ b/py4DSTEM/process/rdf/__init__.py
@@ -0,0 +1 @@
+from py4DSTEM.process.rdf.rdf import *
diff --git a/py4DSTEM/process/rdf/amorph.py b/py4DSTEM/process/rdf/amorph.py
new file mode 100644
index 000000000..3aaf63c45
--- /dev/null
+++ b/py4DSTEM/process/rdf/amorph.py
@@ -0,0 +1,227 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from py4DSTEM.process.utils.elliptical_coords import * ## What else is used here? These fns have
+
+## moved around some. In general, specifying
+## the fns is better practice. TODO: change
+## this import
+from py4DSTEM.process.calibration import fit_ellipse_amorphous_ring
+import matplotlib
+from tqdm import tqdm
+
+# this fixes figure sizes on HiDPI screens
+matplotlib.rcParams["figure.dpi"] = 200
+plt.ion()
+
+
+def fit_stack(datacube, init_coefs, mask=None):
+ """
+ This will fit an ellipse using the polar elliptical transform code to all the
+ diffraction patterns. It will take in a datacube and return a coefficient array which
+ can then be used to map strain, fit the centers, etc.
+
+ Args:
+ datacute: a datacube of diffraction data
+ init_coefs: an initial starting guess for the fit
+ mask: a mask, either 2D or 4D, for either one mask for the whole stack, or one
+ per pattern.
+
+ Returns:
+ an array of coefficients of the fit
+ """
+ coefs_array = np.zeros([i for i in datacube.data.shape[0:2]] + [len(init_coefs)])
+ for i in tqdm(range(datacube.R_Nx)):
+ for j in tqdm(range(datacube.R_Ny)):
+ if len(mask.shape) == 2:
+ mask_current = mask
+ elif len(mask.shape) == 4:
+ mask_current = mask[i, j, :, :]
+
+ coefs = fit_ellipse_amorphous_ring(
+ datacube.data[i, j, :, :], init_coefs, mask=mask_current
+ )
+ coefs_array[i, j] = coefs
+
+ return coefs_array
+
+
+def calculate_coef_strain(coef_cube, r_ref):
+ """
+ This function will calculate the strains from a 3D matrix output by fit_stack
+
+ Coefs order:
+ * I0 the intensity of the first gaussian function
+ * I1 the intensity of the Janus gaussian
+ * sigma0 std of first gaussian
+ * sigma1 inner std of Janus gaussian
+ * sigma2 outer std of Janus gaussian
+ * c_bkgd a constant offset
+ * R center of the Janus gaussian
+ * x0,y0 the origin
+ * B,C 1x^2 + Bxy + Cy^2 = 1
+
+ Args:
+ coef_cube: output from fit_stack
+ r_ref: a reference 0 strain radius - needed because we fit r as well as B and C
+
+ Returns:
+ (3-tuple) A 3-tuple containing:
+
+ * **exx**: strain in the x axis direction in image coordinates
+ * **eyy**: strain in the y axis direction in image coordinates
+ * **exy**: shear
+
+ """
+ R = coef_cube[:, :, 6]
+ r_ratio = (
+ R / r_ref
+ ) # this is a correction factor for what defines 0 strain, and must be applied to A, B and C. This has been found _experimentally_! TODO have someone else read this
+
+ A = 1 / r_ratio**2
+ B = coef_cube[:, :, 9] / r_ratio**2
+ C = coef_cube[:, :, 10] / r_ratio**2
+
+ exx, eyy, exy = np.empty_like(A), np.empty_like(C), np.empty_like(B)
+
+ for i in range(A.shape[0]):
+ for j in range(A.shape[1]):
+ m_ellipse = np.asarray([[A[i, j], B[i, j] / 2], [B[i, j] / 2, C[i, j]]])
+ e_vals, e_vecs = np.linalg.eig(m_ellipse)
+ ang = np.arctan2(e_vecs[1, 0], e_vecs[0, 0])
+ rot_matrix = np.asarray(
+ [[np.cos(ang), -np.sin(ang)], [np.sin(ang), np.cos(ang)]]
+ )
+ transformation_matrix = np.diag(np.sqrt(e_vals))
+ transformation_matrix = rot_matrix @ transformation_matrix @ rot_matrix.T
+
+ exx[i, j] = transformation_matrix[0, 0] - 1
+ eyy[i, j] = transformation_matrix[1, 1] - 1
+ exy[i, j] = 0.5 * (
+ transformation_matrix[0, 1] + transformation_matrix[1, 0]
+ )
+
+ return exx, eyy, exy
+
+
+def plot_strains(strains, cmap="RdBu_r", vmin=None, vmax=None, mask=None):
+ """
+ This function will plot strains with a unified color scale.
+
+ Args:
+ strains (3-tuple of arrays): (exx, eyy, exy)
+ cmap, vmin, vmax: imshow parameters
+ mask: real space mask of values not to show (black)
+ """
+ cmap = plt.get_cmap(cmap)
+ if vmin is None:
+ vmin = np.min(strains)
+ if vmax is None:
+ vmax = np.max(strains)
+ if mask is None:
+ mask = np.ones_like(strains[0])
+ else:
+ cmap.set_under("black")
+ cmap.set_over("black")
+ cmap.set_bad("black")
+
+ mask = mask.astype(bool)
+
+ for i in strains:
+ i[mask] = np.nan
+
+ plt.figure(88, figsize=(9, 5.8), clear=True)
+ f, (ax1, ax2, ax3) = plt.subplots(1, 3, num=88)
+ ax1.imshow(strains[0], cmap=cmap, vmin=vmin, vmax=vmax)
+ ax1.tick_params(
+ axis="both",
+ which="both",
+ bottom=False,
+ top=False,
+ left=False,
+ right=False,
+ labelbottom=False,
+ labelleft=False,
+ )
+ ax1.set_title(r"$\epsilon_{xx}$")
+
+ ax2.imshow(strains[1], cmap=cmap, vmin=vmin, vmax=vmax)
+ ax2.tick_params(
+ axis="both",
+ which="both",
+ bottom=False,
+ top=False,
+ left=False,
+ right=False,
+ labelbottom=False,
+ labelleft=False,
+ )
+ ax2.set_title(r"$\epsilon_{yy}$")
+
+ im = ax3.imshow(strains[2], cmap=cmap, vmin=vmin, vmax=vmax)
+ ax3.tick_params(
+ axis="both",
+ which="both",
+ bottom=False,
+ top=False,
+ left=False,
+ right=False,
+ labelbottom=False,
+ labelleft=False,
+ )
+ ax3.set_title(r"$\epsilon_{xy}$")
+
+ cbar_ax = f.add_axes([0.125, 0.25, 0.775, 0.05])
+ f.colorbar(im, cax=cbar_ax, orientation="horizontal")
+
+ return
+
+
+def convert_stack_polar(datacube, coef_cube):
+ """
+ This function will take the coef_cube from fit_stack and apply it to the image stack,
+ to return polar transformed images.
+
+ Args:
+ datacube: data in datacube format
+ coef_cube: coefs from fit_stack
+
+ Returns:
+ polar transformed datacube
+ """
+
+ return datacube_polar
+
+
+def compute_polar_stack_symmetries(datacube_polar):
+ """
+ This function will take in a datacube of polar-transformed diffraction patterns, and
+ do the autocorrelation, before taking the fourier transform along the theta
+ direction, such that symmetries can be measured. They will be plotted by a different
+ function
+
+ Args:
+ datacube_polar: diffraction pattern cube that has been polar transformed
+
+ Returns:
+ the normalized fft along the theta direction of the autocorrelated patterns in
+ datacube_polar
+ """
+
+ return datacube_symmetries
+
+
+def plot_symmetries(datacube_symmetries, sym_order):
+ """
+ This function will take in a datacube from compute_polar_stack_symmetries and plot a
+ specific symmetry order.
+
+ Args:
+ datacube_symmetries: result of compute_polar_stack_symmetries, the stack of
+ fft'd autocorrelated diffraction patterns
+ sym_order: symmetry order desired to plot
+
+ Returns:
+ None
+ """
+
+ return None
diff --git a/py4DSTEM/process/rdf/rdf.py b/py4DSTEM/process/rdf/rdf.py
new file mode 100644
index 000000000..cee7eeee9
--- /dev/null
+++ b/py4DSTEM/process/rdf/rdf.py
@@ -0,0 +1,81 @@
+# Module for extracting radial distribution functions g(r) from a series of diffraction
+# images. Process follows closely to procedure covered in:
+# Cockayne, D.H.,Annu. Rev. Mater. Res. 37:15987 (2007).
+
+import numpy as np
+from scipy.special import erf
+from scipy.fftpack import dst, idst
+
+from py4DSTEM.process.utils import single_atom_scatter
+
+
+def get_radial_intensity(polar_img, polar_mask):
+ """
+ Takes in a radial transformed image and the radial mask (if any) applied to that image.
+ Designed to be compatible with polar-elliptical transforms from utils
+ """
+ yMean = np.mean(polar_img, axis=0)
+ yNorm = np.mean(polar_mask, axis=0)
+ sub = yNorm > 1e-1
+ yMean[sub] = yMean[sub] / yNorm[sub]
+
+ return yMean
+
+
+def fit_scattering_factor(scale, elements, composition, q_arr, units):
+ """
+ Scale is linear factor
+ Elements is an 1D array of atomic numbers.
+ Composition is a 1D array, same length as elements, describing the average atomic
+ composition of the sample. If the Q_coords is a 1D array of Fourier coordinates,
+ given in inverse Angstroms. Units is a string of 'VA' or 'A', which returns the
+ scattering factor in volt angtroms or in angstroms.
+ """
+
+ ##TODO: actually do fitting
+ scatter = single_atom_scatter(elements, composition, q_arr, units)
+ scatter.get_scattering_factor()
+ return scale * scatter.fe**2
+
+
+def get_phi(radialIntensity, scatter, q_arr):
+ """
+ ymean
+ scale*scatter.fe**2
+ """
+ return ((radialIntensity - scatter) / scatter) * q_arr
+
+
+def get_mask(left, right, midpoint, slopes, q_arr):
+ """
+ start is float
+ stop is float
+ midpoint is float
+ slopes is [float,float]
+ """
+ vec = q_arr
+ mask_left = (erf(slopes[0] * (vec - left)) + 1) / 2
+ mask_right = (erf(slopes[1] * (right - vec)) + 1) / 2
+ mid_idx = np.max(np.where(q_arr < midpoint))
+ mask_left[mid_idx:] = 0
+ mask_right[0:mid_idx] = 0
+
+ return mask_left + mask_right
+
+
+def get_rdf(phi, q_arr):
+ """
+ phi can be masked or not masked
+ """
+ sample_freq = 1 / (
+ q_arr[1] - q_arr[0]
+ ) # this assumes regularly spaced samples in q-space
+ radius = (np.arange(q_arr.shape[0]) / q_arr.shape[0]) * sample_freq
+ radius = radius * 0.5 # scaling factor
+ radius += radius[
+ 1
+ ] # shift by minimum frequency, since first frequency sampled is finite
+
+ G_r = dst(phi, type=2)
+ g_r = G_r / (4 * np.pi * radius) + 1
+ return g_r, radius
diff --git a/py4DSTEM/process/strain/__init__.py b/py4DSTEM/process/strain/__init__.py
new file mode 100644
index 000000000..b487c916b
--- /dev/null
+++ b/py4DSTEM/process/strain/__init__.py
@@ -0,0 +1,10 @@
+from py4DSTEM.process.strain.strain import StrainMap
+from py4DSTEM.process.strain.latticevectors import (
+ index_bragg_directions,
+ add_indices_to_braggvectors,
+ fit_lattice_vectors,
+ fit_lattice_vectors_all_DPs,
+ get_reference_g1g2,
+ get_strain_from_reference_g1g2,
+ get_rotated_strain_map,
+)
diff --git a/py4DSTEM/process/strain/latticevectors.py b/py4DSTEM/process/strain/latticevectors.py
new file mode 100644
index 000000000..dcff91709
--- /dev/null
+++ b/py4DSTEM/process/strain/latticevectors.py
@@ -0,0 +1,469 @@
+# Functions for indexing the Bragg directions
+
+import numpy as np
+from emdfile import PointList, PointListArray, tqdmnd
+from numpy.linalg import lstsq
+from py4DSTEM.data import RealSlice
+
+
+def index_bragg_directions(x0, y0, gx, gy, g1, g2):
+ """
+ From an origin (x0,y0), a set of reciprocal lattice vectors gx,gy, and an pair of
+ lattice vectors g1=(g1x,g1y), g2=(g2x,g2y), find the indices (h,k) of all the
+ reciprocal lattice directions.
+
+ The approach is to solve the matrix equation
+ ``alpha = beta * M``
+ where alpha is the 2xN array of the (x,y) coordinates of N measured bragg directions,
+ beta is the 2x2 array of the two lattice vectors u,v, and M is the 2xN array of the
+ h,k indices.
+
+ Args:
+ x0 (float): x-coord of origin
+ y0 (float): y-coord of origin
+ gx (1d array): x-coord of the reciprocal lattice vectors
+ gy (1d array): y-coord of the reciprocal lattice vectors
+ g1 (2-tuple of floats): g1x,g1y
+ g2 (2-tuple of floats): g2x,g2y
+
+ Returns:
+ (3-tuple) A 3-tuple containing:
+
+ * **h**: *(ndarray of ints)* first index of the bragg directions
+ * **k**: *(ndarray of ints)* second index of the bragg directions
+ * **bragg_directions**: *(PointList)* a 4-coordinate PointList with the
+ indexed bragg directions; coords 'qx' and 'qy' contain bragg_x and bragg_y
+ coords 'h' and 'k' contain h and k.
+ """
+ # Get beta, the matrix of lattice vectors
+ beta = np.array([[g1[0], g2[0]], [g1[1], g2[1]]])
+
+ # Get alpha, the matrix of measured bragg angles
+ alpha = np.vstack([gx - x0, gy - y0])
+
+ # Calculate M, the matrix of peak positions
+ M = lstsq(beta, alpha, rcond=None)[0].T
+ M = np.round(M).astype(int)
+
+ # Get h,k
+ h = M[:, 0]
+ k = M[:, 1]
+
+ # Store in a PointList
+ coords = [("qx", float), ("qy", float), ("h", int), ("k", int)]
+ temp_array = np.zeros([], dtype=coords)
+ bragg_directions = PointList(data=temp_array)
+ bragg_directions.add_data_by_field((gx, gy, h, k))
+ mask = np.zeros(bragg_directions["qx"].shape[0])
+ mask[0] = 1
+ bragg_directions.remove(mask)
+
+ return h, k, bragg_directions
+
+
+def add_indices_to_braggvectors(
+ braggpeaks, lattice, maxPeakSpacing, qx_shift=0, qy_shift=0, mask=None
+):
+ """
+ Using the peak positions (qx,qy) and indices (h,k) in the PointList lattice,
+ identify the indices for each peak in the PointListArray braggpeaks.
+ Return a new braggpeaks_indexed PointListArray, containing a copy of braggpeaks plus
+ three additional data columns -- 'h','k', and 'index_mask' -- specifying the peak
+ indices with the ints (h,k) and indicating whether the peak was successfully indexed
+ or not with the bool index_mask. If `mask` is specified, only the locations where
+ mask is True are indexed.
+
+ Args:
+ braggpeaks (PointListArray): the braggpeaks to index. Must contain
+ the coordinates 'qx', 'qy', and 'intensity'
+ lattice (PointList): the positions (qx,qy) of the (h,k) lattice points.
+ Must contain the coordinates 'qx', 'qy', 'h', and 'k'
+ maxPeakSpacing (float): Maximum distance from the ideal lattice points
+ to include a peak for indexing
+ qx_shift,qy_shift (number): the shift of the origin in the `lattice` PointList
+ relative to the `braggpeaks` PointListArray
+ mask (bool): Boolean mask, same shape as the pointlistarray, indicating which
+ locations should be indexed. This can be used to index different regions of
+ the scan with different lattices
+
+ Returns:
+ (PointListArray): The original braggpeaks pointlistarray, with new coordinates
+ 'h', 'k', containing the indices of each indexable peak.
+ """
+
+ # assert isinstance(braggpeaks,BraggVectors)
+ # assert isinstance(lattice, PointList)
+ # assert np.all([name in lattice.dtype.names for name in ('qx','qy','h','k')])
+
+ if mask is None:
+ mask = np.ones(braggpeaks.Rshape, dtype=bool)
+
+ assert (
+ mask.shape == braggpeaks.Rshape
+ ), "mask must have same shape as pointlistarray"
+ assert mask.dtype == bool, "mask must be boolean"
+
+ coords = [
+ ("qx", float),
+ ("qy", float),
+ ("intensity", float),
+ ("h", int),
+ ("k", int),
+ ]
+
+ indexed_braggpeaks = PointListArray(
+ dtype=coords,
+ shape=braggpeaks.Rshape,
+ )
+
+ calstate = braggpeaks.calstate
+
+ # loop over all the scan positions
+ for Rx, Ry in tqdmnd(mask.shape[0], mask.shape[1]):
+ if mask[Rx, Ry]:
+ pl = braggpeaks.get_vectors(
+ Rx,
+ Ry,
+ center=True,
+ ellipse=calstate["ellipse"],
+ rotate=calstate["rotate"],
+ pixel=False,
+ )
+ for i in range(pl.data.shape[0]):
+ r2 = (pl.data["qx"][i] - lattice.data["qx"] + qx_shift) ** 2 + (
+ pl.data["qy"][i] - lattice.data["qy"] + qy_shift
+ ) ** 2
+ ind = np.argmin(r2)
+ if r2[ind] <= maxPeakSpacing**2:
+ indexed_braggpeaks[Rx, Ry].add_data_by_field(
+ (
+ pl.data["qx"][i],
+ pl.data["qy"][i],
+ pl.data["intensity"][i],
+ lattice.data["h"][ind],
+ lattice.data["k"][ind],
+ )
+ )
+
+ return indexed_braggpeaks
+
+
+def fit_lattice_vectors(braggpeaks, x0=0, y0=0, minNumPeaks=5):
+ """
+ Fits lattice vectors g1,g2 to braggpeaks given some known (h,k) indexing.
+
+ Args:
+ braggpeaks (PointList): A 6 coordinate PointList containing the data to fit.
+ Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a
+ weighting factor when fitting), 'h','k' (indexing). May optionally also
+ contain 'index_mask' (bool), indicating which peaks have been successfully
+ indixed and should be used.
+ x0 (float): x-coord of the origin
+ y0 (float): y-coord of the origin
+ minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks
+ which can be indexed, return None for all return parameters
+
+ Returns:
+ (7-tuple) A 7-tuple containing:
+
+ * **x0**: *(float)* the x-coord of the origin of the best-fit lattice.
+ * **y0**: *(float)* the y-coord of the origin
+ * **g1x**: *(float)* x-coord of the first lattice vector
+ * **g1y**: *(float)* y-coord of the first lattice vector
+ * **g2x**: *(float)* x-coord of the second lattice vector
+ * **g2y**: *(float)* y-coord of the second lattice vector
+ * **error**: *(float)* the fit error
+ """
+ assert isinstance(braggpeaks, PointList)
+ assert np.all(
+ [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")]
+ )
+ braggpeaks = braggpeaks.copy()
+
+ # Remove unindexed peaks
+ if "index_mask" in braggpeaks.dtype.names:
+ deletemask = braggpeaks.data["index_mask"] == False # noqa:E712
+ braggpeaks.remove(deletemask)
+
+ # Check to ensure enough peaks are present
+ if braggpeaks.length < minNumPeaks:
+ return None, None, None, None, None, None, None
+
+ # Get M, the matrix of (h,k) indices
+ h, k = braggpeaks.data["h"], braggpeaks.data["k"]
+ M = np.vstack((np.ones_like(h, dtype=int), h, k)).T
+
+ # Get alpha, the matrix of measured Bragg peak positions
+ alpha = np.vstack((braggpeaks.data["qx"] - x0, braggpeaks.data["qy"] - y0)).T
+
+ # Get weighted matrices
+ weights = braggpeaks.data["intensity"]
+ weighted_M = M * weights[:, np.newaxis]
+ weighted_alpha = alpha * weights[:, np.newaxis]
+
+ # Solve for lattice vectors
+ beta = lstsq(weighted_M, weighted_alpha, rcond=None)[0]
+ x0, y0 = beta[0, 0], beta[0, 1]
+ g1x, g1y = beta[1, 0], beta[1, 1]
+ g2x, g2y = beta[2, 0], beta[2, 1]
+
+ # Calculate the error
+ alpha_calculated = np.matmul(M, beta)
+ error = np.sqrt(np.sum((alpha - alpha_calculated) ** 2, axis=1))
+ error = np.sum(error * weights) / np.sum(weights)
+
+ return x0, y0, g1x, g1y, g2x, g2y, error
+
+
+def fit_lattice_vectors_all_DPs(braggpeaks, x0=0, y0=0, minNumPeaks=5):
+ """
+ Fits lattice vectors g1,g2 to each diffraction pattern in braggpeaks, given some
+ known (h,k) indexing.
+
+ Args:
+ braggpeaks (PointList): A 6 coordinate PointList containing the data to fit.
+ Coords are 'qx','qy' (the bragg peak positions), 'intensity' (used as a
+ weighting factor when fitting), 'h','k' (indexing). May optionally also
+ contain 'index_mask' (bool), indicating which peaks have been successfully
+ indixed and should be used.
+ x0 (float): x-coord of the origin
+ y0 (float): y-coord of the origin
+ minNumPeaks (int): if there are fewer than minNumPeaks peaks found in braggpeaks
+ which can be indexed, return None for all return parameters
+
+ Returns:
+ (RealSlice): A RealSlice ``g1g2map`` containing the following 8 arrays:
+
+ * ``g1g2_map.get_slice('x0')`` x-coord of the origin of the best fit lattice
+ * ``g1g2_map.get_slice('y0')`` y-coord of the origin
+ * ``g1g2_map.get_slice('g1x')`` x-coord of the first lattice vector
+ * ``g1g2_map.get_slice('g1y')`` y-coord of the first lattice vector
+ * ``g1g2_map.get_slice('g2x')`` x-coord of the second lattice vector
+ * ``g1g2_map.get_slice('g2y')`` y-coord of the second lattice vector
+ * ``g1g2_map.get_slice('error')`` the fit error
+ * ``g1g2_map.get_slice('mask')`` 1 for successful fits, 0 for unsuccessful
+ fits
+ """
+ assert isinstance(braggpeaks, PointListArray)
+ assert np.all(
+ [name in braggpeaks.dtype.names for name in ("qx", "qy", "intensity", "h", "k")]
+ )
+
+ # Make RealSlice to contain outputs
+ slicelabels = ("x0", "y0", "g1x", "g1y", "g2x", "g2y", "error", "mask")
+ g1g2_map = RealSlice(
+ data=np.zeros((8, braggpeaks.shape[0], braggpeaks.shape[1])),
+ slicelabels=slicelabels,
+ name="g1g2_map",
+ )
+
+ # Fit lattice vectors
+ for Rx, Ry in tqdmnd(
+ braggpeaks.shape[0],
+ braggpeaks.shape[1],
+ desc="Fitting lattice vectors",
+ unit="DP",
+ unit_scale=True,
+ ):
+ braggpeaks_curr = braggpeaks.get_pointlist(Rx, Ry)
+ qx0, qy0, g1x, g1y, g2x, g2y, error = fit_lattice_vectors(
+ braggpeaks_curr, x0, y0, minNumPeaks
+ )
+ # Store data
+ if g1x is not None:
+ g1g2_map.get_slice("x0").data[Rx, Ry] = qx0
+ g1g2_map.get_slice("y0").data[Rx, Ry] = qx0
+ g1g2_map.get_slice("g1x").data[Rx, Ry] = g1x
+ g1g2_map.get_slice("g1y").data[Rx, Ry] = g1y
+ g1g2_map.get_slice("g2x").data[Rx, Ry] = g2x
+ g1g2_map.get_slice("g2y").data[Rx, Ry] = g2y
+ g1g2_map.get_slice("error").data[Rx, Ry] = error
+ g1g2_map.get_slice("mask").data[Rx, Ry] = 1
+
+ return g1g2_map
+
+
+def get_reference_g1g2(g1g2_map, mask):
+ """
+ Gets a pair of reference lattice vectors from a region of real space specified by
+ mask. Takes the median of the lattice vectors in g1g2_map within the specified
+ region.
+
+ Args:
+ g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data
+ under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for
+ fit_lattice_vectors_all_DPs() for more information.
+ mask (ndarray of bools): use lattice vectors from g1g2_map scan positions wherever
+ mask==True
+
+ Returns:
+ (2-tuple of 2-tuples) A 2-tuple containing:
+
+ * **g1**: *(2-tuple)* first reference lattice vector (x,y)
+ * **g2**: *(2-tuple)* second reference lattice vector (x,y)
+ """
+ assert isinstance(g1g2_map, RealSlice)
+ assert np.all(
+ [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y")]
+ )
+ assert mask.dtype == bool
+ g1x = np.median(g1g2_map.get_slice("g1x").data[mask])
+ g1y = np.median(g1g2_map.get_slice("g1y").data[mask])
+ g2x = np.median(g1g2_map.get_slice("g2x").data[mask])
+ g2y = np.median(g1g2_map.get_slice("g2y").data[mask])
+ return (g1x, g1y), (g2x, g2y)
+
+
+def get_strain_from_reference_g1g2(g1g2_map, g1, g2):
+ """
+ Gets a strain map from the reference lattice vectors g1,g2 and lattice vector map
+ g1g2_map.
+
+ Note that this function will return the strain map oriented with respect to the x/y
+ axes of diffraction space - to rotate the coordinate system, use
+ get_rotated_strain_map(). Calibration of the rotational misalignment between real and
+ diffraction space may also be necessary.
+
+ Args:
+ g1g2_map (RealSlice): the lattice vector map; contains 2D arrays in g1g2_map.data
+ under the keys 'g1x', 'g1y', 'g2x', and 'g2y'. See documentation for
+ fit_lattice_vectors_all_DPs() for more information.
+ g1 (2-tuple): first reference lattice vector (x,y)
+ g2 (2-tuple): second reference lattice vector (x,y)
+
+ Returns:
+ (RealSlice) the strain map; contains the elements of the infinitessimal strain
+ matrix, in the following 5 arrays:
+
+ * ``strain_map.get_slice('e_xx')``: change in lattice x-components with respect
+ to x
+ * ``strain_map.get_slice('e_yy')``: change in lattice y-components with respect
+ to y
+ * ``strain_map.get_slice('e_xy')``: change in lattice x-components with respect
+ to y
+ * ``strain_map.get_slice('theta')``: rotation of lattice with respect to
+ reference
+ * ``strain_map.get_slice('mask')``: 0/False indicates unknown values
+
+ Note 1: the strain matrix has been symmetrized, so e_xy and e_yx are identical
+ """
+ assert isinstance(g1g2_map, RealSlice)
+ assert np.all(
+ [name in g1g2_map.slicelabels for name in ("g1x", "g1y", "g2x", "g2y", "mask")]
+ )
+
+ # Get RealSlice for output storage
+ R_Nx, R_Ny = g1g2_map.get_slice("g1x").shape
+ strain_map = RealSlice(
+ data=np.zeros((5, R_Nx, R_Ny)),
+ slicelabels=("e_xx", "e_yy", "e_xy", "theta", "mask"),
+ name="strain_map",
+ )
+
+ # Get reference lattice matrix
+ g1x, g1y = g1
+ g2x, g2y = g2
+ M = np.array([[g1x, g1y], [g2x, g2y]])
+
+ for Rx, Ry in tqdmnd(
+ R_Nx,
+ R_Ny,
+ desc="Calculating strain",
+ unit="DP",
+ unit_scale=True,
+ ):
+ # Get lattice vectors for DP at Rx,Ry
+ alpha = np.array(
+ [
+ [
+ g1g2_map.get_slice("g1x").data[Rx, Ry],
+ g1g2_map.get_slice("g1y").data[Rx, Ry],
+ ],
+ [
+ g1g2_map.get_slice("g2x").data[Rx, Ry],
+ g1g2_map.get_slice("g2y").data[Rx, Ry],
+ ],
+ ]
+ )
+ # Get transformation matrix
+ beta = lstsq(M, alpha, rcond=None)[0].T
+
+ # Get the infinitesimal strain matrix
+ strain_map.get_slice("e_xx").data[Rx, Ry] = 1 - beta[0, 0]
+ strain_map.get_slice("e_yy").data[Rx, Ry] = 1 - beta[1, 1]
+ strain_map.get_slice("e_xy").data[Rx, Ry] = -(beta[0, 1] + beta[1, 0]) / 2.0
+ strain_map.get_slice("theta").data[Rx, Ry] = (beta[0, 1] - beta[1, 0]) / 2.0
+ strain_map.get_slice("mask").data[Rx, Ry] = g1g2_map.get_slice("mask").data[
+ Rx, Ry
+ ]
+ return strain_map
+
+
+def get_rotated_strain_map(unrotated_strain_map, xaxis_x, xaxis_y, flip_theta):
+ """
+ Starting from a strain map defined with respect to the xy coordinate system of
+ diffraction space, i.e. where exx and eyy are the compression/tension along the Qx
+ and Qy directions, respectively, get a strain map defined with respect to some other
+ right-handed coordinate system, in which the x-axis is oriented along (xaxis_x,
+ xaxis_y).
+
+ Args:
+ xaxis_x,xaxis_y (float): diffraction space (x,y) coordinates of a vector
+ along the new x-axis
+ unrotated_strain_map (RealSlice): a RealSlice object containing 2D arrays of the
+ infinitessimal strain matrix elements, stored at
+ * unrotated_strain_map.get_slice('e_xx')
+ * unrotated_strain_map.get_slice('e_xy')
+ * unrotated_strain_map.get_slice('e_yy')
+ * unrotated_strain_map.get_slice('theta')
+
+ Returns:
+ (RealSlice) the rotated counterpart to unrotated_strain_map, with the
+ rotated_strain_map.get_slice('e_xx') element oriented along the new coordinate
+ system
+ """
+ assert isinstance(unrotated_strain_map, RealSlice)
+ assert np.all(
+ [
+ key in ["e_xx", "e_xy", "e_yy", "theta", "mask"]
+ for key in unrotated_strain_map.slicelabels
+ ]
+ )
+ theta = -np.arctan2(xaxis_y, xaxis_x)
+ cost = np.cos(theta)
+ sint = np.sin(theta)
+ cost2 = cost**2
+ sint2 = sint**2
+
+ Rx, Ry = unrotated_strain_map.get_slice("e_xx").data.shape
+ rotated_strain_map = RealSlice(
+ data=np.zeros((5, Rx, Ry)),
+ slicelabels=["e_xx", "e_xy", "e_yy", "theta", "mask"],
+ name=unrotated_strain_map.name + "_rotated".format(np.degrees(theta)),
+ )
+
+ rotated_strain_map.data[0, :, :] = (
+ cost2 * unrotated_strain_map.get_slice("e_xx").data
+ - 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data
+ + sint2 * unrotated_strain_map.get_slice("e_yy").data
+ )
+ rotated_strain_map.data[1, :, :] = (
+ cost
+ * sint
+ * (
+ unrotated_strain_map.get_slice("e_xx").data
+ - unrotated_strain_map.get_slice("e_yy").data
+ )
+ + (cost2 - sint2) * unrotated_strain_map.get_slice("e_xy").data
+ )
+ rotated_strain_map.data[2, :, :] = (
+ sint2 * unrotated_strain_map.get_slice("e_xx").data
+ + 2 * cost * sint * unrotated_strain_map.get_slice("e_xy").data
+ + cost2 * unrotated_strain_map.get_slice("e_yy").data
+ )
+ if flip_theta is True:
+ rotated_strain_map.data[3, :, :] = -unrotated_strain_map.get_slice("theta").data
+ else:
+ rotated_strain_map.data[3, :, :] = unrotated_strain_map.get_slice("theta").data
+ rotated_strain_map.data[4, :, :] = unrotated_strain_map.get_slice("mask").data
+ return rotated_strain_map
diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py
new file mode 100644
index 000000000..099ecdefd
--- /dev/null
+++ b/py4DSTEM/process/strain/strain.py
@@ -0,0 +1,1582 @@
+# Defines the Strain class
+
+import warnings
+from typing import Optional
+
+import matplotlib.pyplot as plt
+from matplotlib.patches import Circle
+from matplotlib.collections import PatchCollection
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+import numpy as np
+from py4DSTEM import PointList, PointListArray, tqdmnd
+from py4DSTEM.braggvectors import BraggVectors
+from py4DSTEM.data import Data, RealSlice
+from py4DSTEM.preprocess.utils import get_maxima_2D
+from py4DSTEM.process.strain.latticevectors import (
+ add_indices_to_braggvectors,
+ fit_lattice_vectors_all_DPs,
+ get_reference_g1g2,
+ get_rotated_strain_map,
+ get_strain_from_reference_g1g2,
+ index_bragg_directions,
+)
+from py4DSTEM.visualize import (
+ show,
+ add_bragg_index_labels,
+ add_pointlabels,
+ add_vector,
+ ax_addaxes,
+ ax_addaxes_QtoR,
+)
+
+
+class StrainMap(RealSlice, Data):
+ """
+ Storage and processing methods for 4D-STEM datasets.
+
+ """
+
+ def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap"):
+ """
+ Parameters
+ ----------
+ braggvectors : BraggVectors
+ The Bragg vectors
+ name : str
+ The name of the strainmap
+
+ Returns
+ -------
+ A new StrainMap instance.
+ """
+ assert isinstance(
+ braggvectors, BraggVectors
+ ), f"braggvectors must be BraggVectors, not type {type(braggvectors)}"
+
+ # initialize as a RealSlice
+ RealSlice.__init__(
+ self,
+ name=name,
+ data=np.empty(
+ (
+ 6,
+ braggvectors.Rshape[0],
+ braggvectors.Rshape[1],
+ )
+ ),
+ slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"],
+ )
+
+ # set up braggvectors
+ # this assigns the bvs, ensures the origin is calibrated,
+ # and adds the strainmap to the bvs' tree
+ self.braggvectors = braggvectors
+
+ # initialize as Data
+ Data.__init__(self)
+
+ # set calstate
+ # this property is used only to check to make sure that
+ # the braggvectors being used throughout a workflow are
+ # the same. The state of calibration of the vectors is noted
+ # here, and then checked each time the vectors are used -
+ # if they differ, an error message and instructions for
+ # re-calibration are issued
+ self.calstate = self.braggvectors.calstate
+ assert self.calstate["center"], "braggvectors must be centered"
+ if self.calstate["rotate"] is False:
+ warnings.warn(
+ ("Real to reciprocal space rotation not calibrated"),
+ UserWarning,
+ )
+
+ # get the BVM
+ # a new BVM using the current calstate is computed
+ self.bvm = self.braggvectors.histogram(mode="cal")
+
+ # braggvector properties
+
+ @property
+ def braggvectors(self):
+ return self._braggvectors
+
+ @braggvectors.setter
+ def braggvectors(self, x):
+ assert isinstance(
+ x, BraggVectors
+ ), f".braggvectors must be BraggVectors, not type {type(x)}"
+ assert (
+ x.calibration.origin is not None
+ ), "braggvectors must have a calibrated origin"
+ self._braggvectors = x
+ self._braggvectors.tree(self, force=True)
+
+ @property
+ def rshape(self):
+ return self._braggvectors.Rshape
+
+ @property
+ def qshape(self):
+ return self._braggvectors.Qshape
+
+ @property
+ def origin(self):
+ return self.calibration.get_origin_mean()
+
+ def reset_calstate(self):
+ """
+ Resets the calibration state. This recomputes the BVM, and removes any computations
+ this StrainMap instance has stored, which will need to be recomputed.
+ """
+ for attr in (
+ "g0",
+ "g1",
+ "g2",
+ ):
+ if hasattr(self, attr):
+ delattr(self, attr)
+ self.calstate = self.braggvectors.calstate
+ pass
+
+ # Class methods
+
+ def choose_basis_vectors(
+ self,
+ index_g1=None,
+ index_g2=None,
+ index_origin=None,
+ subpixel="multicorr",
+ upsample_factor=16,
+ sigma=0,
+ minAbsoluteIntensity=0,
+ minRelativeIntensity=0,
+ relativeToPeak=0,
+ minSpacing=0,
+ edgeBoundary=1,
+ maxNumPeaks=10,
+ x0=None,
+ y0=None,
+ figsize=(14, 9),
+ c_indices="lightblue",
+ c0="g",
+ c1="r",
+ c2="r",
+ c_vectors="r",
+ c_vectorlabels="w",
+ size_indices=15,
+ width_vectors=1,
+ size_vectorlabels=15,
+ vis_params={},
+ returncalc=False,
+ returnfig=False,
+ ):
+ """
+ Choose basis lattice vectors g1 and g2 for strain mapping.
+
+ Overlays the bvm with the points detected via local 2D
+ maxima detection, plus an index for each point. Three points
+ are selected which correspond to the origin, and the basis
+ reciprocal lattice vectors g1 and g2. By default these are
+ automatically located; the user can override and select these
+ manually using the `index_*` arguments.
+
+ Parameters
+ ----------
+ index_g1 : int
+ selected index for g1
+ index_g2 :int
+ selected index for g2
+ index_origin : int
+ selected index for the origin
+ subpixel : str in ('pixel','poly','multicorr')
+ See the docstring for py4DSTEM.preprocess.get_maxima_2D
+ upsample_factor : int
+ See the py4DSTEM.preprocess.get_maxima_2D docstring
+ sigma : number
+ See the py4DSTEM.preprocess.get_maxima_2D docstring
+ minAbsoluteIntensity : number
+ See the py4DSTEM.preprocess.get_maxima_2D docstring
+ minRelativeIntensity : number
+ See the py4DSTEM.preprocess.get_maxima_2D docstring
+ relativeToPeak : int
+ See the py4DSTEM.preprocess.get_maxima_2D docstring
+ minSpacing : number
+ See the py4DSTEM.preprocess.get_maxima_2D docstring
+ edgeBoundary : number
+ See the py4DSTEM.preprocess.get_maxima_2D docstring
+ maxNumPeaks : int
+ See the py4DSTEM.preprocess.get_maxima_2D docstring
+ figsize : 2-tuple
+ the size of the figure
+ c_indices : color
+ color of the maxima
+ c0 : color
+ color of the origin
+ c1 : color
+ color of g1 point
+ c2 : color
+ color of g2 point
+ c_vectors : color
+ color of the g1/g2 vectors
+ c_vectorlabels : color
+ color of the vector labels
+ size_indices : number
+ size of the indices
+ width_vectors : number
+ width of the vectors
+ size_vectorlabels : number
+ size of the vector labels
+ vis_params : dict
+ additional visualization parameters passed to `show`
+ returncalc : bool
+ toggles returning the answer
+ returnfig : bool
+ toggles returning the figure
+
+ Returns
+ -------
+ (optional) : None or (g0,g1,g2) or (fig,(ax1,ax2)) or the latter two
+ """
+ # validate inputs
+ for i in (index_origin, index_g1, index_g2):
+ assert isinstance(i, (int, np.integer)) or (
+ i is None
+ ), "indices must be integers!"
+ # check the calstate
+ assert (
+ self.calstate == self.braggvectors.calstate
+ ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`."
+
+ # find the maxima
+
+ g = get_maxima_2D(
+ self.bvm.data,
+ subpixel=subpixel,
+ upsample_factor=upsample_factor,
+ sigma=sigma,
+ minAbsoluteIntensity=minAbsoluteIntensity,
+ minRelativeIntensity=minRelativeIntensity,
+ relativeToPeak=relativeToPeak,
+ minSpacing=minSpacing,
+ edgeBoundary=edgeBoundary,
+ maxNumPeaks=maxNumPeaks,
+ )
+
+ # guess the origin and g1 g2 vectors if indices aren't provided
+ if np.any([x is None for x in (index_g1, index_g2, index_origin)]):
+ # get distances and angles from calibrated origin
+ g_dists = np.hypot(g["x"] - self.origin[0], g["y"] - self.origin[1])
+ g_angles = np.angle(
+ g["x"] - self.origin[0] + 1j * (g["y"] - self.origin[1])
+ )
+
+ # guess the origin
+ if index_origin is None:
+ index_origin = np.argmin(g_dists)
+ g_dists[index_origin] = 2 * np.max(g_dists)
+
+ # guess g1
+ if index_g1 is None:
+ index_g1 = np.argmin(g_dists)
+ g_dists[index_g1] = 2 * np.max(g_dists)
+
+ # guess g2
+ if index_g2 is None:
+ angle_scaling = np.cos(g_angles - g_angles[index_g1]) ** 2
+ index_g2 = np.argmin(g_dists * (angle_scaling + 0.1))
+
+ # get the lattice vectors
+ gx, gy = g["x"], g["y"]
+ g0 = gx[index_origin], gy[index_origin]
+ g1x = gx[index_g1] - g0[0]
+ g1y = gy[index_g1] - g0[1]
+ g2x = gx[index_g2] - g0[0]
+ g2y = gy[index_g2] - g0[1]
+ g1, g2 = (g1x, g1y), (g2x, g2y)
+
+ # index the lattice vectors
+ _, _, braggdirections = index_bragg_directions(
+ g0[0], g0[1], g["x"], g["y"], g1, g2
+ )
+
+ # make the figure
+ fig, ax = plt.subplots(1, 3, figsize=figsize)
+ show(self.bvm.data, figax=(fig, ax[0]), **vis_params)
+ show(self.bvm.data, figax=(fig, ax[1]), **vis_params)
+ self.show_bragg_indexing(
+ self.bvm.data,
+ bragg_directions=braggdirections,
+ points=True,
+ figax=(fig, ax[2]),
+ size=size_indices,
+ **vis_params,
+ )
+
+ # Add indices to left panel
+ d = {"x": gx, "y": gy, "size": size_indices, "color": c_indices}
+ d0 = {
+ "x": gx[index_origin],
+ "y": gy[index_origin],
+ "size": size_indices,
+ "color": c0,
+ "fontweight": "bold",
+ "labels": [str(index_origin)],
+ }
+ d1 = {
+ "x": gx[index_g1],
+ "y": gy[index_g1],
+ "size": size_indices,
+ "color": c1,
+ "fontweight": "bold",
+ "labels": [str(index_g1)],
+ }
+ d2 = {
+ "x": gx[index_g2],
+ "y": gy[index_g2],
+ "size": size_indices,
+ "color": c2,
+ "fontweight": "bold",
+ "labels": [str(index_g2)],
+ }
+ add_pointlabels(ax[0], d)
+ add_pointlabels(ax[0], d0)
+ add_pointlabels(ax[0], d1)
+ add_pointlabels(ax[0], d2)
+
+ # Add vectors to right panel
+ dg1 = {
+ "x0": gx[index_origin],
+ "y0": gy[index_origin],
+ "vx": g1[0],
+ "vy": g1[1],
+ "width": width_vectors,
+ "color": c_vectors,
+ "label": r"$g_1$",
+ "labelsize": size_vectorlabels,
+ "labelcolor": c_vectorlabels,
+ }
+ dg2 = {
+ "x0": gx[index_origin],
+ "y0": gy[index_origin],
+ "vx": g2[0],
+ "vy": g2[1],
+ "width": width_vectors,
+ "color": c_vectors,
+ "label": r"$g_2$",
+ "labelsize": size_vectorlabels,
+ "labelcolor": c_vectorlabels,
+ }
+ add_vector(ax[1], dg1)
+ add_vector(ax[1], dg2)
+
+ # store vectors
+ self.g = g
+ self.g0 = g0
+ self.g1 = g1
+ self.g2 = g2
+
+ # center the bragg directions and store
+ braggdirections.data["qx"] -= self.origin[0]
+ braggdirections.data["qy"] -= self.origin[1]
+ self.braggdirections = braggdirections
+
+ # return
+ if returncalc and returnfig:
+ return (self.g0, self.g1, self.g2, self.braggdirections), (fig, ax)
+ elif returncalc:
+ return (self.g0, self.g1, self.g2, self.braggdirections)
+ elif returnfig:
+ return (fig, ax)
+ else:
+ return
+
+ def set_max_peak_spacing(
+ self,
+ max_peak_spacing,
+ returnfig=False,
+ **vis_params,
+ ):
+ """
+ Set the size of the regions of diffraction space in which detected Bragg
+ peaks will be indexed and included in subsequent fitting of basis
+ vectors, and visualize those regions.
+
+ Parameters
+ ----------
+ max_peak_spacing : number
+ The maximum allowable distance in pixels between a detected Bragg peak and
+ the indexed maxima found in `choose_basis_vectors` for the detected
+ peak to be indexed
+ returnfig : bool
+ Toggles returning the figure
+ vis_params : dict
+ Any additional arguments are passed to the `show` function when
+ visualization the BVM
+ """
+ # set the max peak spacing
+ self.max_peak_spacing = max_peak_spacing
+
+ # make the figure
+ fig, ax = show(
+ self.bvm.data,
+ returnfig=True,
+ **vis_params,
+ )
+
+ # make the circle patch collection
+ patches = []
+ qx = self.braggdirections["qx"]
+ qy = self.braggdirections["qy"]
+ origin = self.origin
+ for idx in range(len(qx)):
+ c = Circle(
+ xy=(qy[idx] + origin[1], qx[idx] + origin[0]),
+ radius=self.max_peak_spacing,
+ edgecolor="r",
+ fill=False,
+ )
+ patches.append(c)
+ pc = PatchCollection(patches, match_original=True)
+
+ # draw the circles
+ ax.add_collection(pc)
+
+ # return
+ if returnfig:
+ return fig, ax
+ else:
+ plt.show()
+
+ def fit_basis_vectors(
+ self, mask=None, max_peak_spacing=None, vis_params={}, returncalc=False
+ ):
+ """
+ Fit the basis lattice vectors to the detected Bragg peaks at each
+ scan position.
+
+ First, the lattice vectors at each scan position are indexed using the
+ basis vectors g1 and g2 specified previously with `choose_basis_vectors`
+ Detected Bragg peaks which are farther from the set of lattice vectors
+ found in `choose_basis vectors` than the maximum peak spacing are
+ ignored; the maximum peak spacing can be set previously by calling
+ `set_max_peak_spacing` or by specifying the `max_peak_spacing` argument
+ here. A fit is then performed to refine the values of g1 and g2 at each
+ scan position, fitting the basis vectors to all detected and indexed
+ peaks, weighting the peaks according to their intensity.
+
+ Parameters
+ ----------
+ mask : 2d boolean array
+ A real space shaped Boolean mask indicating scan positions at which
+ to fit the lattice vectors.
+ max_peak_spacing : float
+ Maximum distance from the ideal lattice points to include a peak
+ for indexing
+ vis_params : dict
+ Visualization parameters for showing the max peak spacing; ignored
+ if `max_peak_spacing` is not set
+ returncalc : bool
+ if True, returns bragg_directions, bragg_vectors_indexed, g1g2_map
+ """
+ # check the calstate
+ assert (
+ self.calstate == self.braggvectors.calstate
+ ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`."
+
+ # handle the max peak spacing
+ if max_peak_spacing is not None:
+ self.set_max_peak_spacing(max_peak_spacing, **vis_params)
+ assert hasattr(self, "max_peak_spacing"), "Set the maximum peak spacing!"
+
+ # index the bragg vectors
+
+ # handle the mask
+ if mask is None:
+ mask = np.ones(self.braggvectors.Rshape, dtype=bool)
+ assert (
+ mask.shape == self.braggvectors.Rshape
+ ), "mask must have same shape as pointlistarray"
+ assert mask.dtype == bool, "mask must be boolean"
+ self.mask = mask
+
+ # set up new braggpeaks PLA
+ indexed_braggpeaks = PointListArray(
+ dtype=[
+ ("qx", float),
+ ("qy", float),
+ ("intensity", float),
+ ("h", int),
+ ("k", int),
+ ],
+ shape=self.braggvectors.Rshape,
+ )
+
+ # loop over all the scan positions
+ # and perform indexing, excluding peaks outside of max_peak_spacing
+ calstate = self.braggvectors.calstate
+ for Rx, Ry in tqdmnd(
+ mask.shape[0],
+ mask.shape[1],
+ desc="Indexing Bragg scattering",
+ unit="DP",
+ unit_scale=True,
+ ):
+ if mask[Rx, Ry]:
+ pl = self.braggvectors.get_vectors(
+ Rx,
+ Ry,
+ center=True,
+ ellipse=calstate["ellipse"],
+ rotate=calstate["rotate"],
+ pixel=False,
+ )
+ for i in range(pl.data.shape[0]):
+ r = np.hypot(
+ pl.data["qx"][i] - self.braggdirections.data["qx"],
+ pl.data["qy"][i] - self.braggdirections.data["qy"],
+ )
+ ind = np.argmin(r)
+ if r[ind] <= self.max_peak_spacing:
+ indexed_braggpeaks[Rx, Ry].add_data_by_field(
+ (
+ pl.data["qx"][i],
+ pl.data["qy"][i],
+ pl.data["intensity"][i],
+ self.braggdirections.data["h"][ind],
+ self.braggdirections.data["k"][ind],
+ )
+ )
+ self.bragg_vectors_indexed = indexed_braggpeaks
+
+ # fit bragg vectors
+ g1g2_map = fit_lattice_vectors_all_DPs(self.bragg_vectors_indexed)
+ self.g1g2_map = g1g2_map
+
+ # update the mask
+ g1g2_mask = self.g1g2_map["mask"].data.astype("bool")
+ self.mask = np.logical_and(self.mask, g1g2_mask)
+
+ # return
+ if returncalc:
+ return self.bragg_vectors_indexed, self.g1g2_map
+
+ def get_strain(
+ self, gvects=None, coordinate_rotation=0, returncalc=False, **kwargs
+ ):
+ """
+ Compute the strain as the deviation of the basis reciprocal lattice
+ vectors which have been fit at each scan position with respect to a
+ pair of reference lattice vectors, determined by the argument `gvects`.
+
+ Parameters
+ ----------
+ gvects : None or 2d-array or tuple
+ Specifies how to select the reference lattice vectors. If None,
+ use the median of the fit lattice vectors over the whole dataset.
+ If a 2d array is passed, it should be real space shaped and boolean.
+ In this case, uses the median of the fit lattice vectors in all scan
+ positions where this array is True. Otherwise, should be a length 2
+ tuple of length 2 array/list/tuples, which are used directly as
+ g1 and g2.
+ coordinate_rotation : number
+ Rotate the reference coordinate system counterclockwise by this
+ amount, in degrees
+ returncal : bool
+ It True, returns rotated map
+ """
+ # confirm that the calstate hasn't changed
+ assert (
+ self.calstate == self.braggvectors.calstate
+ ), "The calibration state has changed! To resync the calibration state, use `.reset_calstate`."
+
+ # get the reference g-vectors
+ if gvects is None:
+ g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, self.mask)
+ elif isinstance(gvects, np.ndarray):
+ assert gvects.shape == self.rshape
+ assert gvects.dtype == bool
+ g1_ref, g2_ref = get_reference_g1g2(
+ self.g1g2_map, np.logical_and(gvects, self.mask)
+ )
+ else:
+ g1_ref = np.array(gvects[0])
+ g2_ref = np.array(gvects[1])
+
+ # find the strain
+ strainmap_g1g2 = get_strain_from_reference_g1g2(self.g1g2_map, g1_ref, g2_ref)
+ self.strainmap_g1g2 = strainmap_g1g2
+
+ # get the reference coordinate system
+ theta = np.radians(coordinate_rotation)
+ xaxis_x = np.cos(theta)
+ xaxis_y = np.sin(theta)
+ self.coordinate_rotation_degrees = coordinate_rotation
+ self.coordinate_rotation_radians = theta
+
+ # get the strain in the reference coordinates
+ strainmap_rotated = get_rotated_strain_map(
+ self.strainmap_g1g2,
+ xaxis_x=xaxis_x,
+ xaxis_y=xaxis_y,
+ flip_theta=False,
+ )
+
+ # store the data
+ self.data[0] = strainmap_rotated["e_xx"].data
+ self.data[1] = strainmap_rotated["e_yy"].data
+ self.data[2] = strainmap_rotated["e_xy"].data
+ self.data[3] = strainmap_rotated["theta"].data
+ self.data[4] = strainmap_rotated["mask"].data
+
+ # plot the results
+ fig, ax = self.show_strain(
+ **kwargs,
+ returnfig=True,
+ )
+
+ # return
+ if returncalc:
+ return self.strainmap
+
+ def get_reference_g1g2(self, ROI):
+ """
+ Get reference g1,g2 vectors by taking the median fit vectors
+ in the specified ROI.
+
+ Parameters
+ ----------
+ ROI : real space shaped 2d boolean ndarray
+ Use scan positions where ROI is True
+
+ Returns
+ -------
+ g1_ref,g2_ref : 2 tuple of length 2 ndarrays
+ """
+ g1_ref, g2_ref = get_reference_g1g2(self.g1g2_map, ROI)
+ return g1_ref, g2_ref
+
+ def show_strain(
+ self,
+ vrange=[-3, 3],
+ vrange_theta=[-3, 3],
+ vrange_exx=None,
+ vrange_exy=None,
+ vrange_eyy=None,
+ bkgrd=True,
+ show_cbars=None,
+ bordercolor="k",
+ borderwidth=1,
+ titlesize=18,
+ ticklabelsize=10,
+ ticknumber=5,
+ unitlabelsize=16,
+ cmap="RdBu_r",
+ cmap_theta="PRGn",
+ mask_color="k",
+ color_axes="k",
+ show_gvects=True,
+ color_gvects="r",
+ legend_camera_length=1.6,
+ scale_gvects=0.6,
+ layout="square",
+ figsize=None,
+ returnfig=False,
+ ):
+ """
+ Display a strain map, showing the 4 strain components
+ (e_xx,e_yy,e_xy,theta), and masking each image with
+ strainmap.get_slice('mask')
+
+ Parameters
+ ----------
+ vrange : length 2 list or tuple
+ The colorbar intensity range for exx,eyy, and exy.
+ vrange_theta : length 2 list or tuple
+ The colorbar intensity range for theta.
+ vrange_exx : length 2 list or tuple
+ The colorbar intensity range for exx; overrides `vrange`
+ for exx
+ vrange_exy : length 2 list or tuple
+ The colorbar intensity range for exy; overrides `vrange`
+ for exy
+ vrange_eyy : length 2 list or tuple
+ The colorbar intensity range for eyy; overrides `vrange`
+ for eyy
+ bkgrd : bool
+ Overlay a mask over background pixels
+ show_cbars : None or a tuple of strings
+ Show colorbars for the specified axes. Valid strings are
+ 'exx', 'eyy', 'exy', and 'theta'.
+ bordercolor : color
+ Color for the image borders
+ borderwidth : number
+ Width of the image borders
+ titlesize : number
+ Size of the image titles
+ ticklabelsize : number
+ Size of the colorbar ticks
+ ticknumber : number
+ Number of ticks on colorbars
+ unitlabelsize : number
+ Size of the units label on the colorbars
+ cmap : colormap
+ Colormap for exx, exy, and eyy
+ cmap_theta : colormap
+ Colormap for theta
+ mask_color : color
+ Color for the background mask
+ color_axes : color
+ Color for the legend coordinate axes
+ show_gvects : bool
+ Toggles displaying the g-vectors in the legend
+ color_gvects : color
+ Color for the legend g-vectors
+ legend_camera_length : number
+ The distance the legend is viewed from; a smaller number yields
+ a larger legend
+ scale_gvects : number
+ Scaling for the legend g-vectors relative to the coordinate axes
+ layout : int
+ Determines the layout of the grid which the strain components
+ will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1).
+ figsize : length 2 tuple of numbers
+ Size of the figure
+ returnfig : bool
+ Toggles returning the figure
+ """
+ # Lookup table for different layouts
+ assert layout in ("square", "horizontal", "vertical")
+ layout_lookup = {
+ "square": ["left", "right", "left", "right"],
+ "horizontal": ["bottom", "bottom", "bottom", "bottom"],
+ "vertical": ["right", "right", "right", "right"],
+ }
+
+ layout_p = layout_lookup[layout]
+
+ # Set which colorbars to display
+ if show_cbars is None:
+ if np.all(
+ [
+ v is None
+ for v in (
+ vrange_exx,
+ vrange_eyy,
+ vrange_exy,
+ )
+ ]
+ ):
+ show_cbars = ("eyy", "theta")
+ else:
+ show_cbars = ("exx", "eyy", "exy", "theta")
+ else:
+ assert np.all([v in ("exx", "eyy", "exy", "theta") for v in show_cbars])
+
+ # Contrast limits
+ if vrange_exx is None:
+ vrange_exx = vrange
+ if vrange_exy is None:
+ vrange_exy = vrange
+ if vrange_eyy is None:
+ vrange_eyy = vrange
+ for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta):
+ assert len(vrange) == 2, "vranges must have length 2"
+ vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0
+ vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0
+ vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0
+ # theta is plotted in units of degrees
+ vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / (
+ 180.0 / np.pi
+ )
+
+ # Get images
+ e_xx = np.ma.array(
+ self.get_slice("exx").data,
+ mask=self.get_slice("mask").data == False, # noqa: E712,E501
+ )
+ e_yy = np.ma.array(
+ self.get_slice("eyy").data,
+ mask=self.get_slice("mask").data == False, # noqa: E712,E501
+ )
+ e_xy = np.ma.array(
+ self.get_slice("exy").data,
+ mask=self.get_slice("mask").data == False, # noqa: E712,E501
+ )
+ theta = np.ma.array(
+ self.get_slice("theta").data,
+ mask=self.get_slice("mask").data == False, # noqa: E712
+ )
+
+ ## Plot
+
+ # if figsize hasn't been set, set it based on the
+ # chosen layout and the image shape
+ if figsize is None:
+ ratio = np.sqrt(self.rshape[1] / self.rshape[0])
+ if layout == "square":
+ figsize = (13 * ratio, 8 / ratio)
+ elif layout == "horizontal":
+ figsize = (10 * ratio, 4 / ratio)
+ else:
+ figsize = (4 * ratio, 10 / ratio)
+
+ # set up layout
+ if layout == "square":
+ fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) = plt.subplots(
+ 2, 3, figsize=figsize
+ )
+ elif layout == "horizontal":
+ figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2))
+ fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots(
+ 1, 5, figsize=figsize
+ )
+ else:
+ figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2))
+ fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots(
+ 5, 1, figsize=figsize
+ )
+
+ # display images, returning cbar axis references
+ cax11 = show(
+ e_xx,
+ figax=(fig, ax11),
+ vmin=vmin_exx,
+ vmax=vmax_exx,
+ intensity_range="absolute",
+ cmap=cmap,
+ mask=self.mask,
+ mask_color=mask_color,
+ returncax=True,
+ )
+ cax12 = show(
+ e_yy,
+ figax=(fig, ax12),
+ vmin=vmin_eyy,
+ vmax=vmax_eyy,
+ intensity_range="absolute",
+ cmap=cmap,
+ mask=self.mask,
+ mask_color=mask_color,
+ returncax=True,
+ )
+ cax21 = show(
+ e_xy,
+ figax=(fig, ax21),
+ vmin=vmin_exy,
+ vmax=vmax_exy,
+ intensity_range="absolute",
+ cmap=cmap,
+ mask=self.mask,
+ mask_color=mask_color,
+ returncax=True,
+ )
+ cax22 = show(
+ theta,
+ figax=(fig, ax22),
+ vmin=vmin_theta,
+ vmax=vmax_theta,
+ intensity_range="absolute",
+ cmap=cmap_theta,
+ mask=self.mask,
+ mask_color=mask_color,
+ returncax=True,
+ )
+ ax11.set_title(r"$\epsilon_{xx}$", size=titlesize)
+ ax12.set_title(r"$\epsilon_{yy}$", size=titlesize)
+ ax21.set_title(r"$\epsilon_{xy}$", size=titlesize)
+ ax22.set_title(r"$\theta$", size=titlesize)
+
+ # Add black background
+ if bkgrd:
+ mask = np.ma.masked_where(
+ self.get_slice("mask").data.astype(bool),
+ np.zeros_like(self.get_slice("mask").data),
+ )
+ ax11.matshow(mask, cmap="gray")
+ ax12.matshow(mask, cmap="gray")
+ ax21.matshow(mask, cmap="gray")
+ ax22.matshow(mask, cmap="gray")
+
+ # add colorbars
+ show_cbars = np.array(
+ [
+ "exx" in show_cbars,
+ "eyy" in show_cbars,
+ "exy" in show_cbars,
+ "theta" in show_cbars,
+ ]
+ )
+ if np.any(show_cbars):
+ divider11 = make_axes_locatable(ax11)
+ divider12 = make_axes_locatable(ax12)
+ divider21 = make_axes_locatable(ax21)
+ divider22 = make_axes_locatable(ax22)
+ cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15)
+ cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15)
+ cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15)
+ cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15)
+ for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip(
+ range(4),
+ show_cbars,
+ (cax11, cax12, cax21, cax22),
+ (cbax11, cbax12, cbax21, cbax22),
+ (vmin_exx, vmin_eyy, vmin_exy, vmin_theta),
+ (vmax_exx, vmax_eyy, vmax_exy, vmax_theta),
+ (layout_p[0], layout_p[1], layout_p[2], layout_p[3]),
+ ("% ", " %", "% ", r" $^\circ$"),
+ ):
+ if show_cbar:
+ ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True)
+ if ind < 3:
+ ticklabels = np.round(
+ np.linspace(
+ 100 * vmin, 100 * vmax, ticknumber, endpoint=True
+ ),
+ decimals=2,
+ ).astype(str)
+ else:
+ ticklabels = np.round(
+ np.linspace(
+ (180 / np.pi) * vmin,
+ (180 / np.pi) * vmax,
+ ticknumber,
+ endpoint=True,
+ ),
+ decimals=2,
+ ).astype(str)
+
+ if tickside in ("left", "right"):
+ cb = plt.colorbar(
+ cax, cax=cbax, ticks=ticks, orientation="vertical"
+ )
+ cb.ax.set_yticklabels(ticklabels, size=ticklabelsize)
+ cbax.yaxis.set_ticks_position(tickside)
+ cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0)
+ cbax.yaxis.set_label_position(tickside)
+ else:
+ cb = plt.colorbar(
+ cax, cax=cbax, ticks=ticks, orientation="horizontal"
+ )
+ cb.ax.set_xticklabels(ticklabels, size=ticklabelsize)
+ cbax.xaxis.set_ticks_position(tickside)
+ cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0)
+ cbax.xaxis.set_label_position(tickside)
+ else:
+ cbax.axis("off")
+
+ # Add borders
+ if bordercolor is not None:
+ for ax in (ax11, ax12, ax21, ax22):
+ for s in ["bottom", "top", "left", "right"]:
+ ax.spines[s].set_color(bordercolor)
+ ax.spines[s].set_linewidth(borderwidth)
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+ # Legend
+
+ # for layout "square", combine vertical plots on the right end
+ if layout == "square":
+ # get gridspec object
+ gs = ax_legend1.get_gridspec()
+ # remove last two axes
+ ax_legend1.remove()
+ ax_legend2.remove()
+ # make new axis
+ ax_legend = fig.add_subplot(gs[:, -1])
+
+ # get the coordinate axes' directions
+ rotation = self.coordinate_rotation_radians
+ xaxis_vectx = np.cos(rotation)
+ xaxis_vecty = np.sin(rotation)
+ yaxis_vectx = np.cos(rotation + np.pi / 2)
+ yaxis_vecty = np.sin(rotation + np.pi / 2)
+
+ # make the coordinate axes
+ ax_legend.arrow(
+ x=0,
+ y=0,
+ dx=xaxis_vecty,
+ dy=xaxis_vectx,
+ color=color_axes,
+ length_includes_head=True,
+ width=0.01,
+ head_width=0.1,
+ )
+ ax_legend.arrow(
+ x=0,
+ y=0,
+ dx=yaxis_vecty,
+ dy=yaxis_vectx,
+ color=color_axes,
+ length_includes_head=True,
+ width=0.01,
+ head_width=0.1,
+ )
+ ax_legend.text(
+ x=xaxis_vecty * 1.16,
+ y=xaxis_vectx * 1.16,
+ s="x",
+ fontsize=14,
+ color=color_axes,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+ ax_legend.text(
+ x=yaxis_vecty * 1.16,
+ y=yaxis_vectx * 1.16,
+ s="y",
+ fontsize=14,
+ color=color_axes,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+
+ # make the g-vectors
+ if show_gvects:
+ # get the g-vectors directions
+ g1q = np.array(self.g1)
+ g2q = np.array(self.g2)
+ g1norm = np.linalg.norm(g1q)
+ g2norm = np.linalg.norm(g2q)
+ g1q /= g1norm
+ g2q /= g2norm
+ # set the lengths
+ g_ratio = g2norm / g1norm
+ if g_ratio > 1:
+ g1q /= g_ratio
+ else:
+ g2q *= g_ratio
+ g1_x, g1_y = g1q
+ g2_x, g2_y = g2q
+
+ # draw the g vectors
+ ax_legend.arrow(
+ x=0,
+ y=0,
+ dx=g1_y * scale_gvects,
+ dy=g1_x * scale_gvects,
+ color=color_gvects,
+ length_includes_head=True,
+ width=0.005,
+ head_width=0.05,
+ )
+ ax_legend.arrow(
+ x=0,
+ y=0,
+ dx=g2_y * scale_gvects,
+ dy=g2_x * scale_gvects,
+ color=color_gvects,
+ length_includes_head=True,
+ width=0.005,
+ head_width=0.05,
+ )
+ ax_legend.text(
+ x=g1_y * scale_gvects * 1.2,
+ y=g1_x * scale_gvects * 1.2,
+ s=r"$g_1$",
+ fontsize=12,
+ color=color_gvects,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+ ax_legend.text(
+ x=g2_y * scale_gvects * 1.2,
+ y=g2_x * scale_gvects * 1.2,
+ s=r"$g_2$",
+ fontsize=12,
+ color=color_gvects,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+
+ # find center and extent
+ xmin = np.min([0, 0, xaxis_vectx, yaxis_vectx])
+ xmax = np.max([0, 0, xaxis_vectx, yaxis_vectx])
+ ymin = np.min([0, 0, xaxis_vecty, yaxis_vecty])
+ ymax = np.max([0, 0, xaxis_vecty, yaxis_vecty])
+ if show_gvects:
+ xmin = np.min([xmin, g1_x, g2_x])
+ xmax = np.max([xmax, g1_x, g2_x])
+ ymin = np.min([ymin, g1_y, g2_y])
+ ymax = np.max([ymax, g1_y, g2_y])
+ x0 = np.mean([xmin, xmax])
+ y0 = np.mean([ymin, ymax])
+ xL = (xmax - x0) * legend_camera_length
+ yL = (ymax - y0) * legend_camera_length
+
+ # set the extent and aspect
+ ax_legend.set_xlim([y0 - yL, y0 + yL])
+ ax_legend.set_ylim([x0 - xL, x0 + xL])
+ ax_legend.invert_yaxis()
+ ax_legend.set_aspect("equal")
+ ax_legend.axis("off")
+
+ # show/return
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ axs = ((ax11, ax12), (ax21, ax22))
+ return fig, axs
+
+ def show_reference_directions(
+ self,
+ im_uncal=None,
+ im_cal=None,
+ color_axes="linen",
+ color_gvects="r",
+ origin_uncal=None,
+ origin_cal=None,
+ camera_length=1.8,
+ visp_uncal={"scaling": "log"},
+ visp_cal={"scaling": "log"},
+ layout="horizontal",
+ titlesize=16,
+ size_labels=14,
+ figsize=None,
+ returnfig=False,
+ ):
+ """
+ Show the reference coordinate system used to compute the strain
+ overlaid over calibrated and uncalibrated diffraction space images.
+
+ The diffraction images used can be specificied with the `im_uncal`
+ and `im_cal` arguments, and default to the uncalibrated and calibrated
+ Bragg vector maps. The `rotate_cal` argument causes the `im_cal` array
+ to be rotated by -QR rotation from the calibration metadata, so that an
+ uncalibrated image (like a raw diffraction image or mean or max
+ diffraction pattern) can be passed to the `im_cal` argument.
+
+ Parameters
+ ----------
+ im_uncal : 2d array or None
+ Uncalibrated diffraction space image to dispay; defaults to
+ the maximal diffraction image.
+ im_cal : 2d array or None
+ Calibrated diffraction space image to display; defaults to
+ the calibrated Bragg vector map.
+ color_axes : color
+ The color of the overlaid coordinate axes
+ color_gvects : color
+ The color of the g-vectors
+ origin_uncal : 2-tuple or None
+ Where to place the origin of the coordinate system overlaid on
+ the uncalibrated diffraction image. Defaults to the mean origin
+ from the calibration metadata.
+ origin_cal : 2-tuple or None
+ Where to place the origin of the coordinate system overlaid on
+ the calibrated diffraction image. Defaults to the mean origin
+ from the calibration metadata.
+ camera_length : number
+ Determines the length of the overlaid coordinate axes; a smaller
+ number yields larger axes.
+ visp_uncal : dict
+ Visualization parameters for the uncalibrated diffraction image.
+ visp_cal : dict
+ Visualization parameters for the calibrated diffraction image.
+ layout : str; either "horizontal" or "vertical"
+ Determines the layout of the visualization.
+ titlesize : number
+ The size of the plot titles
+ size_labels : number
+ The size of the axis labels
+ figsize : length 2 tuple of numbers or None
+ Size of the figure
+ returnfig : bool
+ Toggles returning the figure
+ """
+ # Set up the figure
+ assert layout in ("horizontal", "vertical")
+
+ # Set the figsize
+ if figsize is None:
+ ratio = np.sqrt(self.rshape[1] / self.rshape[0])
+ if layout == "horizontal":
+ figsize = (10 * ratio, 8 / ratio)
+ else:
+ figsize = (8 * ratio, 12 / ratio)
+
+ # Create the figure
+ if layout == "horizontal":
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
+ else:
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)
+
+ # prepare images
+ if im_uncal is None:
+ im_uncal = self.braggvectors.histogram(mode="raw")
+ if im_cal is None:
+ im_cal = self.braggvectors.histogram(mode="cal")
+
+ # display images
+ show(im_cal, figax=(fig, ax1), **visp_cal)
+ show(im_uncal, figax=(fig, ax2), **visp_uncal)
+ ax1.set_title("Calibrated", size=titlesize)
+ ax2.set_title("Uncalibrated", size=titlesize)
+
+ # Get the coordinate axes
+
+ # get the directions
+
+ # calibrated
+ rotation = self.coordinate_rotation_radians
+ xaxis_cal = np.array([np.cos(rotation), np.sin(rotation)])
+ yaxis_cal = np.array(
+ [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)]
+ )
+
+ # uncalibrated
+ QRrot = self.calibration.get_QR_rotation()
+ rotation = np.sum([self.coordinate_rotation_radians, -QRrot])
+ xaxis_uncal = np.array([np.cos(rotation), np.sin(rotation)])
+ yaxis_uncal = np.array(
+ [np.cos(rotation + np.pi / 2), np.sin(rotation + np.pi / 2)]
+ )
+ # inversion
+ if self.calibration.get_QR_flip():
+ xaxis_uncal = np.array([xaxis_uncal[1], xaxis_uncal[0]])
+ yaxis_uncal = np.array([yaxis_uncal[1], yaxis_uncal[0]])
+
+ # set the lengths
+ Lmean = np.mean([im_cal.shape[0], im_cal.shape[1]]) / 2
+ xaxis_cal *= Lmean / camera_length
+ yaxis_cal *= Lmean / camera_length
+ xaxis_uncal *= Lmean / camera_length
+ yaxis_uncal *= Lmean / camera_length
+
+ # Get the g-vectors
+
+ # calibrated
+ g1_cal = np.array(self.g1)
+ g2_cal = np.array(self.g2)
+
+ # uncalibrated
+ R = np.array([[np.cos(QRrot), -np.sin(QRrot)], [np.sin(QRrot), np.cos(QRrot)]])
+ g1_uncal = np.matmul(g1_cal, R)
+ g2_uncal = np.matmul(g2_cal, R)
+ # inversion
+ if self.calibration.get_QR_flip():
+ g1_uncal = np.array([g1_uncal[1], g1_uncal[0]])
+ g2_uncal = np.array([g2_uncal[1], g2_uncal[0]])
+
+ # Set origin positions
+ if origin_uncal is None:
+ origin_uncal = self.calibration.get_origin_mean()
+ if origin_cal is None:
+ origin_cal = self.calibration.get_origin_mean()
+
+ # Draw calibrated coordinate axes
+ coordax_width = Lmean * 2 / 100
+ ax1.arrow(
+ x=origin_cal[1],
+ y=origin_cal[0],
+ dx=xaxis_cal[1],
+ dy=xaxis_cal[0],
+ color=color_axes,
+ length_includes_head=True,
+ width=coordax_width,
+ head_width=coordax_width * 5,
+ )
+ ax1.arrow(
+ x=origin_cal[1],
+ y=origin_cal[0],
+ dx=yaxis_cal[1],
+ dy=yaxis_cal[0],
+ color=color_axes,
+ length_includes_head=True,
+ width=coordax_width,
+ head_width=coordax_width * 5,
+ )
+ ax1.text(
+ x=origin_cal[1] + xaxis_cal[1] * 1.16,
+ y=origin_cal[0] + xaxis_cal[0] * 1.16,
+ s="x",
+ fontsize=size_labels,
+ color=color_axes,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+ ax1.text(
+ x=origin_cal[1] + yaxis_cal[1] * 1.16,
+ y=origin_cal[0] + yaxis_cal[0] * 1.16,
+ s="y",
+ fontsize=size_labels,
+ color=color_axes,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+
+ # Draw uncalibrated coordinate axes
+ ax2.arrow(
+ x=origin_uncal[1],
+ y=origin_uncal[0],
+ dx=xaxis_uncal[1],
+ dy=xaxis_uncal[0],
+ color=color_axes,
+ length_includes_head=True,
+ width=coordax_width,
+ head_width=coordax_width * 5,
+ )
+ ax2.arrow(
+ x=origin_uncal[1],
+ y=origin_uncal[0],
+ dx=yaxis_uncal[1],
+ dy=yaxis_uncal[0],
+ color=color_axes,
+ length_includes_head=True,
+ width=coordax_width,
+ head_width=coordax_width * 5,
+ )
+ ax2.text(
+ x=origin_uncal[1] + xaxis_uncal[1] * 1.16,
+ y=origin_uncal[0] + xaxis_uncal[0] * 1.16,
+ s="x",
+ fontsize=size_labels,
+ color=color_axes,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+ ax2.text(
+ x=origin_uncal[1] + yaxis_uncal[1] * 1.16,
+ y=origin_uncal[0] + yaxis_uncal[0] * 1.16,
+ s="y",
+ fontsize=size_labels,
+ color=color_axes,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+
+ # Draw the calibrated g-vectors
+
+ # draw the g vectors
+ ax1.arrow(
+ x=origin_cal[1],
+ y=origin_cal[0],
+ dx=g1_cal[1],
+ dy=g1_cal[0],
+ color=color_gvects,
+ length_includes_head=True,
+ width=coordax_width * 0.5,
+ head_width=coordax_width * 2.5,
+ )
+ ax1.arrow(
+ x=origin_cal[1],
+ y=origin_cal[0],
+ dx=g2_cal[1],
+ dy=g2_cal[0],
+ color=color_gvects,
+ length_includes_head=True,
+ width=coordax_width * 0.5,
+ head_width=coordax_width * 2.5,
+ )
+ ax1.text(
+ x=origin_cal[1] + g1_cal[1] * 1.16,
+ y=origin_cal[0] + g1_cal[0] * 1.16,
+ s=r"$g_1$",
+ fontsize=size_labels * 0.88,
+ color=color_gvects,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+ ax1.text(
+ x=origin_cal[1] + g2_cal[1] * 1.16,
+ y=origin_cal[0] + g2_cal[0] * 1.16,
+ s=r"$g_2$",
+ fontsize=size_labels * 0.88,
+ color=color_gvects,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+
+ # Draw the uncalibrated g-vectors
+
+ # draw the g vectors
+ ax2.arrow(
+ x=origin_uncal[1],
+ y=origin_uncal[0],
+ dx=g1_uncal[1],
+ dy=g1_uncal[0],
+ color=color_gvects,
+ length_includes_head=True,
+ width=coordax_width * 0.5,
+ head_width=coordax_width * 2.5,
+ )
+ ax2.arrow(
+ x=origin_uncal[1],
+ y=origin_uncal[0],
+ dx=g2_uncal[1],
+ dy=g2_uncal[0],
+ color=color_gvects,
+ length_includes_head=True,
+ width=coordax_width * 0.5,
+ head_width=coordax_width * 2.5,
+ )
+ ax2.text(
+ x=origin_uncal[1] + g1_uncal[1] * 1.16,
+ y=origin_uncal[0] + g1_uncal[0] * 1.16,
+ s=r"$g_1$",
+ fontsize=size_labels * 0.88,
+ color=color_gvects,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+ ax2.text(
+ x=origin_uncal[1] + g2_uncal[1] * 1.16,
+ y=origin_uncal[0] + g2_uncal[0] * 1.16,
+ s=r"$g_2$",
+ fontsize=size_labels * 0.88,
+ color=color_gvects,
+ horizontalalignment="center",
+ verticalalignment="center",
+ )
+
+ # show/return
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, (ax1, ax2)
+
+ def show_lattice_vectors(
+ ar,
+ x0,
+ y0,
+ g1,
+ g2,
+ color="r",
+ width=1,
+ labelsize=20,
+ labelcolor="w",
+ returnfig=False,
+ **kwargs,
+ ):
+ """
+ Adds the vectors g1,g2 to an image, with tail positions at (x0,y0).
+ g1 and g2 are 2-tuples (gx,gy).
+ """
+ fig, ax = show(ar, returnfig=True, **kwargs)
+
+ # Add vectors
+ dg1 = {
+ "x0": x0,
+ "y0": y0,
+ "vx": g1[0],
+ "vy": g1[1],
+ "width": width,
+ "color": color,
+ "label": r"$g_1$",
+ "labelsize": labelsize,
+ "labelcolor": labelcolor,
+ }
+ dg2 = {
+ "x0": x0,
+ "y0": y0,
+ "vx": g2[0],
+ "vy": g2[1],
+ "width": width,
+ "color": color,
+ "label": r"$g_2$",
+ "labelsize": labelsize,
+ "labelcolor": labelcolor,
+ }
+ add_vector(ax, dg1)
+ add_vector(ax, dg2)
+
+ if returnfig:
+ return fig, ax
+ else:
+ plt.show()
+ return
+
+ def show_bragg_indexing(
+ self,
+ ar,
+ bragg_directions,
+ voffset=5,
+ hoffset=0,
+ color="w",
+ size=20,
+ points=True,
+ pointcolor="r",
+ pointsize=50,
+ figax=None,
+ returnfig=False,
+ **kwargs,
+ ):
+ """
+ Shows an array with an overlay describing the Bragg directions
+
+ Parameters
+ ----------
+ ar : np.ndarray
+ The display image
+ bragg_directions : PointList
+ The Bragg scattering directions. Must have coordinates
+ 'qx','qy','h', and 'k'. Optionally may also have 'l'.
+ """
+ assert isinstance(bragg_directions, PointList)
+ for k in ("qx", "qy", "h", "k"):
+ assert k in bragg_directions.data.dtype.fields
+
+ if figax is None:
+ fig, ax = show(ar, returnfig=True, **kwargs)
+ else:
+ fig = figax[0]
+ ax = figax[1]
+ show(ar, figax=figax, **kwargs)
+
+ d = {
+ "bragg_directions": bragg_directions,
+ "voffset": voffset,
+ "hoffset": hoffset,
+ "color": color,
+ "size": size,
+ "points": points,
+ "pointsize": pointsize,
+ "pointcolor": pointcolor,
+ }
+ add_bragg_index_labels(ax, d)
+
+ if returnfig:
+ return fig, ax
+ else:
+ return
+
+ def copy(self, name=None):
+ name = name if name is not None else self.name + "_copy"
+ strainmap_copy = StrainMap(self.braggvectors)
+ for attr in (
+ "g",
+ "g0",
+ "g1",
+ "g2",
+ "calstate",
+ "bragg_directions",
+ "bragg_vectors_indexed",
+ "g1g2_map",
+ "strainmap_g1g2",
+ "strainmap_rotated",
+ "mask",
+ ):
+ if hasattr(self, attr):
+ setattr(strainmap_copy, attr, getattr(self, attr))
+
+ for k in self.metadata.keys():
+ strainmap_copy.metadata = self.metadata[k].copy()
+ return strainmap_copy
+
+ # TODO IO methods
+
+ # read
+ @classmethod
+ def _get_constructor_args(cls, group):
+ """
+ Returns a dictionary of args/values to pass to the class constructor
+ """
+ ar_constr_args = RealSlice._get_constructor_args(group)
+ args = {
+ "data": ar_constr_args["data"],
+ "name": ar_constr_args["name"],
+ }
+ return args
diff --git a/py4DSTEM/process/utils/__init__.py b/py4DSTEM/process/utils/__init__.py
new file mode 100644
index 000000000..643de1bf5
--- /dev/null
+++ b/py4DSTEM/process/utils/__init__.py
@@ -0,0 +1,15 @@
+from py4DSTEM.process.utils.utils import *
+from py4DSTEM.process.utils.cross_correlate import *
+from py4DSTEM.process.utils.multicorr import *
+from py4DSTEM.process.utils.elliptical_coords import *
+from py4DSTEM.process.utils.masks import *
+from py4DSTEM.process.utils.single_atom_scatter import *
+
+# from preprocessing
+from py4DSTEM.preprocess.utils import (
+ bin2D,
+ get_maxima_2D,
+ get_shifted_ar,
+ filter_2D_maxima,
+ linear_interpolation_2D,
+)
diff --git a/py4DSTEM/process/utils/cross_correlate.py b/py4DSTEM/process/utils/cross_correlate.py
new file mode 100644
index 000000000..50de91e33
--- /dev/null
+++ b/py4DSTEM/process/utils/cross_correlate.py
@@ -0,0 +1,178 @@
+# Cross correlation function
+
+import numpy as np
+from py4DSTEM.preprocess.utils import get_shifted_ar
+from py4DSTEM.process.utils.multicorr import upsampled_correlation
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+
+def get_cross_correlation(ar, template, corrPower=1, _returnval="real"):
+ """
+ Get the cross/phase/hybrid correlation of `ar` with `template`, where
+ the latter is in real space.
+
+ If _returnval is 'real', returns the real-valued cross-correlation.
+ Otherwise, returns the complex valued result.
+ """
+ assert _returnval in ("real", "fourier")
+ template_FT = np.conj(np.fft.fft2(template))
+ return get_cross_correlation_FT(
+ ar, template_FT, corrPower=corrPower, _returnval=_returnval
+ )
+
+
+def get_cross_correlation_FT(ar, template_FT, corrPower=1, _returnval="real"):
+ """
+ Get the cross/phase/hybrid correlation of `ar` with `template_FT`, where
+ the latter is already in Fourier space (i.e. `template_FT` is
+ `np.conj(np.fft.fft2(template))`.
+
+ If _returnval is 'real', returns the real-valued cross-correlation.
+ Otherwise, returns the complex valued result.
+ """
+ assert _returnval in ("real", "fourier")
+ m = np.fft.fft2(ar) * template_FT
+ if corrPower != 1:
+ cc = np.abs(m) ** (corrPower) * np.exp(1j * np.angle(m))
+ else:
+ cc = m
+ if _returnval == "real":
+ cc = np.maximum(np.real(np.fft.ifft2(cc)), 0)
+ return cc
+
+
+def get_shift(ar1, ar2, corrPower=1):
+ """
+ Determine the relative shift between a pair of arrays giving the best overlap.
+
+ Shift determination uses the brightest pixel in the cross correlation, and is
+ thus limited to pixel resolution. corrPower specifies the cross correlation
+ power, with 1 corresponding to a cross correlation and 0 a phase correlation.
+
+ Args:
+ ar1,ar2 (2D ndarrays):
+ corrPower (float between 0 and 1, inclusive): 1=cross correlation, 0=phase
+ correlation
+
+ Returns:
+ (2-tuple): (shiftx,shifty) - the relative image shift, in pixels
+ """
+ cc = get_cross_correlation(ar1, ar2, corrPower)
+ xshift, yshift = np.unravel_index(np.argmax(cc), ar1.shape)
+ return xshift, yshift
+
+
+def align_images_fourier(
+ G1,
+ G2,
+ upsample_factor,
+ device="cpu",
+):
+ """
+ Alignment of two images using DFT upsampling of cross correlation.
+
+ Parameters
+ -------
+ G1: ndarray
+ fourier transform of image 1
+ G2: ndarray
+ fourier transform of image 2
+ upsample_factor: float
+ upsampling for correlation. Must be greater than 2.
+ device: str, optional
+ calculation device will be perfomed on. Must be 'cpu' or 'gpu'
+
+ Returns:
+ xy_shift [pixels]
+ """
+
+ if device == "cpu":
+ xp = np
+ elif device == "gpu":
+ xp = cp
+
+ G1 = xp.asarray(G1)
+ G2 = xp.asarray(G2)
+
+ # cross correlation
+ cc = G1 * xp.conj(G2)
+ cc_real = xp.real(xp.fft.ifft2(cc))
+
+ # local max
+ x0, y0 = xp.unravel_index(cc_real.argmax(), cc.shape)
+
+ # half pixel shifts
+ x_inds = xp.mod(x0 + xp.arange(-1, 2), cc.shape[0]).astype("int")
+ y_inds = xp.mod(y0 + xp.arange(-1, 2), cc.shape[1]).astype("int")
+
+ vx = cc_real[x_inds, y0]
+ vy = cc_real[x0, y_inds]
+ dx = (vx[2] - vx[0]) / (4 * vx[1] - 2 * vx[2] - 2 * vx[0])
+ dy = (vy[2] - vy[0]) / (4 * vy[1] - 2 * vy[2] - 2 * vy[0])
+
+ x0 = xp.round((x0 + dx) * 2.0) / 2.0
+ y0 = xp.round((y0 + dy) * 2.0) / 2.0
+
+ # subpixel shifts
+ xy_shift = upsampled_correlation(
+ cc, upsample_factor, xp.array([x0, y0]), device=device
+ )
+
+ return xy_shift
+
+
+def align_and_shift_images(
+ image_1,
+ image_2,
+ upsample_factor,
+ device="cpu",
+):
+ """
+ Alignment of two images using DFT upsampling of cross correlation.
+
+ Parameters
+ -------
+ image_1: ndarray
+ image 1
+ image_2: ndarray
+ image 2
+ upsample_factor: float
+ upsampling for correlation. Must be greater than 2.
+ device: str, optional
+ calculation device will be perfomed on. Must be 'cpu' or 'gpu'.
+
+ Returns:
+ shifted image [pixels]
+ """
+
+ if device == "cpu":
+ xp = np
+
+ elif device == "gpu":
+ xp = cp
+
+ image_1 = xp.asarray(image_1)
+ image_2 = xp.asarray(image_2)
+
+ xy_shift = align_images_fourier(
+ xp.fft.fft2(image_1),
+ xp.fft.fft2(image_2),
+ upsample_factor=upsample_factor,
+ device=device,
+ )
+ dx = (
+ xp.mod(xy_shift[0] + image_1.shape[0] / 2, image_1.shape[0])
+ - image_1.shape[0] / 2
+ )
+ dy = (
+ xp.mod(xy_shift[1] + image_1.shape[1] / 2, image_1.shape[1])
+ - image_1.shape[1] / 2
+ )
+
+ image_2_shifted = get_shifted_ar(image_2, dx, dy, device=device)
+
+ return image_2_shifted
diff --git a/py4DSTEM/process/utils/elliptical_coords.py b/py4DSTEM/process/utils/elliptical_coords.py
new file mode 100644
index 000000000..97291bc20
--- /dev/null
+++ b/py4DSTEM/process/utils/elliptical_coords.py
@@ -0,0 +1,434 @@
+"""
+Contains functions relating to polar-elliptical calculations.
+
+This includes
+ - transforming data from cartesian to polar-elliptical coordinates
+ - converting between ellipse representations
+ - radial and polar-elliptical radial integration
+
+Functions for measuring/fitting elliptical distortions are found in
+process/calibration/ellipse.py. Functions for computing radial and
+polar-elliptical radial backgrounds are found in process/preprocess/ellipse.py.
+
+py4DSTEM uses 2 ellipse representations - one user-facing representation, and
+one internal representation. The user-facing represenation is in terms of the
+following 5 parameters:
+
+ x0,y0 the center of the ellipse
+ a the semimajor axis length
+ b the semiminor axis length
+ theta the (positive, right handed) tilt of the a-axis
+ to the x-axis, in radians
+
+Internally, fits are performed using the canonical ellipse parameterization,
+in terms of the parameters (x0,y0,A,B,C):
+
+ A(x-x0)^2 + B(x-x0)(y-y0) C(y-y0)^2 = 1
+
+It is possible to convert between (a,b,theta) <--> (A,B,C) using
+the convert_ellipse_params() and convert_ellipse_params_r() methods.
+
+Transformation from cartesian to polar-elliptical space is done using
+
+ x = x0 + a*r*cos(phi)*cos(theta) + b*r*sin(phi)*sin(theta)
+ y = y0 + a*r*cos(phi)*sin(theta) - b*r*sin(phi)*cos(theta)
+
+where (r,phi) are the polar-elliptical coordinates. All angular quantities are in
+radians.
+"""
+
+import numpy as np
+
+### Convert between representations
+
+
+def convert_ellipse_params(A, B, C):
+ """
+ Converts ellipse parameters from canonical form (A,B,C) into semi-axis lengths and
+ tilt (a,b,theta).
+ See module docstring for more info.
+
+ Args:
+ A,B,C (floats): parameters of an ellipse in the form:
+ Ax^2 + Bxy + Cy^2 = 1
+
+ Returns:
+ (3-tuple): A 3-tuple consisting of:
+
+ * **a**: (float) the semimajor axis length
+ * **b**: (float) the semiminor axis length
+ * **theta**: (float) the tilt of the ellipse semimajor axis with respect to
+ the x-axis, in radians
+ """
+ val = np.sqrt((A - C) ** 2 + B**2)
+ b4a = B**2 - 4 * A * C
+ # Get theta
+ if B == 0:
+ if A < C:
+ theta = 0
+ else:
+ theta = np.pi / 2.0
+ else:
+ theta = np.arctan2((C - A - val), B)
+ # Get a,b
+ a = -np.sqrt(-2 * b4a * (A + C + val)) / b4a
+ b = -np.sqrt(-2 * b4a * (A + C - val)) / b4a
+ a, b = max(a, b), min(a, b)
+ return a, b, theta
+
+
+def convert_ellipse_params_r(a, b, theta):
+ """
+ Converts from ellipse parameters (a,b,theta) to (A,B,C).
+ See module docstring for more info.
+
+ Args:
+ a,b,theta (floats): parameters of an ellipse, where `a`/`b` are the
+ semimajor/semiminor axis lengths, and theta is the tilt of the semimajor axis
+ with respect to the x-axis, in radians.
+
+ Returns:
+ (3-tuple): A 3-tuple consisting of (A,B,C), the ellipse parameters in
+ canonical form.
+ """
+ sin2, cos2 = np.sin(theta) ** 2, np.cos(theta) ** 2
+ a2, b2 = a**2, b**2
+ A = sin2 / b2 + cos2 / a2
+ C = cos2 / b2 + sin2 / a2
+ B = 2 * (b2 - a2) * np.sin(theta) * np.cos(theta) / (a2 * b2)
+ return A, B, C
+
+
+### Polar elliptical transformation
+
+
+def cartesian_to_polarelliptical_transform(
+ cartesianData,
+ p_ellipse,
+ dr=1,
+ dphi=np.radians(2),
+ r_range=None,
+ mask=None,
+ maskThresh=0.99,
+):
+ """
+ Transforms an array of data in cartesian coordinates into a data array in
+ polar-elliptical coordinates.
+
+ Discussion of the elliptical parametrization used can be found in the docstring
+ for the process.utils.elliptical_coords module.
+
+ Args:
+ cartesianData (2D float array): the data in cartesian coordinates
+ p_ellipse (5-tuple): specifies (qx0,qy0,a,b,theta), the parameters for the
+ transformation. These are the same 5 parameters which are outputs
+ of the elliptical fitting functions in the process.calibration
+ module, e.g. fit_ellipse_amorphous_ring and fit_ellipse_1D. For
+ more details, see the process.utils.elliptical_coords module docstring
+ dr (float): sampling of the (r,phi) coords: the width of the bins in r
+ dphi (float): sampling of the (r,phi) coords: the width of the bins in phi,
+ in radians
+ r_range (number or length 2 list/tuple or None): specifies the sampling of the
+ (r,theta) coords. Precise behavior which depends on the parameter type:
+ * if None, autoselects max r value
+ * if r_range is a number, specifies the maximum r value
+ * if r_range is a length 2 list/tuple, specifies the min/max r values
+ mask (2d array of bools): shape must match cartesianData; where mask==False,
+ ignore these datapoints in making the polarElliptical data array
+ maskThresh (float): the final data mask is calculated by converting mask (above)
+ from cartesian to polar elliptical coords. Due to interpolation, this
+ results in some non-boolean values - this is converted back to a boolean
+ array by taking polarEllipticalMask = polarTrans(mask) < maskThresh. Cells
+ where polarTrans is less than 1 (i.e. has at least one masked NN) should
+ generally be masked, hence the default value of 0.99.
+
+ Returns:
+ (3-tuple): A 3-tuple, containing:
+
+ * **polarEllipticalData**: *(2D masked array)* a masked array containing
+ the data and the data mask, in polarElliptical coordinates
+ * **rr**: *(2D array)* meshgrid of the r coordinates
+ * **pp**: *(2D array)* meshgrid of the phi coordinates
+ """
+ if mask is None:
+ mask = np.ones_like(cartesianData.data, dtype=bool)
+ assert (
+ cartesianData.shape == mask.shape
+ ), "Mask and cartesian data array shapes must match."
+ assert len(p_ellipse) == 5, "p_ellipse must have length 5"
+
+ # Get params
+ qx0, qy0, a, b, theta = p_ellipse
+ Nx, Ny = cartesianData.shape
+
+ # Define r_range:
+ if r_range is None:
+ # find corners of image
+ corners = np.array(
+ [
+ [0, 0],
+ [0, cartesianData.shape[0]],
+ [0, cartesianData.shape[1]],
+ [cartesianData.shape[0], cartesianData.shape[1]],
+ ]
+ )
+ # find maximum corner distance
+ r_min, r_max = 0, np.ceil(
+ np.max(
+ np.sqrt(
+ np.sum(
+ (corners - np.broadcast_to(np.array((qx0, qy0)), corners.shape))
+ ** 2,
+ axis=1,
+ )
+ )
+ )
+ ).astype(int)
+ else:
+ try:
+ r_min, r_max = r_range[0], r_range[1]
+ except TypeError:
+ r_min, r_max = 0, r_range
+
+ # Define the r/phi coords
+ r_bins = np.arange(r_min + dr / 2.0, r_max + dr / 2.0, dr) # values are bin centers
+ p_bins = np.arange(-np.pi + dphi / 2.0, np.pi + dphi / 2.0, dphi)
+ rr, pp = np.meshgrid(r_bins, p_bins)
+ Nr, Np = rr.shape
+
+ # Get (qx,qy) corresponding to each (r,phi) in the newly defined coords
+ xr = rr * np.cos(pp)
+ yr = rr * np.sin(pp)
+ qx = qx0 + xr * np.cos(theta) - yr * (b / a) * np.sin(theta)
+ qy = qy0 + xr * np.sin(theta) + yr * (b / a) * np.cos(theta)
+
+ # qx,qy are now shape (Nr,Np) arrays, such that (qx[r,phi],qy[r,phi]) is the point
+ # in cartesian space corresponding to r,phi. We now get the values for the final
+ # polarEllipticalData array by interpolating values at these coords from the original
+ # cartesianData array.
+
+ transform_mask = (qx > 0) * (qy > 0) * (qx < Nx - 1) * (qy < Ny - 1)
+
+ # Bilinear interpolation
+ xF = np.floor(qx[transform_mask])
+ yF = np.floor(qy[transform_mask])
+ dx = qx[transform_mask] - xF
+ dy = qy[transform_mask] - yF
+ x_inds = np.vstack((xF, xF + 1, xF, xF + 1)).astype(int)
+ y_inds = np.vstack((yF, yF, yF + 1, yF + 1)).astype(int)
+ weights = np.vstack(
+ ((1 - dx) * (1 - dy), (dx) * (1 - dy), (1 - dx) * (dy), (dx) * (dy))
+ )
+ transform_mask = transform_mask.ravel()
+ polarEllipticalData = np.zeros(Nr * Np)
+ polarEllipticalData[transform_mask] = np.sum(
+ cartesianData[x_inds, y_inds] * weights, axis=0
+ )
+ polarEllipticalData = np.reshape(polarEllipticalData, (Nr, Np))
+
+ # Transform mask
+ polarEllipticalMask = np.zeros(Nr * Np)
+ polarEllipticalMask[transform_mask] = np.sum(mask[x_inds, y_inds] * weights, axis=0)
+ polarEllipticalMask = np.reshape(polarEllipticalMask, (Nr, Np))
+
+ polarEllipticalData = np.ma.array(
+ data=polarEllipticalData, mask=polarEllipticalMask < maskThresh
+ )
+ return polarEllipticalData, rr, pp
+
+
+### Cartesian elliptical transform
+
+
+def elliptical_resample_datacube(
+ datacube,
+ p_ellipse,
+ mask=None,
+ maskThresh=0.99,
+):
+ """
+ Perform elliptic resamplig on each diffraction pattern in a DataCube
+ Detailed description of the args is found in ``elliptical_resample``.
+
+ NOTE: Only use this function if you need to resample the raw data.
+ If you only need for Bragg disk positions to be corrected, use the
+ BraggVector calibration routines, as it is much faster to perform
+ this on the peak positions than the entire datacube.
+ """
+
+ from emdfile import tqdmnd
+
+ for rx, ry in tqdmnd(datacube.R_Nx, datacube.R_Ny):
+ datacube.data[rx, ry] = elliptical_resample(
+ datacube.data[rx, ry], p_ellipse, mask, maskThresh
+ )
+
+ return datacube
+
+
+def elliptical_resample(
+ data,
+ p_ellipse,
+ mask=None,
+ maskThresh=0.99,
+):
+ """
+ Resamples data with elliptic distortion to correct distortion of the
+ input pattern.
+
+ Discussion of the elliptical parametrization used can be found in the docstring
+ for the process.utils.elliptical_coords module.
+
+ Args:
+ data (2D float array): the data in cartesian coordinates
+ p_ellipse (5-tuple): specifies (qx0,qy0,a,b,theta), the parameters for the
+ transformation. These are the same 5 parameters which are outputs
+ of the elliptical fitting functions in the process.calibration
+ module, e.g. fit_ellipse_amorphous_ring and fit_ellipse_1D. For
+ more details, see the process.utils.elliptical_coords module docstring
+ dr (float): sampling of the (r,phi) coords: the width of the bins in r
+ dphi (float): sampling of the (r,phi) coords: the width of the bins in phi,
+ in radians
+ r_range (number or length 2 list/tuple or None): specifies the sampling of the
+ (r,theta) coords. Precise behavior which depends on the parameter type:
+ * if None, autoselects max r value
+ * if r_range is a number, specifies the maximum r value
+ * if r_range is a length 2 list/tuple, specifies the min/max r values
+ mask (2d array of bools): shape must match cartesianData; where mask==False,
+ ignore these datapoints in making the polarElliptical data array
+ maskThresh (float): the final data mask is calculated by converting mask (above)
+ from cartesian to polar elliptical coords. Due to interpolation, this
+ results in some non-boolean values - this is converted back to a boolean
+ array by taking polarEllipticalMask = polarTrans(mask) < maskThresh. Cells
+ where polarTrans is less than 1 (i.e. has at least one masked NN) should
+ generally be masked, hence the default value of 0.99.
+
+ Returns:
+ (3-tuple): A 3-tuple, containing:
+
+ * **resampled_data**: *(2D masked array)* a masked array containing
+ the data and the data mask, in polarElliptical coordinates
+ """
+ if mask is None:
+ mask = np.ones_like(data, dtype=bool)
+ assert data.shape == mask.shape, "Mask and data array shapes must match."
+ assert len(p_ellipse) == 5, "p_ellipse must have length 5"
+
+ # Expand params
+ qx0, qy0, a, b, theta = p_ellipse
+ Nx, Ny = data.shape
+
+ # Get (qx,qy) corresponding to the coordinates distorted by the ellipse
+ xr, yr = np.mgrid[0:Nx, 0:Ny]
+ xr0 = xr.astype(np.float_) - qx0
+ yr0 = yr.astype(np.float_) - qy0
+ xr = xr0 * np.cos(-theta) - yr0 * np.sin(-theta)
+ yr = xr0 * np.sin(-theta) + yr0 * np.cos(-theta)
+ qx = qx0 + xr * np.cos(theta) - yr * (b / a) * np.sin(theta)
+ qy = qy0 + xr * np.sin(theta) + yr * (b / a) * np.cos(theta)
+
+ # qx,qy are now shape (Nx,Ny) arrays, such that (qx[x,y],qy[x,y]) is the point
+ # in the distorted space corresponding to x,y. We now get the values for the final
+ # resampled_data array by interpolating values at these coords from the original
+ # data array.
+
+ transform_mask = (qx > 0) * (qy > 0) * (qx < Nx - 1) * (qy < Ny - 1)
+
+ # Bilinear interpolation
+ xF = np.floor(qx[transform_mask])
+ yF = np.floor(qy[transform_mask])
+ dx = qx[transform_mask] - xF
+ dy = qy[transform_mask] - yF
+ x_inds = np.vstack((xF, xF + 1, xF, xF + 1)).astype(int)
+ y_inds = np.vstack((yF, yF, yF + 1, yF + 1)).astype(int)
+ weights = np.vstack(
+ ((1 - dx) * (1 - dy), (dx) * (1 - dy), (1 - dx) * (dy), (dx) * (dy))
+ )
+ transform_mask = transform_mask.ravel()
+ resampled_data = np.zeros(Nx * Ny)
+ resampled_data[transform_mask] = np.sum(data[x_inds, y_inds] * weights, axis=0)
+ resampled_data = np.reshape(resampled_data, (Nx, Ny))
+
+ # Transform mask
+ data_mask = np.zeros(Nx * Ny)
+ data_mask[transform_mask] = np.sum(mask[x_inds, y_inds] * weights, axis=0)
+ data_mask = np.reshape(data_mask, (Nx, Ny))
+
+ resampled_data = np.ma.array(data=resampled_data, mask=data_mask < maskThresh)
+ return resampled_data
+
+
+### Radial integration
+
+
+def radial_elliptical_integral(
+ ar,
+ dr,
+ p_ellipse,
+ rmax=None,
+):
+ """
+ Computes the radial integral of array ar from center (x0,y0) with a step size in r of
+ dr.
+
+ Args:
+ ar (2d array): the data
+ dr (number): the r sampling
+ p_ellipse (5-tuple): the parameters (x0,y0,a,b,theta) for the ellipse
+ r_max (float): maximum radial value
+
+ Returns:
+ (2-tuple): A 2-tuple containing:
+
+ * **rbin_centers**: *(1d array)* the bins centers of the radial integral
+ * **radial_integral**: *(1d array)* the radial integral
+ radial_integral (1d array) the radial integral
+ """
+ x0, y0 = p_ellipse[0], p_ellipse[1]
+ if rmax is None:
+ rmax = int(
+ max(
+ (
+ np.hypot(x0, y0),
+ np.hypot(x0, ar.shape[1] - y0),
+ np.hypot(ar.shape[0] - x0, y0),
+ np.hypot(ar.shape[0] - x0, ar.shape[1] - y0),
+ )
+ )
+ )
+
+ polarAr, rr, pp = cartesian_to_polarelliptical_transform(
+ ar, p_ellipse=p_ellipse, dr=dr, dphi=np.radians(2), r_range=rmax
+ )
+ radial_integral = np.sum(polarAr, axis=0)
+ rbin_centers = rr[0, :]
+ return rbin_centers, radial_integral
+
+
+def radial_integral(ar, x0=None, y0=None, dr=0.1, rmax=None):
+ """
+ Computes the radial integral of array ar from center (x0,y0) with a step size in r of dr.
+
+ Args:
+ ar (2d array): the data
+ x0,y0 (floats): the origin
+ dr (number): radial step size
+ rmax (float): maximum radial dimension
+
+ Returns:
+ (2-tuple): A 2-tuple containing:
+
+ * **rbin_centers**: *(1d array)* the bins centers of the radial integral
+ * **radial_integral**: *(1d array)* the radial integral
+ """
+
+ # Default values
+ if x0 is None:
+ x0 = ar.shape[0] / 2
+ if y0 is None:
+ y0 = ar.shape[1] / 2
+
+ if rmax is None:
+ return radial_elliptical_integral(ar, dr, (x0, y0, 1, 1, 0))
+ else:
+ return radial_elliptical_integral(ar, dr, (x0, y0, 1, 1, 0), rmax=rmax)
diff --git a/py4DSTEM/process/utils/masks.py b/py4DSTEM/process/utils/masks.py
new file mode 100644
index 000000000..c6800edc9
--- /dev/null
+++ b/py4DSTEM/process/utils/masks.py
@@ -0,0 +1,78 @@
+# Functions for generating masks
+
+import numpy as np
+from scipy.ndimage import binary_dilation
+
+
+def get_beamstop_mask(dp, qx0, qy0, theta, dtheta=1, w=10, r=10):
+ """
+ Generates a beamstop shaped mask.
+
+ Args:
+ dp (2d array): a diffraction pattern
+ qx0,qy0 (numbers): the center position of the beamstop
+ theta (number): the orientation of the beamstop, in degrees
+ dtheta (number): angular span of the wedge representing the beamstop, in degrees
+ w (integer): half the width of the beamstop arm, in pixels
+ r (number): the radius of a circle at the end of the beamstop, in pixels
+
+ Returns:
+ (2d boolean array): the mask
+ """
+ # Handle inputs
+ theta = np.mod(np.radians(theta), 2 * np.pi)
+ dtheta = np.abs(np.radians(dtheta))
+
+ # Get a meshgrid
+ Q_Nx, Q_Ny = dp.shape
+ qyy, qxx = np.meshgrid(np.arange(Q_Ny), np.arange(Q_Nx))
+ qyy, qxx = qyy - qy0, qxx - qx0
+
+ # wedge handles
+ if dtheta > 0:
+ qzz = qxx + qyy * 1j
+ phi = np.mod(np.angle(qzz), 2 * np.pi)
+ # Handle the branch cut in the complex plane
+ if theta - dtheta < 0:
+ phi, theta = np.mod(phi + dtheta, 2 * np.pi), theta + dtheta
+ elif theta + dtheta > 2 * np.pi:
+ phi, theta = np.mod(phi - dtheta, 2 * np.pi), theta - dtheta
+ mask1 = np.abs(phi - theta) < dtheta
+ if w > 0:
+ mask1 = binary_dilation(mask1, iterations=w)
+
+ # straight handles
+ else:
+ pass
+
+ # circle mask
+ qrr = np.hypot(qxx, qyy)
+ mask2 = qrr < r
+
+ # combine masks
+ mask = np.logical_or(mask1, mask2)
+
+ return mask
+
+
+def make_circular_mask(shape, qxy0, radius):
+ """
+ Create a hard circular mask, for use in DPC integration or
+ or to use as a filter in diffraction or real space.
+
+ Args:
+ shape (2-tuple of ints) image size, in pixels
+ qxy0 (2-tuple of floats) center coordinates, in pixels. Must be in (row, column) format.
+ radius (float) radius of mask, in pixels
+
+ Returns:
+ mask (2D boolean array) the mask
+
+ """
+ # coordinates
+ qx = np.arange(shape[0]) - qxy0[0]
+ qy = np.arange(shape[1]) - qxy0[1]
+ [qya, qxa] = np.meshgrid(qy, qx)
+
+ # return circular mask
+ return qxa**2 + qya**2 < radius**2
diff --git a/py4DSTEM/process/utils/multicorr.py b/py4DSTEM/process/utils/multicorr.py
new file mode 100644
index 000000000..bc07390bb
--- /dev/null
+++ b/py4DSTEM/process/utils/multicorr.py
@@ -0,0 +1,200 @@
+"""
+loosely based on multicorr.py found at:
+https://github.com/ercius/openNCEM/blob/master/ncempy/algo/multicorr.py
+
+modified by SEZ, May 2019 to integrate with py4DSTEM utility functions
+ * rewrote upsampleFFT (previously did not work correctly)
+ * modified upsampled_correlation to accept xyShift, the point around which to
+ upsample the DFT
+ * eliminated the factor-2 FFT upsample step in favor of using parabolic
+ for first-pass subpixel (since parabolic is so fast)
+ * rewrote the matrix multiply DFT to be more pythonic
+"""
+
+import numpy as np
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+
+def upsampled_correlation(imageCorr, upsampleFactor, xyShift, device="cpu"):
+ """
+ Refine the correlation peak of imageCorr around xyShift by DFT upsampling.
+
+ There are two approaches to Fourier upsampling for subpixel refinement: (a) one
+ can pad an (appropriately shifted) FFT with zeros and take the inverse transform,
+ or (b) one can compute the DFT by matrix multiplication using modified
+ transformation matrices. The former approach is straightforward but requires
+ performing the FFT algorithm (which is fast) on very large data. The latter method
+ trades one speedup for a slowdown elsewhere: the matrix multiply steps are expensive
+ but we operate on smaller matrices. Since we are only interested in a very small
+ region of the FT around a peak of interest, we use the latter method to get
+ a substantial speedup and enormous decrease in memory requirement. This
+ "DFT upsampling" approach computes the transformation matrices for the matrix-
+ multiply DFT around a small 1.5px wide region in the original `imageCorr`.
+
+ Following the matrix multiply DFT we use parabolic subpixel fitting to
+ get even more precision! (below 1/upsampleFactor pixels)
+
+ NOTE: previous versions of multiCorr operated in two steps: using the zero-
+ padding upsample method for a first-pass factor-2 upsampling, followed by the
+ DFT upsampling (at whatever user-specified factor). I have implemented it
+ differently, to better support iterating over multiple peaks. **The DFT is always
+ upsampled around xyShift, which MUST be specified to HALF-PIXEL precision
+ (no more, no less) to replicate the behavior of the factor-2 step.**
+ (It is possible to refactor this so that peak detection is done on a Fourier
+ upsampled image rather than using the parabolic subpixel and rounding as now...
+ I like keeping it this way because all of the parameters and logic will be identical
+ to the other subpixel methods.)
+
+
+ Args:
+ imageCorr (complex valued ndarray):
+ Complex product of the FFTs of the two images to be registered
+ i.e. m = np.fft.fft2(DP) * probe_kernel_FT;
+ imageCorr = np.abs(m)**(corrPower) * np.exp(1j*np.angle(m))
+ upsampleFactor (int):
+ Upsampling factor. Must be greater than 2. (To do upsampling
+ with factor 2, use upsampleFFT, which is faster.)
+ xyShift:
+ Location in original image coordinates around which to upsample the
+ FT. This should be given to exactly half-pixel precision to
+ replicate the initial FFT step that this implementation skips
+
+ Returns:
+ (2-element np array): Refined location of the peak in image coordinates.
+ """
+
+ if device == "cpu":
+ xp = np
+ elif device == "gpu":
+ xp = cp
+
+ assert upsampleFactor > 2
+
+ xyShift[0] = xp.round(xyShift[0] * upsampleFactor) / upsampleFactor
+ xyShift[1] = xp.round(xyShift[1] * upsampleFactor) / upsampleFactor
+
+ globalShift = xp.fix(xp.ceil(upsampleFactor * 1.5) / 2)
+
+ upsampleCenter = xp.asarray(globalShift - upsampleFactor * xyShift)
+
+ imageCorrUpsample = xp.conj(
+ dftUpsample(xp.conj(imageCorr), upsampleFactor, upsampleCenter, device=device)
+ )
+
+ xySubShift = xp.asarray(
+ xp.unravel_index(imageCorrUpsample.argmax(), imageCorrUpsample.shape)
+ )
+
+ # add a subpixel shift via parabolic fitting
+ try:
+ icc = xp.real(
+ imageCorrUpsample[
+ xySubShift[0] - 1 : xySubShift[0] + 2,
+ xySubShift[1] - 1 : xySubShift[1] + 2,
+ ]
+ )
+ dx = (icc[2, 1] - icc[0, 1]) / (4 * icc[1, 1] - 2 * icc[2, 1] - 2 * icc[0, 1])
+ dy = (icc[1, 2] - icc[1, 0]) / (4 * icc[1, 1] - 2 * icc[1, 2] - 2 * icc[1, 0])
+ except:
+ dx, dy = (
+ 0,
+ 0,
+ ) # this is the case when the peak is near the edge and one of the above values does not exist
+
+ xySubShift = xySubShift - globalShift
+
+ xyShift = xyShift + (xySubShift + xp.array([dx, dy])) / upsampleFactor
+
+ return xyShift
+
+
+def upsampleFFT(cc, device="cpu"):
+ """
+ Zero-padding FFT upsampling. Returns the real IFFT of the input with 2x
+ upsampling. This may have an error for matrices with an odd size. Takes
+ a complex np array as input.
+ """
+ if device == "cpu":
+ xp = np
+ elif device == "gpu":
+ xp = cp
+
+ sz = cc.shape
+ ups = xp.zeros((sz[0] * 2, sz[1] * 2), dtype=complex)
+
+ ups[: int(np.ceil(sz[0] / 2)), : int(np.ceil(sz[1] / 2))] = cc[
+ : int(np.ceil(sz[0] / 2)), : int(np.ceil(sz[1] / 2))
+ ]
+ ups[-int(np.ceil(sz[0] / 2)) :, : int(np.ceil(sz[1] / 2))] = cc[
+ -int(np.ceil(sz[0] / 2)) :, : int(np.ceil(sz[1] / 2))
+ ]
+ ups[: int(np.ceil(sz[0] / 2)), -int(np.ceil(sz[1] / 2)) :] = cc[
+ : int(np.ceil(sz[0] / 2)), -int(np.ceil(sz[1] / 2)) :
+ ]
+ ups[-int(np.ceil(sz[0] / 2)) :, -int(np.ceil(sz[1] / 2)) :] = cc[
+ -int(np.ceil(sz[0] / 2)) :, -int(np.ceil(sz[1] / 2)) :
+ ]
+
+ return xp.real(xp.fft.ifft2(ups))
+
+
+def dftUpsample(imageCorr, upsampleFactor, xyShift, device="cpu"):
+ """
+ This performs a matrix multiply DFT around a small neighboring region of the inital
+ correlation peak. By using the matrix multiply DFT to do the Fourier upsampling, the
+ efficiency is greatly improved. This is adapted from the subfuction dftups found in
+ the dftregistration function on the Matlab File Exchange.
+
+ https://www.mathworks.com/matlabcentral/fileexchange/18401-efficient-subpixel-image-registration-by-cross-correlation
+
+ The matrix multiplication DFT is from:
+
+ Manuel Guizar-Sicairos, Samuel T. Thurman, and James R. Fienup, "Efficient subpixel
+ image registration algorithms," Opt. Lett. 33, 156-158 (2008).
+ http://www.sciencedirect.com/science/article/pii/S0045790612000778
+
+ Args:
+ imageCorr (complex valued ndarray):
+ Correlation image between two images in Fourier space.
+ upsampleFactor (int):
+ Scalar integer of how much to upsample.
+ xyShift (list of 2 floats):
+ Coordinates in the UPSAMPLED GRID around which to upsample.
+ These must be single-pixel IN THE UPSAMPLED GRID
+
+ Returns:
+ (ndarray):
+ Upsampled image from region around correlation peak.
+ """
+ if device == "cpu":
+ xp = np
+ elif device == "gpu":
+ xp = cp
+
+ imageSize = imageCorr.shape
+ pixelRadius = 1.5
+ numRow = np.ceil(pixelRadius * upsampleFactor)
+ numCol = numRow
+
+ colKern = xp.exp(
+ (-1j * 2 * np.pi / (imageSize[1] * upsampleFactor))
+ * xp.outer(
+ (xp.fft.ifftshift((xp.arange(imageSize[1]))) - xp.floor(imageSize[1] / 2)),
+ (xp.arange(numCol) - xyShift[1]),
+ )
+ )
+
+ rowKern = xp.exp(
+ (-1j * 2 * np.pi / (imageSize[0] * upsampleFactor))
+ * xp.outer(
+ (xp.arange(numRow) - xyShift[0]),
+ (xp.fft.ifftshift(xp.arange(imageSize[0])) - xp.floor(imageSize[0] / 2)),
+ )
+ )
+
+ imageUpsample = xp.real(rowKern @ imageCorr @ colKern)
+ return imageUpsample
diff --git a/py4DSTEM/process/utils/scattering_factors.txt b/py4DSTEM/process/utils/scattering_factors.txt
new file mode 100644
index 000000000..1e33df156
--- /dev/null
+++ b/py4DSTEM/process/utils/scattering_factors.txt
@@ -0,0 +1,103 @@
+6.47384848835291790e-03 2.78519885379148890e+00 -4.90192576780229040e-01 2.77620428330644750e+00 5.73284160390876480e-01 2.77538591050625130e+00 -3.79403301483990480e-01 2.76759302867258810e+00 5.54426474774079140e-01 2.76511897642927540e+00
+3.05745116099835460e+00 1.08967248726078810e+00 -6.20044779127325260e+01 9.39838798143121100e-01 6.40055537084614490e+01 9.25289034386265530e-01 -5.00132578542780590e+00 8.22947498708650580e-01 1.51798828700526440e-01 5.77393110675402220e-01
+3.92622272886147930e+00 8.14276013517280360e+00 -4.54861962639998030e+00 4.98941077007855770e+00 2.19335312878658510e+00 4.14428999239410880e+00 6.99451265033965710e-02 4.01922315065680210e-01 2.09864224851937560e-03 1.56479034719823580e-01
+3.39824970557054050e+00 4.44270178622409520e+00 -1.90866886095696660e+00 3.32451542526422950e+00 3.90702117539227370e-02 1.89772880348214850e-01 -1.11631010210714520e-02 8.71918614644603550e-02 9.46204465357523140e-03 8.27809060041340560e-02
+1.47279248639329290e+00 3.74974048281819130e+00 -4.01933042199387140e-01 5.88066536139673750e-01 3.05998956982689360e-01 5.15639613103010990e-01 1.96144217173168000e-02 1.21377570080603680e-01 9.77177106088202540e-04 6.80982412160313910e-02
+1.24466088621343300e+02 2.42120849256005630e+00 -2.20352857078963780e+02 2.30537943752425800e+00 1.95235352280479130e+02 2.04851932106564230e+00 -9.81079361269799650e+01 1.93352552917547380e+00 1.42023041213623150e-02 7.68976818478339650e-02
+5.81327150702556140e+01 1.70044856413471050e+00 -1.47542409087812730e+02 1.55903852601740440e+00 1.30143065649639450e+02 1.41576827473146880e+00 -3.96195674084154280e+01 1.27841818205455790e+00 1.05957763331480850e-02 5.65587798474805540e-02
+2.99474045242362440e+01 1.30283987880010680e+00 -7.76101266255278260e+01 1.15794105258309530e+00 9.98817764623144200e+01 1.00988549338025120e+00 -5.12127005505673050e+01 9.43327971433265970e-01 8.19618954446032010e-03 4.33197611321825550e-02
+9.48984894503524750e-01 1.45882933198645910e+00 -3.01333923043554930e+01 6.88779993187680020e-01 5.27965078127338640e+01 6.54239869346695650e-01 -2.27062703795272430e+01 6.14836130811994290e-01 6.56997664531440950e-03 3.42837419495011160e-02
+5.82741192220907370e-01 1.28118573143877200e+00 3.70676561841054910e-01 4.44520897170477600e-01 -5.46744967350809240e-01 1.98650875510481020e-01 4.14052682480208050e-01 1.85477246656276460e-01 5.19903080863993140e-03 2.75738382033885800e-02
+2.36700603946792610e+01 8.45148773514603140e+00 -2.18531786159742470e+01 8.04096600474298210e+00 5.92499448108946390e-01 6.24996000526314990e-01 -2.44652290310244000e-02 1.32450394947296380e-01 4.83950221706535770e-03 2.33994362049878630e-02
+4.85501047687149520e+00 5.94639273842456450e+00 -2.66220906476843670e+00 4.17130312520697900e+00 4.78001236085108030e-01 3.98269808150374270e-01 -7.02307064692064830e-02 1.61886185837474190e-01 3.98905828104019820e-03 1.95345056363103450e-02
+2.83409561607507500e+00 6.66235023980533200e+00 -4.28004133378261020e+00 5.51294722224021430e-01 4.42191680548311440e+00 5.09328963445973780e-01 -3.45774471896400580e-02 1.11784837425331210e-01 3.52385941406079890e-03 1.67602351805257390e-02
+2.87189142611612350e+00 5.08487103642989610e+00 -2.06173501195173530e+00 4.29178185305126190e-01 2.17114024204478720e+00 3.66485434192162230e-01 -6.63073633058801900e-02 1.19710611296903400e-01 3.01070709670513740e-03 1.43994536128397540e-02
+2.79151840023151410e+00 3.90065961865466140e+00 -4.36506837823822110e+00 3.29825968377150000e-01 4.43558455516699010e+00 3.06089956505888550e-01 -8.09635773399473850e-02 1.08083232545972810e-01 2.67900017966440440e-03 1.25894495331186820e-02
+2.67971415610199460e+00 3.06889121199971140e+00 -4.74252822230755160e-01 3.78216702185809000e-01 5.14835948989687650e-01 1.88721811902548830e-01 -9.58360024990722450e-02 9.23370590031494100e-02 2.48871963818935240e-03 1.11920877241144090e-02
+2.56624839980020260e+00 2.41594920365612390e+00 -3.38876350828591740e-01 4.21414239310215990e-01 1.14584558755515010e+00 1.09592404975830310e-01 -9.23109316547079620e-01 9.90955458226752960e-02 2.29168002041042150e-03 9.99665948927521040e-03
+2.45981746414068510e+00 1.94004631988856560e+00 -3.64198177076995090e-01 3.99241067884398780e-01 2.50584477222474460e-01 1.17472406274412200e-01 -5.77437029544345810e-02 5.67803726023621840e-02 2.30143866828583670e-03 9.15579832920707970e-03
+5.81107878601454790e+00 1.26691483399036820e+01 -5.02537096539422590e+01 3.95641039698166490e+00 4.88609412059842470e+01 3.68385059577154510e+00 7.40628592048382940e-02 1.07458517569562820e-01 7.27802738626221870e-04 6.65576789391501140e-03
+2.11781161524159530e+01 6.39608619431736170e+00 -3.39043824317468880e+02 3.74024713891749450e+00 3.22756958523296650e+02 3.64888449922605270e+00 6.50077673896996690e-02 9.45090634514673540e-02 6.55874366578593520e-04 5.98520619883758600e-03
+1.26035186572148630e+01 6.15625615363852940e+00 -2.76875382053702590e+02 3.08873554266679310e+00 2.68871603907342830e+02 3.02727663298548320e+00 5.56824178897458030e-02 8.18874748375215680e-02 5.77071255145399560e-04 5.38289832050797570e-03
+8.57595775238129930e+00 6.00780668875581810e+00 -2.10331563465304870e+02 2.60285856745213010e+00 2.06097172601541360e+02 2.55352345051105620e+00 4.77773948977261930e-02 7.11429484024153900e-02 5.05716484480320570e-04 4.85628439383726120e-03
+6.52768433234788950e+00 5.83552479352448120e+00 -2.00430576829172650e+02 2.23255952381127630e+00 1.98015053889994020e+02 2.19786018594228990e+00 4.13911518064005580e-02 6.23973876828547300e-02 4.47455024329856520e-04 4.40383649079958020e-03
+3.02831784843691310e+00 8.35911504314631770e+00 -9.55393933081432320e+01 1.80263790264103240e+00 9.61761562352198070e+01 1.77509488957047210e+00 3.59777315957987700e-02 5.48144412143113730e-02 3.91492890718422320e-04 3.99828968916034810e-03
+4.37417550633122690e+00 5.51031705504928220e+00 -1.60925510918779510e+02 1.68798216402433910e+00 1.60273308060322340e+02 1.66614047777702390e+00 3.12303861049162410e-02 4.83370390321277020e-02 3.46966021020938010e-04 3.64746960053068160e-03
+3.79810090836859620e+00 5.31712645899433060e+00 -9.16893549381687190e+01 1.49713094884748130e+00 9.14454252155429830e+01 1.46809241804410510e+00 2.72754344027604720e-02 4.27247850128954910e-02 3.03379854383771030e-04 3.32791855231874110e-03
+3.33037874467544270e+00 5.18135964579543180e+00 -7.70017596472967090e+01 1.32915122265139420e+00 7.70725221790516880e+01 1.30284928963228920e+00 2.39904669040569940e-02 3.80645486826370960e-02 2.68256665523942030e-04 3.05010007991570780e-03
+2.96908078725286370e+00 5.04180949109380010e+00 -7.57477069129040220e+01 1.18275507921629310e+00 7.60398287625307320e+01 1.16216545846629930e+00 2.10162113692966970e-02 3.37479087883035340e-02 2.31151751135153040e-04 2.78680862070740520e-03
+1.75207145212145600e+00 6.18750497986187130e+00 -4.30410523492124500e+01 1.00266263628976620e+00 4.40705915543573850e+01 9.85384311353030280e-01 1.86876154088144730e-02 3.02984703916117610e-02 2.01727324791624140e-04 2.55855598748879050e-03
+2.46637110499459130e+00 4.91028078493815910e+00 -6.14678541332537450e+01 9.67898520322992170e-01 6.20176945237481260e+01 9.51283834775355500e-01 1.64160173931416210e-02 2.69600967667566260e-02 1.72487117595380870e-04 2.34109611046253710e-03
+2.76010203108428340e+00 6.10128224537662690e+00 -3.44452614207467890e+01 7.65143313553464880e-01 3.52262267244016340e+01 7.51328623382859660e-01 1.32067196999419880e-02 2.24879634317251170e-02 1.25945560928245720e-04 2.06737374278712340e-03
+3.18241635260000020e+00 5.01719040860914680e+00 -5.24514037811166670e+01 7.12395764437798170e-01 5.29690827162218480e+01 7.02280192528193960e-01 1.14096168591864820e-02 1.96747295694015350e-02 9.50954358145057910e-05 1.84146614494011450e-03
+3.45642969119603950e+00 4.01358016032945210e+00 -3.33176044431721220e+01 6.62355778050629060e-01 3.35712193855332190e+01 6.45771941056073830e-01 9.79002295640342070e-03 1.70919353231011460e-02 6.53434862265172480e-05 1.60301602839420070e-03
+3.64905047801926410e+00 3.25043267112593390e+00 -4.36851662221238610e+01 6.09666201650288950e-01 4.36920288601154780e+01 5.96971300802123240e-01 8.44902284199144410e-03 1.48554512740362250e-02 3.78611474110038160e-05 1.33625587356563020e-03
+3.83846312224289490e+00 2.61189470573232270e+00 -5.22723471011233870e+01 5.66195062874718210e-01 5.19861279495659620e+01 5.55279326699873900e-01 7.33955989338044760e-03 1.29746469432440690e-02 1.64694207951339920e-05 1.02986536855875730e-03
+4.02541030310827440e+00 2.13648381443739990e+00 -4.63042332078929770e+01 5.26591166454131180e-01 4.57213681904101180e+01 5.14136784464577560e-01 6.35337959625315720e-03 1.12807242242006810e-02 1.33477845255742950e-06 4.88089857940970360e-04
+4.77092509299825980e+00 1.33668881330412750e+01 1.47597850155283420e+00 1.33738379577196680e+00 3.04451355544188670e-01 1.77532369437782830e-01 3.59474981916728340e-03 7.79105033001826780e-03 3.00085549228278960e-07 2.82255139648537230e-04
+3.38975351595393980e+00 2.05744814368171130e+01 2.14348348679172410e+00 1.91079945218516370e+00 3.54322603510978160e-01 1.97410589396603890e-01 3.74009340085664200e-03 8.13459465343988920e-03 3.00342502342229660e-07 2.92685791056951060e-04
+4.60721019875234020e+00 1.08686905557151850e+01 1.42801851039840270e+00 1.31137455873124110e+00 2.95581045577753830e-01 1.68022870784711150e-01 3.38997840852895120e-03 7.35964545350478290e-03 2.66862974969022400e-07 2.62343281021387260e-04
+4.31175453406871600e+00 9.45896580517490190e+00 1.49331578039339450e+00 1.33063622845245620e+00 2.81236050128884140e-01 1.56506871444453550e-01 3.09337738455747140e-03 6.82410523811735180e-03 2.58024448512530600e-07 2.49818879110702270e-04
+3.11179039113495380e+00 1.06903141343023660e+01 2.20259060945803050e+00 1.65316356158933140e+00 2.70330774910020890e-01 1.45115185718969860e-01 2.68799194751098100e-03 6.13956351582850470e-03 2.32549483347779190e-07 2.32240235937672540e-04
+2.83105968432053560e+00 1.04357195758950730e+01 2.34858137489674460e+00 1.60482868674597220e+00 2.45105888429696440e-01 1.31696934774606920e-01 2.35282108218515860e-03 5.54977901383345900e-03 2.31270838343859220e-07 2.22747463764897690e-04
+2.57179859323352570e+00 1.01643117131377530e+01 2.45633741957203040e+00 1.53441919244550480e+00 2.20658440881371460e-01 1.19138619754110400e-01 2.05531418012438150e-03 5.01853240882988840e-03 2.32132947793692180e-07 2.14590379120861160e-04
+2.33230030353946560e+00 9.92167460159959450e+00 2.53578025489049620e+00 1.45585668808788140e+00 1.98208042136026000e-01 1.07682187873349280e-01 1.76120130460053530e-03 4.47243545654967340e-03 1.98129411546024490e-07 1.96574716232499390e-04
+2.11352534859470030e+00 9.65913725859762270e+00 2.58636316155013770e+00 1.37106656934329620e+00 1.77063913863686530e-01 9.70353027584648780e-02 1.49737051003056920e-03 3.97128509088792110e-03 2.05481445555622420e-07 1.91355190737798320e-04
+6.42159796182687260e-01 5.97479750263406120e+00 2.97914814426328920e+00 1.43359432541277740e+00 1.68154426004270800e-01 9.09868401172676950e-02 1.33744213878407760e-03 3.62410137116162160e-03 1.91410968246861570e-07 1.80688914416045000e-04
+1.55317216680389560e+00 8.15620235758956550e+00 2.63930364699988780e+00 1.21600887481801510e+00 1.42015486956788170e-01 7.90098864915873420e-02 1.00850460099760920e-03 2.96547901363948970e-03 1.94638429973024380e-07 1.74593095919146200e-04
+6.15307851992860150e+01 3.11468102533247300e+00 -7.86016741201582080e+01 2.76016983388574480e+00 2.15501292602770460e+01 1.93551312324722380e+00 1.37685015664191560e-01 7.22468347260293160e-02 3.24644930946521010e-04 1.17001629624684460e-03
+4.22232177901524610e+00 6.07265510403227450e+00 -2.64121318353202560e+01 1.64550178959387900e+00 2.72852852708746920e+01 1.52257074939506750e+00 1.21617901884137960e-01 6.56507914695225600e-02 3.06883546371525800e-04 1.12093851473035640e-03
+5.14222074642053690e+00 5.27272636474484460e+00 -2.54945413762576720e+01 1.53194959209148300e+00 2.57414487548601870e+01 1.40257525040087170e+00 1.11778227592199600e-01 6.11685387263086880e-02 2.93647384737742710e-04 1.07560815394522340e-03
+6.24164031831883650e+00 4.26984108082687540e+00 -9.33868724419552190e+01 1.39407746140250150e+00 9.26332875829884300e+01 1.35566854910434480e+00 1.03412692221298720e-01 5.72667949652199830e-02 2.81848426894390730e-04 1.03307954282067370e-03
+7.37743301813403020e+00 3.46917757783897770e+00 -1.26025106931628140e+02 1.29759810886736600e+00 1.24128405004095550e+02 1.26771067646066030e+00 9.59978371818569760e-02 5.37771750179423110e-02 2.71072216398882020e-04 9.93084577029176810e-04
+9.64400666272123170e+00 2.72645545366544080e+00 -1.22924435350112520e+02 1.23723425850866020e+00 1.18682564816567250e+02 1.20062036872249790e+00 8.95025947025538950e-02 5.06678682285199920e-02 2.61276121490475030e-04 9.55437883383497270e-04
+1.55451749674860600e+01 2.10637340865492680e+00 -1.18241027856744580e+02 1.20860376129512240e+00 1.08009524963196970e+02 1.15395270567214010e+00 8.36259341992221660e-02 4.78189391274824250e-02 2.51991862429073870e-04 9.19886257592613110e-04
+4.28708739181692260e+00 2.26587870798541500e+01 3.23250665422964990e+00 2.23797386470064240e+00 6.74029533561706250e-01 3.68995568664316260e-01 6.18908083387697610e-02 4.02266575298484350e-02 2.35612052952190030e-04 8.83761890878035130e-04
+6.24475187390461530e+00 1.51431354190985630e+01 2.35172271416388460e+00 1.45379000604814030e+00 4.74279373222252000e-01 3.20835646387801770e-01 6.38113874286849870e-02 4.04354532094699790e-02 2.34651280562792400e-04 8.54081131280032660e-04
+6.09788179599509660e+00 1.24288544315854800e+01 2.19495164736675010e+00 1.50535992352023460e+00 5.48172791960022550e-01 3.33738039717124290e-01 6.16669573222662570e-02 3.87744553499998830e-02 2.26807355864714200e-04 8.24049884064866980e-04
+5.79526879647240540e+00 1.42801055058256010e+01 2.37022664107843270e+00 1.35969015719134510e+00 4.71398756901114880e-01 3.02017349663514120e-01 5.74368260587866800e-02 3.66436798147007710e-02 2.18979489259659270e-04 7.95434526518371140e-04
+5.60406255377525750e+00 1.39517490230574650e+01 2.35796259561812920e+00 1.31239754917439220e+00 4.76001098572882360e-01 2.94933701192956530e-01 5.44123374315247100e-02 3.48601542462324140e-02 2.11414602205149750e-04 7.68261682661809690e-04
+5.42908391969770320e+00 1.36503649427659970e+01 2.33687325360880300e+00 1.26759841390319910e+00 4.83373554104881860e-01 2.88610681576929930e-01 5.14654964557478970e-02 3.31291807446609310e-02 2.03776132863761090e-04 7.42348483801266290e-04
+5.26774445089444580e+00 1.33602096827394550e+01 2.30855813326304960e+00 1.22585856586941720e+00 4.93265479026012750e-01 2.82919632127866860e-01 4.86356279741696550e-02 3.14735397127137930e-02 1.96308842320675220e-04 7.17658025069615150e-04
+5.12680428517077580e+00 1.31581501026129040e+01 2.26925534008380580e+00 1.18129508243337060e+00 5.04209300453302480e-01 2.77107038219062910e-01 4.56923992416221960e-02 2.97957604549544570e-02 1.88675050493886220e-04 6.94011099199098280e-04
+4.97962359749809200e+00 1.28392663078033800e+01 2.24183087455669480e+00 1.14705446420486030e+00 5.12933961402893710e-01 2.70387160935243730e-01 4.29801848498334440e-02 2.82418714952601960e-02 1.81381692485787560e-04 6.71486603173992420e-04
+5.07835830045611390e+00 1.05132725469047320e+01 1.95744027181658950e+00 1.11764941281524610e+00 5.92825983221375140e-01 2.84386741833675240e-01 4.19502034143925600e-02 2.72663327641347110e-02 1.75241091528316820e-04 6.50310860707303690e-04
+4.71161636657300060e+00 1.23109415948539130e+01 2.17261950786578910e+00 1.08743796767492460e+00 5.39717601342865730e-01 2.59804965509275510e-01 3.79796004547811110e-02 2.53289923895189670e-02 1.66923763564841650e-04 6.29278566961911160e-04
+4.59075504485146620e+00 1.20656740623369880e+01 2.13572373036567380e+00 1.05811737115864360e+00 5.51355560094152540e-01 2.53994438918661800e-01 3.56058406718513800e-02 2.39508739649649150e-02 1.59824016856789790e-04 6.09482748433288060e-04
+4.48400110626710810e+00 1.18749250691570630e+01 2.08904347078539800e+00 1.02828493708789590e+00 5.65932618316104640e-01 2.48907807908075960e-01 3.32703590211063040e-02 2.25844065532147840e-02 1.52445610282892090e-04 5.90374445159424470e-04
+4.37665140786963210e+00 1.16653080717139480e+01 2.04645150383634090e+00 1.00314769675495310e+00 5.80662860591587670e-01 2.44153292686312390e-01 3.12388528327766720e-02 2.13618580000731020e-02 1.45374869662535650e-04 5.72110338454766620e-04
+4.28308318247419530e+00 1.14961905956683400e+01 1.99538001453730710e+00 9.77039102018820600e-01 5.97053571332504250e-01 2.39429132975829070e-01 2.90951668869559870e-02 2.00912233327255780e-02 1.38064769037521620e-04 5.54377138492266260e-04
+4.19563840737123250e+00 1.14107750456652330e+01 1.94333285978645610e+00 9.49011924454877360e-01 6.12446617285465790e-01 2.34950326535304650e-01 2.71515207694932420e-02 1.89051576280975480e-02 1.30594787351929680e-04 5.37233649209711080e-04
+4.35692593296384260e+00 9.29434514718568930e+00 1.69589204777315850e+00 9.10500045957419960e-01 6.63904520068470450e-01 2.38746595911376420e-01 2.63020018760281370e-02 1.82098542546462290e-02 1.25497318499823790e-04 5.21559371963562160e-04
+4.33138405664923450e+00 7.87684433810288360e+00 1.52764864728680740e+00 9.42515642627748780e-01 7.35795922891237960e-01 2.41699480039806010e-01 2.49526323201402640e-02 1.72899894448508980e-02 1.18740852581057760e-04 5.05834631313525580e-04
+4.19726001319641990e+00 6.93674024894457200e+00 1.46854806774708810e+00 1.01725276618348250e+00 7.83931248211088170e-01 2.38948939715070610e-01 2.32494094864395090e-02 1.62332433075381040e-02 1.11261358607285950e-04 4.90239004021149030e-04
+3.97629715819077400e+00 6.29685779228774930e+00 1.52292684257341810e+00 1.11289951157691090e+00 7.97706246390561310e-01 2.31057040678473440e-01 2.13666741558685220e-02 1.51013545567204450e-02 1.03078689385717270e-04 4.74682477105025770e-04
+3.75144381439811880e+00 5.79754636082953830e+00 1.62768802977632740e+00 1.18223631077710680e+00 7.81756755454271810e-01 2.19913586024770210e-01 1.95170803912291560e-02 1.39810455481456640e-02 9.43199804293427490e-05 4.59127125829838610e-04
+3.48401517338711560e+00 5.43998846080953680e+00 1.79377920416965450e+00 1.22792134140128130e+00 7.44878356691920260e-01 2.06208856992909220e-01 1.75427785162434460e-02 1.27899401721146020e-02 8.44872351568859950e-05 4.43032226787478600e-04
+1.59956578198844190e+00 5.79244447385567530e+00 2.97534452194120510e+00 1.55300982973225970e+00 6.95092678366882160e-01 1.88626335936648240e-01 1.48779627458687970e-02 1.11763470662453270e-02 6.90549579101583330e-05 4.22772361655551400e-04
+2.04021563975728260e+00 6.65819429609627460e+00 2.89922634624825460e+00 1.41337923778932420e+00 6.36344083815797660e-01 1.74000104502178950e-01 1.32071919595408320e-02 1.00687804530210510e-02 5.67382192500131220e-05 4.03077105692223100e-04
+1.67593467064870840e+00 5.52231093211402510e+00 3.00486602969729290e+00 1.38007223007196280e+00 5.95340013161635540e-01 1.62229237655945410e-01 1.17163186623094810e-02 9.01814890416575630e-03 4.29678297639817100e-05 3.79277667477667070e-04
+2.23522850443105270e+00 5.02030988960240340e+00 2.68276638651994890e+00 1.23077590583777650e+00 5.55194926212433270e-01 1.52248122992863640e-01 1.07273354366258700e-02 8.28399116906210380e-03 3.28473992046262910e-05 3.56241938931758870e-04
+2.80342737412509990e+00 6.55876872804534190e+00 2.71882787966020030e+00 1.16972422516667860e+00 5.22475915391999560e-01 1.43556691800178100e-01 9.84537836971042040e-03 7.61976526230039810e-03 2.34524526508316210e-05 3.29627673902985420e-04
+3.60861020997771350e+00 6.58162594621962780e+00 2.45056774737198340e+00 1.02772852660588470e+00 4.78639500107256980e-01 1.33533680634720870e-01 8.87214235169215240e-03 6.84861238948429770e-03 1.04001932909467290e-05 2.76388875456666010e-04
+4.24209901057315890e+00 5.75280115536079960e+00 2.09994342055827850e+00 8.73901489195461510e-01 4.32836631379922900e-01 1.23599956180524660e-01 8.02021570381059210e-03 6.17600333058570230e-03 7.21785030851340250e-07 1.41495177617230900e-04
+4.63620095668962320e+00 4.88825299780982900e+00 1.78063311426987080e+00 7.56310526880074720e-01 4.03730443952738490e-01 1.17302364452249240e-01 7.68539554668673950e-03 5.93034394242894160e-03 8.95415902807643600e-08 7.66668663873028710e-05
+4.96592250595070530e+00 4.09129378787496290e+00 1.43815561507033650e+00 6.29228996602551270e-01 3.71299200579298050e-01 1.11096678695535160e-01 7.47256428450201060e-03 5.77235010347761780e-03 1.14115284152585070e-07 8.08511893879382500e-05
+5.30615614475033230e+00 3.48800735481655670e+00 1.11733160236072120e+00 4.81190741149771170e-01 3.15858723110854330e-01 1.01874467945753100e-01 7.20342242310210040e-03 5.59245374417078650e-03 1.07355330545863460e-07 7.79506673540736650e-05
+4.52053399042335970e+00 1.94482234232948810e+01 4.10695397909124620e+00 1.89824673155996850e+00 7.13946878503728510e-01 1.69553563595341790e-01 1.69294027687453900e-02 1.14819567548934220e-02 8.57492129191724410e-05 3.46122038257982700e-04
+6.52401072001873320e+00 1.40092554298974860e+01 3.20787080745661290e+00 1.32635035961665260e+00 5.40478774349637650e-01 1.31408568386036120e-01 8.78278806898853160e-03 6.28647434520409560e-03 6.91010602949131460e-06 2.27839981458950950e-04
+6.89602853559619080e+00 1.10763825670398680e+01 2.83514154536523220e+00 1.17132616290503000e+00 5.03506881888815090e-01 1.23494651391799910e-01 8.02294510966706890e-03 5.73515595790497110e-03 9.20400954758739720e-08 7.19707541802523240e-05
+7.09374900162630160e+00 9.09473795165944670e+00 2.52912373903193990e+00 1.06391667033161210e+00 4.82107819888889010e-01 1.18619446621877890e-01 7.71936674145110980e-03 5.53945708895034650e-03 7.27114199425041410e-08 6.58494730528462700e-05
+6.43401324797284160e+00 1.02596851307297850e+01 2.97099970535788760e+00 1.13177452306163300e+00 4.60796651787848010e-01 1.12775935362382300e-01 7.24033125775455300e-03 5.27459583353132120e-03 6.36236664309127860e-08 6.19263824088759710e-05
+6.21070826784019920e+00 1.00214105859162840e+01 3.03934453825623450e+00 1.10399880155883360e+00 4.37339984439162780e-01 1.07218973162453880e-01 6.80714764147316580e-03 5.02432130850682420e-03 6.18229327742851270e-08 6.00506994197195790e-05
+6.00431598308986470e+00 9.81793046719106140e+00 3.09431654531417970e+00 1.06978705068029440e+00 4.12808487716417040e-01 1.01635291448645790e-01 6.40891657880066860e-03 4.79524846888725040e-03 6.73007376216075430e-08 6.02806598726462900e-05
+5.20061716810304200e+00 1.12038310194406670e+01 3.49849440367144030e+00 1.12846369546928190e+00 4.05414931131320270e-01 9.89333472686131640e-02 6.22343700826529910e-03 4.65917431977950700e-03 6.00859329543440850e-08 5.72590638962334120e-05
+5.02533860036029090e+00 1.09772190697628730e+01 3.51843988285154060e+00 1.08477214832630780e+00 3.81950349446272210e-01 9.36746874853835870e-02 5.82110208096211020e-03 4.42682346873485640e-03 6.52609338829711920e-08 5.73856837050511520e-05
+5.34656100260619650e+00 9.23118379752449410e+00 3.22468466560910820e+00 9.72836229193835170e-01 3.49461752625445210e-01 8.70596220966562280e-02 5.29251756748815110e-03 4.13011964465032340e-03 6.15917630236237360e-08 5.49463894449675660e-05
+5.22582350334004710e+00 9.07135706618718270e+00 3.22818873995295300e+00 9.31124277456634840e-01 3.27098878823851360e-01 8.20720602059845510e-02 4.88882575592239010e-03 3.89029655993416560e-03 5.21272269419363250e-08 5.10367976636641620e-05
+4.58641247753547890e+00 1.03186102092335170e+01 3.51869595665101230e+00 9.57278837829232840e-01 3.19261714201205420e-01 7.96807431443880690e-02 4.72979647985625770e-03 3.77193199996848920e-03 5.51324480806204080e-08 5.09750860376011080e-05
+4.45799480675412600e+00 1.00890693383401260e+01 3.50867212637662320e+00 9.19446447361622620e-01 3.01375380528315420e-01 7.56419177786329250e-02 4.40763746214481880e-03 3.57026465124809290e-03 4.88787903228591640e-08 4.80495174395811320e-05
+4.33897576401170860e+00 9.96330937223836170e+00 3.49185098896403410e+00 8.78083956982974370e-01 2.81471099016286750e-01 7.11637115768777580e-02 4.00209271499046060e-03 3.32152404215077700e-03 5.52929808195514040e-08 4.85103179460372770e-05
+4.22729470403101180e+00 9.72400602040031360e+00 3.47249227510661560e+00 8.42873775960210400e-01 2.64822229496898600e-01 6.73534743974851110e-02 3.69072877833192920e-03 3.12364606263320840e-03 6.25871426200247310e-08 4.91217097017619240e-05
+4.10951702443020390e+00 9.67735994510120180e+00 3.45799132522750740e+00 8.06940042470817190e-01 2.47087351222386680e-01 6.32815436954167060e-02 3.30423920940995170e-03 2.87544639641209180e-03 5.99104926231931580e-08 4.70679153697598100e-05
+4.52147421198378830e+00 8.28309906861142050e+00 3.20212985587804380e+00 7.31918958125393870e-01 2.23028726956451650e-01 5.80942773018654980e-02 2.81716453892029730e-03 2.56168016047444900e-03 4.06427954038620530e-08 4.03816515529006540e-05
\ No newline at end of file
diff --git a/py4DSTEM/process/utils/single_atom_scatter.py b/py4DSTEM/process/utils/single_atom_scatter.py
new file mode 100644
index 000000000..8d6e2a891
--- /dev/null
+++ b/py4DSTEM/process/utils/single_atom_scatter.py
@@ -0,0 +1,90 @@
+import numpy as np
+import os
+
+
+class single_atom_scatter(object):
+ """
+ This class calculates the composition averaged single atom scattering factor for a
+ material. The parameterization is based upon Lobato, Acta Cryst. (2014). A70,
+ 636–649.
+
+ Elements is an 1D array of atomic numbers.
+ Composition is a 1D array, same length as elements, describing the average atomic
+ composition of the sample. If the Q_coords is a 1D array of Fourier coordinates,
+ given in inverse Angstroms. Units is a string of 'VA' or 'A', which returns the
+ scattering factor in volt angtroms or in angstroms.
+ """
+
+ def __init__(self, elements=None, composition=None, q_coords=None, units=None):
+ self.elements = elements
+ self.composition = composition
+ self.q_coords = q_coords
+ self.units = units
+ path = os.path.join(os.path.dirname(__file__), "scattering_factors.txt")
+ self.e_scattering_factors = np.loadtxt(path, dtype=np.float64)
+
+ return
+
+ def electron_scattering_factor(self, Z, gsq, units="A"):
+ ai = self.e_scattering_factors[Z - 1, 0:10:2]
+ bi = self.e_scattering_factors[Z - 1, 1:10:2]
+
+ # Planck's constant in Js
+ h = 6.62607004e-34
+ # Electron rest mass in kg
+ me = 9.10938356e-31
+ # Electron charge in Coulomb
+ qe = 1.60217662e-19
+
+ fe = np.zeros_like(gsq)
+ for i in range(5):
+ fe += ai[i] * (2 + bi[i] * gsq) / (1 + bi[i] * gsq) ** 2
+
+ # Result can be returned in units of Volt Angstrom³ ('VA') or Angstrom ('A')
+ if units == "VA":
+ return h**2 / (2 * np.pi * me * qe) * 1e18 * fe
+ elif units == "A":
+ return fe
+
+ def get_scattering_factor(
+ self, elements=None, composition=None, q_coords=None, units=None
+ ):
+ if elements is None:
+ assert (
+ not self.elements is None
+ ), "Must pass a list of atomic numbers in either class initialization or in call to get_scattering_factor()"
+ elements = self.elements
+
+ if composition is None:
+ assert (
+ not self.elements is None
+ ), "Must pass composition fractions in either class initialization or in call to get_scattering_factor()"
+ composition = self.composition
+
+ if q_coords is None:
+ assert (
+ not self.elements is None
+ ), "Must pass a q_space array in either class initialization or in call to get_scattering_factor()"
+ q_coords = self.q_coords
+
+ if units is None:
+ units = self.units
+ if self.units is None:
+ print("Setting output units to Angstroms")
+ units = "A"
+
+ assert len(elements) == len(
+ composition
+ ), "Each element must have an associated composition."
+
+ if np.sum(composition) > 1:
+ # normalize composition if passed as stoichiometry instead of atomic fractions
+ composition /= np.sum(composition)
+
+ fe = np.zeros_like(q_coords)
+ for i in range(len(elements)):
+ fe += composition[i] * self.electron_scattering_factor(
+ elements[i], np.square(q_coords), units
+ )
+
+ self.fe = fe
diff --git a/py4DSTEM/process/utils/utils.py b/py4DSTEM/process/utils/utils.py
new file mode 100644
index 000000000..4ef2e1d8a
--- /dev/null
+++ b/py4DSTEM/process/utils/utils.py
@@ -0,0 +1,717 @@
+# Defines utility functions used by other functions in the /process/ directory.
+
+import numpy as np
+from numpy.fft import fftfreq, fftshift
+from scipy.ndimage import gaussian_filter
+from scipy.spatial import Voronoi
+import math as ma
+import matplotlib.pyplot as plt
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
+import matplotlib.font_manager as fm
+
+from emdfile import tqdmnd
+from py4DSTEM.process.utils.multicorr import upsampled_correlation
+from py4DSTEM.preprocess.utils import make_Fourier_coords2D
+
+try:
+ from IPython.display import clear_output
+except ImportError:
+
+ def clear_output(wait=True):
+ pass
+
+
+try:
+ import cupy as cp
+except ModuleNotFoundError:
+ cp = np
+
+
+def radial_reduction(ar, x0, y0, binsize=1, fn=np.mean, coords=None):
+ """
+ Evaluate a reduction function on pixels within annular rings centered on (x0,y0),
+ with a ring width of binsize.
+
+ By default, returns the mean value of pixels within each annulus.
+ Some other useful reductions include: np.sum, np.std, np.count, np.median, ...
+
+ When running in a loop, pre-compute the pixel coordinates and pass them in
+ for improved performance, like so:
+ coords = np.mgrid[0:ar.shape[0],0:ar.shape[1]]
+ radial_sums = radial_reduction(ar, x0,y0, coords=coords)
+ """
+ qx, qy = coords if coords else np.mgrid[0 : ar.shape[0], 0 : ar.shape[1]]
+
+ r = (
+ np.floor(np.hypot(qx - x0, qy - y0).ravel() / binsize).astype(np.int64)
+ * binsize
+ )
+ edges = np.cumsum(np.bincount(r)[::binsize])
+ slices = [slice(0, edges[0])] + [
+ slice(edges[i], edges[i + 1]) for i in range(len(edges) - 1)
+ ]
+ rargsort = np.argsort(r)
+ sorted_ar = ar.ravel()[rargsort]
+ reductions = np.array([fn(sorted_ar[s]) for s in slices])
+
+ return reductions
+
+
+def plot(
+ img,
+ title="Image",
+ savePath=None,
+ cmap="inferno",
+ show=True,
+ vmax=None,
+ figsize=(10, 10),
+ scale=None,
+):
+ fig, ax = plt.subplots(figsize=figsize)
+ im = ax.imshow(img, interpolation="nearest", cmap=plt.get_cmap(cmap), vmax=vmax)
+ divider = make_axes_locatable(ax)
+ cax = divider.append_axes("right", size="5%", pad=0.05)
+ plt.colorbar(im, cax=cax)
+ ax.set_title(title)
+ fontprops = fm.FontProperties(size=18)
+ if scale is not None:
+ scalebar = AnchoredSizeBar(
+ ax.transData,
+ scale[0],
+ scale[1],
+ "lower right",
+ pad=0.1,
+ color="white",
+ frameon=False,
+ size_vertical=img.shape[0] / 40,
+ fontproperties=fontprops,
+ )
+
+ ax.add_artist(scalebar)
+ ax.grid(False)
+ if savePath is not None:
+ fig.savefig(savePath + ".png", dpi=600)
+ fig.savefig(savePath + ".eps", dpi=600)
+ if show:
+ plt.show()
+
+
+def electron_wavelength_angstrom(E_eV):
+ m = 9.109383 * 10**-31
+ e = 1.602177 * 10**-19
+ c = 299792458
+ h = 6.62607 * 10**-34
+
+ lam = (
+ h
+ / ma.sqrt(2 * m * e * E_eV)
+ / ma.sqrt(1 + e * E_eV / 2 / m / c**2)
+ * 10**10
+ )
+ return lam
+
+
+def electron_interaction_parameter(E_eV):
+ m = 9.109383 * 10**-31
+ e = 1.602177 * 10**-19
+ c = 299792458
+ h = 6.62607 * 10**-34
+ lam = (
+ h
+ / ma.sqrt(2 * m * e * E_eV)
+ / ma.sqrt(1 + e * E_eV / 2 / m / c**2)
+ * 10**10
+ )
+ sigma = (
+ (2 * np.pi / lam / E_eV) * (m * c**2 + e * E_eV) / (2 * m * c**2 + e * E_eV)
+ )
+ return sigma
+
+
+def sector_mask(shape, centre, radius, angle_range=(0, 360)):
+ """
+ Return a boolean mask for a circular sector. The start/stop angles in
+ `angle_range` should be given in clockwise order.
+
+ Args:
+ shape: 2D shape of the mask
+ centre: 2D center of the circular sector
+ radius: radius of the circular mask
+ angle_range: angular range of the circular mask
+ """
+ x, y = np.ogrid[: shape[0], : shape[1]]
+ cx, cy = centre
+ tmin, tmax = np.deg2rad(angle_range)
+
+ # ensure stop angle > start angle
+ if tmax < tmin:
+ tmax += 2 * np.pi
+
+ # convert cartesian --> polar coordinates
+ r2 = (x - cx) * (x - cx) + (y - cy) * (y - cy)
+ theta = np.arctan2(x - cx, y - cy) - tmin
+
+ # wrap angles between 0 and 2*pi
+ theta %= 2 * np.pi
+
+ # circular mask
+ circmask = r2 <= radius * radius
+
+ # print 'radius - ', radius
+
+ # angular mask
+ anglemask = theta < (tmax - tmin)
+
+ return circmask * anglemask
+
+
+def get_qx_qy_1d(M, dx=[1, 1], fft_shifted=False):
+ """
+ Generates 1D Fourier coordinates for a (Nx,Ny)-shaped 2D array.
+ Specifying the dx argument sets a unit size.
+
+ Args:
+ M: (2,) shape of the returned array
+ dx: (2,) tuple, pixel size
+ fft_shifted: True if result should be fft_shifted to have the origin in the center of the array
+ """
+
+ qxa = fftfreq(M[0], dx[0])
+ qya = fftfreq(M[1], dx[1])
+ if fft_shifted:
+ qxa = fftshift(qxa)
+ qya = fftshift(qya)
+ return qxa, qya
+
+
+def make_Fourier_coords2D(Nx, Ny, pixelSize=1):
+ """
+ Generates Fourier coordinates for a (Nx,Ny)-shaped 2D array.
+ Specifying the pixelSize argument sets a unit size.
+ """
+ if hasattr(pixelSize, "__len__"):
+ assert len(pixelSize) == 2, "pixelSize must either be a scalar or have length 2"
+ pixelSize_x = pixelSize[0]
+ pixelSize_y = pixelSize[1]
+ else:
+ pixelSize_x = pixelSize
+ pixelSize_y = pixelSize
+
+ qx = np.fft.fftfreq(Nx, pixelSize_x)
+ qy = np.fft.fftfreq(Ny, pixelSize_y)
+ qy, qx = np.meshgrid(qy, qx)
+ return qx, qy
+
+
+def get_CoM(ar, device="cpu", corner_centered=False):
+ """
+ Finds and returns the center of mass of array ar.
+ If corner_centered is True, uses fftfreq for indices.
+ """
+ if device == "cpu":
+ xp = np
+ elif device == "gpu":
+ xp = cp
+
+ ar = xp.asarray(ar)
+ nx, ny = ar.shape
+
+ if corner_centered:
+ ry, rx = xp.meshgrid(xp.fft.fftfreq(ny, 1 / ny), xp.fft.fftfreq(nx, 1 / nx))
+ else:
+ ry, rx = xp.meshgrid(xp.arange(ny), xp.arange(nx))
+
+ tot_intens = xp.sum(ar)
+ xCoM = xp.sum(rx * ar) / tot_intens
+ yCoM = xp.sum(ry * ar) / tot_intens
+ return xCoM, yCoM
+
+
+def get_maxima_1D(ar, sigma=0, minSpacing=0, minRelativeIntensity=0, relativeToPeak=0):
+ """
+ Finds the indices where 1D array ar is a local maximum.
+ Optional parameters allow blurring the array and filtering the output;
+ setting each to 0 (default) turns off these functions.
+
+ Args:
+ ar (1D array):
+ sigma (number): gaussian blur std to apply to ar before finding maxima
+ minSpacing (number): if two maxima are found within minSpacing, the dimmer one
+ is removed
+ minRelativeIntensity (number): maxima dimmer than minRelativeIntensity compared
+ to the relativeToPeak'th brightest maximum are removed
+ relativeToPeak (int): 0=brightest maximum. 1=next brightest, etc.
+
+ Returns:
+ (array of ints): An array of indices where ar is a local maximum, sorted by intensity.
+ """
+ assert len(ar.shape) == 1, "ar must be 1D"
+ assert isinstance(
+ relativeToPeak, (int, np.integer)
+ ), "relativeToPeak must be an int"
+ if sigma > 0:
+ ar = gaussian_filter(ar, sigma)
+
+ # Get maxima and intensity arrays
+ maxima_bool = np.logical_and((ar > np.roll(ar, -1)), (ar >= np.roll(ar, +1)))
+ x = np.arange(len(ar))[maxima_bool]
+ intensity = ar[maxima_bool]
+
+ # Sort by intensity
+ temp_ar = np.array(
+ [(x, inten) for inten, x in sorted(zip(intensity, x), reverse=True)]
+ )
+ x, intensity = temp_ar[:, 0], temp_ar[:, 1]
+
+ # Remove points which are too close
+ if minSpacing > 0:
+ deletemask = np.zeros(len(x), dtype=bool)
+ for i in range(len(x)):
+ if not deletemask[i]:
+ delete = np.abs(x[i] - x) < minSpacing
+ delete[: i + 1] = False
+ deletemask = deletemask | delete
+ x = np.delete(x, deletemask.nonzero()[0])
+ intensity = np.delete(intensity, deletemask.nonzero()[0])
+
+ # Remove points which are too dim
+ if minRelativeIntensity > 0:
+ deletemask = intensity / intensity[relativeToPeak] < minRelativeIntensity
+ x = np.delete(x, deletemask.nonzero()[0])
+ intensity = np.delete(intensity, deletemask.nonzero()[0])
+
+ return x.astype(int)
+
+
+def linear_interpolation_1D(ar, x):
+ """
+ Calculates the 1D linear interpolation of array ar at position x using the two
+ nearest elements.
+ """
+ x0, x1 = int(np.floor(x)), int(np.ceil(x))
+ dx = x - x0
+ return (1 - dx) * ar[x0] + dx * ar[x1]
+
+
+def add_to_2D_array_from_floats(ar, x, y, I):
+ """
+ Adds the values I to array ar, distributing the value between the four pixels nearest
+ (x,y) using linear interpolation. Inputs (x,y,I) may be floats or arrays of floats.
+
+ Note that if the same [x,y] coordinate appears more than once in the input array,
+ only the *final* value of I at that coordinate will get added.
+ """
+ Nx, Ny = ar.shape
+ x0, x1 = (np.floor(x)).astype(int), (np.ceil(x)).astype(int)
+ y0, y1 = (np.floor(y)).astype(int), (np.ceil(y)).astype(int)
+ mask = np.logical_and(
+ np.logical_and(np.logical_and((x0 >= 0), (y0 >= 0)), (x1 < Nx)), (y1 < Ny)
+ )
+ dx = x - x0
+ dy = y - y0
+ ar[x0[mask], y0[mask]] += (1 - dx[mask]) * (1 - dy[mask]) * I[mask]
+ ar[x0[mask], y1[mask]] += (1 - dx[mask]) * (dy[mask]) * I[mask]
+ ar[x1[mask], y0[mask]] += (dx[mask]) * (1 - dy[mask]) * I[mask]
+ ar[x1[mask], y1[mask]] += (dx[mask]) * (dy[mask]) * I[mask]
+ return ar
+
+
+def get_voronoi_vertices(voronoi, nx, ny, dist=10):
+ """
+ From a scipy.spatial.Voronoi instance, return a list of ndarrays, where each array
+ is shape (N,2) and contains the (x,y) positions of the vertices of a voronoi region.
+
+ The problem this function solves is that in a Voronoi instance, some vertices outside
+ the field of view of the tesselated region are left unspecified; only the existence
+ of a point beyond the field is referenced (which may or may not be 'at infinity').
+ This function specifies all points, such that the vertices and edges of the
+ tesselation may be directly laid over data.
+
+ Args:
+ voronoi (scipy.spatial.Voronoi): the voronoi tesselation
+ nx (int): the x field-of-view of the tesselated region
+ ny (int): the y field-of-view of the tesselated region
+ dist (float, optional): place new vertices by extending new voronoi edges outside
+ the frame by a distance of this factor times the distance of its known vertex
+ from the frame edge
+
+ Returns:
+ (list of ndarrays of shape (N,2)): the (x,y) coords of the vertices of each
+ voronoi region
+ """
+ assert isinstance(
+ voronoi, Voronoi
+ ), "voronoi must be a scipy.spatial.Voronoi instance"
+
+ vertex_list = []
+
+ # Get info about ridges containing an unknown vertex. Include:
+ # -the index of its known vertex, in voronoi.vertices, and
+ # -the indices of its regions, in voronoi.point_region
+ edgeridge_vertices_and_points = []
+ for i in range(len(voronoi.ridge_vertices)):
+ ridge = voronoi.ridge_vertices[i]
+ if -1 in ridge:
+ edgeridge_vertices_and_points.append(
+ [max(ridge), voronoi.ridge_points[i, 0], voronoi.ridge_points[i, 1]]
+ )
+ edgeridge_vertices_and_points = np.array(edgeridge_vertices_and_points)
+
+ # Loop over all regions
+ for index in range(len(voronoi.regions)):
+ # Get the vertex indices
+ vertex_indices = voronoi.regions[index]
+ vertices = np.array([0, 0])
+ # Loop over all vertices
+ for i in range(len(vertex_indices)):
+ index_current = vertex_indices[i]
+ if index_current != -1:
+ # For known vertices, just add to a running list
+ vertices = np.vstack((vertices, voronoi.vertices[index_current]))
+ else:
+ # For unknown vertices, get the first vertex it connects to,
+ # and the two voronoi points that this ridge divides
+ index_prev = vertex_indices[(i - 1) % len(vertex_indices)]
+ edgeridge_index = int(
+ np.argwhere(edgeridge_vertices_and_points[:, 0] == index_prev)
+ )
+ index_vert, region0, region1 = edgeridge_vertices_and_points[
+ edgeridge_index, :
+ ]
+ x, y = voronoi.vertices[index_vert]
+ # Only add new points for unknown vertices if the known index it connects to
+ # is inside the frame. Add points by finding the line segment starting at
+ # the known point which is perpendicular to the segment connecting the two
+ # voronoi points, and extending that line segment outside the frame.
+ if (x > 0) and (x < nx) and (y > 0) and (y < ny):
+ x_r0, y_r0 = voronoi.points[region0]
+ x_r1, y_r1 = voronoi.points[region1]
+ m = -(x_r1 - x_r0) / (y_r1 - y_r0)
+ # Choose the direction to extend the ridge
+ ts = np.array([-x, -y / m, nx - x, (ny - y) / m])
+ x_t = lambda t: x + t
+ y_t = lambda t: y + m * t
+ t = ts[np.argmin(np.hypot(x - x_t(ts), y - y_t(ts)))]
+ x_new, y_new = x_t(dist * t), y_t(dist * t)
+ vertices = np.vstack((vertices, np.array([x_new, y_new])))
+ else:
+ # If handling unknown points connecting to points outside the frame is
+ # desired, add here
+ pass
+
+ # Repeat for the second vertec the unknown vertex connects to
+ index_next = vertex_indices[(i + 1) % len(vertex_indices)]
+ edgeridge_index = int(
+ np.argwhere(edgeridge_vertices_and_points[:, 0] == index_next)
+ )
+ index_vert, region0, region1 = edgeridge_vertices_and_points[
+ edgeridge_index, :
+ ]
+ x, y = voronoi.vertices[index_vert]
+ if (x > 0) and (x < nx) and (y > 0) and (y < ny):
+ x_r0, y_r0 = voronoi.points[region0]
+ x_r1, y_r1 = voronoi.points[region1]
+ m = -(x_r1 - x_r0) / (y_r1 - y_r0)
+ # Choose the direction to extend the ridge
+ ts = np.array([-x, -y / m, nx - x, (ny - y) / m])
+ x_t = lambda t: x + t
+ y_t = lambda t: y + m * t
+ t = ts[np.argmin(np.hypot(x - x_t(ts), y - y_t(ts)))]
+ x_new, y_new = x_t(dist * t), y_t(dist * t)
+ vertices = np.vstack((vertices, np.array([x_new, y_new])))
+ else:
+ pass
+
+ # Remove regions with insufficiently many vertices
+ if len(vertices) < 4:
+ vertices = np.array([])
+ # Remove initial dummy point
+ else:
+ vertices = vertices[1:, :]
+ # Update vertex list with this region's vertices
+ vertex_list.append(vertices)
+
+ return vertex_list
+
+
+def get_ewpc_filter_function(Q_Nx, Q_Ny):
+ """
+ Returns a function for computing the exit wave power cepstrum of a diffraction
+ pattern using a Hanning window. This can be passed as the filter_function in the
+ Bragg disk detection functions (with the probe an array of ones) to find the lattice
+ vectors by the EWPC method (but be careful as the lengths are now in realspace
+ units!) See https://arxiv.org/abs/1911.00984
+ """
+ h = np.hanning(Q_Nx)[:, np.newaxis] * np.hanning(Q_Ny)[np.newaxis, :]
+ return (
+ lambda x: np.abs(np.fft.fftshift(np.fft.fft2(h * np.log(np.maximum(x, 0.01)))))
+ ** 2
+ )
+
+
+def fourier_resample(
+ array,
+ scale=None,
+ output_size=None,
+ force_nonnegative=False,
+ bandlimit_nyquist=None,
+ bandlimit_power=2,
+ dtype=np.float32,
+):
+ """
+ Resize a 2D array along any dimension, using Fourier interpolation / extrapolation.
+ For 4D input arrays, only the final two axes can be resized.
+
+ The scaling of the array can be specified by passing either `scale`, which sets
+ the scaling factor along both axes to be scaled; or by passing `output_size`,
+ which specifies the final dimensions of the scaled axes (and allows for different
+ scaling along the x,y or kx,ky axes.)
+
+ Args:
+ array (2D/4D numpy array): Input array, or 4D stack of arrays, to be resized.
+ scale (float): scalar value giving the scaling factor for all dimensions
+ output_size (2-tuple of ints): two values giving either the (x,y) output size for 2D, or (kx,ky) for 4D
+ force_nonnegative (bool): Force all outputs to be nonnegative, after filtering
+ bandlimit_nyquist (float): Gaussian filter information limit in Nyquist units (0.5 max in both directions)
+ bandlimit_power (float): Gaussian filter power law scaling (higher is sharper)
+ dtype (numpy dtype): datatype for binned array. default is single precision float
+
+ Returns:
+ the resized array (2D/4D numpy array)
+ """
+
+ # Verify input is 2D or 4D
+ if np.size(array.shape) != 2 and np.size(array.shape) != 4:
+ raise Exception(
+ "Function does not support arrays with "
+ + str(np.size(array.shape))
+ + " dimensions"
+ )
+
+ # Get input size from last 2 dimensions
+ input__size = array.shape[-2:]
+
+ if scale is not None:
+ assert (
+ output_size is None
+ ), "Cannot specify both a scaling factor and output size"
+ assert np.size(scale) == 1, "scale should be a single value"
+ scale = np.asarray(scale)
+ output_size = (input__size * scale).astype("intp")
+ else:
+ assert scale is None, "Cannot specify both a scaling factor and output size"
+ assert np.size(output_size) == 2, "output_size must contain two values"
+ output_size = np.asarray(output_size)
+
+ scale_output = np.prod(output_size) / np.prod(input__size)
+
+ if bandlimit_nyquist is not None:
+ kx = np.fft.fftfreq(output_size[0])
+ ky = np.fft.fftfreq(output_size[1])
+ k2 = kx[:, None] ** 2 + ky[None, :] ** 2
+ # Gaussian filter
+ k_filt = np.exp(
+ (k2 ** (bandlimit_power / 2)) / (-2 * bandlimit_nyquist**bandlimit_power)
+ )
+
+ # generate slices
+ # named as {dimension}_{corner}_{in_/out},
+ # where corner is ul, ur, ll, lr for {upper/lower}{left/right}
+
+ # x slices
+ if output_size[0] > input__size[0]:
+ # x dimension increases
+ x0 = int((input__size[0] + 1) // 2)
+ x1 = int(input__size[0] // 2)
+
+ x_ul_out = slice(0, x0)
+ x_ul_in_ = slice(0, x0)
+
+ x_ll_out = slice(0 - x1 + output_size[0], output_size[0])
+ x_ll_in_ = slice(0 - x1 + input__size[0], input__size[0])
+
+ x_ur_out = slice(0, x0)
+ x_ur_in_ = slice(0, x0)
+
+ x_lr_out = slice(0 - x1 + output_size[0], output_size[0])
+ x_lr_in_ = slice(0 - x1 + input__size[0], input__size[0])
+
+ elif output_size[0] < input__size[0]:
+ # x dimension decreases
+ x0 = int((output_size[0] + 1) // 2)
+ x1 = int(output_size[0] // 2)
+
+ x_ul_out = slice(0, x0)
+ x_ul_in_ = slice(0, x0)
+
+ x_ll_out = slice(0 - x1 + output_size[0], output_size[0])
+ x_ll_in_ = slice(0 - x1 + input__size[0], input__size[0])
+
+ x_ur_out = slice(0, x0)
+ x_ur_in_ = slice(0, x0)
+
+ x_lr_out = slice(0 - x1 + output_size[0], output_size[0])
+ x_lr_in_ = slice(0 - x1 + input__size[0], input__size[0])
+
+ else:
+ # x dimension does not change
+ x_ul_out = slice(None)
+ x_ul_in_ = slice(None)
+
+ x_ll_out = slice(None)
+ x_ll_in_ = slice(None)
+
+ x_ur_out = slice(None)
+ x_ur_in_ = slice(None)
+
+ x_lr_out = slice(None)
+ x_lr_in_ = slice(None)
+
+ # y slices
+ if output_size[1] > input__size[1]:
+ # y increases
+ y0 = int((input__size[1] + 1) // 2)
+ y1 = int(input__size[1] // 2)
+
+ y_ul_out = slice(0, y0)
+ y_ul_in_ = slice(0, y0)
+
+ y_ll_out = slice(0, y0)
+ y_ll_in_ = slice(0, y0)
+
+ y_ur_out = slice(0 - y1 + output_size[1], output_size[1])
+ y_ur_in_ = slice(0 - y1 + input__size[1], input__size[1])
+
+ y_lr_out = slice(0 - y1 + output_size[1], output_size[1])
+ y_lr_in_ = slice(0 - y1 + input__size[1], input__size[1])
+
+ elif output_size[1] < input__size[1]:
+ # y decreases
+ y0 = int((output_size[1] + 1) // 2)
+ y1 = int(output_size[1] // 2)
+
+ y_ul_out = slice(0, y0)
+ y_ul_in_ = slice(0, y0)
+
+ y_ll_out = slice(0, y0)
+ y_ll_in_ = slice(0, y0)
+
+ y_ur_out = slice(0 - y1 + output_size[1], output_size[1])
+ y_ur_in_ = slice(0 - y1 + input__size[1], input__size[1])
+
+ y_lr_out = slice(0 - y1 + output_size[1], output_size[1])
+ y_lr_in_ = slice(0 - y1 + input__size[1], input__size[1])
+
+ else:
+ # y dimension does not change
+ y_ul_out = slice(None)
+ y_ul_in_ = slice(None)
+
+ y_ll_out = slice(None)
+ y_ll_in_ = slice(None)
+
+ y_ur_out = slice(None)
+ y_ur_in_ = slice(None)
+
+ y_lr_out = slice(None)
+ y_lr_in_ = slice(None)
+
+ if len(array.shape) == 2:
+ # image array
+ array_resize = np.zeros(output_size, dtype=np.complex64)
+ array_fft = np.fft.fft2(array)
+
+ # copy each quadrant into the resize array
+ array_resize[x_ul_out, y_ul_out] = array_fft[x_ul_in_, y_ul_in_]
+ array_resize[x_ll_out, y_ll_out] = array_fft[x_ll_in_, y_ll_in_]
+ array_resize[x_ur_out, y_ur_out] = array_fft[x_ur_in_, y_ur_in_]
+ array_resize[x_lr_out, y_lr_out] = array_fft[x_lr_in_, y_lr_in_]
+
+ # Band limit if needed
+ if bandlimit_nyquist is not None:
+ array_resize *= k_filt
+
+ # Back to real space
+ array_resize = np.real(np.fft.ifft2(array_resize)).astype(dtype)
+
+ elif len(array.shape) == 4:
+ # This case is the same as the 2D case, but loops over the probe index arrays
+
+ # init arrays
+ array_resize = np.zeros((*array.shape[:2], *output_size), dtype)
+ array_fft = np.zeros(input__size, dtype=np.complex64)
+ array_output = np.zeros(output_size, dtype=np.complex64)
+
+ for Rx, Ry in tqdmnd(
+ array.shape[0],
+ array.shape[1],
+ desc="Resampling 4D datacube",
+ unit="DP",
+ unit_scale=True,
+ ):
+ array_fft[:, :] = np.fft.fft2(array[Rx, Ry, :, :])
+ array_output[:, :] = 0
+
+ # copy each quadrant into the resize array
+ array_output[x_ul_out, y_ul_out] = array_fft[x_ul_in_, y_ul_in_]
+ array_output[x_ll_out, y_ll_out] = array_fft[x_ll_in_, y_ll_in_]
+ array_output[x_ur_out, y_ur_out] = array_fft[x_ur_in_, y_ur_in_]
+ array_output[x_lr_out, y_lr_out] = array_fft[x_lr_in_, y_lr_in_]
+
+ # Band limit if needed
+ if bandlimit_nyquist is not None:
+ array_output *= k_filt
+
+ # Back to real space
+ array_resize[Rx, Ry, :, :] = np.real(np.fft.ifft2(array_output)).astype(
+ dtype
+ )
+
+ # Enforce positivity if needed, after filtering
+ if force_nonnegative:
+ array_resize = np.maximum(array_resize, 0)
+
+ # Normalization
+ array_resize = array_resize * scale_output
+
+ return array_resize
+
+
+# import matplotlib.pyplot as plt
+# from mpl_toolkits.axes_grid1 import make_axes_locatable
+# from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
+# import matplotlib.font_manager as fm
+#
+#
+# try:
+# from IPython.display import clear_output
+# except ImportError:
+# def clear_output(wait=True):
+# pass
+#
+# def plot(img, title='Image', savePath=None, cmap='inferno', show=True, vmax=None,
+# figsize=(10, 10), scale=None):
+# fig, ax = plt.subplots(figsize=figsize)
+# im = ax.imshow(img, interpolation='nearest', cmap=plt.get_cmap(cmap), vmax=vmax)
+# divider = make_axes_locatable(ax)
+# cax = divider.append_axes("right", size="5%", pad=0.05)
+# plt.colorbar(im, cax=cax)
+# ax.set_title(title)
+# fontprops = fm.FontProperties(size=18)
+# if scale is not None:
+# scalebar = AnchoredSizeBar(ax.transData,
+# scale[0], scale[1], 'lower right',
+# pad=0.1,
+# color='white',
+# frameon=False,
+# size_vertical=img.shape[0] / 40,
+# fontproperties=fontprops)
+#
+# ax.add_artist(scalebar)
+# ax.grid(False)
+# if savePath is not None:
+# fig.savefig(savePath + '.png', dpi=600)
+# fig.savefig(savePath + '.eps', dpi=600)
+# if show:
+# plt.show()
diff --git a/py4DSTEM/process/wholepatternfit/__init__.py b/py4DSTEM/process/wholepatternfit/__init__.py
new file mode 100644
index 000000000..8fb0e351e
--- /dev/null
+++ b/py4DSTEM/process/wholepatternfit/__init__.py
@@ -0,0 +1,2 @@
+from .wp_models import *
+from .wpf import *
diff --git a/py4DSTEM/process/wholepatternfit/wp_models.py b/py4DSTEM/process/wholepatternfit/wp_models.py
new file mode 100644
index 000000000..3d53c1743
--- /dev/null
+++ b/py4DSTEM/process/wholepatternfit/wp_models.py
@@ -0,0 +1,1303 @@
+from typing import Optional
+from enum import Flag, auto
+import numpy as np
+
+
+class WPFModelType(Flag):
+ """
+ Flags to signify capabilities and other semantics of a Model
+ """
+
+ BACKGROUND = auto()
+
+ AMORPHOUS = auto()
+ LATTICE = auto()
+ MOIRE = auto()
+
+ DUMMY = auto() # Model has no direct contribution to pattern
+ META = auto() # Model depends on multiple sub-Models
+
+
+class WPFModel:
+ """
+ Prototype class for a compent of a whole-pattern model.
+ Holds the following:
+ name: human-readable name of the model
+ params: a dict of names and initial (or returned) values of the model parameters
+ func: a function that takes as arguments:
+ • the diffraction pattern being built up, which the function should modify in place
+ • positional arguments in the same order as the params dictionary
+ • keyword arguments. this is to provide some pre-computed information for convenience
+ kwargs will include:
+ • xArray, yArray meshgrid of the x and y coordinates
+ • global_x0 global x-coordinate of the pattern center
+ • global_y0 global y-coordinate of the pattern center
+ jacobian: a function that takes as arguments:
+ • the diffraction pattern being built up, which the function should modify in place
+ • positional arguments in the same order as the params dictionary
+ • offset: the first index (j) that values should be written into
+ (the function should ONLY write into 0,1, and offset:offset+nParams)
+ 0 and 1 are the entries for global_x0 and global_y0, respectively
+ **REMEMBER TO ADD TO 0 and 1 SINCE ALL MODELS CAN CONTRIBUTE TO THIS PARTIAL DERIVATIVE**
+ • keyword arguments. this is to provide some pre-computed information for convenience
+ """
+
+ def __init__(self, name: str, params: dict, model_type=WPFModelType.DUMMY):
+ self.name = name
+ self.params = params
+
+ self.nParams = len(params.keys())
+
+ self.hasJacobian = getattr(self, "jacobian", None) is not None
+
+ self.model_type = model_type
+
+ def func(self, DP: np.ndarray, x, **kwargs) -> None:
+ raise NotImplementedError()
+
+ # Required signature for the Jacobian:
+ #
+ # def jacobian(self, J: np.ndarray, *args, offset: int, **kwargs) -> None:
+ # raise NotImplementedError()
+
+
+class Parameter:
+ def __init__(
+ self,
+ initial_value,
+ lower_bound: Optional[float] = None,
+ upper_bound: Optional[float] = None,
+ ):
+ """
+ Object representing a fitting parameter with bounds.
+
+ Can be specified three ways:
+ Parameter(initial_value) - Unbounded, with an initial guess
+ Parameter(initial_value, deviation) - Bounded within deviation of initial_guess
+ Parameter(initial_value, lower_bound, upper_bound) - Both bounds specified
+ """
+ if hasattr(initial_value, "__iter__"):
+ if len(initial_value) == 2:
+ initial_value = (
+ initial_value[0],
+ initial_value[0] - initial_value[1],
+ initial_value[0] + initial_value[1],
+ )
+ self.set_params(*initial_value)
+ else:
+ self.set_params(initial_value, lower_bound, upper_bound)
+
+ # Store a dummy offset. This must be set by WPF during setup
+ # This stores the index in the master parameter and Jacobian arrays
+ # corresponding to this parameter
+ self.offset = np.nan
+
+ def set_params(
+ self,
+ initial_value,
+ lower_bound,
+ upper_bound,
+ ):
+ self.initial_value = initial_value
+ self.lower_bound = lower_bound if lower_bound is not None else -np.inf
+ self.upper_bound = upper_bound if upper_bound is not None else np.inf
+
+ def __str__(self):
+ return f"Value: {self.initial_value} (Range: {self.lower_bound},{self.upper_bound})"
+
+ def __repr__(self):
+ return f"Value: {self.initial_value} (Range: {self.lower_bound},{self.upper_bound})"
+
+
+class _BaseModel(WPFModel):
+ """
+ Model object used by the WPF class as a container for the global Parameters.
+
+ **This object should not be instantiated directly.**
+ """
+
+ def __init__(self, x0, y0, name="Globals"):
+ params = {"x center": Parameter(x0), "y center": Parameter(y0)}
+
+ super().__init__(name, params, model_type=WPFModelType.DUMMY)
+
+ def func(self, DP: np.ndarray, x, **kwargs) -> None:
+ pass
+
+ def jacobian(self, J: np.ndarray, *args, **kwargs) -> None:
+ pass
+
+
+class DCBackground(WPFModel):
+ """
+ Model representing constant background intensity.
+
+ Parameters
+ ----------
+ background_value
+ Background intensity value.
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ """
+
+ def __init__(self, background_value=0.0, name="DC Background"):
+ params = {"DC Level": Parameter(background_value)}
+
+ super().__init__(name, params, model_type=WPFModelType.BACKGROUND)
+
+ def func(self, DP: np.ndarray, x, **kwargs) -> None:
+ DP += x[self.params["DC Level"].offset]
+
+ def jacobian(self, J: np.ndarray, *args, **kwargs):
+ J[:, self.params["DC Level"].offset] = 1
+
+
+class GaussianBackground(WPFModel):
+ """
+ Model representing a 2D Gaussian intensity distribution
+
+ Parameters
+ ----------
+ WPF: WholePatternFit
+ Parent WPF object
+ sigma
+ parameter specifying width of the Gaussian
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ intensity
+ parameter specifying intensity of the Gaussian
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ global_center: bool
+ If True, uses same center coordinate as the global model
+ If False, uses an independent center
+ x0, y0:
+ Center coordinates of model for local origin
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ """
+
+ def __init__(
+ self,
+ WPF,
+ sigma,
+ intensity,
+ global_center=True,
+ x0=0.0,
+ y0=0.0,
+ name="Gaussian Background",
+ ):
+ params = {"sigma": Parameter(sigma), "intensity": Parameter(intensity)}
+ if global_center:
+ params["x center"] = WPF.coordinate_model.params["x center"]
+ params["y center"] = WPF.coordinate_model.params["y center"]
+ else:
+ params["x center"] = Parameter(x0)
+ params["y center"] = Parameter(y0)
+
+ super().__init__(name, params, model_type=WPFModelType.BACKGROUND)
+
+ def func(self, DP: np.ndarray, x: np.ndarray, **kwargs) -> None:
+ sigma = x[self.params["sigma"].offset]
+ level = x[self.params["intensity"].offset]
+
+ r = kwargs["parent"]._get_distance(
+ x, self.params["x center"], self.params["y center"]
+ )
+
+ DP += level * np.exp(r**2 / (-2 * sigma**2))
+
+ def jacobian(self, J: np.ndarray, x: np.ndarray, **kwargs) -> None:
+ sigma = x[self.params["sigma"].offset]
+ level = x[self.params["intensity"].offset]
+ x0 = x[self.params["x center"].offset]
+ y0 = x[self.params["y center"].offset]
+
+ r = kwargs["parent"]._get_distance(
+ x, self.params["x center"], self.params["y center"]
+ )
+ exp_expr = np.exp(r**2 / (-2 * sigma**2))
+
+ # dF/d(x0)
+ J[:, self.params["x center"].offset] += (
+ level * (kwargs["xArray"] - x0) * exp_expr / sigma**2
+ ).ravel()
+
+ # dF/d(y0)
+ J[:, self.params["y center"].offset] += (
+ level * (kwargs["yArray"] - y0) * exp_expr / sigma**2
+ ).ravel()
+
+ # dF/s(sigma)
+ J[:, self.params["sigma"].offset] += (
+ level * r**2 * exp_expr / sigma**3
+ ).ravel()
+
+ # dF/d(level)
+ J[:, self.params["intensity"].offset] += exp_expr.ravel()
+
+
+class GaussianRing(WPFModel):
+ """
+ Model representing a halo with Gaussian falloff
+
+ Parameters
+ ----------
+ WPF: WholePatternFit
+ parent fitting object
+ radius:
+ radius of halo
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ sigma:
+ width of Gaussian falloff
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ intensity:
+ Intensity of the halo
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ global_center: bool
+ If True, uses same center coordinate as the global model
+ If False, uses an independent center
+ x0, y0:
+ Center coordinates of model for local origin
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ """
+
+ def __init__(
+ self,
+ WPF,
+ radius,
+ sigma,
+ intensity,
+ global_center=True,
+ x0=0.0,
+ y0=0.0,
+ name="Gaussian Ring",
+ ):
+ params = {
+ "radius": Parameter(radius),
+ "sigma": Parameter(sigma),
+ "intensity": Parameter(intensity),
+ "x center": WPF.coordinate_model.params["x center"]
+ if global_center
+ else Parameter(x0),
+ "y center": WPF.coordinate_model.params["y center"]
+ if global_center
+ else Parameter(y0),
+ }
+
+ super().__init__(name, params, model_type=WPFModelType.AMORPHOUS)
+
+ def func(self, DP: np.ndarray, x: np.ndarray, **kwargs) -> None:
+ radius = x[self.params["radius"].offset]
+ sigma = x[self.params["sigma"].offset]
+ level = x[self.params["level"].offset]
+
+ r = kwargs["parent"]._get_distance(
+ x, self.params["x center"], self.params["y center"]
+ )
+
+ DP += level * np.exp((r - radius) ** 2 / (-2 * sigma**2))
+
+ def jacobian(self, J: np.ndarray, x: np.ndarray, **kwargs) -> None:
+ radius = x[self.params["radius"].offset]
+ sigma = x[self.params["sigma"].offset]
+ level = x[self.params["level"].offset]
+
+ x0 = x[self.params["x center"].offset]
+ y0 = x[self.params["y center"].offset]
+ r = kwargs["parent"]._get_distance(
+ x, self.params["x center"], self.params["y center"]
+ )
+
+ local_r = radius - r
+ clipped_r = np.maximum(local_r, 0.1)
+
+ exp_expr = np.exp(local_r**2 / (-2 * sigma**2))
+
+ # dF/d(x0)
+ J[:, self.params["x center"].offset] += (
+ level
+ * exp_expr
+ * (kwargs["xArray"] - x0)
+ * local_r
+ / (sigma**2 * clipped_r)
+ ).ravel()
+
+ # dF/d(y0)
+ J[:, self.parans["y center"].offset] += (
+ level
+ * exp_expr
+ * (kwargs["yArray"] - y0)
+ * local_r
+ / (sigma**2 * clipped_r)
+ ).ravel()
+
+ # dF/d(radius)
+ J[:, self.params["radius"].offset] += (
+ -1.0 * level * exp_expr * local_r / (sigma**2)
+ ).ravel()
+
+ # dF/d(sigma)
+ J[:, self.params["sigma"].offset] += (
+ level * local_r**2 * exp_expr / sigma**3
+ ).ravel()
+
+ # dF/d(intensity)
+ J[:, self.params["intensity"].offset] += exp_expr.ravel()
+
+
+class SyntheticDiskLattice(WPFModel):
+ """
+ Model representing a lattice of diffraction disks with a soft edge
+
+ Parameters
+ ----------
+
+ WPF: WholePatternFit
+ parent fitting object
+ ux,uy,vx,vy
+ x and y components of the lattice vectors u and v.
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ disk_radius
+ Radius of each diffraction disk.
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ disk_width
+ Width of the smooth falloff at the edge of the disk
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ u_max, v_max
+ Maximum lattice indices to include in the pattern.
+ Disks outside the pattern are automatically clipped.
+ intensity_0
+ Initial intensity for each diffraction disk.
+ Each disk intensity is an independent fit variable in the final model
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ refine_radius: bool
+ Flag whether disk radius is made a fitting parameter
+ refine_width: bool
+ Flag whether disk edge width is made a fitting parameter
+ global_center: bool
+ If True, uses same center coordinate as the global model
+ If False, uses an independent center
+ x0, y0:
+ Center coordinates of model for local origin
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ exclude_indices: list
+ Indices to exclude from the pattern
+ include_indices: list
+ If specified, only the indices in the list are added to the pattern
+ """
+
+ def __init__(
+ self,
+ WPF,
+ ux: float,
+ uy: float,
+ vx: float,
+ vy: float,
+ disk_radius: float,
+ disk_width: float,
+ u_max: int,
+ v_max: int,
+ intensity_0: float,
+ refine_radius: bool = False,
+ refine_width: bool = False,
+ global_center: bool = True,
+ x0: float = 0.0,
+ y0: float = 0.0,
+ exclude_indices: list = [],
+ include_indices: list = None,
+ name="Synthetic Disk Lattice",
+ verbose=False,
+ ):
+ self.disk_radius = disk_radius
+ self.disk_width = disk_width
+
+ params = {}
+
+ if global_center:
+ params["x center"] = WPF.coordinate_model.params["x center"]
+ params["y center"] = WPF.coordinate_model.params["y center"]
+ else:
+ params["x center"] = Parameter(x0)
+ params["y center"] = Parameter(y0)
+
+ x0 = params["x center"].initial_value
+ y0 = params["y center"].initial_value
+
+ params["ux"] = Parameter(ux)
+ params["uy"] = Parameter(uy)
+ params["vx"] = Parameter(vx)
+ params["vy"] = Parameter(vy)
+
+ Q_Nx = WPF.static_data["Q_Nx"]
+ Q_Ny = WPF.static_data["Q_Ny"]
+
+ if include_indices is None:
+ u_inds, v_inds = np.mgrid[-u_max : u_max + 1, -v_max : v_max + 1]
+ self.u_inds = u_inds.ravel()
+ self.v_inds = v_inds.ravel()
+
+ delete_mask = np.zeros_like(self.u_inds, dtype=bool)
+ for i, (u, v) in enumerate(zip(u_inds.ravel(), v_inds.ravel())):
+ x = (
+ x0
+ + (u * params["ux"].initial_value)
+ + (v * params["vx"].initial_value)
+ )
+ y = (
+ y0
+ + (u * params["uy"].initial_value)
+ + (v * params["vy"].initial_value)
+ )
+ if [u, v] in exclude_indices:
+ delete_mask[i] = True
+ elif (x < 0) or (x > Q_Nx) or (y < 0) or (y > Q_Ny):
+ delete_mask[i] = True
+ if verbose:
+ print(
+ f"Excluding peak [{u},{v}] because it is outside the pattern..."
+ )
+ else:
+ params[f"[{u},{v}] Intensity"] = Parameter(intensity_0)
+
+ self.u_inds = self.u_inds[~delete_mask]
+ self.v_inds = self.v_inds[~delete_mask]
+ else:
+ for ind in include_indices:
+ params[f"[{ind[0]},{ind[1]}] Intensity"] = Parameter(intensity_0)
+ inds = np.array(include_indices)
+ self.u_inds = inds[:, 0]
+ self.v_inds = inds[:, 1]
+
+ self.refine_radius = refine_radius
+ self.refine_width = refine_width
+ if refine_radius:
+ params["disk radius"] = Parameter(disk_radius)
+ if refine_width:
+ params["edge width"] = Parameter(disk_width)
+
+ super().__init__(name, params, model_type=WPFModelType.LATTICE)
+
+ def func(self, DP: np.ndarray, x: np.ndarray, **static_data) -> None:
+ x0 = x[self.params["x center"].offset]
+ y0 = x[self.params["y center"].offset]
+ ux = x[self.params["ux"].offset]
+ uy = x[self.params["uy"].offset]
+ vx = x[self.params["vx"].offset]
+ vy = x[self.params["vy"].offset]
+
+ disk_radius = (
+ x[self.params["disk radius"].offset]
+ if self.refine_radius
+ else self.disk_radius
+ )
+
+ disk_width = (
+ x[self.params["edge width"].offset]
+ if self.refine_width
+ else self.disk_width
+ )
+
+ for i, (u, v) in enumerate(zip(self.u_inds, self.v_inds)):
+ x_pos = x0 + (u * ux) + (v * vx)
+ y_pos = y0 + (u * uy) + (v * vy)
+
+ DP += x[self.params[f"[{u},{v}] Intensity"].offset] / (
+ 1.0
+ + np.exp(
+ np.minimum(
+ 4
+ * (
+ np.sqrt(
+ (static_data["xArray"] - x_pos) ** 2
+ + (static_data["yArray"] - y_pos) ** 2
+ )
+ - disk_radius
+ )
+ / disk_width,
+ 20,
+ )
+ )
+ )
+
+ def jacobian(self, J: np.ndarray, x: np.ndarray, **static_data) -> None:
+ x0 = x[self.params["x center"].offset]
+ y0 = x[self.params["y center"].offset]
+ ux = x[self.params["ux"].offset]
+ uy = x[self.params["uy"].offset]
+ vx = x[self.params["vx"].offset]
+ vy = x[self.params["vy"].offset]
+ WPF = static_data["parent"]
+
+ r = np.maximum(
+ 5e-1, WPF._get_distance(x, self.params["x center"], self.params["y center"])
+ )
+
+ disk_radius = (
+ x[self.params["disk radius"].offset]
+ if self.refine_radius
+ else self.disk_radius
+ )
+
+ disk_width = (
+ x[self.params["edge width"].offset]
+ if self.refine_width
+ else self.disk_width
+ )
+
+ for i, (u, v) in enumerate(zip(self.u_inds, self.v_inds)):
+ x_pos = x0 + (u * ux) + (v * vx)
+ y_pos = y0 + (u * uy) + (v * vy)
+
+ disk_intensity = x[self.params[f"[{u},{v}] Intensity"].offset]
+
+ r_disk = np.maximum(
+ 5e-1,
+ np.sqrt(
+ (static_data["xArray"] - x_pos) ** 2
+ + (static_data["yArray"] - y_pos) ** 2
+ ),
+ )
+
+ mask = r_disk < (2 * disk_radius)
+
+ top_exp = mask * np.exp(
+ np.minimum(30, 4 * ((mask * r_disk) - disk_radius) / disk_width)
+ )
+
+ # dF/d(x0)
+ dx = (
+ 4
+ * disk_intensity
+ * (static_data["xArray"] - x_pos)
+ * top_exp
+ / ((1.0 + top_exp) ** 2 * disk_width * r)
+ ).ravel()
+
+ # dF/d(y0)
+ dy = (
+ 4
+ * disk_intensity
+ * (static_data["yArray"] - y_pos)
+ * top_exp
+ / ((1.0 + top_exp) ** 2 * disk_width * r)
+ ).ravel()
+
+ # insert center position derivatives
+ J[:, self.params["x center"].offset] += disk_intensity * dx
+ J[:, self.params["y center"].offset] += disk_intensity * dy
+
+ # insert lattice vector derivatives
+ J[:, self.params["ux"].offset] += disk_intensity * u * dx
+ J[:, self.params["uy"].offset] += disk_intensity * u * dy
+ J[:, self.params["vx"].offset] += disk_intensity * v * dx
+ J[:, self.params["vy"].offset] += disk_intensity * v * dy
+
+ # insert intensity derivative
+ dI = (mask * (1.0 / (1.0 + top_exp))).ravel()
+ J[:, self.params[f"[{u},{v}] Intensity"].offset] += dI
+
+ # insert disk radius derivative
+ if self.refine_radius:
+ dR = (
+ 4.0 * disk_intensity * top_exp / (disk_width * (1.0 + top_exp) ** 2)
+ ).ravel()
+ J[:, self.params["disk radius"].offset] += dR
+
+ if self.refine_width:
+ dW = (
+ 4.0
+ * disk_intensity
+ * top_exp
+ * (r_disk - disk_radius)
+ / (disk_width**2 * (1.0 + top_exp) ** 2)
+ ).ravel()
+ J[:, self.params["edge width"].offset] += dW
+
+
+class SyntheticDiskMoire(WPFModel):
+ """
+ Model of diffraction disks arising from interference between two lattices.
+
+ The Moire unit cell is determined automatically using the two input lattices.
+
+ Parameters
+ ----------
+ WPF: WholePatternFit
+ parent fitting object
+ lattice_a, lattice_b: SyntheticDiskLattice
+ parent lattices for the Moire
+ intensity_0
+ Initial guess of Moire disk intensity
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ decorated_peaks: list
+ When specified, only the reflections in the list are decorated with Moire spots
+ If not specified, all peaks are decorated
+ link_moire_disk_intensities: bool
+ When False, each Moire disk has an independently fit intensity
+ When True, Moire disks arising from the same order of parent reflection share
+ the same intensity
+ link_disk_parameters: bool
+ When True, edge_width and disk_radius are inherited from lattice_a
+ refine_width: bool
+ Flag whether disk edge width is a fit variable
+ edge_width
+ Width of the soft edge of the diffraction disk.
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ refine_radius: bool
+ Flag whether disk radius is a fit variable
+ disk radius
+ Radius of the diffraction disks
+ Specified as initial_value, (initial_value, deviation), or
+ (initial_value, lower_bound, upper_bound). See
+ Parameter documentation for details.
+ """
+
+ def __init__(
+ self,
+ WPF,
+ lattice_a: SyntheticDiskLattice,
+ lattice_b: SyntheticDiskLattice,
+ intensity_0: float,
+ decorated_peaks: list = None,
+ link_moire_disk_intensities: bool = False,
+ link_disk_parameters: bool = True,
+ refine_width: bool = True,
+ edge_width: list = None,
+ refine_radius: bool = True,
+ disk_radius: list = None,
+ name: str = "Moire Lattice",
+ ):
+ # ensure both models share the same center coordinate
+ if (lattice_a.params["x center"] is not lattice_b.params["x center"]) or (
+ lattice_a.params["y center"] is not lattice_b.params["y center"]
+ ):
+ raise ValueError(
+ "The center coordinates for each model must be linked, "
+ "either by passing global_center=True or linking after instantiation."
+ )
+
+ self.lattice_a = lattice_a
+ self.lattice_b = lattice_b
+
+ # construct a 2x4 matrix "M" that transforms the parent lattices into
+ # the moire lattice vectors
+
+ lat_ab = self._get_parent_lattices(lattice_a, lattice_b)
+
+ # pick the pairing that gives the smallest unit cell
+ test_peaks = np.stack((lattice_b.u_inds, lattice_b.v_inds), axis=1)
+ tests = np.stack(
+ [
+ np.hstack((np.eye(2), np.vstack((b1, b2))))
+ for b1 in test_peaks
+ for b2 in test_peaks
+ if not np.allclose(b1, b2)
+ ],
+ axis=0,
+ )
+ # choose only cells where the two unit vectors are not nearly parallel,
+ # and penalize cells with large discrepancy in lattce vector length
+ lat_m = tests @ lat_ab
+ a_dot_b = (
+ np.sum(lat_m[:, 0] * lat_m[:, 1], axis=1)
+ / np.minimum(
+ np.linalg.norm(lat_m[:, 0], axis=1), np.linalg.norm(lat_m[:, 1], axis=1)
+ )
+ ** 2
+ )
+ tests = tests[
+ np.abs(a_dot_b) < 0.9
+ ] # this factor of 0.9 sets the parallel cutoff
+ # with the parallel vectors filtered, pick the cell with the smallest volume
+ lat_m = tests @ lat_ab
+ V = np.sum(
+ lat_m[:, 0]
+ * np.cross(
+ np.hstack((lat_m[:, 1], np.zeros((lat_m.shape[0],))[:, None])),
+ [0, 0, 1],
+ )[:, :2],
+ axis=1,
+ )
+ M = tests[np.argmin(np.abs(V))]
+
+ # ensure the moire vectors are less 90 deg apart
+ if np.arccos(
+ ((M @ lat_ab)[0] @ (M @ lat_ab)[1])
+ / (np.linalg.norm((M @ lat_ab)[0]) * np.linalg.norm((M @ lat_ab)[1]))
+ ) > np.radians(90):
+ M[1] *= -1.0
+
+ # ensure they are right-handed
+ if np.cross(*(M @ lat_ab)) < 0.0:
+ M = np.flipud(np.eye(2)) @ M
+
+ # store moire construction
+ self.moire_matrix = M
+
+ # generate the indices of each peak, then find unique peaks
+ if decorated_peaks is not None:
+ decorated_peaks = np.array(decorated_peaks)
+ parent_peaks = np.vstack(
+ (
+ np.concatenate(
+ (decorated_peaks, np.zeros_like(decorated_peaks)), axis=1
+ ),
+ np.concatenate(
+ (np.zeros_like(decorated_peaks), decorated_peaks), axis=1
+ ),
+ )
+ )
+ else:
+ parent_peaks = np.vstack(
+ (
+ np.concatenate(
+ (
+ np.stack((lattice_a.u_inds, lattice_a.v_inds), axis=1),
+ np.zeros((lattice_a.u_inds.shape[0], 2)),
+ ),
+ axis=1,
+ ),
+ np.concatenate(
+ (
+ np.zeros((lattice_b.u_inds.shape[0], 2)),
+ np.stack((lattice_b.u_inds, lattice_b.v_inds), axis=1),
+ ),
+ axis=1,
+ ),
+ )
+ )
+
+ # trial indices for moire peaks
+ mx, my = np.mgrid[-1:2, -1:2]
+ moire_peaks = np.stack([mx.ravel(), my.ravel()], axis=1)[1:-1]
+
+ # construct a giant index array with columns a_h a_k b_h b_k m_h m_k
+ parent_expanded = np.zeros((parent_peaks.shape[0], 6))
+ parent_expanded[:, :4] = parent_peaks
+ moire_expanded = np.zeros((moire_peaks.shape[0], 6))
+ moire_expanded[:, 4:] = moire_peaks
+
+ all_indices = (
+ parent_expanded[:, None, :] + moire_expanded[None, :, :]
+ ).reshape(-1, 6)
+
+ lat_abm = np.vstack((lat_ab, M @ lat_ab))
+
+ all_peaks = all_indices @ lat_abm
+
+ _, idx_unique = np.unique(all_peaks, axis=0, return_index=True)
+
+ all_indices = all_indices[idx_unique]
+
+ # remove peaks outside of pattern
+ Q_Nx = WPF.static_data["Q_Nx"]
+ Q_Ny = WPF.static_data["Q_Ny"]
+ all_peaks = all_indices @ lat_abm
+ all_peaks[:, 0] += lattice_a.params["x center"].initial_value
+ all_peaks[:, 1] += lattice_a.params["y center"].initial_value
+ delete_mask = np.logical_or.reduce(
+ [
+ all_peaks[:, 0] < 0.0,
+ all_peaks[:, 0] >= Q_Nx,
+ all_peaks[:, 1] < 0.0,
+ all_peaks[:, 1] >= Q_Ny,
+ ]
+ )
+ all_indices = all_indices[~delete_mask]
+
+ # remove spots that coincide with primary peaks
+ parent_spots = parent_peaks @ lat_ab
+ self.moire_indices_uvm = np.array(
+ [idx for idx in all_indices if (idx @ lat_abm) not in parent_spots]
+ )
+
+ self.link_moire_disk_intensities = link_moire_disk_intensities
+ if link_moire_disk_intensities:
+ # each order of parent reflection has a separate moire intensity
+ max_order = int(np.max(np.abs(self.moire_indices_uvm[:, :4])))
+
+ params = {
+ f"Order {n} Moire Intensity": Parameter(intensity_0)
+ for n in range(max_order + 1)
+ }
+ else:
+ params = {
+ f"a ({ax},{ay}), b ({bx},{by}), moire ({mx},{my}) Intensity": Parameter(
+ intensity_0
+ )
+ for ax, ay, bx, by, mx, my in self.moire_indices_uvm
+ }
+
+ params["x center"] = lattice_a.params["x center"]
+ params["y center"] = lattice_a.params["y center"]
+
+ # add disk edge and width parameters if needed
+ if link_disk_parameters:
+ if (lattice_a.refine_width) and (lattice_b.refine_width):
+ self.refine_width = True
+ params["edge width"] = lattice_a.params["edge width"]
+ if (lattice_a.refine_radius) and (lattice_b.refine_radius):
+ self.refine_radius = True
+ params["disk radius"] = lattice_a.params["disk radius"]
+ else:
+ self.refine_width = refine_width
+ if self.refine_width:
+ params["edge width"] = Parameter(edge_width)
+
+ self.refine_radius = refine_radius
+ if self.refine_radius:
+ params["disk radius"] = Parameter(disk_radius)
+
+ # store some data that helps compute the derivatives
+ selector_matrices = np.eye(8).reshape(-1, 4, 2)
+ selector_parameters = [
+ self.lattice_a.params["ux"],
+ self.lattice_a.params["uy"],
+ self.lattice_a.params["vx"],
+ self.lattice_a.params["vy"],
+ self.lattice_b.params["ux"],
+ self.lattice_b.params["uy"],
+ self.lattice_b.params["vx"],
+ self.lattice_b.params["vy"],
+ ]
+ self.parent_vector_selectors = [
+ (p, m) for p, m in zip(selector_parameters, selector_matrices)
+ ]
+
+ super().__init__(
+ name,
+ params,
+ model_type=WPFModelType.META | WPFModelType.MOIRE,
+ )
+
+ def _get_parent_lattices(self, lattice_a, lattice_b):
+ lat_a = np.array(
+ [
+ [
+ lattice_a.params["ux"].initial_value,
+ lattice_a.params["uy"].initial_value,
+ ],
+ [
+ lattice_a.params["vx"].initial_value,
+ lattice_a.params["vy"].initial_value,
+ ],
+ ]
+ )
+
+ lat_b = np.array(
+ [
+ [
+ lattice_b.params["ux"].initial_value,
+ lattice_b.params["uy"].initial_value,
+ ],
+ [
+ lattice_b.params["vx"].initial_value,
+ lattice_b.params["vy"].initial_value,
+ ],
+ ]
+ )
+
+ return np.vstack((lat_a, lat_b))
+
+ def func(self, DP: np.ndarray, x: np.ndarray, **static_data):
+ # construct the moire unit cell from the current vectors
+ # of the two parent lattices
+
+ lat_ab = self._get_parent_lattices(self.lattice_a, self.lattice_b)
+ lat_abm = np.vstack((lat_ab, self.moire_matrix @ lat_ab))
+
+ # grab shared parameters
+ disk_radius = (
+ x[self.params["disk radius"].offset]
+ if self.refine_radius
+ else self.disk_radius
+ )
+
+ disk_width = (
+ x[self.params["edge width"].offset]
+ if self.refine_width
+ else self.disk_width
+ )
+
+ # compute positions of each moire peak
+ positions = self.moire_indices_uvm @ lat_abm
+ positions += np.array(
+ [x[self.params["x center"].offset], x[self.params["y center"].offset]]
+ )
+
+ for (x_pos, y_pos), indices in zip(positions, self.moire_indices_uvm):
+ # Each peak has an intensity based on the max index of parent lattice
+ # which it decorates
+ order = int(np.max(np.abs(indices[:4])))
+
+ if self.link_moire_disk_intensities:
+ intensity = x[self.params[f"Order {order} Moire Intensity"].offset]
+ else:
+ ax, ay, bx, by, mx, my = indices
+ intensity = x[
+ self.params[
+ f"a ({ax},{ay}), b ({bx},{by}), moire ({mx},{my}) Intensity"
+ ].offset
+ ]
+
+ DP += intensity / (
+ 1.0
+ + np.exp(
+ np.minimum(
+ 4
+ * (
+ np.sqrt(
+ (static_data["xArray"] - x_pos) ** 2
+ + (static_data["yArray"] - y_pos) ** 2
+ )
+ - disk_radius
+ )
+ / disk_width,
+ 20,
+ )
+ )
+ )
+
+ def jacobian(self, J: np.ndarray, x: np.ndarray, **static_data):
+ # construct the moire unit cell from the current vectors
+ # of the two parent lattices
+ lat_ab = self._get_parent_lattices(self.lattice_a, self.lattice_b)
+ lat_abm = np.vstack((lat_ab, self.moire_matrix @ lat_ab))
+
+ # grab shared parameters
+ disk_radius = (
+ x[self.params["disk radius"].offset]
+ if self.refine_radius
+ else self.disk_radius
+ )
+
+ disk_width = (
+ x[self.params["edge width"].offset]
+ if self.refine_width
+ else self.disk_width
+ )
+
+ # distance from center coordinate
+ r = np.maximum(
+ 5e-1,
+ static_data["parent"]._get_distance(
+ x, self.params["x center"], self.params["y center"]
+ ),
+ )
+
+ # compute positions of each moire peak
+ positions = self.moire_indices_uvm @ lat_abm
+ positions += np.array(
+ [x[self.params["x center"].offset], x[self.params["y center"].offset]]
+ )
+
+ for (x_pos, y_pos), indices in zip(positions, self.moire_indices_uvm):
+ # Each peak has an intensity based on the max index of parent lattice
+ # which it decorates
+ if self.link_moire_disk_intensities:
+ order = int(np.max(np.abs(indices[:4])))
+ intensity_idx = self.params[f"Order {order} Moire Intensity"].offset
+ else:
+ ax, ay, bx, by, mx, my = indices
+ intensity_idx = self.params[
+ f"a ({ax},{ay}), b ({bx},{by}), moire ({mx},{my}) Intensity"
+ ].offset
+ disk_intensity = x[intensity_idx]
+
+ r_disk = np.maximum(
+ 5e-1,
+ np.sqrt(
+ (static_data["xArray"] - x_pos) ** 2
+ + (static_data["yArray"] - y_pos) ** 2
+ ),
+ )
+
+ mask = r_disk < (2 * disk_radius)
+
+ # clamp the argument of the exponent at a very large finite value
+ top_exp = mask * np.exp(
+ np.minimum(30, 4 * ((mask * r_disk) - disk_radius) / disk_width)
+ )
+
+ # dF/d(x0)
+ dx = (
+ 4
+ * disk_intensity
+ * (static_data["xArray"] - x_pos)
+ * top_exp
+ / ((1.0 + top_exp) ** 2 * disk_width * r)
+ ).ravel()
+
+ # dF/d(y0)
+ dy = (
+ 4
+ * disk_intensity
+ * (static_data["yArray"] - y_pos)
+ * top_exp
+ / ((1.0 + top_exp) ** 2 * disk_width * r)
+ ).ravel()
+
+ # insert center position derivatives
+ J[:, self.params["x center"].offset] += disk_intensity * dx
+ J[:, self.params["y center"].offset] += disk_intensity * dy
+
+ # insert lattice vector derivatives
+ for par, mat in self.parent_vector_selectors:
+ # find the x and y derivatives of the position of this
+ # disk in terms of each of the parent lattice vectors
+ d_abm = np.vstack((mat, self.moire_matrix @ mat))
+ d_param = indices @ d_abm
+ J[:, par.offset] += disk_intensity * (d_param[0] * dx + d_param[1] * dy)
+
+ # insert intensity derivative
+ dI = (mask * (1.0 / (1.0 + top_exp))).ravel()
+ J[:, intensity_idx] += dI
+
+ # insert disk radius derivative
+ if self.refine_radius:
+ dR = (
+ 4.0 * disk_intensity * top_exp / (disk_width * (1.0 + top_exp) ** 2)
+ ).ravel()
+ J[:, self.params["disk radius"].offset] += dR
+
+ if self.refine_width:
+ dW = (
+ 4.0
+ * disk_intensity
+ * top_exp
+ * (r_disk - disk_radius)
+ / (disk_width**2 * (1.0 + top_exp) ** 2)
+ ).ravel()
+ J[:, self.params["edge width"].offset] += dW
+
+
+class ComplexOverlapKernelDiskLattice(WPFModel):
+ def __init__(
+ self,
+ WPF,
+ probe_kernel: np.ndarray,
+ ux: float,
+ uy: float,
+ vx: float,
+ vy: float,
+ u_max: int,
+ v_max: int,
+ intensity_0: float,
+ exclude_indices: list = [],
+ global_center: bool = True,
+ x0=0.0,
+ y0=0.0,
+ name="Complex Overlapped Disk Lattice",
+ verbose=False,
+ ):
+ return NotImplementedError(
+ "This model type has not been updated for use with the new architecture."
+ )
+
+ params = {}
+
+ self.probe_kernelFT = np.fft.fft2(probe_kernel)
+
+ if global_center:
+ params["x center"] = WPF.coordinate_model.params["x center"]
+ params["y center"] = WPF.coordinate_model.params["y center"]
+ else:
+ params["x center"] = Parameter(x0)
+ params["y center"] = Parameter(y0)
+
+ x0 = params["x center"].initial_value
+ y0 = params["y center"].initial_value
+
+ params["ux"] = Parameter(ux)
+ params["uy"] = Parameter(uy)
+ params["vx"] = Parameter(vx)
+ params["vy"] = Parameter(vy)
+
+ u_inds, v_inds = np.mgrid[-u_max : u_max + 1, -v_max : v_max + 1]
+ self.u_inds = u_inds.ravel()
+ self.v_inds = v_inds.ravel()
+
+ delete_mask = np.zeros_like(self.u_inds, dtype=bool)
+ Q_Nx = WPF.static_data["Q_Nx"]
+ Q_Ny = WPF.static_data["Q_Ny"]
+
+ self.yqArray = np.tile(np.fft.fftfreq(Q_Ny)[np.newaxis, :], (Q_Nx, 1))
+ self.xqArray = np.tile(np.fft.fftfreq(Q_Nx)[:, np.newaxis], (1, Q_Ny))
+
+ for i, (u, v) in enumerate(zip(u_inds.ravel(), v_inds.ravel())):
+ x = (
+ WPF.static_data["global_x0"]
+ + (u * params["ux"].initial_value)
+ + (v * params["vx"].initial_value)
+ )
+ y = (
+ WPF.static_data["global_y0"]
+ + (u * params["uy"].initial_value)
+ + (v * params["vy"].initial_value)
+ )
+ if [u, v] in exclude_indices:
+ delete_mask[i] = True
+ elif (x < 0) or (x > Q_Nx) or (y < 0) or (y > Q_Ny):
+ delete_mask[i] = True
+ if verbose:
+ print(
+ f"Excluding peak [{u},{v}] because it is outside the pattern..."
+ )
+ else:
+ params[f"[{u},{v}] Intensity"] = Parameter(intensity_0)
+ if u == 0 and v == 0:
+ params[f"[{u}, {v}] Phase"] = Parameter(
+ 0.0, 0.0, 0.0
+ ) # direct beam clamped at zero phase
+ else:
+ params[f"[{u}, {v}] Phase"] = Parameter(0.01, -np.pi, np.pi)
+
+ self.u_inds = self.u_inds[~delete_mask]
+ self.v_inds = self.v_inds[~delete_mask]
+
+ super().__init__(name, params, model_type=WPFModelType.LATTICE)
+
+ def func(self, DP: np.ndarray, x_fit, **kwargs) -> None:
+ x0 = x_fit[self.params["x center"].offset]
+ y0 = x_fit[self.params["y center"].offset]
+ ux = x_fit[self.params["ux"].offset]
+ uy = x_fit[self.params["uy"].offset]
+ vx = x_fit[self.params["vx"].offset]
+ vy = x_fit[self.params["vy"].offset]
+
+ localDP = np.zeros_like(DP, dtype=np.complex64)
+
+ for i, (u, v) in enumerate(zip(self.u_inds, self.v_inds)):
+ x = x0 + (u * ux) + (v * vx)
+ y = y0 + (u * uy) + (v * vy)
+
+ localDP += (
+ x_fit[self.params[f"[{u},{v}] Intensity"].offset]
+ * np.exp(1j * x_fit[self.params[f"[{u},{v}] Phase"].offset])
+ * np.abs(
+ np.fft.ifft2(
+ self.probe_kernelFT
+ * np.exp(-2j * np.pi * (self.xqArray * x + self.yqArray * y))
+ )
+ )
+ )
+
+ DP += np.abs(localDP) ** 2
+
+
+class KernelDiskLattice(WPFModel):
+ def __init__(
+ self,
+ WPF,
+ probe_kernel: np.ndarray,
+ ux: float,
+ uy: float,
+ vx: float,
+ vy: float,
+ u_max: int,
+ v_max: int,
+ intensity_0: float,
+ exclude_indices: list = [],
+ global_center: bool = True,
+ x0=0.0,
+ y0=0.0,
+ name="Custom Kernel Disk Lattice",
+ verbose=False,
+ ):
+ params = {}
+
+ self.probe_kernelFT = np.fft.fft2(probe_kernel)
+
+ if global_center:
+ params["x center"] = WPF.coordinate_model.params["x center"]
+ params["y center"] = WPF.coordinate_model.params["y center"]
+ else:
+ params["x center"] = Parameter(x0)
+ params["y center"] = Parameter(y0)
+
+ x0 = params["x center"].initial_value
+ y0 = params["y center"].initial_value
+
+ params["ux"] = Parameter(ux)
+ params["uy"] = Parameter(uy)
+ params["vx"] = Parameter(vx)
+ params["vy"] = Parameter(vy)
+
+ u_inds, v_inds = np.mgrid[-u_max : u_max + 1, -v_max : v_max + 1]
+ self.u_inds = u_inds.ravel()
+ self.v_inds = v_inds.ravel()
+
+ delete_mask = np.zeros_like(self.u_inds, dtype=bool)
+ Q_Nx = WPF.static_data["Q_Nx"]
+ Q_Ny = WPF.static_data["Q_Ny"]
+
+ self.yqArray = np.tile(np.fft.fftfreq(Q_Ny)[np.newaxis, :], (Q_Nx, 1))
+ self.xqArray = np.tile(np.fft.fftfreq(Q_Nx)[:, np.newaxis], (1, Q_Ny))
+
+ for i, (u, v) in enumerate(zip(u_inds.ravel(), v_inds.ravel())):
+ x = x0 + (u * params["ux"].initial_value) + (v * params["vx"].initial_value)
+ y = y0 + (u * params["uy"].initial_value) + (v * params["vy"].initial_value)
+ if [u, v] in exclude_indices:
+ delete_mask[i] = True
+ elif (x < 0) or (x > Q_Nx) or (y < 0) or (y > Q_Ny):
+ delete_mask[i] = True
+ if verbose:
+ print(
+ f"Excluding peak [{u},{v}] because it is outside the pattern..."
+ )
+ else:
+ params[f"[{u},{v}] Intensity"] = Parameter(intensity_0)
+
+ self.u_inds = self.u_inds[~delete_mask]
+ self.v_inds = self.v_inds[~delete_mask]
+
+ super().__init__(name, params, model_type=WPFModelType.LATTICE)
+
+ def func(self, DP: np.ndarray, x_fit: np.ndarray, **static_data) -> None:
+ x0 = x_fit[self.params["x center"].offset]
+ y0 = x_fit[self.params["y center"].offset]
+ ux = x_fit[self.params["ux"].offset]
+ uy = x_fit[self.params["uy"].offset]
+ vx = x_fit[self.params["vx"].offset]
+ vy = x_fit[self.params["vy"].offset]
+
+ for i, (u, v) in enumerate(zip(self.u_inds, self.v_inds)):
+ x = x0 + (u * ux) + (v * vx)
+ y = y0 + (u * uy) + (v * vy)
+
+ DP += (
+ x_fit[self.params[f"[{u},{v}] Intensity"].offset]
+ * np.abs(
+ np.fft.ifft2(
+ self.probe_kernelFT
+ * np.exp(-2j * np.pi * (self.xqArray * x + self.yqArray * y))
+ )
+ )
+ ) ** 2
diff --git a/py4DSTEM/process/wholepatternfit/wpf.py b/py4DSTEM/process/wholepatternfit/wpf.py
new file mode 100644
index 000000000..f206004b4
--- /dev/null
+++ b/py4DSTEM/process/wholepatternfit/wpf.py
@@ -0,0 +1,710 @@
+from __future__ import annotations
+from py4DSTEM import DataCube, RealSlice
+from emdfile import tqdmnd
+from py4DSTEM.process.wholepatternfit.wp_models import (
+ WPFModel,
+ _BaseModel,
+ WPFModelType,
+ Parameter,
+)
+
+from typing import Optional
+import numpy as np
+
+from scipy.optimize import least_squares
+import matplotlib.pyplot as plt
+import matplotlib.colors as mpl_c
+from matplotlib.gridspec import GridSpec
+
+__all__ = ["WholePatternFit"]
+
+
+class WholePatternFit:
+ from py4DSTEM.process.wholepatternfit.wpf_viz import (
+ show_model_grid,
+ show_lattice_points,
+ show_fit_metrics,
+ )
+
+ def __init__(
+ self,
+ datacube: DataCube,
+ x0: Optional[float] = None,
+ y0: Optional[float] = None,
+ mask: Optional[np.ndarray] = None,
+ use_jacobian: bool = True,
+ meanCBED: Optional[np.ndarray] = None,
+ ):
+ """
+ Perform pixelwise fits using composable models and numerical optimization.
+
+ Instantiate components of the fit model using the objects in wp_models,
+ and add them to the WPF object using ``add_model``.
+ All fitting parameters, including ``x0`` and ``y0``, can be specified as
+ floats or, if the parameter should be bounded, as a tuple with the format:
+ (initial guess, lower bound, upper bound)
+ Then, refine the initial guess using ``fit_to_mean_CBED``. If the initial
+ refinement is good, save it using ``accept_mean_CBED_fit``, which updates
+ the initial guesses in each model object.
+
+ Then, refine the model to each diffraction pattern in the dataset using
+ ``fit_all_patterns``. The fit results are returned in RealSlice objects
+ with slice labels corresponding to the names of each model and their
+ parameters.
+
+ To map strain, use ``get_lattice_maps`` to extract RealSice object with
+ the refined g vectors at each point, and then use the ordinary py4DSTEM
+ strain mapping pipeline
+
+ Parameters
+ ----------
+ datacube : (DataCube)
+ x0, y0 : Optional float or np.ndarray to specify the initial guess for the origin
+ of diffraction space, in pixels
+ mask : Optional np.ndarray to specify which pixels in the diffraction pattern
+ should be used for computing the loss function. Pixels occluded by a beamstop
+ or fixed detector should be set to False so they do not contribte to the loss
+ use_jacobian: bool, whether or not to use the analytic Jacobians for each model
+ in the optimizer. When False, finite differences is used for all gradient evaluations
+ meanCBED: Optional np.ndarray, used to specify the diffraction pattern used
+ for initial refinement of the parameters. If not specified, the average across
+ all scan positions is computed
+
+ """
+ self.datacube = datacube
+ self.meanCBED = (
+ meanCBED if meanCBED is not None else np.mean(datacube.data, axis=(0, 1))
+ )
+ # Global scaling parameter
+ self.intensity_scale = 1 / np.mean(self.meanCBED)
+
+ self.mask = mask if mask is not None else np.ones_like(self.meanCBED)
+
+ if hasattr(x0, "__iter__") and hasattr(y0, "__iter__"):
+ x0 = np.array(x0)
+ y0 = np.array(y0)
+ if x0.size == 2:
+ global_xy0_lb = np.array([x0[0] - x0[1], y0[0] - y0[1]])
+ global_xy0_ub = np.array([x0[0] + x0[1], y0[0] + y0[1]])
+ elif x0.size == 3:
+ global_xy0_lb = np.array([x0[1], y0[1]])
+ global_xy0_ub = np.array([x0[2], y0[2]])
+ else:
+ global_xy0_lb = np.array([0.0, 0.0])
+ global_xy0_ub = np.array([datacube.Q_Nx, datacube.Q_Ny])
+ x0 = x0[0]
+ y0 = y0[0]
+
+ else:
+ global_xy0_lb = np.array([0.0, 0.0])
+ global_xy0_ub = np.array([datacube.Q_Nx, datacube.Q_Ny])
+
+ # The WPF object holds a special Model that manages the shareable center coordinates
+ self.coordinate_model = _BaseModel(
+ x0=(x0, global_xy0_lb[0], global_xy0_ub[0]),
+ y0=(y0, global_xy0_lb[1], global_xy0_ub[1]),
+ )
+
+ self.model = [
+ self.coordinate_model,
+ ]
+
+ self.nParams = 0
+ self.use_jacobian = use_jacobian
+
+ # set up the global arguments
+ self._setup_static_data()
+
+ # for debugging: tracks all function evals
+ self._track = False
+ self._fevals = []
+ self._xevals = []
+ # self._cost_history = []
+
+ def add_model(self, model: WPFModel):
+ """
+ Add a WPFModel to the current model
+
+ Parameters
+ ----------
+ model: WPFModel
+ model to add to the fitting routine
+ """
+ self.model.append(model)
+
+ self.nParams += len(model.params.keys())
+
+ self._finalize_model()
+
+ def add_model_list(self, model_list: list[WPFModel]):
+ """
+ Add multiple WPFModel objects to the current model
+
+ Parameters
+ ----------
+ model: list[WPFModel]
+ models to add to the fitting routine
+ """
+ for m in model_list:
+ self.add_model(m)
+
+ def link_parameters(
+ self,
+ parent_model: WPFModel,
+ child_model: WPFModel | list[WPFModel],
+ parameters: str | list[str],
+ ):
+ """
+ Link parameters of separate models together. The parameters of
+ the child_model are replaced with the parameters of the parent_model.
+ Note, this does not add the models to the WPF object, that must
+ be performed separately.
+
+ Parameters
+ ----------
+ parent_model: WPFModel
+ model from which parameters will be copied
+ child_model: WPFModel or list of WPFModels
+ model(s) whose independent parameters are to be linked
+ with those of the parent_model
+ parameters: str or list of str
+ names of parameters to be linked
+ """
+ # Make sure child_model and parameters are iterable
+ child_model = (
+ [
+ child_model,
+ ]
+ if not hasattr(child_model, "__iter__")
+ else child_model
+ )
+
+ parameters = (
+ [
+ parameters,
+ ]
+ if not hasattr(parameters, "__iter__")
+ else parameters
+ )
+
+ for child in child_model:
+ for par in parameters:
+ child.params[par] = parent_model.params[par]
+
+ def generate_initial_pattern(self) -> np.ndarray:
+ """
+ Generate a diffraction pattern using the initial parameter
+ guesses for each model component
+
+ Returns
+ -------
+ initial_pattern: np.ndarray
+
+ """
+
+ # update parameters:
+ self._finalize_model()
+ return self._pattern(self.x0, self.static_data.copy()) / self.intensity_scale
+
+ def fit_to_mean_CBED(self, **fit_opts):
+ """
+ Fit model parameters to the mean CBED pattern
+
+ Parameters
+ ----------
+ fit_opts: keyword arguments passed to scipy.optimize.least_squares
+
+ Returns
+ -------
+ optimizer_result: dict
+ Output of scipy.optimize.least_squares
+ (also stored in self.mean_CBED_fit)
+
+ """
+ # first make sure we have the latest parameters
+ self._finalize_model()
+
+ # set the current active pattern to the mean CBED:
+ current_pattern = self.meanCBED * self.intensity_scale
+
+ self._fevals = []
+ self._xevals = []
+ self._cost_history = []
+
+ default_opts = {
+ "method": "trf",
+ "verbose": 1,
+ "x_scale": "jac",
+ }
+ default_opts.update(fit_opts)
+
+ if self.hasJacobian & self.use_jacobian:
+ opt = least_squares(
+ self._pattern_error,
+ self.x0,
+ jac=self._jacobian,
+ bounds=(self.lower_bound, self.upper_bound),
+ args=(current_pattern, self.static_data),
+ **default_opts,
+ )
+ else:
+ opt = least_squares(
+ self._pattern_error,
+ self.x0,
+ bounds=(self.lower_bound, self.upper_bound),
+ args=(current_pattern, self.static_data),
+ **default_opts,
+ )
+
+ self.mean_CBED_fit = opt
+
+ # Plotting
+ fig = plt.figure(constrained_layout=True, figsize=(12, 12))
+ gs = GridSpec(2, 2, figure=fig)
+
+ ax = fig.add_subplot(gs[0, 0])
+ err_hist = np.array(self._cost_history)
+ ax.plot(err_hist)
+ ax.set_ylabel("Sum Squared Error")
+ ax.set_xlabel("Iterations")
+ ax.set_yscale("log")
+
+ DP = (
+ self._pattern(self.mean_CBED_fit.x, self.static_data) / self.intensity_scale
+ )
+ ax = fig.add_subplot(gs[0, 1])
+ CyRd = mpl_c.LinearSegmentedColormap.from_list(
+ "CyRd", ["#00ccff", "#ffffff", "#ff0000"]
+ )
+ ax.matshow(
+ err_im := -(DP - self.meanCBED),
+ cmap=CyRd,
+ vmin=-np.abs(err_im).max() / 4,
+ vmax=np.abs(err_im).max() / 4,
+ )
+ ax.set_title("Error")
+ ax.axis("off")
+
+ ax = fig.add_subplot(gs[1, :])
+ ax.matshow(np.hstack((DP, self.meanCBED)) ** 0.25, cmap="turbo")
+ ax.axis("off")
+ ax.text(0.25, 0.92, "Refined", transform=ax.transAxes, ha="center", va="center")
+ ax.text(
+ 0.75, 0.92, "Mean CBED", transform=ax.transAxes, ha="center", va="center"
+ )
+
+ plt.show()
+
+ return opt
+
+ def fit_all_patterns(
+ self,
+ resume: bool = False,
+ real_space_mask: Optional[np.ndarray] = None,
+ show_fit_metrics: bool = True,
+ distributed: bool = True,
+ num_jobs: int = None,
+ threads_per_job: int = 1,
+ **fit_opts,
+ ):
+ """
+ Apply model fitting to all patterns.
+
+ Parameters
+ ----------
+ resume: bool (optional)
+ Set to true to continue a previous fit with more iterations.
+ real_space_mask: np.ndarray of bools (optional)
+ Only perform the fitting on a subset of the probe positions,
+ where real_space_mask[rx,ry] == True.
+ distributed: bool (optional)
+ Whether to evaluate using a pool of worker threads
+ num_jobs: int (optional)
+ number of parallel worker threads to launch if distributed=True
+ Defaults to number of CPU cores
+ threads_per_job: int (optional)
+ number of threads for each parallel job. If num_jobs is not specified,
+ the number of workers is automatically chosen so as to not oversubscribe
+ the cores (num_jobs = CPU_count // threads_per_job)
+ fit_opts: args (optional)
+ args passed to scipy.optimize.least_squares
+
+ Returns
+ --------
+ fit_data: RealSlice
+ Fitted coefficients for all probe positions
+ fit_metrics: RealSlice
+ Fitting metrics for all probe positions
+
+ """
+
+ # make sure we have the latest parameters
+ unique_params, unique_names = self._finalize_model()
+
+ # set tracking off
+ self._track = False
+ self._fevals = []
+
+ if resume:
+ assert hasattr(self, "fit_data"), "No existing data resuming fit!"
+
+ # init
+ fit_data = np.zeros((self.x0.shape[0], self.datacube.R_Nx, self.datacube.R_Ny))
+ fit_metrics = np.zeros((4, self.datacube.R_Nx, self.datacube.R_Ny))
+
+ # Default fitting options
+ default_opts = {
+ "method": "trf",
+ "verbose": 0,
+ "x_scale": "jac",
+ }
+ default_opts.update(fit_opts)
+
+ # Masking function
+ if real_space_mask is None:
+ mask = np.ones(
+ (self.datacube.R_Nx, self.datacube.R_Ny),
+ dtype=bool,
+ )
+ else:
+ mask = real_space_mask
+
+ # Loop over probe positions
+ if not distributed:
+ for rx, ry in tqdmnd(self.datacube.R_Nx, self.datacube.R_Ny):
+ if mask[rx, ry]:
+ current_pattern = (
+ self.datacube.data[rx, ry, :, :] * self.intensity_scale
+ )
+ x0 = self.fit_data.data[rx, ry] if resume else self.x0
+
+ try:
+ if self.hasJacobian & self.use_jacobian:
+ opt = least_squares(
+ self._pattern_error,
+ x0,
+ jac=self._jacobian,
+ bounds=(self.lower_bound, self.upper_bound),
+ args=(current_pattern, self.static_data),
+ **default_opts,
+ )
+ else:
+ opt = least_squares(
+ self._pattern_error,
+ x0,
+ bounds=(self.lower_bound, self.upper_bound),
+ args=(current_pattern, self.static_data),
+ **default_opts,
+ )
+
+ fit_data_single = opt.x
+ fit_metrics_single = [
+ opt.cost,
+ opt.optimality,
+ opt.nfev,
+ opt.status,
+ ]
+ except Exception as err:
+ fit_data_single = x0
+ fit_metrics_single = [0, 0, 0, -2]
+
+ fit_data[:, rx, ry] = fit_data_single
+ fit_metrics[:, rx, ry] = fit_metrics_single
+
+ else:
+ # distributed evaluation
+ self._fit_distributed(
+ resume=resume,
+ real_space_mask=mask,
+ num_jobs=num_jobs,
+ threads_per_job=threads_per_job,
+ fit_opts=default_opts,
+ fit_data=fit_data,
+ fit_metrics=fit_metrics,
+ )
+
+ self.fit_data = RealSlice(fit_data, name="Fit Data", slicelabels=unique_names)
+ self.fit_metrics = RealSlice(
+ fit_metrics,
+ name="Fit Metrics",
+ slicelabels=["cost", "optimality", "nfev", "status"],
+ )
+
+ if show_fit_metrics:
+ self.show_fit_metrics()
+
+ return self.fit_data, self.fit_metrics
+
+ def accept_mean_CBED_fit(self):
+ """
+ Sets the parameters optimized by fitting to mean CBED
+ as the initial guess for each of the component models.
+ """
+ x = self.mean_CBED_fit.x
+
+ for model in self.model:
+ for param in model.params.values():
+ param.initial_value = x[param.offset]
+
+ def get_lattice_maps(self) -> list[RealSlice]:
+ """
+ Get the fitted reciprical lattice vectors refined at each scan point.
+
+ Returns
+ -------
+ g_maps: list[RealSlice]
+ RealSlice objects containing the lattice data for each scan position
+ """
+ assert hasattr(self, "fit_data"), "Please run fitting first!"
+
+ lattices = [m for m in self.model if WPFModelType.LATTICE in m.model_type]
+
+ g_maps = []
+ for lat in lattices:
+ data = np.stack(
+ [
+ self.fit_data.data[lat.params["ux"].offset],
+ self.fit_data.data[lat.params["uy"].offset],
+ self.fit_data.data[lat.params["vx"].offset],
+ self.fit_data.data[lat.params["vy"].offset],
+ self.fit_metrics["status"].data
+ >= 0, # negative status indicates fit error
+ ],
+ axis=0,
+ )
+
+ g_map = RealSlice(
+ data,
+ slicelabels=["g1x", "g1y", "g2x", "g2y", "mask"],
+ name=lat.name,
+ )
+ g_maps.append(g_map)
+
+ return g_maps
+
+ def _setup_static_data(self):
+ """
+ Generate basic data that each model can access during the fitting routine
+ """
+ self.static_data = {}
+
+ xArray, yArray = np.mgrid[0 : self.datacube.Q_Nx, 0 : self.datacube.Q_Ny]
+ self.static_data["xArray"] = xArray
+ self.static_data["yArray"] = yArray
+
+ self.static_data["Q_Nx"] = self.datacube.Q_Nx
+ self.static_data["Q_Ny"] = self.datacube.Q_Ny
+
+ self.static_data["parent"] = self
+
+ def _get_distance(self, params: np.ndarray, x: Parameter, y: Parameter):
+ """
+ Return the distance from a point in pixel coordinates specified
+ by two Parameter objects.
+ This method caches the result from the _BaseModel for performance
+ """
+ if (
+ x is self.model[0].params["x center"]
+ and y is self.model[0].params["y center"]
+ ):
+ # TODO: actually implement caching
+ pass
+
+ return np.hypot(
+ self.static_data["xArray"] - params[x.offset],
+ self.static_data["yArray"] - params[y.offset],
+ )
+
+ def _pattern_error(self, x, current_pattern, shared_data):
+ DP = self._pattern(x, shared_data)
+
+ DP = (DP - current_pattern) * self.mask
+
+ if self._track:
+ self._fevals.append(DP)
+ self._xevals.append(x)
+ self._cost_history.append(np.sum(DP**2))
+
+ return DP.ravel()
+
+ def _pattern(self, x, shared_data):
+ DP = np.zeros((self.datacube.Q_Nx, self.datacube.Q_Ny))
+
+ for m in self.model:
+ m.func(DP, x, **shared_data)
+
+ return DP * self.mask
+
+ def _jacobian(self, x, current_pattern, shared_data):
+ # TODO: automatic mixed analytic/finite difference
+
+ J = np.zeros(((self.datacube.Q_Nx * self.datacube.Q_Ny), self.nParams))
+
+ for m in self.model:
+ m.jacobian(J, x, **shared_data)
+
+ return J * self.mask.ravel()[:, np.newaxis]
+
+ def _finalize_model(self):
+ # iterate over all models and assign indices, accumulate list
+ # of unique parameters. then, accumulate initial value and bounds vectors
+
+ # get unique names for each model
+ model_names = []
+ for m in self.model:
+ n = m.name
+ if n in model_names:
+ i = 1
+ while n in model_names:
+ n = m.name + "_" + str(i)
+ i += 1
+ model_names.append(n)
+
+ unique_params = []
+ unique_names = []
+ idx = 0
+ for model, model_name in zip(self.model, model_names):
+ for param_name, param in model.params.items():
+ if param not in unique_params:
+ unique_params.append(param)
+ unique_names.append(model_name + "/" + param_name)
+ param.offset = idx
+ idx += 1
+
+ self.x0 = np.array([param.initial_value for param in unique_params])
+ self.upper_bound = np.array([param.upper_bound for param in unique_params])
+ self.lower_bound = np.array([param.lower_bound for param in unique_params])
+
+ self.hasJacobian = all([m.hasJacobian for m in self.model])
+
+ self.nParams = self.x0.shape[0]
+
+ return unique_params, unique_names
+
+ def _fit_single_pattern(
+ self,
+ data: np.ndarray,
+ initial_guess: np.ndarray,
+ mask: bool,
+ fit_opts,
+ ):
+ """
+ Apply model fitting to one pattern.
+
+ Parameters
+ ----------
+ data: np.ndarray
+ Diffraction pattern
+ initial_guess: np.ndarray
+ starting guess for fitting
+ mask: bool
+ Fitting is skipped if mask is False, and default values are returned
+ fit_opts:
+ args passed to scipy.optimize.least_squares
+
+ Returns
+ --------
+ fit_coefs: np.array
+ Fitted coefficients
+ fit_metrics: np.array
+ Fitting metrics
+
+ """
+ if mask:
+ try:
+ if self.hasJacobian & self.use_jacobian:
+ opt = least_squares(
+ self._pattern_error,
+ initial_guess,
+ jac=self._jacobian,
+ bounds=(self.lower_bound, self.upper_bound),
+ args=(data * self.intensity_scale, self.static_data),
+ **fit_opts,
+ )
+ else:
+ opt = least_squares(
+ self._pattern_error,
+ initial_guess,
+ bounds=(self.lower_bound, self.upper_bound),
+ args=(data * self.intensity_scale, self.static_data),
+ **fit_opts,
+ )
+
+ fit_coefs = opt.x
+ fit_metrics_single = [
+ opt.cost,
+ opt.optimality,
+ opt.nfev,
+ opt.status,
+ ]
+ except Exception as err:
+ # print(err)
+ fit_coefs = initial_guess
+ fit_metrics_single = [0, 0, 0, -2]
+
+ return fit_coefs, fit_metrics_single
+ else:
+ return np.zeros_like(initial_guess), [0, 0, 0, 0]
+
+ def _fit_distributed(
+ self,
+ fit_opts: dict,
+ fit_data: np.ndarray,
+ fit_metrics: np.ndarray,
+ real_space_mask: np.ndarray,
+ resume=False,
+ num_jobs=None,
+ threads_per_job=1,
+ ):
+ """
+ Run fitting using multiprocessing to fit several patterns in parallel
+ """
+ from mpire import WorkerPool, cpu_count
+ from threadpoolctl import threadpool_limits
+
+ # prevent oversubscription when using multiple threads per job
+ num_jobs = num_jobs or cpu_count() // threads_per_job
+
+ def f(shared_data, args):
+ with threadpool_limits(limits=threads_per_job):
+ return self._fit_single_pattern(**args, fit_opts=shared_data)
+
+ # hopefully the data entries remain as views until dispatch time...
+ fit_inputs = [
+ (
+ {
+ "data": self.datacube[rx, ry],
+ "initial_guess": self.fit_data[rx, ry] if resume else self.x0,
+ "mask": real_space_mask[rx, ry],
+ },
+ )
+ for rx in range(self.datacube.R_Nx)
+ for ry in range(self.datacube.R_Ny)
+ ]
+
+ with WorkerPool(
+ n_jobs=num_jobs,
+ shared_objects=fit_opts,
+ ) as pool:
+ results = pool.map(
+ f,
+ fit_inputs,
+ progress_bar=True,
+ )
+
+ for (rx, ry), res in zip(
+ np.ndindex((self.datacube.R_Nx, self.datacube.R_Ny)), results
+ ):
+ fit_data[:, rx, ry] = res[0]
+ fit_metrics[:, rx, ry] = res[1]
+
+ def __getstate__(self):
+ # Prevent pickling from copying the datacube, so that distributed
+ # evaluation does not balloon memory usage.
+ # Copy the object's state from self.__dict__ which contains
+ # all our instance attributes. Always use the dict.copy()
+ # method to avoid modifying the original state.
+ state = self.__dict__.copy()
+ # Remove the unpicklable entries.
+ del state["datacube"]
+ return state
diff --git a/py4DSTEM/process/wholepatternfit/wpf_viz.py b/py4DSTEM/process/wholepatternfit/wpf_viz.py
new file mode 100644
index 000000000..436ae40a2
--- /dev/null
+++ b/py4DSTEM/process/wholepatternfit/wpf_viz.py
@@ -0,0 +1,262 @@
+from typing import Optional
+import numpy as np
+
+import matplotlib.pyplot as plt
+import matplotlib.colors as mpl_c
+from matplotlib.gridspec import GridSpec
+
+from py4DSTEM.process.wholepatternfit.wp_models import WPFModelType
+
+
+def show_model_grid(self, x=None, **plot_kwargs):
+ x = self.mean_CBED_fit.x if x is None else x
+
+ model = [m for m in self.model if WPFModelType.DUMMY not in m.model_type]
+
+ N = len(model)
+ cols = int(np.ceil(np.sqrt(N)))
+ rows = (N + 1) // cols
+
+ kwargs = dict(constrained_layout=True)
+ kwargs.update(plot_kwargs)
+ fig, ax = plt.subplots(rows, cols, **kwargs)
+
+ for a, m in zip(ax.flat, model):
+ DP = np.zeros((self.datacube.Q_Nx, self.datacube.Q_Ny))
+ m.func(DP, x, **self.static_data)
+
+ a.matshow(DP, cmap="turbo")
+
+ # Determine if text color should be white or black
+ int_range = np.array((np.min(DP), np.max(DP)))
+ if int_range[0] != int_range[1]:
+ r = (np.mean(DP[: DP.shape[0] // 10, :]) - int_range[0]) / (
+ int_range[1] - int_range[0]
+ )
+ if r < 0.5:
+ color = "w"
+ else:
+ color = "k"
+ else:
+ color = "w"
+
+ a.text(
+ 0.5,
+ 0.92,
+ m.name,
+ transform=a.transAxes,
+ ha="center",
+ va="center",
+ color=color,
+ )
+ for a in ax.flat:
+ a.axis("off")
+
+ plt.show()
+
+
+def show_lattice_points(
+ self,
+ im=None,
+ vmin=None,
+ vmax=None,
+ power=None,
+ show_vectors=True,
+ crop_to_pattern=False,
+ returnfig=False,
+ moire_origin_idx=[0, 0, 0, 0],
+ *args,
+ **kwargs,
+):
+ """
+ Plotting utility to show the initial lattice points.
+
+ Parameters
+ ----------
+ im: np.ndarray
+ Optional: Image to show, defaults to mean CBED
+ vmin, vmax: float
+ Intensity ranges for plotting im
+ power: float
+ Gamma level for showing im
+ show_vectors: bool
+ Flag to plot the lattice vectors
+ crop_to_pattern: bool
+ Flag to limit the field of view to the pattern area. If False,
+ spots outside the pattern are shown
+ returnfig: bool
+ If True, (fig,ax) are returned and plt.show() is not called
+ moire_origin_idx: list of length 4
+ Indices of peak on which to draw Moire vectors, written as
+ [a_u, a_v, b_u, b_v]
+ args, kwargs
+ Passed to plt.subplots
+
+ Returns
+ -------
+ fig,ax: If returnfig=True
+ """
+
+ if im is None:
+ im = self.meanCBED
+ if power is None:
+ power = 0.5
+
+ fig, ax = plt.subplots(*args, **kwargs)
+ if vmin is None and vmax is None:
+ ax.matshow(
+ im**power,
+ cmap="gray",
+ )
+ else:
+ ax.matshow(
+ im**power,
+ vmin=vmin,
+ vmax=vmax,
+ cmap="gray",
+ )
+
+ lattices = [m for m in self.model if WPFModelType.LATTICE in m.model_type]
+
+ for m in lattices:
+ ux, uy = m.params["ux"].initial_value, m.params["uy"].initial_value
+ vx, vy = m.params["vx"].initial_value, m.params["vy"].initial_value
+
+ lat = np.array([[ux, uy], [vx, vy]])
+ inds = np.stack([m.u_inds, m.v_inds], axis=1)
+
+ spots = inds @ lat
+ spots[:, 0] += m.params["x center"].initial_value
+ spots[:, 1] += m.params["y center"].initial_value
+
+ axpts = ax.scatter(
+ spots[:, 1],
+ spots[:, 0],
+ s=100,
+ marker="x",
+ label=m.name,
+ )
+
+ if show_vectors:
+ ax.arrow(
+ m.params["y center"].initial_value,
+ m.params["x center"].initial_value,
+ m.params["uy"].initial_value,
+ m.params["ux"].initial_value,
+ length_includes_head=True,
+ color=axpts.get_facecolor(),
+ width=1.0,
+ )
+
+ ax.arrow(
+ m.params["y center"].initial_value,
+ m.params["x center"].initial_value,
+ m.params["vy"].initial_value,
+ m.params["vx"].initial_value,
+ length_includes_head=True,
+ color=axpts.get_facecolor(),
+ width=1.0,
+ )
+
+ moires = [m for m in self.model if WPFModelType.MOIRE in m.model_type]
+
+ for m in moires:
+ lat_ab = m._get_parent_lattices(m.lattice_a, m.lattice_b)
+ lat_abm = np.vstack((lat_ab, m.moire_matrix @ lat_ab))
+
+ spots = m.moire_indices_uvm @ lat_abm
+ spots[:, 0] += m.params["x center"].initial_value
+ spots[:, 1] += m.params["y center"].initial_value
+
+ axpts = ax.scatter(
+ spots[:, 1],
+ spots[:, 0],
+ s=100,
+ marker="+",
+ label=m.name,
+ )
+
+ if show_vectors:
+ arrow_origin = np.array(moire_origin_idx) @ lat_ab
+ arrow_origin[0] += m.params["x center"].initial_value
+ arrow_origin[1] += m.params["y center"].initial_value
+
+ ax.arrow(
+ arrow_origin[1],
+ arrow_origin[0],
+ lat_abm[4, 1],
+ lat_abm[4, 0],
+ length_includes_head=True,
+ color=axpts.get_facecolor(),
+ width=1.0,
+ )
+
+ ax.arrow(
+ arrow_origin[1],
+ arrow_origin[0],
+ lat_abm[5, 1],
+ lat_abm[5, 0],
+ length_includes_head=True,
+ color=axpts.get_facecolor(),
+ width=1.0,
+ )
+
+ ax.legend()
+
+ if crop_to_pattern:
+ ax.set_xlim(0, im.shape[1] - 1)
+ ax.set_ylim(im.shape[0] - 1, 0)
+
+ return (fig, ax) if returnfig else plt.show()
+
+
+def show_fit_metrics(self, returnfig=False, **subplots_kwargs):
+ assert hasattr(self, "fit_metrics"), "Please run fitting first!"
+
+ kwargs = dict(figsize=(14, 12), constrained_layout=True)
+ kwargs.update(subplots_kwargs)
+ fig, ax = plt.subplots(2, 2, **kwargs)
+ im = ax[0, 0].matshow(self.fit_metrics["cost"].data, norm=mpl_c.LogNorm())
+ ax[0, 0].set_title("Final Cost Function")
+ fig.colorbar(im, ax=ax[0, 0])
+
+ opt_cmap = mpl_c.ListedColormap(
+ (
+ (0.6, 0.05, 0.05),
+ (0.8941176470588236, 0.10196078431372549, 0.10980392156862745),
+ (0.21568627450980393, 0.49411764705882355, 0.7215686274509804),
+ (0.30196078431372547, 0.6862745098039216, 0.2901960784313726),
+ (0.596078431372549, 0.3058823529411765, 0.6392156862745098),
+ (1.0, 0.4980392156862745, 0.0),
+ (1.0, 1.0, 0.2),
+ )
+ )
+ im = ax[0, 1].matshow(
+ self.fit_metrics["status"].data, cmap=opt_cmap, vmin=-2.5, vmax=4.5
+ )
+ cbar = fig.colorbar(im, ax=ax[0, 1], ticks=[-2, -1, 0, 1, 2, 3, 4])
+ cbar.ax.set_yticklabels(
+ [
+ "Unknown Error",
+ "MINPACK Error",
+ "Max f evals exceeded",
+ "$gtol$ satisfied",
+ "$ftol$ satisfied",
+ "$xtol$ satisfied",
+ "$xtol$ & $ftol$ satisfied",
+ ]
+ )
+ ax[0, 1].set_title("Optimizer Status")
+ fig.set_facecolor("w")
+
+ im = ax[1, 0].matshow(self.fit_metrics["optimality"].data, norm=mpl_c.LogNorm())
+ ax[1, 0].set_title("First Order Optimality")
+ fig.colorbar(im, ax=ax[1, 0])
+
+ im = ax[1, 1].matshow(self.fit_metrics["nfev"].data)
+ ax[1, 1].set_title("Number f evals")
+ fig.colorbar(im, ax=ax[1, 1])
+
+ fig.set_facecolor("w")
+
+ return (fig, ax) if returnfig else plt.show()
diff --git a/py4DSTEM/utils/__init__.py b/py4DSTEM/utils/__init__.py
new file mode 100644
index 000000000..b0c484e80
--- /dev/null
+++ b/py4DSTEM/utils/__init__.py
@@ -0,0 +1 @@
+from py4DSTEM.utils.configuration_checker import check_config
diff --git a/py4DSTEM/utils/configuration_checker.py b/py4DSTEM/utils/configuration_checker.py
new file mode 100644
index 000000000..904dceb29
--- /dev/null
+++ b/py4DSTEM/utils/configuration_checker.py
@@ -0,0 +1,502 @@
+#### this file contains a function/s that will check if various
+# libaries/compute options are available
+import importlib
+from operator import mod
+
+# list of modules we expect/may expect to be installed
+# as part of a standard py4DSTEM installation
+# this needs to be the import name e.g. import mp_api not mp-api
+modules = [
+ "crystal4D",
+ "cupy",
+ "dask",
+ "dill",
+ "distributed",
+ "gdown",
+ "h5py",
+ "ipyparallel",
+ "jax",
+ "matplotlib",
+ "mp_api",
+ "ncempy",
+ "numba",
+ "numpy",
+ "pymatgen",
+ "skimage",
+ "sklearn",
+ "scipy",
+ "tensorflow",
+ "tensorflow-addons",
+ "tqdm",
+]
+
+# currently this was copy and pasted from setup.py,
+# hopefully there's a programatic way to do this.
+module_depenencies = {
+ "base": [
+ "numpy",
+ "scipy",
+ "h5py",
+ "ncempy",
+ "matplotlib",
+ "skimage",
+ "sklearn",
+ "tqdm",
+ "dill",
+ "gdown",
+ "dask",
+ "distributed",
+ ],
+ "ipyparallel": ["ipyparallel", "dill"],
+ "cuda": ["cupy"],
+ "acom": ["pymatgen", "mp_api"],
+ "aiml": ["tensorflow", "tensorflow-addons", "crystal4D"],
+ "aiml-cuda": ["tensorflow", "tensorflow-addons", "crystal4D", "cupy"],
+ "numba": ["numba"],
+}
+
+
+#### Class and Functions to Create Coloured Strings ####
+class colours:
+ CEND = "\x1b[0m"
+ WARNING = "\x1b[7;93m"
+ SUCCESS = "\x1b[7;92m"
+ FAIL = "\x1b[7;91m"
+ BOLD = "\033[1m"
+ UNDERLINE = "\033[4m"
+
+
+def create_warning(s: str) -> str:
+ """
+ Creates a yellow shaded with white font version of string s
+
+ Args:
+ s (str): string to be turned into a warning style string
+
+ Returns:
+ str: stylized version of string s
+ """
+ s = colours.WARNING + s + colours.CEND
+ return s
+
+
+def create_success(s: str) -> str:
+ """
+ Creates a yellow shaded with white font version of string s
+
+ Args:
+ s (str): string to be turned into a warning style string
+
+ Returns:
+ str: stylized version of string s
+ """
+ s = colours.SUCCESS + s + colours.CEND
+ return s
+
+
+def create_failure(s: str) -> str:
+ """
+ Creates a yellow shaded with white font version of string s
+
+ Args:
+ s (str): string to be turned into a warning style string
+
+ Returns:
+ str: stylized version of string s
+ """
+ s = colours.FAIL + s + colours.CEND
+ return s
+
+
+def create_bold(s: str) -> str:
+ """
+ Creates a yellow shaded with white font version of string s
+
+ Args:
+ s (str): string to be turned into a warning style string
+
+ Returns:
+ str: stylized version of string s
+ """
+ s = colours.BOLD + s + colours.CEND
+ return s
+
+
+def create_underline(s: str) -> str:
+ """
+ Creates an underlined version of string s
+
+ Args:
+ s (str): string to be turned into an underlined style string
+
+ Returns:
+ str: stylized version of string s
+ """
+ s = colours.UNDERLINE + s + colours.CEND
+ return s
+
+
+#### Functions to check imports etc.
+### here I use the term state to define a boolean condition as to whether a libary/module was sucessfully imported/can be used
+
+
+def get_import_states(modules: list = modules) -> dict:
+ """
+ Check the ability to import modules and store the results as a boolean value. Returns as a dict.
+ e.g. import_states_dict['numpy'] == True/False
+
+ Args:
+ modules (list, optional): List of modules to check the ability to import. Defaults to modules.
+
+ Returns:
+ dict: dictionary where key is the name of the module, and val is the boolean state if it could be imported
+ """
+ # Create the import states dict
+ import_states_dict = {}
+
+ # check if the modules import
+ # and update the states dict to reflect this
+ for m in modules:
+ state = import_tester(m)
+ import_states_dict[m] = state
+
+ return import_states_dict
+
+
+def get_module_states(state_dict: dict) -> dict:
+ """_summary_
+
+ Args:
+ state_dict (dict): _description_
+
+ Returns:
+ dict: _description_
+ """
+
+ # create an empty dict to put module states into:
+ module_states = {}
+
+ # key is the name of the module e.g. ACOM
+ # val is a list of its dependencies
+ # module_dependencies comes from the namespace
+ for key, val in module_depenencies.items():
+ # create a list to store the status of the depencies
+ temp_lst = []
+
+ # loop over all the dependencies required for the module to work
+ # append the bool if they could be imported
+ for depend in val:
+ temp_lst.append(state_dict[depend])
+
+ # check that all the depencies could be imported i.e. state == True
+ # and set the state of the module to that
+ module_states[key] = all(temp_lst) is True
+
+ return module_states
+
+
+def print_import_states(import_states: dict) -> None:
+ """_summary_
+
+ Args:
+ import_states (dict): _description_
+
+ Returns:
+ _type_: _description_
+ """
+ # m is the name of the import module
+ # state is whether it was importable
+ for m, state in import_states.items():
+ # if it was importable i.e. state == True
+ if state:
+ s = f" Module {m.capitalize()} Imported Successfully "
+ s = create_success(s)
+ s = f"{s: <80}"
+ print(s)
+ # if unable to import i.e. state == False
+ else:
+ s = f" Module {m.capitalize()} Import Failed "
+ s = create_failure(s)
+ s = f"{s: <80}"
+ print(s)
+ return None
+
+
+def print_module_states(module_states: dict) -> None:
+ """_summary_
+
+ Args:
+ module_states (dict): _description_
+
+ Returns:
+ _type_: _description_
+ """
+ # Print out the state of all the modules in colour code
+ # key is the name of a py4DSTEM Module
+ # Val is the state i.e. True/False
+ for key, val in module_states.items():
+ # if the state is True i.e. all dependencies are installed
+ if val:
+ s = f" All Dependencies for {key.capitalize()} are Installed "
+ s = create_success(s)
+ print(s)
+ # if something is missing
+ else:
+ s = f" Not All Dependencies for {key.capitalize()} are Installed"
+ s = create_failure(s)
+ print(s)
+ return None
+
+
+def perfrom_extra_checks(
+ import_states: dict, verbose: bool, gratuitously_verbose: bool, **kwargs
+) -> None:
+ """_summary_
+
+ Args:
+ import_states (dict): _description_
+ verbose (bool): _description_
+ gratuitously_verbose (bool): _description_
+
+ Returns:
+ _type_: _description_
+ """
+
+ # print a output module
+ extra_checks_message = "Running Extra Checks"
+ extra_checks_message = create_bold(extra_checks_message)
+ print(f"{extra_checks_message}")
+ # For modules that import run any extra checks
+ for key, val in import_states.items():
+ if val:
+ # s = create_underline(key.capitalize())
+ # print(s)
+ func = funcs_dict.get(key)
+ if func is not None:
+ s = create_underline(key.capitalize())
+ print(s)
+ func(verbose=verbose, gratuitously_verbose=gratuitously_verbose)
+ else:
+ # if gratuitously_verbose print out all modules without checks
+ if gratuitously_verbose:
+ s = create_underline(key.capitalize())
+ print(s)
+ print_no_extra_checks(key)
+ else:
+ pass
+
+ return None
+
+
+def import_tester(m: str) -> bool:
+ """
+ This function will try and import the module, m,
+ it returns the success as boolean and prints a message.
+ Args:
+ m (str): string name of a module
+
+ Returns:
+ bool: boolean if the module was able to be imported
+ """
+ # set a boolean switch
+ state = True
+
+ # try and import the module
+ try:
+ importlib.import_module(m)
+ except:
+ state = False
+
+ return state
+
+
+def check_module_functionality(state_dict: dict) -> None:
+ """
+ This function checks all the py4DSTEM modules, e.g. acom, ml-ai, and whether all the required dependencies are importable
+
+ Args:
+ state_dict (dict): dictionary of the state, i.e. boolean, of all the modules and the ability to import.
+ It will then print in a 'Success' or 'Failure' message. All dependencies must be available to succeed.
+
+ Returns:
+ None: Prints the state of a py4DSTEM module's libaray depencencies
+ """
+
+ # create an empty dict to put module states into:
+ module_states = {}
+
+ # key is the name of the module e.g. ACOM
+ # val is a list of its dependencies
+ for key, val in module_depenencies.items():
+ # create a list to store the status of the depencies
+ temp_lst = []
+
+ # loop over all the dependencies required for the module to work
+ # append the bool if they could be imported
+ for depend in val:
+ temp_lst.append(state_dict[depend])
+
+ # check that all the depencies could be imported i.e. state == True
+ # and set the state of the module to that
+ module_states[key] = all(temp_lst) is True
+
+ # Print out the state of all the modules in colour code
+ for key, val in module_states.items():
+ # if the state is True
+ if val:
+ s = f" All Dependencies for {key.capitalize()} are Installed "
+ s = create_success(s)
+ print(s)
+ # if something is missing
+ else:
+ s = f" Not All Dependencies for {key.capitalize()} are Installed"
+ s = create_failure(s)
+ print(s)
+
+ return None # module_states
+
+
+#### ADDTIONAL CHECKS ####
+
+
+def check_cupy_gpu(gratuitously_verbose: bool, **kwargs):
+ """
+ This function performs some additional tests which may be useful in
+ diagnosing Cupy GPU performance
+
+ Args:
+ verbose (bool, optional): Will print additional information e.g. CUDA path, Cupy version. Defaults to False
+ gratuitously_verbose (bool, optional): Will print out atributes of all Defaults to False.
+ """
+ # import some libaries
+ from pprint import pprint
+ import cupy as cp
+
+ # check that CUDA is detected correctly
+ cuda_availability = cp.cuda.is_available()
+ if cuda_availability:
+ s = " CUDA is Available "
+ s = create_success(s)
+ s = f"{s: <80}"
+ print(s)
+ else:
+ s = " CUDA is Unavailable "
+ s = create_failure(s)
+ s = f"{s: <80}"
+ print(s)
+
+ # Count how many GPUs Cupy can detect
+ # probably should change this to a while loop ...
+ for i in range(24):
+ try:
+ d = cp.cuda.Device(i)
+ hasattr(d, "attributes")
+ except:
+ num_gpus_detected = i
+ break
+
+ # print how many GPUs were detected, filter for a couple of special conditons
+ if num_gpus_detected == 0:
+ s = " Detected no GPUs "
+ s = create_failure(s)
+ s = f"{s: <80}"
+ print(s)
+ elif num_gpus_detected >= 24:
+ s = " Detected at least 24 GPUs, could be more "
+ s = create_warning(s)
+ s = f"{s: <80}"
+ print(s)
+ else:
+ s = f" Detected {num_gpus_detected} GPUs "
+ s = create_success(s)
+ s = f"{s: <80}"
+ print(s)
+
+ cuda_path = cp.cuda.get_cuda_path()
+ print(f"Detected CUDA Path:\t{cuda_path}")
+ cupy_version = cp.__version__
+ print(f"Cupy Version:\t\t{cupy_version}")
+
+ # if verbose print extra information
+ if gratuitously_verbose:
+ for i in range(num_gpus_detected):
+ d = cp.cuda.Device(i)
+ s = f"GPU: {i}"
+ s = create_warning(s)
+ print(f" {s} ")
+ pprint(d.attributes)
+ return None
+
+
+def print_no_extra_checks(m: str):
+ """
+ This function prints a warning style message that the module m
+ currently has no extra checks.
+
+ Args:
+ m (str): This is the name of the module
+
+ Returns:
+ None
+ """
+ s = f" There are no Extra Checks for {m} "
+ s = create_warning(s)
+ s = f"{s}"
+ print(s)
+
+ return None
+
+
+# dict of extra check functions
+funcs_dict = {"cupy": check_cupy_gpu}
+
+
+#### main function used to check the configuration of the installation
+def check_config(
+ # modules:list = modules, # removed to not be user editable as this will often break. Could make it append to modules... but for now just removing
+ verbose: bool = False,
+ gratuitously_verbose: bool = False,
+ # egregiously_verbose:bool = False
+) -> None:
+ """
+ This function checks the state of required imports to run py4DSTEM.
+
+ Default behaviour will provide a summary of install dependencies for each module e.g. Base, ACOM etc.
+
+ Args:
+ verbose (bool, optional): Will provide the status of all possible requriements for py4DSTEM, and perform any additonal checks. Defaults to False.
+ gratuitously_verbose (bool, optional): Provides more indepth analysis. Defaults to False.
+
+ Returns:
+ None
+ """
+
+ # get the states of all imports
+ states_dict = get_import_states(modules)
+
+ # get the states of all modules dependencies
+ modules_dict = get_module_states(states_dict)
+
+ # print the modules compatiabiltiy
+ # prepare a message
+ modules_checks_message = "Checking Module Dependencies"
+ modules_checks_message = create_bold(modules_checks_message)
+ print(modules_checks_message)
+ # print the status
+ print_module_states(modules_dict)
+
+ if verbose:
+ # Print that Import Checks are happening
+ imports_check_message = "Running Import Checks"
+ imports_check_message = create_bold(imports_check_message)
+ print(f"{imports_check_message}")
+
+ print_import_states(states_dict)
+
+ perfrom_extra_checks(
+ import_states=states_dict,
+ verbose=verbose,
+ gratuitously_verbose=gratuitously_verbose,
+ )
+
+ return None
diff --git a/py4DSTEM/version.py b/py4DSTEM/version.py
new file mode 100644
index 000000000..103751b29
--- /dev/null
+++ b/py4DSTEM/version.py
@@ -0,0 +1 @@
+__version__ = "0.14.9"
diff --git a/py4DSTEM/visualize/README.md b/py4DSTEM/visualize/README.md
new file mode 100644
index 000000000..6bbbac233
--- /dev/null
+++ b/py4DSTEM/visualize/README.md
@@ -0,0 +1,11 @@
+# `py4DSTEM.visualize`
+
+Visualization functions. The basic visualization function has a call signature
+
+```
+show(ar,min=0,max=3,power=1,figsize=(12,12),contrast='std',ax=None,
+ bordercolor=None,borderwidth=5,returnfig=False,cmap='gray',**kwargs)
+```
+
+Most other visualization functions are built on top of this one, and accept these arguments, possibly plus others. Additional keyword arguments passed as `**kwargs` are passed to `plt.show`. Creating and then performing additional edits to a plot is accomplished by setting `returnfig=True`, which then returns a 2-tuple `(fig,ax)`.
+
diff --git a/py4DSTEM/visualize/__init__.py b/py4DSTEM/visualize/__init__.py
new file mode 100644
index 000000000..d9e5b4c68
--- /dev/null
+++ b/py4DSTEM/visualize/__init__.py
@@ -0,0 +1,5 @@
+from py4DSTEM.visualize.overlay import *
+from py4DSTEM.visualize.show import *
+from py4DSTEM.visualize.vis_RQ import *
+from py4DSTEM.visualize.vis_grid import *
+from py4DSTEM.visualize.vis_special import *
diff --git a/py4DSTEM/visualize/overlay.py b/py4DSTEM/visualize/overlay.py
new file mode 100644
index 000000000..32baff443
--- /dev/null
+++ b/py4DSTEM/visualize/overlay.py
@@ -0,0 +1,1293 @@
+import numpy as np
+from matplotlib.patches import Rectangle, Circle, Wedge, Ellipse
+from matplotlib.axes import Axes
+from matplotlib.colors import is_color_like
+from numbers import Number
+from math import log
+from fractions import Fraction
+
+from emdfile import PointList
+
+
+def add_rectangles(ax, d):
+ """
+ Adds one or more rectangles to Axis ax using the parameters in dictionary d.
+ """
+ # Handle inputs
+ assert isinstance(ax, Axes)
+ # lims
+ assert "lims" in d.keys()
+ lims = d["lims"]
+ if isinstance(lims, tuple):
+ assert len(lims) == 4
+ lims = [lims]
+ assert isinstance(lims, list)
+ N = len(lims)
+ assert all([isinstance(t, tuple) for t in lims])
+ assert all([len(t) == 4 for t in lims])
+ # color
+ color = d["color"] if "color" in d.keys() else "r"
+ if isinstance(color, list):
+ assert len(color) == N
+ assert all([is_color_like(c) for c in color])
+ else:
+ assert is_color_like(color)
+ color = [color for i in range(N)]
+ # fill
+ fill = d["fill"] if "fill" in d.keys() else False
+ if isinstance(fill, bool):
+ fill = [fill for i in range(N)]
+ else:
+ assert isinstance(fill, list)
+ assert len(fill) == N
+ assert all([isinstance(f, bool) for f in fill])
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 1
+ if isinstance(alpha, (float, int, np.float64)):
+ alpha = [alpha for i in range(N)]
+ else:
+ assert isinstance(alpha, list)
+ assert len(alpha) == N
+ assert all([isinstance(a, (float, int, np.float64)) for a in alpha])
+ # linewidth
+ linewidth = d["linewidth"] if "linewidth" in d.keys() else 2
+ if isinstance(linewidth, (float, int, np.float64)):
+ linewidth = [linewidth for i in range(N)]
+ else:
+ assert isinstance(linewidth, list)
+ assert len(linewidth) == N
+ assert all([isinstance(lw, (float, int, np.float64)) for lw in linewidth])
+ # additional parameters
+ kws = [
+ k for k in d.keys() if k not in ("lims", "color", "fill", "alpha", "linewidth")
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # add the rectangles
+ for i in range(N):
+ l, c, f, a, lw = lims[i], color[i], fill[i], alpha[i], linewidth[i]
+ rect = Rectangle(
+ (l[2] - 0.5, l[0] - 0.5),
+ l[3] - l[2],
+ l[1] - l[0],
+ color=c,
+ fill=f,
+ alpha=a,
+ linewidth=lw,
+ **kwargs,
+ )
+ ax.add_patch(rect)
+
+ return
+
+
+def add_circles(ax, d):
+ """
+ adds one or more circles to axis ax using the parameters in dictionary d.
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # center
+ assert "center" in d.keys()
+ center = d["center"]
+ if isinstance(center, tuple):
+ assert len(center) == 2
+ center = [center]
+ assert isinstance(center, list)
+ N = len(center)
+ assert all([isinstance(x, tuple) for x in center])
+ assert all([len(x) == 2 for x in center])
+ # radius
+ assert "R" in d.keys()
+ R = d["R"]
+ if isinstance(R, Number):
+ R = [R for i in range(N)]
+ assert isinstance(R, list)
+ assert len(R) == N
+ assert all([isinstance(i, Number) for i in R])
+ # color
+ color = d["color"] if "color" in d.keys() else "r"
+ if isinstance(color, list):
+ assert len(color) == N
+ assert all([is_color_like(c) for c in color])
+ else:
+ assert is_color_like(color)
+ color = [color for i in range(N)]
+ # fill
+ fill = d["fill"] if "fill" in d.keys() else False
+ if isinstance(fill, bool):
+ fill = [fill for i in range(N)]
+ else:
+ assert isinstance(fill, list)
+ assert len(fill) == N
+ assert all([isinstance(f, bool) for f in fill])
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 1
+ if isinstance(alpha, (float, int, np.float64)):
+ alpha = [alpha for i in range(N)]
+ else:
+ assert isinstance(alpha, list)
+ assert len(alpha) == N
+ assert all([isinstance(a, (float, int, np.float64)) for a in alpha])
+ # linewidth
+ linewidth = d["linewidth"] if "linewidth" in d.keys() else 2
+ if isinstance(linewidth, (float, int, np.float64)):
+ linewidth = [linewidth for i in range(N)]
+ else:
+ assert isinstance(linewidth, list)
+ assert len(linewidth) == N
+ assert all([isinstance(lw, (float, int, np.float64)) for lw in linewidth])
+ # additional parameters
+ kws = [
+ k
+ for k in d.keys()
+ if k not in ("center", "R", "color", "fill", "alpha", "linewidth")
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # add the circles
+ for i in range(N):
+ cent, r, col, f, a, lw = (
+ center[i],
+ R[i],
+ color[i],
+ fill[i],
+ alpha[i],
+ linewidth[i],
+ )
+ circ = Circle(
+ (cent[1], cent[0]), r, color=col, fill=f, alpha=a, linewidth=lw, **kwargs
+ )
+ ax.add_patch(circ)
+
+ return
+
+
+def add_annuli(ax, d):
+ """
+ Adds one or more annuli to Axis ax using the parameters in dictionary d.
+ """
+
+ # Check that all required inputs are present
+ assert isinstance(ax, Axes)
+ assert "center" in d.keys()
+ assert "radii" in d.keys()
+
+ # Get user-provided center and radii
+ center = d["center"]
+ radii = d["radii"]
+
+ # Determine number of annuli being plotted
+ if isinstance(center, list):
+ N = len(center)
+ elif isinstance(radii, list):
+ N = len(radii)
+ else:
+ N = 1
+
+ # center
+ if isinstance(center, tuple):
+ assert len(center) == 2
+ center = [center] * N
+ # assert(isinstance(center,list))
+ assert all([isinstance(x, tuple) for x in center])
+ assert all([len(x) == 2 for x in center])
+ # radii
+ if isinstance(radii, tuple):
+ assert len(radii) == 2
+ ri = [radii[0] for i in range(N)]
+ ro = [radii[1] for i in range(N)]
+ else:
+ assert isinstance(radii, list)
+ assert all([isinstance(x, tuple) for x in radii])
+ assert len(radii) == N
+ ri = [radii[i][0] for i in range(N)]
+ ro = [radii[i][1] for i in range(N)]
+ assert all([isinstance(i, Number) for i in ri])
+ assert all([isinstance(i, Number) for i in ro])
+ # color
+ color = d["color"] if "color" in d.keys() else "r"
+ if isinstance(color, list):
+ assert len(color) == N
+ assert all([is_color_like(c) for c in color])
+ else:
+ assert is_color_like(color)
+ color = [color for i in range(N)]
+ # fill
+ fill = d["fill"] if "fill" in d.keys() else True
+ if isinstance(fill, bool):
+ fill = [fill for i in range(N)]
+ else:
+ assert isinstance(fill, list)
+ assert len(fill) == N
+ assert all([isinstance(f, bool) for f in fill])
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 1
+ if isinstance(alpha, (float, int, np.float64)):
+ alpha = [alpha for i in range(N)]
+ else:
+ assert isinstance(alpha, list)
+ assert len(alpha) == N
+ assert all([isinstance(a, (float, int, np.float64)) for a in alpha])
+ # linewidth
+ linewidth = d["linewidth"] if "linewidth" in d.keys() else 2
+ if isinstance(linewidth, (float, int, np.float64)):
+ linewidth = [linewidth for i in range(N)]
+ else:
+ assert isinstance(linewidth, list)
+ assert len(linewidth) == N
+ assert all([isinstance(lw, (float, int, np.float64)) for lw in linewidth])
+ # additional parameters
+ kws = [
+ k
+ for k in d.keys()
+ if k not in ("center", "radii", "color", "fill", "alpha", "linewidth")
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # add the annuli
+ for i in range(N):
+ cent, Ri, Ro, col, f, a, lw = (
+ center[i],
+ ri[i],
+ ro[i],
+ color[i],
+ fill[i],
+ alpha[i],
+ linewidth[i],
+ )
+ annulus = Wedge(
+ (cent[1], cent[0]),
+ Ro,
+ 0,
+ 360,
+ width=Ro - Ri,
+ color=col,
+ fill=f,
+ alpha=a,
+ linewidth=lw,
+ **kwargs,
+ )
+ ax.add_patch(annulus)
+
+ return
+
+
+def add_ellipses(ax, d):
+ """
+ Adds one or more ellipses to axis ax using the parameters in dictionary d.
+
+ Parameters:
+ center
+ a
+ b
+ theta
+ color
+ fill
+ alpha
+ linewidth
+ linestyle
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # semimajor axis length
+ assert "a" in d.keys()
+ a = d["a"]
+ if isinstance(a, Number):
+ a = [a]
+ assert isinstance(a, list)
+ N = len(a)
+ assert all([isinstance(i, Number) for i in a])
+ # semiminor axis length
+ assert "b" in d.keys()
+ b = d["b"]
+ if isinstance(b, Number):
+ b = [b]
+ assert isinstance(b, list)
+ assert len(b) == N
+ assert all([isinstance(i, Number) for i in b])
+ # center
+ assert "center" in d.keys()
+ center = d["center"]
+ if isinstance(center, tuple):
+ assert len(center) == 2
+ center = [center for i in range(N)]
+ assert isinstance(center, list)
+ assert len(center) == N
+ assert all([isinstance(x, tuple) for x in center])
+ assert all([len(x) == 2 for x in center])
+ # theta
+ assert "theta" in d.keys()
+ theta = d["theta"]
+ if isinstance(theta, Number):
+ theta = [theta for i in range(N)]
+ assert isinstance(theta, list)
+ assert len(theta) == N
+ assert all([isinstance(i, Number) for i in theta])
+ # color
+ color = d["color"] if "color" in d.keys() else "r"
+ if isinstance(color, list):
+ assert len(color) == N
+ assert all([is_color_like(c) for c in color])
+ else:
+ assert is_color_like(color)
+ color = [color for i in range(N)]
+ # fill
+ fill = d["fill"] if "fill" in d.keys() else False
+ if isinstance(fill, bool):
+ fill = [fill for i in range(N)]
+ else:
+ assert isinstance(fill, list)
+ assert len(fill) == N
+ assert all([isinstance(f, bool) for f in fill])
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 1
+ if isinstance(alpha, (float, int, np.float64)):
+ alpha = [alpha for i in range(N)]
+ else:
+ assert isinstance(alpha, list)
+ assert len(alpha) == N
+ assert all([isinstance(alp, (float, int, np.float64)) for alp in alpha])
+ # linewidth
+ linewidth = d["linewidth"] if "linewidth" in d.keys() else 2
+ if isinstance(linewidth, (float, int, np.float64)):
+ linewidth = [linewidth for i in range(N)]
+ else:
+ assert isinstance(linewidth, list)
+ assert len(linewidth) == N
+ assert all([isinstance(lw, (float, int, np.float64)) for lw in linewidth])
+ # linestyle
+ linestyle = d["linestyle"] if "linestyle" in d.keys() else "-"
+ if isinstance(linestyle, (str)):
+ linestyle = [linestyle for i in range(N)]
+ else:
+ assert isinstance(linestyle, list)
+ assert len(linestyle) == N
+ assert all([isinstance(lw, (str)) for lw in linestyle])
+ # additional parameters
+ kws = [
+ k
+ for k in d.keys()
+ if k
+ not in (
+ "center",
+ "a",
+ "b",
+ "theta",
+ "color",
+ "fill",
+ "alpha",
+ "linewidth",
+ "linestyle",
+ )
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # add the ellipses
+ for i in range(N):
+ cent, _a, _b, _theta, col, f, _alpha, lw, ls = (
+ center[i],
+ a[i],
+ b[i],
+ theta[i],
+ color[i],
+ fill[i],
+ alpha[i],
+ linewidth[i],
+ linestyle[i],
+ )
+ ellipse = Ellipse(
+ (cent[1], cent[0]),
+ 2 * _b,
+ 2 * _a,
+ angle=-np.degrees(_theta),
+ color=col,
+ fill=f,
+ alpha=_alpha,
+ linewidth=lw,
+ linestyle=ls,
+ **kwargs,
+ )
+ ax.add_patch(ellipse)
+
+ return
+
+
+def add_points(ax, d):
+ """
+ adds one or more points to axis ax using the parameters in dictionary d.
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # x
+ assert "x" in d.keys()
+ x = d["x"]
+ if isinstance(x, Number):
+ x = [x]
+ x = np.array(x)
+ N = len(x)
+ # y
+ assert "y" in d.keys()
+ y = d["y"]
+ if isinstance(y, Number):
+ y = [y]
+ y = np.array(y)
+ assert len(y) == N
+ # s
+ s = d["s"] if "s" in d.keys() else np.ones(N)
+ if isinstance(s, Number):
+ s = np.ones_like(x) * s
+ assert len(s) == N
+ s = np.where(s > 0, s, 0)
+ # scale
+ scale = d["scale"] if "scale" in d.keys() else 25
+ assert isinstance(scale, Number)
+ # point color
+ color = d["pointcolor"] if "pointcolor" in d.keys() else "r"
+ if isinstance(color, (list, np.ndarray)):
+ assert len(color) == N
+ assert all([is_color_like(c) for c in color])
+ else:
+ assert is_color_like(color)
+ color = [color for i in range(N)]
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 1.0
+ assert isinstance(alpha, Number)
+ # open_circles
+ open_circles = d["open_circles"] if "open_circles" in d.keys() else False
+ assert isinstance(open_circles, bool)
+ # additional parameters
+ kws = [
+ k
+ for k in d.keys()
+ if k not in ("x", "y", "s", "scale", "pointcolor", "alpha", "open_circles")
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # add the points
+ if open_circles:
+ ax.scatter(
+ y, x, s=scale, edgecolor=color, facecolor="none", alpha=alpha, **kwargs
+ )
+ else:
+ ax.scatter(y, x, s=s * scale / np.max(s), color=color, alpha=alpha, **kwargs)
+
+ return
+
+
+def add_pointlabels(ax, d):
+ """
+ adds number indices for a set of points to axis ax using the parameters in dictionary d.
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # x
+ assert "x" in d.keys()
+ x = d["x"]
+ if isinstance(x, Number):
+ x = [x]
+ x = np.array(x)
+ N = len(x)
+ # y
+ assert "y" in d.keys()
+ y = d["y"]
+ if isinstance(y, Number):
+ y = [y]
+ y = np.array(y)
+ assert len(y) == N
+ # size
+ size = d["size"] if "size" in d.keys() else 20
+ assert isinstance(size, Number)
+ # color
+ color = d["color"] if "color" in d.keys() else "r"
+ assert is_color_like(color)
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 1.0
+ assert isinstance(alpha, Number)
+ # labels
+ labels = d["labels"] if "labels" in d.keys() else np.arange(N).astype(str)
+ assert len(labels) == N
+ # additional parameters
+ kws = [
+ k for k in d.keys() if k not in ("x", "y", "size", "color", "alpha", "labels")
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # add the point labels
+ for i in range(N):
+ ax.text(y[i], x[i], s=labels[i], color=color, size=size, alpha=alpha, **kwargs)
+
+ return
+
+
+def add_bragg_index_labels(ax, d):
+ """
+ Adds labels for indexed bragg directions to a plot, using the parameters in dict d.
+
+ The dictionary d has required and optional parameters as follows:
+ bragg_directions (req'd) (PointList) the Bragg directions. This PointList must have
+ the fields 'qx','qy','h', and 'k', and may optionally have 'l'
+ voffset (number) vertical offset for the labels
+ hoffset (number) horizontal offset for the labels
+ color (color)
+ size (number)
+ points (bool)
+ pointsize (number)
+ pointcolor (color)
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # bragg directions
+ assert "bragg_directions" in d.keys()
+ bragg_directions = d["bragg_directions"]
+ assert isinstance(bragg_directions, PointList)
+ for k in ("qx", "qy", "h", "k"):
+ assert k in bragg_directions.data.dtype.fields
+ include_l = True if "l" in bragg_directions.data.dtype.fields else False
+ # offsets
+ hoffset = d["hoffset"] if "hoffset" in d.keys() else 0
+ voffset = d["voffset"] if "voffset" in d.keys() else 5
+ # size, color
+ size = d["size"] if "size" in d.keys() else 20
+ assert isinstance(size, Number)
+ color = d["color"] if "color" in d.keys() else "w"
+ assert is_color_like(color)
+ # points
+ points = d["points"] if "points" in d.keys() else True
+ pointsize = d["pointsize"] if "pointsize" in d.keys() else 50
+ pointcolor = d["pointcolor"] if "pointcolor" in d.keys() else "r"
+ assert isinstance(points, bool)
+ assert isinstance(pointsize, Number)
+ assert is_color_like(pointcolor)
+
+ # add the points
+ if points:
+ ax.scatter(
+ bragg_directions.data["qy"],
+ bragg_directions.data["qx"],
+ color=pointcolor,
+ s=pointsize,
+ )
+
+ # add index labels
+ for i in range(bragg_directions.length):
+ x, y = bragg_directions.data["qx"][i], bragg_directions.data["qy"][i]
+ x -= voffset
+ y += hoffset
+ h, k = bragg_directions.data["h"][i], bragg_directions.data["k"][i]
+ h = str(h) if h >= 0 else r"$\overline{{{}}}$".format(np.abs(h))
+ k = str(k) if k >= 0 else r"$\overline{{{}}}$".format(np.abs(k))
+ s = h + "," + k
+ if include_l:
+ l = bragg_directions.data["l"][i]
+ l = str(l) if l >= 0 else r"$\overline{{{}}}$".format(np.abs(l))
+ s += l
+ ax.text(y, x, s, color=color, size=size, ha="center", va="bottom")
+
+ return
+
+
+def add_vector(ax, d):
+ """
+ Adds a vector to an image, using the parameters in dict d.
+
+ The dictionary d has required and optional parameters as follows:
+ x0,y0 (req'd) the tail position
+ vx,vy (req'd) the vector
+ color (color)
+ width (number)
+ label (str)
+ labelsize (number)
+ labelcolor (color)
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # head and tail positions
+ assert "x0" in d.keys()
+ assert "y0" in d.keys()
+ assert "vx" in d.keys()
+ assert "vy" in d.keys()
+ x0, y0, vx, vy = d["x0"], d["y0"], d["vx"], d["vy"]
+ # width
+ width = d["width"] if "width" in d.keys() else 1
+ # color
+ color = d["color"] if "color" in d.keys() else "r"
+ assert is_color_like(color)
+ # label
+ label = d["label"] if "label" in d.keys() else False
+ labelsize = d["labelsize"] if "labelsize" in d.keys() else 20
+ labelcolor = d["labelcolor"] if "labelcolor" in d.keys() else "w"
+ assert isinstance(label, (str, bool))
+ assert isinstance(labelsize, Number)
+ assert is_color_like(labelcolor)
+ # additional parameters
+ kws = [
+ k
+ for k in d.keys()
+ if k
+ not in (
+ "x0",
+ "y0",
+ "vx",
+ "vy",
+ "width",
+ "color",
+ "label",
+ "labelsize",
+ "labelcolor",
+ )
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # Add the vector
+ ax.arrow(
+ y0, x0, vy, vx, color=color, width=width, length_includes_head=True, **kwargs
+ )
+
+ # Add label
+ if label:
+ x, y = x0 + 0.5 * vx, y0 + 0.5 * vy
+ ax.text(y, x, label, size=labelsize, color=labelcolor, ha="center", va="center")
+
+ return
+
+
+def add_grid_overlay(ax, d):
+ """
+ adds an overlaid grid over some subset of pixels in an image
+ using the parameters in dictionary d.
+
+ The dictionary d has required and optional parameters as follows:
+ x0,y0 (req'd) (ints) the corner of the grid
+ xL,xL (req'd) (ints) the extent of the grid
+ color (color)
+ linewidth (number)
+ alpha (number)
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # corner, extent
+ lims = [0, 0, 0, 0]
+ for i, k in enumerate(("x0", "y0", "xL", "yL")):
+ assert k in d.keys(), "Error: add_grid_overlay expects keys 'x0','y0','xL','yL'"
+ lims[i] = d[k]
+ x0, y0, xL, yL = lims
+ # color
+ color = d["color"] if "color" in d.keys() else "k"
+ assert is_color_like(color)
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 1
+ assert isinstance(alpha, (Number))
+ # linewidth
+ linewidth = d["linewidth"] if "linewidth" in d.keys() else 1
+ assert isinstance(linewidth, (Number))
+ # additional parameters
+ kws = [
+ k
+ for k in d.keys()
+ if k not in ("x0", "y0", "xL", "yL", "color", "alpha", "linewidth")
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # add the grid
+ yy, xx = np.meshgrid(np.arange(y0, y0 + yL), np.arange(x0, x0 + xL))
+ for xi in range(xL):
+ for yi in range(yL):
+ x, y = xx[xi, yi], yy[xi, yi]
+ rect = Rectangle(
+ (y - 0.5, x - 0.5),
+ 1,
+ 1,
+ lw=linewidth,
+ color=color,
+ alpha=alpha,
+ fill=False,
+ )
+ ax.add_patch(rect)
+ return
+
+
+def add_scalebar(ax, d):
+ """
+ Adds an overlaid scalebar to an image, using the parameters in dict d.
+
+ The dictionary d has required and optional parameters as follows:
+ Nx,Ny (req'd) the image extent
+ space (str) 'Q' or 'R'
+ length (number) the scalebar length
+ width (number) the scalebar width
+ pixelsize (number)
+ pixelunits (str)
+ color (color)
+ label (bool)
+ labelsize (number)
+ labelcolor (color)
+ alpha (number)
+ position (str) 'ul','ur','bl', or 'br' for the
+ upperleft, upperright, bottomleft, bottomright
+ ticks (bool) if False, turns off image border ticks
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # image extent
+ assert "Nx" in d.keys()
+ assert "Ny" in d.keys()
+ Nx, Ny = d["Nx"], d["Ny"]
+ # real or diffraction
+ space = d["space"] if "space" in d.keys() else "Q"
+ assert space in ("Q", "R")
+ # length,width
+ length = d["length"] if "length" in d.keys() else None
+ width = d["width"] if "width" in d.keys() else 6
+ # pixelsize, pixelunits
+ pixelsize = d["pixelsize"] if "pixelsize" in d.keys() else 1
+ pixelunits = d["pixelunits"] if "pixelunits" in d.keys() else "pixels"
+ # color
+ color = d["color"] if "color" in d.keys() else "w"
+ assert is_color_like(color)
+ # labels
+ label = d["label"] if "label" in d.keys() else True
+ labelsize = d["labelsize"] if "labelsize" in d.keys() else 16
+ labelcolor = d["labelcolor"] if "labelcolor" in d.keys() else color
+ assert isinstance(label, bool)
+ assert isinstance(labelsize, Number)
+ assert is_color_like(labelcolor)
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 1
+ assert isinstance(alpha, (Number))
+ # position
+ position = d["position"] if "position" in d.keys() else "br"
+ assert position in ("ul", "ur", "bl", "br")
+ # ticks
+ ticks = d["ticks"] if "ticks" in d.keys() else False
+ assert isinstance(ticks, bool)
+ # additional parameters
+ kws = [
+ k
+ for k in d.keys()
+ if k
+ not in (
+ "Nx",
+ "Ny",
+ "length",
+ "width",
+ "pixelsize",
+ "pixelunits",
+ "color",
+ "label",
+ "labelsize",
+ "labelcolor",
+ "alpha",
+ "position",
+ "ticks",
+ )
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # Get length
+ if length is None:
+ length_units, length_pixels, _ = get_nice_spacing(Nx, Ny, pixelsize)
+ else:
+ length_units, length_pixels = length, length / pixelsize
+
+ # Get position
+ if position == "ul":
+ x0, y0 = 0, 0
+ xshiftdir, yshiftdir = 1, 1
+ elif position == "ur":
+ x0, y0 = 0, Ny - 1
+ xshiftdir, yshiftdir = 1, -1
+ elif position == "bl":
+ x0, y0 = Nx - 1, 0
+ xshiftdir, yshiftdir = -1, 1
+ else:
+ x0, y0 = Nx - 1, Ny - 1
+ xshiftdir, yshiftdir = -1, -1
+ pad = 0.2 * length_pixels
+ xshift = xshiftdir * pad
+ yshift = yshiftdir * (length_pixels / 2.0 + pad)
+ x0 = x0 + xshift
+ y0 = y0 + yshift
+ xi, yi = x0, y0 - length_pixels / 2.0
+ xf, yf = x0, y0 + length_pixels / 2.0
+ labelpos_x = x0 + pad * xshiftdir / 2.0
+ labelpos_y = y0
+
+ # Add line
+ ax.plot(
+ (yi, yf),
+ (xi, xf),
+ color=color,
+ alpha=alpha,
+ lw=width,
+ solid_capstyle="butt",
+ )
+
+ # Add label
+ if label:
+ labeltext = f"{np.round(length_units,3)}" + " " + pixelunits
+ if xshiftdir > 0:
+ va = "top"
+ else:
+ va = "bottom"
+ ax.text(
+ labelpos_y,
+ labelpos_x,
+ labeltext,
+ size=labelsize,
+ color=labelcolor,
+ alpha=alpha,
+ ha="center",
+ va=va,
+ )
+
+ # if not ticks:
+ # ax.set_xticks([])
+ # ax.set_yticks([])
+ return
+
+
+def add_cartesian_grid(ax, d):
+ """
+ Adds an overlaid cartesian coordinate grid over an image
+ using the parameters in dictionary d.
+
+ The dictionary d has required and optional parameters as follows:
+ x0,y0 (req'd) the origin
+ Nx,Ny (req'd) the image extent
+ space (str) 'Q' or 'R'
+ spacing (number) spacing between gridlines
+ pixelsize (number)
+ pixelunits (str)
+ lw (number)
+ ls (str)
+ color (color)
+ label (bool)
+ labelsize (number)
+ labelcolor (color)
+ alpha (number)
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # origin
+ assert "x0" in d.keys()
+ assert "y0" in d.keys()
+ x0, y0 = d["x0"], d["y0"]
+ # image extent
+ assert "Nx" in d.keys()
+ assert "Ny" in d.keys()
+ Nx, Ny = d["Nx"], d["Ny"]
+ assert x0 < Nx and y0 < Ny
+ # real or diffraction
+ space = d["space"] if "space" in d.keys() else "Q"
+ assert space in ("Q", "R")
+ # spacing, pixelsize, pixelunits
+ spacing = d["spacing"] if "spacing" in d.keys() else None
+ pixelsize = d["pixelsize"] if "pixelsize" in d.keys() else 1
+ pixelunits = d["pixelunits"] if "pixelunits" in d.keys() else "pixels"
+ # gridlines
+ lw = d["lw"] if "lw" in d.keys() else 1
+ ls = d["ls"] if "ls" in d.keys() else ":"
+ # color
+ color = d["color"] if "color" in d.keys() else "w"
+ assert is_color_like(color)
+ # labels
+ label = d["label"] if "label" in d.keys() else False
+ labelsize = d["labelsize"] if "labelsize" in d.keys() else 12
+ labelcolor = d["labelcolor"] if "labelcolor" in d.keys() else "k"
+ assert isinstance(label, bool)
+ assert isinstance(labelsize, Number)
+ assert is_color_like(labelcolor)
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 0.35
+ assert isinstance(alpha, (Number))
+ # additional parameters
+ kws = [
+ k
+ for k in d.keys()
+ if k
+ not in (
+ "x0",
+ "y0",
+ "spacing",
+ "lw",
+ "ls",
+ "color",
+ "label",
+ "labelsize",
+ "labelcolor",
+ "alpha",
+ )
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # Get the major grid-square size
+ if spacing is None:
+ gridspacing, _, _gridspacing = get_nice_spacing(Nx, Ny, pixelsize)
+ else:
+ gridspacing, _gridspacing = spacing, 1.5
+
+ # Get positions for the major gridlines
+ xmin = (-x0) * pixelsize
+ xmax = (Nx - 1 - x0) * pixelsize
+ ymin = (-y0) * pixelsize
+ ymax = (Ny - 1 - y0) * pixelsize
+ xticksmajor = np.concatenate(
+ (
+ -1 * np.arange(0, np.abs(xmin), gridspacing)[1:][::-1],
+ np.arange(0, xmax, gridspacing),
+ )
+ )
+ yticksmajor = np.concatenate(
+ (
+ -1 * np.arange(0, np.abs(ymin), gridspacing)[1:][::-1],
+ np.arange(0, ymax, gridspacing),
+ )
+ )
+ xticklabels = xticksmajor.copy()
+ yticklabels = yticksmajor.copy()
+ xticksmajor = (xticksmajor - xmin) / pixelsize
+ yticksmajor = (yticksmajor - ymin) / pixelsize
+
+ # Get labels
+ exp_spacing = int(np.round(log(gridspacing, 10), 6))
+ if np.sign(log(gridspacing, 10)) < 0:
+ exp_spacing -= 1
+ base_spacing = gridspacing / (10**exp_spacing)
+ xticklabels = xticklabels / (10**exp_spacing)
+ yticklabels = yticklabels / (10**exp_spacing)
+ if exp_spacing == 1:
+ xticklabels *= 10
+ yticklabels *= 10
+ if _gridspacing in (0.4, 0.75, 1.5, 2.5) and exp_spacing != 1:
+ xticklabels = ["{:.1f}".format(n) for n in xticklabels]
+ yticklabels = ["{:.1f}".format(n) for n in yticklabels]
+ else:
+ xticklabels = ["{:.0f}".format(n) for n in xticklabels]
+ yticklabels = ["{:.0f}".format(n) for n in yticklabels]
+
+ # Add the grid
+ ax.set_xticks(yticksmajor)
+ ax.set_yticks(xticksmajor)
+ ax.xaxis.set_ticks_position("bottom")
+ if label:
+ ax.set_xticklabels(yticklabels, size=labelsize, color=labelcolor)
+ ax.set_yticklabels(xticklabels, size=labelsize, color=labelcolor)
+ axislabel_x = r"$q_x$" if space == "Q" else r"$x$"
+ axislabel_y = r"$q_y$" if space == "Q" else r"$y$"
+ if exp_spacing in (0, 1):
+ ax.set_xlabel(
+ axislabel_y + " (" + pixelunits + ")", size=labelsize, color=labelcolor
+ )
+ ax.set_ylabel(
+ axislabel_x + " (" + pixelunits + ")", size=labelsize, color=labelcolor
+ )
+ else:
+ ax.set_xlabel(
+ axislabel_y + " (" + pixelunits + " e" + str(exp_spacing) + ")",
+ size=labelsize,
+ color=labelcolor,
+ )
+ ax.set_ylabel(
+ axislabel_x + " (" + pixelunits + " e" + str(exp_spacing) + ")",
+ size=labelsize,
+ color=labelcolor,
+ )
+ else:
+ ax.set_xticklabels([])
+ ax.set_yticklabels([])
+ ax.grid(linestyle=ls, linewidth=lw, color=color, alpha=alpha)
+
+ return
+
+
+def add_polarelliptical_grid(ax, d):
+ """
+ adds an overlaid polar-ellitpical coordinate grid over an image
+ using the parameters in dictionary d.
+
+ The dictionary d has required and optional parameters as follows:
+ x0,y0 (req'd) the origin
+ e,theta (req'd) the ellipticity (a/b) and major axis angle (radians)
+ Nx,Ny (req'd) the image extent
+ space (str) 'Q' or 'R'
+ spacing (number) spacing between radial gridlines
+ N_thetalines (int) the number of theta gridlines
+ pixelsize (number)
+ pixelunits (str)
+ lw (number)
+ ls (str)
+ color (color)
+ label (bool)
+ labelsize (number)
+ labelcolor (color)
+ alpha (number)
+ """
+ # handle inputs
+ assert isinstance(ax, Axes)
+ # origin
+ assert "x0" in d.keys()
+ assert "y0" in d.keys()
+ x0, y0 = d["x0"], d["y0"]
+ # ellipticity
+ assert "e" in d.keys()
+ assert "theta" in d.keys()
+ e, theta = d["e"], d["theta"]
+ # image extent
+ assert "Nx" in d.keys()
+ assert "Ny" in d.keys()
+ Nx, Ny = d["Nx"], d["Ny"]
+ assert x0 < Nx and y0 < Ny
+ # real or diffraction
+ space = d["space"] if "space" in d.keys() else "Q"
+ assert space in ("Q", "R")
+ # spacing, N_thetalines, pixelsize, pixelunits
+ spacing = d["spacing"] if "spacing" in d.keys() else None
+ N_thetalines = d["N_thetalines"] if "N_thetalines" in d.keys() else 8
+ assert N_thetalines % 2 == 0, "N_thetalines must be even"
+ N_thetalines = N_thetalines // 2
+ assert isinstance(N_thetalines, (int, np.integer))
+ pixelsize = d["pixelsize"] if "pixelsize" in d.keys() else 1
+ pixelunits = d["pixelunits"] if "pixelunits" in d.keys() else "pixels"
+ # gridlines
+ lw = d["lw"] if "lw" in d.keys() else 1
+ ls = d["ls"] if "ls" in d.keys() else ":"
+ # color
+ color = d["color"] if "color" in d.keys() else "w"
+ assert is_color_like(color)
+ # labels
+ label = d["label"] if "label" in d.keys() else False
+ labelsize = d["labelsize"] if "labelsize" in d.keys() else 8
+ labelcolor = d["labelcolor"] if "labelcolor" in d.keys() else color
+ assert isinstance(label, bool)
+ assert isinstance(labelsize, Number)
+ assert is_color_like(labelcolor)
+ # alpha
+ alpha = d["alpha"] if "alpha" in d.keys() else 0.5
+ assert isinstance(alpha, (Number))
+ # additional parameters
+ kws = [
+ k
+ for k in d.keys()
+ if k
+ not in (
+ "x0",
+ "y0",
+ "spacing",
+ "lw",
+ "ls",
+ "color",
+ "label",
+ "labelsize",
+ "labelcolor",
+ "alpha",
+ )
+ ]
+ kwargs = dict()
+ for k in kws:
+ kwargs[k] = d[k]
+
+ # Get the radial spacing
+ if spacing is None:
+ spacing, _, _spacing = get_nice_spacing(Nx, Ny, pixelsize)
+ spacing = spacing / 2.0
+ else:
+ _spacing = 1.5
+
+ # Get positions for the radial gridlines
+ xmin = (-x0) * pixelsize
+ xmax = (Nx - 1 - x0) * pixelsize
+ ymin = (-y0) * pixelsize
+ ymax = (Ny - 1 - y0) * pixelsize
+ rcorners = (
+ np.hypot(xmin, ymin),
+ np.hypot(xmin, ymax),
+ np.hypot(xmax, ymin),
+ np.hypot(xmax, ymax),
+ )
+ rticks = np.arange(0, np.max(rcorners), spacing)[1:]
+ rticklabels = rticks.copy()
+ rticks = rticks / pixelsize
+
+ # Add radial gridlines
+ N = len(rticks)
+ d_ellipses = {
+ "a": list(rticks),
+ "center": (x0, y0),
+ "e": e,
+ "theta": theta,
+ "fill": False,
+ "color": color,
+ "linewidth": lw,
+ "linestyle": ls,
+ "alpha": alpha,
+ }
+ add_ellipses(ax, d_ellipses)
+
+ # Add radial gridline labels
+ if label:
+ # Get gridline label positions
+ rlabelpos_scale = 1 + (e - 1) * np.sin(np.pi / 2 - theta) ** 4
+ rlabelpositions = x0 - rticks * rlabelpos_scale
+ for i in range(len(rticklabels)):
+ xpos = rlabelpositions[i]
+ if xpos > labelsize / 2:
+ ax.text(
+ y0,
+ rlabelpositions[i],
+ rticklabels[i],
+ size=labelsize,
+ color=labelcolor,
+ alpha=alpha,
+ ha="center",
+ va="center",
+ )
+
+ # Add theta gridlines
+ def add_line(ax, x0, y0, theta, Nx, Ny):
+ """adds a line through (x0,y0) at an angle theta which terminates at the image edges
+ returns the termination points (xi,yi),(xf,xy)
+ """
+ theta = np.mod(np.pi / 2 - theta, np.pi)
+ if theta == 0:
+ xs, ys = [0, Nx - 1], [y0, y0]
+ elif theta == np.pi / 2:
+ xs, ys = [x0, x0], [0, Ny - 1]
+ else:
+ # Get line params
+ m = np.tan(theta)
+ b = y0 - m * x0
+ # Get intersections with x=0,x=Nx-1,y=0,y=Ny-1
+ x1, y1 = 0, b
+ x2, y2 = Nx - 1, m * (Nx - 1) + b
+ x3, y3 = -b / m, 0
+ x4, y4 = (Ny - 1 - b) / m, Ny - 1
+ # Determine which points are on the image bounding box
+ xs, ys = [], []
+ if 0 <= y1 < Ny - 1:
+ xs.append(x1), ys.append(y1)
+ if 0 <= y2 < Ny - 1:
+ xs.append(x2), ys.append(y2)
+ if 0 <= x3 < Nx - 1:
+ xs.append(x3), ys.append(y3)
+ if 0 <= x4 < Nx - 1:
+ xs.append(x4), ys.append(y4)
+ assert len(xs) == len(ys) == 2
+
+ ax.plot(xs, ys, color=color, ls=ls, alpha=alpha, lw=lw)
+ return tuple([(xs[i], ys[i]) for i in range(2)])
+
+ thetalabelpos = []
+ for t in theta + np.linspace(0, np.pi, N_thetalines, endpoint=False):
+ thetalabelpos.append(add_line(ax, x0, y0, t, Nx, Ny))
+ thetalabelpos = [thetalabelpos[i][0] for i in range(len(thetalabelpos))] + [
+ thetalabelpos[i][1] for i in range(len(thetalabelpos))
+ ]
+ # Get angles for the theta gridlines
+ thetaticklabels = [
+ str(Fraction(i, N_thetalines)) + r"$\pi$" for i in range(2 * N_thetalines)
+ ]
+ thetaticklabels[0] = "0"
+ thetaticklabels[N_thetalines] = r"$\pi$"
+
+ # Add theta gridline labels
+ if label:
+ for i in range(len(thetaticklabels)):
+ x, y = thetalabelpos[i]
+ if x == 0:
+ ha, va = "left", "center"
+ elif x == Nx - 1:
+ ha, va = "right", "center"
+ elif y == 0:
+ ha, va = "center", "top"
+ else:
+ ha, va = "center", "bottom"
+ ax.text(
+ x,
+ y,
+ thetaticklabels[i],
+ size=labelsize,
+ color=labelcolor,
+ alpha=alpha,
+ ha=ha,
+ va=va,
+ )
+ pass
+
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_xlim([0, Nx - 1])
+ ax.set_ylim([0, Ny - 1])
+ ax.invert_yaxis()
+ return
+
+
+def add_rtheta_grid(ar, d):
+ return
+
+
+def get_nice_spacing(Nx, Ny, pixelsize):
+ """Get a nice distance for gridlines, scalebars, etc
+
+ Args:
+ Nx,Nx (int): the image dimensions
+ pixelsize (float): the size of each pixel, in some units
+
+ Returns:
+ (3-tuple): A 3-tuple containing:
+
+ * **spacing_units**: the spacing in real units
+ * **spacing_pixels**:the spacing in pixels
+ * **spacing**: the leading digits of the spacing
+ """
+ D = np.mean((Nx * pixelsize, Ny * pixelsize)) / 2.0
+ exp = int(log(D, 10))
+ if np.sign(log(D, 10)) < 0:
+ exp -= 1
+ base = D / (10**exp)
+ if base >= 1 and base < 2.1:
+ _spacing = 0.5
+ elif base >= 2.1 and base < 4.6:
+ _spacing = 1
+ elif base >= 4.6 and base <= 10:
+ _spacing = 2
+ # if base>=1 and base<1.25:
+ # _spacing=0.4
+ # elif base>=1.25 and base<1.75:
+ # _spacing=0.5
+ # elif base>=1.75 and base<2.5:
+ # _spacing=0.75
+ # elif base>=2.5 and base<3.25:
+ # _spacing=1
+ # elif base>=3.25 and base<4.75:
+ # _spacing=1.5
+ # elif base>=4.75 and base<6:
+ # _spacing=2
+ # elif base>=6 and base<8:
+ # _spacing=2.5
+ # elif base>=8 and base<10:
+ # _spacing=3
+ else:
+ raise Exception("how did this happen?? base={}".format(base))
+ spacing = _spacing * 10**exp
+ return spacing, spacing / pixelsize, _spacing
diff --git a/py4DSTEM/visualize/show.py b/py4DSTEM/visualize/show.py
new file mode 100644
index 000000000..8462eec7d
--- /dev/null
+++ b/py4DSTEM/visualize/show.py
@@ -0,0 +1,1401 @@
+import warnings
+from copy import copy
+from math import log
+from numbers import Number
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.axes import Axes
+from matplotlib.colors import is_color_like
+from matplotlib.figure import Figure
+from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
+from py4DSTEM.data import Calibration, DiffractionSlice, RealSlice
+from py4DSTEM.visualize.overlay import (
+ add_annuli,
+ add_cartesian_grid,
+ add_circles,
+ add_ellipses,
+ add_grid_overlay,
+ add_points,
+ add_polarelliptical_grid,
+ add_rectangles,
+ add_rtheta_grid,
+ add_scalebar,
+)
+
+
+def show(
+ ar,
+ figsize=(5, 5),
+ cmap="gray",
+ scaling="none",
+ intensity_range="ordered",
+ clipvals=None,
+ vmin=None,
+ vmax=None,
+ min=None,
+ max=None,
+ power=None,
+ power_offset=True,
+ combine_images=False,
+ ticks=True,
+ bordercolor=None,
+ borderwidth=5,
+ show_image=True,
+ return_ar_scaled=False,
+ return_intensity_range=False,
+ returncax=False,
+ returnfig=False,
+ figax=None,
+ hist=False,
+ n_bins=256,
+ mask=None,
+ mask_color="k",
+ mask_alpha=0.0,
+ masked_intensity_range=False,
+ rectangle=None,
+ circle=None,
+ annulus=None,
+ ellipse=None,
+ points=None,
+ grid_overlay=None,
+ cartesian_grid=None,
+ polarelliptical_grid=None,
+ rtheta_grid=None,
+ scalebar=None,
+ calibration=None,
+ rx=None,
+ ry=None,
+ space="Q",
+ pixelsize=None,
+ pixelunits=None,
+ x0=None,
+ y0=None,
+ a=None,
+ e=None,
+ theta=None,
+ title=None,
+ show_fft=False,
+ show_cbar=False,
+ **kwargs,
+):
+ """
+ General visualization function for 2D arrays.
+
+ The simplest use of this function is::
+
+ >>> show(ar)
+
+ which will generate and display a matplotlib figure showing the 2D array ``ar``.
+ Additional functionality includes:
+
+ * scaling the image (log scaling, power law scaling)
+ * displaying the image histogram
+ * altering the histogram clip values
+ * masking some subset of the image
+ * setting the colormap
+ * adding geometric overlays (e.g. points, circles, rectangles, annuli)
+ * adding informational overlays (scalebars, coordinate grids, oriented axes or
+ vectors)
+ * further customization tools
+
+ These are each discussed in turn below.
+
+ Scaling:
+ Setting the parameter ``scaling`` will scale the display image. Options are
+ 'none', 'auto', 'power', or 'log'. If 'power' is specified, the parameter ``power`` must
+ also be passed. The underlying data is not altered. Values less than or equal to
+ zero are set to zero. If the image histogram is displayed using ``hist=True``,
+ the scaled image histogram is shown.
+
+ Examples::
+
+ >>> show(ar,scaling='log')
+ >>> show(ar,power=0.5)
+ >>> show(ar,scaling='power',power=0.5,hist=True)
+
+ Histogram:
+ Setting the argument ``hist=True`` will display the image histogram, instead of
+ the image. The displayed histogram will reflect any scaling requested. The number
+ of bins can be set with ``n_bins``. The upper and lower clip values, indicating
+ where the image display will be saturated, are shown with dashed lines.
+
+ Intensity range:
+ Controlling the lower and upper values at which the display image will be
+ saturated is accomplished with the ``intensity_range`` parameter, or its
+ (soon deprecated) alias ``clipvals``, in combination with ``vmin``,
+ and ``vmax``. The method by which the upper and lower clip values
+ are determined is controlled by ``intensity_range``, and must be a string in
+ ('None','ordered','minmax','absolute','std','centered'). See the argument
+ description for ``intensity_range`` for a description of the behavior for each.
+ The clip values can be returned with the ``return_intensity_range`` parameter.
+
+ Masking:
+ If a numpy masked array is passed to show, the function will automatically
+ mask the appropriate pixels. Alternatively, a boolean array of the same shape as
+ the data array may be passed to the ``mask`` argument, and these pixels will be
+ masked. Masked pixels are displayed as a single uniform color, black by default,
+ and which can be specified with the ``mask_color`` argument. Masked pixels
+ are excluded when displaying the histogram or computing clip values. The mask
+ can also be blended with the hidden data by setting the ``mask_alpha`` argument.
+
+ Overlays (geometric):
+ The function natively supports overlaying points, circles, rectangles, annuli,
+ and ellipses. Each is invoked by passing a dictionary to the appropriate input
+ variable specifying the geometry and features of the requested overlay. For
+ example:
+
+ >>> show(ar, rectangle={'lims':(10,20,10,20),'color':'r'})
+
+ will overlay a single red square, and
+
+ >>> show(ar, annulus={'center':[(28,68),(92,160)],
+ 'radii':[(16,24),(12,36)],
+ 'fill':True,
+ 'alpha':[0.9,0.3],
+ 'color':['r',(0,1,1,1)]})
+
+ will overlay two annuli with two different centers, radii, colors, and
+ transparencies. For a description of the accepted dictionary parameters
+ for each type of overlay, see the visualize functions add_*, where
+ * = ('rectangle','circle','annulus','ellipse','points'). (These docstrings
+ are under construction!)
+
+ Overlays (informational):
+ Informational overlays supported by this function include coordinate axes
+ (cartesian, polar-elliptical, or r-theta) and scalebars. These are added
+ by passing the appropriate input argument a dictionary of the desired
+ parameters, as with geometric overlays. However, there are two key differences
+ between these overlays and the geometric overlays. First, informational
+ overlays (coordinate systems and scalebars) require information about the
+ plot - e.g. the position of the origin, the pixel sizes, the pixel units,
+ any elliptical distortions, etc. The easiest way to pass this information
+ is by pass a Calibration object containing this info to ``show`` as the
+ keyword ``calibration``. Second, once the coordinate information has been
+ passed, informational overlays can autoselect their own parameters, thus simply
+ passing an empty dict to one of these parameters will add that overlay.
+
+ For example:
+
+ >>> show(dp, scalebar={}, calibration=calibration)
+
+ will display the diffraction pattern ``dp`` with a scalebar overlaid in the
+ bottom left corner given the pixel size and units described in ``calibration``,
+ and
+
+ >>> show(dp, calibration=calibration, scalebar={'length':0.5,'width':2,
+ 'position':'ul','label':True'})
+
+ will display a more customized scalebar.
+
+ When overlaying coordinate grids, it is important to note that some relevant
+ parameters, e.g. the position of the origin, may change by scan position.
+ In these cases, the parameters ``rx``,``ry`` must also be passed to ``show``,
+ to tell the ``Calibration`` object where to look for the relevant parameters.
+ For example:
+
+ >>> show(dp, cartesian_grid={}, calibration=calibration, rx=2,ry=5)
+
+ will overlay a cartesian coordinate grid on the diffraction pattern at scan
+ position (2,5). Adding
+
+ >>> show(dp, calibration=calibration, rx=2, ry=5, cartesian_grid={'label':True,
+ 'alpha':0.7,'color':'r'})
+
+ will customize the appearance of the grid further. And
+
+ >>> show(im, calibration=calibration, cartesian_grid={}, space='R')
+
+ displays a cartesian grid over a real space image. For more details, see the
+ documentation for the visualize functions add_*, where * = ('scalebar',
+ 'cartesian_grid', 'polarelliptical_grid', 'rtheta_grid'). (Under construction!)
+
+ Further customization:
+ Most parameters accepted by a matplotlib axis will be accepted by ``show``.
+ Pass a valid matplotlib colormap or a known string indicating a colormap
+ as the argument ``cmap`` to specify the colormap. Pass ``figsize`` to
+ specify the figure size. Etc.
+
+ Further customization can be accomplished by either (1) returning the figure
+ generated by show and then manipulating it using the normal matplotlib
+ functions, or (2) generating a matplotlib Figure with Axes any way you like
+ (e.g. with ``plt.subplots``) and then using this function to plot inside a
+ single one of the Axes of your choice.
+
+ Option (1) is accomplished by simply passing this function ``returnfig=True``.
+ Thus:
+
+ >>> fig,ax = show(ar, returnfig=True)
+
+ will now give you direct access to the figure and axes to continue to alter.
+ Option (2) is accomplished by passing an existing figure and axis to ``show``
+ as a 2-tuple to the ``figax`` argument. Thus:
+
+ >>> fig,(ax1,ax2) = plt.subplots(1,2)
+ >>> show(ar, figax=(fig,ax1))
+ >>> show(ar, figax=(fig,ax2), hist=True)
+
+ will generate a 2-axis figure, and then plot the array ``ar`` as an image on
+ the left, while plotting its histogram on the right.
+
+
+ Args:
+ ar (2D array or a list of 2D arrays): the data to plot. Normally this
+ is a 2D array of the data. If a list of 2D arrays is passed, plots
+ a corresponding grid of images.
+ figsize (2-tuple): size of the plot
+ cmap (colormap): any matplotlib cmap; default is gray
+ scaling (str): selects a scaling scheme for the intensity values. Default is
+ none. Accepted values:
+ * 'none': do not scale intensity values
+ * 'full': fill entire color range with sorted intensity values
+ * 'power': power law scaling
+ * 'log': values where ar<=0 are set to 0
+ intensity_range (str): method for setting clipvalues (min and max intensities).
+ The original name "clipvals" is now deprecated.
+ Default is 'ordered'. Accepted values:
+ * 'ordered': vmin/vmax are set to fractions of the
+ distribution of pixel values in the array, e.g. vmin=0.02
+ will set the minumum display value to saturate the lower 2% of pixels
+ * 'minmax': The vmin/vmax values are np.min(ar)/np.max(r)
+ * 'absolute': The vmin/vmax values are set to the values of
+ the vmin,vmax arguments received by this function
+ * 'std': The vmin/vmax values are ``np.median(ar) -/+ N*np.std(ar)``, and
+ N is this functions min,max vals.
+ * 'centered': The vmin/vmax values are set to ``c -/+ m``, where by default
+ 'c' is zero and m is the max(abs(ar-c), or the two params can be user
+ specified using the kwargs vmin/vmax -> c/m.
+ vmin (number): min intensity, behavior depends on clipvals
+ vmax (number): max intensity, behavior depends on clipvals
+ min,max: alias' for vmin,vmax, throws deprecation warning
+ power (number): specifies the scaling power
+ power_offset (bool): If true, image has min value subtracted before power scaling
+ ticks (bool): Turn outer tick marks on or off
+ bordercolor (color or None): if not None, add a border of this color.
+ The color can be anything matplotlib recognizes as a color.
+ borderwidth (number):
+ returnfig (bool): if True, the function returns the tuple (figure,axis)
+ figax (None or 2-tuple): controls which matplotlib Axes object draws the image.
+ If None, generates a new figure with a single Axes instance. Otherwise, ax
+ must be a 2-tuple containing the matplotlib class instances (Figure,Axes),
+ with ar then plotted in the specified Axes instance.
+ hist (bool): if True, instead of plotting a 2D image in ax, plots a histogram of
+ the intensity values of ar, after any scaling this function has performed.
+ Plots the clipvals as dashed vertical lines
+ n_bins (int): number of hist bins
+ mask (None or boolean array): if not None, must have the same shape as 'ar'.
+ Wherever mask==True, plot the pixel normally, and where ``mask==False``,
+ pixel values are set to mask_color. If hist==True, ignore these values in the
+ histogram. If ``mask_alpha`` is specified, the mask is blended with the array
+ underneath, with 0 yielding an opaque mask and 1 yielding a fully transparent
+ mask. If ``mask_color`` is set to ``'empty'`` instead of a matplotlib.color,
+ nothing is done to pixels where ``mask==False``, allowing overlaying multiple
+ arrays in different regions of an image by invoking the ``figax` kwarg over
+ multiple calls to show
+ mask_color (color): see 'mask'
+ mask_alpha (float): see 'mask'
+ masked_intensity_range (bool): controls if masked pixel values are included when
+ determining the display value range; False indicates that all pixel values
+ will be used to determine the intensity range, True indicates only unmasked
+ pixels will be used
+ scalebar (None or dict or Bool): if None, and a DiffractionSlice or RealSlice
+ with calibrations is passed, adds a scalebar. If scalebar is not displaying the proper
+ calibration, check .calibration pixel_size and pixel_units. If None and an array is passed,
+ does not add a scalebar. If a dict is passed, it is propagated to the add_scalebar function
+ which will attempt to use it to overlay a scalebar. If True, uses calibraiton or pixelsize/pixelunits
+ for scalebar. If False, no scalebar is added.
+ show_fft (bool): if True, plots 2D-fft of array
+ show_cbar (bool) : if True, adds cbar
+ **kwargs: any keywords accepted by matplotlib's ax.matshow()
+
+ Returns:
+ if returnfig==False (default), the figure is plotted and nothing is returned.
+ if returnfig==True, return the figure and the axis.
+ """
+ if scalebar is True:
+ scalebar = {}
+
+ # Alias dep
+ if min is not None:
+ vmin = min
+ if max is not None:
+ vmax = max
+ if min is not None or max is not None:
+ warnings.warn(
+ "Warning, min/max are deprecated and will not be supported in a future version. Use vmin/vmax instead."
+ )
+ if clipvals is not None:
+ warnings.warn(
+ "Warning, clipvals is deprecated and will not be supported in a future version. Use intensity_range instead."
+ )
+ if intensity_range is None:
+ intensity_range = clipvals
+
+ # check if list is of length 1
+ ar = ar[0] if (isinstance(ar, list) and len(ar) == 1) else ar
+
+ # plot a grid if `ar` is a list, or use multichannel functionality to make an RGBA image
+ if isinstance(ar, list):
+ args = locals()
+ if "kwargs" in args.keys():
+ del args["kwargs"]
+ rm = []
+ for k in args.keys():
+ if args[k] is None:
+ rm.append(k)
+ for k in rm:
+ del args[k]
+
+ if combine_images is False:
+ # use show_grid to plot grid of images
+ from py4DSTEM.visualize.show_extention import _show_grid
+
+ if returnfig:
+ return _show_grid(**args, **kwargs)
+ else:
+ _show_grid(**args, **kwargs)
+ return
+ else:
+ # generate a multichannel combined RGB image
+
+ # init
+ num_images = len(ar)
+ hue_angles = np.linspace(0.0, 2.0 * np.pi, num_images, endpoint=False)
+ cos_total = np.zeros(ar[0].shape)
+ sin_total = np.zeros(ar[0].shape)
+ val_total = np.zeros(ar[0].shape)
+
+ # loop over images
+ from py4DSTEM.visualize import show
+
+ if show_fft:
+ n0 = ar.shape
+ w0 = np.hanning(n0[1]) * np.hanning(n0[0])[:, None]
+ ar = np.abs(np.fft.fftshift(np.fft.fft2(w0 * ar.copy())))
+ for a0 in range(num_images):
+ im = show(
+ ar[a0],
+ scaling="none",
+ intensity_range=intensity_range,
+ clipvals=clipvals,
+ vmin=vmin,
+ vmax=vmax,
+ power=power,
+ power_offset=power_offset,
+ return_ar_scaled=True,
+ show_image=False,
+ **kwargs,
+ )
+ cos_total += np.cos(hue_angles[a0]) * im
+ sin_total += np.sin(hue_angles[a0]) * im
+ # val_max = np.maximum(val_max, im)
+ val_total += im
+
+ # Assemble final image
+ sat_change = np.maximum(val_total - 1.0, 0.0)
+ ar_hsv = np.zeros((ar[0].shape[0], ar[0].shape[1], 3))
+ ar_hsv[:, :, 0] = np.mod(
+ np.arctan2(sin_total, cos_total) / (2 * np.pi), 1.0
+ )
+ ar_hsv[:, :, 1] = 1 - sat_change
+ ar_hsv[:, :, 2] = val_total # np.sqrt(cos_total**2 + sin_total**2)
+ ar_hsv = np.clip(ar_hsv, 0.0, 1.0)
+
+ # Convert to RGB
+ from matplotlib.colors import hsv_to_rgb
+
+ ar_rgb = hsv_to_rgb(ar_hsv)
+
+ # Output image for plotting
+ ar = ar_rgb
+
+ # support for native data types
+ elif not isinstance(ar, np.ndarray):
+ # support for calibration/auto-scalebars
+ if (
+ hasattr(ar, "calibration")
+ and (ar.calibration is not None)
+ and (scalebar is not False)
+ ):
+ cal = ar.calibration
+ er = ".calibration attribute must be a Calibration instance"
+ assert isinstance(cal, Calibration), er
+ if isinstance(ar, DiffractionSlice):
+ scalebar = {
+ "Nx": ar.data.shape[0],
+ "Ny": ar.data.shape[1],
+ "pixelsize": cal.get_Q_pixel_size(),
+ "pixelunits": cal.get_Q_pixel_units(),
+ "space": "Q",
+ "position": "br",
+ }
+ pixelsize = cal.get_Q_pixel_size()
+ pixelunits = cal.get_Q_pixel_units()
+ elif isinstance(ar, RealSlice):
+ scalebar = {
+ "Nx": ar.data.shape[0],
+ "Ny": ar.data.shape[1],
+ "pixelsize": cal.get_R_pixel_size(),
+ "pixelunits": cal.get_R_pixel_units(),
+ "space": "Q",
+ "position": "br",
+ }
+ pixelsize = cal.get_R_pixel_size()
+ pixelunits = cal.get_R_pixel_units()
+ # get the data
+ if hasattr(ar, "data"):
+ if ar.data.ndim == 2:
+ ar = ar.data
+ else:
+ raise Exception('input argument "ar" has unsupported type ' + str(type(ar)))
+ # Otherwise, plot one image
+ if show_fft:
+ if combine_images is False:
+ ar = np.abs(np.fft.fftshift(np.fft.fft2(ar.copy())))
+
+ # get image from a masked array
+ if mask is not None:
+ assert mask.shape == ar.shape
+ assert is_color_like(mask_color) or mask_color == "empty"
+ if isinstance(ar, np.ma.masked_array):
+ ar = np.ma.array(data=ar.data, mask=np.logical_or(ar.mask, ~mask))
+ else:
+ ar = np.ma.array(data=ar, mask=np.logical_not(mask))
+ elif isinstance(ar, np.ma.masked_array):
+ pass
+ else:
+ mask = np.zeros_like(ar, dtype=bool)
+ ar = np.ma.array(data=ar, mask=mask)
+
+ # New intensity scaling logic
+ assert scaling in ("none", "full", "log", "power", "hist")
+ assert intensity_range in (
+ "ordered",
+ "absolute",
+ "manual",
+ "minmax",
+ "std",
+ "centered",
+ )
+ if power is not None:
+ scaling = "power"
+ if scaling == "none":
+ _ar = ar.copy()
+ _mask = np.ones_like(_ar.data, dtype=bool)
+ elif scaling == "full":
+ _ar = np.reshape(ar.ravel().argsort().argsort(), ar.shape) / (ar.size - 1)
+ _mask = np.ones_like(_ar.data, dtype=bool)
+ elif scaling == "log":
+ _mask = ar.data > 0.0
+ _ar = np.zeros_like(ar.data, dtype=float)
+ _ar[_mask] = np.log(ar.data[_mask])
+ _ar[~_mask] = np.nan
+ if np.all(np.isnan(_ar)):
+ _ar[:, :] = 0
+ if intensity_range == "absolute":
+ if vmin is not None:
+ if vmin > 0.0:
+ vmin = np.log(vmin)
+ else:
+ vmin = np.min(_ar[_mask])
+ if vmax is not None:
+ vmax = np.log(vmax)
+ elif scaling == "power":
+ if power_offset is False:
+ _mask = ar.data > 0.0
+ _ar = np.zeros_like(ar.data, dtype=float)
+ _ar[_mask] = np.power(ar.data[_mask], power)
+ _ar[~_mask] = np.nan
+ else:
+ ar_min = np.min(ar)
+ if ar_min < 0:
+ _ar = np.power(ar.copy() - np.min(ar), power)
+ else:
+ _ar = np.power(ar.copy(), power)
+ _mask = np.ones_like(_ar.data, dtype=bool)
+ if intensity_range == "absolute":
+ if vmin is not None:
+ vmin = np.power(vmin, power)
+ if vmax is not None:
+ vmax = np.power(vmax, power)
+ else:
+ raise Exception
+
+ # Create the masked array applying the user mask (this is done before the
+ # vmin and vmax are determined so the mask affects those)
+ _ar = np.ma.array(data=_ar.data, mask=np.logical_or(~_mask, ar.mask))
+
+ # set scaling for boolean arrays
+ if _ar.dtype == "bool":
+ intensity_range = "absolute"
+ vmin = 0
+ vmax = 1
+
+ # Set the clipvalues
+ if intensity_range == "manual":
+ warnings.warn(
+ "Warning - intensity_range='manual' is deprecated, use 'absolute' instead"
+ )
+ intensity_range = "absolute"
+ if intensity_range == "ordered":
+ if vmin is None:
+ vmin = 0.02
+ if vmax is None:
+ vmax = 0.98
+ if masked_intensity_range:
+ vals = np.sort(
+ _ar[np.logical_and(~np.isnan(_ar), np.logical_not(_ar.mask))]
+ )
+ else:
+ vals = np.sort(_ar.data[~np.isnan(_ar)])
+ ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
+ ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
+ ind_vmin = np.max([0, ind_vmin])
+ ind_vmax = np.min([len(vals) - 1, ind_vmax])
+ vmin = vals[ind_vmin]
+ vmax = vals[ind_vmax]
+ # check if vmin and vmax are the same, defaulting to minmax scaling if needed
+ if vmax == vmin:
+ vmin = vals[0]
+ vmax = vals[-1]
+ elif intensity_range == "minmax":
+ vmin, vmax = np.nanmin(_ar), np.nanmax(_ar)
+ elif intensity_range == "absolute":
+ if vmin is None:
+ vmin = np.min(_ar)
+ print(
+ "Warning, vmin not provided, setting minimum intensity = " + str(vmin)
+ )
+ if vmax is None:
+ vmax = np.max(_ar)
+ print(
+ "Warning, vmax not provided, setting maximum intensity = " + str(vmax)
+ )
+ # assert vmin is not None and vmax is not None
+ # vmin,vmax = vmin,vmax
+ elif intensity_range == "std":
+ assert vmin is not None and vmax is not None
+ m, s = np.nanmedian(_ar), np.nanstd(_ar)
+ vmin = m + vmin * s
+ vmax = m + vmax * s
+ elif intensity_range == "centered":
+ c = np.nanmean(_ar) if vmin is None else vmin
+ m = np.nanmax(np.ma.abs(c - _ar)) if vmax is None else vmax
+ vmin = c - m
+ vmax = c + m
+ else:
+ raise Exception
+
+ if show_image:
+ # Create or attach to the appropriate Figure and Axis
+ if figax is None:
+ fig, ax = plt.subplots(1, 1, figsize=figsize)
+ else:
+ fig, ax = figax
+ assert isinstance(fig, Figure)
+ assert isinstance(ax, Axes)
+
+ # Create colormap with mask_color for bad values
+ cm = copy(plt.get_cmap(cmap))
+ if mask_color == "empty":
+ cm.set_bad(alpha=0)
+ else:
+ cm.set_bad(color=mask_color)
+
+ # Plot the image
+ if not hist:
+ cax = ax.matshow(_ar, vmin=vmin, vmax=vmax, cmap=cm, **kwargs)
+ if np.any(_ar.mask):
+ mask_display = np.ma.array(data=_ar.data, mask=~_ar.mask)
+ ax.matshow(
+ mask_display, cmap=cmap, alpha=mask_alpha, vmin=vmin, vmax=vmax
+ )
+ if show_cbar:
+ ax_divider = make_axes_locatable(ax)
+ c_axis = ax_divider.append_axes("right", size="7%")
+ fig.colorbar(cax, cax=c_axis)
+ # ...or, plot its histogram
+ else:
+ hist, bin_edges = np.histogram(
+ _ar, bins=np.linspace(np.min(_ar), np.max(_ar), num=n_bins)
+ )
+ w = bin_edges[1] - bin_edges[0]
+ x = bin_edges[:-1] + w / 2.0
+ ax.bar(x, hist, width=w)
+ ax.vlines((vmin, vmax), 0, ax.get_ylim()[1], color="k", ls="--")
+
+ # add a title
+ if title is not None:
+ ax.set_title(title)
+
+ # Add a border
+ if bordercolor is not None:
+ for s in ["bottom", "top", "left", "right"]:
+ ax.spines[s].set_color(bordercolor)
+ ax.spines[s].set_linewidth(borderwidth)
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+ # Add shape/point overlays
+ if rectangle is not None:
+ add_rectangles(ax, rectangle)
+ if circle is not None:
+ add_circles(ax, circle)
+ if annulus is not None:
+ add_annuli(ax, annulus)
+ if ellipse is not None:
+ add_ellipses(ax, ellipse)
+ if points is not None:
+ add_points(ax, points)
+ if grid_overlay is not None:
+ add_grid_overlay(ax, grid_overlay)
+
+ # Parse arguments for scale/coordinate overlays
+ if calibration is not None:
+ assert isinstance(calibration, Calibration)
+ assert space in ("Q", "R")
+ # pixel size/units
+ if pixelsize is None and calibration is None:
+ pixelsize = 1
+ if pixelsize is not None:
+ pass
+ else:
+ if space == "Q":
+ pixelsize = calibration.get_Q_pixel_size()
+ else:
+ pixelsize = calibration.get_R_pixel_size()
+ if pixelunits is None and calibration is None:
+ pixelunits = "pixels"
+ if pixelunits is not None:
+ pass
+ else:
+ if space == "Q":
+ pixelunits = calibration.get_Q_pixel_units()
+ else:
+ pixelunits = calibration.get_R_pixel_units()
+ # origin
+ if space == "Q":
+ if x0 is not None:
+ pass
+ elif calibration is not None:
+ try:
+ x0 = calibration.get_origin(rx, ry)[0]
+ except AttributeError:
+ raise Exception(
+ "The Calibration instance passed does not contain a value for qx0"
+ )
+ else:
+ x0 = 0
+ if y0 is not None:
+ pass
+ elif calibration is not None:
+ try:
+ y0 = calibration.get_origin(rx, ry)[1]
+ except AttributeError:
+ raise Exception(
+ "The Calibration instance passed does not contain a value for qy0"
+ )
+ else:
+ y0 = 0
+ else:
+ x0 = x0 if x0 is not None else 0
+ y0 = y0 if y0 is not None else 0
+ # ellipticity
+ if space == "Q":
+ if a is not None:
+ pass
+ elif calibration is not None:
+ try:
+ a = calibration.get_a(rx, ry)
+ except AttributeError:
+ raise Exception(
+ "The Calibration instance passed does not contain a value for a"
+ )
+ else:
+ a = 1
+ if theta is not None:
+ pass
+ elif calibration is not None:
+ try:
+ theta = calibration.get_theta(rx, ry)
+ except AttributeError:
+ raise Exception(
+ "The Calibration instance passed does not contain a value for theta"
+ )
+ else:
+ theta = 0
+ else:
+ a = a if a is not None else 1
+ theta = theta if theta is not None else 0
+
+ # Add a scalebar
+ if scalebar is not None and scalebar is not False:
+ # Add the grid
+ scalebar["Nx"] = ar.shape[0]
+ scalebar["Ny"] = ar.shape[1]
+ scalebar["pixelsize"] = pixelsize
+ scalebar["pixelunits"] = pixelunits
+ scalebar["space"] = space
+ # determine good default scale bar fontsize
+ if figax is not None:
+ bbox = figax[1].get_window_extent()
+ dpi = figax[0].dpi
+ size = (bbox.width / dpi, bbox.height / dpi)
+ scalebar["labelsize"] = np.min(np.array(size)) * 3.0
+ if "labelsize" not in scalebar.keys():
+ scalebar["labelsize"] = np.min(np.array(figsize)) * 2.0
+ add_scalebar(ax, scalebar)
+
+ # Add cartesian grid
+ if cartesian_grid is not None:
+ Nx, Ny = ar.shape
+ assert isinstance(
+ x0, Number
+ ), "Error: x0 must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+ assert isinstance(
+ y0, Number
+ ), "Error: y0 must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+ cartesian_grid["x0"], cartesian_grid["y0"] = x0, y0
+ cartesian_grid["Nx"], cartesian_grid["Ny"] = Nx, Ny
+ cartesian_grid["pixelsize"] = pixelsize
+ cartesian_grid["pixelunits"] = pixelunits
+ cartesian_grid["space"] = space
+ add_cartesian_grid(ax, cartesian_grid)
+
+ # Add polarelliptical grid
+ if polarelliptical_grid is not None:
+ Nx, Ny = ar.shape
+ assert isinstance(
+ x0, Number
+ ), "Error: x0 must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+ assert isinstance(
+ y0, Number
+ ), "Error: y0 must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+ assert isinstance(
+ e, Number
+ ), "Error: e must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+ assert isinstance(
+ theta, Number
+ ), "Error: theta must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+ polarelliptical_grid["x0"], polarelliptical_grid["y0"] = x0, y0
+ polarelliptical_grid["e"], polarelliptical_grid["theta"] = e, theta
+ polarelliptical_grid["Nx"], polarelliptical_grid["Ny"] = Nx, Ny
+ polarelliptical_grid["pixelsize"] = pixelsize
+ polarelliptical_grid["pixelunits"] = pixelunits
+ polarelliptical_grid["space"] = space
+ add_polarelliptical_grid(ax, polarelliptical_grid)
+
+ # Add r-theta grid
+ if rtheta_grid is not None:
+ add_rtheta_grid(ax, rtheta_grid)
+
+ # tick marks
+ if ticks is False:
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+ # Show or return
+ returnval = []
+ if returnfig:
+ returnval.append((fig, ax))
+ if return_ar_scaled:
+ ar_scaled = np.clip((ar - vmin) / (vmax - vmin), 0.0, 1.0)
+ returnval.append(ar_scaled)
+ if return_intensity_range:
+ if scaling == "log":
+ vmin, vmax = np.power(np.e, vmin), np.power(np.e, vmax)
+ elif scaling == "power":
+ vmin, vmax = np.power(vmin, 1 / power), np.power(vmax, 1 / power)
+ returnval.append((vmin, vmax))
+ if returncax:
+ returnval.append(cax)
+ if len(returnval) == 0:
+ if figax is None:
+ plt.show()
+ return
+ elif (len(returnval)) == 1:
+ return returnval[0]
+ else:
+ return tuple(returnval)
+
+
+def show_hist(
+ arr,
+ bins=200,
+ vlines=None,
+ vlinecolor="k",
+ vlinestyle="--",
+ returnhist=False,
+ returnfig=False,
+):
+ """
+ Visualization function to show histogram from any ndarray (arr).
+
+ Accepts:
+ arr (ndarray) any array
+ bins (int) number of bins that the intensity values will be sorted
+ into for histogram
+ returnhist (bool) determines whether or not the histogram values are
+ returned (see Returns)
+ returnfig (bool) determines whether or not figure and its axis are
+ returned (see Returns)
+
+ Returns:
+ If
+ returnhist==False and returnfig==False returns nothing
+ returnhist==True and returnfig==True returns (counts,bin_edges) the histogram
+ values and bin edge locations
+ returnhist==False and returnfig==True returns (fig,ax), the Figure and Axis
+ returnhist==True and returnfig==True returns (hist,bin_edges),(fig,ax)
+ """
+ counts, bin_edges = np.histogram(arr, bins=bins, range=(np.min(arr), np.max(arr)))
+ bin_width = bin_edges[1] - bin_edges[0]
+ bin_centers = bin_edges[:-1] + bin_width / 2
+
+ fig, ax = plt.subplots(1, 1)
+ ax.bar(bin_centers, counts, width=bin_width, align="center")
+ plt.ylabel("Counts")
+ plt.xlabel("Intensity")
+ if vlines is not None:
+ ax.vlines(vlines, 0, np.max(counts), color=vlinecolor, ls=vlinestyle)
+ if not returnhist and not returnfig:
+ plt.show()
+ return
+ elif returnhist and not returnfig:
+ return counts, bin_edges
+ elif not returnhist and returnfig:
+ return fig, ax
+ else:
+ return (counts, bin_edges), (fig, ax)
+
+
+# Show functions with overlaid scalebars and/or coordinate system gridlines
+
+
+def show_Q(
+ ar,
+ scalebar=True,
+ grid=False,
+ polargrid=False,
+ Q_pixel_size=None,
+ Q_pixel_units=None,
+ calibration=None,
+ rx=None,
+ ry=None,
+ qx0=None,
+ qy0=None,
+ e=None,
+ theta=None,
+ scalebarloc=0,
+ scalebarsize=None,
+ scalebarwidth=None,
+ scalebartext=None,
+ scalebartextloc="above",
+ scalebartextsize=12,
+ gridspacing=None,
+ gridcolor="w",
+ majorgridlines=True,
+ majorgridlw=1,
+ majorgridls=":",
+ minorgridlines=True,
+ minorgridlw=0.5,
+ minorgridls=":",
+ gridlabels=False,
+ gridlabelsize=12,
+ gridlabelcolor="k",
+ alpha=0.35,
+ **kwargs,
+):
+ """
+ Shows a diffraction space image with options for several overlays to define the scale,
+ including a scalebar, a cartesian grid, or a polar / polar-elliptical grid.
+
+ Regardless of which overlay is requested, the function must recieve either values
+ for Q_pixel_size and Q_pixel_units, or a Calibration instance containing these values.
+ If both are passed, the absolutely passed values take precedence.
+ If a cartesian grid is requested, (qx0,qy0) are required, either passed absolutely or
+ passed as a Calibration instance with the appropriate (rx,ry) value.
+ If a polar grid is requested, (qx0,qy0,e,theta) are required, again either absolutely
+ or via a Calibration instance.
+
+ Any arguments accepted by the show() function (e.g. image scaling, clipvalues, etc)
+ may be passed to this function as kwargs.
+ """
+ # Check inputs
+ assert isinstance(ar, np.ndarray) and len(ar.shape) == 2
+ if calibration is not None:
+ assert isinstance(calibration, Calibration)
+ try:
+ Q_pixel_size = (
+ Q_pixel_size if Q_pixel_size is not None else calibration.get_Q_pixel_size()
+ )
+ except AttributeError:
+ raise Exception(
+ "Q_pixel_size must be specified, either in calibration or absolutely"
+ )
+ try:
+ Q_pixel_units = (
+ Q_pixel_units
+ if Q_pixel_units is not None
+ else calibration.get_Q_pixel_units()
+ )
+ except AttributeError:
+ raise Exception(
+ "Q_pixel_size must be specified, either in calibration or absolutely"
+ )
+ if grid or polargrid:
+ try:
+ qx0 = qx0 if qx0 is not None else calibration.get_qx0(rx, ry)
+ except AttributeError:
+ raise Exception(
+ "qx0 must be specified, either in calibration or absolutely"
+ )
+ try:
+ qy0 = qy0 if qy0 is not None else calibration.get_qy0(rx, ry)
+ except AttributeError:
+ raise Exception(
+ "qy0 must be specified, either in calibration or absolutely"
+ )
+ assert isinstance(
+ qx0, Number
+ ), "Error: qx0 must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+ assert isinstance(
+ qy0, Number
+ ), "Error: qy0 must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+ if polargrid:
+ e = e if e is not None else calibration.get_e(rx, ry)
+ theta = theta if theta is not None else calibration.get_theta(rx, ry)
+ assert isinstance(
+ e, Number
+ ), "Error: e must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+ assert isinstance(
+ theta, Number
+ ), "Error: theta must be a number. If a Coordinate system was passed, try passing a position (rx,ry)."
+
+ # Make the plot
+ fig, ax = show(ar, returnfig=True, **kwargs)
+
+ # Add a scalebar
+ if scalebar:
+ pass
+
+ # Add a cartesian grid
+ if grid:
+ # parse arguments
+ assert isinstance(majorgridlines, bool)
+ majorgridlw = majorgridlw if majorgridlines else 0
+ assert isinstance(majorgridlw, Number)
+ assert isinstance(majorgridls, str)
+ assert isinstance(minorgridlines, bool)
+ minorgridlw = minorgridlw if minorgridlines else 0
+ assert isinstance(minorgridlw, Number)
+ assert isinstance(minorgridls, str)
+ assert is_color_like(gridcolor)
+ assert isinstance(gridlabels, bool)
+ assert isinstance(gridlabelsize, Number)
+ assert is_color_like(gridlabelcolor)
+ if gridspacing is not None:
+ assert isinstance(gridspacing, Number)
+
+ Q_Nx, Q_Ny = ar.shape
+ assert qx0 < Q_Nx and qy0 < Q_Ny
+
+ # Get the major grid-square size
+ if gridspacing is None:
+ D = np.mean((Q_Nx * Q_pixel_size, Q_Ny * Q_pixel_size)) / 2.0
+ exp = int(log(D, 10))
+ if np.sign(log(D, 10)) < 0:
+ exp -= 1
+ base = D / (10**exp)
+ if base >= 1 and base < 1.25:
+ _gridspacing = 0.4
+ elif base >= 1.25 and base < 1.75:
+ _gridspacing = 0.5
+ elif base >= 1.75 and base < 2.5:
+ _gridspacing = 0.75
+ elif base >= 2.5 and base < 3.25:
+ _gridspacing = 1
+ elif base >= 3.25 and base < 4.75:
+ _gridspacing = 1.5
+ elif base >= 4.75 and base < 6:
+ _gridspacing = 2
+ elif base >= 6 and base < 8:
+ _gridspacing = 2.5
+ elif base >= 8 and base < 10:
+ _gridspacing = 3
+ else:
+ raise Exception("how did this happen?? base={}".format(base))
+ gridspacing = _gridspacing * 10**exp
+
+ # Get the positions and label for the major gridlines
+ xmin = (-qx0) * Q_pixel_size
+ xmax = (Q_Nx - 1 - qx0) * Q_pixel_size
+ ymin = (-qy0) * Q_pixel_size
+ ymax = (Q_Ny - 1 - qy0) * Q_pixel_size
+ xticksmajor = np.concatenate(
+ (
+ -1 * np.arange(0, np.abs(xmin), gridspacing)[1:][::-1],
+ np.arange(0, xmax, gridspacing),
+ )
+ )
+ yticksmajor = np.concatenate(
+ (
+ -1 * np.arange(0, np.abs(ymin), gridspacing)[1:][::-1],
+ np.arange(0, ymax, gridspacing),
+ )
+ )
+ xticklabels = xticksmajor.copy()
+ yticklabels = yticksmajor.copy()
+ xticksmajor = (xticksmajor - xmin) / Q_pixel_size
+ yticksmajor = (yticksmajor - ymin) / Q_pixel_size
+ # Labels
+ exp_spacing = int(np.round(log(gridspacing, 10), 6))
+ if np.sign(log(gridspacing, 10)) < 0:
+ exp_spacing -= 1
+ xticklabels = xticklabels / (10**exp_spacing)
+ yticklabels = yticklabels / (10**exp_spacing)
+ if exp_spacing == 1:
+ xticklabels *= 10
+ yticklabels *= 10
+ if _gridspacing in (0.4, 0.75, 1.5, 2.5) and exp_spacing != 1:
+ xticklabels = ["{:.1f}".format(n) for n in xticklabels]
+ yticklabels = ["{:.1f}".format(n) for n in yticklabels]
+ else:
+ xticklabels = ["{:.0f}".format(n) for n in xticklabels]
+ yticklabels = ["{:.0f}".format(n) for n in yticklabels]
+
+ # Add the grid
+ ax.set_xticks(yticksmajor)
+ ax.set_yticks(xticksmajor)
+ ax.xaxis.set_ticks_position("bottom")
+ if gridlabels:
+ ax.set_xticklabels(yticklabels, size=gridlabelsize, color=gridlabelcolor)
+ ax.set_yticklabels(xticklabels, size=gridlabelsize, color=gridlabelcolor)
+ if exp_spacing in (0, 1):
+ ax.set_xlabel(r"$q_y$ (" + Q_pixel_units + ")")
+ ax.set_ylabel(r"$q_x$ (" + Q_pixel_units + ")")
+ else:
+ ax.set_xlabel(
+ r"$q_y$ (" + Q_pixel_units + " e" + str(exp_spacing) + ")"
+ )
+ ax.set_ylabel(
+ r"$q_x$ (" + Q_pixel_units + " e" + str(exp_spacing) + ")"
+ )
+ else:
+ ax.set_xticklabels([])
+ ax.set_yticklabels([])
+ ax.grid(
+ linestyle=majorgridls, linewidth=majorgridlw, color=gridcolor, alpha=alpha
+ )
+
+ # Add the grid
+ if majorgridlines:
+ add_cartesian_grid(
+ ax,
+ d={
+ "x0": qx0,
+ "y0": qy0,
+ "spacing": gridspacing,
+ "majorlw": majorgridlw,
+ "majorls": majorgridls,
+ "minorlw": minorgridlw,
+ "minorls": minorgridls,
+ "color": gridcolor,
+ "label": gridlabels,
+ "labelsize": gridlabelsize,
+ "labelcolor": gridlabelcolor,
+ "alpha": alpha,
+ },
+ )
+ if minorgridlines:
+ add_cartesian_grid(
+ ax,
+ d={
+ "x0": qx0,
+ "y0": qy0,
+ "spacing": gridspacing,
+ "majorlw": majorgridlw,
+ "majorls": majorgridls,
+ "minorlw": minorgridlw,
+ "minorls": minorgridls,
+ "color": gridcolor,
+ "label": gridlabels,
+ "labelsize": gridlabelsize,
+ "labelcolor": gridlabelcolor,
+ "alpha": alpha,
+ },
+ )
+
+ # Add a polar-elliptical grid
+ if polargrid:
+ pass
+
+ return
+
+
+# Shape overlays
+
+
+def show_rectangles(
+ ar,
+ lims=(0, 1, 0, 1),
+ color="r",
+ fill=True,
+ alpha=0.25,
+ linewidth=2,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ Visualization function which plots a 2D array with one or more overlayed rectangles.
+ lims is specified in the order (x0,xf,y0,yf). The rectangle bounds begin at the upper
+ left corner of (x0,y0) and end at the upper left corner of (xf,yf) -- i.e inclusive
+ in the lower bound, exclusive in the upper bound -- so that the boxed region encloses
+ the area of array ar specified by ar[x0:xf,y0:yf].
+
+ To overlay one rectangle, lims must be a single 4-tuple. To overlay N rectangles,
+ lims must be a list of N 4-tuples. color, fill, and alpha may each be single values,
+ which are then applied to all the rectangles, or a length N list.
+
+ See the docstring for py4DSTEM.visualize.show() for descriptions of all input
+ parameters not listed below.
+
+ Accepts:
+ lims (4-tuple, or list of N 4-tuples) the rectangle bounds (x0,xf,y0,yf)
+ color (valid matplotlib color, or list of N colors)
+ fill (bool or list of N bools) filled in or empty rectangles
+ alpha (number, 0 to 1) transparency
+ linewidth (number)
+
+ Returns:
+ If returnfig==False (default), the figure is plotted and nothing is returned.
+ If returnfig==False, the figure and its one axis are returned, and can be
+ further edited.
+ """
+ fig, ax = show(ar, returnfig=True, **kwargs)
+ d = {
+ "lims": lims,
+ "color": color,
+ "fill": fill,
+ "alpha": alpha,
+ "linewidth": linewidth,
+ }
+ add_rectangles(ax, d)
+
+ if not returnfig:
+ return
+ else:
+ return fig, ax
+
+
+def show_circles(
+ ar,
+ center,
+ R,
+ color="r",
+ fill=True,
+ alpha=0.3,
+ linewidth=2,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ Visualization function which plots a 2D array with one or more overlayed circles.
+ To overlay one circle, center must be a single 2-tuple. To overlay N circles,
+ center must be a list of N 2-tuples. color, fill, and alpha may each be single values,
+ which are then applied to all the circles, or a length N list.
+
+ See the docstring for py4DSTEM.visualize.show() for descriptions of all input
+ parameters not listed below.
+
+ Accepts:
+ ar (2D array) the data
+ center (2-tuple, or list of N 2-tuples) the center of the circle (x0,y0)
+ R (number of list of N numbers) the circles radius
+ color (valid matplotlib color, or list of N colors)
+ fill (bool or list of N bools) filled in or empty rectangles
+ alpha (number, 0 to 1) transparency
+ linewidth (number)
+
+ Returns:
+ If returnfig==False (default), the figure is plotted and nothing is returned.
+ If returnfig==False, the figure and its one axis are returned, and can be
+ further edited.
+ """
+
+ fig, ax = show(ar, returnfig=True, **kwargs)
+
+ d = {
+ "center": center,
+ "R": R,
+ "color": color,
+ "fill": fill,
+ "alpha": alpha,
+ "linewidth": linewidth,
+ }
+ add_circles(ax, d)
+
+ if not returnfig:
+ return
+ else:
+ return fig, ax
+
+
+def show_ellipses(
+ ar,
+ center,
+ a,
+ b,
+ theta,
+ color="r",
+ fill=True,
+ alpha=0.3,
+ linewidth=2,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ Visualization function which plots a 2D array with one or more overlayed ellipses.
+ To overlay one ellipse, center must be a single 2-tuple. To overlay N circles,
+ center must be a list of N 2-tuples. Similarly, the remaining ellipse parameters -
+ a, e, and theta - must each be a single number or a len-N list. color, fill, and
+ alpha may each be single values, which are then applied to all the circles, or
+ length N lists.
+
+ See the docstring for py4DSTEM.visualize.show() for descriptions of all input
+ parameters not listed below.
+
+ Accepts:
+ center (2-tuple, or list of N 2-tuples) the center of the circle (x0,y0)
+ a (number or list of N numbers) the semimajor axis length
+ e (number or list of N numbers) ratio of semiminor/semimajor length
+ theta (number or list of N numbers) the tilt angle in radians
+ color (valid matplotlib color, or list of N colors)
+ fill (bool or list of N bools) filled in or empty rectangles
+ alpha (number, 0 to 1) transparency
+ linewidth (number)
+
+ Returns:
+ If returnfig==False (default), the figure is plotted and nothing is returned.
+ If returnfig==False, the figure and its one axis are returned, and can be
+ further edited.
+ """
+ fig, ax = show(ar, returnfig=True, **kwargs)
+ d = {
+ "center": center,
+ "a": a,
+ "b": b,
+ "theta": theta,
+ "color": color,
+ "fill": fill,
+ "alpha": alpha,
+ "linewidth": linewidth,
+ }
+ add_ellipses(ax, d)
+
+ if not returnfig:
+ return
+ else:
+ return fig, ax
+
+
+def show_annuli(
+ ar,
+ center,
+ radii,
+ color="r",
+ fill=True,
+ alpha=0.3,
+ linewidth=2,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ Visualization function which plots a 2D array with one or more overlayed annuli.
+ To overlay one annulus, center must be a single 2-tuple. To overlay N annuli,
+ center must be a list of N 2-tuples. color, fill, and alpha may each be single values,
+ which are then applied to all the circles, or a length N list.
+
+ See the docstring for py4DSTEM.visualize.show() for descriptions of all input
+ parameters not listed below.
+
+ Accepts:
+ center (2-tuple, or list of N 2-tuples) the center of the annulus (x0,y0)
+ radii (2-tuple, or list of N 2-tuples) the inner and outer radii
+ color (string of list of N strings)
+ fill (bool or list of N bools) filled in or empty rectangles
+ alpha (number, 0 to 1) transparency
+ linewidth (number)
+
+ Returns:
+ If returnfig==False (default), the figure is plotted and nothing is returned.
+ If returnfig==False, the figure and its one axis are returned, and can be
+ further edited.
+ """
+ fig, ax = show(ar, returnfig=True, **kwargs)
+ d = {
+ "center": center,
+ "radii": radii,
+ "color": color,
+ "fill": fill,
+ "alpha": alpha,
+ "linewidth": linewidth,
+ }
+ add_annuli(ax, d)
+
+ if not returnfig:
+ return
+ else:
+ return fig, ax
+
+
+def show_points(
+ ar,
+ x,
+ y,
+ s=1,
+ scale=50,
+ alpha=1,
+ pointcolor="r",
+ open_circles=False,
+ title=None,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ Plots a 2D array with one or more points.
+ x and y are the point centers and must have the same length, N.
+ s is the relative point sizes, and must have length 1 or N.
+ scale is the size of the largest point.
+ pointcolor have length 1 or N.
+
+ Accepts:
+ ar (array) the image
+ x,y (number or iterable of numbers) the point positions
+ s (number or iterable of numbers) the relative point sizes
+ scale (number) the maximum point size
+ title (str) title for plot
+ pointcolor
+ alpha
+
+ Returns:
+ If returnfig==False (default), the figure is plotted and nothing is returned.
+ If returnfig==False, the figure and its one axis are returned, and can be
+ further edited.
+ """
+ fig, ax = show(ar, title=title, returnfig=True, **kwargs)
+ d = {
+ "x": x,
+ "y": y,
+ "s": s,
+ "scale": scale,
+ "pointcolor": pointcolor,
+ "alpha": alpha,
+ "open_circles": open_circles,
+ }
+ add_points(ax, d)
+
+ if not returnfig:
+ return
+ else:
+ return fig, ax
diff --git a/py4DSTEM/visualize/show_extention.py b/py4DSTEM/visualize/show_extention.py
new file mode 100644
index 000000000..8fdf522a2
--- /dev/null
+++ b/py4DSTEM/visualize/show_extention.py
@@ -0,0 +1,34 @@
+from py4DSTEM.visualize.vis_grid import show_image_grid
+
+
+def _show_grid(**kwargs):
+ """ """
+ assert "ar" in kwargs.keys()
+ ar = kwargs["ar"]
+ del kwargs["ar"]
+
+ # parse grid of images
+ if isinstance(ar[0], list):
+ assert all([isinstance(ar[i], list) for i in range(len(ar))])
+ W = len(ar[0])
+ H = len(ar)
+
+ def get_ar(i):
+ h = i // W
+ w = i % W
+ try:
+ return ar[h][w]
+ except IndexError:
+ return
+
+ else:
+ W = len(ar)
+ H = 1
+
+ def get_ar(i):
+ return ar[i]
+
+ if kwargs["returnfig"]:
+ return show_image_grid(get_ar, H, W, **kwargs)
+ else:
+ show_image_grid(get_ar, H, W, **kwargs)
diff --git a/py4DSTEM/visualize/vis_RQ.py b/py4DSTEM/visualize/vis_RQ.py
new file mode 100644
index 000000000..85c0eb042
--- /dev/null
+++ b/py4DSTEM/visualize/vis_RQ.py
@@ -0,0 +1,599 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.axes import Axes
+
+from py4DSTEM.visualize.show import show, show_points
+
+
+def show_selected_dp(
+ datacube,
+ image,
+ rx,
+ ry,
+ figsize=(12, 6),
+ returnfig=False,
+ pointsize=50,
+ pointcolor="r",
+ scaling="log",
+ **kwargs,
+):
+ """ """
+ dp = datacube.data[rx, ry, :, :]
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
+ _, _ = show_points(
+ image,
+ rx,
+ ry,
+ scale=pointsize,
+ pointcolor=pointcolor,
+ figax=(fig, ax1),
+ returnfig=True,
+ )
+ _, _ = show(dp, figax=(fig, ax2), scaling=scaling, returnfig=True)
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, (ax1, ax2)
+
+
+def show_RQ(
+ realspace_image,
+ diffractionspace_image,
+ realspace_pdict={},
+ diffractionspace_pdict={"scaling": "log"},
+ figsize=(12, 6),
+ returnfig=False,
+):
+ """
+ Shows side-by-side real/reciprocal space images.
+
+ Accepts:
+ realspace_image (2D array)
+ diffractionspace_image (2D array)
+ realspace_pdict (dictionary) arguments and values to pass
+ to the show() fn for the real space image
+ diffractionspace_pdict (dictionary)
+ """
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
+ show(realspace_image, figax=(fig, ax1), **realspace_pdict)
+ show(diffractionspace_image, figax=(fig, ax2), **diffractionspace_pdict)
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, (ax1, ax2)
+
+
+def ax_addvector(ax, vx, vy, vlength, x0, y0, width=1, color="r"):
+ """
+ Adds a vector to the subplot at ax.
+
+ Accepts:
+ ax (matplotlib subplot)
+ vx,vy (numbers) x,y components of the vector
+ Only the orientation is used, vector is
+ normalized and rescaled by
+ vlength (number) the vector length
+ x0,y0 (numbers) the origin / vector tail position
+ """
+ vL = np.hypot(vx, vy)
+ vx, vy = vlength * vx / vL, vlength * vy / vL
+ ax.arrow(y0, x0, vy, vx, color=color, width=width, length_includes_head=True)
+
+
+def ax_addvector_RtoQ(ax, vx, vy, vlength, x0, y0, QR_rotation, width=1, color="r"):
+ """
+ Adds a vector to the subplot at ax, where the vector (vx,vy) passed
+ to the function is in real space and the plotted vector is transformed
+ into and plotted in reciprocal space.
+
+ Accepts:
+ ax (matplotlib subplot)
+ vx,vy (numbers) x,y components of the vector,
+ in *real* space. Only the orientation is used,
+ vector is normalized and rescaled by
+ vlength (number) the vector length, in *reciprocal*
+ space
+ x0,y0 (numbers) the origin / vector tail position,
+ in *reciprocal* space
+ QR_rotation (number) the offset angle between real and
+ diffraction space. Specifically, this is
+ the counterclockwise rotation of real space
+ with respect to diffraction space. In degrees.
+ """
+ from py4DSTEM.process.calibration.rotation import get_Qvector_from_Rvector
+
+ _, _, vx, vy = get_Qvector_from_Rvector(vx, vy, QR_rotation)
+ vx, vy = vx * vlength, vy * vlength
+ ax.arrow(y0, x0, vy, vx, color=color, width=width, length_includes_head=True)
+
+
+def ax_addvector_QtoR(ax, vx, vy, vlength, x0, y0, QR_rotation, width=1, color="r"):
+ """
+ Adds a vector to the subplot at ax, where the vector (vx,vy) passed
+ to the function is in reciprocal space and the plotted vector is
+ transformed into and plotted in real space.
+
+ Accepts:
+ ax (matplotlib subplot)
+ vx,vy (numbers) x,y components of the vector,
+ in *reciprocal* space. Only the orientation is
+ used, vector is normalized and rescaled by
+ vlength (number) the vector length, in *real* space
+ x0,y0 (numbers) the origin / vector tail position,
+ in *real* space
+ QR_rotation (number) the offset angle between real and
+ diffraction space. Specifically, this is
+ the counterclockwise rotation of real space
+ with respect to diffraction space. In degrees.
+ """
+ from py4DSTEM.process.calibration.rotation import get_Rvector_from_Qvector
+
+ vx, vy, _, _ = get_Rvector_from_Qvector(vx, vy, QR_rotation)
+ vx, vy = vx * vlength, vy * vlength
+ ax.arrow(y0, x0, vy, vx, color=color, width=width, length_includes_head=True)
+
+
+def show_RQ_vector(
+ realspace_image,
+ diffractionspace_image,
+ realspace_pdict,
+ diffractionspace_pdict,
+ vx,
+ vy,
+ vlength_R,
+ vlength_Q,
+ x0_R,
+ y0_R,
+ x0_Q,
+ y0_Q,
+ QR_rotation,
+ vector_space="R",
+ width_R=1,
+ color_R="r",
+ width_Q=1,
+ color_Q="r",
+ figsize=(12, 6),
+ returnfig=False,
+):
+ """
+ Shows side-by-side real/reciprocal space images with a vector
+ overlaid in each showing corresponding directions.
+
+ Accepts:
+ realspace_image (2D array)
+ diffractionspace_image (2D array)
+ realspace_pdict (dictionary) arguments and values to pass
+ to the show() fn for the real space image
+ diffractionspace_pdict (dictionary)
+ vx,vy (numbers) x,y components of the vector
+ in either real or diffraction space,
+ depending on the value of vector_space.
+ Note (vx,vy) is used for the orientation
+ only - the two vectors are normalized
+ and rescaled by
+ vlength_R,vlength_Q (number) the vector length in each
+ space, in pixels
+ x0_R,y0_R,x0_Q,y0_Q (numbers) the origins / vector tail positions
+ QR_rotation (number) the offset angle between real and
+ diffraction space. Specifically, this is
+ the counterclockwise rotation of real space
+ with respect to diffraction space. In degrees.
+ vector_space (string) must be 'R' or 'Q'. Specifies
+ whether the (vx,vy) values passed to this
+ function describes a real or diffracation
+ space vector.
+ """
+ assert vector_space in ("R", "Q")
+ fig, (ax1, ax2) = show_RQ(
+ realspace_image,
+ diffractionspace_image,
+ realspace_pdict,
+ diffractionspace_pdict,
+ figsize=figsize,
+ returnfig=True,
+ )
+ if vector_space == "R":
+ ax_addvector(ax1, vx, vy, vlength_R, x0_R, y0_R, width=width_R, color=color_R)
+ ax_addvector_RtoQ(
+ ax2,
+ vx,
+ vy,
+ vlength_Q,
+ x0_Q,
+ y0_Q,
+ QR_rotation,
+ width=width_Q,
+ color=color_Q,
+ )
+ else:
+ ax_addvector(ax2, vx, vy, vlength_Q, x0_Q, y0_Q, width=width_Q, color=color_Q)
+ ax_addvector_QtoR(
+ ax1,
+ vx,
+ vy,
+ vlength_R,
+ x0_R,
+ y0_R,
+ QR_rotation,
+ width=width_R,
+ color=color_R,
+ )
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, (ax1, ax2)
+
+
+def show_RQ_vectors(
+ realspace_image,
+ diffractionspace_image,
+ realspace_pdict,
+ diffractionspace_pdict,
+ vx,
+ vy,
+ vlength_R,
+ vlength_Q,
+ x0_R,
+ y0_R,
+ x0_Q,
+ y0_Q,
+ QR_rotation,
+ vector_space="R",
+ width_R=1,
+ color_R="r",
+ width_Q=1,
+ color_Q="r",
+ figsize=(12, 6),
+ returnfig=False,
+):
+ """
+ Shows side-by-side real/reciprocal space images with several vectors
+ overlaid in each showing corresponding directions.
+
+ Accepts:
+ realspace_image (2D array)
+ diffractionspace_image (2D array)
+ realspace_pdict (dictionary) arguments and values to pass
+ to the show() fn for the real space image
+ diffractionspace_pdict (dictionary)
+ vx,vy (1D arrays) x,y components of the vectors
+ in either real or diffraction space,
+ depending on the value of vector_space.
+ Note (vx,vy) is used for the orientation
+ only - the two vectors are normalized
+ and rescaled by
+ vlength_R,vlenght_Q (number) the vector length in each
+ space, in pixels
+ x0_R,y0_R,x0_Q,y0_Q (numbers) the origins / vector tail positions
+ QR_rotation (number) the offset angle between real and
+ diffraction space. Specifically, this is
+ the counterclockwise rotation of real space
+ with respect to diffraction space. In degrees.
+ vector_space (string) must be 'R' or 'Q'. Specifies
+ whether the (vx,vy) values passed to this
+ function describes a real or diffracation
+ space vector.
+ """
+ assert vector_space in ("R", "Q")
+ assert len(vx) == len(vy)
+ if isinstance(color_R, tuple) or isinstance(color_R, list):
+ assert len(vx) == len(color_R)
+ else:
+ color_R = [color_R for i in range(len(vx))]
+ if isinstance(color_Q, tuple) or isinstance(color_Q, list):
+ assert len(vx) == len(color_Q)
+ else:
+ color_Q = [color_Q for i in range(len(vx))]
+
+ fig, (ax1, ax2) = show_RQ(
+ realspace_image,
+ diffractionspace_image,
+ realspace_pdict,
+ diffractionspace_pdict,
+ figsize=figsize,
+ returnfig=True,
+ )
+ for x, y, cR, cQ in zip(vx, vy, color_R, color_Q):
+ if vector_space == "R":
+ ax_addvector(ax1, x, y, vlength_R, x0_R, y0_R, width=width_R, color=cR)
+ ax_addvector_RtoQ(
+ ax2, x, y, vlength_Q, x0_Q, y0_Q, QR_rotation, width=width_Q, color=cQ
+ )
+ else:
+ ax_addvector(ax2, x, y, vlength_Q, x0_Q, y0_Q, width=width_Q, color=cQ)
+ ax_addvector_QtoR(
+ ax1, x, y, vlength_R, x0_R, y0_R, QR_rotation, width=width_R, color=cR
+ )
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, (ax1, ax2)
+
+
+def ax_addaxes(
+ ax,
+ vx,
+ vy,
+ vlength,
+ x0,
+ y0,
+ width=1,
+ color="r",
+ labelaxes=True,
+ labelsize=12,
+ labelcolor="r",
+ righthandedcoords=True,
+):
+ """
+ Adds a pair of x/y axes to the matplotlib subplot ax. The user supplies
+ the x-axis direction with (vx,vy), and the y-axis is then chosen
+ by rotating 90 degrees, in a direction set by righthandedcoords.
+
+ Accepts:
+ ax (matplotlib subplot)
+ vx,vy (numbers) x,y components of the x-axis,
+ Only the orientation is used; the axis
+ is normalized and rescaled by
+ vlength (number) the axis length
+ x0,y0 (numbers) the origin of the axes
+ labelaxes (bool) if True, label 'x' and 'y'
+ righthandedcoords (bool) if True, y-axis is counterclockwise
+ with respect to x-axis
+ """
+ # Get the x-axis
+ vL = np.hypot(vx, vy)
+ xaxis_x, xaxis_y = vlength * vx / vL, vlength * vy / vL
+ # Get the y-axes
+ if righthandedcoords:
+ yaxis_x, yaxis_y = -xaxis_y, xaxis_x
+ else:
+ yaxis_x, yaxis_y = xaxis_y, -xaxis_x
+ ax_addvector(ax, xaxis_x, xaxis_y, vlength, x0, y0, width=width, color=color)
+ ax_addvector(ax, yaxis_x, yaxis_y, vlength, x0, y0, width=width, color=color)
+ # Label axes:
+ if labelaxes:
+ xaxislabel_x = x0 + 1.1 * xaxis_x
+ xaxislabel_y = y0 + xaxis_y
+ yaxislabel_x = x0 + yaxis_x
+ yaxislabel_y = y0 + 1.1 * yaxis_y
+ ax.text(xaxislabel_y, xaxislabel_x, "x", color=labelcolor, size=labelsize)
+ ax.text(yaxislabel_y, yaxislabel_x, "y", color=labelcolor, size=labelsize)
+
+
+def ax_addaxes_QtoR(
+ ax,
+ vx,
+ vy,
+ vlength,
+ x0,
+ y0,
+ QR_rotation,
+ width=1,
+ color="r",
+ labelaxes=True,
+ labelsize=12,
+ labelcolor="r",
+):
+ """
+ Adds a pair of x/y axes to the matplotlib subplot ax. The user supplies
+ the x-axis direction with (vx,vy) in reciprocal space coordinates, and
+ the function transforms and displays the corresponding vector in real space.
+
+ Accepts:
+ ax (matplotlib subplot)
+ vx,vy (numbers) x,y components of the x-axis,
+ in reciprocal space coordinates. Only
+ the orientation is used; the axes
+ are normalized and rescaled by
+ vlength (number) the axis length, in real space
+ x0,y0 (numbers) the origin of the axes, in
+ real space
+ labelaxes (bool) if True, label 'x' and 'y'
+ QR_rotation (number) the offset angle between real and
+ diffraction space. Specifically, this is
+ the counterclockwise rotation of real space
+ with respect to diffraction space. In degrees.
+ """
+ from py4DSTEM.process.calibration.rotation import get_Rvector_from_Qvector
+
+ vx, vy, _, _ = get_Rvector_from_Qvector(vx, vy, QR_rotation)
+ ax_addaxes(
+ ax,
+ vx,
+ vy,
+ vlength,
+ x0,
+ y0,
+ width=width,
+ color=color,
+ labelaxes=labelaxes,
+ labelsize=labelsize,
+ labelcolor=labelcolor,
+ righthandedcoords=True,
+ )
+
+
+def ax_addaxes_RtoQ(
+ ax,
+ vx,
+ vy,
+ vlength,
+ x0,
+ y0,
+ QR_rotation,
+ width=1,
+ color="r",
+ labelaxes=True,
+ labelsize=12,
+ labelcolor="r",
+):
+ """
+ Adds a pair of x/y axes to the matplotlib subplot ax. The user supplies
+ the x-axis direction with (vx,vy) in real space coordinates, and the function
+ transforms and displays the corresponding vector in reciprocal space.
+
+ Accepts:
+ ax (matplotlib subplot)
+ vx,vy (numbers) x,y components of the x-axis,
+ in real space coordinates. Only
+ the orientation is used; the axes
+ are normalized and rescaled by
+ vlength (number) the axis length, in reciprocal space
+ x0,y0 (numbers) the origin of the axes, in
+ reciprocal space
+ labelaxes (bool) if True, label 'x' and 'y'
+ QR_rotation (number) the offset angle between real and
+ diffraction space. Specifically, this is
+ the counterclockwise rotation of real space
+ with respect to diffraction space. In degrees.
+ """
+ from py4DSTEM.process.calibration.rotation import get_Qvector_from_Rvector
+
+ _, _, vx, vy = get_Qvector_from_Rvector(vx, vy, QR_rotation)
+ ax_addaxes(
+ ax,
+ vx,
+ vy,
+ vlength,
+ x0,
+ y0,
+ width=width,
+ color=color,
+ labelaxes=labelaxes,
+ labelsize=labelsize,
+ labelcolor=labelcolor,
+ righthandedcoords=True,
+ )
+
+
+def show_RQ_axes(
+ realspace_image,
+ diffractionspace_image,
+ realspace_pdict,
+ diffractionspace_pdict,
+ vx,
+ vy,
+ vlength_R,
+ vlength_Q,
+ x0_R,
+ y0_R,
+ x0_Q,
+ y0_Q,
+ QR_rotation,
+ vector_space="R",
+ width_R=1,
+ color_R="r",
+ width_Q=1,
+ color_Q="r",
+ labelaxes=True,
+ labelcolor_R="r",
+ labelcolor_Q="r",
+ labelsize_R=12,
+ labelsize_Q=12,
+ figsize=(12, 6),
+ returnfig=False,
+):
+ """
+ Shows side-by-side real/reciprocal space images with a set of corresponding
+ coordinate axes overlaid in each. (vx,vy) specifies the x-axis, and the y-axis
+ is rotated 90 degrees counterclockwise in reciprocal space (relevant in case of
+ an R/Q transposition).
+
+ Accepts:
+ realspace_image (2D array)
+ diffractionspace_image (2D array)
+ realspace_pdict (dictionary) arguments and values to pass
+ to the show() fn for the real space image
+ diffractionspace_pdict (dictionary)
+ vx,vy (numbers) x,y components of the x-axis
+ in either real or diffraction space,
+ depending on the value of vector_space.
+ Note (vx,vy) is used for the orientation
+ only - the vectors are normalized
+ and rescaled by
+ vlength_R,vlength_Q (number or 1D arrays) the vector length in each
+ space, in pixels
+ x0_R,y0_R,x0_Q,y0_Q (number) the origins / vector tail positions
+ QR_rotation (number) the offset angle between real and
+ diffraction space. Specifically, this is
+ the counterclockwise rotation of real space
+ with respect to diffraction space. In degrees.
+ vector_space (string) must be 'R' or 'Q'. Specifies
+ whether the (vx,vy) values passed to this
+ function describes a real or diffracation
+ space vector.
+ """
+ assert vector_space in ("R", "Q")
+ fig, (ax1, ax2) = show_RQ(
+ realspace_image,
+ diffractionspace_image,
+ realspace_pdict,
+ diffractionspace_pdict,
+ figsize=figsize,
+ returnfig=True,
+ )
+ if vector_space == "R":
+ ax_addaxes(
+ ax1,
+ vx,
+ vy,
+ vlength_R,
+ x0_R,
+ y0_R,
+ width=width_R,
+ color=color_R,
+ labelaxes=labelaxes,
+ labelsize=labelsize_R,
+ labelcolor=labelcolor_R,
+ )
+ ax_addaxes_RtoQ(
+ ax2,
+ vx,
+ vy,
+ vlength_Q,
+ x0_Q,
+ y0_Q,
+ QR_rotation,
+ width=width_Q,
+ color=color_Q,
+ labelaxes=labelaxes,
+ labelsize=labelsize_Q,
+ labelcolor=labelcolor_Q,
+ )
+ else:
+ ax_addaxes(
+ ax2,
+ vx,
+ vy,
+ vlength_Q,
+ x0_Q,
+ y0_Q,
+ width=width_Q,
+ color=color_Q,
+ labelaxes=labelaxes,
+ labelsize=labelsize_Q,
+ labelcolor=labelcolor_Q,
+ )
+ ax_addaxes_QtoR(
+ ax1,
+ vx,
+ vy,
+ vlength_R,
+ x0_R,
+ y0_R,
+ QR_rotation,
+ width=width_R,
+ color=color_R,
+ labelaxes=labelaxes,
+ labelsize=labelsize_R,
+ labelcolor=labelcolor_R,
+ )
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, (ax1, ax2)
diff --git a/py4DSTEM/visualize/vis_grid.py b/py4DSTEM/visualize/vis_grid.py
new file mode 100644
index 000000000..d24b0b8d8
--- /dev/null
+++ b/py4DSTEM/visualize/vis_grid.py
@@ -0,0 +1,297 @@
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.patches import Rectangle
+
+from py4DSTEM.visualize.show import show, show_points
+from py4DSTEM.visualize.overlay import add_grid_overlay
+
+
+def show_DP_grid(
+ datacube, x0, y0, xL, yL, axsize=(6, 6), returnfig=False, space=0, **kwargs
+):
+ """
+ Shows a grid of diffraction patterns from DataCube datacube, starting from
+ scan position (x0,y0) and extending xL,yL.
+
+ Accepts:
+ datacube (DataCube) the 4D-STEM data
+ (x0,y0) the corner of the grid of DPs to display
+ xL,yL the extent of the grid
+ axsize the size of each diffraction pattern
+ space (number) controls the space between subplots
+
+ Returns:
+ if returnfig==false (default), the figure is plotted and nothing is returned.
+ if returnfig==false, the figure and its one axis are returned, and can be
+ further edited.
+ """
+ yy, xx = np.meshgrid(np.arange(y0, y0 + yL), np.arange(x0, x0 + xL))
+
+ fig, axs = plt.subplots(xL, yL, figsize=(yL * axsize[0], xL * axsize[1]))
+ for xi in range(xL):
+ for yi in range(yL):
+ ax = axs[xi, yi]
+ x, y = xx[xi, yi], yy[xi, yi]
+ dp = datacube.data[x, y, :, :]
+ _, _ = show(dp, figax=(fig, ax), returnfig=True, **kwargs)
+ plt.tight_layout()
+ plt.subplots_adjust(wspace=space, hspace=space)
+
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, axs
+
+
+def show_grid_overlay(
+ image, x0, y0, xL, yL, color="k", linewidth=1, alpha=1, returnfig=False, **kwargs
+):
+ """
+ Shows the image with an overlaid boxgrid outline about the pixels
+ beginning at (x0,y0) and with extent xL,yL in the two directions.
+
+ Accepts:
+ image the image array
+ x0,y0 the corner of the grid
+ xL,xL the extent of the grid
+ """
+ fig, ax = show(image, returnfig=True, **kwargs)
+ add_grid_overlay(
+ ax,
+ d={
+ "x0": x0,
+ "y0": y0,
+ "xL": xL,
+ "yL": yL,
+ "color": color,
+ "linewidth": linewidth,
+ "alpha": alpha,
+ },
+ )
+ plt.tight_layout()
+
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, ax
+
+
+def _show_grid_overlay(
+ image, x0, y0, xL, yL, color="k", linewidth=1, alpha=1, returnfig=False, **kwargs
+):
+ """
+ Shows the image with an overlaid boxgrid outline about the pixels
+ beginning at (x0,y0) and with extent xL,yL in the two directions.
+
+ Accepts:
+ image the image array
+ x0,y0 the corner of the grid
+ xL,xL the extent of the grid
+ """
+ yy, xx = np.meshgrid(np.arange(y0, y0 + yL), np.arange(x0, x0 + xL))
+
+ fig, ax = show(image, returnfig=True, **kwargs)
+ for xi in range(xL):
+ for yi in range(yL):
+ x, y = xx[xi, yi], yy[xi, yi]
+ rect = Rectangle(
+ (y - 0.5, x - 0.5),
+ 1,
+ 1,
+ lw=linewidth,
+ color=color,
+ alpha=alpha,
+ fill=False,
+ )
+ ax.add_patch(rect)
+ plt.tight_layout()
+
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, ax
+
+
+def show_image_grid(
+ get_ar,
+ H,
+ W,
+ axsize=(6, 6),
+ returnfig=False,
+ figax=None,
+ title=None,
+ title_index=False,
+ suptitle=None,
+ get_bordercolor=None,
+ get_x=None,
+ get_y=None,
+ get_pointcolors=None,
+ get_s=None,
+ open_circles=False,
+ **kwargs,
+):
+ """
+ Displays a set of images in a grid.
+
+ The images are specified by some function get_ar(i), which returns an
+ image for values of some integer index i. The values of i passed to
+ get_ar are 0 through HW-1.
+
+ To display the first 4 two-dimensional slices of some 3D array ar
+ some 3D array ar, you can do
+
+ >>> show_image_grid(lambda i:ar[:,:,i], H=2, W=2)
+
+ Its also possible to add colored borders, or overlaid points,
+ using similar functions to get_ar, i.e. functions which return
+ the color or set of points of interest as a function of index
+ i, which must be defined in the range [0,HW-1].
+
+ Accepts:
+ get_ar a function which returns a 2D array when passed
+ the integers 0 through HW-1
+ H,W integers, the dimensions of the grid
+ axsize the size of each image
+ figax controls which matplotlib Axes object draws the image.
+ If None, generates a new figure with a single Axes instance.
+ Otherwise, ax must be a 2-tuple containing the matplotlib class instances
+ (Figure,Axes), with ar then plotted in the specified Axes instance.
+ title if title is sting, then prints title as suptitle. If a suptitle is also provided,
+ the suptitle is printed insead.
+ if title is a list of strings (ex: ['title 1','title 2']), each array has
+ corresponding title in list.
+ title_index if True, prints the index i passed to get_ar over each image
+ suptitle string, suptitle on plot
+ get_bordercolor
+ if not None, should be a function defined over
+ the same i as get_ar, and which returns a
+ valid matplotlib color for each i. Adds
+ a colored bounding box about each image. E.g.
+ if `colors` is an array of colors:
+
+ >>> show_image_grid(lambda i:ar[:,:,i],H=2,W=2,
+ get_bordercolor=lambda i:colors[i])
+
+ get_x,get_y functions which returns sets of x/y positions
+ as a function of index i
+ get_s function which returns a set of point sizes
+ as a function of index i
+ get_pointcolors a function which returns a color or list of colors
+ as a function of index i
+
+ Returns:
+ if returnfig==false (default), the figure is plotted and nothing is returned.
+ if returnfig==false, the figure and its one axis are returned, and can be
+ further edited.
+ """
+ _get_bordercolor = get_bordercolor is not None
+ _get_points = (get_x is not None) and (get_y is not None)
+ _get_colors = get_pointcolors is not None
+ _get_s = get_s is not None
+
+ if figax is None:
+ fig, axs = plt.subplots(H, W, figsize=(W * axsize[0], H * axsize[1]))
+ else:
+ fig, axs = figax
+ if H == 1:
+ axs = axs[np.newaxis, :]
+ elif W == 1:
+ axs = axs[:, np.newaxis]
+ for i in range(H):
+ for j in range(W):
+ ax = axs[i, j]
+ N = i * W + j
+ # make titles
+ if type(title) == list:
+ print_title = title[N]
+ else:
+ print_title = None
+ if title_index:
+ if print_title is not None:
+ print_title = f"{N}. " + print_title
+ else:
+ print_title = f"{N}."
+ # make figures
+ try:
+ ar = get_ar(N)
+ if _get_bordercolor and _get_points:
+ bc = get_bordercolor(N)
+ x, y = get_x(N), get_y(N)
+ if _get_colors:
+ pointcolors = get_pointcolors(N)
+ else:
+ pointcolors = "r"
+ if _get_s:
+ s = get_s(N)
+ _, _ = show_points(
+ ar,
+ figax=(fig, ax),
+ returnfig=True,
+ bordercolor=bc,
+ x=x,
+ y=y,
+ s=s,
+ pointcolor=pointcolors,
+ open_circles=open_circles,
+ title=print_title,
+ **kwargs,
+ )
+ else:
+ _, _ = show_points(
+ ar,
+ figax=(fig, ax),
+ returnfig=True,
+ bordercolor=bc,
+ x=x,
+ y=y,
+ pointcolor=pointcolors,
+ open_circles=open_circles,
+ title=print_title,
+ **kwargs,
+ )
+ elif _get_bordercolor:
+ bc = get_bordercolor(N)
+ _, _ = show(
+ ar,
+ figax=(fig, ax),
+ returnfig=True,
+ bordercolor=bc,
+ title=print_title,
+ **kwargs,
+ )
+ elif _get_points:
+ x, y = get_x(N), get_y(N)
+ if _get_colors:
+ pointcolors = get_pointcolors(N)
+ else:
+ pointcolors = "r"
+ _, _ = show_points(
+ ar,
+ figax=(fig, ax),
+ x=x,
+ y=y,
+ returnfig=True,
+ pointcolor=pointcolors,
+ open_circles=open_circles,
+ title=print_title,
+ **kwargs,
+ )
+ else:
+ _, _ = show(
+ ar, figax=(fig, ax), returnfig=True, title=print_title, **kwargs
+ )
+ except IndexError:
+ ax.axis("off")
+ if type(title) == str:
+ fig.suptitle(title)
+ if suptitle:
+ fig.suptitle(suptitle)
+ plt.tight_layout()
+
+ if not returnfig:
+ return
+ else:
+ return fig, axs
diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py
new file mode 100644
index 000000000..c1e9d6b19
--- /dev/null
+++ b/py4DSTEM/visualize/vis_special.py
@@ -0,0 +1,897 @@
+from matplotlib import cm, colors as mcolors, pyplot as plt
+import numpy as np
+from matplotlib.patches import Wedge
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+from scipy.spatial import Voronoi
+
+from emdfile import PointList
+from py4DSTEM.visualize import show
+from py4DSTEM.visualize.overlay import (
+ add_pointlabels,
+ add_vector,
+ add_bragg_index_labels,
+ add_ellipses,
+ add_points,
+ add_scalebar,
+)
+from py4DSTEM.visualize.vis_grid import show_image_grid
+from py4DSTEM.visualize.vis_RQ import ax_addaxes, ax_addaxes_QtoR
+from colorspacious import cspace_convert
+
+
+def show_elliptical_fit(
+ ar,
+ fitradii,
+ p_ellipse,
+ fill=True,
+ color_ann="y",
+ color_ell="r",
+ alpha_ann=0.2,
+ alpha_ell=0.7,
+ linewidth_ann=2,
+ linewidth_ell=2,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ Plots an elliptical curve over its annular fit region.
+
+ Args:
+ center (2-tuple): the center
+ fitradii (2-tuple of numbers): the annulus inner and outer fit radii
+ p_ellipse (5-tuple): the parameters of the fit ellipse, (qx0,qy0,a,b,theta).
+ See the module docstring for utils.elliptical_coords for more details.
+ fill (bool): if True, fills in the annular fitting region,
+ else shows only inner/outer edges
+ color_ann (color): annulus color
+ color_ell (color): ellipse color
+ alpha_ann: transparency for the annulus
+ alpha_ell: transparency forn the fit ellipse
+ linewidth_ann:
+ linewidth_ell:
+ """
+ Ri, Ro = fitradii
+ qx0, qy0, a, b, theta = p_ellipse
+ fig, ax = show(
+ ar,
+ annulus={
+ "center": (qx0, qy0),
+ "radii": (Ri, Ro),
+ "fill": fill,
+ "color": color_ann,
+ "alpha": alpha_ann,
+ "linewidth": linewidth_ann,
+ },
+ ellipse={
+ "center": (qx0, qy0),
+ "a": a,
+ "b": b,
+ "theta": theta,
+ "color": color_ell,
+ "alpha": alpha_ell,
+ "linewidth": linewidth_ell,
+ },
+ returnfig=True,
+ **kwargs,
+ )
+
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, ax
+
+
+def show_amorphous_ring_fit(
+ dp,
+ fitradii,
+ p_dsg,
+ N=12,
+ cmap=("gray", "gray"),
+ fitborder=True,
+ fitbordercolor="k",
+ fitborderlw=0.5,
+ scaling="log",
+ ellipse=False,
+ ellipse_color="r",
+ ellipse_alpha=0.7,
+ ellipse_lw=2,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ Display a diffraction pattern with a fit to its amorphous ring, interleaving
+ the data and the fit in a pinwheel pattern.
+
+ Args:
+ dp (array): the diffraction pattern
+ fitradii (2-tuple of numbers): the min/max distances of the fitting annulus
+ p_dsg (11-tuple): the fit parameters to the double-sided gaussian
+ function returned by fit_ellipse_amorphous_ring
+ N (int): the number of pinwheel sections
+ cmap (colormap or 2-tuple of colormaps): if passed a single cmap, uses this
+ colormap for both the data and the fit; if passed a 2-tuple of cmaps, uses
+ the first for the data and the second for the fit
+ fitborder (bool): if True, plots a border line around the fit data
+ fitbordercolor (color): color of the fitborder
+ fitborderlw (number): linewidth of the fitborder
+ scaling (str): the normal scaling param -- see docstring for visualize.show
+ ellipse (bool): if True, overlay an ellipse
+ returnfig (bool): if True, returns the figure
+ """
+ from py4DSTEM.process.calibration import double_sided_gaussian
+ from py4DSTEM.process.utils import convert_ellipse_params
+
+ assert len(p_dsg) == 11
+ assert isinstance(N, (int, np.integer))
+ if isinstance(cmap, tuple):
+ cmap_data, cmap_fit = cmap[0], cmap[1]
+ else:
+ cmap_data, cmap_fit = cmap, cmap
+ Q_Nx, Q_Ny = dp.shape
+ qmin, qmax = fitradii
+
+ # Make coords
+ qx0, qy0 = p_dsg[6], p_dsg[7]
+ qyy, qxx = np.meshgrid(np.arange(Q_Ny), np.arange(Q_Nx))
+ qx, qy = qxx - qx0, qyy - qy0
+ q = np.hypot(qx, qy)
+ theta = np.arctan2(qy, qx)
+
+ # Make mask
+ thetas = np.linspace(-np.pi, np.pi, 2 * N + 1)
+ pinwheel = np.zeros((Q_Nx, Q_Ny), dtype=bool)
+ for i in range(N):
+ pinwheel += (theta > thetas[2 * i]) * (theta <= thetas[2 * i + 1])
+ mask = pinwheel * (q > qmin) * (q <= qmax)
+
+ # Get fit data
+ fit = double_sided_gaussian(p_dsg, qxx, qyy)
+
+ # Show
+ (fig, ax), (vmin, vmax) = show(
+ dp,
+ scaling=scaling,
+ cmap=cmap_data,
+ mask=np.logical_not(mask),
+ mask_color="empty",
+ returnfig=True,
+ return_intensity_range=True,
+ **kwargs,
+ )
+ show(
+ fit,
+ scaling=scaling,
+ figax=(fig, ax),
+ intensity_range="absolute",
+ vmin=vmin,
+ vmax=vmax,
+ cmap=cmap_fit,
+ mask=mask,
+ mask_color="empty",
+ **kwargs,
+ )
+ if fitborder:
+ if N % 2 == 1:
+ thetas += (thetas[1] - thetas[0]) / 2
+ if (N // 2 % 2) == 0:
+ thetas = np.roll(thetas, -1)
+ for i in range(N):
+ ax.add_patch(
+ Wedge(
+ (qy0, qx0),
+ qmax,
+ np.degrees(thetas[2 * i]),
+ np.degrees(thetas[2 * i + 1]),
+ width=qmax - qmin,
+ fill=None,
+ color=fitbordercolor,
+ lw=fitborderlw,
+ )
+ )
+
+ # Add ellipse overlay
+ if ellipse:
+ A, B, C = p_dsg[8], p_dsg[9], p_dsg[10]
+ a, b, theta = convert_ellipse_params(A, B, C)
+ ellipse = {
+ "center": (qx0, qy0),
+ "a": a,
+ "b": b,
+ "theta": theta,
+ "color": ellipse_color,
+ "alpha": ellipse_alpha,
+ "linewidth": ellipse_lw,
+ }
+ add_ellipses(ax, ellipse)
+
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, ax
+
+
+def show_qprofile(
+ q,
+ intensity,
+ ymax=None,
+ figsize=(12, 4),
+ returnfig=False,
+ color="k",
+ xlabel="q (pixels)",
+ ylabel="Intensity (A.U.)",
+ labelsize=16,
+ ticklabelsize=14,
+ grid=True,
+ label=None,
+ **kwargs,
+):
+ """
+ Plots a diffraction space radial profile.
+ Params:
+ q (1D array) the diffraction coordinate / x-axis
+ intensity (1D array) the y-axis values
+ ymax (number) max value for the yaxis
+ color (matplotlib color) profile color
+ xlabel (str)
+ ylabel
+ labelsize size of x and y labels
+ ticklabelsize
+ grid True or False
+ label a legend label for the plotted curve
+ """
+ if ymax is None:
+ ymax = np.max(intensity) * 1.05
+
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.plot(q, intensity, color=color, label=label)
+ ax.grid(grid)
+ ax.set_ylim(0, ymax)
+ ax.tick_params(axis="x", labelsize=ticklabelsize)
+ ax.set_yticklabels([])
+ ax.set_xlabel(xlabel, size=labelsize)
+ ax.set_ylabel(ylabel, size=labelsize)
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, ax
+
+
+def show_kernel(kernel, R, L, W, figsize=(12, 6), returnfig=False, **kwargs):
+ """
+ Plots, side by side, the probe kernel and its line profile.
+ R is the kernel plot's window size.
+ L and W are the length and width of the lineprofile.
+ """
+ lineprofile_1 = np.concatenate(
+ [np.sum(kernel[-L:, :W], axis=1), np.sum(kernel[:L, :W], axis=1)]
+ )
+ lineprofile_2 = np.concatenate(
+ [np.sum(kernel[:W, -L:], axis=0), np.sum(kernel[:W, :L], axis=0)]
+ )
+
+ im_kernel = np.vstack(
+ [
+ np.hstack([kernel[-int(R) :, -int(R) :], kernel[-int(R) :, : int(R)]]),
+ np.hstack([kernel[: int(R), -int(R) :], kernel[: int(R), : int(R)]]),
+ ]
+ )
+
+ fig, axs = plt.subplots(1, 2, figsize=figsize)
+ axs[0].matshow(im_kernel, cmap="gray")
+ axs[0].plot(np.ones(2 * R) * R, np.arange(2 * R), c="r")
+ axs[0].plot(np.arange(2 * R), np.ones(2 * R) * R, c="c")
+
+ axs[1].plot(np.arange(len(lineprofile_1)), lineprofile_1, c="r")
+ axs[1].plot(np.arange(len(lineprofile_2)), lineprofile_2, c="c")
+
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, axs
+
+
+def show_voronoi(
+ ar,
+ x,
+ y,
+ color_points="r",
+ color_lines="w",
+ max_dist=None,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ words
+ """
+ from py4DSTEM.process.utils import get_voronoi_vertices
+
+ Nx, Ny = ar.shape
+ points = np.vstack((x, y)).T
+ voronoi = Voronoi(points)
+ vertices = get_voronoi_vertices(voronoi, Nx, Ny)
+
+ if max_dist is None:
+ fig, ax = show(ar, returnfig=True, **kwargs)
+ else:
+ centers = [(x[i], y[i]) for i in range(len(x))]
+ fig, ax = show(
+ ar,
+ returnfig=True,
+ **kwargs,
+ circle={
+ "center": centers,
+ "R": max_dist,
+ "fill": False,
+ "color": color_points,
+ },
+ )
+
+ ax.scatter(voronoi.points[:, 1], voronoi.points[:, 0], color=color_points)
+ for region in range(len(vertices)):
+ vertices_curr = vertices[region]
+ for i in range(len(vertices_curr)):
+ x0, y0 = vertices_curr[i, :]
+ xf, yf = vertices_curr[(i + 1) % len(vertices_curr), :]
+ ax.plot((y0, yf), (x0, xf), color=color_lines)
+ ax.set_xlim([0, Ny])
+ ax.set_ylim([0, Nx])
+ plt.gca().invert_yaxis()
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, ax
+
+
+def show_class_BPs(ar, x, y, s, s2, color="r", color2="y", **kwargs):
+ """
+ words
+ """
+ N = len(x)
+ assert N == len(y) == len(s)
+
+ fig, ax = show(ar, returnfig=True, **kwargs)
+ ax.scatter(y, x, s=s2, color=color2)
+ ax.scatter(y, x, s=s, color=color)
+ plt.show()
+ return
+
+
+def show_class_BPs_grid(
+ ar,
+ H,
+ W,
+ x,
+ y,
+ get_s,
+ s2,
+ color="r",
+ color2="y",
+ returnfig=False,
+ axsize=(6, 6),
+ titlesize=0,
+ get_bordercolor=None,
+ **kwargs,
+):
+ """
+ words
+ """
+ fig, axs = show_image_grid(
+ lambda i: ar,
+ H,
+ W,
+ axsize=axsize,
+ titlesize=titlesize,
+ get_bordercolor=get_bordercolor,
+ returnfig=True,
+ **kwargs,
+ )
+ for i in range(H):
+ for j in range(W):
+ ax = axs[i, j]
+ N = i * W + j
+ s = get_s(N)
+ ax.scatter(y, x, s=s2, color=color2)
+ ax.scatter(y, x, s=s, color=color)
+ if not returnfig:
+ plt.show()
+ return
+ else:
+ return fig, axs
+
+
+def show_pointlabels(
+ ar, x, y, color="lightblue", size=20, alpha=1, returnfig=False, **kwargs
+):
+ """
+ Show enumerated index labels for a set of points
+ """
+ fig, ax = show(ar, returnfig=True, **kwargs)
+ d = {"x": x, "y": y, "size": size, "color": color, "alpha": alpha}
+ add_pointlabels(ax, d)
+
+ if returnfig:
+ return fig, ax
+ else:
+ plt.show()
+ return
+
+
+def select_point(
+ ar,
+ x,
+ y,
+ i,
+ color="lightblue",
+ color_selected="r",
+ size=20,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ Show enumerated index labels for a set of points, with one selected point highlighted
+ """
+ fig, ax = show(ar, returnfig=True, **kwargs)
+ d1 = {"x": x, "y": y, "size": size, "color": color}
+ d2 = {
+ "x": x[i],
+ "y": y[i],
+ "size": size,
+ "color": color_selected,
+ "fontweight": "bold",
+ }
+ add_pointlabels(ax, d1)
+ add_pointlabels(ax, d2)
+
+ if returnfig:
+ return fig, ax
+ else:
+ plt.show()
+ return
+
+
+def show_max_peak_spacing(
+ ar, spacing, braggdirections, color="g", lw=2, returnfig=False, **kwargs
+):
+ """Show a circle of radius `spacing` about each Bragg direction"""
+ centers = [
+ (braggdirections.data["qx"][i], braggdirections.data["qy"][i])
+ for i in range(braggdirections.length)
+ ]
+ fig, ax = show(
+ ar,
+ circle={
+ "center": centers,
+ "R": spacing,
+ "color": color,
+ "fill": False,
+ "lw": lw,
+ },
+ returnfig=True,
+ **kwargs,
+ )
+ if returnfig:
+ return fig, ax
+ else:
+ plt.show()
+ return
+
+
+def show_origin_meas(data):
+ """
+ Show the measured positions of the origin.
+
+ Args:
+ data (DataCube or Calibration or 2-tuple of arrays (qx0,qy0))
+ """
+ from py4DSTEM.data import Calibration
+ from py4DSTEM.datacube import DataCube
+
+ if isinstance(data, tuple):
+ assert len(data) == 2
+ qx, qy = data
+ elif isinstance(data, DataCube):
+ qx, qy = data.calibration.get_origin_meas()
+ elif isinstance(data, Calibration):
+ qx, qy = data.get_origin_meas()
+ else:
+ raise Exception("data must be of type Datacube or Calibration or tuple")
+
+ show_image_grid(get_ar=lambda i: [qx, qy][i], H=1, W=2, cmap="RdBu")
+
+
+def show_origin_fit(data):
+ """
+ Show the measured, fit, and residuals of the origin positions.
+
+ Args:
+ data (DataCube or Calibration or (3,2)-tuple of arrays
+ ((qx0_meas,qy0_meas),(qx0_fit,qy0_fit),(qx0_residuals,qy0_residuals))
+ """
+ from py4DSTEM.data import Calibration
+ from py4DSTEM.datacube import DataCube
+
+ if isinstance(data, tuple):
+ assert len(data) == 3
+ qx0_meas, qy_meas = data[0]
+ qx0_fit, qy0_fit = data[1]
+ qx0_residuals, qy0_residuals = data[2]
+ elif isinstance(data, DataCube):
+ qx0_meas, qy0_meas = data.calibration.get_origin_meas()
+ qx0_fit, qy0_fit = data.calibration.get_origin()
+ qx0_residuals, qy0_residuals = data.calibration.get_origin_residuals()
+ elif isinstance(data, Calibration):
+ qx0_meas, qy0_meas = data.get_origin_meas()
+ qx0_fit, qy0_fit = data.get_origin()
+ qx0_residuals, qy0_residuals = data.get_origin_residuals()
+ else:
+ raise Exception("data must be of type Datacube or Calibration or tuple")
+
+ show_image_grid(
+ get_ar=lambda i: [
+ qx0_meas,
+ qx0_fit,
+ qx0_residuals,
+ qy0_meas,
+ qy0_fit,
+ qy0_residuals,
+ ][i],
+ H=2,
+ W=3,
+ cmap="RdBu",
+ )
+
+
+def show_selected_dps(
+ datacube,
+ positions,
+ im,
+ bragg_pos=None,
+ colors=None,
+ HW=None,
+ figsize_im=(6, 6),
+ figsize_dp=(4, 4),
+ **kwargs,
+):
+ """
+ Shows two plots: first, a real space image overlaid with colored dots
+ at the specified positions; second, a grid of diffraction patterns
+ corresponding to these scan positions.
+
+ Args:
+ datacube (DataCube):
+ positions (len N list or tuple of 2-tuples): the scan positions
+ im (2d array): a real space image
+ bragg_pos (len N list of pointlistarrays): bragg disk positions
+ for each position. if passed, overlays the disk positions,
+ and supresses plot of the real space image
+ colors (len N list of colors or None):
+ HW (2-tuple of ints): diffraction pattern grid shape
+ figsize_im (2-tuple): size of the image figure
+ figsize_dp (2-tuple): size of each diffraction pattern panel
+ **kwargs (dict): arguments passed to visualize.show for the
+ *diffraction patterns*. Default is `scaling='log'`
+ """
+ from py4DSTEM.datacube import DataCube
+
+ assert isinstance(datacube, DataCube)
+ N = len(positions)
+ assert all(
+ [len(x) == 2 for x in positions]
+ ), "Improperly formated argument `positions`"
+ if bragg_pos is not None:
+ show_disk_pos = True
+ assert len(bragg_pos) == N
+ else:
+ show_disk_pos = False
+ if colors is None:
+ from matplotlib.cm import gist_ncar
+
+ linsp = np.linspace(0, 1, N, endpoint=False)
+ colors = [gist_ncar(i) for i in linsp]
+ assert len(colors) == N, "Number of positions and colors don't match"
+ from matplotlib.colors import is_color_like
+
+ assert [is_color_like(i) for i in colors]
+ if HW is None:
+ W = int(np.ceil(np.sqrt(N)))
+ if W < 3:
+ W = 3
+ H = int(np.ceil(N / W))
+ else:
+ H, W = HW
+ assert all([isinstance(x, (int, np.integer)) for x in (H, W)])
+
+ x = [i[0] for i in positions]
+ y = [i[1] for i in positions]
+ if "scaling" not in kwargs.keys():
+ kwargs["scaling"] = "log"
+ if not show_disk_pos:
+ fig, ax = show(im, figsize=figsize_im, returnfig=True)
+ add_points(ax, d={"x": x, "y": y, "pointcolor": colors})
+ show_image_grid(
+ get_ar=lambda i: datacube.data[x[i], y[i], :, :],
+ H=H,
+ W=W,
+ get_bordercolor=lambda i: colors[i],
+ axsize=figsize_dp,
+ **kwargs,
+ )
+ else:
+ show_image_grid(
+ get_ar=lambda i: datacube.data[x[i], y[i], :, :],
+ H=H,
+ W=W,
+ get_bordercolor=lambda i: colors[i],
+ axsize=figsize_dp,
+ get_x=lambda i: bragg_pos[i].data["qx"],
+ get_y=lambda i: bragg_pos[i].data["qy"],
+ get_pointcolors=lambda i: colors[i],
+ **kwargs,
+ )
+
+
+def Complex2RGB(complex_data, vmin=None, vmax=None, power=None, chroma_boost=1):
+ """
+ complex_data (array): complex array to plot
+ vmin (float) : minimum absolute value
+ vmax (float) : maximum absolute value
+ power (float) : power to raise amplitude to
+ chroma_boost (float): boosts chroma for higher-contrast (~1-2.5)
+ """
+ amp = np.abs(complex_data)
+ phase = np.angle(complex_data)
+
+ if power is not None:
+ amp = amp**power
+
+ if np.isclose(np.max(amp), np.min(amp)):
+ if vmin is None:
+ vmin = 0
+ if vmax is None:
+ vmax = np.max(amp)
+ else:
+ if vmin is None:
+ vmin = 0.02
+ if vmax is None:
+ vmax = 0.98
+ vals = np.sort(amp[~np.isnan(amp)])
+ ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
+ ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
+ ind_vmin = np.max([0, ind_vmin])
+ ind_vmax = np.min([len(vals) - 1, ind_vmax])
+ vmin = vals[ind_vmin]
+ vmax = vals[ind_vmax]
+
+ amp = np.where(amp < vmin, vmin, amp)
+ amp = np.where(amp > vmax, vmax, amp)
+ amp = ((amp - vmin) / vmax).clip(1e-16, 1)
+
+ J = amp * 61.5 # Note we restrict luminance to the monotonic chroma cutoff
+ C = np.minimum(chroma_boost * 98 * J / 123, 110)
+ h = np.rad2deg(phase) + 180
+
+ JCh = np.stack((J, C, h), axis=-1)
+ rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1)
+
+ return rgb
+
+
+def add_colorbar_arg(cax, chroma_boost=1, c=49, j=61.5):
+ """
+ cax : axis to add cbar to
+ chroma_boost (float): boosts chroma for higher-contrast (~1-2.25)
+ c (float) : constant chroma value
+ j (float) : constant luminance value
+ """
+
+ h = np.linspace(0, 360, 256, endpoint=False)
+ J = np.full_like(h, j)
+ C = np.full_like(h, np.minimum(c * chroma_boost, 110))
+ JCh = np.stack((J, C, h), axis=-1)
+ rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1)
+ newcmp = mcolors.ListedColormap(rgb_vals)
+ norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi)
+
+ cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax)
+
+ cb.set_label("arg", rotation=0, ha="center", va="bottom")
+ cb.ax.yaxis.set_label_coords(0.5, 1.01)
+ cb.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]))
+ cb.set_ticklabels(
+ [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"]
+ )
+
+
+def show_complex(
+ ar_complex,
+ vmin=None,
+ vmax=None,
+ power=None,
+ chroma_boost=1,
+ cbar=True,
+ scalebar=False,
+ pixelunits="pixels",
+ pixelsize=1,
+ returnfig=False,
+ **kwargs,
+):
+ """
+ Function to plot complex arrays
+
+ Args:
+ ar_complex (2D array) : complex array to be plotted. If ar_complex is list of complex arrarys
+ such as [array1, array2], then arrays are horizonally plotted in one figure
+ vmin (float, optional) : minimum absolute value
+ vmax (float, optional) : maximum absolute value
+ if None, vmin/vmax are set to fractions of the distribution of pixel values in the array,
+ e.g. vmin=0.02 will set the minumum display value to saturate the lower 2% of pixels
+ power (float,optional) : power to raise amplitude to
+ chroma_boost (float) : boosts chroma for higher-contrast (~1-2.25)
+ cbar (bool, optional) : if True, include color bar
+ scalebar (bool, optional) : if True, adds scale bar
+ pixelunits (str, optional) : units for scalebar
+ pixelsize (float, optional) : size of one pixel in pixelunits for scalebar
+ returnfig (bool, optional) : if True, the function returns the tuple (figure,axis)
+
+ Returns:
+ if returnfig==False (default), the figure is plotted and nothing is returned.
+ if returnfig==True, return the figure and the axis.
+ """
+ # convert to complex colors
+ ar_complex = (
+ ar_complex[0]
+ if (isinstance(ar_complex, list) and len(ar_complex) == 1)
+ else ar_complex
+ )
+ if isinstance(ar_complex, list):
+ if isinstance(ar_complex[0], list):
+ rgb = [
+ Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost)
+ for sublist in ar_complex
+ for ar in sublist
+ ]
+ H = len(ar_complex)
+ W = len(ar_complex[0])
+
+ else:
+ rgb = [
+ Complex2RGB(ar, vmin, vmax, power=power, chroma_boost=chroma_boost)
+ for ar in ar_complex
+ ]
+ if len(rgb[0].shape) == 4:
+ H = len(ar_complex)
+ W = rgb[0].shape[0]
+ else:
+ H = 1
+ W = len(ar_complex)
+ is_grid = True
+ else:
+ rgb = Complex2RGB(
+ ar_complex, vmin, vmax, power=power, chroma_boost=chroma_boost
+ )
+ if len(rgb.shape) == 4:
+ is_grid = True
+ H = 1
+ W = rgb.shape[0]
+ elif len(rgb.shape) == 5:
+ is_grid = True
+ H = rgb.shape[0]
+ W = rgb.shape[1]
+ rgb = rgb.reshape((-1,) + rgb.shape[-3:])
+ else:
+ is_grid = False
+ # plot
+ if is_grid:
+ from py4DSTEM.visualize import show_image_grid
+
+ fig, ax = show_image_grid(
+ get_ar=lambda i: rgb[i],
+ H=H,
+ W=W,
+ vmin=0,
+ vmax=1,
+ intensity_range="absolute",
+ returnfig=True,
+ **kwargs,
+ )
+ if scalebar is True:
+ scalebar = {
+ "Nx": ar_complex[0].shape[0],
+ "Ny": ar_complex[0].shape[1],
+ "pixelsize": pixelsize,
+ "pixelunits": pixelunits,
+ }
+
+ add_scalebar(ax[0, 0], scalebar)
+ else:
+ fig, ax = show(
+ rgb, vmin=0, vmax=1, intensity_range="absolute", returnfig=True, **kwargs
+ )
+
+ if scalebar is True:
+ scalebar = {
+ "Nx": ar_complex.shape[0],
+ "Ny": ar_complex.shape[1],
+ "pixelsize": pixelsize,
+ "pixelunits": pixelunits,
+ }
+
+ add_scalebar(ax, scalebar)
+
+ # add color bar
+ if cbar:
+ if is_grid:
+ for ax_flat in ax.flatten():
+ divider = make_axes_locatable(ax_flat)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(ax_cb, chroma_boost=chroma_boost)
+ else:
+ divider = make_axes_locatable(ax)
+ ax_cb = divider.append_axes("right", size="5%", pad="2.5%")
+ add_colorbar_arg(ax_cb, chroma_boost=chroma_boost)
+
+ fig.tight_layout()
+
+ if returnfig:
+ return fig, ax
+
+
+def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False):
+ """
+ Utility function for calculating min and max values for plotting array
+ based on distribution of pixel values
+
+ Parameters
+ ----------
+ array: np.array
+ array to be plotted
+ vmin: float
+ lower fraction cut off of pixel values
+ vmax: float
+ upper fraction cut off of pixel values
+ normalize: bool
+ if True, rescales from 0 to 1
+
+ Returns
+ ----------
+ scaled_array: np.array
+ array clipped outside vmin and vmax
+ vmin: float
+ lower value to be plotted
+ vmax: float
+ upper value to be plotted
+ """
+
+ if vmin is None:
+ vmin = 0.02
+ if vmax is None:
+ vmax = 0.98
+
+ vals = np.sort(array.ravel())
+ ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
+ ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
+ ind_vmin = np.max([0, ind_vmin])
+ ind_vmax = np.min([len(vals) - 1, ind_vmax])
+ vmin = vals[ind_vmin]
+ vmax = vals[ind_vmax]
+
+ if vmax == vmin:
+ vmin = vals[0]
+ vmax = vals[-1]
+
+ scaled_array = array.copy()
+ scaled_array = np.where(scaled_array < vmin, vmin, scaled_array)
+ scaled_array = np.where(scaled_array > vmax, vmax, scaled_array)
+
+ if normalize:
+ scaled_array -= scaled_array.min()
+ scaled_array /= scaled_array.max()
+ vmin = 0
+ vmax = 1
+
+ return scaled_array, vmin, vmax
diff --git a/real_space_control_widget.ui b/real_space_control_widget.ui
deleted file mode 100644
index 090245fc3..000000000
--- a/real_space_control_widget.ui
+++ /dev/null
@@ -1,312 +0,0 @@
-
-
- diffraction_space_widget
-
-
-
- 0
- 0
- 671
- 561
-
-
-
- Real Space
-
-
-
-
-
-
-
-
-
- Binning and Cropping
-
-
- Qt::AlignCenter
-
-
-
-
-
-
-
-
- Bin by
-
-
-
-
-
-
- 1
-
-
-
-
-
-
- Bin Data
-
-
-
-
-
-
-
-
-
-
- Set Crop Window
-
-
-
-
-
-
- Crop Data
-
-
-
-
-
-
-
-
-
-
-
-
- Processing Thing 1
-
-
- Qt::AlignCenter
-
-
-
-
-
-
-
-
- 500
-
-
- Qt::Horizontal
-
-
-
-
-
-
-
-
- 0
-
-
-
-
-
-
- Execute Thing
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- Processing Thing 2
-
-
- Qt::AlignCenter
-
-
-
-
-
-
-
-
-
-
- Method
-
-
- Qt::AlignCenter
-
-
-
-
-
-
-
-
- Method 1
-
-
-
-
-
-
- Method 2
-
-
-
-
-
-
- Method 3
-
-
-
-
-
-
- Method 4
-
-
-
-
-
-
-
-
-
-
-
-
- Parameter Values
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- Execute
-
-
-
-
-
-
-
-
-
-
-
-
- Processing Thing 2
-
-
- Qt::AlignCenter
-
-
-
-
-
-
-
-
-
-
- Method
-
-
- Qt::AlignCenter
-
-
-
-
-
-
-
-
- Method 1
-
-
-
-
-
-
- Method 2
-
-
-
-
-
-
- Method 3
-
-
-
-
-
-
- Method 4
-
-
-
-
-
-
-
-
-
-
-
-
- Parameter Values
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- Execute
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/setup.py b/setup.py
new file mode 100644
index 000000000..631f23f9a
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,66 @@
+from setuptools import setup, find_packages
+from distutils.util import convert_path
+
+with open("README.md", "r") as f:
+ long_description = f.read()
+
+version_ns = {}
+vpath = convert_path("py4DSTEM/version.py")
+with open(vpath) as version_file:
+ exec(version_file.read(), version_ns)
+
+setup(
+ name="py4DSTEM",
+ version=version_ns["__version__"],
+ packages=find_packages(),
+ description="An open source python package for processing and analysis of 4D STEM data.",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ url="https://github.com/py4dstem/py4DSTEM/",
+ author="Benjamin H. Savitzky",
+ author_email="ben.savitzky@gmail.com",
+ license="GNU GPLv3",
+ keywords="STEM 4DSTEM",
+ python_requires=">=3.9,<=3.12",
+ install_requires=[
+ "numpy >= 1.19",
+ "scipy >= 1.5.2",
+ "h5py >= 3.2.0",
+ "hdf5plugin >= 4.1.3",
+ "ncempy >= 1.8.1",
+ "matplotlib >= 3.2.2",
+ "scikit-image >= 0.17.2",
+ "scikit-learn >= 0.23.2",
+ "scikit-optimize >= 0.9.0",
+ "tqdm >= 4.46.1",
+ "dill >= 0.3.3",
+ "gdown >= 4.7.1",
+ "dask >= 2.3.0",
+ "distributed >= 2.3.0",
+ "emdfile >= 0.0.14",
+ "mpire >= 2.7.1",
+ "threadpoolctl >= 3.1.0",
+ "pylops >= 2.1.0",
+ "colorspacious >= 1.1.2",
+ ],
+ extras_require={
+ "ipyparallel": ["ipyparallel >= 6.2.4", "dill >= 0.3.3"],
+ "cuda": ["cupy >= 10.0.0"],
+ "acom": ["pymatgen >= 2022", "mp-api == 0.24.1"],
+ "aiml": ["tensorflow == 2.4.1", "tensorflow-addons <= 0.14.0", "crystal4D"],
+ "aiml-cuda": [
+ "tensorflow == 2.4.1",
+ "tensorflow-addons <= 0.14.0",
+ "crystal4D",
+ "cupy >= 10.0.0",
+ ],
+ "numba": ["numba >= 0.49.1"],
+ },
+ package_data={
+ "py4DSTEM": [
+ "process/utils/scattering_factors.txt",
+ "braggvectors/multicorr_row_kernel.cu",
+ "braggvectors/multicorr_col_kernel.cu",
+ ]
+ },
+)
diff --git a/test/README.md b/test/README.md
new file mode 100644
index 000000000..711d8379d
--- /dev/null
+++ b/test/README.md
@@ -0,0 +1,51 @@
+# `py4DSTEM.test` submodule
+
+Testing py4DSTEM with pytest.
+
+
+
+## Setup
+
+Install the latest pytest with
+
+`pip install -U pytest`
+
+
+Some tests need data files to run.
+In an environment with py4DSTEM installed,
+do `python download_test_data.py` from this directory.
+The script will make a new `unit_test_data` and
+download the requisite files here.
+
+
+
+## Running tests
+
+To run all tests, you can then do `pytest` from
+the command line - pytest will collect and run all the test
+in this directory and its subdirectories. You can also run a
+single test file or all files in a single test subdirectory with
+
+`pytest test_file.py`
+`pytest test_dir`
+
+
+
+
+
+## Adding new tests
+
+When pytest is run it will find files in this directory and its
+subdirectories with the format `test_*.py` or `*_test.py`.
+Name your new file `test_*.py` for some short, descriptive `*`
+specifying that nature of your tests.
+
+Inside the file, any function called `test_*` will be found and run
+by pytest, and in classes named `Test*` any methods called `test_*` will
+also be found and run.
+
+
+
+
+
+
diff --git a/test/download_test_data.py b/test/download_test_data.py
new file mode 100644
index 000000000..913dc73a6
--- /dev/null
+++ b/test/download_test_data.py
@@ -0,0 +1,15 @@
+# When run as a Python script, this file
+# makes a folder called 'unit_test_data' if one
+# doesn't already exist, and downloads
+# py4DSTEM's test data there.
+
+
+from py4DSTEM import _TESTPATH
+
+filepath = _TESTPATH
+
+
+if __name__ == "__main__":
+ from py4DSTEM.io import download_file_from_google_drive as download
+
+ download(id_="unit_test_data", destination=filepath, overwrite=True)
diff --git a/test/gettestdata.py b/test/gettestdata.py
new file mode 100644
index 000000000..f1012f03c
--- /dev/null
+++ b/test/gettestdata.py
@@ -0,0 +1,74 @@
+# A command line tool for downloading data to run the py4DSTEM test suite
+
+
+import argparse
+from os.path import exists
+from os import makedirs
+
+from py4DSTEM import _TESTPATH as testpath
+from py4DSTEM.io import gdrive_download as download
+
+
+# Make the argument parser
+parser = argparse.ArgumentParser(
+ description="A command line tool for downloading data to run the py4DSTEM test suite"
+)
+
+# Set up data download options
+data_options = [
+ "tutorials",
+ "io",
+ "basic",
+ "strain",
+]
+
+# Add arguments
+parser.add_argument(
+ "data",
+ help="which data to download.",
+ choices=data_options,
+)
+parser.add_argument(
+ "-o",
+ "--overwrite",
+ help="if turned on, overwrite files that are already present. Otherwise, skips these files.",
+ action="store_true",
+)
+parser.add_argument(
+ "-v", "--verbose", help="turn on verbose output", action="store_true"
+)
+
+
+# Get the command line arguments
+args = parser.parse_args()
+
+
+# Set up paths
+if not exists(testpath):
+ makedirs(testpath)
+
+
+# Set data collection key
+if args.data == "tutorials":
+ data = ["tutorials"]
+elif args.data == "io":
+ data = ["test_io", "test_arina"]
+elif args.data == "basic":
+ data = ["small_datacube"]
+elif args.data == "strain":
+ data = ["strain"]
+else:
+ raise Exception(f"invalid data choice, {parser.data}")
+
+# Download data
+for d in data:
+ download(d, destination=testpath, overwrite=args.overwrite, verbose=args.verbose)
+
+# Always download the basic datacube
+if args.data != "basic":
+ download(
+ "small_datacube",
+ destination=testpath,
+ overwrite=args.overwrite,
+ verbose=args.verbose,
+ )
diff --git a/test/test_braggvectors.py b/test/test_braggvectors.py
new file mode 100644
index 000000000..7b0a9989b
--- /dev/null
+++ b/test/test_braggvectors.py
@@ -0,0 +1,92 @@
+import py4DSTEM
+import numpy as np
+from os.path import join
+
+# set filepath
+path = join(py4DSTEM._TESTPATH, "test_io/legacy_v0.9_simAuNanoplatelet_bin.h5")
+
+
+class TestDiskDetectionBasic:
+ # setup/teardown
+ def setup_class(cls):
+ # Read sim Au datacube
+ datacube = py4DSTEM.io.read(path, data_id="polyAu_4DSTEM")
+ cls.datacube = datacube
+
+ # prepare a probe
+ mask = np.zeros(datacube.Rshape, dtype=bool)
+ mask[28:33, 14:19] = 1
+ probe = datacube.get_vacuum_probe(ROI=mask)
+ alpha_pr, qx0_pr, qy0_pr = py4DSTEM.process.calibration.get_probe_size(
+ probe.probe
+ )
+ probe.get_kernel(
+ mode="sigmoid", origin=(qx0_pr, qy0_pr), radii=(alpha_pr, 2 * alpha_pr)
+ )
+ cls.probe = probe
+
+ # Set disk detection parameters
+ cls.detect_params = {
+ "corrPower": 1.0,
+ "sigma": 0,
+ "edgeBoundary": 2,
+ "minRelativeIntensity": 0,
+ "minAbsoluteIntensity": 8,
+ "minPeakSpacing": 4,
+ "subpixel": "poly",
+ "maxNumPeaks": 1000,
+ # 'CUDA': True,
+ }
+
+ # find disks
+ cls.braggpeaks = datacube.find_Bragg_disks(
+ template=probe.kernel,
+ **cls.detect_params,
+ )
+
+ # set an arbitrary center for testing
+ cls.braggpeaks.calibration.set_origin(
+ (datacube.Qshape[0] / 2, datacube.Qshape[1] / 2)
+ )
+
+ # tests
+
+ def test_BraggVectors_import(self):
+ from py4DSTEM.braggvectors import BraggVectors # noqa: F401
+
+ pass
+
+ def test_disk_detection_selected_positions(self):
+ rxs = 36, 15, 11, 59, 32, 34
+ rys = (
+ 9,
+ 15,
+ 31,
+ 39,
+ 20,
+ 68,
+ )
+
+ disks_selected = self.datacube.find_Bragg_disks( # noqa: F841
+ data=(rxs, rys),
+ template=self.probe.kernel,
+ **self.detect_params,
+ )
+
+ def test_BraggVectors(self):
+ print(self.braggpeaks)
+ print()
+ print(self.braggpeaks.raw[0, 0])
+ print()
+ print(self.braggpeaks.cal[0, 0])
+ print()
+ print(
+ self.braggpeaks.get_vectors(
+ scan_x=5,
+ scan_y=5,
+ center=True,
+ ellipse=False,
+ pixel=False,
+ rotate=False,
+ )
+ )
diff --git a/test/test_calibration.py b/test/test_calibration.py
new file mode 100644
index 000000000..78a24d800
--- /dev/null
+++ b/test/test_calibration.py
@@ -0,0 +1,59 @@
+import py4DSTEM
+from py4DSTEM import Calibration
+import numpy as np
+from os import mkdir, remove, rmdir
+from os.path import join, exists
+
+# set filepaths
+path_datacube = join(py4DSTEM._TESTPATH, "small_datacube.dm4")
+path_3Darray = join(py4DSTEM._TESTPATH, "test_io/small_dm3_3Dstack.dm3")
+
+path_out_dir = join(py4DSTEM._TESTPATH, "test_outputs")
+path_out = join(path_out_dir, "test_calibration.h5")
+
+
+class TestCalibration:
+ # setup
+
+ def setup_class(cls):
+ if not exists(path_out_dir):
+ mkdir(path_out_dir)
+
+ def teardown_class(cls):
+ if exists(path_out_dir):
+ rmdir(path_out_dir)
+
+ def teardown_method(self):
+ if exists(path_out):
+ remove(path_out)
+
+ # test
+
+ def test_imported_datacube_calibration(self):
+ datacube = py4DSTEM.import_file(path_datacube)
+
+ assert hasattr(datacube, "calibration")
+ assert isinstance(datacube.calibration, Calibration)
+ assert hasattr(datacube, "root")
+ assert isinstance(datacube.root, py4DSTEM.Root)
+
+ def test_instantiated_datacube_calibration(self):
+ datacube = py4DSTEM.DataCube(data=np.ones((4, 8, 128, 128)))
+
+ assert hasattr(datacube, "calibration")
+ assert isinstance(datacube.calibration, Calibration)
+ assert hasattr(datacube, "root")
+ assert isinstance(datacube.root, py4DSTEM.Root)
+
+ datacube.calibration.set_Q_pixel_size(10)
+
+ py4DSTEM.save(path_out, datacube)
+
+ new_datacube = py4DSTEM.read(path_out)
+
+ assert hasattr(new_datacube, "calibration")
+ assert isinstance(new_datacube.calibration, Calibration)
+ assert hasattr(new_datacube, "root")
+ assert isinstance(new_datacube.root, py4DSTEM.Root)
+
+ assert new_datacube.calibration.get_Q_pixel_size() == 10
diff --git a/test/test_crystal.py b/test/test_crystal.py
new file mode 100644
index 000000000..c36837ba5
--- /dev/null
+++ b/test/test_crystal.py
@@ -0,0 +1,23 @@
+# from py4DSTEM.classes import (
+# Crystal
+# )
+
+
+class TestCrystal:
+ def setup_cls(self):
+ pass
+
+ def teardown_cls(self):
+ pass
+
+ def setup_method(self):
+ pass
+
+ def teardown_method(self):
+ pass
+
+ def test_Crystal(self):
+ # crystal = Crystal( **args )
+ # assert(isinstance(crystal,Crystal))
+
+ pass
diff --git a/test/test_datacube.py b/test/test_datacube.py
new file mode 100644
index 000000000..67f419014
--- /dev/null
+++ b/test/test_datacube.py
@@ -0,0 +1,31 @@
+import py4DSTEM
+import numpy as np
+
+# set filepath
+path = py4DSTEM._TESTPATH + "/small_datacube.dm4"
+
+
+class TestDataCube:
+ # setup/teardown
+ def setup_class(cls):
+ # Read datacube
+ datacube = py4DSTEM.import_file(path)
+ cls.datacube = datacube
+
+ # tests
+
+ def test_binning_default_dtype(self):
+ dtype = self.datacube.data.dtype
+ assert dtype == np.uint16
+
+ self.datacube.bin_Q(2)
+
+ assert self.datacube.data.dtype == dtype
+
+ new_dtype = np.uint32
+ self.datacube.bin_Q(2, dtype=new_dtype)
+
+ assert self.datacube.data.dtype == new_dtype
+ assert self.datacube.data.dtype != dtype
+
+ pass
diff --git a/test/test_import.py b/test/test_import.py
new file mode 100644
index 000000000..fbaa4285d
--- /dev/null
+++ b/test/test_import.py
@@ -0,0 +1,7 @@
+# test import
+
+
+def test_import():
+ import py4DSTEM
+
+ py4DSTEM.__version__
diff --git a/test/test_misc.py b/test/test_misc.py
new file mode 100644
index 000000000..aafb2de22
--- /dev/null
+++ b/test/test_misc.py
@@ -0,0 +1,24 @@
+import py4DSTEM
+import numpy as np
+
+
+def test_attach():
+ """tests to make sure Data.attach handles metadata merging correctly"""
+
+ x = py4DSTEM.DiffractionSlice(np.ones((5, 5)), name="x")
+ y = py4DSTEM.DiffractionSlice(np.ones((5, 5)), name="y")
+
+ x.calibration.set_Q_pixel_size(50)
+ y.calibration.set_Q_pixel_size(2)
+
+ x.attach(y)
+
+ assert "y" in x.treekeys
+ assert x.calibration.get_Q_pixel_size() == 50
+
+
+def test_datacube_copy():
+ """tests datacube.copy()"""
+ x = py4DSTEM.DataCube(data=np.zeros((3, 3, 4, 4)))
+ y = x.copy()
+ assert isinstance(y, py4DSTEM.DataCube)
diff --git a/test/test_native_io/test_calibration_io.py b/test/test_native_io/test_calibration_io.py
new file mode 100644
index 000000000..4bf1fbb3c
--- /dev/null
+++ b/test/test_native_io/test_calibration_io.py
@@ -0,0 +1,36 @@
+# import py4DSTEM
+# import numpy as np
+# from os.path import join
+
+# set filepath
+# path = join(py4DSTEM._TESTPATH, "filename")
+
+
+# class TestCalibrationIO:
+#
+#
+#
+# def test_datacube_cal_io(self):
+# # TODO
+# # make a datacube
+# # modify its calibration
+# # save
+# # load the datacube
+# # check its calibration
+# assert 0
+# pass
+#
+#
+# def test_datacube_child_node(self):
+# # TODO
+# # make a datacube
+# # make a child node
+# # confirm calibrations are the same
+# # modify the calibration
+# # save
+# # load the datacube
+# # check its calibration
+# # load just the child node
+# # check its calibration
+# assert 0
+# pass
diff --git a/test/test_native_io/test_listwrite.py b/test/test_native_io/test_listwrite.py
new file mode 100644
index 000000000..e90d2ccd8
--- /dev/null
+++ b/test/test_native_io/test_listwrite.py
@@ -0,0 +1,29 @@
+import py4DSTEM
+import numpy as np
+
+# filepath
+from os import getcwd, remove
+from os.path import join, exists
+
+path = join(getcwd(), "test.h5")
+
+
+def test_listwrite():
+ # make two arrays
+ ar1 = py4DSTEM.RealSlice(data=np.arange(24).reshape((2, 3, 4)), name="array1")
+ ar2 = py4DSTEM.RealSlice(data=np.arange(48).reshape((4, 3, 4)), name="array2")
+
+ # save them
+ py4DSTEM.save(filepath=path, data=[ar1, ar2], mode="o")
+
+ # read them
+ data1 = py4DSTEM.read(path, datapath="array1_root")
+ data2 = py4DSTEM.read(path, datapath="array2_root")
+
+ # check
+ assert np.array_equal(data1.data, ar1.data)
+ assert np.array_equal(data2.data, ar2.data)
+
+ # delete the file
+ if exists(path):
+ remove(path)
diff --git a/test/test_native_io/test_realslice_read.py b/test/test_native_io/test_realslice_read.py
new file mode 100644
index 000000000..f58a54435
--- /dev/null
+++ b/test/test_native_io/test_realslice_read.py
@@ -0,0 +1,13 @@
+# Test reading realslices in v13
+
+
+import py4DSTEM
+from os.path import join
+
+
+# Set filepaths
+filepath = join(py4DSTEM._TESTPATH, "test_io/test_realslice_io.h5")
+
+
+def test_read_realslice():
+ realslice = py4DSTEM.read(filepath, datapath="4DSTEM/Fit Data") # noqa: F841
diff --git a/test/test_native_io/test_single_object_io.py b/test/test_native_io/test_single_object_io.py
new file mode 100644
index 000000000..2d7791435
--- /dev/null
+++ b/test/test_native_io/test_single_object_io.py
@@ -0,0 +1,222 @@
+import numpy as np
+from os.path import join
+from numpy import array_equal
+
+import py4DSTEM
+from py4DSTEM import save, read
+
+from py4DSTEM import (
+ DiffractionSlice,
+ RealSlice,
+ QPoints,
+ DataCube,
+ VirtualImage,
+ VirtualDiffraction,
+ BraggVectors,
+ Probe,
+)
+
+# Set paths
+dirpath = py4DSTEM._TESTPATH
+path_dm3 = join(dirpath, "test_io/small_dm3_3Dstack.dm3")
+path_h5 = join(dirpath, "test.h5")
+
+
+class TestDataCubeIO:
+ def test_datacube_instantiation(self):
+ """
+ Instantiate a datacube and apply basic calibrations
+ """
+ datacube = DataCube(data=np.arange(np.prod((4, 5, 6, 7))).reshape((4, 5, 6, 7)))
+ # calibration
+ datacube.calibration.set_Q_pixel_size(0.062)
+ datacube.calibration.set_Q_pixel_units("A^-1")
+ datacube.calibration.set_R_pixel_size(2.8)
+ datacube.calibration.set_R_pixel_units("nm")
+
+ return datacube
+
+ def test_datacube_io(self):
+ """
+ Instantiate, save, then read a datacube, and
+ compare its contents before/after
+ """
+ datacube = self.test_datacube_instantiation()
+
+ assert isinstance(datacube, DataCube)
+ # test dim vectors
+ assert datacube.dim_names[0] == "Rx"
+ assert datacube.dim_names[1] == "Ry"
+ assert datacube.dim_names[2] == "Qx"
+ assert datacube.dim_names[3] == "Qy"
+ assert datacube.dim_units[0] == "nm"
+ assert datacube.dim_units[1] == "nm"
+ assert datacube.dim_units[2] == "A^-1"
+ assert datacube.dim_units[3] == "A^-1"
+ assert datacube.dims[0][1] == 2.8
+ assert datacube.dims[2][1] == 0.062
+ # check the calibrations
+ assert datacube.calibration.get_Q_pixel_size() == 0.062
+ assert datacube.calibration.get_Q_pixel_units() == "A^-1"
+ # save and read
+ save(path_h5, datacube, mode="o")
+ new_datacube = read(path_h5)
+ # check it's the same
+ assert isinstance(new_datacube, DataCube)
+ assert array_equal(datacube.data, new_datacube.data)
+ assert new_datacube.calibration.get_Q_pixel_size() == 0.062
+ assert new_datacube.calibration.get_Q_pixel_units() == "A^-1"
+ assert new_datacube.dims[0][1] == 2.8
+ assert new_datacube.dims[2][1] == 0.062
+
+
+class TestBraggVectorsIO:
+ def test_braggvectors_instantiation(self):
+ """
+ Instantiate a braggvectors instance
+ """
+ braggvectors = BraggVectors(Rshape=(5, 6), Qshape=(7, 8))
+ for x in range(braggvectors.Rshape[0]):
+ for y in range(braggvectors.Rshape[1]):
+ L = int(4 * (np.sin(x * y) + 1))
+ braggvectors._v_uncal[x, y].add(
+ np.ones(L, dtype=braggvectors._v_uncal.dtype)
+ )
+ return braggvectors
+
+ def test_braggvectors_io(self):
+ """Save then read a BraggVectors instance, and compare contents before/after"""
+ braggvectors = self.test_braggvectors_instantiation()
+
+ assert isinstance(braggvectors, BraggVectors)
+ # save then read
+ save(path_h5, braggvectors, mode="o")
+ new_braggvectors = read(path_h5)
+ # check it's the same
+ assert isinstance(new_braggvectors, BraggVectors)
+ assert new_braggvectors is not braggvectors
+ for x in range(new_braggvectors.shape[0]):
+ for y in range(new_braggvectors.shape[1]):
+ assert array_equal(
+ new_braggvectors._v_uncal[x, y].data,
+ braggvectors._v_uncal[x, y].data,
+ )
+
+
+class TestSlices:
+ # test instantiation
+
+ def test_diffractionslice_instantiation(self):
+ diffractionslice = DiffractionSlice(
+ data=np.arange(np.prod((4, 8, 2))).reshape((4, 8, 2)),
+ slicelabels=["a", "b"],
+ )
+ return diffractionslice
+
+ def test_realslice_instantiation(self):
+ realslice = RealSlice(
+ data=np.arange(np.prod((8, 4, 2))).reshape((8, 4, 2)),
+ slicelabels=["x", "y"],
+ )
+ return realslice
+
+ def test_virtualdiffraction_instantiation(self):
+ virtualdiffraction = VirtualDiffraction(
+ data=np.arange(np.prod((8, 4, 2))).reshape((8, 4, 2)),
+ )
+ return virtualdiffraction
+
+ def test_virtualimage_instantiation(self):
+ virtualimage = VirtualImage(
+ data=np.arange(np.prod((8, 4, 2))).reshape((8, 4, 2)),
+ )
+ return virtualimage
+
+ def test_probe_instantiation(self):
+ probe = Probe(data=np.arange(8 * 12).reshape((8, 12)))
+ # add a kernel
+ probe.kernel = np.ones_like(probe.probe)
+ # return
+ return probe
+
+ # test io
+
+ def test_diffractionslice_io(self):
+ """test diffractionslice io"""
+ diffractionslice = self.test_diffractionslice_instantiation()
+ assert isinstance(diffractionslice, DiffractionSlice)
+ # save and read
+ save(path_h5, diffractionslice, mode="o")
+ new_diffractionslice = read(path_h5)
+ # check it's the same
+ assert isinstance(new_diffractionslice, DiffractionSlice)
+ assert array_equal(diffractionslice.data, new_diffractionslice.data)
+ assert diffractionslice.slicelabels == new_diffractionslice.slicelabels
+
+ def test_realslice_io(self):
+ """test realslice io"""
+ realslice = self.test_realslice_instantiation()
+ assert isinstance(realslice, RealSlice)
+ # save and read
+ save(path_h5, realslice, mode="o")
+ rs = read(path_h5)
+ # check it's the same
+ assert isinstance(rs, RealSlice)
+ assert array_equal(realslice.data, rs.data)
+ assert rs.slicelabels == realslice.slicelabels
+
+ def test_virtualdiffraction_io(self):
+ """test virtualdiffraction io"""
+ virtualdiffraction = self.test_virtualdiffraction_instantiation()
+ assert isinstance(virtualdiffraction, VirtualDiffraction)
+ # save and read
+ save(path_h5, virtualdiffraction, mode="o")
+ vd = read(path_h5)
+ # check it's the same
+ assert isinstance(vd, VirtualDiffraction)
+ assert array_equal(vd.data, virtualdiffraction.data)
+ pass
+
+ def test_virtualimage_io(self):
+ """test virtualimage io"""
+ virtualimage = self.test_virtualimage_instantiation()
+ assert isinstance(virtualimage, VirtualImage)
+ # save and read
+ save(path_h5, virtualimage, mode="o")
+ virtIm = read(path_h5)
+ # check it's the same
+ assert isinstance(virtIm, VirtualImage)
+ assert array_equal(virtualimage.data, virtIm.data)
+ pass
+
+ def test_probe1_io(self):
+ """test probe io"""
+ probe0 = self.test_probe_instantiation()
+ assert isinstance(probe0, Probe)
+ # save and read
+ save(path_h5, probe0, mode="o")
+ probe = read(path_h5)
+ # check it's the same
+ assert isinstance(probe, Probe)
+ assert array_equal(probe0.data, probe.data)
+ pass
+
+
+class TestPoints:
+ def test_qpoints_instantiation(self):
+ qpoints = QPoints(
+ data=np.ones(10, dtype=[("qx", float), ("qy", float), ("intensity", float)])
+ )
+ return qpoints
+
+ def test_qpoints_io(self):
+ """test qpoints io"""
+ qpoints0 = self.test_qpoints_instantiation()
+ assert isinstance(qpoints0, QPoints)
+ # save and read
+ save(path_h5, qpoints0, mode="o")
+ qpoints = read(path_h5)
+ # check it's the same
+ assert isinstance(qpoints, QPoints)
+ assert array_equal(qpoints0.data, qpoints.data)
+ pass
diff --git a/test/test_native_io/test_v0_13.py b/test/test_native_io/test_v0_13.py
new file mode 100644
index 000000000..e1d91bba4
--- /dev/null
+++ b/test/test_native_io/test_v0_13.py
@@ -0,0 +1,32 @@
+from py4DSTEM import read, print_h5_tree, _TESTPATH
+from os.path import join
+
+
+# Set filepaths
+filepath = join(_TESTPATH, "test_io/legacy_v0.13.h5")
+
+
+class TestV13:
+ # setup/teardown
+ def setup_class(cls):
+ cls.path = filepath
+ pass
+
+ @classmethod
+ def teardown_class(cls):
+ pass
+
+ def setup_method(self, method):
+ pass
+
+ def teardown_method(self, method):
+ pass
+
+ def test_print_tree(self):
+ print_h5_tree(self.path)
+
+ def test_read(self):
+ d = read(
+ self.path,
+ )
+ d
diff --git a/test/test_native_io/test_v0_14.py b/test/test_native_io/test_v0_14.py
new file mode 100644
index 000000000..ef7df120b
--- /dev/null
+++ b/test/test_native_io/test_v0_14.py
@@ -0,0 +1,120 @@
+import py4DSTEM
+from os.path import join, exists
+
+
+path = join(py4DSTEM._TESTPATH, "test_io/legacy_v0.14.h5")
+
+
+def _make_v14_test_file():
+ # enforce v14
+ assert py4DSTEM.__version__.split(".")[1] == "14", "no!"
+
+ # Set filepaths
+ filepath_data = join(
+ py4DSTEM._TESTPATH, "test_io/legacy_v0.9_simAuNanoplatelet_bin.h5"
+ )
+
+ # Read sim Au datacube
+ datacube = py4DSTEM.io.read(filepath_data, data_id="polyAu_4DSTEM")
+
+ # # Virtual diffraction
+
+ # Get mean and max DPs
+ datacube.get_dp_mean()
+ datacube.get_dp_max()
+
+ # # Disk detection
+
+ # find a vacuum region
+ import numpy as np
+
+ mask = np.zeros(datacube.Rshape, dtype=bool)
+ mask[28:33, 14:19] = 1
+
+ # generate a probe
+ probe = datacube.get_vacuum_probe(ROI=mask)
+
+ # Find the center and semiangle
+ alpha, qx0, qy0 = py4DSTEM.process.probe.get_probe_size(probe.probe)
+
+ # prepare the probe kernel
+ kern = probe.get_kernel(
+ mode="sigmoid", origin=(qx0, qy0), radii=(alpha, 2 * alpha)
+ ) # noqa: F841
+
+ # Set disk detection parameters
+ detect_params = {
+ "corrPower": 1.0,
+ "sigma": 0,
+ "edgeBoundary": 2,
+ "minRelativeIntensity": 0,
+ "minAbsoluteIntensity": 8,
+ "minPeakSpacing": 4,
+ "subpixel": "poly",
+ "maxNumPeaks": 1000,
+ # 'CUDA': True,
+ }
+
+ # compute
+ braggpeaks = datacube.find_Bragg_disks( # noqa: F841
+ template=probe.kernel,
+ **detect_params,
+ )
+
+ # # Virtual Imaging
+
+ # set geometries
+ geo_bf = ((qx0, qy0), alpha + 6)
+ geo_df = ((qx0, qy0), (3 * alpha, 6 * alpha))
+
+ # bright field
+ datacube.get_virtual_image(
+ mode="circle",
+ geometry=geo_bf,
+ name="bright_field",
+ )
+
+ # dark field
+ datacube.get_virtual_image(mode="annulus", geometry=geo_df, name="dark_field")
+
+ # # Write
+
+ py4DSTEM.save(path, datacube, tree=None, mode="o")
+
+
+class TestV14:
+ # setup/teardown
+ def setup_class(cls):
+ if not (exists(path)):
+ print("no test file for v14 found")
+ if py4DSTEM.__version__.split(".")[1] == "14":
+ print("v14 detected. writing new test file...")
+ _make_v14_test_file()
+ else:
+ raise Exception(
+ f"No v14 testfile was found at path {path}, and a new one can't be written with this py4DSTEM version {py4DSTEM.__version__}"
+ )
+
+ @classmethod
+ def teardown_class(cls):
+ pass
+
+ def setup_method(self, method):
+ pass
+
+ def teardown_method(self, method):
+ pass
+
+ def test_meowth(self):
+ # py4DSTEM.print_h5_tree(path)
+ data = py4DSTEM.read(path)
+ data.tree()
+
+ assert isinstance(data.tree("braggvectors"), py4DSTEM.BraggVectors)
+ assert isinstance(data.tree("bright_field"), py4DSTEM.VirtualImage)
+ assert isinstance(data.tree("dark_field"), py4DSTEM.VirtualImage)
+ assert isinstance(data.tree("dp_max"), py4DSTEM.VirtualDiffraction)
+ assert isinstance(data.tree("dp_mean"), py4DSTEM.VirtualDiffraction)
+ assert isinstance(data.tree("probe"), py4DSTEM.Probe)
+
+ pass
diff --git a/test/test_native_io/test_v0_9.py b/test/test_native_io/test_v0_9.py
new file mode 100644
index 000000000..6ebea1661
--- /dev/null
+++ b/test/test_native_io/test_v0_9.py
@@ -0,0 +1,14 @@
+from py4DSTEM import read, DataCube, _TESTPATH
+from os.path import join
+
+path = join(_TESTPATH, "test_io/legacy_v0.9_simAuNanoplatelet_bin.h5")
+
+
+def test_read_v0_9_noID():
+ d = read(path)
+ d
+
+
+def test_read_v0_9_withID():
+ d = read(path, data_id="polyAu_4DSTEM")
+ assert isinstance(d, DataCube)
diff --git a/test/test_nonnative_io/test_arina.py b/test/test_nonnative_io/test_arina.py
new file mode 100644
index 000000000..c02964cf8
--- /dev/null
+++ b/test/test_nonnative_io/test_arina.py
@@ -0,0 +1,16 @@
+import py4DSTEM
+import emdfile
+from os.path import join
+
+
+# Set filepaths
+filepath = join(py4DSTEM._TESTPATH, "test_arina/STO_STEM_bench_20us_master.h5")
+
+
+def test_read_arina():
+ # read
+ data = py4DSTEM.import_file(filepath)
+
+ # check imported data
+ assert isinstance(data, emdfile.Array)
+ assert isinstance(data, py4DSTEM.DataCube)
diff --git a/test/test_nonnative_io/test_dm.py b/test/test_nonnative_io/test_dm.py
new file mode 100644
index 000000000..ee6f1b2eb
--- /dev/null
+++ b/test/test_nonnative_io/test_dm.py
@@ -0,0 +1,24 @@
+import py4DSTEM
+import emdfile
+from os.path import join
+
+
+# Set filepaths
+filepath_dm4_datacube = join(py4DSTEM._TESTPATH, "small_datacube.dm4")
+filepath_dm3_3Dstack = join(py4DSTEM._TESTPATH, "test_io/small_dm3_3Dstack.dm3")
+
+
+def test_dmfile_datacube():
+ data = py4DSTEM.import_file(filepath_dm4_datacube)
+ assert isinstance(data, emdfile.Array)
+ assert isinstance(data, py4DSTEM.DataCube)
+
+
+def test_dmfile_3Darray():
+ data = py4DSTEM.import_file(filepath_dm3_3Dstack)
+ assert isinstance(data, emdfile.Array)
+
+
+# TODO
+# def test_dmfile_multiple_datablocks():
+# def test_dmfile_2Darray
diff --git a/test/test_probe.py b/test/test_probe.py
new file mode 100644
index 000000000..d55e0dd3f
--- /dev/null
+++ b/test/test_probe.py
@@ -0,0 +1,51 @@
+import py4DSTEM
+from py4DSTEM import Probe
+import numpy as np
+
+# set filepath
+path = py4DSTEM._TESTPATH + "/small_datacube.dm4"
+
+
+class TestProbe:
+ # setup/teardown
+ def setup_class(cls):
+ # Read datacube
+ datacube = py4DSTEM.import_file(path)
+ cls.datacube = datacube
+
+ # tests
+
+ def test_probe_gen_from_dp(self):
+ p = Probe.from_vacuum_data(self.datacube[0, 0])
+ assert isinstance(p, Probe)
+ pass
+
+ def test_probe_gen_from_stack(self):
+ # get a 3D stack
+ x, y = np.zeros(10).astype(int), np.arange(10).astype(int)
+ data = self.datacube.data[x, y, :, :]
+ # get the probe
+ p = Probe.from_vacuum_data(data)
+ assert isinstance(p, Probe)
+ pass
+
+ def test_probe_gen_from_datacube_ROI_1(self):
+ ROI = np.zeros(self.datacube.Rshape, dtype=bool)
+ ROI[3:7, 5:10] = True
+ p = self.datacube.get_vacuum_probe(ROI)
+ assert isinstance(p, Probe)
+
+ self.datacube.tree()
+ self.datacube.tree(True)
+ _p = self.datacube.tree("probe")
+ print(_p)
+
+ assert p is self.datacube.tree("probe")
+ pass
+
+ def test_probe_gen_from_datacube_ROI_2(self):
+ ROI = (3, 7, 5, 10)
+ p = self.datacube.get_vacuum_probe(ROI)
+ assert isinstance(p, Probe)
+ assert p is self.datacube.tree("probe")
+ pass
diff --git a/test/test_strain.py b/test/test_strain.py
new file mode 100644
index 000000000..e2be240c4
--- /dev/null
+++ b/test/test_strain.py
@@ -0,0 +1,26 @@
+import py4DSTEM
+from py4DSTEM import StrainMap
+from os.path import join
+
+
+# set filepath
+path = join(py4DSTEM._TESTPATH, "strain/downsample_Si_SiGe_analysis_braggdisks_cal.h5")
+
+
+class TestStrainMap:
+ # setup/teardown
+ def setup_class(cls):
+ # Read braggpeaks
+ # origin is calibrated
+ cls.braggpeaks = py4DSTEM.io.read(path)
+
+ # tests
+
+ def test_strainmap_instantiation(self):
+ strainmap = StrainMap(
+ braggvectors=self.braggpeaks,
+ )
+
+ assert isinstance(strainmap, StrainMap)
+ assert strainmap.calibration is not None
+ assert strainmap.calibration is strainmap.braggvectors.calibration
diff --git a/test/test_workflow/test_basics.py b/test/test_workflow/test_basics.py
new file mode 100644
index 000000000..0c838707a
--- /dev/null
+++ b/test/test_workflow/test_basics.py
@@ -0,0 +1,67 @@
+import py4DSTEM
+from os.path import join
+
+# set filepath
+path = join(py4DSTEM._TESTPATH, "test_io/legacy_v0.9_simAuNanoplatelet_bin.h5")
+
+
+class TestBasics:
+ # setup/teardown
+ def setup_class(cls):
+ # Read sim Au datacube
+ datacube = py4DSTEM.io.read(path, data_id="polyAu_4DSTEM")
+ cls.datacube = datacube
+
+ # get center and probe radius
+ datacube.get_dp_mean()
+ alpha, qx0, qy0 = datacube.get_probe_size()
+ cls.alpha = alpha
+ cls.qx0, cls.qy0 = qx0, qy0
+
+ # tests
+
+ def test_get_dp(self):
+ dp = self.datacube[10, 30]
+ dp
+
+ def test_show(self):
+ dp = self.datacube[10, 30]
+ py4DSTEM.visualize.show(dp)
+
+ # virtual diffraction and imaging
+
+ def test_virt_diffraction(self):
+ dp_mean = self.datacube.get_dp_mean() # noqa: F841
+ self.datacube.get_dp_max()
+
+ def test_virt_imaging_bf(self):
+ geo = ((self.qx0, self.qy0), self.alpha + 3)
+
+ # position detector
+ self.datacube.position_detector(
+ mode="circle",
+ geometry=geo,
+ )
+
+ # compute
+ self.datacube.get_virtual_image(
+ mode="circle",
+ geometry=geo,
+ name="bright_field",
+ )
+
+ def test_virt_imaging_adf(self):
+ geo = ((self.qx0, self.qy0), (3 * self.alpha, 6 * self.alpha))
+
+ # position detector
+ self.datacube.position_detector(
+ mode="annulus",
+ geometry=geo,
+ )
+
+ # compute
+ self.datacube.get_virtual_image(
+ mode="annulus",
+ geometry=geo,
+ name="annular_dark_field",
+ )
diff --git a/test/test_workflow/test_disk_detection_basic.py b/test/test_workflow/test_disk_detection_basic.py
new file mode 100644
index 000000000..be9d8a1c8
--- /dev/null
+++ b/test/test_workflow/test_disk_detection_basic.py
@@ -0,0 +1,64 @@
+import py4DSTEM
+from os.path import join
+from numpy import zeros
+
+# set filepath
+path = join(py4DSTEM._TESTPATH, "test_io/legacy_v0.9_simAuNanoplatelet_bin.h5")
+
+
+class TestDiskDetectionBasic:
+ # setup/teardown
+ def setup_class(cls):
+ # Read sim Au datacube
+ datacube = py4DSTEM.io.read(path, data_id="polyAu_4DSTEM")
+ cls.datacube = datacube
+
+ # prepare a probe
+ mask = zeros(datacube.Rshape, dtype=bool)
+ mask[28:33, 14:19] = 1
+ probe = datacube.get_vacuum_probe(ROI=mask)
+ alpha_pr, qx0_pr, qy0_pr = py4DSTEM.process.calibration.get_probe_size(
+ probe.probe
+ )
+ probe.get_kernel(
+ mode="sigmoid", origin=(qx0_pr, qy0_pr), radii=(alpha_pr, 2 * alpha_pr)
+ )
+ cls.probe = probe
+
+ # Set disk detection parameters
+ cls.detect_params = {
+ "corrPower": 1.0,
+ "sigma": 0,
+ "edgeBoundary": 2,
+ "minRelativeIntensity": 0,
+ "minAbsoluteIntensity": 8,
+ "minPeakSpacing": 4,
+ "subpixel": "poly",
+ "maxNumPeaks": 1000,
+ # 'CUDA': True,
+ }
+
+ # tests
+
+ def test_disk_detection_selected_positions(self):
+ rxs = 36, 15, 11, 59, 32, 34
+ rys = (
+ 9,
+ 15,
+ 31,
+ 39,
+ 20,
+ 68,
+ )
+
+ disks_selected = self.datacube.find_Bragg_disks( # noqa: F841
+ data=(rxs, rys),
+ template=self.probe.kernel,
+ **self.detect_params,
+ )
+
+ def test_disk_detection(self):
+ braggpeaks = self.datacube.find_Bragg_disks( # noqa: F841
+ template=self.probe.kernel,
+ **self.detect_params,
+ )
diff --git a/test/test_workflow/test_disk_detection_with_calibration.py b/test/test_workflow/test_disk_detection_with_calibration.py
new file mode 100644
index 000000000..f3892b4bc
--- /dev/null
+++ b/test/test_workflow/test_disk_detection_with_calibration.py
@@ -0,0 +1,64 @@
+import py4DSTEM
+from os.path import join
+from numpy import zeros
+
+# set filepath
+path = join(py4DSTEM._TESTPATH, "test_io/legacy_v0.9_simAuNanoplatelet_bin.h5")
+
+
+class TestDiskDetectionWithCalibration:
+ # setup/teardown
+ def setup_class(cls):
+ # Read sim Au datacube
+ datacube = py4DSTEM.io.read(path, data_id="polyAu_4DSTEM")
+ cls.datacube = datacube
+
+ # prepare a probe
+ mask = zeros(datacube.Rshape, dtype=bool)
+ mask[28:33, 14:19] = 1
+ probe = datacube.get_vacuum_probe(ROI=mask)
+ alpha_pr, qx0_pr, qy0_pr = py4DSTEM.process.calibration.get_probe_size(
+ probe.probe
+ )
+ probe.get_kernel(
+ mode="sigmoid", origin=(qx0_pr, qy0_pr), radii=(alpha_pr, 2 * alpha_pr)
+ )
+ cls.probe = probe
+
+ # Set disk detection parameters
+ cls.detect_params = {
+ "corrPower": 1.0,
+ "sigma": 0,
+ "edgeBoundary": 2,
+ "minRelativeIntensity": 0,
+ "minAbsoluteIntensity": 8,
+ "minPeakSpacing": 4,
+ "subpixel": "poly",
+ "maxNumPeaks": 1000,
+ # 'CUDA': True,
+ }
+
+ # tests
+
+ def test_disk_detection(self):
+ braggpeaks = self.datacube.find_Bragg_disks( # noqa: F841
+ template=self.probe.kernel,
+ **self.detect_params,
+ )
+
+ # calibrate center
+
+ # calibrate ellipse
+
+ # calibrate pixel
+
+ # calibrate rotation
+
+ # show
+
+ # save
+
+ # load
+
+ # check loaded data
+ # check loaded cali
diff --git a/utils.py b/utils.py
deleted file mode 100644
index a44119305..000000000
--- a/utils.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import os
-from PySide2 import QtCore, QtGui, QtUiTools
-import pyqtgraph as pg
-
-
-def sibling_path(fpath, fname):
- """
- Given a file with absolute path fpath, returns the absolute path to another file with name
- fname in the same directory.
- """
- return os.path.join(os.path.dirname(fpath), fname)
-
-
-def load_qt_ui_file(ui_filename):
- """
- Loads a ui file specifying a user interface configuration
- """
- ui_loader = QtUiTools.QUiLoader()
- ui_file = QtCore.QFile(ui_filename)
- ui_file.open(QtCore.QFile.ReadOnly)
- ui = ui_loader.load(ui_file)
- ui_file.close()
- return ui
-
-def pg_point_roi(view_box):
- """
- Point selection. Based in pyqtgraph, and returns a pyqtgraph CircleROI object.
- This object has a sigRegionChanged.connect() signal method to connect to other functions.
- """
- circ_roi = pg.CircleROI( (0,0), (2,2), movable=True, pen=(0,9))
- h = circ_roi.addTranslateHandle((0.5,0.5))
- h.pen = pg.mkPen('r')
- h.update()
- view_box.addItem(circ_roi)
- circ_roi.removeHandle(0)
- return circ_roi
-
-
-
-
diff --git a/viewer.py b/viewer.py
deleted file mode 100644
index 341ad8d15..000000000
--- a/viewer.py
+++ /dev/null
@@ -1,328 +0,0 @@
-######## Viewer for 4D STEM data ########
-#
-# Defines a class -- DataViewer - enabling a simple GUI for
-# interacting with 4D STEM datasets.
-#
-# Relevant documentation for lower level code:
-#
-# ScopeFoundry
-# ScopeFoundry is a flexible package for both scientific data visualization and control of labrotory experiments. See http://www.scopefoundry.org/. This code uses the ScopeFoundary object
-# LQCollection, which enables intelligent interactive storage of logged quantities.
-#
-# Qt
-# Qt is being run through Pyside/PySide2/PyQt/Qt for Python. See https://www.qt.io/qt-for-python. Presently PySide is being used.
-# TODO: (maybe) use PySide2 (moves some objects from QtGui to the newer QtWidgets. Or (maybe)
-# use qtpy, a small wrapper which supports systems with either PySide or PySide2 (basically, for
-# python 2 or 3).
-#
-# pyqtgraph
-# pyqtgraph is a library which facilitates fast-running scientific visualization. See http://pyqtgraph.org/. pyqtgraph is being used for the final data displays.
-
-
-from __future__ import division, print_function
-from PySide2 import QtCore, QtWidgets
-import numpy as np
-import sys, os
-from ScopeFoundry import BaseApp, LQCollection
-from utils import load_qt_ui_file, sibling_path, pg_point_roi
-import pyqtgraph as pg
-import dm3_lib as dm3
-from control_panel import ControlPanel
-from datacube import DataCube
-
-import IPython
-if IPython.version_info[0] < 4:
- from IPython.qt.console.rich_ipython_widget import RichIPythonWidget as RichJupyterWidget
- from IPython.qt.inprocess import QtInProcessKernelManager
-else:
- from qtconsole.rich_jupyter_widget import RichJupyterWidget
- from qtconsole.inprocess import QtInProcessKernelManager
-
-
-class DataViewer(QtCore.QObject):
- """
- DataViewer objects inherit from the ScopeFoundry.BaseApp class.
- ScopeFoundry.BaseApp objects inherit from the QtCore.QObject class.
- Additional functionality is provided by pyqtgraph widgets.
-
- The class is used by instantiating and then entering the main Qt loop with, e.g.:
- app = DataViewer(sys.argv)
- app.exec_()
- """
- def __init__(self, argv):
- """
- Initialize class, setting up windows and widgets.
- """
- self.this_dir, self.this_filename = os.path.split(__file__)
-
- # Set a pointer referring to the application object
- self.qtapp = QtWidgets.QApplication.instance()
- if not self.qtapp:
- self.qtapp = QtWidgets.QApplication(argv)
-
- # TODO: consider removing dependency on LQCollection object
- self.settings = LQCollection()
-
- # Set up widgets
- self.setup_diffraction_space_widget()
- self.setup_real_space_widget()
- self.setup_diffraction_space_control_widget()
- self.setup_real_space_control_widget()
- self.setup_console_widget()
- self.setup_geometry()
- return
-
- ############ Setup methods #############
-
- def setup_diffraction_space_widget(self):
- """
- Set up the diffraction space window.
- """
- # Create pyqtgraph ImageView object
- self.diffraction_space_widget = pg.ImageView()
- self.diffraction_space_widget.setImage(np.random.random((512,512)))
-
- # Create virtual detector ROI selector
- self.virtual_detector_roi = pg.RectROI([256, 256], [50,50], pen=(3,9))
- self.diffraction_space_widget.getView().addItem(self.virtual_detector_roi)
- self.virtual_detector_roi.sigRegionChanged.connect(self.update_virtual_image)
-
- # Name, show, return
- self.diffraction_space_widget.setWindowTitle('Diffraction Space')
- self.diffraction_space_widget.show()
- return self.diffraction_space_widget
-
- def setup_real_space_widget(self):
- """
- Set up the real space window.
- """
- # Create pyqtgraph ImageView object
- self.real_space_widget = pg.ImageView()
- self.real_space_widget.setImage(np.random.random((512,512)))
-
- # Add point selector connected to displayed diffraction pattern
- self.real_space_point_selector = pg_point_roi(self.real_space_widget.getView())
- self.real_space_point_selector.sigRegionChanged.connect(self.update_diffraction_view)
-
- # Name, show, return
- self.real_space_widget.setWindowTitle('Real Space')
- self.real_space_widget.show()
- return self.real_space_widget
-
- def setup_diffraction_space_control_widget(self):
- """
- Set up the control window for diffraction space.
- """
- #self.diffraction_space_control_widget = load_qt_ui_file(sibling_path(__file__, "diffraction_space_control_widget.ui"))
- self.diffraction_space_control_widget = ControlPanel()
- self.diffraction_space_control_widget.setWindowTitle("Diffraction space")
- self.diffraction_space_control_widget.show()
- self.diffraction_space_control_widget.raise_()
-
- ########## Controls ##########
- # For each control:
- # -create references in self.settings
- # -connect UI changes to updates in self.settings
- # -call methods
- ##############################
-
- # File loading
- self.settings.New('data_filename',dtype='file')
- self.settings.data_filename.connect_to_browse_widgets(self.diffraction_space_control_widget.lineEdit_LoadFile, self.diffraction_space_control_widget.pushButton_BrowseFiles)
- self.settings.data_filename.updated_value.connect(self.load_file)
- #self.diffraction_space_control_widget.pushButton_LoadFile.clicked.connect(self.load_file)
-
- # Scan shape
- self.settings.New('R_Nx', dtype=int, initial=1)
- self.settings.New('R_Ny', dtype=int, initial=1)
-
- self.settings.R_Nx.updated_value.connect(self.update_scan_shape_Nx)
- self.settings.R_Ny.updated_value.connect(self.update_scan_shape_Ny)
-
- self.settings.R_Nx.connect_bidir_to_widget(self.diffraction_space_control_widget.spinBox_Nx)
- self.settings.R_Ny.connect_bidir_to_widget(self.diffraction_space_control_widget.spinBox_Ny)
-
- return self.diffraction_space_control_widget
-
- def setup_real_space_control_widget(self):
- """
- Set up the control window.
- """
- self.real_space_control_widget = load_qt_ui_file(sibling_path(__file__, "real_space_control_widget.ui"))
- self.real_space_control_widget.setWindowTitle("Real space")
- self.real_space_control_widget.show()
- self.real_space_control_widget.raise_()
- return self.real_space_control_widget
-
- def setup_console_widget(self):
- self.kernel_manager = QtInProcessKernelManager()
- self.kernel_manager.start_kernel()
- self.kernel = self.kernel_manager.kernel
- self.kernel.gui = 'qt4'
- self.kernel.shell.push({'np': np, 'app': self})
- self.kernel_client = self.kernel_manager.client()
- self.kernel_client.start_channels()
-
- self.console_widget = RichJupyterWidget()
- self.console_widget.setWindowTitle("4D-STEM IPython Console")
- self.console_widget.kernel_manager = self.kernel_manager
- self.console_widget.kernel_client = self.kernel_client
-
- self.console_widget.show()
- return self.console_widget
-
-
- def setup_geometry(self):
- """
- Arrange windows and their geometries.
- """
- self.diffraction_space_widget.setGeometry(100,0,600,600)
- self.diffraction_space_control_widget.setGeometry(0,0,350,600)
- self.real_space_widget.setGeometry(700,0,600,600)
- self.real_space_control_widget.setGeometry(1150,0,200,600)
- self.console_widget.setGeometry(0,670,1300,170)
-
- self.console_widget.raise_()
- self.real_space_control_widget.raise_()
- self.real_space_widget.raise_()
- self.diffraction_space_widget.raise_()
- self.diffraction_space_control_widget.raise_()
- return
-
- ######### Methods controlling responses to user inputs #########
-
- def load_file(self):
- """
- Loads a file by creating and storing a DataCube object
- """
- fname = self.settings.data_filename.val
- print("Loading file",fname)
-
- # Instantiate DataCube object
- self.datacube = DataCube(fname)
-
- # Update scan shape information
- self.R_N = self.datacube.R_N
- self.settings.R_Nx.update_value(1)
- self.settings.R_Ny.update_value(self.R_N)
-
- self.diffraction_space_widget.setImage(self.data_3Dflattened.swapaxes(1,2))
-
- return
-
- def update_virtual_image(self):
- roi_state = self.virtual_detector_roi.saveState()
- x0,y0 = roi_state['pos']
- slices, transforms = self.virtual_detector_roi.getArraySlice(self.data_3Dflattened, self.diffraction_space_widget.getImageItem())
- slice_x, slice_y, slice_z = slices
- self.real_space_widget.setImage(self.data4D[:,:,slice_y, slice_x].sum(axis=(2,3)).T)
- return
-
- def update_diffraction_view(self):
- roi_state = self.real_space_point_selector.saveState()
- x0,y0 = roi_state['pos']
- xc,yc = x0+1,y0+1
- stack_num = self.settings.R_Nx.val*int(yc)+int(xc)
- self.diffraction_space_widget.setCurrentIndex(stack_num)
- return
-
- def update_scan_shape_Nx(self):
- R_Nx = self.settings.R_Nx.val
- self.settings.R_Ny.update_value(int(self.R_N/R_Nx))
- R_Ny = self.settings.R_Ny.val
- try:
- self.datacube.set_scan_shape(R_Ny, R_Nx)
- except ValueError:
- pass
- if hasattr(self, "virtual_detector_roi"):
- self.update_virtual_image()
- return
-
- def update_scan_shape_Ny(self):
- R_Ny = self.settings.R_Ny.val
- self.settings.R_Nx.update_value(int(self.R_N/R_Ny))
- R_Nx = self.settings.R_Nx.val
- try:
- self.datacube.set_scan_shape(R_Ny, R_Nx)
- except ValueError:
- pass
- if hasattr(self, "virtual_detector_roi"):
- self.update_virtual_image()
- return
-
- def exec_(self):
- return self.qtapp.exec_()
-
-
-
- ####### DEPRECATED ##########
-
- #def update_scan_shape_Nx(self):
- # R_Nx = self.settings.R_Nx.val
- # self.settings.R_Ny.update_value(int(self.R_N/R_Nx))
- # R_Ny = self.settings.R_Ny.val
- # try:
- # self.data4D = self.data_3Dflattened.reshape(R_Ny,R_Nx,self.Q_Ny,self.Q_Nx)
- # except ValueError:
- # pass
- # if hasattr(self, "virtual_detector_roi"):
- # self.update_virtual_image()
- # return
-
- #def update_scan_shape_Ny(self):
- # R_Ny = self.settings.R_Ny.val
- # self.settings.R_Nx.update_value(int(self.R_N/R_Ny))
- # R_Nx = self.settings.R_Nx.val
- # try:
- # self.data4D = self.data_3Dflattened.reshape(R_Ny,R_Nx,self.Q_Ny,self.Q_Nx)
- # except ValueError:
- # pass
- # if hasattr(self, "virtual_detector_roi"):
- # self.update_virtual_image()
- # return
-
-
-
- #def on_stem_pt_roi_change(self):
- # roi_state = self.stem_pt_roi.saveState()
- # x0,y0 = roi_state['pos']
- # xc,yc = x0+1, y0+1
- # stack_num = self.settings.R_Nx.val*int(yc)+int(xc)
- # self.stack_imv.setCurrentIndex(stack_num)
-
- #def on_real_space_roi_change(self):
- # roi_state = self.real_space_roi.saveState()
- # x0,y0 = roi_state['pos']
- # slices, transforms = self.virtual_aperture_roi.getArraySlice(self.data_3Dflattened, self.stack_imv.getImageItem())
- # slice_x, slice_y, slice_z = slices
- # self.stem_imv.setImage(self.data4D[:,:,slice_y, slice_x].sum(axis=(2,3)).T)
-
- #def load_file(self):
- # fname = self.settings.data_filename.val
- # print("Loading file",fname)
- #
- # try:
- # self.dm3f = dm3.DM3(fname, debug=True)
- # self.data_3Dflattened = self.dm3f.imagedata
- # except Exception as err:
- # print("Failed to load", err)
- # self.data_3Dflattened = np.random.rand(100,512,512)
- # self.R_N, self.Q_Ny, self.Q_Nx = self.data_3Dflattened.shape
- #
- # self.diffraction_space_widget.setImage(self.data_3Dflattened.swapaxes(1,2))
- #
- # self.settings.R_Nx.update_value(1)
- # self.settings.R_Ny.update_value(self.R_N)
- # return
-
-
-############### End of class ###############
-
-
-if __name__=="__main__":
- app = DataViewer(sys.argv)
-
- sys.exit(app.exec_())
-
-
-