Skip to content

Commit

Permalink
Merge branch 'rename' into ci
Browse files Browse the repository at this point in the history
  • Loading branch information
adefossez committed Sep 18, 2024
2 parents 85bcac8 + 5112385 commit 59e732c
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 50 deletions.
2 changes: 1 addition & 1 deletion moshi/moshi/models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _is_safetensors(path: Path | str) -> bool:
return Path(path).suffix in (".safetensors", ".sft", ".sfts")


def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = False) -> Path:
def resolve_model_checkpoint(name: str, hf_repo: str = HF_REPO, allow_local_file: bool = True) -> Path:
"""Load a model checkpoint from HF.
If `allow_local_file` is True, then if a file `name` exists, it will be used instead.
"""
Expand Down
12 changes: 6 additions & 6 deletions moshi/moshi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ def main():
parser.add_argument("--gradio_tunnel_token",
help='Provide a custom (secret) token here to keep getting the same URL.')

parser.add_argument("--tokenizer-name", type=str, default=loaders.TEXT_TOKENIZER_V0_1,
parser.add_argument("--tokenizer", type=str, default=loaders.TEXT_TOKENIZER_V0_1,
help="Name of the text tokenizer file in the given HF repo, or path to a local file.")
parser.add_argument("--moshi-name", type=str, default=loaders.MOSHIKO_V0_1,
parser.add_argument("--moshi-weight", type=str, default=loaders.MOSHIKO_V0_1,
help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.")
parser.add_argument("--mimi-name", type=str, default=loaders.MIMI_V0_1,
parser.add_argument("--mimi-weight", type=str, default=loaders.MIMI_V0_1,
help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.")
parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO,
help="HF repo to look into, defaults to Kyutai official one.")
Expand All @@ -204,15 +204,15 @@ def main():
tunnel_token = args.gradio_tunnel_token

log("info", "loading mimi")
mimi_path = loaders.resolve_model_checkpoint(args.mimi_name, args.hf_repo, allow_local_file=True)
mimi_path = loaders.resolve_model_checkpoint(args.mimi_weight, args.hf_repo)
mimi = loaders.get_mimi(mimi_path, args.device)
log("info", "mimi loaded")

tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer_name, args.hf_repo, allow_local_file=True)
tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer, args.hf_repo)
text_tokenizer = loaders.get_text_tokenizer(tokenizer_path)

log("info", "loading moshi")
moshi_path = loaders.resolve_model_checkpoint(args.moshi_name, args.hf_repo, allow_local_file=True)
moshi_path = loaders.resolve_model_checkpoint(args.moshi_weight, args.hf_repo)
lm = loaders.get_moshi_lm(moshi_path, args.device)
log("info", "moshi loaded")

Expand Down
2 changes: 1 addition & 1 deletion moshi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "moshi"
requires-python = ">= 3.10"
description = "Moshi is moshi"
dependencies = [
"numpy >= 2.1.0, < 2.2",
"numpy >= 1.26, < 2.2",
"safetensors >= 0.4.0, < 0.5",
"huggingface-hub >= 0.24, < 0.25",
"einops == 0.7",
Expand Down
1 change: 1 addition & 0 deletions moshi/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ sounddevice==0.5.0
soundfile==0.12.1
sphn==0.1.4
torch==2.2.0
numpy==1.26.4
aiohttp>=3.10.5, <3.11
huggingface-hub==0.24.6
3 changes: 2 additions & 1 deletion scripts/mimi_streaming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def seed_all(seed):
print("mimi loaded")


def mimi_streaming_test(mimi, pcm_chunk_size=1920, max_duration_sec=10.0):
def mimi_streaming_test(mimi, max_duration_sec=10.0):
pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
# wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
sample_pcm, sample_sr = sphn.read("bria.mp3")
sample_rate = mimi.sample_rate
Expand Down
87 changes: 46 additions & 41 deletions scripts/moshi_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,30 @@
# LICENSE file in the root directory of this source tree.

import argparse
import moshi
import sentencepiece
import torch
import sphn
import numpy as np
import random
import time

import numpy as np
import sentencepiece
import sphn
import torch
from torch.profiler import profile, ProfilerActivity

SAMPLE_RATE = moshi.models.moshi.SAMPLE_RATE
DEVICE = "cuda:0"
ENABLE_PROFILING = False
from moshi.models import loaders, LMGen


parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", type=str)
parser.add_argument("--moshi-weights", type=str)
parser.add_argument("--mimi-weights", type=str)
parser.add_argument("--tokenizer", type=str, default=loaders.TEXT_TOKENIZER_V0_1,
help="Name of the text tokenizer file in the given HF repo, or path to a local file.")
parser.add_argument("--moshi-weight", type=str, default=loaders.MOSHIKO_V0_1,
help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.")
parser.add_argument("--mimi-weight", type=str, default=loaders.MIMI_V0_1,
help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.")
parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO,
help="HF repo to look into, defaults to Kyutai official one.")
parser.add_argument("--steps", default=100, type=int)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--device", type=str, default='cuda')
args = parser.parse_args()


Expand All @@ -39,52 +43,53 @@ def seed_all(seed):

seed_all(42424242)

tokenizer_path = loaders.resolve_model_checkpoint(args.tokenizer, args.hf_repo)
text_tokenizer = loaders.get_text_tokenizer(tokenizer_path)

print("loading mimi")
ec = moshi.models.moshi.get_encodec(args.mimi_weights, DEVICE)
mimi_path = loaders.resolve_model_checkpoint(args.mimi_weight, args.hf_repo)
mimi = loaders.get_mimi(mimi_path, args.device)
print("mimi loaded")
text_tokenizer = sentencepiece.SentencePieceProcessor(args.tokenizer)

print("loading moshi")
lm = moshi.models.moshi.get_lm(args.moshi_weights, DEVICE)
lm.to(torch.bfloat16)
moshi_path = loaders.resolve_model_checkpoint(args.moshi_weight, args.hf_repo)
lm = loaders.get_moshi_lm(moshi_path, args.device)
lm_gen = LMGen(lm)
print("lm loaded")

lm_gen = moshi.models.LMGen(lm)


def cb(step, total):
print(f"{step:06d} / {total:06d}", end="\r")


def streaming_test(bs):

main_audio = []
main_text = []

frame_size = int(mimi.sample_rate / mimi.frame_rate)

def run_step():
start_time = time.time()
# Chunk should contain the pcm data from the user, single channel with a sample rate of 24000.
chunk = torch.zeros((bs, 1, 1920), dtype=torch.float, device=DEVICE)
codes = ec.encode(chunk)
chunk = torch.zeros((bs, 1, frame_size), dtype=torch.float, device=args.device)
codes = mimi.encode(chunk)
assert codes.shape[-1] == 1
for c in range(codes.shape[-1]):
be = time.time()
ev = torch.cuda.Event(enable_timing=True)
ev.record()
tokens = lm_gen.step(codes[:, :, c : c + 1])
if tokens is None:
print("Skipping")
return
evb = torch.cuda.Event(enable_timing=True)
evb.record()
dt_step = time.time() - be
text_tokens = tokens[:, 0, 0]
audio_tokens = tokens[:, 1:, :]
main_pcm = ec.decode(audio_tokens)
# main_pcm is the audio to be played back to the user, here we just append it and store it in
# a file once the loop is finished.
main_audio.append(main_pcm[0])
be = time.time()
ev = torch.cuda.Event(enable_timing=True)
ev.record()
tokens = lm_gen.step(codes[:, :, :1])
if tokens is None:
print("Skipping")
return
evb = torch.cuda.Event(enable_timing=True)
evb.record()
dt_step = time.time() - be
text_tokens = tokens[:, 0, 0]
audio_tokens = tokens[:, 1:, :]
main_pcm = mimi.decode(audio_tokens)
# main_pcm is the audio to be played back to the user, here we just append it and store it in
# a file once the loop is finished.
main_audio.append(main_pcm[0])
evb.synchronize()
dg = ev.elapsed_time(evb)
torch.cuda.synchronize()
Expand All @@ -109,17 +114,17 @@ def run_step():
run_step()
print()
prof.export_chrome_trace("trace.json")
main_audio = torch.cat(main_audio, dim=-1)
print(main_audio.shape)
main_audio_th = torch.cat(main_audio, dim=-1)
print(main_audio_th.shape)
print("generated text:")
print("".join(main_text))
sphn.write_wav(
"gen_main.wav", main_audio[0].cpu().numpy().astype(np.float32), SAMPLE_RATE
"gen_main.wav", main_audio_th[0].cpu().numpy().astype(np.float32), mimi.sample_rate
)


print("streaming test")
bs = 1
with torch.no_grad():
with ec.streaming(bs), lm_gen.streaming(bs):
with mimi.streaming(bs), lm_gen.streaming(bs):
streaming_test(bs)

0 comments on commit 59e732c

Please sign in to comment.