Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix paired aggregation behaviour #676

Merged
merged 7 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions search/src/main/scala/weco/api/search/models/Aggregation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,28 @@ import scala.util.{Success, Try}
// ...
// }
object AggregationMapping {

import weco.json.JsonUtil._

private case class Result(buckets: Option[Seq[Bucket]])

// When we use a global aggregation we can't predict the key name of the
// resultant sub-aggregation. This optic says,
// "for each key of the root object that has a key `buckets`, decode
// We can't predict the key name of the resultant sub-aggregation.
// This optic says, "for each key of the root object that has a key `buckets`, decode
// the value of that field as an array of Buckets"
private val globalAggBuckets = root.each.buckets.each.as[Bucket]
// This optic does the same for buckets within the self aggregation
private val selfAggBuckets = root.self.each.buckets.each.as[Bucket]

// When we use the self aggregation pattern, buckets are returned
// in aggregations at multiple depths. This will return
// buckets from the expected locations.
// The order of sub aggregations vs the top-level aggregation is not guaranteed,
// so construct a sequence consisting of first the top-level buckets, then the self buckets.
// The top-level buckets will contain all the properly counted bucket values. The self buckets
// exist only to "top up" the main list with the filtered values if those values were not returned in
// the main aggregation.
// If any of the filtered terms are present in the main aggregation, then they will be duplicated
// in the self buckets, hence the need for distinct.
private def bucketsFromAnywhere(json: Json): Seq[Bucket] =
(globalAggBuckets.getAll(json) ++ selfAggBuckets.getAll(json)) distinct

private case class Bucket(
key: Json,
Expand All @@ -80,7 +92,7 @@ object AggregationMapping {
Success(buckets)
case Result(None) =>
parse(jsonString)
.map(globalAggBuckets.getAll)
.map(bucketsFromAnywhere)
.toTry
}
.map { buckets =>
Expand All @@ -98,12 +110,7 @@ object AggregationMapping {
}
.map { buckets =>
Aggregation(
buckets
// Sort manually here because filtered aggregations are bucketed before filtering
// therefore they are not always ordered by their final counts.
// Sorting in Scala is stable.
.sortBy(_.count)(Ordering[Int].reverse)
.toList
buckets.toList
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import com.sksamuel.elastic4s.requests.searches.aggs.{
AbstractAggregation,
Aggregation,
FilterAggregation,
GlobalAggregation
TermsAggregation,
TermsOrder
}
import com.sksamuel.elastic4s.requests.searches.queries.Query
import com.sksamuel.elastic4s.requests.searches.term.TermsQuery
import weco.api.search.models._
import weco.api.search.models.request.{
ImageAggregationRequest,
Expand All @@ -34,36 +36,107 @@ trait FiltersAndAggregationsBuilder[Filter, AggregationRequest] {
val filters: List[Filter]
val requestToAggregation: AggregationRequest => Aggregation
val filterToQuery: Filter => Query
val searchQuery: Query

def pairedAggregationRequests(filter: Filter): List[AggregationRequest]

// Ensure that characters like parentheses can still be returned by the self aggregation, escape any
// regex tokens that might appear in filter terms.
// (as present in some labels, e.g. /concepts/gafuyqgp: "Nicholson, Michael C. (Michael Christopher), 1962-")
private def escapeRegexTokens(term: String): String =
term.replaceAll("""([.?+*|{}\[\]()\\"])""", "\\\\$1")

/**
* An aggregation that will contain the filtered-upon value even
* if no documents in the aggregation context match it.
*/
private def toSelfAggregation(
agg: AbstractAggregation,
filterQuery: Query
): AbstractAggregation =
agg match {
case terms: TermsAggregation =>
val filterTerm = filterQuery match {
case TermsQuery(_, values, _, _, _, _, _) =>
//Aggregable values are a JSON object encoded as a string
// Filter terms correspond to a property of this JSON
// object (normally id or label).
// In order to ensure that this aggregation only matches
// the whole value in question, it is enclosed in
// escaped quotes.
// This is not perfect, but should be sufficient.
// If (e.g.) this filter operates
// on id and there is a label for a different value
// that exactly matches it, then both will be returned.
// This is an unlikely scenario, and will still result
// in the desired value being returned.
s"""\\"(${values
.map(value => escapeRegexTokens(value.toString))
.mkString("|")})\\""""
case _ => ""
}
// The aggregation context may be excluding all documents
// that match this filter term. Setting minDocCount to 0
// allows the term in question to be returned.
terms.minDocCount(0).includeRegex(s".*($filterTerm).*")
case agg => agg
}

// Given an aggregation request, convert it to an actual aggregation
// with a predictable sort order.
// Higher count buckets come first, and within each group of higher count buckets,
// the keys should be in alphabetical order.
private def requestToOrderedAggregation(
aggReq: AggregationRequest
): AbstractAggregation =
requestToAggregation(aggReq) match {
case terms: TermsAggregation =>
terms.order(
Seq(TermsOrder("_count", asc = false), TermsOrder("_key", asc = true))
)
case agg => agg
}

lazy val filteredAggregations: List[AbstractAggregation] =
aggregationRequests.map { aggReq =>
val agg = requestToAggregation(aggReq)
pairedFilter(aggReq) match {
case Some(paired) =>
val subFilters = filters.filterNot(_ == paired)
GlobalAggregation(
// We would like to rename the aggregation here to something predictable
// (eg "global_agg") but because it is an opaque AbstractAggregation we
// make do with naming it the same as its parent GlobalAggregation, so that
// the latter can be picked off when parsing in WorkAggregations
name = agg.name,
subaggs = Seq(
agg.addSubagg(
FilterAggregation(
"filtered",
boolQuery.filter {
searchQuery :: subFilters.map(filterToQuery)
}
)
)
filteredAggregation(aggReq)
}

/**
* Turn an aggregation request into an actual aggregation.
* The aggregation will be filtered using all filters that do not operate on the same field as the aggregation (if any).
* It also contains a subaggregation that *does* additionally filter on that field. This ensures that the filtered
* values are returned even if they fall outside the top n buckets as defined by the main aggregation
*/
private def filteredAggregation(aggReq: AggregationRequest) = {
val agg = requestToOrderedAggregation(aggReq)
pairedFilter(aggReq) match {
case Some(paired) =>
val otherFilters = filters.filterNot(_ == paired)
val pairedQuery = filterToQuery(paired)
FilterAggregation(
name = agg.name,
boolQuery.filter {
otherFilters.map(filterToQuery)
},
subaggs = Seq(
agg,
FilterAggregation(
name = "self",
pairedQuery,
subaggs = Seq(toSelfAggregation(agg, pairedQuery))
)
)
case _ => agg
}
)
case _ =>
FilterAggregation(
name = agg.name,
boolQuery.filter {
filters.map(filterToQuery)
},
subaggs = Seq(agg)
)
}
}

private def pairedFilter(
aggregationRequest: AggregationRequest
Expand All @@ -78,8 +151,7 @@ class WorkFiltersAndAggregationsBuilder(
val aggregationRequests: List[WorkAggregationRequest],
val filters: List[WorkFilter],
val requestToAggregation: WorkAggregationRequest => Aggregation,
val filterToQuery: WorkFilter => Query,
val searchQuery: Query
val filterToQuery: WorkFilter => Query
) extends FiltersAndAggregationsBuilder[WorkFilter, WorkAggregationRequest] {

override def pairedAggregationRequests(
Expand All @@ -102,8 +174,7 @@ class ImageFiltersAndAggregationsBuilder(
val aggregationRequests: List[ImageAggregationRequest],
val filters: List[ImageFilter],
val requestToAggregation: ImageAggregationRequest => Aggregation,
val filterToQuery: ImageFilter => Query,
val searchQuery: Query
val filterToQuery: ImageFilter => Query
) extends FiltersAndAggregationsBuilder[ImageFilter, ImageAggregationRequest] {

override def pairedAggregationRequests(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,7 @@ class ImagesRequestBuilder(queryConfig: QueryConfig)
def request(searchOptions: ImageSearchOptions, index: Index): SearchRequest =
search(index)
.aggs { filteredAggregationBuilder(searchOptions).filteredAggregations }
.query(
searchQuery(searchOptions)
.filter(
buildImageFilterQuery(searchOptions.filters)
)
)
.query(searchQuery(searchOptions))
.sortBy { sortBy(searchOptions) }
.limit(searchOptions.pageSize)
.from(PaginationQuery.safeGetFrom(searchOptions))
Expand All @@ -54,14 +49,16 @@ class ImagesRequestBuilder(queryConfig: QueryConfig)
// to send the image's vectors to Elasticsearch
"query.inferredData.reducedFeatures"
)
.postFilter {
must(buildImageFilterQuery(searchOptions.filters))
}

private def filteredAggregationBuilder(searchOptions: ImageSearchOptions) =
new ImageFiltersAndAggregationsBuilder(
aggregationRequests = searchOptions.aggregations,
filters = searchOptions.filters,
requestToAggregation = toAggregation,
filterToQuery = buildImageFilterQuery,
searchQuery = searchQuery(searchOptions)
)

private def searchQuery(searchOptions: ImageSearchOptions): BoolQuery =
Expand Down Expand Up @@ -144,7 +141,8 @@ class ImagesRequestBuilder(queryConfig: QueryConfig)
RangeQuery(
"query.source.production.dates.range.from",
lte = lte,
gte = gte)
gte = gte
)
}

private def buildImageFilterQuery(filters: Seq[ImageFilter]): Seq[Query] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ object WorksRequestBuilder
implicit val s = searchOptions
search(index)
.aggs { filteredAggregationBuilder.filteredAggregations }
.query { filteredQuery }
.query { searchQuery }
.sortBy { sortBy }
.limit { searchOptions.pageSize }
.from { PaginationQuery.safeGetFrom(searchOptions) }
.sourceInclude("display", "type")
.postFilter {
must(
buildWorkFilterQuery(VisibleWorkFilter :: searchOptions.filters)
)
}
}

private def filteredAggregationBuilder(
Expand All @@ -41,8 +46,7 @@ object WorksRequestBuilder
aggregationRequests = searchOptions.aggregations,
filters = searchOptions.filters,
requestToAggregation = toAggregation,
filterToQuery = buildWorkFilterQuery,
searchQuery = searchQuery
filterToQuery = buildWorkFilterQuery
)

private def toAggregation(aggReq: WorkAggregationRequest) = aggReq match {
Expand Down Expand Up @@ -134,14 +138,6 @@ object WorksRequestBuilder
}
.getOrElse { boolQuery }

private def filteredQuery(
implicit searchOptions: WorkSearchOptions
): BoolQuery =
searchQuery
.filter {
buildWorkFilterQuery(VisibleWorkFilter :: searchOptions.filters)
}

private def buildWorkFilterQuery(filters: Seq[WorkFilter]): Seq[Query] =
filters.map {
buildWorkFilterQuery
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,70 +143,4 @@ class AggregationResultsTest extends AnyFunSpec with Matchers {
)
)
}

it("sorts the buckets by count (in descending order)") {
val searchResponse = SearchResponse(
took = 1234,
isTimedOut = false,
isTerminatedEarly = false,
suggest = Map(),
_shards = Shards(total = 1, failed = 0, successful = 1),
scrollId = None,
hits = SearchHits(
total = Total(0, "potatoes"),
maxScore = 0.0,
hits = Array()
),
_aggregationsAsMap = Map(
"format" -> Map(
"doc_count" -> 12345,
"format" -> Map(
"doc_count_error_upper_bound" -> 0,
"sum_other_doc_count" -> 0,
"buckets" -> List(
Map(
"key" -> """ "damson" """,
"doc_count" -> 10,
"filtered" -> Map(
"doc_count" -> 1
)
),
Map(
"key" -> """ "cherry" """,
"doc_count" -> 9,
"filtered" -> Map(
"doc_count" -> 2
)
),
Map(
"key" -> """ "banana" """,
"doc_count" -> 8,
"filtered" -> Map(
"doc_count" -> 3
)
),
Map(
"key" -> """ "apricot" """,
"doc_count" -> 7,
"filtered" -> Map(
"doc_count" -> 4
)
)
)
)
)
)
)
val singleAgg = WorkAggregations(searchResponse)
singleAgg.get.format
.flatMap(_.buckets.headOption)
.get shouldBe AggregationBucket(
data = Json.fromString("apricot"),
count = 4
)
singleAgg.get.format
.map(_.buckets.map(_.count))
.get
.reverse shouldBe sorted
}
}
Loading