Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Digest Authentification support for the legacy client part only #1111

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions src/websockets/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
UpgradeProtocol,
)

from urllib.parse import urlparse
import hashlib
import time
import os

from .uri import WebSocketURI

__all__ = [
"build_host",
Expand All @@ -34,6 +40,120 @@

T = TypeVar("T")

# Class from the module requests
class HTTPDigestAuth():
"""Attaches HTTP Digest Authentication to the given Request object."""

def __init__(self, username, password):
self.username = username
self.password = password
self.last_nonce = ''
self.nonce_count = 0
self.chal = {}
self.pos = None

def build_digest_header(self, method, url):
"""
:rtype: str
"""

realm = self.chal['realm']
nonce = self.chal['nonce']
qop = self.chal.get('qop')
algorithm = self.chal.get('algorithm')
opaque = self.chal.get('opaque')
hash_utf8 = None

if algorithm is None:
_algorithm = 'MD5'
else:
_algorithm = algorithm.upper()
# lambdas assume digest modules are imported at the top level
if _algorithm == 'MD5' or _algorithm == 'MD5-SESS':
def md5_utf8(x):
if isinstance(x, str):
x = x.encode('utf-8')
return hashlib.md5(x).hexdigest()
hash_utf8 = md5_utf8
elif _algorithm == 'SHA':
def sha_utf8(x):
if isinstance(x, str):
x = x.encode('utf-8')
return hashlib.sha1(x).hexdigest()
hash_utf8 = sha_utf8
elif _algorithm == 'SHA-256':
def sha256_utf8(x):
if isinstance(x, str):
x = x.encode('utf-8')
return hashlib.sha256(x).hexdigest()
hash_utf8 = sha256_utf8
elif _algorithm == 'SHA-512':
def sha512_utf8(x):
if isinstance(x, str):
x = x.encode('utf-8')
return hashlib.sha512(x).hexdigest()
hash_utf8 = sha512_utf8

KD = lambda s, d: hash_utf8("%s:%s" % (s, d))

if hash_utf8 is None:
return None

# XXX not implemented yet
entdig = None
p_parsed = urlparse(url)
#: path is request-uri defined in RFC 2616 which should not be empty
path = p_parsed.path or "/"
if p_parsed.query:
path += '?' + p_parsed.query

A1 = '%s:%s:%s' % (self.username, realm, self.password)
A2 = '%s:%s' % (method, path)

HA1 = hash_utf8(A1)
HA2 = hash_utf8(A2)

if nonce == self.last_nonce:
self.nonce_count += 1
else:
self.nonce_count = 1
ncvalue = '%08x' % self.nonce_count
s = str(self.nonce_count).encode('utf-8')
s += nonce.encode('utf-8')
s += time.ctime().encode('utf-8')
s += os.urandom(8)

cnonce = (hashlib.sha1(s).hexdigest()[:16])
if _algorithm == 'MD5-SESS':
HA1 = hash_utf8('%s:%s:%s' % (HA1, nonce, cnonce))

if not qop:
respdig = KD(HA1, "%s:%s" % (nonce, HA2))
elif qop == 'auth' or 'auth' in qop.split(','):
noncebit = "%s:%s:%s:%s:%s" % (
nonce, ncvalue, cnonce, 'auth', HA2
)
respdig = KD(HA1, noncebit)
else:
# XXX handle auth-int.
return None

self.last_nonce = nonce

# XXX should the partial digests be encoded too?
base = 'username="%s", realm="%s", nonce="%s", uri="%s", ' \
'response="%s"' % (self.username, realm, nonce, path, respdig)
if opaque:
base += ', opaque="%s"' % opaque
if algorithm:
base += ', algorithm="%s"' % algorithm
if entdig:
base += ', digest="%s"' % entdig
if qop:
base += ', qop="auth", nc=%s, cnonce="%s"' % (ncvalue, cnonce)

return 'Digest %s' % (base)


def build_host(host: str, port: int, secure: bool) -> str:
"""
Expand Down Expand Up @@ -585,3 +705,62 @@ def build_authorization_basic(username: str, password: str) -> str:
user_pass = f"{username}:{password}"
basic_credentials = base64.b64encode(user_pass.encode()).decode()
return "Basic " + basic_credentials

def build_authorization_digest(username: str, password: str, response_header, wsuri: WebSocketURI,) -> str:
"""

Build an ``Authorization`` header for HTTP Digest Auth.

It needs the once challenge that the server sended in the 401 headers response

"""
assert ":" not in username

auth_rec = response_header.get('www-authenticate', '')

auth_rec_splited = auth_rec.split(',')
# Remove the Auth type (Digest or Basic)
first_field = auth_rec_splited[0]
first_field_splited = first_field.split(' ')
auth_type = first_field_splited[0]
auth_rec_splited[0] = first_field_splited[1]

realm = ''
qop = ''
nonce = ''

for field in auth_rec_splited:
field = field.strip()
key, value = field.split('=', 1)
if key == 'realm':
realm = value
elif key == 'qop':
qop = value
elif key == 'nonce':
nonce = value

digest_challenge = HTTPDigestAuth(username, password)

# Remove the double quotes (begining and end)
digest_challenge.chal['realm'] = realm[1:-1]
digest_challenge.chal['nonce'] = nonce[1:-1]
digest_challenge.chal['qop'] = qop[1:-1]

# Reconstruction of the websocket uri for build_digest_header()
if wsuri.secure:
uri = 'wss://'
else:
uri = 'ws://'

if wsuri.host:
uri += wsuri.host

if wsuri.port:
uri += ':' + str(wsuri.port)

if wsuri.resource_name:
uri += wsuri.resource_name

digest_credential = digest_challenge.build_digest_header('GET', wsuri.resource_name)

return digest_credential
65 changes: 65 additions & 0 deletions src/websockets/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..extensions import ClientExtensionFactory, Extension
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import (
build_authorization_digest,
build_authorization_basic,
build_extension,
build_host,
Expand Down Expand Up @@ -324,6 +325,70 @@ async def handshake(
if "Location" not in response_headers:
raise InvalidHeader("Location")
raise RedirectHandshake(response_headers["Location"])

elif status_code == 401:
# The Digest method needs a challenge, the server send the 401 response
# with the data of the challenge in the headers
challenge_headers = Headers()

# Extract the authentication type at the beginning of the www-authenticate header
websocket_auth_type = response_headers.get('www-authenticate', '')
websocket_auth_type = websocket_auth_type.split(' ', 1)
websocket_auth_type = websocket_auth_type[0].lower()

if websocket_auth_type == 'digest':

# Start to build the challenge headers
if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover
challenge_headers["Host"] = wsuri.host
else:
challenge_headers["Host"] = f"{wsuri.host}:{wsuri.port}"

# Build the digest
challenge_headers["Authorization"] = build_authorization_digest(
*wsuri.user_info, response_headers, wsuri
)

# Other headers
if origin is not None:
challenge_headers["Origin"] = origin

# Don't generate a new key with Digest auth, use the existing key provided by the server
key = build_request(challenge_headers, key)

if available_extensions is not None:
extensions_header = build_extension(
[
(extension_factory.name, extension_factory.get_request_params())
for extension_factory in available_extensions
]
)
challenge_headers["Sec-WebSocket-Extensions"] = extensions_header

if available_subprotocols is not None:
protocol_header = build_subprotocol(available_subprotocols)
challenge_headers["Sec-WebSocket-Protocol"] = protocol_header

if extra_headers is not None:
if isinstance(extra_headers, Headers):
extra_headers = extra_headers.raw_items()
elif isinstance(extra_headers, collections.abc.Mapping):
extra_headers = extra_headers.items()
for name, value in extra_headers:
challenge_headers[name] = value

# Last headers for the challenge response
challenge_headers.setdefault("User-Agent", USER_AGENT)

# Send Challenge
self.write_http_request(wsuri.resource_name, challenge_headers)
# Wait response
status_code, response_headers = await self.read_http_response()

else:
# If not digest type, same as not 101 answer
raise InvalidStatusCode(status_code, response_headers)

elif status_code != 101:
raise InvalidStatusCode(status_code, response_headers)

Expand Down
9 changes: 7 additions & 2 deletions src/websockets/legacy/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,25 @@
__all__ = ["build_request", "check_request", "build_response", "check_response"]


def build_request(headers: Headers) -> str:
def build_request(headers: Headers, digest_key: str = '') -> str:
"""
Build a handshake request to send to the server.

Update request headers passed in argument.

Args:
headers: handshake request headers.
digest_key: the key sended by the server during the handshake, only for Digest Auth
This paramter is optional, if not provided, the key is generated by this function

Returns:
str: ``key`` that must be passed to :func:`check_response`.

"""
key = generate_key()
if not digest_key:
key = generate_key()
else:
key = digest_key
headers["Upgrade"] = "websocket"
headers["Connection"] = "Upgrade"
headers["Sec-WebSocket-Key"] = key
Expand Down