Skip to content

Commit

Permalink
[GenAI Factory] Refactoring workflow, workflow server, client and con…
Browse files Browse the repository at this point in the history
…fig (#23)

* [GenAI Factory] Refactoring workflow, workflow server, client and config

* fixed import

* moved logger to workflow server and added setup.py

* format code

* import workflow_server from workflow_path

* remove chroma from requirements

* guys rc

* fix setup.py file

* put logger in utils & fix controller client in workflow server

* fix logger import

* cli entrypoint

* fix cli main

* .

* fix build graph

* lint

* fix workflow to schema

* update default user name

* add logger info message

* debug

* fix _get and graph build

* workflow schema graph as dict

* working example, working factory

* add agent example

* removed branch name from requirements

* updated quick_start example folder name

* last updates
  • Loading branch information
yonishelach authored Sep 24, 2024
1 parent 933b873 commit 90887a6
Show file tree
Hide file tree
Showing 47 changed files with 2,565 additions and 944 deletions.
12 changes: 10 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,21 @@

CONTROLLER_NAME = "genai-factory-controller"

.PHONY: genai-factory
genai-factory:
# Build the Docker image using the
docker-compose up -d --build
@echo "GenAI Factory Controller and UI application are running in the background"
@echo "UI application is available at http://localhost:3000"
@echo "Controller API is available at http://localhost:8001"

.PHONY: controller
controller:
# Build controller's image:
docker build -f controller/Dockerfile -t $(CONTROLLER_NAME):latest .

# Run controller locally in a container:
docker run -d --net host --name $(CONTROLLER_NAME) $(CONTROLLER_NAME):latest
docker run -d -p 8001:8001 --name $(CONTROLLER_NAME) $(CONTROLLER_NAME):latest

# Announce the server is running:
@echo "GenAI Factory Controller is running in the background"
Expand All @@ -42,4 +50,4 @@ lint-imports: ## Validates import dependencies
fmt-check: ## Check the code (using ruff)
@echo "Running ruff checks..."
python -m ruff check --exit-non-zero-on-fix
python -m ruff format --check
python -m ruff format --check
11 changes: 9 additions & 2 deletions controller/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,15 @@ RUN mkdir -p ../data
# Install requirements:
RUN pip install -r /controller/requirements.txt

# Set python path environment variable:
ENV PYTHONPATH="/controller/src"
ENV CTRL_DATA_PATH="/data"

# Expose the controller's API port:
EXPOSE 8001

# Initiate database:
RUN python -m controller.src.main initdb
RUN python -m controller initdb

# Run the controller's API server:
CMD ["uvicorn", "controller.src.api:app", "--port", "8001", "--reload"]
CMD ["uvicorn", "controller.api:app","--host", "0.0.0.0", "--port", "8001", "--reload"]
3 changes: 2 additions & 1 deletion controller/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ python-dotenv
pyyaml
click
requests
tabulate
tabulate
git+https://github.com/mlrun/genai-factory.git
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def print_config():
@click.option(
"-m", "--metadata", type=(str, str), multiple=True, help="Metadata Key value pair"
)
@click.option("-v", "--version", type=str, help="document version")
@click.option("-v", "--version", type=str, help="document version", default="")
@click.option("-d", "--data-source", type=str, help="Data source name")
@click.option(
"-f", "--from-file", is_flag=True, help="Take the document paths from the file"
Expand All @@ -120,7 +120,7 @@ def ingest(path, project, name, loader, metadata, version, data_source, from_fil
:param from_file: Take the document paths from the file
"""
db_session = client.get_db_session()
project = client.get_project(project_name=project, db_session=db_session)
project = client.get_project(name=project, db_session=db_session)
data_source = client.get_data_source(
project_id=project.uid, name=data_source, db_session=db_session
)
Expand Down Expand Up @@ -200,7 +200,7 @@ def infer(
"""
db_session = client.get_db_session()

project = client.get_project(project_name=project, db_session=db_session)
project = client.get_project(name=project, db_session=db_session)
# Getting the workflow:
workflow = client.get_workflow(
project_id=project.uid, name=workflow_name, db_session=db_session
Expand Down
18 changes: 5 additions & 13 deletions controller/src/controller/api/endpoints/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def get_data_source(
:return: The data source from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
try:
# Parse the version if provided:
uid, version = parse_version(uid, version)
Expand Down Expand Up @@ -153,9 +151,7 @@ def delete_data_source(
:returThe response from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid, version)
try:
client.delete_data_source(
Expand Down Expand Up @@ -198,11 +194,9 @@ def list_data_sources(
:return: The response from the database.
"""
owner = client.get_user(user_name=auth.username, db_session=db_session)
owner = client.get_user(name=auth.username, db_session=db_session)
owner_id = getattr(owner, "uid", None)
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
try:
data = client.list_data_sources(
project_id=project_id,
Expand Down Expand Up @@ -251,9 +245,7 @@ def ingest(
:return: The response from the application.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, ds_version = parse_version(uid, version)
data_source = client.get_data_source(
name=name,
Expand Down
14 changes: 4 additions & 10 deletions controller/src/controller/api/endpoints/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def get_dataset(
:return: The dataset from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
try:
uid, version = parse_version(uid, version)
data = client.get_dataset(
Expand Down Expand Up @@ -137,9 +135,7 @@ def delete_dataset(
:return: The response from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid, version)
try:
client.delete_dataset(
Expand Down Expand Up @@ -182,11 +178,9 @@ def list_datasets(
:return: The response from the database.
"""
owner = client.get_user(user_name=auth.username, db_session=db_session)
owner = client.get_user(name=auth.username, db_session=db_session)
owner_id = getattr(owner, "uid", None)
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
try:
data = client.list_datasets(
project_id=project_id,
Expand Down
14 changes: 4 additions & 10 deletions controller/src/controller/api/endpoints/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def get_document(
:return: The document from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid, version)
try:
data = client.get_document(
Expand Down Expand Up @@ -139,9 +137,7 @@ def delete_document(
:return: The response from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid, version)
try:
client.delete_document(
Expand Down Expand Up @@ -182,11 +178,9 @@ def list_documents(
:return: The response from the database.
"""
owner = client.get_user(user_name=auth.username, db_session=db_session)
owner = client.get_user(name=auth.username, db_session=db_session)
owner_id = getattr(owner, "uid", None)
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
try:
data = client.list_documents(
project_id=project_id,
Expand Down
14 changes: 4 additions & 10 deletions controller/src/controller/api/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def get_model(
:return: The model from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid, version)
try:
data = client.get_model(
Expand Down Expand Up @@ -137,9 +135,7 @@ def delete_model(
:return: The response from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid, version)
try:
client.delete_model(
Expand Down Expand Up @@ -182,11 +178,9 @@ def list_models(
:return: The response from the database.
"""
owner = client.get_user(user_name=auth.username, db_session=db_session)
owner = client.get_user(name=auth.username, db_session=db_session)
owner_id = getattr(owner, "uid", None)
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
try:
data = client.list_models(
project_id=project_id,
Expand Down
2 changes: 1 addition & 1 deletion controller/src/controller/api/endpoints/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def list_projects(
:return: The response from the database.
"""
if owner_name is not None:
owner_id = client.get_user(user_name=owner_name, db_session=db_session).uid
owner_id = client.get_user(name=owner_name, db_session=db_session).uid
else:
owner_id = None
try:
Expand Down
14 changes: 4 additions & 10 deletions controller/src/controller/api/endpoints/prompt_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def get_prompt(
:return: The prompt from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid, version)
try:
data = client.get_prompt_template(
Expand Down Expand Up @@ -139,9 +137,7 @@ def delete_prompt(
:return: The response from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid, version)
try:
client.delete_prompt_template(
Expand Down Expand Up @@ -182,11 +178,9 @@ def list_prompts(
:return: The response from the database.
"""
owner = client.get_user(user_name=auth.username, db_session=db_session)
owner = client.get_user(name=auth.username, db_session=db_session)
owner_id = getattr(owner, "uid", None)
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
try:
data = client.list_prompt_templates(
project_id=project_id,
Expand Down
6 changes: 3 additions & 3 deletions controller/src/controller/api/endpoints/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_session(
"""
user_id = None
if name == "$last":
user_id = client.get_user(user_name=user_name, db_session=db_session).uid
user_id = client.get_user(name=user_name, db_session=db_session).uid
name = None
try:
data = client.get_session(
Expand Down Expand Up @@ -127,7 +127,7 @@ def delete_session(
:return: The response from the database.
"""
user_id = client.get_user(user_name=user_name, db_session=db_session).uid
user_id = client.get_user(name=user_name, db_session=db_session).uid
try:
client.delete_session(
name=name, uid=uid, user_id=user_id, db_session=db_session
Expand Down Expand Up @@ -163,7 +163,7 @@ def list_sessions(
:return: The response from the database.
"""
user_id = client.get_user(user_name=user_name, db_session=db_session).uid
user_id = client.get_user(name=user_name, db_session=db_session).uid
try:
data = client.list_sessions(
user_id=user_id,
Expand Down
23 changes: 8 additions & 15 deletions controller/src/controller/api/endpoints/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def get_workflow(
:return: The workflow from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid, version)
try:
data = client.get_workflow(
Expand Down Expand Up @@ -153,9 +151,7 @@ def delete_workflow(
:return: The response from the database.
"""
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
uid, version = parse_version(uid=uid, version=version)
try:
client.delete_workflow(
Expand Down Expand Up @@ -198,11 +194,9 @@ def list_workflows(
:return: The response from the database.
"""
owner = client.get_user(user_name=auth.username, db_session=db_session)
owner = client.get_user(name=auth.username, db_session=db_session)
owner_id = getattr(owner, "uid", None)
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
try:
data = client.list_workflows(
name=name,
Expand Down Expand Up @@ -242,17 +236,14 @@ def infer_workflow(
:return: The response from the database.
"""
# Get workflow from the database
project_id = client.get_project(
project_name=project_name, db_session=db_session
).uid
project_id = client.get_project(name=project_name, db_session=db_session).uid
workflow = client.get_workflow(
project_id=project_id, name=name, db_session=db_session
)
if workflow is None:
return APIResponse(
success=False, error=f"Workflow with name = {name} not found"
)
path = workflow.get_infer_path()

if query.session_name:
# Get session by name:
Expand All @@ -263,7 +254,7 @@ def infer_workflow(
name=query.session_name,
workflow_id=workflow.uid,
owner_id=client.get_user(
user_name=auth.username, db_session=db_session
name=auth.username, db_session=db_session
).uid,
),
)
Expand All @@ -272,9 +263,11 @@ def infer_workflow(
"item": query.dict(),
"workflow": workflow.to_dict(short=True),
}
path = workflow.deployment

# Sent the event to the application's workflow:
try:
print(f"Sending data to {path}: {data}")
data = _send_to_application(
path=path,
method="POST",
Expand Down
Loading

0 comments on commit 90887a6

Please sign in to comment.