From ccbe247d9d3e1efbf27294e87cd8694ee4a18ab4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20D=C3=A9fossez?= Date: Tue, 17 Sep 2024 13:39:31 +0200 Subject: [PATCH] refacto --- moshi/README.md | 1 + moshi/moshi/client.py | 32 +++++- moshi/moshi/client_utils.py | 8 +- moshi/moshi/models/__init__.py | 14 +-- moshi/moshi/models/compression.py | 25 +++-- moshi/moshi/models/loaders.py | 86 +++++++++------ moshi/moshi/modules/seanet.py | 6 +- moshi/moshi/modules/transformer.py | 13 +-- moshi/moshi/quantization/base.py | 4 +- moshi/moshi/quantization/core_vq.py | 7 +- moshi/moshi/quantization/vq.py | 4 +- moshi/moshi/server.py | 164 ++++++++++++++-------------- moshi/pyproject.toml | 10 +- moshi/setup.cfg | 1 + 14 files changed, 214 insertions(+), 161 deletions(-) create mode 100644 moshi/README.md diff --git a/moshi/README.md b/moshi/README.md new file mode 100644 index 0000000..022ce8d --- /dev/null +++ b/moshi/README.md @@ -0,0 +1 @@ +# moshi - pytorch diff --git a/moshi/moshi/client.py b/moshi/moshi/client.py index 2872738..cba9048 100644 --- a/moshi/moshi/client.py +++ b/moshi/moshi/client.py @@ -1,16 +1,17 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Client for the Moshi server.""" import argparse import asyncio import queue import sys +import aiohttp import numpy as np import sphn import sounddevice as sd -import aiohttp from .client_utils import AnyPrinter, Printer, RawPrinter @@ -141,7 +142,27 @@ async def run(self) -> None: async def run(printer: AnyPrinter, args): - uri = f"ws://{args.host}:{args.port}/api/chat" + if args.url is None: + proto = "ws" + if args.https: + proto += "s" + uri = f"{proto}://{args.host}:{args.port}/api/chat" + else: + proto = "wss" + if '://' in args.url: + proto, without_proto = args.url.split('://', 1) + if proto in ['ws', 'http']: + proto = "ws" + elif proto in ['wss', 'https']: + proto = "wss" + else: + printer.log("error", "The provided URL {args.url} seems to contain a protocol but it is unknown.") + sys.exit(1) + else: + without_proto = args.url + uri = f"{proto}://{without_proto}/api/chat" + + printer.log("info", "Connecting to {uri}.") async with aiohttp.ClientSession() as session: async with session.ws_connect(uri) as ws: printer.log("info", "connected!") @@ -152,8 +173,11 @@ async def run(printer: AnyPrinter, args): def main(): parser = argparse.ArgumentParser("client_opus") - parser.add_argument("--host", default="localhost", type=str) - parser.add_argument("--port", default=8998, type=int) + parser.add_argument("--host", default="localhost", type=str, help="Hostname to connect to.") + parser.add_argument("--port", default=8998, type=int, help="Port to connect to.") + parser.add_argument("--https", action='store_true', + help="Set this flag for using a https connection.") + parser.add_argument("--url", type=str, help='Provides directly a URL, e.g. to a gradio tunnel.') args = parser.parse_args() printer: AnyPrinter diff --git a/moshi/moshi/client_utils.py b/moshi/moshi/client_utils.py index c7aaa92..0bc37f5 100644 --- a/moshi/moshi/client_utils.py +++ b/moshi/moshi/client_utils.py @@ -1,6 +1,8 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Utilities for the command line client, in particular for handling interactions with the terminal. +""" from dataclasses import dataclass import sys @@ -14,11 +16,11 @@ def colorize(text, color): def make_log(level: str, msg: str) -> str: if level == "warning": - prefix = colorize("Warning:", "1;31") + prefix = colorize("[Warn]", "1;31") elif level == "info": - prefix = colorize("Info:", "1;34") + prefix = colorize("[Info]", "1;34") elif level == "error": - prefix = colorize("Error:", "1;31") + prefix = colorize("[Err ]", "1;31") else: raise ValueError(f"Unknown level {level}") return prefix + " " + msg diff --git a/moshi/moshi/models/__init__.py b/moshi/moshi/models/__init__.py index 1fcf526..5501848 100644 --- a/moshi/moshi/models/__init__.py +++ b/moshi/moshi/models/__init__.py @@ -1,20 +1,14 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. """ -Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. +Models for the compression model Moshi, """ # flake8: noqa -from .encodec import ( +from .compression import ( CompressionModel, - EncodecModel, + MimiModel, ) from .lm import LMModel, LMGen -from .moshi_ import get_encodec, get_lm +from .loaders import get_mimi, get_moshi_lm diff --git a/moshi/moshi/models/compression.py b/moshi/moshi/models/compression.py index 02d1723..9cf9996 100644 --- a/moshi/moshi/models/compression.py +++ b/moshi/moshi/models/compression.py @@ -2,13 +2,15 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# Part of this file is adapted from encodec.py in https://github.com/facebookresearch/audiocraft +# released under the following license. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -"""Compression models or wrapper around existing models. -Also defines the main interface that a model must follow to be usable as an audio tokenizer. +"""Compression models or wrapper around existing models. In particular, provides the implementation +for Mimi. Also defines the main interface that a model must follow to be usable as an audio tokenizer. """ from abc import abstractmethod @@ -19,7 +21,6 @@ import torch from torch import nn -from torch.nn import functional as F from ..quantization import ( @@ -46,12 +47,12 @@ def forward(self, x: torch.Tensor) -> QuantizedResult: ... @abstractmethod def encode(self, x: torch.Tensor) -> torch.Tensor: - """See `EncodecModel.encode`.""" + """See `MimiModel.encode`.""" ... @abstractmethod def decode(self, codes: torch.Tensor) -> torch.Tensor: - """See `EncodecModel.decode`.""" + """See `MimiModel.decode`.""" ... @abstractmethod @@ -90,7 +91,7 @@ def set_num_codebooks(self, n: int): @dataclass -class _EncodecState: +class _MimiState: graphed_tr_enc: CUDAGraphed | None graphed_tr_dec: CUDAGraphed | None @@ -98,8 +99,8 @@ def reset(self): pass -class EncodecModel(CompressionModel[_EncodecState]): - """Encodec model operating on the raw waveform. +class MimiModel(CompressionModel[_MimiState]): + """Mimi model operating on the raw waveform. Args: encoder (nn.Module): Encoder network. @@ -122,6 +123,7 @@ class EncodecModel(CompressionModel[_EncodecState]): torch_compile_encoder_decoder (bool): if True, uses torch.compile on the encoder / decoder. Deactivated by default for training as this is incompatible at the moment with weight norm. See https://github.com/pytorch/pytorch/issues/121902 + Also this seems to work well with 2.2.0, but completely fail with 2.4.0. """ def __init__( @@ -217,14 +219,14 @@ def __init__( channel_wise=upsample_channel_wise_bug, ) - def _init_streaming_state(self, batch_size: int) -> _EncodecState: + def _init_streaming_state(self, batch_size: int) -> _MimiState: graphed_tr_dec = None graphed_tr_enc = None if self.encoder_transformer is not None: graphed_tr_enc = CUDAGraphed(self.encoder_transformer) if self.decoder_transformer is not None: graphed_tr_dec = CUDAGraphed(self.decoder_transformer) - return _EncodecState(graphed_tr_enc, graphed_tr_dec) + return _MimiState(graphed_tr_enc, graphed_tr_dec) @property def channels(self) -> int: @@ -368,7 +370,8 @@ def encode(self, x: torch.Tensor) -> torch.Tensor: x (torch.Tensor): Float tensor of shape [B, C, T] Returns: - codes (torch.Tensor): an int tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. + codes (torch.Tensor): an int tensor of shape [B, K, T] + with K the number of codebooks used and T the timestep. """ emb = self._encode_to_unquantized_latent(x) codes = self.quantizer.encode(emb) diff --git a/moshi/moshi/models/loaders.py b/moshi/moshi/models/loaders.py index d136200..d629ead 100644 --- a/moshi/moshi/models/loaders.py +++ b/moshi/moshi/models/loaders.py @@ -1,21 +1,30 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +"""Retrieves the pretrained models for Moshi and Mimi.""" +from pathlib import Path +import typing as tp -from ..modules import SEANetEncoder, SEANetDecoder, transformer -from .encodec import EncodecModel +from huggingface_hub import hf_hub_download +from safetensors.torch import load_model +import torch + +from .compression import MimiModel from .lm import LMModel +from ..modules import SEANetEncoder, SEANetDecoder, transformer from ..quantization import SplitResidualVectorQuantizer -import torch -from safetensors.torch import load_model -from pathlib import Path -import typing as tp SAMPLE_RATE = 24000 FRAME_RATE = 12.5 +HF_REPO = 'kmhf' +_MODEL_REGISTRY = { + 'mimi': 'tokenizer-e351c8d8-checkpoint125.safetensors', + 'moshiko': 'moshiko_pt_301e30bf@120.safetensors', + 'moshika': 'moshika_pt_3d736a96@120.safetensors', +} -seanet_kwargs = { +_seanet_kwargs = { "channels": 1, "dimension": 512, "causal": True, @@ -35,15 +44,15 @@ "ratios": [8, 6, 5, 4], "true_skip": True, } -quantizer_kwargs = { +_quantizer_kwargs = { "dimension": 256, "n_q": 32, "bins": 2048, - "input_dimension": seanet_kwargs["dimension"], - "output_dimension": seanet_kwargs["dimension"], + "input_dimension": _seanet_kwargs["dimension"], + "output_dimension": _seanet_kwargs["dimension"], } -transformer_kwargs = { - "d_model": seanet_kwargs["dimension"], +_transformer_kwargs = { + "d_model": _seanet_kwargs["dimension"], "num_heads": 8, "num_layers": 8, "causal": True, @@ -55,17 +64,17 @@ "norm": "layer_norm", "positional_embedding": "rope", "dim_feedforward": 2048, - "input_dimension": seanet_kwargs["dimension"], - "output_dimensions": [seanet_kwargs["dimension"]], + "input_dimension": _seanet_kwargs["dimension"], + "output_dimensions": [_seanet_kwargs["dimension"]], } -lm_kwargs = { +_lm_kwargs = { "dim": 4096, "text_card": 32000, "existing_text_padding_id": 3, "n_q": 16, "dep_q": 8, - "card": quantizer_kwargs["bins"], + "card": _quantizer_kwargs["bins"], "num_heads": 32, "num_layers": 32, "hidden_scale": 4.125, @@ -92,24 +101,39 @@ } -def _is_safetensors(filename: tp.Union[str, Path]) -> bool: - filename = Path(filename) - return filename.suffix in (".safetensors", ".sft", ".sfts") +def _is_safetensors(path: Path | str) -> bool: + return Path(path).suffix in (".safetensors", ".sft", ".sfts") + + +def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = False) -> Path: + """Load a model checkpoint from HF, + potentially resolving `name` if it is a known alias (mimi, moshiko, moshika). + If `allow_local_file` is True, then if a file `name` exists, it will be used instead. + """ + if name in _MODEL_REGISTRY: + filename = _MODEL_REGISTRY[name] + elif allow_local_file and Path(name).exists(): + return Path(name) + else: + filename = name + return Path(hf_hub_download(hf_repo, filename)) -def get_encodec(filename: tp.Union[str, Path], device): - encoder = SEANetEncoder(**seanet_kwargs) - decoder = SEANetDecoder(**seanet_kwargs) +def get_mimi(filename: tp.Union[str, Path], + device: torch.device | str = 'cpu') -> MimiModel: + """Return a pretrained Mimi model.""" + encoder = SEANetEncoder(**_seanet_kwargs) + decoder = SEANetDecoder(**_seanet_kwargs) encoder_transformer = transformer.ProjectedTransformer( - device=device, **transformer_kwargs + device=device, **_transformer_kwargs ) decoder_transformer = transformer.ProjectedTransformer( - device=device, **transformer_kwargs + device=device, **_transformer_kwargs ) quantizer = SplitResidualVectorQuantizer( - **quantizer_kwargs, + **_quantizer_kwargs, ) - model = EncodecModel( + model = MimiModel( encoder, decoder, quantizer, @@ -126,21 +150,19 @@ def get_encodec(filename: tp.Union[str, Path], device): if _is_safetensors(filename): load_model(model, filename) else: - pkg = torch.load( - filename, - "cpu", - ) + pkg = torch.load(filename, "cpu") model.load_state_dict(pkg["model"]) model.set_num_codebooks(8) return model -def get_lm(filename: tp.Union[str, Path], device): +def get_moshi_lm(filename: tp.Union[str, Path], + device: torch.device | str = 'cpu') -> LMModel: dtype = torch.bfloat16 model = LMModel( device=device, dtype=dtype, - **lm_kwargs, + **_lm_kwargs, ).to(device=device, dtype=dtype) model.eval() if _is_safetensors(filename): diff --git a/moshi/moshi/modules/seanet.py b/moshi/moshi/modules/seanet.py index 0fe706b..1d8ff28 100644 --- a/moshi/moshi/modules/seanet.py +++ b/moshi/moshi/modules/seanet.py @@ -159,8 +159,7 @@ def __init__( self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks assert ( - self.disable_norm_outer_blocks >= 0 - and self.disable_norm_outer_blocks <= self.n_blocks + self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks ), ( "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." @@ -307,8 +306,7 @@ def __init__( self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks assert ( - self.disable_norm_outer_blocks >= 0 - and self.disable_norm_outer_blocks <= self.n_blocks + self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks ), ( "Number of blocks for which to disable norm is invalid." "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." diff --git a/moshi/moshi/modules/transformer.py b/moshi/moshi/modules/transformer.py index d3d0ede..212d721 100644 --- a/moshi/moshi/modules/transformer.py +++ b/moshi/moshi/modules/transformer.py @@ -240,10 +240,7 @@ def reset(self): def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult: assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape) B, H, T, D = k.shape - indexes = ( - torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) - + self.end_offset - ) + indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset indexes = indexes % self.capacity self.cache[0].index_copy_(2, indexes, k) self.cache[1].index_copy_(2, indexes, v) @@ -485,8 +482,8 @@ def __init__( context=context, rope=rope, weights_per_step=weights_per_step, - **attn_kwargs, - **factory_kwargs, + **attn_kwargs, # type: ignore + **factory_kwargs, # type: ignore ) # type: ignore self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) @@ -542,8 +539,8 @@ def __init__( self.layer_scale_1 = nn.Identity() self.layer_scale_2 = nn.Identity() else: - self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) - self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) + self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore + self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore def _init_streaming_state(self, batch_size: int) -> _LayerState: return _LayerState(offset_cpu=0) diff --git a/moshi/moshi/quantization/base.py b/moshi/moshi/quantization/base.py index e8f0ad4..02228a9 100644 --- a/moshi/moshi/quantization/base.py +++ b/moshi/moshi/quantization/base.py @@ -68,7 +68,7 @@ def num_codebooks(self) -> int: raise NotImplementedError() @property - def semantic_quantizer(self): + def semantic_quantizer(self) -> 'BaseQuantizer': """This returns the quantizer that models the first level of the hierarchy (typically semantic). In this case, it's the quantizer itself. @@ -76,7 +76,7 @@ def semantic_quantizer(self): return self @property - def acoustic_quantizer(self): + def acoustic_quantizer(self) -> 'BaseQuantizer': """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic). In this case, it's the quantizer itself. diff --git a/moshi/moshi/quantization/core_vq.py b/moshi/moshi/quantization/core_vq.py index 670b3a9..54abb5b 100644 --- a/moshi/moshi/quantization/core_vq.py +++ b/moshi/moshi/quantization/core_vq.py @@ -8,10 +8,9 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import math import typing as tp -from einops import rearrange, repeat +from einops import rearrange import torch from torch import nn from torch import distributed @@ -339,7 +338,7 @@ def forward( n_q = n_q or len(self.layers) previous_layer_is_initialized = True - for i, layer in enumerate(self.layers[:n_q]): + for i, layer in enumerate(self.layers[:n_q]): # type: ignore quantized, codes, loss, metrics = layer( residual, initialize=previous_layer_is_initialized ) @@ -366,7 +365,7 @@ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: residual = x all_indices = [] n_q = n_q or len(self.layers) - for layer in self.layers[:n_q]: + for layer in self.layers[:n_q]: # type: ignore indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized diff --git a/moshi/moshi/quantization/vq.py b/moshi/moshi/quantization/vq.py index 0e436c1..4fa5b0a 100644 --- a/moshi/moshi/quantization/vq.py +++ b/moshi/moshi/quantization/vq.py @@ -321,12 +321,12 @@ def dimension(self): return self.rvq_first.dimension @property - def semantic_quantizer(self): + def semantic_quantizer(self) -> ResidualVectorQuantizer: """This returns the quantizer that models the first level of the hierarchy (typically semantic).""" return self.rvq_first @property - def acoustic_quantizer(self): + def acoustic_quantizer(self) -> ResidualVectorQuantizer: """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic).""" return self.rvq_rest diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index 70dfbff..b683520 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -5,67 +5,29 @@ import argparse import asyncio from dataclasses import dataclass -from pathlib import Path import random +import os +from pathlib import Path import tarfile import time +import secrets +import sys -import os +import aiohttp +from aiohttp import web +from huggingface_hub import hf_hub_download import numpy as np import sentencepiece import sphn import torch -import aiohttp -from aiohttp import web - -from huggingface_hub import hf_hub_download - -from .models import moshi_, EncodecModel, LMGen -SAMPLE_RATE = moshi_.SAMPLE_RATE -DEVICE = "cuda:0" -ENABLE_PROFILING = False - -def colorize(text, color): - code = f"\033[{color}m" - restore = "\033[0m" - return "".join([code, text, restore]) +from .client_utils import make_log +from .models import loaders, MimiModel, LMModel, LMGen def log(level: str, msg: str): - if level == "warning": - prefix = colorize("[Warn]", "1;31") - elif level == "info": - prefix = colorize("[Info]", "1;34") - elif level == "error": - prefix = colorize("[Err ]", "1;31") - else: - raise ValueError(f"Unknown level {level}") - print(prefix + " " + msg) - - -parser = argparse.ArgumentParser() -parser.add_argument("--host", default="localhost", type=str) -parser.add_argument("--port", default=8998, type=int) -parser.add_argument("--static", type=str) -parser.add_argument("--tokenizer", type=str) -parser.add_argument("--moshi-weights", type=str) -parser.add_argument("--mimi-weights", type=str) -parser.add_argument("--hf-repo", type=str, default="kmhf/msh-v0.1") - -args = parser.parse_args() - -if args.tokenizer is None: - args.tokenizer = hf_hub_download(args.hf_repo, "tokenizer_spm_32k_3.model") -if args.moshi_weights is None: - args.moshi_weights = hf_hub_download( - args.hf_repo, "moshiko_pt_301e30bf@120.safetensors" - ) -if args.mimi_weights is None: - args.mimi_weights = hf_hub_download( - args.hf_repo, "tokenizer-e351c8d8-checkpoint125.safetensors" - ) + print(make_log(level, msg)) def seed_all(seed): @@ -79,43 +41,35 @@ def seed_all(seed): torch.backends.cudnn.benchmark = False -seed_all(42424242) - - @dataclass class ServerState: - ec: EncodecModel + mimi: MimiModel text_tokenizer: sentencepiece.SentencePieceProcessor lm_gen: LMGen lock: asyncio.Lock - def __init__(self): - log("info", "loading mimi") - self.ec = moshi_.get_encodec(args.mimi_weights, DEVICE) - log("info", "mimi loaded") - self.text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) - log("info", "loading moshi") - lm = moshi_.get_lm(args.moshi_weights, DEVICE) + def __init__(self, mimi: MimiModel, text_tokenizer: sentencepiece.SentencePieceProcessor, + lm: LMModel, device: str | torch.device): + self.mimi = mimi + self.text_tokenizer = text_tokenizer self.lm_gen = LMGen(lm) - self.frame_size = int(self.ec.sample_rate / self.ec.frame_rate) + self.device = device + self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate) self.lock = asyncio.Lock() - self.ec.streaming_forever(1) + self.mimi.streaming_forever(1) self.lm_gen.streaming_forever(1) - log("info", "lm loaded") def warmup(self): for chunk in range(4): - chunk = torch.zeros( - 1, 1, self.frame_size, dtype=torch.float32, device=DEVICE - ) - codes = self.ec.encode(chunk) + chunk = torch.zeros(1, 1, self.frame_size, dtype=torch.float32, device=self.device) + codes = self.mimi.encode(chunk) for c in range(codes.shape[-1]): - tokens = self.lm_gen.step(codes[:, :, c : c + 1]) + tokens = self.lm_gen.step(codes[:, :, c: c + 1]) if tokens is None: continue - _ = self.ec.decode(tokens[:, 1:]) + _ = self.mimi.decode(tokens[:, 1:]) torch.cuda.synchronize() async def handle_chat(self, request): @@ -168,21 +122,21 @@ async def opus_loop(): while all_pcm_data.shape[-1] >= self.frame_size: be = time.time() chunk = all_pcm_data[: self.frame_size] - all_pcm_data = all_pcm_data[self.frame_size :] + all_pcm_data = all_pcm_data[self.frame_size:] chunk = torch.from_numpy(chunk) - chunk = chunk.to(device=DEVICE)[None, None] - codes = self.ec.encode(chunk) + chunk = chunk.to(device=self.device)[None, None] + codes = self.mimi.encode(chunk) for c in range(codes.shape[-1]): - tokens = self.lm_gen.step(codes[:, :, c : c + 1]) + tokens = self.lm_gen.step(codes[:, :, c: c + 1]) if tokens is None: continue assert tokens.shape[1] == self.lm_gen.lm_model.dep_q + 1 - main_pcm = self.ec.decode(tokens[:, 1:]) + main_pcm = self.mimi.decode(tokens[:, 1:]) main_pcm = main_pcm.cpu() opus_writer.append_pcm(main_pcm[0, 0].numpy()) text_token = tokens[0, 0, 0].item() if text_token not in (0, 3): - _text = self.text_tokenizer.id_to_piece(text_token) + _text = self.text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") msg = b"\x02" + bytes(_text, encoding="utf8") log("info", f"text token '{_text}'") @@ -201,9 +155,9 @@ async def send_loop(): log("info", "accepted connection") close = False async with self.lock: - opus_writer = sphn.OpusStreamWriter(self.ec.sample_rate) - opus_reader = sphn.OpusStreamReader(self.ec.sample_rate) - self.ec.reset_streaming() + opus_writer = sphn.OpusStreamWriter(self.mimi.sample_rate) + opus_reader = sphn.OpusStreamReader(self.mimi.sample_rate) + self.mimi.reset_streaming() self.lm_gen.reset_streaming() # Send the handshake. await ws.send_bytes(b"\x00") @@ -213,14 +167,62 @@ async def send_loop(): def main(): - state = ServerState() + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="localhost", type=str) + parser.add_argument("--port", default=8998, type=int) + parser.add_argument("--static", type=str) + parser.add_argument("--gradio_tunnel", action='store_true', help='Activate a gradio tunnel.') + parser.add_argument("--gradio_tunnel_token", + help='Provide a custom (secret) token here to keep getting the same URL.') + + parser.add_argument("--tokenizer-name", type=str, default="tokenizer_spm_32k_3.model", + help="Name of the text tokenizer file in the given HF repo, or path to a local file.") + parser.add_argument("--moshi-name", type=str, default="moshiko", + help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.") + parser.add_argument("--mimi-name", type=str, default="mimi", + help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.") + parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO, + help="HF repo to look into, defaults to Kyutai official one.") + parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.") + + args = parser.parse_args() + seed_all(42424242) + + setup_tunnel = None + tunnel_token = '' + if args.gradio: + try: + from gradio import networking # type: ignore + except ImportError: + log("error", "Cannot find gradio which is required to activate a tunnel. " + "Please install with `pip install gradio`.") + sys.exit(1) + setup_tunnel = networking.setup_tunnel + if args.gradio_tunnel_secret is None: + tunnel_token = secrets.token_urlsafe(32) + else: + tunnel_token = args.gradio_tunnel_secret + + log("info", "loading mimi") + mimi_path = loaders.resolve_model_checkpoint(args.mimi_name, args.hf_repo, allow_local_file=True) + mimi = loaders.get_mimi(mimi_path, args.device) + log("info", "mimi loaded") + + tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer_name, args.hf_repo, allow_local_file=True) + text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path) # type: ignore + log("info", "loading moshi") + + moshi_path = loaders.resolve_model_checkpoint(args.moshi_name, args.hf_repo, allow_local_file=True) + lm = loaders.get_moshi_lm(moshi_path, args.device) + + state = ServerState(mimi, text_tokenizer, lm, args.device) log("info", "warming up the model") state.warmup() app = web.Application() app.router.add_get("/api/chat", state.handle_chat) static_path: None | str = None if args.static is None: - log("info", f"retrieving the static content") + log("info", "retrieving the static content") dist_tgz = hf_hub_download(args.hf_repo, "dist.tgz") dist_tgz = Path(dist_tgz) dist = dist_tgz.parent / "dist" @@ -232,7 +234,6 @@ def main(): # When set to the "none" string, we don't serve any static content. static_path = args.static if static_path is not None: - async def handle_root(_): return web.FileResponse(os.path.join(static_path, "index.html")) @@ -241,6 +242,9 @@ async def handle_root(_): app.router.add_static( "/", path=static_path, follow_symlinks=True, name="static" ) + if setup_tunnel is not None: + tunnel = setup_tunnel('localhost', args.port, tunnel_token, None) + log("info", f"Tunnel started listening at {tunnel}.") log("info", f"listening to ws://{args.host}:{args.port}") web.run_app(app, port=args.port) diff --git a/moshi/pyproject.toml b/moshi/pyproject.toml index 61d4c54..4b77729 100644 --- a/moshi/pyproject.toml +++ b/moshi/pyproject.toml @@ -1,6 +1,5 @@ [project] name = "moshi" -version = "0.0.1" requires-python = ">= 3.10" description = "Moshi is moshi" dependencies = [ @@ -17,6 +16,7 @@ dependencies = [ authors = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] maintainers = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] license = {text = "MIT"} +dynamic = ["version"] [build-system] @@ -25,3 +25,11 @@ build-backend = "setuptools.build_meta" [tool.setuptools] packages = ["moshi", "moshi.utils", "moshi.modules", "moshi.models", "moshi.quantization"] + + +[project.optional-dependencies] +dev = [ + "pyright", + "flake8", + "pre-commit", +] diff --git a/moshi/setup.cfg b/moshi/setup.cfg index dc7aa4b..5bccac4 100644 --- a/moshi/setup.cfg +++ b/moshi/setup.cfg @@ -3,3 +3,4 @@ max-line-length = 120 [flake8] max-line-length = 120 +ignore = E203,E704