#!/usr/bin/env python3
"""
Search for similar texts using cosine similarity on embeddings.

Usage:
    python3 search_similar.py --query "programming" --db embeddings.db
    python3 search_similar.py --query "machine learning" --db embeddings.db --top 5
"""

import sqlite3
import subprocess
import json
import argparse
import math
from typing import List, Tuple


def get_embedding(text: str, model: str = "embeddinggemma") -> List[float]:
    """Get embedding from Ollama for a given text."""
    result = subprocess.run(
        ["ollama", "run", model],
        capture_output=True,
        text=True,
        input=text
    )
    
    if result.returncode != 0:
        raise Exception(f"Ollama error: {result.stderr}")
    
    embedding = json.loads(result.stdout.strip())
    return embedding


def cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
    """Calculate cosine similarity between two vectors."""
    dot_product = sum(a * b for a, b in zip(vec1, vec2))
    norm1 = math.sqrt(sum(a * a for a in vec1))
    norm2 = math.sqrt(sum(b * b for b in vec2))
    
    if norm1 == 0 or norm2 == 0:
        return 0.0
    
    return dot_product / (norm1 * norm2)


def euclidean_distance(vec1: List[float], vec2: List[float]) -> float:
    """Calculate Euclidean distance between two vectors."""
    return math.sqrt(sum((a - b) ** 2 for a, b in zip(vec1, vec2)))


def dot_product(vec1: List[float], vec2: List[float]) -> float:
    """Calculate dot product between two vectors."""
    return sum(a * b for a, b in zip(vec1, vec2))


def search_similar(
    query: str,
    db_path: str,
    model: str = "embeddinggemma",
    top_n: int = 10,
    metric: str = "cosine",
    table_name: str = "texts"
) -> List[Tuple[str, str, float]]:
    """
    Search for texts similar to query.
    
    Args:
        query: Query text
        db_path: Path to database
        model: Ollama model to use
        top_n: Number of results to return
        metric: Similarity metric ('cosine', 'euclidean', 'dot')
        table_name: Table name to search
    
    Returns:
        List of (text, source, score) tuples
    """
    print(f"Getting embedding for query: '{query}'")
    query_embedding = get_embedding(query, model)
    print(f"Query embedding dimension: {len(query_embedding)}")
    
    # Select similarity function
    if metric == "cosine":
        similarity_fn = cosine_similarity
        higher_is_better = True
    elif metric == "euclidean":
        similarity_fn = euclidean_distance
        higher_is_better = False
    elif metric == "dot":
        similarity_fn = dot_product
        higher_is_better = True
    else:
        raise ValueError(f"Unknown metric: {metric}")
    
    # Connect to database
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Get all texts
    cursor.execute(f"SELECT id, text, source, embedding FROM {table_name}")
    texts = cursor.fetchall()
    
    # Calculate similarity for each text
    results = []
    for text_id, text, source, embedding_blob in texts:
        embedding = json.loads(embedding_blob)
        score = similarity_fn(query_embedding, embedding)
        results.append((text, source, score))
    
    # Sort by score
    if higher_is_better:
        results.sort(key=lambda x: x[2], reverse=True)
    else:
        results.sort(key=lambda x: x[2])
    
    # Get top N
    results = results[:top_n]
    
    conn.close()
    return results


def main():
    parser = argparse.ArgumentParser(description="Search for similar texts")
    parser.add_argument("--query", required=True, help="Query text")
    parser.add_argument("--db", default="embeddings.db", help="Database path")
    parser.add_argument("--model", default="embeddinggemma", help="Ollama model")
    parser.add_argument("--top", type=int, default=10, help="Number of results")
    parser.add_argument("--metric", default="cosine", choices=["cosine", "euclidean", "dot"], help="Similarity metric")
    parser.add_argument("--table", default="texts", help="Table name")
    
    args = parser.parse_args()
    
    results = search_similar(
        query=args.query,
        db_path=args.db,
        model=args.model,
        top_n=args.top,
        metric=args.metric,
        table_name=args.table
    )
    
    # Print results
    print(f"\nTop {len(results)} results for '{args.query}' ({args.metric}):\n")
    for i, (text, source, score) in enumerate(results, 1):
        print(f"{i}. {source or 'N/A'}")
        print(f"   {text[:100]}{'...' if len(text) > 100 else ''}")
        print(f"   Score: {score:.4f}")
        print()


if __name__ == "__main__":
    main()