Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client-v2] Added implementation for Bearer token auth #1904

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions client-v2/src/main/java/com/clickhouse/client/api/Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ public class Client implements AutoCloseable {
private final ColumnToMethodMatchingStrategy columnToMethodMatchingStrategy;

private Client(Set<String> endpoints, Map<String,String> configuration, boolean useNewImplementation,
ExecutorService sharedOperationExecutor, ColumnToMethodMatchingStrategy columnToMethodMatchingStrategy) {
ExecutorService sharedOperationExecutor, ColumnToMethodMatchingStrategy columnToMethodMatchingStrategy,
Supplier<String> bearerTokenSupplier) {
this.endpoints = endpoints;
this.configuration = configuration;
this.endpoints.forEach(endpoint -> {
Expand All @@ -169,7 +170,7 @@ private Client(Set<String> endpoints, Map<String,String> configuration, boolean
}
this.useNewImplementation = useNewImplementation;
if (useNewImplementation) {
this.httpClientHelper = new HttpAPIClientHelper(configuration);
this.httpClientHelper = new HttpAPIClientHelper(configuration, bearerTokenSupplier);
LOG.info("Using new http client implementation");
} else {
this.oldClient = ClientV1AdaptorHelper.createClient(configuration);
Expand Down Expand Up @@ -219,6 +220,8 @@ public static class Builder {
private ExecutorService sharedOperationExecutor = null;
private ColumnToMethodMatchingStrategy columnToMethodMatchingStrategy;

private Supplier<String> bearerTokenSupplier = null;

public Builder() {
this.endpoints = new HashSet<>();
this.configuration = new HashMap<String, String>();
Expand Down Expand Up @@ -886,6 +889,32 @@ public Builder useHTTPBasicAuth(boolean useBasicAuth) {
return this;
}

/**
* Specifies whether to use Bearer Authentication and what token to use.
* The token will be sent as is, so it should be encoded before passing to this method.
*
* @param bearerToken - token to use
* @return same instance of the builder
*/
public Builder useBearerTokenAuth(String bearerToken) {
this.httpHeader("Authorization", "Bearer " + bearerToken);
return this;
}

/**
* Specifies a supplier for a bearer tokens. It is useful when token should be refreshed.
* Supplier is called each time before sending a request.
* Supplier should return encoded token.
* This configuration cannot be used with {@link #useBearerTokenAuth(String)}.
*
* @param tokenSupplier - token supplier
* @return
*/
public Builder useBearerTokenAuth(Supplier<String> tokenSupplier) {
this.bearerTokenSupplier = tokenSupplier;
return this;
}

public Client build() {
setDefaults();

Expand All @@ -896,15 +925,22 @@ public Client build() {
// check if username and password are empty. so can not initiate client?
if (!this.configuration.containsKey("access_token") &&
(!this.configuration.containsKey("user") || !this.configuration.containsKey("password")) &&
!MapUtils.getFlag(this.configuration, "ssl_authentication")) {
throw new IllegalArgumentException("Username and password (or access token, or SSL authentication) are required");
!MapUtils.getFlag(this.configuration, "ssl_authentication", false) &&
!this.configuration.containsKey(ClientSettings.HTTP_HEADER_PREFIX + "Authorization") &&
this.bearerTokenSupplier == null) {
throw new IllegalArgumentException("Username and password (or access token or SSL authentication or pre-define Authorization header) are required");
}

if (this.configuration.containsKey("ssl_authentication") &&
(this.configuration.containsKey("password") || this.configuration.containsKey("access_token"))) {
throw new IllegalArgumentException("Only one of password, access token or SSL authentication can be used per client.");
}

if (this.configuration.containsKey(ClientSettings.HTTP_HEADER_PREFIX + "Authorization") &&
this.bearerTokenSupplier != null) {
throw new IllegalArgumentException("Bearer token supplier cannot be used with a predefined Authorization header");
}

if (this.configuration.containsKey("ssl_authentication") &&
!this.configuration.containsKey(ClickHouseClientOption.SSL_CERTIFICATE.getKey())) {
throw new IllegalArgumentException("SSL authentication requires a client certificate");
Expand Down Expand Up @@ -943,7 +979,15 @@ public Client build() {
throw new IllegalArgumentException("Nor server timezone nor specific timezone is set");
}

return new Client(this.endpoints, this.configuration, this.useNewImplementation, this.sharedOperationExecutor, this.columnToMethodMatchingStrategy);
// check for only new implementation configuration
if (!this.useNewImplementation) {
if (this.bearerTokenSupplier != null) {
throw new IllegalArgumentException("Bearer token supplier cannot be used with old implementation");
}
}

return new Client(this.endpoints, this.configuration, this.useNewImplementation, this.sharedOperationExecutor,
this.columnToMethodMatchingStrategy, this.bearerTokenSupplier);
}

private static final int DEFAULT_NETWORK_BUFFER_SIZE = 300_000;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import java.net.NoRouteToHostException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URLEncoder;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
Expand All @@ -74,6 +73,7 @@
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;

public class HttpAPIClientHelper {
private static final Logger LOG = LoggerFactory.getLogger(Client.class);
Expand All @@ -90,9 +90,12 @@ public class HttpAPIClientHelper {

private final Set<ClientFaultCause> defaultRetryCauses;

public HttpAPIClientHelper(Map<String, String> configuration) {
private final Supplier<String> bearerTokenSupplier;

public HttpAPIClientHelper(Map<String, String> configuration, Supplier<String> bearerTokenSupplier) {
this.chConfiguration = configuration;
this.httpClient = createHttpClient();
this.bearerTokenSupplier = bearerTokenSupplier;

RequestConfig.Builder reqConfBuilder = RequestConfig.custom();
MapUtils.applyLong(chConfiguration, "connection_request_timeout",
Expand Down Expand Up @@ -401,6 +404,8 @@ private void addHeaders(HttpPost req, Map<String, String> chConfig, Map<String,
if (MapUtils.getFlag(chConfig, "ssl_authentication", false)) {
req.addHeader(ClickHouseHttpProto.HEADER_DB_USER, chConfig.get(ClickHouseDefaults.USER.getKey()));
req.addHeader(ClickHouseHttpProto.HEADER_SSL_CERT_AUTH, "on");
} else if (bearerTokenSupplier != null) {
req.addHeader(HttpHeaders.AUTHORIZATION, "Bearer " + bearerTokenSupplier.get());
} else if (chConfig.getOrDefault(ClientSettings.HTTP_USE_BASIC_AUTH, "true").equalsIgnoreCase("true")) {
req.addHeader(HttpHeaders.AUTHORIZATION, "Basic " + Base64.getEncoder().encodeToString(
(chConfig.get(ClickHouseDefaults.USER.getKey()) + ":" + chConfig.get(ClickHouseDefaults.PASSWORD.getKey())).getBytes(StandardCharsets.UTF_8)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.io.ByteArrayInputStream;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Base64;
Expand All @@ -44,8 +45,11 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;

import static com.github.tomakehurst.wiremock.stubbing.Scenario.STARTED;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;

public class HttpTransportTests extends BaseIntegrationTest {
Expand All @@ -59,7 +63,6 @@ public void testConnectionTTL(Long connectionTtl, Long keepAlive, int openSocket
ClickHouseNode server = getServer(ClickHouseProtocol.HTTP);

int proxyPort = new Random().nextInt(1000) + 10000;
System.out.println("proxyPort: " + proxyPort);
ConnectionCounterListener connectionCounter = new ConnectionCounterListener();
WireMockServer proxy = new WireMockServer(WireMockConfiguration
.options().port(proxyPort)
Expand Down Expand Up @@ -147,7 +150,6 @@ public void closed(Socket socket) {
public void testConnectionRequestTimeout() {

int serverPort = new Random().nextInt(1000) + 10000;
System.out.println("proxyPort: " + serverPort);
ConnectionCounterListener connectionCounter = new ConnectionCounterListener();
WireMockServer proxy = new WireMockServer(WireMockConfiguration
.options().port(serverPort)
Expand Down Expand Up @@ -638,4 +640,69 @@ public void testErrorWithSendProgressHeaders() throws Exception {
}
}
}

@Test(groups = { "integration" })
public void testBearerTokenAuth() throws Exception {
WireMockServer mockServer = new WireMockServer( WireMockConfiguration
.options().port(9090).notifier(new ConsoleNotifier(false)));
mockServer.start();

String jwtToken1 = Arrays.stream(
new String[]{"header", "payload", "signature"})
.map(s -> Base64.getEncoder().encodeToString(s.getBytes(StandardCharsets.UTF_8)))
.reduce((s1, s2) -> s1 + "." + s2).get();
try (Client client = new Client.Builder().addEndpoint(Protocol.HTTP, "localhost", mockServer.port(), false)
.useBearerTokenAuth(jwtToken1)
.build()) {

mockServer.addStubMapping(WireMock.post(WireMock.anyUrl())
.withHeader("Authorization", WireMock.equalTo("Bearer " + jwtToken1))
.willReturn(WireMock.aResponse()
.withHeader("X-ClickHouse-Summary",
"{ \"read_bytes\": \"10\", \"read_rows\": \"1\"}")).build());

try (QueryResponse response = client.query("SELECT 1").get(1, TimeUnit.SECONDS)) {
Assert.assertEquals(response.getReadBytes(), 10);
} catch (Exception e) {
Assert.fail("Unexpected exception", e);
}
}

String jwtToken2 = Arrays.stream(
new String[]{"header2", "payload2", "signature2"})
.map(s -> Base64.getEncoder().encodeToString(s.getBytes(StandardCharsets.UTF_8)))
.reduce((s1, s2) -> s1 + "." + s2).get();
final AtomicInteger callCount = new AtomicInteger(0);
Supplier<String> tokenSupplier = () -> {
callCount.incrementAndGet();
return jwtToken2;
};

mockServer.addStubMapping(WireMock.post(WireMock.anyUrl())
.withHeader("Authorization", WireMock.equalTo("Bearer " + jwtToken2))
.willReturn(WireMock.aResponse()
.withHeader("X-ClickHouse-Summary",
"{ \"read_bytes\": \"10\", \"read_rows\": \"1\"}")).build());

try (Client client = new Client.Builder().addEndpoint(Protocol.HTTP, "localhost", mockServer.port(), false)
.useBearerTokenAuth(tokenSupplier)
.build()) {

for (int i = 0; i < 3; i++ ) {

try (QueryResponse response = client.query("SELECT 1").get(1, TimeUnit.SECONDS)) {
Assert.assertEquals(response.getReadBytes(), 10);
} catch (Exception e) {
Assert.fail("Unexpected exception", e);
}
}
}

assertEquals(callCount.get(), 3);

assertThrows(IllegalArgumentException.class, () -> {
new Client.Builder().useBearerTokenAuth("token")
.useBearerTokenAuth(() -> "token2").build();
});
}
}
Loading