diff --git a/.gitignore b/.gitignore index 64707625..618c5f77 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ __pycache__ outputs/ datasets/* !datasets/sacre_coeur/ +build +hloc/third_party diff --git a/hloc/extractors/d2net.py b/hloc/extractors/d2net.py index 64af5f9e..1afc8946 100644 --- a/hloc/extractors/d2net.py +++ b/hloc/extractors/d2net.py @@ -1,11 +1,11 @@ -import sys +import os, sys from pathlib import Path import subprocess import torch from ..utils.base_model import BaseModel -d2net_path = Path(__file__).parent / '../../third_party/d2net' +d2net_path = os.path.dirname(__import__('hloc').__file__) + '/third_party/d2net' sys.path.append(str(d2net_path)) from lib.model_test import D2Net as _D2Net from lib.pyramid import process_multiscale diff --git a/hloc/extractors/dir.py b/hloc/extractors/dir.py index 4290135d..81bf8afa 100644 --- a/hloc/extractors/dir.py +++ b/hloc/extractors/dir.py @@ -1,4 +1,4 @@ -import sys +import os, sys from pathlib import Path import torch from zipfile import ZipFile @@ -9,7 +9,7 @@ from ..utils.base_model import BaseModel sys.path.append(str( - Path(__file__).parent / '../../third_party/deep-image-retrieval')) + os.path.dirname(__import__('hloc').__file__) + '/third_party/d2net/deep-image-retrieval')) os.environ['DB_ROOT'] = '' # required by dirtorch from dirtorch.utils import common # noqa: E402 diff --git a/hloc/extractors/r2d2.py b/hloc/extractors/r2d2.py index e1465ee8..a64f6943 100644 --- a/hloc/extractors/r2d2.py +++ b/hloc/extractors/r2d2.py @@ -1,10 +1,10 @@ -import sys +import os, sys from pathlib import Path import torchvision.transforms as tvf from ..utils.base_model import BaseModel -r2d2_path = Path(__file__).parent / "../../third_party/r2d2" +r2d2_path = os.path.dirname(__import__('hloc').__file__) + "/third_party/r2d2" sys.path.append(str(r2d2_path)) from extract import load_network, NonMaxSuppression, extract_multiscale diff --git a/hloc/extractors/superpoint.py b/hloc/extractors/superpoint.py index 739246a1..4aa6b31a 100644 --- a/hloc/extractors/superpoint.py +++ b/hloc/extractors/superpoint.py @@ -1,10 +1,10 @@ -import sys +import os, sys from pathlib import Path import torch from ..utils.base_model import BaseModel -sys.path.append(str(Path(__file__).parent / '../../third_party')) +sys.path.append(str(os.path.dirname(__import__('hloc').__file__) + '/third_party/')) from SuperGluePretrainedNetwork.models import superpoint # noqa E402 diff --git a/hloc/matchers/superglue.py b/hloc/matchers/superglue.py index 916f9785..616c8b1c 100644 --- a/hloc/matchers/superglue.py +++ b/hloc/matchers/superglue.py @@ -1,9 +1,9 @@ -import sys +import os, sys from pathlib import Path from ..utils.base_model import BaseModel -sys.path.append(str(Path(__file__).parent / '../../third_party')) +sys.path.append(str(os.path.dirname(__import__('hloc').__file__) + '/third_party/')) from SuperGluePretrainedNetwork.models.superglue import SuperGlue as SG diff --git a/setup.py b/setup.py index 0561844e..9baff070 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,35 @@ from pathlib import Path from setuptools import setup, find_packages +from setuptools.command.install import install as _install +from setuptools.command.develop import develop as _develop +from shutil import copytree +import os +import site -description = ['Tools and baselines for visual localization and mapping'] +class InstallAndCopy(_install): + def run(self): + # Call the standard install procedure + _install.run(self) + + # Define the source and destination for the thirdparty folder + source = Path(__file__).parent / 'third_party' + destination = Path(site.getsitepackages()[0]) / 'hloc' / 'third_party' + + # Copy the thirdparty folder + if source.exists() and not destination.exists(): + copytree(source, destination) + +class CustomDevelop(_develop): + def run(self): + # Run standard develop command + _develop.run(self) + + # create symlink in hloc folder to make it work in editable mode + if not os.path.exists(Path(__file__).parent / 'hloc' / 'third_party'): + os.symlink(Path(__file__).parent / 'third_party', Path(__file__).parent / 'hloc' / 'third_party') + +# Read in various files for the setup function root = Path(__file__).parent with open(str(root / 'README.md'), 'r', encoding='utf-8') as f: readme = f.read() @@ -11,14 +38,17 @@ with open(str(root / 'requirements.txt'), 'r') as f: dependencies = f.read().split('\n') +# Setup function setup( name='hloc', version=version, packages=find_packages(), python_requires='>=3.6', install_requires=dependencies, + cmdclass={'install': InstallAndCopy, + 'develop': CustomDevelop}, # Use the custom install class author='Paul-Edouard Sarlin', - description=description, + description='Tools and baselines for visual localization and mapping', long_description=readme, long_description_content_type="text/markdown", url='https://github.com/cvg/Hierarchical-Localization/',