#!/usr/bin/env python3
"""
Batch process files or directories to generate embeddings.

Usage:
    python3 batch_embed.py --input /path/to/articles --output articles.db --model embeddinggemma
    python3 batch_embed.py --input /path/to/docs --output docs.db --model embeddinggemma --pattern "*.txt"
"""

import sqlite3
import subprocess
import json
import argparse
import os
from pathlib import Path
from typing import List


def get_embedding(text: str, model: str = "embeddinggemma") -> List[float]:
    """Get embedding from Ollama for a given text."""
    result = subprocess.run(
        ["ollama", "run", model],
        capture_output=True,
        text=True,
        input=text
    )
    
    if result.returncode != 0:
        raise Exception(f"Ollama error: {result.stderr}")
    
    embedding = json.loads(result.stdout.strip())
    return embedding


def create_database(db_path: str, table_name: str = "texts") -> sqlite3.Connection:
    """Create SQLite database with embeddings table."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute(f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            text TEXT NOT NULL,
            source TEXT,
            embedding BLOB,
            model TEXT,
            embedding_dim INTEGER,
            created_at TEXT DEFAULT (datetime('now'))
        )
    """)
    
    conn.commit()
    return conn


def batch_embed(
    input_path: str,
    output_db: str,
    model: str = "embeddinggemma",
    pattern: str = "*.md",
    table_name: str = "texts",
    recursive: bool = False
) -> None:
    """
    Batch process files to generate embeddings.
    
    Args:
        input_path: Path to file or directory
        output_db: Path to output database
        model: Ollama model to use
        pattern: File pattern to match
        table_name: Table name to store in
        recursive: Whether to search recursively
    """
    input_path = Path(input_path)
    
    # Find files
    if input_path.is_file():
        files = [input_path]
    else:
        if recursive:
            files = list(input_path.rglob(pattern))
        else:
            files = list(input_path.glob(pattern))
    
    print(f"Found {len(files)} files matching '{pattern}'")
    
    # Create database
    conn = create_database(output_db, table_name)
    cursor = conn.cursor()
    
    # Process each file
    for i, file_path in enumerate(files, 1):
        print(f"\n[{i}/{len(files)}] Processing: {file_path.name}")
        
        try:
            # Read file
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read()
            
            if not text.strip():
                print("  Skipping empty file")
                continue
            
            # Get embedding
            embedding = get_embedding(text, model)
            print(f"  Embedded (dimension: {len(embedding)})")
            
            # Store in database
            embedding_json = json.dumps(embedding)
            cursor.execute(f"""
                INSERT INTO {table_name} (text, source, embedding, model, embedding_dim)
                VALUES (?, ?, ?, ?, ?)
            """, (text, str(file_path), embedding_json, model, len(embedding)))
            
            conn.commit()
            
        except Exception as e:
            print(f"  Error: {e}")
    
    conn.close()
    print(f"\nDone! Processed {len(files)} files")
    print(f"Database saved to: {output_db}")


def main():
    parser = argparse.ArgumentParser(description="Batch process files to generate embeddings")
    parser.add_argument("--input", required=True, help="Input file or directory")
    parser.add_argument("--output", required=True, help="Output database path")
    parser.add_argument("--model", default="embeddinggemma", help="Ollama model")
    parser.add_argument("--pattern", default="*.md", help="File pattern (default: *.md)")
    parser.add_argument("--table", default="texts", help="Table name (default: texts)")
    parser.add_argument("--recursive", action="store_true", help="Search recursively")
    
    args = parser.parse_args()
    
    batch_embed(
        input_path=args.input,
        output_db=args.output,
        model=args.model,
        pattern=args.pattern,
        table_name=args.table,
        recursive=args.recursive
    )


if __name__ == "__main__":
    main()