From f56423e7452065865cfbf3f5c4890637167960ff Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Wed, 14 Aug 2024 20:40:40 +0200 Subject: [PATCH] fix: bugs text-classification bulk --- pyproject.toml | 2 +- src/dataset_viber/_plotly/bulk.py | 36 +++++++++++++++++-------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6885fe0..64a16a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dataset-viber" -version = "0.1.1" +version = "0.1.2" description = "Dataset Viber is your chill repo for data collection, annotation and vibe checks." authors = [ {name = "davidberenstein1957", email = "david.m.berenstein@gmail.com"}, diff --git a/src/dataset_viber/_plotly/bulk.py b/src/dataset_viber/_plotly/bulk.py index 2ec5511..83219af 100644 --- a/src/dataset_viber/_plotly/bulk.py +++ b/src/dataset_viber/_plotly/bulk.py @@ -153,10 +153,11 @@ def update_labels(n_clicks, selectedData, new_label, current_figure): current_figure["data"] = updated_traces local_dataframe = self.umap_df.copy() - tooltip_data = self.get_chat_tooltip(local_dataframe) - local_dataframe[self.content_column] = local_dataframe[ - self.content_column - ].apply(lambda x: x[0]["content"]) + tooltip_data = self.get_tooltip(local_dataframe) + if self.content_format == "chat": + local_dataframe[self.content_column] = local_dataframe[ + self.content_column + ].apply(lambda x: x[0]["content"]) return current_figure, local_dataframe.to_dict("records"), tooltip_data # Callback to print the dataframe @@ -216,10 +217,11 @@ def update_selection(selectedData, figure): ] local_dataframe = filtered_df.copy() - tooltip_data = self.get_chat_tooltip(local_dataframe) - local_dataframe[self.content_column] = local_dataframe[ - self.content_column - ].apply(lambda x: x[0]["content"]) + tooltip_data = self.get_tooltip(local_dataframe) + if self.content_format == "chat": + local_dataframe[self.content_column] = local_dataframe[ + self.content_column + ].apply(lambda x: x[0]["content"]) return figure, local_dataframe.to_dict("records"), tooltip_data self.app = app @@ -433,24 +435,24 @@ def _get_app_layout(self, figure, dataframe, labels, hf_token): n_clicks=0, ) ), - # dbc.Button( - # "Upload to Hub", - # id="upload-button", - # n_clicks=0, - # ), + dbc.Button( + "Upload to Hub", + id="upload-button", + n_clicks=0, + ), dbc.Button("Download Text", id="btn-download-txt"), dcc.Download(id="download-text"), ] ) if self.content_format == "chat": - tooltip_data = self.get_chat_tooltip(local_dataframe) + tooltip_data = self.get_tooltip(local_dataframe) local_dataframe[self.content_column] = local_dataframe[ self.content_column ].apply(lambda x: x[0]["content"]) columns = local_dataframe.columns elif self.content_format == "text": tooltip_data = None - columns = [local_dataframe.columns] + columns = local_dataframe.columns else: raise ValueError( "content_format should be either 'text' or 'chat' but got {self.content_format}" @@ -568,7 +570,9 @@ def format_content(self, content, max_length=120, content_format="text"): wrapped_text += f"{turn['role']}:
{self.format_content(turn['content'])}

" return wrapped_text - def get_chat_tooltip(self, dataframe): + def get_tooltip(self, dataframe): + if self.content_format == "text": + return None return [ { self.content_column: {