
    .hS                        d dl mZ d dlZd dlmZ d dlmZ d dlZd dl	Z	d dl
mZmZmZ d dlmZ d dlmZmZmZ d dlmZ d d	lmZmZmZ d d
lmZ  G d de      Zy)    )annotationsN)Path)Any)build_dataloaderbuild_yolo_dataset	converter)BaseValidator)LOGGERnmsops)check_requirements)ConfusionMatrix
DetMetricsbox_iou)plot_imagesc                       e Zd ZdZdd fdZddZddZddZddZddZ	ddZ
dd	Zdd
Zd dZddZd!dZd"d#dZd$dZd%dZ	 d&	 	 	 	 	 	 	 	 	 d'dZd(dZd)dZd*dZd+dZ	 	 d,	 	 	 	 	 	 	 	 	 	 	 d-dZ xZS ).DetectionValidatora~  
    A class extending the BaseValidator class for validation based on a detection model.

    This class implements validation functionality specific to object detection tasks, including metrics calculation,
    prediction processing, and visualization of results.

    Attributes:
        is_coco (bool): Whether the dataset is COCO.
        is_lvis (bool): Whether the dataset is LVIS.
        class_map (list[int]): Mapping from model class indices to dataset class indices.
        metrics (DetMetrics): Object detection metrics calculator.
        iouv (torch.Tensor): IoU thresholds for mAP calculation.
        niou (int): Number of IoU thresholds.
        lb (list[Any]): List for storing ground truth labels for hybrid saving.
        jdict (list[dict[str, Any]]): List for storing JSON detection results.
        stats (dict[str, list[torch.Tensor]]): Dictionary for storing statistics during validation.

    Examples:
        >>> from ultralytics.models.yolo.detect import DetectionValidator
        >>> args = dict(model="yolo11n.pt", data="coco8.yaml")
        >>> validator = DetectionValidator(args=args)
        >>> validator()
    c                
   t         |   ||||       d| _        d| _        d| _        d| j
                  _        t        j                  ddd      | _	        | j                  j                         | _        t               | _        y)a  
        Initialize detection validator with necessary variables and settings.

        Args:
            dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
            save_dir (Path, optional): Directory to save results.
            args (dict[str, Any], optional): Arguments for the validator.
            _callbacks (list[Any], optional): List of callback functions.
        FNdetectg      ?gffffff?
   )super__init__is_cocois_lvis	class_mapargstasktorchlinspaceiouvnumelniour   metrics)self
dataloadersave_dirr   
_callbacks	__class__s        `/var/www/html/ai-service/venv/lib/python3.12/site-packages/ultralytics/models/yolo/detect/val.pyr   zDetectionValidator.__init__-   sg     	XtZ@!		NN3b1	IIOO%	!|    c                ^   |j                         D ]W  \  }}t        |t        j                        s!|j	                  | j
                  | j
                  j                  dk(        ||<   Y | j                  j                  r|d   j                         n|d   j                         dz  |d<   |S )z
        Preprocess batch of images for YOLO validation.

        Args:
            batch (dict[str, Any]): Batch containing images and annotations.

        Returns:
            (dict[str, Any]): Preprocessed batch.
        cuda)non_blockingimg   )
items
isinstancer   Tensortodevicetyper   halffloat)r$   batchkvs       r)   
preprocesszDetectionValidator.preprocess@   s     KKM 	VDAq!U\\*44$++:J:Jf:T4Ua	V 04yy~~e))+5<CUCUCW[^^er*   c                6   | j                   j                  | j                  j                  d      }t	        |t
              xrL d|v xrF |j                  t        j                   d      xs" |j                  t        j                   d      | _	        t	        |t
              xr d|v xr | j                   | _
        | j                  rt        j                         n*t        t        dt        |j                         dz               | _        | j                  xj$                  | j                  j&                  xr) | j                  xs | j                  xr | j(                   z  c_        |j                   | _        t        |j                         | _        t-        |dd      | _        d	| _        g | _        |j                   | j4                  _        t7        |j                   | j                  j8                  xr | j                  j:                  
      | _        y)z
        Initialize evaluation metrics for YOLO detection validation.

        Args:
            model (torch.nn.Module): Model to validate.
         cocozval2017.txtztest-dev2017.txtlvis   end2endFr   )namessave_matchesN)datagetr   splitr1   strendswithossepr   r   r   coco80_to_coco91_classlistrangelenrB   r   	save_jsonvaltrainingncgetattrrA   seenjdictr#   r   plots	visualizeconfusion_matrix)r$   modelrP   s      r)   init_metricszDetectionValidator.init_metricsP   s    iimmDIIOOR0sC  d#d45bP`Fa9b 	
 "#s+R#RdllBR?C||99;QUV[\]_bchcncn_ors_sVtQu		tyy}}e$,,2N$,,eX\XeXeTee[[
