Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6 mc dropout in hf pipeline #9

Merged
merged 15 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,7 @@ Thumbs.db
# Common editor files
*~
*.swp

# other
temp
.vscode
79 changes: 79 additions & 0 deletions scripts/variational_TTS_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
An example use of the transcription, translation and summarisation pipeline.
"""
import torch
from datasets import Audio, load_dataset

from arc_spice.dropout_utils.variational_inference import TTSVariationalPipeline


def main(TTS_params):
"""main function"""
var_pipe = TTSVariationalPipeline(TTS_params,n_variational_runs=2)

ds = load_dataset(
"facebook/multilingual_librispeech", "french", split="test", streaming=True
)
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]

var_pipe.clean_inference(input_speech["array"])
clean_output = var_pipe.clean_output

# logit shapes
print("\nLogit shapes:")
for step in var_pipe.pipeline_map.keys():
print(f"{step.capitalize()}: {clean_output[step]["logits"].shape}")

# entropy
print("\nMean entropy:")
for step in var_pipe.pipeline_map.keys():
print(f"{step.capitalize()}: {torch.mean(clean_output[step]["entropy"])}")

# normalised entropy
print("\nNormalised mean entropy:")
cumulative = 1
for step in var_pipe.pipeline_map.keys():
step_entropy = torch.mean(clean_output[step]["normalised_entropy"])
cumulative*= (1-step_entropy)
print(f"{step.capitalize()}: {step_entropy}")
print(f"Cumulative confidence (1 - entropy): {cumulative}")

# probabilities
print("\nMean top probabilities:")
cumulative = 1
for step in var_pipe.pipeline_map.keys():
step_prob = torch.mean(clean_output[step]["probs"])
cumulative *= step_prob
print(f"{step.capitalize()}: {step_prob}")
print(f"Cumulative confidence: {cumulative}")

print("\nConditional probabilities:")
for step in var_pipe.pipeline_map.keys():
token_probs = clean_output[step]["probs"]
cond_prob = torch.pow(torch.prod(token_probs,-1),1/len(token_probs))
print(f"{step.capitalize()}: {cond_prob}")

var_pipe.variational_inference(x=input_speech['array'])
variational_output = var_pipe.var_output
print("\nVariational Inference Semantic Density:")
for step in variational_output['variational'].keys():
print(f"{step}: {variational_output['variational'][step]['semantic_density']}")


if __name__ == "__main__":
TTS_pars = {
"transcriber": {
"specific_task": "automatic-speech-recognition",
"model": "jonatasgrosman/wav2vec2-large-xlsr-53-french",
},
"translator": {
"specific_task": "translation_fr_to_en",
"model": "ybanas/autotrain-fr-en-translate-51410121895",
},
"summariser": {
"specific_task": "summarization",
"model": "marianna13/flan-t5-base-summarization",
},
}
main(TTS_params=TTS_pars)
Empty file.
38 changes: 38 additions & 0 deletions src/arc_spice/dropout_utils/dropout_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from transformers import Pipeline, pipeline


def set_dropout(model: torch.nn.Module, dropout_flag: bool) -> None:
"""
Turn on or turn off dropout layers of a model.

Args:
model: pytorch model
dropout_flag: dropout -> True/False
"""
for _, param in model.named_modules():
if isinstance(param, torch.nn.Dropout):
# dropout on (True) -> want training mode train(True)
# dropout off (False) -> eval mode train(False)
param.train(dropout_flag)


def MCDropoutPipeline(task: str, model: str):
pl = pipeline(
task=task,
model=model,
)
initial_model = pl.model
pl.model = set_dropout(model=initial_model, dropout_flag=True)
return pl


def test_dropout(pipe: Pipeline, dropout_flag: bool):
model = pipe.model
dropout_count = 0
for _, param in model.named_modules():
if isinstance(param, torch.nn.Dropout):
dropout_count += 1
assert param.training == dropout_flag

print(f"{dropout_count} dropout layers found in correct configuration.")
Loading
Loading