"""Minimal Connect agent — public-deps-only reference implementation.

Demonstrates the full relay handshake using only public Python packages.
No dependency on Connect's internal modules. The wire types below are
re-declared from the spec at ``docs/RELAY_WIRE_PROTOCOL.md``; they are
byte-compatible with Connect's own ``connect/relay/protocol.py``.

What this agent does:
1. Registers a fresh agent via POST /v1/agents (no Connect-side account
   needed; the API returns an api_key bound to the new agent).
2. Opens an outbound WebSocket to the relay.
3. Completes the binary AuthMessage handshake.
4. Replies to inbound RequestMessage frames with a simple JSON echo.
5. Replies to PingMessage frames with PongMessage.
6. Exits cleanly on DrainMessage.

Run:

    pip install msgspec httpx httpx-ws

    BOT_NAME=my-test-agent uv run python examples/minimal_agent.py

Required env:
    BOT_NAME        — display name (a short uuid suffix is added on
                      first registration to avoid collisions)

Optional env:
    CONNECT_API     — default ``https://api.actex.ai/connect``
    CONNECT_WS      — default ``wss://api.actex.ai/connect/ws/agent``
    BOT_KEY_FILE    — cache for the registered identity across runs
                      (default: ``/tmp/<bot_name>.json``)
    LOG_LEVEL       — INFO (default) | DEBUG

This file is intentionally standalone. Anything you can do here, you
can port to any language with a MessagePack library.
"""

from __future__ import annotations

import asyncio
import json
import logging
import os
import sys
import uuid
from pathlib import Path
from typing import Any, NamedTuple

import httpx
import msgspec
from httpx_ws import aconnect_ws

# -- Wire types (re-declared from the protocol spec) -------------------------
# These are byte-compatible with connect/relay/protocol.py. The integer
# tag values are the source of truth; field names and types match the
# spec at docs/RELAY_WIRE_PROTOCOL.md.


class AuthMessage(msgspec.Struct, tag_field="type", tag=0):
    token: str = ""
    dh_public_key: bytes = b""
    capabilities: list[str] = msgspec.field(default_factory=list)


class AuthOkMessage(msgspec.Struct, tag_field="type", tag=1):
    agent_id: str = ""
    proxy_url: str = ""
    dh_public_key: bytes = b""
    instance_id: str = ""


class AuthFailMessage(msgspec.Struct, tag_field="type", tag=2):
    reason: str = ""


class RequestMessage(msgspec.Struct, tag_field="type", tag=3):
    request_id: str = ""
    method: str = "POST"
    path: str = "/"
    headers: dict[str, str] = msgspec.field(default_factory=dict)
    body: bytes = b""
    is_final: bool = True


class ResponseMessage(msgspec.Struct, tag_field="type", tag=4):
    request_id: str = ""
    status_code: int = 200
    headers: dict[str, str] = msgspec.field(default_factory=dict)
    body: bytes = b""
    error: str | None = None
    is_final: bool = True


class PingMessage(msgspec.Struct, tag_field="type", tag=5):
    hmac: bytes = b""


class PongMessage(msgspec.Struct, tag_field="type", tag=6):
    hmac: bytes = b""


class ErrorMessage(msgspec.Struct, tag_field="type", tag=7):
    request_id: str = ""
    error: str = ""


class DrainMessage(msgspec.Struct, tag_field="type", tag=8):
    reason: str = ""
    deadline_seconds: float = 0.0


class CancelMessage(msgspec.Struct, tag_field="type", tag=9):
    request_id: str = ""
    reason: str = ""


class UploadChunkMessage(msgspec.Struct, tag_field="type", tag=10):
    request_id: str = ""
    sequence: int = 0
    body: bytes = b""
    is_final: bool = False


# Server-to-client union for decoding inbound frames.
ServerMessage = (
    AuthOkMessage
    | AuthFailMessage
    | RequestMessage
    | PingMessage
    | DrainMessage
    | CancelMessage
    | UploadChunkMessage
)


encoder = msgspec.msgpack.Encoder()
inbound_decoder = msgspec.msgpack.Decoder(ServerMessage)


# -- Registration ------------------------------------------------------------


def _make_card(name: str, url: str) -> dict[str, Any]:
    """Minimal A2A v1.0 agent card for registration.

    The ``url`` field's domain must resolve (DNS-must-resolve validator
    on the registration endpoint). For testing, ``https://example.com``
    works. For production, use a domain you control.
    """
    return {
        "card": {
            "name": name,
            "description": (
                "Minimal reference agent. Returns request body wrapped "
                "in {ack: true, echo: <body>}."
            ),
            "version": "1.0.0",
            "protocol_version": "1.0",
            "url": url,
            "skills": [
                {
                    "id": "echo",
                    "name": "Echo",
                    "description": "Returns request body verbatim.",
                    "tags": ["smoke-test", "reference"],
                }
            ],
        },
        "tagline": "Minimal reference",
        "category": "general",
    }


class AgentIdentity(NamedTuple):
    agent_id: str
    api_key: str
    name: str


