diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md new file mode 100644 index 0000000..b0f3d65 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -0,0 +1,29 @@ +--- +name: 🐛 Bug Report +about: Submit a bug report to help us improve +labels: bug, triage +--- + +## 🐛 Bug Report + + + +## To Reproduce + + + +## Your Environment + + + +**Please fill this part, failure to do so will lead to your issue being directly closed.** + +- Backend impacted (PyTorch, MLX, Rust, or Other): +- Operating system (OSX, Windows, Linux): +- Operating system version: + +If the backend impacted is PyTorch: +- Python version: +- PyTorch version: +- CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`): +- GPU model and memory: diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md new file mode 100644 index 0000000..a074579 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/question.md @@ -0,0 +1,9 @@ +--- +name: "❓Questions/Help/Support" +about: If you have a question about the paper, code or algorithm, please ask here! +labels: question, triage +--- + +## ❓ Questions + + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..8ae0569 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1 @@ +# Test diff --git a/.github/actions/moshi_build/action.yml b/.github/actions/moshi_build/action.yml new file mode 100755 index 0000000..432edf4 --- /dev/null +++ b/.github/actions/moshi_build/action.yml @@ -0,0 +1,27 @@ +name: moshi_build +description: 'Build env.' +runs: + using: "composite" + steps: + - uses: actions/setup-python@v2 + with: + python-version: '3.10' + - uses: actions/cache@v3 + id: cache + with: + path: env + key: env-${{ hashFiles('moshi/pyproject.toml') }} + - name: Install dependencies + if: steps.cache.outputs.cache-hit != 'true' + shell: bash + run: | + python3 -m venv env + . env/bin/activate + python -m pip install --upgrade pip + pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu + pip install -e './moshi[dev]' + - name: Setup env + shell: bash + run: | + . env/bin/activate + pre-commit install diff --git a/.github/actions/rust_build/action.yml b/.github/actions/rust_build/action.yml new file mode 100755 index 0000000..7a11a68 --- /dev/null +++ b/.github/actions/rust_build/action.yml @@ -0,0 +1,33 @@ +name: rust_build +description: 'Setup rust env' +inputs: + os: + default: ubuntu-latest + toolchain: + default: stable + target: + default: check +runs: + using: "composite" + steps: + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ inputs.toolchain }} + override: true + - name: cargo cache + uses: actions/cache@v3 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + rust/target/ + key: ${{ inputs.os }}-cargo-${{ inputs.target }}-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ inputs.os }}-cargo- + - name: install deps + shell: bash + run: | + sudo apt-get update + sudo apt-get install libasound2-dev diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml new file mode 100644 index 0000000..51ca7ba --- /dev/null +++ b/.github/workflows/precommit.yml @@ -0,0 +1,17 @@ +name: precommit +on: + push: + branches: [ main ] + pull_request: + branches: [ main, refacto ] + +jobs: + run_precommit: + name: Run precommit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: ./.github/actions/moshi_build + - run: | + . env/bin/activate + bash .git/hooks/pre-commit diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml new file mode 100644 index 0000000..43448c9 --- /dev/null +++ b/.github/workflows/rust-ci.yml @@ -0,0 +1,53 @@ +on: + push: + branches: [ main ] + pull_request: + branches: [ main, refacto ] + +name: Rust CI + +jobs: + check: + name: Check + defaults: + run: + working-directory: ./rust + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + rust: [stable] + steps: + - uses: actions/checkout@v2 + - uses: ./.github/actions/rust_build + - name: check + shell: bash + run: | + cargo check + - name: clippy + shell: bash + run: | + cargo clippy -- -D warnings + - name: fmt + shell: bash + run: | + cargo fmt --all -- --check + test: + name: Test + defaults: + run: + working-directory: ./rust + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + rust: [stable] + steps: + - uses: actions/checkout@v2 + - uses: ./.github/actions/rust_build + with: + target: test + - name: test + shell: bash + run: | + cargo test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a977dab --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,24 @@ +repos: + - repo: local + hooks: + - id: flake8-moshi + name: flake8 on moshi package + language: system + entry: bash -c 'cd moshi && flake8' + pass_filenames: false + - id: pyright-moshi + name: pyright on moshi package + language: system + entry: bash -c 'cd moshi && pyright' + pass_filenames: false + - id: flake8-moshi_mlx + name: flake8 on moshi_mlx package + language: system + entry: bash -c 'cd moshi_mlx && flake8' + pass_filenames: false + - id: pyright-moshi_mlx + name: pyright on moshi_mlx package + language: system + entry: bash -c 'cd moshi_mlx && pyright' + pass_filenames: false + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..af18af5 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,30 @@ +# Contributing to Moshi + +## Pull Requests + +Moshi is the implementation of a research paper. +Therefore, we do not plan on accepting many pull requests for new features. +We certainly welcome them for bug fixes. + +1. Fork the repo and create your branch from `main`. +2. If you've changed APIs, update the documentation. +3. Ensure pre-commit hooks pass properly, in particular the linting and typing. +4. Accept the Contributor License Agreement (see after). + +Note that in general we will not accept refactoring of the code. + + +## Contributor License Agreement ("CLA") + +In order to accept your pull request, we need you to submit a Contributtor License Agreement. +As this CLA is not ready yet, we will delay acceptance of your PR. + +## Issues + +Please submit issues on our Github repository. + +## License + +By contributing to Moshi, you agree that your contributions will be licensed +under the LICENSE-* files in the root directory of this source tree. +In particular, the rust code is licensed under APACHE, and the python code under MIT. diff --git a/moshi/LICENSE b/moshi/LICENSE new file mode 100644 index 0000000..31aa793 --- /dev/null +++ b/moshi/LICENSE @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/moshi/pyproject.toml b/moshi/pyproject.toml index d5b7578..e0da712 100644 --- a/moshi/pyproject.toml +++ b/moshi/pyproject.toml @@ -25,6 +25,9 @@ version = {attr = "moshi.__version__"} requires = ["setuptools"] build-backend = "setuptools.build_meta" +[tool.setuptools.dynamic] +version = {attr = "moshi.__version__"} + [project.optional-dependencies] dev = [ "pyright", diff --git a/moshi_mlx/LICENSE b/moshi_mlx/LICENSE new file mode 100644 index 0000000..31aa793 --- /dev/null +++ b/moshi_mlx/LICENSE @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/moshi_mlx/moshi_mlx/__init__.py b/moshi_mlx/moshi_mlx/__init__.py index 5876c62..2c0d3d3 100644 --- a/moshi_mlx/moshi_mlx/__init__.py +++ b/moshi_mlx/moshi_mlx/__init__.py @@ -1,12 +1,12 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# flake8: noqa """ moshi_mlx is the MLX inference codebase for Kyutai audio generation models. """ -# flake8: noqa from . import modules, models, utils __version__ = "0.0.1" diff --git a/moshi_mlx/moshi_mlx/local.py b/moshi_mlx/moshi_mlx/local.py index 742b56e..28ab3a6 100644 --- a/moshi_mlx/moshi_mlx/local.py +++ b/moshi_mlx/moshi_mlx/local.py @@ -20,7 +20,7 @@ from .client_utils import AnyPrinter, Printer, RawPrinter import rustymimi -import moshi_mlx +from moshi_mlx import models, utils import huggingface_hub @@ -99,10 +99,10 @@ def log(s): printer_q.put_nowait((PrinterType.INFO, s)) log(f"[SERVER] loading text tokenizer {tokenizer_file}") - text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_file) + text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_file) # type: ignore mx.random.seed(299792458) - lm_config = moshi_mlx.models.config_v0_1() - model = moshi_mlx.models.Lm(lm_config) + lm_config = models.config_v0_1() + model = models.Lm(lm_config) model.set_dtype(mx.bfloat16) if args.quantized is not None: group_size = 32 if args.quantized == 4 else 64 @@ -114,11 +114,11 @@ def log(s): model.warmup() log("[SERVER] model warmed up") - gen = moshi_mlx.models.LmGen( + gen = models.LmGen( model=model, max_steps=steps + 5, - text_sampler=moshi_mlx.utils.Sampler(), - audio_sampler=moshi_mlx.utils.Sampler(), + text_sampler=utils.Sampler(), + audio_sampler=utils.Sampler(), check=False, ) @@ -137,7 +137,7 @@ def log(s): text_token = text_token[0].item() audio_tokens = gen.last_audio_tokens() if text_token not in (0, 3): - _text = text_tokenizer.id_to_piece(text_token) + _text = text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") printer_q.put_nowait((PrinterType.TOKEN, _text)) else: @@ -158,7 +158,7 @@ def client(printer_q, client_to_server, server_to_client, args): ) input_queue = queue.Queue() output_queue = queue.Queue() - audio_tokenizer = rustymimi.StreamTokenizer(mimi_file) + audio_tokenizer = rustymimi.StreamTokenizer(mimi_file) # type: ignore start = server_to_client.get() printer_q.put_nowait( (PrinterType.INFO, f"[CLIENT] received '{start}' from server, starting...") @@ -255,9 +255,9 @@ def main(printer: AnyPrinter): parser.add_argument("--tokenizer", type=str) parser.add_argument("--model", type=str) parser.add_argument("--mimi", type=str) - parser.add_argument("--quantized", type=int) + parser.add_argument("-q", "--quantized", type=int) parser.add_argument("--steps", default=2500, type=int) - parser.add_argument("--hf-repo", type=str, default="") + parser.add_argument("--hf-repo", type=str, default="kmhf/msh-v0.1") args = parser.parse_args() diff --git a/moshi_mlx/moshi_mlx/local_web.py b/moshi_mlx/moshi_mlx/local_web.py index 2e22aba..a3632e5 100644 --- a/moshi_mlx/moshi_mlx/local_web.py +++ b/moshi_mlx/moshi_mlx/local_web.py @@ -22,7 +22,7 @@ import mlx.nn as nn import rustymimi -import moshi_mlx +from moshi_mlx import models, utils import huggingface_hub @@ -120,10 +120,10 @@ def model_server(client_to_server, server_to_client, args): steps = args.steps log("info", f"[SERVER] loading text tokenizer {tokenizer_file}") - text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_file) + text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_file) # type: ignore mx.random.seed(299792458) - lm_config = moshi_mlx.models.config_v0_1() - model = moshi_mlx.models.Lm(lm_config) + lm_config = models.config_v0_1() + model = models.Lm(lm_config) model.set_dtype(mx.bfloat16) if args.quantized is not None: group_size = 32 if args.quantized == 4 else 64 @@ -135,11 +135,11 @@ def model_server(client_to_server, server_to_client, args): model.warmup() log("info", "[SERVER] model warmed up") - gen = moshi_mlx.models.LmGen( + gen = models.LmGen( model=model, max_steps=steps + 5, - text_sampler=moshi_mlx.utils.Sampler(), - audio_sampler=moshi_mlx.utils.Sampler(), + text_sampler=utils.Sampler(), + audio_sampler=utils.Sampler(), check=False, ) @@ -153,7 +153,7 @@ def model_server(client_to_server, server_to_client, args): text_token = text_token[0].item() audio_tokens = gen.last_audio_tokens() if text_token not in (0, 3): - _text = text_tokenizer.id_to_piece(text_token) + _text = text_tokenizer.id_to_piece(text_token) # type: ignore _text = _text.replace("▁", " ") server_to_client.put_nowait((1, _text)) if audio_tokens is not None: @@ -172,7 +172,7 @@ def web_server(client_to_server, server_to_client, args): input_queue = queue.Queue() output_queue = queue.Queue() text_queue = queue.Queue() - audio_tokenizer = rustymimi.StreamTokenizer(mimi_file) + audio_tokenizer = rustymimi.StreamTokenizer(mimi_file) # type: ignore start = server_to_client.get() log("info", f"[CLIENT] received '{start}' from server, starting...") @@ -311,7 +311,7 @@ async def go(): app.router.add_get("/api/chat", handle_chat) static_path: None | str = None if args.static is None: - log("info", f"retrieving the static content") + log("info", "retrieving the static content") dist_tgz = hf_hub_download(args.hf_repo, "dist.tgz") dist_tgz = Path(dist_tgz) dist = dist_tgz.parent / "dist" @@ -330,7 +330,7 @@ async def handle_root(_): log("info", f"serving static content from {static_path}") app.router.add_get("/", handle_root) app.router.add_static("/", path=static_path, name="static") - log("info", f"listening to ws://{args.host}:{args.port}") + log("info", f"listening to http://{args.host}:{args.port}") runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, args.host, args.port) @@ -350,9 +350,9 @@ def main(): parser.add_argument("--tokenizer", type=str) parser.add_argument("--model", type=str) parser.add_argument("--mimi", type=str) - parser.add_argument("--quantized", type=int) + parser.add_argument("-q", "--quantized", type=int) parser.add_argument("--steps", default=2500, type=int) - parser.add_argument("--hf-repo", type=str, default="") + parser.add_argument("--hf-repo", type=str, default="kmhf/msh-v0.1") parser.add_argument("--static", type=str) parser.add_argument("--host", default="localhost", type=str) parser.add_argument("--port", default=8998, type=int) @@ -370,7 +370,6 @@ def main(): # Start the processes p1.start() p2.start() - events = [] try: while p1.is_alive() and p2.is_alive(): diff --git a/moshi_mlx/moshi_mlx/models/__init__.py b/moshi_mlx/moshi_mlx/models/__init__.py index e39de69..14d021d 100644 --- a/moshi_mlx/moshi_mlx/models/__init__.py +++ b/moshi_mlx/moshi_mlx/models/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# flake8: noqa """ Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. """ diff --git a/moshi_mlx/moshi_mlx/models/generate.py b/moshi_mlx/moshi_mlx/models/generate.py index be03144..1b7bb36 100644 --- a/moshi_mlx/moshi_mlx/models/generate.py +++ b/moshi_mlx/moshi_mlx/models/generate.py @@ -69,13 +69,13 @@ def step(self, other_audio_tokens: mx.array) -> mx.array: audio_token = self.gen_sequence[:, cb_idx + 1, gen_idx][None] else: audio_token = mx.array([[self.audio_padding_token]]) - if (audio_token == self.ungenerated_token).any(): + if (audio_token == self.ungenerated_token).any(): # type: ignore raise ValueError( f"ungenerated value in audio tokens cb: {cb_idx} step: {self.step_idx}" ) assert audio_token.shape == (1, 1), "invalid audio-tokens shape" audio_tokens.append(audio_token) - if (text_tokens == self.ungenerated_token).any(): + if (text_tokens == self.ungenerated_token).any(): # type: ignore raise ValueError(f"ungenerated value in text tokens {self.step_idx}") assert text_tokens.shape == (1, 1), "invalid text-tokens shape" text_tokens, audio_tokens = self.model.sample( @@ -101,8 +101,8 @@ def last_audio_tokens(self) -> Optional[mx.array]: if gen_idx < 0: return None tokens = self.gen_sequence[:, 1 : 1 + self.main_codebooks, gen_idx] - if (tokens == self.audio_padding_token).any(): + if (tokens == self.audio_padding_token).any(): # type: ignore return None - if (tokens == self.ungenerated_token).any(): + if (tokens == self.ungenerated_token).any(): # type: ignore raise ValueError(f"ungenerated value in last-audio tokens {self.step_idx}") return tokens diff --git a/moshi_mlx/moshi_mlx/models/lm.py b/moshi_mlx/moshi_mlx/models/lm.py index a392397..027358a 100644 --- a/moshi_mlx/moshi_mlx/models/lm.py +++ b/moshi_mlx/moshi_mlx/models/lm.py @@ -3,7 +3,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -27,7 +26,7 @@ class LmConfig: text_out_vocab_size: int audio_vocab_size: int audio_codebooks: int - audio_delays: List[int] + audio_delays: list[int] @property def audio_eos_token(self) -> int: @@ -62,7 +61,7 @@ class DepFormer(nn.Module): def __init__(self, cfg: LmConfig): super().__init__() - self.slices: List[DepFormerSlice] = [] + self.slices: list[DepFormerSlice] = [] for slice_idx in range(cfg.depformer.num_slices): in_vs = cfg.text_in_vocab_size if slice_idx == 0 else cfg.audio_vocab_size slice = DepFormerSlice( @@ -144,11 +143,11 @@ def __call__( def sample( self, text_token_ids: mx.array, - audio_token_ids: List[mx.array], + audio_token_ids: list[mx.array], step_idx: int, text_sampler: sampling.Sampler, audio_sampler: sampling.Sampler, - ) -> Tuple[mx.array, mx.array]: + ) -> tuple[mx.array, mx.array]: xs = self.text_emb(text_token_ids) for token_ids, emb in zip(audio_token_ids, self.audio_embs): xs = xs + emb(token_ids) diff --git a/moshi_mlx/moshi_mlx/modules/__init__.py b/moshi_mlx/moshi_mlx/modules/__init__.py index 2d2de8d..10a4218 100644 --- a/moshi_mlx/moshi_mlx/modules/__init__.py +++ b/moshi_mlx/moshi_mlx/modules/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# flake8: noqa """Modules used for building the models.""" from .kv_cache import KVCache, RotatingKVCache diff --git a/moshi_mlx/moshi_mlx/modules/kv_cache.py b/moshi_mlx/moshi_mlx/modules/kv_cache.py index 7cc7005..257675a 100644 --- a/moshi_mlx/moshi_mlx/modules/kv_cache.py +++ b/moshi_mlx/moshi_mlx/modules/kv_cache.py @@ -4,7 +4,7 @@ import inspect from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any import mlx.core as mx @@ -24,7 +24,7 @@ def __init__(self, head_dim, n_kv_heads): self.offset = 0 self.step = 256 - def update_and_fetch(self, keys, values) -> Tuple[mx.array, mx.array]: + def update_and_fetch(self, keys, values) -> tuple[mx.array, mx.array]: prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: B = keys.shape[0] @@ -34,6 +34,7 @@ def update_and_fetch(self, keys, values) -> Tuple[mx.array, mx.array]: new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: + assert self.values is not None if prev % self.step != 0: self.keys = self.keys[..., :prev, :] self.values = self.values[..., :prev, :] @@ -44,6 +45,7 @@ def update_and_fetch(self, keys, values) -> Tuple[mx.array, mx.array]: self.offset += keys.shape[2] self.keys[..., prev : self.offset, :] = keys + assert self.values is not None self.values[..., prev : self.offset, :] = values return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] @@ -83,7 +85,7 @@ def _trim(self, trim_size, v, append=None): to_cat.append(append) return mx.concatenate(to_cat, axis=2) - def update_and_fetch(self, keys, values) -> Tuple[mx.array, mx.array]: + def update_and_fetch(self, keys, values) -> tuple[mx.array, mx.array]: prev = self.offset B, _, S = keys.shape[:3] @@ -114,6 +116,7 @@ def update_and_fetch(self, keys, values) -> Tuple[mx.array, mx.array]: new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: + assert self.values is not None self.keys = mx.concatenate([self.keys, new_k], axis=2) self.values = mx.concatenate([self.values, new_v], axis=2) else: @@ -133,6 +136,7 @@ def update_and_fetch(self, keys, values) -> Tuple[mx.array, mx.array]: # Assign self.keys[..., self._idx : self._idx + 1, :] = keys + assert self.values is not None self.values[..., self._idx : self._idx + 1, :] = values self.offset += 1 self._idx += 1 @@ -171,7 +175,7 @@ def create_additive_causal_mask(N: int, offset: int = 0): return mask * -1e9 -def create_attention_mask(h: mx.array, cache: Optional[Any] = None): +def create_attention_mask(h: mx.array, cache: Any | None = None): T = h.shape[1] if T > 1: if cache is not None and cache[0] is not None: diff --git a/moshi_mlx/moshi_mlx/modules/transformer.py b/moshi_mlx/moshi_mlx/modules/transformer.py index d039edc..31c1b06 100644 --- a/moshi_mlx/moshi_mlx/modules/transformer.py +++ b/moshi_mlx/moshi_mlx/modules/transformer.py @@ -3,7 +3,6 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import List, Optional, Tuple from .kv_cache import KVCache, RotatingKVCache import mlx.core as mx @@ -19,7 +18,7 @@ class TransformerConfig: norm_first: bool bias_ff: bool bias_attn: bool - layer_scale: Optional[float] + layer_scale: float | None positional_embedding: str use_conv_block: bool cross_attention: bool @@ -75,7 +74,7 @@ def __call__( self, xs: mx.array, cache: KVCache | RotatingKVCache, - mask: Optional[mx.array] = None, + mask: mx.array | None = None, ) -> mx.array: assert self.cfg.kv_repeat == 1, "only kv_repeat==1 is supported" @@ -181,7 +180,7 @@ def __init__(self, cfg: TransformerConfig): def __call__( self, xs: mx.array, - cache: List[KVCache] | List[RotatingKVCache], + cache: list[KVCache] | list[RotatingKVCache], ) -> mx.array: for layer, c in zip(self.layers, cache): xs = layer(xs, cache=c) diff --git a/moshi_mlx/moshi_mlx/utils/__init__.py b/moshi_mlx/moshi_mlx/utils/__init__.py index a2f25b9..f67bd8c 100644 --- a/moshi_mlx/moshi_mlx/utils/__init__.py +++ b/moshi_mlx/moshi_mlx/utils/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# flake8: noqa """Utilities.""" from .sampling import Sampler diff --git a/moshi_mlx/moshi_mlx/utils/sampling.py b/moshi_mlx/moshi_mlx/utils/sampling.py index 1b1f4e9..97b50d3 100644 --- a/moshi_mlx/moshi_mlx/utils/sampling.py +++ b/moshi_mlx/moshi_mlx/utils/sampling.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from functools import partial -from typing import List, Optional, Tuple, Dict import mlx.core as mx @@ -39,7 +38,7 @@ def min_p_sampling( raise ValueError( f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}" ) - # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 + # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 # noqa # Softmax probabilities probs = mx.softmax(logits * (1 / temperature), axis=-1) @@ -78,7 +77,7 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr Returns: token selected based on the top-p criterion. """ - # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 + # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 # noqa probs = mx.softmax(logits * (1 / temperature), axis=-1) # sort probs in ascending order @@ -111,9 +110,9 @@ class Sampler: top_p: float = 0.95 min_p: float = 0.0 min_tokens_to_keep: int = 1 - logit_bias: Optional[Dict[int, float]] = None + logit_bias: dict[int, float] | None = None - def __call__(self, logits: mx.array) -> Tuple[mx.array, float]: + def __call__(self, logits: mx.array) -> tuple[mx.array, mx.array]: if self.logit_bias: indices = mx.array(list(self.logit_bias.keys())) values = mx.array(list(self.logit_bias.values())) diff --git a/moshi_mlx/pyproject.toml b/moshi_mlx/pyproject.toml index 3e2d87d..c6bce62 100644 --- a/moshi_mlx/pyproject.toml +++ b/moshi_mlx/pyproject.toml @@ -1,6 +1,5 @@ [project] name = "moshi_mlx" -version = "0.0.1" requires-python = ">= 3.10" description = "Moshi is moshi, but running on macOS" dependencies = [ @@ -17,11 +16,19 @@ dependencies = [ authors = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] maintainers = [{name="Laurent Mazaré", email="laurent@kyutai.org"}] license = {text = "MIT"} +dynamic = ["version"] [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" -[tool.setuptools] -packages = ["moshi_mlx", "moshi_mlx.modules", "moshi_mlx.models", "moshi_mlx.utils"] +[tool.setuptools.dynamic] +version = {attr = "moshi_mlx.__version__"} + +[project.optional-dependencies] +dev = [ + "pyright", + "flake8", + "pre-commit", +] diff --git a/moshi_mlx/setup.cfg b/moshi_mlx/setup.cfg index dc7aa4b..5bccac4 100644 --- a/moshi_mlx/setup.cfg +++ b/moshi_mlx/setup.cfg @@ -3,3 +3,4 @@ max-line-length = 120 [flake8] max-line-length = 120 +ignore = E203,E704 diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..6f5acfb --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +pre-commit>=3.8 +pyright>=1.1 +flake8>=7.1