Skip to content

Commit

Permalink
Learning Rate finder (#1341)
Browse files Browse the repository at this point in the history
* Add LR Finder support structure

* Move trainer.Feeder to lib.training.generator

* Expose model.io

* Add lr_finder

* Update docs and locales

* Pre-PR fixups
  - Fix training graph not displaying
  - CI fixes
  - Switch lr finder progress to tqdm
  - Exit lr finder early on NaN
  - Display lr finder progress in GUI
  • Loading branch information
torzdf authored Aug 22, 2023
1 parent 68a3322 commit cf0efff
Show file tree
Hide file tree
Showing 22 changed files with 1,128 additions and 633 deletions.
8 changes: 7 additions & 1 deletion docs/full/lib/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ training.generator module
:undoc-members:
:show-inheritance:

training.lr_finder module
=========================

.. automodule:: lib.training.lr_finder
:members:
:undoc-members:
:show-inheritance:

training.preview_cv module
==========================
Expand All @@ -41,7 +48,6 @@ training.preview_cv module
:undoc-members:
:show-inheritance:


training.preview_tk module
==========================

Expand Down
27 changes: 19 additions & 8 deletions lib/cli/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,25 @@ def get_argument_list() -> list[dict[str, T.Any]]:
"\nL|mirrored: Supports synchronous distributed training across multiple local "
"GPUs. A copy of the model and all variables are loaded onto each GPU with "
"batches distributed to each GPU at each iteration.")))
argument_list.append(dict(
opts=("-nl", "--no-logs"),
action="store_true",
dest="no_logs",
default=False,
group=_("training"),
help=_("Disables TensorBoard logging. NB: Disabling logs means that you will not be "
"able to use the graph or analysis for this session in the GUI.")))
argument_list.append(dict(
opts=("-r", "--use-lr-finder"),
action="store_true",
dest="use_lr_finder",
default=False,
group=_("training"),
help=_("Use the Learning Rate Finder to discover the optimal learning rate for "
"training. For new models, this will calculate the optimal learning rate for "
"the model. For existing models this will use the optimal learning rate that "
"was discovered when initializing the model. Setting this option will ignore "
"the manually configured learning rate (configurable in train settings).")))
argument_list.append(dict(
opts=("-s", "--save-interval"),
action=Slider,
Expand Down Expand Up @@ -1127,14 +1146,6 @@ def get_argument_list() -> list[dict[str, T.Any]]:
group=_("preview"),
help=_("Writes the training result to a file. The image will be stored in the root "
"of your FaceSwap folder.")))
argument_list.append(dict(
opts=("-nl", "--no-logs"),
action="store_true",
dest="no_logs",
default=False,
group=_("training"),
help=_("Disables TensorBoard logging. NB: Disabling logs means that you will not be "
"able to use the graph or analysis for this session in the GUI.")))
argument_list.append(dict(
opts=("-wl", "--warp-to-landmarks"),
action="store_true",
Expand Down
60 changes: 46 additions & 14 deletions lib/gui/custom_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import platform
import re
import sys
import typing as T
import tkinter as tk
from tkinter import ttk, TclError

Expand Down Expand Up @@ -101,7 +102,7 @@ def __init__(self, labels, actions, hotkeys=None):
def _create_menu(self):
""" Create the menu based on :attr:`_labels` and :attr:`_actions`. """
for idx, (label, action) in enumerate(zip(self._labels, self._actions)):
kwargs = dict(label=label, command=action)
kwargs = {"label": label, "command": action}
if isinstance(self._hotkeys, (list, tuple)) and self._hotkeys[idx]:
kwargs["accelerator"] = self._hotkeys[idx]
self.add_command(**kwargs)
Expand Down Expand Up @@ -428,12 +429,13 @@ class StatusBar(ttk.Frame): # pylint: disable=too-many-ancestors
frame otherwise ``False``. Default: ``False``
"""

def __init__(self, parent, hide_status=False):
def __init__(self, parent: ttk.Frame, hide_status: bool = False) -> None:
super().__init__(parent)
self._frame = ttk.Frame(self)
self._message = tk.StringVar()
self._pbar_message = tk.StringVar()
self._pbar_position = tk.IntVar()
self._mode: T.Literal["indeterminate", "determinate"] = "determinate"

self._message.set("Ready")

Expand All @@ -443,12 +445,12 @@ def __init__(self, parent, hide_status=False):
self._frame.pack(padx=10, pady=2, fill=tk.X, expand=False)

@property
def message(self):
def message(self) -> tk.StringVar:
""":class:`tkinter.StringVar`: The variable to hold the status bar message on the left
hand side of the status bar. """
return self._message

def _status(self, hide_status):
def _status(self, hide_status: bool) -> None:
""" Place Status label into left of the status bar.
Parameters
Expand All @@ -472,8 +474,14 @@ def _status(self, hide_status):
anchor=tk.W)
lblstatus.pack(side=tk.LEFT, anchor=tk.W, fill=tk.X, expand=True)

def _progress_bar(self):
""" Place progress bar into right of the status bar. """
def _progress_bar(self) -> ttk.Progressbar:
""" Place progress bar into right of the status bar.
Returns
-------
:class:`tkinter.ttk.Progressbar`
The progress bar object
"""
progressframe = ttk.Frame(self._frame)
progressframe.pack(side=tk.RIGHT, anchor=tk.E, fill=tk.X)

Expand All @@ -484,12 +492,12 @@ def _progress_bar(self):
length=200,
variable=self._pbar_position,
maximum=100,
mode="determinate")
mode=self._mode)
pbar.pack(side=tk.LEFT, padx=2, fill=tk.X, expand=True)
pbar.pack_forget()
return pbar

