diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index b5741a19b..dc5675b4f 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -518,14 +518,11 @@ def run_ai_cell(self, args: CellArgs, prompt: str): provider_params["request_schema"] = args.request_schema provider_params["response_path"] = args.response_path - # Validate that the request schema is well-formed JSON - try: - json.loads(args.request_schema) - except json.JSONDecodeError as e: - raise ValueError( - "request-schema must be valid JSON. " - f"Error at line {e.lineno}, column {e.colno}: {e.msg}" - ) from None + if args.model_kwargs: + provider_params["model_kwargs"] = args.model_kwargs + + if args.endpoint_kwargs: + provider_params["endpoint_kwargs"] = args.endpoint_kwargs provider = Provider(**provider_params) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index deffb3176..da4e889b1 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -1,3 +1,4 @@ +import json from typing import Literal, Optional, get_args import click @@ -33,14 +34,14 @@ ) ENDPOINT_ARGS_SHORT_OPTION = "-e" -ENDPOINT_ARGS_LONG_OPTION = "--endpoint-args" +ENDPOINT_ARGS_LONG_OPTION = "--endpoint-kwargs" ENDPOINT_ARGS_HELP = ( "A JSON value that specifies extra values that will be passed " "to the SageMaker Endpoint invoke function." ) MODEL_ARGS_SHORT_OPTION = "-m" -MODEL_ARGS_LONG_OPTION = "--model-args" +MODEL_ARGS_LONG_OPTION = "--model-kwargs" MODEL_ARGS_HELP = ( "A JSON value that specifies extra values that will be passed to" "the payload body of the invoke function. This can be useful to" @@ -58,6 +59,8 @@ class CellArgs(BaseModel): region_name: Optional[str] request_schema: Optional[str] response_path: Optional[str] + model_kwargs: Optional[str] + endpoint_kwargs: Optional[str] # Should match CellArgs, but without "reset" @@ -109,6 +112,19 @@ def get_help(self, ctx): click.echo(super().get_help(ctx)) +def verify_json_value(ctx, param, value): + if not value: + return value + try: + json.loads(value) + except json.JSONDecodeError as e: + raise ValueError( + f"{param.get_error_hint(ctx)} must be valid JSON. " + f"Error at line {e.lineno}, column {e.colno}: {e.msg}" + ) + return value + + @click.command() @click.argument("model_id") @click.option( @@ -136,6 +152,7 @@ def get_help(self, ctx): REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP, + callback=verify_json_value, ) @click.option( RESPONSE_PATH_SHORT_OPTION, @@ -148,6 +165,14 @@ def get_help(self, ctx): ENDPOINT_ARGS_LONG_OPTION, required=False, help=ENDPOINT_ARGS_HELP, + callback=verify_json_value, +) +@click.option( + MODEL_ARGS_SHORT_OPTION, + MODEL_ARGS_LONG_OPTION, + required=False, + help=MODEL_ARGS_HELP, + callback=verify_json_value, ) def cell_magic_parser(**kwargs): """ @@ -188,6 +213,7 @@ def line_magic_parser(): REQUEST_SCHEMA_LONG_OPTION, required=False, help=REQUEST_SCHEMA_HELP, + callback=verify_json_value, ) @click.option( RESPONSE_PATH_SHORT_OPTION, @@ -195,6 +221,20 @@ def line_magic_parser(): required=False, help=RESPONSE_PATH_HELP, ) +@click.option( + ENDPOINT_ARGS_SHORT_OPTION, + ENDPOINT_ARGS_LONG_OPTION, + required=False, + help=ENDPOINT_ARGS_HELP, + callback=verify_json_value, +) +@click.option( + MODEL_ARGS_SHORT_OPTION, + MODEL_ARGS_LONG_OPTION, + required=False, + help=MODEL_ARGS_HELP, + callback=verify_json_value, +) def error_subparser(**kwargs): """ Explains the most recent error. Takes the same options (except -r) as diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 933d834ea..edc139d5e 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -677,8 +677,14 @@ class BedrockProvider(BaseProvider, Bedrock): format="text", ), TextField(key="region_name", label="Region name (optional)", format="text"), + MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"), ] + def __init__(self, *args, **kwargs): + model_kwargs = pop_with_default(kwargs, "model_kwargs", "{}") + model_kwargs = json.loads(model_kwargs) + super().__init__(*args, **kwargs, model_kwargs=model_kwargs) + async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs) @@ -701,8 +707,14 @@ class BedrockChatProvider(BaseProvider, BedrockChat): format="text", ), TextField(key="region_name", label="Region name (optional)", format="text"), + MultilineTextField(key="model_kwargs", label="Model Arguments", format="json"), ] + def __init__(self, *args, **kwargs): + model_kwargs = pop_with_default(kwargs, "model_kwargs", "{}") + model_kwargs = json.loads(model_kwargs) + super().__init__(*args, **kwargs, model_kwargs=model_kwargs) + async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs)