# -*- coding: utf-8 -*-
"""ArrowSynth_LocalTrackingTrain_TrainPercentage_Comittable.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1yfuX0TsVYqhpS4ZOay34h909CxMh7QYz
"""

import sys
import numpy as np
import cv2
import os
from PIL import Image
import time
import random
import multiprocessing
import re

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import transforms as T
#from torchvision.transforms import v2 as T

import segmentation_models_pytorch as smp

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import CosineAnnealingLR

import gc

tracking_range = 4
train_percentage = 1.0
epochs = 50
warm_restarts_per_epoch = 2
scheduler_switch_epochs = 8
backbone = "resnet50"
model_save_steps = 1000
resume = False
max_workers = 8

sample_category = "ArrowSynthTurns"
data_root = f"/home/szage/runs/data/ArrowSynth/{sample_category}_train_samples"
sample_names = sorted(os.listdir(data_root))
sample_names = sample_names[:int(len(sample_names) * train_percentage)]

model_name = f"{sample_category}_[DLV3p,{backbone}]_FBtr{tracking_range}_Ep{epochs}_Adv2_SR{train_percentage}"
output_root = "/home/szage/runs/trained_nets/GeneralTracking/ArrowSynth/"

output_dir = None
if os.path.exists(output_root) and os.path.isdir(output_root):
  output_dir = os.path.join(output_root, f"{sample_category}/LocalTracking/{sample_category}_DLV3p_{backbone}_FBtr{tracking_range}_Ep{epochs}_SR{train_percentage}")
  os.makedirs(output_dir, exist_ok=True)
else:
  raise FileNotFoundError(f"The directory '{output_root}' does not exist.")

