Skip to content

Commit

Permalink
go : adding features to the go-whisper example, go ci, etc (ggerganov…
Browse files Browse the repository at this point in the history
…#384)

* Updated bindings so they can be used in third pary packages.

* Updated makefiles to set FMA flag on optionally, for xeon E5 on Darwin

* Added test script

* Changes for examples

* Reverted

* Made the NewContext method private
  • Loading branch information
djthorpe authored and abitofevrything committed Jan 8, 2023
1 parent 850b55a commit c560b8b
Show file tree
Hide file tree
Showing 10 changed files with 369 additions and 30 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/bindings.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Bindings Tests
on:
push:
paths:
- bindings/go/**

jobs:
ubuntu-latest:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3
with:
go-version: '^1.19'
- uses: actions/checkout@v1
- run: |
cd bindings/go
make test
22 changes: 22 additions & 0 deletions bindings/go/examples/go-whisper/color.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package main

import "fmt"

///////////////////////////////////////////////////////////////////////////////
// CONSTANTS

const (
Reset = "\033[0m"
RGBPrefix = "\033[38;5;" // followed by RGB values in decimal format separated by colons
RGBSuffix = "m"
)

///////////////////////////////////////////////////////////////////////////////
// PUBLIC METHODS

// Colorize text with RGB values, from 0 to 23
func Colorize(text string, v int) string {
// https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit
// Grayscale colors are in the range 232-255
return RGBPrefix + fmt.Sprint(v%24+232) + RGBSuffix + text + Reset
}
97 changes: 96 additions & 1 deletion bindings/go/examples/go-whisper/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ package main

import (
"flag"
"fmt"
"strings"
"time"

// Packages
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
)

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -42,6 +48,26 @@ func (flags *Flags) GetLanguage() string {
return flags.Lookup("language").Value.String()
}

func (flags *Flags) IsTranslate() bool {
return flags.Lookup("translate").Value.(flag.Getter).Get().(bool)
}

func (flags *Flags) GetOffset() time.Duration {
return flags.Lookup("offset").Value.(flag.Getter).Get().(time.Duration)
}

func (flags *Flags) GetDuration() time.Duration {
return flags.Lookup("duration").Value.(flag.Getter).Get().(time.Duration)
}

func (flags *Flags) GetThreads() uint {
return flags.Lookup("threads").Value.(flag.Getter).Get().(uint)
}

func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String())
}

func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
Expand All @@ -50,12 +76,81 @@ func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true"
}

func (flags *Flags) IsColorize() bool {
return flags.Lookup("colorize").Value.String() == "true"
}

func (flags *Flags) GetMaxLen() uint {
return flags.Lookup("max-len").Value.(flag.Getter).Get().(uint)
}

func (flags *Flags) GetMaxTokens() uint {
return flags.Lookup("max-tokens").Value.(flag.Getter).Get().(uint)
}

func (flags *Flags) GetWordThreshold() float32 {
return float32(flags.Lookup("word-thold").Value.(flag.Getter).Get().(float64))
}

func (flags *Flags) SetParams(context whisper.Context) error {
if lang := flags.GetLanguage(); lang != "" && lang != "auto" {
fmt.Fprintf(flags.Output(), "Setting language to %q\n", lang)
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if flags.IsTranslate() && context.IsMultilingual() {
fmt.Fprintf(flags.Output(), "Setting translate to true\n")
context.SetTranslate(true)
}
if offset := flags.GetOffset(); offset != 0 {
fmt.Fprintf(flags.Output(), "Setting offset to %v\n", offset)
context.SetOffset(offset)
}
if duration := flags.GetDuration(); duration != 0 {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration)
}
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads)
}
if max_len := flags.GetMaxLen(); max_len != 0 {
fmt.Fprintf(flags.Output(), "Setting max_segment_length to %d\n", max_len)
context.SetMaxSegmentLength(max_len)
}
if max_tokens := flags.GetMaxTokens(); max_tokens != 0 {
fmt.Fprintf(flags.Output(), "Setting max_tokens to %d\n", max_tokens)
context.SetMaxTokensPerSegment(max_tokens)
}
if word_threshold := flags.GetWordThreshold(); word_threshold != 0 {
fmt.Fprintf(flags.Output(), "Setting word_threshold to %f\n", word_threshold)
context.SetTokenThreshold(word_threshold)
}

// Return success
return nil
}

///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

func registerFlags(flag *Flags) {
flag.String("model", "", "Path to the model file")
flag.String("language", "", "Language")
flag.String("language", "", "Spoken language")
flag.Bool("translate", false, "Translate from source language to english")
flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score")
flag.Bool("tokens", false, "Display tokens")
flag.Bool("colorize", false, "Colorize tokens")
flag.String("out", "", "Output format (srt, none or leave as empty string)")
}
3 changes: 1 addition & 2 deletions bindings/go/examples/go-whisper/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ func main() {

// Process files
for _, filename := range flags.Args() {
fmt.Println("Processing", filename)
if err := Process(model, filename, flags.GetLanguage(), flags.IsSpeedup(), flags.IsTokens()); err != nil {
if err := Process(model, filename, flags); err != nil {
fmt.Fprintln(os.Stderr, err)
continue
}
Expand Down
85 changes: 66 additions & 19 deletions bindings/go/examples/go-whisper/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
wav "github.com/go-audio/wav"
)

func Process(model whisper.Model, path string, lang string, speedup, tokens bool) error {
func Process(model whisper.Model, path string, flags *Flags) error {
var data []float32

// Create processing context
Expand All @@ -20,14 +20,20 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
return err
}

// Set the parameters
if err := flags.SetParams(context); err != nil {
return err
}

// Open the file
fmt.Fprintf(flags.Output(), "Loading %q\n", path)
fh, err := os.Open(path)
if err != nil {
return err
}
defer fh.Close()

// Decode the WAV file
// Decode the WAV file - load the full buffer
dec := wav.NewDecoder(fh)
if buf, err := dec.FullPCMBuffer(); err != nil {
return err
Expand All @@ -39,42 +45,83 @@ func Process(model whisper.Model, path string, lang string, speedup, tokens bool
data = buf.AsFloat32Buffer().Data
}

// Set the parameters
// Segment callback when -tokens is specified
var cb whisper.SegmentCallback
if lang != "" {
if err := context.SetLanguage(lang); err != nil {
return err
}
}
if speedup {
context.SetSpeedup(true)
}
if tokens {
if flags.IsTokens() {
cb = func(segment whisper.Segment) {
fmt.Printf("%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
fmt.Fprintf(flags.Output(), "%02d [%6s->%6s] ", segment.Num, segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
for _, token := range segment.Tokens {
fmt.Printf("%q ", token.Text)
if flags.IsColorize() && context.IsText(token) {
fmt.Fprint(flags.Output(), Colorize(token.Text, int(token.P*24.0)), " ")
} else {
fmt.Fprint(flags.Output(), token.Text, " ")
}
}
fmt.Println("")
fmt.Fprintln(flags.Output(), "")
fmt.Fprintln(flags.Output(), "")
}
}

// Process the data
fmt.Fprintf(flags.Output(), " ...processing %q\n", path)
if err := context.Process(data, cb); err != nil {
return err
}

// Print out the results
switch {
case flags.GetOut() == "srt":
return OutputSRT(os.Stdout, context)
case flags.GetOut() == "none":
return nil
default:
return Output(os.Stdout, context, flags.IsColorize())
}
}

// Output text as SRT file
func OutputSRT(w io.Writer, context whisper.Context) error {
n := 1
for {
segment, err := context.NextSegment()
if err == io.EOF {
break
return nil
} else if err != nil {
return err
}
fmt.Printf("[%6s->%6s] %s\n", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond), segment.Text)
fmt.Fprintln(w, n)
fmt.Fprintln(w, srtTimestamp(segment.Start), " --> ", srtTimestamp(segment.End))
fmt.Fprintln(w, segment.Text)
fmt.Fprintln(w, "")
n++
}
}

// Output text to terminal
func Output(w io.Writer, context whisper.Context, colorize bool) error {
for {
segment, err := context.NextSegment()
if err == io.EOF {
return nil
} else if err != nil {
return err
}
fmt.Fprintf(w, "[%6s->%6s]", segment.Start.Truncate(time.Millisecond), segment.End.Truncate(time.Millisecond))
if colorize {
for _, token := range segment.Tokens {
if !context.IsText(token) {
continue
}
fmt.Fprint(w, " ", Colorize(token.Text, int(token.P*24.0)))
}
fmt.Fprint(w, "\n")
} else {
fmt.Fprintln(w, " ", segment.Text)
}
}
}

// Return success
return nil
// Return srtTimestamp
func srtTimestamp(t time.Duration) string {
return fmt.Sprintf("%02d:%02d:%02d,%03d", t/time.Hour, (t%time.Hour)/time.Minute, (t%time.Minute)/time.Second, (t%time.Second)/time.Millisecond)
}
25 changes: 25 additions & 0 deletions bindings/go/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v)
}

// Set language id
func (p *Params) SetLanguage(lang int) error {
str := C.whisper_lang_str(C.int(lang))
if str == nil {
Expand All @@ -57,25 +58,49 @@ func (p *Params) SetLanguage(lang int) error {
return nil
}

// Get language id
func (p *Params) Language() int {
if p.language == nil {
return -1
}
return int(C.whisper_lang_id(p.language))
}

// Set number of threads to use
func (p *Params) SetThreads(threads int) {
p.n_threads = C.int(threads)
}

// Set start offset in ms
func (p *Params) SetOffset(offset_ms int) {
p.offset_ms = C.int(offset_ms)
}

// Set audio duration to process in ms
func (p *Params) SetDuration(duration_ms int) {
p.duration_ms = C.int(duration_ms)
}

// Set timestamp token probability threshold (~0.01)
func (p *Params) SetTokenThreshold(t float32) {
p.thold_pt = C.float(t)
}

// Set timestamp token sum probability threshold (~0.01)
func (p *Params) SetTokenSumThreshold(t float32) {
p.thold_ptsum = C.float(t)
}

// Set max segment length in characters
func (p *Params) SetMaxSegmentLength(n int) {
p.max_len = C.int(n)
}

// Set max tokens per segment (0 = no limit)
func (p *Params) SetMaxTokensPerSegment(n int) {
p.max_tokens = C.int(n)
}

///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

Expand Down
9 changes: 5 additions & 4 deletions bindings/go/pkg/whisper/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ import (
// ERRORS

var (
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrUnableToLoadModel = errors.New("unable to load model")
ErrInternalAppError = errors.New("internal application error")
ErrProcessingFailed = errors.New("processing failed")
ErrUnsupportedLanguage = errors.New("unsupported language")
ErrModelNotMultilingual = errors.New("model is not multilingual")
)

///////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit c560b8b

Please sign in to comment.