diff --git a/application/api/user/routes.py b/application/api/user/routes.py index 2ead8ef10..0bdd7ab4a 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -479,6 +479,19 @@ def get(self): "retriever": "brave_search", } ) + + if "twitter_search" in settings.RETRIEVERS_ENABLED: + data.append( + { + "name": "Twitter Search", + "language": "en", + "date": "twitter_search", + "model": settings.EMBEDDINGS_NAME, + "location": "custom", + "tokens": "", + "retriever": "twitter_search", + } + ) except Exception as err: return make_response(jsonify({"success": False, "error": str(err)}), 400) diff --git a/application/core/settings.py b/application/core/settings.py index d4b02481c..263c4bfa6 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -19,7 +19,7 @@ class Settings(BaseSettings): MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5} UPLOAD_FOLDER: str = "inputs" VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" - RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search + RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search", "twitter_search"] # also brave_search # LLM Cache CACHE_REDIS_URL: str = "redis://localhost:6379/2" @@ -74,7 +74,8 @@ class Settings(BaseSettings): LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors BRAVE_SEARCH_API_KEY: Optional[str] = None - + TWITTER_API_KEY: Optional[str] = None + TWITTER_API_KEY_SECRET: Optional[str] = None FLASK_DEBUG_MODE: bool = False diff --git a/application/requirements.txt b/application/requirements.txt index 6ea1d1ba2..966d6a38b 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -70,6 +70,7 @@ redis==5.0.1 referencing==0.30.2 regex==2024.9.11 requests==2.32.3 +requests-oauthlib==2.0.0 retry==0.9.2 sentence-transformers==3.0.1 tiktoken==0.7.0 diff --git a/application/retriever/retriever_creator.py b/application/retriever/retriever_creator.py index 07be373d8..5850844d0 100644 --- a/application/retriever/retriever_creator.py +++ b/application/retriever/retriever_creator.py @@ -1,7 +1,7 @@ from application.retriever.classic_rag import ClassicRAG from application.retriever.duckduck_search import DuckDuckSearch from application.retriever.brave_search import BraveRetSearch - +from application.retriever.twitter_search import TwitterRetSearch class RetrieverCreator: @@ -9,6 +9,7 @@ class RetrieverCreator: 'classic': ClassicRAG, 'duckduck_search': DuckDuckSearch, 'brave_search': BraveRetSearch, + 'twitter_search': TwitterRetSearch, 'default': ClassicRAG } diff --git a/application/retriever/twitter_search.py b/application/retriever/twitter_search.py new file mode 100644 index 000000000..434b48015 --- /dev/null +++ b/application/retriever/twitter_search.py @@ -0,0 +1,187 @@ +import sys +from application.retriever.base import BaseRetriever +from application.retriever.classic_rag import ClassicRAG +from application.core.settings import settings +from application.llm.llm_creator import LLMCreator +from application.utils import num_tokens_from_string +import requests +import base64 + + +class TwitterRetSearch(BaseRetriever): + + def __init__( + self, + question, + source, + chat_history, + prompt, + chunks=2, + token_limit=150, + gpt_model="docsgpt", + user_api_key=None, + ): + self.question = question + self.source = source + self.chat_history = chat_history + self.prompt = prompt + self.chunks = chunks + self.gpt_model = gpt_model + self.token_limit = ( + token_limit + if token_limit + < settings.MODEL_TOKEN_LIMITS.get( + self.gpt_model, settings.DEFAULT_MAX_HISTORY + ) + else settings.MODEL_TOKEN_LIMITS.get( + self.gpt_model, settings.DEFAULT_MAX_HISTORY + ) + ) + self.user_api_key = user_api_key + + def _get_data(self): + if self.chunks == 0: + docs = [] + else: + # Question should ask llm to generate search query for twitter based on the question + llm_query = f"Generate a search term for the Twitter API based on: {self.question}. Provide single or multiple words without quotes." + + messages_combine = [{"role": "user", "content": llm_query}] + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) + + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) + twitter_search_query = "" + for line in completion: + twitter_search_query += str(line) + + results = self.search_tweets(twitter_search_query, count=int(self.chunks)) + + # TODO work on processing the results json below by following proper schema of x api + + docs = [] + for i in results: + try: + title = i["title"] + link = i["link"] + snippet = i["snippet"] + docs.append({"text": snippet, "title": title, "link": link}) + except IndexError: + pass + if settings.LLM_NAME == "llama.cpp": + docs = [docs[0]] + + return docs + + def gen(self): + docs = self._get_data() + + # join all page_content together with a newline + docs_together = "\n".join([doc["text"] for doc in docs]) + p_chat_combine = self.prompt.replace("{summaries}", docs_together) + messages_combine = [{"role": "system", "content": p_chat_combine}] + for doc in docs: + yield {"source": doc} + + if len(self.chat_history) > 1: + tokens_current_history = 0 + # count tokens in history + self.chat_history.reverse() + for i in self.chat_history: + if "prompt" in i and "response" in i: + tokens_batch = num_tokens_from_string(i["prompt"]) + num_tokens_from_string( + i["response"] + ) + if tokens_current_history + tokens_batch < self.token_limit: + tokens_current_history += tokens_batch + messages_combine.append( + {"role": "user", "content": i["prompt"]} + ) + messages_combine.append( + {"role": "system", "content": i["response"]} + ) + messages_combine.append({"role": "user", "content": self.question}) + + llm = LLMCreator.create_llm( + settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key + ) + + completion = llm.gen_stream(model=self.gpt_model, messages=messages_combine) + for line in completion: + yield {"answer": str(line)} + + def search(self): + return self._get_data() + + def get_params(self): + return { + "question": self.question, + "source": self.source, + "chat_history": self.chat_history, + "prompt": self.prompt, + "chunks": self.chunks, + "token_limit": self.token_limit, + "gpt_model": self.gpt_model, + "user_api_key": self.user_api_key + } + + + def get_bearer_token(self, consumer_key, consumer_secret): + + # Step 1: Concatenate with a colon + bearer_token_credentials = f"{consumer_key}:{consumer_secret}" + + # Step 2: Base64 encode the concatenated string + base64_encoded_credentials = base64.b64encode(bearer_token_credentials.encode('utf-8')).decode('utf-8') + + # Step 3: Obtain Bearer Token + url = 'https://api.x.com/oauth2/token' + headers = { + 'Authorization': f'Basic {base64_encoded_credentials}', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8' + } + data = { + 'grant_type': 'client_credentials' + } + + # Make the POST request to get the Bearer Token + response = requests.post(url, headers=headers, data=data) + + # Check if the request was successful + if response.status_code == 200: + token_response = response.json() + return token_response.get('access_token') + else: + raise Exception(f"Failed to get bearer token: {response.status_code}, {response.text}") + + + # Function to search for tweets using Twitter API v1.1 + def search_tweets(self, search_term): + oauth2_token = self.get_bearer_token(settings.TWITTER_API_KEY, settings.TWITTER_API_KEY_SECRET) + print(oauth2_token, file=sys.stderr) + # Parameters for the search query + params = { + 'query': search_term, + } + + # Make the GET request using httpx + SEARCH_URL = "https://api.twitter.com/2/tweets/search/recent" + headers = { + 'Authorization': f'Bearer {oauth2_token}' + } + response = requests.get(SEARCH_URL, headers=headers, params=params) + + print(response.status_code, file=sys.stderr) + + # Check if the response is OK + if response.status_code != 200: + print(response.text, file=sys.stderr) + raise Exception(f"Request failed: {response.status_code} {response.text}") + + # Parse the JSON response + tweet_data = response.json() + + # Extract and return relevant tweet information + return tweet_data.get('statuses', []) +