"""
CorvOS v2.1-Sigma - Quantum Security Shield
Protecao contra ataques a coerencia lambda2 = 0.999
Synapse-k | Arkhe(n) | Rio de Janeiro
"""
import numpy as np
from scipy.linalg import eig, norm
from scipy.fft import fft, fftfreq
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Set, Optional
from datetime import datetime, timedelta
from enum import Enum
import hashlib, hmac, asyncio, random
from collections import deque, defaultdict

class ThreatType(Enum):
    DECOHERENCE_INJECTION = "decoherence_injection"
    SENSOR_POISONING = "sensor_poisoning"
    RETROCAUSAL_HIJACKING = "retrocausal_hijacking"
    QUANTUM_DDOS = "quantum_ddos"
    EIGENVALUE_MANIPULATION = "eigenvalue_manipulation"
    ENTANGLEMENT_SPOOFING = "entanglement_spoofing"
    TEMPORAL_PARADOX = "temporal_paradox"

class ThreatSeverity(Enum):
    CRITICAL = 4
    HIGH = 3
    MEDIUM = 2
    LOW = 1

@dataclass
class QuantumThreat:
    threat_id: str
    threat_type: ThreatType
    severity: ThreatSeverity
    timestamp: datetime
    affected_sensors: Set[int]
    coherence_impact: float
    attack_vector: np.ndarray
    containment_status: str = "ACTIVE"
    mitigation_applied: Optional[str] = None

class PhaseAnomalyDetector:
    def __init__(self, n_sensors: int = 168, history_length: int = 100):
        self.n_sensors = n_sensors
        self.phase_history = deque(maxlen=history_length)
        self.expected_phase_derivative = 0.0

    def update(self, current_phases: np.ndarray) -> None:
        self.phase_history.append(current_phases.copy())

    def detect_spoofing(self) -> Tuple[bool, Set[int], float]:
        if len(self.phase_history) < 3:
            return False, set(), 0.0
        recent = np.array(list(self.phase_history)[-3:])
        phase_diff = np.abs(np.diff(recent, axis=0))
        max_jump = np.max(phase_diff, axis=0)
        suspicious = np.where(max_jump > np.pi / 2)[0]
        phases = np.array(list(self.phase_history)[-1])
        diff_matrix = np.abs(phases[:, None] - phases[None, :])
        neighbor_diffs = []
        for i in range(self.n_sensors):
            diffs = np.delete(diff_matrix[i], i)
            diffs.sort()
            neighbor_diffs.append(np.mean(diffs[:5]))
        isolated = np.where(np.array(neighbor_diffs) > 1.0)[0]
        affected = set(suspicious) | set(isolated)
        if affected:
            severity = len(affected) / self.n_sensors * min(1.0, np.mean(max_jump) / np.pi)
            return True, affected, severity
        return False, set(), 0.0

    def _compute_neighbor_differences(self, phases: np.ndarray) -> np.ndarray:
        phases = np.array(list(self.phase_history)[-1])
        diff_matrix = np.abs(phases[:, None] - phases[None, :])
        neighbor_diffs = []
        for i in range(self.n_sensors):
            diffs = np.delete(diff_matrix[i], i)
            diffs.sort()
            neighbor_diffs.append(np.mean(diffs[:5]))
        return np.array(neighbor_diffs)

