diff --git a/docs/source/users/index.md b/docs/source/users/index.md index 4226df56b..b5e7066b5 100644 --- a/docs/source/users/index.md +++ b/docs/source/users/index.md @@ -888,8 +888,8 @@ This configuration allows specifying arbitrary parameters that are unpacked and This is useful for passing parameters such as model tuning that affect the response generation by the model. This is also an appropriate place to pass in custom attributes required by certain providers/models. -The accepted value should be a dictionary, with top level keys as the model id (provider:model_id), and value -should be any arbitrary dictionary which is unpacked and passed as is to the provider class. +The accepted value is a dictionary, with top level keys as the model id (provider:model_id), and value +should be any arbitrary dictionary which is unpacked and passed as-is to the provider class. #### Configuring as a startup option In this sample, the `bedrock` provider will be created with the value for `model_kwargs` when `ai21.j2-mid-v1` model is selected. diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index dc5675b4f..5266b071d 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -518,13 +518,9 @@ def run_ai_cell(self, args: CellArgs, prompt: str): provider_params["request_schema"] = args.request_schema provider_params["response_path"] = args.response_path - if args.model_kwargs: - provider_params["model_kwargs"] = args.model_kwargs + model_parameters = json.loads(args.model_parameters) - if args.endpoint_kwargs: - provider_params["endpoint_kwargs"] = args.endpoint_kwargs - - provider = Provider(**provider_params) + provider = Provider(**provider_params, **model_parameters) # Apply a prompt template. prompt = provider.get_prompt_template(args.format).format(prompt=prompt) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index da4e889b1..a3b14bbb0 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -33,20 +33,12 @@ + "does nothing with other providers." ) -ENDPOINT_ARGS_SHORT_OPTION = "-e" -ENDPOINT_ARGS_LONG_OPTION = "--endpoint-kwargs" -ENDPOINT_ARGS_HELP = ( +MODEL_PARAMETERS_SHORT_OPTION = "-m" +MODEL_PARAMETERS_LONG_OPTION = "--model-parameters" +MODEL_PARAMETERS_HELP = ( "A JSON value that specifies extra values that will be passed " - "to the SageMaker Endpoint invoke function." -) - -MODEL_ARGS_SHORT_OPTION = "-m" -MODEL_ARGS_LONG_OPTION = "--model-kwargs" -MODEL_ARGS_HELP = ( - "A JSON value that specifies extra values that will be passed to" - "the payload body of the invoke function. This can be useful to" - "pass model tuning parameters such as token count, temperature " - "etc., that affects the response generated by of a model." + "to the model. The accepted value parsed to a dict, unpacked " + "and passed as-is to the provider class." ) @@ -59,8 +51,7 @@ class CellArgs(BaseModel): region_name: Optional[str] request_schema: Optional[str] response_path: Optional[str] - model_kwargs: Optional[str] - endpoint_kwargs: Optional[str] + model_parameters: Optional[str] # Should match CellArgs, but without "reset" @@ -161,18 +152,12 @@ def verify_json_value(ctx, param, value): help=RESPONSE_PATH_HELP, ) @click.option( - ENDPOINT_ARGS_SHORT_OPTION, - ENDPOINT_ARGS_LONG_OPTION, - required=False, - help=ENDPOINT_ARGS_HELP, - callback=verify_json_value, -) -@click.option( - MODEL_ARGS_SHORT_OPTION, - MODEL_ARGS_LONG_OPTION, + MODEL_PARAMETERS_SHORT_OPTION, + MODEL_PARAMETERS_LONG_OPTION, required=False, - help=MODEL_ARGS_HELP, + help=MODEL_PARAMETERS_HELP, callback=verify_json_value, + default="{}", ) def cell_magic_parser(**kwargs): """ @@ -222,18 +207,12 @@ def line_magic_parser(): help=RESPONSE_PATH_HELP, ) @click.option( - ENDPOINT_ARGS_SHORT_OPTION, - ENDPOINT_ARGS_LONG_OPTION, - required=False, - help=ENDPOINT_ARGS_HELP, - callback=verify_json_value, -) -@click.option( - MODEL_ARGS_SHORT_OPTION, - MODEL_ARGS_LONG_OPTION, + MODEL_PARAMETERS_SHORT_OPTION, + MODEL_PARAMETERS_LONG_OPTION, required=False, - help=MODEL_ARGS_HELP, + help=MODEL_PARAMETERS_HELP, callback=verify_json_value, + default="{}", ) def error_subparser(**kwargs): """