Skip to content

Instantly share code, notes, and snippets.

@kmaher9
Created August 24, 2018 21:44
Show Gist options
  • Save kmaher9/32cdbd80c05bea9e7fb6b87e69d1b877 to your computer and use it in GitHub Desktop.
Save kmaher9/32cdbd80c05bea9e7fb6b87e69d1b877 to your computer and use it in GitHub Desktop.
package main
import (
"fmt"
"io/ioutil"
"github.com/jbrukh/bayesian"
)
// exported: the classes that are used to store the learned data.
const (
Business bayesian.Class = "Business" // represents all learned business documents.
Tech bayesian.Class = "Tech" // represents all learned tech documents.
businessDirectoryLocation = "bbc/business"
techDirectoryLocation = "bbc/tech"
testFileLocation = "bbc/test.txt"
)
var businessFiles []string // holds the extracted contents of all business files, for use in training.
var techFiles []string // holds the extracted contents of all tech files, for use in training.
func main() {
classifier := bayesian.NewClassifier(Business, Tech)
enumerateClasses()
learn(classifier)
predict(testFileLocation, classifier)
}
func predict(location string, classifier *bayesian.Classifier) {
probabilities, _, _ := classifier.ProbScores([]string{readFile(location)})
fmt.Println(probabilities)
}
func learn(classifier *bayesian.Classifier) {
classifier.Learn(businessFiles, Business)
classifier.Learn(techFiles, Tech)
}
func enumerateClasses() {
businessDirectory := enumerateDirectory(businessDirectoryLocation) // retrieves a list of all filenames stored in the directory.
for _, file := range businessDirectory {
fileContent := readFile(file) // returns the entire contents of the files as a string.
businessFiles = append(businessFiles, fileContent)
}
techDirectory := enumerateDirectory(techDirectoryLocation)
for _, file := range techDirectory {
fileContent := readFile(file)
techFiles = append(techFiles, fileContent)
}
}
func enumerateDirectory(location string) []string {
var files []string
listing, err := ioutil.ReadDir(location)
if err != nil {
panic(err)
}
for _, l := range listing {
files = append(files, location+"/"+l.Name())
}
return files
}
func readFile(location string) string {
file, err := ioutil.ReadFile(location)
if err != nil {
panic(err)
}
return string(file)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment