diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3fa5bef..780ec4f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,4 +6,10 @@ repos: language: system entry: bash -c 'cd moshi && flake8' pass_filenames: false + hooks: + - id: pyright-moshi + name: pyrgith on moshi package + language: system + entry: bash -c 'cd moshi && pyright' + pass_filenames: false diff --git a/moshi/moshi/modules/transformer.py b/moshi/moshi/modules/transformer.py index 0a9575e..212d721 100644 --- a/moshi/moshi/modules/transformer.py +++ b/moshi/moshi/modules/transformer.py @@ -482,8 +482,8 @@ def __init__( context=context, rope=rope, weights_per_step=weights_per_step, - **attn_kwargs, - **factory_kwargs, + **attn_kwargs, # type: ignore + **factory_kwargs, # type: ignore ) # type: ignore self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) @@ -539,8 +539,8 @@ def __init__( self.layer_scale_1 = nn.Identity() self.layer_scale_2 = nn.Identity() else: - self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) - self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) + self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore + self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore def _init_streaming_state(self, batch_size: int) -> _LayerState: return _LayerState(offset_cpu=0) diff --git a/moshi/moshi/quantization/base.py b/moshi/moshi/quantization/base.py index e8f0ad4..02228a9 100644 --- a/moshi/moshi/quantization/base.py +++ b/moshi/moshi/quantization/base.py @@ -68,7 +68,7 @@ def num_codebooks(self) -> int: raise NotImplementedError() @property - def semantic_quantizer(self): + def semantic_quantizer(self) -> 'BaseQuantizer': """This returns the quantizer that models the first level of the hierarchy (typically semantic). In this case, it's the quantizer itself. @@ -76,7 +76,7 @@ def semantic_quantizer(self): return self @property - def acoustic_quantizer(self): + def acoustic_quantizer(self) -> 'BaseQuantizer': """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic). In this case, it's the quantizer itself. diff --git a/moshi/moshi/quantization/vq.py b/moshi/moshi/quantization/vq.py index 0e436c1..4fa5b0a 100644 --- a/moshi/moshi/quantization/vq.py +++ b/moshi/moshi/quantization/vq.py @@ -321,12 +321,12 @@ def dimension(self): return self.rvq_first.dimension @property - def semantic_quantizer(self): + def semantic_quantizer(self) -> ResidualVectorQuantizer: """This returns the quantizer that models the first level of the hierarchy (typically semantic).""" return self.rvq_first @property - def acoustic_quantizer(self): + def acoustic_quantizer(self) -> ResidualVectorQuantizer: """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic).""" return self.rvq_rest