From c67f8b164dacfab3e292c4d735c393aba675fd28 Mon Sep 17 00:00:00 2001 From: Laurent Date: Tue, 24 Sep 2024 10:10:50 +0200 Subject: [PATCH] Add the ssl flag to the pytorch server too. --- moshi/moshi/server.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 212ba9c..2dbf106 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -182,6 +182,14 @@ def main(): help="HF repo to look into, defaults Moshiko. " "Use this to select a different pre-trained model.") parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.") + parser.add_argument( + "--ssl", + type=str, + help=( + "use https instead of http, this flag should point to a directory " + "that contains valid key.pem and cert.pem files" + ) + ) args = parser.parse_args() seed_all(42424242) @@ -244,12 +252,23 @@ async def handle_root(_): app.router.add_static( "/", path=static_path, follow_symlinks=True, name="static" ) - log("info", f"Access the Web UI directly at http://{args.host}:{args.port}") + protocol = "http" + ssl_context = None + if args.ssl is not None: + import ssl + + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + cert_file = os.path.join(args.ssl, "cert.pem") + key_file = os.path.join(args.ssl, "key.pem") + ssl_context.load_cert_chain(certfile=cert_file, keyfile=key_file) + protocol = "https" + + log("info", f"Access the Web UI directly at {protocol}://{args.host}:{args.port}") if setup_tunnel is not None: tunnel = setup_tunnel('localhost', args.port, tunnel_token, None) log("info", f"Tunnel started, if executing on a remote GPU, you can use {tunnel}.") log("info", "Note that this tunnel goes through the US and you might experience high latency in Europe.") - web.run_app(app, port=args.port) + web.run_app(app, port=args.port, ssl_context=ssl_context) with torch.no_grad():