ekk"ui7	
"[[ /ekkPTPYPYP_P_Pwdhdmdmdwdw xr*   c                    ddz  S )zBReturn a formatted string summarizing class metrics of YOLO model.z%22s%11s%11s%11s%11s%11s%11s)ClassImages	InstanceszBox(PRmAP50z	mAP50-95) r$   s    r)   get_desczDetectionValidator.get_desch   s    #'kkkr*   c                   t        j                  || j                  j                  | j                  j                  | j                  j
                  dk(  rdn| j                  d| j                  j                  xs | j                  j                  | j                  j                  | j                  | j                  j
                  dk(  	      }|D cg c])  }|ddddf   |dddf   |dddf   |ddd	df   d
+ c}S c c}w )aN  
        Apply Non-maximum suppression to prediction outputs.

        Args:
            preds (torch.Tensor): Raw predictions from the model.

        Returns:
            (list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
                'bboxes', 'conf', 'cls', and 'extra' tensors.
        r   r   Tobb)rR   multi_labelagnosticmax_detrA   rotatedN         )bboxesconfclsextra)r   non_max_suppressionr   rn   iour   rR   
single_clsagnostic_nmsrh   rA   )r$   predsoutputsxs       r)   postprocesszDetectionValidator.postprocessl   s     ))IINNIIMMIINNh.qDGGYY))CTYY-C-CII%%LLIINNe+

 cjj]^1QU8Qq!tWQq!tWqQRTUTVQVxXjjjs   .Dc                H   |d   |k(  }|d   |   j                  d      }|d   |   }|d   |   }|d   j                  dd }|d	   |   }|j                  d
   r<t        j                  |      t	        j
                  || j                        g d   z  }||||||d   |   dS )a*  
        Prepare a batch of images and annotations for validation.

        Args:
            si (int): Batch index.
            batch (dict[str, Any]): Batch data containing images and annotations.

        Returns:
            (dict[str, Any]): Prepared batch with processed annotations.
        	batch_idxro   rm   	ori_shaper.      N	ratio_padr   )r4   )r@   r   r@   r   im_file)ro   rm   r|   imgszr~   r   )squeezeshaper   	xywh2xyxyr   tensorr4   )	r$   sir8   idxro   bboxr|   r   r~   s	            r)   _prepare_batchz!DetectionValidator._prepare_batch   s     K B&El3''+Xs#+&r*	e""12&+&r*	99Q<==&eDKK)PQ])^^D""Y'+
 	
r*   c                L    | j                   j                  r|dxx   dz  cc<   |S )a  
        Prepare predictions for evaluation against ground truth.

        Args:
            pred (dict[str, torch.Tensor]): Post-processed predictions from the model.

        Returns:
            (dict[str, torch.Tensor]): Prepared predictions in native space.
        ro   r   )r   rs   )r$   preds     r)   _prepare_predz DetectionValidator._prepare_pred   s$     99K1Kr*   c                8   t        |      D ]  \  }}| xj                  dz  c_        | j                  ||      }| j                  |      }|d   j	                         j                         }|d   j                  d   dk(  }| j                  j                  i | j                  ||      |t        j                  |      |rt        j                  d      n |d   j	                         j                         |rt        j                  d      n |d   j	                         j                         d       | j                  j                  rx| j                  j!                  ||| j                  j"                         | j                  j$                  r0| j                  j'                  |d   |   |d   | j(                         |r| j                  j*                  s| j                  j,                  r| j/                  ||      }	| j                  j*                  r| j1                  	|       | j                  j,                  s8| j3                  	| j                  j4                  |d	   | j(                  d
z  t7        |d         j8                   dz          y)z
        Update metrics with new predictions and ground truth.

        Args:
            preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
            batch (dict[str, Any]): Batch data containing ground truth.
        r@   ro   r   rn   )
