Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Python] Add device param to HuggingFacePipelineModelHandler #29223

Merged
merged 7 commits into from
Nov 1, 2023
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions sdks/python/apache_beam/ml/inference/huggingface_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ def __init__(
task: Union[str, PipelineTask] = "",
model: str = "",
*,
device: Optional[str] = None,
inference_fn: PipelineInferenceFn = _default_pipeline_inference_fn,
load_pipeline_args: Optional[Dict[str, Any]] = None,
inference_args: Optional[Dict[str, Any]] = None,
Expand All @@ -583,10 +584,6 @@ def __init__(
"""
Implementation of the ModelHandler interface for Hugging Face Pipelines.

**Note:** To specify which device to use (CPU/GPU),
use the load_pipeline_args with key-value as you would do in the usual
Hugging Face pipeline. Ex: load_pipeline_args={'device':0})

Example Usage model::
pcoll | RunInference(HuggingFacePipelineModelHandler(
task="fill-mask"))
Expand All @@ -606,6 +603,11 @@ def __init__(
task="text-generation", model="meta-llama/Llama-2-7b-hf",
load_pipeline_args={'model_kwargs':{'quantization_map':config}})

device (str): the device (`"CPU"` or `"GPU"`) on which you wish to run
the pipeline. Defaults to GPU. If GPU is not available then it falls
back to CPU. You can also use advanced option like `device_map` with
key-value pair as you would do in the usual Hugging Face pipeline using
`load_pipeline_args`. Ex: load_pipeline_args={'device_map':auto}).
inference_fn: the inference function to use during RunInference.
Default is _default_pipeline_inference_fn.
load_pipeline_args (Dict[str, Any]): keyword arguments to provide load
Expand Down Expand Up @@ -638,8 +640,37 @@ def __init__(
if max_batch_size is not None:
self._batching_kwargs['max_batch_size'] = max_batch_size
self._large_model = large_model

# Check if the device is specified twice. If true then the device parameter
# of model handler is overridden.
self._deduplicate_device_value(device)
_validate_constructor_args_hf_pipeline(self._task, self._model)

def _deduplicate_device_value(self, device: str):
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
damccorm marked this conversation as resolved.
Show resolved Hide resolved
current_device = device.upper() if device else None
if 'device' not in self._load_pipeline_args:
if (not current_device and current_device != 'CPU' and
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
current_device != 'GPU'):
raise ValueError(
f"Invalid device value: {device}. Please specify "
"either CPU or GPU. Defaults to GPU if no value "
"is provided.")
elif current_device == 'CPU':
self._load_pipeline_args['device'] = 'cpu'
else:
if is_gpu_available_torch():
self._load_pipeline_args['device'] = 'cuda:1'
else:
_LOGGER.warning(
"HuggingFaceModelHandler specified a 'GPU' device, "
"but GPUs are not available. Switching to CPU.")
self._load_pipeline_args['device'] = 'cpu'
else:
if current_device:
_LOGGER.warning(
riteshghorse marked this conversation as resolved.
Show resolved Hide resolved
'`device` specified in `load_pipeline_args`. `device` '
'parameter for HuggingFacePipelineModelHandler will be ignored.')

def load_model(self):
"""Loads and initializes the pipeline for processing."""
return pipeline(
Expand Down
Loading