diff --git a/great_tables/_export.py b/great_tables/_export.py index 29fc31577..8c13dc9ed 100644 --- a/great_tables/_export.py +++ b/great_tables/_export.py @@ -271,24 +271,6 @@ def as_latex(self: GT, use_longtable: bool = False, tbl_pos: str | None = None) DebugDumpOptions: TypeAlias = Literal["zoom", "width_resize", "final_resize"] -class _NoOpDriverCtx: - """Context manager that no-ops entering a webdriver(options=...) instance.""" - - def __init__(self, driver: webdriver.Remote): - self.driver = driver - - def __call__(self, options): - # no-op what is otherwise instantiating webdriver with options, - # since a webdriver instance was already passed on init - return self - - def __enter__(self): - return self.driver - - def __exit__(self, *args): - pass - - def save( self: GT, file: Path | str, @@ -300,7 +282,7 @@ def save( debug_port: None | int = None, encoding: str = "utf-8", _debug_dump: DebugDumpOptions | None = None, -) -> None: +) -> GTSelf: """ Produce a high-resolution image file or PDF of the table. @@ -333,7 +315,7 @@ def save( debug_port Port number to use for debugging. By default no debugging port is opened. encoding - The encoding used when writing temporary files. + The character encoding used for the HTML content. _debug_dump Whether the saved image should be a big browser window, with key elements outlined. This is helpful for debugging this function's resizing, cropping heuristics. This is an internal @@ -341,9 +323,9 @@ def save( Returns ------- - None - This function does not return anything; it simply saves the image to the specified file - path. + GT + The GT object is returned. This is the same object that the method is called on so that we + can facilitate method chaining. Details ------- @@ -365,95 +347,54 @@ def save( ``` """ + import base64 # Import the required packages _try_import(name="selenium", pip_install_line="pip install selenium") - from selenium import webdriver + from ._utils_selenium import _get_web_driver if selector != "table": raise NotImplementedError("Currently, only selector='table' is supported.") - if isinstance(file, Path): - file = str(file) - # If there is no file extension, add the .png extension - if not Path(file).suffix: - file += ".png" + file = str(Path(file).with_suffix(".png")) # Get the HTML content from the displayed output html_content = as_raw_html(self) - # Set the webdriver and options based on the chosen browser (`web_driver=` argument) - if isinstance(web_driver, webdriver.Remote): - wdriver = _NoOpDriverCtx(web_driver) - wd_options = None - - elif web_driver == "chrome": - wdriver = webdriver.Chrome - wd_options = webdriver.ChromeOptions() - elif web_driver == "safari": - wdriver = webdriver.Safari - wd_options = webdriver.SafariOptions() - elif web_driver == "firefox": - wdriver = webdriver.Firefox - wd_options = webdriver.FirefoxOptions() - elif web_driver == "edge": - wdriver = webdriver.Edge - wd_options = webdriver.EdgeOptions() - else: - raise ValueError(f"Unsupported web driver: {web_driver}") - - # specify headless flag ---- - if web_driver in {"firefox", "edge"}: - wd_options.add_argument("--headless") - elif web_driver == "chrome": - # Operate all webdrivers in headless mode - wd_options.add_argument("--headless=new") - else: - # note that safari currently doesn't support headless browsing - pass - - if debug_port: - if web_driver == "chrome": - wd_options.add_argument(f"--remote-debugging-port={debug_port}") - elif web_driver == "firefox": - # TODO: not sure how to connect to this session on firefox? - wd_options.add_argument(f"--start-debugger-server {debug_port}") - else: - warnings.warn("debug_port argument only supported on chrome and firefox") - debug_port = None + wdriver = _get_web_driver(web_driver) # run browser ---- - with ( - tempfile.TemporaryDirectory() as tmp_dir, - wdriver(options=wd_options) as headless_browser, - ): - - # Write the HTML content to the temp file - with open(f"{tmp_dir}/table.html", "w", encoding=encoding) as temp_file: - temp_file.write(html_content) - - # Open the HTML file in the headless browser + with wdriver(debug_port=debug_port) as headless_browser: headless_browser.set_window_size(*window_size) - headless_browser.get("file://" + temp_file.name) + encoded = base64.b64encode(html_content.encode(encoding=encoding)).decode(encoding=encoding) + headless_browser.get(f"data:text/html;base64,{encoded}") _save_screenshot(headless_browser, scale, file, debug=_debug_dump) - if debug_port: - input( - f"Currently debugging on port {debug_port}.\n\n" - "If you are using Chrome, enter chrome://inspect to preview the headless browser." - "Other browsers may have different ways to preview headless browser sessions.\n\n" - "Press enter to continue." - ) + if debug_port and web_driver not in {"chrome", "firefox"}: + warnings.warn("debug_port argument only supported on chrome and firefox") + debug_port = None + + if debug_port: + input( + f"Currently debugging on port {debug_port}.\n\n" + "If you are using Chrome, enter chrome://inspect to preview the headless browser." + "Other browsers may have different ways to preview headless browser sessions.\n\n" + "Press enter to continue." + ) + + return self def _save_screenshot( - driver: webdriver.Chrome, scale, path: str, debug: DebugDumpOptions | None + driver: webdriver.Chrome, scale: float, path: str, debug: DebugDumpOptions | None ) -> None: from io import BytesIO from selenium.webdriver.common.by import By + from selenium.webdriver.support import expected_conditions as EC + from selenium.webdriver.support.ui import WebDriverWait # Based on: https://stackoverflow.com/a/52572919/ # In some headless browsers, element position and width do not always reflect @@ -466,7 +407,6 @@ def _save_screenshot( # # I can't say for sure whether the final sleep is needed. Only that it seems like # on CI with firefox sometimes the final screencapture is wider than necessary. - original_size = driver.get_window_size() # set table zoom ---- @@ -517,19 +457,14 @@ def _save_screenshot( if debug == "final_resize": return _dump_debug_screenshot(driver, path) - el = driver.find_element(by=By.TAG_NAME, value="body") + el = WebDriverWait(driver, 1).until(EC.visibility_of_element_located((By.TAG_NAME, "body"))) - time.sleep(0.05) - - if path.endswith(".png"): - el.screenshot(path) - else: - _try_import(name="PIL", pip_install_line="pip install pillow") + _try_import(name="PIL", pip_install_line="pip install pillow") - from PIL import Image + from PIL import Image - # convert to other formats (e.g. pdf, bmp) using PIL - Image.open(fp=BytesIO(el.screenshot_as_png)).save(fp=path) + # convert to other formats (e.g. pdf, bmp) using PIL + Image.open(fp=BytesIO(el.screenshot_as_png)).save(fp=path) def _dump_debug_screenshot(driver, path): diff --git a/great_tables/_utils_selenium.py b/great_tables/_utils_selenium.py new file mode 100644 index 000000000..6e39d1a06 --- /dev/null +++ b/great_tables/_utils_selenium.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from types import TracebackType +from typing import Literal +from typing_extensions import TypeAlias +from selenium import webdriver + +# Create a list of all selenium webdrivers +WebDrivers: TypeAlias = Literal[ + "chrome", + "firefox", + "safari", + "edge", +] + + +class _BaseWebDriver: + + def __init__(self, debug_port: int | None = None): + self.debug_port = debug_port + self.wd_options = self.cls_wd_options() + self.add_arguments() + self.driver = self.cls_driver(self.wd_options) + + def add_arguments(self): ... + + def __enter__(self) -> WebDrivers | webdriver.Remote: + return self.driver + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + self.driver.quit() + + +class _ChromeWebDriver(_BaseWebDriver): + cls_driver = webdriver.Chrome + cls_wd_options = webdriver.ChromeOptions + + def add_arguments(self): + self.wd_options.add_argument("--headless=new") + if self.debug_port is not None: + self.wd_options.add_argument(f"--remote-debugging-port={self.debug_port}") + + +class _SafariWebDriver(_BaseWebDriver): + cls_driver = webdriver.Safari + cls_wd_options = webdriver.SafariOptions + + +class _FirefoxWebDriver(_BaseWebDriver): + cls_driver = webdriver.Firefox + cls_wd_options = webdriver.FirefoxOptions + + def add_arguments(self): + self.wd_options.add_argument("--headless") + if self.debug_port is not None: + self.wd_options.add_argument(f"--start-debugger-server {self.debug_port}") + + +class _EdgeWebDriver(_BaseWebDriver): + cls_driver = webdriver.Edge + cls_wd_options = webdriver.EdgeOptions + + def add_arguments(self): + self.wd_options.add_argument("--headless") + + +def no_op_callable(web_driver: webdriver.Remote): + def wrapper(*args, **kwargs): + return web_driver + + return wrapper + + +def _get_web_driver(web_driver: WebDrivers | webdriver.Remote): + if isinstance(web_driver, webdriver.Remote): + return no_op_callable(web_driver) + elif web_driver == "chrome": + return _ChromeWebDriver + elif web_driver == "safari": + return _SafariWebDriver + elif web_driver == "firefox": + return _FirefoxWebDriver + elif web_driver == "edge": + return _EdgeWebDriver + else: + raise ValueError(f"Unsupported web driver: {web_driver}") diff --git a/tests/test__utils_selenium.py b/tests/test__utils_selenium.py new file mode 100644 index 000000000..993ea6f24 --- /dev/null +++ b/tests/test__utils_selenium.py @@ -0,0 +1,39 @@ +import pytest + +from great_tables._utils_selenium import ( + _get_web_driver, + no_op_callable, + _ChromeWebDriver, + _SafariWebDriver, + _FirefoxWebDriver, + _EdgeWebDriver, +) + + +def test_no_op_callable(): + """ + The test should cover the scenario of obtaining a remote driver in `_get_web_driver`. + """ + fake_input = object() + f = no_op_callable(fake_input) + assert f(1, x="x") is fake_input + + +@pytest.mark.parametrize( + "web_driver,Driver", + [ + ("chrome", _ChromeWebDriver), + ("safari", _SafariWebDriver), + ("firefox", _FirefoxWebDriver), + ("edge", _EdgeWebDriver), + ], +) +def test_get_web_driver(web_driver, Driver): + assert _get_web_driver(web_driver) is Driver + + +def test_get_web_driver_raise(): + fake_web_driver = "fake_web_driver" + with pytest.raises(ValueError) as exc_info: + _get_web_driver(fake_web_driver) + assert exc_info.value.args[0] == f"Unsupported web driver: {fake_web_driver}"