from __future__ import annotations

import json
import resource
import sys
import textwrap
from pathlib import Path

ROOT = Path('/home/.z/workspaces/con_2gAuSTkawPiOse8J/llm-foundry')
SRC = ROOT / 'src'
OUT = Path('/home/workspace/Deliverables/qwen-sky-proof')
MODEL = 'Qwen/Qwen2.5-0.5B-Instruct'
MAX_NEW = 64

sys.path.insert(0, str(SRC))

from llm_foundry.adapters import HuggingFacePipelineBackend
from llm_foundry.memory import CompressionEngine, ObsidianMemoryVault
from llm_foundry.tokenizer import estimate_token_count

def rss_mb():
    return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0

def run_prompt(backend, prompt):
    start = time.perf_counter()
    out = backend.generate(prompt)
    ms = (time.perf_counter() - start) * 1000
    return out, ms

import time

def esc(t):
    return (t or '').replace('&','&amp;').replace('<','&lt;').replace('>','&gt;').replace('"','&quot;').replace("'", '&#39;')

def make_report():
    OUT.mkdir(parents=True)
    backend = HuggingFacePipelineBackend(MODEL, max_new_tokens=MAX_NEW)

    # Memory vault seeded with facts the model can use
    vault = ObsidianMemoryVault(OUT / 'memory-vault')
    vault.add_note('sky-scattering', 'Short wavelengths scatter more than long ones. Blue/violet scatter first, leaving red/orange at sunset.', tags=['science','physics'])
    vault.add_note('sunset-color', 'At sunset light travels longer through atmosphere, scattering blue away, leaving red/orange hues.', tags=['science','sunset'])
    vault.add_note('day-sky', 'During day we see scattered blue light overhead. Short waves scatter in all directions.', tags=['science','sky'])

    compressor = CompressionEngine(vault=vault)

    # Both get SAME question+task. After gets relevant memory injected.
    TASK = 'Answer: why is the sky blue during the day and red at sunset? Use 3 bullet points. Plain English.'
    NO_CONTEXT = TASK
    CONTEXT = compressor.compress_transcript(
        task=TASK,
        transcript=[
            'The sky changes colour because of how sunlight interacts with the atmosphere.',
            'Shorter wavelengths scatter first, making the day sky blue.',
            'At sunset light passes through more atmosphere, scattering blue away, leaving red.',
        ],
        memory_query='sky blue red sunset scattering atmosphere',
        target_tokens=60,
    )
    WITH_CONTEXT = f"{TASK}\n\nCONTEXT:\n{CONTEXT.to_prompt()}"

    no_ctx_out, no_ctx_ms = run_prompt(backend, NO_CONTEXT)
    with_ctx_out, with_ctx_ms = run_prompt(backend, WITH_CONTEXT)

    no_ctx_toks = estimate_token_count(NO_CONTEXT)
    ctx_toks = estimate_token_count(WITH_CONTEXT)
    peak = rss_mb()

    # Token delta
    tok_delta_pct = 100.0 * (1 - ctx_toks / no_ctx_toks)
    ms_delta = with_ctx_ms - no_ctx_ms

    transcript = textwrap.dedent(f"""\
        == LLM Foundry local model proof ==
        model={MODEL}
        max_new_tokens={MAX_NEW}
        before_tokens={no_ctx_toks}  after_tokens={ctx_toks}
        before_latency_ms={no_ctx_ms:.1f}  after_latency_ms={with_ctx_ms:.1f}

        TASK (identical for both):
        {TASK}

        BEFORE (task only, no memory):
        {NO_CONTEXT}

        BEFORE OUTPUT:
        {no_ctx_out}

        AFTER (task + compressed memory context):
        {WITH_CONTEXT}

        AFTER OUTPUT:
        {with_ctx_out}

        DELTA:
        token_delta_pct={tok_delta_pct:.1f}%  (saved on context)
        latency_delta_ms={ms_delta:.1f}  (higher because context is larger)
        peak_rss_mb={peak:.0f}
    """)

    def section(title, pre_text, caption=''):
        cap = f'<p class="cap">{esc(caption)}</p>' if caption else ''
        return (
            f'<section class="panel">'
            f'<h3>{esc(title)}</h3>{cap}'
            f'<pre>{esc(pre_text or "(no output)")}</pre>'
            f'</section>'
        )

    def metric(val, label, note=''):
        note_html = f'<div class="note">{esc(note)}</div>' if note else ''
        return (
            f'<div class="metric">'
            f'<div class="val">{esc(val)}</div>'
            f'<div class="lbl">{esc(label)}</div>'
            f'{note_html}'
            f'</div>'
        )

    def row(a_lbl, a_val, b_lbl, b_val, diff, diff_note):
        cls = 'pos' if diff > 0 else 'neg'
        arrow = '&#8593;' if diff > 0 else '&#8595;'
        return (
            f'<tr>'
            f'<td>{esc(a_lbl)}</td><td class="val">{esc(a_val)}</td>'
            f'<td>{esc(b_lbl)}</td><td class="val">{esc(b_val)}</td>'
            f'<td class="{cls}">{arrow} {diff_note}</td>'
            f'</tr>'
        )

    html = f"""<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width,initial-scale=1"/>
<title>LLM Foundry local model proof</title>
<style>
:root{{color-scheme:dark;}}
body{{margin:0;background:#060c18;color:#d8e4ff;font-family:Inter,system-ui,sans-serif;}}
.wrap{{max-width:1380px;margin:0 auto;padding:24px;}}
.hero,section{{background:linear-gradient(135deg,rgba(255,255,255,.05),rgba(255,255,255,.02);border:1px solid rgba(124,156,255,.1);border-radius:20px;padding:22px;margin:14px 0;}}
h1{{font-size:clamp(1.8rem,4.5vw,3.2rem);line-height:1.05;margin:0 0 8px;}}
.sub{{color:#7a90c4;margin:0 0 14px;font-size:1.05rem;}}
.chips{{display:flex;flex-wrap:wrap;gap:8px;}}
.chip{{padding:5px 12px;border-radius:999px;background:rgba(124,156,255,.12);border:1px solid rgba(124,156,255,.2);font-size:.88rem;}}
.mgrid{{display:grid;grid-template-columns:repeat(3,1fr);gap:14px;margin:14px 0;}}
.metric{{background:#0c1830;border:1px solid rgba(255,255,255,.07);border-radius:16px;padding:16px 14px;}}
.val{{font-size:2.1rem;font-weight:800;color:#86efac;line-height:1;}}
.lbl{{color:#8a9bca;margin-top:6px;font-size:.92rem;}}
.note{{color:#5a6b94;font-size:.82rem;margin-top:4px;}}
.compare{{display:grid;grid-template-columns:1fr 1fr;gap:18px;margin:14px 0;}}
.col{{display:flex;flex-direction:column;gap:14px;}}
.panel{{background:#0b1628;border:1px solid rgba(255,255,255,.08);border-radius:16px;padding:16px 18px;}}
.panel h3{{margin:0 0 10px;font-size:.88rem;text-transform:uppercase;letter-spacing:.1em;color:#86efac;}}
.cap{{color:#6070a0;font-size:.85rem;margin:4px 0 0;}}
pre{{background:#08101e;border:1px solid rgba(255,255,255,.06);border-radius:12px;padding:14px;margin:0;white-space:pre-wrap;word-break:break-word;font-size:.94rem;line-height:1.6;max-height:260px;overflow:auto;}}
.table{{background:#0b1628;border:1px solid rgba(255,255,255,.08);border-radius:16px;padding:18px 20px;margin:14px 0;}}
.table h3{{margin:0 0 12px;font-size:.88rem;text-transform:uppercase;letter-spacing:.1em;color:#86efac;}}
table{{width:100%;border-collapse:collapse;}}
th{{text-align:left;padding:8px 10px;border-bottom:1px solid rgba(255,255,255,.07);font-size:.78rem;text-transform:uppercase;letter-spacing:.1em;color:#86efac;}}
td{{padding:9px 10px;border-bottom:1px solid rgba(255,255,255,.04);font-size:.95rem;}}
.pos{{color:#86efac;}} .neg{{color:#f38ba8;}}
.footer a{{color:#8ab4ff;}}
@media(max-width:700px){{.compare,.mgrid{{grid-template-columns:1fr;}}}}
</style>
</head>
<body>
<div class="wrap">
  <div class="hero">
    <h1>Same local model. Same task. Real before vs after memory.</h1>
    <p class="sub">Qwen2.5-0.5B-Instruct through LLM Foundry's memory layer. No API calls. No cloud. 100% local.</p>
    <div class="chips">
      <span class="chip">model={MODEL}</span>
      <span class="chip">HuggingFacePipelineBackend</span>
      <span class="chip">Local memory vault + embeddings</span>
      <span class="chip">0 cloud APIs</span>
    </div>
  </div>

  <div class="mgrid">
    {metric(f'{no_ctx_ms:.0f} ms', 'Before latency', 'raw task, no memory')}
    {metric(f'{with_ctx_ms:.0f} ms', 'After latency', '+ memory context')}
    {metric(f'{peak:.0f} MB', 'Peak RAM', 'both runs')}
  </div>

  <div class="compare">
    <div class="col">
      {section('BEFORE — task only', NO_CONTEXT, 'no memory, no compression')}
      {section('BEFORE output', no_ctx_out or '(no output)')}
    </div>
    <div class="col">
      {section('AFTER — task + memory context', WITH_CONTEXT, '+ compressed memory context from vault')}
      {section('AFTER output', with_ctx_out or '(no output)')}
    </div>
  </div>

  <div class="table">
    <h3>Delta</h3>
    <table>
      <thead><tr><th>Metric</th><th>Before</th><th>After</th><th>Change</th></tr></thead>
      <tbody>
        {row('Latency', f'{no_ctx_ms:.0f} ms', 'Latency', f'{with_ctx_ms:.0f} ms', ms_delta_abs = (no_ctx_ms - with_ctx_ms); (ms_delta_abs if ms_delta_abs > 0 else 0), f'{"faster" if ms_delta < 0 else "slower"} by {abs(ms_delta):.0f} ms')}
        {row('Context tokens', f'{no_ctx_toks}', 'Context tokens', f'{ctx_toks}', -tok_delta_pct, f'{"saved" if tok_delta_pct > 0 else "added"} {abs(tok_delta_pct:.0f}% via compression')}
        {row('Peak RAM', f'{peak:.0f} MB', '—', '—', 0, '')}
      </tbody>
    </table>
  </div>

  <div class="footer" style="background:#0b1628;border:1px solid rgba(124,156,255,.1);border-radius:16px;padding:18px 20px;margin:14px 0;">
    <p style="margin:0 0 8px;color:#7a90c4;font-size:.92rem;">
      model={MODEL} &middot; backend=HuggingFacePipelineBackend &middot; max_new_tokens={MAX_NEW} &middot; 100% local
    </p>
    <p style="margin:0;color:#8ab4ff;font-size:.9rem;">
      <a href="https://github.com/AmSach/llm-foundry">GitHub: AmSach/llm-foundry</a> &middot;
      <a href="https://www.instagram.com/i.amsach">Instagram @i.amsach</a> &middot;
      <a href="https://www.linkedin.com/in/theamansachan">LinkedIn TheAmanSachan</a>
    </p>
  </div>
</div>
</body>
</html>"""

    (OUT / 'report.html').write_text(html, encoding='utf-8')
    (OUT / 'terminal_transcript.txt').write_text(transcript)
    (OUT / 'comparison.json').write_text(json.dumps({
        'model': MODEL,
        'before': {'prompt': NO_CONTEXT, 'output': no_ctx_out, 'latency_ms': no_ctx_ms, 'tokens': no_ctx_toks},
        'after': {'prompt': WITH_CONTEXT, 'output': with_ctx_out, 'latency_ms': with_ctx_ms, 'tokens': ctx_toks},
        'tok_delta_pct': tok_delta_pct,
        'latency_delta_ms': ms_delta,
        'peak_rss_mb': peak,
    }, indent=2))
    print(f"DONE tok_delta_pct={tok_delta_pct:.1f} ms_delta={ms_delta:.1f}")
    print(f"BEFORE:\n{no_ctx_out}\n")
    print(f"AFTER:\n{with_ctx_out}\n")

if __name__ == '__main__':
    make_report()