target_cls
target_imgrn   pred_cls)rn   r.   r   r|   labelsz.txtN)	enumeraterT   r   r   cpunumpyr   r#   update_stats_process_batchnpuniquezerosr   rV   rX   process_batchrn   rW   plot_matchesr&   rO   save_txtscale_predspred_to_jsonsave_one_txt	save_confr   stem)
r$   ru   r8   r   r   pbatchprednro   no_predpredn_scaleds
             r)   update_metricsz!DetectionValidator.update_metrics   s    "%( $	HBIINI((U3F&&t,E-##%++-CEl((+q0GLL%%))%8"%"$))C.+2BHHQKf8I8I8K8Q8Q8S/6E%L<L<L<N<T<T<V yy%%33E6		3W99&&))66uU|B7GPYIZ\`\i\ij yy""dii&8&8#//v>yy""!!,7yy!!!! II'';'MMH,$vi7H2I2N2N1Ot/TT	?$	r*   c                D   | j                   j                  r9dD ]4  }| j                  j                  | j                  || j
                         6 | j                  | j                  _        | j                  | j                  _        | j                  | j                  _        y)z8Set final values for metrics speed and confusion matrix.)TF)r&   	normalizeon_plotN)r   rV   rX   plotr&   r   speedr#   )r$   r   s     r)   finalize_metricsz#DetectionValidator.finalize_metrics   sv    99??( n	%%**DMMY`d`l`l*mn!ZZ(,(=(=% $r*   c                    | j                   j                  | j                  | j                  j                  | j
                         | j                   j                          | j                   j                  S )z
        Calculate and return metrics statistics.

        Returns:
            (dict[str, Any]): Dictionary containing metrics results.
        )r&   r   r   )r#   processr&   r   rV   r   clear_statsresults_dictrb   s    r)   	get_statszDetectionValidator.get_stats   sP     	dmm$))//SWS_S_`  "||(((r*   c                   ddt        | j                  j                        z  z   }t        j                  |d| j
                  | j                  j                  j                         g| j                  j                         z         | j                  j                  j                         dk(  r-t        j                  d| j                  j                   d       | j                  j                  r| j                  s| j                  dkD  rt        | j                  j                        rt!        | j                  j"                        D ]w  \  }}t        j                  || j$                  |   | j                  j&                  |   | j                  j                  |   g| j                  j)                  |      z         y yyyyy)	z0Print training/validation set metrics per class.z%22s%11i%11iz%11.3gallr   zno labels found in z, set, can not compute metrics without labelsr@   N)rN   r#   keysr
   inforT   nt_per_classsummean_resultswarningr   r   verboserQ   rR   statsr   ap_class_indexrB   nt_per_imageclass_result)r$   pfics       r)   print_resultsz DetectionValidator.print_results   s]    8c$,,2C2C.D#DDB%DLL,E,E,I,I,KjdllNgNgNijjk<<$$((*a/NN00@@lmn 99T]]tww{s4<<K]K]G^!$,,"="=> 	1

