Skip to content

Commit

Permalink
Merge branch 'main' into readme
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed Sep 18, 2024
2 parents 152062e + a45ec3e commit 5e99abb
Show file tree
Hide file tree
Showing 19 changed files with 128 additions and 95 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/precommit.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: precommmit
name: precommit
on:
push:
branches: [ main ]
Expand Down
2 changes: 2 additions & 0 deletions client/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link rel="icon" type="image/png" sizes="32x32" href="/assets/favicon-32x32.png">
<link rel="icon" type="image/png" sizes="16x16" href="/assets/favicon-16x16.png">
<title>moshi.chat</title>
</head>
<body class=" bg-black font-mono font-thin">
Expand Down
Binary file added client/public/assets/favicon-16x16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added client/public/assets/favicon-32x32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added client/public/assets/favicon.ico
Binary file not shown.
6 changes: 4 additions & 2 deletions moshi/moshi/models/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,14 @@ def __init__(
)

def _init_streaming_state(self, batch_size: int) -> _MimiState:
device = next(self.parameters()).device
disable = device.type != 'cuda'
graphed_tr_dec = None
graphed_tr_enc = None
if self.encoder_transformer is not None:
graphed_tr_enc = CUDAGraphed(self.encoder_transformer)
graphed_tr_enc = CUDAGraphed(self.encoder_transformer, disable=disable)
if self.decoder_transformer is not None:
graphed_tr_dec = CUDAGraphed(self.decoder_transformer)
graphed_tr_dec = CUDAGraphed(self.decoder_transformer, disable=disable)
return _MimiState(graphed_tr_enc, graphed_tr_dec)

@property
Expand Down
5 changes: 3 additions & 2 deletions moshi/moshi/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,9 @@ def _init_streaming_state(self, batch_size: int) -> _LMGenState:
dtype=torch.long,
)

graphed_main = CUDAGraphed(lm_model.forward_text)
graphed_depth = CUDAGraphed(self.depformer_step)
disable = lm_model.device.type != 'cuda'
graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable)
graphed_depth = CUDAGraphed(self.depformer_step, disable=disable)

return _LMGenState(cache, initial, graphed_main, graphed_depth)

Expand Down
2 changes: 1 addition & 1 deletion moshi/moshi/models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _is_safetensors(path: Path | str) -> bool:
return Path(path).suffix in (".safetensors", ".sft", ".sfts")


