Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

go bindings: Adding features to the go-whisper example, etc #384

Merged
merged 9 commits into from
Jan 7, 2023
17 changes: 17 additions & 0 deletions .github/workflows/bindings.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Bindings Tests
on:
push:
paths:
- bindings/go/**

jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3
with:
go-version: '^1.19'
- uses: actions/checkout@v1
- run: |
cd bindings/go
make test
22 changes: 22 additions & 0 deletions bindings/go/examples/go-whisper/color.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package main

import "fmt"

///////////////////////////////////////////////////////////////////////////////
// CONSTANTS

const (
Reset = "\033[0m"
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
RGBSuffix = "m"
)

///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS

// Colorize text with RGB values, from 0 to 23
func Colorize(text string, v int) string {
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
// Grayscale colors are in the range 232-255
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
}
97 changes: 96 additions & 1 deletion bindings/go/examples/go-whisper/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ package main

import (
"flag"
"fmt"
"strings"
"time"

// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -42,6 +48,26 @@ func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String()
}

func (flags *Flags) IsTranslate() bool {
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
}

func (flags *Flags) GetOffset() time.Duration {
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
}

func (flags *Flags) GetDuration() time.Duration {
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
}

func (flags *Flags) GetThreads() uint {
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
}

func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String())
}

func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
Expand All @@ -50,12 +76,81 @@ func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true"
}

func (flags *Flags) IsColorize() bool {
return flags.Lookup("colorize").Value.String() == "true"
}

func (flags *Flags) GetMaxLen() uint {
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
}

func (flags *Flags) GetMaxTokens() uint {
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
}

func (flags *Flags) GetWordThreshold() float32 {
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
}

func (flags *Flags) SetParams(context whisper.Context) error {
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if flags.IsTranslate() && context.IsMultilingual() {
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
context.SetTranslate(true)
}
if offset := flags.GetOffset(); offset != 0 {
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
context.SetOffset(offset)
}
if duration := flags.GetDuration(); duration != 0 {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration)
}
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads)
}
if max_len := flags.GetMaxLen(); max_len != 0 {
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
context.SetMaxSegmentLength(max_len)
}
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
context.SetMaxTokensPerSegment(max_tokens)
}
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
context.SetTokenThreshold(word_threshold)
}

// Return success
return nil
}

///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file")
flag.String("language", "", "Language")
flag.String("language", "", "Spoken language")
flag.Bool("translate", false, "Translate from source language to english")
flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score")
flag.Bool("tokens", false, "Display tokens")
flag.Bool("colorize", false, "Colorize tokens")
flag.String("out", "", "Output format (srt, none or leave as empty string)")
}
3 changes: 1 addition & 2 deletions bindings/go/examples/go-whisper/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ func main() {

// Process files
for _, filename := range flags.Args() {
fmt.Println("Processing", filename)
if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
if err := Process(model, filename, flags); err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
Expand Down
85 changes: 66 additions & 19 deletions bindings/go/examples/go-whisper/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
wav "github.com/go-audio/wav"
)

func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
func Process(model whisper.Model, path string, flags *Flags) error {
var data []float32

// Create processing context
Expand All @@ -20,14 +20,20 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
return err
}

// Set the parameters
if err := flags.SetParams(context); err != nil {
return err
}

// Open the file
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
fh, err := os.Open(path)
if err != nil {
return err
}
defer fh.Close()

// Decode the WAV file
// Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
if buf, err := dec.FullPCMBuffer(); err != nil {
return err
Expand All @@ -39,42 +45,83 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
data = buf.AsFloat32Buffer().Data
}

// Set the parameters
// Segment callback when -tokens is specified
var cb whisper.SegmentCallback
if lang != "" {
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if speedup {
context.SetSpeedup(true)
}
if tokens {
if flags.IsTokens() {
cb = func(segment whisper.Segment) {
fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
for _, token := range segment.Tokens {
fmt.Printf("%q ", token.Text)
if flags.IsColorize() && context.IsText(token) {
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
} else {
fmt.Fprint(flags.Output(), token.Text, " ")
}
}
fmt.Println("")
fmt.Fprintln(flags.Output(), "")
fmt.Fprintln(flags.Output(), "")
}
}

// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
if err := context.Process(data, cb); err != nil {
return err
}

// Print out the results
switch {
case flags.GetOut() == "srt":
return OutputSRT(os.Stdout, context)
case flags.GetOut() == "none":
return nil
default:
return Output(os.Stdout, context, flags.IsColorize())
}
}

// Output text as SRT file
func OutputSRT(w io.Writer, context whisper.Context) error {
n := 1
for {
segment, err := context.NextSegment()
if err == io.EOF {
break
return nil
} else if err != nil {
return err
}
fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
fmt.Fprintln(w, n)
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
fmt.Fprintln(w, segment.Text)
fmt.Fprintln(w, "")
n++
}
}

// Output text to terminal
func Output(w io.Writer, context whisper.Context, colorize bool) error {
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
if colorize {
for _, token := range segment.Tokens {
if !context.IsText(token) {
continue
}
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
}
fmt.Fprint(w, "\n")
} else {
fmt.Fprintln(w, " ", segment.Text)
}
}
}

// Return success
return nil
// Return srtTimestamp
func srtTimestamp(t time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
}
25 changes: 25 additions & 0 deletions bindings/go/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v)
}

// Set language id
func (p *Params) SetLanguage(lang int) error {
str := C.whisper_lang_str(C.int(lang))
if str == nil {
Expand All @@ -57,25 +58,49 @@ func (p *Params) SetLanguage(lang int) error {
return nil
}

// Get language id
func (p *Params) Language() int {
if p.language == nil {
return -1
}
return int(C.whisper_lang_id(p.language))
}

// Set number of threads to use
func (p *Params) SetThreads(threads int) {
p.n_threads = C.int(threads)
}

// Set start offset in ms
func (p *Params) SetOffset(offset_ms int) {
p.offset_ms = C.int(offset_ms)
}

// Set audio duration to process in ms
func (p *Params) SetDuration(duration_ms int) {
p.duration_ms = C.int(duration_ms)
}

// Set timestamp token probability threshold (~0.01)
func (p *Params) SetTokenThreshold(t float32) {
p.thold_pt = C.float(t)
}

// Set timestamp token sum probability threshold (~0.01)
func (p *Params) SetTokenSumThreshold(t float32) {
p.thold_ptsum = C.float(t)
}

// Set max segment length in characters
func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n)
}

// Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)
}

///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

Expand Down
9 changes: 5 additions & 4 deletions bindings/go/pkg/whisper/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ import (
// ERRORS

var (
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrModelNotMultilingual = errors.New("model is not multilingual")
)

///////////////////////////////////////////////////////////////////////////////
Expand Down
Loading