from __future__ import annotations

import html
import json
import resource
import sys
import textwrap
import time
from pathlib import Path

ROOT = Path('/home/.z/workspaces/con_2gAuSTkawPiOse8J/llm-foundry')
SRC = ROOT / 'src'
OUT = Path('/home/workspace/Deliverables/kvquant-bitforge-side-by-side-proof')
MODEL = 'Qwen/Qwen2.5-0.5B-Instruct'

if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))

from llm_foundry.adapters import HuggingFacePipelineBackend
from llm_foundry.memory import CompressionEngine, ObsidianMemoryVault
from llm_foundry.rag import LocalRetriever
from llm_foundry.tokenizer import estimate_token_count

QUESTION = (
    'A small product team has one day before launch. The checkout sometimes fails, '
    'but the dashboard is only slow. Which should they fix first, and why? '
    'Answer in exactly 4 bullets.'
)

NOISY_CONTEXT = [
    'The checkout bug blocks payment completion for a subset of users.',
    'The dashboard is slow, but it does not stop people from buying.',
    'The team has limited time and only one engineer available for the fix.',
    'The launch date is tomorrow.',
    'The team wants a short answer with a clear priority and a practical reason.',
    'The team wants the answer to be easy to paste into a status update.',
    'The team already knows speed matters, but blocking revenue matters more.',
]

RUBRIC_TERMS = [
    'checkout',
    'payment',
    'dashboard',
    'priority',
    'first',
    'revenue',
    'trust',
    'blocked',
    'risk',
    'launch',
]


def run_prompt(backend: HuggingFacePipelineBackend, prompt: str) -> tuple[str, float]:
    start = time.perf_counter()
    output = backend.generate(prompt)
    latency_ms = (time.perf_counter() - start) * 1000
    return output, latency_ms


def score_output(text: str) -> dict[str, float | int | bool]:
    lowered = text.lower()
    bullet_count = sum(1 for line in text.splitlines() if line.strip().startswith(('-', '1.', '2.', '3.', '4.')))
    term_hits = sum(1 for term in RUBRIC_TERMS if term in lowered)
    structure_score = min(1.0, bullet_count / 4) if bullet_count else 0.0
    content_score = min(1.0, term_hits / 6)
    overall = round((structure_score * 0.55) + (content_score * 0.45), 3)
    return {
        'bullet_count': bullet_count,
        'term_hits': term_hits,
        'structure_score': round(structure_score, 3),
        'content_score': round(content_score, 3),
        'overall': overall,
        'has_checkout': 'checkout' in lowered or 'payment' in lowered,
        'has_dashboard': 'dashboard' in lowered,
        'has_priority_reason': any(term in lowered for term in ['priority', 'because', 'first', 'risk', 'launch']),
    }


def rss_mb() -> float:
    return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0


def esc(value: str) -> str:
    return html.escape(value)


def render_panel(title: str, content: str) -> str:
    return f"<section class='panel'><h3>{esc(title)}</h3><pre>{esc(content)}</pre></section>"


def render_card(label: str, value: str) -> str:
    return f"<div class='metric'><div class='metric-value'>{esc(value)}</div><div class='metric-label'>{esc(label)}</div></div>"


