#!/usr/bin/env python3
"""
Generate text embeddings using Ollama and store them in SQLite.

Usage:
    python3 embed_text.py --text "Your text here" --model embeddinggemma
    python3 embed_text.py --file texts.txt --model embeddinggemma
    python3 embed_text.py --dir /path/to/docs --model embeddinggemma --pattern "*.md"
"""

import sqlite3
import subprocess
import json
import argparse
import os
from pathlib import Path
from typing import List, Tuple


def get_embedding(text: str, model: str = "embeddinggemma") -> List[float]:
    """
    Get embedding from Ollama for a given text.
    
    Args:
        text: The text to embed
        model: The Ollama model to use (default: embeddinggemma)
    
    Returns:
        List of float values representing the embedding
    """
    result = subprocess.run(
        ["ollama", "run", model],
        capture_output=True,
        text=True,
        input=text
    )
    
    if result.returncode != 0:
        raise Exception(f"Ollama error: {result.stderr}")
    
    # Parse JSON output
    embedding = json.loads(result.stdout.strip())
    return embedding


def create_database(db_path: str, table_name: str = "texts") -> sqlite3.Connection:
    """
    Create SQLite database with embeddings table.
    
    Args:
        db_path: Path to the database file
        table_name: Name of the table to create
    
    Returns:
        SQLite connection object
    """
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute(f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            text TEXT NOT NULL,
            source TEXT,
            embedding BLOB,
            model TEXT,
            embedding_dim INTEGER,
            created_at TEXT DEFAULT (datetime('now'))
        )
    """)
    
    conn.commit()
    return conn


def store_embedding(
    conn: sqlite3.Connection,
    text: str,
    embedding: List[float],
    source: str = None,
    model: str = "embeddinggemma",
    table_name: str = "texts"
) -> None:
    """
    Store embedding in SQLite database.
    
    Args:
        conn: SQLite connection
        text: Original text content
        embedding: Embedding vector
        source: Source file or identifier
        model: Model used for embedding
        table_name: Table name to store in
    """
    cursor = conn.cursor()
    
    # Store as JSON in BLOB
    embedding_json = json.dumps(embedding)
    
    cursor.execute(f"""
        INSERT INTO {table_name} (text, source, embedding, model, embedding_dim)
        VALUES (?, ?, ?, ?, ?)
    """, (text, source, embedding_json, model, len(embedding)))
    
    conn.commit()


def embed_text(text: str, db_path: str, model: str = "embeddinggemma", source: str = None) -> None:
    """
    Embed a single text and store in database.
    
    Args:
        text: Text to embed
        db_path: Path to database
        model: Ollama model to use
        source: Source identifier
    """
    print(f"Embedding text: {text[:50]}...")
    
    # Get embedding
    embedding = get_embedding(text, model)
    print(f"  Embedding dimension: {len(embedding)}")
    
    # Store in database
    conn = create_database(db_path)
    store_embedding(conn, text, embedding, source=source, model=model)
    conn.close()
    
    print("  Stored in database")


def embed_file(file_path: str, db_path: str, model: str = "embeddinggemma") -> None:
    """
    Embed texts from a file (one per line).
    
    Args:
        file_path: Path to text file
        db_path: Path to database
        model: Ollama model to use
    """
    with open(file_path, 'r') as f:
        lines = f.readlines()
    
    print(f"Found {len(lines)} lines in {file_path}")
    
    conn = create_database(db_path)
    
    for i, line in enumerate(lines, 1):
        line = line.strip()
        if not line:
            continue
        
        print(f"Processing line {i}/{len(lines)}: {line[:50]}...")
        
        try:
            embedding = get_embedding(line, model)
            store_embedding(
                conn, 
                line, 
                embedding, 
                source=f"{file_path}:{i}", 
                model=model
            )
        except Exception as e:
            print(f"  Error: {e}")
    
    conn.close()
    print(f"\nDone! Processed {len(lines)} lines")


def embed_directory(
    dir_path: str, 
    db_path: str, 
    model: str = "embeddinggemma", 
    pattern: str = "*.md"
) -> None:
    """
    Embed all files matching pattern in a directory.
    
    Args:
        dir_path: Path to directory
        db_path: Path to database
        model: Ollama model to use
        pattern: File pattern to match (e.g., "*.md", "*.txt")
    """
    dir_path = Path(dir_path)
    files = list(dir_path.glob(pattern))
    
    print(f"Found {len(files)} files matching '{pattern}' in {dir_path}")
    
    conn = create_database(db_path)
    
    for i, file_path in enumerate(files, 1):
        print(f"\nProcessing file {i}/{len(files)}: {file_path.name}")
        
        try:
            with open(file_path, 'r') as f:
                text = f.read()
            
            embedding = get_embedding(text, model)
            store_embedding(
                conn,
                text,
                embedding,
                source=str(file_path),
                model=model
            )
            print(f"  Embedded (dimension: {len(embedding)})")
            
        except Exception as e:
            print(f"  Error: {e}")
    
    conn.close()
    print(f"\nDone! Processed {len(files)} files")


def main():
    parser = argparse.ArgumentParser(description="Generate text embeddings using Ollama")
    parser.add_argument("--text", help="Text to embed")
    parser.add_argument("--file", help="File with texts (one per line)")
    parser.add_argument("--dir", help="Directory with files to embed")
    parser.add_argument("--pattern", default="*.md", help="File pattern for directory (default: *.md)")
    parser.add_argument("--db", default="embeddings.db", help="Database path (default: embeddings.db)")
    parser.add_argument("--model", default="embeddinggemma", help="Ollama model (default: embeddinggemma)")
    
    args = parser.parse_args()
    
    if args.text:
        embed_text(args.text, args.db, args.model)
    elif args.file:
        embed_file(args.file, args.db, args.model)
    elif args.dir:
        embed_directory(args.dir, args.db, args.model, args.pattern)
    else:
        parser.print_help()


if __name__ == "__main__":
    main()