diff --git a/blacksheep/server/application.py b/blacksheep/server/application.py index de2e945a..7295dd35 100644 --- a/blacksheep/server/application.py +++ b/blacksheep/server/application.py @@ -32,7 +32,6 @@ from blacksheep.common import extend from blacksheep.common.files.asyncfs import FilesHandler from blacksheep.contents import ASGIContent -from blacksheep.exceptions import HTTPException from blacksheep.messages import Request, Response from blacksheep.middlewares import get_middlewares_chain from blacksheep.scribe import send_asgi_response @@ -739,31 +738,22 @@ async def _handle_websocket(self, scope, receive, send): RouteMethod.GET_WS, scope["path"] ) - if route: - ws.route_values = route.values - try: - return await route.handler(ws) - except UnauthorizedError as unauthorized_error: - # If the WebSocket connection was not accepted yet, we close the - # connection with an HTTP Status Code, otherwise we close the connection - # with a WebSocket status code - if ws.accepted: - # Use a WebSocket error code, not an HTTP error code - await ws.close(1005, "Unauthorized") - else: - # Still in handshake phase, we close with an HTTP Status Code - # https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event - await ws.close(403, str(unauthorized_error)) - except HTTPException as http_exception: - # Same like above - if ws.accepted: - # Use a WebSocket error code, not an HTTP error code - await ws.close(1005, str(http_exception)) - else: - # Still in handshake phase, we close with an HTTP Status Code - # https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event - await ws.close(http_exception.status, str(http_exception)) - await ws.close() + if route is None: + return await ws.close() + + ws.route_values = route.values + + try: + return await route.handler(ws) + except Exception as exc: + # If WebSocket connection accepted, close + # the connection using WebSocket Internal error code. + if ws.accepted: + return await ws.close(1011, reason=str(exc)) + + # Otherwise, just close the connection, the ASGI server + # will anyway respond 403 to the client. + return await ws.close() async def _handle_http(self, scope, receive, send): assert scope["type"] == "http" diff --git a/itests/app_2.py b/itests/app_2.py index 91f7e5d3..7634a65d 100644 --- a/itests/app_2.py +++ b/itests/app_2.py @@ -150,9 +150,15 @@ async def echo_text_admin_users(websocket: WebSocket): await websocket.send_text(msg) -@app_2.router.ws("/websocket-echo-text-http-exp") +@app_2.router.ws("/websocket-error-before-accept") async def echo_text_http_exp(websocket: WebSocket): - raise BadRequest("Example") + raise RuntimeError("Error before accept") + + +@app_2.router.ws("/websocket-server-error") +async def websocket_server_error(websocket: WebSocket): + await websocket.accept() + raise RuntimeError("Server error") @auth("authenticated") diff --git a/itests/test_server.py b/itests/test_server.py index 905d94ad..abdbdd1e 100644 --- a/itests/test_server.py +++ b/itests/test_server.py @@ -1,14 +1,12 @@ import json -import os import shutil from base64 import urlsafe_b64encode from urllib.parse import unquote from uuid import uuid4 -import pytest import websockets import yaml -from websockets.exceptions import InvalidStatusCode +from websockets.exceptions import ConnectionClosedError, InvalidStatusCode from .client_fixtures import get_static_path from .server_fixtures import * # NoQA @@ -793,7 +791,7 @@ async def test_websocket(server_host, server_port_4, route, data): "route", [ "websocket-echo-text-auth", - "websocket-echo-text-http-exp", + "websocket-error-before-accept", ], ) async def test_websocket_auth(server_host, server_port_2, route): @@ -805,3 +803,16 @@ async def test_websocket_auth(server_host, server_port_2, route): assert error.value.status_code == 403 assert "server rejected" in str(error.value) + + +@pytest.mark.asyncio +async def test_websocket_server_error(server_host, server_port_2): + uri = f"ws://{server_host}:{server_port_2}/websocket-server-error" + + with pytest.raises(ConnectionClosedError) as error: + async with websockets.connect(uri) as ws: + async for _message in ws: + pass + + assert error.value.code == 1011 + assert error.value.reason == "Server error"