diff --git a/nnunetv2/inference/export_prediction.py b/nnunetv2/inference/export_prediction.py index f5cdb958d..2720ccc39 100644 --- a/nnunetv2/inference/export_prediction.py +++ b/nnunetv2/inference/export_prediction.py @@ -1,11 +1,10 @@ -import os -from copy import deepcopy +import time from typing import Union, List import numpy as np import torch from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice -from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle +from batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle from nnunetv2.configuration import default_num_processes from nnunetv2.utilities.label_handling.label_handling import LabelManager @@ -21,30 +20,45 @@ def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits num_threads_torch: int = default_num_processes): old_threads = torch.get_num_threads() torch.set_num_threads(num_threads_torch) - # resample to original shape spacing_transposed = [properties_dict['spacing'][i] for i in plans_manager.transpose_forward] current_spacing = configuration_manager.spacing if \ len(configuration_manager.spacing) == \ len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [spacing_transposed[0], *configuration_manager.spacing] - predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, - properties_dict['shape_after_cropping_and_before_resampling'], - current_spacing, - [properties_dict['spacing'][i] for i in plans_manager.transpose_forward]) - # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because - # apply_inference_nonlin will convert to torch - predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) - del predicted_logits - segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities) + if return_probabilities: + predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits, + properties_dict[ + 'shape_after_cropping_and_before_resampling'], + current_spacing, + [properties_dict['spacing'][i] for i in + plans_manager.transpose_forward]) + # return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because + # apply_inference_nonlin will convert to torch + predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) + del predicted_logits + segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities) + else: + predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits) + del predicted_logits + segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities) + segmentation = configuration_manager.resampling_fn_probabilities(segmentation.unsqueeze(0), + properties_dict[ + 'shape_after_cropping_and_before_resampling'], + current_spacing, + [properties_dict['spacing'][i] for i in + plans_manager.transpose_forward], + order=0 + ) # segmentation may be torch.Tensor but we continue with numpy if isinstance(segmentation, torch.Tensor): segmentation = segmentation.cpu().numpy() # put segmentation in bbox (revert cropping) segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'], - dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) + dtype=np.uint8 if len( + label_manager.foreground_labels) < 255 else np.uint16) slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) segmentation_reverted_cropping[slicer] = segmentation del segmentation @@ -81,7 +95,8 @@ def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, tor # elif predicted_array_or_file.endswith('.npz'): # predicted_array_or_file = np.load(predicted_array_or_file)['softmax'] # os.remove(tmp) - + print("[INFO] Start working on export_prediction_from_logits") + tic = time.time() if isinstance(dataset_json_dict_or_file, str): dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) @@ -105,6 +120,7 @@ def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, tor rw = plans_manager.image_reader_writer_class() rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'], properties_dict) + print(f"[INFO] Elapsed time for export_prediction_from_logits: {time.time() - tic}") def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str, @@ -130,7 +146,8 @@ def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \ [spacing_transposed[0], *configuration_manager.spacing] target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \ - len(properties_dict['shape_after_cropping_and_before_resampling']) else \ + len(properties_dict[ + 'shape_after_cropping_and_before_resampling']) else \ [spacing_transposed[0], *configuration_manager.spacing] predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted, target_shape, diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 1f5ede64f..e3db152e8 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -63,7 +63,7 @@ def __init__(self, self.perform_everything_on_device = perform_everything_on_device def initialize_from_trained_model_folder(self, model_training_output_dir: str, - use_folds: Union[Tuple[Union[int, str]], None], + use_folds: Union[Tuple[Union[int, str], ...], None], checkpoint_name: str = 'checkpoint_final.pth'): """ This is used when making predictions with a trained model