# -*- coding: utf-8 -*-
"""MOTSynthCVPR22_LocalTrackingTrain_Comittable.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1KyRLIT924hHyB_hlS271cAWJJ-AqwQGQ
"""

import sys
import numpy as np
import cv2
import os
from PIL import Image
import time
import random
import multiprocessing

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import transforms as T
#from torchvision.transforms import v2 as T

import segmentation_models_pytorch as smp

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import CosineAnnealingLR

import gc

tracking_range = 2
epochs = 4
warm_restarts_per_epoch = 5
scheduler_switch_epochs = 3
backbone = "resnet50"
model_save_steps = 10000
resume = False
max_workers = 8

data_root = "/home/szage/runs/data/MOTSynthCVPR22/transformed_smallres/train/recordings/frames/"
model_name = f"MOTSynth_KernelTracker_[DLV3p,{backbone}]_FBtr{tracking_range}_Ep{epochs}_Adv2"
output_root = "/home/szage/runs/trained_nets/GeneralTracking/MOTSynthCVPR22/LocalTracking/"

output_dir = None
if os.path.exists(output_root) and os.path.isdir(output_root):
  output_dir = os.path.join(output_root, f"DLV3p_{backbone}_FBtr{tracking_range}_Ep{epochs}")
  os.makedirs(output_dir, exist_ok=True)
else:
  raise FileNotFoundError(f"The directory '{output_root}' does not exist.")

