diff --git a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py index cc77265d..f37c9dae 100644 --- a/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py +++ b/pyrit/prompt_target/hugging_face/hugging_face_chat_target.py @@ -38,7 +38,8 @@ class HuggingFaceChatTarget(PromptChatTarget): def __init__( self, *, - model_id: str, + model_id: Optional[str] = None, + model_path: Optional[str] = None, hf_access_token: Optional[str] = None, use_cuda: bool = False, tensor_format: str = "pt", @@ -47,16 +48,26 @@ def __init__( temperature: float = 1.0, top_p: float = 1.0, skip_special_tokens: bool = True, + trust_remote_code: bool = False, ) -> None: super().__init__() + if (model_id is None) == (model_path is None): + raise ValueError("Provide exactly one of `model_id` or `model_path`.") + self.model_id = model_id + self.model_path = model_path self.use_cuda = use_cuda self.tensor_format = tensor_format + self.trust_remote_code = trust_remote_code - # Use the `get_required_value` to get the API key (from env or passed value) - self.huggingface_token = default_values.get_required_value( - env_var_name=self.HUGGINGFACE_TOKEN_ENVIRONMENT_VARIABLE, passed_value=hf_access_token + # Only get the Hugging Face token if a model ID is provided + self.huggingface_token = ( + default_values.get_required_value( + env_var_name=self.HUGGINGFACE_TOKEN_ENVIRONMENT_VARIABLE, passed_value=hf_access_token + ) + if model_id + else None ) try: @@ -106,46 +117,69 @@ async def load_model_and_tokenizer(self): Exception: If the model loading fails. """ try: - # Define the default Hugging Face cache directory - cache_dir = os.path.join( - os.path.expanduser("~"), ".cache", "huggingface", "hub", f"models--{self.model_id.replace('/', '--')}" - ) + # Determine the identifier for caching purposes + model_identifier = self.model_path or self.model_id # Check if the model is already cached - if HuggingFaceChatTarget._cache_enabled and HuggingFaceChatTarget._cached_model_id == self.model_id: - logger.info(f"Using cached model and tokenizer for {self.model_id}.") + if HuggingFaceChatTarget._cache_enabled and HuggingFaceChatTarget._cached_model_id == model_identifier: + logger.info(f"Using cached model and tokenizer for {model_identifier}.") self.model = HuggingFaceChatTarget._cached_model self.tokenizer = HuggingFaceChatTarget._cached_tokenizer return - if self.necessary_files is None: - # Download all files if no specific files are provided - logger.info(f"Downloading all files for {self.model_id}...") - await download_specific_files(self.model_id, None, self.huggingface_token, cache_dir) + if self.model_path: + # Load the tokenizer and model from the local directory + logger.info(f"Loading model from local path: {self.model_path}...") + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, trust_remote_code=self.trust_remote_code + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, trust_remote_code=self.trust_remote_code + ) else: - # Download only the necessary files - logger.info(f"Downloading specific files for {self.model_id}...") - await download_specific_files(self.model_id, self.necessary_files, self.huggingface_token, cache_dir) - - # Load the tokenizer and model from the specified directory - logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...") - self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, cache_dir=cache_dir) - self.model = AutoModelForCausalLM.from_pretrained(self.model_id, cache_dir=cache_dir) + # Define the default Hugging Face cache directory + cache_dir = os.path.join( + os.path.expanduser("~"), + ".cache", + "huggingface", + "hub", + f"models--{self.model_id.replace('/', '--')}", + ) + + if self.necessary_files is None: + # Download all files if no specific files are provided + logger.info(f"Downloading all files for {self.model_id}...") + await download_specific_files(self.model_id, None, self.huggingface_token, cache_dir) + else: + # Download only the necessary files + logger.info(f"Downloading specific files for {self.model_id}...") + await download_specific_files( + self.model_id, self.necessary_files, self.huggingface_token, cache_dir + ) + + # Load the tokenizer and model from the specified directory + logger.info(f"Loading model {self.model_id} from cache path: {cache_dir}...") + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_id, cache_dir=cache_dir, trust_remote_code=self.trust_remote_code + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_id, cache_dir=cache_dir, trust_remote_code=self.trust_remote_code + ) # Move the model to the correct device self.model = self.model.to(self.device) # Debug prints to check types - logger.info(f"Model loaded: {type(self.model)}") # Debug print - logger.info(f"Tokenizer loaded: {type(self.tokenizer)}") # Debug print + logger.info(f"Model loaded: {type(self.model)}") + logger.info(f"Tokenizer loaded: {type(self.tokenizer)}") # Cache the loaded model and tokenizer if caching is enabled if HuggingFaceChatTarget._cache_enabled: HuggingFaceChatTarget._cached_model = self.model HuggingFaceChatTarget._cached_tokenizer = self.tokenizer - HuggingFaceChatTarget._cached_model_id = self.model_id + HuggingFaceChatTarget._cached_model_id = model_identifier - logger.info(f"Model {self.model_id} loaded successfully.") + logger.info(f"Model {model_identifier} loaded successfully.") except Exception as e: logger.error(f"Error loading model {self.model_id}: {e}") @@ -207,10 +241,12 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P logger.info(f"Assistant's response: {assistant_response}") + model_identifier = self.model_id or self.model_path + return construct_response_from_request( request=request, response_text_pieces=[assistant_response], - prompt_metadata=json.dumps({"model_id": self.model_id}), + prompt_metadata=json.dumps({"model_id": model_identifier}), ) except Exception as e: diff --git a/tests/target/test_huggingface_chat_target.py b/tests/target/test_huggingface_chat_target.py index b4f43a78..208743f6 100644 --- a/tests/target/test_huggingface_chat_target.py +++ b/tests/target/test_huggingface_chat_target.py @@ -225,3 +225,41 @@ def test_enable_disable_cache(): assert HuggingFaceChatTarget._cached_model is None assert HuggingFaceChatTarget._cached_tokenizer is None assert HuggingFaceChatTarget._cached_model_id is None + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +@pytest.mark.asyncio +async def test_load_model_with_model_path(): + """Test loading a model from a local directory (`model_path`).""" + model_path = "./mock_local_model_path" + hf_chat = HuggingFaceChatTarget(model_path=model_path, use_cuda=False, trust_remote_code=False) + await hf_chat.load_model_and_tokenizer() + assert hf_chat.model is not None + assert hf_chat.tokenizer is not None + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +@pytest.mark.asyncio +async def test_load_model_with_trust_remote_code(): + """Test loading a remote model requiring `trust_remote_code=True`.""" + model_id = "mock_remote_model" + hf_chat = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, trust_remote_code=True) + await hf_chat.load_model_and_tokenizer() + assert hf_chat.model is not None + assert hf_chat.tokenizer is not None + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_init_with_both_model_id_and_model_path_raises(): + """Ensure providing both `model_id` and `model_path` raises an error.""" + with pytest.raises(ValueError) as excinfo: + HuggingFaceChatTarget(model_id="test_model", model_path="./mock_local_model_path", use_cuda=False) + assert "Provide exactly one of `model_id` or `model_path`." in str(excinfo.value) + + +@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed") +def test_load_model_without_model_id_or_path(): + """Ensure initializing without `model_id` or `model_path` raises an error.""" + with pytest.raises(ValueError) as excinfo: + HuggingFaceChatTarget(use_cuda=False) + assert "Provide exactly one of `model_id` or `model_path`." in str(excinfo.value)