-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move similarity search examples (#275)
- Loading branch information
1 parent
8ad263d
commit e1dae83
Showing
4 changed files
with
1,077 additions
and
0 deletions.
There are no files selected for viewing
370 changes: 370 additions & 0 deletions
370
api_examples/hopsworks/vector_similarity_search/news-search-knn-save-model.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,370 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1e23a0d0", | ||
"metadata": {}, | ||
"source": [ | ||
"# News search using kNN in Hopsworks" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "03fb3165", | ||
"metadata": {}, | ||
"source": [ | ||
"In this tutorial, you are going to learn how to create a news search application which allows you to search news using natural language. You will create embedding for the news and search news similar to a given description using embeddings and kNN search. You will also learn how to avoid training-serving skew by using model registry. The steps include:\n", | ||
"1. Load news data\n", | ||
"2. Create embedddings for news heading and news body\n", | ||
"3. Save the embedding model to model registry\n", | ||
"4. Ingest the news data and embedding into Hopsworks\n", | ||
"5. Search news using Hopsworks" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "12c30fea", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load news data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d5d5513d", | ||
"metadata": {}, | ||
"source": [ | ||
"First, you need to load the news articles downloaded from [Kaggle news articles](https://www.kaggle.com/datasets/asad1m9a9h6mood/news-articles).\n", | ||
"Since creating embeddings for the full news is time-consuming, here we sample some articles." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e95346ff", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"\n", | ||
"df_all = pd.read_csv(\"https://repo.hops.works/dev/jdowling/Articles.csv\", encoding='utf-8', encoding_errors='ignore')\n", | ||
"df = df_all.sample(n=300).reset_index().drop([\"index\"], axis=1)\n", | ||
"df[\"news_id\"] = list(range(len(df)))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "96bfc948", | ||
"metadata": {}, | ||
"source": [ | ||
"## Create embeddings" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b7bd09b2", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, you need to create embeddings for heading and body of the news. The embeddings will then be used for kNN search against the embedding of the news description you want to search. Here we use a light weighted language model (LM) which encodes the news into embeddings. You can use any other language models including LLM (llama, Mistral)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "88053d37", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install sentence_transformers -q" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9c22d8fb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sentence_transformers import SentenceTransformer\n", | ||
"model = SentenceTransformer('all-MiniLM-L6-v2')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "43e230f9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# truncate the body to 100 characters\n", | ||
"embeddings_body = model.encode([body for body in df[\"Article\"]])\n", | ||
"embeddings_heading = model.encode(df[\"Heading\"])\n", | ||
"df[\"embedding_heading\"] = pd.Series(embeddings_heading.tolist())\n", | ||
"df[\"embedding_body\"] = pd.Series(embeddings_body.tolist())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "41b150f6", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df.head()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "65f73330", | ||
"metadata": {}, | ||
"source": [ | ||
"## Ingest into Hopsworks" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0721e6c1", | ||
"metadata": {}, | ||
"source": [ | ||
"You need to ingest the data to Hopsworks, so that they are stored and indexed. First, you login into Hopsworks and prepare the feature store." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "27a19f00", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import hopsworks\n", | ||
"proj = hopsworks.login()\n", | ||
"fs = proj.get_feature_store()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "36cf4331", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, as embeddings are stored in an index in the backing vecotor database, you need to specify the index name and the embedding features in the dataframe. You can also save the embedding model to model registry, and attach the model to the embedding features. This is useful for avoiding training-serving skew as at inference time you can get back the same model used for creating embedding at training time." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4c564fc9", | ||
"metadata": {}, | ||
"source": [ | ||
"First, you save the model to model registry." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6e5e5270", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pickle\n", | ||
"model_name = \"SentenceTransformer_all_MiniLM_L6_v2\"\n", | ||
"mr = proj.get_model_registry()\n", | ||
"# Check if the model has been created, and get back the model if available. Otherwise create the model.\n", | ||
"try:\n", | ||
" hsml_model = mr.get_model(model_name, 1)\n", | ||
"except:\n", | ||
" with open(f\"{model_name}.pkl\", \"wb\") as f:\n", | ||
" pickle.dump(model, f)\n", | ||
" hsml_model = mr.python.create_model(model_name)\n", | ||
" hsml_model.save(f\"{model_name}.pkl\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "db4a76ea", | ||
"metadata": {}, | ||
"source": [ | ||
"Then, you specify the index name, embedding features, and model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "47e0cf41", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"version = 1\n", | ||
"from hsfs import embedding\n", | ||
"\n", | ||
"emb = embedding.EmbeddingIndex(index_name=f\"news_fg_{version}\")\n", | ||
"# specify the name, dimension, and model of the embedding features \n", | ||
"emb.add_embedding(\"embedding_body\", model.get_sentence_embedding_dimension(), model=hsml_model)\n", | ||
"emb.add_embedding(\"embedding_heading\", model.get_sentence_embedding_dimension(), model=hsml_model)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "25bddf53", | ||
"metadata": {}, | ||
"source": [ | ||
"Next, you create a feature group with the `embedding_index` and ingest data into the feature group." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e55522ba", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"news_fg = fs.get_or_create_feature_group(\n", | ||
" name=\"news_fg\",\n", | ||
" embedding_index=emb,\n", | ||
" primary_key=[\"news_id\"],\n", | ||
" version=version,\n", | ||
" online_enabled=True\n", | ||
")\n", | ||
"\n", | ||
"news_fg.insert(df, write_options={\"start_offline_materialization\": False})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "db6b194b", | ||
"metadata": {}, | ||
"source": [ | ||
"## Search News" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7fc8854e", | ||
"metadata": {}, | ||
"source": [ | ||
"Once the data are ingested into Hopsworks, you can search news by giving a news description. The news description first needs to be encoded by the same LM you used to encode the news. You can get back the model in the model registry from the embedding feature. And then you can search news which are similar to the description using kNN search functionality provided by the feature group." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ede462ef", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# set the logging level to WARN to avoid INFO message\n", | ||
"import logging\n", | ||
"logging.getLogger().setLevel(logging.WARN)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f8d7df08", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"news_description = \"news about europe\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ba2b2bde", | ||
"metadata": {}, | ||
"source": [ | ||
"You can get back the model file from embedding feature object, and load the model file back to the embedding model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0f9f2cd4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"hsml_model = news_fg.embedding_index.get_embedding(\"embedding_heading\").model\n", | ||
"local_model_path = hsml_model.download()\n", | ||
"with open(f\"{local_model_path}/{hsml_model.name}.pkl\", 'rb') as f:\n", | ||
" loaded_model = pickle.load(f)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "401915b0", | ||
"metadata": {}, | ||
"source": [ | ||
"You can search similar news to the description against news heading." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7356be82", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"results = news_fg.find_neighbors(loaded_model.encode(news_description), k=3, col=\"embedding_heading\")\n", | ||
"# print out the heading\n", | ||
"for result in results:\n", | ||
" print(result[1][2])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "24e70246", | ||
"metadata": {}, | ||
"source": [ | ||
"Alternative, you can search similar news to the description against the news body and filter by news type." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "32d954bc", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"results = news_fg.find_neighbors(loaded_model.encode(news_description), k=3, col=\"embedding_body\",\n", | ||
" filter=news_fg.newstype == \"business\")\n", | ||
"# print out the heading\n", | ||
"for result in results:\n", | ||
" print(result[1][2])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c0938afc", | ||
"metadata": {}, | ||
"source": [ | ||
"## Next step" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "28ffda57", | ||
"metadata": {}, | ||
"source": [ | ||
"Now you are able to search articles using natural language. You can learn how to rank the result in [this tutorial]() and learn best practices in the [guide]()." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.