#!/usr/bin/env python3
"""
Generate pre-rendered JSON data for notebook pages from DuckDB databases.

This script queries a DuckDB database and transforms the results into a JSON
structure suitable for use with React notebook pages in web/scratch.

Usage:
    python generate-story-json.py <database_path> <output_path> [--regional]

Example:
    python generate-story-json.py research/usa_energy/usa_electricity.duckdb \
        web/scratch/src/data/usa_energy/energy_story.json
"""

import sys
import json
import duckdb
from pathlib import Path


def generate_usa_energy_json(db_path: str, output_path: str) -> None:
    """Generate JSON for US electricity energy analysis."""
    
    conn = duckdb.connect(db_path)
    
    # Generate headline statistics
    headline_query = """
    WITH overview AS (
        SELECT 
            year,
            total_generation_twh,
            total_retail_sales_twh,
            coal_pct,
            natural_gas_pct,
            nuclear_pct,
            renewables_pct,
            loss_pct
        FROM annual_overview
    ),
    first_last AS (
        SELECT 
            (SELECT year FROM overview ORDER BY year LIMIT 1) AS first_year,
            (SELECT year FROM overview ORDER BY year DESC LIMIT 1) AS last_year
    )
    SELECT 
        fl.first_year,
        fl.last_year,
        ((SELECT total_generation_twh FROM overview WHERE year = fl.last_year) / 
         (SELECT total_generation_twh FROM overview WHERE year = fl.first_year) - 1) * 100 AS generation_change_pct,
        ((SELECT total_retail_sales_twh FROM overview WHERE year = fl.last_year) / 
         (SELECT total_retail_sales_twh FROM overview WHERE year = fl.first_year) - 1) * 100 AS retail_sales_change_pct,
        (SELECT MAX(coal_pct) FROM overview) - (SELECT MIN(coal_pct) FROM overview) AS coal_share_pp,
        (SELECT MAX(natural_gas_pct) FROM overview) - (SELECT MIN(natural_gas_pct) FROM overview) AS natural_gas_share_pp,
        (SELECT MAX(renewables_pct) FROM overview) - (SELECT MIN(renewables_pct) FROM overview) AS renewables_share_pp,
        (SELECT MAX(nuclear_pct) FROM overview) - (SELECT MIN(nuclear_pct) FROM overview) AS nuclear_share_pp,
        (SELECT loss_pct FROM overview WHERE year = fl.first_year) AS loss_pct_first,
        (SELECT loss_pct FROM overview WHERE year = fl.last_year) AS loss_pct_last
    FROM first_last fl
    """
    
    headline = conn.execute(headline_query).fetchone()
    headline_dict = {
        "first_year": headline[0],
        "last_year": headline[1],
        "generation_change_pct": headline[2],
        "retail_sales_change_pct": headline[3],
        "coal_share_pp": headline[4],
        "natural_gas_share_pp": headline[5],
        "renewables_share_pp": headline[6],
        "nuclear_share_pp": headline[7],
        "loss_pct_first": headline[8],
        "loss_pct_last": headline[9],
    }
    
    # Generate annual data
    annual_query = """
    SELECT 
        year,
        total_generation_twh AS generation_twh,
        total_retail_sales_twh AS retail_sales_twh,
        losses_and_other_twh AS losses_twh,
        loss_pct
    FROM annual_overview
    ORDER BY year
    """
    annual = conn.execute(annual_query).fetchdf().to_dict(orient="records")
    
    # Generate demand by sector
    demand_query = """
    SELECT 
        year,
        residential_million_kwh / 1000 AS residential_twh,
        commercial_million_kwh / 1000 AS commercial_twh,
        industrial_million_kwh / 1000 AS industrial_twh,
        transportation_million_kwh / 1000 AS transportation_twh,
        total_million_kwh / 1000 AS total_twh
    FROM retail_sales_by_sector
    ORDER BY year
    """
    demand_by_sector = conn.execute(demand_query).fetchdf().to_dict(orient="records")
    
    # Generate demand index
    demand_index_query = """
    WITH base AS (
        SELECT year, total_million_kwh / 1000 AS total_twh
        FROM retail_sales_by_sector
        ORDER BY year
    ),
    indexed AS (
        SELECT 
            year,
            total_twh,
            FIRST_VALUE(total_twh) OVER (ORDER BY year) AS first_total,
            LAG(total_twh) OVER (ORDER BY year) AS prev_total
        FROM base
    )
    SELECT 
        year,
        total_twh,
        (total_twh / first_total) * 100 AS index_2014_100,
        CASE 
            WHEN prev_total IS NULL THEN NULL
            ELSE ((total_twh - prev_total) / prev_total) * 100 
        END AS yoy_pct
    FROM indexed
    ORDER BY year
    """
    demand_index = conn.execute(demand_index_query).fetchdf().to_dict(orient="records")
    
    # Generate generation by key sources
    generation_query = """
    SELECT 
        year,
        total_thousand_mwh / 1000 AS total_twh,
        coal_thousand_mwh / 1000 AS coal_twh,
        natural_gas_thousand_mwh / 1000 AS natural_gas_twh,
        nuclear_thousand_mwh / 1000 AS nuclear_twh,
        (wind_thousand_mwh + solar_total_thousand_mwh + hydro_conventional_thousand_mwh + 
         geothermal_thousand_mwh + wood_thousand_mwh + biomass_other_thousand_mwh) / 1000 AS renewables_twh,
        (geothermal_thousand_mwh + wood_thousand_mwh + biomass_other_thousand_mwh + 
         petroleum_thousand_mwh + petroleum_coke_thousand_mwh + other_gases_thousand_mwh + 
         other_thousand_mwh) / 1000 AS other_twh
    FROM generation_by_fuel_source
    ORDER BY year
    """
    generation_key = conn.execute(generation_query).fetchdf().to_dict(orient="records")
    
    # Generate generation shares (from summary table)
    shares_query = """
    SELECT 
        year,
        coal_pct,
        natural_gas_pct,
        nuclear_pct,
        renewables_pct
    FROM summary
    ORDER BY year
    """
    generation_shares = conn.execute(shares_query).fetchdf().to_dict(orient="records")
    
    # Generate renewables breakdown
    renewables_query = """
    SELECT 
        year,
        wind_thousand_mwh / 1000 AS wind_twh,
        solar_total_thousand_mwh / 1000 AS solar_twh,
        hydro_conventional_thousand_mwh / 1000 AS hydro_twh,
        (geothermal_thousand_mwh + wood_thousand_mwh + biomass_other_thousand_mwh) / 1000 AS other_renewables_twh,
        (wind_thousand_mwh + solar_total_thousand_mwh + hydro_conventional_thousand_mwh + 
         geothermal_thousand_mwh + wood_thousand_mwh + biomass_other_thousand_mwh) / 1000 AS renewables_twh
    FROM generation_by_fuel_source
    ORDER BY year
    """
    renewables_breakdown = conn.execute(renewables_query).fetchdf().to_dict(orient="records")
    
    conn.close()
    
    # Build the complete JSON structure
    output = {
        "meta": {
            "dataset": "usa_energy",
            "source": "U.S. Energy Information Administration (EIA) Open Data (ELEC.* series)",
            "coverage": "United States, annual",
            "years": [headline_dict["first_year"], headline_dict["last_year"]],
            "generated_from": str(Path(db_path).parent)
        },
        "headline": headline_dict,
        "annual": annual,
        "demand_by_sector": demand_by_sector,
        "demand_index": demand_index,
        "generation_key_sources": generation_key,
        "generation_shares": generation_shares,
        "renewables_breakdown": renewables_breakdown
    }
    
    # Write output
    output_file = Path(output_path)
    output_file.parent.mkdir(parents=True, exist_ok=True)
    
    with open(output_file, "w") as f:
        json.dump(output, f, indent=2)
    
    print(f"Generated JSON: {output_path}")


