# -*- coding: utf-8 -*-
"""MOTSynthCVPR22_InstSegTrain_Comittable.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1ufJlz3JvkKqgGI0HzEtlyKynAyXNlP3V
"""

import os
import torch

from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data.datasets import register_coco_instances
from detectron2.data import transforms as T
from detectron2.engine import DefaultTrainer
from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.data import build_detection_train_loader
from detectron2.utils.logger import setup_logger

data_root = "/home/szage/runs/data/MOTSynthCVPR22/transformed_smallres/train"
annot_dir = os.path.join(data_root, "annotations")
recording_dir = os.path.join(data_root, "recordings")

sample_names = sorted([f[:-5] for f in os.listdir(annot_dir)])

train_samples = tuple(sample_names)
for sample_name in sample_names:
  register_coco_instances(sample_name, {}, os.path.join(annot_dir,sample_name+".json"), os.path.join(recording_dir))

setup_logger()

#https://stackoverflow.com/questions/67061435/how-to-train-detectron2-model-with-multiple-custom-dataset
cfg = get_cfg()
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml")
cfg.MODEL.MASK_ON = True
cfg.DATASETS.TRAIN = train_samples
cfg.INPUT.MASK_FORMAT = "bitmask"
cfg.DATALOADER.NUM_WORKERS = 4

cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.01
cfg.SOLVER.GAMMA = 0.1
cfg.SOLVER.MAX_ITER = 540000
cfg.SOLVER.STEPS = (270000, 420000, 500000)
cfg.SOLVER.WARMUP_ITERS = 1000

"""
cfg.SOLVER.IMS_PER_BATCH = 4
cfg.SOLVER.BASE_LR = 0.001
cfg.SOLVER.GAMMA = 0.1
cfg.SOLVER.MAX_ITER = 270000
cfg.SOLVER.STEPS = (210000, 250000)
cfg.SOLVER.WARMUP_ITERS = 1000
"""

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.OUTPUT_DIR = "/home/szage/runs/trained_nets/GeneralTracking/MOTSynthCVPR22/InstanceSegmentation"
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

with open(os.path.join(cfg.OUTPUT_DIR,"config.yaml"), "w") as f:
    cfg.dump(stream=f)

augs = [
    T.RandomApply(T.RandomBrightness(0.7, 1.3),prob=0.2),
    T.RandomApply(T.RandomContrast(0.7, 1.3),prob=0.2),
    T.RandomApply(T.RandomFlip(prob=1.0),prob=0.5),
    T.RandomApply(T.RandomExtent((0.7, 1.3),(1,1)),prob=0.2)
]

class MyTrainer(DefaultTrainer):
    @classmethod
    def build_train_loader(cls, cfg):
        mapper = DatasetMapper(cfg, is_train=True, augmentations=augs)
        return build_detection_train_loader(cfg, mapper=mapper)

trainer = MyTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()