diff --git a/search/src/main/scala/weco/api/search/services/ImagesRequestBuilder.scala b/search/src/main/scala/weco/api/search/services/ImagesRequestBuilder.scala index 996ccc94d..a5c76ddc7 100644 --- a/search/src/main/scala/weco/api/search/services/ImagesRequestBuilder.scala +++ b/search/src/main/scala/weco/api/search/services/ImagesRequestBuilder.scala @@ -53,10 +53,10 @@ class ImagesRequestBuilder() searchOptions.searchQuery.isDefined || searchOptions.color.isDefined, includes = Seq("display", "vectorValues.reducedFeatures"), aggs = filteredAggregationBuilder(pairables).filteredAggregations, - preFilter = unpairables.map(buildImageFilterQuery), + preFilter = unpairables.collect(buildImageFilterQuery), postFilter = Some( must( - searchOptions.filters.map(buildImageFilterQuery) + searchOptions.filters.collect(buildImageFilterQuery) ) ), knn = searchOptions.color.map(ColorQuery(_)) @@ -111,8 +111,7 @@ class ImagesRequestBuilder() searchOptions.sortOrder } - private def buildImageFilterQuery(filter: ImageFilter): Query = - filter match { + private val buildImageFilterQuery: PartialFunction[ImageFilter, Query] = { case LicenseFilter(licenseIds) => termsQuery( field = "filterableValues.locations.license.id", @@ -134,10 +133,10 @@ class ImagesRequestBuilder() genreLabels ) case GenreConceptFilter(conceptIds) if conceptIds.nonEmpty => - termsQuery( - "filterableValues.source.genres.concepts.id", - conceptIds - ) + termsQuery( + "filterableValues.source.genres.concepts.id", + conceptIds + ) case SubjectLabelFilter(subjectLabels) => termsQuery( "filterableValues.source.subjects.label", diff --git a/search/src/main/scala/weco/api/search/services/WorksRequestBuilder.scala b/search/src/main/scala/weco/api/search/services/WorksRequestBuilder.scala index 27db572f6..fd98a3205 100644 --- a/search/src/main/scala/weco/api/search/services/WorksRequestBuilder.scala +++ b/search/src/main/scala/weco/api/search/services/WorksRequestBuilder.scala @@ -50,10 +50,10 @@ object WorksRequestBuilder sortByScore = searchOptions.searchQuery.isDefined, includes = Seq("display", "type"), aggs = filteredAggregationBuilder(pairables).filteredAggregations, - preFilter = (VisibleWorkFilter :: unpairables).map(buildWorkFilterQuery), + preFilter = (VisibleWorkFilter :: unpairables).collect(buildWorkFilterQuery), postFilter = Some( must( - pairables.map(buildWorkFilterQuery) + pairables.collect(buildWorkFilterQuery) ) ) ) @@ -139,8 +139,7 @@ object WorksRequestBuilder searchOptions.sortOrder } - private def buildWorkFilterQuery(workFilter: WorkFilter): Query = - workFilter match { + val buildWorkFilterQuery: PartialFunction[WorkFilter, Query] = { case VisibleWorkFilter => termQuery(field = "type", value = "Visible") case FormatFilter(formatIds) =>