Skip to content

Commit

Permalink
resolves vladmihalcea#674 - extract parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ymajoros2 committed Nov 8, 2023
1 parent d1d782c commit c4e39ef
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package io.hypersistence.utils.hibernate.query;

import io.hypersistence.utils.hibernate.util.ReflectionUtils;
import jakarta.persistence.Parameter;
import jakarta.persistence.Query;
import org.hibernate.query.spi.DomainQueryExecutionContext;
import org.hibernate.query.spi.QueryImplementor;
import org.hibernate.query.spi.QueryInterpretationCache;
import org.hibernate.query.spi.QueryParameterBindings;
import org.hibernate.query.spi.SelectQueryPlan;
import org.hibernate.query.sqm.internal.ConcreteSqmSelectQueryPlan;
import org.hibernate.query.sqm.internal.DomainParameterXref;
Expand All @@ -16,7 +18,11 @@
import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* The {@link SQLExtractor} allows you to extract the
Expand All @@ -39,67 +45,110 @@ protected SQLExtractor() {
* @return the underlying SQL generated by the provided JPA query
*/
public static String from(Query query) {
Query hibernateQuery = getHibernateQuery(query);
if (hibernateQuery instanceof SqmInterpretationsKey.InterpretationsKeySource &&
hibernateQuery instanceof QueryImplementor &&
hibernateQuery instanceof QuerySqmImpl) {
QueryInterpretationCache.Key cacheKey = SqmInterpretationsKey.createInterpretationsKey((SqmInterpretationsKey.InterpretationsKeySource) hibernateQuery);
QuerySqmImpl querySqm = (QuerySqmImpl) hibernateQuery;
Supplier buildSelectQueryPlan = () -> ReflectionUtils.invokeMethod(querySqm, "buildSelectQueryPlan");
SelectQueryPlan plan = cacheKey != null ? ((QueryImplementor) hibernateQuery).getSession().getFactory().getQueryEngine()
.getInterpretationCache()
.resolveSelectQueryPlan(cacheKey, buildSelectQueryPlan) :
(SelectQueryPlan) buildSelectQueryPlan.get();
if (plan instanceof ConcreteSqmSelectQueryPlan) {
ConcreteSqmSelectQueryPlan selectQueryPlan = (ConcreteSqmSelectQueryPlan) plan;
Object cacheableSqmInterpretation = ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "cacheableSqmInterpretation");
if (cacheableSqmInterpretation == null) {
DomainQueryExecutionContext domainQueryExecutionContext = DomainQueryExecutionContext.class.cast(querySqm);
cacheableSqmInterpretation = ReflectionUtils.invokeStaticMethod(
ReflectionUtils.getMethod(
ConcreteSqmSelectQueryPlan.class,
"buildCacheableSqmInterpretation",
SqmSelectStatement.class,
DomainParameterXref.class,
DomainQueryExecutionContext.class
),
ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "sqm"),
ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "domainParameterXref"),
domainQueryExecutionContext
);
}
if (cacheableSqmInterpretation != null) {
JdbcOperationQuerySelect jdbcSelect = ReflectionUtils.getFieldValueOrNull(cacheableSqmInterpretation, "jdbcSelect");
if (jdbcSelect != null) {
return jdbcSelect.getSqlString();
}
return getSqmQueryOptional(query)
.map(SQLExtractor::getSQLFromSqmQuery)
.orElseGet(() -> ReflectionUtils.invokeMethod(query, "getQueryString"));
}

private static String getSQLFromSqmQuery(QuerySqmImpl<?> querySqm) {
QueryInterpretationCache.Key cacheKey = SqmInterpretationsKey.createInterpretationsKey(querySqm);
Supplier<SelectQueryPlan<Object>> buildSelectQueryPlan = () -> ReflectionUtils.invokeMethod(querySqm, "buildSelectQueryPlan");
SelectQueryPlan<Object> plan = cacheKey != null ? ((QueryImplementor<?>) querySqm).getSession().getFactory().getQueryEngine()
.getInterpretationCache()
.resolveSelectQueryPlan(cacheKey, buildSelectQueryPlan) :
buildSelectQueryPlan.get();
if (plan instanceof ConcreteSqmSelectQueryPlan) {
ConcreteSqmSelectQueryPlan<?> selectQueryPlan = (ConcreteSqmSelectQueryPlan<?>) plan;
Object cacheableSqmInterpretation = ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "cacheableSqmInterpretation");
if (cacheableSqmInterpretation == null) {
cacheableSqmInterpretation = ReflectionUtils.invokeStaticMethod(
ReflectionUtils.getMethod(
ConcreteSqmSelectQueryPlan.class,
"buildCacheableSqmInterpretation",
SqmSelectStatement.class,
DomainParameterXref.class,
DomainQueryExecutionContext.class
),
ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "sqm"),
ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "domainParameterXref"),
querySqm
);
}
if (cacheableSqmInterpretation != null) {
JdbcOperationQuerySelect jdbcSelect = ReflectionUtils.getFieldValueOrNull(cacheableSqmInterpretation, "jdbcSelect");
if (jdbcSelect != null) {
return jdbcSelect.getSqlString();
}
}
}
return ReflectionUtils.invokeMethod(hibernateQuery, "getQueryString");
return querySqm.getQueryString();
}

