#!/usr/bin/env python3
"""
High-throughput image-text embedding using OpenCLIP.

Usage:
    # Embed all images in a directory
    python3 embed_images.py --input /path/to/images --output embeddings.db

    # Search for images by text
    python3 embed_images.py --query "car" --db embeddings.db --top 10

    # Search for similar images
    python3 embed_images.py --image-query /path/to/image.jpg --db embeddings.db --top 10
"""
import argparse
import json
import sqlite3
import time
from pathlib import Path
from typing import List, Tuple

import numpy as np
import open_clip
import torch
from PIL import Image
from tqdm import tqdm


# Default model - ViT-B-32 is a good balance of speed and quality
MODEL_NAME = "ViT-B-32"
PRETRAINED = "laion2b_s34b_b79k"
IMAGE_SIZE = 224


def load_model(model_name: str = MODEL_NAME, pretrained: str = PRETRAINED):
    """Load the CLIP model."""
    print(f"Loading model: {model_name} / {pretrained}")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    model, _, preprocess = open_clip.create_model_and_transforms(
        model_name, 
        pretrained=pretrained,
        device=device
    )
    model.eval()
    
    tokenizer = open_clip.get_tokenizer(model_name)
    
    print(f"Model loaded. Embedding dim: {model.visual.output_dim}")
    return model, tokenizer, preprocess, device


def load_images(input_path: Path, pattern: str = "*.jpg", recursive: bool = False) -> List[Path]:
    """Load image paths from a directory."""
    if input_path.is_file():
        return [input_path]
    
    glob_fn = input_path.rglob if recursive else input_path.glob
    files = sorted(glob_fn(pattern))
    
    # Also check for common image extensions
    for ext in ["*.jpeg", "*.png", "*.webp"]:
        files.extend(sorted(glob_fn(ext)))
    
    return sorted(set(files))


def embed_images_batch(
    model,
    preprocess,
    device,
    image_paths: List[Path],
    batch_size: int = 64,
    show_progress: bool = True
) -> List[List[float]]:
    """Embed images in batches for high throughput."""
    print(f"Embedding {len(image_paths)} images with batch size {batch_size}...")
    
    embeddings = []
    iterator = range(0, len(image_paths), batch_size)
    if show_progress:
        iterator = tqdm(iterator, desc="Embedding")
    
    with torch.no_grad():
        for i in iterator:
            batch_paths = image_paths[i:i + batch_size]
            
            # Load and preprocess images
            images = []
            valid_indices = []
            
            for idx, img_path in enumerate(batch_paths):
                try:
                    img = Image.open(img_path).convert("RGB")
                    img = preprocess(img)
                    images.append(img)
                    valid_indices.append(idx)
                except Exception as e:
                    print(f"Error loading {img_path}: {e}")
            
            if not images:
                # Add None placeholders for all
                embeddings.extend([None] * len(batch_paths))
                continue
            
            # Stack images
            image_input = torch.stack(images).to(device)
            
            # Embed
            image_features = model.encode_image(image_input)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            
            # Convert to list
            batch_embeddings = image_features.cpu().numpy().tolist()
            
            # Reconstruct with None placeholders for failed loads
            batch_result = [None] * len(batch_paths)
            for local_idx, global_idx in enumerate(valid_indices):
                batch_result[global_idx] = batch_embeddings[local_idx]
            
            embeddings.extend(batch_result)
    
    return embeddings


def embed_text(model, tokenizer, device, text: str) -> List[float]:
    """Embed a text query."""
    text_tokens = tokenizer([text]).to(device)
    
    with torch.no_grad():
        text_features = model.encode_text(text_tokens)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    
    return text_features.cpu().numpy()[0].tolist()


