import argparse
import os
import re
import SimpleITK as sitk
import numpy as np

def find_subfolders(base_path):
    image_dir, label_dir = None, None
    for entry in os.listdir(base_path):
        full_path = os.path.join(base_path, entry)
        if os.path.isdir(full_path):
            if entry.startswith("images"):
                image_dir = full_path
            elif entry.startswith("labels"):
                label_dir = full_path
    if not image_dir or not label_dir:
        raise ValueError("Both 'images*' and 'labels*' subfolders must be found in the input path.")
    return image_dir, label_dir

def load_volume(path):
    return sitk.ReadImage(path)

def rotate_volume(volume):
    array = sitk.GetArrayFromImage(volume)  
    rotated_array = np.flip(array, axis=1)
    rotated_array = np.flip(rotated_array, axis=2)
    rotated_volume = sitk.GetImageFromArray(rotated_array)
    rotated_volume.CopyInformation(volume)
    return rotated_volume

def process_folder(input_folder, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    for root, _, files in os.walk(input_folder):
        rel_path = os.path.relpath(root, input_folder)
        out_subdir = os.path.join(output_folder, rel_path)
        os.makedirs(out_subdir, exist_ok=True)

        for fname in files:
            if re.search(r'\.(mha|nii|nii\.gz)$', fname, re.IGNORECASE):
                in_path = os.path.join(root, fname)
                out_path = os.path.join(out_subdir, fname)
                volume = load_volume(in_path)
                rotated = rotate_volume(volume)
                sitk.WriteImage(rotated, out_path)

def main():
    parser = argparse.ArgumentParser(description="Fix ToothFairy radiological orientation.")
    parser.add_argument("input_path", type=str, help="Input folder containing 'images*' and 'labels*' subfolders.")
    parser.add_argument("output_path", type=str, help="Output folder to store transformed volumes.")
    args = parser.parse_args()

    image_dir, label_dir = find_subfolders(args.input_path)
    out_image_dir = os.path.join(args.output_path, os.path.basename(image_dir))
    out_label_dir = os.path.join(args.output_path, os.path.basename(label_dir))

    print(f"Processing images from: {image_dir}")
    process_folder(image_dir, out_image_dir)

    print(f"Processing labels from: {label_dir}")
    process_folder(label_dir, out_label_dir)

    print("All volumes processed successfully.")

if __name__ == "__main__":
    main()
