From dd647e3338c718e40c01e76231d33f482fc8dfe4 Mon Sep 17 00:00:00 2001 From: vasnake Date: Fri, 4 Apr 2014 23:19:03 +0400 Subject: [PATCH] added: predictOneVsAll tests --- ex3/test_predictOneVsAll.m | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/ex3/test_predictOneVsAll.m b/ex3/test_predictOneVsAll.m index f0b3151..226e647 100644 --- a/ex3/test_predictOneVsAll.m +++ b/ex3/test_predictOneVsAll.m @@ -1,6 +1,6 @@ function test_predictOneVsAll () epsilon = 1e-3; - + % learning three classes % % x < 1.5 => 1 @@ -10,15 +10,29 @@ function test_predictOneVsAll () X = [0 1 2 3 4 5]'; y = [1 1 2 2 3 3]'; % direct classification num_labels = 3; - + all_theta = oneVsAll(X, y, num_labels, 0); assert(size(all_theta), [3 2]); - + yy = predictOneVsAll(all_theta, X); assert(yy, y); - + % predict ones we haven't seen yet (should match the original formula) X = [-10 2.5 10.0]'; assert(predictOneVsAll(all_theta, X), [1 2 3]'); - + + + % https://class.coursera.org/ml-005/forum/thread?thread_id=1425 + all_theta = oneVsAll([0.1 3.1 1.2; 1.8 0.9 0.7; 3.2 -1.4 6.7], [1 2 0]', 3, 0.3); + p = predictOneVsAll(all_theta, [0.1 3.1 1.2; 1.8 0.9 0.7; 3.2 -1.4 6.7]); + assert(p, [ 1 2 2 ]'); + + all_theta = oneVsAll([0 1 2 2 1 0 1 3 4 5 5 4 3]', [1 1 1 1 1 1 2 2 2 2 2 2 2]', 2, 1); + p = predictOneVsAll(all_theta, [0 1 2 2 1 0 1 3 4 5 5 4 3]'); + assert(p, [1 1 1 1 1 1 1 2 2 2 2 2 2]'); + + all_theta = oneVsAll([0 1 2 2 1 0 3 4 5 5 4 3]', [1 1 1 1 1 1 2 2 2 2 2 2]', 2, 0.001); + p = predictOneVsAll(all_theta, [0 1 2 2 1 0 3 4 5 5 4 3]'); + assert(p, [1 1 1 1 1 1 2 2 2 2 2 2]'); + endfunction