Skip to content

Commit

Permalink
refactor attempt to utilized property annotations
Browse files Browse the repository at this point in the history
This removed the override of class setattr and getattr in favor
of the `property` annotations to enable deepcopy

Signed-off-by: Jeffrey Martin <[email protected]>
  • Loading branch information
jmartin-tech committed Nov 27, 2024
1 parent 5c34717 commit ca1a665
Showing 1 changed file with 78 additions and 84 deletions.
162 changes: 78 additions & 84 deletions garak/attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,95 +105,89 @@ def as_dict(self) -> dict:
"messages": self.messages,
}

def __getattribute__(self, name: str) -> Any:
"""override prompt and outputs access to take from history"""
if name == "prompt":
if len(self.messages) == 0: # nothing set
return None
if isinstance(self.messages[0], dict): # only initial prompt set
return self.messages[0]["content"]
if isinstance(
self.messages, list
): # there's initial prompt plus some history
return self.messages[0][0]["content"]
else:
raise ValueError(
"Message history of attempt uuid %s in unexpected state, sorry: "
% str(self.uuid)
+ repr(self.messages)
)
@property
def prompt(self):
if len(self.messages) == 0: # nothing set
return None
if isinstance(self.messages[0], dict): # only initial prompt set
return self.messages[0]["content"]
if isinstance(self.messages, list): # there's initial prompt plus some history
return self.messages[0][0]["content"]
else:
raise ValueError(
"Message history of attempt uuid %s in unexpected state, sorry: "
% str(self.uuid)
+ repr(self.messages)
)

elif name == "outputs":
if len(self.messages) and isinstance(self.messages[0], list):
# work out last_output_turn that was assistant
assistant_turns = [
@property
def outputs(self):
if len(self.messages) and isinstance(self.messages[0], list):
# work out last_output_turn that was assistant
assistant_turns = [
idx
for idx, val in enumerate(self.messages[0])
if val["role"] == "assistant"
]
if assistant_turns == []:
return []
last_output_turn = max(assistant_turns)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return []

@property
def latest_prompts(self):
if len(self.messages[0]) > 1:
# work out last_output_turn that was user
last_output_turn = max(
[
idx
for idx, val in enumerate(self.messages[0])
if val["role"] == "assistant"
if val["role"] == "user"
]
if assistant_turns == []:
return []
last_output_turn = max(assistant_turns)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return []

elif name == "latest_prompts":
if len(self.messages[0]) > 1:
# work out last_output_turn that was user
last_output_turn = max(
[
idx
for idx, val in enumerate(self.messages[0])
if val["role"] == "user"
]
)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return (
self.prompt
) # returning a string instead of a list tips us off that generation count is not yet known

elif name == "all_outputs":
all_outputs = []
if len(self.messages) and not isinstance(self.messages[0], dict):
for thread in self.messages:
for turn in thread:
if turn["role"] == "assistant":
all_outputs.append(turn["content"])
return all_outputs

else:
return super().__getattribute__(name)

def __setattr__(self, name: str, value: Any) -> None:
"""override prompt and outputs access to take from history NB. output elements need to be able to be None"""

if name == "prompt":
if value is None:
raise TypeError("'None' prompts are not valid")
self._add_first_turn("user", value)

elif name == "outputs":
if not (isinstance(value, list) or isinstance(value, GeneratorType)):
raise TypeError("Value for attempt.outputs must be a list or generator")
value = list(value)
if len(self.messages) == 0:
raise TypeError("A prompt must be set before outputs are given")
# do we have only the initial prompt? in which case, let's flesh out messages a bit
elif len(self.messages) == 1 and isinstance(self.messages[0], dict):
self._expand_prompt_to_histories(len(value))
# append each list item to each history, with role:assistant
self._add_turn("assistant", value)

elif name == "latest_prompts":
assert isinstance(value, list)
self._add_turn("user", value)

)
# return these (via list compr)
return [m[last_output_turn]["content"] for m in self.messages]
else:
return super().__setattr__(name, value)
return (
self.prompt
) # returning a string instead of a list tips us off that generation count is not yet known

@property
def all_outputs(self):
all_outputs = []
if len(self.messages) and not isinstance(self.messages[0], dict):
for thread in self.messages:
for turn in thread:
if turn["role"] == "assistant":
all_outputs.append(turn["content"])
return all_outputs

@prompt.setter
def prompt(self, value):
if value is None:
raise TypeError("'None' prompts are not valid")
self._add_first_turn("user", value)

@outputs.setter
def outputs(self, value):
if not (isinstance(value, list) or isinstance(value, GeneratorType)):
raise TypeError("Value for attempt.outputs must be a list or generator")
value = list(value)
if len(self.messages) == 0:
raise TypeError("A prompt must be set before outputs are given")
# do we have only the initial prompt? in which case, let's flesh out messages a bit
elif len(self.messages) == 1 and isinstance(self.messages[0], dict):
self._expand_prompt_to_histories(len(value))
# append each list item to each history, with role:assistant
self._add_turn("assistant", value)

@latest_prompts.setter
def latest_prompts(self, value):
assert isinstance(value, list)
self._add_turn("user", value)

def _expand_prompt_to_histories(self, breadth):
"""expand a prompt-only message history to many threads"""
Expand Down

0 comments on commit ca1a665

Please sign in to comment.