Skip to content

Commit

Permalink
semantic density now calculated in the class during variational infer…
Browse files Browse the repository at this point in the history
…ence
  • Loading branch information
J-Dymond committed Oct 10, 2024
1 parent 8a41bb6 commit ac2f04e
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 30 deletions.
25 changes: 15 additions & 10 deletions scripts/variational_TTS_example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
An example use of the transcription, translation and summarisation pipeline.
"""
import json

import torch
from datasets import Audio, load_dataset

Expand All @@ -11,14 +9,17 @@

def main(TTS_params):
"""main function"""
var_pipe = TTSVariationalPipeline(TTS_params)
var_pipe = TTSVariationalPipeline(TTS_params,n_variational_runs=2)

ds = load_dataset(
"facebook/multilingual_librispeech", "french", split="test", streaming=True
)
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]

clean_output = var_pipe.clean_inference(input_speech["array"])
var_pipe.clean_inference(input_speech["array"])
clean_output = var_pipe.clean_output

# logit shapes
print("\nLogit shapes:")
for step in var_pipe.pipeline_map.keys():
Expand Down Expand Up @@ -47,14 +48,18 @@ def main(TTS_params):
print(f"{step.capitalize()}: {step_prob}")
print(f"Cumulative confidence: {cumulative}")

variational_output = var_pipe.variational_inference(x=input_speech['array'],n_runs=2)
print("\nConditional probabilities:")
for step in var_pipe.pipeline_map.keys():
token_probs = clean_output[step]["probs"]
cond_prob = torch.pow(torch.prod(token_probs,-1),1/len(token_probs))
print(f"{step.capitalize()}: {cond_prob}")

var_pipe.variational_inference(x=input_speech['array'])
variational_output = var_pipe.var_output
print("\nVariational Inference Semantic Density:")
for step in variational_output['variational'].keys():
print(f"{step}: {variational_output['variational'][step]['semantic_density']}")

for step in var_pipe.pipeline_map.keys():
print(f'\n{step}:')
step_output = variational_output['variational'][step]
for run in step_output:
print(run['semantic_embedding'])

if __name__ == "__main__":
TTS_pars = {
Expand Down
105 changes: 85 additions & 20 deletions src/arc_spice/dropout_utils/variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import torch
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.nn.functional import softmax
from torch.nn.functional import cosine_similarity, softmax
from transformers import (
AutomaticSpeechRecognitionPipeline,
AutoModel,
AutoModelForSequenceClassification,
AutoTokenizer,
SummarizationPipeline,
TranslationPipeline,
Expand Down Expand Up @@ -39,7 +40,7 @@ class TTSVariationalPipeline:
variational version of the TTS pipeline
"""

def __init__(self, pars: dict[str : dict[str:str]]):
def __init__(self, pars: dict[str : dict[str:str]], n_variational_runs=5):
self.transcriber = pipeline(
task=pars["transcriber"]["specific_task"],
model=pars["transcriber"]["model"],
Expand All @@ -64,17 +65,35 @@ def __init__(self, pars: dict[str : dict[str:str]]):
"sentence-transformers/all-MiniLM-L6-v2"
)

self.nli_tokenizer = AutoTokenizer.from_pretrained(
"microsoft/deberta-large-mnli"
)

self.nli_model = AutoModelForSequenceClassification.from_pretrained(
"microsoft/deberta-large-mnli"
)

self.pipeline_map = {
"transcription": self.transcriber,
"translation": self.translator,
"summarisation": self.summariser,
}
self.generate_kwargs = {"output_scores": True}

self.func_map = {
"transcription": self.transcribe,
"translation": self.translate,
"summarisation": self.summarise,
}
self.naive_outputs = {
"outputs",
"logits",
"entropy",
"normalised_entropy",
"probs",
"semantic_embedding",
}
self.n_variational_runs = n_variational_runs

