tsha-mri-tumor-labeling/kimo/4-final-copy.py

223 lines
6.7 KiB
Python
Raw Normal View History

import logging
import json
import os
import shutil
from PIL import Image
import numpy as np
import SimpleITK as sitk
ROOT_DIR = "/mnt/t24/Public/kimo/"
RAW_DIR = f"{ROOT_DIR}/raw/"
DICOM_ROOT= f"{RAW_DIR}/DICOM/"
imagesTs_DIR = f"{RAW_DIR}/Dataset2602_BraTS-CK/imagesTs"
POSTPROCESSED_DIR = f"{RAW_DIR}/ensemble_predictions_postprocessed/"
FINAL_DIR = f"{ROOT_DIR}/final/"
NII_JSON_PATH = f"{RAW_DIR}/nii.json"
CASE_SET = set()
def copytree(src, dst):
os.makedirs(dst, exist_ok=True)
for root, dirs, files in os.walk(src):
for file in files:
src_file = os.path.join(root, file)
dst_file = os.path.join(dst, file)
shutil.copyfile(src_file, dst_file)
def get_largest_connected_component(binary_image):
"""
Extracts the largest connected component from a binary SimpleITK image.
Args:
binary_image (sitk.Image): A binary image (pixel values 0 or 1).
Returns:
sitk.Image: A binary image containing only the largest connected component.
"""
# 1. Label connected components
# Each connected component will be assigned a unique integer label (1, 2, 3, etc)
label_image = sitk.ConnectedComponent(binary_image)
# 2. Relabel components sorted by size
# The largest component will be assigned label 1
relabel_filter = sitk.RelabelComponentImageFilter()
relabel_image = relabel_filter.Execute(label_image)
# Optional: Get the size of objects in pixels if needed for verification
# sizes = relabel_filter.GetSizeOfObjectsInPixels()
# print(f"Object sizes (sorted): {sizes}")
# 3. Threshold the relabeled image to keep only label 1 (the largest component)
# The result is a binary image of the largest component
largest_component_image = sitk.BinaryThreshold(
relabel_image,
lowerThreshold=1,
upperThreshold=1,
insideValue=1,
outsideValue=0
)
return largest_component_image
def rescale_coronal_sagittal(arr_cor, arr_sag, scale):
img_cor = Image.fromarray(arr_cor)
img_sag = Image.fromarray(arr_sag)
cw, ch = img_cor.size
img_cor = img_cor.resize((cw, int(ch * scale)), Image.Resampling.BILINEAR)
sw, sh = img_sag.size
img_sag = img_sag.resize((sw, int(sh * scale)), Image.Resampling.BILINEAR)
return np.array(img_cor), np.array(img_sag)
def pad_to_h(arr, h):
pad_h = h - arr.shape[0]
if pad_h > 0:
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_width = [(pad_top, pad_bottom), (0, 0)]
if arr.ndim == 3:
pad_width.append((0, 0))
return np.pad(arr, tuple(pad_width), mode='constant')
return arr
def contour_overlay(image_nii, label_nii):
label = sitk.ReadImage(label_nii)
label = sitk.DICOMOrient(label)
largest = get_largest_connected_component(label)
stats = sitk.LabelShapeStatisticsImageFilter()
stats.Execute(largest)
if not stats.GetNumberOfLabels():
logging.warning(f"No label found in {label_nii}")
return None
print(stats.GetPhysicalSize(1))
exit()
center_idx = largest.TransformPhysicalPointToIndex(stats.GetCentroid(1))
image = sitk.ReadImage(image_nii)
image = sitk.DICOMOrient(image)
image = sitk.Cast(sitk.RescaleIntensity(image), sitk.sitkUInt8)
label.CopyInformation(image)
contour = sitk.LabelContour(label)
contour.CopyInformation(image)
contour = sitk.LabelOverlay(image, contour, opacity=1)
overlay = sitk.LabelOverlay(image, label)
# Get spacing (x, y, z)
spacing = image.GetSpacing()
sx, sy, sz = spacing
scale_z = sz / sy
def extract_slices(img_3d):
axial = img_3d[:, :, center_idx[2]]
coronal = img_3d[:, center_idx[1], :]
sagittal = img_3d[center_idx[0], :, :]
arr_axial = sitk.GetArrayFromImage(axial)
arr_coronal = sitk.GetArrayFromImage(coronal)
arr_sagittal = sitk.GetArrayFromImage(sagittal)
if arr_axial.ndim == 2:
arr_axial = np.stack((arr_axial,) * 3, axis=-1)
arr_coronal = np.stack((arr_coronal,) * 3, axis=-1)
arr_sagittal = np.stack((arr_sagittal,) * 3, axis=-1)
arr_coronal = np.flip(arr_coronal, axis=0)
arr_sagittal = np.flip(arr_sagittal, axis=0)
arr_coronal, arr_sagittal = rescale_coronal_sagittal(arr_coronal, arr_sagittal, scale_z)
max_h = max(arr_axial.shape[0], arr_coronal.shape[0], arr_sagittal.shape[0])
padded_arrays = []
for arr in [arr_axial, arr_coronal, arr_sagittal]:
padded_arrays.append(pad_to_h(arr, max_h))
return np.concatenate(padded_arrays, axis=1)
row_con = extract_slices(contour)
row_orig = extract_slices(image)
row_ov = extract_slices(overlay)
# Image.fromarray(row_con).save(f"{save_dir}/contour.png")
# Stack rows vertically
combined_arr = np.concatenate([row_orig, row_ov], axis=0)
combined_img = Image.fromarray(combined_arr)
# combined_img.save(f"{save_dir}/overlay.png")
return {
'contour_image': Image.fromarray(row_con),
'overlay_image': combined_img
}
def main():
FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s'
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
handlers=[
logging.StreamHandler(),
# logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8')
logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8')
]
)
with open(NII_JSON_PATH, 'r') as f:
NII_DICT = json.load(f)
for stem, dicomdir in sorted(NII_DICT.items()):
logging.info(f"{stem} {dicomdir}")
hash, date = stem.split("-")
if hash in CASE_SET:
continue
category = dicomdir.split("/")[0]
image_nii = f"{imagesTs_DIR}/{stem}_0000.nii.gz"
label_nii = f"{POSTPROCESSED_DIR}/{stem}.nii.gz"
dicomdir = f"{DICOM_ROOT}/{dicomdir}"
category_dir = f"{FINAL_DIR}/{category}"
dest_dir = f"{category_dir}/{stem}"
if not os.path.exists(image_nii):
continue
if not os.path.exists(label_nii):
continue
con_ov = contour_overlay(image_nii, label_nii)
if con_ov is None:
continue
os.makedirs(dest_dir, exist_ok=True)
con_ov['contour_image'].save(f"{dest_dir}/contour.png")
con_ov['overlay_image'].save(f"{category_dir}/{stem}.png")
shutil.copyfile(image_nii, f"{dest_dir}/image.nii.gz")
shutil.copyfile(label_nii, f"{dest_dir}/label.nii.gz")
copytree(dicomdir, f"{dest_dir}/DICOM")
CASE_SET.add(hash)
# exit()
if __name__ == "__main__":
main()