public static List<Object> getSQLParameterValues(Query query) {
return getSqmQueryOptional(query)
.map(SQLExtractor::getParametersFromInternalQuerySqm)
.orElseGet(() -> getSQLParametersFromJPAQuery(query));
}

/**
* Retrieves the parameters from the internal query SQM.
*
* @param querySqm the internal query SQM object
* @return a list of parameter values
*/
private static List<Object> getParametersFromInternalQuerySqm(QuerySqmImpl<?> querySqm) {
List<Object> parameterValues = new ArrayList<>();

QueryParameterBindings parameterBindings = querySqm.getParameterBindings();
parameterBindings.visitBindings((queryParameterImplementor, queryParameterBinding) -> {
Object value = queryParameterBinding.getBindValue();
parameterValues.add(value);
});

return parameterValues;
}

/**
* Get parameters from JPA query without any magic or Hibernate implementation tricks. Order is probably lost in current Hibernate versions.
*
* @param query
* @return
*/
private static List<Object> getSQLParametersFromJPAQuery(Query query) {
return query.getParameters()
.stream()
.map(Parameter::getPosition)
.map(query::getParameter)
.collect(Collectors.toList());
}


/**
* Get the unproxied hibernate query underlying the provided query object.
*
* @param query JPA query
* @return the unproxied Hibernate query, or original query
* @return the unproxied Hibernate query, or original query if there is no proxy, or null if it's not an Hibernate query of required type
*/
private static Query getHibernateQuery(Query query) {
private static Optional<QuerySqmImpl<?>> getSqmQueryOptional(Query query) {
try {
if (query instanceof QuerySqmImpl || !Proxy.isProxyClass(query.getClass())) {
return query;
if (query instanceof QuerySqmImpl) {
QuerySqmImpl<?> querySqm = (QuerySqmImpl<?>) query;
return Optional.of(querySqm);
}
if (!Proxy.isProxyClass(query.getClass())) {
return Optional.empty();
}
// is proxyied, get it out
InvocationHandler invocationHandler = Proxy.getInvocationHandler(query);
Class<?> innerClass = invocationHandler.getClass();
Field targetField = innerClass.getDeclaredField("target");
targetField.setAccessible(true);
return (Query) targetField.get(invocationHandler);
QuerySqmImpl<?> querySqm = (QuerySqmImpl<?>) targetField.get(invocationHandler);
return Optional.of(querySqm);
} catch (NoSuchFieldException exception) {
return query; // seems it cannot extract it, probably not a hibernate proxy
return Optional.empty(); // not an Hibernate query
} catch (IllegalAccessException exception) {
throw new RuntimeException(exception);
throw new IllegalStateException(exception);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jakarta.persistence.Query;
import jakarta.persistence.Table;
import jakarta.persistence.Tuple;
import jakarta.persistence.TypedQuery;
import jakarta.persistence.criteria.CriteriaBuilder;
import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.criteria.Join;
Expand All @@ -20,7 +21,9 @@
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.time.LocalDate;
import java.util.List;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;

/**
Expand All @@ -39,15 +42,7 @@ protected Class<?>[] entities() {
@Test
public void testJPQL() {
doInJPA(entityManager -> {
Query jpql = entityManager
.createQuery(
"select " +
" YEAR(p.createdOn) as year, " +
" count(p) as postCount " +
"from " +
" Post p " +
"group by " +
" YEAR(p.createdOn)", Tuple.class);
Query jpql = createTestJPQL(entityManager);

String sql = SQLExtractor.from(jpql);

Expand All @@ -60,11 +55,10 @@ public void testJPQL() {
);
});
}

@Test
public void testCriteriaAPI() {
doInJPA(entityManager -> {
Query criteriaQuery = createTestQuery(entityManager);
Query criteriaQuery = createTestCriteriaQuery(entityManager);

String sql = SQLExtractor.from(criteriaQuery);

Expand All @@ -81,7 +75,7 @@ public void testCriteriaAPI() {
@Test
public void testCriteriaAPIWithProxy() {
doInJPA(entityManager -> {
Query criteriaQuery = createTestQuery(entityManager);
Query criteriaQuery = createTestCriteriaQuery(entityManager);
Query proxiedQuery = proxy(criteriaQuery);

String sql = SQLExtractor.from(proxiedQuery);
Expand All @@ -96,11 +90,59 @@ public void testCriteriaAPIWithProxy() {
});
}

@Test
public void testJPQLGetSQLParameters() {
doInJPA(entityManager -> {
Query jpql = createTestJPQL(entityManager);

List<?> parameters = SQLExtractor.getSQLParameterValues(jpql);

assertFalse(parameters.isEmpty());

LOGGER.info(
"The Criteria API query: [\n{}\n]\nhas following SQL parameters: \n{}\n",
jpql.unwrap(org.hibernate.query.Query.class).getQueryString(),
parameters
);
});
}

@Test
public void testCriteriaGetSQLParameters() {
doInJPA(entityManager -> {
Query criteriaQuery = createTestCriteriaQuery(entityManager);

List<?> parameters = SQLExtractor.getSQLParameterValues(criteriaQuery);

assertFalse(parameters.isEmpty());

LOGGER.info(
"The Criteria API query: [\n{}\n]\nhas following SQL parameters: \n{}\n",
criteriaQuery.unwrap(org.hibernate.query.Query.class).getQueryString(),
parameters
);
});
}

private static Query proxy(Query criteriaQuery) {
return (Query) Proxy.newProxyInstance(Query.class.getClassLoader(), new Class[]{Query.class}, new HibernateLikeInvocationHandler(criteriaQuery));
}

private static Query createTestQuery(EntityManager entityManager) {
private static Query createTestJPQL(EntityManager entityManager) {
Query jpql = entityManager
.createQuery(
"select " +
" YEAR(p.createdOn) as year, " +
" count(p) as postCount " +
"from Post p " +
"where p.title like :titleTemplate " +
"group by YEAR(p.createdOn) ",
Tuple.class);
jpql.setParameter("titleTemplate", "%Java%");
return jpql;
}

private static Query createTestCriteriaQuery(EntityManager entityManager) {
CriteriaBuilder builder = entityManager.getCriteriaBuilder();

CriteriaQuery<PostComment> criteria = builder.createQuery(PostComment.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Date;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
Expand Down Expand Up @@ -103,6 +104,8 @@ public static class Post {
@Type(PostgreSQLEnumType.class)
private PostStatus status;

private Date createdOn;

public Long getId() {
return id;
}
Expand All @@ -126,5 +129,13 @@ public PostStatus getStatus() {
public void setStatus(PostStatus status) {
this.status = status;
}

public Date getCreatedOn() {
return createdOn;
}

public void setCreatedOn(Date createdOn) {
this.createdOn = createdOn;
}
}
}

0 comments on commit c4e39ef

Please sign in to comment.