111!411!4 2215		 H_{]r*   c                Z   |d   j                   d   dk(  s|d   j                   d   dk(  r9dt        j                  |d   j                   d   | j                  ft              iS t        |d   |d         }d| j                  |d   |d   |      j                         j                         iS )a  
        Return correct prediction matrix.

        Args:
            preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
            batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.

        Returns:
            (dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
        ro   r   tpdtyperm   )	r   r   r   r"   boolr   match_predictionsr   r   )r$   ru   r8   rr   s       r)   r   z!DetectionValidator._process_batch  s     <a A%u););A)>!)C"((E%L$6$6q$9499#ETRSSeHouX7d,,U5\5<MQQSYY[\\r*   c                `    t        | j                  ||| j                  || j                        S )al  
        Build YOLO Dataset.

        Args:
            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`.

        Returns:
            (Dataset): YOLO dataset.
        )modestride)r   r   rD   r   )r$   img_pathr   r8   s       r)   build_datasetz DetectionValidator.build_dataset  s(     "$))Xudiid[_[f[fggr*   c           	         | j                  ||d      }t        ||| j                  j                  dd| j                  j                  | j
                        S )a   
        Construct and return dataloader.

        Args:
            dataset_path (str): Path to the dataset.
            batch_size (int): Size of each batch.

        Returns:
            (torch.utils.data.DataLoader): Dataloader for validation.
        rP   )r8   r   Fr{   )shufflerank	drop_last
pin_memory)r   r   r   workerscompilerQ   )r$   dataset_path
batch_sizedatasets       r)   get_dataloaderz!DetectionValidator.get_dataloader"  sU     $$\%$PIIii''}}
 	
r*   c                t    t        ||d   | j                  d| dz  | j                  | j                         y)z
        Plot validation image samples.

        Args:
            batch (dict[str, Any]): Batch containing images and annotations.
            ni (int): Batch index.
        r   	val_batchz_labels.jpg)r   pathsfnamerB   r   N)r   r&   rB   r   )r$   r8   nis      r)   plot_val_samplesz#DetectionValidator.plot_val_samples8  s:     		"--IbT"==**LL	
r*   c                   t        |      D ]#  \  }}t        j                  |d         |z  |d<   % |d   j                         }|xs | j                  j
                  }|D 	ci c].  }|t        j                  |D 	cg c]
  }	|	|   d|  c}	d      0 }
}}	t        j                  |
d   ddddf         |
d   ddddf<   t        |d   |
|d	   | j                  d
| dz  | j                  | j                         yc c}	w c c}	}w )au  
        Plot predicted bounding boxes on input images and save the result.

        Args:
            batch (dict[str, Any]): Batch containing images and annotations.
            preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
            ni (int): Batch index.
            max_det (Optional[int]): Maximum number of detections to plot.
        rn   rz   r   Ndimrm   rj   r.   r   r   z	_pred.jpg)imagesr   r   r   rB   r   )r   r   	ones_liker   r   rh   catr   	xyxy2xywhr   r&   rB   r   )r$   r8   ru   r   rh   r   r   r   r9   rw   batched_predss              r)   plot_predictionsz#DetectionValidator.plot_predictionsH  s    !' 	BGAt %V = AD	BQx}}.TYY..W[\RSEIIu&E!qtHW~&E1MM\\),}X7NqRTSTRTu7U)Vh2A2&< 	"--IbT";;**LL	
 'F\s   $D	=DD	D	c                2   ddl m}  |t        j                  |d   |d   ft        j                        d| j
                  t        j                  |d   |d   j                  d      |d	   j                  d      gd
            j                  ||       y)a  
        Save YOLO detections to a txt file in normalized coordinates in a specific format.

        Args:
            predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
            save_conf (bool): Whether to save confidence scores.
            shape (tuple[int, int]): Shape of the original image (height, width).
            file (Path): File path to save the detections.
        r   )Resultsr@   r   Nrm   rn   r{   ro   r   )pathrB   boxes)r   )
ultralytics.engine.resultsr   r   r   uint8rB   r   r   	unsqueezer   )r$   r   r   r   filer   s         r)   r   zDetectionValidator.save_one_txte  s     	7HHeAha):**))U8_eFm.E.Eb.I5QV<KaKabdKeflmn		

 (49(
-r*   c                <   t        |d         }|j                  }|j                         rt        |      n|}t	        j
                  |d         }|ddddfxx   |ddddf   dz  z  cc<   t        |j                         |d   j                         |d   j                               D ]i  \  }}}	| j                  j                  ||j                  | j                  t        |	         |D 
cg c]  }
t        |
d       c}
t        |d      d	       k yc c}
w )
a  
        Serialize YOLO predictions to COCO json format.

        Args:
            predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
                with bounding box coordinates, confidence scores, and class predictions.
            pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.

        Examples:
             >>> result = {
             ...     "image_id": 42,
             ...     "file_name": "42.jpg",
             ...     "category_id": 18,
             ...     "bbox": [258.15, 41.29, 348.26, 243.78],
             ...     "score": 0.236,
             ... }
        r   rm   Nr}   rn   ro      rk   )image_id	file_namecategory_idr   score)r   r   	isnumericintr   r   ziptolistrU   appendnamer   round)r$   r   r   r   r   r   boxbsr   rw   s              r)   r   zDetectionValidator.pred_to_jsonx  s    $ F9%&yy $ 03t9dmmE(O,ArrE
c!QR%j1n$
3::<v)=)=)?uATATAVW 		GAq!JJ (!%#'>>#a&#9234QU1a[4"1a[		 5s   /Dc           	     t    i |dt        j                  |d   |d   j                         |d   |d         iS )z.Scales predictions to the original image size.rm   r   r|   r~   )r~   )r   scale_boxesclone)r$   r   r   s      r)   r   zDetectionValidator.scale_preds  sN    

coowh%%'{# -	
 	
r*   c                    | j                   dz  }| j                  d   dz  | j                  rdnd| j                  j                   dz  }| j                  |||      S )a  
        Evaluate YOLO output in JSON format and return performance statistics.

        Args:
            stats (dict[str, Any]): Current statistics dictionary.

        Returns:
            (dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
        zpredictions.jsonr   r   zinstances_val2017.jsonlvis_v1_z.json)r&   rD   r   r   rF   coco_evaluate)r$   r   	pred_json	anno_jsons       r)   	eval_jsonzDetectionValidator.eval_json  sh     MM$66	IIf+/<<'x		GXX]=^` 	
 !!%I>>r*   c           	     
   | j                   j                  r7| j                  s| j                  rt	        | j
                        rt        j                  d| d| d       	 ||fD ]  }|j                         rJ | d        t        |t              r|gn|}t        |t              r|gn|}t        d       ddlm}m}  ||      }	|	j                  |      }
