import argparse
import json
import os
import sys
from pathlib import Path
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen

API_URL = "https://api.exa.ai/search"
SEARCH_TYPES = ["auto", "fast", "instant", "deep", "deep-reasoning", "neural"]


def load_json_text(value: str | None, label: str) -> dict | None:
    if not value:
        return None
    try:
        return json.loads(value)
    except json.JSONDecodeError as exc:
        raise SystemExit(f"Invalid JSON for {label}: {exc}")


def load_json_file(path: str | None, label: str) -> dict | None:
    if not path:
        return None
    resolved = Path(path).expanduser()
    try:
        return json.loads(resolved.read_text())
    except FileNotFoundError:
        raise SystemExit(f"{label} file not found at {resolved}")
    except json.JSONDecodeError as exc:
        raise SystemExit(f"Invalid JSON in {label}: {exc}")


def format_http_error(exc: HTTPError) -> str:
    raw = exc.read().decode("utf-8", errors="replace")
    try:
        payload = json.loads(raw)
    except json.JSONDecodeError:
        return f"API error {exc.code}: {raw}"
    error = payload.get("error") or raw
    tag = payload.get("tag")
    request_id = payload.get("requestId")
    parts = [f"API error {exc.code}: {error}"]
    if tag:
        parts.append(f"tag={tag}")
    if request_id:
        parts.append(f"requestId={request_id}")
    return " | ".join(parts)


def summarize_results(resp: dict) -> None:
    results = resp.get("results", [])
    print()
    print("Results summary:")
    if not results:
        print("  (no results returned)")
        return
    for idx, item in enumerate(results[:3], start=1):
        title = item.get("title") or "Untitled"
        url = item.get("url") or item.get("id") or "(no url)"
        published = item.get("publishedDate")
        author = item.get("author")
        extras = []
        if published:
            extras.append(f"published {published}")
        if author:
            extras.append(f"author {author}")
        suffix = f" ({'; '.join(extras)})" if extras else ""
        print(f"  {idx}. {title} — {url}{suffix}")


def summarize_grounding(resp: dict) -> None:
    output = resp.get("output") or {}
    grounding = output.get("grounding") or []
    if not grounding:
        return
    print()
    print("Grounding:")
    for entry in grounding:
        field = entry.get("field", "content")
        confidence = entry.get("confidence") or "unknown"
        print(f"  Field: {field} (confidence: {confidence})")
        citations = entry.get("citations") or []
        if not citations:
            print("    (no citations)")
            continue
        for cite in citations:
            title = cite.get("title") or cite.get("url") or "(unknown)"
            url = cite.get("url") or "(no url)"
            print(f"    • {title} — {url}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Call the Exa Search API with a reusable payload builder and grounding summary.")
    parser.add_argument("--query", required=True, help="Natural-language search query.")
    parser.add_argument("--type", choices=SEARCH_TYPES, default="auto", help="Search type.")
    parser.add_argument("--num-results", type=int, default=10, help="How many results to request (1-100).")
    parser.add_argument("--category", help="Restrict to a category (company, people, research paper, news, etc.).")
    parser.add_argument("--user-location", help="Two-letter ISO country code to bias results.")
    parser.add_argument("--include-domains", nargs="+", help="Domains to include.")
    parser.add_argument("--exclude-domains", nargs="+", help="Domains to exclude.")
    parser.add_argument("--start-published-date", help="ISO date YYYY-MM-DD or full ISO 8601.")
    parser.add_argument("--end-published-date", help="ISO date YYYY-MM-DD or full ISO 8601.")
    parser.add_argument("--start-crawl-date", help="ISO date YYYY-MM-DD or full ISO 8601.")
    parser.add_argument("--end-crawl-date", help="ISO date YYYY-MM-DD or full ISO 8601.")
    parser.add_argument("--include-text", nargs="+", help="Strings that must appear in results.")
    parser.add_argument("--exclude-text", nargs="+", help="Strings that must not appear in results.")
    parser.add_argument("--moderation", action="store_true", help="Enable moderation filter.")
    parser.add_argument("--additional-queries", nargs="+", help="Extra query variations for deep-only searches.")
    parser.add_argument("--system-prompt", help="System prompt for deep or deep-reasoning searches.")
    parser.add_argument("--output-schema-file", help="Path to JSON schema file.")
    parser.add_argument("--contents", help="JSON string for the contents object.")
    parser.add_argument("--contents-file", help="Path to JSON file for the contents object.")
    parser.add_argument("--extra", help="JSON file whose fields are merged into the payload.")
    parser.add_argument("--dry-run", action="store_true", help="Print payload without sending it.")
    args = parser.parse_args()

    payload: dict[str, object] = {
        "query": args.query,
        "type": args.type,
        "numResults": args.num_results,
    }

    for key, value in (
        ("category", args.category),
        ("userLocation", args.user_location),
        ("startPublishedDate", args.start_published_date),
        ("endPublishedDate", args.end_published_date),
        ("startCrawlDate", args.start_crawl_date),
        ("endCrawlDate", args.end_crawl_date),
    ):
        if value:
            payload[key] = value

    if args.include_domains:
        payload["includeDomains"] = args.include_domains
    if args.exclude_domains:
        payload["excludeDomains"] = args.exclude_domains
    if args.include_text:
        payload["includeText"] = args.include_text
    if args.exclude_text:
        payload["excludeText"] = args.exclude_text
    if args.moderation:
        payload["moderation"] = True
    if args.additional_queries:
        payload["additionalQueries"] = args.additional_queries
    if args.system_prompt:
        payload["systemPrompt"] = args.system_prompt

    contents = load_json_text(args.contents, "contents") or load_json_file(args.contents_file, "contents")
    if contents is not None:
        payload["contents"] = contents

    if args.output_schema_file:
        payload["outputSchema"] = load_json_file(args.output_schema_file, "outputSchema")

    extra = load_json_file(args.extra, "extra")
    if extra:
        payload.update(extra)

    if args.dry_run:
        print("Dry run payload:")
        print(json.dumps(payload, indent=2))
        return

    api_key = os.environ.get("EXA_API_KEY")
    if not api_key:
        raise SystemExit("Set EXA_API_KEY in your environment before running this script.")

    request = Request(
        API_URL,
        data=json.dumps(payload).encode("utf-8"),
        headers={
            "Content-Type": "application/json",
            "x-api-key": api_key,
            "User-Agent": "exa-deep-search-skill/1.0",
        },
        method="POST",
    )

    try:
        with urlopen(request, timeout=60) as response:
            resp_json = json.loads(response.read().decode("utf-8"))
    except HTTPError as exc:
        raise SystemExit(format_http_error(exc))
    except URLError as exc:
        raise SystemExit(f"Network error: {exc.reason}")

    print("=== Response ===")
    print(json.dumps(resp_json, indent=2))

    cost = (resp_json.get("costDollars") or {}).get("total")
    if cost is not None:
        print(f"\nEstimated cost: ${cost:.6f}")

    summarize_results(resp_json)
    summarize_grounding(resp_json)


if __name__ == "__main__":
    main()
