Skip to content

Commit

Permalink
Merge pull request #4 from itsmontoya/feature-implement-n-character-g…
Browse files Browse the repository at this point in the history
…rams

feature: implement character  ngrams
  • Loading branch information
itsmontoya committed Jul 15, 2024
2 parents a99b1ba + 565f7c3 commit 292df86
Show file tree
Hide file tree
Showing 10 changed files with 390 additions and 103 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ samples:

- [X] Working implementation as Go library
- [X] Training sets
- [ ] Support Character NGrams
- [X] Support Character NGrams
- [ ] Text normalization added to inbound text processing
- [X] CLI utility

## Long term goals
- [ ] Generated model as MMAP file
9 changes: 5 additions & 4 deletions bag-cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"flag"
"fmt"
"log"
"os"

Expand Down Expand Up @@ -34,10 +33,12 @@ func main() {
return
}

fmt.Printf("TS: %+v\n", t)

a.interactivePrint("Training set loaded\n")
a.b = bag.NewFromTrainingSet(t)
if a.b, err = bag.NewFromTrainingSet(t); err != nil {
log.Fatalf("error initializing from training set: %v\n", err)
return
}

a.interactivePrint("Model generated\n")
a.interactivePrint("Interactive mode is active. Type your input and press Enter:\n")

Expand Down
41 changes: 28 additions & 13 deletions bag.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,32 @@ package bag

import "math"

