diff --git a/setup.py b/setup.py index 915ec490a5..0b6c223291 100755 --- a/setup.py +++ b/setup.py @@ -776,6 +776,44 @@ def _check_rocm(self) -> None: _INSTALL_FAILED = True +def _check_ld_config(lib: str) -> str: + """ Locate a library in ldconfig + + Parameters + ---------- + lib: str The library to locate + + Returns + ------- + str + The library from ldconfig, or empty string if not found + """ + retval = "" + ldconfig = which("ldconfig") + if not ldconfig: + return retval + + retval = next((line.decode("utf-8", errors="replace").strip() + for line in run([ldconfig, "-p"], + capture_output=True, + check=False).stdout.splitlines() + if lib.encode("utf-8") in line), "") + + if retval or (not retval and not os.environ.get("LD_LIBRARY_PATH")): + return retval + + for path in os.environ["LD_LIBRARY_PATH"].split(":"): + if not path: + continue + + retval = next((fname.strip() for fname in reversed(os.listdir(path)) + if lib in fname), "") + if retval: + break + + return retval + + class ROCmCheck(): # pylint:disable=too-few-public-methods """ Find the location of system installed ROCm on Linux """ def __init__(self) -> None: @@ -796,16 +834,7 @@ def _rocm_check(self) -> None: with ldconfig then attempt to find it in LD_LIBRARY_PATH. If found, set the :attr:`rocm_version` to the discovered version """ - ldconfig = os.popen("which ldconfig").read() - if not ldconfig: - return - chk = os.popen(f"{ldconfig} -p | grep -P \"librocm-core.so.\\d+\" | head -n 1").read() - if not chk and os.environ.get("LD_LIBRARY_PATH"): - for path in os.environ["LD_LIBRARY_PATH"].split(":"): - chk = os.popen(f"ls {path} | grep -P -o \"librocmcore.so.\\d+\" | " - "head -n 1").read() - if chk: - break + chk = _check_ld_config("librocm-core.so.") if not chk: return @@ -852,8 +881,7 @@ def _cuda_check(self) -> None: stdout.decode(locale.getpreferredencoding(), errors="ignore")) if version is not None: self.cuda_version = version.groupdict().get("cuda", None) - locate = "where" if self._os == "windows" else "which" - path = os.popen(f"{locate} nvcc").read() + path = which("nvcc") if path: path = path.split("\n")[0] # Split multiple entries and take first found while True: # Get Cuda root folder @@ -870,22 +898,15 @@ def _cuda_check(self) -> None: def _cuda_check_linux(self) -> None: """ For Linux check the dynamic link loader for libcudart. If not found with ldconfig then attempt to find it in LD_LIBRARY_PATH. """ - ldconfig = os.popen("which ldconfig").read() - if not ldconfig: - return - chk = os.popen(f"{ldconfig} -p | grep -P \"libcudart.so.\\d+.\\d+\" | head -n 1").read() - if not chk and os.environ.get("LD_LIBRARY_PATH"): - for path in os.environ["LD_LIBRARY_PATH"].split(":"): - chk = os.popen(f"ls {path} | grep -P -o \"libcudart.so.\\d+.\\d+\" | " - "head -n 1").read() - if chk: - break + chk = _check_ld_config("libcudart.so.") if not chk: # Cuda not found return cudavers = chk.strip().replace("libcudart.so.", "") - self.cuda_version = cudavers[:cudavers.find(" ")] - self.cuda_path = chk[chk.find("=>") + 3:chk.find("targets") - 1] + self.cuda_version = cudavers[:cudavers.find(" ")] if " " in cudavers else cudavers + cuda_path = chk[chk.find("=>") + 3:chk.find("targets") - 1] + if os.path.exists(cuda_path): + self.cuda_path = cuda_path def _cuda_check_windows(self) -> None: """ Check Windows CUDA Version and path from Environment Variables""" @@ -930,10 +951,7 @@ def _cudnn_check(self) -> None: if self._os == "windows": return - ldconfig = os.popen("which ldconfig").read() - if not ldconfig: - return - chk = os.popen(f"{ldconfig} -p | grep -P \"libcudnn.so.\" | head -n 1").read() + chk = _check_ld_config("libcudnn.so.") if not chk: return cudnnvers = chk.strip().replace("libcudnn.so.", "").split()[0] @@ -952,10 +970,7 @@ def _get_checkfiles_linux(self) -> list[str]: list List of header file locations to scan for cuDNN versions """ - ldconfig = os.popen("which ldconfig").read() - if not ldconfig: - return [] - chk = os.popen(f"{ldconfig} -p | grep -P \"libcudnn.so.\\d+\" | head -n 1").read() + chk = _check_ld_config("libcudnn.so.") chk = chk.strip().replace("libcudnn.so.", "") if not chk: return [] @@ -978,7 +993,7 @@ def _get_checkfiles_windows(self) -> list[str]: List of header file locations to scan for cuDNN versions """ # TODO A more reliable way of getting the windows location - if not self.cuda_path: + if not self.cuda_path or not os.path.exists(self.cuda_path): return [] scandir = os.path.join(self.cuda_path, "include") cudnn_checkfiles = [os.path.join(scandir, header) for header in self._cudnn_header_files]