From 19469d938a1b3e4d3f922feb6efec29a3ad422bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20D=C3=A9fossez?= Date: Thu, 19 Sep 2024 15:33:17 +0200 Subject: [PATCH] Fix moshi_benchmark Fix for issue reported in PR https://github.com/kyutai-labs/moshi/pull/64 --- scripts/moshi_benchmark.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/scripts/moshi_benchmark.py b/scripts/moshi_benchmark.py index 081e9d0..e1089cd 100644 --- a/scripts/moshi_benchmark.py +++ b/scripts/moshi_benchmark.py @@ -17,14 +17,12 @@ parser = argparse.ArgumentParser() -parser.add_argument("--tokenizer", type=str, default=loaders.TEXT_TOKENIZER_V0_1, - help="Name of the text tokenizer file in the given HF repo, or path to a local file.") -parser.add_argument("--moshi-weight", type=str, default=loaders.MOSHIKO_V0_1, - help="Name of the Moshi checkpoint in the given HF repo, or path to a local file.") -parser.add_argument("--mimi-weight", type=str, default=loaders.MIMI_V0_1, - help="Name of the Mimi checkpoint in the given HF repo, or path to a local file.") -parser.add_argument("--hf-repo", type=str, default=loaders.HF_REPO, - help="HF repo to look into, defaults to Kyutai's official one.") +parser.add_argument("--tokenizer", type=str, help="Path to a local tokenizer file.") +parser.add_argument("--moshi-weight", type=str, help="Path to a local checkpoint file for Moshi.") +parser.add_argument("--mimi-weight", type=str, help="Path to a local checkpoint file for Mimi.") +parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO, + help="HF repo to look into, defaults Moshiko. " + "Use this to select a different pre-trained model.") parser.add_argument("--steps", default=100, type=int) parser.add_argument("--profile", action="store_true") parser.add_argument("--device", type=str, default='cuda') @@ -58,16 +56,12 @@ def seed_all(seed): if args.moshi_weight is None: args.moshi_weight = hf_hub_download(args.hf_repo, loaders.MOSHI_NAME) lm = loaders.get_moshi_lm(args.moshi_weight, args.device) -moshi_path = loaders.resolve_model_checkpoint(args.moshi_weight, args.hf_repo) -lm = loaders.get_moshi_lm(moshi_path, args.device) lm_gen = LMGen(lm) print("lm loaded") - def cb(step, total): print(f"{step:06d} / {total:06d}", end="\r") - def streaming_test(bs): main_audio = [] main_text = []