Skip to content

Commit

Permalink
Supports filling elements through templates for expression (#2308)
Browse files Browse the repository at this point in the history
issue: milvus-io/milvus#36672

Signed-off-by: Cai Zhang <[email protected]>
  • Loading branch information
xiaocai2333 authored Oct 29, 2024
1 parent df326c6 commit fadc01b
Show file tree
Hide file tree
Showing 12 changed files with 756 additions and 287 deletions.
196 changes: 196 additions & 0 deletions examples/search_with_template_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# hello_milvus.py demonstrates the basic operations of PyMilvus, a Python SDK of Milvus.
# 1. connect to Milvus
# 2. create collection
# 3. insert data
# 4. create index
# 5. search, query, and hybrid search on entities
# 6. delete entities by PK
# 7. drop collection
import time

import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)

fmt = "\n=== {:30} ===\n"
search_latency_fmt = "search latency = {:.4f}s"
num_entities, dim = 3000, 8

#################################################################################
# 1. connect to Milvus
# Add a new connection alias `default` for Milvus server in `localhost:19530`
# Actually the "default" alias is a buildin in PyMilvus.
# If the address of Milvus is the same as `localhost:19530`, you can omit all
# parameters and call the method as: `connections.connect()`.
#
# Note: the `using` parameter of the following methods is default to "default".
print(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

has = utility.has_collection("hello_milvus")
print(f"Does collection hello_milvus exist in Milvus: {has}")

#################################################################################
# 2. create collection
# We're going to create a collection with 3 fields.
# +-+------------+------------+------------------+------------------------------+
# | | field name | field type | other attributes | field description |
# +-+------------+------------+------------------+------------------------------+
# |1| "pk" | VarChar | is_primary=True | "primary field" |
# | | | | auto_id=False | |
# +-+------------+------------+------------------+------------------------------+
# |2| "random" | Double | | "a double field" |
# +-+------------+------------+------------------+------------------------------+
# |3|"embeddings"| FloatVector| dim=8 | "float vector with dim 8" |
# +-+------------+------------+------------------+------------------------------+
fields = [
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="random", dtype=DataType.DOUBLE),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim)
]

schema = CollectionSchema(fields, "hello_milvus is the simplest demo to introduce the APIs")

print(fmt.format("Create collection `hello_milvus`"))
hello_milvus = Collection("hello_milvus", schema, consistency_level="Strong")

################################################################################
# 3. insert data
# We are going to insert 3000 rows of data into `hello_milvus`
# Data to be inserted must be organized in fields.
#
# The insert() method returns:
# - either automatically generated primary keys by Milvus if auto_id=True in the schema;
# - or the existing primary key field from the entities if auto_id=False in the schema.

print(fmt.format("Start inserting entities"))
rng = np.random.default_rng(seed=19530)
entities = [
# provide the pk field because `auto_id` is set to False
[str(i) for i in range(num_entities)],
rng.random(num_entities).tolist(), # field random, only supports list
rng.random((num_entities, dim), np.float32), # field embeddings, supports numpy.ndarray and list
]

insert_result = hello_milvus.insert(entities)

row = {
"pk": "19530",
"random": 0.5,
"embeddings": rng.random((1, dim), np.float32)[0]
}
hello_milvus.insert(row)

hello_milvus.flush()
print(f"Number of entities in Milvus: {hello_milvus.num_entities}") # check the num_entities

################################################################################
# 4. create index
# We are going to create an IVF_FLAT index for hello_milvus collection.
# create_index() can only be applied to `FloatVector` and `BinaryVector` fields.
print(fmt.format("Start Creating index IVF_FLAT"))
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}

hello_milvus.create_index("embeddings", index)

################################################################################
# 5. search, query, and hybrid search
# After data were inserted into Milvus and indexed, you can perform:
# - search based on vector similarity
# - query based on scalar filtering(boolean, int, etc.)
# - hybrid search based on vector similarity and scalar filtering.
#

# Before conducting a search or a query, you need to load the data in `hello_milvus` into memory.
print(fmt.format("Start loading"))
hello_milvus.load()

# -----------------------------------------------------------------------------
# search based on vector similarity
print(fmt.format("Start searching based on vector similarity"))
vectors_to_search = entities[-1][-2:]
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}

exprs = {
"pk == {str}": {"str": "10"},
"pk in {list}": {"list": ["1", "10", "100"]},
"random > {target}": {"target": 5},
"random <= {target}": {"target": 111.5},
"{min} <= random < {max}": {"min": 0, "max": 9999},
}

for expr, expr_params in exprs.items():
print(f"search with expression: {expr}")
start_time = time.time()
result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, expr=expr,
output_fields=["random"], expr_params=expr_params)
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}, random field: {hit.entity.get('random')}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# query based on scalar filtering(boolean, int, etc.)
start_time = time.time()
result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"], expr_params=expr_params)
end_time = time.time()

print(f"query result:\n-{result}")
print(search_latency_fmt.format(end_time - start_time))

