Skip to content

Commit

Permalink
Apple-Silicon: Place unsupported Ops on to CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed Jul 10, 2024
1 parent d80a6f8 commit ae3f68a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion lib/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from lib.gpu_stats import GPUStats
from lib.logger import crash_log, log_setup
from lib.utils import FaceswapError, get_torch_version, safe_shutdown, set_backend
from lib.utils import FaceswapError, get_backend, get_torch_version, safe_shutdown, set_backend

if T.TYPE_CHECKING:
import argparse
Expand Down Expand Up @@ -47,6 +47,10 @@ def _set_environment_variables(self) -> None:
logger.debug("Setting NUMEXPR_MAX_THREADS to %s", allocate)
os.environ["NUMEXPR_MAX_THREADS"] = str(allocate)

if get_backend() == "apple_silicon": # Let apple put unsupported ops on the CPU
logger.debug("Enabling unsupported Ops on CPU for Apple Silicon")
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

def _import_script(self) -> Callable:
""" Imports the relevant script as indicated by :attr:`_command` from the scripts folder.
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-tensorboard.*]
ignore_missing_imports = True
[mypy-tensorflow.*]
[mypy-torch.*]
ignore_missing_imports = True
[mypy-tqdm.*]
ignore_missing_imports = True
Expand Down

0 comments on commit ae3f68a

Please sign in to comment.