diff --git a/src/paramit/cli/__init__.py b/src/paramit/cli/__init__.py index 19f206c..09b5af0 100644 --- a/src/paramit/cli/__init__.py +++ b/src/paramit/cli/__init__.py @@ -648,8 +648,20 @@ def main(): print( f"{YELLOW}Warning: Configuration file {config_path} already exists{RESET}" ) - overwrite = input("Do you want to overwrite it? (y/n): ") - if overwrite.lower() == "y": + overwrite_file = False + + if "overwrite" in hyperparameters: + if hyperparameters["overwrite"].values[0].strip().lower() in ["y", "yes"]: + overwrite_file = True + # Remove "overwrite" from hyperparameters if it exists + del hyperparameters["overwrite"] + else: + overwrite = input("Do you want to overwrite it? (y/n): ").strip().lower() + if overwrite == "y": + overwrite_file = True + + # If overwrite is allowed, generate and write the config file + if overwrite_file: generated_config = generate_config_file(tree, path) with open(config_path, "wb") as f: @@ -660,6 +672,8 @@ def main(): orig_script_path = config["meta"]["script_path"] experiment_configs = generate_configs_from_hyperparameters(config, hyperparameters) + + if len(experiment_configs) > 100: print(f"{YELLOW}Warning: Running {len(experiment_configs)} experiments{RESET}")