#!/usr/bin/env python3

import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import json
import sys
import time
import psutil
import os
import urllib.request
import urllib.parse
import json
import tempfile

# Server base URL - update this to your actual server URL
SERVER_URL = "https://melt.graphics/Melt/wtf/stJordi_backend"

# Item class from the original script
class Item:
    def __init__(self, objectType, position, rotation, scale, size):
        """
        Initialize an Item with the given properties.

        Args:
            objectType (int): The type of the object.
            position (np.array): A Vector3 representing the position (x, y, z).
            rotation (np.array): A Vector3 representing the rotation (x, y, z).
            scale (np.array): A Vector3 representing the scale (x, y, z).
            size (np.array): A Vector2 representing the size (width, height).
        """
        self.objectType = objectType
        self.position = np.array(position, dtype=float)  # Vector3
        self.rotation = np.array(rotation, dtype=float)  # Vector3
        self.scale = np.array(scale, dtype=float)        # Vector3
        self.size = np.array(size, dtype=float)          # Vector2

    def __repr__(self):
        return (f"Item(objectType={self.objectType}, "
                f"position={self.position}, "
                f"rotation={self.rotation}, "
                f"scale={self.scale}, "
                f"size={self.size})")

# Global variable for transformation matrix (will be set later)
M = None

# Function to transform rectangle coordinates
def get_transform_rect(x, y, w, h): 
    point = np.array([x, y, 1])
    transformed_point = np.dot(M, point)
    tx = transformed_point[0] / transformed_point[2]
    ty = transformed_point[1] / transformed_point[2]

    point2 = np.array([x + w, y + h, 1])
    transformed_point2 = np.dot(M, point2)
    tx2 = transformed_point2[0] / transformed_point2[2]
    ty2 = transformed_point2[1] / transformed_point2[2]
    tw = tx2 - tx
    th = ty2 - ty

    return int(tx), int(ty), int(tw), int(th)

# Function to find inside rectangles
def find_inside_rects(x, y, w, h):
    i = len(x) - 1  # Start from the last index
    while i >= 0:
        j = len(x) - 1  # Start from the last index for the inner loop
        while j >= 0:
            if i != j:
                tx_i, ty_i, tw_i, th_i = get_transform_rect(x[i], y[i], w[i], h[i])
                tx_j, ty_j, tw_j, th_j = get_transform_rect(x[j], y[j], w[j], h[j])
                # Check if rectangle i is inside rectangle j
                if (tx_i >= tx_j and ty_i >= ty_j and
                    tx_i + tw_i <= tx_j + tw_j and
                    ty_i + th_i <= ty_j + th_j):
                    # Remove rectangle i
                    x.pop(i)
                    y.pop(i)
                    w.pop(i)
                    h.pop(i)
                    break  # Exit the inner loop since rectangle i is removed
            j -= 1
        i -= 1  # Move to the previous rectangle
    return x, y, w, h

# Function to find outside rectangles
def find_outside_rects(x, y, w, h, x_track, y_track, w_track, h_track):
    i = 0
    hasRemoved = False
    while i < len(x):
        tx, ty, tw, th = get_transform_rect(x[i], y[i], w[i], h[i])
        if ((tx + tw) < 0 or tx > w_track or
            (ty + th) < 0 or ty > h_track):
            x.pop(i)
            y.pop(i)
            w.pop(i)
            h.pop(i)
            hasRemoved = True
        if hasRemoved: 
            hasRemoved = False
        else:
            i += 1

    return x, y, w, h

# Function to unify rectangles
def unify_rects(x, y, w, h):
    i = 0
    while i < len(x):
        j = i + 1
        while j < len(x):
            tx_i, ty_i, tw_i, th_i = get_transform_rect(x[i], y[i], w[i], h[i])
            tx_j, ty_j, tw_j, th_j = get_transform_rect(x[j], y[j], w[j], h[j])
            # Check if the rectangles are overlapping or very close
            if (tx_i <= (tx_j + tw_j + 5) and (tx_i + tw_i + 5) >= tx_j and
            ty_i <= (ty_j + th_j + 5) and (ty_i + th_i + 5) >= ty_j):
                #Remove both triangles to make a big one
                x_new = min(x[i], x[j])
                y_new = min(y[i], y[j])
                w_new = max(x[i] + w[i], x[j] + w[j]) - x_new
                h_new = max(y[i] + h[i], y[j] + h[j]) - y_new
                # Remove the old rectangles
                x.pop(j)
                y.pop(j)
                w.pop(j)
                h.pop(j)
                x[i] = x_new
                y[i] = y_new
                w[i] = w_new
                h[i] = h_new
            else:
                j += 1
        i += 1

    return x, y, w, h

