import cv2
import numpy as np
import base64
import asyncio
import requests
import logging
import time
from threading import Thread, Lock
from queue import Queue
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
from ultralytics import YOLO
import json

# Configuration du logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

app = FastAPI(title="Drone Surveillance API - Système Minier")

# Configuration
class Config:
    LARAVEL_API_URL = "http://localhost:8000/api"  # URL de votre Laravel
    LARAVEL_API_KEY = "votre-api-key-secrete"       # Clé API pour Laravel
    CONFIDENCE_THRESHOLD = 0.7
    RTSP_TIMEOUT = 10
    MAX_RETRIES = 3

# Liste des classes pertinentes et leur mapping
# Ceci rend le filtrage et le renommage plus clair
RELEVANT_CLASSES_MAPPING = {
    'person': 'person',
    'car': 'vehicle',
    'truck': 'heavy_equipment',
    'bus': 'heavy_equipment',
    # Ajouter d'autres classes si nécessaire, par ex. 'excavator', 'dump_truck'
    # 'fire': 'fire',
    # 'smoke': 'smoke'
}

# Chargement du modèle YOLO
# Utilisation de torch.device pour vérifier la disponibilité du GPU
try:
    model = YOLO('yolov8n.pt')
    logger.info("✅ Modèle YOLO chargé avec succès")
    logger.info(f"📊 Device: {model.device}")
except Exception as e:
    logger.error(f"❌ Erreur chargement modèle: {e}")
    model = None

# Modèles Pydantic
class DetectionRequest(BaseModel):
    image_data: str
    drone_id: int
    zone_id: int
    gps_coordinates: Dict[str, float]
    timestamp: str
    confidence_threshold: float = 0.7
    metadata: Dict[str, Any] = {}

class BoundingBox(BaseModel):
    x: int
    y: int
    width: int
    height: int

class DetectionResult(BaseModel):
    class_name: str = Field(..., description="Nom de la classe détectée (ex: 'person', 'vehicle')")
    confidence: float = Field(..., ge=0, le=1, description="Niveau de confiance de la détection")
    bbox: BoundingBox = Field(..., description="Coordonnées de la boîte englobante")

class DetectionRequest(BaseModel):
    image_data: str = Field(..., description="Données de l'image encodées en Base64")
    drone_id: int
    zone_id: int
    gps_coordinates: Dict[str, float]
    timestamp: str
    confidence_threshold: float = Field(0.7, ge=0, le=1, description="Seuil de confiance pour filtrer les détections")
    metadata: Dict[str, Any] = Field({}, description="Métadonnées additionnelles")

class DetectionResponse(BaseModel):
    detections: List[DetectionResult] = Field([], description="Liste des objets détectés")
    processing_time: float = Field(..., description="Temps de traitement de la requête en secondes")
    frame_id: str = Field(..., description="Identifiant unique du cadre, souvent le timestamp")
    drone_id: int = Field(..., description="Identifiant du drone émetteur")
    zone_id: int = Field(..., description="Identifiant de la zone surveillée")
    gps_coordinates: Dict[str, float] = Field({}, description="Coordonnées GPS du drone")

# Gestionnaire de flux RTSP
class RTSPStreamManager:
    def __init__(self):
        self.streams: Dict[int, 'RTSPStream'] = {}
        self.lock = Lock()
    
    def add_stream(self, drone_id: int, rtsp_url: str):
        with self.lock:
            if drone_id in self.streams:
                self.streams[drone_id].stop()
            
            self.streams[drone_id] = RTSPStream(drone_id, rtsp_url)
            self.streams[drone_id].start()
            logger.info(f"📹 Stream ajouté pour drone {drone_id}: {rtsp_url}")
    
    def remove_stream(self, drone_id: int):
        with self.lock:
            if drone_id in self.streams:
                self.streams[drone_id].stop()
                del self.streams[drone_id]
                logger.info(f"🛑 Stream arrêté pour drone {drone_id}")
    
    def get_frame(self, drone_id: int) -> Optional[np.ndarray]:
        with self.lock:
            if drone_id in self.streams:
                return self.streams[drone_id].get_latest_frame()
        return None

