Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions src/hawk/memory_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Memory-as-voluntary-tools for agent-driven memory management.

Lets agents strategically decide what to remember/recall rather than
auto-ingesting everything. Wraps yaad's memory API as tool functions.
auto-ingesting everything. Wraps any MemoryBackend-compatible client
as tool functions.

Usage:
from hawk.memory_tools import MemoryTools
Expand All @@ -12,11 +13,28 @@

from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

if TYPE_CHECKING:
from collections.abc import Sequence

from .tools import Tool


@runtime_checkable
class MemoryBackend(Protocol):
"""Structural interface for persistent memory backends.

Any object with ``remember`` and ``recall`` methods satisfying these
signatures is accepted — no import or inheritance required.
"""

def remember(self, content: str, *, session_id: str | None = None) -> Any: ...
def recall(
self, query: str, *, limit: int = 5, session_id: str | None = None
) -> Sequence[str]: ...


class MemoryTools:
"""Provides record/retrieve memory operations as agent tools.

Expand All @@ -30,6 +48,9 @@ def __init__(self, client: Any, *, session_id: str | None = None) -> None:
self._client = client
self._session_id = session_id
self._local_memories: list[dict[str, str]] = []
self._backend: MemoryBackend | None = (
client if hasattr(client, "remember") and hasattr(client, "recall") else None
)

def record_memory(
self, content: str, category: str = "general", importance: str = "normal"
Expand All @@ -42,34 +63,32 @@ def record_memory(
}
self._local_memories.append(memory)

# If client supports yaad memory API, persist
try:
if hasattr(self._client, "remember"):
self._client.remember(content, session_id=self._session_id)
if self._backend is not None:
try:
self._backend.remember(content, session_id=self._session_id)
return f"Recorded to persistent memory: '{content[:100]}...'"
except Exception as exc:
import logging
except Exception as exc:
import logging

logging.getLogger(__name__).debug("Failed to persist memory via yaad: %s", exc)
logging.getLogger(__name__).debug("Failed to persist memory: %s", exc)

return f"Recorded to session memory: '{content[:100]}...'"

def retrieve_memories(self, query: str, limit: int = 5) -> str:
"""Retrieve relevant memories for the current context."""
results = []

# Try yaad recall
try:
if hasattr(self._client, "recall"):
recalled = self._client.recall(query, limit=limit, session_id=self._session_id)
if self._backend is not None:
try:
recalled = self._backend.recall(query, limit=limit, session_id=self._session_id)
if recalled:
return f"Recalled {len(recalled)} memories:\n" + "\n".join(
f"- {m}" for m in recalled
)
except Exception as exc:
import logging
except Exception as exc:
import logging

logging.getLogger(__name__).debug("Failed to recall memories via yaad: %s", exc)
logging.getLogger(__name__).debug("Failed to recall memories: %s", exc)

# Fallback to local fuzzy match
query_lower = query.lower()
Expand Down
Loading