import logging import json import os import shutil from PIL import Image import numpy as np import pandas as pd import SimpleITK as sitk ROOT_DIR = "/mnt/t24/Public/kimo/" RAW_DIR = f"{ROOT_DIR}/raw/" DICOM_ROOT= f"{RAW_DIR}/DICOM/" imagesTs_DIR = f"{RAW_DIR}/Dataset2602_BraTS-CK/imagesTs" POSTPROCESSED_DIR = f"{RAW_DIR}/ensemble_predictions_postprocessed/" FINAL_DIR = f"{ROOT_DIR}/final/" NII_JSON_PATH = f"{RAW_DIR}/nii.json" CASE_SET = set() def copytree(src, dst): os.makedirs(dst, exist_ok=True) for root, dirs, files in os.walk(src): for file in files: src_file = os.path.join(root, file) dst_file = os.path.join(dst, file) shutil.copyfile(src_file, dst_file) def get_largest_connected_component(binary_image): """ Extracts the largest connected component from a binary SimpleITK image. Args: binary_image (sitk.Image): A binary image (pixel values 0 or 1). Returns: sitk.Image: A binary image containing only the largest connected component. """ # 1. Label connected components # Each connected component will be assigned a unique integer label (1, 2, 3, etc) label_image = sitk.ConnectedComponent(binary_image) # 2. Relabel components sorted by size # The largest component will be assigned label 1 relabel_filter = sitk.RelabelComponentImageFilter() relabel_image = relabel_filter.Execute(label_image) # Optional: Get the size of objects in pixels if needed for verification # sizes = relabel_filter.GetSizeOfObjectsInPixels() # print(f"Object sizes (sorted): {sizes}") # 3. Threshold the relabeled image to keep only label 1 (the largest component) # The result is a binary image of the largest component largest_component_image = sitk.BinaryThreshold( relabel_image, lowerThreshold=1, upperThreshold=1, insideValue=1, outsideValue=0 ) return largest_component_image def rescale_coronal_sagittal(arr_cor, arr_sag, scale): img_cor = Image.fromarray(arr_cor) img_sag = Image.fromarray(arr_sag) cw, ch = img_cor.size img_cor = img_cor.resize((cw, int(ch * scale)), Image.Resampling.BILINEAR) sw, sh = img_sag.size img_sag = img_sag.resize((sw, int(sh * scale)), Image.Resampling.BILINEAR) return np.array(img_cor), np.array(img_sag) def pad_to_h(arr, h): pad_h = h - arr.shape[0] if pad_h > 0: pad_top = pad_h // 2 pad_bottom = pad_h - pad_top pad_width = [(pad_top, pad_bottom), (0, 0)] if arr.ndim == 3: pad_width.append((0, 0)) return np.pad(arr, tuple(pad_width), mode='constant') return arr def contour_overlay(image_nii, label_nii): label = sitk.ReadImage(label_nii) label = sitk.DICOMOrient(label) largest = get_largest_connected_component(label) stats = sitk.LabelShapeStatisticsImageFilter() stats.Execute(largest) if not stats.GetNumberOfLabels(): logging.warning(f"No label found in {label_nii}") return None if stats.GetPhysicalSize(1) < 1000: logging.warning(f"Label size too small in {label_nii}") return None # print(stats.GetFeretDiameter(1)) # print(stats.GetEquivalentEllipsoidDiameter(1)) # print(stats.GetPhysicalSize(1)) # exit() center_idx = largest.TransformPhysicalPointToIndex(stats.GetCentroid(1)) image = sitk.ReadImage(image_nii) image = sitk.DICOMOrient(image) image = sitk.Cast(sitk.RescaleIntensity(image), sitk.sitkUInt8) label.CopyInformation(image) contour = sitk.LabelContour(label) contour.CopyInformation(image) contour = sitk.LabelOverlay(image, contour, opacity=1) overlay = sitk.LabelOverlay(image, label) # Get spacing (x, y, z) spacing = image.GetSpacing() sx, sy, sz = spacing scale_z = sz / sy def extract_slices(img_3d): axial = img_3d[:, :, center_idx[2]] coronal = img_3d[:, center_idx[1], :] sagittal = img_3d[center_idx[0], :, :] arr_axial = sitk.GetArrayFromImage(axial) arr_coronal = sitk.GetArrayFromImage(coronal) arr_sagittal = sitk.GetArrayFromImage(sagittal) if arr_axial.ndim == 2: arr_axial = np.stack((arr_axial,) * 3, axis=-1) arr_coronal = np.stack((arr_coronal,) * 3, axis=-1) arr_sagittal = np.stack((arr_sagittal,) * 3, axis=-1) arr_coronal = np.flip(arr_coronal, axis=0) arr_sagittal = np.flip(arr_sagittal, axis=0) arr_coronal, arr_sagittal = rescale_coronal_sagittal(arr_coronal, arr_sagittal, scale_z) max_h = max(arr_axial.shape[0], arr_coronal.shape[0], arr_sagittal.shape[0]) padded_arrays = [] for arr in [arr_axial, arr_coronal, arr_sagittal]: padded_arrays.append(pad_to_h(arr, max_h)) return np.concatenate(padded_arrays, axis=1) row_con = extract_slices(contour) row_orig = extract_slices(image) row_ov = extract_slices(overlay) # Image.fromarray(row_con).save(f"{save_dir}/contour.png") # Stack rows vertically combined_arr = np.concatenate([row_orig, row_ov], axis=0) combined_img = Image.fromarray(combined_arr) # combined_img.save(f"{save_dir}/overlay.png") return { 'contour_image': Image.fromarray(row_con), 'overlay_image': combined_img, 'stats': stats, } DATALIST = [] def main(): FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s' logging.basicConfig( level=logging.INFO, format=FORMAT, handlers=[ logging.StreamHandler(), # logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8') logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8') ] ) with open(NII_JSON_PATH, 'r') as f: NII_DICT = json.load(f) for stem, dicomdir in sorted(NII_DICT.items()): logging.info(f"{stem} {dicomdir}") hash, date = stem.split("-") if hash in CASE_SET: continue category = dicomdir.split("/")[0] image_nii = f"{imagesTs_DIR}/{stem}_0000.nii.gz" label_nii = f"{POSTPROCESSED_DIR}/{stem}.nii.gz" dicomdir = f"{DICOM_ROOT}/{dicomdir}" category_dir = f"{FINAL_DIR}/{category}" dest_dir = f"{category_dir}/{stem}" if not os.path.exists(image_nii): continue if not os.path.exists(label_nii): continue con_ov = contour_overlay(image_nii, label_nii) if con_ov is None: continue os.makedirs(dest_dir, exist_ok=True) con_ov['contour_image'].save(f"{dest_dir}/contour.png") con_ov['overlay_image'].save(f"{category_dir}/{stem}.png") shutil.copyfile(image_nii, f"{dest_dir}/image.nii.gz") shutil.copyfile(label_nii, f"{dest_dir}/label.nii.gz") # copytree(dicomdir, f"{dest_dir}/DICOM") CASE_SET.add(hash) # exit() DATALIST.append({ "Category": category, "Stem": stem, "EquivalentEllipsoidDiameter": con_ov['stats'].GetEquivalentEllipsoidDiameter(1), "PhysicalSize": con_ov['stats'].GetPhysicalSize(1), # "dicomdir": dicomdir, # "image_nii": image_nii, # "label_nii": label_nii, # "dest_dir": dest_dir, }) df = pd.DataFrame(DATALIST) df.to_excel(f"{FINAL_DIR}/datalist.xlsx", index=False) if __name__ == "__main__": main()