Skip to content

Commit

Permalink
Merge pull request #15 from GopherML/pre-v1.0.0-clean-up
Browse files Browse the repository at this point in the history
Pre v1.0.0 clean up
  • Loading branch information
itsmontoya committed Jul 19, 2024
2 parents 31ac2c8 + bd4f5db commit 70b2112
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 31 deletions.
49 changes: 40 additions & 9 deletions bag.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package bag

import "math"

// New will initialize and return a new Bag with a provided configuration
func New(c Config) (out *Bag, err error) {
// Validate Config
if err = c.Validate(); err != nil {
Expand All @@ -11,36 +12,56 @@ func New(c Config) (out *Bag, err error) {
var b Bag
b.c = c
b.vocabByLabel = map[string]Vocabulary{}
b.countByLabel = map[string]int{}
b.documentCountByLabel = map[string]int{}
out = &b
return
}

// NewFromTrainingSet will initialize and return a new pre-trained Bag from a provided training set
func NewFromTrainingSet(t TrainingSet) (b *Bag, err error) {
if b, err = New(t.Config); err != nil {
// Error initializing, return
return
}

// Train with provided samples, iterate over samples by label
for label, samples := range t.Samples {
// For each within samples slice
for _, sample := range samples {
// Train for a given sample and label
b.Train(sample, label)
}
}

return
}

// NewFromTrainingSetFile will initialize and return a new pre-trained Bag from a provided training set filepath
func NewFromTrainingSetFile(filepath string) (b *Bag, err error) {
var t TrainingSet
// Make new training set
if t, err = makeTrainingSetFromFile(filepath); err != nil {
// Error making training set, return
return
}

// Create new Bag from training set
return NewFromTrainingSet(t)
}

// Bag represents a bag of words (BoW) model
type Bag struct {
// Configuration values
c Config
// Vocabulary sets by label
vocabByLabel map[string]Vocabulary
// Count of trained documents by label
countByLabel map[string]int
documentCountByLabel map[string]int
// Total count of trained documents
totalCount int
totalDocumentCount int
}

// GetResults will return the classification results for a given input string
func (b *Bag) GetResults(in string) (r Results) {
// Convert inbound data to NGrams
ns := b.toNGrams(in)
Expand All @@ -55,6 +76,7 @@ func (b *Bag) GetResults(in string) (r Results) {
return
}

// Train will process a given input string and assign it the provided label for training
func (b *Bag) Train(in, label string) {
// Convert inbound data to a slice of NGrams
ns := b.toNGrams(in)
Expand All @@ -73,9 +95,11 @@ func (b *Bag) Train(in, label string) {
// toNGrams converts the inbound string into n-grams based on the configuration settings
func (b *Bag) toNGrams(in string) (ns []string) {
if b.c.NGramType == "word" {
// NGram type is word, use n-grams
return toNGrams(in, b.c.NGramSize)
}

// NGram type is character, use character n-grams
return toCharacterNGrams(in, b.c.NGramSize)
}

Expand All @@ -84,7 +108,7 @@ func (b *Bag) getProbability(ns []string, label string, vocab Vocabulary) (proba
// Set initial probability value as the prior probability value
probability = b.getLogPriorProbability(label)
// Get the current counts by label (to be used by Laplace smoothing during for-loop)
countsByLabel := float64(b.countByLabel[label]) + b.c.SmoothingParameter*float64(len(vocab))
countsByLabel := float64(b.documentCountByLabel[label]) + b.c.SmoothingParameter*float64(len(vocab))

// Iterate through NGrams
for _, n := range ns {
Expand All @@ -98,15 +122,21 @@ func (b *Bag) getProbability(ns []string, label string, vocab Vocabulary) (proba
return
}

// getLogPriorProbability will get the starting probability value for a given label
func (b *Bag) getLogPriorProbability(label string) (probability float64) {
count := float64(b.countByLabel[label])
total := float64(b.totalCount)
// Document count for the given label
countByLabel := float64(b.documentCountByLabel[label])
// Total document count
total := float64(b.totalDocumentCount)
// Get the logarithmic value of count divided by total count
return math.Log(count / total)
return math.Log(countByLabel / total)
}

// getOrCreate vocabulary will get a vocabulary set for a given label,
// if the vocabulary doesn't exist - it is created
func (b *Bag) getOrCreateVocabulary(label string) (v Vocabulary) {
var ok bool
// Attempt to get vocabulary for the given label
v, ok = b.vocabByLabel[label]
// Check if vocabulary set does not exist for the provided label
if !ok {
Expand All @@ -119,9 +149,10 @@ func (b *Bag) getOrCreateVocabulary(label string) (v Vocabulary) {
return
}

// incrementCounts will increment trained documents count globally and by label
func (b *Bag) incrementCounts(label string) {
// Increment count of trained documents for the provided label
b.countByLabel[label]++
b.documentCountByLabel[label]++
// Increment total count of trained documents
b.totalCount++
b.totalDocumentCount++
}
69 changes: 57 additions & 12 deletions bag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"log"
"os"
"testing"

"github.com/go-yaml/yaml"
)

var (
Expand All @@ -19,16 +17,8 @@ var (
)

func TestMain(m *testing.M) {
var (
f *os.File
err error
)
if f, err = os.Open("./examples/yes-no-training.yaml"); err != nil {
log.Fatal(err)
}
defer f.Close()

if err = yaml.NewDecoder(f).Decode(&testTrainingYesNo); err != nil {
var err error
if testTrainingYesNo, err = makeTrainingSetFromFile("./examples/yes-no-training.yaml"); err != nil {
log.Fatal(err)
}

Expand Down Expand Up @@ -112,6 +102,61 @@ func TestNewFromTrainingSet(t *testing.T) {
}
}

func TestNewFromTrainingSetFile(t *testing.T) {
type args struct {
filepath string
}

type teststruct struct {
name string
args args

wantConfig Config
wantErr bool
}

tests := []teststruct{
{
name: "basic",
args: args{
filepath: "./examples/yes-no-training.yaml",
},
wantConfig: Config{
NGramSize: 2,
NGramType: "character",
SmoothingParameter: 1,
},
wantErr: false,
},
{
name: "no file",
args: args{
filepath: "./examples/no_exists.yaml",
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b, err := NewFromTrainingSetFile(tt.args.filepath)
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}

if err != nil {
return
}

if b.c != tt.wantConfig {
t.Errorf("New() bag.Config = %+v, wantConfig %+v", b.c, tt.wantConfig)
return
}
})
}
}

func TestBag_GetResults(t *testing.T) {
positiveNegative := SamplesByLabel{
"positive": {
Expand Down
35 changes: 25 additions & 10 deletions circularbuffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,58 @@ type circularBuffer[T any] struct {
s []T
}

// Shift will add an item to the end of the circular buffer,
// if the buffer is full - it will pop an item from the front
func (c *circularBuffer[T]) Shift(item T) (popped T) {
// Get oldest buffer item as popped value
popped = c.s[c.end]
// Replace oldest position with new item
c.s[c.end] = item

c.end++
if c.len < c.cap {
c.len++
} else {
if c.start++; c.start >= c.cap {
c.start = 0
}

// Increment oldest position and check to see if the oldest position exceeds capacity
if c.end++; c.end >= c.cap {
// Oldest position exceeds capacity, set to 0
c.end = 0
}

if c.end >= c.cap {
c.end = 0
// Check to see if length is less than capaicity
if c.len < c.cap {
// Length is not at capacity, increment
c.len++
// Increment start value and check to see if new value exceeds capacity
} else if c.start++; c.start >= c.cap {
// New start value exceeds start capacity, set to 0
c.start = 0
}

return
}

// ForEach will iterate through the buffer items
func (c *circularBuffer[T]) ForEach(fn func(t T) (end bool)) (ended bool) {
// First index is at starting position
index := c.start
// Iterate for the length of the buffer
for i := 0; i < c.len; i++ {
// Get item at current index
item := c.s[index]
// Pass item to func
if fn(item) {
// Func returned break boolean as true, return true
return true
}

// Increment index and see if index exceeds length
if index++; index >= c.len {
// Index exceeds length, set to 0
index = 0
}
}

return
}

// Len will return the length of a circular buffer
func (c *circularBuffer[T]) Len() int {
return c.len
}
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Config struct {
func (c *Config) Validate() (err error) {
c.fill()

// Check to see if n-gram type is supported
switch c.NGramType {
case "word":
case "character":
Expand Down
6 changes: 6 additions & 0 deletions results.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@ package bag

import "math"

// Results (by label) represents the probability of a processed input matching each of the possible labels (classifications)
type Results map[string]float64

func (r Results) GetHighestProbability() (match string) {
// Since probability values can be negative, initialize to negative infinity
max := math.Inf(-1)
// Iterate through probability results
for label, prob := range r {
// Check to see if the current probability is higher than the max
if prob > max {
// Current probability is higher
// Set max as the current probability
max = prob
// Set match as the current label
match = label
}
}
Expand Down
1 change: 1 addition & 0 deletions samples.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
package bag

// Samples represents a set of input samples to be used for model training
type Samples []string
1 change: 1 addition & 0 deletions samplesbylabel.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
package bag

// SamplesByLabel represents sets of samples keyed by label
type SamplesByLabel map[string]Samples
24 changes: 24 additions & 0 deletions trainingset.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,29 @@
package bag

import (
"fmt"
"os"

"github.com/go-yaml/yaml"
)

// makeTrainingSetFromFile will initialize a training set from a filepath
func makeTrainingSetFromFile(filepath string) (t TrainingSet, err error) {
var f *os.File
// Attempt to open file at given filepath
if f, err = os.Open(filepath); err != nil {
err = fmt.Errorf("error opening training set: %v", err)
return
}
// Close file when function exits
defer f.Close()

// Initialize new YAML decoder and decode file as a training set
err = yaml.NewDecoder(f).Decode(&t)
return
}

// TrainingSet is used to train a bag of words (BoW) model
type TrainingSet struct {
Config `yaml:"config"`

Expand Down
Loading

0 comments on commit 70b2112

Please sign in to comment.