num_workers = min(max_workers, multiprocessing.cpu_count())
print(f"Used CPU workers for dataloading: {num_workers}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

def remove_extension(file_list):
    return [file.split('.')[0] for file in file_list]

def format_duration(seconds):
    days, remainder = divmod(seconds, 86400)
    hours, remainder = divmod(remainder, 3600)
    minutes, seconds = divmod(remainder, 60)

    formatted_time = f"{int(days)} days {int(hours):02}:{int(minutes):02}:{int(seconds):02}"
    return formatted_time

class InvalidDataError(Exception):
    pass

# Defining dataset and color augmentations

class CTCLocalTrackingDataset(Dataset):
    def __init__(self, data_root, sample_names, tracking_range, color_transforms=None):
        self.data_root = data_root
        self.samples_list = self.__RegisterValidSamples__(data_root, sample_names)
        print("Registered Samples: ")
        for i in range(len(self.samples_list)):
          print(f"sample: {self.samples_list[i]['sample_name']}, num_frames: {self.samples_list[i]['num_frames']}")
        self.tracking_range = tracking_range
        self.color_transforms = color_transforms
        #print(f"Amount of available frames: {len(self.sample_names)*self.video_len}\n")

    def __RegisterValidSamples__(self, data_root, sample_names):
      samples_list = []
      for sample_name in sample_names:
        video_folder = os.path.join(data_root, sample_name, 'imgs')
        gt_folder = os.path.join(data_root, sample_name, 'labels')

        image_files = sorted([file for file in os.listdir(video_folder) if file.endswith('.png')])
        IDs = sorted([file[0:7] for file in image_files])

        track_files = sorted([file for file in os.listdir(gt_folder) if file.endswith('.png')])
        track_IDs = sorted([file[0:7] for file in track_files])
        if len(set(IDs)-set(track_IDs)) == 0:
          sample_dict = {"sample_name": sample_name, "num_frames": len(IDs), "image_files": image_files, "annot_files": track_files}
          samples_list.append(sample_dict)
        else:
          print(f"Sample {sample_name} does not contain all the necessary tracking annotations, and thus it is not registered!")
      return samples_list

    def __len__(self):
      all_samples = 0
      for i in range(len(self.samples_list)):
        all_samples += self.samples_list[i]['num_frames']
      return all_samples

    def getitem(self,idx):
      return self.__getitem__(idx)

    def __getitem__(self, idx):
      frame = idx
      for sample_id, sample in enumerate(self.samples_list):
        if frame < sample['num_frames']:
            break
        else:
            frame -= sample['num_frames']
      sample = self.samples_list[sample_id]

      central_img = Image.open(os.path.join(self.data_root, sample['sample_name'], 'imgs', sample['image_files'][frame]))
      grayscale = False
      if len(np.shape(central_img)) == 2:
        grayscale = True
        central_img = np.expand_dims(central_img, axis=2)
      H, W, C = np.shape(central_img)
      ret_val = True

      try:

        input = [np.zeros([H,W,C])]*(2*self.tracking_range+1)
        annot = [np.zeros([H,W])]*(2*self.tracking_range+1)
        input[self.tracking_range] = central_img

        # Forward loop from center (to handle forward temporal edges)
        for dt in range(0, self.tracking_range+1):
          if frame+dt<sample['num_frames'] and dt != 0:
            img = np.array(Image.open(os.path.join(self.data_root, sample['sample_name'], 'imgs', sample['image_files'][frame+dt])))
            if grayscale:
              img = np.expand_dims(img, axis=2)
            input[self.tracking_range+dt] = img

          else:
            input[self.tracking_range+dt] = input[self.tracking_range]

          if frame+dt<sample['num_frames']:
            annot[self.tracking_range+dt] = np.array(Image.open(os.path.join(self.data_root, sample['sample_name'], 'labels', sample['annot_files'][frame+dt])))

          else:
            annot[self.tracking_range+dt] = annot[self.tracking_range]

        # Backward loop from center (to handle backward temporal edges)
        for dt in reversed(range(-self.tracking_range,1)):
          if frame+dt >= 0 and dt != 0:
            img = np.array(Image.open(os.path.join(self.data_root, sample['sample_name'], 'imgs', sample['image_files'][frame+dt])))
            if grayscale:
              img = np.expand_dims(img, axis=2)
            input[self.tracking_range+dt] = img

          else:
            input[self.tracking_range+dt] = input[self.tracking_range]

          if frame+dt >= 0 and dt != 0:
            annot[self.tracking_range+dt] = np.array(Image.open(os.path.join(self.data_root, sample['sample_name'], 'labels', sample['annot_files'][frame+dt])))
          else:
            annot[self.tracking_range+dt] = annot[self.tracking_range]

        input = np.array(input)
        annot = np.array(annot)

        # Select object randomly
        unique_values = np.unique(annot[self.tracking_range])
        unique_values = unique_values[unique_values != 0]
        object_id = None
        if len(unique_values)>0:
          random.shuffle(unique_values)
          object_id = unique_values[0]

        else:
          raise InvalidDataError(f"Data with no objects at index {idx}")
        if object_id is not None:
          label = annot == object_id
        else:
          raise InvalidDataError(f"Data with no unobscured objects at index {idx}")

        # Merge temporal and color dimensions, to get from [T,H,W,C] to [H, W, T combined C]
        input = input.transpose(1, 2, 0, 3).reshape(H, W, -1)
        label = np.transpose(label, (1, 2, 0))

        # Perform color augmentations (positional augmentations are not available for now)
        if self.color_transforms:
          for i in range(np.shape(label)[2]):
            if grayscale:
              input[:,:,i:i+1] = np.expand_dims(self.color_transforms(Image.fromarray(np.squeeze(input[:,:,i:i+1], axis = 2))), axis = 2)
            else:
              input[:,:,3*i:3*(i+1)] = self.color_transforms(Image.fromarray(input[:,:,3*i:3*(i+1)]))

        # Mark object on last input channel with its solid bounding box (not centroid marking as it may be outside of the object)
        input = np.array(input, dtype=np.uint8)
        annot = np.array(annot, dtype=np.uint8)
        bx, by, bw, bh = cv2.boundingRect(np.array(label[:,:,self.tracking_range], dtype=np.uint8)*255)
        bounding_rect = cv2.rectangle(np.zeros([H,W]), (bx, by), (bx + bw, by + bh), 255, thickness=cv2.FILLED)
        input = np.concatenate([input, np.expand_dims(bounding_rect, axis=2)], axis=2)

        # Transform the data from [H, W, C] to [C, H, W] and into float Torch tensors
        input = torch.tensor(input, dtype=torch.float32) / 255.0
        label = torch.tensor(label, dtype=torch.float32)
        input = input.permute(2,0,1)
        label = label.permute(2,0,1)

      except Exception as e:
        print(f"Error during data loading: {e}")
        input = torch.zeros([C*(2*self.tracking_range+1)+1, H, W], dtype=torch.float32)
        label = torch.zeros([2*self.tracking_range+1, H, W], dtype=torch.float32)
        ret_val = False


      # Pad the data to 16 divisible shape for the deeplabv3+ architecture
      pad_h = (16 - H % 16) % 16
      pad_w = (16 - W % 16) % 16
      input = F.pad(input, (0, pad_w, 0, pad_h), mode='constant', value=0)
      label = F.pad(label, (0, pad_w, 0, pad_h), mode='constant', value=0)

      return input, label, ret_val

ColorTransforms = T.RandomApply(torch.nn.ModuleList([
    T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.02)
]), p=0.5)

