tsha-mri-tumor-labeling/kimo/2-check-series.py

463 lines
15 KiB
Python
Raw Permalink Normal View History

2026-03-07 11:32:28 +00:00
from pathlib import Path
import base64
import collections
import datetime
import hashlib
import json
import logging
import os
import shelve
import shutil
import tempfile
2026-03-07 11:32:28 +00:00
from lmdbm import Lmdb
from tqdm import tqdm
import dicom2nifti
import pydicom
EXCLUDED_HASH = [
# "GWJU7LPC",
]
INCLUDED_HASH = [
# "GWJU7LPC",
]
LAST_DAY = datetime.datetime.strptime("2025-11-01", "%Y-%m-%d")
2026-02-22 04:12:24 +00:00
MAX_PATIENT = 10//3
2026-03-07 11:32:28 +00:00
MAX_PATIENT = 1000
SRC_ROOT = "/mnt/t24/Public/kimo/TSHA"
RAW_DIR = "/mnt/t24/Public/kimo/raw/"
DST_ROOT = os.path.join(RAW_DIR, "DICOM")
imagesTs_DIR = os.path.join(RAW_DIR, "Dataset2602_BraTS-CK/imagesTs/")
NII_JSON_PATH = os.path.join(RAW_DIR, 'nii.json')
2026-03-07 11:32:28 +00:00
if os.path.isfile(NII_JSON_PATH):
with open(NII_JSON_PATH, 'r') as f:
NII_DICT = json.load(f)
else:
NII_DICT = {}
# {'', 'PELVISLOWEXTREM', 'BRAIN', 'CSPINE', 'KNEE', 'TSPINE', 'CAROTID', 'NECK', 'ABDOMEN', 'ORBIT', 'HEAD', 'CHEST', 'IAC', 'WHOLEBODY', 'WHOLESPINE', 'ABDOMENPELVIS', 'PELVIS', 'LSPINE', 'SPINE', 'CIRCLEOFWILLIS'}
# BodyPartExamined: Counter({'BRAIN': 152087, 'ABDOMEN': 14101, 'HEAD': 11806, 'ABDOMENPELVIS': 10905, 'SPINE': 9277, 'CHEST': 3746, 'PELVIS': 3208, 'NECK': 3205, 'CSPINE': 1527, 'CAROTID': 1186,
# 'HEART': 1122, 'LSPINE': 1080, 'KNEE': 591, 'PELVISLOWEXTREM': 496, '': 385, 'ORBIT': 360, 'CIRCLEOFWILLIS': 322, 'HUMERUS': 320, 'ARM': 304, 'IAC': 291,
# 'EXTREMITY': 287, 'SHOULDER': 242, 'WHOLEBODY': 190, 'TSPINE': 150, 'HEADNECK': 48, 'WHOLESPINE': 45})
BodyPartExamined = collections.Counter()
BodyPartIncluded = set([
'BRAIN',
2026-03-07 11:32:28 +00:00
'BRAIN_CE',
'CIRCLEOFWILLIS',
'HEAD',
'IAC',
# 'ORBIT',
])
2026-03-07 11:32:28 +00:00
STUDY_DB_PATH = 'study.db'
class JsonLmdb(Lmdb):
def _pre_key(self, value):
return value.encode("utf-8")
def _post_key(self, value):
return value.decode("utf-8")
def _pre_value(self, value):
def default(o):
if isinstance(o, (pydicom.dataset.FileDataset, pydicom.Dataset)):
return json.loads(o.to_json(suppress_invalid_tags=True))
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")
return json.dumps(value, default=default).encode("utf-8")
def _post_value(self, value):
data = json.loads(value.decode("utf-8"))
if isinstance(data, dict):
for k, v in data.items():
if isinstance(v, dict) and 'FileDataset' in v:
if isinstance(v['FileDataset'], dict) or isinstance(v['FileDataset'], str):
v['FileDataset'] = pydicom.Dataset.from_json(v['FileDataset'])
return data
# study_db = shelve.open(STUDY_DB_PATH, 'c')
study_db = JsonLmdb.open(STUDY_DB_PATH, "c")
def is_axial(s):
o = s.split(' ')
return o[1]=='0' and o[2]=='0' and o[3]=='0' and o[5]=='0'
def int_orientation(o):
orientation = [str(round(f)) for f in o]
return ' '.join(orientation)
def check_4d_series(ds_list):
# path = Path(directory_path)
# Dictionary to group slices by their 3D coordinates (x, y, z)
spatial_groups = collections.defaultdict(list)
for ds in ds_list:
try:
# ds = pydicom.dcmread(dcm_file, stop_before_pixels=True)
# Use ImagePositionPatient (0020, 0032) to identify unique 3D locations
if "ImagePositionPatient" in ds:
pos = tuple(ds.ImagePositionPatient)
# Store relevant temporal tags for inspection
time_info = {
"AcquisitionTime": getattr(ds, "AcquisitionTime", "N/A"),
"ContentTime": getattr(ds, "ContentTime", "N/A"),
"InstanceNumber": getattr(ds, "InstanceNumber", "N/A")
}
spatial_groups[pos].append(time_info)
except Exception as e:
logging.error(f"Error reading {dcm_file.name}: {e}")
# Analyze findings
is_4d = False
for pos, slices in spatial_groups.items():
if len(slices) > 1:
is_4d = True
logging.warning(f"{ds.SeriesDescription} Detected 4D: Position {pos} has {len(slices)} temporal frames.")
break
# if not is_4d:
# logging.warning(f"{ds.SeriesDescription} Series appears to be standard 3D (one slice per position).")
return is_4d
def check_study(study_dir):
2026-03-07 11:32:28 +00:00
key = '_'.join(Path(study_dir).parts[-2:])
# with JsonLmdb.open(STUDY_DB_PATH, "c") as db:
study = study_db.get(key, None)
2026-03-07 11:32:28 +00:00
if study:
logging.warning(f"use cached series for {study_dir}")
else:
logging.warning(f"scan series for {study_dir}")
2026-03-07 11:32:28 +00:00
SeriesDescription = set()
study = {}
2026-03-07 11:32:28 +00:00
DCMS = []
SERIES_DS = {}
2026-03-07 11:32:28 +00:00
for root, dirs, files in os.walk(study_dir):
for file in sorted(files, key=lambda x: int(x.split("_")[-1].split(".")[0])):
if file.endswith(".dcm"):
DCMS.append(os.path.join(root, file))
2026-03-07 11:32:28 +00:00
for dcm_file in tqdm(DCMS):
ds = pydicom.dcmread(dcm_file, force=True, stop_before_pixels=True)
if 'BodyPartExamined' in ds:
BodyPartExamined[ds.BodyPartExamined] += 1
if 'ImageOrientationPatient' not in ds:
continue
# print(f"{dcm_file}")
series_instance_uid = ds.SeriesInstanceUID
SeriesDescription.add(ds.SeriesDescription)
# print(body_part_examined, series_description)
if series_instance_uid not in study:
study[series_instance_uid] = {
'FileDataset': ds,
'1st_file': dcm_file,
'orientations': [],
'files': [],
}
SERIES_DS[series_instance_uid] = []
study[series_instance_uid]['files'].append(os.path.relpath(dcm_file, study_dir))
study[series_instance_uid]['orientations'].append(int_orientation(ds.ImageOrientationPatient))
SERIES_DS[series_instance_uid].append(ds)
for uid, ds_list in SERIES_DS.items():
study[uid]['is_4d'] = check_4d_series(ds_list)
# with JsonLmdb.open(STUDY_DB_PATH, "c") as db:
# with shelve.open(STUDY_DB_PATH) as db:
study_db[key] = study
brain_list = []
body_parts = set()
2026-03-07 11:32:28 +00:00
s = None
for uid, s in study.items():
# logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}")
if 'BodyPartExamined' in s['FileDataset']:
if s['FileDataset'].BodyPartExamined in BodyPartIncluded:
brain_list.append(s)
else:
body_parts.add(s['FileDataset'].BodyPartExamined)
if not brain_list:
if body_parts:
2026-03-07 11:32:28 +00:00
logging.warning(f"no brain, BodyPartExamined: {body_parts}")
return None
else:
2026-03-07 11:32:28 +00:00
if not s:
logging.warning(f"no series found")
return None
logging.warning(f"BodyPartExamined is empty")
if 'brain' in s['FileDataset'].StudyDescription.lower():
logging.warning(f"brain in {s['FileDataset'].StudyDescription}, adding all series")
brain_list = list(study.values())
else:
2026-03-07 11:32:28 +00:00
logging.warning(f"no brain in {s['FileDataset'].StudyDescription}")
return None
t1c = []
for s in brain_list:
2026-03-07 11:32:28 +00:00
sd = s['FileDataset'].SeriesDescription.lower()
if not ('+' in sd or 'gd' in sd):
continue
2026-03-07 11:32:28 +00:00
if s['is_4d']:
logging.warning(f"4D series found: {s['FileDataset'].SeriesDescription}")
continue
if 't1' not in sd and (
'flair' in sd or
't2' in sd or
2026-03-07 11:32:28 +00:00
'tof' in sd or
'dwi' in sd or
'ep2d' in sd or
'perf' in sd # perfusion series
):
continue
t1c.append(s)
if not t1c:
2026-03-07 11:32:28 +00:00
logging.warning(f"no t1c in {s['FileDataset'].StudyDescription}")
for s in brain_list:
2026-03-07 11:32:28 +00:00
logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {len(s['files'])}")
return None
t1c_axial = []
for s in t1c:
c = collections.Counter(s['orientations'])
orientation_str = c.most_common(1)[0][0]
2026-03-07 11:32:28 +00:00
if is_axial(orientation_str):
logging.warning(f"--- {s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {orientation_str} {len(s['files'])}")
t1c_axial.append(s)
if not t1c_axial:
2026-03-07 11:32:28 +00:00
logging.warning(f"no axial t1c in {study_dir}")
for s in t1c:
2026-03-07 11:32:28 +00:00
logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {orientation_str} {len(s['files'])} {s['FileDataset'].StudyDescription}")
return None
best_series = max(t1c_axial, key=lambda x: (len(x['files']), -x['FileDataset'].SeriesNumber))
# best_series = min(t1c_axial, key=lambda x: len(x['files'], x['FileDataset'].SeriesNumber)))
2026-03-07 11:32:28 +00:00
logging.warning(f"{best_series['FileDataset'].SeriesNumber} {best_series['FileDataset'].SeriesDescription} {orientation_str} {len(best_series['files'])}")
return best_series
def hashptid(mrn, hosp='NTUH'):
ptsalt = (mrn+hosp).upper().encode()
hash_in_bytes = hashlib.md5(ptsalt)
md5 = hash_in_bytes.hexdigest()
hash = base64.b32encode(hash_in_bytes.digest())[:8].decode()
return md5, hash
2026-03-07 11:32:28 +00:00
def anonymize_series_to_nifti(study_dir, series_files, dst_dir):
os.makedirs(dst_dir, exist_ok=True)
for f in series_files:
2026-03-07 11:32:28 +00:00
ds = pydicom.dcmread(os.path.join(study_dir, f))
md5, hash = hashptid(ds.PatientID)
for elem in ds:
if elem.tag.group == 0x0010:
elem.value = ''
ds.PatientID = hash
dst_file = os.path.join(dst_dir, os.path.basename(f.split("_")[-1]))
ds.save_as(dst_file)
with tempfile.TemporaryDirectory() as tmpdirname:
dicom2nifti.convert_directory(dst_dir, tmpdirname, compression=True)
for e in os.scandir(tmpdirname):
if e.is_file() and e.name.endswith(".nii.gz"):
stem = f"{hash}-{ds.StudyDate}"
dst_file = os.path.join(imagesTs_DIR, f"{stem}_0000.nii.gz")
2026-03-07 11:32:28 +00:00
logging.warning(f"copying to {dst_file}")
shutil.copyfile(e.path, dst_file)
NII_DICT[stem] = os.path.relpath(dst_dir, DST_ROOT)
def main():
FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s'
logging.basicConfig(
2026-03-07 11:32:28 +00:00
level=logging.WARNING,
format=FORMAT,
handlers=[
logging.StreamHandler(),
# logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8')
logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8')
]
)
2026-03-07 11:32:28 +00:00
# shutil.rmtree(imagesTs_DIR, ignore_errors=True)
os.makedirs(imagesTs_DIR, exist_ok=True)
for patho in sorted(os.listdir(SRC_ROOT)):
patho_dir = os.path.join(SRC_ROOT, patho)
2026-02-22 04:12:24 +00:00
num_patient = 0
for patient in sorted(os.listdir(patho_dir)):
md5, hash = hashptid(patient)
if INCLUDED_HASH:
if hash not in INCLUDED_HASH:
continue
if hash in EXCLUDED_HASH:
continue
patient_dir = os.path.join(patho_dir, patient)
if not os.path.isdir(patient_dir):
continue
2026-03-07 11:32:28 +00:00
if not os.path.isfile(patient_dir+'.complete'):
continue
if not os.path.isfile(os.path.join(patho_dir, f"{patient}.complete")):
2026-03-07 11:32:28 +00:00
logging.warning(f"skip {patient_dir}")
continue
md5, hash = hashptid(patient)
dst_patient_dir = os.path.join(DST_ROOT, patho, hash)
complete_file = os.path.join(DST_ROOT, patho, f'{hash}.complete')
if os.path.exists(complete_file):
2026-03-07 11:32:28 +00:00
logging.warning(f"skip {patient_dir}")
continue
num_study = 0
for study in sorted(os.listdir(patient_dir), reverse=True):
study_date = study.split('_')[0]
if datetime.datetime.strptime(study_date, "%Y%m%d") > LAST_DAY:
2026-03-07 11:32:28 +00:00
logging.warning(f"skip {study_date}")
continue
study_dir = os.path.join(patient_dir, study)
if not os.path.isdir(study_dir):
continue
2026-03-07 11:32:28 +00:00
# logging.warning(study_dir)
best_series = check_study(study_dir)
if not best_series:
continue
dst_dir = os.path.join(dst_patient_dir, study)
2026-03-07 11:32:28 +00:00
anonymize_series_to_nifti(study_dir, best_series['files'], dst_dir)
num_study += 1
if num_study > 0:
with open(complete_file, 'w') as f:
f.write('done')
2026-02-22 04:12:24 +00:00
num_patient += 1
if num_patient >= MAX_PATIENT:
break
# break
2026-03-07 11:32:28 +00:00
logging.warning(f"BodyPartExamined: {BodyPartExamined}")
with open(NII_JSON_PATH, 'w') as f:
json.dump(NII_DICT, f, indent=1)
2026-03-07 11:32:28 +00:00
def list_4d():
PLUS_3D = set()
PLUS_4D = set()
for study_id, study in study_db.items():
brain_list = []
body_parts = set()
s = None
for uid, s in study.items():
# logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}")
if 'BodyPartExamined' in s['FileDataset']:
if s['FileDataset'].BodyPartExamined in BodyPartIncluded:
brain_list.append(s)
else:
body_parts.add(s['FileDataset'].BodyPartExamined)
if not brain_list:
if body_parts:
logging.warning(f"no brain, BodyPartExamined: {body_parts}")
continue
else:
if not s:
logging.warning(f"no series found")
continue
logging.warning(f"BodyPartExamined is empty")
if 'brain' in s['FileDataset'].StudyDescription.lower():
logging.warning(f"brain in {s['FileDataset'].StudyDescription}, adding all series")
brain_list = list(study.values())
else:
logging.warning(f"no brain in {s['FileDataset'].StudyDescription}")
continue
for s in brain_list:
sd = s['FileDataset'].SeriesDescription.lower()
if not ('+' in sd or 'gd' in sd):
continue
if 't1' in sd:
continue
if 't2' in sd:
continue
if 'flair' in sd:
continue
c = collections.Counter(s['orientations'])
orientation_str = c.most_common(1)[0][0]
if not is_axial(orientation_str):
continue
if s['is_4d']:
logging.warning(f"4D series found: {s['FileDataset'].SeriesDescription}")
PLUS_4D.add(s['FileDataset'].SeriesDescription)
else:
PLUS_3D.add(s['FileDataset'].SeriesDescription)
with open("plus_4d.json", 'w') as f:
json.dump(sorted(PLUS_4D), f, indent=1)
with open("plus_3d.json", 'w') as f:
json.dump(sorted(PLUS_3D), f, indent=1)
exit()
if __name__ == '__main__':
2026-03-07 11:32:28 +00:00
# list_4d()
main()