Skip to content

Commit

Permalink
Make restapi ports dynamic and non-conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Nov 21, 2024
1 parent 81895bb commit 3bf2c72
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
13 changes: 11 additions & 2 deletions tests/restapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,25 @@
@pytest.fixture(scope='function')
def restapi_server():
"""Make REST API server"""
import socket
from aiida.restapi.common.config import CLI_DEFAULTS
from aiida.restapi.run_api import configure_api
from werkzeug.serving import make_server

def _restapi_server(restapi=None):
# Dynamically find a free port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(('', 0)) # Bind to a free port provided by the OS
_, port = sock.getsockname() # Get the dynamically assigned port

if restapi is None:
flask_restapi = configure_api()
else:
flask_restapi = configure_api(flask_api=restapi)

return make_server(
host=CLI_DEFAULTS['HOST_NAME'],
port=int(CLI_DEFAULTS['PORT']),
port=port,
app=flask_restapi.app,
threaded=True,
processes=1,
Expand All @@ -43,7 +49,10 @@ def _restapi_server(restapi=None):
def server_url():
from aiida.restapi.common.config import API_CONFIG, CLI_DEFAULTS

return f"http://{CLI_DEFAULTS['HOST_NAME']}:{CLI_DEFAULTS['PORT']}{API_CONFIG['PREFIX']}"
def _server_url(hostname: str | None = None, port: int | None = None):
return f"http://{hostname or CLI_DEFAULTS['HOST_NAME']}:{port or CLI_DEFAULTS['PORT']}{API_CONFIG['PREFIX']}"

return _server_url


@pytest.fixture
Expand Down
8 changes: 6 additions & 2 deletions tests/restapi/test_identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ def test_full_type_unregistered(process_class, restapi_server, server_url):
server = restapi_server()
server_thread = Thread(target=server.serve_forever)

_server_url = server_url(port=server.server_port)

try:
server_thread.start()
type_count_response = requests.get(f'{server_url}/nodes/full_types', timeout=10)
type_count_response = requests.get(f'{_server_url}/nodes/full_types', timeout=10)
finally:
server.shutdown()

Expand Down Expand Up @@ -188,9 +190,11 @@ def test_full_type_backwards_compatibility(node_class, restapi_server, server_ur
server = restapi_server()
server_thread = Thread(target=server.serve_forever)

_server_url = server_url(port=server.server_port)

try:
server_thread.start()
type_count_response = requests.get(f'{server_url}/nodes/full_types', timeout=10)
type_count_response = requests.get(f'{_server_url}/nodes/full_types', timeout=10)
finally:
server.shutdown()

Expand Down
6 changes: 4 additions & 2 deletions tests/restapi/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def test_count_consistency(restapi_server, server_url):
server = restapi_server()
server_thread = Thread(target=server.serve_forever)

_server_url = server_url(port=server.server_port)

try:
server_thread.start()
type_count_response = requests.get(f'{server_url}/nodes/full_types_count', timeout=10)
statistics_response = requests.get(f'{server_url}/nodes/statistics', timeout=10)
type_count_response = requests.get(f'{_server_url}/nodes/full_types_count', timeout=10)
statistics_response = requests.get(f'{_server_url}/nodes/statistics', timeout=10)
finally:
server.shutdown()

Expand Down
6 changes: 4 additions & 2 deletions tests/restapi/test_threaded_restapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,19 @@ def test_run_threaded_server(restapi_server, server_url, aiida_localhost):
This test will fail, if database connections are not being properly closed by the end-point calls.
"""
server = restapi_server()
computer_id = aiida_localhost.uuid

# Create a thread that will contain the running server,
# since we do not wish to block the main thread
server_thread = Thread(target=server.serve_forever)
_server_url = server_url(port=server.server_port)

computer_id = aiida_localhost.uuid

try:
server_thread.start()

for _ in range(NO_OF_REQUESTS):
response = requests.get(f'{server_url}/computers/{computer_id}', timeout=10)
response = requests.get(f'{_server_url}/computers/{computer_id}', timeout=10)

assert response.status_code == 200

Expand Down

0 comments on commit 3bf2c72

Please sign in to comment.