Skip to content

Commit

Permalink
coreml : use the correct n_mel value (#1458)
Browse files Browse the repository at this point in the history
  • Loading branch information
jxy authored Nov 8, 2023
1 parent baeb733 commit 0de8582
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion coreml/whisper-encoder-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v

/**
Make a prediction using the convenience interface
@param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
@param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
Expand Down
4 changes: 4 additions & 0 deletions coreml/whisper-encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Code is derived from the work of Github user @wangchou
// ref: https://github.com/wangchou/callCoreMLFromCpp

#include <stdint.h>

#if __cplusplus
extern "C" {
#endif
Expand All @@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);

void whisper_coreml_encode(
const whisper_coreml_context * ctx,
int64_t n_ctx,
int64_t n_mel,
float * mel,
float * out);

Expand Down
6 changes: 4 additions & 2 deletions coreml/whisper-encoder.mm
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {

void whisper_coreml_encode(
const whisper_coreml_context * ctx,
int64_t n_ctx,
int64_t n_mel,
float * mel,
float * out) {
MLMultiArray * inMultiArray = [
[MLMultiArray alloc] initWithDataPointer: mel
shape: @[@1, @80, @3000]
shape: @[@1, @(n_mel), @(n_ctx)]
dataType: MLMultiArrayDataTypeFloat32
strides: @[@(240000), @(3000), @1]
strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
deallocator: nil
error: nil
];
Expand Down
4 changes: 2 additions & 2 deletions models/convert-whisper-to-coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def install_hooks(layer: nn.Module):
def convert_encoder(hparams, model, quantize=False):
model.eval()

input_shape = (1, 80, 3000)
input_shape = (1, hparams.n_mels, 3000)
input_data = torch.randn(input_shape)
traced_model = torch.jit.trace(model, input_data)

Expand Down Expand Up @@ -302,7 +302,7 @@ def convert_decoder(hparams, model, quantize=False):
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
args = parser.parse_args()

if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]:
raise ValueError("Invalid model name")

whisper = load_model(args.model).cpu()
Expand Down
2 changes: 1 addition & 1 deletion whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
ggml_allocr_alloc(alloc, cur);

if (!ggml_allocr_is_measure(alloc)) {
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
}
#endif
#ifdef WHISPER_USE_OPENVINO
Expand Down

0 comments on commit 0de8582

Please sign in to comment.