From fdc72d1181d7a89eb3cab027d62826ec70a67e83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C4=81ris=20Narti=C5=A1s?= Date: Tue, 9 Jan 2024 20:44:40 +0200 Subject: [PATCH] i.svm: Add LIBSVM-based image classification (#2189) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two modules – i.svm.train and i.svm.predict – provide a supervised raster imagery classification workflow. Both modules utilize LIBSVM to perform actual classification. Modules are designed to mimic existing classification modules e.g. by providing similar parameter names. Thanks to @wenzeslaus and @nilason for helpful feedback. --- configure | 201 +++++++ configure.ac | 27 + gui/wxpython/gui_core/gselect.py | 7 +- gui/wxpython/xml/toolboxes.xml | 7 + imagery/Makefile | 2 + imagery/i.signatures/main.c | 14 +- .../testsuite/test_i_signatures.py | 59 ++ imagery/i.svm.predict/Makefile | 14 + imagery/i.svm.predict/i.svm.predict.html | 82 +++ imagery/i.svm.predict/main.c | 424 +++++++++++++++ .../testsuite/test_i_svm_predict.py | 191 +++++++ imagery/i.svm.train/Makefile | 14 + imagery/i.svm.train/fill.c | 102 ++++ imagery/i.svm.train/fill.h | 29 + imagery/i.svm.train/i.svm.train.html | 111 ++++ imagery/i.svm.train/main.c | 513 ++++++++++++++++++ .../i.svm.train/testsuite/test_i_svm_train.py | 321 +++++++++++ include/Make/Platform.make.in | 5 + include/grass/config.h.in | 9 + include/grass/imagery.h | 4 +- lib/imagery/manage_signatures.c | 3 + lib/imagery/testsuite/test_imagery_find.py | 55 ++ .../test_imagery_signature_management.py | 337 +++++++++++- 23 files changed, 2524 insertions(+), 7 deletions(-) create mode 100644 imagery/i.svm.predict/Makefile create mode 100644 imagery/i.svm.predict/i.svm.predict.html create mode 100644 imagery/i.svm.predict/main.c create mode 100644 imagery/i.svm.predict/testsuite/test_i_svm_predict.py create mode 100644 imagery/i.svm.train/Makefile create mode 100644 imagery/i.svm.train/fill.c create mode 100644 imagery/i.svm.train/fill.h create mode 100644 imagery/i.svm.train/i.svm.train.html create mode 100644 imagery/i.svm.train/main.c create mode 100644 imagery/i.svm.train/testsuite/test_i_svm_train.py diff --git a/configure b/configure index 8f3a589cf9e..73f7c3f5d8d 100755 --- a/configure +++ b/configure @@ -679,6 +679,9 @@ CAIROLIB CAIROINC CAIRO_HAS_XRENDER_SURFACE CAIRO_HAS_XRENDER +USE_LIBSVM +LIBSVM_LIB +LIBSVM_INC LAPACKINC LAPACKLIB BLASINC @@ -902,6 +905,7 @@ with_odbc with_fftw with_blas with_lapack +with_libsvm with_cairo with_freetype with_nls @@ -948,6 +952,8 @@ with_blas_includes with_blas_libs with_lapack_includes with_lapack_libs +with_libsvm_includes +with_libsvm_libs with_cairo_includes with_cairo_libs with_cairo_ldflags @@ -1634,6 +1640,7 @@ Optional Packages: --with-fftw support FFTW functionality (default: yes) --with-blas support BLAS functionality (default: no) --with-lapack support LAPACK functionality (default: no) + --with-libsvm support LIBSVM functionality (default: no) --with-cairo support Cairo functionality (default: yes) --with-freetype support FreeType functionality (default: yes) --with-nls support NLS functionality (default: no) @@ -1709,6 +1716,9 @@ Optional Packages: --with-lapack-includes=DIRS LAPACK include files are in DIRS --with-lapack-libs=DIRS LAPACK library files are in DIRS + --with-libsvm-includes=DIRS + LIBSVM include files are in DIRS + --with-libsvm-libs=DIRS LIBSVM library files are in DIRS --with-cairo-includes=DIRS cairo include files are in DIRS --with-cairo-libs=DIRS cairo library files are in DIRS @@ -5468,6 +5478,17 @@ fi +# Check whether --with-libsvm was given. +if test ${with_libsvm+y} +then : + withval=$with_libsvm; +else $as_nop + with_libsvm=no +fi + + + + # Check whether --with-cairo was given. if test ${with_cairo+y} then : @@ -5921,6 +5942,25 @@ fi +# Check whether --with-libsvm-includes was given. +if test ${with_libsvm_includes+y} +then : + withval=$with_libsvm_includes; +fi + + + + +# Check whether --with-libsvm-libs was given. +if test ${with_libsvm_libs+y} +then : + withval=$with_libsvm_libs; +fi + + + + + # Check whether --with-cairo-includes was given. if test ${with_cairo_includes+y} then : @@ -14519,6 +14559,165 @@ fi # $USE_BLAS # Done checking LAPACK +# libsvm option +LIBSVM_INC= +LIBSVM_LIB= +USE_LIBSVM= + + +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking whether to use LIBSVM" >&5 +printf %s "checking whether to use LIBSVM... " >&6; } +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: \"$with_libsvm\"" >&5 +printf "%s\n" "\"$with_libsvm\"" >&6; } +case "$with_libsvm" in + "no") USE_LIBSVM= ;; + "yes") USE_LIBSVM="1" ;; + *) as_fn_error $? "*** You must answer yes or no." "$LINENO" 5 ;; +esac + + + +if test -n "$USE_LIBSVM"; then + +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for location of LIBSVM includes" >&5 +printf %s "checking for location of LIBSVM includes... " >&6; } +case "$with_libsvm_includes" in +y | ye | yes | n | no) + as_fn_error $? "*** You must supply a directory to --with-libsvm-includes." "$LINENO" 5 + ;; +esac +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: $with_libsvm_includes" >&5 +printf "%s\n" "$with_libsvm_includes" >&6; } + +if test -n "$with_libsvm_includes" ; then + for dir in $with_libsvm_includes; do + if test -d "$dir"; then + LIBSVM_INC="$LIBSVM_INC -I$dir" + else + as_fn_error $? "*** LIBSVM includes directory $dir does not exist." "$LINENO" 5 + fi + done +fi + + +ac_save_cppflags="$CPPFLAGS" +CPPFLAGS="$LIBSVM_INC $CPPFLAGS" + for ac_header in svm.h +do : + ac_fn_c_check_header_compile "$LINENO" "svm.h" "ac_cv_header_svm_h" "$ac_includes_default" +if test "x$ac_cv_header_svm_h" = xyes +then : + printf "%s\n" "#define HAVE_SVM_H 1" >>confdefs.h + +else $as_nop + +ac_save_cppflags="$CPPFLAGS" +CPPFLAGS="$LIBSVM_INC $CPPFLAGS" + for ac_header in libsvm/svm.h +do : + ac_fn_c_check_header_compile "$LINENO" "libsvm/svm.h" "ac_cv_header_libsvm_svm_h" "$ac_includes_default" +if test "x$ac_cv_header_libsvm_svm_h" = xyes +then : + printf "%s\n" "#define HAVE_LIBSVM_SVM_H 1" >>confdefs.h + +else $as_nop + + as_fn_error $? "*** Unable to locate LIBSVM includes." "$LINENO" 5 + +fi + +done +CPPFLAGS=$ac_save_cppflags + + +fi + +done +CPPFLAGS=$ac_save_cppflags + + +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for location of libsvm library" >&5 +printf %s "checking for location of libsvm library... " >&6; } +case "$with_libsvm_libs" in +y | ye | yes | n | no) + as_fn_error $? "*** You must supply a directory to --with-libsvm-libs." "$LINENO" 5 + ;; +esac +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: $with_libsvm_libs" >&5 +printf "%s\n" "$with_libsvm_libs" >&6; } + +if test -n "$with_libsvm_libs"; then + for dir in $with_libsvm_libs; do + if test -d "$dir"; then + LIBSVM_LIB="$LIBSVM_LIB -L$dir" + else + as_fn_error $? "*** libsvm library directory $dir does not exist." "$LINENO" 5 + fi + done +fi + + +ac_save_ldflags="$LDFLAGS" +LDFLAGS="$LIBSVM_LIB $LDFLAGS" + + +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: checking for svm_load_model in -lsvm" >&5 +printf %s "checking for svm_load_model in -lsvm... " >&6; } + +ac_check_lib_save_LIBS=$LIBS +LIBS="-lsvm $LIBS" +cat confdefs.h - <<_ACEOF >conftest.$ac_ext +/* end confdefs.h. */ + +/* Override any GCC internal prototype to avoid an error. + Use char because int might match the return type of a GCC + builtin and then its argument prototype would still apply. */ +char svm_load_model (); +int +main (void) +{ +return svm_load_model (); + ; + return 0; +} +_ACEOF +if ac_fn_c_try_link "$LINENO" +then : + ac_cv_lib_svm_svm_load_model=yes +else $as_nop + ac_cv_lib_svm_svm_load_model=no +fi +rm -f core conftest.err conftest.$ac_objext conftest.beam \ + conftest$ac_exeext conftest.$ac_ext +LIBS=$ac_check_lib_save_LIBS +{ printf "%s\n" "$as_me:${as_lineno-$LINENO}: result: $ac_cv_lib_svm_svm_load_model" >&5 +printf "%s\n" "$ac_cv_lib_svm_svm_load_model" >&6; } +if test "x$ac_cv_lib_svm_svm_load_model" = xyes +then : + LIBSVM_LIB="$LIBSVM_LIB -lsvm " +else $as_nop + +LDFLAGS=${ac_save_ldflags} + + as_fn_error $? "*** Unable to locate LIBSVM library." "$LINENO" 5 + + +fi + + + +LDFLAGS=${ac_save_ldflags} + + +printf "%s\n" "#define HAVE_LIBSVM 1" >>confdefs.h + +fi + + + + +# Done with LIBSVM + # Enable Cairo display driver option @@ -17436,6 +17635,8 @@ echo " Large File support (LFS): `if test -n "${USE_LARGEFILES}" ; then echo echo " libLAS support: `if test -n "${USE_LIBLAS}" ; then echo yes ; else echo no ; fi`" +echo " LIBSVM support: `if test -n "${USE_LIBSVM}" ; then echo yes ; else echo no ; fi`" + echo " MySQL support: `if test -n "${USE_MYSQL}" ; then echo yes ; else echo no ; fi`" echo " NetCDF support: `if test -n "${USE_NETCDF}" ; then echo yes ; else echo no ; fi`" diff --git a/configure.ac b/configure.ac index 2f1bfc69218..3ad17a73c5a 100644 --- a/configure.ac +++ b/configure.ac @@ -296,6 +296,7 @@ LOC_ARG_WITH(odbc, ODBC, no) LOC_ARG_WITH(fftw, FFTW) LOC_ARG_WITH(blas, BLAS, no) LOC_ARG_WITH(lapack, LAPACK, no) +LOC_ARG_WITH(libsvm, LIBSVM, no) LOC_ARG_WITH(cairo, Cairo) LOC_ARG_WITH(freetype, FreeType) LOC_ARG_WITH(nls, NLS, no) @@ -385,6 +386,9 @@ LOC_ARG_WITH_LIB(blas, BLAS) LOC_ARG_WITH_INC(lapack, LAPACK) LOC_ARG_WITH_LIB(lapack, LAPACK) +LOC_ARG_WITH_INC(libsvm, LIBSVM) +LOC_ARG_WITH_LIB(libsvm, LIBSVM) + LOC_ARG_WITH_INC(cairo, cairo) LOC_ARG_WITH_LIB(cairo, cairo) LOC_ARG_WITH_LDFLAGS(cairo, cairo) @@ -1696,6 +1700,28 @@ AC_SUBST(LAPACKINC) # Done checking LAPACK +# libsvm option +LIBSVM_INC= +LIBSVM_LIB= +USE_LIBSVM= + +LOC_CHECK_USE(libsvm,LIBSVM,USE_LIBSVM) + +if test -n "$USE_LIBSVM"; then + LOC_CHECK_INC_PATH(libsvm,LIBSVM,LIBSVM_INC) + LOC_CHECK_INCLUDES(svm.h,LIBSVM,$LIBSVM_INC, [ + LOC_CHECK_INCLUDES(libsvm/svm.h,LIBSVM,$LIBSVM_INC) + ]) + LOC_CHECK_LIB_PATH(libsvm,libsvm,LIBSVM_LIB) + LOC_CHECK_LIBS(svm,svm_load_model,LIBSVM,$LIBSVM_LIB,LIBSVM_LIB,,,) + AC_DEFINE(HAVE_LIBSVM, 1, [Define to 1 if using LIBSVM.]) +fi + +AC_SUBST(LIBSVM_INC) +AC_SUBST(LIBSVM_LIB) +AC_SUBST(USE_LIBSVM) +# Done with LIBSVM + # Enable Cairo display driver option LOC_CHECK_USE(cairo,Cairo,USE_CAIRO) @@ -2022,6 +2048,7 @@ LOC_MSG_USE(GEOS support,USE_GEOS) LOC_MSG_USE(LAPACK support,USE_LAPACK) LOC_MSG_USE(Large File support (LFS), USE_LARGEFILES) LOC_MSG_USE(libLAS support,USE_LIBLAS) +LOC_MSG_USE(LIBSVM support,USE_LIBSVM) LOC_MSG_USE(MySQL support,USE_MYSQL) LOC_MSG_USE(NetCDF support,USE_NETCDF) LOC_MSG_USE(NLS support,USE_NLS) diff --git a/gui/wxpython/gui_core/gselect.py b/gui/wxpython/gui_core/gselect.py index 3f819d7d849..89344ee725b 100644 --- a/gui/wxpython/gui_core/gselect.py +++ b/gui/wxpython/gui_core/gselect.py @@ -3096,7 +3096,7 @@ def __init__( def UpdateItems(self, element): """Update list of signature files for given element - :param str element: signatures/sig or signatures/sigset + :param str element: signatures/sig, signatures/sigset or signatures/libsvm """ items = [] if self.mapsets: @@ -3120,6 +3120,7 @@ def _append_mapset_signatures(self, mapset, element, items): from grass.lib.imagery import ( I_SIGFILE_TYPE_SIG, I_SIGFILE_TYPE_SIGSET, + I_SIGFILE_TYPE_LIBSVM, I_signatures_list_by_type, I_free_signatures_list, ) @@ -3133,6 +3134,8 @@ def _append_mapset_signatures(self, mapset, element, items): sig_type = I_SIGFILE_TYPE_SIG elif element == "signatures/sigset": sig_type = I_SIGFILE_TYPE_SIGSET + elif element == "signatures/libsvm": + sig_type = I_SIGFILE_TYPE_LIBSVM else: return list_ptr = ctypes.POINTER(ctypes.c_char_p) @@ -3151,7 +3154,7 @@ def __init__( ): super().__init__(parent, id, size=size, **kwargs) self.SetName("SignatureTypeSelect") - self.SetItems(["sig", "sigset"]) + self.SetItems(["sig", "sigset", "libsvm"]) class SeparatorSelect(wx.ComboBox): diff --git a/gui/wxpython/xml/toolboxes.xml b/gui/wxpython/xml/toolboxes.xml index 6ea83fe471c..4e8a4b7c3fb 100644 --- a/gui/wxpython/xml/toolboxes.xml +++ b/gui/wxpython/xml/toolboxes.xml @@ -1559,6 +1559,13 @@ + + + + + + + diff --git a/imagery/Makefile b/imagery/Makefile index f08a7cf72f2..76fa69cb656 100644 --- a/imagery/Makefile +++ b/imagery/Makefile @@ -32,6 +32,8 @@ SUBDIRS = \ i.segment \ i.signatures \ i.smap \ + i.svm.predict \ + i.svm.train \ i.target \ i.topo.corr \ i.pca \ diff --git a/imagery/i.signatures/main.c b/imagery/i.signatures/main.c index 681037db69b..58ad0e5f9e2 100644 --- a/imagery/i.signatures/main.c +++ b/imagery/i.signatures/main.c @@ -48,6 +48,9 @@ void print_json(const char *type, I_SIGFILE_TYPE sigtype, const char *mapset) fprintf(stdout, "],\n"); fprintf(stdout, "\"sigset\": ["); print_inline(I_SIGFILE_TYPE_SIGSET, mapset); + fprintf(stdout, "],\n"); + fprintf(stdout, "\"libsvm\": ["); + print_inline(I_SIGFILE_TYPE_LIBSVM, mapset); fprintf(stdout, "]\n"); } else if (sigtype == I_SIGFILE_TYPE_SIG) { @@ -60,6 +63,11 @@ void print_json(const char *type, I_SIGFILE_TYPE sigtype, const char *mapset) print_inline(I_SIGFILE_TYPE_SIGSET, mapset); fprintf(stdout, "]\n"); } + else if (sigtype == I_SIGFILE_TYPE_LIBSVM) { + fprintf(stdout, "\"libsvm\": ["); + print_inline(I_SIGFILE_TYPE_LIBSVM, mapset); + fprintf(stdout, "]\n"); + } fprintf(stdout, "}\n"); fflush(stdout); } @@ -104,7 +112,7 @@ int main(int argc, char *argv[]) parms.type->type = TYPE_STRING; parms.type->key_desc = "name"; parms.type->required = NO; - parms.type->options = "sig,sigset"; + parms.type->options = "sig,sigset,libsvm"; parms.type->guidependency = "remove,rename,copy"; parms.type->gisprompt = "old,sigtype,sigtype"; parms.type->description = _("Type of signature file"); @@ -174,6 +182,8 @@ int main(int argc, char *argv[]) sigtype = I_SIGFILE_TYPE_SIG; else if (strcmp(parms.type->answer, "sigset") == 0) sigtype = I_SIGFILE_TYPE_SIGSET; + else if (strcmp(parms.type->answer, "libsvm") == 0) + sigtype = I_SIGFILE_TYPE_LIBSVM; if (parms.copy->answers) { int i = 0; @@ -214,6 +224,8 @@ int main(int argc, char *argv[]) print_plain("sig", I_SIGFILE_TYPE_SIG, parms.mapset->answer); print_plain("sigset", I_SIGFILE_TYPE_SIGSET, parms.mapset->answer); + print_plain("libsvm", I_SIGFILE_TYPE_LIBSVM, + parms.mapset->answer); } else print_json(NULL, sigtype, parms.mapset->answer); diff --git a/imagery/i.signatures/testsuite/test_i_signatures.py b/imagery/i.signatures/testsuite/test_i_signatures.py index 93c5158f014..34f38774bc0 100644 --- a/imagery/i.signatures/testsuite/test_i_signatures.py +++ b/imagery/i.signatures/testsuite/test_i_signatures.py @@ -35,18 +35,28 @@ def setUpClass(cls): # tools, we must ensure signature directories exist os.makedirs(f"{cls.mpath}/signatures/sig/", exist_ok=True) os.makedirs(f"{cls.mpath}/signatures/sigset/", exist_ok=True) + os.makedirs(f"{cls.mpath}/signatures/libsvm/", exist_ok=True) + # Fake signature of sig type cls.sig_name1 = tempname(10) sig_dir1 = f"{cls.mpath}/signatures/sig/{cls.sig_name1}" os.makedirs(sig_dir1) cls.sigdirs.append(sig_dir1) sigfile_name1 = f"{sig_dir1}/sig" open(sigfile_name1, "a").close() + # Fake signature of sigset type cls.sig_name2 = tempname(10) sig_dir2 = f"{cls.mpath}/signatures/sigset/{cls.sig_name2}" os.makedirs(sig_dir2) cls.sigdirs.append(sig_dir2) sigfile_name2 = f"{sig_dir2}/sig" open(sigfile_name2, "a").close() + # Fake signature of libsvm type + cls.sig_name3 = tempname(10) + sig_dir3 = f"{cls.mpath}/signatures/libsvm/{cls.sig_name3}" + os.makedirs(sig_dir3) + cls.sigdirs.append(sig_dir3) + sigfile_name3 = f"{sig_dir3}/sig" + open(sigfile_name3, "a").close() @classmethod def tearDownClass(cls): @@ -62,6 +72,7 @@ def test_print_all_plain(self): self.assertTrue(i_sig.outputs.stdout) self.assertIn(self.sig_name1, i_sig.outputs.stdout) self.assertIn(self.sig_name2, i_sig.outputs.stdout) + self.assertIn(self.sig_name3, i_sig.outputs.stdout) def test_print_type_plain(self): """ @@ -73,12 +84,21 @@ def test_print_type_plain(self): self.assertTrue(i_sig.outputs.stdout) self.assertIn(self.sig_name1, i_sig.outputs.stdout) self.assertNotIn(self.sig_name2, i_sig.outputs.stdout) + self.assertNotIn(self.sig_name3, i_sig.outputs.stdout) # Case for sigset i_sig = SimpleModule("i.signatures", type="sigset", flags="p") self.assertModule(i_sig) self.assertTrue(i_sig.outputs.stdout) self.assertNotIn(self.sig_name1, i_sig.outputs.stdout) self.assertIn(self.sig_name2, i_sig.outputs.stdout) + self.assertNotIn(self.sig_name3, i_sig.outputs.stdout) + # Case for libsvm + i_sig = SimpleModule("i.signatures", type="libsvm", flags="p") + self.assertModule(i_sig) + self.assertTrue(i_sig.outputs.stdout) + self.assertNotIn(self.sig_name1, i_sig.outputs.stdout) + self.assertNotIn(self.sig_name2, i_sig.outputs.stdout) + self.assertIn(self.sig_name3, i_sig.outputs.stdout) def test_print_all_json(self): """ @@ -90,6 +110,7 @@ def test_print_all_json(self): json_out = json.loads(i_sig.outputs.stdout) self.assertIn(f"{self.sig_name1}@{self.mapset_name}", json_out["sig"]) self.assertIn(f"{self.sig_name2}@{self.mapset_name}", json_out["sigset"]) + self.assertIn(f"{self.sig_name3}@{self.mapset_name}", json_out["libsvm"]) def test_print_type_json(self): """ @@ -102,6 +123,7 @@ def test_print_type_json(self): json_out = json.loads(i_sig.outputs.stdout) self.assertIn(f"{self.sig_name1}@{self.mapset_name}", json_out["sig"]) self.assertNotIn("sigset", json_out.keys()) + self.assertNotIn("libsvm", json_out.keys()) # Case for sigset i_sig = SimpleModule("i.signatures", type="sigset", format="json", flags="p") self.assertModule(i_sig) @@ -109,6 +131,15 @@ def test_print_type_json(self): json_out = json.loads(i_sig.outputs.stdout) self.assertIn(f"{self.sig_name2}@{self.mapset_name}", json_out["sigset"]) self.assertNotIn("sig", json_out.keys()) + self.assertNotIn("libsvm", json_out.keys()) + # Case for libsvm + i_sig = SimpleModule("i.signatures", type="libsvm", format="json", flags="p") + self.assertModule(i_sig) + self.assertTrue(i_sig.outputs.stdout) + json_out = json.loads(i_sig.outputs.stdout) + self.assertIn(f"{self.sig_name3}@{self.mapset_name}", json_out["libsvm"]) + self.assertNotIn("sig", json_out.keys()) + self.assertNotIn("sigset", json_out.keys()) class ManageSignaturesTestCase(TestCase): @@ -121,30 +152,49 @@ def setUpClass(cls): # tools, we must ensure signature directories exist os.makedirs(f"{cls.mpath}/signatures/sig/", exist_ok=True) os.makedirs(f"{cls.mpath}/signatures/sigset/", exist_ok=True) + os.makedirs(f"{cls.mpath}/signatures/libsvm/", exist_ok=True) + # sig cls.sig_name1 = tempname(10) sig_dir1 = f"{cls.mpath}/signatures/sig/{cls.sig_name1}" os.makedirs(sig_dir1) cls.sigdirs.append(sig_dir1) sigfile_name1 = f"{sig_dir1}/sig" open(sigfile_name1, "a").close() + # sig cls.sig_name2 = tempname(10) sig_dir2 = f"{cls.mpath}/signatures/sig/{cls.sig_name2}" os.makedirs(sig_dir2) cls.sigdirs.append(sig_dir2) sigfile_name2 = f"{sig_dir2}/sig" open(sigfile_name2, "a").close() + # sigset cls.sig_name3 = tempname(10) sig_dir3 = f"{cls.mpath}/signatures/sigset/{cls.sig_name3}" os.makedirs(sig_dir3) cls.sigdirs.append(sig_dir3) sigfile_name3 = f"{sig_dir3}/sig" open(sigfile_name3, "a").close() + # sigset cls.sig_name4 = tempname(10) sig_dir4 = f"{cls.mpath}/signatures/sigset/{cls.sig_name4}" os.makedirs(sig_dir4) cls.sigdirs.append(sig_dir4) sigfile_name4 = f"{sig_dir4}/sig" open(sigfile_name4, "a").close() + # libsvm + cls.sig_name5 = tempname(10) + sig_dir5 = f"{cls.mpath}/signatures/libsvm/{cls.sig_name5}" + os.makedirs(sig_dir5) + cls.sigdirs.append(sig_dir5) + sigfile_name5 = f"{sig_dir5}/sig" + open(sigfile_name5, "a").close() + # libsvm + cls.sig_name6 = tempname(10) + sig_dir6 = f"{cls.mpath}/signatures/libsvm/{cls.sig_name6}" + os.makedirs(sig_dir6) + cls.sigdirs.append(sig_dir6) + sigfile_name6 = f"{sig_dir6}/sig" + open(sigfile_name6, "a").close() @classmethod def tearDownClass(cls): @@ -169,6 +219,7 @@ def test_copy(self): self.assertIn(f"{self.sig_name1}@{self.mapset_name}", json_out["sig"]) self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["sig"]) self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["sigset"]) + self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["libsvm"]) # If all is correct, copy should succeed i_sig = SimpleModule("i.signatures", type="sig", copy=(self.sig_name1, a_copy)) @@ -180,6 +231,7 @@ def test_copy(self): self.assertIn(f"{self.sig_name1}@{self.mapset_name}", json_out["sig"]) self.assertIn(f"{a_copy}@{self.mapset_name}", json_out["sig"]) self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["sigset"]) + self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["libsvm"]) def test_rename(self): a_copy = tempname(10) @@ -199,6 +251,7 @@ def test_rename(self): self.assertIn(f"{self.sig_name2}@{self.mapset_name}", json_out["sig"]) self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["sig"]) self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["sigset"]) + self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["libsvm"]) # If all is correct, rename should succeed i_sig = SimpleModule( @@ -211,8 +264,10 @@ def test_rename(self): json_out = json.loads(l_sig.outputs.stdout) self.assertNotIn(f"{self.sig_name2}@{self.mapset_name}", json_out["sig"]) self.assertNotIn(f"{self.sig_name2}@{self.mapset_name}", json_out["sigset"]) + self.assertNotIn(f"{self.sig_name2}@{self.mapset_name}", json_out["libsvm"]) self.assertIn(f"{a_copy}@{self.mapset_name}", json_out["sig"]) self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["sigset"]) + self.assertNotIn(f"{a_copy}@{self.mapset_name}", json_out["libsvm"]) def test_remove(self): # Fail if type is not provided @@ -229,6 +284,8 @@ def test_remove(self): self.assertIn(f"{self.sig_name3}@{self.mapset_name}", json_out["sigset"]) self.assertIn(f"{self.sig_name4}@{self.mapset_name}", json_out["sigset"]) self.assertIn(f"{self.sig_name1}@{self.mapset_name}", json_out["sig"]) + self.assertIn(f"{self.sig_name5}@{self.mapset_name}", json_out["libsvm"]) + self.assertIn(f"{self.sig_name6}@{self.mapset_name}", json_out["libsvm"]) # If all is correct, remove should succeed i_sig = SimpleModule("i.signatures", type="sigset", remove=self.sig_name3) @@ -240,6 +297,8 @@ def test_remove(self): self.assertNotIn(f"{self.sig_name3}@{self.mapset_name}", json_out["sigset"]) self.assertIn(f"{self.sig_name4}@{self.mapset_name}", json_out["sigset"]) self.assertIn(f"{self.sig_name1}@{self.mapset_name}", json_out["sig"]) + self.assertIn(f"{self.sig_name5}@{self.mapset_name}", json_out["libsvm"]) + self.assertIn(f"{self.sig_name6}@{self.mapset_name}", json_out["libsvm"]) if __name__ == "__main__": diff --git a/imagery/i.svm.predict/Makefile b/imagery/i.svm.predict/Makefile new file mode 100644 index 00000000000..d20b31a3cbe --- /dev/null +++ b/imagery/i.svm.predict/Makefile @@ -0,0 +1,14 @@ +MODULE_TOPDIR = ../.. + +PGM = i.svm.predict + +LIBES = $(RASTERLIB) $(IMAGERYLIB) $(GISLIB) $(LIBSVM_LIB) +DEPENDENCIES = $(RASTERDEP) $(IMAGERYDEP) $(GISDEP) + +EXTRA_INC = $(LIBSVM_INC) + +include $(MODULE_TOPDIR)/include/Make/Module.make + +ifneq ($(USE_LIBSVM),) +default: cmd +endif diff --git a/imagery/i.svm.predict/i.svm.predict.html b/imagery/i.svm.predict/i.svm.predict.html new file mode 100644 index 00000000000..3ddf4f33879 --- /dev/null +++ b/imagery/i.svm.predict/i.svm.predict.html @@ -0,0 +1,82 @@ +

DESCRIPTION

+ +

i.svm.predict predicts values with a Support Vector Machine (SVM) +and stores them in a raster file. Predictions are based on a signature file +generated with i.svm.train. +

+ +

Internally the module performs input value rescaling of each of imagery +group rasters by minimum and maximum range determined during training.

+ +

NOTES

+ +

i.svm.train internally is using the LIBSVM. For introduction +into value prediction or estimation with LIBSVM, see +a +Practical Guide to Support Vector Classification by +Chih-Wei Hsu, Chih-Chung Chang, and Chih-Jen Lin.

+ +

It is strongly suggested to have semantic labels set for each raster +map in the training data (feature value) and in value prediction imagery groups. +Use r.support to set semantic labels.

+ +

PERFORMANCE

+ +

Value prediction is done cell by cell and thus memory consumption +should be constant.

+ +

The cache parameter determines the maximum memory allocated +for kernel caching to enhance computational speed. It's important to +note that the actual module's memory consumption may vary from this +setting, as it solely impacts LIBSVM's internal caching. The cache is +utilized on an as-needed basis, so it's unlikely to reach the specified value.

+ +

EXAMPLE

+ +

This is the second part of classification process. See +i.svm.train for the first part.

+ +

Predict land use classes form a LANDSAT scene from +October of 2002 with a SVM trained on a 1996 land +use map landuse96_28m.

+
+i.svm.predict group=lsat7_2002 subgroup=res_30m \
+    signaturefile=landuse96_rnd_points output=pred_landuse_2002
+
+ +

SEE ALSO

+ + +Train SVM: i.svm.train
+Set semantic labels: r.support
+Other classification modules: i.maxlik, +i.smap +

+LIBSVM home page: LIBSVM - A +Library for Support Vector Machines + +

REFERENCES

+ +

Please cite both - LIBSVM and i.svm.

+
    +
  • + For i.svm.* modules:
    + Nartiss, M., & Melniks, R. (2023). Improving pixel-­based classification of GRASS + GIS with support vector machine. Transactions in GIS, 00, 1–16. + https://doi.org/10.1111/tgis.13102 +
  • +
  • + For LIBSVM:
    + Chang, C.-C., & Lin, C.-J. (2011). LIBSVM : a library for support vector machines. + ACM Transactions on Intelligent Systems and Technology, 2:27:1--27:27. +
  • +
+ +

AUTHORS

+ +Maris Nartiss, University of Latvia. + + diff --git a/imagery/i.svm.predict/main.c b/imagery/i.svm.predict/main.c new file mode 100644 index 00000000000..968950d5627 --- /dev/null +++ b/imagery/i.svm.predict/main.c @@ -0,0 +1,424 @@ + +/**************************************************************************** + * + * MODULE: i.svm.predict + * AUTHOR(S): Maris Nartiss - maris.gis gmail.com + * PURPOSE: Predicts values with Support Vector Machine classifier + * + * COPYRIGHT: (C) 2023 by Maris Nartiss and the GRASS Development Team + * + * This program is free software under the GNU General Public + * License (>=v2). Read the file COPYING that comes with GRASS + * for details. + * + * Development of this module was supported from + * science funding of University of Latvia (2020-2023). + * + *****************************************************************************/ +#include +#include +#include +#include + +#include +#if HAVE_SVM_H +#include +#elif HAVE_LIBSVM_SVM_H +#include +#endif + +#include +#include +#include +#include + +/* LIBSVM message wrapper */ +void print_func(const char *s) +{ + G_verbose_message("%s", s); +} + +int main(int argc, char *argv[]) +{ + struct GModule *module; + struct Option *opt_group, *opt_subgroup, *opt_sigfile, *opt_values; + struct Option *opt_svm_cache_size; + + G_gisinit(argv[0]); + + module = G_define_module(); + G_add_keyword(_("imagery")); + G_add_keyword(_("svm")); + G_add_keyword(_("classification")); + G_add_keyword(_("prediction")); + G_add_keyword(_("regression")); + module->label = _("Predict with a SVM"); + module->description = _("Predict with a Support Vector Machine"); + + opt_group = G_define_standard_option(G_OPT_I_GROUP); + /* GTC: SVM prediction input */ + opt_group->description = _("Maps with feature values (attributes)"); + + opt_subgroup = G_define_standard_option(G_OPT_I_SUBGROUP); + opt_subgroup->required = NO; + + opt_sigfile = G_define_option(); + opt_sigfile->key = "signaturefile"; + opt_sigfile->type = TYPE_STRING; + opt_sigfile->key_desc = "name"; + opt_sigfile->required = YES; + opt_sigfile->gisprompt = "old,signatures/libsvm,sigfile"; + opt_sigfile->description = _("Name of input file containing signatures"); + + opt_values = G_define_standard_option(G_OPT_R_OUTPUT); + opt_values->required = YES; + opt_values->description = + _("Output map with predicted class or calculated value"); + + opt_svm_cache_size = G_define_option(); + opt_svm_cache_size->key = "cache"; + opt_svm_cache_size->type = TYPE_INTEGER; + opt_svm_cache_size->key_desc = "cache size"; + opt_svm_cache_size->required = NO; + opt_svm_cache_size->options = "1-"; + opt_svm_cache_size->answer = "512"; + opt_svm_cache_size->description = _("LIBSVM kernel cache size in MB"); + + if (G_parser(argc, argv)) + exit(EXIT_FAILURE); + + /* Input validation */ + /* Input maps */ + char name_values[GNAME_MAX], name_sigfile[GNAME_MAX]; + char name_group[GNAME_MAX], name_subgroup[GNAME_MAX]; + char mapset_values[GMAPSET_MAX], mapset_sigfile[GMAPSET_MAX]; + char mapset_group[GMAPSET_MAX], mapset_subgroup[GMAPSET_MAX]; + char sigfile_dir[GPATH_MAX], model_file[GPATH_MAX]; + if (G_unqualified_name(opt_group->answer, NULL, name_group, mapset_group) == + 0) + strcpy(mapset_group, G_mapset()); + if (opt_subgroup->answer && + G_unqualified_name(opt_subgroup->answer, NULL, name_subgroup, + mapset_subgroup) != 0 && + strcmp(mapset_subgroup, mapset_group) != 0) + G_fatal_error(_("Invalid subgroup <%s> provided"), + opt_subgroup->answer); + if (!I_find_group2(name_group, mapset_group)) { + G_fatal_error(_("Group <%s> not found in mapset <%s>"), name_group, + mapset_group); + } + if (opt_subgroup->answer && + !I_find_subgroup2(name_group, name_subgroup, mapset_group)) { + G_fatal_error(_("Subgroup <%s> in group <%s@%s> not found"), + name_subgroup, name_group, mapset_group); + } + + if (G_unqualified_name(opt_sigfile->answer, NULL, name_sigfile, + mapset_sigfile) == 0) + strcpy(mapset_sigfile, G_mapset()); + if (!I_find_signature2(I_SIGFILE_TYPE_LIBSVM, name_sigfile, mapset_sigfile)) + G_fatal_error(_("Signature file <%s@%s> not found"), name_sigfile, + mapset_sigfile); + + if (G_unqualified_name(opt_values->answer, G_mapset(), name_values, + mapset_values) < 0) + G_fatal_error(_("<%s> does not match the current mapset"), + mapset_values); + if (G_legal_filename(name_values) < 0) + G_fatal_error(_("<%s> is an illegal file name"), name_values); + + /* Get bands */ + struct Ref group_ref; + if (opt_subgroup->answer) { + if (!I_get_subgroup_ref2(name_group, opt_subgroup->answer, mapset_group, + &group_ref)) { + G_fatal_error( + _("There was an error reading subgroup <%s> in group <%s@%s>"), + opt_subgroup->answer, name_group, mapset_group); + } + } + else { + if (!I_get_group_ref2(name_group, mapset_group, &group_ref)) { + G_fatal_error(_("There was an error reading group <%s@%s>"), + name_group, mapset_group); + } + } + if (group_ref.nfiles <= 0) { + if (opt_subgroup->answer) + G_fatal_error( + _("Subgroup <%s> in group <%s@%s> contains no raster maps."), + opt_subgroup->answer, name_group, mapset_group); + else + G_fatal_error(_("Group <%s@%s> contains no raster maps."), + name_group, mapset_group); + } + const char **semantic_labels_group = + G_malloc(group_ref.nfiles * sizeof(char *)); + for (int n = 0; n < group_ref.nfiles; n++) { + semantic_labels_group[n] = Rast_get_semantic_label_or_name( + group_ref.file[n].name, group_ref.file[n].mapset); + } + + I_get_signatures_dir(sigfile_dir, I_SIGFILE_TYPE_LIBSVM); + /* Read signature file version */ + int sigfile_version; + FILE *misc_file = + G_fopen_old_misc(sigfile_dir, "version", name_sigfile, mapset_sigfile); + if (fscanf(misc_file, "%d", &sigfile_version) != 1) { + G_fatal_error(_("Invalid signature file")); + } + fclose(misc_file); + /* Current version number is 1 */ + if (sigfile_version != 1) { + G_fatal_error(_("Invalid signature file version")); + } + + /* Reorder group items to match order from the signature file (=training + * order) */ + misc_file = G_fopen_old_misc(sigfile_dir, "semantic_label", name_sigfile, + mapset_sigfile); + if (!misc_file) + G_fatal_error(_("Unable to read signature file '%s'."), name_sigfile); + char **names_ordered = G_malloc(group_ref.nfiles * sizeof(char *)); + char **mapsets_ordered = G_malloc(group_ref.nfiles * sizeof(char *)); + + char frmt[10]; + snprintf(frmt, sizeof(frmt), "%%%ds", GNAME_MAX - 1); + char semantic_label[GNAME_MAX]; + int semantic_label_count = 0, semantic_label_match_count = 0; + while (fscanf(misc_file, frmt, semantic_label) == 1) { + semantic_label_count++; + bool found = false; + + for (int n = 0; n < group_ref.nfiles; n++) { + if (semantic_label[0] != '\n' && + strcmp(semantic_label, semantic_labels_group[n]) == 0) { + semantic_label_match_count++; + found = true; + names_ordered[n] = group_ref.file[n].name; + mapsets_ordered[n] = group_ref.file[n].mapset; + break; + } + } + if (!found) + G_fatal_error(_("Imagery group does not contain a raster with a " + "semantic label '%s'"), + semantic_label); + } + fclose(misc_file); + if (semantic_label_match_count != semantic_label_count || + semantic_label_match_count != group_ref.nfiles) { + G_fatal_error(_("Unable to match all signature file bands to imagery " + "group bands. " + "Signature band count: %d, imagery group band count: " + "%d, band match count: %d."), + semantic_label_count, group_ref.nfiles, + semantic_label_match_count); + } + + /* Read rescaling parameters */ + DCELL *means, *ranges, mean, range; + int scale_count = 0; + misc_file = + G_fopen_old_misc(sigfile_dir, "scale", name_sigfile, mapset_sigfile); + if (!misc_file) + G_fatal_error(_("Unable to read signature file '%s'."), name_sigfile); + means = G_malloc(group_ref.nfiles * sizeof(DCELL)); + ranges = G_malloc(group_ref.nfiles * sizeof(DCELL)); + while (fscanf(misc_file, "%lf %lf", &mean, &range) == 2) { + if (scale_count >= group_ref.nfiles) + G_fatal_error(_("Unable to read signature file '%s'."), + name_sigfile); + means[scale_count] = mean; + ranges[scale_count] = range; + if (range == 0) + G_fatal_error(_("Unable to read signature file '%s'."), + name_sigfile); + scale_count++; + } + fclose(misc_file); + if (scale_count != group_ref.nfiles) + G_fatal_error(_("Unable to read signature file '%s'."), name_sigfile); + + /* Pass LIBSVM messages through GRASS */ + svm_set_print_string_function(&print_func); + + /* Load trained model from a file */ + struct svm_model *model; + G_verbose_message("Reading in trained SVM"); + G_file_name_misc(model_file, sigfile_dir, "sig", name_sigfile, + mapset_sigfile); + model = svm_load_model(model_file); + if (model == NULL) + G_fatal_error(_("Unable to open trained model file <%s>"), + name_sigfile); + + G_message(_("Starting value prediction process")); + /* For row, cell: svm_predict */ + int row, col, band; + int nrows, ncols; + int svm_type; + int fd_values = 0; + RASTER_MAP_TYPE out_type; + + int *fd_bands; + DCELL **buf_bands; + struct svm_node *nodes; + + svm_type = svm_get_svm_type(model); + + nrows = Rast_window_rows(); + ncols = Rast_window_cols(); + + buf_bands = (DCELL **)G_malloc(group_ref.nfiles * sizeof(DCELL *)); + fd_bands = (int *)G_calloc(group_ref.nfiles, sizeof(int)); + for (band = 0; band < group_ref.nfiles; band++) { + buf_bands[band] = Rast_allocate_d_buf(); + fd_bands[band] = + Rast_open_old(names_ordered[band], mapsets_ordered[band]); + } + nodes = (struct svm_node *)G_malloc(((size_t)group_ref.nfiles + 1) * + sizeof(struct svm_node)); + + /* Predict a class (C_SVC, NU_SVC, ONE_CLASS) + * Other SVM types calculate a value */ + if (svm_type == C_SVC || svm_type == NU_SVC || svm_type == ONE_CLASS) { + CELL *out_row; + DCELL val; + + out_row = Rast_allocate_c_buf(); + out_type = CELL_TYPE; + fd_values = Rast_open_c_new(name_values); + + for (row = 0; row < nrows; row++) { + G_percent(row, nrows, 2); + for (band = 0; band < group_ref.nfiles; band++) + Rast_get_d_row(fd_bands[band], &buf_bands[band][0], row); + for (col = 0; col < ncols; col++) { + nodes[0].index = -1; + for (band = 0; band < group_ref.nfiles; band++) { + if (Rast_is_d_null_value(&buf_bands[band][col])) + continue; + nodes[band].index = band; + nodes[band].value = + (buf_bands[band][col] - means[band]) / ranges[band]; + } + + /* All values where NULLs */ + if (nodes[0].index == -1) { + Rast_set_c_null_value(&out_row[col], 1); + continue; + } + /* Mark the end of values in nodes */ + nodes[group_ref.nfiles].index = -1; + + val = svm_predict(model, nodes); + out_row[col] = (CELL)val; + } + Rast_put_row(fd_values, out_row, out_type); + } + G_percent(1, 1, 1); + G_free(out_row); + } + else { + DCELL *out_row; + DCELL val; + + out_row = Rast_allocate_d_buf(); + out_type = DCELL_TYPE; + fd_values = Rast_open_fp_new(name_values); + + for (row = 0; row < nrows; row++) { + G_percent(row, nrows, 2); + for (band = 0; band < group_ref.nfiles; band++) + Rast_get_d_row(fd_bands[band], &buf_bands[band][0], row); + for (col = 0; col < ncols; col++) { + nodes[0].index = -1; + for (band = 0; band < group_ref.nfiles; band++) { + if (Rast_is_d_null_value(&buf_bands[band][col])) + continue; + nodes[band].index = band; + nodes[band].value = + (buf_bands[band][col] - means[band]) / ranges[band]; + } + + /* All values where NULLs */ + if (nodes[0].index == -1) { + Rast_set_d_null_value(&out_row[col], 1); + continue; + } + /* Mark the end of values in nodes */ + nodes[group_ref.nfiles].index = -1; + + val = svm_predict(model, nodes); + out_row[col] = val; + } + Rast_put_row(fd_values, out_row, out_type); + } + G_percent(1, 1, 1); + G_free(out_row); + } + + /* Clean up */ + Rast_close(fd_values); + for (band = 0; band < group_ref.nfiles; band++) { + Rast_close(fd_bands[band]); + G_free(buf_bands[band]); + } + G_free(nodes); + G_free(means); + G_free(ranges); + + /* Try to give full history */ + struct History history; + G_verbose_message("Writing out history"); + Rast_short_history(name_values, "raster", &history); + misc_file = + G_fopen_old_misc(sigfile_dir, "history", name_sigfile, mapset_sigfile); + if (misc_file != NULL) { + char hist_line[4096]; /* history lines are limited to 4096 */ + + while (G_getl(hist_line, sizeof(hist_line), misc_file) == 1) { + Rast_append_history(&history, hist_line); + } + fclose(misc_file); + } + Rast_command_history(&history); + if (opt_subgroup->answer) + Rast_format_history(&history, HIST_DATSRC_1, "Group/subgroup: %s@%s/%s", + name_group, mapset_group, opt_subgroup->answer); + else + Rast_format_history(&history, HIST_DATSRC_1, "Group: %s@%s", name_group, + mapset_group); + Rast_format_history(&history, HIST_DATSRC_2, "Signature file: %s@%s", + name_sigfile, mapset_sigfile); + Rast_write_history(name_values, &history); + + if (svm_type != ONE_CLASS) { + char in_path[GPATH_MAX], out_path[GPATH_MAX]; + + /* Copy CATs file from the original training map */ + G_verbose_message("Copying category information"); + G_file_name_misc(in_path, sigfile_dir, "cats", name_sigfile, + mapset_sigfile); + /* Avoid warnings if file does not exist */ + if (access(in_path, 0) == 0) { + G_file_name(out_path, "cats", name_values, G_mapset()); + G_copy_file(in_path, out_path); + } + + /* Copy color file from the original training map */ + G_verbose_message("Copying color information"); + G_file_name_misc(in_path, sigfile_dir, "colr", name_sigfile, + mapset_sigfile); + if (access(in_path, 0) == 0) { + G_file_name(out_path, "colr", name_values, G_mapset()); + G_copy_file(in_path, out_path); + } + } + Rast_put_cell_title(name_values, + /* GTC: A map title */ + _("Values predicted with a Support Vector Machine")); + + exit(EXIT_SUCCESS); +} diff --git a/imagery/i.svm.predict/testsuite/test_i_svm_predict.py b/imagery/i.svm.predict/testsuite/test_i_svm_predict.py new file mode 100644 index 00000000000..706af4f4fa2 --- /dev/null +++ b/imagery/i.svm.predict/testsuite/test_i_svm_predict.py @@ -0,0 +1,191 @@ +""" +Name: i.svm.predict input & output tests +Purpose: Validates user input validation code and output generation + +Author: Maris Nartiss +Copyright: (C) 2023 by Maris Nartiss and the GRASS Development Team +Licence: This program is free software under the GNU General Public + License (>=v2). Read the file COPYING that comes with GRASS + for details. +""" +import unittest +import shutil + +from grass.script import core as grass +from grass.gunittest.case import TestCase +from grass.gunittest.main import test +from grass.gunittest.gmodules import SimpleModule +from grass.pygrass.gis import Mapset +from grass.lib.imagery import ( + I_SIGFILE_TYPE_LIBSVM, + I_signatures_remove, +) + + +class IOValidationTest(TestCase): + """Test input validation and output generation with i.svm.predict""" + + @classmethod + @unittest.skipIf(shutil.which("i.svm.predict") is None, "i.svm.predict not found.") + def setUpClass(cls): + cls.tmp_rasts = [] + cls.tmp_groups = [] + cls.mapset_name = Mapset().name + # Small region for small testing rasters + cls.use_temp_region() + cls.runModule("g.region", n=10, s=0, e=10, w=0, res=1) + cls.rastt = grass.tempname(10) + cls.runModule( + "r.mapcalc", expression=f"{cls.rastt}=rand(0.0,1)", seed=1, quiet=True + ) + cls.tmp_rasts.append(cls.rastt) + cls.runModule("r.colors", _map=cls.rastt, color="grey", quiet=True) + cls.runModule( + "r.support", _map=cls.rastt, semantic_label="GRASS_RNDT", quiet=True + ) + # A raster without a semantic label + cls.rast1 = grass.tempname(10) + cls.runModule( + "r.mapcalc", expression=f"{cls.rast1}=rand(0.0,1)", seed=1, quiet=True + ) + cls.tmp_rasts.append(cls.rast1) + cls.runModule("r.colors", _map=cls.rast1, color="grey", quiet=True) + # A raster with a semantic label + cls.rast2 = grass.tempname(10) + cls.runModule( + "r.mapcalc", expression=f"{cls.rast2}=rand(0.0,1)", seed=1, quiet=True + ) + cls.tmp_rasts.append(cls.rast2) + cls.runModule( + "r.support", _map=cls.rast2, semantic_label="GRASS_RND1", quiet=True + ) + cls.rast3 = grass.tempname(10) + cls.runModule( + "r.mapcalc", expression=f"{cls.rast3}=rand(0.0,1)", seed=1, quiet=True + ) + cls.tmp_rasts.append(cls.rast3) + cls.runModule( + "r.support", _map=cls.rast3, semantic_label="GRASS_RND2", quiet=True + ) + # An empty imagery group + cls.group1 = grass.tempname(10) + cls.runModule("i.group", group=cls.group1, _input=(cls.rast1,), quiet=True) + cls.tmp_groups.append(cls.group1) + cls.runModule( + "i.group", flags="r", group=cls.group1, _input=(cls.rast1,), quiet=True + ) + # A good imagery group + cls.group3 = grass.tempname(10) + cls.runModule( + "i.group", group=cls.group3, _input=(cls.rast1, cls.rast3), quiet=True + ) + cls.tmp_groups.append(cls.group3) + # A imagery group with different semantic label count + cls.group4 = grass.tempname(10) + cls.runModule("i.group", group=cls.group4, _input=(cls.rast3), quiet=True) + cls.tmp_groups.append(cls.group4) + cls.group5 = grass.tempname(10) + cls.runModule( + "i.group", + group=cls.group5, + _input=(cls.rast1, cls.rast3, cls.rastt), + quiet=True, + ) + cls.tmp_groups.append(cls.group5) + # Generate a signature file + cls.sig1 = grass.tempname(10) + isvm = SimpleModule( + "i.svm.train", + group=cls.group3, + trainingmap=cls.rastt, + signaturefile=cls.sig1, + quiet=True, + ) + isvm.run() + + @classmethod + def tearDownClass(cls): + """Remove the temporary region and generated data""" + cls.del_temp_region() + for rast in cls.tmp_rasts: + cls.runModule("g.remove", flags="f", _type="raster", name=rast) + for group in cls.tmp_groups: + cls.runModule("g.remove", flags="f", _type="group", name=group) + I_signatures_remove(I_SIGFILE_TYPE_LIBSVM, cls.sig1) + + @unittest.skipIf(shutil.which("i.svm.predict") is None, "i.svm.predict not found.") + def test_empty_group(self): + """Empty imagery group handling""" + rast = grass.tempname(10) + isvm = SimpleModule( + "i.svm.predict", + group=self.group1, + output=rast, + signaturefile=self.sig1, + quiet=True, + ) + self.tmp_rasts.append(rast) + self.assertModuleFail(isvm) + self.assertTrue(isvm.outputs.stderr) + self.assertIn(self.group1, isvm.outputs.stderr) + + @unittest.skipIf(shutil.which("i.svm.predict") is None, "i.svm.predict not found.") + def test_semantic_label_mismatch1(self): + """There are more semantic labels in the signature file than in the group""" + rast = grass.tempname(10) + isvm = SimpleModule( + "i.svm.predict", + group=self.group4, + output=rast, + signaturefile=self.sig1, + quiet=True, + ) + self.tmp_rasts.append(rast) + self.assertModuleFail(isvm) + self.assertTrue(isvm.outputs.stderr) + self.assertIn( + "Imagery group does not contain a raster with a semantic label", + isvm.outputs.stderr, + ) + self.assertIn( + self.rast1, + isvm.outputs.stderr, + ) + + @unittest.skipIf(shutil.which("i.svm.predict") is None, "i.svm.predict not found.") + def test_semantic_label_mismatch2(self): + """There are more semantic labels in the group than in the signature file""" + rast = grass.tempname(10) + isvm = SimpleModule( + "i.svm.predict", + group=self.group5, + output=rast, + signaturefile=self.sig1, + quiet=True, + ) + self.tmp_rasts.append(rast) + self.assertModuleFail(isvm) + self.assertTrue(isvm.outputs.stderr) + self.assertTrue( + "Signature band count: 2, imagery group band count: 3" + in isvm.outputs.stderr + ) + + @unittest.skipIf(shutil.which("i.svm.predict") is None, "i.svm.predict not found.") + def test_prediction(self): + """A successful run""" + rast = grass.tempname(10) + isvm = SimpleModule( + "i.svm.predict", + group=self.group3, + output=rast, + signaturefile=self.sig1, + quiet=True, + ) + self.tmp_rasts.append(rast) + self.assertModule(isvm) + self.assertRasterExists(rast) + + +if __name__ == "__main__": + test() diff --git a/imagery/i.svm.train/Makefile b/imagery/i.svm.train/Makefile new file mode 100644 index 00000000000..19883b46eca --- /dev/null +++ b/imagery/i.svm.train/Makefile @@ -0,0 +1,14 @@ +MODULE_TOPDIR = ../.. + +PGM = i.svm.train + +LIBES = $(RASTERLIB) $(IMAGERYLIB) $(GISLIB) $(LIBSVM_LIB) +DEPENDENCIES = $(RASTERDEP) $(IMAGERYDEP) $(GISDEP) + +EXTRA_INC = $(LIBSVM_INC) + +include $(MODULE_TOPDIR)/include/Make/Module.make + +ifneq ($(USE_LIBSVM),) +default: cmd +endif diff --git a/imagery/i.svm.train/fill.c b/imagery/i.svm.train/fill.c new file mode 100644 index 00000000000..f43941bb1a4 --- /dev/null +++ b/imagery/i.svm.train/fill.c @@ -0,0 +1,102 @@ +/* + * i.svm.train Functions filling svm_problem struct + * + * Copyright 2023 by Maris Nartiss, and The GRASS Development Team + * Author: Maris Nartiss + * + * This program is free software licensed under the GPL (>=v2). + * Read the COPYING file that comes with GRASS for details. + * + */ +#include +#include + +#include "fill.h" + +void fill_problem(const char *name_labels, const char *mapset_labels, + struct Ref band_refs, const DCELL *means, const DCELL *ranges, + struct svm_problem *problem) +{ + /* Keep track of used svm_problem node head */ + int label_num = 0; + int label_max = 0; + problem->l = 0; + problem->x = NULL; + problem->y = NULL; + + int nrows = Rast_window_rows(); + int ncols = Rast_window_cols(); + + int fd_labels = Rast_open_old(name_labels, mapset_labels); + /* svm_problem always stores labels as doubles */ + DCELL *buf_labels = Rast_allocate_d_buf(); + + DCELL **buf_bands = (DCELL **)G_malloc(band_refs.nfiles * sizeof(DCELL *)); + int *fd_bands = (int *)G_calloc(band_refs.nfiles, sizeof(int)); + int band; + for (band = 0; band < band_refs.nfiles; band++) { + buf_bands[band] = Rast_allocate_d_buf(); + fd_bands[band] = Rast_open_old(band_refs.file[band].name, + band_refs.file[band].mapset); + } + + for (int row = 0; row < nrows; row++) { + G_percent(row, nrows, 10); + Rast_get_d_row(fd_labels, buf_labels, row); + for (band = 0; band < band_refs.nfiles; band++) + Rast_get_d_row(fd_bands[band], &buf_bands[band][0], row); + for (int col = 0; col < ncols; col++) { + if (Rast_is_d_null_value(&buf_labels[col])) + continue; + if (label_num >= label_max) { + label_max += SIZE_INCREMENT; + problem->y = + G_realloc(problem->y, (size_t)label_max * sizeof(double)); + problem->x = G_realloc( + problem->x, (size_t)label_max * sizeof(struct svm_node *)); + } + problem->l = label_num; + problem->y[label_num] = buf_labels[col]; + problem->x[label_num] = NULL; + int value_num = 0; + int value_max = 0; + for (band = 0; band < band_refs.nfiles; band++) { + if (Rast_is_d_null_value(&buf_bands[band][col])) + continue; + if (value_num >= value_max) { + /* Three bands are typical, thus we start with 4 nodes */ + value_max += 4; + problem->x[label_num] = G_realloc( + problem->x[label_num], + ((size_t)value_max + 1) * sizeof(struct svm_node)); + } + problem->x[label_num][value_num].index = band; + problem->x[label_num][value_num].value = + (buf_bands[band][col] - means[band]) / ranges[band]; + value_num++; + } + /* If label has no data */ + if (value_num == 0) { + continue; + } + problem->x[label_num][value_num].index = -1; + label_num++; + } + } + + /* Although there could be more memory allocated, not all might be filled */ + problem->l = label_num; + + /* Clean up */ + Rast_close(fd_labels); + G_free(buf_labels); + + for (band = 0; band < band_refs.nfiles; band++) { + Rast_close(fd_bands[band]); + G_free(buf_bands[band]); + } + G_free(fd_bands); + G_free(buf_bands); + G_percent(1, 1, 1); + G_percent_reset(); +} diff --git a/imagery/i.svm.train/fill.h b/imagery/i.svm.train/fill.h new file mode 100644 index 00000000000..e9c28caffec --- /dev/null +++ b/imagery/i.svm.train/fill.h @@ -0,0 +1,29 @@ +/* + * i.svm.train Functions filling svm_problem struct + * + * Copyright 2023 by Maris Nartiss, and The GRASS Development Team + * Author: Maris Nartiss + * + * This program is free software licensed under the GPL (>=v2). + * Read the COPYING file that comes with GRASS for details. + * + */ + +#include +#if HAVE_SVM_H +#include +#elif HAVE_LIBSVM_SVM_H +#include +#endif + +#include + +#ifndef FILL_H +#define FILL_H + +#define SIZE_INCREMENT 64; + +void fill_problem(const char *, const char *, struct Ref, const DCELL *, + const DCELL *, struct svm_problem *); + +#endif // FILL_H diff --git a/imagery/i.svm.train/i.svm.train.html b/imagery/i.svm.train/i.svm.train.html new file mode 100644 index 00000000000..ca26a1636d1 --- /dev/null +++ b/imagery/i.svm.train/i.svm.train.html @@ -0,0 +1,111 @@ +

DESCRIPTION

+ +

i.svm.train finds parameters for a Support Vector Machine +and stores them in a signature file for later usage by +i.svm.predict. +

+ +

Internally the module performs input value rescaling of each of imagery +group rasters by mean normalisation based on minimum and maximum value present +in the raster metadata. Rescaling parameters are +written into the signature file for use during prediction.

+ +

NOTES

+ +

i.svm.train internally is using the LIBSVM. For introduction +into value prediction or estimation with LIBSVM, see +a +Practical Guide to Support Vector Classification by +Chih-Wei Hsu, Chih-Chung Chang, and Chih-Jen Lin.

+ +

It is strongly suggested to have semantic labels set for each raster +map in the training data (feature value) imagery group. +Use r.support to set semantic labels.

+ +

PERFORMANCE

+ +

SVM training is done by loading all training data into memory. +In a case of large input raster files, use sparse label rasters +(e.g. raster points or small patches instead of uninterrupted cover).

+ +

During the training process there is no progress output printed. +Training with large number of data points can take significant time - +just be patient.

+ +

By default the shrinking heuristics option of LIBSVM is enabled. +It should not impact the outcome, just the training time. On some +input parameter and data combinations training with the shrinking +heuristics disabled might be faster.

+ +

The cache parameter determines the maximum memory allocated +for kernel caching to enhance computational speed. It's important to +note that the actual module's memory consumption may vary from this +setting, as it solely impacts LIBSVM's internal caching. The cache is +utilized on an as-needed basis, so it's unlikely to reach the specified value.

+ +

EXAMPLE

+ +

This is the first part of classification process. See +i.svm.predict for the second part.

+ +

Train a SVM to identify land use classes according to the 1996 land +use map landuse96_28m and then classify a LANDSAT scene from +October of 2002. Example requires the nc_spm_08 location.

+
+# Align computation region to the scene
+g.region raster=lsat7_2002_10 -p
+
+# store VIZ, NIR, MIR into group/subgroup
+i.group group=lsat7_2002 subgroup=res_30m \
+    input=lsat7_2002_10,lsat7_2002_20,lsat7_2002_30,lsat7_2002_40,lsat7_2002_50,lsat7_2002_70
+
+# Now digitize training areas "training" with the digitizer
+# and convert to raster model with v.to.rast
+v.to.rast input=training output=training use=cat label_column=label
+# If you are just playing around and do not care about the accuracy of outcome,
+# just use one of existing maps instead e.g.
+# r.random input=landuse96_28m npoints=10000 raster=training -s
+
+# Train the SVM
+i.svm.train group=lsat7_2002 subgroup=res_30m \
+    trainingmap=training signaturefile=landuse96_rnd_points
+
+# Go to i.svm.predict for the next step.
+
+ +

SEE ALSO

+ + +Predict values: i.svm.predict
+Set semantic labels: r.support
+Other classification modules: i.maxlik, +i.smap +

+LIBSVM home page: LIBSVM - A +Library for Support Vector Machines + +

REFERENCES

+ +

Please cite both - LIBSVM and i.svm.

+
    +
  • + For i.svm.* modules:
    + Nartiss, M., & Melniks, R. (2023). Improving pixel-­based classification of GRASS + GIS with support vector machine. Transactions in GIS, 00, 1–16. + https://doi.org/10.1111/tgis.13102 +
  • +
  • + For LIBSVM:
    + Chang, C.-C., & Lin, C.-J. (2011). LIBSVM : a library for support vector machines. + ACM Transactions on Intelligent Systems and Technology, 2:27:1--27:27. +
  • +
+ +

AUTHORS

+ +Maris Nartiss, University of Latvia. + + diff --git a/imagery/i.svm.train/main.c b/imagery/i.svm.train/main.c new file mode 100644 index 00000000000..e8406b52393 --- /dev/null +++ b/imagery/i.svm.train/main.c @@ -0,0 +1,513 @@ + +/**************************************************************************** + * + * MODULE: i.svm.train + * AUTHOR(S): Maris Nartiss - maris.gis gmail.com + * PURPOSE: Trains Support Vector Machine classifier + * + * COPYRIGHT: (C) 2023 by Maris Nartiss and the GRASS Development Team + * + * This program is free software under the GNU General Public + * License (>=v2). Read the file COPYING that comes with GRASS + * for details. + * + * Development of this module was supported from + * science funding of University of Latvia (2020-2023). + * + *****************************************************************************/ +#include +#include + +#include +#if HAVE_SVM_H +#include +#elif HAVE_LIBSVM_SVM_H +#include +#endif + +#include +#include +#include +#include + +#include "fill.h" + +/* LIBSVM message wrapper */ +void print_func(const char *s) +{ + G_verbose_message("%s", s); +} + +int main(int argc, char *argv[]) +{ + struct GModule *module; + struct Option *opt_group, *opt_subgroup, *opt_sigfile, *opt_labels; + struct Option *opt_svm_type, *opt_svm_kernel; + struct Option *opt_svm_cache_size, *opt_svm_degree, *opt_svm_gamma, + *opt_svm_coef0, *opt_svm_eps, *opt_svm_cost, *opt_svm_nu, *opt_svm_p; + struct Flag *flag_svm_shrink, *flag_svm_prob; + + G_gisinit(argv[0]); + + module = G_define_module(); + G_add_keyword(_("imagery")); + G_add_keyword(_("svm")); + G_add_keyword(_("classification")); + G_add_keyword(_("training")); + module->label = _("Train a SVM"); + module->description = _("Train a Support Vector Machine"); + + opt_group = G_define_standard_option(G_OPT_I_GROUP); + /* GTC: SVM training input */ + opt_group->description = _("Maps with feature values (attributes)"); + + opt_subgroup = G_define_standard_option(G_OPT_I_SUBGROUP); + opt_subgroup->required = NO; + + opt_labels = G_define_standard_option(G_OPT_R_INPUT); + opt_labels->key = "trainingmap"; + opt_labels->description = _("Map with training labels or target values"); + + opt_sigfile = G_define_option(); + opt_sigfile->key = "signaturefile"; + opt_sigfile->type = TYPE_STRING; + opt_sigfile->key_desc = "name"; + opt_sigfile->required = YES; + opt_sigfile->gisprompt = "new,signatures/libsvm,sigfile"; + opt_sigfile->description = + _("Name for output file containing result signatures"); + + opt_svm_type = G_define_option(); + opt_svm_type->key = "type"; + opt_svm_type->type = TYPE_STRING; + opt_svm_type->key_desc = "name"; + opt_svm_type->required = NO; + opt_svm_type->options = "c_svc,nu_svc,one_class,epsilon_svr,nu_svr"; + opt_svm_type->answer = "c_svc"; + opt_svm_type->description = _("Type of SVM"); + opt_svm_type->guisection = _("SVM parameters"); + G_asprintf((char **)&(opt_svm_type->descriptions), + "c_svc;%s;" + "nu_svc;%s;" + "one_class;%s;" + "epsilon_svr;%s;" + "nu_svr;%s;", + /* GTC: SVM type */ + _("C-SVM classification"), + /* GTC: SVM type */ + _("nu-SVM classification"), + /* GTC: SVM type */ + _("one-class SVM"), + /* GTC: SVM type */ + _("epsilon-SVM regression"), + /* GTC: SVM type */ + _("nu-SVM regression")); + + opt_svm_kernel = G_define_option(); + opt_svm_kernel->key = "kernel"; + opt_svm_kernel->type = TYPE_STRING; + opt_svm_kernel->key_desc = "name"; + opt_svm_kernel->required = NO; + opt_svm_kernel->options = "linear,poly,rbf,sigmoid"; + opt_svm_kernel->answer = "rbf"; + opt_svm_kernel->description = _("SVM kernel type"); + opt_svm_kernel->guisection = _("SVM parameters"); + G_asprintf((char **)&(opt_svm_kernel->descriptions), + "linear;%s;" + "poly;%s;" + "rbf;%s;" + "sigmoid;%s;" /* "precomputed;%s;" */, + /* GTC: SVM kernel type */ + _("u'*v"), + /* GTC: SVM kernel type */ + _("(gamma*u'*v + coef0)^degree"), + /* GTC: SVM kernel type */ + _("exp(-gamma*|u-v|^2)"), + /* GTC: SVM kernel type */ + _("tanh(gamma*u'*v + coef0)")); + /* TODO: precomputed */ + + opt_svm_cache_size = G_define_option(); + opt_svm_cache_size->key = "cache"; + opt_svm_cache_size->type = TYPE_INTEGER; + opt_svm_cache_size->key_desc = "cache size"; + opt_svm_cache_size->required = NO; + opt_svm_cache_size->options = "1-"; + opt_svm_cache_size->answer = "512"; + opt_svm_cache_size->description = _("LIBSVM kernel cache size in MB"); + + opt_svm_degree = G_define_option(); + opt_svm_degree->key = "degree"; + opt_svm_degree->type = TYPE_INTEGER; + opt_svm_degree->key_desc = "value"; + opt_svm_degree->required = NO; + opt_svm_degree->options = "0-"; + opt_svm_degree->answer = "3"; + opt_svm_degree->description = _("Degree in kernel function"); + opt_svm_degree->guisection = _("SVM options"); + + opt_svm_gamma = G_define_option(); + opt_svm_gamma->key = "gamma"; + opt_svm_gamma->type = TYPE_DOUBLE; + opt_svm_gamma->key_desc = "value"; + opt_svm_gamma->required = NO; + opt_svm_gamma->answer = "1"; + opt_svm_gamma->description = _("Gamma in kernel function"); + opt_svm_gamma->guisection = _("SVM options"); + + opt_svm_coef0 = G_define_option(); + opt_svm_coef0->key = "coef0"; + opt_svm_coef0->type = TYPE_DOUBLE; + opt_svm_coef0->key_desc = "value"; + opt_svm_coef0->required = NO; + opt_svm_coef0->answer = "0"; + opt_svm_coef0->description = _("coef0 in kernel function"); + opt_svm_coef0->guisection = _("SVM options"); + + opt_svm_eps = G_define_option(); + opt_svm_eps->key = "eps"; + opt_svm_eps->type = TYPE_DOUBLE; + opt_svm_eps->key_desc = "value"; + opt_svm_eps->required = NO; + /* GTC: SVM epsilon */ + opt_svm_eps->label = _("Tolerance of termination criterion"); + opt_svm_eps->description = + _("Defaults to 0.00001 for nu-SVC and 0.001 for others"); + opt_svm_eps->guisection = _("SVM options"); + + opt_svm_cost = G_define_option(); + opt_svm_cost->key = "cost"; + opt_svm_cost->type = TYPE_DOUBLE; + opt_svm_cost->key_desc = "value"; + opt_svm_cost->required = NO; + opt_svm_cost->answer = "1"; + /* GTC: SVM C */ + opt_svm_cost->label = _("Cost of constraints violation"); + opt_svm_cost->description = + _("The parameter C of C-SVC, epsilon-SVR, and nu-SVR"); + opt_svm_cost->guisection = _("SVM options"); + + opt_svm_nu = G_define_option(); + opt_svm_nu->key = "nu"; + opt_svm_nu->type = TYPE_DOUBLE; + opt_svm_nu->key_desc = "value"; + opt_svm_nu->required = NO; + opt_svm_nu->answer = "0.5"; + opt_svm_nu->description = + _("The parameter nu of nu-SVC, one-class SVM, and nu-SVR"); + opt_svm_nu->guisection = _("SVM options"); + + opt_svm_p = G_define_option(); + opt_svm_p->key = "p"; + opt_svm_p->type = TYPE_DOUBLE; + opt_svm_p->key_desc = "value"; + opt_svm_p->required = NO; + opt_svm_p->answer = "0.1"; + opt_svm_p->description = _("The epsilon in epsilon-insensitive loss " + "function of epsilon-SVM regression"); + opt_svm_p->guisection = _("SVM options"); + + flag_svm_shrink = G_define_flag(); + flag_svm_shrink->key = 's'; + flag_svm_shrink->label = _("Do not use the shrinking heuristics"); + /* GTC: SVM flag description */ + flag_svm_shrink->description = + _("Defaults to use the shrinking heuristics"); + flag_svm_shrink->guisection = _("SVM options"); + + flag_svm_prob = G_define_flag(); + flag_svm_prob->key = 'p'; + flag_svm_prob->label = + _("Train a SVC or SVR model for probability estimates"); + /* GTC: SVM flag description */ + flag_svm_prob->description = _("Defaults to no probabilities in model"); + flag_svm_prob->guisection = _("SVM options"); + + if (G_parser(argc, argv)) + exit(EXIT_FAILURE); + + /* Input validation */ + /* Input maps */ + char name_group[GNAME_MAX], name_subgroup[GNAME_MAX], + name_sigfile[GNAME_MAX]; + char mapset_group[GMAPSET_MAX], mapset_subgroup[GMAPSET_MAX], + mapset_sigfile[GMAPSET_MAX]; + char sigfile_dir[GPATH_MAX]; + char in_path[GPATH_MAX], out_path[GPATH_MAX]; + if (G_unqualified_name(opt_group->answer, NULL, name_group, mapset_group) == + 0) + strcpy(mapset_group, G_mapset()); + if (opt_subgroup->answer && + G_unqualified_name(opt_subgroup->answer, NULL, name_subgroup, + mapset_subgroup) != 0 && + strcmp(mapset_subgroup, mapset_group) != 0) + G_fatal_error(_("Invalid subgroup <%s> provided"), + opt_subgroup->answer); + if (!I_find_group2(name_group, mapset_group)) { + G_fatal_error(_("Group <%s> not found in mapset <%s>"), name_group, + mapset_group); + } + if (opt_subgroup->answer && + !I_find_subgroup2(name_group, name_subgroup, mapset_group)) { + G_fatal_error(_("Subgroup <%s> in group <%s@%s> not found"), + name_subgroup, name_group, mapset_group); + } + + const char *mapset_labels; + char name_labels[GNAME_MAX + GMAPSET_MAX]; + strcpy(name_labels, opt_labels->answer); + if ((mapset_labels = G_find_raster(name_labels, "")) == NULL) { + G_fatal_error(_("Raster map <%s> not found"), opt_labels->answer); + } + + if (G_unqualified_name(opt_sigfile->answer, G_mapset(), name_sigfile, + mapset_sigfile) < 0) + G_fatal_error(_("<%s> does not match the current mapset"), + mapset_sigfile); + if (G_legal_filename(name_sigfile) < 0) + G_fatal_error(_("<%s> is an illegal file name"), name_sigfile); + + /* Input SVM parameters */ + /* TODO: Implement parameter checking duplicating svm_check_parameter() to + * generate translatable errors */ + struct svm_parameter parameters; + parameters.cache_size = atoi(opt_svm_cache_size->answer); + parameters.degree = atoi(opt_svm_degree->answer); + parameters.gamma = atof(opt_svm_gamma->answer); + parameters.coef0 = atof(opt_svm_coef0->answer); + parameters.C = atof(opt_svm_cost->answer); + parameters.nu = atof(opt_svm_nu->answer); + parameters.p = atof(opt_svm_p->answer); + + if (strcmp(opt_svm_type->answer, "c_svc") == 0) + parameters.svm_type = C_SVC; + else if (strcmp(opt_svm_type->answer, "nu_svc") == 0) + parameters.svm_type = NU_SVC; + else if (strcmp(opt_svm_type->answer, "one_class") == 0) + parameters.svm_type = ONE_CLASS; + else if (strcmp(opt_svm_type->answer, "epsilon_svr") == 0) + parameters.svm_type = EPSILON_SVR; + else if (strcmp(opt_svm_type->answer, "nu_svr") == 0) + parameters.svm_type = NU_SVR; + else + G_fatal_error(_("Wrong SVM type")); + + if (strcmp(opt_svm_kernel->answer, "linear") == 0) + parameters.kernel_type = LINEAR; + else if (strcmp(opt_svm_kernel->answer, "poly") == 0) + parameters.kernel_type = POLY; + else if (strcmp(opt_svm_kernel->answer, "rbf") == 0) + parameters.kernel_type = RBF; + else if (strcmp(opt_svm_kernel->answer, "sigmoid") == 0) + parameters.kernel_type = SIGMOID; + else if (strcmp(opt_svm_kernel->answer, "precomputed") == 0) + parameters.kernel_type = PRECOMPUTED; + else + G_fatal_error(_("Wrong kernel type")); + + if (opt_svm_eps->answer) + parameters.eps = atof(opt_svm_eps->answer); + else { + if (parameters.svm_type == NU_SVC) + parameters.eps = 0.00001; + else + parameters.eps = 0.001; + } + + if (flag_svm_shrink->answer) + parameters.shrinking = 0; + else + parameters.shrinking = 1; + + if (flag_svm_prob->answer) + parameters.probability = 1; + else + parameters.probability = 0; + + /* TODO: implement weight support */ + parameters.nr_weight = 0; + + /* Get bands */ + struct Ref group_ref; + if (opt_subgroup->answer) { + if (!I_get_subgroup_ref2(name_group, opt_subgroup->answer, mapset_group, + &group_ref)) { + G_fatal_error( + _("There was an error reading subgroup <%s> in group <%s@%s>"), + opt_subgroup->answer, name_group, mapset_group); + } + } + else { + if (!I_get_group_ref2(name_group, mapset_group, &group_ref)) { + G_fatal_error(_("There was an error reading group <%s@%s>"), + name_group, mapset_group); + } + } + if (group_ref.nfiles <= 0) { + if (opt_subgroup->answer) + G_fatal_error( + _("Subgroup <%s> in group <%s@%s> contains no raster maps."), + opt_subgroup->answer, name_group, mapset_group); + else + G_fatal_error(_("Group <%s@%s> contains no raster maps."), + name_group, mapset_group); + } + const char **semantic_labels = G_malloc(group_ref.nfiles * sizeof(char *)); + + /* Precompute values for mean normalization */ + DCELL *means, *ranges; + means = G_malloc(group_ref.nfiles * sizeof(DCELL)); + ranges = G_malloc(group_ref.nfiles * sizeof(DCELL)); + for (int n = 0; n < group_ref.nfiles; n++) { + struct Range crange; + struct FPRange fprange; + int cmin, cmax; + double dmin, dmax; + int ret; + semantic_labels[n] = Rast_get_semantic_label_or_name( + group_ref.file[n].name, group_ref.file[n].mapset); + /* Use raster range for value rescaling */ + ret = Rast_read_range(group_ref.file[n].name, group_ref.file[n].mapset, + &crange); + if (ret == 1) { + Rast_get_range_min_max(&crange, &cmin, &cmax); + /* Calculate mean without a risk of integer overflow */ + means[n] = + (cmin / 2.0) + (cmax / 2.0) + ((cmin % 2 + cmax % 2) / 2.0); + ranges[n] = (cmax - cmin) / 2.0; + } + else if (ret == 3) { + ret = Rast_read_fp_range(group_ref.file[n].name, + group_ref.file[n].mapset, &fprange); + if (ret != 1) { + G_fatal_error( + _("Unable to get value range for raster map <%s@%s>"), + group_ref.file[n].name, group_ref.file[n].mapset); + } + Rast_get_fp_range_min_max(&fprange, &dmin, &dmax); + means[n] = (dmin + dmax) / 2.0; + ranges[n] = (dmax - dmin) / 2.0; + } + else { + G_fatal_error(_("Unable to get value range for raster map <%s@%s>"), + group_ref.file[n].name, group_ref.file[n].mapset); + } + if (ranges[n] < GRASS_EPSILON) { + G_fatal_error(_("Invalid value range for raster map <%s@%s>"), + group_ref.file[n].name, group_ref.file[n].mapset); + } + } + + /* Pass LIBSVM messages through GRASS */ + svm_set_print_string_function(&print_func); + + /* Fill svm_problem struct with training data */ + struct svm_problem problem; + G_message(_("Reading training data")); + fill_problem(name_labels, mapset_labels, group_ref, means, ranges, + &problem); + + /* svm_check_parameter needs filled svm_problem struct thus checking only + * now */ + G_verbose_message("Checking SVM parametrization"); + const char *parameters_error = svm_check_parameter(&problem, ¶meters); + if (parameters_error) + G_fatal_error(_("SVM parameter validation returned an error: %s\n"), + parameters_error); + + /* Train model. Might take some time. */ + struct svm_model *model; + G_message(_("Starting training process (it will take some time; " + "no progress is printed, be patient)")); + model = svm_train(&problem, ¶meters); + + /* Write out training results */ + G_verbose_message("Writing out trained SVM"); + /* This is a specific case as file is not written by GRASS but + by LIBSVM and thus "normal" GRASS lib functions can not be used. */ + I_make_signatures_dir(I_SIGFILE_TYPE_LIBSVM); + I_get_signatures_dir(sigfile_dir, I_SIGFILE_TYPE_LIBSVM); + /* G_fopen_new_misc should create a directory for later use */ + FILE *misc_file = G_fopen_new_misc(sigfile_dir, "version", name_sigfile); + if (!misc_file) + G_fatal_error(_("Unable to write trained model to file '%s'."), + name_sigfile); + fprintf(misc_file, "1\n"); + fclose(misc_file); + + /* Write out SVM values in a signature file */ + G_file_name_misc(out_path, sigfile_dir, "sig", name_sigfile, G_mapset()); + int out_status = svm_save_model(out_path, model); + if (out_status != 0) { + G_fatal_error( + _("Unable to write trained model to file '%s'. Error code: %d"), + out_path, out_status); + } + svm_free_and_destroy_model(&model); + /* Write out semantic label info */ + misc_file = G_fopen_new_misc(sigfile_dir, "semantic_label", name_sigfile); + if (!misc_file) + G_fatal_error(_("Unable to write trained model to file '%s'."), + name_sigfile); + for (int n = 0; n < group_ref.nfiles; n++) { + fprintf(misc_file, "%s\n", semantic_labels[n]); + } + fclose(misc_file); + G_free(semantic_labels); + + /* Write out rescaling value as the same value has to be used for prediction + */ + misc_file = G_fopen_new_misc(sigfile_dir, "scale", name_sigfile); + if (!misc_file) + G_fatal_error(_("Unable to write trained model to file '%s'."), + name_sigfile); + for (int n = 0; n < group_ref.nfiles; n++) { + fprintf(misc_file, "%lf %lf\n", means[n], ranges[n]); + } + fclose(misc_file); + G_free(means); + G_free(ranges); + + /* Copy CATs file. Will be used for prediction result maps */ + struct Categories cats; + G_verbose_message("Copying category information"); + if (Rast_read_cats(name_labels, mapset_labels, &cats) == 0) { + /* Path to training label map CATs file */ + G_file_name(in_path, "cats", name_labels, mapset_labels); + G_file_name_misc(out_path, sigfile_dir, "cats", name_sigfile, + G_mapset()); + /* It is OK to call G_copy if source file doesn't exist */ + G_copy_file(in_path, out_path); + } + + /* Copy color file. Will be used for prediction result maps */ + G_verbose_message("Copying colour information"); + if (G_find_file2("colr", name_labels, mapset_labels)) { + /* Path to training label map colr file */ + G_file_name(in_path, "colr", name_labels, mapset_labels); + G_file_name_misc(out_path, sigfile_dir, "colr", name_sigfile, + G_mapset()); + /* It is OK to call G_copy if source file doesn't exist */ + G_copy_file(in_path, out_path); + } + + /* History will be appended to a prediction result map history */ + struct History history; + G_verbose_message("Writing out history"); + misc_file = G_fopen_new_misc(sigfile_dir, "history", name_sigfile); + if (misc_file != NULL) { + G_zero(&history, sizeof(struct History)); + /* Rast_command_history performs command wrapping */ + Rast_command_history(&history); + for (int i = 0; i < history.nlines; i++) + fprintf(misc_file, "%s\n", history.lines[i]); + fclose(misc_file); + } + else { + G_warning(_("Unable to write history information for <%s>"), + name_sigfile); + } + + G_message(_("Training successfully complete")); + exit(EXIT_SUCCESS); +} diff --git a/imagery/i.svm.train/testsuite/test_i_svm_train.py b/imagery/i.svm.train/testsuite/test_i_svm_train.py new file mode 100644 index 00000000000..d895ffcd22f --- /dev/null +++ b/imagery/i.svm.train/testsuite/test_i_svm_train.py @@ -0,0 +1,321 @@ +""" +Name: i.svm.train input & output tests +Purpose: Validates user input validation code and output generation + +Author: Maris Nartiss +Copyright: (C) 2023 by Maris Nartiss and the GRASS Development Team +Licence: This program is free software under the GNU General Public + License (>=v2). Read the file COPYING that comes with GRASS + for details. +""" +import os +import unittest +import ctypes +import shutil + +from grass.script import core as grass +from grass.gunittest.case import TestCase +from grass.gunittest.main import test +from grass.gunittest.gmodules import SimpleModule +from grass.pygrass.gis import Mapset +from grass.pygrass import utils + +from grass.lib.gis import ( + GPATH_MAX, + GNAME_MAX, + G_file_name_misc, +) +from grass.lib.imagery import ( + I_SIGFILE_TYPE_LIBSVM, + I_get_signatures_dir, + I_signatures_remove, +) + + +class IOValidationTest(TestCase): + """Test input validation and output generation with i.svm.train""" + + @classmethod + @unittest.skipIf(shutil.which("i.svm.train") is None, "i.svm.train not found.") + def setUpClass(cls): + cls.tmp_rasts = [] + cls.tmp_groups = [] + cls.tmp_sigs = [] + cls.mapset_name = Mapset().name + # Small region for small testing rasters + cls.use_temp_region() + cls.runModule("g.region", n=10, s=0, e=10, w=0, res=1) + # A raster without a semantic label + cls.rast1 = grass.tempname(10) + cls.runModule( + "r.mapcalc", expression=f"{cls.rast1}=rand(0.0,1)", seed=1, quiet=True + ) + cls.tmp_rasts.append(cls.rast1) + cls.runModule("r.colors", _map=cls.rast1, color="grey", quiet=True) + # A raster with a semantic label + cls.rast2 = grass.tempname(10) + cls.runModule( + "r.mapcalc", expression=f"{cls.rast2}=rand(0.0,1)", seed=1, quiet=True + ) + cls.tmp_rasts.append(cls.rast2) + cls.runModule( + "r.support", _map=cls.rast2, semantic_label="GRASS_RND1", quiet=True + ) + cls.rast3 = grass.tempname(10) + cls.runModule( + "r.mapcalc", expression=f"{cls.rast3}=rand(0.0,1)", seed=1, quiet=True + ) + cls.tmp_rasts.append(cls.rast3) + cls.runModule( + "r.support", _map=cls.rast3, semantic_label="GRASS_RND2", quiet=True + ) + cls.rast4 = grass.tempname(10) + cls.runModule( + "r.mapcalc", expression=f"{cls.rast4}=rand(0.0,1)", seed=1, quiet=True + ) + cls.tmp_rasts.append(cls.rast4) + cls.runModule( + "r.support", _map=cls.rast4, semantic_label="GRASS_RND3", quiet=True + ) + cls.rast5 = grass.tempname(10) + cls.runModule( + "r.mapcalc", expression=f"{cls.rast5}=rand(-1.0,1)", seed=1, quiet=True + ) + cls.tmp_rasts.append(cls.rast5) + cls.runModule( + "r.support", _map=cls.rast5, semantic_label="GRASS_RND4", quiet=True + ) + cls.rast6 = grass.tempname(10) + cls.runModule( + "r.mapcalc", + expression=( + f"{cls.rast6}=if(row() == 1 && col() == 1, 10, if(row() == 2 " + "&& col() == 2, -10, rand(-10.0,10)))" + ), + seed=1, + quiet=True, + ) + cls.tmp_rasts.append(cls.rast6) + cls.runModule( + "r.support", _map=cls.rast6, semantic_label="GRASS_RND5", quiet=True + ) + # An empty raster + cls.rast7 = grass.tempname(10) + cls.runModule("r.mapcalc", expression=f"{cls.rast7}=null()", quiet=True) + cls.tmp_rasts.append(cls.rast7) + cls.runModule( + "r.support", _map=cls.rast7, semantic_label="GRASS_RND7", quiet=True + ) + # An empty imagery group + cls.group1 = grass.tempname(10) + cls.runModule("i.group", group=cls.group1, _input=(cls.rast1,), quiet=True) + cls.tmp_groups.append(cls.group1) + cls.runModule( + "i.group", flags="r", group=cls.group1, _input=(cls.rast1,), quiet=True + ) + # A good imagery group + cls.group3 = grass.tempname(10) + cls.runModule( + "i.group", group=cls.group3, _input=(cls.rast2, cls.rast3), quiet=True + ) + cls.tmp_groups.append(cls.group3) + # Range test group + cls.group4 = grass.tempname(10) + cls.runModule( + "i.group", group=cls.group4, _input=(cls.rast5, cls.rast6), quiet=True + ) + cls.tmp_groups.append(cls.group4) + # A group with empty raster + cls.group5 = grass.tempname(10) + cls.runModule( + "i.group", group=cls.group5, _input=(cls.rast6, cls.rast7), quiet=True + ) + cls.tmp_groups.append(cls.group5) + + @classmethod + def tearDownClass(cls): + """Remove the temporary region and generated data""" + cls.del_temp_region() + for rast in cls.tmp_rasts: + cls.runModule("g.remove", flags="f", _type="raster", name=rast) + for group in cls.tmp_groups: + cls.runModule("g.remove", flags="f", _type="group", name=group) + for sig in cls.tmp_sigs: + I_signatures_remove(I_SIGFILE_TYPE_LIBSVM, sig) + + @unittest.skipIf(shutil.which("i.svm.train") is None, "i.svm.train not found.") + def test_empty_group(self): + """Empty imagery group handling""" + sigfile = grass.tempname(10) + isvm = SimpleModule( + "i.svm.train", + group=self.group1, + trainingmap=self.rast1, + signaturefile=sigfile, + quiet=True, + ) + self.assertModuleFail(isvm) + self.assertTrue(isvm.outputs.stderr) + self.assertIn(self.group1, isvm.outputs.stderr) + + @unittest.skipIf(shutil.which("i.svm.train") is None, "i.svm.train not found.") + def test_wrong_sigfile_mapset(self): + """Attempt to use FQ signature file name with not current mapset""" + sigfile = grass.tempname(10) + mapset = grass.tempname(10) + isvm = SimpleModule( + "i.svm.train", + group=self.group3, + trainingmap=self.rast1, + signaturefile=f"{sigfile}@{mapset}", + quiet=True, + ) + self.assertModuleFail(isvm) + self.assertTrue(isvm.outputs.stderr) + self.assertIn(mapset, isvm.outputs.stderr) + + @unittest.skipIf(shutil.which("i.svm.train") is None, "i.svm.train not found.") + def test_wrong_svm_param(self): + """Attempt to use invalid SVM parametres""" + sigfile = grass.tempname(10) + isvm = SimpleModule( + "i.svm.train", + group=self.group3, + trainingmap=self.rast1, + signaturefile=sigfile, + eps=-1, + quiet=True, + ) + self.assertModuleFail(isvm) + self.assertTrue(isvm.outputs.stderr) + self.assertIn("eps", isvm.outputs.stderr) + + @unittest.skipIf(shutil.which("i.svm.train") is None, "i.svm.train not found.") + def test_creation_of_misc_files(self): + """Validate creation of category, history and colour files""" + sigfile = grass.tempname(10) + csigdir = ctypes.create_string_buffer(GNAME_MAX) + I_get_signatures_dir(csigdir, I_SIGFILE_TYPE_LIBSVM) + sigdir = utils.decode(csigdir.value) + isvm = SimpleModule( + "i.svm.train", + group=self.group3, + trainingmap=self.rast1, + signaturefile=sigfile, + quiet=True, + ) + self.assertModule(isvm) + self.tmp_sigs.append(sigfile) + cpath = ctypes.create_string_buffer(GPATH_MAX) + G_file_name_misc(cpath, sigdir, "version", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "sig", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "cats", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "colr", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "history", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + + @unittest.skipIf(shutil.which("i.svm.train") is None, "i.svm.train not found.") + def test_dont_fail_if_misc_files_missing(self): + """Colour file is missing but it should not cause a failure""" + sigfile = grass.tempname(10) + csigdir = ctypes.create_string_buffer(GNAME_MAX) + I_get_signatures_dir(csigdir, I_SIGFILE_TYPE_LIBSVM) + sigdir = utils.decode(csigdir.value) + isvm = SimpleModule( + "i.svm.train", + group=self.group3, + trainingmap=self.rast4, + signaturefile=sigfile, + quiet=True, + ) + self.assertModule(isvm) + self.tmp_sigs.append(sigfile) + cpath = ctypes.create_string_buffer(GPATH_MAX) + G_file_name_misc(cpath, sigdir, "version", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "sig", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "cats", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "colr", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertFalse(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "history", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + + @unittest.skipIf(shutil.which("i.svm.train") is None, "i.svm.train not found.") + def test_rescaling(self): + """Raster values should be rescaled""" + sigfile = grass.tempname(10) + csigdir = ctypes.create_string_buffer(GNAME_MAX) + I_get_signatures_dir(csigdir, I_SIGFILE_TYPE_LIBSVM) + sigdir = utils.decode(csigdir.value) + isvm = SimpleModule( + "i.svm.train", + group=self.group4, + trainingmap=self.rast4, + signaturefile=sigfile, + quiet=True, + ) + self.assertModule(isvm) + self.tmp_sigs.append(sigfile) + cpath = ctypes.create_string_buffer(GPATH_MAX) + G_file_name_misc(cpath, sigdir, "version", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "sig", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "cats", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "colr", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertFalse(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "history", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + G_file_name_misc(cpath, sigdir, "scale", sigfile, self.mapset_name) + misc_file = utils.decode(cpath.value) + self.assertTrue(os.path.isfile(misc_file)) + with open(misc_file) as rf: + lines = rf.readlines() + M, R = lines[0].strip().split(" ") + self.assertTrue(float(M) > -1 and float(M) < 1) + self.assertTrue(float(R) <= 2) + M, R = lines[1].strip().split(" ") + self.assertTrue(float(M) > -1 and float(M) < 1) + self.assertTrue(float(R) <= 20) + + @unittest.skipIf(shutil.which("i.svm.train") is None, "i.svm.train not found.") + def test_fail_on_empty_raster(self): + """One of imagery group rasters is empty""" + sigfile = grass.tempname(10) + isvm = SimpleModule( + "i.svm.train", + group=self.group5, + trainingmap=self.rast4, + signaturefile=sigfile, + quiet=True, + ) + self.assertModuleFail(isvm) + self.tmp_sigs.append(sigfile) + self.assertTrue(isvm.outputs.stderr) + self.assertIn("range", isvm.outputs.stderr) + + +if __name__ == "__main__": + test() diff --git a/include/Make/Platform.make.in b/include/Make/Platform.make.in index 93a6f07923c..c65ed28c686 100644 --- a/include/Make/Platform.make.in +++ b/include/Make/Platform.make.in @@ -162,6 +162,11 @@ BLASINC = @BLASINC@ LAPACKLIB = @LAPACKLIB@ LAPACKINC = @LAPACKINC@ +#LIBSVM +LIBSVM_LIB = @LIBSVM_LIB@ +LIBSVM_INC = @LIBSVM_INC@ +USE_LIBSVM = @USE_LIBSVM@ + #GDAL/OGR GDALLIBS = @GDAL_LIBS@ GDALCFLAGS = @GDAL_CFLAGS@ diff --git a/include/grass/config.h.in b/include/grass/config.h.in index 9b1d8cafd44..5730328ccb5 100644 --- a/include/grass/config.h.in +++ b/include/grass/config.h.in @@ -110,6 +110,12 @@ /* Define to 1 if you have the header file. */ #undef HAVE_LIBPQ_FE_H +/* Define to 1 if using LIBSVM. */ +#undef HAVE_LIBSVM + +/* Define to 1 if you have the header file. */ +#undef HAVE_LIBSVM_SVM_H + /* Define to 1 if you have the header file. */ #undef HAVE_LIMITS_H @@ -227,6 +233,9 @@ /* Define to 1 if you have the header file. */ #undef HAVE_STRING_H +/* Define to 1 if you have the header file. */ +#undef HAVE_SVM_H + /* Define to 1 if you have the header file. */ #undef HAVE_SYS_IOCTL_H diff --git a/include/grass/imagery.h b/include/grass/imagery.h index 0fceb05311a..1768ebb6f68 100644 --- a/include/grass/imagery.h +++ b/include/grass/imagery.h @@ -193,12 +193,12 @@ struct scdScattData { typedef enum { I_SIGFILE_TYPE_SIG, /*! Signature files used by i.maxlik */ I_SIGFILE_TYPE_SIGSET, /*! Signature files used by i.smap */ - + I_SIGFILE_TYPE_LIBSVM, /*! Signature files used by i.svm */ } I_SIGFILE_TYPE; #define SIGNATURE_TYPE_MIXED 1 /* Unused? */ #define I_SIGFILE_TYPE_COUNT \ - 2 /*! Total count of supported signature file types */ + 3 /*! Total count of supported signature file types */ #define GROUPFILE "CURGROUP" #define SUBGROUPFILE "CURSUBGROUP" diff --git a/lib/imagery/manage_signatures.c b/lib/imagery/manage_signatures.c index 5e095a96e15..26167b5d77c 100644 --- a/lib/imagery/manage_signatures.c +++ b/lib/imagery/manage_signatures.c @@ -34,6 +34,9 @@ void I_get_signatures_dir(char *dir, I_SIGFILE_TYPE type) else if (type == I_SIGFILE_TYPE_SIGSET) { sprintf(dir, "signatures%csigset", HOST_DIRSEP); } + else if (type == I_SIGFILE_TYPE_LIBSVM) { + sprintf(dir, "signatures%clibsvm", HOST_DIRSEP); + } else { G_fatal_error("Programming error: unknown signature file type"); } diff --git a/lib/imagery/testsuite/test_imagery_find.py b/lib/imagery/testsuite/test_imagery_find.py index 479eba62a28..a943e2d3520 100644 --- a/lib/imagery/testsuite/test_imagery_find.py +++ b/lib/imagery/testsuite/test_imagery_find.py @@ -22,6 +22,7 @@ from grass.lib.imagery import ( I_SIGFILE_TYPE_SIG, I_SIGFILE_TYPE_SIGSET, + I_SIGFILE_TYPE_LIBSVM, I_find_signature, I_find_signature2, ) @@ -37,6 +38,7 @@ def setUpClass(cls): # tools, we must ensure signature directories exist os.makedirs(f"{cls.mpath}/signatures/sig/", exist_ok=True) os.makedirs(f"{cls.mpath}/signatures/sigset/", exist_ok=True) + os.makedirs(f"{cls.mpath}/signatures/libsvm/", exist_ok=True) cls.sig_name1 = tempname(10) cls.sig_dir1 = f"{cls.mpath}/signatures/sigset/{cls.sig_name1}" os.makedirs(cls.sig_dir1) @@ -47,6 +49,11 @@ def setUpClass(cls): os.makedirs(cls.sig_dir2) cls.sigdirs.append(cls.sig_dir2) open(f"{cls.sig_dir2}/sig", "a").close() + cls.sig_name3 = tempname(10) + cls.sig_dir3 = f"{cls.mpath}/signatures/libsvm/{cls.sig_name3}" + os.makedirs(cls.sig_dir3) + cls.sigdirs.append(cls.sig_dir3) + open(f"{cls.sig_dir3}/sig", "a").close() @classmethod def tearDownClass(cls): @@ -101,6 +108,30 @@ def test_find_sigset(self): ret = I_find_signature(I_SIGFILE_TYPE_SIGSET, self.sig_name1, "PERMANENT") self.assertFalse(ret) + def test_find_libsvm(self): + # Non existing without a mapset + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, tempname(10), None) + self.assertFalse(ret) + # Non existing with a mapset + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, tempname(10), self.mapset_name) + self.assertFalse(ret) + # Libsvm with sig type should equal non existing + ret = I_find_signature(I_SIGFILE_TYPE_SIG, self.sig_name3, self.mapset_name) + self.assertFalse(ret) + # Existing without a mapset + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, self.sig_name3, None) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + # Existing with a mapset + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, self.sig_name3, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + # Existing in a different mapset should fail + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, self.sig_name3, "PERMANENT") + self.assertFalse(ret) + def test_find2_sig(self): # Non existing without a mapset ret = I_find_signature2(I_SIGFILE_TYPE_SIG, tempname(10), None) @@ -149,6 +180,30 @@ def test_find2_sigset(self): ret = I_find_signature2(I_SIGFILE_TYPE_SIGSET, self.sig_name1, "PERMANENT") self.assertFalse(ret) + def test_find2_libsvm(self): + # Non existing without a mapset + ret = I_find_signature2(I_SIGFILE_TYPE_LIBSVM, tempname(10), None) + self.assertFalse(ret) + # Non existing with a mapset + ret = I_find_signature2(I_SIGFILE_TYPE_LIBSVM, tempname(10), self.mapset_name) + self.assertFalse(ret) + # Libsvm with sig type should equal non existing + ret = I_find_signature2(I_SIGFILE_TYPE_SIG, self.sig_name3, self.mapset_name) + self.assertFalse(ret) + # Existing without a mapset + ret = I_find_signature2(I_SIGFILE_TYPE_LIBSVM, self.sig_name3, None) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + # Existing with a mapset + ret = I_find_signature2(I_SIGFILE_TYPE_LIBSVM, self.sig_name3, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + # Existing in a different mapset should fail + ret = I_find_signature2(I_SIGFILE_TYPE_LIBSVM, self.sig_name3, "PERMANENT") + self.assertFalse(ret) + if __name__ == "__main__": test() diff --git a/lib/imagery/testsuite/test_imagery_signature_management.py b/lib/imagery/testsuite/test_imagery_signature_management.py index ae926afd6e8..b2ec954a74e 100644 --- a/lib/imagery/testsuite/test_imagery_signature_management.py +++ b/lib/imagery/testsuite/test_imagery_signature_management.py @@ -30,6 +30,7 @@ from grass.lib.imagery import ( I_SIGFILE_TYPE_SIG, I_SIGFILE_TYPE_SIGSET, + I_SIGFILE_TYPE_LIBSVM, I_find_signature, I_signatures_remove, I_signatures_copy, @@ -41,7 +42,7 @@ ) -class GetSignaturesElementTestCase(TestCase): +class GetSignaturesDirTestCase(TestCase): def test_get_sig(self): cdir = ctypes.create_string_buffer(GNAME_MAX) I_get_signatures_dir(cdir, I_SIGFILE_TYPE_SIG) @@ -52,8 +53,13 @@ def test_get_sigset(self): I_get_signatures_dir(cdir, I_SIGFILE_TYPE_SIGSET) self.assertEqual(utils.decode(cdir.value), f"signatures{HOST_DIRSEP}sigset") + def test_get_libsvm(self): + elem = ctypes.create_string_buffer(GNAME_MAX) + I_get_signatures_dir(elem, I_SIGFILE_TYPE_LIBSVM) + self.assertEqual(utils.decode(elem.value), f"signatures{HOST_DIRSEP}libsvm") -class MakeSignaturesElementTestCase(TestCase): + +class MakeSignaturesDirTestCase(TestCase): @classmethod def setUpClass(cls): cls.org_mapset = Mapset() @@ -90,6 +96,17 @@ def test_make_sigset(self): os.path.isdir(os.path.join(self.tmp_mapset_path, "signatures", "sigset")) ) + def test_make_libsvm(self): + I_make_signatures_dir(I_SIGFILE_TYPE_LIBSVM) + self.assertTrue( + os.path.isdir(os.path.join(self.tmp_mapset_path, "signatures", "libsvm")) + ) + # There should not be any side effects of calling function multiple times + I_make_signatures_dir(I_SIGFILE_TYPE_LIBSVM) + self.assertTrue( + os.path.isdir(os.path.join(self.tmp_mapset_path, "signatures", "libsvm")) + ) + class SignaturesRemoveTestCase(TestCase): @classmethod @@ -101,6 +118,7 @@ def setUpClass(cls): # tools, we must ensure signature directories exist os.makedirs(f"{cls.mpath}/signatures/sig/", exist_ok=True) os.makedirs(f"{cls.mpath}/signatures/sigset/", exist_ok=True) + os.makedirs(f"{cls.mpath}/signatures/libsvm/", exist_ok=True) @classmethod def tearDownClass(cls): @@ -266,6 +284,86 @@ def test_remove_nonexisting_sigset(self): ms = utils.decode(ret) self.assertEqual(ms, self.mapset_name) + def test_remove_existing_libsvm(self): + # This test will fail if run in PERMANENT! + # Set up files and mark for clean-up + sig_name1 = tempname(10) + sig_dir1 = f"{self.mpath}/signatures/libsvm/{sig_name1}" + os.makedirs(sig_dir1) + sigfile_name1 = f"{sig_dir1}/sig" + open(sigfile_name1, "a").close() + self.sigdirs.append(sig_dir1) + sig_name2 = tempname(10) + sig_dir2 = f"{self.mpath}/signatures/libsvm/{sig_name2}" + os.makedirs(sig_dir2) + sigfile_name2 = f"{sig_dir2}/sig" + open(sigfile_name2, "a").close() + self.sigdirs.append(sig_dir2) + sig_name3 = tempname(10) + sig_dir3 = f"{self.mpath}/signatures/sig/{sig_name3}" + os.makedirs(sig_dir3) + sigfile_name3 = f"{sig_dir3}/sig" + open(sigfile_name3, "a").close() + self.sigdirs.append(sig_dir3) + # Try to remove with wrong type + ret = I_signatures_remove(I_SIGFILE_TYPE_SIG, sig_name2) + self.assertEqual(ret, 1) + # Try to remove with wrong mapset + ret = I_signatures_remove(I_SIGFILE_TYPE_LIBSVM, f"{sig_name2}@PERMANENT") + self.assertEqual(ret, 1) + # Should be still present + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, sig_name2, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + # Now remove with correct type + ret = I_signatures_remove(I_SIGFILE_TYPE_LIBSVM, sig_name2) + self.assertEqual(ret, 0) + # removed should be gone + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, sig_name2, self.mapset_name) + self.assertFalse(ret) + # Others should remain + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, sig_name1, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + ret = I_find_signature(I_SIGFILE_TYPE_SIG, sig_name3, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + + def test_remove_nonexisting_libsvm(self): + # Set up files and mark for clean-up + sig_name1 = tempname(10) + sig_dir1 = f"{self.mpath}/signatures/sigset/{sig_name1}" + os.makedirs(sig_dir1) + sigfile_name1 = f"{sig_dir1}/sig" + open(sigfile_name1, "a").close() + self.sigdirs.append(sig_dir1) + sig_name2 = tempname(10) + # Do not create sig_name2 matching file + sig_name3 = tempname(10) + sig_dir3 = f"{self.mpath}/signatures/libsvm/{sig_name3}" + os.makedirs(sig_dir3) + sigfile_name3 = f"{sig_dir3}/sig" + open(sigfile_name3, "a").close() + self.sigdirs.append(sig_dir3) + # Now remove one (should fail as file is absent) + ret = I_signatures_remove(I_SIGFILE_TYPE_LIBSVM, sig_name2) + self.assertEqual(ret, 1) + # removed should be still absent + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, sig_name2, self.mapset_name) + self.assertFalse(ret) + # All others should remain + ret = I_find_signature(I_SIGFILE_TYPE_SIGSET, sig_name1, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, sig_name3, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + class SignaturesCopyTestCase(TestCase): @classmethod @@ -277,6 +375,7 @@ def setUpClass(cls): # tools, we must ensure signature directories exist os.makedirs(f"{cls.mpath}/signatures/sig/", exist_ok=True) os.makedirs(f"{cls.mpath}/signatures/sigset/", exist_ok=True) + os.makedirs(f"{cls.mpath}/signatures/libsvm/", exist_ok=True) # A mapset with a random name cls.src_mapset_name = tempname(10) G_make_mapset(None, None, cls.src_mapset_name) @@ -300,6 +399,14 @@ def setUpClass(cls): f = open(f"{cls.src_sigset_dir}/sig", "w") f.write("A sigset file") f.close() + os.makedirs(f"{cls.src_mapset_path}/signatures/libsvm/") + cls.src_libsvm = tempname(10) + cls.src_libsvm_dir = f"{cls.src_mapset_path}/signatures/libsvm/{cls.src_libsvm}" + os.makedirs(cls.src_libsvm_dir) + cls.sigdirs.append(cls.src_libsvm_dir) + f = open(f"{cls.src_libsvm_dir}/sig", "w") + f.write("A libsvm file") + f.close() @classmethod def tearDownClass(cls): @@ -327,6 +434,12 @@ def test_sigset_does_not_exist(self): ) self.assertEqual(ret, 1) + def test_libsvm_does_not_exist(self): + ret = I_signatures_copy( + I_SIGFILE_TYPE_LIBSVM, tempname(10), self.mapset_name, tempname(10) + ) + self.assertEqual(ret, 1) + def test_success_unqualified_sig(self): dst = tempname(10) ret = I_find_signature(I_SIGFILE_TYPE_SIG, dst, self.mapset_name) @@ -409,6 +522,49 @@ def test_success_fq_sigset(self): os.path.isfile(f"{self.mpath}/signatures/sigset/{dst_name}/sig") ) + def test_success_unqualified_libsvm(self): + dst = tempname(10) + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, dst, self.mapset_name) + self.assertFalse(ret) + ret = I_find_signature( + I_SIGFILE_TYPE_LIBSVM, self.src_libsvm, self.src_mapset_name + ) + self.assertTrue(ret) + ret = I_signatures_copy( + I_SIGFILE_TYPE_LIBSVM, self.src_libsvm, self.src_mapset_name, dst + ) + self.sigdirs.append(f"{self.mpath}/signatures/libsvm/{dst}") + self.assertEqual(ret, 0) + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, dst, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + self.assertTrue(os.path.isfile(f"{self.mpath}/signatures/libsvm/{dst}/sig")) + + def test_success_fq_libsvm(self): + dst = tempname(10) + dst_dir = f"{self.mpath}/signatures/libsvm/{dst}" + self.sigdirs.append(dst_dir) + dst = dst + "@" + self.mapset_name + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, dst, self.mapset_name) + self.assertFalse(ret) + ret = I_find_signature( + I_SIGFILE_TYPE_LIBSVM, self.src_libsvm, self.src_mapset_name + ) + self.assertTrue(ret) + ret = I_signatures_copy( + I_SIGFILE_TYPE_LIBSVM, + self.src_libsvm + "@" + self.src_mapset_name, + self.src_mapset_name, + dst, + ) + self.assertEqual(ret, 0) + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, dst, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + self.assertTrue(os.path.isfile(f"{dst_dir}/sig")) + class SignaturesRenameTestCase(TestCase): @classmethod @@ -420,6 +576,7 @@ def setUpClass(cls): # tools, we must ensure signature directories exist os.makedirs(f"{cls.mpath}/signatures/sig/", exist_ok=True) os.makedirs(f"{cls.mpath}/signatures/sigset/", exist_ok=True) + os.makedirs(f"{cls.mpath}/signatures/libsvm/", exist_ok=True) @classmethod def tearDownClass(cls): @@ -444,6 +601,10 @@ def test_sigset_does_not_exist(self): ret = I_signatures_rename(I_SIGFILE_TYPE_SIGSET, tempname(10), tempname(10)) self.assertEqual(ret, 1) + def test_libsvm_does_not_exist(self): + ret = I_signatures_rename(I_SIGFILE_TYPE_LIBSVM, tempname(10), tempname(10)) + self.assertEqual(ret, 1) + def test_success_unqualified_sig(self): src_sig = tempname(10) sig_dir = f"{self.mpath}/signatures/sig/{src_sig}" @@ -544,6 +705,56 @@ def test_success_fq_sigset(self): os.path.isfile(f"{self.mpath}/signatures/sigset/{dst_name}/sig") ) + def test_success_unqualified_libsvm(self): + src_sig = tempname(10) + sig_dir = f"{self.mpath}/signatures/libsvm/{src_sig}" + os.makedirs(sig_dir) + self.sigdirs.append(sig_dir) + f = open(f"{sig_dir}/sig", "w") + f.write("A libsvm file") + f.close() + dst = tempname(10) + self.sigdirs.append(f"{self.mpath}/signatures/libsvm/{dst}") + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, dst, self.mapset_name) + self.assertFalse(ret) + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, src_sig, self.mapset_name) + self.assertTrue(ret) + ret = I_signatures_rename(I_SIGFILE_TYPE_LIBSVM, src_sig, dst) + self.assertEqual(ret, 0) + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, dst, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + self.assertTrue(os.path.isfile(f"{self.mpath}/signatures/libsvm/{dst}/sig")) + + def test_success_fq_libsvm(self): + src_sig = tempname(10) + sig_dir = f"{self.mpath}/signatures/libsvm/{src_sig}" + os.makedirs(sig_dir) + self.sigdirs.append(sig_dir) + f = open(f"{sig_dir}/sig", "w") + f.write("A libsvm file") + f.close() + dst = tempname(10) + dst_dir = f"{self.mpath}/signatures/libsvm/{dst}" + self.sigdirs.append(dst_dir) + dst = dst + "@" + self.mapset_name + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, dst, self.mapset_name) + self.assertFalse(ret) + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, src_sig, self.mapset_name) + self.assertTrue(ret) + ret = I_signatures_rename( + I_SIGFILE_TYPE_LIBSVM, + src_sig + "@" + self.mapset_name, + dst, + ) + self.assertEqual(ret, 0) + ret = I_find_signature(I_SIGFILE_TYPE_LIBSVM, dst, self.mapset_name) + self.assertTrue(ret) + ms = utils.decode(ret) + self.assertEqual(ms, self.mapset_name) + self.assertTrue(os.path.isfile(f"{dst_dir}/sig")) + class SignaturesListByTypeTestCase(TestCase): @classmethod @@ -556,6 +767,7 @@ def setUpClass(cls): # tools, we must ensure signature directories exist os.makedirs(f"{cls.mpath}/signatures/sig/", exist_ok=True) os.makedirs(f"{cls.mpath}/signatures/sigset/", exist_ok=True) + os.makedirs(f"{cls.mpath}/signatures/libsvm/", exist_ok=True) # A mapset with a random name cls.rnd_mapset_name = tempname(10) G_make_mapset(None, None, cls.rnd_mapset_name) @@ -564,6 +776,7 @@ def setUpClass(cls): ) os.makedirs(f"{cls.rnd_mapset_path}/signatures/sig/") os.makedirs(f"{cls.rnd_mapset_path}/signatures/sigset/") + os.makedirs(f"{cls.rnd_mapset_path}/signatures/libsvm/") @classmethod def tearDownClass(cls): @@ -586,6 +799,11 @@ def test_no_sigs_at_all(self): ) self.assertEqual(ret, 0) I_free_signatures_list(ret, ctypes.byref(sig_list)) + ret = I_signatures_list_by_type( + I_SIGFILE_TYPE_LIBSVM, self.rnd_mapset_name, ctypes.byref(sig_list) + ) + self.assertEqual(ret, 0) + I_free_signatures_list(ret, ctypes.byref(sig_list)) def test_sig_in_different_mapset(self): # Should return 0 signatures from a different mapset @@ -619,6 +837,22 @@ def test_sig_in_different_mapset(self): shutil.rmtree(sig_dir) self.assertEqual(ret, 0) I_free_signatures_list(ret, ctypes.byref(sig_list)) + # Libsvm type + local_sig = tempname(10) + sig_dir = f"{self.mpath}/signatures/libsvm/{local_sig}" + os.makedirs(sig_dir) + sig_file = f"{sig_dir}/sig" + self.sigdirs.append(sig_dir) + f = open(sig_file, "w") + f.write("A libsvm file") + f.close() + sig_list = self.list_ptr() + ret = I_signatures_list_by_type( + I_SIGFILE_TYPE_LIBSVM, self.rnd_mapset_name, ctypes.byref(sig_list) + ) + os.remove(sig_file) + self.assertEqual(ret, 0) + I_free_signatures_list(ret, ctypes.byref(sig_list)) def test_single_sig(self): # Case when only a single signature file is present @@ -657,6 +891,23 @@ def test_single_sig(self): val = utils.decode(sigset_list[0]) self.assertEqual(val, f"{rnd_sigset}@{self.rnd_mapset_name}") I_free_signatures_list(ret, ctypes.byref(sigset_list)) + # libsvm type + rnd_sig = tempname(10) + sig_dir = f"{self.rnd_mapset_path}/signatures/libsvm/{rnd_sig}" + os.makedirs(sig_dir) + sig_file = f"{sig_dir}/sig" + f = open(sig_file, "w") + f.write("A libsvm file") + f.close() + sig_list = self.list_ptr() + ret = I_signatures_list_by_type( + I_SIGFILE_TYPE_LIBSVM, self.rnd_mapset_name, ctypes.byref(sig_list) + ) + shutil.rmtree(sig_dir) + self.assertEqual(ret, 1) + val = utils.decode(sig_list[0]) + self.assertEqual(val, f"{rnd_sig}@{self.rnd_mapset_name}") + I_free_signatures_list(ret, ctypes.byref(sig_list)) def test_multiple_sigs(self): # Should result into a multiple sigs returned @@ -719,6 +970,36 @@ def test_multiple_sigs(self): self.assertIn(utils.decode(sigset_list[0]), golden) self.assertIn(utils.decode(sigset_list[1]), golden) I_free_signatures_list(ret, ctypes.byref(sigset_list)) + # libsvm type + rnd_sig1 = tempname(10) + sig_dir1 = f"{self.rnd_mapset_path}/signatures/libsvm/{rnd_sig1}" + os.makedirs(sig_dir1) + sig_file1 = f"{sig_dir1}/sig" + f = open(sig_file1, "w") + f.write("A libsvm file") + f.close() + rnd_sig2 = tempname(10) + sig_dir2 = f"{self.rnd_mapset_path}/signatures/libsvm/{rnd_sig2}" + os.makedirs(sig_dir2) + sig_file2 = f"{sig_dir2}/sig" + f = open(sig_file2, "w") + f.write("A libsvm file") + f.close() + # POINTER(POINTER(c_char)) + sig_list = self.list_ptr() + ret = I_signatures_list_by_type( + I_SIGFILE_TYPE_LIBSVM, self.rnd_mapset_name, ctypes.byref(sig_list) + ) + shutil.rmtree(sig_dir1) + shutil.rmtree(sig_dir2) + self.assertEqual(ret, 2) + golden = ( + f"{rnd_sig1}@{self.rnd_mapset_name}", + f"{rnd_sig2}@{self.rnd_mapset_name}", + ) + self.assertIn(utils.decode(sig_list[0]), golden) + self.assertIn(utils.decode(sig_list[1]), golden) + I_free_signatures_list(ret, ctypes.byref(sig_list)) def test_multiple_sigs_multiple_mapsets(self): # Test searching in multiple mapsets. Identical to SIGSET case @@ -822,6 +1103,58 @@ def test_multiple_sigsets_multiple_mapsets(self): self.assertIn(golden[1], ret_list) I_free_signatures_list(ret, ctypes.byref(sig_list)) + def test_multiple_libsvms_multiple_mapsets(self): + # Test searching in multiple mapsets. Identical to SIG and SIGSET case + rnd_sig1 = tempname(10) + sig_dir1 = f"{self.rnd_mapset_path}/signatures/libsvm/{rnd_sig1}" + os.makedirs(sig_dir1) + sig_file1 = f"{sig_dir1}/sig" + f = open(sig_file1, "w") + f.write("A libsvm file") + f.close() + rnd_sig2 = tempname(10) + sig_dir2 = f"{self.mpath}/signatures/libsvm/{rnd_sig2}" + os.makedirs(sig_dir2) + sig_file2 = f"{sig_dir2}/sig" + f = open(sig_file2, "w") + f.write("A libsvm file") + f.close() + self.sigdirs.append(sig_dir2) + sig_list = self.list_ptr() + ret = I_signatures_list_by_type( + I_SIGFILE_TYPE_LIBSVM, None, ctypes.byref(sig_list) + ) + # As temporary mapset is not in the search path, there must be + # at least one sig file present + # There could be more sigs if this is not an empty mapset + self.assertTrue(ret >= 1) + ret_list = list(map(utils.decode, sig_list[:ret])) + golden = ( + f"{rnd_sig1}@{self.rnd_mapset_name}", + f"{rnd_sig2}@{self.mapset_name}", + ) + self.assertIn(golden[1], ret_list) + # Temporary mapset is not in the search path: + self.assertNotIn(golden[0], ret_list) + I_free_signatures_list(ret, ctypes.byref(sig_list)) + # Add temporary mapset to search path and re-run test + grass.run_command("g.mapsets", mapset=self.rnd_mapset_name, operation="add") + # Search path is cached for this run => reset! + G_reset_mapsets() + ret = I_signatures_list_by_type( + I_SIGFILE_TYPE_LIBSVM, None, ctypes.byref(sig_list) + ) + grass.run_command("g.mapsets", mapset=self.rnd_mapset_name, operation="remove") + G_reset_mapsets() + shutil.rmtree(sig_dir1) + shutil.rmtree(sig_dir2) + # There could be more sigs if this is not an empty mapset + self.assertTrue(ret >= 2) + ret_list = list(map(utils.decode, sig_list[:ret])) + self.assertIn(golden[0], ret_list) + self.assertIn(golden[1], ret_list) + I_free_signatures_list(ret, ctypes.byref(sig_list)) + if __name__ == "__main__": test()