Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 27, 2023
1 parent 060fb54 commit fcb7aa9
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 166 deletions.
53 changes: 27 additions & 26 deletions tools/speech_data_explorer/data_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ def parse_args():
help='field name for which you want to see statistics (optional). Example: pred_text_contextnet.',
)
parser.add_argument(
'--gpu',
'-gpu',
action='store_true',
help='use GPU-acceleration',
'--gpu', '-gpu', action='store_true', help='use GPU-acceleration',
)
args = parser.parse_args()

Expand Down Expand Up @@ -490,7 +487,7 @@ def plot_histogram(data, key, label, gpu_acceleration=False):
data_frame = data[key].to_list()
else:
data_frame = [item[key] for item in data]

fig = px.histogram(
data_frame=data_frame,
nbins=50,
Expand All @@ -504,10 +501,10 @@ def plot_histogram(data, key, label, gpu_acceleration=False):
return fig


def plot_word_accuracy(vocabulary_data):
def plot_word_accuracy(vocabulary_data):
labels = ['Unrecognized', 'Sometimes recognized', 'Always recognized']
counts = [0, 0, 0]

if args.gpu:
counts[0] = (vocabulary_data['Accuracy'] == 0).sum()
counts[1] = (vocabulary_data['Accuracy'] < 100).sum()
Expand Down Expand Up @@ -576,24 +573,27 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
if args.gpu:
if args.names_compared is not None:
raise Exception(f"Currently comparision mode is not available with gpu acceleation.")

hypothesis_fields = ["pred_text"]
if args.show_statistics is not None:
hypothesis_fields = [args.show_statistics]

enable_plk = True
if args.disable_caching_metrics:
enable_plk = False

cu_df = cuDF()

dataset = Dataset(manifest_filepath = args.manifest, data_engine = cu_df,
hypothesis_fields = hypothesis_fields,
estimate_audio_metrics = args.estimate_audio_metrics,
enable_plk = enable_plk)

dataset = Dataset(
manifest_filepath=args.manifest,
data_engine=cu_df,
hypothesis_fields=hypothesis_fields,
estimate_audio_metrics=args.estimate_audio_metrics,
enable_plk=enable_plk,
)

dataset = dataset.process()

data = dataset.samples_data
num_hours = dataset.duration
vocabulary = dataset.vocabulary_data
Expand All @@ -602,8 +602,8 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
metrics_available = len(dataset.hypotheses) != 0
if metrics_available:
wer = dataset.hypotheses[hypothesis_fields[0]].wer
cer = dataset.hypotheses[hypothesis_fields[0]].cer
wmr = dataset.hypotheses[hypothesis_fields[0]].wmr
cer = dataset.hypotheses[hypothesis_fields[0]].cer
wmr = dataset.hypotheses[hypothesis_fields[0]].wmr
mwa = dataset.hypotheses[hypothesis_fields[0]].mwa
else:
if not comparison_mode:
Expand Down Expand Up @@ -706,7 +706,7 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):
figure_word_acc = plot_word_accuracy(vocabulary_data)
else:
figure_word_acc = plot_word_accuracy(vocabulary)

stats_layout = [
dbc.Row(dbc.Col(html.H5(children='Global Statistics'), class_name='text-secondary'), class_name='mt-3'),
dbc.Row(
Expand Down Expand Up @@ -827,7 +827,7 @@ def absolute_audio_filepath(audio_filepath, audio_base_path):

wordstable_columns = [{'name': 'Word', 'id': 'Word'}, {'name': 'Count', 'id': 'Amount'}]

if args.gpu:
if args.gpu:
vocabulary_columns = vocabulary.columns
else:
vocabulary_columns = vocabulary[0].keys()
Expand Down Expand Up @@ -910,22 +910,22 @@ def update_wordstable(page_current, sort_by, filter_query):
if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'):
if args.gpu:
vocabulary_view = vocabulary_view.loc[getattr(operator, op)(vocabulary_view[col_name], filter_value)]
else:
else:
vocabulary_view = [x for x in vocabulary_view if getattr(operator, op)(x[col_name], filter_value)]
elif op == 'contains':
vocabulary_view = [x for x in vocabulary_view if filter_value in str(x[col_name])]

if len(sort_by):
col = sort_by[0]['column_id']
ascending = sort_by[0]['direction'] != 'desc'

if args.gpu:
vocabulary_view = vocabulary_view.sort_values(col, ascending=ascending)
else:
vocabulary_view = sorted(vocabulary_view, key=lambda x: x[col], reverse=descending)
if page_current * DATA_PAGE_SIZE >= len(vocabulary_view):
page_current = len(vocabulary_view) // DATA_PAGE_SIZE

if args.gpu:
return [
vocabulary_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE].to_dict('records'),
Expand All @@ -937,6 +937,7 @@ def update_wordstable(page_current, sort_by, filter_query):
math.ceil(len(vocabulary_view) / DATA_PAGE_SIZE),
]


if args.gpu:
col_names = data.columns
else:
Expand Down Expand Up @@ -1564,22 +1565,22 @@ def update_datatable(page_current, sort_by, filter_query):
if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'):
if args.gpu:
data_view = data_view.loc[getattr(operator, op)(data_view[col_name], filter_value)]
else:
else:
data_view = [x for x in data_view if getattr(operator, op)(x[col_name], filter_value)]
elif op == 'contains':
data_view = [x for x in data_view if filter_value in str(x[col_name])]

if len(sort_by):
col = sort_by[0]['column_id']
ascending = sort_by[0]['direction'] != 'desc'

if args.gpu:
data_view = data_view.sort_values(col, ascending=ascending)
else:
data_view = sorted(data_view, key=lambda x: x[col], reverse=descending)
if page_current * DATA_PAGE_SIZE >= len(data_view):
page_current = len(data_view) // DATA_PAGE_SIZE

if args.gpu:
return [
data_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE].to_dict('records'),
Expand Down
Loading

0 comments on commit fcb7aa9

Please sign in to comment.