import argparse
import os
import re
import SimpleITK as sitk
import numpy as np

def find_subfolders(base_path):
    image_dir, label_dir = None, None
    for entry in os.listdir(base_path):
        full_path = os.path.join(base_path, entry)
        if os.path.isdir(full_path):
            if entry.startswith("images"):
                image_dir = full_path
            elif entry.startswith("labels"):
                label_dir = full_path
    if not image_dir or not label_dir:
        raise ValueError("Both 'images*' and 'labels*' subfolders must be found in the input path.")
    return image_dir, label_dir

def load_volume(path):
    return sitk.ReadImage(path)

def rotate_volume(volume):
    array = sitk.GetArrayFromImage(volume)  
    rotated_array = array
    rotated_array = np.flip(rotated_array, axis=1)
    #rotated_array = np.flip(rotated_array, axis=2)
    rotated_array = np.flip(rotated_array, axis=0)
    rotated_volume = sitk.GetImageFromArray(rotated_array)
    
    # Copy basic information (spacing, origin)
    rotated_volume.SetSpacing(volume.GetSpacing())
    rotated_volume.SetOrigin((0.0, 0.0, 0.0))
    
    # Set the direction matrix to LPS orientation
    # LPS means: Left-to-right, Posterior-to-anterior, Superior-to-inferior
    lps_direction = (-1.0, 0.0, 0.0,   # X axis: Left to Right (negative)
                     0.0, -1.0, 0.0,   # Y axis: Posterior to Anterior (negative) 
                     0.0, 0.0, 1.0)    # Z axis: Superior to Inferior (positive)
    rotated_volume.SetDirection(lps_direction)
    
    return rotated_volume

def process_folder(input_folder, output_folder, max_files=None):
    os.makedirs(output_folder, exist_ok=True)
    processed_count = 0
    
    for root, _, files in os.walk(input_folder):
        if max_files is not None and processed_count >= max_files:
            break
            
        rel_path = os.path.relpath(root, input_folder)
        out_subdir = os.path.join(output_folder, rel_path)
        os.makedirs(out_subdir, exist_ok=True)

        for fname in files:
            if max_files is not None and processed_count >= max_files:
                break
                
            if re.search(r'\.(mha|nii|nii\.gz)$', fname, re.IGNORECASE):
                in_path = os.path.join(root, fname)
                out_path = os.path.join(out_subdir, fname)
                volume = load_volume(in_path)
                rotated = rotate_volume(volume)
                sitk.WriteImage(rotated, out_path)
                processed_count += 1
                print(f"Processed {processed_count}/{max_files if max_files else '∞'}: {fname}")
    
    return processed_count

def main():
    parser = argparse.ArgumentParser(description="Fix ToothFairy radiological orientation.")
    parser.add_argument("input_path", type=str, help="Input folder containing 'images*' and 'labels*' subfolders.")
    parser.add_argument("output_path", type=str, help="Output folder to store transformed volumes.")
    parser.add_argument("max_files", type=int, nargs='?', default=None, help="Maximum number of files to process from each folder (optional).")
    args = parser.parse_args()

    image_dir, label_dir = find_subfolders(args.input_path)
    out_image_dir = os.path.join(args.output_path, os.path.basename(image_dir))
    out_label_dir = os.path.join(args.output_path, os.path.basename(label_dir))

    print(f"Processing images from: {image_dir}")
    if args.max_files:
        print(f"Limiting to {args.max_files} files per folder")
    
    images_processed = process_folder(image_dir, out_image_dir, args.max_files)
    print(f"Processed {images_processed} images")

    print(f"Processing labels from: {label_dir}")
    labels_processed = process_folder(label_dir, out_label_dir, args.max_files)
    print(f"Processed {labels_processed} labels")

    print("All volumes processed successfully.")

if __name__ == "__main__":
    main()
