# -*- coding: utf-8 -*-
"""ArrowSynth_Predictor_Comittable.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1fYiVe12YiUqW8kiy55P99ZvSaM1Jx3Lm
"""

#@title Importing required functions and libraries

# Imports from Symmetry-Tracker repo

import sys
sys.path.append("/home/szage/runs/evaluation/ArrowSynth/requirements/Symmetry-Tracker")

from symmetry_tracker.general_functionalities.video_transformation import TransformVideoFromTIFF

from symmetry_tracker.segmentation.segmentator import SingleVideoSegmentation
from symmetry_tracker.segmentation.segmentation_io import DisplaySegmentation, WriteSegmentation

from symmetry_tracker.tracking.symmetry_tracker import SingleVideoSymmetryTracking
from symmetry_tracker.tracking.tracking_io import DisplayTracks, WriteTracks, SaveTracksVideo, SaveTracks, LoadTracks
from symmetry_tracker.tracking.post_processing import InterpolateMissingObjects, RemoveShortPaths, HeuristicalEquivalence

from symmetry_tracker.tracking.kalman_tracker import SingleVideoKalmanTracking
from symmetry_tracker.tracking.symmetry_tracker_l2dist import SingleVideoSymmetryTracking_L2Distance
from symmetry_tracker.tracking.symmetry_tracker_shapedist import SingleVideoSymmetryTracking_ShapeDistance

# Other necessary imports

import torch
import os
import shutil

#@title Pipeline parameter setup

sample_record_name = "ArrowSynthTurns"
predictor_names = ["Kalman", "Symmetry", "SymmetryL2", "SymmetryShape"]
sample_ratio = "1.0"

#Tracker train parameters
TimeKernelSize = 4
Epochs = 50

# Input paths
ModelsRoot = "/home/szage/runs/trained_nets/GeneralTracking/ArrowSynth/"
SegmentationModelPath = os.path.join(ModelsRoot, sample_record_name, "InstanceSegmentation", f"SampleReduced{sample_ratio}", "model_final.pth")
SegmentationModelConfigPath = os.path.join(ModelsRoot, sample_record_name, "InstanceSegmentation", f"SampleReduced{sample_ratio}", "config.yaml")
TrackingModelPath = os.path.join(ModelsRoot, sample_record_name, "LocalTracking",
                                 f"{sample_record_name}_DLV3p_resnet50_FBtr{TimeKernelSize}_Ep{Epochs}_SR{sample_ratio}",
                                 f"{sample_record_name}_[DLV3p,resnet50]_FBtr{TimeKernelSize}_Ep{Epochs}_Adv2_SR{sample_ratio}_final.pth")

# Output paths
OutputRoot = f"/home/szage/runs/evaluation/ArrowSynth/predictions/"
SegmentationBaseRoot = os.path.join(OutputRoot,"segmentations")
TrackingBaseRoot = os.path.join(OutputRoot,"tracks")
SegmentationSaveRoot = os.path.join(SegmentationBaseRoot,f"{sample_record_name}_SR{sample_ratio}/")

if not os.path.exists(SegmentationBaseRoot):
    os.makedirs(SegmentationBaseRoot)
if not os.path.exists(TrackingBaseRoot):
    os.makedirs(TrackingBaseRoot)
if not os.path.exists(SegmentationSaveRoot):
    os.makedirs(SegmentationSaveRoot)

# Matching colab environment (for now GPU vs CPU)
Device = ("cuda:0" if torch.cuda.is_available() else "cpu")
print("Colab environment: "+Device)

#@title Full Prediction on multiple videos

import time

def format_time(seconds):
    hours, remainder = divmod(seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"

print(f"Predicting: {sample_record_name}_SR{sample_ratio}")

for sample_id_numeric in range(81, 101):
  t0 = time.time()

  sample_id = f"{sample_id_numeric:03}"

  InputVideoPath = f"/home/szage/runs/data/AmobeSynth/{sample_record_name_input}/{sample_id}/imgs/"
  SegmentationSavePath = os.path.join(SegmentationSaveRoot, f"{sample_id}.txt")

  Outmasks = SingleVideoSegmentation(InputVideoPath,
                                    SegmentationModelPath,
                                    SegmentationModelConfigPath,
                                    Device,
                                    Color = "GRAYSCALE",
                                    ScoreThreshold = 0.4)

  WriteSegmentation(Outmasks, SegmentationSavePath)
  for predictor_name in predictor_names:

    TrackingWriteRoot = os.path.join(TrackingBaseRoot,f"{sample_record_name}_SR{sample_ratio}_{predictor_name}/")
    if not os.path.exists(TrackingWriteRoot):
      os.makedirs(TrackingWriteRoot)
    TrackingWritePath = os.path.join(TrackingWriteRoot, f"{sample_id}.txt")

    if predictor_name == "Kalman":
      AnnotDF = SingleVideoKalmanTracking(InputVideoPath,
                                          SegmentationSavePath,
                                          MaxCentroidDistance = 40)

    elif predictor_name == "Symmetry":
      AnnotDF = SingleVideoSymmetryTracking(InputVideoPath,
                                            TrackingModelPath,
                                            Device,
                                            SegmentationSavePath,
                                            TimeKernelSize = TimeKernelSize,
                                            Color = "GRAYSCALE",
                                            Marker = "BBOX",
                                            SegmentationConfidence = 0.2,
                                            MinRequiredSimilarity = 0.2,
                                            MaxTimeKernelShift = None)

    elif predictor_name == "SymmetryL2":
      AnnotDF = SingleVideoSymmetryTracking_L2Distance(InputVideoPath,
                                            TrackingModelPath,
                                            Device,
                                            SegmentationSavePath,
                                            TimeKernelSize = TimeKernelSize,
                                            Color = "GRAYSCALE",
                                            Marker = "BBOX",
                                            SegmentationConfidence = 0.2,
                                            MinRequiredSimilarity = 0.2,
                                            MaxTimeKernelShift = None)

    if predictor_name == "SymmetryShape":
      elif predictor_name == "SymmetryL2":
      AnnotDF = SingleVideoSymmetryTracking_ShapeDistance(InputVideoPath,
                                            TrackingModelPath,
                                            Device,
                                            SegmentationSavePath,
                                            TimeKernelSize = TimeKernelSize,
                                            Color = "GRAYSCALE",
                                            Marker = "BBOX",
                                            SegmentationConfidence = 0.2,
                                            MinRequiredSimilarity = 0.2,
                                            MaxTimeKernelShift = None)

    AnnotDF = InterpolateMissingObjects(AnnotDF)

    WriteTracks(AnnotDF, TrackingWritePath)

  t1 = time.time()
  dt_formatted = format_time(t1-t0)

  print(f"Processed: {sample_id}\tdt: {dt_formatted}")

print("Prediction done")