num_workers = min(max_workers, multiprocessing.cpu_count())
print(f"Used CPU workers for dataloading: {num_workers}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Utility functions and definitions

def remove_extension(file_list):
    return [file.split('.')[0] for file in file_list]

def format_duration(seconds):
    days, remainder = divmod(seconds, 86400)
    hours, remainder = divmod(remainder, 3600)
    minutes, seconds = divmod(remainder, 60)

    formatted_time = f"{int(days)} days {int(hours):02}:{int(minutes):02}:{int(seconds):02}"
    return formatted_time

class InvalidDataError(Exception):
    pass

# Defining dataset and color augmentations

class MOTSynthLocalTrackingDataset(Dataset):
    def __init__(self, data_root, tracking_range, video_len = 1800, color_transforms=None, min_central_obj_threshold_ratio = 0.0005):
        self.data_root = data_root
        valid_sample_names = self.__RegisterValidSamples__(data_root, video_len)
        print("Registered Samples: ")
        print('\n'.join(map(lambda name: f'- {name}', valid_sample_names)))
        self.sample_names = list(valid_sample_names)
        self.tracking_range = tracking_range
        self.video_len = video_len
        self.color_transforms = color_transforms
        self.min_central_obj_threshold_ratio = min_central_obj_threshold_ratio
        print(f"Amount of available frames: {len(self.sample_names)*self.video_len}\n")

    def __RegisterValidSamples__(self, data_root, video_len = None):
        sample_names = set(os.listdir(data_root))
        valid_sample_names = set()
        for sample_name in sample_names:
          sample_path = os.path.join(data_root, sample_name)
          sample_path_folders = os.listdir(sample_path)
          if ("annot_IDs" in sample_path_folders) and ("rgb" in sample_path_folders):
            annots = set(remove_extension(os.listdir(os.path.join(sample_path, "annot_IDs"))))
            rgbs = set(remove_extension(os.listdir(os.path.join(sample_path, "rgb"))))
            if annots == rgbs and (video_len is None or len(rgbs) == video_len):
              valid_sample_names.add(sample_name)
        return valid_sample_names

    def __len__(self):
        return len(self.sample_names)*self.video_len

    def __getitem__(self, idx):
      video_name = self.sample_names[int(idx/self.video_len)]
      frame = idx%self.video_len+1
      central_img = cv2.imread(os.path.join(self.data_root,video_name,"rgb", "{:04d}".format(frame)+".jpg"))
      H, W, C = np.shape(central_img)
      ret_val = True

      try:

        input = [np.zeros([H,W,C])]*(2*self.tracking_range+1)
        annot = [np.zeros([H,W])]*(2*self.tracking_range+1)
        input[self.tracking_range] = central_img

        # Forward loop from center (to handle forward temporal edges)
        for dt in range(0, self.tracking_range+1):
          if frame+dt<=self.video_len and dt != 0:
            input[self.tracking_range+dt] = cv2.imread(os.path.join(self.data_root,video_name,"rgb",
                                                                    "{:04d}".format(frame+dt)+".jpg"))
          else:
            input[self.tracking_range+dt] = input[self.tracking_range]

          if frame+dt<=self.video_len:
            annot[self.tracking_range+dt] = np.array(Image.open(os.path.join(self.data_root,video_name,"annot_IDs",
                                                                                "{:04d}".format(frame+dt)+".png")))
          else:
            annot[self.tracking_range+dt] = annot[self.tracking_range]

        # Backward loop from center (to handle backward temporal edges)
        for dt in reversed(range(-self.tracking_range,1)):
          if frame+dt > 0 and dt != 0:
            input[self.tracking_range+dt] = cv2.imread(os.path.join(self.data_root,video_name,"rgb",
                                                                    "{:04d}".format(frame+dt)+".jpg"))
          else:
            input[self.tracking_range+dt] = input[self.tracking_range]

          if frame+dt > 0 and dt != 0:
            annot[self.tracking_range+dt] = np.array(Image.open(os.path.join(self.data_root,video_name,"annot_IDs",
                                                                                "{:04d}".format(frame+dt)+".png")))
          else:
            annot[self.tracking_range+dt] = annot[self.tracking_range]

        input = np.array(input)
        annot = np.array(annot)

        # Select object randomly, but only one with sufficient amount of pixels in central image
        unique_values = np.unique(annot[self.tracking_range])
        unique_values = unique_values[unique_values != 0]
        valid_object_id = None
        if len(unique_values)>0:
          random.shuffle(unique_values)
          for object_id in unique_values:
            central_img_obj = np.array(annot[self.tracking_range] == object_id, dtype=int)
            if np.sum(central_img_obj) > H*W*self.min_central_obj_threshold_ratio:
              valid_object_id = object_id
              break
        else:
          raise InvalidDataError(f"Data with no objects at index {idx}")
        if valid_object_id is not None:
          label = annot == valid_object_id
        else:
          raise InvalidDataError(f"Data with no unobscured objects at index {idx}")

        # Merge temporal and color dimensions, to get from [T,H,W,C] to [H, W, T combined C]
        input = input.transpose(1, 2, 0, 3).reshape(H, W, -1)
        label = np.transpose(label, (1, 2, 0))

        # Perform color augmentations (positional augmentations are not available for now)
        if self.color_transforms:
          for i in range(np.shape(label)[2]):
            input[:,:,3*i:3*(i+1)] = self.color_transforms(Image.fromarray(input[:,:,3*i:3*(i+1)]))

        # Mark object on last input channel with its solid bounding box (not centroid marking as it may be outside of the object)
        input = np.array(input, dtype=np.uint8)
        annot = np.array(annot, dtype=np.uint8)
        bx, by, bw, bh = cv2.boundingRect(np.array(label[:,:,self.tracking_range], dtype=np.uint8)*255)
        bounding_rect = cv2.rectangle(np.zeros([H,W]), (bx, by), (bx + bw, by + bh), 255, thickness=cv2.FILLED)
        input = np.concatenate([input, np.expand_dims(bounding_rect, axis=2)], axis=2)

        # Transform the data from [H, W, C] to [C, H, W] and into float Torch tensors
        input = torch.tensor(input, dtype=torch.float32) / 255.0
        label = torch.tensor(label, dtype=torch.float32)
        input = input.permute(2,0,1)
        label = label.permute(2,0,1)

      except Exception as e:
        print(f"Error during data loading: {e}")
        input = torch.zeros([C*(2*self.tracking_range+1)+1, H, W], dtype=torch.float32)
        label = torch.zeros([2*self.tracking_range+1, H, W], dtype=torch.float32)
        ret_val = False

      # Pad the data to 16 divisible shape for the deeplabv3+ architecture
      pad_h = (16 - H % 16) % 16
      pad_w = (16 - W % 16) % 16
      input = F.pad(input, (0, pad_w, 0, pad_h), mode='constant', value=0)
      label = F.pad(label, (0, pad_w, 0, pad_h), mode='constant', value=0)

      return input, label, ret_val

ColorTransforms = T.RandomApply(torch.nn.ModuleList([
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.02)
]), p=0.5)

