# -*- coding: utf-8 -*-
"""AmobeSynth_Predictor_Comittable.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1kte2EK6AWB9UCpYRF1zmxgexL0DV9x3K
"""

#@title Importing required functions and libraries

# Imports from Symmetry-Tracker repo

import sys
sys.path.append("/home/szage/runs/evaluation/AmobeSynth/requirements/Symmetry-Tracker")

from symmetry_tracker.general_functionalities.video_transformation import TransformVideoFromTIFF

from symmetry_tracker.segmentation.segmentator import SingleVideoSegmentation
from symmetry_tracker.segmentation.segmentation_io import DisplaySegmentation, WriteSegmentation

from symmetry_tracker.tracking.symmetry_tracker import SingleVideoSymmetryTracking
from symmetry_tracker.tracking.tracking_io import DisplayTracks, WriteTracks, SaveTracksVideo, SaveTracks, LoadTracks
from symmetry_tracker.tracking.post_processing import InterpolateMissingObjects, RemoveShortPaths, HeuristicalEquivalence

from symmetry_tracker.tracking.kalman_tracker import SingleVideoKalmanTracking
from symmetry_tracker.tracking.symmetry_tracker_l2dist import SingleVideoSymmetryTracking_L2Distance
from symmetry_tracker.tracking.symmetry_tracker_shapedist import SingleVideoSymmetryTracking_ShapeDistance

# Other necessary imports

import torch
import os
import shutil

#@title Pipeline parameter setup

sample_record_name = "AmobeSynthPhaseContrast"
sample_record_name_input = "amobesynth_phasecontrast_test_samples"
sample_ratio = "1.0"

#Tracker train parameters
TimeKernelSize = 4
Epochs = 50

# Input paths
ModelsRoot = "/home/szage/runs/trained_nets/GeneralTracking/AmobeSynth/"
SegmentationModelPath = os.path.join(ModelsRoot, sample_record_name, "InstanceSegmentation", f"SampleReduced{sample_ratio}", "model_final.pth")
SegmentationModelConfigPath = os.path.join(ModelsRoot, sample_record_name, "InstanceSegmentation", f"SampleReduced{sample_ratio}", "config.yaml")
TrackingModelPath = os.path.join(ModelsRoot, sample_record_name, "LocalTracking",
                                 f"{sample_record_name}_DLV3p_resnet50_FBtr{TimeKernelSize}_Ep{Epochs}_SR{sample_ratio}",
                                 f"{sample_record_name}_[DLV3p,resnet50]_FBtr{TimeKernelSize}_Ep{Epochs}_Adv2_SR{sample_ratio}_final.pth")

# Output paths
OutputRoot = f"./home/szage/runs/evaluation/AmobeSynth/predictions/"
SegmentationBaseRoot = os.path.join(OutputRoot,"segmentations")
TrackingBaseRoot = os.path.join(OutputRoot,"tracks")
SegmentationSaveRoot = os.path.join(SegmentationBaseRoot,f"{sample_record_name}_SR{sample_ratio}/")
TrackingWriteRoot = os.path.join(TrackingBaseRoot,f"{sample_record_name}_SR{sample_ratio}/")

if not os.path.exists(SegmentationBaseRoot):
    os.makedirs(SegmentationBaseRoot)
if not os.path.exists(TrackingBaseRoot):
    os.makedirs(TrackingBaseRoot)
if not os.path.exists(SegmentationSaveRoot):
    os.makedirs(SegmentationSaveRoot)
if not os.path.exists(TrackingWriteRoot):
    os.makedirs(TrackingWriteRoot)

# Matching colab environment (for now GPU vs CPU)
Device = ("cuda:0" if torch.cuda.is_available() else "cpu")
print("Colab environment: "+Device)

#@title Full Prediction on multiple videos

import time

def format_time(seconds):
    hours, remainder = divmod(seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}"

print(f"Predicting: {sample_record_name}_SR{sample_ratio}")

for sample_id_numeric in range(81, 101):

  t0 = time.time()

  sample_id = f"{sample_id_numeric:03}"

  InputVideoPath = f"./home/szage/runs/data/AmobeSynth/{sample_record_name_input}/{sample_id}/imgs/"
  SegmentationSavePath = os.path.join(SegmentationSaveRoot, f"{sample_id}.txt")
  TrackingWritePath = os.path.join(TrackingWriteRoot, f"{sample_id}.txt")

  Outmasks = SingleVideoSegmentation(InputVideoPath,
                                    SegmentationModelPath,
                                    SegmentationModelConfigPath,
                                    Device,
                                    Color = "GRAYSCALE",
                                    ScoreThreshold = 0.4)

  WriteSegmentation(Outmasks, SegmentationSavePath)

  AnnotDF = SingleVideoSymmetryTracking(InputVideoPath,
                                        TrackingModelPath,
                                        Device,
                                        SegmentationSavePath,
                                        TimeKernelSize = TimeKernelSize,
                                        Color = "GRAYSCALE",
                                        Marker = "BBOX",
                                        SegmentationConfidence = 0.2,
                                        MinRequiredSimilarity = 0.2,
                                        MaxTimeKernelShift = None)

  AnnotDF = InterpolateMissingObjects(AnnotDF)

  WriteTracks(AnnotDF, TrackingWritePath)

  t1 = time.time()
  dt_formatted = format_time(t1-t0)

  print(f"Processed: {sample_id}\tdt: {dt_formatted}")

print("Prediction done")