def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = False) -> Path:
def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = True) -> Path:
"""Load a model checkpoint from HF.
If `allow_local_file` is True, then if a file `name` exists, it will be used instead.
"""
Expand Down
17 changes: 11 additions & 6 deletions moshi/moshi/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
See `StreamingTransformer` for more information.
"""

from contextlib import ExitStack
from dataclasses import dataclass
import typing as tp

Expand All @@ -17,6 +18,7 @@
import torch.nn as nn
from torch.nn import functional as F

from ..utils.compile import no_compile
from .gating import make_gating
from .rope import RotaryEmbedding
from .streaming import StreamingModule, StreamingContainer
Expand Down Expand Up @@ -579,12 +581,15 @@ def _sa_block(self, x: torch.Tensor):
return x_orig + self.layer_scale_1(update)

def forward(self, x: torch.Tensor):
x = self._sa_block(x)
x = self._ff_block(x)
state = self._streaming_state
if state:
state.offset_cpu += x.shape[1]
return x
with ExitStack() as stack:
if x.device.type != 'cuda':
stack.enter_context(no_compile())
x = self._sa_block(x)
x = self._ff_block(x)
state = self._streaming_state
if state:
state.offset_cpu += x.shape[1]
return x


@dataclass
Expand Down
12 changes: 6 additions & 6 deletions moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ def main():
parser.add_argument("--gradio_tunnel_token",
help='Provide a custom (secret) token here to keep getting the same URL.')

parser.add_argument("--tokenizer-name", type=str, default=loaders.TEXT_TOKENIZER_V0_1,
parser.add_argument("--tokenizer", type=str, default=loaders.TEXT_TOKENIZER_V0_1,
help="Name of the text tokenizer file in the given HF repo, or path to a local file.")
parser.add_argument("--moshi-name", type=str, default=loaders.MOSHIKO_V0_1,
parser.add_argument("--moshi-weight", type=str, default=loaders.MOSHIKO_V0_1,
help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.")
parser.add_argument("--mimi-name", type=str, default=loaders.MIMI_V0_1,
parser.add_argument("--mimi-weight", type=str, default=loaders.MIMI_V0_1,
help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.")
parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO,
help="HF repo to look into, defaults to Kyutai official one.")
Expand All @@ -204,15 +204,15 @@ def main():
tunnel_token = args.gradio_tunnel_token

log("info", "loading mimi")
mimi_path = loaders.resolve_model_checkpoint(args.mimi_name, args.hf_repo, allow_local_file=True)
mimi_path = loaders.resolve_model_checkpoint(args.mimi_weight, args.hf_repo)
mimi = loaders.get_mimi(mimi_path, args.device)
log("info", "mimi loaded")

tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer_name, args.hf_repo, allow_local_file=True)
tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer, args.hf_repo)
text_tokenizer = loaders.get_text_tokenizer(tokenizer_path)

log("info", "loading moshi")
moshi_path = loaders.resolve_model_checkpoint(args.moshi_name, args.hf_repo, allow_local_file=True)
moshi_path = loaders.resolve_model_checkpoint(args.moshi_weight, args.hf_repo)
lm = loaders.get_moshi_lm(moshi_path, args.device)
log("info", "moshi loaded")

Expand Down
11 changes: 7 additions & 4 deletions moshi/moshi/utils/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@contextmanager
def no_compile():
"""Disable torch.compile locally."""
"""Disable torch.compile locally. Now Pytorch 2.4 provides a function to do that."""
global _compile_disabled

prev_disabled = _compile_disabled
Expand Down Expand Up @@ -194,11 +194,14 @@ class CUDAGraphed:
be top level args, not nested in structures (tuples, dicts, etc). Keyword
arguments are NOT supported for simplicity.
warmup_steps: how many call to make normally before CUDA Graphing. In particular, this
allows torch.compiled functions to get properly compiled."""
allows torch.compiled functions to get properly compiled.
disabled: if True, just call the func directly, useful to quickly deactivate on CPU.
"""

def __init__(self, func: tp.Callable, warmup_steps: int = 1):
def __init__(self, func: tp.Callable, warmup_steps: int = 1, disable: bool = False):
self.func = func
self.warmup_steps = warmup_steps
self.disable = disable
self._graph: cuda.CUDAGraph | None = None
self._output: tuple | None = None
self._args: tuple | None = None
Expand All @@ -214,7 +217,7 @@ def reset(self, warmup_steps: int = 0) -> None:
def __call__(self, *args, **kwargs) -> tp.Any:
if kwargs:
raise RuntimeError("Named arguments not supported for now.")
if not _is_cuda_graph_enabled() or in_cuda_graph():
if self.disable or not _is_cuda_graph_enabled() or in_cuda_graph():
return self.func(*args, **kwargs)

def _clone_tensors(args: tuple) -> tuple:
Expand Down
7 changes: 3 additions & 4 deletions moshi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "moshi"
requires-python = ">= 3.10"
description = "Moshi is moshi"
dependencies = [
"numpy >= 2.1.0, < 2.2",
"numpy >= 1.26, < 2.2",
"safetensors >= 0.4.0, < 0.5",
"huggingface-hub >= 0.24, < 0.25",
"einops == 0.7",
Expand All @@ -18,14 +18,13 @@ maintainers = [{name="Laurent Mazaré", email="[email protected]"}]
license = {text = "MIT"}
dynamic = ["version"]

[tool.setuptools.dynamic]
version = {attr = "moshi.__version__"}

[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
packages = ["moshi", "moshi.utils", "moshi.modules", "moshi.models", "moshi.quantization"]

[tool.setuptools.dynamic]
version = {attr = "moshi.__version__"}

Expand Down
1 change: 1 addition & 0 deletions moshi/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ sounddevice==0.5.0
soundfile==0.12.1
sphn==0.1.4
torch==2.2.0
numpy==1.26.4
aiohttp>=3.10.5, <3.11
huggingface-hub==0.24.6
7 changes: 7 additions & 0 deletions moshi_mlx/moshi_mlx/local_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import sphn
import aiohttp
from aiohttp import web
import webbrowser

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -334,6 +335,11 @@ async def handle_root(_):
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, args.host, args.port)

if not args.no_browser:
log("info", f"opening browser at http://{args.host}:{args.port}")
webbrowser.open(f"http://{args.host}:{args.port}")

await asyncio.gather(
recv_loop(), send_loop(), recv_loop2(), send_loop2(), site.start()
)
Expand All @@ -356,6 +362,7 @@ def main():
parser.add_argument("--static", type=str)
parser.add_argument("--host", default="localhost", type=str)
parser.add_argument("--port", default=8998, type=int)
parser.add_argument("--no-browser", action="store_true")

args = parser.parse_args()

Expand Down
2 changes: 1 addition & 1 deletion rust/moshi-backend/config-q8.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"instance_name": "foo",
"hf_repo": "kmhf/msh-v0.1",
"hf_repo": "kmhf/moshi-v0.1",
"lm_model_file": "$HOME/tmp/[email protected]",
"text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model",
"log_dir": "$HOME/tmp/moshi-logs",
Expand Down
2 changes: 1 addition & 1 deletion rust/moshi-backend/config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"instance_name": "foo",
"hf_repo": "kmhf/msh-v0.1",
"hf_repo": "kmhf/moshi-v0.1",
"lm_model_file": "$HOME/tmp/[email protected]",
"text_tokenizer_file": "$HOME/tmp/tokenizer_spm_32k_3.model",
"log_dir": "$HOME/tmp/moshi-logs",
Expand Down
59 changes: 33 additions & 26 deletions scripts/mimi_streaming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
# LICENSE file in the root directory of this source tree.

import argparse
import moshi
import random
import time
import torch

import numpy as np
import sphn
import torch
from torch.profiler import profile, ProfilerActivity
import numpy as np
import random

SAMPLE_RATE = moshi.models.moshi.SAMPLE_RATE
DEVICE = "cuda:0"
ENABLE_PROFILING = False
from moshi.models import loaders


parser = argparse.ArgumentParser()
parser.add_argument("--weights", type=str)
parser.add_argument("--weights", type=str, default=loaders.MIMI_V0_1)
parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO)
parser.add_argument("--device", type=str,
default='cuda' if torch.cuda.device_count() else 'cpu')
parser.add_argument("--profile", action='store_true')
args = parser.parse_args()


Expand All @@ -35,23 +38,27 @@ def seed_all(seed):


print("loading mimi")
ec = moshi.models.moshi.get_encodec(args.weights, DEVICE)
mimi = loaders.get_mimi(
loaders.resolve_model_checkpoint(args.weights, args.hf_repo),
args.device)
print("mimi loaded")


def encodec_streaming_test(ec, pcm_chunk_size=1920, max_duration_sec=10.0):
def mimi_streaming_test(mimi, max_duration_sec=10.0):
pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
sample_pcm, sample_sr = sphn.read("bria.mp3")
sample_rate = mimi.sample_rate
print("loaded pcm", sample_pcm.shape, sample_sr)
sample_pcm = sphn.resample(
sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=SAMPLE_RATE
sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
)
sample_pcm = torch.tensor(sample_pcm, device=DEVICE)
max_duration_len = int(SAMPLE_RATE * max_duration_sec)
sample_pcm = torch.tensor(sample_pcm, device=args.device)
max_duration_len = int(sample_rate * max_duration_sec)
if sample_pcm.shape[-1] > max_duration_len:
sample_pcm = sample_pcm[..., :max_duration_len]
print("resampled pcm", sample_pcm.shape, sample_sr)
sample_pcm = sample_pcm[None].to(device=DEVICE)
sample_pcm = sample_pcm[None].to(device=args.device)

print("streaming encoding...")
start_time = time.time()
Expand All @@ -61,34 +68,34 @@ def run_loop():
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
chunk = sample_pcm[..., start_idx:end_idx]
codes, _scale = ec.encode(chunk)
codes = mimi.encode(chunk)
if codes.shape[-1]:
print(start_idx, codes.shape, end="\r")
all_codes.append(codes)

if ENABLE_PROFILING:
if args.profile:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
run_loop()
prof.export_chrome_trace("trace.json")
else:
run_loop()
all_codes = torch.cat(all_codes, dim=-1)
print(f"codes {all_codes.shape} generated in {time.time() - start_time:.2f}s")
all_codes_th = torch.cat(all_codes, dim=-1)
print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s")
print("streaming decoding...")
all_pcms = []
with ec.streaming():
for i in range(all_codes.shape[-1]):
codes = all_codes[..., i : i + 1]
pcm = ec.decode(codes, scale=None)
with mimi.streaming(1):
for i in range(all_codes_th.shape[-1]):
codes = all_codes_th[..., i : i + 1]
pcm = mimi.decode(codes)
print(i, pcm.shape, end="\r")
all_pcms.append(pcm)
all_pcms = torch.cat(all_pcms, dim=-1)
print("pcm", all_pcms.shape, all_pcms.dtype)
sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), SAMPLE_RATE)
pcm = ec.decode(all_codes, scale=None)
sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate)
pcm = mimi.decode(all_codes_th)
print("pcm", pcm.shape, pcm.dtype)
sphn.write_wav.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), SAMPLE_RATE)
sphn.write_wav("roundtrip_out.wav", pcm[0, 0].cpu().numpy(), sample_rate)


with torch.no_grad():
encodec_streaming_test(ec)
mimi_streaming_test(mimi)
Loading

0 comments on commit 5e99abb

Please sign in to comment.