diff --git a/bag.go b/bag.go index a4d1aa0..cc408a3 100644 --- a/bag.go +++ b/bag.go @@ -54,13 +54,6 @@ func (b *Bag) GetResults(in string) (r Results) { return } -func (b *Bag) toNGrams(in string) (ns []string) { - if b.c.NGramType == "word" { - return toNGrams(in, b.c.NGramSize) - } - - return tocharacterNGrams(in, b.c.NGramSize) -} func (b *Bag) Train(in, label string) { // Convert inbound data to a slice of NGrams @@ -73,16 +66,23 @@ func (b *Bag) Train(in, label string) { v[n]++ } - // Increment count of trained documents for the provided label - b.countByLabel[label]++ - // Increment total count of trained documents - b.totalCount++ + // Increment model counters + b.incrementCounts(label) +} + +// toNGrams converts the inbound string into n-grams based on the configuration settings +func (b *Bag) toNGrams(in string) (ns []string) { + if b.c.NGramType == "word" { + return toNGrams(in, b.c.NGramSize) + } + + return tocharacterNGrams(in, b.c.NGramSize) } // getProbability uses a Naive Bayes classifier to determine probability for a given label func (b *Bag) getProbability(ns []string, label string, vocab Vocabulary) (probability float64) { // Set initial probability value as the prior probability value - probability = b.getPriorProbability(label) + probability = b.getLogPriorProbability(label) // Get the current counts by label (to be used by Laplace smoothing during for-loop) countsByLabel := float64(b.countByLabel[label]) + b.c.SmoothingParameter*float64(len(vocab)) @@ -98,7 +98,7 @@ func (b *Bag) getProbability(ns []string, label string, vocab Vocabulary) (proba return } -func (b *Bag) getPriorProbability(label string) (probability float64) { +func (b *Bag) getLogPriorProbability(label string) (probability float64) { count := float64(b.countByLabel[label]) total := float64(b.totalCount) // Get the logarithmic value of count divided by total count @@ -118,3 +118,10 @@ func (b *Bag) getOrCreateVocabulary(label string) (v Vocabulary) { return } + +func (b *Bag) incrementCounts(label string) { + // Increment count of trained documents for the provided label + b.countByLabel[label]++ + // Increment total count of trained documents + b.totalCount++ +} diff --git a/ngram.go b/ngram.go index c4dffb2..a3e0dc9 100644 --- a/ngram.go +++ b/ngram.go @@ -2,57 +2,57 @@ package bag import "bytes" -// toNGrams will convert inbound data to an NGram of provided size +// toNGrams will convert inbound data to an nGram of provided size func toNGrams(in string, size int) (ns []string) { - // Initialize NGram with a provided size - n := make(NGram, size) + // Initialize nGram with a provided size + n := make(nGram, size) // Iterate inbound data as words toWords(in, func(word string) { - // Append word to NGram + // Append word to nGram n = n.Append(word) if !n.IsFull() { // NGram is not full - we do not want to append yet, return return } - // Append current NGram to NGrams slice + // Append current nGram to nGrams slice ns = append(ns, n.String()) }) if !n.IsFull() && !n.IsZero() { - // The NGram is not full, so we haven't appended yet - // The NGram is not empty, so we have something to append - // Append current NGram to NGrams slice + // The nGram is not full, so we haven't appended yet + // The nGram is not empty, so we have something to append + // Append current nGram to nGrams slice ns = append(ns, n.String()) } return } -// NGram represents an NGram (variable sized) -type NGram []string +// nGram represents an N-Gram (variable sized) +type nGram []string -// Append will append a given string to an NGram and output the new value -// Note: The original NGram is NOT modified -func (n NGram) Append(str string) (out NGram) { - // Initialize new NGram with the same size as the original NGram - out = make(NGram, len(n)) - // Iterate through original NGram, starting at index 1 +// Append will append a given string to an nGram and output the new value +// Note: The original nGram is NOT modified +func (n nGram) Append(str string) (out nGram) { + // Initialize new nGram with the same size as the original nGram + out = make(nGram, len(n)) + // Iterate through original nGram, starting at index 1 for i := 1; i < len(n); i++ { - // Set the value of the current original NGram index as the value for the previous index for the output NGram + // Set the value of the current original nGram index as the value for the previous index for the output nGram out[i-1] = n[i] } - // Set the last value of the output NGram as the input string + // Set the last value of the output nGram as the input string out[len(n)-1] = str return } -// String will convert the NGram contents to a string -func (n NGram) String() (out string) { +// String will convert the nGram contents to a string +func (n nGram) String() (out string) { // Initialize buffer buf := bytes.NewBuffer(nil) - // Iterate through NGram values + // Iterate through nGram values n.iterate(func(value string) { if buf.Len() > 0 { // Buffer is not empty, prefix the iterating value with a space @@ -67,21 +67,21 @@ func (n NGram) String() (out string) { return buf.String() } -// IsZero returns whether or not the NGram is empty -func (n NGram) IsZero() bool { +// IsZero returns whether or not the nGram is empty +func (n nGram) IsZero() bool { // Return result of if the value in the last position is empty return len(n[len(n)-1]) == 0 } -// IsFull returns whether or not the NGram is full -func (n NGram) IsFull() bool { +// IsFull returns whether or not the nGram is full +func (n nGram) IsFull() bool { // Return result of if the value in the first position is populated return len(n[0]) > 0 } -// iterate will iterate through the NGram values -func (n NGram) iterate(fn func(word string)) { - // Iterate through NGram values +// iterate will iterate through the nGram values +func (n nGram) iterate(fn func(word string)) { + // Iterate through nGram values for _, str := range n { // Check if value is empty if len(str) == 0 { diff --git a/sample.go b/sample.go deleted file mode 100644 index 731894d..0000000 --- a/sample.go +++ /dev/null @@ -1,6 +0,0 @@ -package bag - -type Sample struct { - Input string `toml:"input"` - Label string `toml:"label"` -} diff --git a/samples.go b/samples.go new file mode 100644 index 0000000..2951a88 --- /dev/null +++ b/samples.go @@ -0,0 +1,3 @@ +package bag + +type Samples []string diff --git a/samplesbylabel.go b/samplesbylabel.go new file mode 100644 index 0000000..8ecd9dd --- /dev/null +++ b/samplesbylabel.go @@ -0,0 +1,3 @@ +package bag + +type SamplesByLabel map[string]Samples diff --git a/trainingset.go b/trainingset.go index c845dbe..d44f20a 100644 --- a/trainingset.go +++ b/trainingset.go @@ -5,7 +5,3 @@ type TrainingSet struct { Samples SamplesByLabel `yaml:"samples"` } - -type SamplesByLabel map[string]Samples - -type Samples []string