Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI [2/N] #29

Merged
merged 30 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/ISSUE_TEMPLATE/bug.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
name: 🐛 Bug Report
about: Submit a bug report to help us improve
labels: bug, triage
---

## 🐛 Bug Report

<!-- A clear and concise description of what the bug is -->

## To Reproduce

<!-- How to reproduce the bug -->

## Your Environment

<!-- Include as many relevant details about the environment you experienced the bug in.-->

**Please fill this part, failure to do so will lead to your issue being directly closed.**

- Backend impacted (PyTorch, MLX, Rust, or Other):
- Operating system (OSX, Windows, Linux):
- Operating system version:

If the backend impacted is PyTorch:
- Python version:
- PyTorch version:
- CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`):
- GPU model and memory:
9 changes: 9 additions & 0 deletions .github/ISSUE_TEMPLATE/question.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
name: "❓Questions/Help/Support"
about: If you have a question about the paper, code or algorithm, please ask here!
labels: question, triage
---

## ❓ Questions

<!-- (Please ask your question here.) -->
1 change: 1 addition & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Test
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will test once this is merged, for now I cannot see how this renders.

27 changes: 27 additions & 0 deletions .github/actions/moshi_build/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: moshi_build
description: 'Build env.'
runs:
using: "composite"
steps:
- uses: actions/setup-python@v2
with:
python-version: '3.10'
- uses: actions/cache@v3
id: cache
with:
path: env
key: env-${{ hashFiles('moshi/pyproject.toml') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
shell: bash
run: |
python3 -m venv env
. env/bin/activate
python -m pip install --upgrade pip
pip install torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu
pip install -e './moshi[dev]'
- name: Setup env
shell: bash
run: |
. env/bin/activate
pre-commit install
33 changes: 33 additions & 0 deletions .github/actions/rust_build/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: rust_build
description: 'Setup rust env'
inputs:
os:
default: ubuntu-latest
toolchain:
default: stable
target:
default: check
runs:
using: "composite"
steps:
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: ${{ inputs.toolchain }}
override: true
- name: cargo cache
uses: actions/cache@v3
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
rust/target/
key: ${{ inputs.os }}-cargo-${{ inputs.target }}-${{ hashFiles('**/Cargo.lock') }}
restore-keys: ${{ inputs.os }}-cargo-
- name: install deps
shell: bash
run: |
sudo apt-get update
sudo apt-get install libasound2-dev
17 changes: 17 additions & 0 deletions .github/workflows/precommit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: precommit
on:
push:
branches: [ main ]
pull_request:
branches: [ main, refacto ]

jobs:
run_precommit:
name: Run precommit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/moshi_build
- run: |
. env/bin/activate
bash .git/hooks/pre-commit
53 changes: 53 additions & 0 deletions .github/workflows/rust-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
on:
push:
branches: [ main ]
pull_request:
branches: [ main, refacto ]

name: Rust CI

jobs:
check:
name: Check
defaults:
run:
working-directory: ./rust
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/rust_build
- name: check
shell: bash
run: |
cargo check
- name: clippy
shell: bash
run: |
cargo clippy -- -D warnings
- name: fmt
shell: bash
run: |
cargo fmt --all -- --check
test:
name: Test
defaults:
run:
working-directory: ./rust
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: ./.github/actions/rust_build
with:
target: test
- name: test
shell: bash
run: |
cargo test
24 changes: 24 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
repos:
- repo: local
hooks:
- id: flake8-moshi
name: flake8 on moshi package
language: system
entry: bash -c 'cd moshi && flake8'
pass_filenames: false
- id: pyright-moshi
name: pyright on moshi package
language: system
entry: bash -c 'cd moshi && pyright'
pass_filenames: false
- id: flake8-moshi_mlx
name: flake8 on moshi_mlx package
language: system
entry: bash -c 'cd moshi_mlx && flake8'
pass_filenames: false
- id: pyright-moshi_mlx
name: pyright on moshi_mlx package
language: system
entry: bash -c 'cd moshi_mlx && pyright'
pass_filenames: false

30 changes: 30 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Contributing to Moshi

## Pull Requests

Moshi is the implementation of a research paper.
Therefore, we do not plan on accepting many pull requests for new features.
We certainly welcome them for bug fixes.

1. Fork the repo and create your branch from `main`.
2. If you've changed APIs, update the documentation.
3. Ensure pre-commit hooks pass properly, in particular the linting and typing.
4. Accept the Contributor License Agreement (see after).

Note that in general we will not accept refactoring of the code.


## Contributor License Agreement ("CLA")

In order to accept your pull request, we need you to submit a Contributtor License Agreement.
As this CLA is not ready yet, we will delay acceptance of your PR.

## Issues

Please submit issues on our Github repository.

## License

By contributing to Moshi, you agree that your contributions will be licensed
under the LICENSE-* files in the root directory of this source tree.
In particular, the rust code is licensed under APACHE, and the python code under MIT.
2 changes: 2 additions & 0 deletions moshi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ build-backend = "setuptools.build_meta"
[tool.setuptools]
packages = ["moshi", "moshi.utils", "moshi.modules", "moshi.models", "moshi.quantization"]

[tool.setuptools.dynamic]
version = {attr = "moshi.__version__"}

[project.optional-dependencies]
dev = [
Expand Down
2 changes: 1 addition & 1 deletion moshi_mlx/moshi_mlx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa

"""
moshi_mlx is the MLX inference codebase for Kyutai audio generation models.
"""

# flake8: noqa
from . import modules, models, utils

__version__ = "0.0.1"
6 changes: 3 additions & 3 deletions moshi_mlx/moshi_mlx/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def log(s):
printer_q.put_nowait((PrinterType.INFO, s))

log(f"[SERVER] loading text tokenizer {tokenizer_file}")
text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_file)
text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_file) # type: ignore
mx.random.seed(299792458)
lm_config = moshi_mlx.models.config_v0_1()
model = moshi_mlx.models.Lm(lm_config)
Expand Down Expand Up @@ -137,7 +137,7 @@ def log(s):
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
if text_token not in (0, 3):
_text = text_tokenizer.id_to_piece(text_token)
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
printer_q.put_nowait((PrinterType.TOKEN, _text))
else:
Expand All @@ -158,7 +158,7 @@ def client(printer_q, client_to_server, server_to_client, args):
)
input_queue = queue.Queue()
output_queue = queue.Queue()
audio_tokenizer = rustymimi.StreamTokenizer(mimi_file)
audio_tokenizer = rustymimi.StreamTokenizer(mimi_file) # type: ignore
start = server_to_client.get()
printer_q.put_nowait(
(PrinterType.INFO, f"[CLIENT] received '{start}' from server, starting...")
Expand Down
9 changes: 4 additions & 5 deletions moshi_mlx/moshi_mlx/local_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def model_server(client_to_server, server_to_client, args):
steps = args.steps

log("info", f"[SERVER] loading text tokenizer {tokenizer_file}")
text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_file)
text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_file) # type: ignore
mx.random.seed(299792458)
lm_config = moshi_mlx.models.config_v0_1()
model = moshi_mlx.models.Lm(lm_config)
Expand Down Expand Up @@ -153,7 +153,7 @@ def model_server(client_to_server, server_to_client, args):
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
if text_token not in (0, 3):
_text = text_tokenizer.id_to_piece(text_token)
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("▁", " ")
server_to_client.put_nowait((1, _text))
if audio_tokens is not None:
Expand All @@ -172,7 +172,7 @@ def web_server(client_to_server, server_to_client, args):
input_queue = queue.Queue()
output_queue = queue.Queue()
text_queue = queue.Queue()
audio_tokenizer = rustymimi.StreamTokenizer(mimi_file)
audio_tokenizer = rustymimi.StreamTokenizer(mimi_file) # type: ignore
start = server_to_client.get()
log("info", f"[CLIENT] received '{start}' from server, starting...")

Expand Down Expand Up @@ -311,7 +311,7 @@ async def go():
app.router.add_get("/api/chat", handle_chat)
static_path: None | str = None
if args.static is None:
log("info", f"retrieving the static content")
log("info", "retrieving the static content")
dist_tgz = hf_hub_download(args.hf_repo, "dist.tgz")
dist_tgz = Path(dist_tgz)
dist = dist_tgz.parent / "dist"
Expand Down Expand Up @@ -370,7 +370,6 @@ def main():
# Start the processes
p1.start()
p2.start()
events = []

try:
while p1.is_alive() and p2.is_alive():
Expand Down
1 change: 1 addition & 0 deletions moshi_mlx/moshi_mlx/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
"""
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
"""
Expand Down
8 changes: 4 additions & 4 deletions moshi_mlx/moshi_mlx/models/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ def step(self, other_audio_tokens: mx.array) -> mx.array:
audio_token = self.gen_sequence[:, cb_idx + 1, gen_idx][None]
else:
audio_token = mx.array([[self.audio_padding_token]])
if (audio_token == self.ungenerated_token).any():
if (audio_token == self.ungenerated_token).any(): # type: ignore
raise ValueError(
f"ungenerated value in audio tokens cb: {cb_idx} step: {self.step_idx}"
)
assert audio_token.shape == (1, 1), "invalid audio-tokens shape"
audio_tokens.append(audio_token)
if (text_tokens == self.ungenerated_token).any():
if (text_tokens == self.ungenerated_token).any(): # type: ignore
raise ValueError(f"ungenerated value in text tokens {self.step_idx}")
assert text_tokens.shape == (1, 1), "invalid text-tokens shape"
text_tokens, audio_tokens = self.model.sample(
Expand All @@ -101,8 +101,8 @@ def last_audio_tokens(self) -> Optional[mx.array]:
if gen_idx < 0:
return None
tokens = self.gen_sequence[:, 1 : 1 + self.main_codebooks, gen_idx]
if (tokens == self.audio_padding_token).any():
if (tokens == self.audio_padding_token).any(): # type: ignore
return None
if (tokens == self.ungenerated_token).any():
if (tokens == self.ungenerated_token).any(): # type: ignore
raise ValueError(f"ungenerated value in last-audio tokens {self.step_idx}")
return tokens
Loading