From 29b3965c5a65556a7eeb8a50247dc38d5a26b528 Mon Sep 17 00:00:00 2001 From: Furen Xiao Date: Sat, 7 Mar 2026 19:32:28 +0800 Subject: [PATCH] excluding 4D imageswhich brake nnUNet --- kimo/2-check-series.py | 307 ++++++++++++++++++++++++++++++++--------- kimo/3-inference.py | 21 ++- kimo/4-final-copy.py | 33 ++++- test_dicom_json.py | 10 ++ 4 files changed, 301 insertions(+), 70 deletions(-) create mode 100644 test_dicom_json.py diff --git a/kimo/2-check-series.py b/kimo/2-check-series.py index d56b70a..6d53b70 100644 --- a/kimo/2-check-series.py +++ b/kimo/2-check-series.py @@ -1,3 +1,6 @@ + +from pathlib import Path + import base64 import collections import datetime @@ -9,6 +12,9 @@ import shelve import shutil import tempfile +from lmdbm import Lmdb +from tqdm import tqdm + import dicom2nifti import pydicom @@ -23,6 +29,7 @@ INCLUDED_HASH = [ LAST_DAY = datetime.datetime.strptime("2025-11-01", "%Y-%m-%d") MAX_PATIENT = 10//3 +MAX_PATIENT = 1000 SRC_ROOT = "/mnt/t24/Public/kimo/TSHA" RAW_DIR = "/mnt/t24/Public/kimo/raw/" @@ -31,7 +38,7 @@ DST_ROOT = os.path.join(RAW_DIR, "DICOM") imagesTs_DIR = os.path.join(RAW_DIR, "Dataset2602_BraTS-CK/imagesTs/") NII_JSON_PATH = os.path.join(RAW_DIR, 'nii.json') -if os.path.exists(NII_JSON_PATH): +if os.path.isfile(NII_JSON_PATH): with open(NII_JSON_PATH, 'r') as f: NII_DICT = json.load(f) else: @@ -46,57 +53,147 @@ else: BodyPartExamined = collections.Counter() BodyPartIncluded = set([ 'BRAIN', + 'BRAIN_CE', 'CIRCLEOFWILLIS', 'HEAD', 'IAC', # 'ORBIT', ]) -def is_axial(o): - return o[1]==0 and o[2]==0 and o[3]==0 and o[5]==0 +STUDY_DB_PATH = 'study.db' + +class JsonLmdb(Lmdb): + def _pre_key(self, value): + return value.encode("utf-8") + def _post_key(self, value): + return value.decode("utf-8") + def _pre_value(self, value): + def default(o): + if isinstance(o, (pydicom.dataset.FileDataset, pydicom.Dataset)): + return json.loads(o.to_json(suppress_invalid_tags=True)) + raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") + return json.dumps(value, default=default).encode("utf-8") + def _post_value(self, value): + data = json.loads(value.decode("utf-8")) + if isinstance(data, dict): + for k, v in data.items(): + if isinstance(v, dict) and 'FileDataset' in v: + if isinstance(v['FileDataset'], dict) or isinstance(v['FileDataset'], str): + v['FileDataset'] = pydicom.Dataset.from_json(v['FileDataset']) + return data + +# study_db = shelve.open(STUDY_DB_PATH, 'c') +study_db = JsonLmdb.open(STUDY_DB_PATH, "c") + + +def is_axial(s): + o = s.split(' ') + return o[1]=='0' and o[2]=='0' and o[3]=='0' and o[5]=='0' + +def int_orientation(o): + orientation = [str(round(f)) for f in o] + return ' '.join(orientation) + +def check_4d_series(ds_list): + # path = Path(directory_path) + # Dictionary to group slices by their 3D coordinates (x, y, z) + spatial_groups = collections.defaultdict(list) + + for ds in ds_list: + try: + # ds = pydicom.dcmread(dcm_file, stop_before_pixels=True) + + # Use ImagePositionPatient (0020, 0032) to identify unique 3D locations + if "ImagePositionPatient" in ds: + pos = tuple(ds.ImagePositionPatient) + # Store relevant temporal tags for inspection + time_info = { + "AcquisitionTime": getattr(ds, "AcquisitionTime", "N/A"), + "ContentTime": getattr(ds, "ContentTime", "N/A"), + "InstanceNumber": getattr(ds, "InstanceNumber", "N/A") + } + spatial_groups[pos].append(time_info) + except Exception as e: + logging.error(f"Error reading {dcm_file.name}: {e}") + + # Analyze findings + is_4d = False + for pos, slices in spatial_groups.items(): + if len(slices) > 1: + is_4d = True + logging.warning(f"{ds.SeriesDescription} Detected 4D: Position {pos} has {len(slices)} temporal frames.") + break + + # if not is_4d: + # logging.warning(f"{ds.SeriesDescription} Series appears to be standard 3D (one slice per position).") + + return is_4d def check_study(study_dir): - SeriesDescription = set() - series = {} - for root, dirs, files in os.walk(study_dir): - for file in sorted(files, key=lambda x: int(x.split("_")[-1].split(".")[0])): - if file.endswith(".dcm"): - dcm_file = os.path.join(root, file) - ds = pydicom.dcmread(dcm_file, force=True, stop_before_pixels=True) + key = '_'.join(Path(study_dir).parts[-2:]) - if 'BodyPartExamined' in ds: - BodyPartExamined[ds.BodyPartExamined] += 1 - if 'StudyDescription' in ds: - StudyDescription = ds.StudyDescription + # with JsonLmdb.open(STUDY_DB_PATH, "c") as db: + study = study_db.get(key, None) - if 'ImageOrientationPatient' not in ds: - continue + if study: + logging.warning(f"use cached series for {study_dir}") + else: + logging.warning(f"scan series for {study_dir}") - # print(f"{dcm_file}") - series_instance_uid = ds.SeriesInstanceUID - SeriesDescription.add(ds.SeriesDescription) - # print(body_part_examined, series_description) + SeriesDescription = set() + study = {} - if series_instance_uid not in series: - series[series_instance_uid] = { - 'FileDataset': ds, + DCMS = [] + SERIES_DS = {} - '1st_file': dcm_file, + for root, dirs, files in os.walk(study_dir): + for file in sorted(files, key=lambda x: int(x.split("_")[-1].split(".")[0])): + if file.endswith(".dcm"): + DCMS.append(os.path.join(root, file)) - 'orientations': [], - 'files': [], - } + for dcm_file in tqdm(DCMS): + ds = pydicom.dcmread(dcm_file, force=True, stop_before_pixels=True) - series[series_instance_uid]['files'].append(dcm_file) - series[series_instance_uid]['orientations'].append(tuple(ds.ImageOrientationPatient)) - # print(ds.ImageOrientationPatient) - # exit() + if 'BodyPartExamined' in ds: + BodyPartExamined[ds.BodyPartExamined] += 1 + + if 'ImageOrientationPatient' not in ds: + continue + + # print(f"{dcm_file}") + series_instance_uid = ds.SeriesInstanceUID + SeriesDescription.add(ds.SeriesDescription) + # print(body_part_examined, series_description) + + if series_instance_uid not in study: + study[series_instance_uid] = { + 'FileDataset': ds, + + '1st_file': dcm_file, + + 'orientations': [], + 'files': [], + } + SERIES_DS[series_instance_uid] = [] + + study[series_instance_uid]['files'].append(os.path.relpath(dcm_file, study_dir)) + study[series_instance_uid]['orientations'].append(int_orientation(ds.ImageOrientationPatient)) + SERIES_DS[series_instance_uid].append(ds) + + for uid, ds_list in SERIES_DS.items(): + study[uid]['is_4d'] = check_4d_series(ds_list) + + # with JsonLmdb.open(STUDY_DB_PATH, "c") as db: + # with shelve.open(STUDY_DB_PATH) as db: + study_db[key] = study brain_list = [] body_parts = set() - for uid, s in series.items(): - # logging.info(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}") + s = None + + for uid, s in study.items(): + # logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}") if 'BodyPartExamined' in s['FileDataset']: if s['FileDataset'].BodyPartExamined in BodyPartIncluded: brain_list.append(s) @@ -105,41 +202,52 @@ def check_study(study_dir): if not brain_list: if body_parts: - logging.info(f"no brain, BodyPartExamined: {body_parts}") + logging.warning(f"no brain, BodyPartExamined: {body_parts}") return None else: - logging.info(f"BodyPartExamined is empty") - if 'brain' in StudyDescription.lower(): - logging.info(f"brain in {StudyDescription}, adding all series") - brain_list = list(series.values()) - # print(series) - # print(brain_list) - # exit() + if not s: + logging.warning(f"no series found") + return None + logging.warning(f"BodyPartExamined is empty") + if 'brain' in s['FileDataset'].StudyDescription.lower(): + logging.warning(f"brain in {s['FileDataset'].StudyDescription}, adding all series") + brain_list = list(study.values()) else: - logging.info(f"no brain in {StudyDescription}") + logging.warning(f"no brain in {s['FileDataset'].StudyDescription}") return None t1c = [] for s in brain_list: + + sd = s['FileDataset'].SeriesDescription.lower() if not ('+' in sd or 'gd' in sd): continue + if s['is_4d']: + logging.warning(f"4D series found: {s['FileDataset'].SeriesDescription}") + continue + if 't1' not in sd and ( 'flair' in sd or 't2' in sd or - 'perf' in sd # perfusion series (ep2d_perf) + + 'tof' in sd or + + 'dwi' in sd or + 'ep2d' in sd or + 'perf' in sd # perfusion series ): continue t1c.append(s) if not t1c: - logging.info(f"no t1c in {StudyDescription}") + logging.warning(f"no t1c in {s['FileDataset'].StudyDescription}") for s in brain_list: - logging.info(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {len(s['files'])}") + logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {len(s['files'])}") return None t1c_axial = [] @@ -147,23 +255,20 @@ def check_study(study_dir): for s in t1c: c = collections.Counter(s['orientations']) orientation_str = c.most_common(1)[0][0] - orientation_float = tuple(float(f) for f in orientation_str) - orientation = tuple(round(f) for f in orientation_float) - s['Orientation'] = orientation - if is_axial(orientation): - logging.info(f"--- {s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {s['Orientation']} {len(s['files'])}") + if is_axial(orientation_str): + logging.warning(f"--- {s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {orientation_str} {len(s['files'])}") t1c_axial.append(s) if not t1c_axial: - logging.info(f"no axial t1c in {study_dir}") + logging.warning(f"no axial t1c in {study_dir}") for s in t1c: - logging.info(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {s['Orientation']} {len(s['files'])} {StudyDescription}") + logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {orientation_str} {len(s['files'])} {s['FileDataset'].StudyDescription}") return None best_series = max(t1c_axial, key=lambda x: (len(x['files']), -x['FileDataset'].SeriesNumber)) # best_series = min(t1c_axial, key=lambda x: len(x['files'], x['FileDataset'].SeriesNumber))) - logging.info(f"{best_series['FileDataset'].SeriesNumber} {best_series['FileDataset'].SeriesDescription} {best_series['Orientation']} {len(best_series['files'])}") + logging.warning(f"{best_series['FileDataset'].SeriesNumber} {best_series['FileDataset'].SeriesDescription} {orientation_str} {len(best_series['files'])}") return best_series @@ -176,10 +281,10 @@ def hashptid(mrn, hosp='NTUH'): hash = base64.b32encode(hash_in_bytes.digest())[:8].decode() return md5, hash -def anonymize_series_to_nifti(series_files, dst_dir): +def anonymize_series_to_nifti(study_dir, series_files, dst_dir): os.makedirs(dst_dir, exist_ok=True) for f in series_files: - ds = pydicom.dcmread(f) + ds = pydicom.dcmread(os.path.join(study_dir, f)) md5, hash = hashptid(ds.PatientID) for elem in ds: if elem.tag.group == 0x0010: @@ -195,7 +300,7 @@ def anonymize_series_to_nifti(series_files, dst_dir): if e.is_file() and e.name.endswith(".nii.gz"): stem = f"{hash}-{ds.StudyDate}" dst_file = os.path.join(imagesTs_DIR, f"{stem}_0000.nii.gz") - logging.info(f"copying to {dst_file}") + logging.warning(f"copying to {dst_file}") shutil.copyfile(e.path, dst_file) NII_DICT[stem] = os.path.relpath(dst_dir, DST_ROOT) @@ -203,7 +308,7 @@ def main(): FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s' logging.basicConfig( - level=logging.INFO, + level=logging.WARNING, format=FORMAT, handlers=[ logging.StreamHandler(), @@ -212,7 +317,7 @@ def main(): ] ) - shutil.rmtree(imagesTs_DIR, ignore_errors=True) + # shutil.rmtree(imagesTs_DIR, ignore_errors=True) os.makedirs(imagesTs_DIR, exist_ok=True) for patho in sorted(os.listdir(SRC_ROOT)): @@ -234,8 +339,12 @@ def main(): if not os.path.isdir(patient_dir): continue + if not os.path.isfile(patient_dir+'.complete'): + continue + + if not os.path.isfile(os.path.join(patho_dir, f"{patient}.complete")): - logging.info(f"skip {patient_dir}") + logging.warning(f"skip {patient_dir}") continue md5, hash = hashptid(patient) @@ -244,7 +353,7 @@ def main(): complete_file = os.path.join(DST_ROOT, patho, f'{hash}.complete') if os.path.exists(complete_file): - logging.info(f"skip {patient_dir}") + logging.warning(f"skip {patient_dir}") continue num_study = 0 @@ -252,20 +361,20 @@ def main(): for study in sorted(os.listdir(patient_dir), reverse=True): study_date = study.split('_')[0] if datetime.datetime.strptime(study_date, "%Y%m%d") > LAST_DAY: - logging.info(f"skip {study_date}") + logging.warning(f"skip {study_date}") continue study_dir = os.path.join(patient_dir, study) if not os.path.isdir(study_dir): continue - logging.info(study_dir) + # logging.warning(study_dir) best_series = check_study(study_dir) if not best_series: continue dst_dir = os.path.join(dst_patient_dir, study) - anonymize_series_to_nifti(best_series['files'], dst_dir) + anonymize_series_to_nifti(study_dir, best_series['files'], dst_dir) num_study += 1 if num_study > 0: @@ -275,11 +384,79 @@ def main(): if num_patient >= MAX_PATIENT: break # break - print(NII_DICT) - logging.info(f"BodyPartExamined: {BodyPartExamined}") + logging.warning(f"BodyPartExamined: {BodyPartExamined}") with open(NII_JSON_PATH, 'w') as f: json.dump(NII_DICT, f, indent=1) +def list_4d(): + PLUS_3D = set() + PLUS_4D = set() + + + for study_id, study in study_db.items(): + + brain_list = [] + body_parts = set() + + s = None + + for uid, s in study.items(): + # logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}") + if 'BodyPartExamined' in s['FileDataset']: + if s['FileDataset'].BodyPartExamined in BodyPartIncluded: + brain_list.append(s) + else: + body_parts.add(s['FileDataset'].BodyPartExamined) + + if not brain_list: + if body_parts: + logging.warning(f"no brain, BodyPartExamined: {body_parts}") + continue + else: + if not s: + logging.warning(f"no series found") + continue + logging.warning(f"BodyPartExamined is empty") + if 'brain' in s['FileDataset'].StudyDescription.lower(): + logging.warning(f"brain in {s['FileDataset'].StudyDescription}, adding all series") + brain_list = list(study.values()) + else: + logging.warning(f"no brain in {s['FileDataset'].StudyDescription}") + continue + + + for s in brain_list: + sd = s['FileDataset'].SeriesDescription.lower() + if not ('+' in sd or 'gd' in sd): + continue + + if 't1' in sd: + continue + if 't2' in sd: + continue + if 'flair' in sd: + continue + + c = collections.Counter(s['orientations']) + orientation_str = c.most_common(1)[0][0] + if not is_axial(orientation_str): + continue + + if s['is_4d']: + logging.warning(f"4D series found: {s['FileDataset'].SeriesDescription}") + PLUS_4D.add(s['FileDataset'].SeriesDescription) + else: + PLUS_3D.add(s['FileDataset'].SeriesDescription) + + with open("plus_4d.json", 'w') as f: + json.dump(sorted(PLUS_4D), f, indent=1) + + with open("plus_3d.json", 'w') as f: + json.dump(sorted(PLUS_3D), f, indent=1) + + exit() + if __name__ == '__main__': + # list_4d() main() diff --git a/kimo/3-inference.py b/kimo/3-inference.py index b413f0d..897e759 100644 --- a/kimo/3-inference.py +++ b/kimo/3-inference.py @@ -38,14 +38,33 @@ def main(): config = yaml.safe_load(f) PREPROCESSED_DIR = config["nnunet_preprocessed"] + RAW_DIR = config["nnunet_raw"] RESULTS_DIR = config["nnunet_results"] + + IMAGES_TS = os.path.join(RAW_DIR, DATASET, "imagesTs") + IMAGES_TS_DONE = os.path.join(RAW_DIR, DATASET, "imagesTs_done") + + POSTPROCESSED_DIR = os.path.join(RESULTS_DIR, DATASET, "ensemble_predictions_postprocessed") + + os.makedirs(IMAGES_TS_DONE, exist_ok=True) + if os.path.exists(POSTPROCESSED_DIR): + processed_files = [f for f in os.listdir(POSTPROCESSED_DIR) if f.endswith('.nii.gz')] + processed_ids = [f.replace('.nii.gz', '') for f in processed_files] + for pid in processed_ids: + for mod in range(10): # support up to 10 modalities + mod_file = f"{pid}_{mod:04d}.nii.gz" + src_path = os.path.join(IMAGES_TS, mod_file) + dst_path = os.path.join(IMAGES_TS_DONE, mod_file) + if os.path.exists(src_path): + shutil.move(src_path, dst_path) + # clean up shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_2d', ignore_errors=True) shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_fullres', ignore_errors=True) shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_lowres', ignore_errors=True) shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions', ignore_errors=True) - shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True) + # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True) runner = nnUNetV2Runner(YAML) try: diff --git a/kimo/4-final-copy.py b/kimo/4-final-copy.py index b563423..e0dc199 100644 --- a/kimo/4-final-copy.py +++ b/kimo/4-final-copy.py @@ -187,15 +187,38 @@ def main(): ] ) - with open(NII_JSON_PATH, 'r') as f: - NII_DICT = json.load(f) + NII_DICT = {} + + for pathology in sorted(os.listdir(DICOM_ROOT)): + pathology_path = os.path.join(DICOM_ROOT, pathology) + if not os.path.isdir(pathology_path): + continue + pathology_name = pathology + + for patient in sorted(os.listdir(pathology_path)): + patient_path = os.path.join(pathology_path, patient) + if not os.path.isdir(patient_path): + continue + patient_name = patient + + for study in sorted(os.listdir(patient_path)): + study_path = os.path.join(patient_path, study) + if not os.path.isdir(study_path): + continue + study_name = study + + key = f"{patient_name}-{study_name.split('_')[0]}" + NII_DICT[key] = os.path.relpath(study_path, DICOM_ROOT) + # print(NII_DICT) + # exit() for stem, dicomdir in sorted(NII_DICT.items()): - logging.info(f"{stem} {dicomdir}") hash, date = stem.split("-") if hash in CASE_SET: + logging.info(f"skip {stem} {dicomdir}") continue + logging.info(f"{stem} {dicomdir}") category = dicomdir.split("/")[0] image_nii = f"{imagesTs_DIR}/{stem}_0000.nii.gz" @@ -207,8 +230,10 @@ def main(): dest_dir = f"{category_dir}/{stem}" if not os.path.exists(image_nii): + logging.info(f"{image_nii} not exists") continue if not os.path.exists(label_nii): + logging.info(f"{label_nii} not exists") continue con_ov = contour_overlay(image_nii, label_nii) @@ -222,7 +247,7 @@ def main(): shutil.copyfile(image_nii, f"{dest_dir}/image.nii.gz") shutil.copyfile(label_nii, f"{dest_dir}/label.nii.gz") - # copytree(dicomdir, f"{dest_dir}/DICOM") + copytree(dicomdir, f"{dest_dir}/DICOM") CASE_SET.add(hash) # exit() diff --git a/test_dicom_json.py b/test_dicom_json.py new file mode 100644 index 0000000..c408347 --- /dev/null +++ b/test_dicom_json.py @@ -0,0 +1,10 @@ +import pydicom +import json +ds = pydicom.dataset.Dataset() +ds.PatientName = "Test" +json_str = ds.to_json() +print("to_json:", json_str) +dict_obj = json.loads(json_str) + +print("type of dict_obj:", type(dict_obj)) +print("dict_obj keys:", dict_obj.keys())