diff --git a/src/hawk/memory_tools.py b/src/hawk/memory_tools.py index f5491b7..be8c539 100644 --- a/src/hawk/memory_tools.py +++ b/src/hawk/memory_tools.py @@ -1,7 +1,8 @@ """Memory-as-voluntary-tools for agent-driven memory management. Lets agents strategically decide what to remember/recall rather than -auto-ingesting everything. Wraps yaad's memory API as tool functions. +auto-ingesting everything. Wraps any MemoryBackend-compatible client +as tool functions. Usage: from hawk.memory_tools import MemoryTools @@ -12,11 +13,28 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + from collections.abc import Sequence from .tools import Tool +@runtime_checkable +class MemoryBackend(Protocol): + """Structural interface for persistent memory backends. + + Any object with ``remember`` and ``recall`` methods satisfying these + signatures is accepted — no import or inheritance required. + """ + + def remember(self, content: str, *, session_id: str | None = None) -> Any: ... + def recall( + self, query: str, *, limit: int = 5, session_id: str | None = None + ) -> Sequence[str]: ... + + class MemoryTools: """Provides record/retrieve memory operations as agent tools. @@ -30,6 +48,9 @@ def __init__(self, client: Any, *, session_id: str | None = None) -> None: self._client = client self._session_id = session_id self._local_memories: list[dict[str, str]] = [] + self._backend: MemoryBackend | None = ( + client if hasattr(client, "remember") and hasattr(client, "recall") else None + ) def record_memory( self, content: str, category: str = "general", importance: str = "normal" @@ -42,15 +63,14 @@ def record_memory( } self._local_memories.append(memory) - # If client supports yaad memory API, persist - try: - if hasattr(self._client, "remember"): - self._client.remember(content, session_id=self._session_id) + if self._backend is not None: + try: + self._backend.remember(content, session_id=self._session_id) return f"Recorded to persistent memory: '{content[:100]}...'" - except Exception as exc: - import logging + except Exception as exc: + import logging - logging.getLogger(__name__).debug("Failed to persist memory via yaad: %s", exc) + logging.getLogger(__name__).debug("Failed to persist memory: %s", exc) return f"Recorded to session memory: '{content[:100]}...'" @@ -58,18 +78,17 @@ def retrieve_memories(self, query: str, limit: int = 5) -> str: """Retrieve relevant memories for the current context.""" results = [] - # Try yaad recall - try: - if hasattr(self._client, "recall"): - recalled = self._client.recall(query, limit=limit, session_id=self._session_id) + if self._backend is not None: + try: + recalled = self._backend.recall(query, limit=limit, session_id=self._session_id) if recalled: return f"Recalled {len(recalled)} memories:\n" + "\n".join( f"- {m}" for m in recalled ) - except Exception as exc: - import logging + except Exception as exc: + import logging - logging.getLogger(__name__).debug("Failed to recall memories via yaad: %s", exc) + logging.getLogger(__name__).debug("Failed to recall memories: %s", exc) # Fallback to local fuzzy match query_lower = query.lower()