Skip to content

Commit

Permalink
Improve GT.save() usability (#499)
Browse files Browse the repository at this point in the history
* Improve `GT.save()` usability

* Isolate web driver preparation logic from `GT.save()`

* Add `quit()` to `_NoOpDriverCtx`

* Try to replace `_NoOpDriverCtx` with `no_op_callable()`

* Add `from __future__ import annotations` to `_utils_selenium.py`

* Introduce `cls_driver` and `cls_wd_options` to WebDriver

* Fix wrong `cls_driver` for `_SafariWebDriver`

* Fix the type hint for `debug_port` in `_BaseWebDriver`

* Add `test__utils_selenium.py`

* Remove `**params` from `GT.save()`

* Restore the comment about using `PIL` for converting tables into different formats
  • Loading branch information
jrycw authored Nov 22, 2024
1 parent aa7d553 commit 538fbf1
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 98 deletions.
131 changes: 33 additions & 98 deletions great_tables/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,24 +271,6 @@ def as_latex(self: GT, use_longtable: bool = False, tbl_pos: str | None = None)
DebugDumpOptions: TypeAlias = Literal["zoom", "width_resize", "final_resize"]


class _NoOpDriverCtx:
"""Context manager that no-ops entering a webdriver(options=...) instance."""

def __init__(self, driver: webdriver.Remote):
self.driver = driver

def __call__(self, options):
# no-op what is otherwise instantiating webdriver with options,
# since a webdriver instance was already passed on init
return self

def __enter__(self):
return self.driver

def __exit__(self, *args):
pass


def save(
self: GT,
file: Path | str,
Expand All @@ -300,7 +282,7 @@ def save(
debug_port: None | int = None,
encoding: str = "utf-8",
_debug_dump: DebugDumpOptions | None = None,
) -> None:
) -> GTSelf:
"""
Produce a high-resolution image file or PDF of the table.
Expand Down Expand Up @@ -333,17 +315,17 @@ def save(
debug_port
Port number to use for debugging. By default no debugging port is opened.
encoding
The encoding used when writing temporary files.
The character encoding used for the HTML content.
_debug_dump
Whether the saved image should be a big browser window, with key elements outlined. This is
helpful for debugging this function's resizing, cropping heuristics. This is an internal
parameter and subject to change.
Returns
-------
None
This function does not return anything; it simply saves the image to the specified file
path.
GT
The GT object is returned. This is the same object that the method is called on so that we
can facilitate method chaining.
Details
-------
Expand All @@ -365,95 +347,54 @@ def save(
```
"""
import base64

# Import the required packages
_try_import(name="selenium", pip_install_line="pip install selenium")

from selenium import webdriver
from ._utils_selenium import _get_web_driver

if selector != "table":
raise NotImplementedError("Currently, only selector='table' is supported.")

if isinstance(file, Path):
file = str(file)

# If there is no file extension, add the .png extension
if not Path(file).suffix:
file += ".png"
file = str(Path(file).with_suffix(".png"))

# Get the HTML content from the displayed output
html_content = as_raw_html(self)

# Set the webdriver and options based on the chosen browser (`web_driver=` argument)
if isinstance(web_driver, webdriver.Remote):
wdriver = _NoOpDriverCtx(web_driver)
wd_options = None

elif web_driver == "chrome":
wdriver = webdriver.Chrome
wd_options = webdriver.ChromeOptions()
elif web_driver == "safari":
wdriver = webdriver.Safari
wd_options = webdriver.SafariOptions()
elif web_driver == "firefox":
wdriver = webdriver.Firefox
wd_options = webdriver.FirefoxOptions()
elif web_driver == "edge":
wdriver = webdriver.Edge
wd_options = webdriver.EdgeOptions()
else:
raise ValueError(f"Unsupported web driver: {web_driver}")

# specify headless flag ----
if web_driver in {"firefox", "edge"}:
wd_options.add_argument("--headless")
elif web_driver == "chrome":
# Operate all webdrivers in headless mode
wd_options.add_argument("--headless=new")
else:
# note that safari currently doesn't support headless browsing
pass

if debug_port:
if web_driver == "chrome":
wd_options.add_argument(f"--remote-debugging-port={debug_port}")
elif web_driver == "firefox":
# TODO: not sure how to connect to this session on firefox?
wd_options.add_argument(f"--start-debugger-server {debug_port}")
else:
warnings.warn("debug_port argument only supported on chrome and firefox")
debug_port = None
wdriver = _get_web_driver(web_driver)

# run browser ----
with (
tempfile.TemporaryDirectory() as tmp_dir,
wdriver(options=wd_options) as headless_browser,
):

# Write the HTML content to the temp file
with open(f"{tmp_dir}/table.html", "w", encoding=encoding) as temp_file:
temp_file.write(html_content)

# Open the HTML file in the headless browser
with wdriver(debug_port=debug_port) as headless_browser:
headless_browser.set_window_size(*window_size)
headless_browser.get("file://" + temp_file.name)
encoded = base64.b64encode(html_content.encode(encoding=encoding)).decode(encoding=encoding)
headless_browser.get(f"data:text/html;base64,{encoded}")

_save_screenshot(headless_browser, scale, file, debug=_debug_dump)

if debug_port:
input(
f"Currently debugging on port {debug_port}.\n\n"
"If you are using Chrome, enter chrome://inspect to preview the headless browser."
"Other browsers may have different ways to preview headless browser sessions.\n\n"
"Press enter to continue."
)
if debug_port and web_driver not in {"chrome", "firefox"}:
warnings.warn("debug_port argument only supported on chrome and firefox")
debug_port = None

if debug_port:
input(
f"Currently debugging on port {debug_port}.\n\n"
"If you are using Chrome, enter chrome://inspect to preview the headless browser."
"Other browsers may have different ways to preview headless browser sessions.\n\n"
"Press enter to continue."
)

return self


def _save_screenshot(
driver: webdriver.Chrome, scale, path: str, debug: DebugDumpOptions | None
driver: webdriver.Chrome, scale: float, path: str, debug: DebugDumpOptions | None
) -> None:
from io import BytesIO
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait

# Based on: https://stackoverflow.com/a/52572919/
# In some headless browsers, element position and width do not always reflect
Expand All @@ -466,7 +407,6 @@ def _save_screenshot(
#
# I can't say for sure whether the final sleep is needed. Only that it seems like
# on CI with firefox sometimes the final screencapture is wider than necessary.

original_size = driver.get_window_size()

# set table zoom ----
Expand Down Expand Up @@ -517,19 +457,14 @@ def _save_screenshot(
if debug == "final_resize":
return _dump_debug_screenshot(driver, path)

el = driver.find_element(by=By.TAG_NAME, value="body")
el = WebDriverWait(driver, 1).until(EC.visibility_of_element_located((By.TAG_NAME, "body")))

time.sleep(0.05)

if path.endswith(".png"):
el.screenshot(path)
else:
_try_import(name="PIL", pip_install_line="pip install pillow")
_try_import(name="PIL", pip_install_line="pip install pillow")

from PIL import Image
from PIL import Image

# convert to other formats (e.g. pdf, bmp) using PIL
Image.open(fp=BytesIO(el.screenshot_as_png)).save(fp=path)
# convert to other formats (e.g. pdf, bmp) using PIL
Image.open(fp=BytesIO(el.screenshot_as_png)).save(fp=path)


def _dump_debug_screenshot(driver, path):
Expand Down
91 changes: 91 additions & 0 deletions great_tables/_utils_selenium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

from types import TracebackType
from typing import Literal
from typing_extensions import TypeAlias
from selenium import webdriver

# Create a list of all selenium webdrivers
WebDrivers: TypeAlias = Literal[
"chrome",
"firefox",
"safari",
"edge",
]


class _BaseWebDriver:

def __init__(self, debug_port: int | None = None):
self.debug_port = debug_port
self.wd_options = self.cls_wd_options()
self.add_arguments()
self.driver = self.cls_driver(self.wd_options)

def add_arguments(self): ...

def __enter__(self) -> WebDrivers | webdriver.Remote:
return self.driver

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
self.driver.quit()


class _ChromeWebDriver(_BaseWebDriver):
cls_driver = webdriver.Chrome
cls_wd_options = webdriver.ChromeOptions

def add_arguments(self):
self.wd_options.add_argument("--headless=new")
if self.debug_port is not None:
self.wd_options.add_argument(f"--remote-debugging-port={self.debug_port}")


class _SafariWebDriver(_BaseWebDriver):
cls_driver = webdriver.Safari
cls_wd_options = webdriver.SafariOptions


class _FirefoxWebDriver(_BaseWebDriver):
cls_driver = webdriver.Firefox
cls_wd_options = webdriver.FirefoxOptions

def add_arguments(self):
self.wd_options.add_argument("--headless")
if self.debug_port is not None:
self.wd_options.add_argument(f"--start-debugger-server {self.debug_port}")


class _EdgeWebDriver(_BaseWebDriver):
cls_driver = webdriver.Edge
cls_wd_options = webdriver.EdgeOptions

def add_arguments(self):
self.wd_options.add_argument("--headless")


def no_op_callable(web_driver: webdriver.Remote):
def wrapper(*args, **kwargs):
return web_driver

return wrapper


def _get_web_driver(web_driver: WebDrivers | webdriver.Remote):
if isinstance(web_driver, webdriver.Remote):
return no_op_callable(web_driver)
elif web_driver == "chrome":
return _ChromeWebDriver
elif web_driver == "safari":
return _SafariWebDriver
elif web_driver == "firefox":
return _FirefoxWebDriver
elif web_driver == "edge":
return _EdgeWebDriver
else:
raise ValueError(f"Unsupported web driver: {web_driver}")
39 changes: 39 additions & 0 deletions tests/test__utils_selenium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from great_tables._utils_selenium import (
_get_web_driver,
no_op_callable,
_ChromeWebDriver,
_SafariWebDriver,
_FirefoxWebDriver,
_EdgeWebDriver,
)


def test_no_op_callable():
"""
The test should cover the scenario of obtaining a remote driver in `_get_web_driver`.
"""
fake_input = object()
f = no_op_callable(fake_input)
assert f(1, x="x") is fake_input


@pytest.mark.parametrize(
"web_driver,Driver",
[
("chrome", _ChromeWebDriver),
("safari", _SafariWebDriver),
("firefox", _FirefoxWebDriver),
("edge", _EdgeWebDriver),
],
)
def test_get_web_driver(web_driver, Driver):
assert _get_web_driver(web_driver) is Driver


def test_get_web_driver_raise():
fake_web_driver = "fake_web_driver"
with pytest.raises(ValueError) as exc_info:
_get_web_driver(fake_web_driver)
assert exc_info.value.args[0] == f"Unsupported web driver: {fake_web_driver}"

0 comments on commit 538fbf1

Please sign in to comment.