#!/usr/bin/env python3
"""
High-throughput batch embedding using sentence-transformers or Ollama API.

Usage:
    # Using sentence-transformers (recommended, ~800 texts/sec)
    python3 fast_embed.py --input /path/to/docs --output embeddings.db --backend st

    # Using Ollama API (concurrent, ~10-20 texts/sec)
    python3 fast_embed.py --input /path/to/docs --output embeddings.db --backend ollama

    # Specify model and batch size
    python3 fast_embed.py --input docs --output out.db --backend st --model all-mpnet-base-v2 --batch-size 64
"""
import argparse
import asyncio
import json
import sqlite3
import time
from pathlib import Path

# Default models
ST_MODEL = "all-MiniLM-L6-v2"
OLLAMA_MODEL = "embeddinggemma"
OLLAMA_URL = "http://localhost:11434/api/embeddings"


def embed_with_sentence_transformers(texts: list[str], model_name: str, batch_size: int) -> list[list[float]]:
    """Embed using sentence-transformers with GPU batching."""
    from sentence_transformers import SentenceTransformer
    
    model = SentenceTransformer(model_name)
    print(f"Model: {model_name}, Device: {model.device}")
    
    embeddings = model.encode(texts, batch_size=batch_size, show_progress_bar=True)
    return [emb.tolist() for emb in embeddings]


async def embed_with_ollama(texts: list[str], model_name: str, concurrency: int = 16, max_chars: int = 2000) -> list[list[float]]:
    """Embed using Ollama API with async concurrency."""
    import aiohttp
    
    async def get_one(session, text, semaphore):
        async with semaphore:
            async with session.post(OLLAMA_URL, json={"model": model_name, "prompt": text[:max_chars]}) as resp:
                result = await resp.json()
                return result.get("embedding")
    
    semaphore = asyncio.Semaphore(concurrency)
    async with aiohttp.ClientSession() as session:
        tasks = [get_one(session, t, semaphore) for t in texts]
        return await asyncio.gather(*tasks)


def load_texts(input_path: Path, pattern: str = "*.md", recursive: bool = False) -> list[tuple[str, str]]:
    """Load texts from a directory. Returns [(filename, content), ...]."""
    if input_path.is_file():
        return [(input_path.name, input_path.read_text())]
    
    glob_fn = input_path.rglob if recursive else input_path.glob
    files = sorted(glob_fn(pattern))
    
    results = []
    for f in files:
        try:
            content = f.read_text()
            results.append((f.name, content))
        except Exception as e:
            print(f"Skipping {f}: {e}")
    
    return results


def save_to_db(db_path: Path, data: list[tuple[str, str, list[float]]], model: str, table: str = "embeddings"):
    """Save embeddings to SQLite database."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute(f"""
        CREATE TABLE IF NOT EXISTS {table} (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT NOT NULL,
            filepath TEXT,
            embedding TEXT,
            model TEXT,
            embedding_dim INTEGER,
            created_at TEXT DEFAULT (datetime('now'))
        )
    """)
    cursor.execute(f"DELETE FROM {table}")
    
    for name, filepath, embedding in data:
        cursor.execute(
            f"INSERT INTO {table} (name, filepath, embedding, model, embedding_dim) VALUES (?, ?, ?, ?, ?)",
            (name, filepath, json.dumps(embedding), model, len(embedding) if embedding else 0)
        )
    
    conn.commit()
    conn.close()


def main():
    parser = argparse.ArgumentParser(description="High-throughput batch embedding")
    parser.add_argument("--input", "-i", required=True, help="Input file or directory")
    parser.add_argument("--output", "-o", default="embeddings.db", help="Output SQLite database")
    parser.add_argument("--backend", "-b", choices=["st", "ollama"], default="st",
                        help="Backend: st (sentence-transformers) or ollama")
    parser.add_argument("--model", "-m", help="Model name (default depends on backend)")
    parser.add_argument("--batch-size", type=int, default=64, help="Batch size for sentence-transformers")
    parser.add_argument("--concurrency", type=int, default=16, help="Concurrency for Ollama API")
    parser.add_argument("--pattern", default="*.md", help="File pattern to match")
    parser.add_argument("--recursive", "-r", action="store_true", help="Search recursively")
    parser.add_argument("--table", default="embeddings", help="Table name in database")
    args = parser.parse_args()
    
    input_path = Path(args.input)
    output_path = Path(args.output)
    
    # Load texts
    print(f"Loading texts from {input_path}...")
    text_data = load_texts(input_path, args.pattern, args.recursive)
    print(f"Found {len(text_data)} files")
    
    if not text_data:
        print("No files found!")
        return
    
    texts = [content for name, content in text_data]
    
    # Embed
    model = args.model
    start = time.time()
    
    if args.backend == "st":
        model = model or ST_MODEL
        print(f"Embedding with sentence-transformers ({model})...")
        embeddings = embed_with_sentence_transformers(texts, model, args.batch_size)
    else:
        model = model or OLLAMA_MODEL
        print(f"Embedding with Ollama API ({model})...")
        embeddings = asyncio.run(embed_with_ollama(texts, model, args.concurrency))
    
    elapsed = time.time() - start
    print(f"Embedded {len(texts)} texts in {elapsed:.2f}s ({len(texts)/elapsed:.1f} texts/sec)")
    
    # Combine with metadata
    results = []
    for (name, content), emb in zip(text_data, embeddings):
        filepath = str(input_path / name) if input_path.is_dir() else str(input_path)
        results.append((name, filepath, emb))
    
    # Save
    save_to_db(output_path, results, model, args.table)
    print(f"Saved to {output_path}")


if __name__ == "__main__":
    main()