def start(self, mode):
def start(self, mode: T.Literal["indeterminate", "determinate"]) -> None:
""" Set progress bar mode and display,
Parameters
Expand All @@ -500,24 +508,48 @@ def start(self, mode):
self._set_mode(mode)
self._pbar.pack()

def stop(self):
def stop(self) -> None:
""" Reset progress bar and hide """
self._pbar_message.set("")
self._pbar_position.set(0)
self._set_mode("determinate")
self._mode = "determinate"
self._set_mode(self._mode)
self._pbar.pack_forget()

def _set_mode(self, mode):
""" Set the progress bar mode """
self._pbar.config(mode=mode)
def _set_mode(self, mode: T.Literal["indeterminate", "determinate"]) -> None:
""" Set the progress bar mode
Parameters
----------
mode: ["indeterminate", "determinate"]
The mode that the progress bar should be executed in
"""
self._mode = mode
self._pbar.config(mode=self._mode)
if mode == "indeterminate":
self._pbar.config(maximum=100)
self._pbar.start()
else:
self._pbar.stop()
self._pbar.config(maximum=100)

def progress_update(self, message, position, update_position=True):
def set_mode(self, mode: T.Literal["indeterminate", "determinate"]) -> None:
""" Set the mode of a currently displayed progress bar and reset position to 0.
If the given mode is the same as the currently configured mode, returns without performing
any action.
Parameters
----------
mode: ["indeterminate", "determinate"]
The mode that the progress bar should be set to
"""
if mode == self._mode:
return
self.stop()
self.start(mode)

def progress_update(self, message: str, position: int, update_position: bool = True) -> None:
""" Update the GUIs progress bar and position.
Parameters
Expand Down
3 changes: 1 addition & 2 deletions lib/gui/display_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
if T.TYPE_CHECKING:
from matplotlib.lines import Line2D

matplotlib.use("TkAgg")

logger: logging.Logger = logging.getLogger(__name__)


Expand All @@ -44,6 +42,7 @@ class GraphBase(ttk.Frame): # pylint: disable=too-many-ancestors
def __init__(self, parent: ttk.Frame, data, ylabel: str) -> None:
logger.debug("Initializing %s", self.__class__.__name__)
super().__init__(parent)
matplotlib.use("TkAgg") # Can't be at module level as breaks Github CI
style.use("ggplot")

self._calcs = data
Expand Down
39 changes: 35 additions & 4 deletions lib/gui/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def _prepare(self, category: T.Literal["faceswap", "tools"]) -> list[str]:
print("Loading...")