def get_confidence_metrics(
self, output_dict: dict[str : str | torch.Tensor]
Expand Down Expand Up @@ -149,6 +168,50 @@ def summarise(self, source_text: str):
output_dict.update(confidence_metrics)
return output_dict

def collect_metrics(self):
new_var_dict = {}
for step in self.var_output["variational"].keys():
new_var_dict[step] = {}
for metric in self.naive_outputs:
new_values = [None] * self.n_variational_runs
for run in range(self.n_variational_runs):
new_values[run] = self.var_output["variational"][step][run][metric]
new_var_dict[step][metric] = new_values

self.var_output["variational"] = new_var_dict

def calculate_semantic_density(self):
for step in self.var_output["variational"].keys():
clean_out = self.var_output["clean"][step]["outputs"]
var_step = self.var_output["variational"][step]
kernel_funcs = torch.zeros(self.n_variational_runs)
cond_probs = torch.zeros(self.n_variational_runs)
sims = [None] * self.n_variational_runs
for run_index, run_out in enumerate(var_step["outputs"]):
run_prob = var_step["probs"][run_index]
nli_inp = clean_out + " [SEP] " + run_out
encoded_nli = self.nli_tokenizer.encode(
nli_inp, padding=True, return_tensors="pt"
)
sims[run_index] = cosine_similarity(
self.var_output["clean"][step]["semantic_embedding"],
var_step["semantic_embedding"][run_index],
)
nli_out = softmax(self.nli_model(encoded_nli)["logits"], dim=-1)[0]
kernel_func = 1 - (nli_out[0] + (0.5 * nli_out[1]))
cond_probs[run_index] = torch.pow(
torch.prod(run_prob, -1), 1 / len(run_prob)
)
kernel_funcs[run_index] = kernel_func
semantic_density = (
1
/ (torch.sum(cond_probs))
* torch.sum(torch.mul(cond_probs, kernel_funcs))
)
self.var_output["variational"][step].update(
{"semantic_density": semantic_density.item(), "cosine_similarity": sims}
)

def clean_inference(self, x: Union[np.ndarray, bytes, str]):
"""
Expand All @@ -161,48 +224,50 @@ def clean_inference(self, x: Union[np.ndarray, bytes, str]):
summarised transcript with associated unvertainties at each step
"""

output = {step: {} for step in self.pipeline_map.keys()}
self.clean_output = {step: {} for step in self.pipeline_map.keys()}
# transcription
transcription = self.transcribe(x)
output["transcription"].update(transcription)
self.clean_output["transcription"].update(transcription)

# translation
translation = self.translate(transcription["outputs"])
output["translation"].update(translation)
self.clean_output["translation"].update(translation)

# summarisation
summarisation = self.summarise(translation["outputs"])
output["summarisation"].update(summarisation)
self.clean_output["summarisation"].update(summarisation)

return output

def variational_inference(self, x, n_runs=5):
def variational_inference(self, x):
# we need clean inputs to pass to each step, we run that first
output = {"clean": {}, "variational": {}}
output["clean"] = self.clean_inference(x)
self.var_output = {"clean": {}, "variational": {}}
self.clean_inference(x)
self.var_output["clean"] = self.clean_output
# each step accepts a different input from the clean pipeline
input_map = {
"transcription": x,
"translation": output["clean"]["transcription"]["outputs"],
"summarisation": output["clean"]["translation"]["outputs"],
"translation": self.var_output["clean"]["transcription"]["outputs"],
"summarisation": self.var_output["clean"]["translation"]["outputs"],
}
# for each model in pipeline
for model_key, pl in self.pipeline_map.items():
# turn on dropout for this model
set_dropout(model=pl.model, dropout_flag=True)
# create the output list
output["variational"][model_key] = [None] * n_runs
self.var_output["variational"][model_key] = [None] * self.n_variational_runs
# do n runs of the inference
for run_idx in range(n_runs):
output["variational"][model_key][run_idx] = self.func_map[model_key](
input_map[model_key]
)
for run_idx in range(self.n_variational_runs):
self.var_output["variational"][model_key][run_idx] = self.func_map[
model_key
](input_map[model_key])
# turn off dropout for this model
set_dropout(model=pl.model, dropout_flag=False)
return output

self.collect_metrics()
self.calculate_semantic_density()

def __call__(self, x):
return self.clean_inference(x)
self.clean_inference(x)
return self.clean_output


class CustomSpeechRecognitionPipeline(AutomaticSpeechRecognitionPipeline):
Expand Down

0 comments on commit ac2f04e

Please sign in to comment.