diff --git a/imagedephi/gui.py b/imagedephi/gui.py index 9b3db37f..59cd7ba1 100644 --- a/imagedephi/gui.py +++ b/imagedephi/gui.py @@ -5,7 +5,7 @@ from pathlib import Path from fastapi import BackgroundTasks, FastAPI, Form, HTTPException, Request -from fastapi.responses import HTMLResponse, PlainTextResponse +from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse from fastapi.templating import Jinja2Templates from jinja2 import FunctionLoader from starlette.background import BackgroundTask @@ -63,11 +63,25 @@ def reset_shutdown_event() -> None: shutdown_event.clear() -@app.get("/", response_class=HTMLResponse) +@app.get("/", response_class=RedirectResponse) +def home(request: Request) -> str: + # On Windows, there may be multiple roots, so pick the one that's an ancestor of the CWD + # On Linux, this should typically resolve to "/" + root_directory = Path.cwd().root + + # TODO: FastAPI has a bug where a URL object can't be directly returned here + return str( + request.url_for("select_directory").include_query_params( + input_directory=str(root_directory), output_directory=str(root_directory) + ) + ) + + +@app.get("/select-directory", response_class=HTMLResponse) def select_directory( request: Request, - input_directory: Path = Path("/"), # noqa: B008 - output_directory: Path = Path("/"), # noqa: B008 + input_directory: Path, + output_directory: Path, ): # TODO: if input_directory is specified but an empty string, it gets instantiated as the CWD if not input_directory.is_dir(): diff --git a/imagedephi/main.py b/imagedephi/main.py index 268f3565..a06ac286 100644 --- a/imagedephi/main.py +++ b/imagedephi/main.py @@ -8,6 +8,7 @@ import webbrowser import click +from fastapi.datastructures import URL from hypercorn import Config from hypercorn.asyncio import serve import yaml @@ -105,9 +106,11 @@ async def announce_ready() -> None: # To avoid race conditions, ensure that the webserver is # actually running before launching the browser await wait_for_port(port) - url = f"http://{host}:{port}/" - click.echo(f"Server is running at {url} .") - webbrowser.open(url) + home_url = app.url_path_for("home").make_absolute_url( + URL(scheme="http", netloc=f"{host}:{port}") + ) + click.echo(f"Server is running at {home_url} .") + webbrowser.open(str(home_url)) async with asyncio.TaskGroup() as task_group: task_group.create_task(announce_ready()) diff --git a/tests/test_gui.py b/tests/test_gui.py index 47cd4287..84bece6b 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -11,8 +11,8 @@ def client() -> TestClient: return TestClient(app) -def test_gui_select_directory(client: TestClient) -> None: - response = client.get("/") +def test_gui_home(client: TestClient) -> None: + response = client.get(app.url_path_for("home"), follow_redirects=True) assert response.status_code == 200 assert "Select Directory" in response.text @@ -35,7 +35,8 @@ def test_gui_select_directory_input_not_found( tmp_path: Path, ) -> None: response = client.get( - app.url_path_for("select_directory"), params={"input_directory": str(tmp_path / "fake")} + app.url_path_for("select_directory"), + params={"input_directory": str(tmp_path / "fake"), "output_directory": str(tmp_path)}, ) assert response.status_code == 404 @@ -47,7 +48,8 @@ def test_gui_select_directory_output_not_found( tmp_path: Path, ) -> None: response = client.get( - app.url_path_for("select_directory"), params={"output_directory": str(tmp_path / "fake")} + app.url_path_for("select_directory"), + params={"input_directory": str(tmp_path), "output_directory": str(tmp_path / "fake")}, ) assert response.status_code == 404