diff --git a/.github/workflows/bindings.yml b/.github/workflows/bindings.yml new file mode 100644 index 00000000000..1bccf59e76b --- /dev/null +++ b/.github/workflows/bindings.yml @@ -0,0 +1,17 @@ +name: Bindings Tests +on: + push: + paths: + - bindings/go/** + +jobs: + ubuntu-latest: + runs-on: ubuntu-latest + steps: + - uses: actions/setup-go@v3 + with: + go-version: '^1.19' + - uses: actions/checkout@v1 + - run: | + cd bindings/go + make test diff --git a/bindings/go/examples/go-whisper/color.go b/bindings/go/examples/go-whisper/color.go new file mode 100644 index 00000000000..fa5ac2f26ba --- /dev/null +++ b/bindings/go/examples/go-whisper/color.go @@ -0,0 +1,22 @@ +package main + +import "fmt" + +/////////////////////////////////////////////////////////////////////////////// +// CONSTANTS + +const ( + Reset = "\033[0m" + RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons + RGBSuffix = "m" +) + +/////////////////////////////////////////////////////////////////////////////// +// PUBLIC METHODS + +// Colorize text with RGB values, from 0 to 23 +func Colorize(text string, v int) string { + // https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit + // Grayscale colors are in the range 232-255 + return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset +} diff --git a/bindings/go/examples/go-whisper/flags.go b/bindings/go/examples/go-whisper/flags.go index a5353d1c83a..ea204455c80 100644 --- a/bindings/go/examples/go-whisper/flags.go +++ b/bindings/go/examples/go-whisper/flags.go @@ -2,6 +2,12 @@ package main import ( "flag" + "fmt" + "strings" + "time" + + // Packages + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" ) /////////////////////////////////////////////////////////////////////////////// @@ -42,6 +48,26 @@ func (flags *Flags) GetLanguage() string { return flags.Lookup("language").Value.String() } +func (flags *Flags) IsTranslate() bool { + return flags.Lookup("translate").Value.(flag.Getter).Get().(bool) +} + +func (flags *Flags) GetOffset() time.Duration { + return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration) +} + +func (flags *Flags) GetDuration() time.Duration { + return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration) +} + +func (flags *Flags) GetThreads() uint { + return flags.Lookup("threads").Value.(flag.Getter).Get().(uint) +} + +func (flags *Flags) GetOut() string { + return strings.ToLower(flags.Lookup("out").Value.String()) +} + func (flags *Flags) IsSpeedup() bool { return flags.Lookup("speedup").Value.String() == "true" } @@ -50,12 +76,81 @@ func (flags *Flags) IsTokens() bool { return flags.Lookup("tokens").Value.String() == "true" } +func (flags *Flags) IsColorize() bool { + return flags.Lookup("colorize").Value.String() == "true" +} + +func (flags *Flags) GetMaxLen() uint { + return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint) +} + +func (flags *Flags) GetMaxTokens() uint { + return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint) +} + +func (flags *Flags) GetWordThreshold() float32 { + return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64)) +} + +func (flags *Flags) SetParams(context whisper.Context) error { + if lang := flags.GetLanguage(); lang != "" && lang != "auto" { + fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang) + if err := context.SetLanguage(lang); err != nil { + return err + } + } + if flags.IsTranslate() && context.IsMultilingual() { + fmt.Fprintf(flags.Output(), "Setting translate to true\n") + context.SetTranslate(true) + } + if offset := flags.GetOffset(); offset != 0 { + fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset) + context.SetOffset(offset) + } + if duration := flags.GetDuration(); duration != 0 { + fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration) + context.SetDuration(duration) + } + if flags.IsSpeedup() { + fmt.Fprintf(flags.Output(), "Setting speedup to true\n") + context.SetSpeedup(true) + } + if threads := flags.GetThreads(); threads != 0 { + fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads) + context.SetThreads(threads) + } + if max_len := flags.GetMaxLen(); max_len != 0 { + fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len) + context.SetMaxSegmentLength(max_len) + } + if max_tokens := flags.GetMaxTokens(); max_tokens != 0 { + fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens) + context.SetMaxTokensPerSegment(max_tokens) + } + if word_threshold := flags.GetWordThreshold(); word_threshold != 0 { + fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold) + context.SetTokenThreshold(word_threshold) + } + + // Return success + return nil +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS func registerFlags(flag *Flags) { flag.String("model", "", "Path to the model file") - flag.String("language", "", "Language") + flag.String("language", "", "Spoken language") + flag.Bool("translate", false, "Translate from source language to english") + flag.Duration("offset", 0, "Time offset") + flag.Duration("duration", 0, "Duration of audio to process") + flag.Uint("threads", 0, "Number of threads to use") flag.Bool("speedup", false, "Enable speedup") + flag.Uint("max-len", 0, "Maximum segment length in characters") + flag.Uint("max-tokens", 0, "Maximum tokens per segment") + flag.Float64("word-thold", 0, "Maximum segment score") flag.Bool("tokens", false, "Display tokens") + flag.Bool("colorize", false, "Colorize tokens") + flag.String("out", "", "Output format (srt, none or leave as empty string)") } diff --git a/bindings/go/examples/go-whisper/main.go b/bindings/go/examples/go-whisper/main.go index b3a89db7552..1bff7f5d50a 100644 --- a/bindings/go/examples/go-whisper/main.go +++ b/bindings/go/examples/go-whisper/main.go @@ -35,8 +35,7 @@ func main() { // Process files for _, filename := range flags.Args() { - fmt.Println("Processing", filename) - if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil { + if err := Process(model, filename, flags); err != nil { fmt.Fprintln(os.Stderr, err) continue } diff --git a/bindings/go/examples/go-whisper/process.go b/bindings/go/examples/go-whisper/process.go index a0e2be86c9b..aacdc6965be 100644 --- a/bindings/go/examples/go-whisper/process.go +++ b/bindings/go/examples/go-whisper/process.go @@ -11,7 +11,7 @@ import ( wav "github.com/go-audio/wav" ) -func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error { +func Process(model whisper.Model, path string, flags *Flags) error { var data []float32 // Create processing context @@ -20,14 +20,20 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool return err } + // Set the parameters + if err := flags.SetParams(context); err != nil { + return err + } + // Open the file + fmt.Fprintf(flags.Output(), "Loading %q\n", path) fh, err := os.Open(path) if err != nil { return err } defer fh.Close() - // Decode the WAV file + // Decode the WAV file - load the full buffer dec := wav.NewDecoder(fh) if buf, err := dec.FullPCMBuffer(); err != nil { return err @@ -39,42 +45,83 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool data = buf.AsFloat32Buffer().Data } - // Set the parameters + // Segment callback when -tokens is specified var cb whisper.SegmentCallback - if lang != "" { - if err := context.SetLanguage(lang); err != nil { - return err - } - } - if speedup { - context.SetSpeedup(true) - } - if tokens { + if flags.IsTokens() { cb = func(segment whisper.Segment) { - fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond)) + fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond)) for _, token := range segment.Tokens { - fmt.Printf("%q ", token.Text) + if flags.IsColorize() && context.IsText(token) { + fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ") + } else { + fmt.Fprint(flags.Output(), token.Text, " ") + } } - fmt.Println("") + fmt.Fprintln(flags.Output(), "") + fmt.Fprintln(flags.Output(), "") } } // Process the data + fmt.Fprintf(flags.Output(), " ...processing %q\n", path) if err := context.Process(data, cb); err != nil { return err } // Print out the results + switch { + case flags.GetOut() == "srt": + return OutputSRT(os.Stdout, context) + case flags.GetOut() == "none": + return nil + default: + return Output(os.Stdout, context, flags.IsColorize()) + } +} + +// Output text as SRT file +func OutputSRT(w io.Writer, context whisper.Context) error { + n := 1 for { segment, err := context.NextSegment() if err == io.EOF { - break + return nil } else if err != nil { return err } - fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text) + fmt.Fprintln(w, n) + fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End)) + fmt.Fprintln(w, segment.Text) + fmt.Fprintln(w, "") + n++ } +} + +// Output text to terminal +func Output(w io.Writer, context whisper.Context, colorize bool) error { + for { + segment, err := context.NextSegment() + if err == io.EOF { + return nil + } else if err != nil { + return err + } + fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond)) + if colorize { + for _, token := range segment.Tokens { + if !context.IsText(token) { + continue + } + fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0))) + } + fmt.Fprint(w, "\n") + } else { + fmt.Fprintln(w, " ", segment.Text) + } + } +} - // Return success - return nil +// Return srtTimestamp +func srtTimestamp(t time.Duration) string { + return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond) } diff --git a/bindings/go/params.go b/bindings/go/params.go index c67a7299b85..d7dc238f5ad 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -47,6 +47,7 @@ func (p *Params) SetSpeedup(v bool) { p.speed_up = toBool(v) } +// Set language id func (p *Params) SetLanguage(lang int) error { str := C.whisper_lang_str(C.int(lang)) if str == nil { @@ -57,6 +58,7 @@ func (p *Params) SetLanguage(lang int) error { return nil } +// Get language id func (p *Params) Language() int { if p.language == nil { return -1 @@ -64,18 +66,41 @@ func (p *Params) Language() int { return int(C.whisper_lang_id(p.language)) } +// Set number of threads to use func (p *Params) SetThreads(threads int) { p.n_threads = C.int(threads) } +// Set start offset in ms func (p *Params) SetOffset(offset_ms int) { p.offset_ms = C.int(offset_ms) } +// Set audio duration to process in ms func (p *Params) SetDuration(duration_ms int) { p.duration_ms = C.int(duration_ms) } +// Set timestamp token probability threshold (~0.01) +func (p *Params) SetTokenThreshold(t float32) { + p.thold_pt = C.float(t) +} + +// Set timestamp token sum probability threshold (~0.01) +func (p *Params) SetTokenSumThreshold(t float32) { + p.thold_ptsum = C.float(t) +} + +// Set max segment length in characters +func (p *Params) SetMaxSegmentLength(n int) { + p.max_len = C.int(n) +} + +// Set max tokens per segment (0 = no limit) +func (p *Params) SetMaxTokensPerSegment(n int) { + p.max_tokens = C.int(n) +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS diff --git a/bindings/go/pkg/whisper/consts.go b/bindings/go/pkg/whisper/consts.go index 710073f08e2..5c22dc13a31 100644 --- a/bindings/go/pkg/whisper/consts.go +++ b/bindings/go/pkg/whisper/consts.go @@ -11,10 +11,11 @@ import ( // ERRORS var ( - ErrUnableToLoadModel = errors.New("unable to load model") - ErrInternalAppError = errors.New("internal application error") - ErrProcessingFailed = errors.New("processing failed") - ErrUnsupportedLanguage = errors.New("unsupported language") + ErrUnableToLoadModel = errors.New("unable to load model") + ErrInternalAppError = errors.New("internal application error") + ErrProcessingFailed = errors.New("processing failed") + ErrUnsupportedLanguage = errors.New("unsupported language") + ErrModelNotMultilingual = errors.New("model is not multilingual") ) /////////////////////////////////////////////////////////////////////////////// diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index baff611c813..5dda57e97e6 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -24,7 +24,7 @@ var _ Context = (*context)(nil) /////////////////////////////////////////////////////////////////////////////// // LIFECYCLE -func NewContext(model *model, params whisper.Params) (Context, error) { +func newContext(model *model, params whisper.Params) (Context, error) { context := new(context) context.model = model context.params = params @@ -41,6 +41,9 @@ func (context *context) SetLanguage(lang string) error { if context.model.ctx == nil { return ErrInternalAppError } + if !context.model.IsMultilingual() { + return ErrModelNotMultilingual + } if id := context.model.ctx.Whisper_lang_id(lang); id < 0 { return ErrUnsupportedLanguage } else if err := context.params.SetLanguage(id); err != nil { @@ -50,16 +53,60 @@ func (context *context) SetLanguage(lang string) error { return nil } +func (context *context) IsMultilingual() bool { + return context.model.IsMultilingual() +} + // Get language func (context *context) Language() string { return whisper.Whisper_lang_str(context.params.Language()) } +// Set translate flag +func (context *context) SetTranslate(v bool) { + context.params.SetTranslate(v) +} + // Set speedup flag func (context *context) SetSpeedup(v bool) { context.params.SetSpeedup(v) } +// Set number of threads to use +func (context *context) SetThreads(v uint) { + context.params.SetThreads(int(v)) +} + +// Set time offset +func (context *context) SetOffset(v time.Duration) { + context.params.SetOffset(int(v.Milliseconds())) +} + +// Set duration of audio to process +func (context *context) SetDuration(v time.Duration) { + context.params.SetOffset(int(v.Milliseconds())) +} + +// Set timestamp token probability threshold (~0.01) +func (context *context) SetTokenThreshold(t float32) { + context.params.SetTokenThreshold(t) +} + +// Set timestamp token sum probability threshold (~0.01) +func (context *context) SetTokenSumThreshold(t float32) { + context.params.SetTokenSumThreshold(t) +} + +// Set max segment length in characters +func (context *context) SetMaxSegmentLength(n uint) { + context.params.SetMaxSegmentLength(int(n)) +} + +// Set max tokens per segment (0 = no limit) +func (context *context) SetMaxTokensPerSegment(n uint) { + context.params.SetMaxTokensPerSegment(int(n)) +} + // Process new sample data and return any errors func (context *context) Process(data []float32, cb SegmentCallback) error { if context.model.ctx == nil { @@ -119,6 +166,65 @@ func (context *context) NextSegment() (Segment, error) { return result, nil } +// Test for text tokens +func (context *context) IsText(t Token) bool { + switch { + case context.IsBEG(t): + return false + case context.IsSOT(t): + return false + case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot(): + return false + case context.IsPREV(t): + return false + case context.IsSOLM(t): + return false + case context.IsNOT(t): + return false + default: + return true + } +} + +// Test for "begin" token +func (context *context) IsBEG(t Token) bool { + return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg() +} + +// Test for "start of transcription" token +func (context *context) IsSOT(t Token) bool { + return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot() +} + +// Test for "end of transcription" token +func (context *context) IsEOT(t Token) bool { + return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot() +} + +// Test for "start of prev" token +func (context *context) IsPREV(t Token) bool { + return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev() +} + +// Test for "start of lm" token +func (context *context) IsSOLM(t Token) bool { + return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm() +} + +// Test for "No timestamps" token +func (context *context) IsNOT(t Token) bool { + return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not() +} + +// Test for token associated with a specific language +func (context *context) IsLANG(t Token, lang string) bool { + if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 { + return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id) + } else { + return false + } +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index 53e4f3f0257..5ca913a8f72 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -20,6 +20,9 @@ type Model interface { // Return a new speech-to-text context. NewContext() (Context, error) + // Return true if the model is multilingual. + IsMultilingual() bool + // Return all languages supported. Languages() []string } @@ -27,8 +30,18 @@ type Model interface { // Context is the speach recognition context. type Context interface { SetLanguage(string) error // Set the language to use for speech recognition. + SetTranslate(bool) // Set translate flag + IsMultilingual() bool // Return true if the model is multilingual. Language() string // Get language - SetSpeedup(bool) // Set speedup flag + + SetOffset(time.Duration) // Set offset + SetDuration(time.Duration) // Set duration + SetThreads(uint) // Set number of threads to use + SetSpeedup(bool) // Set speedup flag + SetTokenThreshold(float32) // Set timestamp token probability threshold + SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold + SetMaxSegmentLength(uint) // Set max segment length in characters + SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit) // Process mono audio data and return any errors. // If defined, newly generated segments are passed to the @@ -38,6 +51,15 @@ type Context interface { // After process is called, return segments until the end of the stream // is reached, when io.EOF is returned. NextSegment() (Segment, error) + + IsBEG(Token) bool // Test for "begin" token + IsSOT(Token) bool // Test for "start of transcription" token + IsEOT(Token) bool // Test for "end of transcription" token + IsPREV(Token) bool // Test for "start of prev" token + IsSOLM(Token) bool // Test for "start of lm" token + IsNOT(Token) bool // Test for "No timestamps" token + IsLANG(Token, string) bool // Test for token associated with a specific language + IsText(Token) bool // Test for text token } // Segment is the text result of a speech recognition. diff --git a/bindings/go/pkg/whisper/model.go b/bindings/go/pkg/whisper/model.go index 13cb52ca7ec..94c2197db73 100644 --- a/bindings/go/pkg/whisper/model.go +++ b/bindings/go/pkg/whisper/model.go @@ -23,7 +23,7 @@ var _ Model = (*model)(nil) /////////////////////////////////////////////////////////////////////////////// // LIFECYCLE -func New(path string) (*model, error) { +func New(path string) (Model, error) { model := new(model) if _, err := os.Stat(path); err != nil { return nil, err @@ -64,6 +64,11 @@ func (model *model) String() string { /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS +// Return true if model is multilingual (language and translation options are supported) +func (model *model) IsMultilingual() bool { + return model.ctx.Whisper_is_multilingual() != 0 +} + // Return all recognized languages. Initially it is set to auto-detect func (model *model) Languages() []string { result := make([]string, 0, whisper.Whisper_lang_max_id()) @@ -91,5 +96,5 @@ func (model *model) NewContext() (Context, error) { params.SetThreads(runtime.NumCPU()) // Return new context - return NewContext(model, params) + return newContext(model, params) }