From 811c885c644e5f5aea0444ae148c32dd37996ffd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20D=C3=A9fossez?= Date: Tue, 17 Sep 2024 17:22:41 +0200 Subject: [PATCH] fix --- moshi/moshi/models/loaders.py | 2 +- moshi/moshi/server.py | 12 ++--- moshi/pyproject.toml | 2 +- moshi/requirements.txt | 1 + scripts/mimi_streaming_test.py | 3 +- scripts/moshi_benchmark.py | 87 ++++++++++++++++++---------------- 6 files changed, 57 insertions(+), 50 deletions(-) diff --git a/moshi/moshi/models/loaders.py b/moshi/moshi/models/loaders.py index af0ca7a..1917694 100644 --- a/moshi/moshi/models/loaders.py +++ b/moshi/moshi/models/loaders.py @@ -104,7 +104,7 @@ def _is_safetensors(path: Path | str) -> bool: return Path(path).suffix in (".safetensors", ".sft", ".sfts") -def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = False) -> Path: +def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = True) -> Path: """Load a model checkpoint from HF. If `allow_local_file` is True, then if a file `name` exists, it will be used instead. """ diff --git a/moshi/moshi/server.py b/moshi/moshi/server.py index e610061..a67a332 100644 --- a/moshi/moshi/server.py +++ b/moshi/moshi/server.py @@ -175,11 +175,11 @@ def main(): parser.add_argument("--gradio_tunnel_token", help='Provide a custom (secret) token here to keep getting the same URL.') - parser.add_argument("--tokenizer-name", type=str, default=loaders.TEXT_TOKENIZER_V0_1, + parser.add_argument("--tokenizer", type=str, default=loaders.TEXT_TOKENIZER_V0_1, help="Name of the text tokenizer file in the given HF repo, or path to a local file.") - parser.add_argument("--moshi-name", type=str, default=loaders.MOSHIKO_V0_1, + parser.add_argument("--moshi-weight", type=str, default=loaders.MOSHIKO_V0_1, help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.") - parser.add_argument("--mimi-name", type=str, default=loaders.MIMI_V0_1, + parser.add_argument("--mimi-weight", type=str, default=loaders.MIMI_V0_1, help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.") parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO, help="HF repo to look into, defaults to Kyutai official one.") @@ -204,15 +204,15 @@ def main(): tunnel_token = args.gradio_tunnel_token log("info", "loading mimi") - mimi_path = loaders.resolve_model_checkpoint(args.mimi_name, args.hf_repo, allow_local_file=True) + mimi_path = loaders.resolve_model_checkpoint(args.mimi_weight, args.hf_repo) mimi = loaders.get_mimi(mimi_path, args.device) log("info", "mimi loaded") - tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer_name, args.hf_repo, allow_local_file=True) + tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer, args.hf_repo) text_tokenizer = loaders.get_text_tokenizer(tokenizer_path) log("info", "loading moshi") - moshi_path = loaders.resolve_model_checkpoint(args.moshi_name, args.hf_repo, allow_local_file=True) + moshi_path = loaders.resolve_model_checkpoint(args.moshi_weight, args.hf_repo) lm = loaders.get_moshi_lm(moshi_path, args.device) log("info", "moshi loaded") diff --git a/moshi/pyproject.toml b/moshi/pyproject.toml index d7044c3..d5b7578 100644 --- a/moshi/pyproject.toml +++ b/moshi/pyproject.toml @@ -3,7 +3,7 @@ name = "moshi" requires-python = ">= 3.10" description = "Moshi is moshi" dependencies = [ - "numpy >= 2.1.0, < 2.2", + "numpy >= 1.26, < 2.2", "safetensors >= 0.4.0, < 0.5", "huggingface-hub >= 0.24, < 0.25", "einops == 0.7", diff --git a/moshi/requirements.txt b/moshi/requirements.txt index 9a93905..876de9d 100644 --- a/moshi/requirements.txt +++ b/moshi/requirements.txt @@ -5,5 +5,6 @@ sounddevice==0.5.0 soundfile==0.12.1 sphn==0.1.4 torch==2.2.0 +numpy==1.26.4 aiohttp>=3.10.5, <3.11 huggingface-hub==0.24.6 diff --git a/scripts/mimi_streaming_test.py b/scripts/mimi_streaming_test.py index 5f8c4da..54865a3 100644 --- a/scripts/mimi_streaming_test.py +++ b/scripts/mimi_streaming_test.py @@ -44,7 +44,8 @@ def seed_all(seed): print("mimi loaded") -def mimi_streaming_test(mimi, pcm_chunk_size=1920, max_duration_sec=10.0): +def mimi_streaming_test(mimi, max_duration_sec=10.0): + pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 sample_pcm, sample_sr = sphn.read("bria.mp3") sample_rate = mimi.sample_rate diff --git a/scripts/moshi_benchmark.py b/scripts/moshi_benchmark.py index 056542f..0bd015b 100644 --- a/scripts/moshi_benchmark.py +++ b/scripts/moshi_benchmark.py @@ -3,26 +3,30 @@ # LICENSE file in the root directory of this source tree. import argparse -import moshi -import sentencepiece -import torch -import sphn -import numpy as np import random import time +import numpy as np +import sentencepiece +import sphn +import torch from torch.profiler import profile, ProfilerActivity -SAMPLE_RATE = moshi.models.moshi.SAMPLE_RATE -DEVICE = "cuda:0" -ENABLE_PROFILING = False +from moshi.models import loaders, LMGen + parser = argparse.ArgumentParser() -parser.add_argument("--tokenizer", type=str) -parser.add_argument("--moshi-weights", type=str) -parser.add_argument("--mimi-weights", type=str) +parser.add_argument("--tokenizer", type=str, default=loaders.TEXT_TOKENIZER_V0_1, + help="Name of the text tokenizer file in the given HF repo, or path to a local file.") +parser.add_argument("--moshi-weight", type=str, default=loaders.MOSHIKO_V0_1, + help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.") +parser.add_argument("--mimi-weight", type=str, default=loaders.MIMI_V0_1, + help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.") +parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO, + help="HF repo to look into, defaults to Kyutai official one.") parser.add_argument("--steps", default=100, type=int) parser.add_argument("--profile", action="store_true") +parser.add_argument("--device", type=str, default='cuda') args = parser.parse_args() @@ -39,52 +43,53 @@ def seed_all(seed): seed_all(42424242) +tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer, args.hf_repo) +text_tokenizer = loaders.get_text_tokenizer(tokenizer_path) print("loading mimi") -ec = moshi.models.moshi.get_encodec(args.mimi_weights, DEVICE) +mimi_path = loaders.resolve_model_checkpoint(args.mimi_weight, args.hf_repo) +mimi = loaders.get_mimi(mimi_path, args.device) print("mimi loaded") -text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer) print("loading moshi") -lm = moshi.models.moshi.get_lm(args.moshi_weights, DEVICE) -lm.to(torch.bfloat16) +moshi_path = loaders.resolve_model_checkpoint(args.moshi_weight, args.hf_repo) +lm = loaders.get_moshi_lm(moshi_path, args.device) +lm_gen = LMGen(lm) print("lm loaded") -lm_gen = moshi.models.LMGen(lm) - def cb(step, total): print(f"{step:06d} / {total:06d}", end="\r") def streaming_test(bs): - main_audio = [] main_text = [] + frame_size = int(mimi.sample_rate / mimi.frame_rate) + def run_step(): start_time = time.time() # Chunk should contain the pcm data from the user, single channel with a sample rate of 24000. - chunk = torch.zeros((bs, 1, 1920), dtype=torch.float, device=DEVICE) - codes = ec.encode(chunk) + chunk = torch.zeros((bs, 1, frame_size), dtype=torch.float, device=args.device) + codes = mimi.encode(chunk) assert codes.shape[-1] == 1 - for c in range(codes.shape[-1]): - be = time.time() - ev = torch.cuda.Event(enable_timing=True) - ev.record() - tokens = lm_gen.step(codes[:, :, c : c + 1]) - if tokens is None: - print("Skipping") - return - evb = torch.cuda.Event(enable_timing=True) - evb.record() - dt_step = time.time() - be - text_tokens = tokens[:, 0, 0] - audio_tokens = tokens[:, 1:, :] - main_pcm = ec.decode(audio_tokens) - # main_pcm is the audio to be played back to the user, here we just append it and store it in - # a file once the loop is finished. - main_audio.append(main_pcm[0]) + be = time.time() + ev = torch.cuda.Event(enable_timing=True) + ev.record() + tokens = lm_gen.step(codes[:, :, :1]) + if tokens is None: + print("Skipping") + return + evb = torch.cuda.Event(enable_timing=True) + evb.record() + dt_step = time.time() - be + text_tokens = tokens[:, 0, 0] + audio_tokens = tokens[:, 1:, :] + main_pcm = mimi.decode(audio_tokens) + # main_pcm is the audio to be played back to the user, here we just append it and store it in + # a file once the loop is finished. + main_audio.append(main_pcm[0]) evb.synchronize() dg = ev.elapsed_time(evb) torch.cuda.synchronize() @@ -109,17 +114,17 @@ def run_step(): run_step() print() prof.export_chrome_trace("trace.json") - main_audio = torch.cat(main_audio, dim=-1) - print(main_audio.shape) + main_audio_th = torch.cat(main_audio, dim=-1) + print(main_audio_th.shape) print("generated text:") print("".join(main_text)) sphn.write_wav( - "gen_main.wav", main_audio[0].cpu().numpy().astype(np.float32), SAMPLE_RATE + "gen_main.wav", main_audio_th[0].cpu().numpy().astype(np.float32), mimi.sample_rate ) print("streaming test") bs = 1 with torch.no_grad(): - with ec.streaming(bs), lm_gen.streaming(bs): + with mimi.streaming(bs), lm_gen.streaming(bs): streaming_test(bs)