tsha-mri-tumor-labeling/kimo/3-inference.py

92 lines
3.3 KiB
Python
Raw Normal View History

# done with GWJU7LPC-20201007
import logging
import os
import shutil
from monai.apps.nnunet import nnUNetV2Runner
import yaml
DATASET="Dataset2602_BraTS-CK"
YAML ="2602.yaml"
DST_DIR="/mnt/t24/Public/kimo/raw/ensemble_predictions_postprocessed"
def copytree(src, dst):
os.makedirs(dst, exist_ok=True)
for root, dirs, files in os.walk(src):
for file in files:
src_file = os.path.join(root, file)
dst_file = os.path.join(dst, file)
shutil.copyfile(src_file, dst_file)
def main():
FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s'
logging.basicConfig(
level=logging.INFO,
format=FORMAT,
handlers=[
logging.StreamHandler(),
# logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8')
logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8')
]
)
with open(YAML, "r") as f:
config = yaml.safe_load(f)
PREPROCESSED_DIR = config["nnunet_preprocessed"]
2026-03-07 11:32:28 +00:00
RAW_DIR = config["nnunet_raw"]
RESULTS_DIR = config["nnunet_results"]
2026-03-07 11:32:28 +00:00
IMAGES_TS = os.path.join(RAW_DIR, DATASET, "imagesTs")
IMAGES_TS_DONE = os.path.join(RAW_DIR, DATASET, "imagesTs_done")
POSTPROCESSED_DIR = os.path.join(RESULTS_DIR, DATASET, "ensemble_predictions_postprocessed")
os.makedirs(IMAGES_TS_DONE, exist_ok=True)
if os.path.exists(POSTPROCESSED_DIR):
processed_files = [f for f in os.listdir(POSTPROCESSED_DIR) if f.endswith('.nii.gz')]
processed_ids = [f.replace('.nii.gz', '') for f in processed_files]
for pid in processed_ids:
for mod in range(10): # support up to 10 modalities
mod_file = f"{pid}_{mod:04d}.nii.gz"
src_path = os.path.join(IMAGES_TS, mod_file)
dst_path = os.path.join(IMAGES_TS_DONE, mod_file)
if os.path.exists(src_path):
shutil.move(src_path, dst_path)
# clean up
shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_2d', ignore_errors=True)
shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_fullres', ignore_errors=True)
shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_lowres', ignore_errors=True)
shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions', ignore_errors=True)
2026-03-07 11:32:28 +00:00
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True)
runner = nnUNetV2Runner(YAML)
try:
runner.predict_ensemble_postprocessing(
# num_processes_preprocessing=1,
# num_processes_segmentation_export=1
)
except Exception as e:
logging.exception(e)
exit()
copytree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', DST_DIR)
# exit()
# clean up
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_2d', ignore_errors=True)
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_fullres', ignore_errors=True)
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_lowres', ignore_errors=True)
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions', ignore_errors=True)
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True)
if __name__ == "__main__":
main()