Skip to content

Commit

Permalink
Merge pull request #28 from kyutai-labs/refacto
Browse files Browse the repository at this point in the history
refacto [1/N]
  • Loading branch information
LaurentMazare authored Sep 18, 2024
2 parents bdf176d + 811c885 commit 5112385
Show file tree
Hide file tree
Showing 20 changed files with 324 additions and 246 deletions.
1 change: 1 addition & 0 deletions moshi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# moshi - pytorch
32 changes: 28 additions & 4 deletions moshi/moshi/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Client for the Moshi server."""

import argparse
import asyncio
import queue
import sys

import aiohttp
import numpy as np
import sphn
import sounddevice as sd
import aiohttp

from .client_utils import AnyPrinter, Printer, RawPrinter

Expand Down Expand Up @@ -141,7 +142,27 @@ async def run(self) -> None:


async def run(printer: AnyPrinter, args):
uri = f"ws://{args.host}:{args.port}/api/chat"
if args.url is None:
proto = "ws"
if args.https:
proto += "s"
uri = f"{proto}://{args.host}:{args.port}/api/chat"
else:
proto = "wss"
if '://' in args.url:
proto, without_proto = args.url.split('://', 1)
if proto in ['ws', 'http']:
proto = "ws"
elif proto in ['wss', 'https']:
proto = "wss"
else:
printer.log("error", "The provided URL {args.url} seems to contain a protocol but it is unknown.")
sys.exit(1)
else:
without_proto = args.url
uri = f"{proto}://{without_proto}/api/chat"

printer.log("info", "Connecting to {uri}.")
async with aiohttp.ClientSession() as session:
async with session.ws_connect(uri) as ws:
printer.log("info", "connected!")
Expand All @@ -152,8 +173,11 @@ async def run(printer: AnyPrinter, args):

def main():
parser = argparse.ArgumentParser("client_opus")
parser.add_argument("--host", default="localhost", type=str)
parser.add_argument("--port", default=8998, type=int)
parser.add_argument("--host", default="localhost", type=str, help="Hostname to connect to.")
parser.add_argument("--port", default=8998, type=int, help="Port to connect to.")
parser.add_argument("--https", action='store_true',
help="Set this flag for using a https connection.")
parser.add_argument("--url", type=str, help='Provides directly a URL, e.g. to a gradio tunnel.')
args = parser.parse_args()
printer: AnyPrinter

Expand Down
8 changes: 5 additions & 3 deletions moshi/moshi/client_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utilities for the command line client, in particular for handling interactions with the terminal.
"""

from dataclasses import dataclass
import sys
Expand All @@ -14,11 +16,11 @@ def colorize(text, color):

def make_log(level: str, msg: str) -> str:
if level == "warning":
prefix = colorize("Warning:", "1;31")
prefix = colorize("[Warn]", "1;31")
elif level == "info":
prefix = colorize("Info:", "1;34")
prefix = colorize("[Info]", "1;34")
elif level == "error":
prefix = colorize("Error:", "1;31")
prefix = colorize("[Err ]", "1;31")
else:
raise ValueError(f"Unknown level {level}")
return prefix + " " + msg
Expand Down
14 changes: 4 additions & 10 deletions moshi/moshi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
Models for the compression model Moshi,
"""

# flake8: noqa
from .encodec import (
from .compression import (
CompressionModel,
EncodecModel,
MimiModel,
)
from .lm import LMModel, LMGen
from .moshi_ import get_encodec, get_lm
from .loaders import get_mimi, get_moshi_lm
31 changes: 18 additions & 13 deletions moshi/moshi/models/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Part of this file is adapted from encodec.py in https://github.com/facebookresearch/audiocraft
# released under the following license.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Compression models or wrapper around existing models.
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
"""Compression models or wrapper around existing models. In particular, provides the implementation
for Mimi. Also defines the main interface that a model must follow to be usable as an audio tokenizer.
"""

from abc import abstractmethod
Expand All @@ -19,7 +21,6 @@

import torch
from torch import nn
from torch.nn import functional as F


from ..quantization import (
Expand All @@ -46,12 +47,12 @@ def forward(self, x: torch.Tensor) -> QuantizedResult: ...

@abstractmethod
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""See `EncodecModel.encode`."""
"""See `MimiModel.encode`."""
...

