diff --git a/appveyor.yml b/appveyor.yml index 5659fac0..a1680938 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -3,8 +3,8 @@ environment: matrix: - TARGET_ARCH: x64 CONDA_PY: 36 - CONDA_INSTALL_LOCN: C:\Miniconda3-x64 - GDAL_DATA: C:\Miniconda3-x64\Library\share\gdal + CONDA_INSTALL_LOCN: C:\Miniconda36-x64 + GDAL_DATA: C:\Miniconda36-x64\Library\share\gdal PROJECT_DIR: C:\projects\spatialist platform: @@ -42,3 +42,6 @@ build: off test_script: - coverage run --source spatialist/ -m pytest + +#after_test: +# - coveralls diff --git a/spatialist/ancillary.py b/spatialist/ancillary.py index da451f55..1d0df89b 100644 --- a/spatialist/ancillary.py +++ b/spatialist/ancillary.py @@ -6,6 +6,9 @@ This script gathers central functions and classes for general applications """ import sys +import dill +import tempfile +import platform import tblib.pickling_support if sys.version_info >= (3, 0): @@ -267,44 +270,84 @@ def multicore(function, cores, multiargs, **singleargs): processlist = [dictmerge(dict([(arg, multiargs[arg][i]) for arg in multiargs]), singleargs) for i in range(len(multiargs[list(multiargs.keys())[0]]))] - # block printing of the executed function - results = None - - def wrapper(**kwargs): - try: - return function(**kwargs) - except Exception as e: - return ExceptionWrapper(e) - - with HiddenPrints(): - # start pool of processes and do the work - try: - pool = mp.Pool(processes=cores) - except NameError: - raise ImportError("package 'pathos' could not be imported") - results = pool.imap(lambda x: wrapper(**x), processlist) - pool.close() - pool.join() - - i = 0 - out = [] - for item in results: - if isinstance(item, ExceptionWrapper): - item.ee = type(item.ee)(str(item.ee) + - "\n(called function '{}' with args {})" - .format(function.__name__, processlist[i])) - raise (item.re_raise()) - out.append(item) - i += 1 - - # evaluate the return of the processing function; - # if any value is not None then the whole list of results is returned - eval = [x for x in out if x is not None] - if len(eval) == 0: - return None + if platform.system() == 'Windows': + + # in Windows parallel processing needs to strictly be in a "if __name__ == '__main__':" wrapper + # it was thus necessary to outsource this to a different script and try to serialize all input for sharing objects + # https://stackoverflow.com/questions/38236211/why-multiprocessing-process-behave-differently-on-windows-and-linux-for-global-o + + # a helper script to perform the parallel processing + script = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'multicore_helper.py') + + # a temporary file to write the serialized function variables + tmpfile = os.path.join(tempfile.gettempdir(), 'spatialist_dump') + + # check if everything can be serialized + if not dill.pickles([function, cores, processlist]): + raise RuntimeError('cannot fully serialize function arguments;\n' + ' see https://github.com/uqfoundation/dill for supported types') + + # write the serialized variables + with open(tmpfile, 'wb') as tmp: + dill.dump([function, cores, processlist], tmp, byref=False) + + # run the helper script + proc = sp.Popen([sys.executable, script], stdin=sp.PIPE, stderr=sp.PIPE) + out, err = proc.communicate() + if proc.returncode != 0: + raise RuntimeError(err.decode()) + + # retrieve the serialized output of the processing which was written to the temporary file by the helper script + with open(tmpfile, 'rb') as tmp: + result = dill.load(tmp) + return result else: - return out + results = None + + def wrapper(**kwargs): + try: + return function(**kwargs) + except Exception as e: + return ExceptionWrapper(e) + + # block printing of the executed function + with HiddenPrints(): + # start pool of processes and do the work + try: + pool = mp.Pool(processes=cores) + except NameError: + raise ImportError("package 'pathos' could not be imported") + results = pool.imap(lambda x: wrapper(**x), processlist) + pool.close() + pool.join() + + i = 0 + out = [] + for item in results: + if isinstance(item, ExceptionWrapper): + item.ee = type(item.ee)(str(item.ee) + + "\n(called function '{}' with args {})" + .format(function.__name__, processlist[i])) + raise (item.re_raise()) + out.append(item) + i += 1 + + # evaluate the return of the processing function; + # if any value is not None then the whole list of results is returned + eval = [x for x in out if x is not None] + if len(eval) == 0: + return None + else: + return out + +def add(x, y, z): + """ + only a dummy function for testing the multicore function + defining it in the test script is not possible since it cannot be serialized + with a reference module that does not exist (i.e. the test script) + """ + return x + y + z class ExceptionWrapper(object): """ diff --git a/spatialist/auxil.py b/spatialist/auxil.py index 5e758aac..45965cc0 100644 --- a/spatialist/auxil.py +++ b/spatialist/auxil.py @@ -123,7 +123,10 @@ def gdalwarp(src, dst, options): ------- """ - out = gdal.Warp(dst, src, options=gdal.WarpOptions(**options)) + try: + out = gdal.Warp(dst, src, options=gdal.WarpOptions(**options)) + except RuntimeError as e: + raise RuntimeError('{}:\n src: {}\n dst: {}\n options: {}'.format(str(e), src, dst, options)) out = None diff --git a/spatialist/envi.py b/spatialist/envi.py index f2792665..ab1b43cb 100644 --- a/spatialist/envi.py +++ b/spatialist/envi.py @@ -115,7 +115,7 @@ def __hdr2dict(self): while '}' not in line: i += 1 line += lines[i].strip('\n').lstrip() - line = list(filter(None, re.split('\s+=\s+', line))) + line = list(filter(None, re.split(r'\s+=\s+', line))) line[1] = re.split(',[ ]*', line[1].strip('{}')) key = line[0].replace(' ', '_') val = line[1] if len(line[1]) > 1 else line[1][0] diff --git a/spatialist/multicore_helper.py b/spatialist/multicore_helper.py new file mode 100644 index 00000000..2e05e517 --- /dev/null +++ b/spatialist/multicore_helper.py @@ -0,0 +1,58 @@ +################################################################# +# helper script to be able to use function ancillary.multicore on +# Windows operating systems +# John Truckenbrodt 2019 +################################################################# +import os +import tempfile +import dill + +try: + import pathos.multiprocessing as mp +except ImportError: + pass + +from spatialist.ancillary import HiddenPrints + +if __name__ == '__main__': + + # de-serialize the arguments written by function ancillary.multicore + tmpfile = os.path.join(tempfile.gettempdir(), 'spatialist_dump') + with open(tmpfile, 'rb') as tmp: + func, cores, processlist = dill.load(tmp) + + # serialize the job arguments to be able to pass them to the processes + processlist = [dill.dumps([func, x]) for x in processlist] + + # a simple wrapper to execute the jobs in the sub-processes + # re-import of modules and passing pickled variables is necessary since on + # Windows the environment is not shared between parent and child processes + def wrapper(job): + import dill + function, proc = dill.loads(job) + return function(**proc) + + # hide print messages in the sub-processes + with HiddenPrints(): + # start pool of processes and do the work + try: + pool = mp.Pool(processes=cores) + except NameError: + raise ImportError("package 'pathos' could not be imported") + results = pool.imap(wrapper, processlist) + pool.close() + pool.join() + + outlist = list(results) + + # evaluate the return of the processing function; + # if any value is not None then the whole list of results is returned + eval = [x for x in outlist if x is not None] + if len(eval) == 0: + out = None + else: + out = outlist + + # serialize and write the output list to be able to read it in function ancillary.multicore + with open(tmpfile, 'wb') as tmp: + dill.dump(out, tmp, byref=False) diff --git a/spatialist/raster.py b/spatialist/raster.py index 0d759f67..ab34047c 100644 --- a/spatialist/raster.py +++ b/spatialist/raster.py @@ -1,6 +1,6 @@ ################################################################# # GDAL wrapper for convenient raster data handling and processing -# John Truckenbrodt 2015-2018 +# John Truckenbrodt 2015-2019 ################################################################# @@ -9,7 +9,7 @@ from __future__ import division import os import re -import shutil +import platform import tempfile from math import sqrt, floor, ceil from time import gmtime, strftime @@ -53,7 +53,7 @@ def __init__(self, filename): self.filename = filename if os.path.isabs(filename) else os.path.join(os.getcwd(), filename) self.raster = gdal.Open(filename, GA_ReadOnly) else: - raise OSError('file does not exist') + raise RuntimeError('raster input must be of type str or gdal.Dataset') # a list to contain arrays self.__data = [None] * self.bands @@ -762,7 +762,7 @@ def write(self, outname, dtype='default', format='ENVI', nodata='default', compr if os.path.isfile(outname) and not overwrite: raise RuntimeError('target file already exists') - if format == 'GTiff' and not re.search('\.tif[f]*$', outname): + if format == 'GTiff' and not re.search(r'\.tif[f]*$', outname): outname += '.tif' dtype = Dtype(self.dtype if dtype == 'default' else dtype).gdalint @@ -781,7 +781,8 @@ def write(self, outname, dtype='default', format='ENVI', nodata='default', compr outDataset.SetProjection(self.projection) for i in range(1, self.bands + 1): outband = outDataset.GetRasterBand(i) - outband.SetNoDataValue(nodata) + if nodata is not None: + outband.SetNoDataValue(nodata) mat = self.matrix(band=i) outband.WriteArray(mat) del mat @@ -992,32 +993,36 @@ def stack(srcfiles, dstfile, resampling, targetres, srcnodata, dstnodata, shapef Parameters ---------- srcfiles: list - a list of file names or a list of lists; each sub-list is treated as task to mosaic its containing files + a list of file names or a list of lists; each sub-list is treated as a task to mosaic its containing files dstfile: str - the destination file or a directory (if separate is True) + the destination file or a directory (if `separate` is True) resampling: {near, bilinear, cubic, cubicspline, lanczos, average, mode, max, min, med, Q1, Q3} the resampling method; see `documentation of gdalwarp `_. - targetres: tuple + targetres: tuple or list two entries for x and y spatial resolution in units of the source CRS srcnodata: int or float the nodata value of the source files dstnodata: int or float the nodata value of the destination file(s) shapefile: str, Vector or None - a shapefile for defining the area of the destination files + a shapefile for defining the spatial extent of the destination files layernames: list - the names of the output layers; if `None`, the basenames of the input files are used + the names of the output layers; if `None`, the basenames of the input files are used; overrides sortfun sortfun: function - a function for sorting the input files; this is needed for defining the mosaicking order + a function for sorting the input files; not used if layernames is not None. + This is first used for sorting the items in each sub-list of srcfiles; + the basename of the first item in a sub-list will then be used as the name for the mosaic of this group. + After mosaicing, the function is again used for sorting the names in the final output + (only relevant if `separate` is False) separate: bool - should the files be written to a single raster block or separate files? - If True, each tile is written to GeoTiff. + should the files be written to a single raster stack (ENVI format) or separate files (GTiff format)? overwrite: bool overwrite the file if it already exists? compress: bool compress the geotiff files? cores: int - the number of CPU threads to use; this is only relevant if separate = True + the number of CPU threads to use; this is only relevant if `separate` is True, in which case each + mosaicing/resampling job is passed to a different CPU Returns ------- @@ -1029,47 +1034,49 @@ def stack(srcfiles, dstfile, resampling, targetres, srcnodata, dstnodata, shapef raster CRS prior to retrieving its extent. """ if len(dissolve(srcfiles)) == 0: - raise IOError('no input files provided to function raster.stack') + raise RuntimeError('no input files provided to function raster.stack') if layernames is not None: if len(layernames) != len(srcfiles): - raise IOError('mismatch between number of source file groups and layernames') + raise RuntimeError('mismatch between number of source file groups and layernames') if not isinstance(targetres, (list, tuple)) or len(targetres) != 2: raise RuntimeError('targetres must be a list or tuple with two entries for x and y resolution') if len(srcfiles) == 1 and not isinstance(srcfiles[0], list): - raise IOError('only one file specified; nothing to be done') + raise RuntimeError('only one file specified; nothing to be done') - if resampling not in ['near', 'bilinear', 'cubic', 'cubicspline', 'lanczos', 'average', 'mode', 'max', 'min', 'med', - 'Q1', 'Q3']: - raise IOError('resampling method not supported') + if resampling not in ['near', 'bilinear', 'cubic', 'cubicspline', 'lanczos', + 'average', 'mode', 'max', 'min', 'med', 'Q1', 'Q3']: + raise RuntimeError('resampling method not supported') projections = list() for x in dissolve(srcfiles): try: projection = Raster(x).projection - except OSError as e: + except RuntimeError as e: print('cannot read file: {}'.format(x)) raise e projections.append(projection) projections = list(set(projections)) if len(projections) > 1: - raise IOError('raster projection mismatch') - elif len(projections) == 0: + raise RuntimeError('raster projection mismatch') + elif projections[0] == '': raise RuntimeError('could not retrieve the projection from any of the {} input images'.format(len(srcfiles))) else: srs = projections[0] # read shapefile bounding coordinates and reduce list of rasters to those overlapping with the shapefile if shapefile is not None: - shp = shapefile if isinstance(shapefile, Vector) else Vector(shapefile) + shp = shapefile.clone() if isinstance(shapefile, Vector) else Vector(shapefile) shp.reproject(srs) ext = shp.extent arg_ext = (ext['xmin'], ext['ymin'], ext['xmax'], ext['ymax']) - for i in range(len(srcfiles)): - group = sorted(srcfiles[i], key=sortfun) if isinstance(srcfiles[i], list) else [srcfiles[i]] + for i, item in enumerate(srcfiles): + group = item if isinstance(item, list) else [item] + if layernames is None and sortfun is not None: + group = sorted(group, key=sortfun) group = [x for x in group if intersect(shp, Raster(x).bbox())] if len(group) > 1: srcfiles[i] = group @@ -1077,15 +1084,12 @@ def stack(srcfiles, dstfile, resampling, targetres, srcnodata, dstnodata, shapef srcfiles[i] = group[0] else: srcfiles[i] = None - srcfiles = filter(None, srcfiles) + shp.close() + srcfiles = list(filter(None, srcfiles)) else: arg_ext = None - # create temporary directory for writing intermediate files dst_base = os.path.splitext(dstfile)[0] - tmpdir = dst_base + '__tmp' - if not os.path.isdir(tmpdir): - os.makedirs(tmpdir) options_warp = {'options': ['-q'], 'format': 'GTiff' if separate else 'ENVI', @@ -1103,50 +1107,62 @@ def stack(srcfiles, dstfile, resampling, targetres, srcnodata, dstnodata, shapef options_buildvrt = {'outputBounds': arg_ext, 'srcNodata': srcnodata} # create VRT files for mosaicing - for i in range(len(srcfiles)): - base = srcfiles[i][0] if isinstance(srcfiles[i], list) else srcfiles[i] - vrt = os.path.join(tmpdir, os.path.splitext(os.path.basename(base))[0] + '.vrt') - gdalbuildvrt(srcfiles[i], vrt, options_buildvrt) + for i, group in enumerate(srcfiles): + base = group[0] if isinstance(group, list) else group + # in-memory VRT files cannot be shared between multiple processes on Windows + # this has to do with different process forking behaviour + # see function spatialist.ancillary.multicore and this link: + # https://stackoverflow.com/questions/38236211/why-multiprocessing-process-behave-differently-on-windows-and-linux-for-global-o + vrt_base = os.path.splitext(os.path.basename(base))[0] + '.vrt' + if platform.system() == 'Windows': + vrt = os.path.join(tempfile.gettempdir(), vrt_base) + else: + vrt = '/vsimem/' + vrt_base + gdalbuildvrt(group, vrt, options_buildvrt) srcfiles[i] = vrt - # if no specific layernames are defined and sortfun is not set to None, - # sort files by custom function or, by default, the basename of the raster/VRT file + # if no specific layernames are defined, sort files by custom function if layernames is None and sortfun is not None: - srcfiles = sorted(srcfiles, key=sortfun if sortfun else os.path.basename) + srcfiles = sorted(srcfiles, key=sortfun) + # use the file basenames without extension as band names if none are defined bandnames = [os.path.splitext(os.path.basename(x))[0] for x in srcfiles] if layernames is None else layernames - if separate or len(srcfiles) == 1: + if len(list(set(bandnames))) != len(bandnames): + raise RuntimeError('output bandnames are not unique') + + if separate: if not os.path.isdir(dstfile): os.makedirs(dstfile) dstfiles = [os.path.join(dstfile, x) + '.tif' for x in bandnames] - if overwrite: - files = [x for x in zip(srcfiles, dstfiles)] - else: - files = [x for x in zip(srcfiles, dstfiles) if not os.path.isfile(x[1])] - if len(files) == 0: + jobs = [x for x in zip(srcfiles, dstfiles)] + if not overwrite: + jobs = [x for x in jobs if not os.path.isfile(x[1])] + if len(jobs) == 0: print('all target tiff files already exist, nothing to be done') - shutil.rmtree(tmpdir) return - srcfiles, dstfiles = map(list, zip(*files)) + srcfiles, dstfiles = map(list, zip(*jobs)) multicore(gdalwarp, cores=cores, multiargs={'src': srcfiles, 'dst': dstfiles}, options=options_warp) else: - # create VRT for stacking - vrt = os.path.join(tmpdir, os.path.basename(dst_base) + '.vrt') - options_buildvrt['options'] = ['-separate'] - gdalbuildvrt(srcfiles, vrt, options_buildvrt) - - # warp files - gdalwarp(vrt, dstfile, options_warp) - - # edit ENVI HDR files to contain specific layer names - with envi.HDRobject(dstfile + '.hdr') as hdr: - hdr.band_names = bandnames - hdr.write() - - # remove temporary directory and files - shutil.rmtree(tmpdir) + if len(srcfiles) == 1: + options_warp['format'] = 'GTiff' + if not dstfile.endswith('.tif'): + dstfile = os.path.splitext(dstfile)[0] + '.tif' + gdalwarp(srcfiles[0], dstfile, options_warp) + else: + # create VRT for stacking + vrt = '/vsimem/' + os.path.basename(dst_base) + '.vrt' + options_buildvrt['options'] = ['-separate'] + gdalbuildvrt(srcfiles, vrt, options_buildvrt) + + # warp files + gdalwarp(vrt, dstfile, options_warp) + + # edit ENVI HDR files to contain specific layer names + with envi.HDRobject(dstfile + '.hdr') as hdr: + hdr.band_names = bandnames + hdr.write() class Dtype(object): diff --git a/spatialist/sqlite_util.py b/spatialist/sqlite_util.py index 1c8797fa..55bcd81f 100644 --- a/spatialist/sqlite_util.py +++ b/spatialist/sqlite_util.py @@ -18,12 +18,12 @@ def check_loading(): 'please refer to the spatialist installation instructions' try: import sqlite3 - + check_loading() except RuntimeError: try: from pysqlite2 import dbapi2 as sqlite3 - + check_loading() except ImportError: raise RuntimeError(errormessage.format('pysqlite2 does not exist as alternative')) @@ -96,7 +96,7 @@ def __init__(self, driver=':memory:', extensions=None): print('using sqlite version {}'.format(self.version['sqlite'])) if 'spatialite' in self.version.keys(): print('using spatialite version {}'.format(self.version['spatialite'])) - + @property def version(self): out = {'sqlite': sqlite3.sqlite_version} @@ -108,7 +108,7 @@ def version(self): except sqlite3.OperationalError: pass return out - + def get_tablenames(self): cursor = self.conn.cursor() cursor.execute('SELECT * FROM sqlite_master WHERE type="table"') @@ -116,7 +116,7 @@ def get_tablenames(self): if bool(type('unicode')): names = [str(x) for x in names] return names - + def load_extension(self, extension): if re.search('spatialite', extension): spatialite_setup() @@ -132,11 +132,11 @@ def load_extension(self, extension): except sqlite3.OperationalError as e: print('{0}: {1}'.format(option, str(e))) continue - + # if loading mod_spatialite fails try to load libspatialite directly if select is None: self.__load_regular('spatialite') - + # initialize spatial support if 'spatial_ref_sys' not in self.get_tablenames(): cursor = self.conn.cursor() @@ -147,26 +147,26 @@ def load_extension(self, extension): # mod_spatialite extension cursor.execute('SELECT InitSpatialMetaData(1);') self.conn.commit() - + else: self.__load_regular(extension) - + def __load_regular(self, extension): options = [] - + # create an extension library option starting with 'lib' without extension suffices; # e.g. 'libgdal' but not 'gdal.so' ext_base = self.__split_ext(extension) if not ext_base.startswith('lib'): ext_base = 'lib' + ext_base options.append(ext_base) - + # get the full extension library name; e.g. 'libgdal.so.20' ext_mod = find_library(extension.replace('lib', '')) if ext_mod is None: raise RuntimeError('no library found for extension {}'.format(extension)) options.append(ext_mod) - + # loop through extension library name options and try to load them success = False for option in options: @@ -178,12 +178,13 @@ def __load_regular(self, extension): break except sqlite3.OperationalError: continue - + if not success: raise RuntimeError('failed to load extension {}'.format(extension)) - - def __split_ext(self, extension): + + @staticmethod + def __split_ext(extension): base = extension - while re.search('\.', base): + while re.search(r'\.', base): base = os.path.splitext(base)[0] return base diff --git a/spatialist/tests/test_ancillary.py b/spatialist/tests/test_ancillary.py index be9e92be..7875dee5 100644 --- a/spatialist/tests/test_ancillary.py +++ b/spatialist/tests/test_ancillary.py @@ -46,18 +46,17 @@ def test_which(): def test_multicore(): - add = lambda x, y, z: x + y + z - assert anc.multicore(add, cores=2, multiargs={'x': [1, 2]}, y=5, z=9) == [15, 16] - assert anc.multicore(add, cores=2, multiargs={'x': [1, 2], 'y': [5, 6]}, z=9) == [15, 17] + assert anc.multicore(anc.add, cores=2, multiargs={'x': [1, 2]}, y=5, z=9) == [15, 16] + assert anc.multicore(anc.add, cores=2, multiargs={'x': [1, 2], 'y': [5, 6]}, z=9) == [15, 17] # unknown argument in multiargs with pytest.raises(AttributeError): - anc.multicore(add, cores=2, multiargs={'foobar': [1, 2]}, y=5, z=9) + anc.multicore(anc.add, cores=2, multiargs={'foobar': [1, 2]}, y=5, z=9) # unknown argument in single args with pytest.raises(AttributeError): - anc.multicore(add, cores=2, multiargs={'x': [1, 2]}, y=5, foobar=9) + anc.multicore(anc.add, cores=2, multiargs={'x': [1, 2]}, y=5, foobar=9) # multiarg values of different length with pytest.raises(AttributeError): - anc.multicore(add, cores=2, multiargs={'x': [1, 2], 'y': [5, 6, 7]}, z=9) + anc.multicore(anc.add, cores=2, multiargs={'x': [1, 2], 'y': [5, 6, 7]}, z=9) def test_finder(tmpdir, testdata): diff --git a/spatialist/tests/test_spatial.py b/spatialist/tests/test_spatial.py index a4c9ae16..7bbb7ebe 100644 --- a/spatialist/tests/test_spatial.py +++ b/spatialist/tests/test_spatial.py @@ -1,9 +1,11 @@ import os +import shutil import pytest import platform import numpy as np from osgeo import ogr -from spatialist import crsConvert, haversine, Raster, stack, ogr2ogr, gdal_translate, gdal_rasterize, bbox, rasterize +from spatialist import crsConvert, haversine, Raster, stack, ogr2ogr, gdal_translate, gdal_rasterize, bbox, rasterize, \ + gdalwarp from spatialist.raster import Dtype from spatialist.vector import feature2vector, dissolve, Vector, intersect from spatialist.envi import hdr, HDRobject @@ -74,7 +76,7 @@ def test_dissolve(tmpdir, travis, testdata): bbox3_name = os.path.join(str(tmpdir), 'bbox3.shp') bbox1.write(bbox3_name) bbox1.close() - + if not travis and platform.system() != 'Windows': # dissolve the geometries in bbox3 and write the result to new bbox4 # this test is currently disabled for Travis as the current sqlite3 version on Travis seems to not support @@ -106,7 +108,7 @@ def test_Raster(tmpdir, testdata): assert len(ras.layers()) == 1 assert ras.projcs == 'WGS 84 / UTM zone 31N' assert ras.res == (20.0, 20.0) - + # test writing a subset with no original data in memory outname = os.path.join(str(tmpdir), 'test_sub.tif') with ras[0:200, 0:100] as sub: @@ -114,14 +116,14 @@ def test_Raster(tmpdir, testdata): with Raster(outname) as ras2: assert ras2.cols == 100 assert ras2.rows == 200 - + ras.load() mat = ras.matrix() assert isinstance(mat, np.ndarray) ras.assign(mat, band=0) # ras.reduce() ras.rescale(lambda x: 10 * x) - + # test writing data with original data in memory ras.write(os.path.join(str(tmpdir), 'test'), format='GTiff', compress_tif=True) with pytest.raises(RuntimeError): @@ -152,13 +154,13 @@ def test_Raster_extract(testdata): ras.extract(1, 4830000) with pytest.raises(RuntimeError): ras.extract(624000, 1) - + # ensure corner extraction capability assert ras.extract(px=ras.geo['xmin'], py=ras.geo['ymax']) == -10.147890090942383 assert ras.extract(px=ras.geo['xmin'], py=ras.geo['ymin']) == -14.640368461608887 assert ras.extract(px=ras.geo['xmax'], py=ras.geo['ymax']) == -9.599242210388182 assert ras.extract(px=ras.geo['xmax'], py=ras.geo['ymin']) == -9.406558990478516 - + # test nodata handling capability and correct indexing mat = ras.matrix() mat[0:10, 0:10] = ras.nodata @@ -170,50 +172,120 @@ def test_Raster_extract(testdata): def test_dtypes(): assert Dtype('Float32').gdalint == 6 + assert Dtype(6).gdalstr == 'Float32' + assert Dtype('uint32').gdalstr == 'UInt32' with pytest.raises(ValueError): Dtype('foobar') + with pytest.raises(ValueError): + Dtype(999) + with pytest.raises(TypeError): + Dtype(None) def test_stack(tmpdir, testdata): name = testdata['tif'] outname = os.path.join(str(tmpdir), 'test') tr = (30, 30) - with pytest.raises(IOError): + # no input files provided + with pytest.raises(RuntimeError): stack(srcfiles=[], resampling='near', targetres=tr, srcnodata=-99, dstnodata=-99, dstfile=outname) - - with pytest.raises(IOError): + + # two files, but only one layer name + with pytest.raises(RuntimeError): stack(srcfiles=[name, name], resampling='near', targetres=tr, srcnodata=-99, dstnodata=-99, dstfile=outname, layernames=['a']) - + + # targetres must be a two-entry tuple/list with pytest.raises(RuntimeError): stack(srcfiles=[name, name], resampling='near', targetres=30, srcnodata=-99, dstnodata=-99, dstfile=outname) - - with pytest.raises(IOError): + + # only one file specified + with pytest.raises(RuntimeError): stack(srcfiles=[name], resampling='near', targetres=tr, overwrite=True, srcnodata=-99, dstnodata=-99, dstfile=outname) - + + # targetres must contain two values with pytest.raises(RuntimeError): stack(srcfiles=[name, name], resampling='near', targetres=(30, 30, 30), srcnodata=-99, dstnodata=-99, dstfile=outname) - - with pytest.raises(IOError): + + # unknown resampling method + with pytest.raises(RuntimeError): stack(srcfiles=[name, name], resampling='foobar', targetres=tr, srcnodata=-99, dstnodata=-99, dstfile=outname) - + + # non-existing files + with pytest.raises(RuntimeError): + stack(srcfiles=['foo', 'bar'], resampling='near', targetres=tr, + srcnodata=-99, dstnodata=-99, dstfile=outname) + + # create a multi-band stack stack(srcfiles=[name, name], resampling='near', targetres=tr, overwrite=True, - srcnodata=-99, dstnodata=-99, dstfile=outname) - - outdir = os.path.join(str(tmpdir), 'subdir') - stack(srcfiles=[name, name], resampling='near', targetres=tr, overwrite=True, layernames=['test1', 'test2'], - srcnodata=-99, dstnodata=-99, dstfile=outdir, separate=True, compress=True) - + srcnodata=-99, dstnodata=-99, dstfile=outname, layernames=['test1', 'test2']) with Raster(outname) as ras: assert ras.bands == 2 # Raster.rescale currently only supports one band with pytest.raises(ValueError): ras.rescale(lambda x: x * 10) + + # pass shapefile + outname = os.path.join(str(tmpdir), 'test2') + with Raster(name).bbox() as box: + stack(srcfiles=[name, name], resampling='near', targetres=tr, overwrite=True, + srcnodata=-99, dstnodata=-99, dstfile=outname, shapefile=box, layernames=['test1', 'test2']) + with Raster(outname) as ras: + assert ras.bands == 2 + + # pass shapefile and do mosaicing + outname = os.path.join(str(tmpdir), 'test3') + with Raster(name).bbox() as box: + stack(srcfiles=[[name, name]], resampling='near', targetres=tr, overwrite=True, + srcnodata=-99, dstnodata=-99, dstfile=outname, shapefile=box) + with Raster(outname + '.tif') as ras: + assert ras.bands == 1 + assert ras.format == 'GTiff' + + # projection mismatch + name2 = os.path.join(str(tmpdir), os.path.basename(name)) + outname = os.path.join(str(tmpdir), 'test4') + gdalwarp(name, name2, options={'dstSRS': crsConvert(4326, 'wkt')}) + with pytest.raises(RuntimeError): + stack(srcfiles=[name, name2], resampling='near', targetres=tr, overwrite=True, + srcnodata=-99, dstnodata=-99, dstfile=outname) + + # no projection found + outname = os.path.join(str(tmpdir), 'test5') + gdal_translate(name, name2, {'options': ['-co', 'PROFILE=BASELINE']}) + with Raster(name2) as ras: + print(ras.projection) + with pytest.raises(RuntimeError): + stack(srcfiles=[name2, name2], resampling='near', targetres=tr, overwrite=True, + srcnodata=-99, dstnodata=-99, dstfile=outname) + + # create separate GeoTiffs + outdir = os.path.join(str(tmpdir), 'subdir') + stack(srcfiles=[name, name], resampling='near', targetres=tr, overwrite=True, layernames=['test1', 'test2'], + srcnodata=-99, dstnodata=-99, dstfile=outdir, separate=True, compress=True) + + # repeat with overwrite disabled (no error raised, just a print message) + stack(srcfiles=[name, name], resampling='near', targetres=tr, overwrite=False, layernames=['test1', 'test2'], + srcnodata=-99, dstnodata=-99, dstfile=outdir, separate=True, compress=True) + + # repeat without layernames but sortfun + # bandnames not unique + outdir = os.path.join(str(tmpdir), 'subdir2') + with pytest.raises(RuntimeError): + stack(srcfiles=[name, name], resampling='near', targetres=tr, overwrite=True, sortfun=os.path.basename, + srcnodata=-99, dstnodata=-99, dstfile=outdir, separate=True, compress=True) + + # repeat without layernames but sortfun + name2 = os.path.join(str(tmpdir), os.path.basename(name).replace('VV', 'XX')) + shutil.copyfile(name, name2) + outdir = os.path.join(str(tmpdir), 'subdir2') + stack(srcfiles=[name, name2], resampling='near', targetres=tr, overwrite=True, sortfun=os.path.basename, + srcnodata=-99, dstnodata=-99, dstfile=outdir, separate=True, compress=True) def test_auxil(tmpdir, testdata): @@ -230,22 +302,22 @@ def test_rasterize(tmpdir, testdata): outname = os.path.join(str(tmpdir), 'test.shp') with Raster(testdata['tif']) as ras: vec = ras.bbox() - + # test length mismatch between burn_values and expressions with pytest.raises(RuntimeError): rasterize(vec, reference=ras, outname=outname, burn_values=[1], expressions=['foo', 'bar']) - + # test a faulty expression with pytest.raises(RuntimeError): rasterize(vec, reference=ras, outname=outname, burn_values=[1], expressions=['foo']) - + # test default parametrization rasterize(vec, reference=ras, outname=outname) assert os.path.isfile(outname) - + # test appending to existing file with valid expression rasterize(vec, reference=ras, outname=outname, append=True, burn_values=[1], expressions=['area=23262400.0']) - + # test wrong input type for reference with pytest.raises(RuntimeError): rasterize(vec, reference='foobar', outname=outname) @@ -277,7 +349,7 @@ def test_sqlite(appveyor): con.close() con = __Handler() assert sorted(con.version.keys()) == ['sqlite'] - + con = __Handler(extensions=['spatialite']) assert sorted(con.version.keys()) == ['spatialite', 'sqlite'] assert 'spatial_ref_sys' in con.get_tablenames() diff --git a/spatialist/vector.py b/spatialist/vector.py index ae3926eb..9680525e 100644 --- a/spatialist/vector.py +++ b/spatialist/vector.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- ################################################################ # OGR wrapper for convenient vector data handling and processing -# John Truckenbrodt 2015-2018 +# John Truckenbrodt 2015-2019 ################################################################ @@ -204,6 +204,9 @@ def bbox(self, outname=None, format='ESRI Shapefile', overwrite=True): else: bbox(self.extent, self.srs, outname=outname, format=format, overwrite=overwrite) + def clone(self): + return feature2vector(self.getfeatures(), ref=self) + def close(self): """ closes the OGR vector file connection @@ -805,6 +808,9 @@ def intersect(obj1, obj2): if not isinstance(obj1, Vector) or not isinstance(obj2, Vector): raise RuntimeError('both objects must be of type Vector') + obj1 = obj1.clone() + obj2 = obj2.clone() + obj1.reproject(obj2.srs) #######################################################