From 25d7d94b8934997596f0ddf6e30abd314dc9326b Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Tue, 30 Jul 2024 10:30:42 +0800 Subject: [PATCH] fix(rag): Fix db schema aretriever bug (#1755) --- .github/workflows/sync-docs.yaml | 20 ------------------- dbgpt/rag/retriever/db_schema.py | 6 ++++-- dbgpt/rag/retriever/tests/test_db_struct.py | 22 +++++++++++++-------- 3 files changed, 18 insertions(+), 30 deletions(-) delete mode 100644 .github/workflows/sync-docs.yaml diff --git a/.github/workflows/sync-docs.yaml b/.github/workflows/sync-docs.yaml deleted file mode 100644 index 03baeae86..000000000 --- a/.github/workflows/sync-docs.yaml +++ /dev/null @@ -1,20 +0,0 @@ -name: Trigger Auto Publish - -on: - push: - tags: - - "*" - -jobs: - trigger-api: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v2 - - - name: Trigger Publish API - run: | - curl -X POST ${{secrets.PUBLISH_SECRET_API}} \ - -H "Content-Type: application/json" \ - -d '{"tag": "${{ github.ref }}"}' diff --git a/dbgpt/rag/retriever/db_schema.py b/dbgpt/rag/retriever/db_schema.py index 0e9922f70..9bced9267 100644 --- a/dbgpt/rag/retriever/db_schema.py +++ b/dbgpt/rag/retriever/db_schema.py @@ -167,7 +167,7 @@ async def _aretrieve( result_candidates = await run_async_tasks( tasks=candidates, concurrency_limit=1 ) - return result_candidates + return cast(List[Chunk], reduce(lambda x, y: x + y, result_candidates)) else: from dbgpt.rag.summary.rdbms_db_summary import ( # noqa: F401 _parse_db_summary, @@ -177,7 +177,9 @@ async def _aretrieve( tasks=[self._aparse_db_summary(root_tracer.get_current_span_id())], concurrency_limit=1, ) - return [Chunk(content=table_summary) for table_summary in table_summaries] + return [ + Chunk(content=table_summary) for table_summary in table_summaries[0] + ] async def _aretrieve_with_score( self, diff --git a/dbgpt/rag/retriever/tests/test_db_struct.py b/dbgpt/rag/retriever/tests/test_db_struct.py index 4cda20365..0b667f69e 100644 --- a/dbgpt/rag/retriever/tests/test_db_struct.py +++ b/dbgpt/rag/retriever/tests/test_db_struct.py @@ -22,29 +22,35 @@ def mock_vector_store_connector(): @pytest.fixture -def dbstruct_retriever(mock_db_connection, mock_vector_store_connector): +def db_struct_retriever(mock_db_connection, mock_vector_store_connector): return DBSchemaRetriever( connector=mock_db_connection, index_store=mock_vector_store_connector, ) -def mock_parse_db_summary() -> str: +def mock_parse_db_summary(conn) -> List[str]: """Patch _parse_db_summary method.""" - return "Table summary" + return ["Table summary"] # Mocking the _parse_db_summary method in your test function @patch.object( dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary ) -def test_retrieve_with_mocked_summary(dbstruct_retriever): +def test_retrieve_with_mocked_summary(db_struct_retriever): query = "Table summary" - chunks: List[Chunk] = dbstruct_retriever._retrieve(query) + chunks: List[Chunk] = db_struct_retriever._retrieve(query) assert isinstance(chunks[0], Chunk) assert chunks[0].content == "Table summary" -async def async_mock_parse_db_summary() -> str: - """Asynchronous patch for _parse_db_summary method.""" - return "Table summary" +@pytest.mark.asyncio +@patch.object( + dbgpt.rag.summary.rdbms_db_summary, "_parse_db_summary", mock_parse_db_summary +) +async def test_aretrieve_with_mocked_summary(db_struct_retriever): + query = "Table summary" + chunks: List[Chunk] = await db_struct_retriever._aretrieve(query) + assert isinstance(chunks[0], Chunk) + assert chunks[0].content == "Table summary"