From 68aa01e627b1476f59704bc9f15244de9b0a1a69 Mon Sep 17 00:00:00 2001 From: Sinju P Date: Tue, 20 Aug 2024 09:20:59 +0530 Subject: [PATCH 1/2] fix: Postgres doesn't reconnect after idle time #541 --- src/vanna/base/base.py | 62 +++++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index fa78a5d4..801e6133 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -437,7 +437,7 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: pass @abstractmethod - def remove_training_data(self, id: str, **kwargs) -> bool: + def remove_training_data(id: str, **kwargs) -> bool: """ Example: ```python @@ -840,6 +840,7 @@ def connect_to_postgres( port: int = None, **kwargs ): + """ Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] **Example:** @@ -913,26 +914,44 @@ def connect_to_postgres( except psycopg2.Error as e: raise ValidationError(e) + def connect_to_db(): + return psycopg2.connect(host=host, dbname=dbname, + user=user, password=password, port=port, **kwargs) + + def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: - if conn: - try: - cs = conn.cursor() - cs.execute(sql) - results = cs.fetchall() + conn = None + try: + conn = connect_to_db() # Initial connection attempt + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() - # Create a pandas dataframe from the results - df = pd.DataFrame( - results, columns=[desc[0] for desc in cs.description] - ) - return df + # Create a pandas dataframe from the results + df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description]) + return df + + except psycopg2.InterfaceError as e: + # Attempt to reconnect and retry the operation + if conn: + conn.close() # Ensure any existing connection is closed + conn = connect_to_db() + cs = conn.cursor() + cs.execute(sql) + results = cs.fetchall() + + # Create a pandas dataframe from the results + df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description]) + return df - except psycopg2.Error as e: + except psycopg2.Error as e: + if conn: conn.rollback() raise ValidationError(e) - except Exception as e: - conn.rollback() - raise e + except Exception as e: + conn.rollback() + raise e self.dialect = "PostgreSQL" self.run_sql_is_set = True @@ -1276,10 +1295,15 @@ def connect_to_bigquery( def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: if conn: - job = conn.query(sql) - df = job.result().to_dataframe() - return df - + try: + job = conn.query(sql) + df = job.result().to_dataframe() + return df + except GoogleAPIError as error: + errors = [] + for error in error.errors: + errors.append(error["message"]) + raise errors return None self.dialect = "BigQuery SQL" From 78c7efc9af26144c206c06fd6d1ac7c73830daea Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Wed, 21 Aug 2024 11:19:36 -0400 Subject: [PATCH 2/2] isolate changes to the postgres method --- src/vanna/base/base.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 78b167f3..e8e0bbf9 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -464,7 +464,7 @@ def get_training_data(self, **kwargs) -> pd.DataFrame: pass @abstractmethod - def remove_training_data(id: str, **kwargs) -> bool: + def remove_training_data(self, id: str, **kwargs) -> bool: """ Example: ```python @@ -1322,15 +1322,9 @@ def connect_to_bigquery( def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: if conn: - try: - job = conn.query(sql) - df = job.result().to_dataframe() - return df - except GoogleAPIError as error: - errors = [] - for error in error.errors: - errors.append(error["message"]) - raise errors + job = conn.query(sql) + df = job.result().to_dataframe() + return df return None self.dialect = "BigQuery SQL"