def generate_regional_json(db_path: str, output_path: str) -> None:
    """Generate JSON for regional ISO data (mock data for demo)."""
    
    # Regional data for major ISOs (2014-2024)
    regions = ["CAISO", "ERCOT", "PJM", "NYISO"]
    years = list(range(2014, 2025))
    
    # Mock data based on typical regional patterns
    regional_data = {
        "generation_twh": [],
        "generation_shares": [],
        "summary": {}
    }
    
    # Generate data for each region
    for region in regions:
        # Different patterns for each region
        if region == "CAISO":
            # California: rapid renewables growth, coal eliminated
            coal = [15.0, 12.0, 9.0, 6.0, 4.0, 2.0, 1.0, 0.5, 0.2, 0.0, 0.0]
            gas = [120.0, 125.0, 130.0, 135.0, 140.0, 145.0, 140.0, 135.0, 130.0, 125.0, 120.0]
            nuclear = [18.0, 18.0, 18.0, 18.0, 18.0, 18.0, 9.0, 9.0, 9.0, 9.0, 9.0]
            renewables = [35.0, 45.0, 55.0, 70.0, 90.0, 110.0, 130.0, 150.0, 170.0, 190.0, 210.0]
        elif region == "ERCOT":
            # Texas: wind-dominated, gas steady
            coal = [40.0, 35.0, 30.0, 25.0, 20.0, 15.0, 10.0, 8.0, 6.0, 4.0, 3.0]
            gas = [150.0, 155.0, 160.0, 165.0, 170.0, 175.0, 180.0, 185.0, 190.0, 195.0, 200.0]
            nuclear = [12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0, 12.0]
            renewables = [15.0, 25.0, 40.0, 60.0, 85.0, 110.0, 135.0, 160.0, 185.0, 210.0, 235.0]
        elif region == "PJM":
            # Mid-Atlantic: coal to gas transition
            coal = [180.0, 170.0, 160.0, 150.0, 140.0, 130.0, 120.0, 110.0, 100.0, 90.0, 80.0]
            gas = [120.0, 130.0, 140.0, 150.0, 160.0, 170.0, 180.0, 190.0, 200.0, 210.0, 220.0]
            nuclear = [90.0, 90.0, 90.0, 90.0, 90.0, 90.0, 90.0, 90.0, 90.0, 90.0, 90.0]
            renewables = [15.0, 20.0, 25.0, 35.0, 50.0, 70.0, 90.0, 110.0, 130.0, 150.0, 170.0]
        else:  # NYISO
            # New York: hydro/nuclear base, renewables grow
            coal = [5.0, 4.0, 3.0, 2.0, 1.0, 0.5, 0.2, 0.0, 0.0, 0.0, 0.0]
            gas = [80.0, 85.0, 90.0, 95.0, 100.0, 105.0, 110.0, 115.0, 120.0, 125.0, 130.0]
            nuclear = [45.0, 45.0, 45.0, 45.0, 45.0, 45.0, 35.0, 35.0, 35.0, 35.0, 35.0]
            renewables = [35.0, 40.0, 45.0, 55.0, 70.0, 85.0, 100.0, 115.0, 130.0, 145.0, 160.0]
        
        # Calculate totals and shares
        total = [c + g + n + r for c, g, n, r in zip(coal, gas, nuclear, renewables)]
        coal_pct = [c / t * 100 for c, t in zip(coal, total)]
        gas_pct = [g / t * 100 for g, t in zip(gas, total)]
        nuclear_pct = [n / t * 100 for n, t in zip(nuclear, total)]
        renewables_pct = [r / t * 100 for r, t in zip(renewables, total)]
        
        regional_data["generation_twh"].append({
            "region": region,
            "years": years,
            "coal": coal,
            "natural_gas": gas,
            "nuclear": nuclear,
            "renewables": renewables
        })
        
        regional_data["generation_shares"].append({
            "region": region,
            "years": years,
            "coal_pct": coal_pct,
            "natural_gas_pct": gas_pct,
            "nuclear_pct": nuclear_pct,
            "renewables_pct": renewables_pct
        })
        
        # Summary statistics
        regional_data["summary"][region] = {
            "first_year": years[0],
            "last_year": years[-1],
            "total_change_pct": ((total[-1] / total[0]) - 1) * 100,
            "coal_share_change": coal_pct[-1] - coal_pct[0],
            "natural_gas_share_change": gas_pct[-1] - gas_pct[0],
            "renewables_share_change": renewables_pct[-1] - renewables_pct[0]
        }
    
    # Write output
    output_file = Path(output_path)
    output_file.parent.mkdir(parents=True, exist_ok=True)
    
    with open(output_file, "w") as f:
        json.dump(regional_data, f, indent=2)
    
    print(f"Generated regional JSON: {output_path}")


def main():
    if len(sys.argv) < 3:
        print("Usage: python generate-story-json.py <database_path> <output_path> [--regional]")
        sys.exit(1)
    
    db_path = sys.argv[1]
    output_path = sys.argv[2]
    regional = "--regional" in sys.argv
    
    if regional:
        generate_regional_json(db_path, output_path)
    else:
        generate_usa_energy_json(db_path, output_path)


if __name__ == "__main__":
    main()