Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support large log file visualization #121

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions examples/browsergym/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ def process_obs_for_viz(obs: dict[str, any], verbose: bool = False):
obs["last_action"]
)

# FIXME: the screenshot is too large to be uploaded to the visualizer server; uncomment this when the issue is fixed
# processed_obs["screenshot"] = compress_base64_image(processed_obs["screenshot"])
processed_obs["screenshot"] = str(processed_obs["screenshot"])[:50]

if not verbose:
return {
"screenshot": processed_obs["screenshot"],
Expand Down Expand Up @@ -128,7 +124,9 @@ def browsergym_node_data_factory(x: MCTSNode, verbose: bool = False):
}


def browsergym_edge_data_factory(n: Union[MCTSNode, BeamSearchNode, DFSNode], verbose: bool = False) -> EdgeData:
def browsergym_edge_data_factory(
n: Union[MCTSNode, BeamSearchNode, DFSNode], verbose: bool = False
) -> EdgeData:
function_calls = highlevel_action_parser.search_string(n.action)
function_calls = sum(function_calls.as_list(), [])

Expand Down
14 changes: 12 additions & 2 deletions reasoners/visualization/tree_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ class Node:
id: NodeId
data: NodeData
selected_edge: Optional[EdgeId] = None
convert_images: bool = True

@dataclass
class Edge:
id: EdgeId
source: NodeId
target: NodeId
data: EdgeData
convert_images: bool = True

def __init__(self, nodes: Collection[Node], edges: Collection[Edge]) -> None:
self.nodes: dict[NodeId, TreeSnapshot.Node] = {node.id: node for node in nodes}
Expand Down Expand Up @@ -51,10 +53,18 @@ def edge(self, edge_id: EdgeId) -> Edge:
return self.edges[edge_id]

def out_edges(self, node_id: NodeId) -> Collection[Edge]:
return [self.edge(edge_id) for edge_id in self.edges if self.edge(edge_id).source == node_id]
return [
self.edge(edge_id)
for edge_id in self.edges
if self.edge(edge_id).source == node_id
]

def in_edges(self, node_id: NodeId) -> Collection[Edge]:
return [self.edge(edge_id) for edge_id in self.edges if self.edge(edge_id).target == node_id]
return [
self.edge(edge_id)
for edge_id in self.edges
if self.edge(edge_id).target == node_id
]

def parent(self, node_id: NodeId) -> NodeId:
return self._parent[node_id]
Expand Down
52 changes: 46 additions & 6 deletions reasoners/visualization/visualizer_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import json
import gzip
from typing import Optional, Union

import requests
Expand All @@ -15,6 +16,11 @@ class VisualizerClient:
def __init__(self, base_url: str = _API_DEFAULT_BASE_URL) -> None:
self.base_url = base_url

@dataclasses.dataclass
class UploadUrl:
upload_url: dict
file_name: dict

@dataclasses.dataclass
class TreeLogReceipt:
id: str
Expand All @@ -24,22 +30,54 @@ class TreeLogReceipt:
def access_url(self) -> str:
return f"{_VISUALIZER_DEFAULT_BASE_URL}/visualizer/{self.id}?accessKey={self.access_key}"

def post_log(self, data: Union[TreeLog, str, dict]) -> Optional[TreeLogReceipt]:
def get_upload_url(self) -> Optional[UploadUrl]:
print("Getting log upload link...")
url = f"{self.base_url}/logs/get-upload-url"
response = requests.get(url)
if response.status_code != 200:
print(
f"GET Upload URL failed with status code: {response.status_code}, message: {response.text}"
)
return None
return self.UploadUrl(**response.json())

def post_log(
self, data: Union[TreeLog, str, dict], upload_url: UploadUrl
) -> Optional[TreeLogReceipt]:
if isinstance(data, TreeLog):
data = json.dumps(data, cls=TreeLogEncoder)
if isinstance(data, dict):
data = json.dumps(data)

url = f"{self.base_url}/logs"
headers = {"Content-Type": "application/json"}
response = requests.post(url, headers=headers, data=data)
print(f"Tree log size: {len(data)} bytes")
data = gzip.compress(data.encode("utf-8"))
files = {"file": (upload_url.file_name, data)}

if response.status_code != 200:
print(f"Tree log compressed size: {len(data)} bytes")
print("Uploading log...")
response = requests.post(
upload_url.upload_url["url"],
data=upload_url.upload_url["fields"],
files=files,
)

if response.status_code != 200 and response.status_code != 204:
print(
f"POST Log failed with status code: {response.status_code}, message: {response.text}"
)
return None

response = requests.post(
f"{self.base_url}/logs/upload-complete",
json={"file_name": upload_url.file_name},
)

if response.status_code != 200:
print(
f"POST Upload Complete failed with status code: {response.status_code}, message: {response.text}"
)
return None

return self.TreeLogReceipt(**response.json())


Expand Down Expand Up @@ -68,7 +106,9 @@ def visualize(
else:
raise TypeError(f"Unsupported result type: {type(result)}")

receipt = VisualizerClient().post_log(tree_log)
client = VisualizerClient()
upload_url = client.get_upload_url()
receipt = client.post_log(tree_log, upload_url)

if receipt is not None:
present_visualizer(receipt)