From 87bd30087a2f51c90db025e650b8bdcf7b58cc3c Mon Sep 17 00:00:00 2001 From: riteshghorse Date: Tue, 19 Sep 2023 00:59:20 -0400 Subject: [PATCH] update doc --- .../ml/inference/huggingface_inference.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py index 2a4c8b1344fe..bc330488bc45 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py @@ -571,7 +571,7 @@ class HuggingFacePipelineModelHandler(ModelHandler[str, def __init__( self, task: Union[str, PipelineTask] = "", - model=None, + model: str = "", *, inference_fn: PipelineInferenceFn = _default_pipeline_inference_fn, load_pipeline_args: Optional[Dict[str, Any]] = None, @@ -594,11 +594,16 @@ def __init__( Args: task (str or enum.Enum): task supported by HuggingFace Pipelines. Accepts a string task or an enum.Enum from PipelineTask. - model : path to the pretrained model-id on Hugging Face Models Hub - to use custom model for the chosen task. If the model already defines - the task then no need to specify the task parameter. Use the model_id - string instead of an actual model here. Model specific kwargs can be - specified with model_kwargs using load_pipeline_args. + model (str): path to the pretrained model-id on Hugging Face Models Hub + to use custom model for the chosen task. If the `model` already defines + the task then no need to specify the `task` parameter. Use the model-id + string instead of an actual model here. + Model-specific kwargs for `from_pretrained(..., **model_kwargs)` can be + specified with `model_kwargs` using `load_pipeline_args`. + Example Usage: + model_handler = HuggingFacePipelineModelHandler( + model_uri="bert-base-uncased", model="meta-llama/Llama-2-7b-hf", + load_pipeline_args={'model_kwargs':{'quantization_map':config}}) inference_fn: the inference function to use during RunInference. Default is _default_pipeline_inference_fn. load_pipeline_args (Dict[str, Any]): keyword arguments to provide load