Skip to content

Instantly share code, notes, and snippets.

@AJamesPhillips
Created July 5, 2016 21:45
Show Gist options
  • Save AJamesPhillips/07471da4b4be0190d8e34bf357c3c431 to your computer and use it in GitHub Desktop.
Save AJamesPhillips/07471da4b4be0190d8e34bf357c3c431 to your computer and use it in GitHub Desktop.
import AIToolbox
import XCTest
@testable import ios_hub
class LogRegTests: XCTestCase {
func getSmallTestData() -> DataSet {
// Create test case
let data = DataSet(dataType: .Classification, inputDimension: 2, outputDimension: 1)
do {
try data.addDataPoint(input: [91.7022004703, 0.0], output: 0)
try data.addDataPoint(input: [41.9194514403, 10.0], output: 0)
try data.addDataPoint(input: [80.0744568676, 20.0], output: 0)
try data.addDataPoint(input: [9.83468338331, 30.0], output: 0)
try data.addDataPoint(input: [98.8861088906, 40.0], output: 0)
try data.addDataPoint(input: [1.93669578703, 50.0], output: 1)
try data.addDataPoint(input: [10.2334428828, 60.0], output: 1)
try data.addDataPoint(input: [90.3401915288, 70.0], output: 1)
try data.addDataPoint(input: [28.3306091206, 80.0], output: 1)
try data.addDataPoint(input: [61.4745972953, 90.0], output: 1)
}
catch {
print("Invalid data set created")
}
return data
}
func getLargeTestData() -> DataSet {
// Create test case
let data = DataSet(dataType: .Classification, inputDimension: 2, outputDimension: 1)
do {
try data.addDataPoint(input: [41.7022004703, 0.0], output: 0)
try data.addDataPoint(input: [72.0324493442, 1.0], output: 0)
try data.addDataPoint(input: [0.0114374817345, 2.0], output: 0)
try data.addDataPoint(input: [30.2332572632, 3.0], output: 0)
try data.addDataPoint(input: [14.6755890817, 4.0], output: 0)
try data.addDataPoint(input: [9.23385947688, 5.0], output: 0)
try data.addDataPoint(input: [18.6260211378, 6.0], output: 0)
try data.addDataPoint(input: [34.5560727043, 7.0], output: 0)
try data.addDataPoint(input: [39.6767474231, 8.0], output: 0)
try data.addDataPoint(input: [53.8816734003, 9.0], output: 0)
try data.addDataPoint(input: [41.9194514403, 10.0], output: 0)
try data.addDataPoint(input: [68.5219500397, 11.0], output: 0)
try data.addDataPoint(input: [20.4452249732, 12.0], output: 0)
try data.addDataPoint(input: [87.8117436391, 13.0], output: 0)
try data.addDataPoint(input: [2.73875931979, 14.0], output: 0)
try data.addDataPoint(input: [67.0467510178, 15.0], output: 0)
try data.addDataPoint(input: [41.7304802367, 16.0], output: 0)
try data.addDataPoint(input: [55.8689828446, 17.0], output: 0)
try data.addDataPoint(input: [14.0386938595, 18.0], output: 0)
try data.addDataPoint(input: [19.8101489085, 19.0], output: 0)
try data.addDataPoint(input: [80.0744568676, 20.0], output: 0)
try data.addDataPoint(input: [96.8261575719, 21.0], output: 0)
try data.addDataPoint(input: [31.3424178159, 22.0], output: 0)
try data.addDataPoint(input: [69.2322615669, 23.0], output: 0)
try data.addDataPoint(input: [87.6389152296, 24.0], output: 0)
try data.addDataPoint(input: [89.4606663504, 25.0], output: 0)
try data.addDataPoint(input: [8.50442113698, 26.0], output: 0)
try data.addDataPoint(input: [3.90547832329, 27.0], output: 0)
try data.addDataPoint(input: [16.9830419565, 28.0], output: 0)
try data.addDataPoint(input: [87.8142503429, 29.0], output: 0)
try data.addDataPoint(input: [9.83468338331, 30.0], output: 0)
try data.addDataPoint(input: [42.1107625005, 31.0], output: 0)
try data.addDataPoint(input: [95.7889530151, 32.0], output: 0)
try data.addDataPoint(input: [53.3165284973, 33.0], output: 0)
try data.addDataPoint(input: [69.187711395, 34.0], output: 0)
try data.addDataPoint(input: [31.5515631006, 35.0], output: 0)
try data.addDataPoint(input: [68.6500927682, 36.0], output: 0)
try data.addDataPoint(input: [83.4625671897, 37.0], output: 0)
try data.addDataPoint(input: [1.82882773442, 38.0], output: 0)
try data.addDataPoint(input: [75.0144314945, 39.0], output: 0)
try data.addDataPoint(input: [98.8861088906, 40.0], output: 0)
try data.addDataPoint(input: [74.816565438, 41.0], output: 0)
try data.addDataPoint(input: [28.0443992064, 42.0], output: 0)
try data.addDataPoint(input: [78.9279328451, 43.0], output: 0)
try data.addDataPoint(input: [10.3226006578, 44.0], output: 0)
try data.addDataPoint(input: [44.7893526176, 45.0], output: 0)
try data.addDataPoint(input: [90.8595503093, 46.0], output: 0)
try data.addDataPoint(input: [29.3614148374, 47.0], output: 0)
try data.addDataPoint(input: [28.7775338586, 48.0], output: 0)
try data.addDataPoint(input: [13.0028572118, 49.0], output: 0)
try data.addDataPoint(input: [1.93669578703, 50.0], output: 1)
try data.addDataPoint(input: [67.883553294, 51.0], output: 1)
try data.addDataPoint(input: [21.1628116, 52.0], output: 1)
try data.addDataPoint(input: [26.5546659372, 53.0], output: 1)
try data.addDataPoint(input: [49.157315928, 54.0], output: 1)
try data.addDataPoint(input: [5.33625451171, 55.0], output: 1)
try data.addDataPoint(input: [57.4117605492, 56.0], output: 1)
try data.addDataPoint(input: [14.6728574906, 57.0], output: 1)
try data.addDataPoint(input: [58.9305536903, 58.0], output: 1)
try data.addDataPoint(input: [69.9758360021, 59.0], output: 1)
try data.addDataPoint(input: [10.2334428828, 60.0], output: 1)
try data.addDataPoint(input: [41.405598782, 61.0], output: 1)
try data.addDataPoint(input: [69.4400157728, 62.0], output: 1)
try data.addDataPoint(input: [41.4179269527, 63.0], output: 1)
try data.addDataPoint(input: [4.99534589461, 64.0], output: 1)
try data.addDataPoint(input: [53.5896405916, 65.0], output: 1)
try data.addDataPoint(input: [66.379464522, 66.0], output: 1)
try data.addDataPoint(input: [51.4889112058, 67.0], output: 1)
try data.addDataPoint(input: [94.4594755991, 68.0], output: 1)
try data.addDataPoint(input: [58.6555040502, 69.0], output: 1)
try data.addDataPoint(input: [90.3401915288, 70.0], output: 1)
try data.addDataPoint(input: [13.7474704146, 71.0], output: 1)
try data.addDataPoint(input: [13.9276347251, 72.0], output: 1)
try data.addDataPoint(input: [80.739128871, 73.0], output: 1)
try data.addDataPoint(input: [39.7676836986, 74.0], output: 1)
try data.addDataPoint(input: [16.5354197117, 75.0], output: 1)
try data.addDataPoint(input: [92.7508580396, 76.0], output: 1)
try data.addDataPoint(input: [34.7765859746, 77.0], output: 1)
try data.addDataPoint(input: [75.0812103136, 78.0], output: 1)
try data.addDataPoint(input: [72.599798535, 79.0], output: 1)
try data.addDataPoint(input: [88.3306091206, 80.0], output: 1)
try data.addDataPoint(input: [62.3672207056, 81.0], output: 1)
try data.addDataPoint(input: [75.0942434027, 82.0], output: 1)
try data.addDataPoint(input: [34.8898341978, 83.0], output: 1)
try data.addDataPoint(input: [26.9927891765, 84.0], output: 1)
try data.addDataPoint(input: [89.5886218196, 85.0], output: 1)
try data.addDataPoint(input: [42.8091189871, 86.0], output: 1)
try data.addDataPoint(input: [96.4840047148, 87.0], output: 1)
try data.addDataPoint(input: [66.3441497818, 88.0], output: 1)
try data.addDataPoint(input: [62.1695720209, 89.0], output: 1)
try data.addDataPoint(input: [11.4745972953, 90.0], output: 1)
try data.addDataPoint(input: [94.9489258707, 91.0], output: 1)
try data.addDataPoint(input: [44.991213348, 92.0], output: 1)
try data.addDataPoint(input: [57.8389614387, 93.0], output: 1)
try data.addDataPoint(input: [40.8136802761, 94.0], output: 1)
try data.addDataPoint(input: [23.7026980243, 95.0], output: 1)
try data.addDataPoint(input: [90.3379520562, 96.0], output: 1)
try data.addDataPoint(input: [57.3679486672, 97.0], output: 1)
try data.addDataPoint(input: [0.287032703116, 98.0], output: 1)
try data.addDataPoint(input: [61.7144913621, 99.0], output: 1)
try data.addDataPoint(input: [61.7144913621, 100.0], output: 1)
}
catch {
print("Invalid data set created")
}
return data
}
func createAndTrainLogisticRegression(data: DataSet) -> LogisticRegression {
let lr = LogisticRegression(numInputs : 2, solvingMethod: .SGD)
do {
try lr.trainClassifier(data)
}
catch {
print("Error training logistic regression")
}
return lr
}
func classifyTestData(lr: LogisticRegression) -> DataSet {
// Create the test set
let test = DataSet(dataType: .Classification, inputDimension: 2, outputDimension: 1)
do {
try test.addTestDataPoint(input: [100.0, 10.0])
try test.addTestDataPoint(input: [ 0.0, 20.0])
try test.addTestDataPoint(input: [100.0, 30.0])
try test.addTestDataPoint(input: [ 0.0, 40.0])
try test.addTestDataPoint(input: [100.0, 50.0])
try test.addTestDataPoint(input: [ 0.0, 60.0])
try test.addTestDataPoint(input: [100.0, 70.0])
try test.addTestDataPoint(input: [ 0.0, 80.0])
}
catch {
print("Invalid test sequence data set created")
}
// Classify the set
do {
try lr.classify(test)
}
catch {
print("Error having logistic regression classify")
}
return test
}
func testLogReg() {
let dataSmall = getSmallTestData()
let lrSmall = createAndTrainLogisticRegression(dataSmall)
let testSmall = classifyTestData(lrSmall)
// Verify the results
do {
var result : Int
result = try testSmall.getClass(0)
XCTAssertEqual(result, 0, "logistic regression test 0") // Usually fails
result = try testSmall.getClass(1)
XCTAssertEqual(result, 0, "logistic regression test 1") // Usually fails
result = try testSmall.getClass(2)
XCTAssertEqual(result, 0, "logistic regression test 2") // Usually fails
result = try testSmall.getClass(3)
XCTAssertEqual(result, 0, "logistic regression test 3") // Usually fails
result = try testSmall.getClass(4)
XCTAssertEqual(result, 1, "logistic regression test 4") // May fail
result = try testSmall.getClass(5)
XCTAssertEqual(result, 1, "logistic regression test 5") // May fail
result = try testSmall.getClass(6)
XCTAssertEqual(result, 1, "logistic regression test 6") // May fail
result = try testSmall.getClass(7)
XCTAssertEqual(result, 1, "logistic regression test 7") // May fail
}
catch {
print("Error getting test results")
}
let dataLarge = getLargeTestData()
let lrLarge = createAndTrainLogisticRegression(dataLarge)
let testLarge = classifyTestData(lrLarge)
// Verify the results
do {
var result : Int
result = try testLarge.getClass(0)
XCTAssertEqual(result, 0, "logistic regression test large 0") // Fails randomly
result = try testLarge.getClass(1)
XCTAssertEqual(result, 0, "logistic regression test large 1") // Fails randomly
result = try testLarge.getClass(2)
XCTAssertEqual(result, 0, "logistic regression test large 2") // Fails randomly
result = try testLarge.getClass(3)
XCTAssertEqual(result, 0, "logistic regression test large 3") // Fails randomly
result = try testLarge.getClass(4)
XCTAssertEqual(result, 1, "logistic regression test large 4") // Fails randomly
result = try testLarge.getClass(5)
XCTAssertEqual(result, 1, "logistic regression test large 5") // Fails randomly
result = try testLarge.getClass(6)
XCTAssertEqual(result, 1, "logistic regression test large 6") // Fails randomly
result = try testLarge.getClass(7)
XCTAssertEqual(result, 1, "logistic regression test large 7") // Fails randomly
}
catch {
print("Error getting test results")
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment