-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ feat(reddit-personas): creating agent personas from user data acroo…
…s various subreddits
- Loading branch information
Gayatri Krishnakumar
authored and
Gayatri Krishnakumar
committed
Dec 23, 2024
1 parent
9aaf6f1
commit 0e19faf
Showing
8 changed files
with
2,253 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -184,3 +184,5 @@ lib/ | |
clientcred.secret | ||
notes.txt | ||
.yamllint | ||
|
||
.venv |
30 changes: 30 additions & 0 deletions
30
examples/election/src/election_sim/sim_utils/reddit_personas/Post-processing_Cleaning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import json | ||
import re | ||
|
||
with open("/Users/gayatrikrishnakumar/Desktop/World_Adapter/persona.txt") as f: | ||
raw_data = f.read() | ||
|
||
data = raw_data.replace("```json", "").replace("```", "") | ||
data = data.strip() | ||
|
||
if not data.startswith("["): | ||
data = f"[{data}" | ||
if not data.endswith("]"): | ||
data = f"{data}]" | ||
|
||
data = re.sub(r"}\s*{", "},\n{", data) | ||
|
||
data = re.sub(r",\s*,", ",", data) | ||
|
||
data = re.sub(r",\s*\]", "]", data) | ||
|
||
try: | ||
parsed = json.loads(data) | ||
|
||
with open("cleaned_personas.json", "w") as f: | ||
json.dump(parsed, f, indent=4, ensure_ascii=False) | ||
print("JSON cleaned and saved to cleaned_personas.json") | ||
except json.JSONDecodeError as e: | ||
print("Error parsing JSON:", e) | ||
|
||
print(data) |
26 changes: 26 additions & 0 deletions
26
examples/election/src/election_sim/sim_utils/reddit_personas/Pre-processing_Cleaning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import json | ||
|
||
# Read data from an external JSON file | ||
input_file = "/Users/gayatrikrishnakumar/Desktop/World_Adapter/Persona Generation/posts_byuser.json" # Replace with your file path | ||
output_file = "formatted_comments.json" # Replace with your desired output file path | ||
|
||
# Load the data from the file | ||
with open(input_file) as file: | ||
data = json.load(file) | ||
|
||
# Grouping titles by author_id | ||
formatted_data = {} | ||
for entry in data: | ||
author_id = entry["author_id"] | ||
if author_id not in formatted_data: | ||
formatted_data[author_id] = {"author_id": author_id, "titles": []} | ||
formatted_data[author_id]["titles"].append(entry["title"]) | ||
|
||
# Converting the result into a list of objects | ||
result = list(formatted_data.values()) | ||
|
||
# Save the result into another JSON file | ||
with open(output_file, "w") as file: | ||
json.dump(result, file, indent=4) | ||
|
||
print(f"Formatted data has been saved to {output_file}") |
22 changes: 22 additions & 0 deletions
22
examples/election/src/election_sim/sim_utils/reddit_personas/Pre-scraping_Extraction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import json | ||
|
||
# Load the JSON file | ||
input_file = "/Users/gayatrikrishnakumar/Desktop/World_Adapter/Persona Generation/top_k_data.json" # Replace with your input JSON file name | ||
output_file = "author_names.txt" | ||
|
||
try: | ||
# Read the JSON data | ||
with open(input_file) as file: | ||
data = json.load(file) | ||
|
||
# Extract author names | ||
author_names = [entry.get("author_name") for entry in data if entry.get("author_name")] | ||
|
||
# Write author names to a .txt file | ||
with open(output_file, "w") as file: | ||
file.write("\n".join(author_names)) | ||
|
||
print(f"Author names successfully written to {output_file}") | ||
|
||
except Exception as e: | ||
print(f"An error occurred: {e}") |
86 changes: 86 additions & 0 deletions
86
examples/election/src/election_sim/sim_utils/reddit_personas/Ranking_Interaction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import json | ||
import sys | ||
|
||
|
||
def calculate_number_of_votes(score, upvote_ratio): | ||
# sourcery skip: assign-if-exp, reintroduce-else | ||
""" | ||
Calculate the number of votes based on score and upvote ratio. | ||
Number of Votes = Score / (2 * Upvote Ratio - 1) | ||
""" | ||
denominator = (2 * upvote_ratio) - 1 | ||
if denominator <= 0: | ||
return None | ||
return score / denominator | ||
|
||
|
||
def calculate_engagement_ratio(num_comments, number_of_votes): | ||
# sourcery skip: assign-if-exp, reintroduce-else | ||
""" | ||
Calculate the engagement ratio. | ||
Engagement Ratio = Number of Comments / Number of Votes | ||
""" | ||
if number_of_votes == 0: | ||
return None | ||
return num_comments / number_of_votes | ||
|
||
|
||
def process_posts(data): | ||
processed = [] | ||
for post in data: | ||
score = post.get("score", 0) | ||
upvote_ratio = post.get("upvote_ratio", 0) | ||
num_comments = post.get("num_comments", 0) | ||
|
||
number_of_votes = calculate_number_of_votes(score, upvote_ratio) | ||
|
||
if number_of_votes is None: | ||
continue | ||
|
||
engagement_ratio = calculate_engagement_ratio(num_comments, number_of_votes) | ||
|
||
if engagement_ratio is None: | ||
continue | ||
|
||
post["engagement_ratio"] = engagement_ratio | ||
processed.append(post) | ||
|
||
return processed | ||
|
||
|
||
def main(): | ||
input_file = "submissions.json" | ||
output_file = "sorted_data.json" | ||
|
||
try: | ||
with open(input_file) as f: | ||
data = json.load(f) | ||
except FileNotFoundError: | ||
print(f"Error: The file '{input_file}' was not found.") | ||
sys.exit(1) | ||
except json.JSONDecodeError: | ||
print(f"Error: The file '{input_file}' is not valid JSON.") | ||
sys.exit(1) | ||
|
||
if not isinstance(data, list): | ||
print("Error: JSON data is not a list of posts.") | ||
sys.exit(1) | ||
|
||
processed_posts = process_posts(data) | ||
|
||
if not processed_posts: | ||
print("No valid posts to process.") | ||
sys.exit(0) | ||
|
||
sorted_posts = sorted(processed_posts, key=lambda x: x["engagement_ratio"], reverse=True) | ||
|
||
with open(output_file, "w") as f: | ||
json.dump(sorted_posts, f, indent=4) | ||
|
||
print(f"Successfully sorted posts by engagement ratio and saved to '{output_file}'.") | ||
print(f"Total processed posts: {len(processed_posts)}") | ||
print(f"Total skipped posts due to invalid data: {len(data) - len(processed_posts)}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
107 changes: 107 additions & 0 deletions
107
examples/election/src/election_sim/sim_utils/reddit_personas/Reddit_Persona_Generation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import json | ||
|
||
from openai import OpenAI | ||
|
||
OPENAI_API_KEY = "" | ||
|
||
|
||
def gen_users(file_path, batch_size=5): | ||
system_message = ( | ||
"You're a psychologist working in social and political sciences. " | ||
"You have to extract information about people based on interactions they have with others, " | ||
"and use it to design personas for a simulation." | ||
) | ||
|
||
with open(file_path) as file: | ||
data = json.load(file) | ||
|
||
if not isinstance(data, list): | ||
raise ValueError("Expected the JSON file to contain a list of interactions.") | ||
|
||
persona_descriptor = """ | ||
{ | ||
"Name": "Assign a realistic and appropriate name", | ||
"User_Reference": "Mention which user the persona is based on", | ||
"Sex": "Specify the sex/gender of the persona", | ||
"Political_Identity": "Describe the political orientation (e.g., liberal, conservative, moderate)", | ||
"Big5_traits": { | ||
"Openness": "Rate on a scale from 1 to 10", | ||
"Conscientiousness": "Rate on a scale from 1 to 10", | ||
"Extraversion": "Rate on a scale from 1 to 10", | ||
"Agreeableness": "Rate on a scale from 1 to 10", | ||
"Neuroticism": "Rate on a scale from 1 to 10" | ||
}, | ||
"Schwartz_values": { | ||
"Self-Direction": "Rate on a scale from 1 to 10", | ||
"Stimulation": "Rate on a scale from 1 to 10", | ||
"Hedonism": "Rate on a scale from 1 to 10", | ||
"Achievement": "Rate on a scale from 1 to 10", | ||
"Power": "Rate on a scale from 1 to 10", | ||
"Security": "Rate on a scale from 1 to 10", | ||
"Conformity": "Rate on a scale from 1 to 10", | ||
"Tradition": "Rate on a scale from 1 to 10", | ||
"Benevolence": "Rate on a scale from 1 to 10", | ||
"Universalism": "Rate on a scale from 1 to 10" | ||
}, | ||
"Description based on interactions": "Make observations about the user based on their interactions" | ||
} | ||
Assign scores thoughtfully: Use the scales provided to assign scores that accurately reflect the persona's characteristics based on the interactions. | ||
Consistency: Ensure that the scores for personality traits align logically with the behavioral indicators. | ||
Detail Orientation: Provide enough detail in each descriptor to make the persona realistic and relatable. | ||
Realistic Names: Give them realistic names and mention which user the persona is based on. | ||
Only generate the output json objects for each person and NOTHING else | ||
""" | ||
|
||
client = OpenAI(api_key=OPENAI_API_KEY) | ||
|
||
total_items = len(data) | ||
batches = [data[i : i + batch_size] for i in range(0, total_items, batch_size)] | ||
|
||
# Open the output file once and append results after each batch | ||
with open("persona.txt", "w") as outfile: | ||
for batch_number, batch_data in enumerate(batches, start=1): | ||
batch_json = json.dumps(batch_data, ensure_ascii=False) | ||
prompt = ( | ||
f"{batch_json}\n\n" | ||
"Based on these interactions, I want you to create personas with the following descriptors:\n" | ||
f"{persona_descriptor}\n\n" | ||
"IMPORTANT: Return ONLY the JSON objects for each person, nothing else. " | ||
"No additional text, explanations, or formatting. Only the JSON." | ||
) | ||
|
||
try: | ||
response = client.chat.completions.create( | ||
model="gpt-4o", | ||
messages=[ | ||
{"role": "system", "content": system_message}, | ||
{"role": "user", "content": prompt}, | ||
], | ||
temperature=1, | ||
max_tokens=16383, | ||
top_p=1, | ||
frequency_penalty=0, | ||
presence_penalty=0, | ||
) | ||
except openai.RateLimitError as e: | ||
print(f"Rate limit error on batch {batch_number}: {e}") | ||
continue | ||
except openai.APIError as e: | ||
print(f"API error on batch {batch_number}: {e}") | ||
continue | ||
|
||
summary = response.choices[0].message.content.strip() | ||
|
||
# Print raw output to inspect | ||
print(f"=== RAW OUTPUT FROM MODEL (Batch {batch_number}) ===") | ||
print(summary) | ||
print("========================================") | ||
|
||
# Write the raw output directly to the text file | ||
outfile.write(summary + "\n") | ||
|
||
print("All persona information has been written to persona.txt") | ||
|
||
|
||
if __name__ == "__main__": | ||
gen_users("/Users/gayatrikrishnakumar/Desktop/World_Adapter/formatted_comments.json") |
28 changes: 28 additions & 0 deletions
28
examples/election/src/election_sim/sim_utils/reddit_personas/Top-k_Extraction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import json | ||
|
||
|
||
def extract_top_k(file_path, key, k, output_file): | ||
with open(file_path) as file: | ||
data = json.load(file) | ||
|
||
if key not in data[0]: | ||
print(f"The key '{key}' is not found in the JSON objects.") | ||
return | ||
|
||
sorted_data = sorted(data, key=lambda x: x[key], reverse=True) | ||
|
||
top_k_data = sorted_data[:k] | ||
|
||
with open(output_file, "w") as output: | ||
json.dump(top_k_data, output, indent=4) | ||
|
||
print(f"Top {k} JSON objects have been saved to '{output_file}'.") | ||
|
||
|
||
file_path = "sorted_data.json" | ||
key = "engagement_ratio" | ||
k = 100 | ||
output_file = "top_k_data.json" | ||
|
||
|
||
extract_top_k(file_path, key, k, output_file) |
Oops, something went wrong.