From 7004f60570723bb13288ec275e7ceceac2b8dfcd Mon Sep 17 00:00:00 2001 From: Paul Butcher Date: Wed, 26 Jul 2023 13:13:18 +0100 Subject: [PATCH] Fix paired aggregation behaviour (#676) * Fix paired aggregation behaviour * improve safety * fix sorting even more * improve commentary * Karl and Jake * remove redundant test --- .../weco/api/search/models/Aggregation.scala | 31 ++-- .../FiltersAndAggregationsBuilder.scala | 125 ++++++++++++---- .../services/ImagesRequestBuilder.scala | 14 +- .../search/services/WorksRequestBuilder.scala | 18 +-- .../models/AggregationResultsTest.scala | 66 --------- .../search/services/AggregationsTest.scala | 20 +-- .../FiltersAndAggregationsBuilderTest.scala | 133 +++++++++++------- 7 files changed, 215 insertions(+), 192 deletions(-) diff --git a/search/src/main/scala/weco/api/search/models/Aggregation.scala b/search/src/main/scala/weco/api/search/models/Aggregation.scala index f8728890ec..9bb499c341 100644 --- a/search/src/main/scala/weco/api/search/models/Aggregation.scala +++ b/search/src/main/scala/weco/api/search/models/Aggregation.scala @@ -48,16 +48,28 @@ import scala.util.{Success, Try} // ... // } object AggregationMapping { - import weco.json.JsonUtil._ - private case class Result(buckets: Option[Seq[Bucket]]) - // When we use a global aggregation we can't predict the key name of the - // resultant sub-aggregation. This optic says, - // "for each key of the root object that has a key `buckets`, decode + // We can't predict the key name of the resultant sub-aggregation. + // This optic says, "for each key of the root object that has a key `buckets`, decode // the value of that field as an array of Buckets" private val globalAggBuckets = root.each.buckets.each.as[Bucket] + // This optic does the same for buckets within the self aggregation + private val selfAggBuckets = root.self.each.buckets.each.as[Bucket] + + // When we use the self aggregation pattern, buckets are returned + // in aggregations at multiple depths. This will return + // buckets from the expected locations. + // The order of sub aggregations vs the top-level aggregation is not guaranteed, + // so construct a sequence consisting of first the top-level buckets, then the self buckets. + // The top-level buckets will contain all the properly counted bucket values. The self buckets + // exist only to "top up" the main list with the filtered values if those values were not returned in + // the main aggregation. + // If any of the filtered terms are present in the main aggregation, then they will be duplicated + // in the self buckets, hence the need for distinct. + private def bucketsFromAnywhere(json: Json): Seq[Bucket] = + (globalAggBuckets.getAll(json) ++ selfAggBuckets.getAll(json)) distinct private case class Bucket( key: Json, @@ -80,7 +92,7 @@ object AggregationMapping { Success(buckets) case Result(None) => parse(jsonString) - .map(globalAggBuckets.getAll) + .map(bucketsFromAnywhere) .toTry } .map { buckets => @@ -98,12 +110,7 @@ object AggregationMapping { } .map { buckets => Aggregation( - buckets - // Sort manually here because filtered aggregations are bucketed before filtering - // therefore they are not always ordered by their final counts. - // Sorting in Scala is stable. - .sortBy(_.count)(Ordering[Int].reverse) - .toList + buckets.toList ) } diff --git a/search/src/main/scala/weco/api/search/services/FiltersAndAggregationsBuilder.scala b/search/src/main/scala/weco/api/search/services/FiltersAndAggregationsBuilder.scala index 9b1e7c3c06..ae645fdcaf 100644 --- a/search/src/main/scala/weco/api/search/services/FiltersAndAggregationsBuilder.scala +++ b/search/src/main/scala/weco/api/search/services/FiltersAndAggregationsBuilder.scala @@ -5,9 +5,11 @@ import com.sksamuel.elastic4s.requests.searches.aggs.{ AbstractAggregation, Aggregation, FilterAggregation, - GlobalAggregation + TermsAggregation, + TermsOrder } import com.sksamuel.elastic4s.requests.searches.queries.Query +import com.sksamuel.elastic4s.requests.searches.term.TermsQuery import weco.api.search.models._ import weco.api.search.models.request.{ ImageAggregationRequest, @@ -34,36 +36,107 @@ trait FiltersAndAggregationsBuilder[Filter, AggregationRequest] { val filters: List[Filter] val requestToAggregation: AggregationRequest => Aggregation val filterToQuery: Filter => Query - val searchQuery: Query def pairedAggregationRequests(filter: Filter): List[AggregationRequest] + // Ensure that characters like parentheses can still be returned by the self aggregation, escape any + // regex tokens that might appear in filter terms. + // (as present in some labels, e.g. /concepts/gafuyqgp: "Nicholson, Michael C. (Michael Christopher), 1962-") + private def escapeRegexTokens(term: String): String = + term.replaceAll("""([.?+*|{}\[\]()\\"])""", "\\\\$1") + + /** + * An aggregation that will contain the filtered-upon value even + * if no documents in the aggregation context match it. + */ + private def toSelfAggregation( + agg: AbstractAggregation, + filterQuery: Query + ): AbstractAggregation = + agg match { + case terms: TermsAggregation => + val filterTerm = filterQuery match { + case TermsQuery(_, values, _, _, _, _, _) => + //Aggregable values are a JSON object encoded as a string + // Filter terms correspond to a property of this JSON + // object (normally id or label). + // In order to ensure that this aggregation only matches + // the whole value in question, it is enclosed in + // escaped quotes. + // This is not perfect, but should be sufficient. + // If (e.g.) this filter operates + // on id and there is a label for a different value + // that exactly matches it, then both will be returned. + // This is an unlikely scenario, and will still result + // in the desired value being returned. + s"""\\"(${values + .map(value => escapeRegexTokens(value.toString)) + .mkString("|")})\\"""" + case _ => "" + } + // The aggregation context may be excluding all documents + // that match this filter term. Setting minDocCount to 0 + // allows the term in question to be returned. + terms.minDocCount(0).includeRegex(s".*($filterTerm).*") + case agg => agg + } + + // Given an aggregation request, convert it to an actual aggregation + // with a predictable sort order. + // Higher count buckets come first, and within each group of higher count buckets, + // the keys should be in alphabetical order. + private def requestToOrderedAggregation( + aggReq: AggregationRequest + ): AbstractAggregation = + requestToAggregation(aggReq) match { + case terms: TermsAggregation => + terms.order( + Seq(TermsOrder("_count", asc = false), TermsOrder("_key", asc = true)) + ) + case agg => agg + } + lazy val filteredAggregations: List[AbstractAggregation] = aggregationRequests.map { aggReq => - val agg = requestToAggregation(aggReq) - pairedFilter(aggReq) match { - case Some(paired) => - val subFilters = filters.filterNot(_ == paired) - GlobalAggregation( - // We would like to rename the aggregation here to something predictable - // (eg "global_agg") but because it is an opaque AbstractAggregation we - // make do with naming it the same as its parent GlobalAggregation, so that - // the latter can be picked off when parsing in WorkAggregations - name = agg.name, - subaggs = Seq( - agg.addSubagg( - FilterAggregation( - "filtered", - boolQuery.filter { - searchQuery :: subFilters.map(filterToQuery) - } - ) - ) + filteredAggregation(aggReq) + } + + /** + * Turn an aggregation request into an actual aggregation. + * The aggregation will be filtered using all filters that do not operate on the same field as the aggregation (if any). + * It also contains a subaggregation that *does* additionally filter on that field. This ensures that the filtered + * values are returned even if they fall outside the top n buckets as defined by the main aggregation + */ + private def filteredAggregation(aggReq: AggregationRequest) = { + val agg = requestToOrderedAggregation(aggReq) + pairedFilter(aggReq) match { + case Some(paired) => + val otherFilters = filters.filterNot(_ == paired) + val pairedQuery = filterToQuery(paired) + FilterAggregation( + name = agg.name, + boolQuery.filter { + otherFilters.map(filterToQuery) + }, + subaggs = Seq( + agg, + FilterAggregation( + name = "self", + pairedQuery, + subaggs = Seq(toSelfAggregation(agg, pairedQuery)) ) ) - case _ => agg - } + ) + case _ => + FilterAggregation( + name = agg.name, + boolQuery.filter { + filters.map(filterToQuery) + }, + subaggs = Seq(agg) + ) } + } private def pairedFilter( aggregationRequest: AggregationRequest @@ -78,8 +151,7 @@ class WorkFiltersAndAggregationsBuilder( val aggregationRequests: List[WorkAggregationRequest], val filters: List[WorkFilter], val requestToAggregation: WorkAggregationRequest => Aggregation, - val filterToQuery: WorkFilter => Query, - val searchQuery: Query + val filterToQuery: WorkFilter => Query ) extends FiltersAndAggregationsBuilder[WorkFilter, WorkAggregationRequest] { override def pairedAggregationRequests( @@ -102,8 +174,7 @@ class ImageFiltersAndAggregationsBuilder( val aggregationRequests: List[ImageAggregationRequest], val filters: List[ImageFilter], val requestToAggregation: ImageAggregationRequest => Aggregation, - val filterToQuery: ImageFilter => Query, - val searchQuery: Query + val filterToQuery: ImageFilter => Query ) extends FiltersAndAggregationsBuilder[ImageFilter, ImageAggregationRequest] { override def pairedAggregationRequests( diff --git a/search/src/main/scala/weco/api/search/services/ImagesRequestBuilder.scala b/search/src/main/scala/weco/api/search/services/ImagesRequestBuilder.scala index 211472765c..d18e9a2d09 100644 --- a/search/src/main/scala/weco/api/search/services/ImagesRequestBuilder.scala +++ b/search/src/main/scala/weco/api/search/services/ImagesRequestBuilder.scala @@ -39,12 +39,7 @@ class ImagesRequestBuilder(queryConfig: QueryConfig) def request(searchOptions: ImageSearchOptions, index: Index): SearchRequest = search(index) .aggs { filteredAggregationBuilder(searchOptions).filteredAggregations } - .query( - searchQuery(searchOptions) - .filter( - buildImageFilterQuery(searchOptions.filters) - ) - ) + .query(searchQuery(searchOptions)) .sortBy { sortBy(searchOptions) } .limit(searchOptions.pageSize) .from(PaginationQuery.safeGetFrom(searchOptions)) @@ -54,6 +49,9 @@ class ImagesRequestBuilder(queryConfig: QueryConfig) // to send the image's vectors to Elasticsearch "query.inferredData.reducedFeatures" ) + .postFilter { + must(buildImageFilterQuery(searchOptions.filters)) + } private def filteredAggregationBuilder(searchOptions: ImageSearchOptions) = new ImageFiltersAndAggregationsBuilder( @@ -61,7 +59,6 @@ class ImagesRequestBuilder(queryConfig: QueryConfig) filters = searchOptions.filters, requestToAggregation = toAggregation, filterToQuery = buildImageFilterQuery, - searchQuery = searchQuery(searchOptions) ) private def searchQuery(searchOptions: ImageSearchOptions): BoolQuery = @@ -144,7 +141,8 @@ class ImagesRequestBuilder(queryConfig: QueryConfig) RangeQuery( "query.source.production.dates.range.from", lte = lte, - gte = gte) + gte = gte + ) } private def buildImageFilterQuery(filters: Seq[ImageFilter]): Seq[Query] = diff --git a/search/src/main/scala/weco/api/search/services/WorksRequestBuilder.scala b/search/src/main/scala/weco/api/search/services/WorksRequestBuilder.scala index 1033a45639..f18d4f29bf 100644 --- a/search/src/main/scala/weco/api/search/services/WorksRequestBuilder.scala +++ b/search/src/main/scala/weco/api/search/services/WorksRequestBuilder.scala @@ -27,11 +27,16 @@ object WorksRequestBuilder implicit val s = searchOptions search(index) .aggs { filteredAggregationBuilder.filteredAggregations } - .query { filteredQuery } + .query { searchQuery } .sortBy { sortBy } .limit { searchOptions.pageSize } .from { PaginationQuery.safeGetFrom(searchOptions) } .sourceInclude("display", "type") + .postFilter { + must( + buildWorkFilterQuery(VisibleWorkFilter :: searchOptions.filters) + ) + } } private def filteredAggregationBuilder( @@ -41,8 +46,7 @@ object WorksRequestBuilder aggregationRequests = searchOptions.aggregations, filters = searchOptions.filters, requestToAggregation = toAggregation, - filterToQuery = buildWorkFilterQuery, - searchQuery = searchQuery + filterToQuery = buildWorkFilterQuery ) private def toAggregation(aggReq: WorkAggregationRequest) = aggReq match { @@ -134,14 +138,6 @@ object WorksRequestBuilder } .getOrElse { boolQuery } - private def filteredQuery( - implicit searchOptions: WorkSearchOptions - ): BoolQuery = - searchQuery - .filter { - buildWorkFilterQuery(VisibleWorkFilter :: searchOptions.filters) - } - private def buildWorkFilterQuery(filters: Seq[WorkFilter]): Seq[Query] = filters.map { buildWorkFilterQuery diff --git a/search/src/test/scala/weco/api/search/models/AggregationResultsTest.scala b/search/src/test/scala/weco/api/search/models/AggregationResultsTest.scala index 49861e4554..72b2ce051c 100644 --- a/search/src/test/scala/weco/api/search/models/AggregationResultsTest.scala +++ b/search/src/test/scala/weco/api/search/models/AggregationResultsTest.scala @@ -143,70 +143,4 @@ class AggregationResultsTest extends AnyFunSpec with Matchers { ) ) } - - it("sorts the buckets by count (in descending order)") { - val searchResponse = SearchResponse( - took = 1234, - isTimedOut = false, - isTerminatedEarly = false, - suggest = Map(), - _shards = Shards(total = 1, failed = 0, successful = 1), - scrollId = None, - hits = SearchHits( - total = Total(0, "potatoes"), - maxScore = 0.0, - hits = Array() - ), - _aggregationsAsMap = Map( - "format" -> Map( - "doc_count" -> 12345, - "format" -> Map( - "doc_count_error_upper_bound" -> 0, - "sum_other_doc_count" -> 0, - "buckets" -> List( - Map( - "key" -> """ "damson" """, - "doc_count" -> 10, - "filtered" -> Map( - "doc_count" -> 1 - ) - ), - Map( - "key" -> """ "cherry" """, - "doc_count" -> 9, - "filtered" -> Map( - "doc_count" -> 2 - ) - ), - Map( - "key" -> """ "banana" """, - "doc_count" -> 8, - "filtered" -> Map( - "doc_count" -> 3 - ) - ), - Map( - "key" -> """ "apricot" """, - "doc_count" -> 7, - "filtered" -> Map( - "doc_count" -> 4 - ) - ) - ) - ) - ) - ) - ) - val singleAgg = WorkAggregations(searchResponse) - singleAgg.get.format - .flatMap(_.buckets.headOption) - .get shouldBe AggregationBucket( - data = Json.fromString("apricot"), - count = 4 - ) - singleAgg.get.format - .map(_.buckets.map(_.count)) - .get - .reverse shouldBe sorted - } } diff --git a/search/src/test/scala/weco/api/search/services/AggregationsTest.scala b/search/src/test/scala/weco/api/search/services/AggregationsTest.scala index c834e032e8..9ed2011c5e 100644 --- a/search/src/test/scala/weco/api/search/services/AggregationsTest.scala +++ b/search/src/test/scala/weco/api/search/services/AggregationsTest.scala @@ -148,35 +148,19 @@ class AggregationsTest List(WorkAggregationRequest.Format, WorkAggregationRequest.Subject), filters = List( FormatFilter(List("a")), - SubjectLabelFilter(Seq("pGkJTZWwn4")) + SubjectLabelFilter(Seq("9SceRNaTEl")) ) ) whenReady(aggregationQuery(index, searchOptions)) { aggs => val buckets = aggs.format.get.buckets - buckets.length shouldBe works.length + buckets.length shouldBe 7 buckets.map(b => getKey(b.data, "label").get.asString.get) should contain theSameElementsAs List( "Books", "Manuscripts", "Music", - "Journals", - "Maps", - "E-videos", - "Videos", "Archives and manuscripts", - "Audio", - "E-journals", - "Pictures", - "Ephemera", - "CD-Roms", "Film", - "Mixed materials", - "Digital Images", - "3-D Objects", - "E-sound", "Standing order", - "E-books", - "Student dissertations", - "Manuscripts", "Web sites" ) } diff --git a/search/src/test/scala/weco/api/search/services/FiltersAndAggregationsBuilderTest.scala b/search/src/test/scala/weco/api/search/services/FiltersAndAggregationsBuilderTest.scala index e58c659d07..1ec4e25389 100644 --- a/search/src/test/scala/weco/api/search/services/FiltersAndAggregationsBuilderTest.scala +++ b/search/src/test/scala/weco/api/search/services/FiltersAndAggregationsBuilderTest.scala @@ -3,13 +3,14 @@ package weco.api.search.services import com.sksamuel.elastic4s.requests.searches.aggs.{ AbstractAggregation, Aggregation, - FilterAggregation, - GlobalAggregation + FilterAggregation } import com.sksamuel.elastic4s.requests.searches.queries.Query import com.sksamuel.elastic4s.requests.searches.queries.compound.BoolQuery +import org.scalatest.LoneElement import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks import weco.api.search.models._ import weco.api.search.models.request.WorkAggregationRequest import weco.api.search.models.{ @@ -19,7 +20,11 @@ import weco.api.search.models.{ WorkFilter } -class FiltersAndAggregationsBuilderTest extends AnyFunSpec with Matchers { +class FiltersAndAggregationsBuilderTest + extends AnyFunSpec + with Matchers + with TableDrivenPropertyChecks + with LoneElement { describe("aggregation-level filtering") { it("applies to aggregations with a paired filter") { @@ -30,22 +35,39 @@ class FiltersAndAggregationsBuilderTest extends AnyFunSpec with Matchers { List(WorkAggregationRequest.Format, WorkAggregationRequest.Languages), filters = List(formatFilter, languagesFilter), requestToAggregation = requestToAggregation, - filterToQuery = filterToQuery, - searchQuery = MockSearchQuery + filterToQuery = filterToQuery ) builder.filteredAggregations should have length 2 - builder.filteredAggregations.head shouldBe a[GlobalAggregation] - val topAgg = builder.filteredAggregations.head - .asInstanceOf[GlobalAggregation] - .subaggs - .head - topAgg shouldBe a[MockAggregation] - - val agg = topAgg.asInstanceOf[MockAggregation] - agg.subaggs.head shouldBe a[FilterAggregation] - agg.request shouldBe WorkAggregationRequest.Format + builder.filteredAggregations.head shouldBe a[FilterAggregation] + // The first aggregation is Format + val filterAgg = builder.filteredAggregations.head + .asInstanceOf[FilterAggregation] + // Filtered on Language=en + filterAgg.query + .asInstanceOf[BoolQuery] + .filters + .loneElement shouldBe MockQuery(LanguagesFilter(Seq("en"))) + + // Within that filtered aggregation are the two aggregation + // requests. + filterAgg.subaggs should have length 2 + // First comes the aggregation for format, without the format filter + val formatAgg = filterAgg.subaggs.head + formatAgg + .asInstanceOf[MockAggregation] + .request shouldBe WorkAggregationRequest.Format + //Then comes the self aggregation + val selfAgg = filterAgg.subaggs(1).asInstanceOf[FilterAggregation] + // Which is the same aggregation + selfAgg.subaggs.loneElement + .asInstanceOf[MockAggregation] + .request shouldBe WorkAggregationRequest.Format + //but additionally, it matches the filter on this field. + selfAgg.query + .asInstanceOf[MockQuery] + .filter shouldBe FormatFilter(Seq("bananas")) } it("does not apply to aggregations without a paired filter") { @@ -54,15 +76,16 @@ class FiltersAndAggregationsBuilderTest extends AnyFunSpec with Matchers { aggregationRequests = List(WorkAggregationRequest.Format), filters = List(languagesFilter), requestToAggregation = requestToAggregation, - filterToQuery = filterToQuery, - searchQuery = MockSearchQuery + filterToQuery = filterToQuery ) - - builder.filteredAggregations should have length 1 - builder.filteredAggregations.head shouldBe a[MockAggregation] - builder.filteredAggregations.head + // The aggregation list is just the requested aggregation + // filtered by the requested (unpaired) filter. + builder.filteredAggregations.loneElement + .asInstanceOf[FilterAggregation] + .subaggs + .loneElement // This marks the absence of the "self" filteraggregation .asInstanceOf[MockAggregation] - .subaggs should have length 0 + .request shouldBe WorkAggregationRequest.Format } it("applies paired filters to non-paired aggregations") { @@ -72,57 +95,68 @@ class FiltersAndAggregationsBuilderTest extends AnyFunSpec with Matchers { List(WorkAggregationRequest.Format, WorkAggregationRequest.Languages), filters = List(formatFilter), requestToAggregation = requestToAggregation, - filterToQuery = filterToQuery, - searchQuery = MockSearchQuery + filterToQuery = filterToQuery ) builder.filteredAggregations should have length 2 - builder.filteredAggregations.head shouldBe a[GlobalAggregation] + //The first aggregation is Format, which has a "self" subaggregation. + // The details of the content of this aggregation are examined + // in the "applies to aggregations with a paired filter" test. builder.filteredAggregations.head - .asInstanceOf[GlobalAggregation] - .subaggs - .head - .asInstanceOf[MockAggregation] - .subaggs should have length 1 + .asInstanceOf[FilterAggregation] + .subaggs should have length 2 - builder.filteredAggregations(1) shouldBe a[MockAggregation] + //The second aggregation is Language, which has no corresponding + // filter in this query, so does not have a "self" subaggregation builder .filteredAggregations(1) + .asInstanceOf[FilterAggregation] + .subaggs + .loneElement .asInstanceOf[MockAggregation] - .subaggs should have length 0 + .request shouldBe WorkAggregationRequest.Languages } it("applies all other aggregation-dependent filters to the paired filter") { val formatFilter = FormatFilter(Seq("bananas")) val languagesFilter = LanguagesFilter(Seq("en")) val genreFilter = GenreFilter(Seq("durian")) + val filters = List(formatFilter, languagesFilter, genreFilter) val builder = new WorkFiltersAndAggregationsBuilder( aggregationRequests = List( WorkAggregationRequest.Format, WorkAggregationRequest.Languages, WorkAggregationRequest.Genre ), - filters = List(formatFilter, languagesFilter, genreFilter), + filters = filters, requestToAggregation = requestToAggregation, - filterToQuery = filterToQuery, - searchQuery = MockSearchQuery + filterToQuery = filterToQuery ) - val agg = - builder.filteredAggregations.head - .asInstanceOf[GlobalAggregation] - .subaggs - .head - .asInstanceOf[MockAggregation] - .subaggs - .head + builder.filteredAggregations should have length 3 + + forAll( + Table( + ("agg", "matchingFilter"), + // The aggregations and the filter list are expected to be + // in the same order, so zipping them should result in + // each aggregation being paired with its corresponding filter + builder.filteredAggregations.zip( + filters + ): _* + ) + ) { (agg: AbstractAggregation, thisFilter: WorkFilter) => + val filterQuery = agg .asInstanceOf[FilterAggregation] - agg.query shouldBe a[BoolQuery] - val query = agg.query.asInstanceOf[BoolQuery] - query.filters should not contain MockQuery(formatFilter) - query.filters should contain only ( - MockSearchQuery, MockQuery(languagesFilter), MockQuery(genreFilter) - ) + .query + .asInstanceOf[BoolQuery] + //Three filters are requested, each aggregation should + // have only two. i.e. not it's own + filterQuery.filters should have length 2 + // And this ensures that it is the correct two. + filterQuery.filters.map(_.asInstanceOf[MockQuery].filter) should contain theSameElementsAs filters + .filterNot(_ == thisFilter) + } } } @@ -134,7 +168,6 @@ class FiltersAndAggregationsBuilderTest extends AnyFunSpec with Matchers { private def filterToQuery(filter: WorkFilter): Query = MockQuery(filter) private case class MockQuery(filter: WorkFilter) extends Query - private case object MockSearchQuery extends Query private case class MockAggregation( name: String,