# -*- coding: utf-8 -*-
"""MOTSynthCVPR22_VideoPrep_SmallRes_Comittable.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1qdPEeG63RIkd9xCL4TJ6WrGdoLlH1qV2
"""

import numpy as np
import os
from time import time
import traceback
import imageio
from PIL import Image
import psutil
import gc

# Train samples: 1: [0-100), 2: [100-200), 3: [200-300), 4:, [300-400), 5: [400-500), 6: [500-600)
# Test samples: 7: [600-700), 8: [700-768)

sample_names = list(map(lambda x: str(x).zfill(3), range(0,100)))
sample_range_str = f"[{sample_names[0]},{sample_names[-1]}]"

input_root = "/home/szage/runs/data/MOTSynthCVPR22/original"
output_root = "/home/szage/runs/data/MOTSynthCVPR22/transformed_smallres/train"

max_frames = 1800
resize_factor = 0.25

annot_dir = os.path.join(input_root, "annotations/")
recording_dir = os.path.join(input_root, "recordings/")

os.makedirs(os.path.join(output_root,"annotations"), exist_ok=True)
os.makedirs(os.path.join(output_root,"recordings"), exist_ok=True)
os.makedirs(os.path.join(output_root,"logs"), exist_ok=True)

message_logfile_name = os.path.join(output_root,"logs",sample_range_str+"video"+"_message_log.txt")
error_logfile_name = os.path.join(output_root,"logs",sample_range_str+"video"+"_error_log.txt")
corrupt_samples_logfile_name = os.path.join(output_root,"logs",sample_range_str+"video"+"_corrupt_samples.txt")
open(message_logfile_name, 'w').close()
open(error_logfile_name, 'w').close()
open(corrupt_samples_logfile_name, 'w').close()

for sample_name in sample_names:
  if os.path.exists(os.path.join(output_root, "recordings", "frames", sample_name)):
    try:
      t1 = time()
      max_memory_usage = 0
      with imageio.get_reader(recording_dir + sample_name + ".mp4") as reader:
        for frame in range(max_frames):
          img = Image.fromarray(reader.get_data(frame))

          current_memory_usage = psutil.virtual_memory().used / (1024 ** 3)
          max_memory_usage = max(max_memory_usage, current_memory_usage)

          [H_O, W_O, C] = np.array(img).shape
          img = img.resize((int(W_O * resize_factor), int(H_O * resize_factor)), Image.Resampling.LANCZOS)

          img.save(
            os.path.join(output_root, "recordings", "frames", sample_name, "rgb", str(frame + 1).zfill(4) + ".jpg")
          )

      reader.close()
      gc.collect()
      t2 = time()
      with open(message_logfile_name, 'a') as f:
        f.write(f"sample [{sample_name}]\tprocessing time: {t2 - t1:.2f} s\tmax RAM usage: {max_memory_usage:.2f} GB\n")
    except Exception:
      with open(error_logfile_name, 'a') as f:
        f.write(f"sample error\t[{sample_name}]\n{traceback.format_exc()}\n")
      with open(corrupt_samples_logfile_name, 'a') as f:
        f.write(f"{sample_name}\n")