# SMP architecture (up to choice, but probably DeepLabV3+)

SMPModel = smp.DeepLabV3Plus(
    encoder_name=backbone,
    encoder_weights="imagenet",
    in_channels=(2*tracking_range+1)+1,
    classes=2*tracking_range+1,
)

batch_size = 8
loss_disp_period = 10
loss_function = nn.BCEWithLogitsLoss()

train_loader = DataLoader(CTCLocalTrackingDataset(data_root, sample_names, tracking_range = tracking_range, color_transforms = ColorTransforms),
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=num_workers)
iters = len(train_loader)

if resume:
  with open(os.path.join(output_dir,"last_model.txt"), "r") as f:
    last_model_name = f.read()
    SMPModel = torch.load(os.path.join(output_dir,last_model_name))
else:
  with open(os.path.join(output_dir,model_name+"_log.txt"), "w"): pass

SMPModel.to(device)
optimizer = optim.SGD(SMPModel.parameters(), lr=0.1, weight_decay=1e-5)
scheduler1 = CosineAnnealingWarmRestarts(optimizer, T_0=int(iters/warm_restarts_per_epoch))
scheduler2 = CosineAnnealingLR(optimizer, T_max = (epochs-scheduler_switch_epochs)*iters)

# Training the model

for epoch in range(epochs):
  running_loss = 0.0
  running_time = 0.0
  t = time.time()
  for i, data in enumerate(train_loader, 0):
    current_step = epoch*iters+i

    if all(data[2]):
      inputs, ground_truth = data[0], data[1]

      inputs=inputs.to(device)
      ground_truth=ground_truth.to(device)

      optimizer.zero_grad()
      output = SMPModel(inputs)
      loss = loss_function(output, ground_truth)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()

    if current_step<=scheduler_switch_epochs*iters:
      scheduler1.step()
    else:
      scheduler2.step()

    running_time += time.time()-t
    t = time.time()

    # print statistics
    if current_step!=0 and current_step % loss_disp_period == 0:
      step_time = running_time/loss_disp_period
      full_steps = iters*epochs
      time_estimate = (full_steps-current_step)*step_time
      lr = optimizer.param_groups[0]['lr']
      print(f'eta: {format_duration(time_estimate)}, t_step: {step_time:.2f} sec, ep: {epoch + 1}, iter: {current_step}/{full_steps}, lr: {lr:.2g}, loss: {running_loss/loss_disp_period:.2g}')
      with open(os.path.join(output_dir,model_name+"_log.txt"), "a+") as LogFile:
        LogFile.write(f'eta: {format_duration(time_estimate)}, t_step: {step_time:.2f} sec, ep: {epoch + 1}, iter: {current_step}/{full_steps}, lr: {lr:.2g}, loss: {running_loss/loss_disp_period:.2g}\n')
      running_loss = 0.0
      running_time = 0.0
      gc.collect()

    if current_step!=0 and current_step % model_save_steps == 0:
      torch.save(SMPModel, os.path.join(output_dir,f"{model_name}_{current_step}.pth"))
      with open(os.path.join(output_dir,"last_model.txt"), "w") as f:
        f.write(f"{model_name}_{current_step}.pth")

torch.save(SMPModel, os.path.join(output_dir,f"{model_name}_final.pth"))
with open(os.path.join(output_dir,"last_model.txt"), "w") as f:
  f.write(f"{model_name}_final.pth")

print('Finished Training of SMP model')