Skip to content

Commit

Permalink
fix: bugs text-classification bulk
Browse files Browse the repository at this point in the history
  • Loading branch information
davidberenstein1957 committed Aug 14, 2024
1 parent 5018ef1 commit f56423e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dataset-viber"
version = "0.1.1"
version = "0.1.2"
description = "Dataset Viber is your chill repo for data collection, annotation and vibe checks."
authors = [
{name = "davidberenstein1957", email = "[email protected]"},
Expand Down
36 changes: 20 additions & 16 deletions src/dataset_viber/_plotly/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,11 @@ def update_labels(n_clicks, selectedData, new_label, current_figure):
current_figure["data"] = updated_traces

local_dataframe = self.umap_df.copy()
tooltip_data = self.get_chat_tooltip(local_dataframe)
local_dataframe[self.content_column] = local_dataframe[
self.content_column
].apply(lambda x: x[0]["content"])
tooltip_data = self.get_tooltip(local_dataframe)
if self.content_format == "chat":
local_dataframe[self.content_column] = local_dataframe[
self.content_column
].apply(lambda x: x[0]["content"])
return current_figure, local_dataframe.to_dict("records"), tooltip_data

# Callback to print the dataframe
Expand Down Expand Up @@ -216,10 +217,11 @@ def update_selection(selectedData, figure):
]

local_dataframe = filtered_df.copy()
tooltip_data = self.get_chat_tooltip(local_dataframe)
local_dataframe[self.content_column] = local_dataframe[
self.content_column
].apply(lambda x: x[0]["content"])
tooltip_data = self.get_tooltip(local_dataframe)
if self.content_format == "chat":
local_dataframe[self.content_column] = local_dataframe[
self.content_column
].apply(lambda x: x[0]["content"])
return figure, local_dataframe.to_dict("records"), tooltip_data

self.app = app
Expand Down Expand Up @@ -433,24 +435,24 @@ def _get_app_layout(self, figure, dataframe, labels, hf_token):
n_clicks=0,
)
),
# dbc.Button(
# "Upload to Hub",
# id="upload-button",
# n_clicks=0,
# ),
dbc.Button(
"Upload to Hub",
id="upload-button",
n_clicks=0,
),
dbc.Button("Download Text", id="btn-download-txt"),
dcc.Download(id="download-text"),
]
)
if self.content_format == "chat":
tooltip_data = self.get_chat_tooltip(local_dataframe)
tooltip_data = self.get_tooltip(local_dataframe)
local_dataframe[self.content_column] = local_dataframe[
self.content_column
].apply(lambda x: x[0]["content"])
columns = local_dataframe.columns
elif self.content_format == "text":
tooltip_data = None
columns = [local_dataframe.columns]
columns = local_dataframe.columns
else:
raise ValueError(
"content_format should be either 'text' or 'chat' but got {self.content_format}"
Expand Down Expand Up @@ -568,7 +570,9 @@ def format_content(self, content, max_length=120, content_format="text"):
wrapped_text += f"<b>{turn['role']}</b>:<br>{self.format_content(turn['content'])}<br><br>"
return wrapped_text

def get_chat_tooltip(self, dataframe):
def get_tooltip(self, dataframe):
if self.content_format == "text":
return None
return [
{
self.content_column: {
Expand Down

0 comments on commit f56423e

Please sign in to comment.