async def _register_or_load(
    api: httpx.AsyncClient,
    bot_name: str,
    agent_url: str,
    key_file: Path,
) -> AgentIdentity:
    """Return the agent's identity, registering on first call.

    Caches in ``key_file``; subsequent runs reuse the same agent. Note:
    the cache is keyed by file path, not by ``bot_name`` or ``agent_url``
    — change those env vars and you'll need to delete the file to
    re-register.
    """
    if key_file.exists():
        cached = json.loads(key_file.read_text())
        return AgentIdentity(cached["id"], cached["api_key"], cached["name"])

    suffix = uuid.uuid4().hex[:6]
    full_name = f"{bot_name}-{suffix}"
    payload = _make_card(full_name, agent_url)

    response = await api.post("/v1/agents", json=payload)
    if response.is_error:
        # Bubble up the actual validation message so users can see
        # which field failed instead of a bare "422".
        raise RuntimeError(f"registration failed: HTTP {response.status_code} — {response.text}")
    body = response.json()

    identity = AgentIdentity(body["id"], body["api_key"], full_name)
    key_file.write_text(
        json.dumps({"id": identity.agent_id, "api_key": identity.api_key, "name": identity.name})
    )
    return identity


# -- Application logic -------------------------------------------------------


def handle_request(req: RequestMessage) -> ResponseMessage:
    """Translate an inbound RequestMessage into a ResponseMessage.

    Replace this function with your real handler. The default returns
    {ack: true, echo: <body-as-utf8>} as JSON.
    """
    try:
        echo_body = req.body.decode("utf-8") if req.body else ""
    except UnicodeDecodeError:
        echo_body = "<binary>"
    payload = {"ack": True, "echo": echo_body}
    return ResponseMessage(
        request_id=req.request_id,
        status_code=200,
        headers={"content-type": "application/json"},
        body=json.dumps(payload).encode("utf-8"),
    )


# -- WebSocket loop ----------------------------------------------------------


async def run_ws(ws_url: str, api_key: str, log: logging.Logger) -> None:
    """Authenticate, then process inbound frames until the server drains us."""
    async with aconnect_ws(ws_url) as ws:

        async def send(msg: Any) -> None:
            await ws.send_bytes(encoder.encode(msg))

        # 1. Send AuthMessage as the FIRST frame after upgrade.
        await send(AuthMessage(token=api_key))

        # 2. Expect AuthOkMessage (or AuthFailMessage on failure).
        first = inbound_decoder.decode(await ws.receive_bytes())
        if isinstance(first, AuthFailMessage):
            log.error("auth rejected: %s", first.reason)
            return
        if not isinstance(first, AuthOkMessage):
            log.error("unexpected first frame: %s", type(first).__name__)
            return
        log.info(
            "connected agent_id=%s instance_id=%s proxy_url=%s",
            first.agent_id,
            first.instance_id,
            first.proxy_url,
        )

        # 3. Run the read loop.
        while True:
            try:
                raw = await ws.receive_bytes()
            except Exception as exc:
                log.warning("websocket closed (%s): %s", type(exc).__name__, exc)
                return

            msg = inbound_decoder.decode(raw)

            if isinstance(msg, PingMessage):
                # Echo the hmac as-is. We did not negotiate DH, so the server
                # does not validate it. With DH enabled, recompute over
                # b"connect-pong" using the session key (see spec § HMAC mode).
                await send(PongMessage(hmac=msg.hmac))
                continue

            if isinstance(msg, RequestMessage):
                await send(handle_request(msg))
                continue

            if isinstance(msg, DrainMessage):
                log.info(
                    "server signalled drain (reason=%s deadline=%.1fs); exiting",
                    msg.reason,
                    msg.deadline_seconds,
                )
                return

            if isinstance(msg, CancelMessage):
                # Abort the in-flight task for msg.request_id. This sample
                # is synchronous so there is nothing to cancel.
                log.debug("cancel request_id=%s reason=%s", msg.request_id, msg.reason)
                continue

            if isinstance(msg, UploadChunkMessage):
                # We did not declare "chunked-upload" capability; server
                # should never send these.
                log.warning("unexpected UploadChunkMessage")
                continue

            log.warning("unknown message type %s", type(msg).__name__)


# -- Entrypoint --------------------------------------------------------------


async def main() -> int:
    logging.basicConfig(
        level=os.environ.get("LOG_LEVEL", "INFO"),
        format="%(asctime)s [%(name)s] %(levelname)s %(message)s",
    )

    bot_name = os.environ.get("BOT_NAME", "minimal-bot")
    log = logging.getLogger(bot_name)

    connect_api = os.environ.get("CONNECT_API", "https://api.actex.ai/connect").rstrip("/")
    connect_ws = os.environ.get("CONNECT_WS", "wss://api.actex.ai/connect/ws/agent")
    # AGENT_URL must point at a domain that resolves (DNS-must-resolve
    # validator on POST /v1/agents). example.com is a good default for
    # testing; in production, use a domain you control.
    agent_url = os.environ.get("AGENT_URL", f"https://example.com/agents/{bot_name}")
    key_file = Path(os.environ.get("BOT_KEY_FILE", f"/tmp/{bot_name}.json"))

    async with httpx.AsyncClient(base_url=connect_api, timeout=20.0) as api:
        identity = await _register_or_load(api, bot_name, agent_url, key_file)
        log.info(
            "agent_id=%s name=%s key_prefix=%s...",
            identity.agent_id,
            identity.name,
            identity.api_key[:8],
        )
        await run_ws(connect_ws, identity.api_key, log)

    return 0


if __name__ == "__main__":
    sys.exit(asyncio.run(main()))
