tsha-mri-tumor-labeling/kimo/2-check-series.py

462 lines
15 KiB
Python

from pathlib import Path
import base64
import collections
import datetime
import hashlib
import json
import logging
import os
import shelve
import shutil
import tempfile
from lmdbm import Lmdb
from tqdm import tqdm
import dicom2nifti
import pydicom
EXCLUDED_HASH = [
# "GWJU7LPC",
]
INCLUDED_HASH = [
# "GWJU7LPC",
]
LAST_DAY = datetime.datetime.strptime("2025-11-01", "%Y-%m-%d")
MAX_PATIENT = 10//3
MAX_PATIENT = 1000
SRC_ROOT = "/mnt/t24/Public/kimo/TSHA"
RAW_DIR = "/mnt/t24/Public/kimo/raw/"
DST_ROOT = os.path.join(RAW_DIR, "DICOM")
imagesTs_DIR = os.path.join(RAW_DIR, "Dataset2602_BraTS-CK/imagesTs/")
NII_JSON_PATH = os.path.join(RAW_DIR, 'nii.json')
if os.path.isfile(NII_JSON_PATH):
with open(NII_JSON_PATH, 'r') as f:
NII_DICT = json.load(f)
else:
NII_DICT = {}
# {'', 'PELVISLOWEXTREM', 'BRAIN', 'CSPINE', 'KNEE', 'TSPINE', 'CAROTID', 'NECK', 'ABDOMEN', 'ORBIT', 'HEAD', 'CHEST', 'IAC', 'WHOLEBODY', 'WHOLESPINE', 'ABDOMENPELVIS', 'PELVIS', 'LSPINE', 'SPINE', 'CIRCLEOFWILLIS'}
# BodyPartExamined: Counter({'BRAIN': 152087, 'ABDOMEN': 14101, 'HEAD': 11806, 'ABDOMENPELVIS': 10905, 'SPINE': 9277, 'CHEST': 3746, 'PELVIS': 3208, 'NECK': 3205, 'CSPINE': 1527, 'CAROTID': 1186,
# 'HEART': 1122, 'LSPINE': 1080, 'KNEE': 591, 'PELVISLOWEXTREM': 496, '': 385, 'ORBIT': 360, 'CIRCLEOFWILLIS': 322, 'HUMERUS': 320, 'ARM': 304, 'IAC': 291,
# 'EXTREMITY': 287, 'SHOULDER': 242, 'WHOLEBODY': 190, 'TSPINE': 150, 'HEADNECK': 48, 'WHOLESPINE': 45})
BodyPartExamined = collections.Counter()
BodyPartIncluded = set([
'BRAIN',
'BRAIN_CE',
'CIRCLEOFWILLIS',
'HEAD',
'IAC',
# 'ORBIT',
])
STUDY_DB_PATH = 'study.db'
class JsonLmdb(Lmdb):
def _pre_key(self, value):
return value.encode("utf-8")
def _post_key(self, value):
return value.decode("utf-8")
def _pre_value(self, value):
def default(o):
if isinstance(o, (pydicom.dataset.FileDataset, pydicom.Dataset)):
return json.loads(o.to_json(suppress_invalid_tags=True))
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")
return json.dumps(value, default=default).encode("utf-8")
def _post_value(self, value):
data = json.loads(value.decode("utf-8"))
if isinstance(data, dict):
for k, v in data.items():
if isinstance(v, dict) and 'FileDataset' in v:
if isinstance(v['FileDataset'], dict) or isinstance(v['FileDataset'], str):
v['FileDataset'] = pydicom.Dataset.from_json(v['FileDataset'])
return data
# study_db = shelve.open(STUDY_DB_PATH, 'c')
study_db = JsonLmdb.open(STUDY_DB_PATH, "c")
def is_axial(s):
o = s.split(' ')
return o[1]=='0' and o[2]=='0' and o[3]=='0' and o[5]=='0'
def int_orientation(o):
orientation = [str(round(f)) for f in o]
return ' '.join(orientation)
def check_4d_series(ds_list):
# path = Path(directory_path)
# Dictionary to group slices by their 3D coordinates (x, y, z)
spatial_groups = collections.defaultdict(list)
for ds in ds_list:
try:
# ds = pydicom.dcmread(dcm_file, stop_before_pixels=True)
# Use ImagePositionPatient (0020, 0032) to identify unique 3D locations
if "ImagePositionPatient" in ds:
pos = tuple(ds.ImagePositionPatient)
# Store relevant temporal tags for inspection
time_info = {
"AcquisitionTime": getattr(ds, "AcquisitionTime", "N/A"),
"ContentTime": getattr(ds, "ContentTime", "N/A"),
"InstanceNumber": getattr(ds, "InstanceNumber", "N/A")
}
spatial_groups[pos].append(time_info)
except Exception as e:
logging.error(f"Error reading {dcm_file.name}: {e}")
# Analyze findings
is_4d = False
for pos, slices in spatial_groups.items():
if len(slices) > 1:
is_4d = True
logging.warning(f"{ds.SeriesDescription} Detected 4D: Position {pos} has {len(slices)} temporal frames.")
break
# if not is_4d:
# logging.warning(f"{ds.SeriesDescription} Series appears to be standard 3D (one slice per position).")
return is_4d
def check_study(study_dir):
key = '_'.join(Path(study_dir).parts[-2:])
# with JsonLmdb.open(STUDY_DB_PATH, "c") as db:
study = study_db.get(key, None)
if study:
logging.warning(f"use cached series for {study_dir}")
else:
logging.warning(f"scan series for {study_dir}")
SeriesDescription = set()
study = {}
DCMS = []
SERIES_DS = {}
for root, dirs, files in os.walk(study_dir):
for file in sorted(files, key=lambda x: int(x.split("_")[-1].split(".")[0])):
if file.endswith(".dcm"):
DCMS.append(os.path.join(root, file))
for dcm_file in tqdm(DCMS):
ds = pydicom.dcmread(dcm_file, force=True, stop_before_pixels=True)
if 'BodyPartExamined' in ds:
BodyPartExamined[ds.BodyPartExamined] += 1
if 'ImageOrientationPatient' not in ds:
continue
# print(f"{dcm_file}")
series_instance_uid = ds.SeriesInstanceUID
SeriesDescription.add(ds.SeriesDescription)
# print(body_part_examined, series_description)
if series_instance_uid not in study:
study[series_instance_uid] = {
'FileDataset': ds,
'1st_file': dcm_file,
'orientations': [],
'files': [],
}
SERIES_DS[series_instance_uid] = []
study[series_instance_uid]['files'].append(os.path.relpath(dcm_file, study_dir))
study[series_instance_uid]['orientations'].append(int_orientation(ds.ImageOrientationPatient))
SERIES_DS[series_instance_uid].append(ds)
for uid, ds_list in SERIES_DS.items():
study[uid]['is_4d'] = check_4d_series(ds_list)
# with JsonLmdb.open(STUDY_DB_PATH, "c") as db:
# with shelve.open(STUDY_DB_PATH) as db:
study_db[key] = study
brain_list = []
body_parts = set()
s = None
for uid, s in study.items():
# logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}")
if 'BodyPartExamined' in s['FileDataset']:
if s['FileDataset'].BodyPartExamined in BodyPartIncluded:
brain_list.append(s)
else:
body_parts.add(s['FileDataset'].BodyPartExamined)
if not brain_list:
if body_parts:
logging.warning(f"no brain, BodyPartExamined: {body_parts}")
return None
else:
if not s:
logging.warning(f"no series found")
return None
logging.warning(f"BodyPartExamined is empty")
if 'brain' in s['FileDataset'].StudyDescription.lower():
logging.warning(f"brain in {s['FileDataset'].StudyDescription}, adding all series")
brain_list = list(study.values())
else:
logging.warning(f"no brain in {s['FileDataset'].StudyDescription}")
return None
t1c = []
for s in brain_list:
sd = s['FileDataset'].SeriesDescription.lower()
if not ('+' in sd or 'gd' in sd):
continue
if s['is_4d']:
logging.warning(f"4D series found: {s['FileDataset'].SeriesDescription}")
continue
if 't1' not in sd and (
'flair' in sd or
't2' in sd or
'tof' in sd or
'dwi' in sd or
'ep2d' in sd or
'perf' in sd # perfusion series
):
continue
t1c.append(s)
if not t1c:
logging.warning(f"no t1c in {s['FileDataset'].StudyDescription}")
for s in brain_list:
logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {len(s['files'])}")
return None
t1c_axial = []
for s in t1c:
c = collections.Counter(s['orientations'])
orientation_str = c.most_common(1)[0][0]
if is_axial(orientation_str):
logging.warning(f"--- {s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {orientation_str} {len(s['files'])}")
t1c_axial.append(s)
if not t1c_axial:
logging.warning(f"no axial t1c in {study_dir}")
for s in t1c:
logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].SeriesDescription} {orientation_str} {len(s['files'])} {s['FileDataset'].StudyDescription}")
return None
best_series = max(t1c_axial, key=lambda x: (len(x['files']), -x['FileDataset'].SeriesNumber))
# best_series = min(t1c_axial, key=lambda x: len(x['files'], x['FileDataset'].SeriesNumber)))
logging.warning(f"{best_series['FileDataset'].SeriesNumber} {best_series['FileDataset'].SeriesDescription} {orientation_str} {len(best_series['files'])}")
return best_series
def hashptid(mrn, hosp='NTUH'):
ptsalt = (mrn+hosp).upper().encode()
hash_in_bytes = hashlib.md5(ptsalt)
md5 = hash_in_bytes.hexdigest()
hash = base64.b32encode(hash_in_bytes.digest())[:8].decode()
return md5, hash
def anonymize_series_to_nifti(study_dir, series_files, dst_dir):
os.makedirs(dst_dir, exist_ok=True)
for f in series_files:
ds = pydicom.dcmread(os.path.join(study_dir, f))
md5, hash = hashptid(ds.PatientID)
for elem in ds:
if elem.tag.group == 0x0010:
elem.value = ''
ds.PatientID = hash
dst_file = os.path.join(dst_dir, os.path.basename(f.split("_")[-1]))
ds.save_as(dst_file)
with tempfile.TemporaryDirectory() as tmpdirname:
dicom2nifti.convert_directory(dst_dir, tmpdirname, compression=True)
for e in os.scandir(tmpdirname):
if e.is_file() and e.name.endswith(".nii.gz"):
stem = f"{hash}-{ds.StudyDate}"
dst_file = os.path.join(imagesTs_DIR, f"{stem}_0000.nii.gz")
logging.warning(f"copying to {dst_file}")
shutil.copyfile(e.path, dst_file)
NII_DICT[stem] = os.path.relpath(dst_dir, DST_ROOT)
def main():
FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s'
logging.basicConfig(
level=logging.WARNING,
format=FORMAT,
handlers=[
logging.StreamHandler(),
# logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8')
logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8')
]
)
# shutil.rmtree(imagesTs_DIR, ignore_errors=True)
os.makedirs(imagesTs_DIR, exist_ok=True)
for patho in sorted(os.listdir(SRC_ROOT)):
patho_dir = os.path.join(SRC_ROOT, patho)
num_patient = 0
for patient in sorted(os.listdir(patho_dir)):
md5, hash = hashptid(patient)
if INCLUDED_HASH:
if hash not in INCLUDED_HASH:
continue
if hash in EXCLUDED_HASH:
continue
patient_dir = os.path.join(patho_dir, patient)
if not os.path.isdir(patient_dir):
continue
if not os.path.isfile(patient_dir+'.complete'):
continue
if not os.path.isfile(os.path.join(patho_dir, f"{patient}.complete")):
logging.warning(f"skip {patient_dir}")
continue
md5, hash = hashptid(patient)
dst_patient_dir = os.path.join(DST_ROOT, patho, hash)
complete_file = os.path.join(DST_ROOT, patho, f'{hash}.complete')
if os.path.exists(complete_file):
logging.warning(f"skip {patient_dir}")
continue
num_study = 0
for study in sorted(os.listdir(patient_dir), reverse=True):
study_date = study.split('_')[0]
if datetime.datetime.strptime(study_date, "%Y%m%d") > LAST_DAY:
logging.warning(f"skip {study_date}")
continue
study_dir = os.path.join(patient_dir, study)
if not os.path.isdir(study_dir):
continue
# logging.warning(study_dir)
best_series = check_study(study_dir)
if not best_series:
continue
dst_dir = os.path.join(dst_patient_dir, study)
anonymize_series_to_nifti(study_dir, best_series['files'], dst_dir)
num_study += 1
if num_study > 0:
with open(complete_file, 'w') as f:
f.write('done')
num_patient += 1
if num_patient >= MAX_PATIENT:
break
# break
logging.warning(f"BodyPartExamined: {BodyPartExamined}")
with open(NII_JSON_PATH, 'w') as f:
json.dump(NII_DICT, f, indent=1)
def list_4d():
PLUS_3D = set()
PLUS_4D = set()
for study_id, study in study_db.items():
brain_list = []
body_parts = set()
s = None
for uid, s in study.items():
# logging.warning(f"{s['FileDataset'].SeriesNumber} {s['FileDataset'].BodyPartExamined} {s['FileDataset'].SeriesDescription} {len(s['files'])} {s['1st_file']}")
if 'BodyPartExamined' in s['FileDataset']:
if s['FileDataset'].BodyPartExamined in BodyPartIncluded:
brain_list.append(s)
else:
body_parts.add(s['FileDataset'].BodyPartExamined)
if not brain_list:
if body_parts:
logging.warning(f"no brain, BodyPartExamined: {body_parts}")
continue
else:
if not s:
logging.warning(f"no series found")
continue
logging.warning(f"BodyPartExamined is empty")
if 'brain' in s['FileDataset'].StudyDescription.lower():
logging.warning(f"brain in {s['FileDataset'].StudyDescription}, adding all series")
brain_list = list(study.values())
else:
logging.warning(f"no brain in {s['FileDataset'].StudyDescription}")
continue
for s in brain_list:
sd = s['FileDataset'].SeriesDescription.lower()
if not ('+' in sd or 'gd' in sd):
continue
if 't1' in sd:
continue
if 't2' in sd:
continue
if 'flair' in sd:
continue
c = collections.Counter(s['orientations'])
orientation_str = c.most_common(1)[0][0]
if not is_axial(orientation_str):
continue
if s['is_4d']:
logging.warning(f"4D series found: {s['FileDataset'].SeriesDescription}")
PLUS_4D.add(s['FileDataset'].SeriesDescription)
else:
PLUS_3D.add(s['FileDataset'].SeriesDescription)
with open("plus_4d.json", 'w') as f:
json.dump(sorted(PLUS_4D), f, indent=1)
with open("plus_3d.json", 'w') as f:
json.dump(sorted(PLUS_3D), f, indent=1)
exit()
if __name__ == '__main__':
# list_4d()
main()