t!        |      D ]M  \  }} ||	|
|| j                  t        j                        }| j"                  j$                  j&                  D cg c]   }t)        t+        |      j,                        " c}|j.                  _        |j3                          |j5                          |j7                          |j8                  d	   |d
||   d    d<   |j8                  d   |d||   d    d<   | j                  s|j8                  d   |d||   d    d<   |j8                  d   |d||   d    d<   |j8                  d   |d||   d    d<   P | j                  r|d   |d<   |S |S c c}w # t:        $ r#}t        j<                  d|        Y d}~|S d}~ww xY w)az  
        Evaluate COCO/LVIS metrics using faster-coco-eval library.

        Performs evaluation using the faster-coco-eval library to compute mAP metrics
        for object detection. Updates the provided stats dictionary with computed metrics
        including mAP50, mAP50-95, and LVIS-specific metrics if applicable.

        Args:
            stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
            pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
            anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
            iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
                Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
            suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
                to iou_types if multiple types provided. Defaults to "Box".

        Returns:
            (dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
        z'
Evaluating faster-coco-eval mAP using z and z...z file not foundzfaster-coco-eval>=1.6.7r   )COCOCOCOeval_faster)iouType
lvis_styleprint_functionAP_50zmetrics/mAP50()AP_allzmetrics/mAP50-95(APrzmetrics/APr(APczmetrics/APc(APfzmetrics/APf(zmetrics/mAP50-95(B)fitnessz faster-coco-eval unable to run: N)r   rO   r   r   rN   rU   r
   r   is_filer1   rG   r   faster_coco_evalr	  r
  loadResr   r%   r   im_filesr   r   r   paramsimgIdsevaluate
accumulate	summarizestats_as_dict	Exceptionr   )r$   r   r  r  	iou_typessuffixrw   r	  r
  annor   r   iou_typerP   es                  r)   r  z DetectionValidator.coco_evaluate  sw   6 99DLLDLLc$**oKKB9+US\R]]`abG"I- >A99;=1#_(==;>+5i+EYK9	%/%<&&"#<=BI||I.#,Y#7 YKAx)dH^d^i^iC EIOOD[D[DdDd(eqT!W\\):(eCJJ%LLNNN$MMO ?B>O>OPW>XEN6!9Q<.:;ADARARS[A\E-fQil^1=>||@C@Q@QRW@XVAYq\N!<=@C@Q@QRW@XVAYq\N!<=@C@Q@QRW@XVAYq\N!<=!Y$ <<',-B'CE)$ u% )f   G!A!EFFGs9   $I <B:I 6%IBI "A+I I 	JI==J)NNNN)returnNone)r8   dict[str, Any]r%  r'  )rY   ztorch.nn.Moduler%  r&  )r%  rG   )ru   ztorch.Tensorr%  list[dict[str, torch.Tensor]])r   r   r8   r'  r%  r'  )r   dict[str, torch.Tensor]r%  r)  )ru   r(  r8   r'  r%  r&  )r%  r'  )ru   r)  r8   r'  r%  zdict[str, np.ndarray])rP   N)r   rG   r   rG   r8   
int | Noner%  ztorch.utils.data.Dataset)r   rG   r   r   r%  ztorch.utils.data.DataLoader)r8   r'  r   r   r%  r&  )N)
r8   r'  ru   r(  r   r   rh   r*  r%  r&  )
r   r)  r   r   r   ztuple[int, int]r   r   r%  r&  )r   r)  r   r'  r%  r&  )r   r)  r   r'  r%  r)  )r   r'  r%  r'  )r   Box)r   r'  r  rG   r  rG   r   str | list[str]r!  r,  r%  r'  )__name__
__module____qualname____doc__r   r;   rZ   rc   rx   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  __classcell__)r(   s   @r)   r   r      s    0$& y0lk0
8,\.	)(] h
,
" ko
#
,I
OR
]g
	
:.& D

?. &,"'== = 	=
 #=  = 
=r*   r   )
__future__r   rI   pathlibr   typingr   r   r   r   ultralytics.datar   r   r   ultralytics.engine.validatorr	   ultralytics.utilsr
   r   r   ultralytics.utils.checksr   ultralytics.utils.metricsr   r   r   ultralytics.utils.plottingr   r   ra   r*   r)   <module>r;     sB    # 	     L L 6 . . 7 J J 2a ar*   