Skip to content

Commit

Permalink
refacto
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed Sep 17, 2024
1 parent bdf176d commit ccbe247
Show file tree
Hide file tree
Showing 14 changed files with 214 additions and 161 deletions.
1 change: 1 addition & 0 deletions moshi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# moshi - pytorch
32 changes: 28 additions & 4 deletions moshi/moshi/client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Client for the Moshi server."""

import argparse
import asyncio
import queue
import sys

import aiohttp
import numpy as np
import sphn
import sounddevice as sd
import aiohttp

from .client_utils import AnyPrinter, Printer, RawPrinter

Expand Down Expand Up @@ -141,7 +142,27 @@ async def run(self) -> None:


async def run(printer: AnyPrinter, args):
uri = f"ws://{args.host}:{args.port}/api/chat"
if args.url is None:
proto = "ws"
if args.https:
proto += "s"
uri = f"{proto}://{args.host}:{args.port}/api/chat"
else:
proto = "wss"
if '://' in args.url:
proto, without_proto = args.url.split('://', 1)
if proto in ['ws', 'http']:
proto = "ws"
elif proto in ['wss', 'https']:
proto = "wss"
else:
printer.log("error", "The provided URL {args.url} seems to contain a protocol but it is unknown.")
sys.exit(1)
else:
without_proto = args.url
uri = f"{proto}://{without_proto}/api/chat"

printer.log("info", "Connecting to {uri}.")
async with aiohttp.ClientSession() as session:
async with session.ws_connect(uri) as ws:
printer.log("info", "connected!")
Expand All @@ -152,8 +173,11 @@ async def run(printer: AnyPrinter, args):

def main():
parser = argparse.ArgumentParser("client_opus")
parser.add_argument("--host", default="localhost", type=str)
parser.add_argument("--port", default=8998, type=int)
parser.add_argument("--host", default="localhost", type=str, help="Hostname to connect to.")
parser.add_argument("--port", default=8998, type=int, help="Port to connect to.")
parser.add_argument("--https", action='store_true',
help="Set this flag for using a https connection.")
parser.add_argument("--url", type=str, help='Provides directly a URL, e.g. to a gradio tunnel.')
args = parser.parse_args()
printer: AnyPrinter

Expand Down
8 changes: 5 additions & 3 deletions moshi/moshi/client_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Utilities for the command line client, in particular for handling interactions with the terminal.
"""

from dataclasses import dataclass
import sys
Expand All @@ -14,11 +16,11 @@ def colorize(text, color):

def make_log(level: str, msg: str) -> str:
if level == "warning":
prefix = colorize("Warning:", "1;31")
prefix = colorize("[Warn]", "1;31")
elif level == "info":
prefix = colorize("Info:", "1;34")
prefix = colorize("[Info]", "1;34")
elif level == "error":
prefix = colorize("Error:", "1;31")
prefix = colorize("[Err ]", "1;31")
else:
raise ValueError(f"Unknown level {level}")
return prefix + " " + msg
Expand Down
14 changes: 4 additions & 10 deletions moshi/moshi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
Models for the compression model Moshi,
"""

# flake8: noqa
from .encodec import (
from .compression import (
CompressionModel,
EncodecModel,
MimiModel,
)
from .lm import LMModel, LMGen
from .moshi_ import get_encodec, get_lm
from .loaders import get_mimi, get_moshi_lm
25 changes: 14 additions & 11 deletions moshi/moshi/models/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Part of this file is adapted from encodec.py in https://github.com/facebookresearch/audiocraft
# released under the following license.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Compression models or wrapper around existing models.
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
"""Compression models or wrapper around existing models. In particular, provides the implementation
for Mimi. Also defines the main interface that a model must follow to be usable as an audio tokenizer.
"""

from abc import abstractmethod
Expand All @@ -19,7 +21,6 @@

import torch
from torch import nn
from torch.nn import functional as F


from ..quantization import (
Expand All @@ -46,12 +47,12 @@ def forward(self, x: torch.Tensor) -> QuantizedResult: ...

@abstractmethod
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""See `EncodecModel.encode`."""
"""See `MimiModel.encode`."""
...

