Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Encoder Begin Callback for golang binding #2291

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bindings/go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func main() {
if err != nil {
panic(err)
}
if err := context.Process(samples, nil, nil); err != nil {
if err := context.Process(samples, nil, nil, nil); err != nil {
return err
}

Expand Down
2 changes: 1 addition & 1 deletion bindings/go/examples/go-whisper/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error {
// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
context.ResetTimings()
if err := context.Process(data, cb, nil); err != nil {
if err := context.Process(data, nil, cb, nil); err != nil {
return err
}

Expand Down
39 changes: 23 additions & 16 deletions bindings/go/pkg/whisper/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f
// Process new sample data and return any errors
func (context *context) Process(
data []float32,
callEncoderBegin EncoderBeginCallback,
callNewSegment SegmentCallback,
callProgress ProgressCallback,
) error {
Expand All @@ -177,30 +178,32 @@ func (context *context) Process(
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin,
func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
if callNewSegment != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
callNewSegment(toSegment(context.model.ctx, i))
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}
}, func(progress int) {
if callProgress != nil {
callProgress(progress)
}
}); err != nil {
}); err != nil {
return err
}

Expand Down Expand Up @@ -286,6 +289,10 @@ func (context *context) IsLANG(t Token, lang string) bool {
}
}

func (context *context) GetDetectedLanguage() string {
return whisper.Whisper_lang_str(context.model.ctx.Whisper_full_lang_id())
}

///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

Expand Down
18 changes: 12 additions & 6 deletions bindings/go/pkg/whisper/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ type SegmentCallback func(Segment)
// processing. It is called during the Process function
type ProgressCallback func(int)

// EncoderBeginCallback is the callback function for checking if we want to
// continue processing. It is called during the Process function
type EncoderBeginCallback func() bool

// Model is the interface to a whisper model. Create a new model with the
// function whisper.New(string)
type Model interface {
Expand All @@ -31,12 +35,14 @@ type Model interface {
Languages() []string
}

// Context is the speach recognition context.
// Context is the speech recognition context.
type Context interface {
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language
SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language.
SetTranslate(bool) // Set translate flag
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language
GetDetectedLanguage() string // Get auto detected language


SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
Expand All @@ -53,7 +59,7 @@ type Context interface {
// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
// callback function during processing.
Process([]float32, SegmentCallback, ProgressCallback) error
Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error

// After process is called, return segments until the end of the stream
// is reached, when io.EOF is returned.
Expand Down
2 changes: 1 addition & 1 deletion bindings/go/pkg/whisper/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (model *model) NewContext() (Context, error) {
}

// Create new context
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_GREEDY)
params := model.ctx.Whisper_full_default_params(whisper.SAMPLING_BEAM_SEARCH)
params.SetTranslate(false)
params.SetPrintSpecial(false)
params.SetPrintProgress(false)
Expand Down
8 changes: 7 additions & 1 deletion bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,13 @@ func Whisper_print_system_info() string {
// Return default parameters for a strategy
func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Params {
// Get default parameters
return Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))
p := Params(C.whisper_full_default_params_cb((*C.struct_whisper_context)(ctx), C.enum_whisper_sampling_strategy(strategy)))

p.greedy.best_of = 5
p.thold_pt = 0
p.thold_ptsum = 0

return p
}

// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
Expand Down
Loading