def create_database(db_path: Path, table: str = "images") -> sqlite3.Connection:
    """Create SQLite database with embeddings table."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute(f"""
        CREATE TABLE IF NOT EXISTS {table} (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            filepath TEXT NOT NULL UNIQUE,
            filename TEXT NOT NULL,
            embedding TEXT,
            model TEXT,
            embedding_dim INTEGER,
            created_at TEXT DEFAULT (datetime('now'))
        )
    """)
    
    # Create indexes for faster search
    cursor.execute(f"CREATE INDEX IF NOT EXISTS idx_filename ON {table}(filename)")
    
    conn.commit()
    return conn


def save_embeddings(
    db_path: Path,
    image_paths: List[Path],
    embeddings: List[List[float]],
    model_name: str,
    table: str = "images"
):
    """Save embeddings to database."""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Create table if not exists
    cursor.execute(f"""
        CREATE TABLE IF NOT EXISTS {table} (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            filepath TEXT NOT NULL UNIQUE,
            filename TEXT NOT NULL,
            embedding TEXT,
            model TEXT,
            embedding_dim INTEGER,
            created_at TEXT DEFAULT (datetime('now'))
        )
    """)
    
    # Create indexes for faster search
    cursor.execute(f"CREATE INDEX IF NOT EXISTS idx_filename_{table} ON {table}(filename)")
    
    for img_path, emb in zip(image_paths, embeddings):
        if emb is None:
            continue
        
        cursor.execute(f"""
            INSERT OR REPLACE INTO {table} (filepath, filename, embedding, model, embedding_dim)
            VALUES (?, ?, ?, ?, ?)
        """, (
            str(img_path),
            img_path.name,
            json.dumps(emb),
            model_name,
            len(emb)
        ))
    
    conn.commit()
    conn.close()
    print(f"Saved embeddings to {db_path}")


def cosine_similarity(a: List[float], b: List[float]) -> float:
    """Calculate cosine similarity between two vectors."""
    a_arr = np.array(a)
    b_arr = np.array(b)
    return np.dot(a_arr, b_arr) / (np.linalg.norm(a_arr) * np.linalg.norm(b_arr))


def search_by_text(
    db_path: Path,
    query: str,
    model,
    tokenizer,
    device,
    top_k: int = 10,
    table: str = "images"
) -> List[Tuple[str, str, float]]:
    """Search for images by text query."""
    # Embed the query
    query_emb = embed_text(model, tokenizer, device, query)
    
    # Load all embeddings from database
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute(f"SELECT filepath, filename, embedding FROM {table}")
    results = cursor.fetchall()
    conn.close()
    
    # Calculate similarities
    similarities = []
    for filepath, filename, emb_json in results:
        emb = json.loads(emb_json)
        sim = cosine_similarity(query_emb, emb)
        similarities.append((filepath, filename, sim))
    
    # Sort by similarity and return top k
    similarities.sort(key=lambda x: x[2], reverse=True)
    return similarities[:top_k]


def search_by_image(
    db_path: Path,
    query_image_path: Path,
    model,
    preprocess,
    device,
    top_k: int = 10,
    table: str = "images"
) -> List[Tuple[str, str, float]]:
    """Search for similar images by image query."""
    # Embed the query image
    query_img = Image.open(query_image_path).convert("RGB")
    query_img = preprocess(query_img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        query_features = model.encode_image(query_img)
        query_features = query_features / query_features.norm(dim=-1, keepdim=True)
    
    query_emb = query_features.cpu().numpy()[0].tolist()
    
    # Load all embeddings from database
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute(f"SELECT filepath, filename, embedding FROM {table}")
    results = cursor.fetchall()
    conn.close()
    
    # Calculate similarities
    similarities = []
    for filepath, filename, emb_json in results:
        emb = json.loads(emb_json)
        sim = cosine_similarity(query_emb, emb)
        similarities.append((filepath, filename, sim))
    
    # Sort by similarity and return top k
    similarities.sort(key=lambda x: x[2], reverse=True)
    return similarities[:top_k]


def main():
    parser = argparse.ArgumentParser(description="Image-text embedding with OpenCLIP")
    parser.add_argument("--input", "-i", help="Input directory or file")
    parser.add_argument("--output", "-o", default="image_embeddings.db", help="Output database")
    parser.add_argument("--query", "-q", help="Text query to search")
    parser.add_argument("--image-query", help="Image file to search for similar images")
    parser.add_argument("--db", default="image_embeddings.db", help="Database file for search")
    parser.add_argument("--model", default=MODEL_NAME, help="Model name")
    parser.add_argument("--pretrained", default=PRETRAINED, help="Pretrained weights")
    parser.add_argument("--batch-size", type=int, default=64, help="Batch size for embedding")
    parser.add_argument("--top", type=int, default=10, help="Number of results to return")
    parser.add_argument("--pattern", default="*.jpg", help="File pattern to match")
    parser.add_argument("--recursive", "-r", action="store_true", help="Search recursively")
    parser.add_argument("--table", default="images", help="Table name in database")
    
    args = parser.parse_args()
    
    # Load model
    model, tokenizer, preprocess, device = load_model(args.model, args.pretrained)
    
    # Search mode
    if args.query:
        print(f"\nSearching for: '{args.query}'")
        results = search_by_text(Path(args.db), args.query, model, tokenizer, device, args.top, args.table)
        print(f"\nTop {len(results)} results:")
        for filepath, filename, score in results:
            print(f"  {score:.4f} - {filename} ({filepath})")
        return
    
    if args.image_query:
        print(f"\nSearching for images similar to: {args.image_query}")
        results = search_by_image(Path(args.image_query), Path(args.db), model, preprocess, device, args.top, args.table)
        print(f"\nTop {len(results)} results:")
        for filepath, filename, score in results:
            print(f"  {score:.4f} - {filename} ({filepath})")
        return
    
    # Embedding mode
    if not args.input:
        parser.error("--input is required for embedding mode")
    
    input_path = Path(args.input)
    output_path = Path(args.output)
    
    # Load images
    print(f"Loading images from {input_path}...")
    image_paths = load_images(input_path, args.pattern, args.recursive)
    print(f"Found {len(image_paths)} images")
    
    if not image_paths:
        print("No images found!")
        return
    
    # Embed
    start = time.time()
    embeddings = embed_images_batch(model, preprocess, device, image_paths, args.batch_size)
    elapsed = time.time() - start
    
    valid_count = sum(1 for e in embeddings if e is not None)
    print(f"Embedded {valid_count} images in {elapsed:.2f}s ({valid_count/elapsed:.1f} images/sec)")
    
    # Save
    save_embeddings(output_path, image_paths, embeddings, f"{args.model}/{args.pretrained}", args.table)


if __name__ == "__main__":
    main()