func New(c Config) *Bag {
func New(c Config) (out *Bag, err error) {
// Validate Config
if err = c.Validate(); err != nil {
return
}

var b Bag
b.c = c
b.vocabByLabel = map[string]Vocabulary{}
b.countByLabel = map[string]int{}
// Fill unset values as default
b.c.fill()
return &b
out = &b
return
}

func NewFromTrainingSet(t TrainingSet) *Bag {
b := New(t.Config)
func NewFromTrainingSet(t TrainingSet) (b *Bag, err error) {
if b, err = New(t.Config); err != nil {
return
}

for label, samples := range t.Samples {
for _, sample := range samples {
b.Train(sample, label)
}
}

return b
return
}

type Bag struct {
Expand All @@ -36,7 +43,7 @@ type Bag struct {

func (b *Bag) GetResults(in string) (r Results) {
// Convert inbound data to NGrams
ns := toNGrams(in, b.c.NGramSize)
ns := b.toNGrams(in)
// Initialize results with the same size as the current number of vocabulary labels
r = make(Results, len(b.vocabByLabel))
// Iterate through vocabulary sets by label
Expand All @@ -47,16 +54,23 @@ func (b *Bag) GetResults(in string) (r Results) {

return
}
func (b *Bag) toNGrams(in string) (ns []string) {
if b.c.NGramType == "word" {
return toNGrams(in, b.c.NGramSize)
}

return tocharacterNGrams(in, b.c.NGramSize)
}

func (b *Bag) Train(in, label string) {
// Convert inbound data to a slice of NGrams
ns := toNGrams(in, b.c.NGramSize)
ns := b.toNGrams(in)
// Get vocabulary for a provided label, if the vocabulary doesn't exist, it will be created)
v := b.getOrCreateVocabulary(label)
// Iterate through NGrams
for _, n := range ns {
// Increment the vocabulary value for the current NGram
v[n.String()]++
v[n]++
}

// Increment count of trained documents for the provided label
Expand All @@ -66,16 +80,17 @@ func (b *Bag) Train(in, label string) {
}

// getProbability uses a Naive Bayes classifier to determine probability for a given label
func (b *Bag) getProbability(ns []NGram, label string, vocab Vocabulary) (probability float64) {
func (b *Bag) getProbability(ns []string, label string, vocab Vocabulary) (probability float64) {
// Set initial probability value as the prior probability value
probability = b.getPriorProbability(label)
// Get the current counts by label (to be used by Laplace smoothing during for-loop)
countsByLabel := float64(b.countByLabel[label] + len(vocab))
countsByLabel := float64(b.countByLabel[label]) + b.c.SmoothingParameter*float64(len(vocab))

// Iterate through NGrams
for _, n := range ns {
// Utilize Laplace smoothing to improve our results when an ngram isn't found within the trained dataset
// Likelihood with Laplace smoothing
count := float64(vocab[n.String()] + b.c.SmoothingParameter)
count := float64(vocab[n]) + b.c.SmoothingParameter
// Add logarithmic result of count (plus )
probability += math.Log(count / countsByLabel)
}
Expand Down
201 changes: 127 additions & 74 deletions bag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,116 @@ package bag

import (
"fmt"
"log"
"os"
"testing"

"github.com/go-yaml/yaml"
)

var (
exampleBag *Bag
exampleResults Results
)

var (
testTrainingYesNo TrainingSet
)

func TestMain(m *testing.M) {
var (
f *os.File
err error
)
if f, err = os.Open("./examples/yes-no-training.yaml"); err != nil {
log.Fatal(err)
}
defer f.Close()

if err = yaml.NewDecoder(f).Decode(&testTrainingYesNo); err != nil {
log.Fatal(err)
}

os.Exit(m.Run())
}

func TestNew(t *testing.T) {
type args struct {
c Config
}

type teststruct struct {
name string
args args
wantErr bool
}

tests := []teststruct{
{
name: "empty",
args: args{
c: Config{},
},
wantErr: false,
},
{
name: "invalid ngram type",
args: args{
c: Config{
NGramType: "foobar",
},
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := New(tt.args.c)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

func TestNewFromTrainingSet(t *testing.T) {
type args struct {
t TrainingSet
}

type teststruct struct {
name string
args args
wantErr bool
}

tests := []teststruct{
{
name: "invalid ngram type",
args: args{
t: TrainingSet{
Config: Config{
NGramType: "foobar",
},
},
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewFromTrainingSet(tt.args.t)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

func TestBag_GetResults(t *testing.T) {
positiveNegative := SamplesByLabel{
"positive": {
Expand All @@ -25,65 +127,6 @@ func TestBag_GetResults(t *testing.T) {
},
}

yesNo := SamplesByLabel{
"yes": {
"Yes",
"Yeah",
"Yep",
"Yup",
"Yea",
"Sure",
"Absolutely",
"Definitely",
"Of course",
"For sure",
"Indeed",
"Affirmative",
"Roger",
"Totally",
"Certainly",
"Without a doubt",
"You bet",
"Uh-huh",
"Right on",
"Cool",
"Okie dokie",
"Aye",
"Yass",
"Fo sho",
"Bet",
"10-4",
},
"no": {
"No",
"Nope",
"Nah",
"Nuh-uh",
"No way",
"Not at all",
"no",
"Not really",
"I don't think so",
"Absolutely not",
"No chance",
"No way, José",
"Out of the question",
"By no means",
"Under no circumstances",
"Never",
"Not in a million years",
"Not happening",
"No can do",
"Not on your life",
"Hell no",
"Nah fam",
"Pass",
"Hard pass",
"Nopey dopey",
"Nix",
},
}

type fields struct {
t TrainingSet
}
Expand Down Expand Up @@ -127,9 +170,7 @@ func TestBag_GetResults(t *testing.T) {
{
name: "yes",
fields: fields{
t: TrainingSet{
Samples: yesNo,
},
t: testTrainingYesNo,
},
args: args{
in: "Oh yes.",
Expand All @@ -139,12 +180,7 @@ func TestBag_GetResults(t *testing.T) {
{
name: "no",
fields: fields{
t: TrainingSet{
Config: Config{
NGramSize: 1,
},
Samples: yesNo,
},
t: testTrainingYesNo,
},
args: args{
in: "Oh no.",
Expand All @@ -155,23 +191,38 @@ func TestBag_GetResults(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := NewFromTrainingSet(tt.fields.t)
b, err := NewFromTrainingSet(tt.fields.t)
if err != nil {
t.Fatal(err)
}

gotR := b.GetResults(tt.args.in).GetHighestProbability()
if gotR != tt.wantMatch {
t.Errorf("Bag.GetResults() = %v, want %v", gotR, tt.wantMatch)
t.Errorf("Bag.GetResults() = wrong value for <%v>: %v, want %v", tt.args.in, gotR, tt.wantMatch)
fmt.Println(b.GetResults(tt.args.in))
}
})
}
}

func ExampleNew() {
var cfg Config
var (
cfg Config
err error
)

// Initialize with default values
exampleBag = New(cfg)
if exampleBag, err = New(cfg); err != nil {
log.Fatal(err)
}
}

func ExampleNewFromTrainingSet() {
var t TrainingSet
var (
t TrainingSet
err error
)

t.Samples = SamplesByLabel{
"positive": {
"I love this product, it is amazing!",
Expand All @@ -187,7 +238,9 @@ func ExampleNewFromTrainingSet() {
}

// Initialize with default values
exampleBag = NewFromTrainingSet(t)
if exampleBag, err = NewFromTrainingSet(t); err != nil {
log.Fatal(err)
}
}

func ExampleBag_Train() {
Expand Down
Loading

0 comments on commit 292df86

Please sign in to comment.