-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: separate app components (#17)
- Loading branch information
Showing
10 changed files
with
180 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,139 +1,33 @@ | ||
""" A simple example of Streamlit. """ | ||
import textwrap | ||
import os | ||
import tiktoken | ||
import fitz | ||
import streamlit as st | ||
import openai | ||
from dotenv import load_dotenv | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.callbacks import StreamlitCallbackHandler | ||
from src.chroma_client import ChromaDB | ||
import src.gui_messages as gm | ||
from src import settings | ||
|
||
from src.agent import PDFExplainer | ||
from gnosis.chroma_client import ChromaDB | ||
import gnosis.gui_messages as gm | ||
from gnosis import settings | ||
from gnosis.components.sidebar import sidebar | ||
from gnosis.components.main import main | ||
|
||
|
||
load_dotenv() | ||
|
||
|
||
def set_api_key(): | ||
"""Set the OpenAI API key.""" | ||
openai.api_key = st.session_state.api_key | ||
st.session_state.api_message = gm.api_message(openai.api_key) | ||
|
||
|
||
def click_wk_button(): | ||
"""Set the OpenAI API key.""" | ||
st.session_state.wk_button = not st.session_state.wk_button | ||
|
||
|
||
openai.api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
if "api_message" not in st.session_state: | ||
st.session_state.api_message = gm.api_message(openai.api_key) | ||
|
||
|
||
if "wk_button" not in st.session_state: | ||
st.session_state.wk_button = False | ||
|
||
|
||
# Build settings | ||
chroma_db = ChromaDB(openai.api_key) | ||
collection = settings.build(chroma_db) | ||
|
||
# Sidebar | ||
with st.sidebar: | ||
st.write("## OpenAI API key") | ||
openai.api_key = st.text_input( | ||
"Enter OpenAI API key", | ||
value="", | ||
type="password", | ||
key="api_key", | ||
placeholder="Enter your OpenAI API key", | ||
on_change=set_api_key, | ||
label_visibility="collapsed", | ||
) | ||
st.write( | ||
"You can find your API key at https://platform.openai.com/account/api-keys" | ||
) | ||
if "wk_button" not in st.session_state: | ||
st.session_state.wk_button = False | ||
|
||
st.checkbox( | ||
"Use Wikipedia", on_change=click_wk_button, value=st.session_state.wk_button | ||
) | ||
st.subheader("Creativity") | ||
st.write("The higher the value, the crazier the text.") | ||
st.slider( | ||
"Temperature", | ||
min_value=0.0, | ||
max_value=2.0, | ||
value=0.9, | ||
step=0.01, | ||
key="temperature", | ||
) | ||
|
||
if st.button("Delete collection"): | ||
st.warning("Are you sure?") | ||
if st.button("Yes"): | ||
try: | ||
chroma_db.delete_collection(collection.name) | ||
except AttributeError: | ||
st.error("Collection erased.") | ||
|
||
# Main | ||
st.title("GnosisPages") | ||
st.subheader("Create your knowledge base") | ||
|
||
## Uploader | ||
|
||
st.write( | ||
"Upload, extract and consult the content of PDF Files for builiding your knowledge base!" | ||
) | ||
pdf = st.file_uploader("Upload a file", type="pdf") | ||
|
||
if pdf is not None: | ||
with fitz.open(stream=pdf.read(), filetype="pdf") as doc: # open document | ||
with st.spinner("Extracting text..."): | ||
text = chr(12).join([page.get_text() for page in doc]) | ||
st.subheader("Text preview") | ||
st.write(text[0:300] + "...") | ||
if st.button("Save chunks"): | ||
with st.spinner("Saving chunks..."): | ||
chunks = textwrap.wrap(text, 1250) | ||
for idx, chunk in enumerate(chunks): | ||
encoding = tiktoken.get_encoding("cl100k_base") | ||
num_tokens = len(encoding.encode(chunk)) | ||
collection.add( | ||
documents=[chunk], | ||
metadatas=[{"source": pdf.name, "num_tokens": num_tokens}], | ||
ids=[pdf.name + str(idx)], | ||
) | ||
else: | ||
st.write("Please upload a file of type: pdf") | ||
|
||
st.subheader("Consult your knowledge base") | ||
|
||
|
||
prompt = st.chat_input() | ||
|
||
if prompt: | ||
# Create Agent | ||
try: | ||
openai_api_key = openai.api_key | ||
llm = ChatOpenAI( | ||
temperature=st.session_state.temperature, | ||
model="gpt-3.5-turbo-16k", | ||
api_key=openai.api_key, | ||
) | ||
agent = PDFExplainer( | ||
llm, | ||
chroma_db, | ||
extra_tools=st.session_state.wk_button, | ||
).agent | ||
except Exception: # pylint: disable=broad-exception-caught | ||
st.warning("Missing OpenAI API Key.") | ||
sidebar(chroma_db, collection) | ||
|
||
st.chat_message("user").write(prompt) | ||
with st.chat_message("assistant"): | ||
st_callback = StreamlitCallbackHandler(st.container()) | ||
response = agent.run(prompt, callbacks=[st_callback]) | ||
st.write(response) | ||
main(openai.api_key, chroma_db, collection) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
"""Module for building the Langchain Agent""" | ||
import streamlit as st | ||
from langchain.chat_models import ChatOpenAI | ||
from gnosis.agent import PDFExplainer | ||
|
||
|
||
def build(key, client): | ||
"""An Agent builder""" | ||
# Build Agent | ||
try: | ||
print(str(st.session_state.temperature)) | ||
llm = ChatOpenAI( | ||
temperature=st.session_state.temperature, | ||
model="gpt-3.5-turbo-16k", | ||
api_key=key, | ||
) | ||
agent = PDFExplainer( | ||
llm, | ||
client, | ||
extra_tools=st.session_state.wk_button, | ||
).agent | ||
except Exception: # pylint: disable=broad-exception-caught | ||
st.warning("Missing OpenAI API Key.") | ||
|
||
return agent |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
"""Handler functions for the components""" | ||
import streamlit as st | ||
import openai | ||
import gnosis.gui_messages as gm | ||
|
||
|
||
def set_api_key(): | ||
"""Set the OpenAI API key.""" | ||
openai.api_key = st.session_state.api_key | ||
st.session_state.api_message = gm.api_message(openai.api_key) | ||
|
||
|
||
def click_wk_button(): | ||
"""Set the OpenAI API key.""" | ||
st.session_state.wk_button = not st.session_state.wk_button |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Main component""" | ||
import textwrap | ||
import tiktoken | ||
import fitz | ||
import streamlit as st | ||
from langchain.callbacks import StreamlitCallbackHandler | ||
import gnosis.gui_messages as gm | ||
from gnosis.builder import build | ||
|
||
|
||
def uploader(collection): | ||
"""Component for upload files""" | ||
st.write( | ||
"Upload, extract and consult the content of PDF Files for builiding your knowledge base!" | ||
) | ||
pdf = st.file_uploader("Upload a file", type="pdf") | ||
|
||
if pdf is not None: | ||
with fitz.open(stream=pdf.read(), filetype="pdf") as doc: # open document | ||
with st.spinner("Extracting text..."): | ||
text = chr(12).join([page.get_text() for page in doc]) | ||
st.subheader("Text preview") | ||
st.write(text[0:300] + "...") | ||
if st.button("Save chunks"): | ||
with st.spinner("Saving chunks..."): | ||
chunks = textwrap.wrap(text, 1250) | ||
for idx, chunk in enumerate(chunks): | ||
encoding = tiktoken.get_encoding("cl100k_base") | ||
num_tokens = len(encoding.encode(chunk)) | ||
collection.add( | ||
documents=[chunk], | ||
metadatas=[{"source": pdf.name, "num_tokens": num_tokens}], | ||
ids=[pdf.name + str(idx)], | ||
) | ||
else: | ||
st.write("Please upload a file of type: pdf") | ||
|
||
|
||
def main(key, client, collection): | ||
"""Main component""" | ||
gm.header() | ||
|
||
uploader(collection) | ||
|
||
st.subheader("Consult your knowledge base") | ||
|
||
prompt = st.chat_input() | ||
|
||
if prompt: | ||
agent = build(key, client) | ||
|
||
st.chat_message("user").write(prompt) | ||
with st.chat_message("assistant"): | ||
st_callback = StreamlitCallbackHandler(st.container()) | ||
response = agent.run(prompt, callbacks=[st_callback]) | ||
st.write(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""Sidebar component for the Streamlit app.""" | ||
import streamlit as st | ||
import openai | ||
from gnosis.components.handlers import set_api_key, click_wk_button | ||
|
||
|
||
def delete_collection(client, collection): | ||
"""Delete collection button.""" | ||
if st.button("Delete collection"): | ||
st.warning("Are you sure?") | ||
if st.button("Yes"): | ||
try: | ||
client.delete_collection(collection.name) | ||
except AttributeError: | ||
st.error("Collection erased.") | ||
|
||
|
||
def openai_api_key_box(): | ||
"""Box for entrying OpenAi API Key""" | ||
st.sidebar.write("## OpenAI API key") | ||
openai.api_key = st.sidebar.text_input( | ||
"Enter OpenAI API key", | ||
value="", | ||
type="password", | ||
key="api_key", | ||
placeholder="Enter your OpenAI API key", | ||
on_change=set_api_key, | ||
label_visibility="collapsed", | ||
) | ||
st.sidebar.write( | ||
"You can find your API key at https://platform.openai.com/account/api-keys" | ||
) | ||
|
||
|
||
def creativity_slider(): | ||
"""Slider with temperature level""" | ||
st.sidebar.subheader("Creativity") | ||
st.sidebar.write("The higher the value, the crazier the text.") | ||
st.sidebar.slider( | ||
"Temperature", | ||
min_value=0.0, | ||
max_value=1.25, # Max level is 2, but it's too stochastic | ||
value=0.5, | ||
step=0.01, | ||
key="temperature", | ||
) | ||
|
||
|
||
def wk_checkbox(): | ||
"""Wikipedia Checkbox for changing state""" | ||
st.sidebar.checkbox( | ||
"Use Wikipedia", on_change=click_wk_button, value=st.session_state.wk_button | ||
) | ||
|
||
|
||
# Sidebar | ||
def sidebar(client, collection): | ||
"""Sidebar component for the Streamlit app.""" | ||
with st.sidebar: | ||
openai_api_key_box() | ||
|
||
wk_checkbox() | ||
|
||
creativity_slider() | ||
|
||
delete_collection(client, collection) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.