Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor attempt to utilized property annotations #1027

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 78 additions & 84 deletions garak/attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,95 +105,89 @@ def as_dict(self) -> dict:
"messages": self.messages,
}

def __getattribute__(self, name: str) -> Any:
"""override prompt and outputs access to take from history"""
if name == "prompt":
if len(self.messages) == 0: # nothing set
return None
if isinstance(self.messages[0], dict): # only initial prompt set
return self.messages[0]["content"]
if isinstance(
self.messages, list
): # there's initial prompt plus some history
return self.messages[0][0]["content"]
else:
raise ValueError(
"Message history of attempt uuid %s in unexpected state, sorry: "
% str(self.uuid)
+ repr(self.messages)
)
@property
def prompt(self):
if len(self.messages) == 0: # nothing set
return None
if isinstance(self.messages[0], dict): # only initial prompt set
return self.messages[0]["content"]
if isinstance(self.messages, list): # there's initial prompt plus some history
return self.messages[0][0]["content"]
else:
raise ValueError(
"Message history of attempt uuid %s in unexpected state, sorry: "
% str(self.uuid)
+ repr(self.messages)
)

elif name == "outputs":
if len(self.messages) and isinstance(self.messages[0], list):
# work out last_output_turn that was assistant
assistant_turns = [
@property
def outputs(self):
if len(self.messages) and isinstance(self.messages[0], list):
# work out last_output_turn that was assistant
assistant_turns = [
idx
for idx, val in enumerate(self.messages[0])
if val["role"] == "assistant"
]
if assistant_turns == []:
return []
last_output_turn = max(assistant_turns)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return []

@property
def latest_prompts(self):
if len(self.messages[0]) > 1:
# work out last_output_turn that was user
last_output_turn = max(
[
idx
for idx, val in enumerate(self.messages[0])
if val["role"] == "assistant"
if val["role"] == "user"
]
if assistant_turns == []:
return []
last_output_turn = max(assistant_turns)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return []

elif name == "latest_prompts":
if len(self.messages[0]) > 1:
# work out last_output_turn that was user
last_output_turn = max(
[
idx
for idx, val in enumerate(self.messages[0])
if val["role"] == "user"
]
)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return (
self.prompt
) # returning a string instead of a list tips us off that generation count is not yet known

elif name == "all_outputs":
all_outputs = []
if len(self.messages) and not isinstance(self.messages[0], dict):
for thread in self.messages:
for turn in thread:
if turn["role"] == "assistant":
all_outputs.append(turn["content"])
return all_outputs

else:
return super().__getattribute__(name)

def __setattr__(self, name: str, value: Any) -> None:
"""override prompt and outputs access to take from history NB. output elements need to be able to be None"""

if name == "prompt":
if value is None:
raise TypeError("'None' prompts are not valid")
self._add_first_turn("user", value)

elif name == "outputs":
if not (isinstance(value, list) or isinstance(value, GeneratorType)):
raise TypeError("Value for attempt.outputs must be a list or generator")
value = list(value)
if len(self.messages) == 0:
raise TypeError("A prompt must be set before outputs are given")
# do we have only the initial prompt? in which case, let's flesh out messages a bit
elif len(self.messages) == 1 and isinstance(self.messages[0], dict):
self._expand_prompt_to_histories(len(value))
# append each list item to each history, with role:assistant
self._add_turn("assistant", value)

elif name == "latest_prompts":
assert isinstance(value, list)
self._add_turn("user", value)

)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return super().__setattr__(name, value)
return (
self.prompt
) # returning a string instead of a list tips us off that generation count is not yet known

@property
def all_outputs(self):
all_outputs = []
if len(self.messages) and not isinstance(self.messages[0], dict):
for thread in self.messages:
for turn in thread:
if turn["role"] == "assistant":
all_outputs.append(turn["content"])
return all_outputs

@prompt.setter
def prompt(self, value):
if value is None:
raise TypeError("'None' prompts are not valid")
self._add_first_turn("user", value)

@outputs.setter
def outputs(self, value):
if not (isinstance(value, list) or isinstance(value, GeneratorType)):
raise TypeError("Value for attempt.outputs must be a list or generator")
value = list(value)
if len(self.messages) == 0:
raise TypeError("A prompt must be set before outputs are given")
# do we have only the initial prompt? in which case, let's flesh out messages a bit
elif len(self.messages) == 1 and isinstance(self.messages[0], dict):
self._expand_prompt_to_histories(len(value))
# append each list item to each history, with role:assistant
self._add_turn("assistant", value)

@latest_prompts.setter
def latest_prompts(self, value):
assert isinstance(value, list)
self._add_turn("user", value)

def _expand_prompt_to_histories(self, breadth):
"""expand a prompt-only message history to many threads"""
Expand Down
Loading