def build_html(data: dict) -> str:
    before = data['before']
    after = data['after']
    return f"""<!doctype html>
<html lang='en'>
<head>
  <meta charset='utf-8' />
  <meta name='viewport' content='width=device-width, initial-scale=1' />
  <title>KVQuant / BitForge side-by-side proof</title>
  <style>
    :root {{ color-scheme: dark; }}
    body {{ margin: 0; background: #070c16; color: #e8eeff; font-family: Inter, system-ui, sans-serif; }}
    .wrap {{ max-width: 1420px; margin: 0 auto; padding: 28px; }}
    .hero, .panel, .box, .card {{ background: linear-gradient(180deg, rgba(255,255,255,.06), rgba(255,255,255,.02)); border: 1px solid rgba(255,255,255,.08); border-radius: 20px; padding: 22px; margin: 16px 0; box-shadow: 0 18px 70px rgba(0,0,0,.26); }}
    h1 {{ font-size: clamp(2.2rem, 5vw, 4rem); line-height: .96; margin: 0 0 10px; }}
    h2, h3 {{ margin: 0 0 10px; }}
    p, li {{ color: #bdc7e8; line-height: 1.65; }}
    .chips {{ display: flex; flex-wrap: wrap; gap: 8px; margin-top: 16px; }}
    .chip {{ padding: 6px 10px; border-radius: 999px; background: rgba(124,156,255,.13); border: 1px solid rgba(124,156,255,.25); color: #dfe7ff; font-size: .9rem; }}
    .metrics {{ display: grid; grid-template-columns: repeat(4, minmax(0, 1fr)); gap: 14px; }}
    .metric {{ background: #091022; border: 1px solid rgba(255,255,255,.08); border-radius: 18px; padding: 18px; }}
    .metric-value {{ font-size: 2rem; font-weight: 800; color: #86efac; line-height: 1.05; }}
    .metric-label {{ color: #aab4d4; margin-top: 8px; }}
    .compare {{ display: grid; grid-template-columns: repeat(2, minmax(0, 1fr)); gap: 16px; }}
    .compare-col {{ display: grid; gap: 16px; }}
    pre, table {{ background: #091022; border: 1px solid rgba(255,255,255,.08); border-radius: 14px; }}
    pre {{ padding: 14px; margin: 0; white-space: pre-wrap; overflow: auto; }}
    table {{ width: 100%; border-collapse: collapse; overflow: hidden; }}
    th, td {{ padding: 10px 9px; border-bottom: 1px solid rgba(255,255,255,.08); text-align: left; vertical-align: top; }}
    th {{ text-transform: uppercase; letter-spacing: .12em; font-size: .78rem; color: #86efac; }}
    .small {{ color: #90a0c7; font-size: .95rem; }}
    a {{ color: #8ab4ff; text-decoration: none; }}
    a:hover {{ text-decoration: underline; }}
  </style>
</head>
<body>
  <div class='wrap'>
    <section class='hero'>
      <h1>KVQuant / BitForge: side-by-side proof of the same local model</h1>
      <p>This is the version that is easy to read: the prompt is on the left, the answer is on the left, the after version is on the right, and the metrics sit underneath. Same local model. Same question. Different stack.</p>
      <div class='chips'>
        <span class='chip'>model={esc(data['model'])}</span>
        <span class='chip'>question={esc(data['question_short'])}</span>
        <span class='chip'>before tokens={before['prompt_tokens']}</span>
        <span class='chip'>after tokens={after['prompt_tokens']}</span>
        <span class='chip'>memory saved={data['memory_saved_pct']:.1f}%</span>
      </div>
    </section>

    <section class='metrics'>
      {render_card('Before latency', f"{before['latency_ms']:.1f} ms")}
      {render_card('After latency', f"{after['latency_ms']:.1f} ms")}
      {render_card('Before score', f"{before['score']['overall']:.3f}")}
      {render_card('After score', f"{after['score']['overall']:.3f}")}
    </section>

    <section class='compare'>
      <div class='compare-col'>
        <div class='box'>
          <h2>Before</h2>
          <p class='small'>Raw prompt, no compression, no semantic retrieval.</p>
        </div>
        {render_panel('Prompt', before['prompt'])}
        {render_panel('Answer', before['output'])}
      </div>
      <div class='compare-col'>
        <div class='box'>
          <h2>After</h2>
          <p class='small'>Compressed context, memory notes, same question.</p>
        </div>
        {render_panel('Prompt', after['prompt'])}
        {render_panel('Answer', after['output'])}
      </div>
    </section>

    <section class='box'>
      <h2>Scores</h2>
      <table>
        <thead><tr><th>Version</th><th>Prompt tokens</th><th>Latency</th><th>Accuracy-like score</th><th>Bullets</th><th>Term hits</th></tr></thead>
        <tbody>
          <tr><td>Before</td><td>{before['prompt_tokens']}</td><td>{before['latency_ms']:.1f} ms</td><td>{before['score']['overall']:.3f}</td><td>{before['score']['bullet_count']}</td><td>{before['score']['term_hits']}</td></tr>
          <tr><td>After</td><td>{after['prompt_tokens']}</td><td>{after['latency_ms']:.1f} ms</td><td>{after['score']['overall']:.3f}</td><td>{after['score']['bullet_count']}</td><td>{after['score']['term_hits']}</td></tr>
        </tbody>
      </table>
    </section>

    <section class='box'>
      <h2>Memory and retrieval</h2>
      <pre>{esc(data['memory_block'])}</pre>
    </section>

    <section class='box'>
      <h2>Terminal transcript</h2>
      <pre>{esc(data['terminal_transcript'])}</pre>
    </section>

    <section class='box'>
      <h2>Repo retrieval hits</h2>
      <pre>{esc(data['retrieval_text'])}</pre>
    </section>

    <section class='box'>
      <h2>Links</h2>
      <p>
        GitHub: <a href='https://github.com/AmSach/llm-foundry'>https://github.com/AmSach/llm-foundry</a><br />
        GitHub profile: <a href='https://github.com/AmSach'>https://github.com/AmSach</a><br />
        Instagram: <a href='https://www.instagram.com/i.amsach'>https://www.instagram.com/i.amsach</a><br />
        LinkedIn: <a href='https://www.linkedin.com/in/theamansachan'>https://www.linkedin.com/in/theamansachan</a>
      </p>
    </section>
  </div>
</body>
</html>"""