# -----------------------------------------------------------------------------
# pagination
r1 = hello_milvus.query(expr=expr, limit=4, output_fields=["random"], expr_params=expr_params)
r2 = hello_milvus.query(expr=expr, offset=1, limit=3, output_fields=["random"], expr_params=expr_params)
print(f"query pagination(limit=4):\n\t{r1}")
print(f"query pagination(offset=1, limit=3):\n\t{r2}")

# -----------------------------------------------------------------------------
# hybrid search

start_time = time.time()
result = hello_milvus.search(vectors_to_search, "embeddings", search_params, limit=3, expr=expr,
output_fields=["random"], expr_params=expr_params)
end_time = time.time()

for hits in result:
for hit in hits:
print(f"hit: {hit}, random field: {hit.entity.get('random')}")
print(search_latency_fmt.format(end_time - start_time))

###############################################################################
# 6. delete entities by PK
# You can delete entities by their PK values using boolean expressions.
ids = insert_result.primary_keys

expr = "pk in {list}"
expr_params = {"list": [ids[0], ids[1]]}
print(fmt.format(f"Start deleting with expr `{expr}`"))

result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"], expr_params=expr_params)
print(f"query before delete by expr=`{expr}` -> result: \n-{result[0]}\n-{result[1]}\n")

hello_milvus.delete(expr, expr_params=expr_params)

result = hello_milvus.query(expr=expr, output_fields=["random", "embeddings"], expr_params=expr_params)
print(f"query after delete by expr=`{expr}` -> result: {result}\n")


###############################################################################
# 7. drop collection
# Finally, drop the hello_milvus collection
print(fmt.format("Drop collection `hello_milvus`"))
utility.drop_collection("hello_milvus")
3 changes: 2 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,8 @@ def delete(
partition_name,
expression,
consistency_level=kwargs.get("consistency_level", 0),
param_name=kwargs.get("param_name"),
param_name=kwargs.pop("param_name", None),
**kwargs,
)
future = self._stub.Delete.future(req, timeout=timeout)

Expand Down
54 changes: 54 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pymilvus.grpc_gen import milvus_pb2 as milvus_types
from pymilvus.grpc_gen import schema_pb2 as schema_types
from pymilvus.orm.schema import CollectionSchema
from pymilvus.orm.types import infer_dtype_by_scalar_data

from . import __version__, blob, entity_helper, ts_utils, utils
from .check import check_pass_param, is_legal_collection_properties
Expand Down Expand Up @@ -683,6 +684,7 @@ def check_str(instr: str, prefix: str):
partition_name=partition_name,
expr=expr,
consistency_level=get_consistency_level(consistency_level),
expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})),
)

@classmethod
Expand Down Expand Up @@ -726,6 +728,56 @@ def _prepare_placeholder_str(cls, data: Any):
common_types.PlaceholderGroup(placeholders=[pl])
)

@classmethod
def prepare_expression_template(cls, values: Dict) -> Any:
def add_data(v: Any) -> schema_types.TemplateValue:
dtype = infer_dtype_by_scalar_data(v)
data = schema_types.TemplateValue()
if dtype in (schema_types.Bool,):
data.bool_val = v
data.type = schema_types.Bool
return data
if dtype in (
schema_types.Int8,
schema_types.Int16,
schema_types.Int32,
schema_types.Int64,
):
data.int64_val = v
data.type = schema_types.Int64
return data
if dtype in (schema_types.Float, schema_types.Double):
data.float_val = v
data.type = schema_types.Double
return data
if dtype in (schema_types.VarChar, schema_types.String):
data.string_val = v
data.type = schema_types.VarChar
return data
if dtype in (schema_types.Array,):
element_datas = schema_types.TemplateArrayValue()
same_type = True
element_type = None
for element in v:
rdata = add_data(element)
element_datas.array.append(rdata)
if element_type is None:
element_type = rdata.type
elif element_type != rdata.type:
same_type = False
element_datas.element_type = element_type if same_type else schema_types.JSON
element_datas.same_type = same_type
data.array_val.CopyFrom(element_datas)
data.type = schema_types.Array
return data
raise ParamError(message=f"Unsupported element type: {dtype}")

expression_template_values = {}
for k, v in values.items():
expression_template_values[k] = add_data(v)

return expression_template_values

@classmethod
def search_requests_with_expr(
cls,
Expand Down Expand Up @@ -810,6 +862,7 @@ def search_requests_with_expr(
placeholder_group=plg_str,
dsl_type=common_types.DslType.BoolExprV1,
search_params=req_params,
expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})),
)
if expr is not None:
request.dsl = expr
Expand Down Expand Up @@ -1051,6 +1104,7 @@ def query_request(
guarantee_timestamp=kwargs.get("guarantee_timestamp", 0),
use_default_consistency=use_default_consistency,
consistency_level=kwargs.get("consistency_level", 0),
expr_template_values=cls.prepare_expression_template(kwargs.get("expr_params", {})),
)

limit = kwargs.get("limit")
Expand Down
Loading

0 comments on commit fadc01b

Please sign in to comment.