diff --git a/src/websockets/headers.py b/src/websockets/headers.py index 9ae3035a5..0a3ce78cb 100644 --- a/src/websockets/headers.py +++ b/src/websockets/headers.py @@ -16,6 +16,12 @@ UpgradeProtocol, ) +from urllib.parse import urlparse +import hashlib +import time +import os + +from .uri import WebSocketURI __all__ = [ "build_host", @@ -34,6 +40,120 @@ T = TypeVar("T") +# Class from the module requests +class HTTPDigestAuth(): + """Attaches HTTP Digest Authentication to the given Request object.""" + + def __init__(self, username, password): + self.username = username + self.password = password + self.last_nonce = '' + self.nonce_count = 0 + self.chal = {} + self.pos = None + + def build_digest_header(self, method, url): + """ + :rtype: str + """ + + realm = self.chal['realm'] + nonce = self.chal['nonce'] + qop = self.chal.get('qop') + algorithm = self.chal.get('algorithm') + opaque = self.chal.get('opaque') + hash_utf8 = None + + if algorithm is None: + _algorithm = 'MD5' + else: + _algorithm = algorithm.upper() + # lambdas assume digest modules are imported at the top level + if _algorithm == 'MD5' or _algorithm == 'MD5-SESS': + def md5_utf8(x): + if isinstance(x, str): + x = x.encode('utf-8') + return hashlib.md5(x).hexdigest() + hash_utf8 = md5_utf8 + elif _algorithm == 'SHA': + def sha_utf8(x): + if isinstance(x, str): + x = x.encode('utf-8') + return hashlib.sha1(x).hexdigest() + hash_utf8 = sha_utf8 + elif _algorithm == 'SHA-256': + def sha256_utf8(x): + if isinstance(x, str): + x = x.encode('utf-8') + return hashlib.sha256(x).hexdigest() + hash_utf8 = sha256_utf8 + elif _algorithm == 'SHA-512': + def sha512_utf8(x): + if isinstance(x, str): + x = x.encode('utf-8') + return hashlib.sha512(x).hexdigest() + hash_utf8 = sha512_utf8 + + KD = lambda s, d: hash_utf8("%s:%s" % (s, d)) + + if hash_utf8 is None: + return None + + # XXX not implemented yet + entdig = None + p_parsed = urlparse(url) + #: path is request-uri defined in RFC 2616 which should not be empty + path = p_parsed.path or "/" + if p_parsed.query: + path += '?' + p_parsed.query + + A1 = '%s:%s:%s' % (self.username, realm, self.password) + A2 = '%s:%s' % (method, path) + + HA1 = hash_utf8(A1) + HA2 = hash_utf8(A2) + + if nonce == self.last_nonce: + self.nonce_count += 1 + else: + self.nonce_count = 1 + ncvalue = '%08x' % self.nonce_count + s = str(self.nonce_count).encode('utf-8') + s += nonce.encode('utf-8') + s += time.ctime().encode('utf-8') + s += os.urandom(8) + + cnonce = (hashlib.sha1(s).hexdigest()[:16]) + if _algorithm == 'MD5-SESS': + HA1 = hash_utf8('%s:%s:%s' % (HA1, nonce, cnonce)) + + if not qop: + respdig = KD(HA1, "%s:%s" % (nonce, HA2)) + elif qop == 'auth' or 'auth' in qop.split(','): + noncebit = "%s:%s:%s:%s:%s" % ( + nonce, ncvalue, cnonce, 'auth', HA2 + ) + respdig = KD(HA1, noncebit) + else: + # XXX handle auth-int. + return None + + self.last_nonce = nonce + + # XXX should the partial digests be encoded too? + base = 'username="%s", realm="%s", nonce="%s", uri="%s", ' \ + 'response="%s"' % (self.username, realm, nonce, path, respdig) + if opaque: + base += ', opaque="%s"' % opaque + if algorithm: + base += ', algorithm="%s"' % algorithm + if entdig: + base += ', digest="%s"' % entdig + if qop: + base += ', qop="auth", nc=%s, cnonce="%s"' % (ncvalue, cnonce) + + return 'Digest %s' % (base) + def build_host(host: str, port: int, secure: bool) -> str: """ @@ -585,3 +705,62 @@ def build_authorization_basic(username: str, password: str) -> str: user_pass = f"{username}:{password}" basic_credentials = base64.b64encode(user_pass.encode()).decode() return "Basic " + basic_credentials + +def build_authorization_digest(username: str, password: str, response_header, wsuri: WebSocketURI,) -> str: + """ + + Build an ``Authorization`` header for HTTP Digest Auth. + + It needs the once challenge that the server sended in the 401 headers response + + """ + assert ":" not in username + + auth_rec = response_header.get('www-authenticate', '') + + auth_rec_splited = auth_rec.split(',') + # Remove the Auth type (Digest or Basic) + first_field = auth_rec_splited[0] + first_field_splited = first_field.split(' ') + auth_type = first_field_splited[0] + auth_rec_splited[0] = first_field_splited[1] + + realm = '' + qop = '' + nonce = '' + + for field in auth_rec_splited: + field = field.strip() + key, value = field.split('=', 1) + if key == 'realm': + realm = value + elif key == 'qop': + qop = value + elif key == 'nonce': + nonce = value + + digest_challenge = HTTPDigestAuth(username, password) + + # Remove the double quotes (begining and end) + digest_challenge.chal['realm'] = realm[1:-1] + digest_challenge.chal['nonce'] = nonce[1:-1] + digest_challenge.chal['qop'] = qop[1:-1] + + # Reconstruction of the websocket uri for build_digest_header() + if wsuri.secure: + uri = 'wss://' + else: + uri = 'ws://' + + if wsuri.host: + uri += wsuri.host + + if wsuri.port: + uri += ':' + str(wsuri.port) + + if wsuri.resource_name: + uri += wsuri.resource_name + + digest_credential = digest_challenge.build_digest_header('GET', wsuri.resource_name) + + return digest_credential diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index 6704d16ce..c5cca602d 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -33,6 +33,7 @@ from ..extensions import ClientExtensionFactory, Extension from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import ( + build_authorization_digest, build_authorization_basic, build_extension, build_host, @@ -324,6 +325,70 @@ async def handshake( if "Location" not in response_headers: raise InvalidHeader("Location") raise RedirectHandshake(response_headers["Location"]) + + elif status_code == 401: + # The Digest method needs a challenge, the server send the 401 response + # with the data of the challenge in the headers + challenge_headers = Headers() + + # Extract the authentication type at the beginning of the www-authenticate header + websocket_auth_type = response_headers.get('www-authenticate', '') + websocket_auth_type = websocket_auth_type.split(' ', 1) + websocket_auth_type = websocket_auth_type[0].lower() + + if websocket_auth_type == 'digest': + + # Start to build the challenge headers + if wsuri.port == (443 if wsuri.secure else 80): # pragma: no cover + challenge_headers["Host"] = wsuri.host + else: + challenge_headers["Host"] = f"{wsuri.host}:{wsuri.port}" + + # Build the digest + challenge_headers["Authorization"] = build_authorization_digest( + *wsuri.user_info, response_headers, wsuri + ) + + # Other headers + if origin is not None: + challenge_headers["Origin"] = origin + + # Don't generate a new key with Digest auth, use the existing key provided by the server + key = build_request(challenge_headers, key) + + if available_extensions is not None: + extensions_header = build_extension( + [ + (extension_factory.name, extension_factory.get_request_params()) + for extension_factory in available_extensions + ] + ) + challenge_headers["Sec-WebSocket-Extensions"] = extensions_header + + if available_subprotocols is not None: + protocol_header = build_subprotocol(available_subprotocols) + challenge_headers["Sec-WebSocket-Protocol"] = protocol_header + + if extra_headers is not None: + if isinstance(extra_headers, Headers): + extra_headers = extra_headers.raw_items() + elif isinstance(extra_headers, collections.abc.Mapping): + extra_headers = extra_headers.items() + for name, value in extra_headers: + challenge_headers[name] = value + + # Last headers for the challenge response + challenge_headers.setdefault("User-Agent", USER_AGENT) + + # Send Challenge + self.write_http_request(wsuri.resource_name, challenge_headers) + # Wait response + status_code, response_headers = await self.read_http_response() + + else: + # If not digest type, same as not 101 answer + raise InvalidStatusCode(status_code, response_headers) + elif status_code != 101: raise InvalidStatusCode(status_code, response_headers) diff --git a/src/websockets/legacy/handshake.py b/src/websockets/legacy/handshake.py index 569937bb9..9a33c8ad9 100644 --- a/src/websockets/legacy/handshake.py +++ b/src/websockets/legacy/handshake.py @@ -14,7 +14,7 @@ __all__ = ["build_request", "check_request", "build_response", "check_response"] -def build_request(headers: Headers) -> str: +def build_request(headers: Headers, digest_key: str = '') -> str: """ Build a handshake request to send to the server. @@ -22,12 +22,17 @@ def build_request(headers: Headers) -> str: Args: headers: handshake request headers. + digest_key: the key sended by the server during the handshake, only for Digest Auth + This paramter is optional, if not provided, the key is generated by this function Returns: str: ``key`` that must be passed to :func:`check_response`. """ - key = generate_key() + if not digest_key: + key = generate_key() + else: + key = digest_key headers["Upgrade"] = "websocket" headers["Connection"] = "Upgrade" headers["Sec-WebSocket-Key"] = key