Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keycloak auth #4600

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions examples/quickstart-pytorch/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ learning-rate = 0.1
batch-size = 32

[tool.flwr.federations]
default = "local-simulation"
default = "my-federation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 10
[tool.flwr.federations.local-simulation.options]
num-supernodes = 10

[tool.flwr.federations.local-simulation-gpu]
options.num-supernodes = 10
options.backend.client-resources.num-cpus = 2 # each ClientApp assumes to use 2CPUs
options.backend.client-resources.num-gpus = 0.2 # at most 5 ClientApp will run in a given GPU
[tool.flwr.federations.local-simulation-gpu.options]
num-supernodes = 10

[tool.flwr.federations.local-simulation-gpu.options.backend.client-resources]
num-cpus = 2
num-gpus = 0.2

[tool.flwr.federations.my-federation]
address = "127.0.0.1:9093"
root-certificates = "certificates/ca.crt"
17 changes: 17 additions & 0 deletions src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ service Exec {

// flwr ls command
rpc ListRuns(ListRunsRequest) returns (ListRunsResponse) {}

// Start login upon request
rpc Login(LoginRequest) returns (LoginResponse) {}

rpc GetAuthToken(GetAuthTokenRequest) returns (GetAuthTokenResponse) {}
}

message StartRunRequest {
Expand All @@ -52,3 +57,15 @@ message ListRunsResponse {
map<uint64, Run> run_dict = 1;
string now = 2;
}

message LoginRequest {}
message LoginResponse {
map<string, string> login_details = 1;
}

message GetAuthTokenRequest {
map<string, string> auth_details = 1;
}
message GetAuthTokenResponse {
map<string, string> auth_tokens = 1;
}
2 changes: 2 additions & 0 deletions src/py/flwr/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .build import build
from .install import install
from .log import log
from .login import login
from .ls import ls
from .new import new
from .run import run
Expand All @@ -39,6 +40,7 @@
app.command()(install)
app.command()(log)
app.command()(ls)
app.command()(login)

typer_click_object = get_command(app)

Expand Down
21 changes: 21 additions & 0 deletions src/py/flwr/cli/login/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower command line interface `login` command."""

from .login import login as login

__all__ = [
"login",
]
173 changes: 173 additions & 0 deletions src/py/flwr/cli/login/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower command line interface `login` command."""

import sys
from logging import DEBUG
from pathlib import Path
from typing import Annotated, Any, Optional

import typer
from tomli_w import dump

from flwr.cli.build import build
from flwr.cli.config_utils import load_and_validate
from flwr.common.auth_plugin import KeycloakUserPlugin, UserAuthPlugin
from flwr.common.config import flatten_dict, get_flwr_dir, parse_config_args
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
from flwr.proto.exec_pb2 import LoginRequest, LoginResponse # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub

auth_plugins = {
"keycloak": KeycloakUserPlugin,
}


def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
log(DEBUG, channel_connectivity)


def login(
app: Annotated[
Path,
typer.Argument(help="Path of the Flower App to run."),
] = Path("."),
federation: Annotated[
Optional[str],
typer.Argument(help="Name of the federation to login into."),
] = None,
) -> None:
"""Login to Flower SuperExec."""
typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)

pyproject_path = app / "pyproject.toml" if app else None
config, errors, warnings = load_and_validate(path=pyproject_path)

if config is None:
typer.secho(
"Project configuration could not be loaded.\n"
"pyproject.toml is invalid:\n"
+ "\n".join([f"- {line}" for line in errors]),
fg=typer.colors.RED,
bold=True,
)
sys.exit()

if warnings:
typer.secho(
"Project configuration is missing the following "
"recommended properties:\n" + "\n".join([f"- {line}" for line in warnings]),
fg=typer.colors.RED,
bold=True,
)

typer.secho("Success", fg=typer.colors.GREEN)

federation = federation or config["tool"]["flwr"]["federations"].get("default")

