diff --git a/app/models.py b/app/models.py index a0956e7b..95a5cfb3 100644 --- a/app/models.py +++ b/app/models.py @@ -61,6 +61,7 @@ class PredictionRequestElement(BaseModel): class PredictionInputModel(BaseModel): document: str elements: str + viewport: Dict class PredictedElement(BaseModel): diff --git a/app/selenium_app.py b/app/selenium_app.py index ffe553af..410f3b66 100644 --- a/app/selenium_app.py +++ b/app/selenium_app.py @@ -57,7 +57,8 @@ def get_page_elements(driver: webdriver.Remote, page_content: str) -> List[WebEl return driver.find_elements(by=By.XPATH, value="//*") -def get_elements_visibility(page_content: str, starting_element_idx: int, ending_element_idx: int) -> Dict[str, bool]: +def get_elements_visibility(page_content: str, starting_element_idx: int, ending_element_idx: int, + viewport: Dict) -> Dict[str, bool]: """Returns a visibility of portion of elements contained in page_content starting_element_idx and ending_element_idx are referring to the starting @@ -65,6 +66,7 @@ def get_elements_visibility(page_content: str, starting_element_idx: int, ending get_page_elements() function. """ driver = get_webdriver() + driver.set_window_size(viewport['width'], viewport['height']) all_elements = get_page_elements(driver, page_content) result = {} @@ -92,7 +94,7 @@ def get_chunks_boundaries(data: Sized, desired_chunks_amount: int) -> Iterable[T yield i * chunk_size, data_size -def get_element_id_to_is_displayed_mapping(page_content: str) -> Dict[str, bool]: +def get_element_id_to_is_displayed_mapping(page_content: str, viewport: Dict) -> Dict[str, bool]: """Returns visibility status of all elements in the page Returned dictionary uses elements' jdn-hash property value as keys @@ -100,6 +102,7 @@ def get_element_id_to_is_displayed_mapping(page_content: str) -> Dict[str, bool] escaped_page_content = str(page_content).encode('utf-8').decode('unicode_escape') driver = get_webdriver() + driver.set_window_size(viewport['width'], viewport['height']) all_elements = get_page_elements(driver, escaped_page_content) driver.quit() @@ -109,7 +112,7 @@ def get_element_id_to_is_displayed_mapping(page_content: str) -> Dict[str, bool] with concurrent.futures.ProcessPoolExecutor(max_workers=num_of_workers) as executor: futures = [ - executor.submit(get_elements_visibility, escaped_page_content, s, e) + executor.submit(get_elements_visibility, escaped_page_content, s, e, viewport) for s, e in jobs_chunks ] for future in concurrent.futures.as_completed(futures): diff --git a/ds_methods/angular_predict.py b/ds_methods/angular_predict.py index 1852604f..0896d6c2 100644 --- a/ds_methods/angular_predict.py +++ b/ds_methods/angular_predict.py @@ -23,6 +23,7 @@ async def angular_predict_elements(body): body_json = json.loads(body_str) elements_json = body_json.get("elements", []) document_json = body_json.get("document", "") + viewport_json = body_json.get("viewport", {}) # create softmax layser function to get probabilities from logits softmax = torch.nn.Softmax(dim=1) @@ -111,7 +112,7 @@ async def angular_predict_elements(body): del model gc.collect() result = results_df[columns_to_publish].to_dict(orient="records") - element_id_to_is_displayed_map = get_element_id_to_is_displayed_mapping(document_json) + element_id_to_is_displayed_map = get_element_id_to_is_displayed_mapping(document_json, viewport_json) for element in result: element["is_shown"] = element_id_to_is_displayed_map.get(element["element_id"], None) return result diff --git a/ds_methods/html5_predict.py b/ds_methods/html5_predict.py index 454df308..6639c6e1 100644 --- a/ds_methods/html5_predict.py +++ b/ds_methods/html5_predict.py @@ -22,6 +22,7 @@ async def html5_predict_elements(body): body_json = json.loads(body_str) elements_json = body_json.get("elements", []) document_json = body_json.get("document", "") + viewport_json = body_json.get("viewport", {}) # generate temporary filename filename = dt.datetime.now().strftime("%Y%m%d%H%M%S%f.json") @@ -95,7 +96,7 @@ async def html5_predict_elements(body): result = results_df[columns_to_publish].to_dict(orient="records") logger.info("Determining visibility locators") - element_id_to_is_displayed_map = get_element_id_to_is_displayed_mapping(document_json) + element_id_to_is_displayed_map = get_element_id_to_is_displayed_mapping(document_json, viewport_json) for element in result: element["is_shown"] = element_id_to_is_displayed_map.get(element["element_id"], None) return result diff --git a/ds_methods/mui_predict.py b/ds_methods/mui_predict.py index 640e31c5..427aefc8 100644 --- a/ds_methods/mui_predict.py +++ b/ds_methods/mui_predict.py @@ -23,6 +23,7 @@ async def mui_predict_elements(body): body_json = json.loads(body_str) elements_json = body_json.get("elements", []) document_json = body_json.get("document", "") + viewport_json = body_json.get("viewport", {}) # create softmax layser function to get probabilities from logits softmax = torch.nn.Softmax(dim=1) @@ -113,7 +114,7 @@ async def mui_predict_elements(body): del model gc.collect() result = results_df[columns_to_publish].to_dict(orient="records") - element_id_to_is_displayed_map = get_element_id_to_is_displayed_mapping(document_json) + element_id_to_is_displayed_map = get_element_id_to_is_displayed_mapping(document_json, viewport_json) for element in result: element["is_shown"] = element_id_to_is_displayed_map.get(element["element_id"], None) return result