def main() -> None:
    OUT.mkdir(parents=True, exist_ok=True)
    vault = ObsidianMemoryVault(OUT / 'memory-vault')
    vault.add_note('Priority rule', 'If a task blocks money, trust, or correctness, fix it before polish.', tags=['priority', 'rule'])
    vault.add_note('Before note', 'Before: the model sees a raw prompt, more clutter, and no compressed memory.', tags=['before', 'kvquant'])
    vault.add_note('After note', 'After: the model sees compressed context, relevant retrieval, and a smaller working set.', tags=['after', 'bitforge'])

    backend = HuggingFacePipelineBackend(MODEL, max_new_tokens=128)
    compressor = CompressionEngine(vault=vault)
    retriever = LocalRetriever(ROOT)

    before_prompt = QUESTION + '\n\n' + '\n'.join(NOISY_CONTEXT)
    before_output, before_latency = run_prompt(backend, before_prompt)
    before_score = score_output(before_output)
    before_prompt_tokens = estimate_token_count(before_prompt)

    compressed_context = compressor.compress_transcript(
        task='Answer the launch-priority question',
        transcript=NOISY_CONTEXT,
        memory_query='priority rule blocks money trust correctness checkout dashboard',
        target_tokens=120,
    )
    after_prompt = QUESTION + '\n\n' + compressed_context.to_prompt()
    after_output, after_latency = run_prompt(backend, after_prompt)
    after_score = score_output(after_output)
    after_prompt_tokens = estimate_token_count(after_prompt)

    hits = retriever.search('priority rule checkout dashboard launch trust correctness', top_k=4)
    retrieval_text = '\n'.join(f'{hit.path} | score={hit.score:.3f} | {hit.text[:180]}' for hit in hits) or 'no retrieval hits'

    memory_saved_pct = 100.0 * (1 - (after_prompt_tokens / before_prompt_tokens)) if before_prompt_tokens else 0.0
    peak_rss = rss_mb()

    terminal_transcript = textwrap.dedent(f'''\
    == KVQuant / BitForge side-by-side proof ==
    model={MODEL}
    before_prompt_tokens={before_prompt_tokens}
    after_prompt_tokens={after_prompt_tokens}
    memory_saved_pct={memory_saved_pct:.1f}%
    peak_rss_mb={peak_rss:.1f}

    QUESTION:
    {QUESTION}

    BEFORE PROMPT:
    {before_prompt}

    BEFORE ANSWER:
    {before_output}

    AFTER PROMPT:
    {after_prompt}

    AFTER ANSWER:
    {after_output}

    DELTA:
    latency_delta_ms={after_latency - before_latency:.1f}
    prompt_tokens_saved={before_prompt_tokens - after_prompt_tokens}
    accuracy_delta={after_score['overall'] - before_score['overall']:.3f}
    memory_saved_pct={memory_saved_pct:.1f}%
    ''').strip()

    memory_block = textwrap.dedent(f'''\
    BEFORE memory notes
    - raw prompt
    - clutter stays in context
    - no compressed retrieval

    AFTER memory notes
    - compressed context
    - relevant retrieval
    - smaller working set

    compressed context used after:
    {compressed_context.to_prompt()}
    ''').strip()

    data = {
        'model': MODEL,
        'question_short': QUESTION,
        'before': {
            'prompt': before_prompt,
            'prompt_tokens': before_prompt_tokens,
            'latency_ms': before_latency,
            'output': before_output,
            'score': before_score,
        },
        'after': {
            'prompt': after_prompt,
            'prompt_tokens': after_prompt_tokens,
            'latency_ms': after_latency,
            'output': after_output,
            'score': after_score,
        },
        'memory_saved_pct': memory_saved_pct,
        'memory_block': memory_block,
        'terminal_transcript': terminal_transcript,
        'retrieval_text': retrieval_text,
    }

    (OUT / 'comparison.json').write_text(json.dumps(data, indent=2, ensure_ascii=False))
    (OUT / 'terminal_transcript.txt').write_text(terminal_transcript)
    (OUT / 'report.html').write_text(build_html(data), encoding='utf-8')

    print(OUT / 'report.html')
    print(OUT / 'terminal_transcript.txt')
    print(f"before_latency_ms={before_latency:.1f}")
    print(f"after_latency_ms={after_latency:.1f}")
    print(f"before_accuracy={before_score['overall']:.3f}")
    print(f"after_accuracy={after_score['overall']:.3f}")
    print(f"before_prompt_tokens={before_prompt_tokens}")
    print(f"after_prompt_tokens={after_prompt_tokens}")
    print(f"memory_saved_pct={memory_saved_pct:.1f}")
    print(f"peak_rss_mb={peak_rss:.1f}")


if __name__ == '__main__':
    main()
