# -*- coding: utf-8 -*-
"""MOTSynthCVPR22_AnnotPrep_SmallRes_Comittable.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/13XP-Ls9JQJDN21kiRU0-o2aIyyhdY_87
"""

import json
import numpy as np
import cv2
import pycocotools.mask as maskUtils
import os
from time import time
import multiprocessing
import traceback
import shutil

# Train samples: 1: [0-100), 2: [100-200), 3: [200-300), 4:, [300-400), 5: [400-500), 6: [500-600)
# Test samples: 7: [600-700), 8: [700-768)

sample_names = list(map(lambda x: str(x).zfill(3), range(0,100)))
sample_range_str = f"[{sample_names[0]},{sample_names[-1]}]"

input_root = "/home/szage/runs/data/MOTSynthCVPR22/original"
output_root = "/home/szage/runs/data/MOTSynthCVPR22/transformed_smallres/train"
max_allowed_multiprocesses = 16

resize_factor = 0.25
min_threshold_ratio = 0.0005

annot_dir = os.path.join(input_root, "annotations/")
recording_dir = os.path.join(input_root, "recordings/")

os.makedirs(os.path.join(output_root,"annotations"), exist_ok=True)
os.makedirs(os.path.join(output_root,"recordings"), exist_ok=True)
os.makedirs(os.path.join(output_root,"logs"), exist_ok=True)

message_logfile_name = os.path.join(output_root,"logs",sample_range_str+"annot"+"_message_log.txt")
error_logfile_name = os.path.join(output_root,"logs",sample_range_str+"annot"+"_error_log.txt")
corrupt_samples_logfile_name = os.path.join(output_root,"logs",sample_range_str+"annot"+"_corrupt_samples.txt")
open(message_logfile_name, 'w').close()
open(error_logfile_name, 'w').close()
open(corrupt_samples_logfile_name, 'w').close()

num_processes = multiprocessing.cpu_count()
with open(message_logfile_name, 'a') as f:
  f.write(f"cpu count: {num_processes}\n")
  f.write(f"max allowed processes: {max_allowed_multiprocesses}\n")

def process_annot(sample_name, recording_dir, annot_dir, output_dir, resize_factor, min_threshold_ratio, error_logfile_name):
  try:
    os.makedirs(os.path.join(output_dir,"recordings","frames",sample_name,"rgb"), exist_ok=True)
    OriginalAnnot = None
    TransformedAnnot = {}

    t1 = time()
    with open(annot_dir+sample_name+".json") as json_file:
      OriginalAnnot = json.load(json_file)
    t2 = time()

    TransformedAnnot["info"] = OriginalAnnot["info"]
    TransformedAnnot["info"]["img_height"]=int(TransformedAnnot["info"]["img_height"]*resize_factor)
    TransformedAnnot["info"]["img_width"]=int(TransformedAnnot["info"]["img_width"]*resize_factor)
    H = TransformedAnnot["info"]["img_height"]
    W = TransformedAnnot["info"]["img_width"]

    TransformedAnnot["licenses"] = OriginalAnnot["licenses"]

    TransformedAnnot["images"] = OriginalAnnot["images"]
    for i in range(len(TransformedAnnot["images"])):
      TransformedAnnot["images"][i]['height'] = H
      TransformedAnnot["images"][i]['width'] = W

    TransformedAnnot["categories"] = [{"id":  1, "name": "person", "supercategory": "person"}]

    TransformedAnnot["annotations"] = []

    AnnotFrameDict = {}
    for item in OriginalAnnot['annotations']:
        image_id = item['image_id']
        if image_id not in AnnotFrameDict:
            AnnotFrameDict[image_id] = []
        AnnotFrameDict[image_id].append(item)

    for frame in range(len(AnnotFrameDict)):

      img_annot = AnnotFrameDict[int(sample_name+str(frame).zfill(4))]

      accepted_annots = []

      for obj in img_annot:

        segmentation = maskUtils.decode(obj["segmentation"])
        segmentation = cv2.resize(segmentation, (W, H))

        sum_segm = np.sum(segmentation)
        if sum_segm>H*W*min_threshold_ratio:

          segmentation_RLE = maskUtils.encode(np.asarray(segmentation, order="F", dtype="uint8"))
          segmentation_RLE["counts"] = segmentation_RLE["counts"].decode('utf-8')

          x, y, w, h = cv2.boundingRect(segmentation)

          accepted_annot = {"segmentation": segmentation_RLE,
                            "iscrowd": obj["iscrowd"],
                            "image_id": obj["image_id"],
                            "category_id": obj["category_id"],
                            "id": obj["id"],
                            "bbox": [int(x),int(y),int(w),int(h)],
                            "area": int(np.sum(segmentation)),
                            "ped_id": obj["ped_id"]}
          accepted_annots.append(accepted_annot)

      TransformedAnnot["annotations"].extend(accepted_annots)

    t3 = time()
    with open(os.path.join(output_dir, "annotations", sample_name+".json"), 'w') as json_file:
      json.dump(TransformedAnnot, json_file)

    t4 = time()
    with open(message_logfile_name, 'a') as f:
      f.write(f"<- Sample {sample_name} |\tLoad: {t2 - t1:.2f} s\tTf: {t3 - t2:.2f} s\tSave: {t4 - t3:.2f} s\n")

  except Exception:
    with open(error_logfile_name, 'a') as f:
      f.write(f"sample error\t[{sample_name}]\n{traceback.format_exc()}\n")
    with open(corrupt_samples_logfile_name, 'a') as f:
      f.write(f"{sample_name}\n")
    try:
      shutil.rmtree(os.path.join(output_dir,"recordings","frames",sample_name))
    except: pass
    try:
      os.remove(os.path.join(output_dir, "annotations", sample_name+".json"))
    except: pass

def process_annot_wrapper(sample_name):
  with open(message_logfile_name, 'a') as f:
    f.write(f"-> Sample {sample_name} in pool\n")
  process_annot(sample_name, recording_dir, annot_dir, output_root, resize_factor, min_threshold_ratio, error_logfile_name)


num_required_processes = min([len(sample_names), num_processes, max_allowed_multiprocesses])
pool = multiprocessing.Pool(processes=num_required_processes)
chunk_size = len(sample_names) // num_required_processes
pool.map(process_annot_wrapper, sample_names, chunksize=chunk_size)
pool.close()
pool.join()