Skip to content

Commit

Permalink
Fix WebSocket exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Klavionik committed Dec 14, 2023
1 parent 5f9bc45 commit b725631
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
42 changes: 16 additions & 26 deletions blacksheep/server/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from blacksheep.common import extend
from blacksheep.common.files.asyncfs import FilesHandler
from blacksheep.contents import ASGIContent
from blacksheep.exceptions import HTTPException
from blacksheep.messages import Request, Response
from blacksheep.middlewares import get_middlewares_chain
from blacksheep.scribe import send_asgi_response
Expand Down Expand Up @@ -739,31 +738,22 @@ async def _handle_websocket(self, scope, receive, send):
RouteMethod.GET_WS, scope["path"]
)

if route:
ws.route_values = route.values
try:
return await route.handler(ws)
except UnauthorizedError as unauthorized_error:
# If the WebSocket connection was not accepted yet, we close the
# connection with an HTTP Status Code, otherwise we close the connection
# with a WebSocket status code
if ws.accepted:
# Use a WebSocket error code, not an HTTP error code
await ws.close(1005, "Unauthorized")
else:
# Still in handshake phase, we close with an HTTP Status Code
# https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event
await ws.close(403, str(unauthorized_error))
except HTTPException as http_exception:
# Same like above
if ws.accepted:
# Use a WebSocket error code, not an HTTP error code
await ws.close(1005, str(http_exception))
else:
# Still in handshake phase, we close with an HTTP Status Code
# https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event
await ws.close(http_exception.status, str(http_exception))
await ws.close()
if route is None:
return await ws.close()

ws.route_values = route.values

try:
return await route.handler(ws)
except Exception as exc:
# If WebSocket connection accepted, close
# the connection using WebSocket Internal error code.
if ws.accepted:
return await ws.close(1011, reason=str(exc))

# Otherwise, just close the connection, the ASGI server
# will anyway respond 403 to the client.
return await ws.close()

async def _handle_http(self, scope, receive, send):
assert scope["type"] == "http"
Expand Down
10 changes: 8 additions & 2 deletions itests/app_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,15 @@ async def echo_text_admin_users(websocket: WebSocket):
await websocket.send_text(msg)


@app_2.router.ws("/websocket-echo-text-http-exp")
@app_2.router.ws("/websocket-error-before-accept")
async def echo_text_http_exp(websocket: WebSocket):
raise BadRequest("Example")
raise RuntimeError("Error before accept")


@app_2.router.ws("/websocket-server-error")
async def websocket_server_error(websocket: WebSocket):
await websocket.accept()
raise RuntimeError("Server error")


@auth("authenticated")
Expand Down
19 changes: 15 additions & 4 deletions itests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import json
import os
import shutil
from base64 import urlsafe_b64encode
from urllib.parse import unquote
from uuid import uuid4

import pytest
import websockets
import yaml
from websockets.exceptions import InvalidStatusCode
from websockets.exceptions import ConnectionClosedError, InvalidStatusCode

from .client_fixtures import get_static_path
from .server_fixtures import * # NoQA
Expand Down Expand Up @@ -793,7 +791,7 @@ async def test_websocket(server_host, server_port_4, route, data):
"route",
[
"websocket-echo-text-auth",
"websocket-echo-text-http-exp",
"websocket-error-before-accept",
],
)
async def test_websocket_auth(server_host, server_port_2, route):
Expand All @@ -805,3 +803,16 @@ async def test_websocket_auth(server_host, server_port_2, route):

assert error.value.status_code == 403
assert "server rejected" in str(error.value)


@pytest.mark.asyncio
async def test_websocket_server_error(server_host, server_port_2):
uri = f"ws://{server_host}:{server_port_2}/websocket-server-error"

with pytest.raises(ConnectionClosedError) as error:
async with websockets.connect(uri) as ws:
async for _message in ws:
pass

assert error.value.code == 1011
assert error.value.reason == "Server error"

0 comments on commit b725631

Please sign in to comment.