From f571a36f1765308bdadd197e4ab47c7de33751a9 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Mon, 2 Sep 2024 19:27:02 -0700 Subject: [PATCH] Added experimental knn option --- CHANGELOG.md | 1 + lib/searchkick/index_options.rb | 25 ++++++++++++++++++++++ lib/searchkick/model.rb | 2 +- lib/searchkick/query.rb | 38 ++++++++++++++++++++++++++++++++- test/knn_test.rb | 10 +++++++++ test/models/product.rb | 9 +++++++- test/support/activerecord.rb | 1 + 7 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 test/knn_test.rb diff --git a/CHANGELOG.md b/CHANGELOG.md index 8508de0c..8c1e6e4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## 5.4.0 (unreleased) +- Added experimental `knn` option - Added experimental support for `_raw` to `where` option - Added warning for `exists` with non-`true` values - Added warning for full reindex and `:queue` mode diff --git a/lib/searchkick/index_options.rb b/lib/searchkick/index_options.rb index f7cdf4ea..7a7460aa 100644 --- a/lib/searchkick/index_options.rb +++ b/lib/searchkick/index_options.rb @@ -169,6 +169,10 @@ def generate_settings max_shingle_diff: 4 } + if options[:knn] && Searchkick.opensearch? + settings[:index][:knn] = true + end + if options[:case_sensitive] settings[:analysis][:analyzer].each do |_, analyzer| analyzer[:filter].delete("lowercase") @@ -406,6 +410,27 @@ def generate_mappings mapping[field] = shape_options.merge(type: "geo_shape") end + (options[:knn] || []).each do |field, knn_options| + if Searchkick.opensearch? + mapping[field.to_s] = { + type: "knn_vector", + dimension: knn_options[:dimensions], + method: { + name: "hnsw", + space_type: "cosinesimil", + engine: "lucene" + } + } + else + mapping[field.to_s] = { + type: "dense_vector", + dims: knn_options[:dimensions], + index: true, + similarity: "cosine" + } + end + end + if options[:inheritance] mapping[:type] = keyword_mapping end diff --git a/lib/searchkick/model.rb b/lib/searchkick/model.rb index 41aa4875..ed13e28b 100644 --- a/lib/searchkick/model.rb +++ b/lib/searchkick/model.rb @@ -4,7 +4,7 @@ def searchkick(**options) options = Searchkick.model_options.merge(options) unknown_keywords = options.keys - [:_all, :_type, :batch_size, :callbacks, :case_sensitive, :conversions, :deep_paging, :default_fields, - :filterable, :geo_shape, :highlight, :ignore_above, :index_name, :index_prefix, :inheritance, :language, + :filterable, :geo_shape, :highlight, :ignore_above, :index_name, :index_prefix, :inheritance, :knn, :language, :locations, :mappings, :match, :max_result_window, :merge_mappings, :routing, :searchable, :search_synonyms, :settings, :similarity, :special_characters, :stem, :stemmer, :stem_conversions, :stem_exclusion, :stemmer_override, :suggest, :synonyms, :text_end, :text_middle, :text_start, :unscope, :word, :word_end, :word_middle, :word_start] diff --git a/lib/searchkick/query.rb b/lib/searchkick/query.rb index bef8be1b..22d782c6 100644 --- a/lib/searchkick/query.rb +++ b/lib/searchkick/query.rb @@ -19,7 +19,7 @@ class Query def initialize(klass, term = "*", **options) unknown_keywords = options.keys - [:aggs, :block, :body, :body_options, :boost, :boost_by, :boost_by_distance, :boost_by_recency, :boost_where, :conversions, :conversions_term, :debug, :emoji, :exclude, :explain, - :fields, :highlight, :includes, :index_name, :indices_boost, :limit, :load, + :fields, :highlight, :includes, :index_name, :indices_boost, :knn, :limit, :load, :match, :misspellings, :models, :model_includes, :offset, :operator, :order, :padding, :page, :per_page, :profile, :request_params, :routing, :scope_results, :scroll, :select, :similar, :smart_aggs, :suggest, :total_entries, :track, :type, :where] raise ArgumentError, "unknown keywords: #{unknown_keywords.join(", ")}" if unknown_keywords.any? @@ -526,6 +526,42 @@ def prepare end end + # knn + if options[:knn] + if term != "*" + raise ArgumentError, "Hybrid search not supported yet" + end + + if options[:where] + raise ArgumentError, "KNN search with where not supported yet" + end + + if options[:knn].size != 1 + raise ArgumentError, "Invalid knn option" + end + + k = per_page + offset + + if Searchkick.opensearch? + payload[:query].delete(:match_all) + payload[:query][:knn] = {} + options[:knn].each do |field, vector| + payload[:query][:knn][field.to_sym] = { + vector: vector, + k: k + } + end + else + options[:knn].each do |field, vector| + payload[:knn] = { + field: field, + k: k, + query_vector: vector + } + end + end + end + # pagination pagination_options = options[:page] || options[:limit] || options[:per_page] || options[:offset] || options[:padding] if !options[:body] || pagination_options diff --git a/test/knn_test.rb b/test/knn_test.rb new file mode 100644 index 00000000..f3a4be0d --- /dev/null +++ b/test/knn_test.rb @@ -0,0 +1,10 @@ +require_relative "test_helper" + +class KnnTest < Minitest::Test + def test_works + store [{name: "A", embedding: [1, 2, 3]}, {name: "B", embedding: [-1, -2, -3]}] + assert_order "*", ["A", "B"], knn: {embedding: [1, 2, 3]} + expected = Searchkick.opensearch? ? [1, 0] : [2, 1] + assert_equal expected, Product.search(knn: {embedding: [1, 2, 3]}).hits.map { |v| v["_score"] } + end +end diff --git a/test/models/product.rb b/test/models/product.rb index b6179df4..7a197dae 100644 --- a/test/models/product.rb +++ b/test/models/product.rb @@ -20,7 +20,14 @@ class Product highlight: [:name], filterable: [:name, :color, :description], similarity: "BM25", - match: ENV["MATCH"] ? ENV["MATCH"].to_sym : nil + match: ENV["MATCH"] ? ENV["MATCH"].to_sym : nil, + knn: {embedding: {dimensions: 3}} + + if ActiveRecord::VERSION::STRING.to_f >= 7.1 + serialize :embedding, coder: JSON + else + serialize :embedding, JSON + end attr_accessor :conversions, :user_ids, :aisle, :details diff --git a/test/support/activerecord.rb b/test/support/activerecord.rb index e8e07e4a..c8412648 100644 --- a/test/support/activerecord.rb +++ b/test/support/activerecord.rb @@ -32,6 +32,7 @@ t.decimal :longitude, precision: 10, scale: 7 t.text :description t.text :alt_description + t.text :embedding t.timestamps null: true end