# -*- coding: utf-8 -*-
"""MOTSynthCVPR22_AnnotPrep_TrackingFramewise_SmallRes_Comittable.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1Zqa4hRwna6T2PeKtect3JnICOXFPBPMN
"""

import json
import numpy as np
import pycocotools.mask as maskUtils
import os
from time import time
import traceback
import shutil
import psutil
from PIL import Image

# Train samples: 1: [0-100), 2: [100-200), 3: [200-300), 4:, [300-400), 5: [400-500), 6: [500-600)
# Test samples: 7: [600-700), 8: [700-768)

sample_names = list(map(lambda x: str(x).zfill(3), range(0,100)))
sample_range_str = f"[{sample_names[0]},{sample_names[-1]}]"

output_root = "/home/szage/runs/data/MOTSynthCVPR22/transformed_smallres/train"

message_logfile_name = os.path.join(output_root,"logs",sample_range_str+"annotTRfwise"+"_message_log.txt")
error_logfile_name = os.path.join(output_root,"logs",sample_range_str+"annotTRfwise"+"_error_log.txt")
corrupt_samples_logfile_name = os.path.join(output_root,"logs",sample_range_str+"annotTRfwise"+"_corrupt_samples.txt")
open(message_logfile_name, 'w').close()
open(error_logfile_name, 'w').close()
open(corrupt_samples_logfile_name, 'w').close()

for sample_name in sample_names:
  if os.path.exists(os.path.join(output_root, "recordings", "frames", sample_name)):
    try:
      t1 = time()
      max_memory_usage = 0
      os.makedirs(os.path.join(output_root,"recordings","frames",sample_name,"annot_IDs"), exist_ok=True)
      OriginalAnnot = None
      with open(os.path.join(output_root, "annotations", sample_name+".json")) as json_file:
        OriginalAnnot = json.load(json_file)

      H = OriginalAnnot["info"]["img_height"]
      W = OriginalAnnot["info"]["img_width"]

      AnnotFrameDict = {}
      for item in OriginalAnnot['annotations']:
          image_id = item['image_id']
          if image_id not in AnnotFrameDict:
              AnnotFrameDict[image_id] = []
          AnnotFrameDict[image_id].append(item)

      for frame in range(len(AnnotFrameDict)):
        img_annot = AnnotFrameDict[int(sample_name+str(frame).zfill(4))]
        combined_segmentations = np.zeros((H,W), dtype=np.uint16)
        segmentation = None
        for obj in img_annot:
          segmentation = np.array(maskUtils.decode(obj["segmentation"]), dtype=bool)
          id = obj["ped_id"]
          combined_segmentations[segmentation]=id

        current_memory_usage = psutil.virtual_memory().used / (1024 ** 3)
        max_memory_usage = max(max_memory_usage, current_memory_usage)

        combined_segmentations = Image.fromarray(combined_segmentations)
        combined_segmentations.save(os.path.join(output_root,"recordings","frames",sample_name,"annot_IDs", str(frame + 1).zfill(4) +".png"))
      t2 = time()
      with open(message_logfile_name, 'a') as f:
          f.write(f"sample [{sample_name}]\tprocessing time: {t2 - t1:.2f} s\tmax RAM usage: {max_memory_usage:.2f} GB\n")
    except Exception:
      with open(error_logfile_name, 'a') as f:
        f.write(f"sample error\t[{sample_name}]\n{traceback.format_exc()}\n")
      with open(corrupt_samples_logfile_name, 'a') as f:
        f.write(f"{sample_name}\n")
      try:
        shutil.rmtree(os.path.join(output_root,"recordings","frames",sample_name,"annot_IDs"))
      except: pass