# SMP architecture (up to choice, but probably DeepLabV3+)

SMPModel = smp.DeepLabV3Plus(
    encoder_name=backbone,
    encoder_weights="imagenet",
    in_channels=3*(2*tracking_range+1)+1,
    classes=2*tracking_range+1,
)

# Setting up training parameters and environment, loading the training data

batch_size = 8
loss_disp_period = 10
loss_function = nn.BCEWithLogitsLoss()

train_loader = DataLoader(MOTSynthLocalTrackingDataset(data_root, tracking_range = tracking_range, color_transforms = ColorTransforms, min_central_obj_threshold_ratio = 0.0005),
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=num_workers)
iters = len(train_loader)

if resume:
  with open(os.path.join(output_dir,"last_model.txt"), "r") as f:
    last_model_name = f.read()
    SMPModel = torch.load(os.path.join(output_dir,last_model_name))
else:
  with open(os.path.join(output_dir,model_name+"_log.txt"), "w"): pass

SMPModel.to(device)
optimizer = optim.SGD(SMPModel.parameters(), lr=0.1, weight_decay=1e-5)
scheduler1 = CosineAnnealingWarmRestarts(optimizer, T_0=int(iters/warm_restarts_per_epoch))
scheduler2 = CosineAnnealingLR(optimizer, T_max = (epochs-scheduler_switch_epochs)*iters)

# Training the model

for epoch in range(epochs):
  running_loss = 0.0
  running_time = 0.0
  t = time.time()
  for i, data in enumerate(train_loader, 0):
    current_step = epoch*iters+i

    if all(data[2]):
      inputs, ground_truth = data[0], data[1]

      inputs=inputs.to(device)
      ground_truth=ground_truth.to(device)

      optimizer.zero_grad()
      output = SMPModel(inputs)
      loss = loss_function(output, ground_truth)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    if current_step<=scheduler_switch_epochs*iters:
      scheduler1.step()
    else:
      scheduler2.step()

    running_time += time.time()-t
    t = time.time()

    # print statistics
    if current_step!=0 and current_step % loss_disp_period == 0:
      step_time = running_time/loss_disp_period
      full_steps = iters*epochs
      time_estimate = (full_steps-current_step)*step_time
      lr = optimizer.param_groups[0]['lr']
      print(f'eta: {format_duration(time_estimate)}, t_step: {step_time:.2f} sec, ep: {epoch + 1}, iter: {current_step}/{full_steps}, lr: {lr:.2g}, loss: {running_loss/loss_disp_period:.2g}')
      with open(os.path.join(output_dir,model_name+"_log.txt"), "a+") as LogFile:
        LogFile.write(f'eta: {format_duration(time_estimate)}, t_step: {step_time:.2f} sec, ep: {epoch + 1}, iter: {current_step}/{full_steps}, lr: {lr:.2g}, loss: {running_loss/loss_disp_period:.2g}\n')
      running_loss = 0.0
      running_time = 0.0
      gc.collect()

    if current_step!=0 and current_step % model_save_steps == 0:
      torch.save(SMPModel, os.path.join(output_dir,f"{model_name}_{current_step}.pth"))
      with open(os.path.join(output_dir,"last_model.txt"), "w") as f:
        f.write(f"{model_name}_{current_step}.pth")

torch.save(SMPModel, os.path.join(output_dir,f"{model_name}_final.pth"))
with open(os.path.join(output_dir,"last_model.txt"), "w") as f:
  f.write(f"{model_name}_final.pth")

print('Finished Training of SMP model')