Skip to content

Commit

Permalink
finishing up osworld benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
chuongnguyen26 committed Jan 6, 2025
1 parent 674476d commit d7f4695
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 476 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ experiments/*/output
# Analysis
agential/benchmarks/computer_use/osworld/vmware_vm_data

.zshrc

# code storage for testing purposes
code_storage
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import io

from typing import List, Tuple
from typing import List, Tuple, cast
from xml.etree.ElementTree import Element

from PIL import Image, ImageDraw, ImageFont
from PIL import Image, ImageFile, ImageDraw, ImageFont
from PIL.ImageFont import FreeTypeFont

state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"

Check failure on line 11 in agential/agents/osworld_baseline/accessibility_tree_wrap/heuristic_retrieve.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (I001)

agential/agents/osworld_baseline/accessibility_tree_wrap/heuristic_retrieve.py:3:1: I001 Import block is un-sorted or un-formatted

Check failure on line 11 in agential/agents/osworld_baseline/accessibility_tree_wrap/heuristic_retrieve.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (I001)

agential/agents/osworld_baseline/accessibility_tree_wrap/heuristic_retrieve.py:3:1: I001 Import block is un-sorted or un-formatted
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
Expand Down Expand Up @@ -182,7 +183,7 @@ def draw_bounding_boxes(

# Load the screenshot image
image_stream = io.BytesIO(image_file_content)
image = Image.open(image_stream)
image: Image.Image = Image.open(image_stream)
if float(down_sampling_ratio) != 1.0:
image = image.resize(
(
Expand All @@ -197,10 +198,11 @@ def draw_bounding_boxes(

try:
# Adjust the path to the font file you have or use a default one
font = ImageFont.truetype("arial.ttf", 15)
font: FreeTypeFont = ImageFont.truetype("arial.ttf", 15)
except IOError:
# Fallback to a basic font if the specified font can't be loaded
font = ImageFont.load_default()
cur_font = ImageFont.load_default()
font = cast(FreeTypeFont, cur_font)

index = 1

Expand Down Expand Up @@ -245,7 +247,7 @@ def draw_bounding_boxes(
coords[0],
bottom_right[1],
) # Adjust Y to be above the bottom right
text_bbox: Tuple[int, int, int, int] = draw.textbbox(
text_bbox: Tuple[float, float, float, float] = draw.textbbox(
text_position, str(index), font=font, anchor="lb"
)
# offset: int = bottom_right[1]-text_bbox[3]
Expand Down
2 changes: 1 addition & 1 deletion agential/benchmarks/computer_use/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Computer Use Base Benchmark."""
"""Computer Use Base Benchmark."""
Empty file.
Empty file.
106 changes: 0 additions & 106 deletions agential/benchmarks/computer_use/osworld/__init__.py
Original file line number Diff line number Diff line change
@@ -1,107 +1 @@
"""OSWorld Benchmark"""

Check failure on line 1 in agential/benchmarks/computer_use/osworld/__init__.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (D415)

agential/benchmarks/computer_use/osworld/__init__.py:1:1: D415 First line should end with a period, question mark, or exclamation point

Check failure on line 1 in agential/benchmarks/computer_use/osworld/__init__.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (D415)

agential/benchmarks/computer_use/osworld/__init__.py:1:1: D415 First line should end with a period, question mark, or exclamation point
from desktop_env.providers.base import VMManager, Provider
from typing import Tuple
import os

from typing import Tuple

from desktop_env.providers.base import Provider, VMManager


def initializer(
self,
provider_name: str = "vmware",
region: str = None,
path_to_vm: str = None,
snapshot_name: str = "init_state",
action_space: str = "computer_13",
cache_dir: str = "cache",
screen_size: Tuple[int] = (1920, 1080),
headless: bool = False,
require_a11y_tree: bool = True,
require_terminal: bool = False,
os_type: str = "Ubuntu",
):
"""Args:
provider_name (str): virtualization provider name, default to "vmware"
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
path_to_vm (str): path to .vmx file
snapshot_name (str): snapshot name to revert to, default to "init_state"
action_space (str): "computer_13" | "pyautogui"
cache_dir (str): cache directory to cache task-related stuffs like
reference file for evaluation
screen_size (Tuple[int]): screen size of the VM
headless (bool): whether to run the VM in headless mode
require_a11y_tree (bool): whether to require accessibility tree
require_terminal (bool): whether to require terminal output.
"""
# Initialize VM manager and vitualization provider
self.region = region

# Default
self.server_port = 5000
self.chromium_port = 9222
self.vnc_port = 8006
self.vlc_port = 8080
self.manager, self.provider = _create_vm_manager_and_provider(provider_name, region)

self.os_type = os_type

# Initialize environment variables
if path_to_vm:
self.path_to_vm = (
os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
if provider_name in {"vmware", "virtualbox"}
else path_to_vm
)
else:
self.path_to_vm = self.manager.get_vm_path(self.os_type, region)

self.snapshot_name = snapshot_name
self.cache_dir_base: str = cache_dir
# todo: add the logic to get the screen size from the VM
self.headless = headless
self.require_a11y_tree = require_a11y_tree
self.require_terminal = require_terminal

# Initialize emulator and controller
if provider_name != "docker": # Check if this is applicable to other VM providers
self._start_emulator()

# mode: human or machine
self.instruction = None
assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space # todo: refactor it to the ActType

# episodic stuffs, like counters, will be updated or reset
# when calling self.reset()
self._traj_no: int = -1
self._step_no: int = 0
self.action_history: List[Dict[str, any]] = []


def _create_vm_manager_and_provider(provider_name: str, region: str):
"""Factory function to get the Virtual Machine Manager and Provider instances based on the provided provider name."""
provider_name = provider_name.lower().strip()
if provider_name == "vmware":
from desktop_env.providers.vmware.manager import VMwareVMManager
from desktop_env.providers.vmware.provider import VMwareProvider

return VMwareVMManager(), VMwareProvider(region)
elif provider_name == "virtualbox":
from desktop_env.providers.virtualbox.manager import VirtualBoxVMManager
from desktop_env.providers.virtualbox.provider import VirtualBoxProvider

return VirtualBoxVMManager(), VirtualBoxProvider(region)
elif provider_name in ["aws", "amazon web services"]:
from desktop_env.providers.aws.manager import AWSVMManager
from desktop_env.providers.aws.provider import AWSProvider

return AWSVMManager(), AWSProvider(region)
elif provider_name == "azure":
from desktop_env.providers.azure.manager import AzureVMManager
from desktop_env.providers.azure.provider import AzureProvider

return AzureVMManager(), AzureProvider(region)
else:
raise NotImplementedError(f"{provider_name} not implemented!")
21 changes: 6 additions & 15 deletions agential/benchmarks/computer_use/osworld/osworld.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""OSWorld Benchmark."""

import os
import subprocess

from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict

from desktop_env.desktop_env import DesktopEnv

from agential.benchmarks.computer_use.base import BaseComputerUseBenchmark
from agential.benchmarks.computer_use.osworld import initializer

import os
import subprocess
Expand Down Expand Up @@ -76,10 +74,7 @@ class OSWorld(BaseComputerUseBenchmark):
Renders the environment's current state for visualization purposes.
"""

Check failure on line 75 in agential/benchmarks/computer_use/osworld/osworld.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (D205)

agential/benchmarks/computer_use/osworld/osworld.py:47:5: D205 1 blank line required between summary line and description

Check failure on line 75 in agential/benchmarks/computer_use/osworld/osworld.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (D205)

agential/benchmarks/computer_use/osworld/osworld.py:47:5: D205 1 blank line required between summary line and description

def __init__(
self,
**kwargs: Any
) -> None:
def __init__(self, **kwargs: Any) -> None:
"""
Initializes the OSWorld benchmark with the provided configuration parameters.
Expand All @@ -89,20 +84,16 @@ def __init__(
"""

Check failure on line 84 in agential/benchmarks/computer_use/osworld/osworld.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Ruff (D212)

agential/benchmarks/computer_use/osworld/osworld.py:78:9: D212 Multi-line docstring summary should start at the first line

Check failure on line 84 in agential/benchmarks/computer_use/osworld/osworld.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (D212)

agential/benchmarks/computer_use/osworld/osworld.py:78:9: D212 Multi-line docstring summary should start at the first line
super().__init__(**kwargs)

DesktopEnv.__init__ = initializer ## Temp and to be removed by Chuong

self.path_to_vm = kwargs.get("path_to_vm")

try:
if self.path_to_vm is not None and not os.path.exists(self.path_to_vm):
del kwargs["path_to_vm"]
self.env = DesktopEnv(**kwargs)
except:
try:
vmrun_command = ['vmrun', 'start', self.path_to_vm]
vmrun_command = f"vmrun start {self.path_to_vm}"
subprocess.run(vmrun_command, check=True)

self.env = DesktopEnv(path_to_vm=self.path_to_vm, **kwargs)

print("VM started successfully.")
self.env = DesktopEnv(**kwargs)
except subprocess.CalledProcessError as e:
print(f"Error occurred: {e}")

Expand Down
Loading

0 comments on commit d7f4695

Please sign in to comment.