Skip to content

Commit

Permalink
Added model and endpoints kwargs options.
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins committed Nov 8, 2023
1 parent b3e8f83 commit 960c692
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 10 deletions.
13 changes: 5 additions & 8 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,14 +518,11 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
provider_params["request_schema"] = args.request_schema
provider_params["response_path"] = args.response_path

# Validate that the request schema is well-formed JSON
try:
json.loads(args.request_schema)
except json.JSONDecodeError as e:
raise ValueError(
"request-schema must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
) from None
if args.model_kwargs:
provider_params["model_kwargs"] = args.model_kwargs

if args.endpoint_kwargs:
provider_params["endpoint_kwargs"] = args.endpoint_kwargs

provider = Provider(**provider_params)

Expand Down
44 changes: 42 additions & 2 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Literal, Optional, get_args

import click
Expand Down Expand Up @@ -33,14 +34,14 @@
)

ENDPOINT_ARGS_SHORT_OPTION = "-e"
ENDPOINT_ARGS_LONG_OPTION = "--endpoint-args"
ENDPOINT_ARGS_LONG_OPTION = "--endpoint-kwargs"
ENDPOINT_ARGS_HELP = (
"A JSON value that specifies extra values that will be passed "
"to the SageMaker Endpoint invoke function."
)

MODEL_ARGS_SHORT_OPTION = "-m"
MODEL_ARGS_LONG_OPTION = "--model-args"
MODEL_ARGS_LONG_OPTION = "--model-kwargs"
MODEL_ARGS_HELP = (
"A JSON value that specifies extra values that will be passed to"
"the payload body of the invoke function. This can be useful to"
Expand All @@ -58,6 +59,8 @@ class CellArgs(BaseModel):
region_name: Optional[str]
request_schema: Optional[str]
response_path: Optional[str]
model_kwargs: Optional[str]
endpoint_kwargs: Optional[str]


# Should match CellArgs, but without "reset"
Expand Down Expand Up @@ -109,6 +112,19 @@ def get_help(self, ctx):
click.echo(super().get_help(ctx))


def verify_json_value(ctx, param, value):
if not value:
return value
try:
json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(
f"{param.get_error_hint(ctx)} must be valid JSON. "
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
)
return value


@click.command()
@click.argument("model_id")
@click.option(
Expand Down Expand Up @@ -136,6 +152,7 @@ def get_help(self, ctx):
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
Expand All @@ -148,6 +165,14 @@ def get_help(self, ctx):
ENDPOINT_ARGS_LONG_OPTION,
required=False,
help=ENDPOINT_ARGS_HELP,
callback=verify_json_value,
)
@click.option(
MODEL_ARGS_SHORT_OPTION,
MODEL_ARGS_LONG_OPTION,
required=False,
help=MODEL_ARGS_HELP,
callback=verify_json_value,
)
def cell_magic_parser(**kwargs):
"""
Expand Down Expand Up @@ -188,13 +213,28 @@ def line_magic_parser():
REQUEST_SCHEMA_LONG_OPTION,
required=False,
help=REQUEST_SCHEMA_HELP,
callback=verify_json_value,
)
@click.option(
RESPONSE_PATH_SHORT_OPTION,
RESPONSE_PATH_LONG_OPTION,
required=False,
help=RESPONSE_PATH_HELP,
)
@click.option(
ENDPOINT_ARGS_SHORT_OPTION,
ENDPOINT_ARGS_LONG_OPTION,
required=False,
help=ENDPOINT_ARGS_HELP,
callback=verify_json_value,
)
@click.option(
MODEL_ARGS_SHORT_OPTION,
MODEL_ARGS_LONG_OPTION,
required=False,
help=MODEL_ARGS_HELP,
callback=verify_json_value,
)
def error_subparser(**kwargs):
"""
Explains the most recent error. Takes the same options (except -r) as
Expand Down
12 changes: 12 additions & 0 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,14 @@ class BedrockProvider(BaseProvider, Bedrock):
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"),
]

def __init__(self, *args, **kwargs):
model_kwargs = pop_with_default(kwargs, "model_kwargs", "{}")
model_kwargs = json.loads(model_kwargs)
super().__init__(*args, **kwargs, model_kwargs=model_kwargs)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

Expand All @@ -701,8 +707,14 @@ class BedrockChatProvider(BaseProvider, BedrockChat):
format="text",
),
TextField(key="region_name", label="Region name (optional)", format="text"),
MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"),
]

def __init__(self, *args, **kwargs):
model_kwargs = pop_with_default(kwargs, "model_kwargs", "{}")
model_kwargs = json.loads(model_kwargs)
super().__init__(*args, **kwargs, model_kwargs=model_kwargs)

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)

Expand Down

0 comments on commit 960c692

Please sign in to comment.