# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/11_metrics.core.ipynb (unless otherwise specified).

__all__ = ['Metric', 'records2coco', 'coco_api_from_records', 'COCOMetric']

# Cell
from ..imports import *
from ..core import *
from ..models import *
from .coco_eval import CocoEvaluator
from pycocotools.coco import COCO

# Cell
class Metric:
    def __init__(self): self._model = None
    def step(self, xb, yb, preds): raise NotImplementedError
    def end(self, outs): raise NotImplementedError
    def register_model(self, model): self._model = model
    @property
    def model(self):
        if notnone(self._model): return self._model
        raise RuntimeError('Register a model with `register_model` before using the metric')

# Cell
def records2coco(records, catmap):
    cats = [{'id':i, 'name':o.name} for i,o in catmap.i2o.items()]
    annots = defaultdict(list)
    iinfos = []
    i = 0
    for r in tqdm(records):
        iinfos.append({
            'id': r.iinfo.iid,
            'file_name': r.iinfo.fp.name,
            'width': r.iinfo.w,
            'height': r.iinfo.h,
        })
        for annot in r.annot:
            annots['id'].append(i) # TODO: Careful with ids! when over all dataset
            annots['image_id'].append(r.iinfo.iid)
            annots['category_id'].append(annot.oid)
            annots['bbox'].append(annot.bbox.xywh)
            annots['area'].append(annot.bbox.area)
            # TODO: for other types of masks
            if notnone(annot.seg): annots['segmentation'].extend(annot.seg.to_erle(r.iinfo.h, r.iinfo.w))
            annots['iscrowd'].append(annot.iscrowd)
            # TODO: Keypoints
            i += 1
    assert allequal(lmap(len, annots.values())), 'Mismatch lenght of elements'
    annots = [{k:v[i] for k,v in annots.items()} for i in range_of(annots['id'])]
    return {'images': iinfos, 'annotations': annots, 'categories': cats}

# Cell
def coco_api_from_records(records, catmap):
    coco_ds = COCO()
    coco_ds.dataset = records2coco(records, catmap)
    coco_ds.createIndex()
    return coco_ds

# Cell
def _get_iou_types(model):
    model_without_ddp = model
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model_without_ddp = model.module
    iou_types = ["bbox"]
    if isinstance(model_without_ddp, MaskRCNNModel):
        iou_types.append("segm")
    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
        raise NotImplementedError
#         iou_types.append("keypoints")
    return iou_types

# Cell
class COCOMetric(Metric):
    def __init__(self, records, catmap):
        super().__init__()
        self._coco_ds = coco_api_from_records(records, catmap)

    def register_model(self, model):
        super().register_model(model)
        self._create_coco_eval()

    def step(self, xb, yb, preds):
        # TODO: Implement batch_to_cpu helper function
        preds = [{k:v.to(torch.device('cpu')) for k,v in p.items()} for p in preds]
        res = {y["image_id"].item():pred for y,pred in zip(yb, preds)}
        self.coco_evaluator.update(res)

    def end(self, outs):
        self.coco_evaluator.synchronize_between_processes()
        self.coco_evaluator.accumulate()
        self.coco_evaluator.summarize()
        self._create_coco_eval()

    def _create_coco_eval(self):
        self.coco_evaluator = CocoEvaluator(self._coco_ds, _get_iou_types(self.model))