tsha-mri-tumor-labeling/kimo/4-final-copy.py

271 lines
8.4 KiB
Python
Raw Normal View History

import logging
import json
import os
import shutil
from PIL import Image
import numpy as np
2026-02-22 04:12:24 +00:00
import pandas as pd
import SimpleITK as sitk
ROOT_DIR = "/mnt/t24/Public/kimo/"
RAW_DIR = f"{ROOT_DIR}/raw/"
DICOM_ROOT= f"{RAW_DIR}/DICOM/"
imagesTs_DIR = f"{RAW_DIR}/Dataset2602_BraTS-CK/imagesTs"
POSTPROCESSED_DIR = f"{RAW_DIR}/ensemble_predictions_postprocessed/"
FINAL_DIR = f"{ROOT_DIR}/final/"
NII_JSON_PATH = f"{RAW_DIR}/nii.json"
CASE_SET = set()
def copytree(src, dst):
os.makedirs(dst, exist_ok=True)
for root, dirs, files in os.walk(src):
for file in files:
src_file = os.path.join(root, file)
dst_file = os.path.join(dst, file)
shutil.copyfile(src_file, dst_file)
def get_largest_connected_component(binary_image):
"""
Extracts the largest connected component from a binary SimpleITK image.
Args:
binary_image (sitk.Image): A binary image (pixel values 0 or 1).
Returns:
sitk.Image: A binary image containing only the largest connected component.
"""
# 1. Label connected components
# Each connected component will be assigned a unique integer label (1, 2, 3, etc)
label_image = sitk.ConnectedComponent(binary_image)
# 2. Relabel components sorted by size
# The largest component will be assigned label 1
relabel_filter = sitk.RelabelComponentImageFilter()
relabel_image = relabel_filter.Execute(label_image)
# Optional: Get the size of objects in pixels if needed for verification
# sizes = relabel_filter.GetSizeOfObjectsInPixels()
# print(f"Object sizes (sorted): {sizes}")
# 3. Threshold the relabeled image to keep only label 1 (the largest component)
# The result is a binary image of the largest component
largest_component_image = sitk.BinaryThreshold(
relabel_image,
lowerThreshold=1,
upperThreshold=1,
insideValue=1,
outsideValue=0
)
return largest_component_image
def rescale_coronal_sagittal(arr_cor, arr_sag, scale):
img_cor = Image.fromarray(arr_cor)
img_sag = Image.fromarray(arr_sag)
cw, ch = img_cor.size
img_cor = img_cor.resize((cw, int(ch * scale)), Image.Resampling.BILINEAR)
sw, sh = img_sag.size
img_sag = img_sag.resize((sw, int(sh * scale)), Image.Resampling.BILINEAR)
return np.array(img_cor), np.array(img_sag)
def pad_to_h(arr, h):
pad_h = h - arr.shape[0]
if pad_h > 0:
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_width = [(pad_top, pad_bottom), (0, 0)]
if arr.ndim == 3:
pad_width.append((0, 0))
return np.pad(arr, tuple(pad_width), mode='constant')
return arr
def contour_overlay(image_nii, label_nii):
label = sitk.ReadImage(label_nii)
label = sitk.DICOMOrient(label)
largest = get_largest_connected_component(label)
stats = sitk.LabelShapeStatisticsImageFilter()
stats.Execute(largest)
if not stats.GetNumberOfLabels():
logging.warning(f"No label found in {label_nii}")
return None
2026-02-22 04:12:24 +00:00
if stats.GetPhysicalSize(1) < 1000:
logging.warning(f"Label size too small in {label_nii}")
return None
# print(stats.GetFeretDiameter(1))
# print(stats.GetEquivalentEllipsoidDiameter(1))
# print(stats.GetPhysicalSize(1))
# exit()
center_idx = largest.TransformPhysicalPointToIndex(stats.GetCentroid(1))
image = sitk.ReadImage(image_nii)
image = sitk.DICOMOrient(image)
image = sitk.Cast(sitk.RescaleIntensity(image), sitk.sitkUInt8)
label.CopyInformation(image)
contour = sitk.LabelContour(label)
contour.CopyInformation(image)
contour = sitk.LabelOverlay(image, contour, opacity=1)
overlay = sitk.LabelOverlay(image, label)
# Get spacing (x, y, z)
spacing = image.GetSpacing()
sx, sy, sz = spacing
scale_z = sz / sy
def extract_slices(img_3d):
axial = img_3d[:, :, center_idx[2]]
coronal = img_3d[:, center_idx[1], :]
sagittal = img_3d[center_idx[0], :, :]
arr_axial = sitk.GetArrayFromImage(axial)
arr_coronal = sitk.GetArrayFromImage(coronal)
arr_sagittal = sitk.GetArrayFromImage(sagittal)
if arr_axial.ndim == 2:
arr_axial = np.stack((arr_axial,) * 3, axis=-1)
arr_coronal = np.stack((arr_coronal,) * 3, axis=-1)
arr_sagittal = np.stack((arr_sagittal,) * 3, axis=-1)
arr_coronal = np.flip(arr_coronal, axis=0)
arr_sagittal = np.flip(arr_sagittal, axis=0)
arr_coronal, arr_sagittal = rescale_coronal_sagittal(arr_coronal, arr_sagittal, scale_z)
max_h = max(arr_axial.shape[0], arr_coronal.shape[0], arr_sagittal.shape[0])
padded_arrays = []
for arr in [arr_axial, arr_coronal, arr_sagittal]:
padded_arrays.append(pad_to_h(arr, max_h))
return np.concatenate(padded_arrays, axis=1)
row_con = extract_slices(contour)
row_orig = extract_slices(image)
row_ov = extract_slices(overlay)
# Image.fromarray(row_con).save(f"{save_dir}/contour.png")
# Stack rows vertically
combined_arr = np.concatenate([row_orig, row_ov], axis=0)
combined_img = Image.fromarray(combined_arr)
# combined_img.save(f"{save_dir}/overlay.png")
return {
'contour_image': Image.fromarray(row_con),
2026-02-22 04:12:24 +00:00
'overlay_image': combined_img,
'stats': stats,
}
2026-02-22 04:12:24 +00:00
DATALIST = []
def main():
FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s'
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
handlers=[
logging.StreamHandler(),
# logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8')
logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8')
]
)
2026-03-07 11:32:28 +00:00
NII_DICT = {}
for pathology in sorted(os.listdir(DICOM_ROOT)):
pathology_path = os.path.join(DICOM_ROOT, pathology)
if not os.path.isdir(pathology_path):
continue
pathology_name = pathology
for patient in sorted(os.listdir(pathology_path)):
patient_path = os.path.join(pathology_path, patient)
if not os.path.isdir(patient_path):
continue
patient_name = patient
for study in sorted(os.listdir(patient_path)):
study_path = os.path.join(patient_path, study)
if not os.path.isdir(study_path):
continue
study_name = study
key = f"{patient_name}-{study_name.split('_')[0]}"
NII_DICT[key] = os.path.relpath(study_path, DICOM_ROOT)
# print(NII_DICT)
# exit()
for stem, dicomdir in sorted(NII_DICT.items()):
hash, date = stem.split("-")
if hash in CASE_SET:
2026-03-07 11:32:28 +00:00
logging.info(f"skip {stem} {dicomdir}")
continue
2026-03-07 11:32:28 +00:00
logging.info(f"{stem} {dicomdir}")
category = dicomdir.split("/")[0]
image_nii = f"{imagesTs_DIR}/{stem}_0000.nii.gz"
label_nii = f"{POSTPROCESSED_DIR}/{stem}.nii.gz"
dicomdir = f"{DICOM_ROOT}/{dicomdir}"
category_dir = f"{FINAL_DIR}/{category}"
dest_dir = f"{category_dir}/{stem}"
if not os.path.exists(image_nii):
2026-03-07 11:32:28 +00:00
logging.info(f"{image_nii} not exists")
continue
if not os.path.exists(label_nii):
2026-03-07 11:32:28 +00:00
logging.info(f"{label_nii} not exists")
continue
con_ov = contour_overlay(image_nii, label_nii)
if con_ov is None:
continue
os.makedirs(dest_dir, exist_ok=True)
con_ov['contour_image'].save(f"{dest_dir}/contour.png")
con_ov['overlay_image'].save(f"{category_dir}/{stem}.png")
shutil.copyfile(image_nii, f"{dest_dir}/image.nii.gz")
shutil.copyfile(label_nii, f"{dest_dir}/label.nii.gz")
2026-03-07 11:32:28 +00:00
copytree(dicomdir, f"{dest_dir}/DICOM")
CASE_SET.add(hash)
# exit()
2026-02-22 04:12:24 +00:00
DATALIST.append({
"Category": category,
"Stem": stem,
"EquivalentEllipsoidDiameter": con_ov['stats'].GetEquivalentEllipsoidDiameter(1),
"PhysicalSize": con_ov['stats'].GetPhysicalSize(1),
# "dicomdir": dicomdir,
# "image_nii": image_nii,
# "label_nii": label_nii,
# "dest_dir": dest_dir,
})
df = pd.DataFrame(DATALIST)
df.to_excel(f"{FINAL_DIR}/datalist.xlsx", index=False)
if __name__ == "__main__":
main()