Skip to content

Commit

Permalink
postgres: implement query/query_raw using rust-postgres query_typed
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 committed Nov 1, 2024
1 parent 67e8e11 commit 1cb4c45
Showing 1 changed file with 84 additions and 61 deletions.
145 changes: 84 additions & 61 deletions quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use futures::{future::FutureExt, lock::Mutex};
use lru_cache::LruCache;

Check warning on line 27 in quaint/src/connector/postgres/native/mod.rs

View workflow job for this annotation

GitHub Actions / rustfmt

Diff in /home/runner/work/prisma-engines/prisma-engines/quaint/src/connector/postgres/native/mod.rs
use native_tls::{Certificate, Identity, TlsConnector};
use postgres_native_tls::MakeTlsConnector;
use postgres_types::{Kind as PostgresKind, Type as PostgresType};
use postgres_types::{Kind as PostgresKind, Type as PostgresType, ToSql};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::{
fmt::{Debug, Display},
Expand Down Expand Up @@ -540,29 +540,37 @@ impl Queryable for PostgreSql {
sql,
params,
move || async move {
let stmt = self.fetch_cached(sql, &[]).await?;

if stmt.params().len() != params.len() {
let kind = ErrorKind::IncorrectNumberOfParameters {
expected: stmt.params().len(),
actual: params.len(),
};

return Err(Error::builder(kind).build());
}
let converted_params = conversion::conv_params(params);
let param_types = conversion::params_to_types(params);
let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params
.iter()
.zip(param_types)
.map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty))
.collect();

// Execute the query using `query_typed`
let rows = self
.perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice()))
.perform_io(self.client.0.query_typed(sql, params_with_types.as_slice()))
.await?;

let col_types = stmt
.columns()
.iter()
.map(|c| PGColumnType::from_pg_type(c.type_()))
.map(ColumnType::from)
.collect::<Vec<_>>();
let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new());
// Extract column information from the first row, if available
let (col_types, column_names) = if let Some(row) = rows.first() {
let columns = row.columns();
let col_types = columns
.iter()
.map(|c| PGColumnType::from_pg_type(c.type_()))
.map(ColumnType::from)
.collect::<Vec<_>>();
let column_names = columns.iter().map(|c| c.name().to_string()).collect();

(col_types, column_names)
} else {
(Vec::new(), Vec::new())
};

let mut result = ResultSet::new(column_names, col_types, Vec::new());

// Process each row in the result set
for row in rows {
result.rows.push(row.get_result_row()?);
}
Expand All @@ -582,28 +590,35 @@ impl Queryable for PostgreSql {
sql,
params,
move || async move {
let stmt = self.fetch_cached(sql, params).await?;

if stmt.params().len() != params.len() {
let kind = ErrorKind::IncorrectNumberOfParameters {
expected: stmt.params().len(),
actual: params.len(),
};

return Err(Error::builder(kind).build());
}

let col_types = stmt
.columns()
let converted_params = conversion::conv_params(params);
let param_types = conversion::params_to_types(params);
let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params
.iter()
.map(|c| PGColumnType::from_pg_type(c.type_()))
.map(ColumnType::from)
.collect::<Vec<_>>();
.zip(param_types)
.map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty))
.collect();

// Execute the query using `query_typed`
let rows = self
.perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice()))
.perform_io(self.client.0.query_typed(sql, params_with_types.as_slice()))
.await?;

let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new());
// Extract column information from the first row, if available
let (col_types, column_names) = if let Some(row) = rows.first() {
let columns = row.columns();
let col_types = columns
.iter()
.map(|c| PGColumnType::from_pg_type(c.type_()))
.map(ColumnType::from)
.collect::<Vec<_>>();
let column_names = columns.iter().map(|c| c.name().to_string()).collect();

(col_types, column_names)
} else {
(Vec::new(), Vec::new())
};

let mut result = ResultSet::new(column_names, col_types, Vec::new());

for row in rows {
result.rows.push(row.get_result_row()?);
Expand Down Expand Up @@ -705,20 +720,24 @@ impl Queryable for PostgreSql {
sql,
params,
move || async move {
let stmt = self.fetch_cached(sql, &[]).await?;

if stmt.params().len() != params.len() {
let kind = ErrorKind::IncorrectNumberOfParameters {
expected: stmt.params().len(),
actual: params.len(),
};

return Err(Error::builder(kind).build());
}
let converted_params = conversion::conv_params(params);
let param_types = conversion::params_to_types(params);
let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params
.iter()
.zip(param_types)
.map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty))
.collect();

Check warning on line 729 in quaint/src/connector/postgres/native/mod.rs

View workflow job for this annotation

GitHub Actions / rustfmt

Diff in /home/runner/work/prisma-engines/prisma-engines/quaint/src/connector/postgres/native/mod.rs

let changes = self
.perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice()))
.await?;
.perform_io(self.client.0.query_typed_raw::<&(dyn ToSql + Sync), _>(
sql,
params_with_types.as_slice().iter()
.map(|(v, t)| (*v, t.clone()))
.collect::<Vec<_>>()
))
.await?
.rows_affected()
.unwrap_or(0);

Ok(changes)
},
Expand All @@ -735,20 +754,24 @@ impl Queryable for PostgreSql {
sql,
params,
move || async move {
let stmt = self.fetch_cached(sql, params).await?;

if stmt.params().len() != params.len() {
let kind = ErrorKind::IncorrectNumberOfParameters {
expected: stmt.params().len(),
actual: params.len(),
};

return Err(Error::builder(kind).build());
}
let converted_params = conversion::conv_params(params);
let param_types = conversion::params_to_types(params);
let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params
.iter()
.zip(param_types)
.map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty))
.collect();

Check warning on line 763 in quaint/src/connector/postgres/native/mod.rs

View workflow job for this annotation

GitHub Actions / rustfmt

Diff in /home/runner/work/prisma-engines/prisma-engines/quaint/src/connector/postgres/native/mod.rs

let changes = self
.perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice()))
.await?;
.perform_io(self.client.0.query_typed_raw::<&(dyn ToSql + Sync), _>(
sql,
params_with_types.as_slice().iter()
.map(|(v, t)| (*v, t.clone()))
.collect::<Vec<_>>()
))
.await?
.rows_affected()
.unwrap_or(0);

Ok(changes)
},
Expand Down

0 comments on commit 1cb4c45

Please sign in to comment.