# done with GWJU7LPC-20201007 import logging import os import shutil from monai.apps.nnunet import nnUNetV2Runner import yaml DATASET="Dataset2602_BraTS-CK" YAML ="2602.yaml" DST_DIR="/mnt/t24/Public/kimo/raw/ensemble_predictions_postprocessed" def copytree(src, dst): os.makedirs(dst, exist_ok=True) for root, dirs, files in os.walk(src): for file in files: src_file = os.path.join(root, file) dst_file = os.path.join(dst, file) shutil.copyfile(src_file, dst_file) def main(): FORMAT = '%(asctime)s [%(filename)s:%(lineno)d] %(message)s' logging.basicConfig( level=logging.INFO, format=FORMAT, handlers=[ logging.StreamHandler(), # logging.FileHandler(__file__.replace('.py','.%s.log'%str(datetime.datetime.now()).replace(':','')), encoding='utf-8') logging.FileHandler(__file__.replace('.py','.log'), encoding='utf-8') ] ) with open(YAML, "r") as f: config = yaml.safe_load(f) PREPROCESSED_DIR = config["nnunet_preprocessed"] RESULTS_DIR = config["nnunet_results"] # clean up shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_2d', ignore_errors=True) shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_fullres', ignore_errors=True) shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_lowres', ignore_errors=True) shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions', ignore_errors=True) shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True) runner = nnUNetV2Runner(YAML) try: runner.predict_ensemble_postprocessing( # num_processes_preprocessing=1, # num_processes_segmentation_export=1 ) except Exception as e: logging.exception(e) exit() copytree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', DST_DIR) # exit() # clean up # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_2d', ignore_errors=True) # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_fullres', ignore_errors=True) # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/pred_3d_lowres', ignore_errors=True) # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions', ignore_errors=True) # shutil.rmtree(f'{RESULTS_DIR}/{DATASET}/ensemble_predictions_postprocessed', ignore_errors=True) if __name__ == "__main__": main()