Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:Add Type Safety and Validation for Provider Inputs #2353

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions keep-ui/app/providers/provider-form.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import "./provider-form.css";
import { useProviders } from "@/utils/hooks/useProviders";
import TimeAgo from "react-timeago";
import { toast } from "react-toastify";
import { providerInputSchema } from "../providers/providerInputSchema";

type ProviderFormProps = {
provider: Provider;
Expand Down Expand Up @@ -297,24 +298,36 @@ const ProviderForm = ({

const validateForm = (updatedFormValues) => {
const errors = {};
for (const [configKey, method] of Object.entries(provider.config)) {
if (!formValues[configKey] && method.required) {
errors[configKey] = true;
}
if (
"validation" in method &&
formValues[configKey] &&
!method.validation(updatedFormValues[configKey])
) {
errors[configKey] = true;
}
if (!formValues.provider_name) {
errors["provider_name"] = true;
try {
providerInputSchema.validateSync(updatedFormValues, { abortEarly: false });
setInputErrors({});
for (const [configKey, method] of Object.entries(provider.config)) {
if (!formValues[configKey] && method.required) {
errors[configKey] = true;
}
if (
"validation" in method &&
formValues[configKey] &&
!method.validation(updatedFormValues[configKey])
) {
errors[configKey] = true;
}
if (!formValues.provider_name) {
errors["provider_name"] = true;
}
}
return {};
}
catch (validationErrors) {
const errors = validationErrors.inner.reduce((acc, error) => {
acc[error.path] = error.message;
return acc;
}, {});

setInputErrors(errors);
return errors;
};
}

const handleInputChange = (event: React.ChangeEvent<HTMLInputElement>) => {
const { name, type } = event.target;
Expand Down
16 changes: 16 additions & 0 deletions keep-ui/app/providers/providerInputSchema.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import * as Yup from 'yup';

export const providerInputSchema = Yup.object().shape({
provider_name: Yup.string().required('Provider name is required'),
url: Yup.string()
.url('Invalid URL format')
.matches(/^(https?:\/\/)?(localhost|[\w.-]+)(:\d+)?\/?$/, 'URL must be in a valid format')
.required('URL is required'),
host: Yup.string()
.matches(/^[a-zA-Z0-9.-]+$/, 'Invalid host format')
.required('Host is required'),
port: Yup.number()
.min(1, 'Port number must be between 1 and 65535')
.max(65535, 'Port number must be between 1 and 65535'),
// Add any other fields as necessary
});
4 changes: 4 additions & 0 deletions keep-ui/app/providers/providers.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ export interface ProviderAuthConfig {
name: string;
description: string;
hint?: string;
host: string;
placeholder?: string;
provider_name: string;
port: number;
url: string;
validation: string; // regex
required?: boolean;
value?: string;
Expand Down
38 changes: 38 additions & 0 deletions keep-ui/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions keep-ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@
"yallist": "^4.0.0",
"yaml": "^2.2.2",
"yocto-queue": "^0.1.0",
"yup": "^1.4.0",
"zod": "^3.22.3"
},
"devDependencies": {
Expand Down
18 changes: 15 additions & 3 deletions keep/api/models/provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from datetime import datetime
from typing import Literal

from pydantic import BaseModel

from pydantic import BaseModel, Field, HttpUrl, validator
from keep.providers.models.provider_config import ProviderScope
from keep.providers.models.provider_method import ProviderMethod

Expand All @@ -11,6 +9,20 @@ class ProviderAlertsCountResponseDTO(BaseModel):
count: int




class ProviderConfigInput(BaseModel):
provider_name: str = Field(..., min_length=3, max_length=50)
provider_url: HttpUrl
port: int = Field(..., gt=0, lt=65536) # Ensures port is within valid range
api_key: str = Field(..., min_length=10) # Example: Min length of 10 characters

@validator("provider_name")
def name_cannot_contain_special_chars(cls, v):
if not v.isalnum():
raise ValueError("Provider name should be alphanumeric")
return v

class Provider(BaseModel):
id: str | None = None
display_name: str
Expand Down
30 changes: 29 additions & 1 deletion keep/api/routes/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
GetAlertException,
ProviderMethodException,
)
from keep.api.models.provider import ProviderConfigInput
from keep.providers.providers_factory import ProvidersFactory
from keep.providers.providers_factory import ProvidersFactory
from keep.providers.providers_service import ProvidersService
from keep.secretmanager.secretmanagerfactory import SecretManagerFactory
Expand All @@ -34,7 +36,33 @@
PROVIDER_DISTRIBUTION_ENABLED = config(
"PROVIDER_DISTRIBUTION_ENABLED", cast=bool, default=True
)

@router.post(
"/{provider_type}/{provider_id}/alerts",
description="Push new alerts to the provider",
)
def add_alert(
provider_type: str,
provider_id: str,
alert: ProviderConfigInput, # Automatically validates the input data
authenticated_entity: AuthenticatedEntity = Depends(
IdentityManagerFactory.get_auth_verifier(["write:alert"])
),
) -> JSONResponse:
try:
# Proceed with business logic if validation passes
tenant_id = authenticated_entity.tenant_id
context_manager = ContextManager(tenant_id=tenant_id)
secret_manager = SecretManagerFactory.get_secret_manager(context_manager)
provider_config = secret_manager.read_secret(
f"{tenant_id}_{provider_type}_{provider_id}", is_json=True
)
provider = ProvidersFactory.get_provider(
context_manager, provider_id, provider_type, provider_config
)
provider.deploy_alert(alert.dict()) # Convert Pydantic model to dict
return JSONResponse(status_code=200, content={"message": "deployed"})
except Exception as e:
return JSONResponse(status_code=500, content=str(e))

def _is_localhost():
# TODO - there are more "advanced" cases that we don't catch here
Expand Down
Loading