# -*- coding: utf-8 -*-
"""ArrowSynth_InstSegTrain_TrainPercentage_Comittable.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1T-LSaqJVHsBV4tG0wueeICywTIrJCHSS
"""

import os
from PIL import Image
import numpy as np
from pycocotools import mask
from pycocotools.coco import COCO
from pycocotools import coco
import json
import re
import torch

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data.datasets import register_coco_instances
from detectron2.data import transforms as T
from detectron2.engine import DefaultTrainer
from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data import build_detection_train_loader
from detectron2.engine import DefaultPredictor
from detectron2.data import MetadataCatalog
from detectron2.data import DatasetCatalog
from detectron2.utils.logger import setup_logger

num_gpus = torch.cuda.device_count()
print(f"num_gpus: {num_gpus}")

sample_category = "ArrowSynthTurns"
data_root = f"/home/szage/runs/data/ArrowSynth/{sample_category}_train_samples"
train_percentage = 1.0
sample_names = sorted(os.listdir(data_root))

# Function to create COCO annotations with RLE encoding
def create_coco_annotations(image_id, instance_masks, image_size):
    annotations = []
    for i in range(len(instance_masks)):
        binary_mask = instance_masks[i]
        rle_mask = mask.encode(np.asfortranarray(binary_mask))
        rle_mask['counts'] = rle_mask['counts'].decode('utf-8')
        annotation = {
            'id': int(f"{image_id}{i+1:04d}"),
            'image_id': image_id,
            'category_id': 1,
            'segmentation': rle_mask,
            'area': int(np.sum(binary_mask)),
            'bbox': mask.toBbox(rle_mask).tolist(),
            'iscrowd': 0,
        }
        annotations.append(annotation)
    return annotations

# Function to convert your data to MS-COCO format
def convert_to_coco_format(data_root, sample_name):
    video_folder = os.path.join(data_root, sample_name, 'imgs')
    gt_folder = os.path.join(data_root, sample_name, 'labels')
    output_file = os.path.join(data_root, sample_name, sample_name+'_coco.json')

    image_files = sorted([file for file in os.listdir(video_folder) if file.endswith('.png')])
    IDs = sorted([file[0:7] for file in image_files])

    track_files = sorted([file for file in os.listdir(gt_folder) if file.endswith('.png')])
    track_IDs = sorted([file[0:7] for file in track_files])
    if len(set(IDs)-set(track_IDs)) != 0:
      raise ValueError(f"Mismatch in input and GT sample number: Input ({len(IDs)}), GT ({track_IDs})")

    coco_data = {
        'images': [],
        'annotations': [],
        'categories': [{'id': 1, 'name': 'arrow'}],
    }

    image_id = 1

    for frame_number in range(0, len(IDs)):
        frame_name = f'{sample_name}_{frame_number:03d}_img.png'
        image_path = os.path.join(video_folder, frame_name)
        gt_path = os.path.join(gt_folder, f'{sample_name}_{frame_number:03d}_label.png')

        if os.path.exists(image_path) and os.path.exists(gt_path):
            image = np.array(Image.open(image_path))

            instance_mask = np.array(Image.open(gt_path).convert('L'))
            unique_instances = np.unique(instance_mask)
            unique_instances = unique_instances[unique_instances > 0]
            instance_masks = [np.where(instance_mask == instance, 1, 0).astype(np.uint8) for instance in unique_instances]

            coco_data['images'].append({
                'id': image_id,
                'file_name': frame_name,
                'width': image.shape[1],
                'height': image.shape[0],
                'coco_url': image_path,
            })

            annotations = create_coco_annotations(image_id, instance_masks, image.shape[0] * image.shape[1])
            coco_data['annotations'].extend(annotations)

            image_id += 1

    with open(output_file, 'w') as json_file:
        json.dump(coco_data, json_file)

    print(f'{sample_name}_coco.json created')

for sample_name in sample_names:
  convert_to_coco_format(data_root, sample_name)

DatasetCatalog.clear()
MetadataCatalog.clear()

train_samples = tuple(sample_names)
train_samples = train_samples[:int(len(train_samples) * train_percentage)]

for sample_name in sample_names:
  register_coco_instances(sample_name, {}, os.path.join(data_root,sample_name,sample_name+"_coco.json"), os.path.join(data_root,sample_name, 'imgs'))

setup_logger()

#https://stackoverflow.com/questions/67061435/how-to-train-detectron2-model-with-multiple-custom-dataset
cfg = get_cfg()
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.MASK_ON = True
cfg.DATASETS.TRAIN = train_samples
cfg.INPUT.MASK_FORMAT = "bitmask"
cfg.DATALOADER.NUM_WORKERS = 4

cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.001
cfg.SOLVER.GAMMA = 0.1
cfg.SOLVER.MAX_ITER = int(27000*train_percentage)
cfg.SOLVER.STEPS = (int(21000*train_percentage), int(25000*train_percentage))
cfg.SOLVER.WARMUP_ITERS = 1000

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.OUTPUT_DIR = f"/home/szage/runs/trained_nets/GeneralTracking/ArrowSynth/{sample_category}/InstanceSegmentation/SampleReduced{train_percentage}/"
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

with open(os.path.join(cfg.OUTPUT_DIR,"config.yaml"), "w") as f:
    cfg.dump(stream=f)

augs = [
    T.RandomApply(T.RandomBrightness(0.7, 1.3),prob=0.2),
    T.RandomApply(T.RandomContrast(0.7, 1.3),prob=0.2),
    T.RandomApply(T.RandomFlip(prob=0.5),prob=0.2),
    T.RandomApply(T.RandomExtent((0.7, 1.3),(1,1)),prob=0.2)
]

class MyTrainer(DefaultTrainer):
    @classmethod
    def build_train_loader(cls, cfg):
        mapper = DatasetMapper(cfg, is_train=True, augmentations=augs)
        return build_detection_train_loader(cfg, mapper=mapper)

trainer = MyTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()