91 lines
3.3 KiB
Python
91 lines
3.3 KiB
Python
# done with GWJU7LPC-20201007
|
|
|
|
import logging
|
|
import os
|
|
import shutil
|
|
|
|
from monai.apps.nnunet import nnUNetV2Runner
|
|
|
|
import yaml
|
|
|
|
DATASET="Dataset2602_BraTS-CK"
|
|
YAML ="2602.yaml"
|
|
|
|
DST_DIR="/mnt/t24/Public/kimo/raw/ensemble_predictions_postprocessed"
|
|
|
|
def copytree(src, dst):
|
|
os.makedirs(dst, exist_ok=True)
|
|
for root, dirs, files in os.walk(src):
|
|
for file in files:
|
|
src_file = os.path.join(root, file)
|
|
dst_file = os.path.join(dst, file)
|
|
shutil.copyfile(src_file, dst_file)
|
|
|
|
|
|
def main():
|
|
FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s'
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format=FORMAT,
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
# logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8')
|
|
logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8')
|
|
]
|
|
)
|
|
|
|
with open(YAML, "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
PREPROCESSED_DIR = config["nnunet_preprocessed"]
|
|
RAW_DIR = config["nnunet_raw"]
|
|
RESULTS_DIR = config["nnunet_results"]
|
|
|
|
|
|
IMAGES_TS = os.path.join(RAW_DIR, DATASET, "imagesTs")
|
|
IMAGES_TS_DONE = os.path.join(RAW_DIR, DATASET, "imagesTs_done")
|
|
|
|
POSTPROCESSED_DIR = os.path.join(RESULTS_DIR, DATASET, "ensemble_predictions_postprocessed")
|
|
|
|
os.makedirs(IMAGES_TS_DONE, exist_ok=True)
|
|
if os.path.exists(POSTPROCESSED_DIR):
|
|
processed_files = [f for f in os.listdir(POSTPROCESSED_DIR) if f.endswith('.nii.gz')]
|
|
processed_ids = [f.replace('.nii.gz', '') for f in processed_files]
|
|
for pid in processed_ids:
|
|
for mod in range(10): # support up to 10 modalities
|
|
mod_file = f"{pid}_{mod:04d}.nii.gz"
|
|
src_path = os.path.join(IMAGES_TS, mod_file)
|
|
dst_path = os.path.join(IMAGES_TS_DONE, mod_file)
|
|
if os.path.exists(src_path):
|
|
shutil.move(src_path, dst_path)
|
|
|
|
# clean up
|
|
shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_2d', ignore_errors=True)
|
|
shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_fullres', ignore_errors=True)
|
|
shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_lowres', ignore_errors=True)
|
|
shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions', ignore_errors=True)
|
|
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True)
|
|
|
|
runner = nnUNetV2Runner(YAML)
|
|
try:
|
|
runner.predict_ensemble_postprocessing(
|
|
# num_processes_preprocessing=1,
|
|
# num_processes_segmentation_export=1
|
|
)
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
exit()
|
|
|
|
copytree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', DST_DIR)
|
|
|
|
# exit()
|
|
|
|
# clean up
|
|
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_2d', ignore_errors=True)
|
|
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_fullres', ignore_errors=True)
|
|
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_lowres', ignore_errors=True)
|
|
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions', ignore_errors=True)
|
|
# shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|