# Line intersection function
def line_intersection(A, A2, B, B2):
    # Direction vectors for the lines
    dir1 = A2 - A  # Direction vector for line D -> D2
    dir2 = B2 - B  # Direction vector for line B -> B2

    # Set up the system of equations
    ES = np.array([
        [dir1[0], -dir2[0]],
        [dir1[1], -dir2[1]]
    ])
    b = np.array([B[0] - A[0], B[1] - A[1]])

    # Solve for t and s
    t, s = np.linalg.solve(ES, b)

    # Find the intersection point
    intersection = A + t * dir1

    return intersection

# Function to compute corners
def compute_corners(approx, A_in, B_in, C_in, D_in, empty_corner):
    # Compute the corners of the tracking rectangle
    A = (0, 0)
    B = (0, 0)
    C = (0, 0)
    D = (0, 0)
    if empty_corner != "A":
        for i in A_in:
            A += approx[i][0]
        A = A // len(A_in)
    if empty_corner != "B":
        for i in B_in:
            B += approx[i][0]
        B = B // len(B_in)
    if empty_corner != "C":
        for i in C_in:
            C += approx[i][0]
        C = C // len(C_in)
    if empty_corner != "D":
        for i in D_in:
            D += approx[i][0]
        D = D // len(D_in)
    
    if(empty_corner != "None"):
        # Get points in approx that are not in the corners
        approx_out = np.delete(approx, np.concatenate((A_in, B_in, C_in, D_in)), axis=0)
        #Get middle points
        thrx_left = (max(approx[:, 0, 0]) - min(approx[:, 0, 0])) // 4 + min(approx[:, 0, 0])
        thry_up = (max(approx[:, 0, 1]) - min(approx[:, 0, 1])) // 4 + min(approx[:, 0, 1])
        thrx_right = max(approx[:,0, 0]) - (max(approx[:, 0, 0]) - min(approx[:, 0, 0])) // 4
        thry_down = max(approx[:, 0, 1]) - (max(approx[:, 0, 1]) - min(approx[:, 0, 1])) // 4
        #Filter the points left up, left down, right up and right down
        values_up = np.where(approx_out[:, 0, 1] < thry_up)[0]
        values_left = np.where(approx_out[:, 0, 0] < thrx_left)[0]
        values_right = np.where(approx_out[:, 0, 0] > thrx_right)[0]
        values_down = np.where(approx_out[:, 0, 1] > thry_down)[0]
        #Find last corner
        A, B, C, D = find_last_corner(approx_out, empty_corner, A, B, C, D, values_up, values_left, values_right, values_down)

    return A, B, C, D

# Function to find the last corner
def find_last_corner(approx_out, empty_corner, A, B, C, D, up, left, right, down):
    if empty_corner == "A":
        #Find left and up lines
        # 1.- Left line will be between corner D and point with min y of left line
        # 2.- Up line will be between corner B and point with min x of up line
        if (left.size == 0) or (up.size == 0):
            print("Please, try again with a better picture")
            exit()
        index_left_min = left[np.argmin(approx_out[left, 0, 1])]
        D2 = approx_out[index_left_min, 0]
        index_up_min = up[np.argmin(approx_out[up, 0, 0])]
        B2 = approx_out[index_up_min, 0]
        #Find intersection of lines D2 and B2
        i = line_intersection(D, D2, B, B2)
        A = (int(i[0]), int(i[1]))
    elif empty_corner == "B":
        #Find right and up lines
        # 1.- Right line will be between corner C and point with min y of right line
        # 2.- Up line will be between corner A and point with max x of up line
        if (right.size == 0) or (up.size == 0):
            print("Please, try again with a better picture")
            exit()
        index_right_min = right[np.argmin(approx_out[right, 0, 1])]
        C2 = approx_out[index_right_min, 0]
        index_up_max = up[np.argmax(approx_out[up, 0, 0])]
        A2 = approx_out[index_up_max, 0]
        #Find intersection of lines C2 and A2
        i = line_intersection(C, C2, A, A2)
        B = (int(i[0]), int(i[1]))
    elif empty_corner == "C":
        #Find right and down lines
        # 1.- Right line will be between corner B and point with max y of right line
        # 2.- Down line will be between corner D and point with max x of down line
        if (right.size == 0) or (down.size == 0):
            print("Please, try again with a better picture")
            exit()
        index_right_max = right[np.argmax(approx_out[right, 0, 1])]
        B2 = approx_out[index_right_max, 0]
        index_down_max = down[np.argmax(approx_out[down, 0, 0])]
        D2 = approx_out[index_down_max, 0]
        #Find intersection of lines B2 and D2
        i = line_intersection(B, B2, D, D2)
        C = (int(i[0]), int(i[1]))
    elif empty_corner == "D":
        #Find left and down lines
        # 1.- Left line will be between corner A and point with max y of left line
        # 2.- Down line will be between corner C and point with min x of down line
        if (left.size == 0) or (down.size == 0):
            print("Please, try again with a better picture")
            exit()
        index_left_max = left[np.argmax(approx_out[left, 0, 1])]
        A2 = approx_out[index_left_max, 0]
        index_down_min = down[np.argmin(approx_out[down, 0, 0])]
        C2 = approx_out[index_down_min, 0]
        #Find intersection of lines A2 and C2
        i = line_intersection(A, A2, C, C2)
        D = (int(i[0]), int(i[1]))
    
    return A, B, C, D

