Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix tests on Twisted trunk. (#16528)
Browse files Browse the repository at this point in the history
Twisted trunk makes a change to the `TLSMemoryBIOFactory` where
the underlying protocol is changed from `TLSMemoryBIOProtocol` to
`BufferingTLSTransport` to improve performance of TLS code (see
twisted/twisted#11989).

In order to properly hook this code up in tests we need to pass the test
reactor's clock into `TLSMemoryBIOFactory` to avoid the global (trial)
reactor being used by default.

Twisted does something similar internally for tests:
https://github.com/twisted/twisted/blob/157cd8e659705940e895d321339d467e76ae9d0a/src/twisted/web/test/test_agent.py#L871-L874
  • Loading branch information
clokep authored Oct 25, 2023
1 parent 95076f7 commit e182dbb
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 111 deletions.
1 change: 1 addition & 0 deletions changelog.d/16528.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix running unit tests on Twisted trunk.
37 changes: 35 additions & 2 deletions tests/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
import subprocess
from typing import List

from incremental import Version
from zope.interface import implementer

import twisted
from OpenSSL import SSL
from OpenSSL.SSL import Connection
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.interfaces import (
IOpenSSLServerConnectionCreator,
IProtocolFactory,
IReactorTime,
)
from twisted.internet.ssl import Certificate, trustRootFromCertificates
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.client import BrowserLikePolicyForHTTPS # noqa: F401
from twisted.web.iweb import IPolicyForHTTPS # noqa: F401

Expand Down Expand Up @@ -153,6 +159,33 @@ def serverConnectionForTLS(self, tlsProtocol: TLSMemoryBIOProtocol) -> Connectio
return Connection(ctx, None)


def wrap_server_factory_for_tls(
factory: IProtocolFactory, clock: IReactorTime, sanlist: List[bytes]
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for
Returns:
interfaces.IProtocolFactory
"""
connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
# Twisted > 23.8.0 has a different API that accepts a clock.
if twisted.version <= Version("Twisted", 23, 8, 0):
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)
else:
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory, clock=clock # type: ignore[call-arg]
)


# A dummy address, useful for tests that use FakeTransport and don't care about where
# packets are going to/coming from.
dummy_address = IPv4Address("TCP", "127.0.0.1", 80)
60 changes: 24 additions & 36 deletions tests/http/federation/test_matrix_federation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
IProtocolFactory,
)
from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
from twisted.web.http import HTTPChannel, Request
Expand All @@ -57,11 +57,7 @@
from synapse.util.caches.ttlcache import TTLCache

from tests import unittest
from tests.http import (
TestServerTLSConnectionFactory,
dummy_address,
get_test_ca_cert_file,
)
from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.utils import checked_cast, default_config

Expand Down Expand Up @@ -125,7 +121,18 @@ def _make_connection(
# build the test server
server_factory = _get_test_protocol_factory()
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
server_factory = wrap_server_factory_for_tls(
server_factory,
self.reactor,
tls_sanlist
or [
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
],
)

server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
Expand Down Expand Up @@ -435,8 +442,16 @@ def _do_get_via_proxy(
request.finish()

# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
server_ssl_protocol = wrap_server_factory_for_tls(
_get_test_protocol_factory(),
self.reactor,
sanlist=[
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
],
).buildProtocol(dummy_address)

# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
Expand Down Expand Up @@ -1786,33 +1801,6 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
raise AssertionError("Expected logcontext %s but was %s" % (context, current))


def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for
Returns:
interfaces.IProtocolFactory
"""
if sanlist is None:
sanlist = [
b"DNS:testserv",
b"DNS:target-server",
b"DNS:xn--bcher-kva.com",
b"IP:1.2.3.4",
b"IP:::1",
]

connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)


def _get_test_protocol_factory() -> IProtocolFactory:
"""Get a protocol Factory which will build an HTTPChannel
Returns:
Expand Down
44 changes: 10 additions & 34 deletions tests/http/test_proxyagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,14 @@
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel

from synapse.http.client import BlocklistingReactorWrapper
from synapse.http.connectproxyclient import BasicProxyCredentials
from synapse.http.proxyagent import ProxyAgent, parse_proxy

from tests.http import (
TestServerTLSConnectionFactory,
dummy_address,
get_test_https_policy,
)
from tests.http import dummy_address, get_test_https_policy, wrap_server_factory_for_tls
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
from tests.utils import checked_cast
Expand Down Expand Up @@ -272,7 +268,9 @@ def _make_connection(
the server Protocol returned by server_factory
"""
if ssl:
server_factory = _wrap_server_factory_for_tls(server_factory, tls_sanlist)
server_factory = wrap_server_factory_for_tls(
server_factory, self.reactor, tls_sanlist or [b"DNS:test.com"]
)

server_protocol = server_factory.buildProtocol(dummy_address)
assert server_protocol is not None
Expand Down Expand Up @@ -639,8 +637,8 @@ def _do_https_request_via_proxy(
request.finish()

# now we make another test server to act as the upstream HTTP server.
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
server_ssl_protocol = wrap_server_factory_for_tls(
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
).buildProtocol(dummy_address)

# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
Expand Down Expand Up @@ -806,7 +804,9 @@ def test_https_request_via_uppercase_proxy_with_blocklist(self) -> None:
request.finish()

# now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
ssl_factory = wrap_server_factory_for_tls(
_get_test_protocol_factory(), self.reactor, sanlist=[b"DNS:test.com"]
)
ssl_protocol = ssl_factory.buildProtocol(dummy_address)
assert isinstance(ssl_protocol, TLSMemoryBIOProtocol)
http_server = ssl_protocol.wrappedProtocol
Expand Down Expand Up @@ -870,30 +870,6 @@ def test_proxy_with_https_scheme(self) -> None:
self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)


def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
Args:
factory: protocol factory to wrap
sanlist: list of domains the cert should be valid for
Returns:
interfaces.IProtocolFactory
"""
if sanlist is None:
sanlist = [b"DNS:test.com"]

connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
return TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=factory
)


def _get_test_protocol_factory() -> IProtocolFactory:
"""Get a protocol Factory which will build an HTTPChannel
Expand Down
52 changes: 13 additions & 39 deletions tests/replication/test_multi_media_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
import os
from typing import Any, Optional, Tuple

from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http import HTTPChannel
from twisted.web.server import Request
Expand All @@ -27,7 +25,11 @@
from synapse.server import HomeServer
from synapse.util import Clock

from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.http import (
TestServerTLSConnectionFactory,
get_test_ca_cert_file,
wrap_server_factory_for_tls,
)
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport, make_request
from tests.test_utils import SMALL_PNG
Expand Down Expand Up @@ -94,7 +96,13 @@ def _get_media_req(
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()

# build the test server
server_tls_protocol = _build_test_server(get_connection_factory())
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request

server_tls_protocol = wrap_server_factory_for_tls(
server_factory, self.reactor, sanlist=[b"DNS:example.com"]
).buildProtocol(None)

# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
Expand All @@ -114,7 +122,7 @@ def _get_media_req(
)

# fish the test server back out of the server-side TLS protocol.
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment]
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol

# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
Expand Down Expand Up @@ -240,40 +248,6 @@ def _count_remote_thumbnails(self) -> int:
return sum(len(files) for _, _, files in os.walk(path))


def get_connection_factory() -> TestServerTLSConnectionFactory:
# this needs to happen once, but not until we are ready to run the first test
global test_server_connection_factory
if test_server_connection_factory is None:
test_server_connection_factory = TestServerTLSConnectionFactory(
sanlist=[b"DNS:example.com"]
)
return test_server_connection_factory


def _build_test_server(
connection_creator: IOpenSSLServerConnectionCreator,
) -> TLSMemoryBIOProtocol:
"""Construct a test server
This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol
Args:
connection_creator: thing to build SSL connections
Returns:
TLSMemoryBIOProtocol
"""
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request

server_tls_factory = TLSMemoryBIOFactory(
connection_creator, isClient=False, wrappedFactory=server_factory
)

return server_tls_factory.buildProtocol(None)


def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request)
12 changes: 12 additions & 0 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
from unittest.mock import Mock

import attr
from incremental import Version
from typing_extensions import ParamSpec
from zope.interface import implementer

import twisted
from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
Expand Down Expand Up @@ -474,6 +476,16 @@ def getHostByName(
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])

# In order for the TLS protocol tests to work, modify _get_default_clock
# on newer Twisted versions to use the test reactor's clock.
#
# This is *super* dirty since it is never undone and relies on the next
# test to overwrite it.
if twisted.version > Version("Twisted", 23, 8, 0):
from twisted.protocols import tls

tls._get_default_clock = lambda: self # type: ignore[attr-defined]

self.nameResolver = SimpleResolverComplexifier(FakeResolver())
super().__init__()

Expand Down

0 comments on commit e182dbb

Please sign in to comment.