diff --git a/libs/core/langchain_core/output_parsers/xml.py b/libs/core/langchain_core/output_parsers/xml.py index 9a93023ce1298..c0e3e72baf4e1 100644 --- a/libs/core/langchain_core/output_parsers/xml.py +++ b/libs/core/langchain_core/output_parsers/xml.py @@ -55,9 +55,12 @@ def parse(self, text: str) -> Dict[str, List[Any]]: def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: + xml_start_re = re.compile(r"<[a-zA-Z:_]") parser = ET.XMLPullParser(["start", "end"]) + xml_started = False current_path: List[str] = [] current_path_has_children = False + buffer = "" for chunk in input: if isinstance(chunk, BaseMessage): # extract text @@ -65,8 +68,19 @@ def _transform( if not isinstance(chunk_content, str): continue chunk = chunk_content - # pass chunk to parser - parser.feed(chunk) + # add chunk to buffer of unprocessed text + buffer += chunk + # if xml string hasn't started yet, continue to next chunk + if not xml_started: + if match := xml_start_re.search(buffer): + # if xml string has started, remove all text before it + buffer = buffer[match.start() :] + xml_started = True + else: + continue + # feed buffer to parser + parser.feed(buffer) + buffer = "" # yield all events for event, elem in parser.read_events(): if event == "start": @@ -80,7 +94,10 @@ def _transform( if not current_path_has_children: yield nested_element(current_path, elem) # prevent yielding of parent element - current_path_has_children = True + if current_path: + current_path_has_children = True + else: + xml_started = False # close parser parser.close() diff --git a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py index fb92e96331a9c..697f4e4776e81 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_xml_parser.py @@ -22,7 +22,22 @@ @pytest.mark.parametrize( "result", - [DEF_RESULT_ENCODING, DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :]], + [ + DEF_RESULT_ENCODING, + DEF_RESULT_ENCODING[DEF_RESULT_ENCODING.find("\n") :], + f""" +```xml +{DEF_RESULT_ENCODING} +``` +""", + f""" +Some random text +```xml +{DEF_RESULT_ENCODING} +``` +More random text +""", + ], ) def test_xml_output_parser(result: str) -> None: """Test XMLOutputParser."""