# Function to get the transform matrix
def get_transform_matrix(current_w, track_cnt, im):
    global M  # Declare M as global so we can set it
    
    track_w = current_w
    track_h = (track_w * 9) // 16 # 16:9 aspect ratio

    # Find shape approximation of tracking rectangle
    approx = cv.approxPolyDP(track_cnt, 0.001 * cv.arcLength(track_cnt, True), True)
    # Draw the approximated contour
    cv.drawContours(im, approx, -1, (0, 255, 0), 1)
    #Find threshold values for the tracking rectangle
    thrx_left = (max(approx[:, 0, 0]) - min(approx[:, 0, 0])) // 4 + min(approx[:, 0, 0])
    thry_up = (max(approx[:, 0, 1]) - min(approx[:, 0, 1])) // 4 + min(approx[:, 0, 1])
    thrx_right = max(approx[:,0, 0]) - (max(approx[:, 0, 0]) - min(approx[:, 0, 0])) // 4
    thry_down = max(approx[:, 0, 1]) - (max(approx[:, 0, 1]) - min(approx[:, 0, 1])) // 4
    #Find points in each corner of the tracking rectangle
    values_left_up = np.where((approx[:, 0, 0] < thrx_left) & (approx[:, 0, 1] < thry_up))[0]
    values_left_down = np.where((approx[:, 0, 0] < thrx_left) & (approx[:, 0, 1] > thry_down))[0]
    values_right_up = np.where((approx[:, 0, 0] > thrx_right) & (approx[:, 0, 1] < thry_up))[0]
    values_right_down = np.where((approx[:, 0, 0] > thrx_right) & (approx[:, 0, 1] > thry_down))[0]

    #Check if any of the corners are empty
    empty_corner = "None"
    empty_corners = 0
    if len(values_left_up) == 0:
        empty_corner = "A"
        empty_corners += 1
    if len(values_right_up) == 0:
        empty_corner = "B"
        empty_corners += 1
    if len(values_right_down) == 0:
        empty_corner = "C"
        empty_corners += 1
    if len(values_left_down) == 0:
        empty_corner = "D"
        empty_corners += 1

    if (empty_corners > 1):
        print("Please, try again with a better picture")
        exit()
    else:
        A, B, C, D = compute_corners(approx, values_left_up, values_right_up, values_right_down, values_left_down, empty_corner)

    cv.circle(im, A, 5, (255, 0, 225), -1)
    cv.circle(im, B, 5, (255, 0, 225), -1)
    cv.circle(im, C, 5, (255, 0, 225), -1)
    cv.circle(im, D, 5, (255, 0, 225), -1)

    #Make transformation matrix for tracking rectangle
    pts1 = np.float32([A,B,C,D])
    pts2 = np.float32([[0,0],[track_w,0],[track_w,track_h],[0,track_h]])

    M = cv.getPerspectiveTransform(pts1,pts2)
    corners = np.array([A, B, C, D])

    return corners, M, track_w, track_h

# Function to print CPU usage
def print_cpu_usage():
    cpu_percent = psutil.cpu_percent(interval=1)  # Get CPU usage percentage
    print(f"CPU Usage: {cpu_percent:.2f}%")

