diff --git a/lemarche/siaes/models.py b/lemarche/siaes/models.py index 394407f0f..91d2f35b6 100644 --- a/lemarche/siaes/models.py +++ b/lemarche/siaes/models.py @@ -1186,6 +1186,7 @@ def elasticsearch_index_metadata(self): "id": self.id, "name": self.name, "website": self.website if self.website else "", + "kind": self.kind, } if self.latitude and self.longitude: metadata["geo_location"] = { diff --git a/lemarche/tenders/models.py b/lemarche/tenders/models.py index 0f9583b61..dbd15d76e 100644 --- a/lemarche/tenders/models.py +++ b/lemarche/tenders/models.py @@ -719,9 +719,10 @@ def set_siae_found_list(self): geo_distance=self.distance_location, geo_lat=self.location.coords.y, geo_lon=self.location.coords.x, + siae_kinds=self.siae_kind, ) else: - siae_ids = api_elasticsearch.siaes_similarity_search(self.description) + siae_ids = api_elasticsearch.siaes_similarity_search(self.description, siae_kinds=self.siae_kind) siaes_had_found_by_ia = Siae.objects.filter(id__in=siae_ids) siaes_had_found_by_ia_too = [] diff --git a/lemarche/utils/apis/api_elasticsearch.py b/lemarche/utils/apis/api_elasticsearch.py index f33ae44c4..eb1e87c1b 100644 --- a/lemarche/utils/apis/api_elasticsearch.py +++ b/lemarche/utils/apis/api_elasticsearch.py @@ -12,7 +12,7 @@ ) -def siaes_similarity_search(search_text: str, search_filter: dict = {}): +def siaes_similarity_search(search_text: str, search_filter: list = [], siae_kinds: list = []): """Performs semantic search with Elasticsearch as a vector db Args: @@ -21,6 +21,10 @@ def siaes_similarity_search(search_text: str, search_filter: dict = {}): Returns: list: list of siaes id that match the search query """ + + if siae_kinds: + search_filter.append({"terms": {"metadata.kind.keyword": siae_kinds}}) + db = ElasticsearchStore( embedding=OpenAIEmbeddings(), es_user=settings.ELASTICSEARCH_USERNAME, @@ -40,7 +44,7 @@ def siaes_similarity_search(search_text: str, search_filter: dict = {}): def siaes_similarity_search_with_geo_distance( - search_text: str, geo_distance: int = None, geo_lat: float = None, geo_lon: float = None + search_text: str, geo_distance: int = None, geo_lat: float = None, geo_lon: float = None, siae_kinds: list = [] ): search_filter = [] if geo_distance and geo_lat and geo_lon: @@ -56,10 +60,10 @@ def siaes_similarity_search_with_geo_distance( } ] - return siaes_similarity_search(search_text, search_filter) + return siaes_similarity_search(search_text, search_filter, siae_kinds) -def siaes_similarity_search_with_city(search_text: str, city: Perimeter): +def siaes_similarity_search_with_city(search_text: str, city: Perimeter, siae_kinds: list = []): search_filter = [ { "bool": { @@ -88,4 +92,4 @@ def siaes_similarity_search_with_city(search_text: str, city: Perimeter): } } ] - return siaes_similarity_search(search_text, search_filter) + return siaes_similarity_search(search_text, search_filter, siae_kinds)