self._statusbar.message.set(f"Executing - {self._command}.py")
mode = "indeterminate" if self._command in ("effmpeg", "train") else "determinate"
mode: T.Literal["indeterminate",
"determinate"] = ("indeterminate" if self._command in ("effmpeg", "train")
else "determinate")
self._statusbar.start(mode)

args = self._build_args(category)
Expand Down Expand Up @@ -236,6 +238,7 @@ def __init__(self, wrapper: ProcessWrapper) -> None:
"tqdm": re.compile(r"(?P<dsc>.*?)(?P<pct>\d+%).*?(?P<itm>\S+/\S+)\W\["
r"(?P<tme>[\d+:]+<.*),\W(?P<rte>.*)[a-zA-Z/]*\]"),
"ffmpeg": re.compile(r"([a-zA-Z]+)=\s*(-?[\d|N/A]\S+)")}
self._first_loss_seen = False
logger.debug("Initialized %s", self.__class__.__name__)

@property
Expand Down Expand Up @@ -269,6 +272,24 @@ def execute_script(self, command: str, args: list[str]) -> None:
self._thread_stderr()
logger.debug("Executed Faceswap")

def _process_training_determinate_function(self, output: str) -> bool:
""" Process an stdout/stderr message to check for determinate TQDM output when training
Parameters
----------
output: str
The stdout/stderr string to test
Returns
-------
bool
``True`` if a determinate TQDM line was parsed when training otherwise ``False``
"""
if self._command == "train" and not self._first_loss_seen and self._capture_tqdm(output):
self._statusbar.set_mode("determinate")
return True
return False

def _process_progress_stdout(self, output: str) -> bool:
""" Process stdout for any faceswap processes that update the status/progress bar(s)
Expand All @@ -282,6 +303,9 @@ def _process_progress_stdout(self, output: str) -> bool:
bool
``True`` if all actions have been completed on the output line otherwise ``False``
"""
if self._process_training_determinate_function(output):
return True

if self._command == "train" and self._capture_loss(output):
return True

Expand All @@ -306,7 +330,9 @@ def _process_training_stdout(self, output: str) -> None:
if self._command != "train" or not tk_vars.is_training.get():
return

if "[saved models]" not in output.strip().lower():
t_output = output.strip().lower()
if "[saved model]" not in t_output or t_output.endswith("[saved model]"):
# Not a saved model line or saving the model for a reason other than standard saving
return

logger.debug("Trigger GUI Training update")
Expand Down Expand Up @@ -346,6 +372,7 @@ def _read_stdout(self) -> None:

returncode = self._process.poll()
assert returncode is not None
self._first_loss_seen = False
message = self._set_final_status(returncode)
self._wrapper.terminate(message)
logger.debug("Terminated stdout reader. returncode: %s", returncode)
Expand All @@ -369,8 +396,7 @@ def _read_stderr(self) -> None:
if output:
if self._command != "train" and self._capture_tqdm(output):
continue
if self._command == "train" and output.startswith("Reading training images"):
print(output.strip(), file=sys.stdout)
if self._process_training_determinate_function(output):
continue
if os.name == "nt" and "Call to CreateProcess failed. Error code: 2" in output:
# Suppress ptxas errors on Tensorflow for Windows
Expand Down Expand Up @@ -438,6 +464,11 @@ def _capture_loss(self, string: str) -> bool:
elapsed = self._calculate_elapsed()
message = (f"Elapsed: {elapsed} | "
f"Session Iterations: {self._train_stats['iterations']} {message}")

if not self._first_loss_seen:
self._statusbar.set_mode("indeterminate")
self._first_loss_seen = True

self._statusbar.progress_update(message, 0, False)
logger.trace("Succesfully captured loss: %s", message) # type:ignore[attr-defined]
return True
Expand Down
3 changes: 2 additions & 1 deletion lib/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import typing as T

from .augmentation import ImageAugmentation
from .generator import PreviewDataGenerator, TrainingDataGenerator
from .generator import Feeder
from .lr_finder import LearningRateFinder
from .preview_cv import PreviewBuffer, TriggerType

if T.TYPE_CHECKING:
Expand Down
Loading

0 comments on commit cf0efff

Please sign in to comment.