# Function to get the latest record with status 1 via PHP
def get_latest_record():
    try:
        with urllib.request.urlopen(f"{SERVER_URL}/get_last_image_to_process.php") as response:
            data = json.loads(response.read().decode('utf-8'))
            
            if data.get('success') and data.get('record'):
                return data['record']
            else:
                print(f"No pending records found or error: {data.get('message', 'Unknown error')}")
                return None
    except Exception as e:
        print(f"Error fetching record: {e}")
        return None

# Function to update record status after processing via PHP
def update_record_status(tb_id, status):
    try:
        url = f"{SERVER_URL}/update_processed_image.php?tb_id={tb_id}&status={status}"
        
        with urllib.request.urlopen(url) as response:
            data = json.loads(response.read().decode('utf-8'))
            
            if data.get('success'):
                print(f"Record status updated to {status}")
                return True
            else:
                print(f"Failed to update record status: {data.get('error', 'Unknown error')}")
                return False
    except Exception as e:
        print(f"Error updating record: {e}")
        return False

# Function to download image from server
def download_image(image_id):
    try:
        image_url = f"{SERVER_URL}/images/{image_id}.jpg"
        print(f"Downloading image from: {image_url}")
        
        # Create downloaded_images directory if it doesn't exist
        download_dir = "downloaded_images"
        if not os.path.exists(download_dir):
            os.makedirs(download_dir)
        
        # Set the path for the downloaded image
        local_path = os.path.join(download_dir, f"{image_id}.jpg")
        
        # Download the image
        urllib.request.urlretrieve(image_url, local_path)
        
        print(f"Image downloaded to: {local_path}")
        return local_path
    except Exception as e:
        print(f"Error downloading image: {e}")
        return None

# NEW FUNCTION: Upload JSON to server
def upload_json_file(json_data, image_id):
    try:
        url = f"{SERVER_URL}/set_json.php?id={image_id}"
        
        # Prepare the request
        headers = {'Content-Type': 'application/json'}
        req = urllib.request.Request(
            url,
            data=json_data.encode('utf-8'),
            headers=headers,
            method='POST'
        )
        
        # Send the request
        with urllib.request.urlopen(req) as response:
            data = json.loads(response.read().decode('utf-8'))
            
            if data.get('success'):
                print(f"JSON data uploaded successfully as {image_id}.json")
                return True
            else:
                print(f"Failed to upload JSON data: {data.get('error', 'Unknown error')}")
                return False
    except Exception as e:
        print(f"Error uploading JSON data: {e}")
        return False

# Start the timer
start_time = time.time()

