Skip to content

Commit

Permalink
use gemini 1.5
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Apr 11, 2024
1 parent e3272fb commit 019034a
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
16 changes: 8 additions & 8 deletions chain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable, Dict, Optional

import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.vectorstores import SupabaseVectorStore
Expand Down Expand Up @@ -33,7 +33,7 @@ class ModelConfig(BaseModel):

@validator("model_type", pre=True, always=True)
def validate_model_type(cls, v):
if v not in ["gpt", "mixtral8x22b", "claude", "mixtral8x7b"]:
if v not in ["gpt", "gemini", "claude", "mixtral8x7b"]:
raise ValueError(f"Unsupported model type: {v}")
return v

Expand All @@ -56,8 +56,8 @@ def setup(self):
self.setup_claude()
elif self.model_type == "mixtral8x7b":
self.setup_mixtral_8x7b()
elif self.model_type == "mixtral8x22b":
self.setup_mixtral_8x22b()
elif self.model_type == "gemini":
self.setup_gemini()


def setup_gpt(self):
Expand Down Expand Up @@ -97,9 +97,9 @@ def setup_claude(self):
},
)

def setup_mixtral_8x22b(self):
def setup_gemini(self):
self.llm = ChatOpenAI(
model_name="mistralai/mixtral-8x22b",
model_name="google/gemini-pro-1.5",
temperature=0.1,
api_key=self.secrets["OPENROUTER_API_KEY"],
max_tokens=700,
Expand Down Expand Up @@ -155,8 +155,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
model_type = "mixtral8x7b"
elif "claude" in model_name.lower():
model_type = "claude"
elif "mixtral 8x22b" in model_name.lower():
model_type = "mixtral8x22b"
elif "gemini" in model_name.lower():
model_type = "gemini"
else:
raise ValueError(f"Unsupported model name: {model_name}")

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
st.caption("Talk your way through data")
model = st.radio(
"",
options=["Claude-3 Haiku", "Mixtral 8x7B", "Mixtral 8x22B", "GPT-3.5"],
options=["Claude-3 Haiku", "Mixtral 8x7B", "Gemini 1.5 Pro", "GPT-3.5"],
index=0,
horizontal=True,
)
Expand Down
10 changes: 6 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
langchain==0.1.5
langchain==0.1.15
pandas==1.5.0
pydantic==1.10.8
snowflake_snowpark_python==1.5.0
snowflake-snowpark-python[pandas]
streamlit==1.31.0
supabase==1.0.3
supabase==2.4.1
unstructured==0.7.12
tiktoken==0.5.2
openai==1.11.0
openai==1.17.0
black==23.3.0
boto3==1.28.57
langchain_openai==0.0.5
langchain_openai==0.1.2
langchain-community==0.0.32
langchain-core==0.1.41
1 change: 1 addition & 0 deletions utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def start_loading_message(self):
self.placeholder.markdown(loading_message_content, unsafe_allow_html=True)

def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
print("on llm bnew token ",token)
if not self.has_streaming_started:
self.has_streaming_started = True

Expand Down

0 comments on commit 019034a

Please sign in to comment.