From 4f4e8f888211b0d2d35856833710fe7334a19713 Mon Sep 17 00:00:00 2001 From: Furen Xiao Date: Sun, 22 Feb 2026 09:49:47 +0800 Subject: [PATCH] feat: Add MRI data preprocessing pipeline including series checking, anonymization, and NIfTI conversion. --- kimo/2-check-series.py | 280 +++++++++++++++++++++++++++++++++++++++++ kimo/2602.yaml | 4 + kimo/3-inference.py | 72 +++++++++++ kimo/4-final-copy.py | 222 ++++++++++++++++++++++++++++++++ 4 files changed, 578 insertions(+) create mode 100644 kimo/2-check-series.py create mode 100644 kimo/2602.yaml create mode 100644 kimo/3-inference.py create mode 100644 kimo/4-final-copy.py diff --git a/kimo/2-check-series.py b/kimo/2-check-series.py new file mode 100644 index 0000000..74c157d --- /dev/null +++ b/kimo/2-check-series.py @@ -0,0 +1,280 @@ +import base64 +import collections +import datetime +import hashlib +import json +import logging +import os +import shelve +import shutil +import tempfile + +import dicom2nifti +import pydicom + +EXCLUDED_HASH = [ + # "GWJU7LPC", +] + +INCLUDED_HASH = [ + # "GWJU7LPC", +] + + +LAST_DAY = datetime.datetime.strptime("2025-11-01", "%Y-%m-%d") + +SRC_ROOT = "/mnt/t24/Public/kimo/TSHA" +RAW_DIR = "/mnt/t24/Public/kimo/raw/" + +DST_ROOT = os.path.join(RAW_DIR, "DICOM") +imagesTs_DIR = os.path.join(RAW_DIR, "Dataset2602_BraTS-CK/imagesTs/") +NII_JSON_PATH = os.path.join(RAW_DIR, 'nii.json') + +if os.path.exists(NII_JSON_PATH): + with open(NII_JSON_PATH, 'r') as f: + NII_DICT = json.load(f) +else: + NII_DICT = {} + +# {'', 'PELVISLOWEXTREM', 'BRAIN', 'CSPINE', 'KNEE', 'TSPINE', 'CAROTID', 'NECK', 'ABDOMEN', 'ORBIT', 'HEAD', 'CHEST', 'IAC', 'WHOLEBODY', 'WHOLESPINE', 'ABDOMENPELVIS', 'PELVIS', 'LSPINE', 'SPINE', 'CIRCLEOFWILLIS'} +# BodyPartExamined: Counter({'BRAIN': 152087, 'ABDOMEN': 14101, 'HEAD': 11806, 'ABDOMENPELVIS': 10905, 'SPINE': 9277, 'CHEST': 3746, 'PELVIS': 3208, 'NECK': 3205, 'CSPINE': 1527, 'CAROTID': 1186, +# 'HEART': 1122, 'LSPINE': 1080, 'KNEE': 591, 'PELVISLOWEXTREM': 496, '': 385, 'ORBIT': 360, 'CIRCLEOFWILLIS': 322, 'HUMERUS': 320, 'ARM': 304, 'IAC': 291, +# 'EXTREMITY': 287, 'SHOULDER': 242, 'WHOLEBODY': 190, 'TSPINE': 150, 'HEADNECK': 48, 'WHOLESPINE': 45}) + + +BodyPartExamined = collections.Counter() +BodyPartIncluded = set([ + 'BRAIN', + 'CIRCLEOFWILLIS', + 'HEAD', + 'IAC', + # 'ORBIT', +]) + +def is_axial(o): + return o[1]==0 and o[2]==0 and o[3]==0 and o[5]==0 + +def check_study(study_dir): + SeriesDescription = set() + series = {} + for root, dirs, files in os.walk(study_dir): + for file in sorted(files, key=lambda x: int(x.split("_")[-1].split(".")[0])): + if file.endswith(".dcm"): + dcm_file = os.path.join(root, file) + ds = pydicom.dcmread(dcm_file, force=True, stop_before_pixels=True) + + if 'BodyPartExamined' in ds: + BodyPartExamined[ds.BodyPartExamined] += 1 + if 'StudyDescription' in ds: + StudyDescription = ds.StudyDescription + + if 'ImageOrientationPatient' not in ds: + continue + + # print(f"{dcm_file}") + series_instance_uid = ds.SeriesInstanceUID + SeriesDescription.add(ds.SeriesDescription) + # print(body_part_examined, series_description) + + if series_instance_uid not in series: + series[series_instance_uid] = { + 'FileDataset': ds, + + '1st_file': dcm_file, + + 'orientations': [], + 'files': [], + } + + series[series_instance_uid]['files'].append(dcm_file) + series[series_instance_uid]['orientations'].append(tuple(ds.ImageOrientationPatient)) + # print(ds.ImageOrientationPatient) + # exit() + + brain_list = [] + body_parts = set() + + for uid, s in series.items(): + # logging.info(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}") + if 'BodyPartExamined' in s['FileDataset']: + if s['FileDataset'].BodyPartExamined in BodyPartIncluded: + brain_list.append(s) + else: + body_parts.add(s['FileDataset'].BodyPartExamined) + + if not brain_list: + if body_parts: + logging.info(f"no brain, BodyPartExamined: {body_parts}") + return None + else: + logging.info(f"BodyPartExamined is empty") + if 'brain' in StudyDescription.lower(): + logging.info(f"brain in {StudyDescription}, adding all series") + brain_list = list(series.values()) + # print(series) + # print(brain_list) + # exit() + else: + logging.info(f"no brain in {StudyDescription}") + return None + + t1c = [] + + for s in brain_list: + sd = s['FileDataset'].SeriesDescription.lower() + + if not ('+' in sd or 'gd' in sd): + continue + + if 't1' not in sd and ( + 'flair' in sd or + 't2' in sd or + 'perf' in sd # perfusion series (ep2d_perf) + ): + continue + + t1c.append(s) + + if not t1c: + logging.info(f"no t1c in {StudyDescription}") + for s in brain_list: + logging.info(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {len(s['files'])}") + return None + + t1c_axial = [] + + for s in t1c: + c = collections.Counter(s['orientations']) + orientation_str = c.most_common(1)[0][0] + orientation_float = tuple(float(f) for f in orientation_str) + orientation = tuple(round(f) for f in orientation_float) + + s['Orientation'] = orientation + if is_axial(orientation): + logging.info(f"--- {s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {s['Orientation']} {len(s['files'])}") + t1c_axial.append(s) + + if not t1c_axial: + logging.info(f"no axial t1c in {study_dir}") + for s in t1c: + logging.info(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {s['Orientation']} {len(s['files'])} {StudyDescription}") + return None + + best_series = max(t1c_axial, key=lambda x: (len(x['files']), -x['FileDataset'].SeriesNumber)) + # best_series = min(t1c_axial, key=lambda x: len(x['files'], x['FileDataset'].SeriesNumber))) + logging.info(f"{best_series['FileDataset'].SeriesNumber} {best_series['FileDataset'].SeriesDescription} {best_series['Orientation']} {len(best_series['files'])}") + + return best_series + +def hashptid(mrn, hosp='NTUH'): + + ptsalt = (mrn+hosp).upper().encode() + hash_in_bytes = hashlib.md5(ptsalt) + + md5 = hash_in_bytes.hexdigest() + hash = base64.b32encode(hash_in_bytes.digest())[:8].decode() + return md5, hash + +def anonymize_series_to_nifti(series_files, dst_dir): + os.makedirs(dst_dir, exist_ok=True) + for f in series_files: + ds = pydicom.dcmread(f) + md5, hash = hashptid(ds.PatientID) + for elem in ds: + if elem.tag.group == 0x0010: + elem.value = '' + + ds.PatientID = hash + dst_file = os.path.join(dst_dir, os.path.basename(f.split("_")[-1])) + ds.save_as(dst_file) + + with tempfile.TemporaryDirectory() as tmpdirname: + dicom2nifti.convert_directory(dst_dir, tmpdirname, compression=True) + for e in os.scandir(tmpdirname): + if e.is_file() and e.name.endswith(".nii.gz"): + stem = f"{hash}-{ds.StudyDate}" + dst_file = os.path.join(imagesTs_DIR, f"{stem}_0000.nii.gz") + logging.info(f"copying to {dst_file}") + shutil.copyfile(e.path, dst_file) + NII_DICT[stem] = os.path.relpath(dst_dir, DST_ROOT) + +def main(): + + FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s' + logging.basicConfig( + level=logging.INFO, + format=FORMAT, + handlers=[ + logging.StreamHandler(), + # logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8') + logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8') + ] + ) + + shutil.rmtree(imagesTs_DIR, ignore_errors=True) + os.makedirs(imagesTs_DIR, exist_ok=True) + + for patho in sorted(os.listdir(SRC_ROOT)): + patho_dir = os.path.join(SRC_ROOT, patho) + + for patient in sorted(os.listdir(patho_dir)): + md5, hash = hashptid(patient) + + if INCLUDED_HASH: + if hash not in INCLUDED_HASH: + continue + + if hash in EXCLUDED_HASH: + continue + + patient_dir = os.path.join(patho_dir, patient) + if not os.path.isdir(patient_dir): + continue + + if not os.path.isfile(os.path.join(patho_dir, f"{patient}.complete")): + logging.info(f"skip {patient_dir}") + continue + + md5, hash = hashptid(patient) + + dst_patient_dir = os.path.join(DST_ROOT, patho, hash) + complete_file = os.path.join(DST_ROOT, patho, f'{hash}.complete') + + if os.path.exists(complete_file): + logging.info(f"skip {patient_dir}") + continue + + num_study = 0 + + for study in sorted(os.listdir(patient_dir), reverse=True): + study_date = study.split('_')[0] + if datetime.datetime.strptime(study_date, "%Y%m%d") > LAST_DAY: + logging.info(f"skip {study_date}") + continue + study_dir = os.path.join(patient_dir, study) + if not os.path.isdir(study_dir): + continue + + logging.info(study_dir) + best_series = check_study(study_dir) + if not best_series: + continue + + dst_dir = os.path.join(dst_patient_dir, study) + + anonymize_series_to_nifti(best_series['files'], dst_dir) + num_study += 1 + + if num_study > 0: + with open(complete_file, 'w') as f: + f.write('done') + # break + # break + print(NII_DICT) + logging.info(f"BodyPartExamined: {BodyPartExamined}") + + with open(NII_JSON_PATH, 'w') as f: + json.dump(NII_DICT, f, indent=1) + +if __name__ == '__main__': + main() diff --git a/kimo/2602.yaml b/kimo/2602.yaml new file mode 100644 index 0000000..b65b29a --- /dev/null +++ b/kimo/2602.yaml @@ -0,0 +1,4 @@ +dataset_name_or_id: 2602 # models trained on BraTS + NTUH +nnunet_preprocessed: "/mnt/b4/Public/nnUNet/preprocessed" # directory for storing pre-processed data (optional) +nnunet_raw: "/mnt/t24/Public/kimo/raw/" # directory for storing formated raw data (optional) +nnunet_results: "/mnt/b4/Public/nnUNet/results" # directory for storing trained model checkpoints (optional) diff --git a/kimo/3-inference.py b/kimo/3-inference.py new file mode 100644 index 0000000..b413f0d --- /dev/null +++ b/kimo/3-inference.py @@ -0,0 +1,72 @@ +# done with GWJU7LPC-20201007 + +import logging +import os +import shutil + +from monai.apps.nnunet import nnUNetV2Runner + +import yaml + +DATASET="Dataset2602_BraTS-CK" +YAML ="2602.yaml" + +DST_DIR="/mnt/t24/Public/kimo/raw/ensemble_predictions_postprocessed" + +def copytree(src, dst): + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + for file in files: + src_file = os.path.join(root, file) + dst_file = os.path.join(dst, file) + shutil.copyfile(src_file, dst_file) + + +def main(): + FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s' + logging.basicConfig( + level=logging.INFO, + format=FORMAT, + handlers=[ + logging.StreamHandler(), + # logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8') + logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8') + ] + ) + + with open(YAML, "r") as f: + config = yaml.safe_load(f) + + PREPROCESSED_DIR = config["nnunet_preprocessed"] + RESULTS_DIR = config["nnunet_results"] + + # clean up + shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_2d', ignore_errors=True) + shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_fullres', ignore_errors=True) + shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_lowres', ignore_errors=True) + shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions', ignore_errors=True) + shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True) + + runner = nnUNetV2Runner(YAML) + try: + runner.predict_ensemble_postprocessing( + # num_processes_preprocessing=1, + # num_processes_segmentation_export=1 + ) + except Exception as e: + logging.exception(e) + exit() + + copytree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', DST_DIR) + + # exit() + + # clean up + # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_2d', ignore_errors=True) + # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_fullres', ignore_errors=True) + # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_lowres', ignore_errors=True) + # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions', ignore_errors=True) + # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True) + +if __name__ == "__main__": + main() diff --git a/kimo/4-final-copy.py b/kimo/4-final-copy.py new file mode 100644 index 0000000..79458cf --- /dev/null +++ b/kimo/4-final-copy.py @@ -0,0 +1,222 @@ +import logging +import json +import os +import shutil + +from PIL import Image + +import numpy as np +import SimpleITK as sitk + +ROOT_DIR = "/mnt/t24/Public/kimo/" +RAW_DIR = f"{ROOT_DIR}/raw/" +DICOM_ROOT= f"{RAW_DIR}/DICOM/" +imagesTs_DIR = f"{RAW_DIR}/Dataset2602_BraTS-CK/imagesTs" +POSTPROCESSED_DIR = f"{RAW_DIR}/ensemble_predictions_postprocessed/" + +FINAL_DIR = f"{ROOT_DIR}/final/" + +NII_JSON_PATH = f"{RAW_DIR}/nii.json" + +CASE_SET = set() + +def copytree(src, dst): + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + for file in files: + src_file = os.path.join(root, file) + dst_file = os.path.join(dst, file) + shutil.copyfile(src_file, dst_file) + + +def get_largest_connected_component(binary_image): + """ + Extracts the largest connected component from a binary SimpleITK image. + + Args: + binary_image (sitk.Image): A binary image (pixel values 0 or 1). + + Returns: + sitk.Image: A binary image containing only the largest connected component. + """ + # 1. Label connected components + # Each connected component will be assigned a unique integer label (1, 2, 3, etc) + label_image = sitk.ConnectedComponent(binary_image) + + # 2. Relabel components sorted by size + # The largest component will be assigned label 1 + relabel_filter = sitk.RelabelComponentImageFilter() + relabel_image = relabel_filter.Execute(label_image) + + # Optional: Get the size of objects in pixels if needed for verification + # sizes = relabel_filter.GetSizeOfObjectsInPixels() + # print(f"Object sizes (sorted): {sizes}") + + # 3. Threshold the relabeled image to keep only label 1 (the largest component) + # The result is a binary image of the largest component + largest_component_image = sitk.BinaryThreshold( + relabel_image, + lowerThreshold=1, + upperThreshold=1, + insideValue=1, + outsideValue=0 + ) + + return largest_component_image + +def rescale_coronal_sagittal(arr_cor, arr_sag, scale): + img_cor = Image.fromarray(arr_cor) + img_sag = Image.fromarray(arr_sag) + + cw, ch = img_cor.size + img_cor = img_cor.resize((cw, int(ch * scale)), Image.Resampling.BILINEAR) + + sw, sh = img_sag.size + img_sag = img_sag.resize((sw, int(sh * scale)), Image.Resampling.BILINEAR) + + return np.array(img_cor), np.array(img_sag) + +def pad_to_h(arr, h): + pad_h = h - arr.shape[0] + if pad_h > 0: + pad_top = pad_h // 2 + pad_bottom = pad_h - pad_top + pad_width = [(pad_top, pad_bottom), (0, 0)] + if arr.ndim == 3: + pad_width.append((0, 0)) + return np.pad(arr, tuple(pad_width), mode='constant') + return arr + +def contour_overlay(image_nii, label_nii): + + label = sitk.ReadImage(label_nii) + label = sitk.DICOMOrient(label) + largest = get_largest_connected_component(label) + + stats = sitk.LabelShapeStatisticsImageFilter() + stats.Execute(largest) + if not stats.GetNumberOfLabels(): + logging.warning(f"No label found in {label_nii}") + return None + + print(stats.GetPhysicalSize(1)) + exit() + + center_idx = largest.TransformPhysicalPointToIndex(stats.GetCentroid(1)) + + image = sitk.ReadImage(image_nii) + image = sitk.DICOMOrient(image) + image = sitk.Cast(sitk.RescaleIntensity(image), sitk.sitkUInt8) + + label.CopyInformation(image) + + + contour = sitk.LabelContour(label) + contour.CopyInformation(image) + contour = sitk.LabelOverlay(image, contour, opacity=1) + + overlay = sitk.LabelOverlay(image, label) + + # Get spacing (x, y, z) + spacing = image.GetSpacing() + sx, sy, sz = spacing + scale_z = sz / sy + + def extract_slices(img_3d): + axial = img_3d[:, :, center_idx[2]] + coronal = img_3d[:, center_idx[1], :] + sagittal = img_3d[center_idx[0], :, :] + + arr_axial = sitk.GetArrayFromImage(axial) + arr_coronal = sitk.GetArrayFromImage(coronal) + arr_sagittal = sitk.GetArrayFromImage(sagittal) + + if arr_axial.ndim == 2: + arr_axial = np.stack((arr_axial,) * 3, axis=-1) + arr_coronal = np.stack((arr_coronal,) * 3, axis=-1) + arr_sagittal = np.stack((arr_sagittal,) * 3, axis=-1) + + arr_coronal = np.flip(arr_coronal, axis=0) + arr_sagittal = np.flip(arr_sagittal, axis=0) + + arr_coronal, arr_sagittal = rescale_coronal_sagittal(arr_coronal, arr_sagittal, scale_z) + + max_h = max(arr_axial.shape[0], arr_coronal.shape[0], arr_sagittal.shape[0]) + + padded_arrays = [] + for arr in [arr_axial, arr_coronal, arr_sagittal]: + padded_arrays.append(pad_to_h(arr, max_h)) + return np.concatenate(padded_arrays, axis=1) + + row_con = extract_slices(contour) + row_orig = extract_slices(image) + row_ov = extract_slices(overlay) + + + # Image.fromarray(row_con).save(f"{save_dir}/contour.png") + + # Stack rows vertically + combined_arr = np.concatenate([row_orig, row_ov], axis=0) + combined_img = Image.fromarray(combined_arr) + # combined_img.save(f"{save_dir}/overlay.png") + + return { + 'contour_image': Image.fromarray(row_con), + 'overlay_image': combined_img + } + + +def main(): + FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s' + logging.basicConfig( + level=logging.INFO, + format=FORMAT, + handlers=[ + logging.StreamHandler(), + # logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8') + logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8') + ] + ) + + with open(NII_JSON_PATH, 'r') as f: + NII_DICT = json.load(f) + + for stem, dicomdir in sorted(NII_DICT.items()): + logging.info(f"{stem} {dicomdir}") + hash, date = stem.split("-") + if hash in CASE_SET: + continue + + category = dicomdir.split("/")[0] + + image_nii = f"{imagesTs_DIR}/{stem}_0000.nii.gz" + label_nii = f"{POSTPROCESSED_DIR}/{stem}.nii.gz" + dicomdir = f"{DICOM_ROOT}/{dicomdir}" + + category_dir = f"{FINAL_DIR}/{category}" + + dest_dir = f"{category_dir}/{stem}" + + if not os.path.exists(image_nii): + continue + if not os.path.exists(label_nii): + continue + + con_ov = contour_overlay(image_nii, label_nii) + if con_ov is None: + continue + + os.makedirs(dest_dir, exist_ok=True) + + con_ov['contour_image'].save(f"{dest_dir}/contour.png") + con_ov['overlay_image'].save(f"{category_dir}/{stem}.png") + + shutil.copyfile(image_nii, f"{dest_dir}/image.nii.gz") + shutil.copyfile(label_nii, f"{dest_dir}/label.nii.gz") + copytree(dicomdir, f"{dest_dir}/DICOM") + + CASE_SET.add(hash) + # exit() + +if __name__ == "__main__": + main()