# Flux RTSP individuel
class RTSPStream:
    def __init__(self, drone_id: int, rtsp_url: str):
        self.drone_id = drone_id
        self.rtsp_url = rtsp_url
        self.cap = None
        self.latest_frame = None
        self.lock = Lock()
        self.is_running = False
        self.thread = None
        self.frame_queue = Queue(maxsize=2)  # Garder seulement 2 frames
    
    def start(self):
        """Démarre la capture RTSP"""
        self.is_running = True
        self.thread = Thread(target=self._capture_worker, daemon=True)
        self.thread.start()
    
    def stop(self):
        """Arrête la capture RTSP"""
        self.is_running = False
        if self.cap:
            self.cap.release()
        if self.thread:
            self.thread.join(timeout=5)
    
    def _capture_worker(self):
        """Worker de capture avec reconnexion automatique"""
        retry_count = 0
        
        while self.is_running and retry_count < Config.MAX_RETRIES:
            try:
                logger.info(f"🔗 Connexion RTSP pour drone {self.drone_id}: {self.rtsp_url}")
                
                self.cap = cv2.VideoCapture(self.rtsp_url)
                self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
                self.cap.set(cv2.CAP_PROP_FPS, 15)  # Réduire à 15 FPS pour stabilité
                
                if not self.cap.isOpened():
                    raise Exception("Impossible d'ouvrir le flux RTSP")
                
                logger.info(f"✅ Flux RTSP connecté pour drone {self.drone_id}")
                retry_count = 0  # Reset counter on success
                
                while self.is_running:
                    ret, frame = self.cap.read()
                    
                    if not ret:
                        logger.warning(f"📹 Frame perdue pour drone {self.drone_id}")
                        break
                    
                    # Mettre à jour la dernière frame
                    with self.lock:
                        self.latest_frame = frame.copy()
                    
                    # Limiter le FPS pour éviter la surcharge
                    time.sleep(0.067)  # ~15 FPS
                
            except Exception as e:
                logger.error(f"❌ Erreur flux drone {self.drone_id}: {e}")
                retry_count += 1
                
                if self.cap:
                    self.cap.release()
                    self.cap = None
                
                if retry_count < Config.MAX_RETRIES:
                    logger.info(f"🔄 Reconnexion dans 5s... ({retry_count}/{Config.MAX_RETRIES})")
                    time.sleep(5)
                else:
                    logger.error(f"💥 Abandon après {Config.MAX_RETRIES} tentatives")
                    break
    
    def get_latest_frame(self) -> Optional[np.ndarray]:
        """Récupère la dernière frame disponible"""
        with self.lock:
            return self.latest_frame.copy() if self.latest_frame is not None else None

# Gestionnaire global
stream_manager = RTSPStreamManager()

# Fonctions utilitaires
def encode_image_to_base64(image: np.ndarray) -> str:
    """Encode une image OpenCV en base64"""
    try:
        _, buffer = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, 80])
        return base64.b64encode(buffer).decode('utf-8')
    except Exception as e:
        logger.error(f"Erreur encodage image: {e}")
        raise

def send_to_laravel(drone_id: int, frame_data: str, detections: List[Dict], metadata: Dict):
    """Envoie les détections à l'API Laravel"""
    try:
        payload = {
            'drone_id': drone_id,
            'frame_data': frame_data,
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
            'detections': detections,
            'metadata': metadata
        }
        
        headers = {
            'Content-Type': 'application/json',
            'Authorization': f'Bearer {Config.LARAVEL_API_KEY}',
            'X-API-Key': Config.LARAVEL_API_KEY
        }
        
        # CHOISIR UN SEUL ENDPOINT - Je recommande celui-ci :
        response = requests.post(
            f"{Config.LARAVEL_API_URL}/detection/process-frame",
            json=payload,
            headers=headers,
            timeout=10
        )
        
        if response.status_code == 200:
            logger.debug(f"✅ Détections envoyées à Laravel pour drone {drone_id}")
        else:
            logger.warning(f"⚠️ Erreur envoi Laravel: {response.status_code} - {response.text}")
            
    except requests.exceptions.RequestException as e:
        logger.error(f"❌ Erreur connexion Laravel: {e}")
    except Exception as e:
        logger.error(f"❌ Erreur inattendue envoi Laravel: {e}")

def process_detections(image: np.ndarray, drone_id: int, zone_id: int) -> List[Dict]:
    """Traite une image et retourne les détections"""
    if model is None:
        return []
    
    try:
        # Inférence YOLO
        results = model(image, conf=Config.CONFIDENCE_THRESHOLD, verbose=False)
        
        detections = []
        for result in results:
            boxes = result.boxes
            if boxes is not None:
                for box in boxes:
                    class_name = model.names[int(box.cls)]
                    
                    # Filtrage des classes pertinentes
                    if class_name in ['person', 'car', 'truck', 'bus', 'fire', 'smoke']:
                        detection = {
                            'class_name': class_name,
                            'confidence': float(box.conf),
                            'bbox': {
                                'x': int(box.xywh[0][0] - box.xywh[0][2]/2),
                                'y': int(box.xywh[0][1] - box.xywh[0][3]/2),
                                'width': int(box.xywh[0][2]),
                                'height': int(box.xywh[0][3])
                            }
                        }
                        detections.append(detection)
        
        return detections
        
    except HTTPException as e:
        # Re-lever l'exception pour qu'elle soit gérée par FastAPI
        raise e
    except Exception as e:
        logger.error(f"❌ Erreur traitement image: {e}")
        return []

