diff --git a/CHANGELOG.md b/CHANGELOG.md index f5affc5..ca6fdb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ ### Fixed - Removed deprecated LLMChain from GraphCypherQAChain to resolve instantiation issues with the use_function_response parameter. +- Removed unnecessary # type: ignore comments, improving type safety and code clarity. ## 0.1.1 diff --git a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py index 69ff000..e84a4df 100644 --- a/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py +++ b/libs/neo4j/langchain_neo4j/chains/graph_qa/cypher.py @@ -164,8 +164,8 @@ class GraphCypherQAChain(Chain): """ graph: GraphStore = Field(exclude=True) - cypher_generation_chain: Runnable - qa_chain: Runnable + cypher_generation_chain: Runnable[Dict[str, Any], str] + qa_chain: Runnable[Dict[str, Any], str] graph_schema: str input_key: str = "query" #: :meta private: output_key: str = "result" #: :meta private: @@ -239,7 +239,7 @@ def from_llm( qa_prompt: Optional[BasePromptTemplate] = None, cypher_prompt: Optional[BasePromptTemplate] = None, cypher_llm: Optional[BaseLanguageModel] = None, - qa_llm: Optional[Union[BaseLanguageModel, Any]] = None, + qa_llm: Optional[BaseLanguageModel] = None, exclude_types: List[str] = [], include_types: List[str] = [], validate_cypher: bool = False, @@ -250,16 +250,28 @@ def from_llm( **kwargs: Any, ) -> GraphCypherQAChain: """Initialize from LLM.""" + # Ensure at least one LLM is provided + if llm is None and qa_llm is None and cypher_llm is None: + raise ValueError("At least one LLM must be provided") - if not cypher_llm and not llm: - raise ValueError("Either `llm` or `cypher_llm` parameters must be provided") - if not qa_llm and not llm: - raise ValueError("Either `llm` or `qa_llm` parameters must be provided") - if cypher_llm and qa_llm and llm: + # Prevent all three LLMs from being provided simultaneously + if llm is not None and qa_llm is not None and cypher_llm is not None: raise ValueError( "You can specify up to two of 'cypher_llm', 'qa_llm'" ", and 'llm', but not all three simultaneously." ) + + # Assign default LLMs if specific ones are not provided + if llm is not None: + qa_llm = qa_llm or llm + cypher_llm = cypher_llm or llm + else: + # If llm is None, both qa_llm and cypher_llm must be provided + if qa_llm is None or cypher_llm is None: + raise ValueError( + "If `llm` is not provided, both `qa_llm` and `cypher_llm` must be " + "provided." + ) if cypher_prompt: if cypher_llm_kwargs: raise ValueError( @@ -271,6 +283,11 @@ def from_llm( cypher_prompt = cypher_llm_kwargs.pop( "prompt", CYPHER_GENERATION_PROMPT ) + if not isinstance(cypher_prompt, BasePromptTemplate): + raise ValueError( + "The cypher_llm_kwargs `prompt` must inherit from " + "BasePromptTemplate" + ) else: cypher_prompt = CYPHER_GENERATION_PROMPT if qa_prompt: @@ -282,6 +299,11 @@ def from_llm( else: if qa_llm_kwargs: qa_prompt = qa_llm_kwargs.pop("prompt", CYPHER_QA_PROMPT) + if not isinstance(qa_prompt, BasePromptTemplate): + raise ValueError( + "The qa_llm_kwargs `prompt` must inherit from " + "BasePromptTemplate" + ) else: qa_prompt = CYPHER_QA_PROMPT use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {} @@ -289,10 +311,12 @@ def from_llm( cypher_llm_kwargs if cypher_llm_kwargs is not None else {} ) - qa_llm = qa_llm or llm if use_function_response: try: - qa_llm.bind_tools({}) # type: ignore[union-attr] + if hasattr(qa_llm, "bind_tools"): + qa_llm.bind_tools({}) + else: + raise AttributeError response_prompt = ChatPromptTemplate.from_messages( [ SystemMessage(content=function_response_system), @@ -300,15 +324,13 @@ def from_llm( MessagesPlaceholder(variable_name="function_response"), ] ) - qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore + qa_chain = response_prompt | qa_llm | StrOutputParser() except (NotImplementedError, AttributeError): raise ValueError("Provided LLM does not support native tools/functions") else: - qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser() # type: ignore - - cypher_llm = cypher_llm or llm + qa_chain = qa_prompt | qa_llm.bind(**use_qa_llm_kwargs) | StrOutputParser() cypher_generation_chain = ( - cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser() # type: ignore + cypher_prompt | cypher_llm.bind(**use_cypher_llm_kwargs) | StrOutputParser() ) if exclude_types and include_types: @@ -379,6 +401,7 @@ def _call( else: context = [] + final_result: Union[List[Dict[str, Any]], str] if self.return_direct: final_result = context else: @@ -390,15 +413,14 @@ def _call( intermediate_steps.append({"context": context}) if self.use_function_response: function_response = get_function_response(question, context) - final_result = self.qa_chain.invoke( # type: ignore + final_result = self.qa_chain.invoke( {"question": question, "function_response": function_response}, ) else: - result = self.qa_chain.invoke( # type: ignore + final_result = self.qa_chain.invoke( {"question": question, "context": context}, callbacks=callbacks, ) - final_result = result # type: ignore chain_result: Dict[str, Any] = {self.output_key: final_result} if self.return_intermediate_steps: diff --git a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py index 506533c..dd97de0 100644 --- a/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py +++ b/libs/neo4j/langchain_neo4j/graphs/neo4j_graph.py @@ -461,21 +461,23 @@ def query( or e.code == "Neo.DatabaseError.Transaction.TransactionStartFailed" ) - and "in an implicit transaction" in e.message # type: ignore + and e.message is not None + and "in an implicit transaction" in e.message ) or ( # isPeriodicCommitError e.code == "Neo.ClientError.Statement.SemanticError" + and e.message is not None and ( - "in an open transaction is not possible" in e.message # type: ignore - or "tried to execute in an explicit transaction" in e.message # type: ignore + "in an open transaction is not possible" in e.message + or "tried to execute in an explicit transaction" in e.message ) ) ): raise # fallback to allow implicit transactions with self._driver.session(database=self._database) as session: - data = session.run(Query(text=query, timeout=self.timeout), params) # type: ignore - json_data = [r.data() for r in data] + result = session.run(Query(text=query, timeout=self.timeout), params) + json_data = [r.data() for r in result] if self.sanitize: json_data = [value_sanitize(el) for el in json_data] return json_data diff --git a/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py b/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py index 452ffc1..9438f1b 100644 --- a/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py +++ b/libs/neo4j/langchain_neo4j/vectorstores/neo4j_vector.py @@ -655,21 +655,23 @@ def query( or e.code == "Neo.DatabaseError.Transaction.TransactionStartFailed" ) - and "in an implicit transaction" in e.message # type: ignore + and e.message is not None + and "in an implicit transaction" in e.message ) or ( # isPeriodicCommitError e.code == "Neo.ClientError.Statement.SemanticError" + and e.message is not None and ( - "in an open transaction is not possible" in e.message # type: ignore - or "tried to execute in an explicit transaction" in e.message # type: ignore + "in an open transaction is not possible" in e.message + or "tried to execute in an explicit transaction" in e.message ) ) ): raise # Fallback to allow implicit transactions with self._driver.session(database=self._database) as session: - data = session.run(Query(text=query), params) # type: ignore - return [r.data() for r in data] + result = session.run(Query(text=query), params) + return [r.data() for r in result] def verify_version(self) -> None: """ diff --git a/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py b/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py index 17fcace..6c27707 100644 --- a/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py +++ b/libs/neo4j/tests/integration_tests/graphs/test_neo4j.py @@ -122,8 +122,9 @@ def test_neo4j_timeout() -> None: try: graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})") except Exception as e: + assert hasattr(e, "code") assert ( - e.code # type: ignore[attr-defined] + e.code == "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration" ) diff --git a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py index 0a74ffa..3a617b5 100644 --- a/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py +++ b/libs/neo4j/tests/unit_tests/graphs/test_neo4j_graph.py @@ -5,14 +5,14 @@ from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph, value_sanitize -def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def] +def test_value_sanitize_with_small_list() -> None: small_list = list(range(15)) # list size > LIST_LIMIT input_dict = {"key1": "value1", "small_list": small_list} expected_output = {"key1": "value1", "small_list": small_list} assert value_sanitize(input_dict) == expected_output -def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def] +def test_value_sanitize_with_oversized_list() -> None: oversized_list = list(range(150)) # list size > LIST_LIMIT input_dict = {"key1": "value1", "oversized_list": oversized_list} expected_output = { @@ -22,21 +22,21 @@ def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def] assert value_sanitize(input_dict) == expected_output -def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def] +def test_value_sanitize_with_nested_oversized_list() -> None: oversized_list = list(range(150)) # list size > LIST_LIMIT input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}} expected_output = {"key1": "value1", "oversized_list": {}} assert value_sanitize(input_dict) == expected_output -def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def] +def test_value_sanitize_with_dict_in_list() -> None: oversized_list = list(range(150)) # list size > LIST_LIMIT input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]} expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]} assert value_sanitize(input_dict) == expected_output -def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-def] +def test_value_sanitize_with_dict_in_nested_list() -> None: input_dict = { "key1": "value1", "deeply_nested_lists": [[[[{"final_nested_key": list(range(200))}]]]], @@ -45,9 +45,9 @@ def test_value_sanitize_with_dict_in_nested_list(): # type: ignore[no-untyped-d assert value_sanitize(input_dict) == expected_output -def test_driver_state_management(): # type: ignore[no-untyped-def] +def test_driver_state_management() -> None: """Comprehensive test for driver state management.""" - with patch("neo4j.GraphDatabase.driver") as mock_driver: + with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver: # Setup mock driver mock_driver_instance = MagicMock() mock_driver.return_value = mock_driver_instance @@ -60,7 +60,7 @@ def test_driver_state_management(): # type: ignore[no-untyped-def] # Store original driver original_driver = graph._driver - original_driver.close = MagicMock() + assert isinstance(original_driver.close, MagicMock) # Test initial state assert hasattr(graph, "_driver") @@ -84,9 +84,9 @@ def test_driver_state_management(): # type: ignore[no-untyped-def] graph.refresh_schema() -def test_close_method_removes_driver(): # type: ignore[no-untyped-def] +def test_close_method_removes_driver() -> None: """Test that close method removes the _driver attribute.""" - with patch("neo4j.GraphDatabase.driver") as mock_driver: + with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver: # Configure mock to return a mock driver mock_driver_instance = MagicMock() mock_driver.return_value = mock_driver_instance @@ -103,9 +103,7 @@ def test_close_method_removes_driver(): # type: ignore[no-untyped-def] # Store a reference to the original driver original_driver = graph._driver - - # Ensure driver's close method can be mocked - original_driver.close = MagicMock() + assert isinstance(original_driver.close, MagicMock) # Call close method graph.close() @@ -120,9 +118,9 @@ def test_close_method_removes_driver(): # type: ignore[no-untyped-def] graph.close() # Should not raise any exception -def test_multiple_close_calls_safe(): # type: ignore[no-untyped-def] +def test_multiple_close_calls_safe() -> None: """Test that multiple close calls do not raise errors.""" - with patch("neo4j.GraphDatabase.driver") as mock_driver: + with patch("neo4j.GraphDatabase.driver", autospec=True) as mock_driver: # Configure mock to return a mock driver mock_driver_instance = MagicMock() mock_driver.return_value = mock_driver_instance @@ -139,9 +137,7 @@ def test_multiple_close_calls_safe(): # type: ignore[no-untyped-def] # Store a reference to the original driver original_driver = graph._driver - - # Mock the driver's close method - original_driver.close = MagicMock() + assert isinstance(original_driver.close, MagicMock) # First close graph.close() diff --git a/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py b/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py index 837cb79..201864d 100644 --- a/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py +++ b/libs/neo4j/tests/unit_tests/vectorstores/test_neo4j.py @@ -321,7 +321,7 @@ def test_get_search_index_query_invalid_search_type() -> None: with pytest.raises(ValueError) as exc_info: _get_search_index_query( - search_type=invalid_search_type, # type: ignore + search_type=invalid_search_type, # type: ignore[arg-type] index_type=IndexType.NODE, ) @@ -356,7 +356,7 @@ def test_check_if_not_null_with_none_value() -> None: def test_handle_field_filter_invalid_field_type() -> None: with pytest.raises(ValueError) as exc_info: - _handle_field_filter(field=123, value="some_value") # type: ignore + _handle_field_filter(field=123, value="some_value") # type: ignore[arg-type] assert "field should be a string" in str(exc_info.value) @@ -535,7 +535,7 @@ def test_neo4jvector_invalid_distance_strategy() -> None: url="bolt://localhost:7687", username="neo4j", password="password", - distance_strategy="INVALID_STRATEGY", # type: ignore + distance_strategy="INVALID_STRATEGY", # type: ignore[arg-type] ) assert "distance_strategy must be either 'EUCLIDEAN_DISTANCE' or 'COSINE'" in str( exc_info.value