class DecoherenceAttackDetector:
    def __init__(self, target_lambda: float = 0.999, decay_threshold: float = 0.001):
        self.target_lambda = target_lambda
        self.decay_threshold = decay_threshold
        self.lambda_history = deque(maxlen=100)
        self.noise_spectrum = None

    def update(self, current_lambda: float, sensor_readings: np.ndarray) -> None:
        self.lambda_history.append(current_lambda)
        fft_vals = fft(sensor_readings)
        freqs = fftfreq(len(sensor_readings))
        self.noise_spectrum = np.abs(fft_vals[:len(fft_vals) // 2])

    def detect_attack(self) -> Tuple[bool, float, str]:
        if len(self.lambda_history) < 2:
            return False, 0.0, "none"
        recent = list(self.lambda_history)[-10:]
        decay_rate = (recent[0] - recent[-1]) / len(recent) if len(recent) > 1 else 0
        if decay_rate > self.decay_threshold:
            severity = min(1.0, decay_rate / (self.decay_threshold * 5))
            return True, severity, "induced_decoherence"
        if self.noise_spectrum is not None:
            high_freq_noise = np.mean(self.noise_spectrum[len(self.noise_spectrum) // 4:])
            if high_freq_noise > 0.5:
                return True, min(1.0, high_freq_noise), "rf_interference"
        return False, 0.0, "none"

class ByzantineConsensusFilter:
    def __init__(self, n_sensors: int = 168, quorum_required: int = 112):
        self.n_sensors = n_sensors
        self.quorum_required = quorum_required
        self.sensor_reputation = np.ones(n_sensors)
        self.reputation_decay = 0.95

    def update_reputation(self, sensor_id: int, was_consistent: bool):
        if was_consistent:
            self.sensor_reputation[sensor_id] = min(1.0, self.sensor_reputation[sensor_id] * 1.02)
        else:
            self.sensor_reputation[sensor_id] *= 0.8

    def compute_consensus(self, readings: np.ndarray) -> Tuple[float, Set[int], Set[int]]:
        weights = self.sensor_reputation / np.sum(self.sensor_reputation)
        weighted_median = self._weighted_median(readings, weights)
        deviations = np.abs(readings - weighted_median)
        mad = np.median(deviations)
        threshold = 3 * mad
        valid = np.where(deviations <= threshold)[0]
        malicious = np.where(deviations > threshold)[0]
        for sid in valid:
            self.update_reputation(sid, True)
        for sid in malicious:
            self.update_reputation(sid, False)
        if len(valid) >= self.quorum_required:
            consensus = np.mean(readings[valid])
            return consensus, set(valid), set(malicious)
        else:
            return weighted_median, set(), set(malicious)

    def _weighted_median(self, data: np.ndarray, weights: np.ndarray) -> float:
        idx = np.argsort(data)
        sorted_data = data[idx]
        sorted_weights = weights[idx]
        cumsum = np.cumsum(sorted_weights)
        median_idx = np.searchsorted(cumsum, 0.5)
        return sorted_data[median_idx]

class RetrocausalHandshakeValidator:
    def __init__(self, difficulty: int = 4):
        self.difficulty = difficulty
        self.seen_hashes = set()

    def validate_pre_ack(self, packet_hash: str, claimed_future_time: datetime,
                         sensor_quorum: Set[int], signature: str) -> Tuple[bool, str]:
        if packet_hash in self.seen_hashes:
            return False, "Replay detected"
        if not packet_hash.startswith('0' * self.difficulty):
            return False, "Invalid proof-of-work"
        if claimed_future_time < datetime.now():
            return False, "Pre-ACK from past, not future"
        if len(sensor_quorum) < 112:
            return False, f"Insufficient quorum: {len(sensor_quorum)}/112"
        expected_sig = hashlib.sha256(f"{packet_hash}{claimed_future_time.isoformat()}".encode()).hexdigest()[:16]
        if signature != expected_sig:
            return False, "Invalid signature"
        self.seen_hashes.add(packet_hash)
        return True, "Valid pre-ACK accepted"

class QuantumSecurityShield:
    def __init__(self):
        self.phase_detector = PhaseAnomalyDetector()
        self.decoherence_detector = DecoherenceAttackDetector()
        self.consensus_filter = ByzantineConsensusFilter()
        self.handshake_validator = RetrocausalHandshakeValidator()
        self.event_log: List[QuantumThreat] = []
        self.blocked_sensors: Set[int] = set()
        self.quarantine_duration = 60

    async def process_sensor_readings(self, sensor_phases: np.ndarray,
                                     measured_lambda: float) -> Dict:
        self.phase_detector.update(sensor_phases)
        self.decoherence_detector.update(measured_lambda, sensor_phases)
        spoofing_detected, spoofed_sensors, spoof_severity = self.phase_detector.detect_spoofing()
        decoherence_detected, deco_severity, attack_type = self.decoherence_detector.detect_attack()
        consensus, valid_sensors, malicious_sensors = self.consensus_filter.compute_consensus(sensor_phases)
        new_blocked = malicious_sensors - valid_sensors
        self.blocked_sensors.update(new_blocked)
        mitigation = "none"
        if spoofing_detected or decoherence_detected:
            mitigation = self._apply_mitigation(spoofing_detected, decoherence_detected,
                                                spoofed_sensors, new_blocked)
        if spoofing_detected or decoherence_detected or len(malicious_sensors) > 0:
            event = QuantumThreat(
                threat_id=f"EVT-{datetime.now().timestamp()}",
                threat_type=ThreatType.ENTANGLEMENT_SPOOFING if spoofing_detected else
                            ThreatType.DECOHERENCE_INJECTION if decoherence_detected else
                            ThreatType.ENTANGLEMENT_SPOOFING,
                severity=ThreatSeverity.HIGH if max(spoof_severity, deco_severity) > 0.6 else ThreatSeverity.MEDIUM,
                timestamp=datetime.now(),
                affected_sensors=spoofed_sensors | malicious_sensors,
                coherence_impact=measured_lambda,
                attack_vector=sensor_phases,
                mitigation_applied=mitigation
            )
            self.event_log.append(event)
        self._clean_old_blocked()
        return {
            'consensus_lambda': consensus,
            'valid_sensors': len(valid_sensors),
            'blocked_sensors': len(self.blocked_sensors),
            'attack_detected': spoofing_detected or decoherence_detected,
            'attack_type': attack_type if decoherence_detected else 'phase_spoofing' if spoofing_detected else 'none',
            'severity': max(spoof_severity, deco_severity),
            'mitigation': mitigation
        }

    def _apply_mitigation(self, spoofing: bool, decoherence: bool,
                         spoofed: Set[int], malicious: Set[int]) -> str:
        mitigations = []
        if spoofing:
            for sid in spoofed:
                self.blocked_sensors.add(sid)
            mitigations.append("isolated_spoofed_sensors")
            mitigations.append("reduced_retrocausal_window")
        if decoherence:
            mitigations.append("increased_controller_gain")
            mitigations.append("activated_em_shielding")
        return ", ".join(mitigations)

    def _clean_old_blocked(self):
        if len(self.blocked_sensors) > 50:
            to_remove = set(np.random.choice(list(self.blocked_sensors),
                                              min(10, len(self.blocked_sensors)), replace=False))
            self.blocked_sensors -= to_remove

    def get_security_status(self) -> Dict:
        return {
            'blocked_sensors': len(self.blocked_sensors),
            'total_events': len(self.event_log),
            'last_event': self.event_log[-1].__dict__ if self.event_log else None,
            'system_integrity': 1.0 - (len(self.blocked_sensors) / 168) if len(self.blocked_sensors) < 168 else 0.0,
            'consensus_reliability': float(self.consensus_filter.sensor_reputation.mean())
        }
