Created
July 5, 2016 21:45
-
-
Save AJamesPhillips/07471da4b4be0190d8e34bf357c3c431 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import AIToolbox | |
import XCTest | |
@testable import ios_hub | |
class LogRegTests: XCTestCase { | |
func getSmallTestData() -> DataSet { | |
// Create test case | |
let data = DataSet(dataType: .Classification, inputDimension: 2, outputDimension: 1) | |
do { | |
try data.addDataPoint(input: [91.7022004703, 0.0], output: 0) | |
try data.addDataPoint(input: [41.9194514403, 10.0], output: 0) | |
try data.addDataPoint(input: [80.0744568676, 20.0], output: 0) | |
try data.addDataPoint(input: [9.83468338331, 30.0], output: 0) | |
try data.addDataPoint(input: [98.8861088906, 40.0], output: 0) | |
try data.addDataPoint(input: [1.93669578703, 50.0], output: 1) | |
try data.addDataPoint(input: [10.2334428828, 60.0], output: 1) | |
try data.addDataPoint(input: [90.3401915288, 70.0], output: 1) | |
try data.addDataPoint(input: [28.3306091206, 80.0], output: 1) | |
try data.addDataPoint(input: [61.4745972953, 90.0], output: 1) | |
} | |
catch { | |
print("Invalid data set created") | |
} | |
return data | |
} | |
func getLargeTestData() -> DataSet { | |
// Create test case | |
let data = DataSet(dataType: .Classification, inputDimension: 2, outputDimension: 1) | |
do { | |
try data.addDataPoint(input: [41.7022004703, 0.0], output: 0) | |
try data.addDataPoint(input: [72.0324493442, 1.0], output: 0) | |
try data.addDataPoint(input: [0.0114374817345, 2.0], output: 0) | |
try data.addDataPoint(input: [30.2332572632, 3.0], output: 0) | |
try data.addDataPoint(input: [14.6755890817, 4.0], output: 0) | |
try data.addDataPoint(input: [9.23385947688, 5.0], output: 0) | |
try data.addDataPoint(input: [18.6260211378, 6.0], output: 0) | |
try data.addDataPoint(input: [34.5560727043, 7.0], output: 0) | |
try data.addDataPoint(input: [39.6767474231, 8.0], output: 0) | |
try data.addDataPoint(input: [53.8816734003, 9.0], output: 0) | |
try data.addDataPoint(input: [41.9194514403, 10.0], output: 0) | |
try data.addDataPoint(input: [68.5219500397, 11.0], output: 0) | |
try data.addDataPoint(input: [20.4452249732, 12.0], output: 0) | |
try data.addDataPoint(input: [87.8117436391, 13.0], output: 0) | |
try data.addDataPoint(input: [2.73875931979, 14.0], output: 0) | |
try data.addDataPoint(input: [67.0467510178, 15.0], output: 0) | |
try data.addDataPoint(input: [41.7304802367, 16.0], output: 0) | |
try data.addDataPoint(input: [55.8689828446, 17.0], output: 0) | |
try data.addDataPoint(input: [14.0386938595, 18.0], output: 0) | |
try data.addDataPoint(input: [19.8101489085, 19.0], output: 0) | |
try data.addDataPoint(input: [80.0744568676, 20.0], output: 0) | |
try data.addDataPoint(input: [96.8261575719, 21.0], output: 0) | |
try data.addDataPoint(input: [31.3424178159, 22.0], output: 0) | |
try data.addDataPoint(input: [69.2322615669, 23.0], output: 0) | |
try data.addDataPoint(input: [87.6389152296, 24.0], output: 0) | |
try data.addDataPoint(input: [89.4606663504, 25.0], output: 0) | |
try data.addDataPoint(input: [8.50442113698, 26.0], output: 0) | |
try data.addDataPoint(input: [3.90547832329, 27.0], output: 0) | |
try data.addDataPoint(input: [16.9830419565, 28.0], output: 0) | |
try data.addDataPoint(input: [87.8142503429, 29.0], output: 0) | |
try data.addDataPoint(input: [9.83468338331, 30.0], output: 0) | |
try data.addDataPoint(input: [42.1107625005, 31.0], output: 0) | |
try data.addDataPoint(input: [95.7889530151, 32.0], output: 0) | |
try data.addDataPoint(input: [53.3165284973, 33.0], output: 0) | |
try data.addDataPoint(input: [69.187711395, 34.0], output: 0) | |
try data.addDataPoint(input: [31.5515631006, 35.0], output: 0) | |
try data.addDataPoint(input: [68.6500927682, 36.0], output: 0) | |
try data.addDataPoint(input: [83.4625671897, 37.0], output: 0) | |
try data.addDataPoint(input: [1.82882773442, 38.0], output: 0) | |
try data.addDataPoint(input: [75.0144314945, 39.0], output: 0) | |
try data.addDataPoint(input: [98.8861088906, 40.0], output: 0) | |
try data.addDataPoint(input: [74.816565438, 41.0], output: 0) | |
try data.addDataPoint(input: [28.0443992064, 42.0], output: 0) | |
try data.addDataPoint(input: [78.9279328451, 43.0], output: 0) | |
try data.addDataPoint(input: [10.3226006578, 44.0], output: 0) | |
try data.addDataPoint(input: [44.7893526176, 45.0], output: 0) | |
try data.addDataPoint(input: [90.8595503093, 46.0], output: 0) | |
try data.addDataPoint(input: [29.3614148374, 47.0], output: 0) | |
try data.addDataPoint(input: [28.7775338586, 48.0], output: 0) | |
try data.addDataPoint(input: [13.0028572118, 49.0], output: 0) | |
try data.addDataPoint(input: [1.93669578703, 50.0], output: 1) | |
try data.addDataPoint(input: [67.883553294, 51.0], output: 1) | |
try data.addDataPoint(input: [21.1628116, 52.0], output: 1) | |
try data.addDataPoint(input: [26.5546659372, 53.0], output: 1) | |
try data.addDataPoint(input: [49.157315928, 54.0], output: 1) | |
try data.addDataPoint(input: [5.33625451171, 55.0], output: 1) | |
try data.addDataPoint(input: [57.4117605492, 56.0], output: 1) | |
try data.addDataPoint(input: [14.6728574906, 57.0], output: 1) | |
try data.addDataPoint(input: [58.9305536903, 58.0], output: 1) | |
try data.addDataPoint(input: [69.9758360021, 59.0], output: 1) | |
try data.addDataPoint(input: [10.2334428828, 60.0], output: 1) | |
try data.addDataPoint(input: [41.405598782, 61.0], output: 1) | |
try data.addDataPoint(input: [69.4400157728, 62.0], output: 1) | |
try data.addDataPoint(input: [41.4179269527, 63.0], output: 1) | |
try data.addDataPoint(input: [4.99534589461, 64.0], output: 1) | |
try data.addDataPoint(input: [53.5896405916, 65.0], output: 1) | |
try data.addDataPoint(input: [66.379464522, 66.0], output: 1) | |
try data.addDataPoint(input: [51.4889112058, 67.0], output: 1) | |
try data.addDataPoint(input: [94.4594755991, 68.0], output: 1) | |
try data.addDataPoint(input: [58.6555040502, 69.0], output: 1) | |
try data.addDataPoint(input: [90.3401915288, 70.0], output: 1) | |
try data.addDataPoint(input: [13.7474704146, 71.0], output: 1) | |
try data.addDataPoint(input: [13.9276347251, 72.0], output: 1) | |
try data.addDataPoint(input: [80.739128871, 73.0], output: 1) | |
try data.addDataPoint(input: [39.7676836986, 74.0], output: 1) | |
try data.addDataPoint(input: [16.5354197117, 75.0], output: 1) | |
try data.addDataPoint(input: [92.7508580396, 76.0], output: 1) | |
try data.addDataPoint(input: [34.7765859746, 77.0], output: 1) | |
try data.addDataPoint(input: [75.0812103136, 78.0], output: 1) | |
try data.addDataPoint(input: [72.599798535, 79.0], output: 1) | |
try data.addDataPoint(input: [88.3306091206, 80.0], output: 1) | |
try data.addDataPoint(input: [62.3672207056, 81.0], output: 1) | |
try data.addDataPoint(input: [75.0942434027, 82.0], output: 1) | |
try data.addDataPoint(input: [34.8898341978, 83.0], output: 1) | |
try data.addDataPoint(input: [26.9927891765, 84.0], output: 1) | |
try data.addDataPoint(input: [89.5886218196, 85.0], output: 1) | |
try data.addDataPoint(input: [42.8091189871, 86.0], output: 1) | |
try data.addDataPoint(input: [96.4840047148, 87.0], output: 1) | |
try data.addDataPoint(input: [66.3441497818, 88.0], output: 1) | |
try data.addDataPoint(input: [62.1695720209, 89.0], output: 1) | |
try data.addDataPoint(input: [11.4745972953, 90.0], output: 1) | |
try data.addDataPoint(input: [94.9489258707, 91.0], output: 1) | |
try data.addDataPoint(input: [44.991213348, 92.0], output: 1) | |
try data.addDataPoint(input: [57.8389614387, 93.0], output: 1) | |
try data.addDataPoint(input: [40.8136802761, 94.0], output: 1) | |
try data.addDataPoint(input: [23.7026980243, 95.0], output: 1) | |
try data.addDataPoint(input: [90.3379520562, 96.0], output: 1) | |
try data.addDataPoint(input: [57.3679486672, 97.0], output: 1) | |
try data.addDataPoint(input: [0.287032703116, 98.0], output: 1) | |
try data.addDataPoint(input: [61.7144913621, 99.0], output: 1) | |
try data.addDataPoint(input: [61.7144913621, 100.0], output: 1) | |
} | |
catch { | |
print("Invalid data set created") | |
} | |
return data | |
} | |
func createAndTrainLogisticRegression(data: DataSet) -> LogisticRegression { | |
let lr = LogisticRegression(numInputs : 2, solvingMethod: .SGD) | |
do { | |
try lr.trainClassifier(data) | |
} | |
catch { | |
print("Error training logistic regression") | |
} | |
return lr | |
} | |
func classifyTestData(lr: LogisticRegression) -> DataSet { | |
// Create the test set | |
let test = DataSet(dataType: .Classification, inputDimension: 2, outputDimension: 1) | |
do { | |
try test.addTestDataPoint(input: [100.0, 10.0]) | |
try test.addTestDataPoint(input: [ 0.0, 20.0]) | |
try test.addTestDataPoint(input: [100.0, 30.0]) | |
try test.addTestDataPoint(input: [ 0.0, 40.0]) | |
try test.addTestDataPoint(input: [100.0, 50.0]) | |
try test.addTestDataPoint(input: [ 0.0, 60.0]) | |
try test.addTestDataPoint(input: [100.0, 70.0]) | |
try test.addTestDataPoint(input: [ 0.0, 80.0]) | |
} | |
catch { | |
print("Invalid test sequence data set created") | |
} | |
// Classify the set | |
do { | |
try lr.classify(test) | |
} | |
catch { | |
print("Error having logistic regression classify") | |
} | |
return test | |
} | |
func testLogReg() { | |
let dataSmall = getSmallTestData() | |
let lrSmall = createAndTrainLogisticRegression(dataSmall) | |
let testSmall = classifyTestData(lrSmall) | |
// Verify the results | |
do { | |
var result : Int | |
result = try testSmall.getClass(0) | |
XCTAssertEqual(result, 0, "logistic regression test 0") // Usually fails | |
result = try testSmall.getClass(1) | |
XCTAssertEqual(result, 0, "logistic regression test 1") // Usually fails | |
result = try testSmall.getClass(2) | |
XCTAssertEqual(result, 0, "logistic regression test 2") // Usually fails | |
result = try testSmall.getClass(3) | |
XCTAssertEqual(result, 0, "logistic regression test 3") // Usually fails | |
result = try testSmall.getClass(4) | |
XCTAssertEqual(result, 1, "logistic regression test 4") // May fail | |
result = try testSmall.getClass(5) | |
XCTAssertEqual(result, 1, "logistic regression test 5") // May fail | |
result = try testSmall.getClass(6) | |
XCTAssertEqual(result, 1, "logistic regression test 6") // May fail | |
result = try testSmall.getClass(7) | |
XCTAssertEqual(result, 1, "logistic regression test 7") // May fail | |
} | |
catch { | |
print("Error getting test results") | |
} | |
let dataLarge = getLargeTestData() | |
let lrLarge = createAndTrainLogisticRegression(dataLarge) | |
let testLarge = classifyTestData(lrLarge) | |
// Verify the results | |
do { | |
var result : Int | |
result = try testLarge.getClass(0) | |
XCTAssertEqual(result, 0, "logistic regression test large 0") // Fails randomly | |
result = try testLarge.getClass(1) | |
XCTAssertEqual(result, 0, "logistic regression test large 1") // Fails randomly | |
result = try testLarge.getClass(2) | |
XCTAssertEqual(result, 0, "logistic regression test large 2") // Fails randomly | |
result = try testLarge.getClass(3) | |
XCTAssertEqual(result, 0, "logistic regression test large 3") // Fails randomly | |
result = try testLarge.getClass(4) | |
XCTAssertEqual(result, 1, "logistic regression test large 4") // Fails randomly | |
result = try testLarge.getClass(5) | |
XCTAssertEqual(result, 1, "logistic regression test large 5") // Fails randomly | |
result = try testLarge.getClass(6) | |
XCTAssertEqual(result, 1, "logistic regression test large 6") // Fails randomly | |
result = try testLarge.getClass(7) | |
XCTAssertEqual(result, 1, "logistic regression test large 7") // Fails randomly | |
} | |
catch { | |
print("Error getting test results") | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment