diff --git a/api/formatters/__init__.py b/api/formatters/__init__.py index 625ea09c..5c8f640c 100644 --- a/api/formatters/__init__.py +++ b/api/formatters/__init__.py @@ -7,11 +7,7 @@ class Formats(str, Enum): - covjson = "covjson" + covjson = "CoverageJSON" # According to EDR spec -def get_edr_formatters() -> dict: - available_formatters = {"covjson": covjson.Covjson()} - # Should also setup dict for alias discover - - return available_formatters +formatters = {"CoverageJSON": covjson.convert_to_covjson} diff --git a/api/formatters/base_formatter.py b/api/formatters/base_formatter.py deleted file mode 100644 index 59386508..00000000 --- a/api/formatters/base_formatter.py +++ /dev/null @@ -1,18 +0,0 @@ -from abc import ABC -from abc import abstractmethod - - -class EDRFormatter(ABC): - """ - This is the abstract class for implementing a formatter in the E-SOH EDR formatter - Name of class should represent expected output format. - """ - - pass - - @abstractmethod - def convert(self, datastore_reply): - """ - Main method for converting protobuf object to given format. - """ - pass diff --git a/api/formatters/covjson.py b/api/formatters/covjson.py index 5a7555c8..160dd1ec 100644 --- a/api/formatters/covjson.py +++ b/api/formatters/covjson.py @@ -15,82 +15,78 @@ from covjson_pydantic.reference_system import ReferenceSystemConnectionObject from covjson_pydantic.unit import Unit from fastapi import HTTPException -from formatters.base_formatter import EDRFormatter from pydantic import AwareDatetime -class Covjson(EDRFormatter): - """ - Class for converting protobuf object to coverage json - """ +# alias = ["covjson", "coveragejson"] +# mime_type = "application/prs.coverage+json" - alias = ["covjson", "coveragejson"] - mime_type = "application/json" # find the type for covjson - def convert(self, response): - # Collect data - coverages = [] - data = [self._collect_data(md.ts_mdata, md.obs_mdata) for md in response.observations] +def convert_to_covjson(response): + # Collect data + coverages = [] + data = [_collect_data(md.ts_mdata, md.obs_mdata) for md in response.observations] - # Need to sort before using groupBy. Also sort on param_id to get consistently sorted output - data.sort(key=lambda x: (x[0], x[1])) - # The multiple coverage logic is not needed for this endpoint, - # but we want to share this code between endpoints - for (lat, lon, times), group in groupby(data, lambda x: x[0]): - referencing = [ - ReferenceSystemConnectionObject( - coordinates=["y", "x"], - system=ReferenceSystem(type="GeographicCRS", id="http://www.opengis.net/def/crs/EPSG/0/4326"), - ), - ReferenceSystemConnectionObject( - coordinates=["z"], - system=ReferenceSystem(type="TemporalRS", calendar="Gregorian"), - ), - ] - domain = Domain( - domainType=DomainType.point_series, - axes=Axes( - x=ValuesAxis[float](values=[lon]), - y=ValuesAxis[float](values=[lat]), - t=ValuesAxis[AwareDatetime](values=times), - ), - referencing=referencing, + # Need to sort before using groupBy. Also sort on param_id to get consistently sorted output + data.sort(key=lambda x: (x[0], x[1])) + # The multiple coverage logic is not needed for this endpoint, + # but we want to share this code between endpoints + for (lat, lon, times), group in groupby(data, lambda x: x[0]): + referencing = [ + ReferenceSystemConnectionObject( + coordinates=["y", "x"], + system=ReferenceSystem(type="GeographicCRS", id="http://www.opengis.net/def/crs/EPSG/0/4326"), + ), + ReferenceSystemConnectionObject( + coordinates=["z"], + system=ReferenceSystem(type="TemporalRS", calendar="Gregorian"), + ), + ] + domain = Domain( + domainType=DomainType.point_series, + axes=Axes( + x=ValuesAxis[float](values=[lon]), + y=ValuesAxis[float](values=[lat]), + t=ValuesAxis[AwareDatetime](values=times), + ), + referencing=referencing, + ) + + parameters = {} + ranges = {} + for (_, _, _), param_id, unit, values in group: + if all(math.isnan(v) for v in values): + continue # Drop ranges if completely nan. + # TODO: Drop the whole coverage if it becomes empty? + values_no_nan = [v if not math.isnan(v) else None for v in values] + # TODO: Improve this based on "standard name", etc. + parameters[param_id] = Parameter( + observedProperty=ObservedProperty(label={"en": param_id}), unit=Unit(label={"en": unit}) + ) # TODO: Also fill symbol? + ranges[param_id] = NdArray( + values=values_no_nan, axisNames=["t", "y", "x"], shape=[len(values_no_nan), 1, 1] ) - parameters = {} - ranges = {} - for (_, _, _), param_id, unit, values in group: - if all(math.isnan(v) for v in values): - continue # Drop ranges if completely nan. - # TODO: Drop the whole coverage if it becomes empty? - values_no_nan = [v if not math.isnan(v) else None for v in values] - # TODO: Improve this based on "standard name", etc. - parameters[param_id] = Parameter( - observedProperty=ObservedProperty(label={"en": param_id}), unit=Unit(label={"en": unit}) - ) # TODO: Also fill symbol? - ranges[param_id] = NdArray( - values=values_no_nan, axisNames=["t", "y", "x"], shape=[len(values_no_nan), 1, 1] - ) + coverages.append(Coverage(domain=domain, parameters=parameters, ranges=ranges)) - coverages.append(Coverage(domain=domain, parameters=parameters, ranges=ranges)) + if len(coverages) == 0: + raise HTTPException(status_code=404, detail="No data found") + elif len(coverages) == 1: + return coverages[0] + else: + return CoverageCollection( + coverages=coverages, parameters=coverages[0].parameters + ) # HACK to take parameters from first one - if len(coverages) == 0: - raise HTTPException(status_code=404, detail="No data found") - elif len(coverages) == 1: - return coverages[0] - else: - return CoverageCollection( - coverages=coverages, parameters=coverages[0].parameters - ) # HACK to take parameters from first one - def _collect_data(self, ts_mdata, obs_mdata): - lat = obs_mdata[0].geo_point.lat # HACK: For now assume they all have the same position - lon = obs_mdata[0].geo_point.lon - tuples = ( - (o.obstime_instant.ToDatetime(tzinfo=timezone.utc), float(o.value)) for o in obs_mdata - ) # HACK: str -> float - (times, values) = zip(*tuples) - param_id = ts_mdata.instrument - unit = ts_mdata.unit +def _collect_data(ts_mdata, obs_mdata): + lat = obs_mdata[0].geo_point.lat # HACK: For now assume they all have the same position + lon = obs_mdata[0].geo_point.lon + tuples = ( + (o.obstime_instant.ToDatetime(tzinfo=timezone.utc), float(o.value)) for o in obs_mdata + ) # HACK: str -> float + (times, values) = zip(*tuples) + param_id = ts_mdata.instrument + unit = ts_mdata.unit - return (lat, lon, times), param_id, unit, values + return (lat, lon, times), param_id, unit, values diff --git a/api/routers/edr.py b/api/routers/edr.py index e15de25b..7cb149a2 100644 --- a/api/routers/edr.py +++ b/api/routers/edr.py @@ -17,8 +17,6 @@ router = APIRouter(prefix="/collections/observations") -edr_formatter = formatters.get_edr_formatters() - @router.get( "/locations", @@ -78,7 +76,7 @@ async def get_data_location_id( temporal_interval=dstore.TimeInterval(start=range[0], end=range[1]) if range else None, ) response = await getObsRequest(get_obs_request) - return edr_formatter[f].convert(response) + return formatters.formatters[f](response) @router.get( @@ -122,5 +120,5 @@ async def get_data_area( temporal_interval=dstore.TimeInterval(start=range[0], end=range[1]) if range else None, ) coverages = await getObsRequest(get_obs_request) - coverages = edr_formatter[f].convert(coverages) + coverages = formatters.formatters[f](coverages) return coverages