From 26122d51c39136a34097e2f323cbc9bffb1ad423 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Fri, 25 Oct 2024 15:44:43 +0800 Subject: [PATCH 1/5] Update litellm_wrapper.py --- optillm/litellm_wrapper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/optillm/litellm_wrapper.py b/optillm/litellm_wrapper.py index 9d19b09..5ac6232 100644 --- a/optillm/litellm_wrapper.py +++ b/optillm/litellm_wrapper.py @@ -24,7 +24,10 @@ class Chat: class Completions: @staticmethod def create(model: str, messages: List[Dict[str, str]], **kwargs): - response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS) + if model.startswith("gemini"): + response = completion(model=model, messages=messages, **kwargs, safety_settings=SAFETY_SETTINGS) + else: + response = completion(model=model, messages=messages, **kwargs) # Convert LiteLLM response to match OpenAI response structure return response From 996e48dba957a50d906aaf6f1135aa7b631544c8 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 28 Oct 2024 06:49:43 +0800 Subject: [PATCH 2/5] Update entropy_decoding.py remove second softmax --- optillm/entropy_decoding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optillm/entropy_decoding.py b/optillm/entropy_decoding.py index 30146d7..33fa4a2 100644 --- a/optillm/entropy_decoding.py +++ b/optillm/entropy_decoding.py @@ -27,7 +27,8 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup return entropy, varentropy def calculate_attention_metrics(attention_scores: torch.Tensor) -> Dict[str, torch.Tensor]: - attention_probs = F.softmax(attention_scores, dim=-1) + # attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_scores attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1) attn_varentropy = torch.var(attn_entropy, dim=-1) From 129ac8090a4a65cb4fa6edb1cefe71974f09c782 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 28 Oct 2024 07:03:29 +0800 Subject: [PATCH 3/5] Update entropy_decoding.py update attention scores --- optillm/entropy_decoding.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/optillm/entropy_decoding.py b/optillm/entropy_decoding.py index 33fa4a2..bbc3de5 100644 --- a/optillm/entropy_decoding.py +++ b/optillm/entropy_decoding.py @@ -26,18 +26,28 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup varentropy = torch.sum(probs * (log_probs / LN_2 + entropy.unsqueeze(-1))**2, dim=axis) return entropy, varentropy -def calculate_attention_metrics(attention_scores: torch.Tensor) -> Dict[str, torch.Tensor]: - # attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_scores +def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]: + # attention_weights are already probabilities (post-softmax) + attention_probs = attention_weights + + # Calculate entropy attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1) + + # Calculate variance of entropy attn_varentropy = torch.var(attn_entropy, dim=-1) + attn_varentropy = torch.where(torch.isnan(attn_varentropy), + torch.zeros_like(attn_varentropy), + attn_varentropy) - attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy) + # Calculate mean attention and agreement mean_attention = torch.mean(attention_probs, dim=1) agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2)) - - interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)) - + + # For interaction strength, we can use log probabilities to approximate the original scores + # This maintains the relative relationships while providing a reasonable proxy for attention strength + attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0)) + interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3)) + return { "attn_entropy": torch.mean(attn_entropy), "attn_varentropy": torch.mean(attn_varentropy), From 97265191b0b2a1e805adaf70fe1d648730611d69 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 28 Oct 2024 07:21:32 +0800 Subject: [PATCH 4/5] Update entropy_decoding.py --- optillm/entropy_decoding.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/optillm/entropy_decoding.py b/optillm/entropy_decoding.py index bbc3de5..3a768fc 100644 --- a/optillm/entropy_decoding.py +++ b/optillm/entropy_decoding.py @@ -27,24 +27,26 @@ def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tup return entropy, varentropy def calculate_attention_metrics(attention_weights: torch.Tensor) -> Dict[str, torch.Tensor]: - # attention_weights are already probabilities (post-softmax) attention_probs = attention_weights # Calculate entropy attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1) - # Calculate variance of entropy - attn_varentropy = torch.var(attn_entropy, dim=-1) + # Calculate variance of entropy with unbiased=False to avoid df issues + # Also add a check for singleton dimensions + if attn_entropy.size(-1) > 1: + attn_varentropy = torch.var(attn_entropy, dim=-1, unbiased=False) + else: + attn_varentropy = torch.zeros_like(attn_entropy) + attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy) - # Calculate mean attention and agreement + # Rest remains the same mean_attention = torch.mean(attention_probs, dim=1) agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2)) - # For interaction strength, we can use log probabilities to approximate the original scores - # This maintains the relative relationships while providing a reasonable proxy for attention strength attention_scores_proxy = torch.log(torch.clamp(attention_probs, 1e-10, 1.0)) interaction_strength = torch.mean(torch.abs(attention_scores_proxy), dim=(1, 2, 3)) From 0ebae20d323f0d01ead3ecac7e9ebdc35ba3584b Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 28 Oct 2024 07:51:54 +0800 Subject: [PATCH 5/5] Update setup.py Bump version for release --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ae76402..7d91dfe 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="optillm", - version="0.0.6", + version="0.0.7", packages=find_packages(), py_modules=['optillm'], package_data={