Skip to content

Commit

Permalink
Per repo loading + final cosmetic and doc changes (#40)
Browse files Browse the repository at this point in the history
* fix

* missing

* wip

* adapting rust

* plop

* model list

* improvs

* fix rust

* fixing bug in vq and streaming
  • Loading branch information
adefossez authored Sep 18, 2024
1 parent bcb79be commit 1ae8e10
Show file tree
Hide file tree
Showing 15 changed files with 206 additions and 96 deletions.
59 changes: 50 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,32 @@ There are three separate versions of the moshi inference stack in this repo.
- The python version using PyTorch is in the [`moshi/`](moshi/) directory.
- The python version using MLX for M series Macs is in the [`moshi_mlx/`](moshi_mlx/) directory.
- The rust version used in production is in the [`rust/`](rust/) directory.
This contains in particular a Mimi implementation in Rust, with Python bindings available
as `rustymimi`.

Finally, the code for the live demo is provided in the [`client/`](client/) directory.


## Models

We release three models:
- our speech codec Mimi,
- Moshi fine-tuned on a male synthetic voice (Moshiko),
- Moshi fine-tuned on a female synthetic voice (Moshika).

Depending on the backend, the file format and quantization available will vary. Here is the list
of the HuggingFace repo with each model. Mimi is bundled in any of those, and always use the same checkpoint format.

- Moshika for PyTorch (bf16): [kmhf/moshika-pytorch-bf16](https://huggingface.co/kmhf/moshika-pytorch-bf16).
- Moshiko for PyTorch (bf16): [kmhf/moshiko-pytorch-bf16](https://huggingface.co/kmhf/moshiko-pytorch-bf16).
- Moshika for MLX (int4, int8, bf16): [kmhf/moshiko-mlx-q4](https://huggingface.co/kmhf/moshika-mlx-q4), [kmhf/moshiko-mlx-q8](https://huggingface.co/kmhf/moshika-mlx-q8), [kmhf/moshiko-mlx-bf16](https://huggingface.co/kmhf/moshika-mlx-bf16).
- Moshiko for MLX (int4, int8, bf16): [kmhf/moshiko-mlx-q4](https://huggingface.co/kmhf/moshiko-mlx-q4), [kmhf/moshiko-mlx-q8](https://huggingface.co/kmhf/moshiko-mlx-q8), [kmhf/moshiko-mlx-bf16](https://huggingface.co/kmhf/moshiko-mlx-bf16).
- Moshiko for Rust/Candle (int8, bf16): [kmhf/moshika-candle-q8](https://huggingface.co/kmhf/moshika-candle-q8), [kmhf/moshiko-mlx-bf16](https://huggingface.co/kmhf/moshika-candle-bf16).
- Moshiko for Rust/Candle (int8, bf16): [kmhf/moshiko-candle-q8](https://huggingface.co/kmhf/moshiko-candle-q8), [kmhf/moshiko-mlx-bf16](https://huggingface.co/kmhf/moshiko-candle-bf16).

## Requirements

You will need at least Python 3.10. For using the rust backend, you will need a recent version of
the [Rust toolchain](https://rustup.rs/). For specific requirements, please check the individual backends
You will need at least Python 3.10. For specific requirements, please check the individual backends
directories. You can install the PyTorch and MLX clients with the following:

```bash
Expand All @@ -64,41 +83,54 @@ pip install moshi_mlx # moshi MLX, from PyPI
# Or the bleeding edge versions for Moshi and Moshi-MLX.
pip install -e "git+https://[email protected]/kyutai-labs/moshi.git#egg=moshi&subdirectory=moshi"
pip install -e "git+https://[email protected]/kyutai-labs/moshi.git#egg=moshi_mlx&subdirectory=moshi_mlx"

pip install rustymimi # mimi, rust implementation with Python bindings from PyPI
```

While we hope that the present codebase will work on Windows, we do not provide official support for it.
We have tested the MLX version with MacBook Pro M3. At the moment, we do not support quantization
for the PyTorch version, so you will need a GPU with a significant amount of memory (24GB).

For using the rust backend, you will need a recent version of the [Rust toolchain](https://rustup.rs/).
To compile GPU support, you will also need the [CUDA](https://developer.nvidia.com/cuda-toolkit) properly installed for your GPU, in particular with `nvcc`.

## Development

If you wish to install from a clone of this repository, maybe to further develop Moshi, you can do the following:
```
```bash
# From the root of the clone of the repo
pip install -e 'moshi[dev]'
pip install -e 'moshi_mlx[dev]'
pre-commit install
```

If you wish to build locally `rustymimi` (assuming you have Rust properly installed):
```bash
pip install maturin
maturin dev -r -m rust/mimi-pyo3/Cargo.toml
```

## Python (PyTorch)

The python api can be found in the `moshi` directory. It provides a streaming
The Pytorch based API can be found in the `moshi` directory. It provides a streaming
version of the audio tokenizer (mimi) and the lm model (moshi).

In order to run in interactive mode, you need to start a server which will
run the model, you can then use either the web UI or a command line client.

Start the server with:
```bash
python -m moshi.server [--gradio_tunnel]
python -m moshi.server [--gradio-tunnel] [--hf-repo kmhf/moshika-pytorch-bf16]
```

And then access the web UI on [localhost:8998](http://localhost:8998). If your GPU is on a distant machine
with no direct access, `--gradio_tunnel` will create a tunnel with a URL accessible from anywhere.
with no direct access, `--gradio-tunnel` will create a tunnel with a URL accessible from anywhere.
Keep in mind that this tunnel goes through the US and can add significant latency (up to 500ms from Europe).
You can use `--gradio-tunnel-token` to set a fixed secret and reuse the same address over time.
Alternatively, you might want to use SSH to redirect your connection.

You can use `--hf-repo` to select a different pretrained model, by setting the proper Hugging Face repository.

Accessing a server that is not localhost via http may cause issues around using
the microphone in the web UI (in some browsers this is only allowed using
https).
Expand All @@ -110,12 +142,19 @@ python -m moshi.client [--url URL_TO_GRADIO]
However note, that unlike the web browser, this client is bare bone. It doesn't do any echo cancellation,
nor does it try to compensate for a growing lag by skipping frames.

For more information, in particular on how to use the API directly, please
checkout [moshi/README.md](moshi/README.md).

## Python (MLX) for local inference on macOS

Once you have installed `moshi_mlx`, you can run
```bash
python -m moshi_mlx.local -q 4 # weights quantized to 4 bits
python -m moshi_mlx.local -q 8 # weights quantized to 8 bits
# And using a different pretrained model:
python -m moshi_mlx.local -q 4 --hf-repo kmhf/moshika-mlx-q4
python -m moshi_mlx.local -q 8 --hf-repo kmhf/moshika-mlx-q8
# be careful to always match the `-q` and `--hf-repo` flag.
```

This uses a command line interface, which is bare bone. It doesn't do any echo cancellation,
Expand All @@ -136,7 +175,8 @@ cargo run --features cuda --bin moshi-backend -r -- --config moshi-backend/confi
When using macOS, you can replace `--features cuda` with `--features metal`.

Alternatively you can use `config-q8.json` rather than `config.json` to use the
quantified q8 model.
quantified q8 model. You can select a different pretrained model, e.g. Moshika,
by changing the `"hf_repo"` key in either file.

Once the server has printed 'standalone worker listening', you can use the web
UI. By default the rust version uses https so it will be at
Expand All @@ -163,7 +203,7 @@ cargo run --bin moshi-cli -r -- tui --host localhost
### Python with PyTorch

```bash
PYTHONPATH=moshi python -m moshi.client
python -m moshi.client
```

### WebUI
Expand Down Expand Up @@ -192,7 +232,8 @@ If you use either Mimi or Moshi, please cite the following paper,
```
@article{defossez2024moshi,
title={Moshi: a speech-text foundation model for real-time dialogue},
author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and Amélie Royer and Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour},
author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and Amélie Royer and
Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour},
journal={arXiv:TBC},
year={2024},
}
Expand Down
54 changes: 51 additions & 3 deletions moshi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@ run the model, you can then use either the web UI or a command line client.

Start the server with:
```bash
python -m moshi.server [--gradio_tunnel]
python -m moshi.server [--gradio-tunnel]
```

And then access the web UI on [localhost:8998](http://localhost:8998). If your GPU is on a distant machine
with no direct access, `--gradio_tunnel` will create a tunnel with a URL accessible from anywhere.
with no direct access, `--gradio-tunnel` will create a tunnel with a URL accessible from anywhere.
Keep in mind that this tunnel goes through the US and can add significant latency (up to 500ms from Europe).
You can use `--gradio-tunnel-token` to set a fixed secret and reuse the same address over time.
Alternatively, you might want to use SSH to redirect your connection.

You can use `--hf-repo` to select a different pretrained model, by setting the proper Hugging Face repository.

Accessing a server that is not localhost via http may cause issues around using
the microphone in the web UI (in some browsers this is only allowed using
https).
Expand All @@ -53,6 +56,49 @@ python -m moshi.client [--url URL_TO_GRADIO]
However note, that unlike the web browser, this client is bare bone. It doesn't do any echo cancellation,
nor does it try to compensate for a growing lag by skipping frames.


## API - Mimi

You can use programmatically the Mimi/Moshi as follows:
```python
from huggingface_hub import hf_hub_download
import torch

from moshi.models import loaders

mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device='cpu')
mimi.set_num_codebooks(8) # up to 32.

wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]
with torch.no_grad():
codes = mimi.encode(wav) # [B, K = 8, T]
decoded = mimi.decodes(codes)

# Supports streaming too.
frame_size = int(mimi.sample_rate / mimi.frame_rate)
all_codes = []
with mimi.streaming(batch_size=1):
for offset in range(0, wav.shape[-1], frame_size):
frame = wav[:, :, offset: offset + frame_size]
codes = mimi.encode(wav)
assert codes.shape[-1] == 1
all_codes.append(codes)

moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
moshi = loaders.get_moshi_lm(moshi_weight, device='cuda')
out_wav_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), moshi.streaming(1), mimi.streaming(1):
for code in all_codes:
tokens_out = moshi.step(code.cuda())
# tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
if out is not None:
wav_chunk = mimi.decode(tokens_out[:, 1:])
out_wav_chunks.append(wav_chunk)
out_wav = torch.cat(out_wav_chunks, dim=-1)
```

## Development

If you wish to install from a clone of this repository, maybe to further develop Moshi, you can do the following:
Expand Down Expand Up @@ -88,10 +134,12 @@ If you use either Mimi or Moshi, please cite the following paper,
```
@article{defossez2024moshi,
title={Moshi: a speech-text foundation model for real-time dialogue},
author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and Amélie Royer and Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour},
author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and Amélie Royer and
Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour},
journal={arXiv:TBC},
year={2024},
}
```

[moshi]: https://arxiv.org/
[main_repo]: https://github.com/kyutai-labs/moshi
27 changes: 5 additions & 22 deletions moshi/moshi/models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
"""Retrieves the pretrained models for Moshi and Mimi."""
from pathlib import Path

from huggingface_hub import hf_hub_download
from safetensors.torch import load_model
import sentencepiece
import torch

from .compression import MimiModel
Expand All @@ -16,11 +14,11 @@

SAMPLE_RATE = 24000
FRAME_RATE = 12.5
HF_REPO = 'kmhf/msh-v0.1'
MIMI_V0_1 = 'tokenizer-e351c8d8-checkpoint125.safetensors'
MOSHIKO_V0_1 = 'moshiko_pt_301e30bf@120.safetensors'
MOSHIKA_V0_1 = 'moshika_pt_3d736a96@120.safetensors'
TEXT_TOKENIZER_V0_1 = 'tokenizer_spm_32k_3.model'

TEXT_TOKENIZER_NAME = 'tokenizer_spm_32k_3.model'
MOSHI_NAME = 'model.safetensors'
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
DEFAULT_REPO = 'kmhf/moshiko-pytorch-bf16'


_seanet_kwargs = {
Expand Down Expand Up @@ -104,21 +102,6 @@ def _is_safetensors(path: Path | str) -> bool:
return Path(path).suffix in (".safetensors", ".sft", ".sfts")


def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = True) -> Path:
"""Load a model checkpoint from HF.
If `allow_local_file` is True, then if a file `name` exists, it will be used instead.
"""
if allow_local_file and Path(name).exists():
return Path(name)
else:
filename = name
return Path(hf_hub_download(hf_repo, filename))


def get_text_tokenizer(filename: str | Path) -> sentencepiece.SentencePieceProcessor:
return sentencepiece.SentencePieceProcessor(str(filename)) # type: ignore


def get_mimi(filename: str | Path,
device: torch.device | str = 'cpu') -> MimiModel:
"""Return a pretrained Mimi model."""
Expand Down
6 changes: 4 additions & 2 deletions moshi/moshi/modules/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,20 @@ def set_streaming_propagate(self, streaming_propagate: bool):
self._streaming_propagate = streaming_propagate

def _apply_named_streaming(self, fn: tp.Any):
def _handle_module(prefix: str, module: nn.Module):
def _handle_module(prefix: str, module: nn.Module, recurse: bool = True):
propagate = True
if isinstance(module, StreamingModule):
if module._streaming_propagate:
fn(prefix, module)
else:
propagate = False
if not recurse:
return
if propagate:
for name, child in module.named_children():
_handle_module(prefix + "." + name, child)

_handle_module("", self)
_handle_module("", self, recurse=False)
for name, child in self.named_children():
_handle_module(name, child)

Expand Down
6 changes: 4 additions & 2 deletions moshi/moshi/quantization/core_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def __init__(
self.register_buffer("_initialized", torch.tensor([False], dtype=torch.float))
self.register_buffer("cluster_usage", torch.ones(codebook_size))
self.register_buffer("embedding_sum", embedding)
self.register_buffer("_embedding", None, persistent=False)
self._cached_initialized = False
self._embedding = None

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs) -> None:
# Mapping old names to new names
Expand All @@ -142,9 +142,11 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs) -> None:
@property
def embedding(self) -> torch.Tensor:
if self._embedding is None:
self._embedding = (
embedding = (
self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
)
self.register_buffer("_embedding", embedding, persistent=False)
return embedding
return self._embedding

def _broadcast_buffers(self) -> None:
Expand Down
35 changes: 18 additions & 17 deletions moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,16 @@ def main():
parser.add_argument("--host", default="localhost", type=str)
parser.add_argument("--port", default=8998, type=int)
parser.add_argument("--static", type=str)
parser.add_argument("--gradio_tunnel", action='store_true', help='Activate a gradio tunnel.')
parser.add_argument("--gradio_tunnel_token",
parser.add_argument("--gradio-tunnel", action='store_true', help='Activate a gradio tunnel.')
parser.add_argument("--gradio-tunnel-token",
help='Provide a custom (secret) token here to keep getting the same URL.')

parser.add_argument("--tokenizer", type=str, default=loaders.TEXT_TOKENIZER_V0_1,
help="Name of the text tokenizer file in the given HF repo, or path to a local file.")
parser.add_argument("--moshi-weight", type=str, default=loaders.MOSHIKO_V0_1,
help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.")
parser.add_argument("--mimi-weight", type=str, default=loaders.MIMI_V0_1,
help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.")
parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO,
help="HF repo to look into, defaults to Kyutai official one.")
parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.")
parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.")
parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.")
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO,
help="HF repo to look into, defaults Moshiko. "
"Use this to select a different pre-trained model.")
parser.add_argument("--device", type=str, default="cuda", help="Device on which to run, defaults to 'cuda'.")

args = parser.parse_args()
Expand All @@ -204,16 +202,19 @@ def main():
tunnel_token = args.gradio_tunnel_token

log("info", "loading mimi")
mimi_path = loaders.resolve_model_checkpoint(args.mimi_weight, args.hf_repo)
mimi = loaders.get_mimi(mimi_path, args.device)
if args.mimi_weight is None:
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
mimi = loaders.get_mimi(args.mimi_weight, args.device)
log("info", "mimi loaded")

tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer, args.hf_repo)
text_tokenizer = loaders.get_text_tokenizer(tokenizer_path)
if args.tokenizer is None:
args.tokenizer = hf_hub_download(args.hf_repo, loaders.TEXT_TOKENIZER_NAME)
text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) # type: ignore

log("info", "loading moshi")
moshi_path = loaders.resolve_model_checkpoint(args.moshi_weight, args.hf_repo)
lm = loaders.get_moshi_lm(moshi_path, args.device)
if args.moshi_weight is None:
args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME)
lm = loaders.get_moshi_lm(args.moshi_weight, args.device)
log("info", "moshi loaded")

state = ServerState(mimi, text_tokenizer, lm, args.device)
Expand All @@ -224,7 +225,7 @@ def main():
static_path: None | str = None
if args.static is None:
log("info", "retrieving the static content")
dist_tgz = hf_hub_download(args.hf_repo, "dist.tgz")
dist_tgz = hf_hub_download("kmhf/moshi-artifacts", "dist.tgz")
dist_tgz = Path(dist_tgz)
dist = dist_tgz.parent / "dist"
if not dist.exists():
Expand Down
3 changes: 2 additions & 1 deletion moshi_mlx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ If you use either Mimi or Moshi, please cite the following paper,
```
@article{defossez2024moshi,
title={Moshi: a speech-text foundation model for real-time dialogue},
author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and Amélie Royer and Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour},
author={Alexandre Défossez and Laurent Mazaré and Manu Orsini and Amélie Royer and
Patrick Pérez and Hervé Jégou and Edouard Grave and Neil Zeghidour},
journal={arXiv:TBC},
year={2024},
}
Expand Down
Loading

0 comments on commit 1ae8e10

Please sign in to comment.