@abstractmethod
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""See `EncodecModel.decode`."""
"""See `MimiModel.decode`."""
...

@abstractmethod
Expand Down Expand Up @@ -90,16 +91,16 @@ def set_num_codebooks(self, n: int):


@dataclass
class _EncodecState:
class _MimiState:
graphed_tr_enc: CUDAGraphed | None
graphed_tr_dec: CUDAGraphed | None

def reset(self):
pass


class EncodecModel(CompressionModel[_EncodecState]):
"""Encodec model operating on the raw waveform.
class MimiModel(CompressionModel[_MimiState]):
"""Mimi model operating on the raw waveform.
Args:
encoder (nn.Module): Encoder network.
Expand All @@ -122,6 +123,7 @@ class EncodecModel(CompressionModel[_EncodecState]):
torch_compile_encoder_decoder (bool): if True, uses torch.compile on the encoder / decoder.
Deactivated by default for training as this is incompatible at the moment with weight norm.
See https://github.com/pytorch/pytorch/issues/121902
Also this seems to work well with 2.2.0, but completely fail with 2.4.0.
"""

def __init__(
Expand Down Expand Up @@ -217,14 +219,16 @@ def __init__(
channel_wise=upsample_channel_wise_bug,
)

def _init_streaming_state(self, batch_size: int) -> _EncodecState:
def _init_streaming_state(self, batch_size: int) -> _MimiState:
device = next(self.parameters()).device
disable = device.type != 'cuda'
graphed_tr_dec = None
graphed_tr_enc = None
if self.encoder_transformer is not None:
graphed_tr_enc = CUDAGraphed(self.encoder_transformer)
graphed_tr_enc = CUDAGraphed(self.encoder_transformer, disable=disable)
if self.decoder_transformer is not None:
graphed_tr_dec = CUDAGraphed(self.decoder_transformer)
return _EncodecState(graphed_tr_enc, graphed_tr_dec)
graphed_tr_dec = CUDAGraphed(self.decoder_transformer, disable=disable)
return _MimiState(graphed_tr_enc, graphed_tr_dec)

@property
def channels(self) -> int:
Expand Down Expand Up @@ -368,7 +372,8 @@ def encode(self, x: torch.Tensor) -> torch.Tensor:
x (torch.Tensor): Float tensor of shape [B, C, T]
Returns:
codes (torch.Tensor): an int tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
codes (torch.Tensor): an int tensor of shape [B, K, T]
with K the number of codebooks used and T the timestep.
"""
emb = self._encode_to_unquantized_latent(x)
codes = self.quantizer.encode(emb)
Expand Down
5 changes: 3 additions & 2 deletions moshi/moshi/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,9 @@ def _init_streaming_state(self, batch_size: int) -> _LMGenState:
dtype=torch.long,
)

graphed_main = CUDAGraphed(lm_model.forward_text)
graphed_depth = CUDAGraphed(self.depformer_step)
disable = lm_model.device.type != 'cuda'
graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable)
graphed_depth = CUDAGraphed(self.depformer_step, disable=disable)

return _LMGenState(cache, initial, graphed_main, graphed_depth)

Expand Down
86 changes: 54 additions & 32 deletions moshi/moshi/models/loaders.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Retrieves the pretrained models for Moshi and Mimi."""
from pathlib import Path

from ..modules import SEANetEncoder, SEANetDecoder, transformer
from .encodec import EncodecModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_model
import sentencepiece
import torch

from .compression import MimiModel
from .lm import LMModel
from ..modules import SEANetEncoder, SEANetDecoder, transformer
from ..quantization import SplitResidualVectorQuantizer
import torch
from safetensors.torch import load_model
from pathlib import Path
import typing as tp

SAMPLE_RATE = 24000
FRAME_RATE = 12.5
HF_REPO = 'kmhf/msh-v0.1'
MIMI_V0_1 = 'tokenizer-e351c8d8-checkpoint125.safetensors'
MOSHIKO_V0_1 = '[email protected]'
MOSHIKA_V0_1 = '[email protected]'
TEXT_TOKENIZER_V0_1 = 'tokenizer_spm_32k_3.model'


seanet_kwargs = {
_seanet_kwargs = {
"channels": 1,
"dimension": 512,
"causal": True,
Expand All @@ -35,15 +43,15 @@
"ratios": [8, 6, 5, 4],
"true_skip": True,
}
quantizer_kwargs = {
_quantizer_kwargs = {
"dimension": 256,
"n_q": 32,
"bins": 2048,
"input_dimension": seanet_kwargs["dimension"],
"output_dimension": seanet_kwargs["dimension"],
"input_dimension": _seanet_kwargs["dimension"],
"output_dimension": _seanet_kwargs["dimension"],
}
transformer_kwargs = {
"d_model": seanet_kwargs["dimension"],
_transformer_kwargs = {
"d_model": _seanet_kwargs["dimension"],
"num_heads": 8,
"num_layers": 8,
"causal": True,
Expand All @@ -55,17 +63,17 @@
"norm": "layer_norm",
"positional_embedding": "rope",
"dim_feedforward": 2048,
"input_dimension": seanet_kwargs["dimension"],
"output_dimensions": [seanet_kwargs["dimension"]],
"input_dimension": _seanet_kwargs["dimension"],
"output_dimensions": [_seanet_kwargs["dimension"]],
}

lm_kwargs = {
_lm_kwargs = {
"dim": 4096,
"text_card": 32000,
"existing_text_padding_id": 3,
"n_q": 16,
"dep_q": 8,
"card": quantizer_kwargs["bins"],
"card": _quantizer_kwargs["bins"],
"num_heads": 32,
"num_layers": 32,
"hidden_scale": 4.125,
Expand All @@ -92,24 +100,40 @@
}


def _is_safetensors(filename: tp.Union[str, Path]) -> bool:
filename = Path(filename)
return filename.suffix in (".safetensors", ".sft", ".sfts")
def _is_safetensors(path: Path | str) -> bool:
return Path(path).suffix in (".safetensors", ".sft", ".sfts")


def get_encodec(filename: tp.Union[str, Path], device):
encoder = SEANetEncoder(**seanet_kwargs)
decoder = SEANetDecoder(**seanet_kwargs)
def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = True) -> Path:
"""Load a model checkpoint from HF.
If `allow_local_file` is True, then if a file `name` exists, it will be used instead.
"""
if allow_local_file and Path(name).exists():
return Path(name)
else:
filename = name
return Path(hf_hub_download(hf_repo, filename))


def get_text_tokenizer(filename: str | Path) -> sentencepiece.SentencePieceProcessor:
return sentencepiece.SentencePieceProcessor(str(filename)) # type: ignore


def get_mimi(filename: str | Path,
device: torch.device | str = 'cpu') -> MimiModel:
"""Return a pretrained Mimi model."""
encoder = SEANetEncoder(**_seanet_kwargs)
decoder = SEANetDecoder(**_seanet_kwargs)
encoder_transformer = transformer.ProjectedTransformer(
device=device, **transformer_kwargs
device=device, **_transformer_kwargs
)
decoder_transformer = transformer.ProjectedTransformer(
device=device, **transformer_kwargs
device=device, **_transformer_kwargs
)
quantizer = SplitResidualVectorQuantizer(
**quantizer_kwargs,
**_quantizer_kwargs,
)
model = EncodecModel(
model = MimiModel(
encoder,
decoder,
quantizer,
Expand All @@ -126,21 +150,19 @@ def get_encodec(filename: tp.Union[str, Path], device):
if _is_safetensors(filename):
load_model(model, filename)
else:
pkg = torch.load(
filename,
"cpu",
)
pkg = torch.load(filename, "cpu")
model.load_state_dict(pkg["model"])
model.set_num_codebooks(8)
return model


def get_lm(filename: tp.Union[str, Path], device):
def get_moshi_lm(filename: str | Path,
device: torch.device | str = 'cpu') -> LMModel:
dtype = torch.bfloat16
model = LMModel(
device=device,
dtype=dtype,
**lm_kwargs,
**_lm_kwargs,
).to(device=device, dtype=dtype)
model.eval()
if _is_safetensors(filename):
Expand Down
Loading

0 comments on commit 5112385

Please sign in to comment.