Created
August 24, 2018 21:44
-
-
Save kmaher9/32cdbd80c05bea9e7fb6b87e69d1b877 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"fmt" | |
"io/ioutil" | |
"github.com/jbrukh/bayesian" | |
) | |
// exported: the classes that are used to store the learned data. | |
const ( | |
Business bayesian.Class = "Business" // represents all learned business documents. | |
Tech bayesian.Class = "Tech" // represents all learned tech documents. | |
businessDirectoryLocation = "bbc/business" | |
techDirectoryLocation = "bbc/tech" | |
testFileLocation = "bbc/test.txt" | |
) | |
var businessFiles []string // holds the extracted contents of all business files, for use in training. | |
var techFiles []string // holds the extracted contents of all tech files, for use in training. | |
func main() { | |
classifier := bayesian.NewClassifier(Business, Tech) | |
enumerateClasses() | |
learn(classifier) | |
predict(testFileLocation, classifier) | |
} | |
func predict(location string, classifier *bayesian.Classifier) { | |
probabilities, _, _ := classifier.ProbScores([]string{readFile(location)}) | |
fmt.Println(probabilities) | |
} | |
func learn(classifier *bayesian.Classifier) { | |
classifier.Learn(businessFiles, Business) | |
classifier.Learn(techFiles, Tech) | |
} | |
func enumerateClasses() { | |
businessDirectory := enumerateDirectory(businessDirectoryLocation) // retrieves a list of all filenames stored in the directory. | |
for _, file := range businessDirectory { | |
fileContent := readFile(file) // returns the entire contents of the files as a string. | |
businessFiles = append(businessFiles, fileContent) | |
} | |
techDirectory := enumerateDirectory(techDirectoryLocation) | |
for _, file := range techDirectory { | |
fileContent := readFile(file) | |
techFiles = append(techFiles, fileContent) | |
} | |
} | |
func enumerateDirectory(location string) []string { | |
var files []string | |
listing, err := ioutil.ReadDir(location) | |
if err != nil { | |
panic(err) | |
} | |
for _, l := range listing { | |
files = append(files, location+"/"+l.Name()) | |
} | |
return files | |
} | |
func readFile(location string) string { | |
file, err := ioutil.ReadFile(location) | |
if err != nil { | |
panic(err) | |
} | |
return string(file) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment