from pathlib import Path import base64 import collections import datetime import hashlib import json import logging import os import shelve import shutil import tempfile from lmdbm import Lmdb from tqdm import tqdm import dicom2nifti import pydicom EXCLUDED_HASH = [ # "GWJU7LPC", ] INCLUDED_HASH = [ # "GWJU7LPC", ] LAST_DAY = datetime.datetime.strptime("2025-11-01", "%Y-%m-%d") MAX_PATIENT = 10//3 MAX_PATIENT = 1000 SRC_ROOT = "/mnt/t24/Public/kimo/TSHA" RAW_DIR = "/mnt/t24/Public/kimo/raw/" DST_ROOT = os.path.join(RAW_DIR, "DICOM") imagesTs_DIR = os.path.join(RAW_DIR, "Dataset2602_BraTS-CK/imagesTs/") NII_JSON_PATH = os.path.join(RAW_DIR, 'nii.json') if os.path.isfile(NII_JSON_PATH): with open(NII_JSON_PATH, 'r') as f: NII_DICT = json.load(f) else: NII_DICT = {} # {'', 'PELVISLOWEXTREM', 'BRAIN', 'CSPINE', 'KNEE', 'TSPINE', 'CAROTID', 'NECK', 'ABDOMEN', 'ORBIT', 'HEAD', 'CHEST', 'IAC', 'WHOLEBODY', 'WHOLESPINE', 'ABDOMENPELVIS', 'PELVIS', 'LSPINE', 'SPINE', 'CIRCLEOFWILLIS'} # BodyPartExamined: Counter({'BRAIN': 152087, 'ABDOMEN': 14101, 'HEAD': 11806, 'ABDOMENPELVIS': 10905, 'SPINE': 9277, 'CHEST': 3746, 'PELVIS': 3208, 'NECK': 3205, 'CSPINE': 1527, 'CAROTID': 1186, # 'HEART': 1122, 'LSPINE': 1080, 'KNEE': 591, 'PELVISLOWEXTREM': 496, '': 385, 'ORBIT': 360, 'CIRCLEOFWILLIS': 322, 'HUMERUS': 320, 'ARM': 304, 'IAC': 291, # 'EXTREMITY': 287, 'SHOULDER': 242, 'WHOLEBODY': 190, 'TSPINE': 150, 'HEADNECK': 48, 'WHOLESPINE': 45}) BodyPartExamined = collections.Counter() BodyPartIncluded = set([ 'BRAIN', 'BRAIN_CE', 'CIRCLEOFWILLIS', 'HEAD', 'IAC', # 'ORBIT', ]) STUDY_DB_PATH = 'study.db' class JsonLmdb(Lmdb): def _pre_key(self, value): return value.encode("utf-8") def _post_key(self, value): return value.decode("utf-8") def _pre_value(self, value): def default(o): if isinstance(o, (pydicom.dataset.FileDataset, pydicom.Dataset)): return json.loads(o.to_json(suppress_invalid_tags=True)) raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") return json.dumps(value, default=default).encode("utf-8") def _post_value(self, value): data = json.loads(value.decode("utf-8")) if isinstance(data, dict): for k, v in data.items(): if isinstance(v, dict) and 'FileDataset' in v: if isinstance(v['FileDataset'], dict) or isinstance(v['FileDataset'], str): v['FileDataset'] = pydicom.Dataset.from_json(v['FileDataset']) return data # study_db = shelve.open(STUDY_DB_PATH, 'c') study_db = JsonLmdb.open(STUDY_DB_PATH, "c") def is_axial(s): o = s.split(' ') return o[1]=='0' and o[2]=='0' and o[3]=='0' and o[5]=='0' def int_orientation(o): orientation = [str(round(f)) for f in o] return ' '.join(orientation) def check_4d_series(ds_list): # path = Path(directory_path) # Dictionary to group slices by their 3D coordinates (x, y, z) spatial_groups = collections.defaultdict(list) for ds in ds_list: try: # ds = pydicom.dcmread(dcm_file, stop_before_pixels=True) # Use ImagePositionPatient (0020, 0032) to identify unique 3D locations if "ImagePositionPatient" in ds: pos = tuple(ds.ImagePositionPatient) # Store relevant temporal tags for inspection time_info = { "AcquisitionTime": getattr(ds, "AcquisitionTime", "N/A"), "ContentTime": getattr(ds, "ContentTime", "N/A"), "InstanceNumber": getattr(ds, "InstanceNumber", "N/A") } spatial_groups[pos].append(time_info) except Exception as e: logging.error(f"Error reading {dcm_file.name}: {e}") # Analyze findings is_4d = False for pos, slices in spatial_groups.items(): if len(slices) > 1: is_4d = True logging.warning(f"{ds.SeriesDescription} Detected 4D: Position {pos} has {len(slices)} temporal frames.") break # if not is_4d: # logging.warning(f"{ds.SeriesDescription} Series appears to be standard 3D (one slice per position).") return is_4d def check_study(study_dir): key = '_'.join(Path(study_dir).parts[-2:]) # with JsonLmdb.open(STUDY_DB_PATH, "c") as db: study = study_db.get(key, None) if study: logging.warning(f"use cached series for {study_dir}") else: logging.warning(f"scan series for {study_dir}") SeriesDescription = set() study = {} DCMS = [] SERIES_DS = {} for root, dirs, files in os.walk(study_dir): for file in sorted(files, key=lambda x: int(x.split("_")[-1].split(".")[0])): if file.endswith(".dcm"): DCMS.append(os.path.join(root, file)) for dcm_file in tqdm(DCMS): ds = pydicom.dcmread(dcm_file, force=True, stop_before_pixels=True) if 'BodyPartExamined' in ds: BodyPartExamined[ds.BodyPartExamined] += 1 if 'ImageOrientationPatient' not in ds: continue # print(f"{dcm_file}") series_instance_uid = ds.SeriesInstanceUID SeriesDescription.add(ds.SeriesDescription) # print(body_part_examined, series_description) if series_instance_uid not in study: study[series_instance_uid] = { 'FileDataset': ds, '1st_file': dcm_file, 'orientations': [], 'files': [], } SERIES_DS[series_instance_uid] = [] study[series_instance_uid]['files'].append(os.path.relpath(dcm_file, study_dir)) study[series_instance_uid]['orientations'].append(int_orientation(ds.ImageOrientationPatient)) SERIES_DS[series_instance_uid].append(ds) for uid, ds_list in SERIES_DS.items(): study[uid]['is_4d'] = check_4d_series(ds_list) # with JsonLmdb.open(STUDY_DB_PATH, "c") as db: # with shelve.open(STUDY_DB_PATH) as db: study_db[key] = study brain_list = [] body_parts = set() s = None for uid, s in study.items(): # logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}") if 'BodyPartExamined' in s['FileDataset']: if s['FileDataset'].BodyPartExamined in BodyPartIncluded: brain_list.append(s) else: body_parts.add(s['FileDataset'].BodyPartExamined) if not brain_list: if body_parts: logging.warning(f"no brain, BodyPartExamined: {body_parts}") return None else: if not s: logging.warning(f"no series found") return None logging.warning(f"BodyPartExamined is empty") if 'brain' in s['FileDataset'].StudyDescription.lower(): logging.warning(f"brain in {s['FileDataset'].StudyDescription}, adding all series") brain_list = list(study.values()) else: logging.warning(f"no brain in {s['FileDataset'].StudyDescription}") return None t1c = [] for s in brain_list: sd = s['FileDataset'].SeriesDescription.lower() if not ('+' in sd or 'gd' in sd): continue if s['is_4d']: logging.warning(f"4D series found: {s['FileDataset'].SeriesDescription}") continue if 't1' not in sd and ( 'flair' in sd or 't2' in sd or 'tof' in sd or 'dwi' in sd or 'ep2d' in sd or 'perf' in sd # perfusion series ): continue t1c.append(s) if not t1c: logging.warning(f"no t1c in {s['FileDataset'].StudyDescription}") for s in brain_list: logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {len(s['files'])}") return None t1c_axial = [] for s in t1c: c = collections.Counter(s['orientations']) orientation_str = c.most_common(1)[0][0] if is_axial(orientation_str): logging.warning(f"--- {s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {orientation_str} {len(s['files'])}") t1c_axial.append(s) if not t1c_axial: logging.warning(f"no axial t1c in {study_dir}") for s in t1c: logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {orientation_str} {len(s['files'])} {s['FileDataset'].StudyDescription}") return None best_series = max(t1c_axial, key=lambda x: (len(x['files']), -x['FileDataset'].SeriesNumber)) # best_series = min(t1c_axial, key=lambda x: len(x['files'], x['FileDataset'].SeriesNumber))) logging.warning(f"{best_series['FileDataset'].SeriesNumber} {best_series['FileDataset'].SeriesDescription} {orientation_str} {len(best_series['files'])}") return best_series def hashptid(mrn, hosp='NTUH'): ptsalt = (mrn+hosp).upper().encode() hash_in_bytes = hashlib.md5(ptsalt) md5 = hash_in_bytes.hexdigest() hash = base64.b32encode(hash_in_bytes.digest())[:8].decode() return md5, hash def anonymize_series_to_nifti(study_dir, series_files, dst_dir): os.makedirs(dst_dir, exist_ok=True) for f in series_files: ds = pydicom.dcmread(os.path.join(study_dir, f)) md5, hash = hashptid(ds.PatientID) for elem in ds: if elem.tag.group == 0x0010: elem.value = '' ds.PatientID = hash dst_file = os.path.join(dst_dir, os.path.basename(f.split("_")[-1])) ds.save_as(dst_file) with tempfile.TemporaryDirectory() as tmpdirname: dicom2nifti.convert_directory(dst_dir, tmpdirname, compression=True) for e in os.scandir(tmpdirname): if e.is_file() and e.name.endswith(".nii.gz"): stem = f"{hash}-{ds.StudyDate}" dst_file = os.path.join(imagesTs_DIR, f"{stem}_0000.nii.gz") logging.warning(f"copying to {dst_file}") shutil.copyfile(e.path, dst_file) NII_DICT[stem] = os.path.relpath(dst_dir, DST_ROOT) def main(): FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s' logging.basicConfig( level=logging.WARNING, format=FORMAT, handlers=[ logging.StreamHandler(), # logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8') logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8') ] ) # shutil.rmtree(imagesTs_DIR, ignore_errors=True) os.makedirs(imagesTs_DIR, exist_ok=True) for patho in sorted(os.listdir(SRC_ROOT)): patho_dir = os.path.join(SRC_ROOT, patho) num_patient = 0 for patient in sorted(os.listdir(patho_dir)): md5, hash = hashptid(patient) if INCLUDED_HASH: if hash not in INCLUDED_HASH: continue if hash in EXCLUDED_HASH: continue patient_dir = os.path.join(patho_dir, patient) if not os.path.isdir(patient_dir): continue if not os.path.isfile(patient_dir+'.complete'): continue if not os.path.isfile(os.path.join(patho_dir, f"{patient}.complete")): logging.warning(f"skip {patient_dir}") continue md5, hash = hashptid(patient) dst_patient_dir = os.path.join(DST_ROOT, patho, hash) complete_file = os.path.join(DST_ROOT, patho, f'{hash}.complete') if os.path.exists(complete_file): logging.warning(f"skip {patient_dir}") continue num_study = 0 for study in sorted(os.listdir(patient_dir), reverse=True): study_date = study.split('_')[0] if datetime.datetime.strptime(study_date, "%Y%m%d") > LAST_DAY: logging.warning(f"skip {study_date}") continue study_dir = os.path.join(patient_dir, study) if not os.path.isdir(study_dir): continue # logging.warning(study_dir) best_series = check_study(study_dir) if not best_series: continue dst_dir = os.path.join(dst_patient_dir, study) anonymize_series_to_nifti(study_dir, best_series['files'], dst_dir) num_study += 1 if num_study > 0: with open(complete_file, 'w') as f: f.write('done') num_patient += 1 if num_patient >= MAX_PATIENT: break # break logging.warning(f"BodyPartExamined: {BodyPartExamined}") with open(NII_JSON_PATH, 'w') as f: json.dump(NII_DICT, f, indent=1) def list_4d(): PLUS_3D = set() PLUS_4D = set() for study_id, study in study_db.items(): brain_list = [] body_parts = set() s = None for uid, s in study.items(): # logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}") if 'BodyPartExamined' in s['FileDataset']: if s['FileDataset'].BodyPartExamined in BodyPartIncluded: brain_list.append(s) else: body_parts.add(s['FileDataset'].BodyPartExamined) if not brain_list: if body_parts: logging.warning(f"no brain, BodyPartExamined: {body_parts}") continue else: if not s: logging.warning(f"no series found") continue logging.warning(f"BodyPartExamined is empty") if 'brain' in s['FileDataset'].StudyDescription.lower(): logging.warning(f"brain in {s['FileDataset'].StudyDescription}, adding all series") brain_list = list(study.values()) else: logging.warning(f"no brain in {s['FileDataset'].StudyDescription}") continue for s in brain_list: sd = s['FileDataset'].SeriesDescription.lower() if not ('+' in sd or 'gd' in sd): continue if 't1' in sd: continue if 't2' in sd: continue if 'flair' in sd: continue c = collections.Counter(s['orientations']) orientation_str = c.most_common(1)[0][0] if not is_axial(orientation_str): continue if s['is_4d']: logging.warning(f"4D series found: {s['FileDataset'].SeriesDescription}") PLUS_4D.add(s['FileDataset'].SeriesDescription) else: PLUS_3D.add(s['FileDataset'].SeriesDescription) with open("plus_4d.json", 'w') as f: json.dump(sorted(PLUS_4D), f, indent=1) with open("plus_3d.json", 'w') as f: json.dump(sorted(PLUS_3D), f, indent=1) exit() if __name__ == '__main__': # list_4d() main()