Skip to content

Commit

Permalink
Vector Search APIs (#55)
Browse files Browse the repository at this point in the history
* adding api calls + data structure tests

* Updating NEWS.md & pkgdown YAML

* adding tests + minor fixes

* Increment version number to 0.2.4

---------

Co-authored-by: Zac Davies <[email protected]>
  • Loading branch information
zacdav-db and Zac Davies authored Jul 12, 2024
1 parent c0b307a commit 7c1e83d
Show file tree
Hide file tree
Showing 31 changed files with 2,168 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: brickster
Title: R Toolkit for Databricks
Version: 0.2.3
Version: 0.2.4
Authors@R:
c(
person(given = "Zac",
Expand Down
23 changes: 23 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ export(db_volume_file_exists)
export(db_volume_list)
export(db_volume_read)
export(db_volume_write)
export(db_vs_endpoints_create)
export(db_vs_endpoints_delete)
export(db_vs_endpoints_get)
export(db_vs_endpoints_list)
export(db_vs_indexes_create)
export(db_vs_indexes_delete)
export(db_vs_indexes_delete_data)
export(db_vs_indexes_get)
export(db_vs_indexes_list)
export(db_vs_indexes_query)
export(db_vs_indexes_query_next_page)
export(db_vs_indexes_scan)
export(db_vs_indexes_sync)
export(db_vs_indexes_upsert_data)
export(db_workspace_delete)
export(db_workspace_export)
export(db_workspace_get_status)
Expand All @@ -123,8 +137,12 @@ export(db_workspace_list)
export(db_workspace_mkdirs)
export(db_wsid)
export(dbfs_storage_info)
export(delta_sync_index_spec)
export(direct_access_index_spec)
export(docker_image)
export(email_notifications)
export(embedding_source_column)
export(embedding_vector_column)
export(file_storage_info)
export(gcp_attributes)
export(get_and_start_cluster)
Expand All @@ -143,8 +161,12 @@ export(is.cluster_autoscale)
export(is.cluster_log_conf)
export(is.cron_schedule)
export(is.dbfs_storage_info)
export(is.delta_sync_index)
export(is.direct_access_index)
export(is.docker_image)
export(is.email_notifications)
export(is.embedding_source_column)
export(is.embedding_vector_column)
export(is.file_storage_info)
export(is.gcp_attributes)
export(is.git_source)
Expand All @@ -167,6 +189,7 @@ export(is.spark_jar_task)
export(is.spark_python_task)
export(is.spark_submit_task)
export(is.valid_task_type)
export(is.vector_search_index_spec)
export(job_task)
export(job_tasks)
export(lib_cran)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# brickster 0.2.4

# brickster 0.2.3

* Adding NEWS.md
Expand All @@ -6,3 +8,4 @@
`DATABRICKS_TOKEN` isn't specified (e.g `db_token()` returns `NULL`)
* Updating authentication vignette to include information on OAuth
* Updating README.md to include quick start and clearer information
* Adding vector search index functions
251 changes: 251 additions & 0 deletions R/data-structures.R
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,254 @@ is.job_task <- function(x) {
inherits(x, "JobTaskSettings")
}


#' Embedding Source Column
#'
#' @param name Name of the column
#' @param model_endpoint_name Name of the embedding model endpoint
#'
#' @family Vector Search API
#'
#' @export
embedding_source_column <- function(name, model_endpoint_name) {

obj <- list(
name = name,
embedding_model_endpoint_name = model_endpoint_name
)

class(obj) <- c("EmbeddingSourceColumn", "list")
obj
}

#' Test if object is of class EmbeddingSourceColumn
#'
#' @param x An object
#' @return `TRUE` if the object inherits from the `EmbeddingSourceColumn` class.
#' @export
is.embedding_source_column <- function(x) {
inherits(x, "EmbeddingSourceColumn")
}

#' Embedding Vector Column
#'
#' @param name Name of the column
#' @param dimension dimension of the embedding vector
#'
#' @family Vector Search API
#'
#' @export
embedding_vector_column <- function(name, dimension) {

stopifnot(is.numeric(dimension))

obj <- list(
name = name,
embedding_dimension = dimension
)

class(obj) <- c("EmbeddingVectorColumn", "list")
obj
}

#' Test if object is of class EmbeddingVectorColumn
#'
#' @param x An object
#' @return `TRUE` if the object inherits from the `EmbeddingVectorColumn` class.
#' @export
is.embedding_vector_column <- function(x) {
inherits(x, "EmbeddingVectorColumn")
}



#' Delta Sync Vector Search Index Specification
#'
#' @param source_table The name of the source table.
#' @param embedding_writeback_table Name of table to sync index contents and
#' computed embeddings back to delta table, see details.
#' @param embedding_source_columns The columns that contain the embedding
#' source, must be one or list of [embedding_source_column()]
#' @param embedding_vector_columns The columns that contain the embedding, must
#' be one or list of [embedding_vector_column()]
#' @param pipeline_type Pipeline execution mode, see details.
#'
#' @details
#' `pipeline_type` is either:
#' - `"TRIGGERED"`: If the pipeline uses the triggered execution mode, the
#' system stops processing after successfully refreshing the source table in
#' the pipeline once, ensuring the table is updated based on the data available
#' when the update started.
#' - `"CONTINUOUS"` If the pipeline uses continuous execution, the pipeline
#' processes new data as it arrives in the source table to keep vector index
#' fresh.
#'
#' The only supported naming convention for `embedding_writeback_table` is
#' `"<index_name>_writeback_table"`.
#'
#' @seealso [db_vs_indexes_create()]
#' @family Vector Search API
#'
#' @export
delta_sync_index_spec <- function(source_table,
embedding_writeback_table = NULL,
embedding_source_columns = NULL,
embedding_vector_columns = NULL,
pipeline_type = c("TRIGGERED", "CONTINUOUS")) {

pipeline_type <- match.arg(pipeline_type)

# check embedding objects comply
if (!is.null(embedding_source_columns)) {
if (is.list(embedding_source_columns) && !is.embedding_source_column(embedding_source_columns)) {
valid_columns <- vapply(embedding_source_columns, function(x) {
is.embedding_source_column(x)
}, logical(1))
if (!all(valid_columns)) {
stop("`embedding_source_columns` must all be defined by `embedding_source_column` function")
}
} else {
stopifnot(is.embedding_source_column(embedding_source_columns))
}
}

if (!is.null(embedding_vector_columns)) {
if (is.list(embedding_vector_columns) && !is.embedding_vector_column(embedding_vector_columns)) {
valid_columns <- vapply(embedding_vector_columns, function(x) {
is.embedding_vector_column(x)
}, logical(1))
if (!all(valid_columns)) {
stop("`embedding_vector_columns` must all be defined by `embedding_vector_column` function")
}
} else {
stopifnot(is.embedding_vector_column(embedding_vector_columns))
}
}

if (is.null(embedding_vector_columns) & is.null(embedding_source_columns)) {
stop("Must specify at least one embedding vector or source column")
}

obj <- list(
source_table = source_table,
embedding_source_columns = embedding_source_columns,
embedding_vector_columns = embedding_vector_columns,
embedding_writeback_table = embedding_writeback_table,
pipeline_type = pipeline_type
)

class(obj) <- c("VectorSearchIndexSpec", "DeltaSyncIndex", "list")
obj
}

#' Delta Sync Vector Search Index Specification
#'
#' @param embedding_source_columns The columns that contain the embedding
#' source, must be one or list of [embedding_source_column()]
#' @param embedding_vector_columns The columns that contain the embedding, must
#' be one or list of [embedding_vector_column()]
#' vectors.
#' @param schema Named list, names are column names, values are types. See
#' details.
#'
#' @details
#' The supported types are:
#' - `"integer"`
#' - `"long"`
#' - `"float"`
#' - `"double"`
#' - `"boolean"`
#' - `"string"`
#' - `"date"`
#' - `"timestamp"`
#' - `"array<float>"`: supported for vector columns
#' - `"array<double>"`: supported for vector columns
#'
#'
#' @seealso [db_vs_indexes_create()]
#' @family Vector Search API
#'
#' @export
direct_access_index_spec <- function(embedding_source_columns = NULL,
embedding_vector_columns = NULL,
schema) {

# check embedding objects comply
if (!is.null(embedding_source_columns)) {
if (is.list(embedding_source_columns) && !is.embedding_source_column(embedding_source_columns)) {
valid_columns <- vapply(embedding_source_columns, function(x) {
is.embedding_source_column(x)
}, logical(1))
if (!all(valid_columns)) {
stop("`embedding_source_columns` must all be defined by `embedding_source_column` function")
}
} else {
stopifnot(is.embedding_source_column(embedding_source_columns))
}
}

if (!is.null(embedding_vector_columns)) {
if (is.list(embedding_vector_columns) && !is.embedding_vector_column(embedding_vector_columns)) {
valid_columns <- vapply(embedding_vector_columns, function(x) {
is.embedding_vector_column(x)
}, logical(1))
if (!all(valid_columns)) {
stop("`embedding_vector_columns` must all be defined by `embedding_vector_column` function")
}
} else {
stopifnot(is.embedding_vector_column(embedding_vector_columns))
}
}

if (is.null(embedding_vector_columns) & is.null(embedding_source_columns)) {
stop("Must specify at least one embedding vector or source column")
}

if (is.null(schema)) {
stop("`schema` must be present.")
}

if (!(is.list(schema) && rlang::is_named(schema))) {
stop("`schema` must be a named list.")
}

obj <- list(
schema_json = jsonlite::toJSON(schema, auto_unbox = TRUE),
embedding_source_columns = embedding_source_columns,
embedding_vector_columns = embedding_vector_columns
)

class(obj) <- c("VectorSearchIndexSpec", "DirectAccessIndex", "list")
obj
}


#' Test if object is of class VectorSearchIndexSpec
#'
#' @param x An object
#' @return `TRUE` if the object inherits from the `VectorSearchIndexSpec` class.
#' @export
is.vector_search_index_spec <- function(x) {
inherits(x, "VectorSearchIndexSpec")
}


#' Test if object is of class DirectAccessIndex
#'
#' @param x An object
#' @return `TRUE` if the object inherits from the `DirectAccessIndex` class.
#' @export
is.direct_access_index <- function(x) {
inherits(x, "DirectAccessIndex")
}


#' Test if object is of class DeltaSyncIndex
#'
#' @param x An object
#' @return `TRUE` if the object inherits from the `DeltaSyncIndex` class.
#' @export
is.delta_sync_index <- function(x) {
inherits(x, "DeltaSyncIndex")
}

Loading

0 comments on commit 7c1e83d

Please sign in to comment.