diff --git a/CLAUDE.md b/CLAUDE.md index 3a6fec8c..410f7d5f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -125,13 +125,12 @@ models.register("my-model", custom_model_instance) ### Agent OS Abstraction -`AgentOs` provides an abstraction layer for OS-level operations: +`ComputerAgentOS` provides an abstraction layer for OS-level operations: ``` -AgentOs (Abstract Interface) - ├── AskUiControllerClient (gRPC to AskUI Agent OS - primary) +ComputerAgentOS (Abstract Interface) + ├── MultiComputerTargetAgentOS (gRPC to AskUI Agent OS - primary) ├── PlaywrightAgentOs (Web browser automation) - └── AndroidAgentOs (Android ADB) ``` ### Locator System @@ -175,7 +174,7 @@ Tools are auto-discovered and can be dynamically loaded via MCP configurations. - `src/askui/prompts/` - System prompts for different models ### Tools & OS -- `src/askui/tools/agent_os.py` - Abstract `AgentOs` interface +- `src/askui/tools/agent_os.py` - Abstract `ComputerAgentOS` interface - `src/askui/tools/askui/` - gRPC client for AskUI Agent OS - `src/askui/tools/android/` - Android-specific tools - `src/askui/tools/playwright/` - Web automation tools @@ -247,7 +246,7 @@ When writing or updating documentation in `docs/`: ## Important Patterns ### Composition over Inheritance -- `AgentToolbox` wraps `AgentOs` implementations +- `AgentToolbox` wraps `ComputerAgentOS` implementations - `ModelRouter` composes multiple model providers - `CompositeReporter` aggregates multiple reporters @@ -261,7 +260,7 @@ When writing or updating documentation in `docs/`: - Retry strategies with exponential backoff ### Adapter Pattern -- `AgentOs` abstraction bridges OS implementations (gRPC, Playwright, ADB) +- `ComputerAgentOS` abstraction bridges OS implementations (gRPC, Playwright, ADB) - `ModelFacade` adapts models to `ActModel`/`GetModel`/`LocateModel` interfaces ### Dependency Injection @@ -299,13 +298,13 @@ When writing or updating documentation in `docs/`: ### Adding Custom Tools 1. Implement `Tool` protocol in `models/shared/tools.py` 2. Register in appropriate MCP server (`api/mcp_servers/{type}.py`) -3. Use `@auto_inject_agent_os` for AgentOs dependency +3. Use `@auto_inject_agent_os` for ComputerAgentOS dependency 4. Follow Pydantic schema validation ### Adding New Agent Types 1. Inherit from `Agent` 2. Implement required abstract methods -3. Provide appropriate `AgentOs` implementation +3. Provide appropriate `ComputerAgentOS` implementation 4. Register in agent factory if needed ## Performance & Caching diff --git a/docs/07_tools.md b/docs/07_tools.md index ffb8552c..743bf971 100644 --- a/docs/07_tools.md +++ b/docs/07_tools.md @@ -68,7 +68,7 @@ Work with any agent type, no special dependencies required. #### Computer Tools (`computer/`) -Require `AgentOs` and work with `ComputerAgent` for desktop automation. +Require `ComputerAgentOS` and work with `ComputerAgent` for desktop automation. **Examples:** - `ComputerSaveScreenshotTool(base_dir)` - Save screenshots to disk diff --git a/mypy.ini b/mypy.ini index cfb75eb0..7a8a99d0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -16,6 +16,8 @@ plugins = pydantic.mypy,sqlalchemy.ext.mypy.plugin exclude = (?x)( ^src/askui/models/ui_tars_ep/ui_tars_api\.py$ | ^src/askui/tools/askui/askui_ui_controller_grpc/.*$ + | ^venv/.*$ + | ^\.venv/.*$ ) mypy_path = src:tests explicit_package_bases = true diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 858cf1a5..9d74baf4 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -45,6 +45,7 @@ from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .retry import ConfigurableRetry, Retry from .tools import ModifierKey, PcKey +from .tools.askui import LocalComputerTarget, RemoteComputerTarget from .utils.image_utils import ImageSource from .utils.source_utils import InputSource @@ -69,6 +70,8 @@ logging.getLogger(__name__).addHandler(logging.NullHandler()) __all__ = [ + "RemoteComputerTarget", + "LocalComputerTarget", "Agent", "AutomationError", "ComputerAgent", diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 5775a9c0..70d2d3cb 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -26,7 +26,7 @@ from askui.models.shared.truncation_strategies import TruncationStrategy from askui.prompts.act_prompts import CACHE_USE_PROMPT, create_default_prompt from askui.telemetry.otel import OtelSettings, setup_opentelemetry_tracing -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS from askui.tools.android.agent_os import AndroidAgentOs from askui.tools.caching_tools import ( InspectCacheMetadata, @@ -57,7 +57,7 @@ def __init__( reporter: Reporter | None = None, retry: Retry | None = None, tools: list[Tool] | None = None, - agent_os: AgentOs | AndroidAgentOs | None = None, + agent_os: ComputerAgentOS | AndroidAgentOs | None = None, settings: AgentSettings | None = None, callbacks: list[ConversationCallback] | None = None, truncation_strategy: TruncationStrategy | None = None, diff --git a/src/askui/computer_agent.py b/src/askui/computer_agent.py index ad0a6627..34b5ea2f 100644 --- a/src/askui/computer_agent.py +++ b/src/askui/computer_agent.py @@ -17,11 +17,13 @@ create_computer_agent_prompt, ) from askui.tools.computer import ( + ComputerGetCurrentComputerTargetIdTool, ComputerGetMousePositionTool, ComputerGetSystemInfoTool, ComputerKeyboardPressedTool, ComputerKeyboardReleaseTool, ComputerKeyboardTapTool, + ComputerListAgentOsTargetComputersTool, ComputerListDisplaysTool, ComputerMouseClickTool, ComputerMouseHoldDownTool, @@ -31,6 +33,7 @@ ComputerRetrieveActiveDisplayTool, ComputerScreenshotTool, ComputerSetActiveDisplayTool, + ComputerSwitchAgentOsTargetComputerTool, ComputerTypeTool, ) from askui.tools.exception_tool import ExceptionTool @@ -38,7 +41,7 @@ from .reporting import CompositeReporter, Reporter from .retry import Retry from .tools import AgentToolbox, ComputerAgentOsFacade, ModifierKey, PcKey -from .tools.askui import AskUiControllerClient +from .tools.askui import ComputerTarget, MultiComputerTargetAgentOS logger = logging.getLogger(__name__) @@ -50,10 +53,30 @@ class ComputerAgent(Agent): This agent can perform various UI interactions like clicking, typing, scrolling, and more. It uses computer vision models to locate UI elements and execute actions on them. + A single `ComputerAgent` can drive **one or more machines** through the + `agent_os_target_computers` argument. Each entry is an Agent OS target + computer (local subprocess or remote gRPC endpoint) identified by a stable + `computer_id`. At any moment one target is *active* and receives all + explicit calls (`click`, `type`, `keyboard`, ...). The active target can be + changed at runtime via + `agent.tools.os.switch_agent_os_target_computer(computer_id)` or scoped to a + block using `agent.tools.os.temporary_select(computer_id)`. The `act()` + model is also given list/switch/get-current tools so it can orchestrate + work across machines on its own (e.g. read something on one computer and + re-enter it on another). + Args: - display (int, optional): The display number to use for screen interactions. Defaults to `1`. + display (int, optional): The display number to use for screen interactions on the default local target. Ignored when `agent_os_target_computers` is provided. Defaults to `1`. reporters (list[Reporter] | None, optional): List of reporter instances for logging and reporting. If `None`, an empty list is used. - tools (AgentToolbox | None, optional): Custom toolbox instance. If `None`, a default one will be created with `AskUiControllerClient`. + agent_os_target_computers (list[ComputerTarget] | None, optional): + Target computers the agent can route actions to. May mix one + `LocalComputerTarget` (managing a controller subprocess on this + machine) with any number of `RemoteComputerTarget`s pointing at + controllers already running on other machines. Constraints: at + least one target, at most one local, and remote `address`es plus + all `computer_id`s must be unique. The first entry becomes the + initial active target. Defaults to a single local target bound to + `display`. settings (AgentSettings | None, optional): Provider-based model settings. If `None`, uses the default AskUI model stack. retry (Retry, optional): The retry instance to use for retrying failed actions. Defaults to `ConfigurableRetry` with exponential backoff. Currently only supported for `locate()` method. act_tools (list[Tool] | None, optional): Additional tools to make available for @@ -61,6 +84,8 @@ class ComputerAgent(Agent): via `act(..., tools=[...])` (see example below). Example: + Single local machine (the default): + ```python from askui import ComputerAgent @@ -70,6 +95,36 @@ class ComputerAgent(Agent): agent.act("Open settings menu") ``` + Example: + Research on one machine and write up the findings on another. The + first target in the list is the active one; `temporary_select` + re-routes a block of explicit calls and restores the previous + active target on exit. + + ```python + from askui import ComputerAgent + from askui.tools.askui import LocalComputerTarget, RemoteComputerTarget + + with ComputerAgent( + agent_os_target_computers=[ + LocalComputerTarget(computer_id="research-box"), + RemoteComputerTarget( + address="192.168.1.42:26000", + description="Writer box with a text editor open", + computer_id="writer-box", + ), + ], + ) as agent: + agent.act( + "On research-box, open a browser, google 'askui', and read " + "the top results to gather key facts about what AskUI is, " + "what it does, and notable features. Then switch to " + "writer-box and write a Markdown document titled " + "'AskUI Findings' summarizing those facts as a bulleted " + "list in the open text editor." + ) + ``` + Example (optional tools for `act()`): Register tools from `askui.tools.store` (or your own `Tool` implementations) either on the agent so they apply to all `act()` calls, or only for one call. @@ -94,11 +149,11 @@ class ComputerAgent(Agent): @telemetry.record_call( exclude={ "reporters", - "tools", "settings", "act_tools", "callbacks", "truncation_strategy", + "agent_os_target_computers", } ) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @@ -106,7 +161,7 @@ def __init__( self, display: Annotated[int, Field(ge=1)] = 1, reporters: list[Reporter] | None = None, - tools: AgentToolbox | None = None, + agent_os_target_computers: list[ComputerTarget] | None = None, settings: AgentSettings | None = None, retry: Retry | None = None, act_tools: list[Tool] | None = None, @@ -114,10 +169,11 @@ def __init__( truncation_strategy: TruncationStrategy | None = None, ) -> None: reporter = CompositeReporter(reporters=reporters) - self.tools = tools or AgentToolbox( - agent_os=AskUiControllerClient( + self.tools = AgentToolbox( + agent_os=MultiComputerTargetAgentOS( display=display, reporter=reporter, + agent_os_target_computers=agent_os_target_computers, ) ) super().__init__( @@ -500,8 +556,8 @@ def cli( with ComputerAgent() as agent: # Use for Windows - agent.cli(r'start "" "C:\Program Files\VideoLAN\VLC\vlc.exe"') # Start in VLC non-blocking - agent.cli(r'"C:\Program Files\VideoLAN\VLC\vlc.exe"') # Start in VLC blocking + agent.cli(r'start "" "C:\\Program Files\\VideoLAN\\VLC\\vlc.exe"') # Start in VLC non-blocking + agent.cli(r'"C:\\Program Files\\VideoLAN\\VLC\\vlc.exe"') # Start in VLC blocking # Mac agent.cli("open -a chrome") # Open Chrome non-blocking for mac @@ -541,6 +597,9 @@ def get_default_tools() -> list[Tool]: ComputerListDisplaysTool(), ComputerRetrieveActiveDisplayTool(), ComputerSetActiveDisplayTool(), + ComputerListAgentOsTargetComputersTool(), + ComputerSwitchAgentOsTargetComputerTool(), + ComputerGetCurrentComputerTargetIdTool(), ] diff --git a/src/askui/models/shared/android_base_tool.py b/src/askui/models/shared/android_base_tool.py index 5fc1c90b..fe4942bf 100644 --- a/src/askui/models/shared/android_base_tool.py +++ b/src/askui/models/shared/android_base_tool.py @@ -2,7 +2,7 @@ from askui.models.shared.tool_tags import ToolTags from askui.models.shared.tools import ToolWithAgentOS -from askui.tools import AgentOs +from askui.tools import ComputerAgentOS from askui.tools.agent_os_type_error import AgentOsTypeError from askui.tools.android.agent_os import AndroidAgentOs @@ -41,11 +41,11 @@ def agent_os(self) -> AndroidAgentOs: return agent_os @agent_os.setter - def agent_os(self, agent_os: AgentOs | AndroidAgentOs) -> None: + def agent_os(self, agent_os: ComputerAgentOS | AndroidAgentOs) -> None: """Set the agent OS. Args: - agent_os (AgentOs | AndroidAgentOs): The agent OS instance to set. + agent_os (ComputerAgentOS | AndroidAgentOs): The agent OS instance to set. Raises: TypeError: If the agent OS is not an AndroidAgentOs instance. diff --git a/src/askui/models/shared/computer_base_tool.py b/src/askui/models/shared/computer_base_tool.py index 0b6f13be..10a45d90 100644 --- a/src/askui/models/shared/computer_base_tool.py +++ b/src/askui/models/shared/computer_base_tool.py @@ -2,17 +2,17 @@ from askui.models.shared.tool_tags import ToolTags from askui.models.shared.tools import ToolWithAgentOS -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS from askui.tools.agent_os_type_error import AgentOsTypeError from askui.tools.android.agent_os import AndroidAgentOs class ComputerBaseTool(ToolWithAgentOS): - """Tool base class that has an AgentOs available.""" + """Tool base class that has a ComputerAgentOS available.""" def __init__( self, - agent_os: AgentOs | None = None, + agent_os: ComputerAgentOS | None = None, required_tags: list[str] | None = None, **kwargs: Any, ) -> None: @@ -23,33 +23,34 @@ def __init__( ) @property - def agent_os(self) -> AgentOs: + def agent_os(self) -> ComputerAgentOS: """Get the agent OS. Returns: - AgentOs: The agent OS instance. + ComputerAgentOS: The agent OS instance. """ agent_os = super().agent_os - if not isinstance(agent_os, AgentOs): + if not isinstance(agent_os, ComputerAgentOS): raise AgentOsTypeError( - expected_type=AgentOs, + expected_type=ComputerAgentOS, actual_type=type(agent_os), ) return agent_os @agent_os.setter - def agent_os(self, agent_os: AgentOs | AndroidAgentOs) -> None: + def agent_os(self, agent_os: ComputerAgentOS | AndroidAgentOs) -> None: """Set the agent OS facade. Args: - agent_os (AgentOs | AndroidAgentOs): The agent OS facade instance to set. + agent_os (ComputerAgentOS | AndroidAgentOs): The agent OS facade + instance to set. Raises: - TypeError: If the agent OS is not an AgentOs instance. + TypeError: If the agent OS is not a ComputerAgentOS instance. """ - if not isinstance(agent_os, AgentOs): + if not isinstance(agent_os, ComputerAgentOS): raise AgentOsTypeError( - expected_type=AgentOs, + expected_type=ComputerAgentOS, actual_type=type(agent_os), ) self._agent_os = agent_os diff --git a/src/askui/models/shared/playwright_base_tool.py b/src/askui/models/shared/playwright_base_tool.py index 1415c99a..da2772a0 100644 --- a/src/askui/models/shared/playwright_base_tool.py +++ b/src/askui/models/shared/playwright_base_tool.py @@ -2,18 +2,18 @@ from askui.models.shared.tool_tags import ToolTags from askui.models.shared.tools import ToolWithAgentOS -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS from askui.tools.agent_os_type_error import AgentOsTypeError from askui.tools.android.agent_os import AndroidAgentOs from askui.tools.playwright.agent_os import PlaywrightAgentOs class PlaywrightBaseTool(ToolWithAgentOS): - """Tool base class that has an the Playwright AgentOs available.""" + """Tool base class that has a Playwright ComputerAgentOS available.""" def __init__( self, - agent_os: AgentOs | None = None, + agent_os: ComputerAgentOS | None = None, required_tags: list[str] | None = None, **kwargs: Any, ) -> None: @@ -39,12 +39,14 @@ def agent_os(self) -> PlaywrightAgentOs: return agent_os @agent_os.setter - def agent_os(self, agent_os: AgentOs | AndroidAgentOs | PlaywrightAgentOs) -> None: + def agent_os( + self, agent_os: ComputerAgentOS | AndroidAgentOs | PlaywrightAgentOs + ) -> None: """Set the agent OS. Args: - agent_os (AgentOs | AndroidAgentOs | PlaywrightAgentOs): The agent OS - instance to set. + agent_os (ComputerAgentOS | AndroidAgentOs | PlaywrightAgentOs): The + agent OS instance to set. Raises: TypeError: If the agent OS is not an `PlaywrightAgentOs` instance. diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index 74912911..22df2631 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -31,7 +31,7 @@ ToolResultBlockParam, ToolUseBlockParam, ) -from askui.tools import AgentOs +from askui.tools import ComputerAgentOS from askui.tools.android.agent_os import AndroidAgentOs from askui.utils.image_utils import ImageSource, base64_to_image @@ -349,23 +349,23 @@ def __call__(self, *args: Any, **kwargs: Any) -> ToolCallResult: class ToolWithAgentOS(Tool): - """Tool base class that has an AgentOs available.""" + """Tool base class that has a ComputerAgentOS available.""" def __init__( self, required_tags: list[str], - agent_os: AgentOs | AndroidAgentOs | None = None, + agent_os: ComputerAgentOS | AndroidAgentOs | None = None, **kwargs: Any, ) -> None: super().__init__(**kwargs, required_tags=required_tags) - self._agent_os: AgentOs | AndroidAgentOs | None = agent_os + self._agent_os: ComputerAgentOS | AndroidAgentOs | None = agent_os @property - def agent_os(self) -> AgentOs | AndroidAgentOs: - """Get the agent OS. + def agent_os(self) -> ComputerAgentOS | AndroidAgentOs: + """Get the AgentOS. Returns: - AgentOs | AndroidAgentOs: The agent OS instance. + ComputerAgentOS | AndroidAgentOs: The AgentOS instance. """ if self._agent_os is None: msg = ( @@ -377,11 +377,11 @@ def agent_os(self) -> AgentOs | AndroidAgentOs: return self._agent_os @agent_os.setter - def agent_os(self, agent_os: AgentOs | AndroidAgentOs) -> None: + def agent_os(self, agent_os: ComputerAgentOS | AndroidAgentOs) -> None: self._agent_os = agent_os def is_agent_os_initialized(self) -> bool: - """Check if the agent OS is initialized.""" + """Check if the AgentOS is initialized.""" return self._agent_os is not None @@ -460,21 +460,21 @@ def __init__( tools: list[Tool] | None = None, mcp_client: McpClientProtocol | None = None, include: set[str] | None = None, - agent_os_list: list[AgentOs | AndroidAgentOs] | None = None, + agent_os_list: list[ComputerAgentOS | AndroidAgentOs] | None = None, ) -> None: self._mcp_client = mcp_client self._include = include - self._agent_os_list: list[AgentOs | AndroidAgentOs] = [] + self._agent_os_list: list[ComputerAgentOS | AndroidAgentOs] = [] self._tools: list[Tool] = tools or [] if agent_os_list: for agent_os in agent_os_list: self.add_agent_os(agent_os) - def add_agent_os(self, agent_os: AgentOs | AndroidAgentOs) -> None: - """Add an agent OS to the collection. + def add_agent_os(self, agent_os: ComputerAgentOS | AndroidAgentOs) -> None: + """Add an AgentOS to the collection. Args: - agent_os (AgentOs | AndroidAgentOs): The agent OS instance to add. + agent_os (ComputerAgentOS | AndroidAgentOs): The AgentOS instance to add. """ self._agent_os_list.append(agent_os) @@ -534,12 +534,23 @@ def reset_tools(self, tools: list[Tool] | None = None) -> None: """Reset the tools in the collection with new tools.""" self._tools = tools or [] - def get_agent_os_by_tags(self, tags: list[str]) -> AgentOs | AndroidAgentOs: - """Get an agent OS by tags.""" + def get_agent_os_by_tags( + self, required_tags: list[str] + ) -> ComputerAgentOS | AndroidAgentOs: + """ + Find the first registered AgentOS whose tags are a superset of + `required_tags`. + + Every tag in `required_tags` must appear in the AgentOS's tags; the + AgentOS may declare additional tags beyond those. + + Raises: + ValueError: when no registered AgentOS satisfies the required tags. + """ for agent_os in self._agent_os_list: - if all(tag in agent_os.tags for tag in tags): + if all(required in agent_os.tags for required in required_tags): return agent_os - msg = f"Agent OS with tags [{', '.join(tags)}] not found" + msg = f"No AgentOS satisfies required tags [{', '.join(required_tags)}]" raise ValueError(msg) def _initialize_tools(self) -> None: diff --git a/src/askui/tools/__init__.py b/src/askui/tools/__init__.py index ecd5bf24..c0f0dcc4 100644 --- a/src/askui/tools/__init__.py +++ b/src/askui/tools/__init__.py @@ -1,10 +1,11 @@ -from .agent_os import AgentOs, Coordinate, ModifierKey, PcKey +from .agent_os import AgentOs, ComputerAgentOS, Coordinate, ModifierKey, PcKey from .askui.askui_controller import RenderObjectStyle from .computer_agent_os_facade import ComputerAgentOsFacade from .toolbox import AgentToolbox __all__ = [ "AgentOs", + "ComputerAgentOS", "AgentToolbox", "ModifierKey", "PcKey", diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index 96ecc831..48697a70 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -1,12 +1,17 @@ from abc import ABC, abstractmethod +from contextlib import AbstractContextManager from typing import TYPE_CHECKING, Literal from PIL import Image from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Self from askui.models.shared.tool_tags import ToolTags if TYPE_CHECKING: + from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + ) from askui.tools.askui.askui_ui_controller_grpc.generated import ( Controller_V1_pb2 as controller_v1_pbs, ) @@ -213,7 +218,7 @@ def __str__(self) -> str: InputEvent = ClickEvent -class AgentOs(ABC): +class ComputerAgentOS(ABC): """ Abstract base class for Agent OS. Cannot be instantiated directly. @@ -677,6 +682,55 @@ def set_window_in_focus(self, process_id: int, window_id: int) -> None: """ raise NotImplementedError + def add_agent_os_target_computer( + self, agent_os_target_computer: "ComputerTarget" + ) -> "ComputerTarget": + """Register an additional target computer. Auto-connects if connected.""" + raise NotImplementedError + + def reset_agent_os_target_computers( + self, + agent_os_target_computers: "list[ComputerTarget] | None" = None, + ) -> None: + """Disconnect (if connected) and replace the target computer list.""" + raise NotImplementedError + + def describe_agent_os_target_computers(self) -> list[str]: + """Return the `repr()` string of every registered target computer.""" + raise NotImplementedError + + def get_current_computer_target_id(self, report: bool = True) -> str: + """Return the `computer_id` of the currently active target computer.""" + raise NotImplementedError + + def switch_agent_os_target_computer(self, computer_id: str) -> "ComputerTarget": + """Switch the active target computer by its `computer_id`.""" + raise NotImplementedError + + def temporary_select(self, computer_id: str) -> AbstractContextManager[Self]: + """ + Temporarily switch the active target computer for the duration of a `with` + block, then restore the previously-active target on exit (even if the + block raises). + + Args: + computer_id (str): Computer id of the target to activate inside the + block. + + Returns: + AbstractContextManager[Self]: Context manager that yields this + `ComputerAgentOS` with the selected target active. + + Example: + ```python + with agent_os.temporary_select('Remote-Machine') as remote_machine: + img = remote_machine.screenshot() + img.save("remote_machine.png") + # previous active target restored here + ``` + """ + raise NotImplementedError + def get_file_names(self, absolute_directory_path: str) -> list[str]: """ List file names in an absolute directory on the automation target @@ -719,3 +773,13 @@ def remove_virtual_displays(self) -> None: NotImplementedError: If the implementation does not support this operation. """ raise NotImplementedError + + +AgentOs = ComputerAgentOS +"""Deprecated alias for `ComputerAgentOS`, kept for backward compatibility. + +`AgentOs` was renamed to `ComputerAgentOS` to reflect that it is the +computer-specific Agent OS interface (mouse, keyboard, displays, ...) rather +than a universal abstraction across all device types. Prefer `ComputerAgentOS` +in new code. +""" diff --git a/src/askui/tools/android/agent_os.py b/src/askui/tools/android/agent_os.py index 3a5a8285..d7fe7e04 100644 --- a/src/askui/tools/android/agent_os.py +++ b/src/askui/tools/android/agent_os.py @@ -1,7 +1,9 @@ from abc import ABC, abstractmethod +from contextlib import AbstractContextManager from typing import List, Literal from PIL import Image +from typing_extensions import Self from askui.tools.android.uiautomator_hierarchy import UIElementCollection @@ -502,3 +504,26 @@ def get_ui_elements(self) -> UIElementCollection: Gets the UI elements. """ raise NotImplementedError + + def temporary_select(self, device_sn: str) -> AbstractContextManager[Self]: + """ + Temporarily switch the active device for the duration of a `with` block, + then restore the previously-active device on exit (even if the block + raises). + + Args: + device_sn (str): Serial number of the device to activate inside the + block. + + Returns: + AbstractContextManager[Self]: Context manager that yields this + `AndroidAgentOs` with `device_sn` active. + + Example: + ```python + with android_agent_os.temporary_select('table_phone') as table_phone: + table_phone.tap(100, 200) + # previous active device restored here + ``` + """ + raise NotImplementedError diff --git a/src/askui/tools/android/agent_os_facade.py b/src/askui/tools/android/agent_os_facade.py index f27d0eee..0bc19aea 100644 --- a/src/askui/tools/android/agent_os_facade.py +++ b/src/askui/tools/android/agent_os_facade.py @@ -1,6 +1,9 @@ +from collections.abc import Iterator +from contextlib import contextmanager from typing import List, Optional, Tuple from PIL import Image +from typing_extensions import Self from askui.models.shared.tool_tags import ToolTags from askui.tools.android.agent_os import ANDROID_KEY, AndroidAgentOs, AndroidDisplay @@ -112,6 +115,15 @@ def set_device_by_serial_number(self, device_sn: str) -> None: self._agent_os.set_device_by_serial_number(device_sn) self._real_screen_resolution = None + @contextmanager + def temporary_select(self, device_sn: str) -> Iterator[Self]: + with self._agent_os.temporary_select(device_sn): + self._real_screen_resolution = None + try: + yield self + finally: + self._real_screen_resolution = None + def get_connected_devices_serial_numbers(self) -> list[str]: return self._agent_os.get_connected_devices_serial_numbers() diff --git a/src/askui/tools/android/ppadb_agent_os.py b/src/askui/tools/android/ppadb_agent_os.py index 9ffa7452..517ed4e1 100644 --- a/src/askui/tools/android/ppadb_agent_os.py +++ b/src/askui/tools/android/ppadb_agent_os.py @@ -2,12 +2,15 @@ import re import shlex import string +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path from typing import List, Optional, get_args from PIL import Image from ppadb.client import Client as AdbClient from ppadb.device import Device as AndroidDevice +from typing_extensions import Self from askui.reporting import NULL_REPORTER, Reporter from askui.tools.android.agent_os import ( @@ -202,6 +205,24 @@ def set_device_by_serial_number(self, device_sn: str) -> None: msg = f"Device name {device_sn} not found" raise AndroidAgentOsError(msg) + @contextmanager + def temporary_select(self, device_sn: str) -> Iterator[Self]: + previous_sn = self._device.serial if self._device is not None else None + self._reporter.add_message( + self._REPORTER_ROLE_NAME, + f"temporary_select({device_sn!r}) [previous={previous_sn!r}]", + ) + self.set_device_by_serial_number(device_sn) + try: + yield self + finally: + if previous_sn is not None and previous_sn != device_sn: + self.set_device_by_serial_number(previous_sn) + self._reporter.add_message( + self._REPORTER_ROLE_NAME, + f"temporary_select({device_sn!r}) -> restored", + ) + def _screenshot_without_reporting(self) -> Image.Image: device: AndroidDevice = self._get_selected_device() self._check_if_display_is_selected() diff --git a/src/askui/tools/askui/__init__.py b/src/askui/tools/askui/__init__.py index 5d46a982..db94e66d 100644 --- a/src/askui/tools/askui/__init__.py +++ b/src/askui/tools/askui/__init__.py @@ -1,6 +1,19 @@ -from .askui_controller import AskUiControllerClient, AskUiControllerServer +from .agent_os_target_computer import ( + ComputerTarget, + LocalComputerTarget, + RemoteComputerTarget, +) +from .askui_controller import MultiComputerTargetAgentOS +from .computer_target_connection import ComputerTargetConnection +from .computer_target_pool import ( + ComputerTargetPool, +) __all__ = [ - "AskUiControllerClient", - "AskUiControllerServer", + "ComputerTarget", + "ComputerTargetConnection", + "ComputerTargetPool", + "MultiComputerTargetAgentOS", + "LocalComputerTarget", + "RemoteComputerTarget", ] diff --git a/src/askui/tools/askui/agent_os_target_computer.py b/src/askui/tools/askui/agent_os_target_computer.py new file mode 100644 index 00000000..03cc596f --- /dev/null +++ b/src/askui/tools/askui/agent_os_target_computer.py @@ -0,0 +1,378 @@ +import logging +import pathlib +import subprocess +import sys +import time +import uuid +from urllib.parse import urlparse + +from typing_extensions import override + +from askui.tools.askui.askui_controller_settings import AskUiControllerSettings +from askui.tools.askui.computer_target_connection import ComputerTargetConnection +from askui.tools.askui.exceptions import AskUiControllerError +from askui.tools.utils import process_exists, wait_for_port + +logger = logging.getLogger(__name__) + + +class ComputerTarget: + """ + Base class describing a computer target (a machine running the AskUI Agent + OS) that a `MultiComputerTargetAgentOS` client can connect to. + + A computer target runs the server-side counterpart of the `ComputerAgentOS` + client abstraction: it exposes a gRPC API for OS-level operations + (screenshot, mouse, keyboard, ...) and is identified by a unique session + GUID. Each computer target also tracks which display it is currently + operating against. + + Args: + address (str): gRPC address of the target computer + (e.g. ``"localhost:23000"``). + description (str): Human-readable description. + display (int, optional): Display ID selected for this target computer. + Defaults to `1`. + computer_id (str | None, optional): Stable, human-friendly identifier for + the target computer. Used by `ComputerTargetPool` lookup + helpers. Must be unique across registered target computers. Defaults + to the target computer's `session_guid`. + """ + + def __init__( + self, + address: str, + description: str, + display: int = 1, + computer_id: str | None = None, + ) -> None: + self._session_guid = "{" + str(uuid.uuid4()) + "}" + self._address = address + self._description = description + self._display = display + self._computer_id = ( + computer_id if computer_id is not None else self._session_guid + ) + self._connection: ComputerTargetConnection | None = None + + @property + def session_guid(self) -> str: + """Unique session GUID assigned to this target computer.""" + return self._session_guid + + @property + def computer_id(self) -> str: + """ + Stable identifier for this target computer. Defaults to `session_guid` + when no custom id was supplied at construction time. + """ + return self._computer_id + + @property + def address(self) -> str: + """gRPC address of the target computer.""" + return self._address + + @property + def description(self) -> str: + """Description of this target computer.""" + return self._description + + @property + def display(self) -> int: + """Display ID currently selected for this target computer.""" + return self._display + + @display.setter + def display(self, value: int) -> None: + self._display = value + + @property + def is_local(self) -> bool: + """Whether this target computer represents a locally-managed process.""" + return False + + @property + def is_connected(self) -> bool: + """Whether an open gRPC connection to this target computer exists.""" + return self._connection is not None + + @property + def connection(self) -> ComputerTargetConnection: + """ + The open gRPC connection to this target computer. + + Raises: + AskUiControllerError: If this target computer is not connected (i.e. + `connect()` has not been called). + """ + if self._connection is None: + error_msg = ( + f"Agent OS target computer {self._description!r} " + f"(computer_id={self._computer_id!r}, address={self._address}) " + "is not connected. Call `MultiComputerTargetAgentOS.connect()` " + "first." + ) + raise AskUiControllerError(error_msg) + return self._connection + + def connect(self) -> None: + """ + Open the gRPC connection to this target computer. Idempotent: returns + silently if already connected. Delegates the gRPC specifics to + `ComputerTargetConnection.open()`. + """ + if self._connection is None: + self._connection = ComputerTargetConnection.open(self) + + def disconnect(self) -> None: + """ + Close the gRPC connection to this target computer. No-op if not + connected. Delegates the gRPC teardown to + `ComputerTargetConnection.close()`. + """ + conn = self._connection + if conn is None: + return + self._connection = None + conn.close(self) + + def start(self, clean_up: bool = False) -> None: + """Start the underlying controller process. No-op for non-local targets.""" + + def stop(self, force: bool = False) -> None: + """Stop the underlying controller process. No-op for non-local targets.""" + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"computer_id={self._computer_id!r}, " + f"description={self._description!r}, " + f"display={self._display!r})" + ) + + +class LocalComputerTarget(ComputerTarget): + """ + Local computer target: manages an AskUI Remote Device Controller + subprocess on this machine. + + Args: + settings (AskUiControllerSettings | None, optional): Process-level settings + (executable path, args). Defaults to a fresh `AskUiControllerSettings`. + address (str, optional): gRPC address. Defaults to ``"localhost:23000"``. + is_service (bool, optional): When `True`, `start()` does not launch the + controller binary because it is managed externally (e.g. AskUI Core + Service on Windows). Defaults to `False`. + discover_service (bool, optional): On Windows, probe for a running + ``askuicoreservice`` and, if found, switch the address to port + ``26000`` and set `is_service` to `True`. Defaults to `True`. + description (str, optional) + display (int, optional): Display ID selected for this target computer. + Defaults to `1`. + """ + + _ASKUI_CORE_SERVICE_NAME = "AskuiCoreService" + _ASKUI_CORE_SERVICE_PORT = 26000 + + def __init__( + self, + description: str = "Local computer target", + settings: AskUiControllerSettings | None = None, + address: str = "localhost:23000", + discover_service: bool = True, + display: int = 1, + computer_id: str | None = None, + ) -> None: + super().__init__( + address=address, + description=description, + display=display, + computer_id=computer_id, + ) + self._is_service = False + self._settings = settings or AskUiControllerSettings() + self._process: subprocess.Popen[bytes] | None = None + if discover_service: + self._discover_service(address) + + @property + @override + def is_local(self) -> bool: + return True + + @property + def is_service(self) -> bool: + """Whether the controller process is managed externally (skip `start()`).""" + return self._is_service + + @staticmethod + def _is_askui_core_service_running() -> bool: + """Return `True` when the `AskuiCoreService` Windows service is RUNNING.""" + if sys.platform == "win32": + try: + result = subprocess.run( + [ + "sc", + "query", + LocalComputerTarget._ASKUI_CORE_SERVICE_NAME, + ], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + if result.returncode != 0: + return False + except (OSError, subprocess.SubprocessError) as e: + error_msg = ( + "Failed to query " + f"{LocalComputerTarget._ASKUI_CORE_SERVICE_NAME} service: {e}" + ) + logger.debug(error_msg) + return False + return "RUNNING" in result.stdout.upper() + return False + + def _discover_service(self, address: str) -> None: + if LocalComputerTarget._is_askui_core_service_running(): + service_msg = ( + f"Detected running {self._ASKUI_CORE_SERVICE_NAME}; using port " + f"{self._ASKUI_CORE_SERVICE_PORT} (controller managed by service)" + ) + logger.info(service_msg) + address = LocalComputerTarget.replace_port( + address, self._ASKUI_CORE_SERVICE_PORT + ) + self._is_service = True + + @staticmethod + def replace_port(address: str, port: int) -> str: + addr = address if "://" in address else "//" + address + parsed = urlparse(addr) + host = parsed.hostname or "localhost" + return f"{host}:{port}" + + def _parse_port(self) -> int: + addr = self._address if "://" in self._address else "//" + self._address + parsed = urlparse(addr) + if parsed.port is None: + error_msg = ( + f"Could not parse port from address {self._address!r}. " + "Expected format 'host:port' (e.g. 'localhost:23000')." + ) + raise ValueError(error_msg) + return parsed.port + + def _start_process( + self, + path: pathlib.Path, + args: str | None = None, + ) -> None: + commands = [str(path)] + if args: + commands.extend(args.split()) + if not logger.isEnabledFor(logging.DEBUG): + self._process = subprocess.Popen( + commands, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + else: + self._process = subprocess.Popen(commands) + wait_for_port(self._parse_port()) + + @override + def start(self, clean_up: bool = False) -> None: + """ + Start the controller process unless this target uses a service-managed + binary. + + Args: + clean_up (bool, optional): Whether to clean up existing processes + (only on Windows) before starting. Defaults to `False`. + """ + if self._is_service: + logger.debug( + "Skipping local controller start; process is managed by service" + ) + return + if ( + sys.platform == "win32" + and clean_up + and process_exists("AskuiRemoteDeviceController.exe") + ): + self.clean_up() + logger.debug( + "Starting AskUI Remote Device Controller", + extra={"path": str(self._settings.controller_path)}, + ) + self._start_process( + self._settings.controller_path, self._settings.controller_args + ) + time.sleep(0.5) + + def clean_up(self) -> None: + subprocess.run("taskkill.exe /IM AskUI*") + time.sleep(0.1) + + @override + def stop(self, force: bool = False) -> None: + """ + Stop the controller process. + + Args: + force (bool, optional): Whether to forcefully terminate the process. + Defaults to `False`. + """ + if self._process is None: + return + + try: + if force: + self._process.kill() + if sys.platform == "win32": + self.clean_up() + else: + self._process.terminate() + except Exception: # noqa: BLE001 - We want to catch all other exceptions here + logger.exception("Error stopping local controller process") + finally: + self._process = None + + +class RemoteComputerTarget(ComputerTarget): + """ + Remote computer target: the client connects to an already-running + controller on another machine. + + No process management is performed; `start()` and `stop()` are no-ops. + + Args: + address (str): gRPC address of the remote target computer (required). + description (str): Human-readable description. + display (int, optional): Display ID selected for this target computer. + Defaults to `1`. + computer_id (str | None, optional): Stable, human-friendly identifier for + the target computer. Defaults to the target computer's + `session_guid`. + """ + + def __init__( + self, + address: str, + description: str, + display: int = 1, + computer_id: str | None = None, + ) -> None: + super().__init__( + address=address, + description=description, + display=display, + computer_id=computer_id, + ) + + +__all__ = [ + "ComputerTarget", + "LocalComputerTarget", + "RemoteComputerTarget", +] diff --git a/src/askui/tools/askui/askui_controller.py b/src/askui/tools/askui/askui_controller.py index 26aeb5d0..b466cf09 100644 --- a/src/askui/tools/askui/askui_controller.py +++ b/src/askui/tools/askui/askui_controller.py @@ -1,1415 +1,1500 @@ -import base64 -import logging -import pathlib -import subprocess -import sys -import time -import types -import uuid -from typing import Literal, Type - -import grpc -from google.protobuf.json_format import MessageToDict -from PIL import Image -from typing_extensions import Self, override - -from askui.container import telemetry -from askui.reporting import NULL_REPORTER, Reporter -from askui.tools.agent_os import ( - AgentOs, - Coordinate, - Display, - DisplaysListResponse, - ModifierKey, - PcKey, -) -from askui.tools.askui.askui_controller_client_settings import ( - AskUiControllerClientSettings, -) -from askui.tools.askui.askui_controller_settings import AskUiControllerSettings -from askui.tools.askui.askui_ui_controller_grpc.desktop_agent_os_error import ( - DesktopAgentOsError, -) -from askui.tools.askui.askui_ui_controller_grpc.generated import ( - Controller_V1_pb2 as controller_v1_pbs, -) -from askui.tools.askui.askui_ui_controller_grpc.generated import ( - Controller_V1_pb2_grpc as controller_v1, -) -from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Request_2501 import ( # noqa: E501 - AddRenderObjectCommand, - AskUIAgentOSSendRequestSchema, - ClearRenderObjectsCommand, - Command, - DeleteRenderObjectCommand, - GetActiveProcessCommand, - GetActiveWindowCommand, - GetFileCommand, - GetFileNamesCommand, - GetMousePositionCommand, - GetSystemInfoCommand, - Guid, - Header, - Length, - Location, - Message, - Parameter3, - RemoveVirtualDisplaysCommand, - RenderImage, - RenderObjectId, - RenderObjectStyle, - RenderText, - SetActiveProcessCommand, - SetActiveWindowCommand, - SetMousePositionCommand, - UpdateRenderObjectCommand, -) -from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Response_2501 import ( # noqa: E501 - AskUIAgentOSSendResponseSchema, - GetActiveProcessResponse, - GetActiveProcessResponseModel, - GetActiveWindowResponse, - GetActiveWindowResponseModel, - GetFileNamesResponse, - GetFileResponse, - GetSystemInfoResponse, - GetSystemInfoResponseModel, -) -from askui.utils.annotated_image import AnnotatedImage -from askui.utils.image_utils import base64_to_image - -from ..utils import process_exists, wait_for_port -from .exceptions import ( - AskUiControllerError, - AskUiControllerInvalidCommandError, - AskUiControllerOperationTimeoutError, -) - -logger = logging.getLogger(__name__) - - -class AskUiControllerServer: - """ - Concrete implementation of `ControllerServer` for managing the AskUI Remote Device - Controller process. - Handles process discovery, startup, and shutdown for the native controller binary. - - Args: - settings (AskUiControllerSettings | None, optional): Settings for the AskUI. - """ - - def __init__(self, settings: AskUiControllerSettings | None = None) -> None: - self._process: subprocess.Popen[bytes] | None = None - self._settings = settings or AskUiControllerSettings() - - def _start_process( - self, - path: pathlib.Path, - args: str | None = None, - ) -> None: - commands = [str(path)] - if args: - commands.extend(args.split()) - if not logger.isEnabledFor(logging.DEBUG): - self._process = subprocess.Popen( - commands, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL - ) - else: - self._process = subprocess.Popen(commands) - wait_for_port(23000) - - def start(self, clean_up: bool = False) -> None: - """ - Start the controller process. - - Args: - clean_up (bool, optional): Whether to clean up existing processes - (only on Windows) before starting. Defaults to `False`. - """ - if ( - sys.platform == "win32" - and clean_up - and process_exists("AskuiRemoteDeviceController.exe") - ): - self.clean_up() - logger.debug( - "Starting AskUI Remote Device Controller", - extra={"path": str(self._settings.controller_path)}, - ) - self._start_process( - self._settings.controller_path, self._settings.controller_args - ) - time.sleep(0.5) - - def clean_up(self) -> None: - subprocess.run("taskkill.exe /IM AskUI*") - time.sleep(0.1) - - def stop(self, force: bool = False) -> None: - """ - Stop the controller process. - - Args: - force (bool, optional): Whether to forcefully terminate the process. - Defaults to `False`. - """ - if self._process is None: - return # Nothing to stop - - try: - if force: - self._process.kill() - if sys.platform == "win32": - self.clean_up() - else: - self._process.terminate() - except Exception: # noqa: BLE001 - We want to catch all other exceptions here - logger.exception("Controller error") - finally: - self._process = None - - -class AskUiControllerClient(AgentOs): - """ - Implementation of `AgentOs` that communicates with the AskUI Remote Device - Controller via gRPC. - - Args: - reporter (Reporter): Reporter used for reporting with the `"AgentOs"`. - display (int, optional): Display number to use. Defaults to `1`. - controller_server (AskUiControllerServer | None, optional): Custom controller - server. Defaults to `ControllerServer`. - """ - - @telemetry.record_call(exclude={"reporter", "controller_server"}) - def __init__( - self, - reporter: Reporter = NULL_REPORTER, - display: int = 1, - controller_server: AskUiControllerServer | None = None, - settings: AskUiControllerClientSettings | None = None, - ) -> None: - self._stub: controller_v1.ControllerAPIStub | None = None - self._channel: grpc.Channel | None = None - self._session_info: controller_v1_pbs.SessionInfo | None = None - self._pre_action_wait = 0 - self._post_action_wait = 0.05 - self._max_retries = 10 - self._display = display - self._reporter = reporter - self._controller_server = controller_server or AskUiControllerServer() - self._session_guid = "{" + str(uuid.uuid4()) + "}" - self._settings = settings or AskUiControllerClientSettings() - - @telemetry.record_call() - @override - def connect(self) -> None: - """ - Establishes a connection to the AskUI Remote Device Controller. - - This method starts the controller server, establishes a gRPC channel, - creates a session, and sets up the initial display. - """ - if self._settings.server_autostart: - self._controller_server.start() - self._channel = grpc.insecure_channel( - self._settings.server_address, - options=[ - ("grpc.max_send_message_length", 2**30), - ("grpc.max_receive_message_length", 2**30), - ("grpc.default_deadline", 300000), - ], - ) - self._stub = controller_v1.ControllerAPIStub(self._channel) - self._start_session() - self._start_execution() - self.set_display(self._display) - if self._settings.clean_virtual_displays: - logger.info( - "clean_virtual_displays is enabled. Removing all virtual displays ... " - ) - self.remove_virtual_displays() - logger.info("Virtual displays removed.") - - def _get_stub(self) -> controller_v1.ControllerAPIStub: - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized. Call `connect()` first." - ) - return self._stub - - def _run_recorder_action( - self, - acion_class_id: controller_v1_pbs.ActionClassID, - action_parameters: controller_v1_pbs.ActionParameters, - ) -> controller_v1_pbs.Response_RunRecordedAction: - time.sleep(self._pre_action_wait) - response: controller_v1_pbs.Response_RunRecordedAction = ( - self._get_stub().RunRecordedAction( - controller_v1_pbs.Request_RunRecordedAction( - sessionInfo=self._session_info, - actionClassID=acion_class_id, - actionParameters=action_parameters, - ) - ) - ) - - time.sleep((response.requiredMilliseconds / 1000)) - num_retries = 0 - for _ in range(self._max_retries): - poll_response: controller_v1_pbs.Response_Poll = self._get_stub().Poll( - controller_v1_pbs.Request_Poll( - sessionInfo=self._session_info, - pollEventID=controller_v1_pbs.PollEventID.PollEventID_ActionFinished, - ) - ) - if ( - poll_response.pollEventParameters.actionFinished.actionID - == response.actionID - ): - break - time.sleep(self._post_action_wait) - num_retries += 1 - if num_retries == self._max_retries - 1: - raise AskUiControllerOperationTimeoutError - return response - - @telemetry.record_call() - @override - def disconnect(self) -> None: - """ - Terminates the connection to the AskUI Remote Device Controller. - - This method stops the execution, ends the session, closes the gRPC channel, - and stops the controller server. - """ - try: - self._stop_execution() - self._stop_session() - if self._channel is not None: - self._channel.close() - self._controller_server.stop() - except Exception as e: # noqa: BLE001 - # We want to catch all other exceptions here and not re-raise them - msg = ( - "Error while disconnecting from the AskUI Remote Device Controller" - f" Error: {e}" - ) - logger.exception(msg) - - @telemetry.record_call() - def __enter__(self) -> Self: - """ - Context manager entry point that establishes the connection. - - Returns: - Self: The instance of AskUiControllerClient. - """ - self.connect() - return self - - @telemetry.record_call(exclude={"exc_value", "traceback"}) - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_value: BaseException | None, - traceback: types.TracebackType | None, - ) -> None: - """ - Context manager exit point that disconnects the client. - - Args: - exc_type: The exception type if an exception was raised. - exc_value: The exception value if an exception was raised. - traceback: The traceback if an exception was raised. - """ - self.disconnect() - - def _start_session(self) -> None: - response = self._get_stub().StartSession( - controller_v1_pbs.Request_StartSession( - sessionGUID=self._session_guid, immediateExecution=True - ) - ) - self._session_info = response.sessionInfo - - def _stop_session(self) -> None: - self._get_stub().EndSession( - controller_v1_pbs.Request_EndSession(sessionInfo=self._session_info) - ) - - def _start_execution(self) -> None: - self._get_stub().StartExecution( - controller_v1_pbs.Request_StartExecution(sessionInfo=self._session_info) - ) - - def _stop_execution(self) -> None: - self._get_stub().StopExecution( - controller_v1_pbs.Request_StopExecution(sessionInfo=self._session_info) - ) - - @telemetry.record_call() - @override - def screenshot(self, report: bool = True) -> Image.Image: - """ - Take a screenshot of the current screen. - - Args: - report (bool, optional): Whether to include the screenshot in reporting. - Defaults to `True`. - - Returns: - Image.Image: A PIL Image object containing the screenshot. - - """ - screenResponse = self._get_stub().CaptureScreen( - controller_v1_pbs.Request_CaptureScreen( - sessionInfo=self._session_info, - captureParameters=controller_v1_pbs.CaptureParameters( - displayID=self._display - ), - ) - ) - r, g, b, _ = Image.frombytes( - "RGBA", - (screenResponse.bitmap.width, screenResponse.bitmap.height), - screenResponse.bitmap.data, - ).split() - image = Image.merge("RGB", (b, g, r)) - self._reporter.add_message("AgentOS", "screenshot()", image) - return image - - @telemetry.record_call() - @override - def mouse_move(self, x: int, y: int, duration: int = 500) -> None: - """ - Moves the mouse cursor to specified screen coordinates. - - Args: - x (int): The horizontal coordinate (in pixels) to move to. - y (int): The vertical coordinate (in pixels) to move to. - duration (int): The duration (in ms) the movement should take. - """ - self._reporter.add_message( - "AgentOS", - f"mouse_move({x}, {y}, duration={duration})", - AnnotatedImage(lambda: self.screenshot(report=False), point_list=[(x, y)]), - ) - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseMove, - action_parameters=controller_v1_pbs.ActionParameters( - mouseMove=controller_v1_pbs.ActionParameters_MouseMove( - position=controller_v1_pbs.Coordinate2(x=x, y=y), - milliseconds=duration, - ) - ), - ) - - @telemetry.record_call(exclude={"text"}) - @override - def type(self, text: str, typing_speed: int = 50) -> None: - """ - Type text at current cursor position as if entered on a keyboard. - - Args: - text (str): The text to type. - typing_speed (int, optional): The speed of typing in characters per second. - Defaults to `50`. - """ - self._reporter.add_message("AgentOS", f'type("{text}", {typing_speed})') - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardType_UnicodeText, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardTypeUnicodeText=controller_v1_pbs.ActionParameters_KeyboardType_UnicodeText( - text=text.encode("utf-16-le"), - typingSpeed=typing_speed, - typingSpeedValue=controller_v1_pbs.TypingSpeedValue.TypingSpeedValue_CharactersPerSecond, - ) - ), - ) - - @telemetry.record_call() - @override - def click( - self, button: Literal["left", "middle", "right"] = "left", count: int = 1 - ) -> None: - """ - Click a mouse button. - - Args: - button (Literal["left", "middle", "right"], optional): The mouse button to - click. Defaults to `"left"`. - count (int, optional): Number of times to click. Defaults to `1`. - """ - self._reporter.add_message("AgentOS", f'click("{button}", {count})') - mouse_button = None - match button: - case "left": - mouse_button = controller_v1_pbs.MouseButton_Left - case "middle": - mouse_button = controller_v1_pbs.MouseButton_Middle - case "right": - mouse_button = controller_v1_pbs.MouseButton_Right - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_PressAndRelease, - action_parameters=controller_v1_pbs.ActionParameters( - mouseButtonPressAndRelease=controller_v1_pbs.ActionParameters_MouseButton_PressAndRelease( - mouseButton=mouse_button, count=count - ) - ), - ) - - @telemetry.record_call() - @override - def mouse_down(self, button: Literal["left", "middle", "right"] = "left") -> None: - """ - Press and hold a mouse button. - - Args: - button (Literal["left", "middle", "right"], optional): The mouse button to - press. Defaults to `"left"`. - """ - self._reporter.add_message("AgentOS", f'mouse_down("{button}")') - mouse_button = None - match button: - case "left": - mouse_button = controller_v1_pbs.MouseButton_Left - case "middle": - mouse_button = controller_v1_pbs.MouseButton_Middle - case "right": - mouse_button = controller_v1_pbs.MouseButton_Right - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Press, - action_parameters=controller_v1_pbs.ActionParameters( - mouseButtonPress=controller_v1_pbs.ActionParameters_MouseButton_Press( - mouseButton=mouse_button - ) - ), - ) - - @telemetry.record_call() - @override - def mouse_up(self, button: Literal["left", "middle", "right"] = "left") -> None: - """ - Release a mouse button. - - Args: - button (Literal["left", "middle", "right"], optional): The mouse button to - release. Defaults to `"left"`. - """ - self._reporter.add_message("AgentOS", f'mouse_up("{button}")') - mouse_button = None - match button: - case "left": - mouse_button = controller_v1_pbs.MouseButton_Left - case "middle": - mouse_button = controller_v1_pbs.MouseButton_Middle - case "right": - mouse_button = controller_v1_pbs.MouseButton_Right - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Release, - action_parameters=controller_v1_pbs.ActionParameters( - mouseButtonRelease=controller_v1_pbs.ActionParameters_MouseButton_Release( - mouseButton=mouse_button - ) - ), - ) - - @telemetry.record_call() - @override - def mouse_scroll(self, dx: int, dy: int) -> None: - """ - Scroll the mouse wheel. - - Args: - dx (int): The horizontal scroll amount. Positive values scroll right, - negative values scroll left. - dy (int): The vertical scroll amount. Positive values scroll down, - negative values scroll up. - """ - self._reporter.add_message("AgentOS", f"mouse_scroll({dx}, {dy})") - if dx != 0: - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, - action_parameters=controller_v1_pbs.ActionParameters( - mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( - direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Horizontal, - deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, - delta=dx, - milliseconds=50, - ) - ), - ) - if dy != 0: - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, - action_parameters=controller_v1_pbs.ActionParameters( - mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( - direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Vertical, - deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, - delta=dy, - milliseconds=50, - ) - ), - ) - - @telemetry.record_call() - @override - def keyboard_pressed( - self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None - ) -> None: - """ - Press and hold a keyboard key. - - Args: - key (PcKey | ModifierKey): The key to press. - modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to - press along with the main key. Defaults to `None`. - """ - self._reporter.add_message( - "AgentOS", f'keyboard_pressed("{key}", {modifier_keys})' - ) - if modifier_keys is None: - modifier_keys = [] - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Press, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyPress=controller_v1_pbs.ActionParameters_KeyboardKey_Press( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) - - @telemetry.record_call() - @override - def keyboard_release( - self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None - ) -> None: - """ - Release a keyboard key. - - Args: - key (PcKey | ModifierKey): The key to release. - modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to - release along with the main key. Defaults to `None`. - """ - self._reporter.add_message( - "AgentOS", f'keyboard_release("{key}", {modifier_keys})' - ) - if modifier_keys is None: - modifier_keys = [] - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Release, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyRelease=controller_v1_pbs.ActionParameters_KeyboardKey_Release( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) - - @telemetry.record_call() - @override - def keyboard_tap( - self, - key: PcKey | ModifierKey, - modifier_keys: list[ModifierKey] | None = None, - count: int = 1, - ) -> None: - """ - Press and immediately release a keyboard key. - - Args: - key (PcKey | ModifierKey): The key to tap. - modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to - press along with the main key. Defaults to `None`. - count (int, optional): The number of times to tap the key. Defaults to `1`. - """ - self._reporter.add_message( - "AgentOS", - f'keyboard_tap("{key}", {modifier_keys}, {count})', - ) - if modifier_keys is None: - modifier_keys = [] - for _ in range(count): - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_PressAndRelease, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyPressAndRelease=controller_v1_pbs.ActionParameters_KeyboardKey_PressAndRelease( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) - - @telemetry.record_call() - @override - def set_display(self, display: int = 1) -> None: - """ - Set the active display. - - Args: - display (int, optional): The display ID to set as active. - This can be either a real display ID or a virtual display ID. - Defaults to `1`. - """ - self._get_stub().SetActiveDisplay( - controller_v1_pbs.Request_SetActiveDisplay(displayID=display) - ) - self._display = display - self._reporter.add_message("AgentOS", f"set_display({display})") - - @telemetry.record_call(exclude={"command"}) - @override - def run_command(self, command: str, timeout_ms: int = 30000) -> None: - """ - Execute a shell command. - - Args: - command (str): The command to execute. - timeout_ms (int, optional): The timeout for command - execution in milliseconds. Defaults to `30000` (30 seconds). - """ - self._reporter.add_message("AgentOS", f'run_command("{command}", {timeout_ms})') - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_RunCommand, - action_parameters=controller_v1_pbs.ActionParameters( - runcommand=controller_v1_pbs.ActionParameters_RunCommand( - command=command, timeoutInMilliseconds=timeout_ms - ) - ), - ) - - @telemetry.record_call() - @override - def retrieve_active_display(self) -> Display: - """ - Retrieve the currently active display/screen. - - Returns: - Display: The currently active display/screen. - """ - self._reporter.add_message("AgentOS", "retrieve_active_display()") - displays_list_response = self.list_displays() - for display in displays_list_response.data: - if display.id == self._display: - self._reporter.add_message( - "AgentOS", f"retrieve_active_display() -> {display}" - ) - return display - error_msg = f"Display {self._display} not found" - raise ValueError(error_msg) - - @telemetry.record_call() - @override - def list_displays( - self, - ) -> DisplaysListResponse: - """ - List all available Displays from the controller. - It includes both real and virtual displays - without describing the type of display (virtual or real). - - Returns: - DisplaysListResponse - """ - - self._reporter.add_message("AgentOS", "list_displays()") - - response: controller_v1_pbs.Response_GetDisplayInformation = ( - self._get_stub().GetDisplayInformation(controller_v1_pbs.Request_Void()) - ) - - response_dict = MessageToDict( - response, - preserving_proto_field_name=True, - ) - - displays = DisplaysListResponse.model_validate(response_dict) - - self._reporter.add_message("AgentOS", f"list_displays() ->{str(displays)}") - - return displays - - @telemetry.record_call() - def get_process_list( - self, get_extended_info: bool = False - ) -> controller_v1_pbs.Response_GetProcessList: - """ - Get a list of running processes. - - Args: - get_extended_info (bool, optional): Whether to include - extended process information. - Defaults to `False`. - - Returns: - controller_v1_pbs.Response_GetProcessList: Process list response containing: - - processes: List of ProcessInfo objects - """ - - self._reporter.add_message("AgentOS", f"get_process_list({get_extended_info})") - - response: controller_v1_pbs.Response_GetProcessList = ( - self._get_stub().GetProcessList( - controller_v1_pbs.Request_GetProcessList( - getExtendedInfo=get_extended_info - ) - ) - ) - self._reporter.add_message( - "AgentOS", f"get_process_list({get_extended_info}) -> {response}" - ) - - return response - - @telemetry.record_call() - def get_window_list( - self, process_id: int - ) -> controller_v1_pbs.Response_GetWindowList: - """ - Get a list of windows for a specific process. - - Args: - process_id (int): The ID of the process to get windows for. - - Returns: - controller_v1_pbs.Response_GetWindowList: Window list response containing: - - windows: List of WindowInfo objects with ID and name - """ - - self._reporter.add_message("AgentOS", f"get_window_list({process_id})") - - response: controller_v1_pbs.Response_GetWindowList = ( - self._get_stub().GetWindowList( - controller_v1_pbs.Request_GetWindowList(processID=process_id) - ) - ) - - self._reporter.add_message( - "AgentOS", f"get_window_list({process_id}) -> {response}" - ) - - return response - - @telemetry.record_call() - def get_automation_target_list( - self, - ) -> controller_v1_pbs.Response_GetAutomationTargetList: - """ - Get a list of available automation targets. - - Returns: - controller_v1_pbs.Response_GetAutomationTargetList: - Automation target list response: - - targets: List of AutomationTarget objects - """ - - self._reporter.add_message("AgentOS", "get_automation_target_list()") - - response: controller_v1_pbs.Response_GetAutomationTargetList = ( - self._get_stub().GetAutomationTargetList(controller_v1_pbs.Request_Void()) - ) - self._reporter.add_message( - "AgentOS", f"get_automation_target_list() -> {response}" - ) - - return response - - @telemetry.record_call() - def set_mouse_delay(self, delay_ms: int) -> None: - """ - Configure mouse action delay. - - Args: - delay_ms (int): The delay in milliseconds to set for mouse actions. - """ - - self._reporter.add_message("AgentOS", f"set_mouse_delay({delay_ms})") - - self._get_stub().SetMouseDelay( - controller_v1_pbs.Request_SetMouseDelay( - sessionInfo=self._session_info, delayInMilliseconds=delay_ms - ) - ) - - @telemetry.record_call() - def set_keyboard_delay(self, delay_ms: int) -> None: - """ - Configure keyboard action delay. - - Args: - delay_ms (int): The delay in milliseconds to set for keyboard actions. - """ - - self._reporter.add_message("AgentOS", f"set_keyboard_delay({delay_ms})") - - self._get_stub().SetKeyboardDelay( - controller_v1_pbs.Request_SetKeyboardDelay( - sessionInfo=self._session_info, delayInMilliseconds=delay_ms - ) - ) - - @telemetry.record_call() - def set_active_window(self, process_id: int, window_id: int) -> int: - """ - Set the active window for automation. - Adds the window as a virtual display and returns the display ID. - It raises an error if display length is not increased after adding the window. - - Args: - process_id (int): The ID of the process that owns the window. - window_id (int): The ID of the window to set as active. - - returns: - int: The new Display ID. - Raises: - AskUiControllerError: - If display length is not increased after adding the window. - """ - - self._reporter.add_message( - "AgentOS", f"set_active_window({process_id}, {window_id})" - ) - - display_length_before_adding_window = len(self.list_displays().data) - - self._get_stub().SetActiveWindow( - controller_v1_pbs.Request_SetActiveWindow( - processID=process_id, windowID=window_id - ) - ) - new_display_length = len(self.list_displays().data) - if new_display_length <= display_length_before_adding_window: - msg = f"Failed to set active window {window_id} for process {process_id}" - raise AskUiControllerError(msg) - self._reporter.add_message( - "AgentOS", - f"set_active_window({process_id}, {window_id}) -> {new_display_length}", - ) - return new_display_length - - @telemetry.record_call() - def set_active_automation_target(self, target_id: int) -> None: - """ - Set the active automation target. - - Args: - target_id (int): The ID of the automation target to set as active. - """ - - self._reporter.add_message( - "AgentOS", f"set_active_automation_target({target_id})" - ) - - self._get_stub().SetActiveAutomationTarget( - controller_v1_pbs.Request_SetActiveAutomationTarget(ID=target_id) - ) - - @telemetry.record_call() - def schedule_batched_action( - self, - action_class_id: controller_v1_pbs.ActionClassID, - action_parameters: controller_v1_pbs.ActionParameters, - ) -> controller_v1_pbs.Response_ScheduleBatchedAction: - """ - Schedule an action for batch execution. - - Args: - action_class_id (controller_v1_pbs.ActionClassID): The class ID - of the action to schedule. - action_parameters (controller_v1_pbs.ActionParameters): - Parameters for the action. - - Returns: - controller_v1_pbs.Response_ScheduleBatchedAction: Response containing - the scheduled action ID. - """ - - self._reporter.add_message( - "AgentOS", - f"schedule_batched_action({action_class_id}, {action_parameters})", - ) - - response: controller_v1_pbs.Response_ScheduleBatchedAction = ( - self._get_stub().ScheduleBatchedAction( - controller_v1_pbs.Request_ScheduleBatchedAction( - sessionInfo=self._session_info, - actionClassID=action_class_id, - actionParameters=action_parameters, - ) - ) - ) - - return response - - @telemetry.record_call() - def start_batch_run(self) -> None: - """ - Start executing batched actions. - """ - - self._reporter.add_message("AgentOS", "start_batch_run()") - - self._get_stub().StartBatchRun( - controller_v1_pbs.Request_StartBatchRun(sessionInfo=self._session_info) - ) - - @telemetry.record_call() - def stop_batch_run(self) -> None: - """ - Stop executing batched actions. - """ - - self._reporter.add_message("AgentOS", "stop_batch_run()") - - self._get_stub().StopBatchRun( - controller_v1_pbs.Request_StopBatchRun(sessionInfo=self._session_info) - ) - - @telemetry.record_call() - def get_action_count(self) -> controller_v1_pbs.Response_GetActionCount: - """ - Get the count of recorded or batched actions. - - Returns: - controller_v1_pbs.Response_GetActionCount: Response - containing the action count. - """ - - response: controller_v1_pbs.Response_GetActionCount = ( - self._get_stub().GetActionCount( - controller_v1_pbs.Request_GetActionCount(sessionInfo=self._session_info) - ) - ) - self._reporter.add_message("AgentOS", f"get_action_count() -> {response}") - return response - - @telemetry.record_call() - def get_action(self, action_index: int) -> controller_v1_pbs.Response_GetAction: - """ - Get a specific action by its index. - - Args: - action_index (int): The index of the action to retrieve. - - Returns: - controller_v1_pbs.Response_GetAction: Action information containing: - - actionID: The action ID - - actionClassID: The action class ID - - actionParameters: The action parameters - """ - - self._reporter.add_message("AgentOS", f"get_action({action_index})") - - response: controller_v1_pbs.Response_GetAction = self._get_stub().GetAction( - controller_v1_pbs.Request_GetAction( - sessionInfo=self._session_info, actionIndex=action_index - ) - ) - - return response - - @telemetry.record_call() - def remove_action(self, action_id: int) -> None: - """ - Remove a specific action by its ID. - - Args: - action_id (int): The ID of the action to remove. - """ - - self._reporter.add_message("AgentOS", f"remove_action({action_id})") - - self._get_stub().RemoveAction( - controller_v1_pbs.Request_RemoveAction( - sessionInfo=self._session_info, actionID=action_id - ) - ) - - @telemetry.record_call() - def remove_all_actions(self) -> None: - """ - Clear all recorded or batched actions. - """ - - self._reporter.add_message("AgentOS", "remove_all_actions()") - - self._get_stub().RemoveAllActions( - controller_v1_pbs.Request_RemoveAllActions(sessionInfo=self._session_info) - ) - - def _send_command(self, command: Command) -> AskUIAgentOSSendResponseSchema: - """ - Send a general command to the controller. - - Args: - command (Command): The command to send to the controller. - - Returns: - AskUIAgentOSSendResponseSchema: Response containing - the message from the controller. - - Raises: - AskUiControllerInvalidCommandError: If the command fails schema validation - on the server side. - """ - - header = Header(authentication=Guid(root=self._session_guid)) - message = Message(header=header, command=command) - - request = AskUIAgentOSSendRequestSchema(message=message) - - request_str = request.model_dump_json(exclude_none=True, by_alias=True) - - try: - response: controller_v1_pbs.Response_Send = self._get_stub().Send( - controller_v1_pbs.Request_Send(message=request_str) - ) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.INVALID_ARGUMENT: - details = e.details() or None - raise AskUiControllerInvalidCommandError(details) from e - raise - - return AskUIAgentOSSendResponseSchema.model_validate_json(response.message) - - @telemetry.record_call() - def get_mouse_position(self) -> Coordinate: - """ - Get the mouse cursor position - - Returns: - Coordinate: Response containing the result of the mouse position change. - """ - self._reporter.add_message("AgentOS", "get_mouse_position()") - res = self._send_command(GetMousePositionCommand()) - coordinate = Coordinate( - x=res.message.command.response.position.x.root, # type: ignore[union-attr] - y=res.message.command.response.position.y.root, # type: ignore[union-attr] - ) - self._reporter.add_message("AgentOS", f"get_mouse_position() -> {coordinate}") - return coordinate - - @telemetry.record_call() - def set_mouse_position(self, x: int, y: int) -> None: - """ - Set the mouse cursor position to specific coordinates. - - Args: - x (int): The horizontal coordinate (in pixels) to set the cursor to. - y (int): The vertical coordinate (in pixels) to set the cursor to. - """ - location = Location(x=Length(root=x), y=Length(root=y)) - command = SetMousePositionCommand(parameters=[location]) - self._reporter.add_message("AgentOS", f"set_mouse_position({x},{y})") - self._send_command(command) - - @telemetry.record_call() - def render_quad(self, style: RenderObjectStyle) -> int: - """ - Render a quad object to the display. - - Args: - style (RenderObjectStyle): The style properties for the quad. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_quad({style})") - command = AddRenderObjectCommand(parameters=["Quad", style]) - res = self._send_command(command) - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call() - def render_line(self, style: RenderObjectStyle, points: list[Coordinate]) -> int: - """ - Render a line object to the display. - - Args: - style (RenderObjectStyle): The style properties for the line. - points (list[Coordinates]): The points defining the line. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_line({style}, {points})") - command = AddRenderObjectCommand(parameters=["Line", style, points]) - res = self._send_command(command) - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call(exclude={"image_data"}) - def render_image(self, style: RenderObjectStyle, image_data: str) -> int: - """ - Render an image object to the display. - - Args: - style (RenderObjectStyle): The style properties for the image. - image_data (str): The base64-encoded image data. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_image({style}, [image_data])") - image = RenderImage(root=image_data) - command = AddRenderObjectCommand(parameters=["Image", style, image]) - res = self._send_command(command) - - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call() - def render_text(self, style: RenderObjectStyle, content: str) -> int: - """ - Render a text object to the display. - - Args: - style (RenderObjectStyle): The style properties for the text. - content (str): The text content to display. - - Returns: - int: Object ID. - """ - self._reporter.add_message("AgentOS", f"render_text({style}, {content})") - text = RenderText(root=content) - command = AddRenderObjectCommand(parameters=["Text", style, text]) - res = self._send_command(command) - return int(res.message.command.response.id.root) # type: ignore[union-attr] - - @telemetry.record_call() - def update_render_object(self, object_id: int, style: RenderObjectStyle) -> None: - """ - Update styling properties of an existing render object. - - Args: - object_id (float): The ID of the render object to update. - style (RenderObjectStyle): The new style properties. - - Returns: - int: Object ID. - """ - self._reporter.add_message( - "AgentOS", f"update_render_object({object_id}, {style})" - ) - render_object_id = RenderObjectId(root=object_id) - command = UpdateRenderObjectCommand(parameters=[render_object_id, style]) - self._send_command(command) - - @telemetry.record_call() - def delete_render_object(self, object_id: int) -> None: - """ - Delete an existing render object from the display. - - Args: - object_id (RenderObjectId): The ID of the render object to delete. - """ - self._reporter.add_message("AgentOS", f"delete_render_object({object_id})") - render_object_id = RenderObjectId(root=object_id) - command = DeleteRenderObjectCommand(parameters=[render_object_id]) - self._send_command(command) - - @telemetry.record_call() - def clear_render_objects(self) -> None: - """ - Clear all render objects from the display. - """ - self._reporter.add_message("AgentOS", "clear_render_objects()") - command = ClearRenderObjectsCommand() - self._send_command(command) - - def get_system_info(self) -> GetSystemInfoResponseModel: - """ - Get the system information. - - Returns: - SystemInfo: The system information. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "get_system_info()") - command = GetSystemInfoCommand() - res = self._send_command(command).message.command - if not isinstance(res, GetSystemInfoResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - self._reporter.add_message("AgentOS", f"get_system_info() -> {res.response}") - return res.response - - def get_active_process(self) -> GetActiveProcessResponseModel: - """ - Get the active process. - - Returns: - GetActiveProcessResponseModel: The active process. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "get_active_process()") - command = GetActiveProcessCommand() - res = self._send_command(command).message.command - if not isinstance(res, GetActiveProcessResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - self._reporter.add_message("AgentOS", f"get_active_process() -> {res.response}") - return res.response - - def set_active_process(self, process_id: int) -> None: - """ - Set the active process. - - Args: - process_id (int): The ID of the process to set as active. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", f"set_active_process({process_id})") - _process_id = Parameter3(root=process_id) - command = SetActiveProcessCommand(parameters=[_process_id]) - self._send_command(command) - - def get_active_window(self) -> GetActiveWindowResponseModel: - """ - Gets the window id and name in addition to the process id - and name of the currently active window (in focus). - - - Returns: - GetActiveWindowResponseModel: The active window. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "get_active_window()") - command = GetActiveWindowCommand() - res = self._send_command(command).message.command - if not isinstance(res, GetActiveWindowResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - self._reporter.add_message("AgentOS", f"get_active_window() -> {res.response}") - return res.response - - def set_window_in_focus(self, process_id: int, window_id: int) -> None: - """ - Sets the window with the specified windowId of the process - with the specified processId active, - which brings it to the front and gives it focus. - - Args: - process_id (int): The ID of the process that owns the window. - window_id (int): The ID of the window to set as active. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message( - "AgentOS", f"set_window_in_focus({process_id}, {window_id})" - ) - _process_id = Parameter3(root=process_id) - _window_id = Parameter3(root=window_id) - command = SetActiveWindowCommand(parameters=[_process_id, _window_id]) - self._send_command(command) - - def get_file_names(self, absolute_directory_path: str) -> list[str]: - """ - Get the file names in the given absolute directory on the device under - automation. - - Args: - absolute_directory_path (str): The absolute directory path to list - file names from. - - Returns: - list[str]: The file names returned by the controller. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message( - "AgentOS", f"get_file_names({absolute_directory_path})" - ) - command = GetFileNamesCommand(parameters=[absolute_directory_path]) - res = self._send_command(command).message.command - if not isinstance(res, GetFileNamesResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - if res.error is not None: - raise DesktopAgentOsError(res.error) - if res.response is None: - message = f"{type(res).__name__} is missing both error and response" - raise DesktopAgentOsError(message) - self._reporter.add_message( - "AgentOS", f"get_file_names({absolute_directory_path}) -> {res.response}" - ) - return res.response.fileNames - - def get_file(self, path: str) -> Image.Image | str: - """ - Get the contents of a file at the given path on the device under - automation. - - The controller returns the file as a Base64-encoded string, which is - decoded and returned as `PIL.Image.Image` when the bytes can be opened - as an image (PNG, JPEG, BMP, GIF, WebP, TIFF, ...), or as `str` when - they decode cleanly as UTF-8 text. - - Args: - path (str): The file path to read on the device under automation. - - Returns: - Image.Image | str: The decoded file contents. - - Raises: - DesktopAgentOsError: If the file cannot be read or the response is invalid. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", f"get_file({path})") - command = GetFileCommand(parameters=[path]) - res = self._send_command(command).message.command - if not isinstance(res, GetFileResponse): - message = f"unexpected response type: {res}" - raise DesktopAgentOsError(message) - if res.error is not None: - raise DesktopAgentOsError(res.error) - if res.response is None: - message = f"{type(res).__name__} is missing both error and response" - raise DesktopAgentOsError(message) - decoded = self._decode_file_payload(res.response.file.content) - if isinstance(decoded, Image.Image): - detail = f"image ({decoded.format}, {decoded.size[0]}x{decoded.size[1]})" - self._reporter.add_message( - "AgentOS", f"get_file({path}) -> {detail}", decoded - ) - return decoded - - detail = f"text ({len(decoded)} chars)" - self._reporter.add_message("AgentOS", f"get_file({path}) -> {detail}") - return decoded - - def remove_virtual_displays(self) -> None: - """ - Remove all virtual displays from the controller, leaving only real - displays active. - """ - assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( - "Stub is not initialized" - ) - self._reporter.add_message("AgentOS", "remove_virtual_displays()") - command = RemoveVirtualDisplaysCommand() - self._send_command(command) - self._reporter.add_message("AgentOS", "remove_virtual_displays() -> done") - - @staticmethod - def _decode_file_payload(base64_data: str) -> Image.Image | str: - try: - return base64_to_image(base64_data) - except ValueError: - pass - data = base64.b64decode(base64_data, validate=True) - if b"\x00" not in data: - try: - return data.decode("utf-8") - except UnicodeDecodeError: - pass - message = "File contents are neither a supported image nor UTF-8 text" - raise DesktopAgentOsError(message) +import base64 +import time +import types +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Literal, Type + +import grpc +from google.protobuf.json_format import MessageToDict +from PIL import Image +from typing_extensions import Self, override + +from askui.container import telemetry +from askui.reporting import NULL_REPORTER, Reporter +from askui.tools.agent_os import ( + ComputerAgentOS, + Coordinate, + Display, + DisplaysListResponse, + ModifierKey, + PcKey, +) +from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + LocalComputerTarget, +) +from askui.tools.askui.askui_ui_controller_grpc.desktop_agent_os_error import ( + DesktopAgentOsError, +) +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2 as controller_v1_pbs, +) +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2_grpc as controller_v1, +) +from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Request_2501 import ( # noqa: E501 + AddRenderObjectCommand, + AskUIAgentOSSendRequestSchema, + ClearRenderObjectsCommand, + Command, + DeleteRenderObjectCommand, + GetActiveProcessCommand, + GetActiveWindowCommand, + GetFileCommand, + GetFileNamesCommand, + GetMousePositionCommand, + GetSystemInfoCommand, + Guid, + Header, + Length, + Location, + Message, + Parameter3, + RemoveVirtualDisplaysCommand, + RenderImage, + RenderObjectId, + RenderObjectStyle, + RenderText, + SetActiveProcessCommand, + SetActiveWindowCommand, + SetMousePositionCommand, + UpdateRenderObjectCommand, +) +from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Response_2501 import ( # noqa: E501 + AskUIAgentOSSendResponseSchema, + GetActiveProcessResponse, + GetActiveProcessResponseModel, + GetActiveWindowResponse, + GetActiveWindowResponseModel, + GetFileNamesResponse, + GetFileResponse, + GetSystemInfoResponse, + GetSystemInfoResponseModel, +) +from askui.tools.askui.computer_target_pool import ( + ComputerTargetPool, +) +from askui.utils.annotated_image import AnnotatedImage +from askui.utils.image_utils import base64_to_image + +from .exceptions import ( + AskUiControllerError, + AskUiControllerInvalidCommandError, + AskUiControllerOperationTimeoutError, +) + + +class MultiComputerTargetAgentOS(ComputerAgentOS): + """ + Implementation of `ComputerAgentOS` that communicates with one or more + computer targets (AskUI Remote Device Controller processes) via gRPC. + + A client is configured with a non-empty list of `agent_os_target_computers` + (at most one local, the rest remote with unique addresses). `connect()` opens + a gRPC channel and session for *every* registered target. Exactly one target + is *active* at a time; agent-os actions are routed to its connection. + `disconnect()` closes every open connection and stops only those local + processes that were started by this client (i.e. `is_local` and not + `is_service` at connect time). + + Use `add_agent_os_target_computer` to register additional targets (which + auto-connect if the client is currently connected), + `switch_agent_os_target_computer` to change the active one, + `describe_agent_os_target_computers` to inspect the registered targets, and + `reset_agent_os_target_computers` to clear or replace the list. + + Args: + reporter (Reporter): Reporter used for reporting with the `"AgentOS"`. + display (int, optional): Display number to use. Defaults to `1`. + agent_os_target_computers (list[ComputerTarget] | None, optional): + Computer targets to register. Must be non-empty if provided, contain + at most one local target, and have unique addresses across remote + targets. If `None` (default), a single `LocalComputerTarget` + with default settings is registered. + """ + + _REPORTER_SOURCE = "AgentOS" + + @telemetry.record_call(exclude={"reporter", "agent_os_target_computers"}) + def __init__( + self, + reporter: Reporter = NULL_REPORTER, + display: int = 1, + agent_os_target_computers: list[ComputerTarget] | None = None, + ) -> None: + if not agent_os_target_computers: + agent_os_target_computers = [LocalComputerTarget(display=display)] + + self._pre_action_wait = 0 + self._post_action_wait = 0.05 + self._max_retries = 10 + self._reporter = reporter + self._manager = ComputerTargetPool( + agent_os_target_computers=agent_os_target_computers + ) + + @property + def agent_os_target_computer_manager(self) -> ComputerTargetPool: + """The underlying target-computer manager.""" + return self._manager + + @property + def is_connected(self) -> bool: + """`True` when at least one target-computer connection is open.""" + return self._manager.is_connected + + def _require_active_agent_os_target_computer(self) -> ComputerTarget: + return self._manager.require_active() + + @property + def _session_info(self) -> controller_v1_pbs.SessionInfo: + return self._manager.active_connection().session_info + + @telemetry.record_call(exclude={"agent_os_target_computer"}) + @override + def add_agent_os_target_computer( + self, agent_os_target_computer: ComputerTarget + ) -> ComputerTarget: + """ + Register an already-constructed target computer. Auto-connects if the + client is currently connected. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f"add_agent_os_target_computer({agent_os_target_computer!r})", + ) + self._manager.add(agent_os_target_computer) + return agent_os_target_computer + + @telemetry.record_call(exclude={"agent_os_target_computers"}) + @override + def reset_agent_os_target_computers( + self, + agent_os_target_computers: list[ComputerTarget] | None = None, + ) -> None: + """ + Disconnect (if connected) and replace the target computer list. + + Args: + agent_os_target_computers (list[ComputerTarget] | None, optional): + New list of target computers to register after the reset. If + `None`, the list is left empty and a subsequent `connect()` will + fail until at least one target has been registered again. Same + validation rules as the constructor (at most one local, unique + remote addresses). + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f"reset_agent_os_target_computers({agent_os_target_computers!r})", + ) + was_connected = self.is_connected + if was_connected: + self.disconnect() + self._manager.reset() + if agent_os_target_computers is not None: + for agent_os_target_computer in agent_os_target_computers: + self._manager.add(agent_os_target_computer) + if was_connected: + self.connect() + + @telemetry.record_call() + @override + def describe_agent_os_target_computers(self) -> list[str]: + """Return the `repr()` string of every registered target computer.""" + self._reporter.add_message( + self._REPORTER_SOURCE, "describe_agent_os_target_computers()" + ) + agent_os_target_computer_reprs = self._manager.describe() + self._reporter.add_message( + self._REPORTER_SOURCE, + "describe_agent_os_target_computers() -> " + f"{agent_os_target_computer_reprs!r}", + ) + return agent_os_target_computer_reprs + + @telemetry.record_call() + @override + def get_current_computer_target_id(self, report: bool = True) -> str: + """Return the `computer_id` of the currently active Agent OS target computer.""" + if report: + self._reporter.add_message( + self._REPORTER_SOURCE, "get_current_computer_target_id()" + ) + computer_id = self._require_active_agent_os_target_computer().computer_id + if report: + self._reporter.add_message( + self._REPORTER_SOURCE, + f"get_current_computer_target_id() -> {computer_id!r}", + ) + return computer_id + + @telemetry.record_call() + @override + def switch_agent_os_target_computer(self, computer_id: str) -> ComputerTarget: + """ + Switch the active target computer by its `computer_id` (the user-supplied + identifier; defaults to the target's `session_guid` when none was supplied + at construction time). + + Connections to all registered targets stay open across switches; this just + changes which connection routes future agent-os actions. If the target was + added after `connect()` and isn't connected yet, it is connected on switch. + + Args: + computer_id (str): The computer id of the target to switch to. + + Returns: + ComputerTarget: The newly active target computer. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"switch_agent_os_target_computer({computer_id!r})" + ) + agent_os_target_computer = self._manager.switch(computer_id) + self._reporter.add_message( + self._REPORTER_SOURCE, + ( + f"switch_agent_os_target_computer({computer_id!r}) -> " + f"{agent_os_target_computer!r}" + ), + ) + return agent_os_target_computer + + @contextmanager + @override + def temporary_select(self, computer_id: str) -> Iterator[Self]: + previous = self._manager.active + self._reporter.add_message( + self._REPORTER_SOURCE, + f"temporary_select({computer_id!r}) [previous={previous!r}]", + ) + self.switch_agent_os_target_computer(computer_id) + try: + yield self + finally: + if previous is not None and previous.computer_id != computer_id: + self.switch_agent_os_target_computer(previous.computer_id) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"temporary_select({computer_id!r}) -> restored", + ) + + @telemetry.record_call() + @override + def connect(self) -> None: + """ + Open a gRPC channel and session to every registered target computer via + the underlying `ComputerTargetPool`. + """ + self._manager.connect() + + def _get_stub(self) -> controller_v1.ControllerAPIStub: + return self._manager.active_connection().stub + + def _run_recorder_action( + self, + acion_class_id: controller_v1_pbs.ActionClassID, + action_parameters: controller_v1_pbs.ActionParameters, + ) -> controller_v1_pbs.Response_RunRecordedAction: + time.sleep(self._pre_action_wait) + response: controller_v1_pbs.Response_RunRecordedAction = ( + self._get_stub().RunRecordedAction( + controller_v1_pbs.Request_RunRecordedAction( + sessionInfo=self._session_info, + actionClassID=acion_class_id, + actionParameters=action_parameters, + ) + ) + ) + + time.sleep((response.requiredMilliseconds / 1000)) + num_retries = 0 + for _ in range(self._max_retries): + poll_response: controller_v1_pbs.Response_Poll = self._get_stub().Poll( + controller_v1_pbs.Request_Poll( + sessionInfo=self._session_info, + pollEventID=controller_v1_pbs.PollEventID.PollEventID_ActionFinished, + ) + ) + if ( + poll_response.pollEventParameters.actionFinished.actionID + == response.actionID + ): + break + time.sleep(self._post_action_wait) + num_retries += 1 + if num_retries == self._max_retries - 1: + agent_os_target_computer = self._require_active_agent_os_target_computer() + timeout_seconds = self._max_retries * self._post_action_wait + timeout_msg = ( + f"Action did not finish on target computer " + f"{agent_os_target_computer.description!r} " + f"(session_guid={agent_os_target_computer.session_guid}) within " + f"{timeout_seconds:.2f}s ({self._max_retries} polls of " + f"{self._post_action_wait:.2f}s). " + f"Action class id: {acion_class_id}." + ) + raise AskUiControllerOperationTimeoutError( + message=timeout_msg, timeout_seconds=timeout_seconds + ) + return response + + @telemetry.record_call() + @override + def disconnect(self) -> None: + """ + Close every open target-computer connection via the underlying + `ComputerTargetPool`. + """ + self._manager.disconnect() + + @telemetry.record_call() + def __enter__(self) -> Self: + """ + Context manager entry point that establishes the connection. + + Returns: + Self: The instance of MultiComputerTargetAgentOS. + """ + self.connect() + return self + + @telemetry.record_call(exclude={"exc_value", "traceback"}) + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + """ + Context manager exit point that disconnects the client. + + Args: + exc_type: The exception type if an exception was raised. + exc_value: The exception value if an exception was raised. + traceback: The traceback if an exception was raised. + """ + self.disconnect() + + @telemetry.record_call() + @override + def screenshot(self, report: bool = True) -> Image.Image: + """ + Take a screenshot of the current screen. + + Args: + report (bool, optional): Whether to include the screenshot in reporting. + Defaults to `True`. + + Returns: + Image.Image: A PIL Image object containing the screenshot. + + """ + screenResponse = self._get_stub().CaptureScreen( + controller_v1_pbs.Request_CaptureScreen( + sessionInfo=self._session_info, + captureParameters=controller_v1_pbs.CaptureParameters( + displayID=self._require_active_agent_os_target_computer().display + ), + ) + ) + r, g, b, _ = Image.frombytes( + "RGBA", + (screenResponse.bitmap.width, screenResponse.bitmap.height), + screenResponse.bitmap.data, + ).split() + image = Image.merge("RGB", (b, g, r)) + self._reporter.add_message(self._REPORTER_SOURCE, "screenshot()", image) + return image + + @telemetry.record_call() + @override + def mouse_move(self, x: int, y: int, duration: int = 500) -> None: + """ + Moves the mouse cursor to specified screen coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to move to. + y (int): The vertical coordinate (in pixels) to move to. + duration (int): The duration (in ms) the movement should take. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f"mouse_move({x}, {y}, duration={duration})", + AnnotatedImage(lambda: self.screenshot(report=False), point_list=[(x, y)]), + ) + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseMove, + action_parameters=controller_v1_pbs.ActionParameters( + mouseMove=controller_v1_pbs.ActionParameters_MouseMove( + position=controller_v1_pbs.Coordinate2(x=x, y=y), + milliseconds=duration, + ) + ), + ) + + @telemetry.record_call(exclude={"text"}) + @override + def type(self, text: str, typing_speed: int = 50) -> None: + """ + Type text at current cursor position as if entered on a keyboard. + + Args: + text (str): The text to type. + typing_speed (int, optional): The speed of typing in characters per second. + Defaults to `50`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'type("{text}", {typing_speed})' + ) + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardType_UnicodeText, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardTypeUnicodeText=controller_v1_pbs.ActionParameters_KeyboardType_UnicodeText( + text=text.encode("utf-16-le"), + typingSpeed=typing_speed, + typingSpeedValue=controller_v1_pbs.TypingSpeedValue.TypingSpeedValue_CharactersPerSecond, + ) + ), + ) + + @telemetry.record_call() + @override + def click( + self, button: Literal["left", "middle", "right"] = "left", count: int = 1 + ) -> None: + """ + Click a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + click. Defaults to `"left"`. + count (int, optional): Number of times to click. Defaults to `1`. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f'click("{button}", {count})') + mouse_button = None + match button: + case "left": + mouse_button = controller_v1_pbs.MouseButton_Left + case "middle": + mouse_button = controller_v1_pbs.MouseButton_Middle + case "right": + mouse_button = controller_v1_pbs.MouseButton_Right + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_PressAndRelease, + action_parameters=controller_v1_pbs.ActionParameters( + mouseButtonPressAndRelease=controller_v1_pbs.ActionParameters_MouseButton_PressAndRelease( + mouseButton=mouse_button, count=count + ) + ), + ) + + @telemetry.record_call() + @override + def mouse_down(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Press and hold a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + press. Defaults to `"left"`. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f'mouse_down("{button}")') + mouse_button = None + match button: + case "left": + mouse_button = controller_v1_pbs.MouseButton_Left + case "middle": + mouse_button = controller_v1_pbs.MouseButton_Middle + case "right": + mouse_button = controller_v1_pbs.MouseButton_Right + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Press, + action_parameters=controller_v1_pbs.ActionParameters( + mouseButtonPress=controller_v1_pbs.ActionParameters_MouseButton_Press( + mouseButton=mouse_button + ) + ), + ) + + @telemetry.record_call() + @override + def mouse_up(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Release a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + release. Defaults to `"left"`. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f'mouse_up("{button}")') + mouse_button = None + match button: + case "left": + mouse_button = controller_v1_pbs.MouseButton_Left + case "middle": + mouse_button = controller_v1_pbs.MouseButton_Middle + case "right": + mouse_button = controller_v1_pbs.MouseButton_Right + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseButton_Release, + action_parameters=controller_v1_pbs.ActionParameters( + mouseButtonRelease=controller_v1_pbs.ActionParameters_MouseButton_Release( + mouseButton=mouse_button + ) + ), + ) + + @telemetry.record_call() + @override + def mouse_scroll(self, dx: int, dy: int) -> None: + """ + Scroll the mouse wheel. + + Args: + dx (int): The horizontal scroll amount. Positive values scroll right, + negative values scroll left. + dy (int): The vertical scroll amount. Positive values scroll down, + negative values scroll up. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f"mouse_scroll({dx}, {dy})") + if dx != 0: + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, + action_parameters=controller_v1_pbs.ActionParameters( + mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( + direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Horizontal, + deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, + delta=dx, + milliseconds=50, + ) + ), + ) + if dy != 0: + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, + action_parameters=controller_v1_pbs.ActionParameters( + mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( + direction=controller_v1_pbs.MouseWheelScrollDirection.MouseWheelScrollDirection_Vertical, + deltaType=controller_v1_pbs.MouseWheelDeltaType.MouseWheelDelta_Raw, + delta=dy, + milliseconds=50, + ) + ), + ) + + @telemetry.record_call() + @override + def keyboard_pressed( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Press and hold a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to press. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'keyboard_pressed("{key}", {modifier_keys})' + ) + if modifier_keys is None: + modifier_keys = [] + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Press, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyPress=controller_v1_pbs.ActionParameters_KeyboardKey_Press( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) + + @telemetry.record_call() + @override + def keyboard_release( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to release. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + release along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'keyboard_release("{key}", {modifier_keys})' + ) + if modifier_keys is None: + modifier_keys = [] + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_Release, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyRelease=controller_v1_pbs.ActionParameters_KeyboardKey_Release( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) + + @telemetry.record_call() + @override + def keyboard_tap( + self, + key: PcKey | ModifierKey, + modifier_keys: list[ModifierKey] | None = None, + count: int = 1, + ) -> None: + """ + Press and immediately release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to tap. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + count (int, optional): The number of times to tap the key. Defaults to `1`. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, + f'keyboard_tap("{key}", {modifier_keys}, {count})', + ) + if modifier_keys is None: + modifier_keys = [] + for _ in range(count): + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_PressAndRelease, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyPressAndRelease=controller_v1_pbs.ActionParameters_KeyboardKey_PressAndRelease( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) + + @telemetry.record_call() + @override + def set_display(self, display: int = 1) -> None: + """ + Set the active display. + + Args: + display (int, optional): The display ID to set as active. + This can be either a real display ID or a virtual display ID. + Defaults to `1`. + """ + self._get_stub().SetActiveDisplay( + controller_v1_pbs.Request_SetActiveDisplay(displayID=display) + ) + self._require_active_agent_os_target_computer().display = display + self._reporter.add_message(self._REPORTER_SOURCE, f"set_display({display})") + + @telemetry.record_call(exclude={"command"}) + @override + def run_command(self, command: str, timeout_ms: int = 30000) -> None: + """ + Execute a shell command. + + Args: + command (str): The command to execute. + timeout_ms (int, optional): The timeout for command + execution in milliseconds. Defaults to `30000` (30 seconds). + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f'run_command("{command}", {timeout_ms})' + ) + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_RunCommand, + action_parameters=controller_v1_pbs.ActionParameters( + runcommand=controller_v1_pbs.ActionParameters_RunCommand( + command=command, timeoutInMilliseconds=timeout_ms + ) + ), + ) + + @telemetry.record_call() + @override + def retrieve_active_display(self) -> Display: + """ + Retrieve the currently active display/screen. + + Returns: + Display: The currently active display/screen. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "retrieve_active_display()") + agent_os_target_computer = self._require_active_agent_os_target_computer() + active_display_id = agent_os_target_computer.display + displays_list_response = self.list_displays() + for display in displays_list_response.data: + if display.id == active_display_id: + self._reporter.add_message( + self._REPORTER_SOURCE, f"retrieve_active_display() -> {display}" + ) + return display + available_ids = ( + ", ".join(str(d.id) for d in displays_list_response.data) or "none" + ) + error_msg = ( + f"Display {active_display_id} not found on target computer " + f"{agent_os_target_computer.description!r} " + f"(session_guid={agent_os_target_computer.session_guid}). " + f"Available display ids: {available_ids}. " + "Call `set_display()` with a valid id, or `list_displays()` to inspect." + ) + raise ValueError(error_msg) + + @telemetry.record_call() + @override + def list_displays( + self, + ) -> DisplaysListResponse: + """ + List all available Displays from the controller. + It includes both real and virtual displays + without describing the type of display (virtual or real). + + Returns: + DisplaysListResponse + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "list_displays()") + + response: controller_v1_pbs.Response_GetDisplayInformation = ( + self._get_stub().GetDisplayInformation(controller_v1_pbs.Request_Void()) + ) + + response_dict = MessageToDict( + response, + preserving_proto_field_name=True, + ) + + displays = DisplaysListResponse.model_validate(response_dict) + + self._reporter.add_message( + self._REPORTER_SOURCE, f"list_displays() ->{str(displays)}" + ) + + return displays + + @telemetry.record_call() + def get_process_list( + self, get_extended_info: bool = False + ) -> controller_v1_pbs.Response_GetProcessList: + """ + Get a list of running processes. + + Args: + get_extended_info (bool, optional): Whether to include + extended process information. + Defaults to `False`. + + Returns: + controller_v1_pbs.Response_GetProcessList: Process list response containing: + - processes: List of ProcessInfo objects + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_process_list({get_extended_info})" + ) + + response: controller_v1_pbs.Response_GetProcessList = ( + self._get_stub().GetProcessList( + controller_v1_pbs.Request_GetProcessList( + getExtendedInfo=get_extended_info + ) + ) + ) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"get_process_list({get_extended_info}) -> {response}", + ) + + return response + + @telemetry.record_call() + def get_window_list( + self, process_id: int + ) -> controller_v1_pbs.Response_GetWindowList: + """ + Get a list of windows for a specific process. + + Args: + process_id (int): The ID of the process to get windows for. + + Returns: + controller_v1_pbs.Response_GetWindowList: Window list response containing: + - windows: List of WindowInfo objects with ID and name + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_window_list({process_id})" + ) + + response: controller_v1_pbs.Response_GetWindowList = ( + self._get_stub().GetWindowList( + controller_v1_pbs.Request_GetWindowList(processID=process_id) + ) + ) + + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_window_list({process_id}) -> {response}" + ) + + return response + + @telemetry.record_call() + def get_automation_target_list( + self, + ) -> controller_v1_pbs.Response_GetAutomationTargetList: + """ + Get a list of available automation targets. + + Returns: + controller_v1_pbs.Response_GetAutomationTargetList: + Automation target list response: + - targets: List of AutomationTarget objects + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, "get_automation_target_list()" + ) + + response: controller_v1_pbs.Response_GetAutomationTargetList = ( + self._get_stub().GetAutomationTargetList(controller_v1_pbs.Request_Void()) + ) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_automation_target_list() -> {response}" + ) + + return response + + @telemetry.record_call() + def set_mouse_delay(self, delay_ms: int) -> None: + """ + Configure mouse action delay. + + Args: + delay_ms (int): The delay in milliseconds to set for mouse actions. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_mouse_delay({delay_ms})" + ) + + self._get_stub().SetMouseDelay( + controller_v1_pbs.Request_SetMouseDelay( + sessionInfo=self._session_info, delayInMilliseconds=delay_ms + ) + ) + + @telemetry.record_call() + def set_keyboard_delay(self, delay_ms: int) -> None: + """ + Configure keyboard action delay. + + Args: + delay_ms (int): The delay in milliseconds to set for keyboard actions. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_keyboard_delay({delay_ms})" + ) + + self._get_stub().SetKeyboardDelay( + controller_v1_pbs.Request_SetKeyboardDelay( + sessionInfo=self._session_info, delayInMilliseconds=delay_ms + ) + ) + + @telemetry.record_call() + def set_active_window(self, process_id: int, window_id: int) -> int: + """ + Set the active window for automation. + Adds the window as a virtual display and returns the display ID. + It raises an error if display length is not increased after adding the window. + + Args: + process_id (int): The ID of the process that owns the window. + window_id (int): The ID of the window to set as active. + + returns: + int: The new Display ID. + Raises: + AskUiControllerError: + If display length is not increased after adding the window. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_active_window({process_id}, {window_id})" + ) + + display_length_before_adding_window = len(self.list_displays().data) + + self._get_stub().SetActiveWindow( + controller_v1_pbs.Request_SetActiveWindow( + processID=process_id, windowID=window_id + ) + ) + new_display_length = len(self.list_displays().data) + if new_display_length <= display_length_before_adding_window: + msg = ( + f"Failed to add window {window_id} of process {process_id} as a " + f"virtual display: display count did not increase " + f"({display_length_before_adding_window} -> {new_display_length}). " + "Verify the process and window ids exist and are valid for the " + "active target computer." + ) + raise AskUiControllerError(msg) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"set_active_window({process_id}, {window_id}) -> {new_display_length}", + ) + return new_display_length + + @telemetry.record_call() + def set_active_automation_target(self, target_id: int) -> None: + """ + Set the active automation target. + + Args: + target_id (int): The ID of the automation target to set as active. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_active_automation_target({target_id})" + ) + + self._get_stub().SetActiveAutomationTarget( + controller_v1_pbs.Request_SetActiveAutomationTarget(ID=target_id) + ) + + @telemetry.record_call() + def schedule_batched_action( + self, + action_class_id: controller_v1_pbs.ActionClassID, + action_parameters: controller_v1_pbs.ActionParameters, + ) -> controller_v1_pbs.Response_ScheduleBatchedAction: + """ + Schedule an action for batch execution. + + Args: + action_class_id (controller_v1_pbs.ActionClassID): The class ID + of the action to schedule. + action_parameters (controller_v1_pbs.ActionParameters): + Parameters for the action. + + Returns: + controller_v1_pbs.Response_ScheduleBatchedAction: Response containing + the scheduled action ID. + """ + + self._reporter.add_message( + self._REPORTER_SOURCE, + f"schedule_batched_action({action_class_id}, {action_parameters})", + ) + + response: controller_v1_pbs.Response_ScheduleBatchedAction = ( + self._get_stub().ScheduleBatchedAction( + controller_v1_pbs.Request_ScheduleBatchedAction( + sessionInfo=self._session_info, + actionClassID=action_class_id, + actionParameters=action_parameters, + ) + ) + ) + + return response + + @telemetry.record_call() + def start_batch_run(self) -> None: + """ + Start executing batched actions. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "start_batch_run()") + + self._get_stub().StartBatchRun( + controller_v1_pbs.Request_StartBatchRun(sessionInfo=self._session_info) + ) + + @telemetry.record_call() + def stop_batch_run(self) -> None: + """ + Stop executing batched actions. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "stop_batch_run()") + + self._get_stub().StopBatchRun( + controller_v1_pbs.Request_StopBatchRun(sessionInfo=self._session_info) + ) + + @telemetry.record_call() + def get_action_count(self) -> controller_v1_pbs.Response_GetActionCount: + """ + Get the count of recorded or batched actions. + + Returns: + controller_v1_pbs.Response_GetActionCount: Response + containing the action count. + """ + + response: controller_v1_pbs.Response_GetActionCount = ( + self._get_stub().GetActionCount( + controller_v1_pbs.Request_GetActionCount(sessionInfo=self._session_info) + ) + ) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_action_count() -> {response}" + ) + return response + + @telemetry.record_call() + def get_action(self, action_index: int) -> controller_v1_pbs.Response_GetAction: + """ + Get a specific action by its index. + + Args: + action_index (int): The index of the action to retrieve. + + Returns: + controller_v1_pbs.Response_GetAction: Action information containing: + - actionID: The action ID + - actionClassID: The action class ID + - actionParameters: The action parameters + """ + + self._reporter.add_message(self._REPORTER_SOURCE, f"get_action({action_index})") + + response: controller_v1_pbs.Response_GetAction = self._get_stub().GetAction( + controller_v1_pbs.Request_GetAction( + sessionInfo=self._session_info, actionIndex=action_index + ) + ) + + return response + + @telemetry.record_call() + def remove_action(self, action_id: int) -> None: + """ + Remove a specific action by its ID. + + Args: + action_id (int): The ID of the action to remove. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, f"remove_action({action_id})") + + self._get_stub().RemoveAction( + controller_v1_pbs.Request_RemoveAction( + sessionInfo=self._session_info, actionID=action_id + ) + ) + + @telemetry.record_call() + def remove_all_actions(self) -> None: + """ + Clear all recorded or batched actions. + """ + + self._reporter.add_message(self._REPORTER_SOURCE, "remove_all_actions()") + + self._get_stub().RemoveAllActions( + controller_v1_pbs.Request_RemoveAllActions(sessionInfo=self._session_info) + ) + + def _send_command(self, command: Command) -> AskUIAgentOSSendResponseSchema: + """ + Send a general command to the controller. + + Args: + command (Command): The command to send to the controller. + + Returns: + AskUIAgentOSSendResponseSchema: Response containing + the message from the controller. + + Raises: + AskUiControllerInvalidCommandError: If the command fails schema validation + on the target computer side. + """ + + agent_os_target_computer = self._require_active_agent_os_target_computer() + header = Header(authentication=Guid(root=agent_os_target_computer.session_guid)) + message = Message(header=header, command=command) + + request = AskUIAgentOSSendRequestSchema(message=message) + + request_str = request.model_dump_json(exclude_none=True, by_alias=True) + + try: + response: controller_v1_pbs.Response_Send = self._get_stub().Send( + controller_v1_pbs.Request_Send(message=request_str) + ) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.INVALID_ARGUMENT: + details = e.details() or None + raise AskUiControllerInvalidCommandError(details) from e + raise + + return AskUIAgentOSSendResponseSchema.model_validate_json(response.message) + + @telemetry.record_call() + def get_mouse_position(self) -> Coordinate: + """ + Get the mouse cursor position + + Returns: + Coordinate: Response containing the result of the mouse position change. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_mouse_position()") + res = self._send_command(GetMousePositionCommand()) + coordinate = Coordinate( + x=res.message.command.response.position.x.root, # type: ignore[union-attr] + y=res.message.command.response.position.y.root, # type: ignore[union-attr] + ) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_mouse_position() -> {coordinate}" + ) + return coordinate + + @telemetry.record_call() + def set_mouse_position(self, x: int, y: int) -> None: + """ + Set the mouse cursor position to specific coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to set the cursor to. + y (int): The vertical coordinate (in pixels) to set the cursor to. + """ + location = Location(x=Length(root=x), y=Length(root=y)) + command = SetMousePositionCommand(parameters=[location]) + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_mouse_position({x},{y})" + ) + self._send_command(command) + + @telemetry.record_call() + def render_quad(self, style: RenderObjectStyle) -> int: + """ + Render a quad object to the display. + + Args: + style (RenderObjectStyle): The style properties for the quad. + + Returns: + int: Object ID. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f"render_quad({style})") + command = AddRenderObjectCommand(parameters=["Quad", style]) + res = self._send_command(command) + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call() + def render_line(self, style: RenderObjectStyle, points: list[Coordinate]) -> int: + """ + Render a line object to the display. + + Args: + style (RenderObjectStyle): The style properties for the line. + points (list[Coordinates]): The points defining the line. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"render_line({style}, {points})" + ) + command = AddRenderObjectCommand(parameters=["Line", style, points]) + res = self._send_command(command) + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call(exclude={"image_data"}) + def render_image(self, style: RenderObjectStyle, image_data: str) -> int: + """ + Render an image object to the display. + + Args: + style (RenderObjectStyle): The style properties for the image. + image_data (str): The base64-encoded image data. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"render_image({style}, [image_data])" + ) + image = RenderImage(root=image_data) + command = AddRenderObjectCommand(parameters=["Image", style, image]) + res = self._send_command(command) + + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call() + def render_text(self, style: RenderObjectStyle, content: str) -> int: + """ + Render a text object to the display. + + Args: + style (RenderObjectStyle): The style properties for the text. + content (str): The text content to display. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"render_text({style}, {content})" + ) + text = RenderText(root=content) + command = AddRenderObjectCommand(parameters=["Text", style, text]) + res = self._send_command(command) + return int(res.message.command.response.id.root) # type: ignore[union-attr] + + @telemetry.record_call() + def update_render_object(self, object_id: int, style: RenderObjectStyle) -> None: + """ + Update styling properties of an existing render object. + + Args: + object_id (float): The ID of the render object to update. + style (RenderObjectStyle): The new style properties. + + Returns: + int: Object ID. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"update_render_object({object_id}, {style})" + ) + render_object_id = RenderObjectId(root=object_id) + command = UpdateRenderObjectCommand(parameters=[render_object_id, style]) + self._send_command(command) + + @telemetry.record_call() + def delete_render_object(self, object_id: int) -> None: + """ + Delete an existing render object from the display. + + Args: + object_id (RenderObjectId): The ID of the render object to delete. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"delete_render_object({object_id})" + ) + render_object_id = RenderObjectId(root=object_id) + command = DeleteRenderObjectCommand(parameters=[render_object_id]) + self._send_command(command) + + @telemetry.record_call() + def clear_render_objects(self) -> None: + """ + Clear all render objects from the display. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "clear_render_objects()") + command = ClearRenderObjectsCommand() + self._send_command(command) + + def get_system_info(self) -> GetSystemInfoResponseModel: + """ + Get the system information. + + Returns: + SystemInfo: The system information. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_system_info()") + command = GetSystemInfoCommand() + res = self._send_command(command).message.command + if not isinstance(res, GetSystemInfoResponse): + message = ( + f"get_system_info: expected GetSystemInfoResponse from the " + f"controller but got {type(res).__name__}: {res!r}" + ) + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_system_info() -> {res.response}" + ) + return res.response + + def get_active_process(self) -> GetActiveProcessResponseModel: + """ + Get the active process. + + Returns: + GetActiveProcessResponseModel: The active process. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_active_process()") + command = GetActiveProcessCommand() + res = self._send_command(command).message.command + if not isinstance(res, GetActiveProcessResponse): + message = ( + f"get_active_process: expected GetActiveProcessResponse from the " + f"controller but got {type(res).__name__}: {res!r}" + ) + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_active_process() -> {res.response}" + ) + return res.response + + def set_active_process(self, process_id: int) -> None: + """ + Set the active process. + + Args: + process_id (int): The ID of the process to set as active. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_active_process({process_id})" + ) + _process_id = Parameter3(root=process_id) + command = SetActiveProcessCommand(parameters=[_process_id]) + self._send_command(command) + + def get_active_window(self) -> GetActiveWindowResponseModel: + """ + Gets the window id and name in addition to the process id + and name of the currently active window (in focus). + + + Returns: + GetActiveWindowResponseModel: The active window. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "get_active_window()") + command = GetActiveWindowCommand() + res = self._send_command(command).message.command + if not isinstance(res, GetActiveWindowResponse): + message = ( + f"get_active_window: expected GetActiveWindowResponse from the " + f"controller but got {type(res).__name__}: {res!r}" + ) + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_active_window() -> {res.response}" + ) + return res.response + + def set_window_in_focus(self, process_id: int, window_id: int) -> None: + """ + Sets the window with the specified windowId of the process + with the specified processId active, + which brings it to the front and gives it focus. + + Args: + process_id (int): The ID of the process that owns the window. + window_id (int): The ID of the window to set as active. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"set_window_in_focus({process_id}, {window_id})" + ) + _process_id = Parameter3(root=process_id) + _window_id = Parameter3(root=window_id) + command = SetActiveWindowCommand(parameters=[_process_id, _window_id]) + self._send_command(command) + + @telemetry.record_call() + @override + def get_file_names(self, absolute_directory_path: str) -> list[str]: + """ + Get the file names in the given absolute directory on the device under + automation. + + Args: + absolute_directory_path (str): The absolute directory path to list + file names from. + + Returns: + list[str]: The file names returned by the controller. + """ + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_file_names({absolute_directory_path})" + ) + command = GetFileNamesCommand(parameters=[absolute_directory_path]) + res = self._send_command(command).message.command + if not isinstance(res, GetFileNamesResponse): + message = f"unexpected response type: {res}" + raise DesktopAgentOsError(message) + if res.error is not None: + raise DesktopAgentOsError(res.error) + if res.response is None: + message = f"{type(res).__name__} is missing both error and response" + raise DesktopAgentOsError(message) + self._reporter.add_message( + self._REPORTER_SOURCE, + f"get_file_names({absolute_directory_path}) -> {res.response}", + ) + return res.response.fileNames + + @telemetry.record_call() + @override + def get_file(self, path: str) -> Image.Image | str: + """ + Get the contents of a file at the given path on the device under + automation. + + The controller returns the file as a Base64-encoded string, which is + decoded and returned as `PIL.Image.Image` when the bytes can be opened + as an image (PNG, JPEG, BMP, GIF, WebP, TIFF, ...), or as `str` when + they decode cleanly as UTF-8 text. + + Args: + path (str): The file path to read on the device under automation. + + Returns: + Image.Image | str: The decoded file contents. + + Raises: + DesktopAgentOsError: If the file cannot be read or the response is invalid. + """ + self._reporter.add_message(self._REPORTER_SOURCE, f"get_file({path})") + command = GetFileCommand(parameters=[path]) + res = self._send_command(command).message.command + if not isinstance(res, GetFileResponse): + message = f"unexpected response type: {res}" + raise DesktopAgentOsError(message) + if res.error is not None: + raise DesktopAgentOsError(res.error) + if res.response is None: + message = f"{type(res).__name__} is missing both error and response" + raise DesktopAgentOsError(message) + decoded = self._decode_file_payload(res.response.file.content) + if isinstance(decoded, Image.Image): + detail = f"image ({decoded.format}, {decoded.size[0]}x{decoded.size[1]})" + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_file({path}) -> {detail}", decoded + ) + return decoded + + detail = f"text ({len(decoded)} chars)" + self._reporter.add_message( + self._REPORTER_SOURCE, f"get_file({path}) -> {detail}" + ) + return decoded + + @telemetry.record_call() + @override + def remove_virtual_displays(self) -> None: + """ + Remove all virtual displays from the controller, leaving only real + displays active. + """ + self._reporter.add_message(self._REPORTER_SOURCE, "remove_virtual_displays()") + command = RemoveVirtualDisplaysCommand() + self._send_command(command) + self._reporter.add_message( + self._REPORTER_SOURCE, "remove_virtual_displays() -> done" + ) + + @staticmethod + def _decode_file_payload(base64_data: str) -> Image.Image | str: + try: + return base64_to_image(base64_data) + except ValueError: + pass + data = base64.b64decode(base64_data, validate=True) + if b"\x00" not in data: + try: + return data.decode("utf-8") + except UnicodeDecodeError: + pass + message = "File contents are neither a supported image nor UTF-8 text" + raise DesktopAgentOsError(message) + + +AskUiControllerClient = MultiComputerTargetAgentOS diff --git a/src/askui/tools/askui/askui_controller_client_settings.py b/src/askui/tools/askui/askui_controller_client_settings.py deleted file mode 100644 index 6e53b747..00000000 --- a/src/askui/tools/askui/askui_controller_client_settings.py +++ /dev/null @@ -1,34 +0,0 @@ -from pydantic import Field -from pydantic_settings import BaseSettings, SettingsConfigDict - - -class AskUiControllerClientSettings(BaseSettings): - """ - Settings for the AskUI Remote Device Controller client. - """ - - model_config = SettingsConfigDict( - env_prefix="ASKUI_CONTROLLER_CLIENT_", - ) - - server_address: str = Field( - default="localhost:23000", - description="Address of the AskUI Remote Device Controller server.", - ) - - server_autostart: bool = Field( - default=True, - description="Whether to automatically start the AskUI Remote Device" - "Controller server. Defaults to True.", - ) - - clean_virtual_displays: bool = Field( - default=False, - description=( - "Whether to clean virtual displays after the controller is started." - "Default: False" - ), - ) - - -__all__ = ["AskUiControllerClientSettings"] diff --git a/src/askui/tools/askui/computer_target_connection.py b/src/askui/tools/askui/computer_target_connection.py new file mode 100644 index 00000000..54b8e225 --- /dev/null +++ b/src/askui/tools/askui/computer_target_connection.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import grpc + +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2 as controller_v1_pbs, +) +from askui.tools.askui.askui_ui_controller_grpc.generated import ( + Controller_V1_pb2_grpc as controller_v1, +) +from askui.tools.askui.exceptions import AskUiControllerError + +if TYPE_CHECKING: + from askui.tools.askui.agent_os_target_computer import ComputerTarget + +logger = logging.getLogger(__name__) + + +@dataclass +class ComputerTargetConnection: + """ + The live gRPC connection to a `ComputerTarget`: the open channel, the + controller stub bound to it, and the session opened on the target computer. + + Holds only the live connection handles; the `ComputerTarget` it belongs to + is passed in when opening or closing. Encapsulates all gRPC specifics so + that `ComputerTarget` and `ComputerTargetPool` stay free of channel / stub / + session details. + + Args: + channel (grpc.Channel): The open gRPC channel. + stub (ControllerAPIStub): The controller API stub bound to `channel`. + session_info (SessionInfo): The session opened on the target computer. + """ + + channel: grpc.Channel + stub: controller_v1.ControllerAPIStub + session_info: controller_v1_pbs.SessionInfo + + @classmethod + def open(cls, target: ComputerTarget) -> ComputerTargetConnection: + """ + Open a gRPC channel and session to `target`. + + Starts the target's local controller process first (a no-op for remote + and service-managed targets), opens an insecure gRPC channel, starts a + session, starts execution, and sets the configured display. + + On failure during session setup, the channel is closed and any started + process is stopped before re-raising. + """ + target.start() + channel = grpc.insecure_channel( + target.address, + options=[ + ("grpc.max_send_message_length", 2**30), + ("grpc.max_receive_message_length", 2**30), + ("grpc.default_deadline", 300000), + ], + ) + stub = controller_v1.ControllerAPIStub(channel) + try: + session_response: controller_v1_pbs.Response_StartSession = ( + stub.StartSession( + controller_v1_pbs.Request_StartSession( + sessionGUID=target.session_guid, + immediateExecution=True, + ) + ) + ) + session_info = session_response.sessionInfo + stub.StartExecution( + controller_v1_pbs.Request_StartExecution(sessionInfo=session_info) + ) + stub.SetActiveDisplay( + controller_v1_pbs.Request_SetActiveDisplay(displayID=target.display) + ) + except Exception as e: + try: + channel.close() + finally: + target.stop() + error_msg = ( + f"Failed to connect to Agent OS target computer " + f"{target.description!r} " + f"(computer_id={target.computer_id!r}, " + f"session_guid={target.session_guid}, " + f"display={target.display}, " + f"address={target.address}): {e}" + ) + raise AskUiControllerError(error_msg) from e + return cls(channel=channel, stub=stub, session_info=session_info) + + def close(self, target: ComputerTarget) -> None: + """ + Close this connection to `target`. + + Stops execution, ends the session, closes the gRPC channel, and stops + the target's local controller process (a no-op unless this client + started one). Errors are logged but never raised, so a partial failure + still releases the rest of the connection. + """ + computer_id = target.computer_id + try: + self.stub.StopExecution( + controller_v1_pbs.Request_StopExecution(sessionInfo=self.session_info) + ) + self.stub.EndSession( + controller_v1_pbs.Request_EndSession(sessionInfo=self.session_info) + ) + except Exception: # noqa: BLE001 + logger.exception( + "Error stopping execution/session for controller %s", computer_id + ) + try: + self.channel.close() + except Exception: # noqa: BLE001 + logger.exception("Error closing channel for controller %s", computer_id) + try: + target.stop() + except Exception: # noqa: BLE001 + logger.exception( + "Error stopping client-started controller process for %s", computer_id + ) + + +__all__ = ["ComputerTargetConnection"] diff --git a/src/askui/tools/askui/computer_target_pool.py b/src/askui/tools/askui/computer_target_pool.py new file mode 100644 index 00000000..7c5ce23d --- /dev/null +++ b/src/askui/tools/askui/computer_target_pool.py @@ -0,0 +1,281 @@ +from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, +) +from askui.tools.askui.computer_target_connection import ComputerTargetConnection +from askui.tools.askui.exceptions import AskUiControllerError + + +class ComputerTargetPool: + """ + Manages a collection of `ComputerTarget` instances and tracks the currently + active one. Each target owns its own gRPC connection + (`ComputerTarget.connection`); the pool only orchestrates connecting / + disconnecting them and selecting the active one. + + Responsibilities: + - Register / unregister `ComputerTarget` instances with uniqueness + constraints (at most one local, unique computer ids / session GUIDs, + unique remote addresses). + - Drive `connect()` / `disconnect()` on registered targets (individually + or all at once). + - Track which registered target is currently active and expose its + connection needed to route agent-os actions to it. + + The first target added becomes active by default. Use `switch` to change + which target is active. `connect` opens connections to every registered + target; subsequently `add` / `switch` auto-connect any + newly-introduced target whenever the manager already holds at least one + open connection. + + Targets are addressed exclusively by their `computer_id`. + + Args: + agent_os_target_computers (list[ComputerTarget] | None, optional): + Initial targets to register. + """ + + def __init__( + self, + agent_os_target_computers: list[ComputerTarget] | None = None, + ) -> None: + # Single store. Python dicts preserve insertion order, so this also + # defines `list()` order and the first-added-is-active semantics. Each + # target owns its own connection, so no separate connection store is + # needed here. + self._by_computer_id: dict[str, ComputerTarget] = {} + self._active_computer_id: str | None = None + if agent_os_target_computers: + for target in agent_os_target_computers: + self.add(target) + + @property + def is_connected(self) -> bool: + """`True` when at least one registered target has an open connection.""" + return any(t.is_connected for t in self._by_computer_id.values()) + + def add(self, target: ComputerTarget) -> ComputerTarget: + """ + Register an Agent OS target computer. Auto-connects when the manager + already has at least one open connection. + + Args: + target (ComputerTarget): The target computer to register. + + Returns: + ComputerTarget: The registered target. + + Raises: + ValueError: If another local target is already registered, the same + session GUID or computer id is already registered, or another + remote target with the same address is already registered. + """ + self._validate_addable(target) + self._by_computer_id[target.computer_id] = target + if self._active_computer_id is None: + self._active_computer_id = target.computer_id + if self.is_connected: + self.connect_target(target) + return target + + def reset(self) -> None: + """Disconnect every open connection and remove all registered targets.""" + self.disconnect() + self._by_computer_id.clear() + self._active_computer_id = None + + def remove(self, computer_id: str) -> None: + """ + Remove a registered target by its `computer_id`. If the target was + connected, its connection is closed first. + + Args: + computer_id (str): The computer id of the target to remove. + + Raises: + KeyError: If no target with the given computer id is registered. + """ + self._require(computer_id) + self.disconnect_target(computer_id) + del self._by_computer_id[computer_id] + if self._active_computer_id == computer_id: + self._active_computer_id = next(iter(self._by_computer_id), None) + + def describe(self) -> list[str]: + """ + Return the `repr()` of every registered target, in registration order. + """ + return [repr(target) for target in self._by_computer_id.values()] + + def get(self, computer_id: str) -> ComputerTarget: + """ + Return the registered target with the given `computer_id`. + + Raises: + KeyError: If no target with the given computer id is registered. + """ + return self._require(computer_id) + + def switch(self, computer_id: str) -> ComputerTarget: + """ + Set the active target by its `computer_id`. Auto-connects the new + active target when the manager already has at least one open connection + but this target is not yet connected. + + Args: + computer_id (str): The computer id of the target to activate. + + Returns: + ComputerTarget: The newly active target. + + Raises: + KeyError: If no target with the given computer id is registered. + """ + target = self._require(computer_id) + self._active_computer_id = computer_id + if self.is_connected and not target.is_connected: + self.connect_target(target) + return target + + @property + def active(self) -> ComputerTarget | None: + """The currently active target, or `None` if no targets are registered.""" + if self._active_computer_id is None: + return None + return self._by_computer_id.get(self._active_computer_id) + + def require_active(self) -> ComputerTarget: + """ + Return the currently active target. + + Raises: + AskUiControllerError: If no target is currently active. + """ + target = self.active + if target is None: + error_msg = ( + "No active Agent OS target computer. Register one via " + "`MultiComputerTargetAgentOS.add_agent_os_target_computer()`, or " + "pass `agent_os_target_computers` to the " + "`MultiComputerTargetAgentOS` constructor." + ) + raise AskUiControllerError(error_msg) + return target + + def active_connection(self) -> ComputerTargetConnection: + """ + Return the gRPC connection for the currently active target. + + Raises: + AskUiControllerError: If no target is currently active or the active + target has no open connection (i.e. `connect()` has not been + called). + """ + return self.require_active().connection + + def connect(self) -> None: + """ + Open the connection to every registered Agent OS target via + `ComputerTarget.connect()`. Targets already connected are skipped, so + calling `connect()` twice is safe. + + Raises: + AskUiControllerError: If no targets are registered. + + On failure mid-loop, all targets connected so far are rolled back via + `disconnect()` before re-raising. + """ + if not self._by_computer_id: + error_msg = ( + "Cannot connect: no Agent OS target computers registered. Provide " + "at least one via the `MultiComputerTargetAgentOS` constructor's " + "`agent_os_target_computers` argument, or call " + "`add_agent_os_target_computer()` before `connect()`." + ) + raise AskUiControllerError(error_msg) + try: + for target in self._by_computer_id.values(): + self.connect_target(target) + except Exception: + self.disconnect() + raise + + def connect_target(self, target: ComputerTarget) -> None: + """ + Open the connection to a single registered Agent OS target. Idempotent: + returns silently if the target is already connected. Delegates to + `ComputerTarget.connect()`. + """ + target.connect() + + def disconnect(self) -> None: + """ + Close every open Agent OS target connection. Errors on one connection + are logged but do not abort the loop - a partial failure still releases + the others. + """ + for target in self._by_computer_id.values(): + target.disconnect() + + def disconnect_target(self, computer_id: str) -> None: + """ + Close a single open Agent OS target connection identified by its + `computer_id`. No-op if no such connection is open or no such target is + registered. Delegates to `ComputerTarget.disconnect()`. + """ + target = self._by_computer_id.get(computer_id) + if target is not None: + target.disconnect() + + def __len__(self) -> int: + return len(self._by_computer_id) + + def __contains__(self, computer_id: object) -> bool: + return isinstance(computer_id, str) and computer_id in self._by_computer_id + + def _validate_addable(self, target: ComputerTarget) -> None: + if target.is_local: + existing_local = next( + (t for t in self._by_computer_id.values() if t.is_local), None + ) + if existing_local is not None: + error_msg = ( + "Cannot register a second local Agent OS target computer. At " + "most one local target is supported. Existing local target: " + f"{existing_local.description!r} " + f"(computer_id={existing_local.computer_id!r}). " + "Remove it first via `remove(computer_id)`." + ) + raise ValueError(error_msg) + if target.computer_id in self._by_computer_id: + error_msg = ( + "An Agent OS target computer with " + f"computer_id={target.computer_id!r} is already registered. " + "Each target must have a unique computer_id." + ) + raise ValueError(error_msg) + if not target.is_local and any( + (not t.is_local) and t.address == target.address + for t in self._by_computer_id.values() + ): + error_msg = ( + f"A remote Agent OS target computer with address " + f"{target.address!r} is already registered. Each remote target " + "must have a unique address." + ) + raise ValueError(error_msg) + + def _require(self, computer_id: str) -> ComputerTarget: + target = self._by_computer_id.get(computer_id) + if target is not None: + return target + registered = ", ".join(repr(cid) for cid in self._by_computer_id) or "none" + error_msg = ( + f"No Agent OS target computer with computer_id={computer_id!r} is " + f"registered. Registered computer ids: {registered}. Use " + "`describe_agent_os_target_computers()` to inspect the registered " + "targets." + ) + raise KeyError(error_msg) + + +__all__ = ["ComputerTargetPool"] diff --git a/src/askui/tools/askui/exceptions.py b/src/askui/tools/askui/exceptions.py index 1398ff2b..ecfc2c16 100644 --- a/src/askui/tools/askui/exceptions.py +++ b/src/askui/tools/askui/exceptions.py @@ -1,8 +1,9 @@ class AskUiControllerError(Exception): """Base exception for AskUI controller errors. - This exception is raised when there is an error in the AskUI controller (client), - which handles the communication with the AskUI controller (server). + This exception is raised when there is an error in the AskUI controller + client, which handles the communication with the AskUI controller process + running on the target computer. Args: message (str): The error message. @@ -42,7 +43,11 @@ class AskUiControllerOperationTimeoutError(AskUiControllerError): """ def __init__( - self, message: str = "Action not yet done", timeout_seconds: float | None = None + self, + message: str = ( + "Controller action did not finish within the expected time window." + ), + timeout_seconds: float | None = None, ): super().__init__(message) self.timeout_seconds = timeout_seconds @@ -52,21 +57,23 @@ class AskUiControllerInvalidCommandError(AskUiControllerError): """Exception raised when a command sent to the controller is invalid. This exception is raised when a command fails schema validation on the - controller server side, typically due to malformed command structure or + target computer side, typically due to malformed command structure or invalid parameters. Args: - details (str | None): Optional additional error details from the server. + details (str | None): Optional additional error details from the target + computer. """ def __init__(self, details: str | None = None): error_msg = ( - "AgentOS: Command validation failed" - " This error may be resolved by updating the AskUI" - " controller to the latest version." + "AgentOS: command validation failed on the target computer. " + "This is typically caused by a malformed command or a version " + "mismatch; updating the AskUI controller to the latest version " + "may resolve it." ) if details: - error_msg += f"\n{details}" + error_msg += f"\nController details: {details}" super().__init__(error_msg) self.details = details diff --git a/src/askui/tools/computer/__init__.py b/src/askui/tools/computer/__init__.py index 0410151e..b146a31e 100644 --- a/src/askui/tools/computer/__init__.py +++ b/src/askui/tools/computer/__init__.py @@ -1,10 +1,12 @@ from .connect_tool import ComputerConnectTool from .disconnect_tool import ComputerDisconnectTool +from .get_current_computer_target_id_tool import ComputerGetCurrentComputerTargetIdTool from .get_mouse_position_tool import ComputerGetMousePositionTool from .get_system_info_tool import ComputerGetSystemInfoTool from .keyboard_pressed_tool import ComputerKeyboardPressedTool from .keyboard_release_tool import ComputerKeyboardReleaseTool from .keyboard_tap_tool import ComputerKeyboardTapTool +from .list_agent_os_target_computers_tool import ComputerListAgentOsTargetComputersTool from .list_displays_tool import ComputerListDisplaysTool from .mouse_click_tool import ComputerMouseClickTool from .mouse_hold_down_tool import ComputerMouseHoldDownTool @@ -14,12 +16,16 @@ from .retrieve_active_display_tool import ComputerRetrieveActiveDisplayTool from .screenshot_tool import ComputerScreenshotTool from .set_active_display_tool import ComputerSetActiveDisplayTool +from .switch_agent_os_target_computer_tool import ( + ComputerSwitchAgentOsTargetComputerTool, +) from .type_tool import ComputerTypeTool __all__ = [ "ComputerGetSystemInfoTool", "ComputerConnectTool", "ComputerDisconnectTool", + "ComputerGetCurrentComputerTargetIdTool", "ComputerGetMousePositionTool", "ComputerKeyboardPressedTool", "ComputerKeyboardReleaseTool", @@ -32,6 +38,8 @@ "ComputerScreenshotTool", "ComputerTypeTool", "ComputerListDisplaysTool", + "ComputerListAgentOsTargetComputersTool", "ComputerRetrieveActiveDisplayTool", "ComputerSetActiveDisplayTool", + "ComputerSwitchAgentOsTargetComputerTool", ] diff --git a/src/askui/tools/computer/connect_tool.py b/src/askui/tools/computer/connect_tool.py index 7e0e35f4..e4ece900 100644 --- a/src/askui/tools/computer/connect_tool.py +++ b/src/askui/tools/computer/connect_tool.py @@ -1,11 +1,11 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerConnectTool(ComputerBaseTool): """Computer Connect Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="connect", description=( diff --git a/src/askui/tools/computer/disconnect_tool.py b/src/askui/tools/computer/disconnect_tool.py index 6f3cea25..88a0fe86 100644 --- a/src/askui/tools/computer/disconnect_tool.py +++ b/src/askui/tools/computer/disconnect_tool.py @@ -1,11 +1,11 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerDisconnectTool(ComputerBaseTool): """Computer Disconnect Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="disconnect", description=( diff --git a/src/askui/tools/computer/get_current_computer_target_id_tool.py b/src/askui/tools/computer/get_current_computer_target_id_tool.py new file mode 100644 index 00000000..74ac248f --- /dev/null +++ b/src/askui/tools/computer/get_current_computer_target_id_tool.py @@ -0,0 +1,18 @@ +from askui.models.shared import ComputerBaseTool +from askui.tools.agent_os import ComputerAgentOS + + +class ComputerGetCurrentComputerTargetIdTool(ComputerBaseTool): + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: + super().__init__( + name="get_current_computer_target_id", + description=""" + Return the `computer_id` of the currently active Agent OS target + computer that agent-os actions are routed to. + """, + agent_os=agent_os, + ) + self.is_cacheable = True + + def __call__(self) -> str: + return self.agent_os.get_current_computer_target_id() diff --git a/src/askui/tools/computer/get_mouse_position_tool.py b/src/askui/tools/computer/get_mouse_position_tool.py index 059822a5..09729790 100644 --- a/src/askui/tools/computer/get_mouse_position_tool.py +++ b/src/askui/tools/computer/get_mouse_position_tool.py @@ -8,12 +8,20 @@ class ComputerGetMousePositionTool(ComputerBaseTool): def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: super().__init__( name="get_mouse_position", - description="Get the current mouse position.", + description=( + "Get the current mouse position on the currently active Agent OS " + "target computer. The result is prefixed with the active target " + "computer's id." + ), agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) self.is_cacheable = True def __call__(self) -> str: + target_id = self.agent_os.get_current_computer_target_id(report=False) cursor_position = self.agent_os.get_mouse_position() - return f"Mouse is at position ({cursor_position.x}, {cursor_position.y})." + return ( + f"[Computer '{target_id}']: Mouse is at position " + f"({cursor_position.x}, {cursor_position.y})." + ) diff --git a/src/askui/tools/computer/get_system_info_tool.py b/src/askui/tools/computer/get_system_info_tool.py index 7f68c07d..c82c0008 100644 --- a/src/askui/tools/computer/get_system_info_tool.py +++ b/src/askui/tools/computer/get_system_info_tool.py @@ -1,11 +1,12 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerGetSystemInfoTool(ComputerBaseTool): """ - Get the system information. - This tool returns the system information as a JSON object. + Get the system information of the currently active Agent OS target computer. + This tool returns the system information as a JSON object prefixed with the + active target computer's id. The JSON object contains the following fields: - platform: The operating system platform. - label: The operating system label. @@ -13,12 +14,14 @@ class ComputerGetSystemInfoTool(ComputerBaseTool): - architecture: The operating system architecture. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="get_system_info_tool", description=""" - Get the system information. - This tool returns the system information as a JSON object. + Get the system information of the currently active Agent OS target + computer. This tool returns the system information as a JSON object + prefixed with the active target computer's id so it is clear which + computer the info belongs to. The JSON object contains the following fields: - platform: The operating system platform. - label: The operating system label. @@ -29,4 +32,6 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: ) def __call__(self) -> str: - return str(self.agent_os.get_system_info().model_dump_json()) + target_id = self.agent_os.get_current_computer_target_id(report=False) + system_info_json = self.agent_os.get_system_info().model_dump_json() + return f"[Computer '{target_id}']: {system_info_json}" diff --git a/src/askui/tools/computer/keyboard_pressed_tool.py b/src/askui/tools/computer/keyboard_pressed_tool.py index e85fad88..8f4fdb05 100644 --- a/src/askui/tools/computer/keyboard_pressed_tool.py +++ b/src/askui/tools/computer/keyboard_pressed_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, ModifierKey, PcKey +from askui.tools.agent_os import ComputerAgentOS, ModifierKey, PcKey class ComputerKeyboardPressedTool(ComputerBaseTool): """Computer Keyboard Pressed Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="keyboard_pressed", description="Press and hold a keyboard key.", diff --git a/src/askui/tools/computer/keyboard_release_tool.py b/src/askui/tools/computer/keyboard_release_tool.py index 13603f4b..7a7aedf9 100644 --- a/src/askui/tools/computer/keyboard_release_tool.py +++ b/src/askui/tools/computer/keyboard_release_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, ModifierKey, PcKey +from askui.tools.agent_os import ComputerAgentOS, ModifierKey, PcKey class ComputerKeyboardReleaseTool(ComputerBaseTool): """Computer Keyboard Release Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="keyboard_release", description="Release a keyboard key.", diff --git a/src/askui/tools/computer/keyboard_tap_tool.py b/src/askui/tools/computer/keyboard_tap_tool.py index 62f48227..64f96956 100644 --- a/src/askui/tools/computer/keyboard_tap_tool.py +++ b/src/askui/tools/computer/keyboard_tap_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, ModifierKey, PcKey +from askui.tools.agent_os import ComputerAgentOS, ModifierKey, PcKey class ComputerKeyboardTapTool(ComputerBaseTool): """Computer Keyboard Tap Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="keyboard_tap", description="Tap (press and release) a keyboard key.", diff --git a/src/askui/tools/computer/list_agent_os_target_computers_tool.py b/src/askui/tools/computer/list_agent_os_target_computers_tool.py new file mode 100644 index 00000000..bafe9ba9 --- /dev/null +++ b/src/askui/tools/computer/list_agent_os_target_computers_tool.py @@ -0,0 +1,19 @@ +from askui.models.shared import ComputerBaseTool +from askui.tools.agent_os import ComputerAgentOS + + +class ComputerListAgentOsTargetComputersTool(ComputerBaseTool): + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: + super().__init__( + name="list_agent_os_target_computers", + description=""" + List all the registered Agent OS target computers that the agent + can route actions to. Each target computer has a unique + `computer_id` that can be used to switch between them. + """, + agent_os=agent_os, + ) + + def __call__(self) -> str: + target_computer_reprs = self.agent_os.describe_agent_os_target_computers() + return "\n".join(target_computer_reprs) diff --git a/src/askui/tools/computer/list_displays_tool.py b/src/askui/tools/computer/list_displays_tool.py index 68f3c207..e500e262 100644 --- a/src/askui/tools/computer/list_displays_tool.py +++ b/src/askui/tools/computer/list_displays_tool.py @@ -1,19 +1,23 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerListDisplaysTool(ComputerBaseTool): - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="list_displays", description=""" - List all the available displays on the computer. + List all the available displays on the currently active Agent OS + target computer. The result is prefixed with the active target + computer's id so it is clear which computer the displays belong to. """, agent_os=agent_os, ) self.is_cacheable = True def __call__(self) -> str: - return self.agent_os.list_displays().model_dump_json( + target_id = self.agent_os.get_current_computer_target_id(report=False) + displays_json = self.agent_os.list_displays().model_dump_json( exclude={"data": {"__all__": {"size"}}}, ) + return f"[Computer '{target_id}']: {displays_json}" diff --git a/src/askui/tools/computer/mouse_click_tool.py b/src/askui/tools/computer/mouse_click_tool.py index 002f7902..264e27eb 100644 --- a/src/askui/tools/computer/mouse_click_tool.py +++ b/src/askui/tools/computer/mouse_click_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, MouseButton +from askui.tools.agent_os import ComputerAgentOS, MouseButton class ComputerMouseClickTool(ComputerBaseTool): """Computer Mouse Click Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="mouse_click", description="Click and release the mouse button at the current position.", diff --git a/src/askui/tools/computer/mouse_hold_down_tool.py b/src/askui/tools/computer/mouse_hold_down_tool.py index 9387b117..74f68496 100644 --- a/src/askui/tools/computer/mouse_hold_down_tool.py +++ b/src/askui/tools/computer/mouse_hold_down_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, MouseButton +from askui.tools.agent_os import ComputerAgentOS, MouseButton class ComputerMouseHoldDownTool(ComputerBaseTool): """Computer Mouse Hold Down Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="mouse_hold_down", description="Hold down the mouse button at the current position.", diff --git a/src/askui/tools/computer/mouse_release_tool.py b/src/askui/tools/computer/mouse_release_tool.py index b8227d9c..39651f22 100644 --- a/src/askui/tools/computer/mouse_release_tool.py +++ b/src/askui/tools/computer/mouse_release_tool.py @@ -1,13 +1,13 @@ from typing import get_args from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs, MouseButton +from askui.tools.agent_os import ComputerAgentOS, MouseButton class ComputerMouseReleaseTool(ComputerBaseTool): """Computer Mouse Release Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="mouse_release", description="Release the mouse button at the current position.", diff --git a/src/askui/tools/computer/retrieve_active_display_tool.py b/src/askui/tools/computer/retrieve_active_display_tool.py index 7eef6cfd..853785d7 100644 --- a/src/askui/tools/computer/retrieve_active_display_tool.py +++ b/src/askui/tools/computer/retrieve_active_display_tool.py @@ -1,20 +1,24 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerRetrieveActiveDisplayTool(ComputerBaseTool): - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="retrieve_active_display", description=""" - Retrieve the currently active display on the computer. - The display is used to take screenshots and perform actions. + Retrieve the currently active display on the currently active Agent + OS target computer. The display is used to take screenshots and + perform actions. The result is prefixed with the active target + computer's id so it is clear which computer the display belongs to. """, agent_os=agent_os, ) self.is_cacheable = True def __call__(self) -> str: - return str( - self.agent_os.retrieve_active_display().model_dump_json(exclude={"size"}) + target_id = self.agent_os.get_current_computer_target_id(report=False) + display_json = self.agent_os.retrieve_active_display().model_dump_json( + exclude={"size"} ) + return f"[Computer '{target_id}']: {display_json}" diff --git a/src/askui/tools/computer/screenshot_tool.py b/src/askui/tools/computer/screenshot_tool.py index fcf46553..0928d389 100644 --- a/src/askui/tools/computer/screenshot_tool.py +++ b/src/askui/tools/computer/screenshot_tool.py @@ -10,12 +10,21 @@ class ComputerScreenshotTool(ComputerBaseTool): def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: super().__init__( name="screenshot", - description="Take a screenshot of the current screen.", + description=( + "Take a screenshot of the current screen on the currently active " + "Agent OS target computer. The accompanying message is prefixed " + "with the active target computer's id so it is clear which " + "computer the screenshot was taken on." + ), agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) self.is_cacheable = True def __call__(self) -> tuple[str, Image.Image]: + target_id = self.agent_os.get_current_computer_target_id(report=False) screenshot = self.agent_os.screenshot() - return "Screenshot was taken.", screenshot + return ( + f"[Computer '{target_id}']: Screenshot was taken.", + screenshot, + ) diff --git a/src/askui/tools/computer/set_active_display_tool.py b/src/askui/tools/computer/set_active_display_tool.py index 94719dec..bec7ba89 100644 --- a/src/askui/tools/computer/set_active_display_tool.py +++ b/src/askui/tools/computer/set_active_display_tool.py @@ -1,9 +1,9 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerSetActiveDisplayTool(ComputerBaseTool): - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="set_active_display", description=""" diff --git a/src/askui/tools/computer/switch_agent_os_target_computer_tool.py b/src/askui/tools/computer/switch_agent_os_target_computer_tool.py new file mode 100644 index 00000000..ded871ef --- /dev/null +++ b/src/askui/tools/computer/switch_agent_os_target_computer_tool.py @@ -0,0 +1,28 @@ +from askui.models.shared import ComputerBaseTool +from askui.tools.agent_os import ComputerAgentOS + + +class ComputerSwitchAgentOsTargetComputerTool(ComputerBaseTool): + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: + super().__init__( + name="switch_agent_os_target_computer", + description=""" + Switch the active Agent OS target computer by its `computer_id`. + Future agent-os actions are routed to the newly selected target + computer. Use `list_agent_os_target_computers` to discover the + available computer ids. + """, + input_schema={ + "type": "object", + "properties": { + "computer_id": { + "type": "string", + }, + }, + "required": ["computer_id"], + }, + agent_os=agent_os, + ) + + def __call__(self, computer_id: str) -> str: + return repr(self.agent_os.switch_agent_os_target_computer(computer_id)) diff --git a/src/askui/tools/computer/type_tool.py b/src/askui/tools/computer/type_tool.py index c2b11741..7ac3343d 100644 --- a/src/askui/tools/computer/type_tool.py +++ b/src/askui/tools/computer/type_tool.py @@ -1,11 +1,11 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerTypeTool(ComputerBaseTool): """Computer Type Tool""" - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="type", description="Type text on the computer.", diff --git a/src/askui/tools/computer_agent_os_facade.py b/src/askui/tools/computer_agent_os_facade.py index 28a1a8c5..14840593 100644 --- a/src/askui/tools/computer_agent_os_facade.py +++ b/src/askui/tools/computer_agent_os_facade.py @@ -1,10 +1,13 @@ +from collections.abc import Iterator +from contextlib import contextmanager from typing import TYPE_CHECKING from PIL import Image +from typing_extensions import Self from askui.models.shared.tool_tags import ToolTags from askui.tools.agent_os import ( - AgentOs, + ComputerAgentOS, Coordinate, Display, DisplaySize, @@ -18,6 +21,9 @@ from askui.utils.image_utils import scale_coordinates, scale_image_to_fit if TYPE_CHECKING: + from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + ) from askui.tools.askui.askui_ui_controller_grpc.generated import ( Controller_V1_pb2 as controller_v1_pbs, ) @@ -28,15 +34,15 @@ ) -class ComputerAgentOsFacade(AgentOs): +class ComputerAgentOsFacade(ComputerAgentOS): """ - Facade for AgentOs that adds coordinate scaling functionality. + Facade for ComputerAgentOS that adds coordinate scaling functionality. This class is used to scale the coordinates to the target resolution and back to the real screen resolution. """ - def __init__(self, agent_os: AgentOs) -> None: + def __init__(self, agent_os: ComputerAgentOS) -> None: self._agent_os = agent_os self._target_resolution: tuple[int, int] = (1024, 768) self._real_screen_resolution: DisplaySize | None = None @@ -266,6 +272,39 @@ def set_window_in_focus(self, process_id: int, window_id: int) -> None: """ self._agent_os.set_window_in_focus(process_id, window_id) + def add_agent_os_target_computer( + self, agent_os_target_computer: "ComputerTarget" + ) -> "ComputerTarget": + return self._agent_os.add_agent_os_target_computer(agent_os_target_computer) + + def reset_agent_os_target_computers( + self, + agent_os_target_computers: "list[ComputerTarget] | None" = None, + ) -> None: + self._agent_os.reset_agent_os_target_computers(agent_os_target_computers) + + def describe_agent_os_target_computers(self) -> list[str]: + return self._agent_os.describe_agent_os_target_computers() + + def get_current_computer_target_id(self, report: bool = True) -> str: + return self._agent_os.get_current_computer_target_id(report=report) + + def switch_agent_os_target_computer(self, computer_id: str) -> "ComputerTarget": + agent_os_target_computer = self._agent_os.switch_agent_os_target_computer( + computer_id + ) + self._real_screen_resolution = None + return agent_os_target_computer + + @contextmanager + def temporary_select(self, computer_id: str) -> Iterator[Self]: + with self._agent_os.temporary_select(computer_id): + self._real_screen_resolution = None + try: + yield self + finally: + self._real_screen_resolution = None + def get_file_names(self, absolute_directory_path: str) -> list[str]: """ List file names in an absolute directory on the automation target. diff --git a/src/askui/tools/playwright/agent_os.py b/src/askui/tools/playwright/agent_os.py index 6381be37..93d6d500 100644 --- a/src/askui/tools/playwright/agent_os.py +++ b/src/askui/tools/playwright/agent_os.py @@ -19,11 +19,18 @@ from askui.reporting import NULL_REPORTER, Reporter from askui.utils.annotated_image import AnnotatedImage -from ..agent_os import AgentOs, Display, DisplaySize, InputEvent, ModifierKey, PcKey +from ..agent_os import ( + ComputerAgentOS, + Display, + DisplaySize, + InputEvent, + ModifierKey, + PcKey, +) -class PlaywrightAgentOs(AgentOs): - """Playwright-based implementation of `AgentOs`. +class PlaywrightAgentOs(ComputerAgentOS): + """Playwright-based implementation of `ComputerAgentOS`. This implementation uses Playwright's Python SDK to control browser automation and simulate user interactions. It provides mouse control, keyboard input, diff --git a/src/askui/tools/store/__init__.py b/src/askui/tools/store/__init__.py index 2a05056d..eba8bb31 100644 --- a/src/askui/tools/store/__init__.py +++ b/src/askui/tools/store/__init__.py @@ -3,7 +3,7 @@ Tools are organized by category: - `android`: Tools specific to Android agents (require AndroidAgentOs) - `computer`: Tools specific to Computer/Desktop agents (require ComputerAgentOsFacade) -- `universal`: Tools that work with any agent type (don't require AgentOs) +- `universal`: Tools that work with any agent type (don't require ComputerAgentOS) Example: ```python diff --git a/src/askui/tools/store/computer/__init__.py b/src/askui/tools/store/computer/__init__.py index fb7f5427..f2fc0ca2 100644 --- a/src/askui/tools/store/computer/__init__.py +++ b/src/askui/tools/store/computer/__init__.py @@ -1,6 +1,6 @@ """Computer-specific tools. -These tools require AgentOs (or ComputerAgentOsFacade) and are designed +These tools require ComputerAgentOS (or ComputerAgentOsFacade) and are designed for use with VisionAgent. """ diff --git a/src/askui/tools/store/computer/experimental/get_file.py b/src/askui/tools/store/computer/experimental/get_file.py index b7bf5c93..a7e47010 100644 --- a/src/askui/tools/store/computer/experimental/get_file.py +++ b/src/askui/tools/store/computer/experimental/get_file.py @@ -1,7 +1,7 @@ from PIL import Image from askui.models.shared import ComputerBaseTool, ToolTags -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerGetFileTool(ComputerBaseTool): @@ -24,7 +24,7 @@ class ComputerGetFileTool(ComputerBaseTool): ``` """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="get_file_tool", description=( diff --git a/src/askui/tools/store/computer/experimental/get_file_names.py b/src/askui/tools/store/computer/experimental/get_file_names.py index 5002b0eb..643820fb 100644 --- a/src/askui/tools/store/computer/experimental/get_file_names.py +++ b/src/askui/tools/store/computer/experimental/get_file_names.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerGetFileNamesTool(ComputerBaseTool): @@ -23,7 +23,7 @@ class ComputerGetFileNamesTool(ComputerBaseTool): ``` """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="get_file_names_tool", description=( diff --git a/src/askui/tools/store/computer/experimental/remove_virtual_displays.py b/src/askui/tools/store/computer/experimental/remove_virtual_displays.py index 1a7b2000..1a952b69 100644 --- a/src/askui/tools/store/computer/experimental/remove_virtual_displays.py +++ b/src/askui/tools/store/computer/experimental/remove_virtual_displays.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerRemoveVirtualDisplaysTool(ComputerBaseTool): @@ -24,7 +24,7 @@ class ComputerRemoveVirtualDisplaysTool(ComputerBaseTool): ``` """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="remove_virtual_displays_tool", description=( diff --git a/src/askui/tools/store/computer/experimental/window_management/add_window_as_virtual_display.py b/src/askui/tools/store/computer/experimental/window_management/add_window_as_virtual_display.py index 7750b333..58e6b73e 100644 --- a/src/askui/tools/store/computer/experimental/window_management/add_window_as_virtual_display.py +++ b/src/askui/tools/store/computer/experimental/window_management/add_window_as_virtual_display.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerAddWindowAsVirtualDisplayTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerAddWindowAsVirtualDisplayTool(ComputerBaseTool): for UI automation tasks. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="add_window_as_virtual_display_tool", description=""" diff --git a/src/askui/tools/store/computer/experimental/window_management/list_process.py b/src/askui/tools/store/computer/experimental/window_management/list_process.py index 775f5a8c..c3141370 100644 --- a/src/askui/tools/store/computer/experimental/window_management/list_process.py +++ b/src/askui/tools/store/computer/experimental/window_management/list_process.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerListProcessTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerListProcessTool(ComputerBaseTool): applications and their process IDs. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="list_process_tool", description=""" diff --git a/src/askui/tools/store/computer/experimental/window_management/list_process_windows.py b/src/askui/tools/store/computer/experimental/window_management/list_process_windows.py index 850c8dcc..00edebff 100644 --- a/src/askui/tools/store/computer/experimental/window_management/list_process_windows.py +++ b/src/askui/tools/store/computer/experimental/window_management/list_process_windows.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerListProcessWindowsTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerListProcessWindowsTool(ComputerBaseTool): list_process_tool. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="list_process_windows_tool", description=""" diff --git a/src/askui/tools/store/computer/experimental/window_management/set_process_in_focus.py b/src/askui/tools/store/computer/experimental/window_management/set_process_in_focus.py index 2e27550f..7e7fff84 100644 --- a/src/askui/tools/store/computer/experimental/window_management/set_process_in_focus.py +++ b/src/askui/tools/store/computer/experimental/window_management/set_process_in_focus.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerSetProcessInFocusTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerSetProcessInFocusTool(ComputerBaseTool): operating system or the process determine which window should be focused. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="set_process_in_focus_tool", description=""" diff --git a/src/askui/tools/store/computer/experimental/window_management/set_window_in_focus.py b/src/askui/tools/store/computer/experimental/window_management/set_window_in_focus.py index e597a78c..41573f5a 100644 --- a/src/askui/tools/store/computer/experimental/window_management/set_window_in_focus.py +++ b/src/askui/tools/store/computer/experimental/window_management/set_window_in_focus.py @@ -1,5 +1,5 @@ from askui.models.shared import ComputerBaseTool -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class ComputerSetWindowInFocusTool(ComputerBaseTool): @@ -9,7 +9,7 @@ class ComputerSetWindowInFocusTool(ComputerBaseTool): before performing automation tasks. """ - def __init__(self, agent_os: AgentOs | None = None) -> None: + def __init__(self, agent_os: ComputerAgentOS | None = None) -> None: super().__init__( name="set_window_in_focus_tool", description=""" diff --git a/src/askui/tools/toolbox.py b/src/askui/tools/toolbox.py index 3f954fe4..a7362810 100644 --- a/src/askui/tools/toolbox.py +++ b/src/askui/tools/toolbox.py @@ -3,7 +3,7 @@ import httpx import pyperclip -from askui.tools.agent_os import AgentOs +from askui.tools.agent_os import ComputerAgentOS class AgentToolbox: @@ -13,16 +13,18 @@ class AgentToolbox: Provides access to OS-level actions, clipboard, web browser, HTTP client etc. Args: - agent_os (AgentOs): The OS interface implementation to use for agent actions. + agent_os (ComputerAgentOS): The OS interface implementation to use for + agent actions. Attributes: webbrowser: Python's built-in `webbrowser` module for opening URLs. clipboard: `pyperclip` module for clipboard access. - agent_os (AgentOs): The OS interface for mouse, keyboard, and screen actions. + agent_os (ComputerAgentOS): The OS interface for mouse, keyboard, and + screen actions. httpx: HTTPX client for HTTP requests. """ - def __init__(self, agent_os: AgentOs): + def __init__(self, agent_os: ComputerAgentOS): self.webbrowser = webbrowser self.clipboard = pyperclip self.os = agent_os diff --git a/tests/conftest.py b/tests/conftest.py index 5eb112db..f2a792af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,7 @@ from PIL import Image from pytest_mock import MockerFixture -from askui.tools.agent_os import AgentOs, Display, DisplaySize -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS, Display, DisplaySize @pytest.fixture @@ -84,22 +83,28 @@ def path_fixtures_github_com__icon(path_fixtures_images: pathlib.Path) -> pathli @pytest.fixture -def agent_os_mock(mocker: MockerFixture) -> AgentOs: +def agent_os_mock(mocker: MockerFixture) -> ComputerAgentOS: """Fixture providing a mock agent os.""" - mock = mocker.MagicMock(spec=AgentOs) + mock = mocker.MagicMock(spec=ComputerAgentOS) mock.retrieve_active_display.return_value = Display( id=1, name="Display 1", size=DisplaySize(width=100, height=100), ) mock.screenshot.return_value = Image.new("RGB", (100, 100), color="white") - return cast("AgentOs", mock) + return cast("ComputerAgentOS", mock) @pytest.fixture -def agent_toolbox_mock(agent_os_mock: AgentOs) -> AgentToolbox: - """Fixture providing a mock agent toolbox.""" - return AgentToolbox(agent_os=agent_os_mock) +def agent_os_mock_patch( + mocker: MockerFixture, agent_os_mock: ComputerAgentOS +) -> ComputerAgentOS: + """Patches `MultiComputerTargetAgentOS` so `ComputerAgent` uses `agent_os_mock`.""" + mocker.patch( + "askui.computer_agent.MultiComputerTargetAgentOS", + return_value=agent_os_mock, + ) + return agent_os_mock @pytest.fixture(autouse=True) diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index 19bdbaa6..b502949e 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -27,7 +27,7 @@ from askui.models.shared.settings import LocateSettings from askui.models.types.geometry import PointList from askui.reporting import Reporter, SimpleHtmlReporter -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS from askui.utils.image_utils import ImageSource @@ -98,7 +98,7 @@ def combo_locate_model(path_fixtures: pathlib.Path) -> LocateModel: @pytest.fixture def agent_with_pta_model( pta_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -106,7 +106,6 @@ def agent_with_pta_model( detection_provider=_LocateModelDetectionProvider(pta_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @@ -114,7 +113,7 @@ def agent_with_pta_model( @pytest.fixture def agent_with_ocr_model( ocr_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -122,7 +121,6 @@ def agent_with_ocr_model( detection_provider=_LocateModelDetectionProvider(ocr_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @@ -130,7 +128,7 @@ def agent_with_ocr_model( @pytest.fixture def agent_with_ai_element_model( ai_element_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -138,7 +136,6 @@ def agent_with_ai_element_model( detection_provider=_LocateModelDetectionProvider(ai_element_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @@ -146,7 +143,7 @@ def agent_with_ai_element_model( @pytest.fixture def agent_with_combo_model( combo_locate_model: LocateModel, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: with ComputerAgent( @@ -154,19 +151,17 @@ def agent_with_combo_model( detection_provider=_LocateModelDetectionProvider(combo_locate_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent @pytest.fixture def vision_agent( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, ) -> Generator[ComputerAgent, None, None]: """Fixture providing a ComputerAgent instance.""" with ComputerAgent( reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: yield agent diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index bae0d4e8..b34aa67e 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -19,7 +19,7 @@ from askui.models.shared.settings import GetSettings from askui.models.types.response_schemas import ResponseSchema from askui.reporting import Reporter -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS from askui.utils.source_utils import Source @@ -97,7 +97,7 @@ class BrowserContextResponse(ResponseSchemaBase): ) def test_get( vision_agent: ComputerAgent, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel | None, @@ -112,7 +112,6 @@ def test_get( settings=AgentSettings( image_qa_provider=_GetModelImageQAProvider(get_model) ), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: url = agent.get( @@ -142,14 +141,13 @@ def test_get( ], ) def test_get_with_pdf_with_gemini_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_pdf: pathlib.Path, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -180,7 +178,7 @@ def test_get_with_pdf_with_gemini_model( ], ) def test_get_with_pdf_too_large( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_pdf: pathlib.Path, @@ -189,7 +187,6 @@ def test_get_with_pdf_too_large( mocker.patch("askui.models.askui.get_model.MAX_FILE_SIZE_BYTES", 1) with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: with pytest.raises(ValueError, match="PDF file size exceeds the limit"): @@ -232,14 +229,13 @@ def test_get_with_pdf_too_large_with_default_model( ], ) def test_get_with_xlsx_with_gemini_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_excel: pathlib.Path, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -279,14 +275,13 @@ class SalaryResponse(ResponseSchemaBase): ], ) def test_get_with_xlsx_with_gemini_model_with_response_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, get_model: GetModel, path_fixtures_dummy_excel: pathlib.Path, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -325,7 +320,7 @@ def test_get_with_docs_with_default_model( def test_get_with_fallback_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, ) -> None: @@ -338,7 +333,6 @@ def test_get_with_fallback_model( image_qa_provider=_GetModelImageQAProvider(askui_get_model) ), reporters=[simple_html_reporter], - tools=agent_toolbox_mock, ) as agent: url = agent.get( "What is the current url shown in the url bar?", @@ -393,7 +387,7 @@ def test_get_with_response_schema_with_default_value( ) def test_get_with_response_schema( vision_agent: ComputerAgent, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel | None, @@ -409,7 +403,6 @@ def test_get_with_response_schema( settings=AgentSettings( image_qa_provider=_GetModelImageQAProvider(get_model) ), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -434,14 +427,13 @@ def test_get_with_response_schema( ], ) def test_get_with_nested_and_inherited_response_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -473,14 +465,13 @@ class LinkedListNode(ResponseSchemaBase): ], ) def test_get_with_recursive_response_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: with pytest.raises( @@ -507,14 +498,13 @@ def test_get_with_recursive_response_schema( ], ) def test_get_with_string_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -545,14 +535,13 @@ def test_get_with_string_schema( ], ) def test_get_with_boolean_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -577,14 +566,13 @@ def test_get_with_boolean_schema( ], ) def test_get_with_integer_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -609,14 +597,13 @@ def test_get_with_integer_schema( ], ) def test_get_with_float_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -641,14 +628,13 @@ def test_get_with_float_schema( ], ) def test_get_returns_str_when_no_schema_specified( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -675,14 +661,13 @@ class Basis(ResponseSchemaBase): ], ) def test_get_with_basis_schema( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -715,14 +700,13 @@ class BasisWithNestedRootModel(ResponseSchemaBase): ], ) def test_get_with_nested_root_model( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, ) -> None: with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( @@ -774,7 +758,7 @@ class PageDom(ResponseSchemaBase): ], ) def test_get_with_deeply_nested_response_schema_with_model_that_does_not_support_recursion( - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, get_model: GetModel, @@ -786,7 +770,6 @@ def test_get_with_deeply_nested_response_schema_with_model_that_does_not_support """ with ComputerAgent( settings=AgentSettings(image_qa_provider=_GetModelImageQAProvider(get_model)), - tools=agent_toolbox_mock, reporters=[simple_html_reporter], ) as agent: response = agent.get( diff --git a/tests/e2e/test_telemetry.py b/tests/e2e/test_telemetry.py index 25b9202a..70539277 100644 --- a/tests/e2e/test_telemetry.py +++ b/tests/e2e/test_telemetry.py @@ -5,13 +5,13 @@ from askui import locators as loc from askui.container import telemetry from askui.telemetry.processors import Segment, SegmentSettings -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS @pytest.mark.timeout(60) def test_telemetry_with_nonexistent_domain_should_not_block( github_login_screenshot: Image.Image, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 ) -> None: telemetry.set_processors( [ @@ -23,6 +23,6 @@ def test_telemetry_with_nonexistent_domain_should_not_block( ) ] ) - with ComputerAgent(tools=agent_toolbox_mock) as agent: + with ComputerAgent() as agent: agent.locate(loc.Text(), screenshot=github_login_screenshot) assert True diff --git a/tests/e2e/tools/askui/test_askui_controller.py b/tests/e2e/tools/askui/test_askui_controller.py index bca9e591..b0cff359 100644 --- a/tests/e2e/tools/askui/test_askui_controller.py +++ b/tests/e2e/tools/askui/test_askui_controller.py @@ -7,33 +7,33 @@ from askui.reporting import CompositeReporter from askui.tools.agent_os import Coordinate +from askui.tools.askui import LocalComputerTarget from askui.tools.askui.askui_controller import ( - AskUiControllerClient, - AskUiControllerServer, + MultiComputerTargetAgentOS, RenderObjectStyle, ) from askui.tools.askui.askui_controller_settings import AskUiControllerSettings @pytest.fixture -def controller_server() -> AskUiControllerServer: - return AskUiControllerServer( +def agent_os_target_computer() -> LocalComputerTarget: + return LocalComputerTarget( settings=AskUiControllerSettings(controller_args="--showOverlay true") ) @pytest.fixture def controller_client( - controller_server: AskUiControllerServer, -) -> AskUiControllerClient: - return AskUiControllerClient( + agent_os_target_computer: LocalComputerTarget, +) -> MultiComputerTargetAgentOS: + return MultiComputerTargetAgentOS( reporter=CompositeReporter(), display=1, - controller_server=controller_server, + agent_os_target_computers=[agent_os_target_computer], ) -def test_actions(controller_client: AskUiControllerClient) -> None: +def test_actions(controller_client: MultiComputerTargetAgentOS) -> None: with controller_client: controller_client.screenshot() controller_client.mouse_move(0, 0) @@ -42,14 +42,15 @@ def test_actions(controller_client: AskUiControllerClient) -> None: @pytest.mark.parametrize("button", ["left", "right", "middle"]) def test_click_all_buttons( - controller_client: AskUiControllerClient, button: Literal["left", "middle", "right"] + controller_client: MultiComputerTargetAgentOS, + button: Literal["left", "middle", "right"], ) -> None: """Test clicking each mouse button""" with controller_client: controller_client.click(button=button) -def test_mouse_multiple_clicks(controller_client: AskUiControllerClient) -> None: +def test_mouse_multiple_clicks(controller_client: MultiComputerTargetAgentOS) -> None: """Test click count parameter""" with controller_client: controller_client.click(count=3) @@ -57,7 +58,8 @@ def test_mouse_multiple_clicks(controller_client: AskUiControllerClient) -> None @pytest.mark.parametrize("button", ["left", "right", "middle"]) def test_mouse_press_hold_release( - controller_client: AskUiControllerClient, button: Literal["left", "middle", "right"] + controller_client: MultiComputerTargetAgentOS, + button: Literal["left", "middle", "right"], ) -> None: """Test mouse_down() and mouse_up() operations""" with controller_client: @@ -67,14 +69,14 @@ def test_mouse_press_hold_release( @pytest.mark.parametrize("x,y", [(0, 0), (100, 100), (500, 300)]) def test_mouse_move_coordinates( - controller_client: AskUiControllerClient, x: int, y: int + controller_client: MultiComputerTargetAgentOS, x: int, y: int ) -> None: """Test mouse movement to various coordinates""" with controller_client: controller_client.mouse_move(x, y) -def test_mouse_scroll_directions(controller_client: AskUiControllerClient) -> None: +def test_mouse_scroll_directions(controller_client: MultiComputerTargetAgentOS) -> None: """Test horizontal and vertical scrolling""" with controller_client: controller_client.mouse_scroll(0, 5) # Vertical scroll @@ -82,54 +84,58 @@ def test_mouse_scroll_directions(controller_client: AskUiControllerClient) -> No controller_client.mouse_scroll(3, -2) # Combined scroll -def test_type_text_basic(controller_client: AskUiControllerClient) -> None: +def test_type_text_basic(controller_client: MultiComputerTargetAgentOS) -> None: """Test typing simple text""" with controller_client: controller_client.type("Hello World") -def test_type_text_with_speed(controller_client: AskUiControllerClient) -> None: +def test_type_text_with_speed(controller_client: MultiComputerTargetAgentOS) -> None: """Test typing with custom speed""" with controller_client: controller_client.type("Fast typing", typing_speed=100) controller_client.type("Slow typing", typing_speed=10) -def test_keyboard_tap_with_modifiers(controller_client: AskUiControllerClient) -> None: +def test_keyboard_tap_with_modifiers( + controller_client: MultiComputerTargetAgentOS, +) -> None: """Test key combination like Ctrl+C""" with controller_client: controller_client.keyboard_tap("c", modifier_keys=["command"]) controller_client.keyboard_tap("v", modifier_keys=["command"]) -def test_keyboard_tap_multiple(controller_client: AskUiControllerClient) -> None: +def test_keyboard_tap_multiple(controller_client: MultiComputerTargetAgentOS) -> None: """Test multiple key taps""" with controller_client: controller_client.keyboard_tap("escape", count=3) -def test_keyboard_press_hold_release(controller_client: AskUiControllerClient) -> None: +def test_keyboard_press_hold_release( + controller_client: MultiComputerTargetAgentOS, +) -> None: """Test keyboard_pressed() and keyboard_release()""" with controller_client: controller_client.keyboard_pressed("escape") controller_client.keyboard_release("escape") -def test_screenshot_basic(controller_client: AskUiControllerClient) -> None: +def test_screenshot_basic(controller_client: MultiComputerTargetAgentOS) -> None: """Test taking screenshots with different report settings""" with controller_client: image_with_report = controller_client.screenshot() assert isinstance(image_with_report, Image.Image) -def test_get_display_information(controller_client: AskUiControllerClient) -> None: +def test_get_display_information(controller_client: MultiComputerTargetAgentOS) -> None: """Test retrieving display information""" with controller_client: display_info = controller_client.list_displays() assert display_info is not None -def test_get_process_list(controller_client: AskUiControllerClient) -> None: +def test_get_process_list(controller_client: MultiComputerTargetAgentOS) -> None: """Test retrieving running processes""" with controller_client: processes = controller_client.get_process_list() @@ -139,38 +145,40 @@ def test_get_process_list(controller_client: AskUiControllerClient) -> None: assert processes_extended is not None -def test_get_automation_target_list(controller_client: AskUiControllerClient) -> None: +def test_get_automation_target_list( + controller_client: MultiComputerTargetAgentOS, +) -> None: """Test retrieving automation targets""" with controller_client: targets = controller_client.get_automation_target_list() assert targets is not None -def test_set_display(controller_client: AskUiControllerClient) -> None: +def test_set_display(controller_client: MultiComputerTargetAgentOS) -> None: """Test changing active display""" with controller_client: controller_client.set_display(1) -def test_set_mouse_delay(controller_client: AskUiControllerClient) -> None: +def test_set_mouse_delay(controller_client: MultiComputerTargetAgentOS) -> None: """Test configuring mouse action delay""" with controller_client: controller_client.set_mouse_delay(100) -def test_set_keyboard_delay(controller_client: AskUiControllerClient) -> None: +def test_set_keyboard_delay(controller_client: MultiComputerTargetAgentOS) -> None: """Test configuring keyboard action delay""" with controller_client: controller_client.set_keyboard_delay(50) -def test_run_command(controller_client: AskUiControllerClient) -> None: +def test_run_command(controller_client: MultiComputerTargetAgentOS) -> None: """Test executing shell commands""" with controller_client: controller_client.run_command("echo test", 0) -def test_get_action_count(controller_client: AskUiControllerClient) -> None: +def test_get_action_count(controller_client: MultiComputerTargetAgentOS) -> None: """Test getting count of batched actions""" with controller_client: count = controller_client.get_action_count() @@ -179,7 +187,7 @@ def test_get_action_count(controller_client: AskUiControllerClient) -> None: def test_operations_before_connect() -> None: """Test calling methods before connect() raises appropriate errors""" - client = AskUiControllerClient(reporter=CompositeReporter(), display=1) + client = MultiComputerTargetAgentOS(reporter=CompositeReporter(), display=1) with pytest.raises( AssertionError, match="Stub is not initialized. Call `connect()` first." @@ -187,19 +195,19 @@ def test_operations_before_connect() -> None: client.screenshot() -def test_invalid_coordinates(controller_client: AskUiControllerClient) -> None: +def test_invalid_coordinates(controller_client: MultiComputerTargetAgentOS) -> None: """Test mouse operations with potentially problematic coordinates""" with controller_client: controller_client.mouse_move(-1, -1) controller_client.mouse_move(9999, 9999) -def test_set_mouse_position(controller_client: AskUiControllerClient) -> None: +def test_set_mouse_position(controller_client: MultiComputerTargetAgentOS) -> None: with controller_client: controller_client.set_mouse_position(100, 100) -def test_get_mouse_position(controller_client: AskUiControllerClient) -> None: +def test_get_mouse_position(controller_client: MultiComputerTargetAgentOS) -> None: """Test getting current mouse coordinates""" with controller_client: position = controller_client.get_mouse_position() @@ -208,7 +216,7 @@ def test_get_mouse_position(controller_client: AskUiControllerClient) -> None: assert hasattr(position, "y") -def test_render_quad(controller_client: AskUiControllerClient) -> None: +def test_render_quad(controller_client: MultiComputerTargetAgentOS) -> None: """Test adding a quad render object to the display""" with controller_client: style = RenderObjectStyle( @@ -225,7 +233,7 @@ def test_render_quad(controller_client: AskUiControllerClient) -> None: assert response is not None -def test_render_line(controller_client: AskUiControllerClient) -> None: +def test_render_line(controller_client: MultiComputerTargetAgentOS) -> None: """Test rendering a line object to the display""" with controller_client: style = RenderObjectStyle( @@ -240,7 +248,7 @@ def test_render_line(controller_client: AskUiControllerClient) -> None: def test_render_image( - controller_client: AskUiControllerClient, + controller_client: MultiComputerTargetAgentOS, askui_logo_bmp: Image.Image, ) -> None: """Test rendering an image object to the display""" @@ -262,7 +270,7 @@ def test_render_image( assert response is not None -def test_render_text(controller_client: AskUiControllerClient) -> None: +def test_render_text(controller_client: MultiComputerTargetAgentOS) -> None: """Test rendering a text object to the display""" with controller_client: style = RenderObjectStyle( @@ -279,7 +287,7 @@ def test_render_text(controller_client: AskUiControllerClient) -> None: assert response is not None -def test_update_render_object(controller_client: AskUiControllerClient) -> None: +def test_update_render_object(controller_client: MultiComputerTargetAgentOS) -> None: """Test updating an existing render object""" with controller_client: style = RenderObjectStyle( @@ -306,7 +314,7 @@ def test_update_render_object(controller_client: AskUiControllerClient) -> None: controller_client.update_render_object(object_id, update_style) -def test_update_text_object(controller_client: AskUiControllerClient) -> None: +def test_update_text_object(controller_client: MultiComputerTargetAgentOS) -> None: """Test updating an existing render object""" with controller_client: style = RenderObjectStyle( @@ -334,7 +342,7 @@ def test_update_text_object(controller_client: AskUiControllerClient) -> None: controller_client.update_render_object(object_id, update_style) -def test_delete_render_object(controller_client: AskUiControllerClient) -> None: +def test_delete_render_object(controller_client: MultiComputerTargetAgentOS) -> None: """Test deleting an existing render object""" with controller_client: style = RenderObjectStyle( @@ -350,7 +358,7 @@ def test_delete_render_object(controller_client: AskUiControllerClient) -> None: controller_client.delete_render_object(quad_id) -def test_clear_render_objects(controller_client: AskUiControllerClient) -> None: +def test_clear_render_objects(controller_client: MultiComputerTargetAgentOS) -> None: """Test clearing all render objects""" with controller_client: style1 = RenderObjectStyle( @@ -374,7 +382,7 @@ def test_clear_render_objects(controller_client: AskUiControllerClient) -> None: controller_client.clear_render_objects() -def test_get_system_info(controller_client: AskUiControllerClient) -> None: +def test_get_system_info(controller_client: MultiComputerTargetAgentOS) -> None: """Test getting system information""" with controller_client: system_info = controller_client.get_system_info() @@ -385,7 +393,7 @@ def test_get_system_info(controller_client: AskUiControllerClient) -> None: assert system_info.architecture is not None -def test_get_active_process(controller_client: AskUiControllerClient) -> None: +def test_get_active_process(controller_client: MultiComputerTargetAgentOS) -> None: with controller_client: active_process = controller_client.get_active_process() @@ -395,7 +403,7 @@ def test_get_active_process(controller_client: AskUiControllerClient) -> None: assert active_process.process.id is not None -def test_set_active_process(controller_client: AskUiControllerClient) -> None: +def test_set_active_process(controller_client: MultiComputerTargetAgentOS) -> None: """Test setting the active process""" with controller_client: controller_client.set_active_process(1062) @@ -404,7 +412,7 @@ def test_set_active_process(controller_client: AskUiControllerClient) -> None: assert active_process.process is not None -def test_get_active_window(controller_client: AskUiControllerClient) -> None: +def test_get_active_window(controller_client: MultiComputerTargetAgentOS) -> None: """Test getting the active window""" with controller_client: active_window = controller_client.get_active_window() diff --git a/tests/integration/agent/test_retry.py b/tests/integration/agent/test_retry.py index 8f08d51a..bd1d453e 100644 --- a/tests/integration/agent/test_retry.py +++ b/tests/integration/agent/test_retry.py @@ -10,7 +10,7 @@ from askui.models.exceptions import ElementNotFoundError, ModelNotFoundError from askui.models.shared.settings import LocateSettings from askui.models.types.geometry import PointList -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS from askui.utils.image_utils import ImageSource @@ -58,21 +58,21 @@ def always_failing_provider() -> FailingDetectionProvider: @pytest.fixture def agent_with_retry( - failing_provider: FailingDetectionProvider, agent_toolbox_mock: AgentToolbox + failing_provider: FailingDetectionProvider, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 ) -> ComputerAgent: return ComputerAgent( settings=AgentSettings(detection_provider=failing_provider), - tools=agent_toolbox_mock, ) @pytest.fixture def agent_with_retry_on_multiple_exceptions( - failing_provider: FailingDetectionProvider, agent_toolbox_mock: AgentToolbox + failing_provider: FailingDetectionProvider, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 ) -> ComputerAgent: return ComputerAgent( settings=AgentSettings(detection_provider=failing_provider), - tools=agent_toolbox_mock, retry=ConfigurableRetry( on_exception_types=( ElementNotFoundError, @@ -88,11 +88,11 @@ def agent_with_retry_on_multiple_exceptions( @pytest.fixture def agent_always_fail( - always_failing_provider: FailingDetectionProvider, agent_toolbox_mock: AgentToolbox + always_failing_provider: FailingDetectionProvider, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG001 ) -> ComputerAgent: return ComputerAgent( settings=AgentSettings(detection_provider=always_failing_provider), - tools=agent_toolbox_mock, retry=ConfigurableRetry( on_exception_types=(ElementNotFoundError,), strategy="Fixed", diff --git a/tests/integration/test_custom_models.py b/tests/integration/test_custom_models.py index 996f610a..0bb8a266 100644 --- a/tests/integration/test_custom_models.py +++ b/tests/integration/test_custom_models.py @@ -26,7 +26,7 @@ from askui.models.shared.prompts import SystemPrompt from askui.models.shared.settings import GetSettings, LocateSettings from askui.models.shared.tools import ToolCollection -from askui.tools.toolbox import AgentToolbox +from askui.tools.agent_os import ComputerAgentOS from askui.utils.image_utils import ImageSource from askui.utils.source_utils import Source @@ -148,12 +148,11 @@ def detection_provider(self) -> SimpleDetectionProvider: def test_inject_and_use_custom_vlm_provider( self, vlm_provider: SimpleVlmProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test injecting and using a custom VLM provider.""" with ComputerAgent( settings=AgentSettings(vlm_provider=vlm_provider), - tools=agent_toolbox_mock, ) as agent: agent.act("test goal") @@ -175,12 +174,11 @@ def test_inject_and_use_custom_vlm_provider( def test_inject_and_use_custom_image_qa_provider( self, image_qa_provider: SimpleImageQAProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test injecting and using a custom image Q&A provider.""" with ComputerAgent( settings=AgentSettings(image_qa_provider=image_qa_provider), - tools=agent_toolbox_mock, ) as agent: result = agent.get("test query") @@ -190,13 +188,12 @@ def test_inject_and_use_custom_image_qa_provider( def test_inject_and_use_custom_image_qa_provider_with_pdf( self, image_qa_provider: SimpleImageQAProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 path_fixtures_dummy_pdf: pathlib.Path, ) -> None: """Test injecting and using a custom image Q&A provider with a PDF.""" with ComputerAgent( settings=AgentSettings(image_qa_provider=image_qa_provider), - tools=agent_toolbox_mock, ) as agent: result = agent.get("test query", source=path_fixtures_dummy_pdf) @@ -206,12 +203,11 @@ def test_inject_and_use_custom_image_qa_provider_with_pdf( def test_inject_and_use_custom_detection_provider( self, detection_provider: SimpleDetectionProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test injecting and using a custom detection provider.""" with ComputerAgent( settings=AgentSettings(detection_provider=detection_provider), - tools=agent_toolbox_mock, ) as agent: agent.click("test element") @@ -222,7 +218,7 @@ def test_inject_all_custom_providers( vlm_provider: SimpleVlmProvider, image_qa_provider: SimpleImageQAProvider, detection_provider: SimpleDetectionProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test injecting all custom providers at once.""" with ComputerAgent( @@ -231,7 +227,6 @@ def test_inject_all_custom_providers( image_qa_provider=image_qa_provider, detection_provider=detection_provider, ), - tools=agent_toolbox_mock, ) as agent: agent.act("test goal") result = agent.get("test query") @@ -258,7 +253,7 @@ def test_inject_all_custom_providers( def test_use_response_schema_with_custom_image_qa_provider( self, image_qa_provider: SimpleImageQAProvider, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test using a response schema with a custom image Q&A provider.""" response = SimpleResponseSchema(value="test value") @@ -266,7 +261,6 @@ def test_use_response_schema_with_custom_image_qa_provider( with ComputerAgent( settings=AgentSettings(image_qa_provider=image_qa_provider), - tools=agent_toolbox_mock, ) as agent: result = agent.get("test query", response_schema=SimpleResponseSchema) @@ -276,8 +270,8 @@ def test_use_response_schema_with_custom_image_qa_provider( def test_defaults_to_built_in_providers_when_not_provided( self, - agent_toolbox_mock: AgentToolbox, + agent_os_mock_patch: ComputerAgentOS, # noqa: ARG002 ) -> None: """Test agent uses built-in defaults when custom ones not provided.""" - with ComputerAgent(tools=agent_toolbox_mock) as agent: + with ComputerAgent() as agent: assert agent is not None diff --git a/tests/unit/tools/askui/test_agent_os_target_computer.py b/tests/unit/tools/askui/test_agent_os_target_computer.py new file mode 100644 index 00000000..86bffd7c --- /dev/null +++ b/tests/unit/tools/askui/test_agent_os_target_computer.py @@ -0,0 +1,131 @@ +from typing import Callable + +import pytest + +from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + LocalComputerTarget, + RemoteComputerTarget, +) + + +class TestReplacePort: + def test_replaces_port_on_bare_authority(self) -> None: + assert ( + LocalComputerTarget.replace_port("example.com:1234", 23000) + == "example.com:23000" + ) + + def test_replaces_port_on_url_with_scheme(self) -> None: + assert ( + LocalComputerTarget.replace_port("http://example.com:1234", 23000) + == "example.com:23000" + ) + + def test_falls_back_to_localhost_when_host_missing(self) -> None: + # A bare ":1234" has no hostname, so the helper falls back to "localhost". + assert LocalComputerTarget.replace_port(":1234", 23000) == "localhost:23000" + + +class TestAgentOsTargetComputer: + def test_session_guid_unique_per_instance(self) -> None: + a = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + b = RemoteComputerTarget(address="5.6.7.8:23000", description="b") + assert a.session_guid != b.session_guid + + def test_computer_id_defaults_to_session_guid(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + assert s.computer_id == s.session_guid + + def test_explicit_computer_id_is_preserved(self) -> None: + s = RemoteComputerTarget( + address="1.2.3.4:23000", description="a", computer_id="laptop" + ) + assert s.computer_id == "laptop" + assert s.session_guid != "laptop" + + def test_display_defaults_to_one_and_is_settable(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + assert s.display == 1 + s.display = 3 + assert s.display == 3 + + def test_explicit_display_is_preserved(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a", display=2) + assert s.display == 2 + + def test_repr_contains_identity_fields(self) -> None: + s = RemoteComputerTarget( + address="1.2.3.4:23000", + description="my rig", + display=2, + computer_id="rig", + ) + r = repr(s) + assert "RemoteComputerTarget" in r + assert "computer_id='rig'" in r + assert "description='my rig'" in r + assert "display=2" in r + + def test_base_class_is_not_local(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + assert s.is_local is False + + def test_start_and_stop_are_no_ops_on_remote(self) -> None: + s = RemoteComputerTarget(address="1.2.3.4:23000", description="a") + s.start() + s.stop() + + +class TestLocalAgentOsTargetComputer: + def test_is_local(self) -> None: + s = LocalComputerTarget(discover_service=False) + assert s.is_local is True + + def test_default_description(self) -> None: + s = LocalComputerTarget(discover_service=False) + assert s.description == "Local computer target" + + def test_default_address(self) -> None: + s = LocalComputerTarget(discover_service=False) + assert s.address == "localhost:23000" + + def test_is_service_default_false(self) -> None: + s = LocalComputerTarget(discover_service=False) + assert s.is_service is False + + def test_explicit_computer_id(self) -> None: + s = LocalComputerTarget(discover_service=False, computer_id="my-laptop") + assert s.computer_id == "my-laptop" + + def test_parse_port_rejects_bad_address(self) -> None: + s = LocalComputerTarget(discover_service=False, address="no-port-here") + with pytest.raises(ValueError, match="Could not parse port"): + s._parse_port() # noqa: SLF001 - intentional unit test against helper + + def test_parse_port_extracts_port(self) -> None: + s = LocalComputerTarget(discover_service=False, address="localhost:24567") + assert s._parse_port() == 24567 # noqa: SLF001 + + +class TestSubclassesPassThroughDisplayAndId: + @pytest.mark.parametrize( + "factory", + [ + lambda: LocalComputerTarget( + discover_service=False, display=4, computer_id="local" + ), + lambda: RemoteComputerTarget( + address="1.2.3.4:23000", + description="r", + display=4, + computer_id="remote", + ), + ], + ) + def test_display_and_computer_id_round_trip( + self, factory: Callable[[], ComputerTarget] + ) -> None: + s: ComputerTarget = factory() + assert s.display == 4 + assert s.computer_id in {"local", "remote"} diff --git a/tests/unit/tools/askui/test_askui_controller_client.py b/tests/unit/tools/askui/test_askui_controller_client.py new file mode 100644 index 00000000..4c007f5a --- /dev/null +++ b/tests/unit/tools/askui/test_askui_controller_client.py @@ -0,0 +1,216 @@ +""" +Unit tests for `MultiComputerTargetAgentOS`'s multi-target registration / routing +logic. These tests intentionally avoid exercising the gRPC code path (which +needs a real controller binary). They cover the in-memory bookkeeping done by +the client and its `ComputerTargetPool`. +""" + +import pytest + +from askui.tools.askui.agent_os_target_computer import ( + LocalComputerTarget, + RemoteComputerTarget, +) +from askui.tools.askui.askui_controller import MultiComputerTargetAgentOS +from askui.tools.askui.computer_target_pool import ( + ComputerTargetPool, +) +from askui.tools.askui.exceptions import AskUiControllerError + + +def _make_local( + description: str = "local", computer_id: str | None = None, display: int = 1 +) -> LocalComputerTarget: + return LocalComputerTarget( + description=description, + discover_service=False, + computer_id=computer_id, + display=display, + ) + + +def _make_remote( + address: str = "1.2.3.4:23000", + description: str = "remote", + computer_id: str | None = None, + display: int = 1, +) -> RemoteComputerTarget: + return RemoteComputerTarget( + address=address, + description=description, + computer_id=computer_id, + display=display, + ) + + +class TestConstruction: + def test_default_registers_single_local_target(self) -> None: + client = MultiComputerTargetAgentOS() + manager = client.agent_os_target_computer_manager + assert len(manager) == 1 + assert isinstance(manager.active, LocalComputerTarget) + + def test_default_propagates_display_to_default_local_target(self) -> None: + client = MultiComputerTargetAgentOS(display=3) + active = client.agent_os_target_computer_manager.active + assert active is not None + assert active.display == 3 + + def test_accepts_explicit_targets(self) -> None: + a = _make_local(computer_id="local") + b = _make_remote(computer_id="remote") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + assert client.agent_os_target_computer_manager.describe() == [ + repr(a), + repr(b), + ] + assert client.agent_os_target_computer_manager.active is a + + def test_explicit_targets_keep_their_own_display(self) -> None: + """Constructor's display arg only seeds the auto-created default target.""" + a = _make_local(computer_id="local", display=2) + b = _make_remote(computer_id="remote", display=3) + client = MultiComputerTargetAgentOS(display=5, agent_os_target_computers=[a, b]) + assert client.agent_os_target_computer_manager.get("local").display == 2 + assert client.agent_os_target_computer_manager.get("remote").display == 3 + + def test_is_connected_false_before_connect(self) -> None: + client = MultiComputerTargetAgentOS(agent_os_target_computers=[_make_remote()]) + assert client.is_connected is False + + +class TestActiveTarget: + def test_get_current_returns_first_registered_id(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + assert client.get_current_computer_target_id() == "a" + + def test_get_current_with_empty_manager_raises(self) -> None: + client = MultiComputerTargetAgentOS(agent_os_target_computers=[_make_remote()]) + client.agent_os_target_computer_manager.reset() + with pytest.raises( + AskUiControllerError, match="No active Agent OS target computer" + ): + client.get_current_computer_target_id(report=False) + + +class TestSwitchAgentOsTargetComputer: + def test_switch_changes_active_when_disconnected(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + client.switch_agent_os_target_computer("b") + assert client.agent_os_target_computer_manager.active is b + + def test_switch_unknown_computer_id_raises_keyerror(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_local(computer_id="a")] + ) + with pytest.raises(KeyError, match="missing"): + client.switch_agent_os_target_computer("missing") + + def test_switch_returns_the_new_active_target(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + result = client.switch_agent_os_target_computer("b") + assert result is b + + def test_per_target_display_preserved_across_switch(self) -> None: + a = _make_local(computer_id="a", display=1) + b = _make_remote(computer_id="b", display=4) + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + client.switch_agent_os_target_computer("b") + active_b = client.agent_os_target_computer_manager.active + assert active_b is not None + assert active_b.display == 4 + client.switch_agent_os_target_computer("a") + active_a = client.agent_os_target_computer_manager.active + assert active_a is not None + assert active_a.display == 1 + + +class TestDescribeAndReset: + def test_describe_returns_registered_target_summaries(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + assert client.describe_agent_os_target_computers() == [repr(a), repr(b)] + + def test_reset_with_no_args_leaves_manager_empty(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_remote(computer_id="r")] + ) + client.reset_agent_os_target_computers() + assert client.describe_agent_os_target_computers() == [] + + def test_reset_with_new_list_replaces_registrations(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_remote(computer_id="old")] + ) + new_agent_os_target_computer = _make_remote( + address="9.9.9.9:23000", computer_id="new" + ) + client.reset_agent_os_target_computers([new_agent_os_target_computer]) + assert client.describe_agent_os_target_computers() == [ + repr(new_agent_os_target_computer) + ] + assert ( + client.agent_os_target_computer_manager.active + is new_agent_os_target_computer + ) + + +class TestAddAgentOsTargetComputerWhileDisconnected: + def test_add_already_constructed_target(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_local(computer_id="l")] + ) + extra = _make_remote(address="2.2.2.2:23000", computer_id="r") + result = client.add_agent_os_target_computer(extra) + assert result is extra + assert repr(extra) in client.describe_agent_os_target_computers() + + +class TestTemporarySelect: + def test_temporary_select_restores_previous_active(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + manager = client.agent_os_target_computer_manager + before = manager.active + assert before is a + with client.temporary_select("b"): + inside = manager.active + assert inside is b + after = manager.active + assert after is a + + def test_temporary_select_restores_previous_even_on_exception(self) -> None: + a = _make_local(computer_id="a") + b = _make_remote(computer_id="b") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a, b]) + error_message = "boom" + with ( + pytest.raises(RuntimeError, match=error_message), + client.temporary_select("b"), + ): + assert client.agent_os_target_computer_manager.active is b + raise RuntimeError(error_message) + assert client.agent_os_target_computer_manager.active is a + + def test_temporary_select_same_id_is_a_noop_around_yield(self) -> None: + a = _make_local(computer_id="a") + client = MultiComputerTargetAgentOS(agent_os_target_computers=[a]) + with client.temporary_select("a"): + assert client.agent_os_target_computer_manager.active is a + assert client.agent_os_target_computer_manager.active is a + + +class TestUsesAgentOsTargetComputerManager: + def test_underlying_manager_is_an_agent_os_target_computer_manager(self) -> None: + client = MultiComputerTargetAgentOS( + agent_os_target_computers=[_make_local(computer_id="l")] + ) + assert isinstance(client.agent_os_target_computer_manager, ComputerTargetPool) diff --git a/tests/unit/tools/askui/test_askui_controller_client_settings.py b/tests/unit/tools/askui/test_askui_controller_client_settings.py deleted file mode 100644 index 3a086453..00000000 --- a/tests/unit/tools/askui/test_askui_controller_client_settings.py +++ /dev/null @@ -1,73 +0,0 @@ -from unittest.mock import patch - -import pytest -from pydantic import ValidationError - -from askui.tools.askui.askui_controller_client_settings import ( - AskUiControllerClientSettings, -) - - -class TestAskUiControllerClientSettings: - """Test suite for AskUiControllerClientSettings.""" - - def test_defaults(self) -> None: - """Defaults are applied when no environment variables are set.""" - with patch.dict("os.environ", {}, clear=True): - settings = AskUiControllerClientSettings() - assert settings.server_address == "localhost:23000" - assert settings.server_autostart is True - - def test_server_address_from_env(self) -> None: - """ - `ASKUI_CONTROLLER_CLIENT_SERVER_ADDRESS` overrides default for `server_address`. - """ - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_ADDRESS": "127.0.0.1:24000"}, - clear=True, - ): - settings = AskUiControllerClientSettings() - assert settings.server_address == "127.0.0.1:24000" - - def test_server_autostart_from_env_false(self) -> None: - """`ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART` parses boolean from env.""" - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART": "False"}, - clear=True, - ): - settings = AskUiControllerClientSettings() - assert settings.server_autostart is False - - def test_server_autostart_from_env_true(self) -> None: - """Boolean true value is parsed correctly from environment variable.""" - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART": "true"}, - clear=True, - ): - settings = AskUiControllerClientSettings() - assert settings.server_autostart is True - - def test_server_address_from_constructor(self) -> None: - """`server_address` is set correctly from constructor.""" - settings = AskUiControllerClientSettings(server_address="127.0.0.1:24000") - assert settings.server_address == "127.0.0.1:24000" - - def test_server_autostart_from_constructor(self) -> None: - """`server_autostart` is set correctly from constructor.""" - settings = AskUiControllerClientSettings(server_autostart=False) - assert settings.server_autostart is False - - def test_autostart_from_env_with_invalid_value(self) -> None: - """ - Test that ValidationError is raised when environment variable is invalid. - """ - with patch.dict( - "os.environ", - {"ASKUI_CONTROLLER_CLIENT_SERVER_AUTOSTART": "invalid"}, - clear=True, - ): - with pytest.raises(ValidationError): - AskUiControllerClientSettings() diff --git a/tests/unit/tools/askui/test_computer_target_pool.py b/tests/unit/tools/askui/test_computer_target_pool.py new file mode 100644 index 00000000..65f33c25 --- /dev/null +++ b/tests/unit/tools/askui/test_computer_target_pool.py @@ -0,0 +1,203 @@ +from collections.abc import Callable + +import pytest + +from askui.tools.askui.agent_os_target_computer import ( + ComputerTarget, + LocalComputerTarget, + RemoteComputerTarget, +) +from askui.tools.askui.computer_target_pool import ( + ComputerTargetPool, +) + + +def _make_remote( + address: str = "1.2.3.4:23000", + description: str = "remote", + computer_id: str | None = None, +) -> RemoteComputerTarget: + return RemoteComputerTarget( + address=address, description=description, computer_id=computer_id + ) + + +def _make_local(computer_id: str | None = None) -> LocalComputerTarget: + return LocalComputerTarget(discover_service=False, computer_id=computer_id) + + +@pytest.fixture(params=["local", "remote"]) +def make_target( + request: pytest.FixtureRequest, +) -> Callable[..., ComputerTarget]: + """Build a single target of the parametrized kind so a test runs once per kind. + + Use for tests that register exactly one target and where the local/remote + distinction is irrelevant to the behavior under test. + """ + + def _make( + computer_id: str | None = None, + address: str = "1.2.3.4:23000", + ) -> ComputerTarget: + if request.param == "local": + return _make_local(computer_id=computer_id) + return _make_remote(address=address, computer_id=computer_id) + + return _make + + +class TestConstruction: + def test_empty_constructor_yields_empty_manager(self) -> None: + m = ComputerTargetPool() + assert m.describe() == [] + assert m.active is None + assert len(m) == 0 + + def test_constructor_registers_initial_targets_in_order(self) -> None: + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m = ComputerTargetPool(agent_os_target_computers=[a, b]) + assert m.describe() == [repr(a), repr(b)] + # First registered becomes active. + assert m.active is a + + def test_first_added_becomes_active( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + a = make_target(computer_id="a") + m.add(a) + assert m.active is a + + +class TestAddConstraints: + def test_rejects_second_local_target(self) -> None: + m = ComputerTargetPool() + m.add(_make_local(computer_id="first")) + with pytest.raises(ValueError, match="second local Agent OS target computer"): + m.add(_make_local(computer_id="second")) + + def test_rejects_duplicate_computer_id(self) -> None: + m = ComputerTargetPool() + m.add(_make_remote(address="1.1.1.1:23000", computer_id="rig")) + with pytest.raises(ValueError, match="computer_id='rig'"): + m.add(_make_remote(address="2.2.2.2:23000", computer_id="rig")) + + def test_rejects_duplicate_remote_address(self) -> None: + m = ComputerTargetPool() + m.add(_make_remote(address="1.1.1.1:23000", computer_id="a")) + with pytest.raises( + ValueError, + match="remote Agent OS target computer with address '1.1.1.1:23000'", + ): + m.add(_make_remote(address="1.1.1.1:23000", computer_id="b")) + + def test_allows_local_plus_remote_with_same_address(self) -> None: + m = ComputerTargetPool() + m.add(_make_local(computer_id="local")) + # Local target's default address is 'localhost:23000' but the local/remote + # address-uniqueness rule only applies between remote targets. + m.add( + _make_remote( + address="localhost:23000", description="remote", computer_id="remote" + ) + ) + assert len(m) == 2 + + +class TestGetAndSwitch: + def test_get_returns_target_by_computer_id( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + a = make_target(address="1.1.1.1:23000", computer_id="a") + m.add(a) + assert m.get("a") is a + + def test_get_raises_keyerror_with_registered_ids( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + m.add(make_target(address="1.1.1.1:23000", computer_id="a")) + with pytest.raises(KeyError) as exc_info: + m.get("missing") + message = str(exc_info.value) + assert "missing" in message + assert "'a'" in message # registered id surfaced + + def test_switch_changes_active(self) -> None: + m = ComputerTargetPool() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + assert m.active is a + m.switch("b") + assert m.active is b + + def test_switch_unknown_id_raises_keyerror( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + m.add(make_target(computer_id="a")) + with pytest.raises(KeyError, match="missing"): + m.switch("missing") + + +class TestRemove: + def test_remove_drops_target(self) -> None: + m = ComputerTargetPool() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + m.remove("a") + assert m.describe() == [repr(b)] + + def test_remove_active_falls_back_to_first_remaining(self) -> None: + m = ComputerTargetPool() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + assert m.active is a + m.remove("a") + assert m.active is b + + def test_remove_last_clears_active( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + m.add(make_target(computer_id="a")) + m.remove("a") + assert m.active is None + assert len(m) == 0 + + def test_remove_inactive_keeps_active_unchanged(self) -> None: + m = ComputerTargetPool() + a = _make_remote(address="1.1.1.1:23000", computer_id="a") + b = _make_remote(address="2.2.2.2:23000", computer_id="b") + m.add(a) + m.add(b) + m.remove("b") + assert m.active is a + + def test_remove_unknown_raises_keyerror( + self, make_target: Callable[..., ComputerTarget] + ) -> None: + m = ComputerTargetPool() + m.add(make_target(computer_id="a")) + with pytest.raises(KeyError): + m.remove("missing") + + +class TestReset: + def test_reset_clears_all(self) -> None: + m = ComputerTargetPool() + m.add(_make_remote(computer_id="a")) + m.add(_make_remote(address="2.2.2.2:23000", computer_id="b")) + m.reset() + assert m.describe() == [] + assert m.active is None + assert len(m) == 0 diff --git a/tests/unit/tools/computer/__init__.py b/tests/unit/tools/computer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/tools/computer/test_agent_os_target_computer_tools.py b/tests/unit/tools/computer/test_agent_os_target_computer_tools.py new file mode 100644 index 00000000..c3c6132a --- /dev/null +++ b/tests/unit/tools/computer/test_agent_os_target_computer_tools.py @@ -0,0 +1,80 @@ +from unittest.mock import MagicMock + +import pytest + +from askui.tools.agent_os import ComputerAgentOS +from askui.tools.askui.agent_os_target_computer import RemoteComputerTarget +from askui.tools.computer import ( + ComputerGetCurrentComputerTargetIdTool, + ComputerListAgentOsTargetComputersTool, + ComputerSwitchAgentOsTargetComputerTool, +) + + +@pytest.fixture +def fake_agent_os() -> MagicMock: + """A MagicMock that passes `isinstance(x, ComputerAgentOS)` checks.""" + return MagicMock(spec=ComputerAgentOS) + + +class TestComputerListAgentOsTargetComputersTool: + def test_tool_name(self, fake_agent_os: MagicMock) -> None: + tool = ComputerListAgentOsTargetComputersTool(agent_os=fake_agent_os) + assert tool.base_name == "list_agent_os_target_computers" + + def test_returns_newline_joined_reprs(self, fake_agent_os: MagicMock) -> None: + a = RemoteComputerTarget( + address="1.1.1.1:23000", description="a", computer_id="a" + ) + b = RemoteComputerTarget( + address="2.2.2.2:23000", description="b", computer_id="b" + ) + fake_agent_os.describe_agent_os_target_computers.return_value = [ + repr(a), + repr(b), + ] + tool = ComputerListAgentOsTargetComputersTool(agent_os=fake_agent_os) + out = tool() + assert out == f"{a!r}\n{b!r}" + + def test_empty_list_yields_empty_string(self, fake_agent_os: MagicMock) -> None: + fake_agent_os.describe_agent_os_target_computers.return_value = [] + tool = ComputerListAgentOsTargetComputersTool(agent_os=fake_agent_os) + assert tool() == "" + + +class TestComputerSwitchAgentOsTargetComputerTool: + def test_tool_name(self, fake_agent_os: MagicMock) -> None: + tool = ComputerSwitchAgentOsTargetComputerTool(agent_os=fake_agent_os) + assert tool.base_name == "switch_agent_os_target_computer" + + def test_input_schema_requires_computer_id(self, fake_agent_os: MagicMock) -> None: + tool = ComputerSwitchAgentOsTargetComputerTool(agent_os=fake_agent_os) + schema = tool.input_schema + assert "computer_id" in schema["properties"] + assert schema["required"] == ["computer_id"] + + def test_call_delegates_to_switch_agent_os_target_computer( + self, fake_agent_os: MagicMock + ) -> None: + switched = RemoteComputerTarget( + address="1.1.1.1:23000", description="new", computer_id="new" + ) + fake_agent_os.switch_agent_os_target_computer.return_value = switched + tool = ComputerSwitchAgentOsTargetComputerTool(agent_os=fake_agent_os) + out = tool(computer_id="new") + fake_agent_os.switch_agent_os_target_computer.assert_called_once_with("new") + assert out == repr(switched) + + +class TestComputerGetCurrentComputerTargetIdTool: + def test_tool_name(self, fake_agent_os: MagicMock) -> None: + tool = ComputerGetCurrentComputerTargetIdTool(agent_os=fake_agent_os) + assert tool.base_name == "get_current_computer_target_id" + + def test_call_returns_current_computer_id(self, fake_agent_os: MagicMock) -> None: + fake_agent_os.get_current_computer_target_id.return_value = "a" + tool = ComputerGetCurrentComputerTargetIdTool(agent_os=fake_agent_os) + out = tool() + fake_agent_os.get_current_computer_target_id.assert_called_once_with() + assert out == "a"