Skip to content

Commit

Permalink
Update headline.py
Browse files Browse the repository at this point in the history
  • Loading branch information
joaorura committed Oct 29, 2024
1 parent e7987e5 commit 60cd938
Showing 1 changed file with 60 additions and 3 deletions.
63 changes: 60 additions & 3 deletions src/ragas/testset/transforms/splitters/headline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
import re
import typing as t
from dataclasses import dataclass
from unidecode import unidecode

from ragas.testset.graph import Node, NodeType, Relationship
from ragas.testset.transforms.base import Splitter


def normalize_text(text):
return unidecode(re.sub(r'\s+', '', text).lower())


def remove_indices(text):
cleaned_text = re.sub(r'(\d+\.)+ *', '', text)
return cleaned_text


def adjust_indices(original_text, indices):
last_index = 0
count = 0

indices = sorted(indices)
new_indices = []
for index in indices:
while last_index < len(original_text):
if not original_text[last_index].isspace():
count += 1
if count == index + 1:
new_indices.append(last_index)
last_index += 1
break
last_index += 1

return new_indices


@dataclass
class HeadlineSplitter(Splitter):
async def split(self, node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]]:
Expand All @@ -16,14 +46,41 @@ async def split(self, node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]]
if headlines is None:
raise ValueError("'headlines' property not found in this node")

if len(headlines) == 0:
return [], []

# create the chunks for the different sections
indices = []
normalized_text = normalize_text(text)

for headline in headlines:
indices.append(text.find(headline))
if headline is not None and not headline.isspace():
indice = normalized_text.find(normalize_text(headline))
if indice == -1:
text_search = remove_indices(headline)
text_search = normalize_text(text_search)
indice = normalized_text.find(text_search)

if indice != -1:
indices.append(indice)

if len(indices) == 0:
return [], []

indices = adjust_indices(text, indices)

indices.append(len(text))
chunks = [text[indices[i] : indices[i + 1]] for i in range(len(indices) - 1)]

# create the nodes
chunks = []
for i in range(len(indices) - 1):
aux = text[indices[i] : indices[i + 1]]

if not aux.isspace():
chunks.append(aux)

if len(chunks) == 0:
return [], []

nodes = [
Node(type=NodeType.CHUNK, properties={"page_content": chunk})
for chunk in chunks
Expand Down

0 comments on commit 60cd938

Please sign in to comment.