From 749c62125450ef9920e1d9ff3f2c952fc01f48d4 Mon Sep 17 00:00:00 2001 From: "Eric Borts (Lambda)" Date: Thu, 3 Apr 2014 22:48:14 -0600 Subject: [PATCH] Added test_predictOneVsAll --- ex3/test_ex3.m | 2 ++ ex3/test_predictOneVsAll.m | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 ex3/test_predictOneVsAll.m diff --git a/ex3/test_ex3.m b/ex3/test_ex3.m index 2f5f082..fe0b975 100644 --- a/ex3/test_ex3.m +++ b/ex3/test_ex3.m @@ -1,3 +1,5 @@ %!test test_sanity() %!test test_oneVsAll() + +%!test test_predictOneVsAll() \ No newline at end of file diff --git a/ex3/test_predictOneVsAll.m b/ex3/test_predictOneVsAll.m new file mode 100644 index 0000000..f0b3151 --- /dev/null +++ b/ex3/test_predictOneVsAll.m @@ -0,0 +1,24 @@ +function test_predictOneVsAll () + epsilon = 1e-3; + + % learning three classes + % + % x < 1.5 => 1 + % 1.5 < x < 3.5 => 2 + % 3.5 < x => 3 + % + X = [0 1 2 3 4 5]'; + y = [1 1 2 2 3 3]'; % direct classification + num_labels = 3; + + all_theta = oneVsAll(X, y, num_labels, 0); + assert(size(all_theta), [3 2]); + + yy = predictOneVsAll(all_theta, X); + assert(yy, y); + + % predict ones we haven't seen yet (should match the original formula) + X = [-10 2.5 10.0]'; + assert(predictOneVsAll(all_theta, X), [1 2 3]'); + +endfunction