# Background task pour le traitement continu
def process_drone_stream(drone_id: int, rtsp_url: str, zone_id: int):
    """Tâche de fond pour traiter le flux d'un drone"""
    logger.info(f"🚀 Démarrage traitement flux drone {drone_id}")
    
    # Ajouter le stream au manager
    stream_manager.add_stream(drone_id, rtsp_url)
    
    frame_count = 0
    while True:
        try:
            # Récupérer la dernière frame
            frame = stream_manager.get_frame(drone_id)
            
            if frame is not None:
                # Traiter une frame sur 3 pour réduire la charge (∼5 FPS)
                if frame_count % 3 == 0:
                    # Effectuer la détection
                    detections = process_detections(frame, drone_id, zone_id)
                    
                    if detections:
                        # Encoder et envoyer à Laravel
                        frame_data = encode_image_to_base64(frame)
                        metadata = {
                            'width': frame.shape[1],
                            'height': frame.shape[0],
                            'frame_count': frame_count
                        }
                        
                        send_to_laravel(drone_id, frame_data, detections, metadata)
                        logger.info(f"📊 Drone {drone_id}: {len(detections)} détections")
                
                frame_count += 1
            
            # Pause pour éviter la surcharge
            time.sleep(0.2)  # 5 FPS max
            
        except Exception as e:
            logger.error(f"❌ Erreur traitement drone {drone_id}: {e}")
            time.sleep(5)

# Endpoints FastAPI
@app.post("/api/v1/detect")
async def detect_objects(request: DetectionRequest):
    """Endpoint pour détection unique"""
    start_time = time.time()
    
    if model is None:
        raise HTTPException(status_code=500, detail="Modèle non chargé")
    
    try:
        # Décoder l'image
        image_bytes = base64.b64decode(request.image_data)
        nparr = np.frombuffer(image_bytes, np.uint8)
        image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        
        if image is None:
            raise HTTPException(status_code=400, detail="Image invalide")
        
        # Effectuer la détection
        detections = process_detections(image, request.drone_id, request.zone_id)
        processing_time = time.time() - start_time
        
        logger.info(f"🎯 Détection unique - Drone {request.drone_id}: {len(detections)} objets")
        
        return {
            "detections": detections,
            "processing_time": processing_time,
            "frame_id": request.timestamp
        }
        
    except Exception as e:
        logger.error(f"❌ Erreur détection: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/v1/stream/start")
async def start_drone_stream(drone_id: int, rtsp_url: str, zone_id: int = 1, background_tasks: BackgroundTasks = None):
    """Démarre le traitement continu d'un flux drone"""
    try:
        # Lancer le traitement en arrière-plan
        background_tasks.add_task(process_drone_stream, drone_id, rtsp_url, zone_id)
        
        return {
            "status": "success",
            "message": f"Traitement flux démarré pour drone {drone_id}",
            "drone_id": drone_id,
            "rtsp_url": rtsp_url
        }
        
    except Exception as e:
        logger.error(f"❌ Erreur démarrage flux: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/v1/stream/stop/{drone_id}")
async def stop_drone_stream(drone_id: int):
    """Arrête le traitement d'un flux drone"""
    try:
        stream_manager.remove_stream(drone_id)
        
        return {
            "status": "success", 
            "message": f"Flux arrêté pour drone {drone_id}",
            "drone_id": drone_id
        }
        
    except Exception as e:
        logger.error(f"❌ Erreur arrêt flux: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """Endpoint de santé"""
    return {
        "status": "healthy" if model is not None else "unhealthy",
        "model_loaded": model is not None,
        "active_streams": len(stream_manager.streams),
        "timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
    }

@app.get("/streams/status")
async def streams_status():
    """Statut des flux actifs"""
    streams_info = {}
    for drone_id, stream in stream_manager.streams.items():
        streams_info[drone_id] = {
            "is_running": stream.is_running,
            "has_frame": stream.latest_frame is not None
        }
    
    return {"active_streams": streams_info}

# Point d'entrée pour le développement
if __name__ == "__main__":
    import uvicorn
    
    logger.info("🚀 Démarrage API de Surveillance Drone...")
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=5000,
        log_level="info"
    )