@abstractmethod
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""See `EncodecModel.decode`."""
"""See `MimiModel.decode`."""
...

@abstractmethod
Expand Down Expand Up @@ -90,16 +91,16 @@ def set_num_codebooks(self, n: int):


@dataclass
class _EncodecState:
class _MimiState:
graphed_tr_enc: CUDAGraphed | None
graphed_tr_dec: CUDAGraphed | None

def reset(self):
pass


class EncodecModel(CompressionModel[_EncodecState]):
"""Encodec model operating on the raw waveform.
class MimiModel(CompressionModel[_MimiState]):
"""Mimi model operating on the raw waveform.
Args:
encoder (nn.Module): Encoder network.
Expand All @@ -122,6 +123,7 @@ class EncodecModel(CompressionModel[_EncodecState]):
torch_compile_encoder_decoder (bool): if True, uses torch.compile on the encoder / decoder.
Deactivated by default for training as this is incompatible at the moment with weight norm.
See https://github.com/pytorch/pytorch/issues/121902
Also this seems to work well with 2.2.0, but completely fail with 2.4.0.
"""

def __init__(
Expand Down Expand Up @@ -217,14 +219,14 @@ def __init__(
channel_wise=upsample_channel_wise_bug,
)

def _init_streaming_state(self, batch_size: int) -> _EncodecState:
def _init_streaming_state(self, batch_size: int) -> _MimiState:
graphed_tr_dec = None
graphed_tr_enc = None
if self.encoder_transformer is not None:
graphed_tr_enc = CUDAGraphed(self.encoder_transformer)
if self.decoder_transformer is not None:
graphed_tr_dec = CUDAGraphed(self.decoder_transformer)
return _EncodecState(graphed_tr_enc, graphed_tr_dec)
return _MimiState(graphed_tr_enc, graphed_tr_dec)

@property
def channels(self) -> int:
Expand Down Expand Up @@ -368,7 +370,8 @@ def encode(self, x: torch.Tensor) -> torch.Tensor:
x (torch.Tensor): Float tensor of shape [B, C, T]
Returns:
codes (torch.Tensor): an int tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
codes (torch.Tensor): an int tensor of shape [B, K, T]
with K the number of codebooks used and T the timestep.
"""
emb = self._encode_to_unquantized_latent(x)
codes = self.quantizer.encode(emb)
Expand Down
86 changes: 54 additions & 32 deletions moshi/moshi/models/loaders.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Retrieves the pretrained models for Moshi and Mimi."""
from pathlib import Path
import typing as tp

from ..modules import SEANetEncoder, SEANetDecoder, transformer
from .encodec import EncodecModel
from huggingface_hub import hf_hub_download
from safetensors.torch import load_model
import torch

from .compression import MimiModel
from .lm import LMModel
from ..modules import SEANetEncoder, SEANetDecoder, transformer
from ..quantization import SplitResidualVectorQuantizer
import torch
from safetensors.torch import load_model
from pathlib import Path
import typing as tp

SAMPLE_RATE = 24000
FRAME_RATE = 12.5
HF_REPO = 'kmhf'

_MODEL_REGISTRY = {
'mimi': 'tokenizer-e351c8d8-checkpoint125.safetensors',
'moshiko': '[email protected]',
'moshika': '[email protected]',
}

seanet_kwargs = {
_seanet_kwargs = {
"channels": 1,
"dimension": 512,
"causal": True,
Expand All @@ -35,15 +44,15 @@
"ratios": [8, 6, 5, 4],
"true_skip": True,
}
quantizer_kwargs = {
_quantizer_kwargs = {
"dimension": 256,
"n_q": 32,
"bins": 2048,
"input_dimension": seanet_kwargs["dimension"],
"output_dimension": seanet_kwargs["dimension"],
"input_dimension": _seanet_kwargs["dimension"],
"output_dimension": _seanet_kwargs["dimension"],
}
transformer_kwargs = {
"d_model": seanet_kwargs["dimension"],
_transformer_kwargs = {
"d_model": _seanet_kwargs["dimension"],
"num_heads": 8,
"num_layers": 8,
"causal": True,
Expand All @@ -55,17 +64,17 @@
"norm": "layer_norm",
"positional_embedding": "rope",
"dim_feedforward": 2048,
"input_dimension": seanet_kwargs["dimension"],
"output_dimensions": [seanet_kwargs["dimension"]],
"input_dimension": _seanet_kwargs["dimension"],
"output_dimensions": [_seanet_kwargs["dimension"]],
}

lm_kwargs = {
_lm_kwargs = {
"dim": 4096,
"text_card": 32000,
"existing_text_padding_id": 3,
"n_q": 16,
"dep_q": 8,
"card": quantizer_kwargs["bins"],
"card": _quantizer_kwargs["bins"],
"num_heads": 32,
"num_layers": 32,
"hidden_scale": 4.125,
Expand All @@ -92,24 +101,39 @@
}


def _is_safetensors(filename: tp.Union[str, Path]) -> bool:
filename = Path(filename)
return filename.suffix in (".safetensors", ".sft", ".sfts")
def _is_safetensors(path: Path | str) -> bool:
return Path(path).suffix in (".safetensors", ".sft", ".sfts")


def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = False) -> Path:
"""Load a model checkpoint from HF,
potentially resolving `name` if it is a known alias (mimi, moshiko, moshika).
If `allow_local_file` is True, then if a file `name` exists, it will be used instead.
"""
if name in _MODEL_REGISTRY:
filename = _MODEL_REGISTRY[name]
elif allow_local_file and Path(name).exists():
return Path(name)
else:
filename = name
return Path(hf_hub_download(hf_repo, filename))


def get_encodec(filename: tp.Union[str, Path], device):
encoder = SEANetEncoder(**seanet_kwargs)
decoder = SEANetDecoder(**seanet_kwargs)
def get_mimi(filename: tp.Union[str, Path],
device: torch.device | str = 'cpu') -> MimiModel:
"""Return a pretrained Mimi model."""
encoder = SEANetEncoder(**_seanet_kwargs)
decoder = SEANetDecoder(**_seanet_kwargs)
encoder_transformer = transformer.ProjectedTransformer(
device=device, **transformer_kwargs
device=device, **_transformer_kwargs
)
decoder_transformer = transformer.ProjectedTransformer(
device=device, **transformer_kwargs
device=device, **_transformer_kwargs
)
quantizer = SplitResidualVectorQuantizer(
**quantizer_kwargs,
**_quantizer_kwargs,
)
model = EncodecModel(
model = MimiModel(
encoder,
decoder,
quantizer,
Expand All @@ -126,21 +150,19 @@ def get_encodec(filename: tp.Union[str, Path], device):
if _is_safetensors(filename):
load_model(model, filename)
else:
pkg = torch.load(
filename,
"cpu",
)
pkg = torch.load(filename, "cpu")
model.load_state_dict(pkg["model"])
model.set_num_codebooks(8)
return model


def get_lm(filename: tp.Union[str, Path], device):
def get_moshi_lm(filename: tp.Union[str, Path],
device: torch.device | str = 'cpu') -> LMModel:
dtype = torch.bfloat16
model = LMModel(
device=device,
dtype=dtype,
**lm_kwargs,
**_lm_kwargs,
).to(device=device, dtype=dtype)
model.eval()
if _is_safetensors(filename):
Expand Down
6 changes: 2 additions & 4 deletions moshi/moshi/modules/seanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ def __init__(
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
self.disable_norm_outer_blocks = disable_norm_outer_blocks
assert (
self.disable_norm_outer_blocks >= 0
and self.disable_norm_outer_blocks <= self.n_blocks
self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
), (
"Number of blocks for which to disable norm is invalid."
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
Expand Down Expand Up @@ -307,8 +306,7 @@ def __init__(
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
self.disable_norm_outer_blocks = disable_norm_outer_blocks
assert (
self.disable_norm_outer_blocks >= 0
and self.disable_norm_outer_blocks <= self.n_blocks
self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
), (
"Number of blocks for which to disable norm is invalid."
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
Expand Down
Loading

0 comments on commit ccbe247

Please sign in to comment.