From de9c7db17dc9464475f588b12c6e3924d694ebaf Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sun, 9 Jun 2019 17:24:26 +0200 Subject: [PATCH] SVMLin & SVMSGD --- shogun/classifier/svm/SVMLin.cpp | 31 +++++-------------------------- shogun/classifier/svm/SVMLin.h | 14 ++------------ shogun/classifier/svm/SVMSGD.cpp | 29 ++++++++++++----------------- shogun/classifier/svm/SVMSGD.h | 7 +++---- 4 files changed, 22 insertions(+), 59 deletions(-) diff --git a/shogun/classifier/svm/SVMLin.cpp b/shogun/classifier/svm/SVMLin.cpp index dc82d27..956a406 100644 --- a/shogun/classifier/svm/SVMLin.cpp +++ b/shogun/classifier/svm/SVMLin.cpp @@ -25,17 +25,6 @@ CSVMLin::CSVMLin() init(); } -CSVMLin::CSVMLin( - float64_t C, CDotFeatures* traindat, CLabels* trainlab) -: CLinearMachine(), C1(C), C2(C), epsilon(1e-5), use_bias(true) -{ - set_features(traindat); - set_labels(trainlab); - - init(); -} - - CSVMLin::~CSVMLin() { } @@ -51,22 +40,13 @@ void CSVMLin::init() SG_ADD(&epsilon, "epsilon", "Convergence precision."); } -bool CSVMLin::train_machine(CFeatures* data) +void CSVMLin::train_machine(CFeatures* features, CLabels* labels) { - ASSERT(m_labels) - - if (data) - { - if (!data->has_property(FP_DOT)) - SG_ERROR("Specified features are not of type CDotFeatures\n") - set_features((CDotFeatures*) data); - } - ASSERT(features) - SGVector train_labels=((CBinaryLabels*) m_labels)->get_labels(); - int32_t num_feat=features->get_dim_feature_space(); - int32_t num_vec=features->get_num_vectors(); + SGVector train_labels = binary_labels(labels)->get_labels(); + int32_t num_feat = features->as()->get_dim_feature_space(); + int32_t num_vec = features->get_num_vectors(); ASSERT(num_vec==train_labels.vlen) @@ -81,7 +61,7 @@ bool CSVMLin::train_machine(CFeatures* data) Data.n=num_feat+1; Data.nz=num_feat+1; Data.Y=train_labels.vector; - Data.features=features; + Data.features = features->as(); Data.C = SG_MALLOC(float64_t, Data.l); Options.algo = SVM; @@ -119,5 +99,4 @@ bool CSVMLin::train_machine(CFeatures* data) SG_FREE(Data.C); SG_FREE(Outputs.vec); - return true; } diff --git a/shogun/classifier/svm/SVMLin.h b/shogun/classifier/svm/SVMLin.h index ae5db4e..1cdb50a 100644 --- a/shogun/classifier/svm/SVMLin.h +++ b/shogun/classifier/svm/SVMLin.h @@ -29,15 +29,6 @@ class CSVMLin : public CLinearMachine /** default constructor */ CSVMLin(); - /** constructor - * - * @param C constant C - * @param traindat training features - * @param trainlab labels for features - */ - CSVMLin( - float64_t C, CDotFeatures* traindat, - CLabels* trainlab); virtual ~CSVMLin(); /** set C @@ -93,10 +84,9 @@ class CSVMLin : public CLinearMachine * @param data training data (parameter can be avoided if distance or * kernel-based classifiers are used and distance/kernels are * initialized with train data) - * - * @return whether training was successful + * @param labels training labels */ - virtual bool train_machine(CFeatures* data=NULL); + virtual void train_machine(CFeatures* features, CLabels* labels); /** set up parameters */ void init(); diff --git a/shogun/classifier/svm/SVMSGD.cpp b/shogun/classifier/svm/SVMSGD.cpp index fe538a0..f0d92a7 100644 --- a/shogun/classifier/svm/SVMSGD.cpp +++ b/shogun/classifier/svm/SVMSGD.cpp @@ -68,18 +68,10 @@ void CSVMSGD::set_loss_function(CLossFunction* loss_func) loss=loss_func; } -bool CSVMSGD::train_machine(CFeatures* data) +void CSVMSGD::train_machine(CFeatures* features, CLabels* labels) { // allocate memory for w and initialize everyting w and bias with 0 - auto labels = binary_labels(m_labels); - - if (data) - { - if (!data->has_property(FP_DOT)) - SG_ERROR("Specified features are not of type CDotFeatures\n") - set_features((CDotFeatures*) data); - } - + auto bin_labels = binary_labels(m_labels); ASSERT(features) int32_t num_train_labels = labels->get_num_labels(); @@ -88,7 +80,8 @@ bool CSVMSGD::train_machine(CFeatures* data) ASSERT(num_vec==num_train_labels) ASSERT(num_vec>0) - SGVector w(features->get_dim_feature_space()); + SGVector w( + features->as()->get_dim_feature_space()); w.zero(); bias=0; @@ -121,13 +114,17 @@ bool CSVMSGD::train_machine(CFeatures* data) for (int32_t i=0; iget_label(i); - float64_t z = y * (features->dense_dot(i, w.vector, w.vlen) + bias); + float64_t y = bin_labels->get_label(i); + float64_t z = + y * + (features->as()->dense_dot(i, w.vector, w.vlen) + + bias); if (z < 1 || is_log_loss) { - float64_t etd = -eta * loss->first_derivative(z,1); - features->add_to_dense_vec(etd * y / wscale, i, w.vector, w.vlen); + float64_t etd = -eta * loss->first_derivative(z, 1); + features->as()->add_to_dense_vec( + etd * y / wscale, i, w.vector, w.vlen); if (use_bias) { @@ -153,8 +150,6 @@ bool CSVMSGD::train_machine(CFeatures* data) SG_INFO("Norm: %.6f, Bias: %.6f\n", wnorm, bias) set_w(w); - - return true; } void CSVMSGD::calibrate() diff --git a/shogun/classifier/svm/SVMSGD.h b/shogun/classifier/svm/SVMSGD.h index 25af009..5e341d0 100644 --- a/shogun/classifier/svm/SVMSGD.h +++ b/shogun/classifier/svm/SVMSGD.h @@ -138,13 +138,12 @@ class CSVMSGD : public CLinearMachine /** train classifier * - * @param data training data (parameter can be avoided if distance or + * @param features training data (parameter can be avoided if distance or * kernel-based classifiers are used and distance/kernels are * initialized with train data) - * - * @return whether training was successful + * @param labels training labels */ - virtual bool train_machine(CFeatures* data=NULL); + virtual void train_machine(CFeatures* features, CLabels* labels); private: void init();