Skip to content

Commit

Permalink
improve commenting throughout
Browse files Browse the repository at this point in the history
  • Loading branch information
itsmontoya committed Jul 18, 2024
1 parent 8fde6cd commit bd4f5db
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 10 deletions.
7 changes: 7 additions & 0 deletions bag.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ func New(c Config) (out *Bag, err error) {
// NewFromTrainingSet will initialize and return a new pre-trained Bag from a provided training set
func NewFromTrainingSet(t TrainingSet) (b *Bag, err error) {
if b, err = New(t.Config); err != nil {
// Error initializing, return
return
}

// Train with provided samples, iterate over samples by label
for label, samples := range t.Samples {
// For each within samples slice
for _, sample := range samples {
// Train for a given sample and label
b.Train(sample, label)
}
}
Expand All @@ -35,10 +39,13 @@ func NewFromTrainingSet(t TrainingSet) (b *Bag, err error) {
// NewFromTrainingSetFile will initialize and return a new pre-trained Bag from a provided training set filepath
func NewFromTrainingSetFile(filepath string) (b *Bag, err error) {
var t TrainingSet
// Make new training set
if t, err = makeTrainingSetFromFile(filepath); err != nil {
// Error making training set, return
return
}

// Create new Bag from training set
return NewFromTrainingSet(t)
}

Expand Down
35 changes: 25 additions & 10 deletions circularbuffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,58 @@ type circularBuffer[T any] struct {
s []T
}

// Shift will add an item to the end of the circular buffer,
// if the buffer is full - it will pop an item from the front
func (c *circularBuffer[T]) Shift(item T) (popped T) {
// Get oldest buffer item as popped value
popped = c.s[c.end]
// Replace oldest position with new item
c.s[c.end] = item

c.end++
if c.len < c.cap {
c.len++
} else {
if c.start++; c.start >= c.cap {
c.start = 0
}

// Increment oldest position and check to see if the oldest position exceeds capacity
if c.end++; c.end >= c.cap {
// Oldest position exceeds capacity, set to 0
c.end = 0
}

if c.end >= c.cap {
c.end = 0
// Check to see if length is less than capaicity
if c.len < c.cap {
// Length is not at capacity, increment
c.len++
// Increment start value and check to see if new value exceeds capacity
} else if c.start++; c.start >= c.cap {
// New start value exceeds start capacity, set to 0
c.start = 0
}

return
}

// ForEach will iterate through the buffer items
func (c *circularBuffer[T]) ForEach(fn func(t T) (end bool)) (ended bool) {
// First index is at starting position
index := c.start
// Iterate for the length of the buffer
for i := 0; i < c.len; i++ {
// Get item at current index
item := c.s[index]
// Pass item to func
if fn(item) {
// Func returned break boolean as true, return true
return true
}

// Increment index and see if index exceeds length
if index++; index >= c.len {
// Index exceeds length, set to 0
index = 0
}
}

return
}

// Len will return the length of a circular buffer
func (c *circularBuffer[T]) Len() int {
return c.len
}
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Config struct {
func (c *Config) Validate() (err error) {
c.fill()

// Check to see if n-gram type is supported
switch c.NGramType {
case "word":
case "character":
Expand Down
6 changes: 6 additions & 0 deletions results.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@ package bag

import "math"

// Results (by label) represents the probability of a processed input matching each of the possible labels (classifications)
type Results map[string]float64

func (r Results) GetHighestProbability() (match string) {
// Since probability values can be negative, initialize to negative infinity
max := math.Inf(-1)
// Iterate through probability results
for label, prob := range r {
// Check to see if the current probability is higher than the max
if prob > max {
// Current probability is higher
// Set max as the current probability
max = prob
// Set match as the current label
match = label
}
}
Expand Down
1 change: 1 addition & 0 deletions samples.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
package bag

// Samples represents a set of input samples to be used for model training
type Samples []string
1 change: 1 addition & 0 deletions samplesbylabel.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
package bag

// SamplesByLabel represents sets of samples keyed by label
type SamplesByLabel map[string]Samples
5 changes: 5 additions & 0 deletions trainingset.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@ import (
"github.com/go-yaml/yaml"
)

// makeTrainingSetFromFile will initialize a training set from a filepath
func makeTrainingSetFromFile(filepath string) (t TrainingSet, err error) {
var f *os.File
// Attempt to open file at given filepath
if f, err = os.Open(filepath); err != nil {
err = fmt.Errorf("error opening training set: %v", err)
return
}
// Close file when function exits
defer f.Close()

// Initialize new YAML decoder and decode file as a training set
err = yaml.NewDecoder(f).Decode(&t)
return
}

// TrainingSet is used to train a bag of words (BoW) model
type TrainingSet struct {
Config `yaml:"config"`

Expand Down
22 changes: 22 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,63 @@ import (
"unicode"
)

// toWords will split a string into words
// - whitespace is omitted
// - puncation is omitted
// - all characters repeating more than 2 times will be truncated to 2
// - all characters are lowercased
func toWords(in string, onWord func(string)) {
// Buffer is used to write characters while building words
buf := bytes.NewBuffer(nil)
// Circular buffer is used to look back on previous characters
c := newCircularBuffer[rune](2)
for _, char := range in {
switch {
case unicode.IsLetter(char):
char = unicode.ToLower(char)
// Create filter function for letter repetition
isMatch := func(r rune) (end bool) {
return r != char
}

// If length is less than two, or letter changes encountered
if c.Len() < 2 || c.ForEach(isMatch) {
// Write character to buffer
buf.WriteRune(char)
// Shift circular buffer with new character
c.Shift(char)
}

case unicode.IsSpace(char):
// Space encountered, call onWord with word
onWord(buf.String())
buf.Reset()
}
}

// Check to see if the buffer is not empty
if buf.Len() > 0 {
// Buffer is not empty, call onWord with word
onWord(buf.String())
buf.Reset()
}
}

// toCharacters will split a string into characters
// - whitespace is included
// - puncation is omitted
// - all characters are lowercased
func toCharacters(in string, onChar func(rune)) {
// Iterate through all input string runes
for _, char := range in {
switch {
case unicode.IsLetter(char):
// Lowercase character
char = unicode.ToLower(char)
// Call onChar with characters
onChar(char)
case unicode.IsSpace(char):
// Call onChar with whitespace
onChar(char)
}
}
Expand Down
1 change: 1 addition & 0 deletions vocabulary.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
package bag

// Vocabulary represents number of encounters for a given n-gram as string
type Vocabulary map[string]int

0 comments on commit bd4f5db

Please sign in to comment.