from __future__ import annotations
import json, resource, sys, time, textwrap as tw
from pathlib import Path

ROOT = Path('/home/.z/workspaces/con_2gAuSTkawPiOse8J/llm-foundry')
SRC = ROOT / 'src'
OUT = Path('/home/workspace/Deliverables/qwen-sky-proof')
MODEL = 'Qwen/Qwen2.5-0.5B-Instruct'
MAX_NEW = 64
sys.path.insert(0, str(SRC))

from llm_foundry.adapters import HuggingFacePipelineBackend
from llm_foundry.memory import CompressionEngine, ObsidianMemoryVault
from llm_foundry.tokenizer import estimate_token_count


def ms_timer():
    return time.perf_counter()


def rss_mb():
    return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024.0


def run_prompt(backend, prompt):
    t0 = ms_timer()
    out = backend.generate(prompt)
    return out, (ms_timer() - t0) * 1000.0


def esc(s):
    s = s or ''
    return s.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;').replace('"', '&quot;')


def make():
    OUT.mkdir(parents=True, exist_ok=True)
    backend = HuggingFacePipelineBackend(MODEL, max_new_tokens=MAX_NEW)

    vault = ObsidianMemoryVault(OUT / 'memory-vault')
    vault.add_note('scattering', 'Short wavelengths scatter more. Blue/violet scatter first in air. Longer red/orange travel straight. Sky blue = scattered short waves overhead. Sunset red = red travel through atmosphere, blue scattered away.', tags=['physics'])
    vault.add_note('atmosphere', 'Atmosphere thickness changes colour. Light scatters differently at sunrise vs midday. Rayleigh scattering: blue light scatters in all directions in air molecules.', tags=['physics'])
    vault.add_note('sunset', 'At sunset light passes through more atmosphere. Blue is scattered out. Red/orange reach eyes directly. Sunrise same effect in reverse.', tags=['physics'])
    compressor = CompressionEngine(vault=vault)

    TASK = 'Answer in 3 bullet points: why is the sky blue during the day and red at sunset? Keep each bullet under 15 words. Plain English. No preamble.'

    no_ctx_prompt = TASK
    ctx = compressor.compress_transcript(
        task='why sky blue daytime red sunset',
        transcript=[
            'The sky is blue because short wavelengths scatter in all directions overhead.',
            'At sunset light travels through more atmosphere, scattering blue away, leaving red.',
            'Rayleigh scattering explains both. Simple English for kids.',
        ],
        memory_query='sky blue red sunset scattering atmosphere rayleigh',
        target_tokens=60,
    )
    with_ctx_prompt = TASK + '\n\nCONTEXT:\n' + ctx.to_prompt()

    no_ctx_out, no_ctx_ms = run_prompt(backend, no_ctx_prompt)
    ctx_out, ctx_ms = run_prompt(backend, with_ctx_prompt)

    no_ctx_toks = estimate_token_count(no_ctx_prompt)
    ctx_toks = estimate_token_count(with_ctx_prompt)
    peak = rss_mb()
    tok_saved_pct = 100.0 * (1.0 - float(ctx_toks) / max(1, no_ctx_toks))
    ms_delta = ctx_ms - no_ctx_ms

    transcript = (
        '== LLM Foundry local model proof ==\n'
        'model=%s\n' % MODEL +
        'max_new_tokens=%d\n' % MAX_NEW +
        'before_tokens=%d  after_tokens=%d\n' % (no_ctx_toks, ctx_toks) +
        'before_ms=%.1f  after_ms=%.1f\n\n' % (no_ctx_ms, ctx_ms) +
        'BEFORE (task only, no memory):\n%s\n\n' % no_ctx_prompt +
        'BEFORE OUTPUT:\n%s\n\n' % no_ctx_out +
        'AFTER (task + memory context):\n%s\n\n' % with_ctx_prompt +
        'AFTER OUTPUT:\n%s\n\n' % ctx_out +
        'tok_saved=%.1f%%  ms_delta=%.1fms  peak_rss=%.0fMB\n' % (tok_saved_pct, ms_delta, peak)
    )

    (OUT / 'terminal_transcript.txt').write_text(transcript)
    (OUT / 'comparison.json').write_text(json.dumps({
        'model': MODEL,
        'before_prompt': no_ctx_prompt,
        'before_output': no_ctx_out,
        'before_ms': no_ctx_ms,
        'before_tokens': no_ctx_toks,
        'after_prompt': with_ctx_prompt,
        'after_output': ctx_out,
        'after_ms': ctx_ms,
        'after_tokens': ctx_toks,
        'tok_saved_pct': tok_saved_pct,
        'ms_delta': ms_delta,
        'peak_rss_mb': peak,
    }, indent=2))
    print('tok_saved_pct=%.1f' % tok_saved_pct)
    print('no_ctx_ms=%.1f  ctx_ms=%.1f' % (no_ctx_ms, ctx_ms))
    print('BEFORE: ' + no_ctx_out[:300])
    print('AFTER: ' + ctx_out[:300])


make()
