Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
J-Dymond committed Nov 15, 2024
2 parents 3fe9180 + e9b2d17 commit 7aadf7d
Show file tree
Hide file tree
Showing 5 changed files with 479 additions and 3 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,7 @@ Thumbs.db
*.swp


# project related
data/Taxi1500*
slurm_scripts/slurm_logs*
# other
temp
.vscode
.venv-3.10
79 changes: 79 additions & 0 deletions scripts/variational_TTS_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
An example use of the transcription, translation and summarisation pipeline.
"""
import torch
from datasets import Audio, load_dataset

from arc_spice.dropout_utils.variational_inference import TTSVariationalPipeline


def main(TTS_params):
"""main function"""
var_pipe = TTSVariationalPipeline(TTS_params,n_variational_runs=2)

ds = load_dataset(
"facebook/multilingual_librispeech", "french", split="test", streaming=True
)
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]

var_pipe.clean_inference(input_speech["array"])
clean_output = var_pipe.clean_output

# logit shapes
print("\nLogit shapes:")
for step in var_pipe.pipeline_map.keys():
print(f"{step.capitalize()}: {clean_output[step]["logits"].shape}")

# entropy
print("\nMean entropy:")
for step in var_pipe.pipeline_map.keys():
print(f"{step.capitalize()}: {torch.mean(clean_output[step]["entropy"])}")

# normalised entropy
print("\nNormalised mean entropy:")
cumulative = 1
for step in var_pipe.pipeline_map.keys():
step_entropy = torch.mean(clean_output[step]["normalised_entropy"])
cumulative*= (1-step_entropy)
print(f"{step.capitalize()}: {step_entropy}")
print(f"Cumulative confidence (1 - entropy): {cumulative}")

# probabilities
print("\nMean top probabilities:")
cumulative = 1
for step in var_pipe.pipeline_map.keys():
step_prob = torch.mean(clean_output[step]["probs"])
cumulative *= step_prob
print(f"{step.capitalize()}: {step_prob}")
print(f"Cumulative confidence: {cumulative}")

print("\nConditional probabilities:")
for step in var_pipe.pipeline_map.keys():
token_probs = clean_output[step]["probs"]
cond_prob = torch.pow(torch.prod(token_probs,-1),1/len(token_probs))
print(f"{step.capitalize()}: {cond_prob}")

var_pipe.variational_inference(x=input_speech['array'])
variational_output = var_pipe.var_output
print("\nVariational Inference Semantic Density:")
for step in variational_output['variational'].keys():
print(f"{step}: {variational_output['variational'][step]['semantic_density']}")


if __name__ == "__main__":
TTS_pars = {
"transcriber": {
"specific_task": "automatic-speech-recognition",
"model": "jonatasgrosman/wav2vec2-large-xlsr-53-french",
},
"translator": {
"specific_task": "translation_fr_to_en",
"model": "ybanas/autotrain-fr-en-translate-51410121895",
},
"summariser": {
"specific_task": "summarization",
"model": "marianna13/flan-t5-base-summarization",
},
}
main(TTS_params=TTS_pars)
Empty file.
38 changes: 38 additions & 0 deletions src/arc_spice/dropout_utils/dropout_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from transformers import Pipeline, pipeline


def set_dropout(model: torch.nn.Module, dropout_flag: bool) -> None:
"""
Turn on or turn off dropout layers of a model.
Args:
model: pytorch model
dropout_flag: dropout -> True/False
"""
for _, param in model.named_modules():
if isinstance(param, torch.nn.Dropout):
# dropout on (True) -> want training mode train(True)
# dropout off (False) -> eval mode train(False)
param.train(dropout_flag)


def MCDropoutPipeline(task: str, model: str):
pl = pipeline(
task=task,
model=model,
)
initial_model = pl.model
pl.model = set_dropout(model=initial_model, dropout_flag=True)
return pl


def test_dropout(pipe: Pipeline, dropout_flag: bool):
model = pipe.model
dropout_count = 0
for _, param in model.named_modules():
if isinstance(param, torch.nn.Dropout):
dropout_count += 1
assert param.training == dropout_flag

print(f"{dropout_count} dropout layers found in correct configuration.")
Loading

0 comments on commit 7aadf7d

Please sign in to comment.