if federation is None:
typer.secho(
"❌ No federation name was provided and the project's `pyproject.toml` "
"doesn't declare a default federation (with a SuperExec address or an "
"`options.num-supernodes` value).",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

# Validate the federation exists in the configuration
federation_config = config["tool"]["flwr"]["federations"].get(federation)
if federation_config is None:
available_feds = {
fed for fed in config["tool"]["flwr"]["federations"] if fed != "default"
}
typer.secho(
f"❌ There is no `{federation}` federation declared in "
"`pyproject.toml`.\n The following federations were found:\n\n"
+ "\n".join(available_feds),
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

if "address" not in federation_config:
typer.secho(
f"❌ The federation `{federation}` does not have `SuperExec` "
"address in its config.\n Please specify the address in "
"`pyproject.toml` and try again.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

stub = create_exec_stub(app, federation_config)
login_request = LoginRequest()
login_response: LoginResponse = stub.Login(login_request)
# login_response = LoginResponse(auth_type="supertokens", auth_url="https://api.flower.ai/auth/signin")
auth_plugin = auth_plugins[login_response.login_details.get("auth_type", "")]
config = auth_plugin.login(login_response.login_details, config, federation, stub)

base_path = get_flwr_dir()
credentials_dir = base_path / ".credentials"
credentials_dir.mkdir(parents=True, exist_ok=True)

credential = credentials_dir / federation_config["address"]

with open(credential, "wb") as config_file:
dump(config, config_file)


def create_exec_stub(app: Path, federation_config: dict[str, Any]) -> ExecStub:
insecure_str = federation_config.get("insecure")
if root_certificates := federation_config.get("root-certificates"):
root_certificates_bytes = (app / root_certificates).read_bytes()
if insecure := bool(insecure_str):
typer.secho(
"❌ `root_certificates` were provided but the `insecure` parameter"
"is set to `True`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)
else:
root_certificates_bytes = None
if insecure_str is None:
typer.secho(
"❌ To disable TLS, set `insecure = true` in `pyproject.toml`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)
if not (insecure := bool(insecure_str)):
typer.secho(
"❌ No certificate were given yet `insecure` is set to `False`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

channel = create_channel(
server_address=federation_config["address"],
insecure=insecure,
root_certificates=root_certificates_bytes,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
interceptors=None,
)
channel.subscribe(on_channel_state_change)
stub = ExecStub(channel)

return stub
43 changes: 40 additions & 3 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import subprocess
from logging import DEBUG
from pathlib import Path
from typing import Annotated, Any, Optional
from typing import Annotated, Any, Dict, Optional

import typer

Expand All @@ -29,8 +29,11 @@
validate_federation_in_project_config,
validate_project_config,
)
from flwr.cli.run.user_interceptor import UserInterceptor
from flwr.common.auth_plugin import KeycloakUserPlugin, UserAuthPlugin
from flwr.common.config import (
flatten_dict,
get_flwr_dir,
parse_config_args,
user_config_to_configsrecord,
)
Expand All @@ -50,6 +53,11 @@
CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)


auth_plugins: Dict[str, UserAuthPlugin] = {
"keycloak": KeycloakUserPlugin,
}


def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
log(DEBUG, channel_connectivity)
Expand Down Expand Up @@ -97,7 +105,33 @@ def run(
)

if "address" in federation_config:
_run_with_exec_api(app, federation_config, config_overrides, stream)
base_path = get_flwr_dir()
credentials_dir = base_path / ".credentials"
credentials_dir.mkdir(parents=True, exist_ok=True)

credential = credentials_dir / federation_config["address"]

config_dict = {}
with credential.open("r", encoding="utf-8") as file:
for line in file:
# Ignore empty lines and comments
line = line.strip()
if not line or line.startswith("#"):
continue

# Split the key and value
if "=" in line:
key, value = line.split("=", 1)
# Remove quotes and whitespace from keys and values
config_dict[key.strip()] = value.strip().strip('"')

auth_type = config_dict.get("auth-type")
auth_plugin: Optional[UserAuthPlugin] = None
if auth_type is not None:
auth_plugin = auth_plugins.get(auth_type)(config_dict, credential)
_run_with_exec_api(
app, federation_config, config_overrides, stream, auth_plugin
)
else:
_run_without_exec_api(app, federation_config, config_overrides, federation)

Expand All @@ -108,6 +142,7 @@ def _run_with_exec_api(
federation_config: dict[str, Any],
config_overrides: Optional[list[str]],
stream: bool,
auth_plugin: Optional[UserAuthPlugin] = None,
) -> None:

insecure, root_certificates_bytes = validate_certificate_in_federation_config(
Expand All @@ -118,7 +153,9 @@ def _run_with_exec_api(
insecure=insecure,
root_certificates=root_certificates_bytes,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
interceptors=None,
interceptors=(
UserInterceptor(auth_plugin) if auth_plugin is not None else None
),
)
channel.subscribe(on_channel_state_change)
stub = ExecStub(channel)
Expand Down
Loading
Loading