# Main execution
if __name__ == "__main__":
    # Check command line arguments
    if len(sys.argv) > 1:
        # If image path is provided, use it directly
        img_path = sys.argv[1]
        json_name = sys.argv[2] if len(sys.argv) > 2 else "output.json"
        print(f"Using provided image: {img_path}")
        print(f"Output JSON will be: {json_name}")
        record_id = None
    else:
        # If no image path is provided, get the latest record from database via PHP
        record = get_latest_record()
        
        if record is None:
            print("No pending records found.")
            sys.exit(1)
        
        # Get image from record ID
        img_id = record['id']
        record_id = record['tb_id']
        
        print(f"Processing record ID: {record_id}, Image ID: {img_id}")
        
        # Download the image from the server
        img_path = download_image(img_id)
        if img_path is None:
            print(f"Error: Could not download image for ID {img_id}")
            update_record_status(record_id, 4)  # Status 4 = error
            sys.exit(1)
        
        # Set JSON output path
        json_name = f"config_json/{img_id}.json"

    # Read the image
    try:
        im = cv.imread(img_path)
        assert im is not None, "Image not found or could not be read"
        im = cv.medianBlur(im, 5)
    except Exception as e:
        print(f"Error loading image: {e}")
        if record_id:
            update_record_status(record_id, 4)  # Status 4 = error
        sys.exit(1)

    # Convert image to grayscale
    imgray = cv.cvtColor(im, cv.COLOR_BGR2GRAY)

    # Apply a adaptative threshold to the grayscale image
    thresh = cv.adaptiveThreshold(imgray, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY, 11, 2)

    # Find contours in the thresholded image
    contours, hierarchy = cv.findContours(thresh, cv.RETR_TREE, cv.CHAIN_APPROX_NONE)

    x = []
    y = []
    w = []
    h = []

    # Save index of tracking rectangle
    track_cnt = contours[0]
    current_w = 0
    for cnt in contours:
        # Get maximum x and y coordinates of the contour
        x_min = min(cnt[:, 0, 0])
        y_min = min(cnt[:, 0, 1])
        x_max = max(cnt[:, 0, 0])
        y_max = max(cnt[:, 0, 1])

        a = (x_max - x_min) * (y_max - y_min)
        # Filter out small contours
        if a > 100:
            x_cnt = x_min
            y_cnt = y_min
            w_cnt = x_max - x_min
            h_cnt = y_max - y_min

            # Filter full image contour
            im_height, im_width = im.shape[:2]
            if w_cnt >= (im_width - 10) or h_cnt >= (im_height - 10):
                continue

            x.append(x_cnt)
            y.append(y_cnt)
            w.append(w_cnt)
            h.append(h_cnt)

            if (w_cnt > current_w):
                current_w = w_cnt
                track_cnt = cnt

    # Detect tracking rect
    cv.drawContours(im, track_cnt, -1, (0, 0, 255), 1)
    corners, M, track_w, track_h = get_transform_matrix(current_w, track_cnt, im)
    dst = cv.warpPerspective(im, M, (track_w, track_h))

    # Remove tracking rect from the list of rectangles
    track_index = w.index(current_w)
    x.pop(track_index)
    y.pop(track_index)
    w.pop(track_index)
    h.pop(track_index)

    i = 0
    while i < len(x):
        if (w[i] >= track_w - 100) or (h[i] >= track_h - 100):
            x.pop(i)
            y.pop(i)
            w.pop(i)
            h.pop(i)
        else:
            i += 1

    # Draw the tracking rectangle
    cv.rectangle(im, (corners[0][0], corners[0][1]), (corners[0][0] + track_w, corners[0][1] + track_h), (255, 0, 0), 2)

    # Delete rectangles that are outside tracking rectangle
    x, y, w, h = find_outside_rects(x, y, w, h, corners[0][0], corners[0][1], track_w, track_h)

    # Delete inside rectangles
    x, y, w, h = find_inside_rects(x, y, w, h)

    # Unify superposed triangles
    x, y, w, h = unify_rects(x, y, w, h)

    for i in range(len(x)):
        tx, ty, tw, th = get_transform_rect(x[i], y[i], w[i], h[i])
        cv.rectangle(dst, (tx, ty), (tx + tw, ty + th), (0, 255, 0), 2)

    # Make array of positions following the aspect ratio of the tracking rectangle
    items = []
    for i in range(len(x)):
        xi, yi, wi, hi = get_transform_rect(x[i], y[i], w[i], h[i])
        # Normalize
        xi = float(xi / track_w)
        yi = float(yi / track_h)
        wi = float(wi / track_w)
        hi = float(hi / track_h)
        # Create Item object
        item = Item(
            objectType = 1,
            position = np.array([xi, yi, 0]),
            rotation = np.array([0, 0, 0]),
            scale = np.array([1, 1, 1]),
            size = np.array([wi, hi])
        )
        items.append(item)

    # Create JSON data
    items_dict = [
        {
            "objectType": item.objectType,
            "position": item.position.tolist(),  # Convert NumPy array to list
            "rotation": item.rotation.tolist(),
            "scale": item.scale.tolist(),
            "size": item.size.tolist()
        }
        for item in items
    ]
    json_data = json.dumps(items_dict, indent=4)
    
    # UPDATED: Save to a JSON file and upload to server if processing from database
    try:
        # Make sure the config_json directory exists for local save
        if os.path.dirname(json_name):
            os.makedirs(os.path.dirname(json_name), exist_ok=True)
        
        # Save locally
        with open(str(json_name), "w") as json_file:
            json_file.write(json_data)
        
        print(f"JSON saved locally to {json_name}")
        
        # If processing a record from the database, upload JSON to server and update status
        if record_id:
            upload_success = upload_json_file(json_data, img_id)
            print(f"JSON upload to server {'successful' if upload_success else 'failed'}")
            
            # Update record status regardless of upload success
            update_record_status(record_id, 3)  # Status 3 = processed
            
    except Exception as e:
        print(f"Error saving/uploading JSON: {e}")
        if record_id:
            update_record_status(record_id, 4)  # Status 4 = error
        sys.exit(1)

    # End the timer
    end_time = time.time()

    # Calculate and print the execution time
    execution_time = end_time - start_time
    print(f"Execution time: {execution_time:.2f} seconds")
    # Call this function at the end of your script
    print_cpu_usage() 