Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SearchByPks (work in progress) #600

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,17 @@ type Client interface {
// Upsert column-based data of collection, returns id column values
Upsert(ctx context.Context, collName string, partitionName string, columns ...entity.Column) (entity.Column, error)
// Search with bool expression
Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error)
Search(
ctx context.Context, collName string, partitions []string, expr string, outputFields []string,
vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int,
sp entity.SearchParam, opts ...SearchQueryOptionFunc,
) ([]SearchResult, error)
// SearchByPks searches using the vectors corresponding to the provided primary keys
SearchByPks(
ctx context.Context, collName string, partitions []string, expr string, outputFields []string,
primaryKeys entity.Column, vectorField string, metricType entity.MetricType, topK int,
sp entity.SearchParam, opts ...SearchQueryOptionFunc,
) ([]SearchResult, error)
// QueryByPks query record by specified primary key(s).
QueryByPks(ctx context.Context, collectionName string, partitionNames []string, ids entity.Column, outputFields []string, opts ...SearchQueryOptionFunc) (ResultSet, error)
// Query performs query records with boolean expression.
Expand Down
2 changes: 1 addition & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func TestGrpcClientNil(t *testing.T) {
mt := m.Type // type of function
if m.Name == "Close" || m.Name == "Connect" || // skip connect & close
m.Name == "UsingDatabase" || // skip use database
m.Name == "Search" || // type alias MetricType treated as string
m.Name == "Search" || m.Name == "SearchByPks" || // type alias MetricType treated as string
m.Name == "CalcDistance" ||
m.Name == "ManualCompaction" || // time.Duration hard to detect in reflect
m.Name == "Insert" || m.Name == "Upsert" { // complex methods with ...
Expand Down
200 changes: 139 additions & 61 deletions client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"

"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
Expand All @@ -35,45 +36,149 @@ const (
)

// Search with bool expression
func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) {
func (c *GrpcClient) Search(
ctx context.Context, collName string, partitions []string, expr string, outputFields []string, vectors []entity.Vector,
vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc,
) ([]SearchResult, error) {
if c.Service == nil {
return []SearchResult{}, ErrClientNotReady
}
var schema *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collName)

_, ok := MetaCache.getCollectionInfo(collName)
if !ok {
coll, err := c.DescribeCollection(ctx, collName)
_, err := c.DescribeCollection(ctx, collName)
if err != nil {
return nil, err
}
schema = coll.Schema
} else {
schema = collInfo.Schema
}

option, err := makeSearchQueryOption(collName, opts...)
if err != nil {
return nil, err
}
// 2. Request milvus Service
req, err := prepareSearchRequest(collName, partitions, expr, outputFields, vectors, vectorField, metricType, topK, sp, option)

params := sp.Params()
bs, err := json.Marshal(params)
if err != nil {
return nil, err
}

sr := make([]SearchResult, 0, len(vectors))
searchParams := prepareSearchParamsForSearchRequest(
vectorField, metricType, topK, bs, option,
)

req := &milvuspb.SearchRequest{
DbName: "",
CollectionName: collName,
PartitionNames: partitions,
Dsl: expr,
PlaceholderGroup: vector2PlaceholderGroupBytes(vectors),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: outputFields,
SearchParams: searchParams,
GuaranteeTimestamp: option.GuaranteeTimestamp,
Nq: int64(len(vectors)),
SearchByPrimaryKeys: false,
}

resp, err := c.Service.Search(ctx, req)
if err != nil {
return nil, err
}
if err := handleRespStatus(resp.GetStatus()); err != nil {
return nil, err
}

return processSearchResponse(resp, outputFields), nil
}

func (c *GrpcClient) SearchByPks(
ctx context.Context, collName string, partitions []string, expr string, outputFields []string,
primaryKeys entity.Column, vectorField string, metricType entity.MetricType, topK int,
sp entity.SearchParam, opts ...SearchQueryOptionFunc,
) ([]SearchResult, error) {
if c.Service == nil {
return []SearchResult{}, ErrClientNotReady
}

if primaryKeys.Len() == 0 {
return nil, errors.New("expected at least one primary key, but got zero")
}
if primaryKeys.Type() != entity.FieldTypeInt64 && primaryKeys.Type() != entity.FieldTypeVarChar {
return nil, errors.New("only int64 and varchar column can be primary key for now")
}

_, ok := MetaCache.getCollectionInfo(collName)
if !ok {
_, err := c.DescribeCollection(ctx, collName)
if err != nil {
return nil, err
}
}

option, err := makeSearchQueryOption(collName, opts...)
if err != nil {
return nil, err
}

params := sp.Params()
bs, err := json.Marshal(params)
if err != nil {
return nil, err
}

searchParams := prepareSearchParamsForSearchRequest(
vectorField, metricType, topK, bs, option,
)

req := &milvuspb.SearchRequest{
DbName: "",
CollectionName: collName,
PartitionNames: partitions,
Dsl: expr,
PlaceholderGroup: primaryKeysToPlaceholderGroupBytes(primaryKeys),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: outputFields,
SearchParams: searchParams,
GuaranteeTimestamp: option.GuaranteeTimestamp,
Nq: int64(primaryKeys.Len()),
SearchByPrimaryKeys: true,
}

resp, err := c.Service.Search(ctx, req)
if err != nil {
return nil, err
}
if err := handleRespStatus(resp.GetStatus()); err != nil {
return nil, err
}
// 3. parse result into result
results := resp.GetResults()

return processSearchResponse(resp, outputFields), nil
}

func prepareSearchParamsForSearchRequest(
vectorField string, metricType entity.MetricType, topK int, bs []byte, opt *SearchQueryOption,
) []*commonpb.KeyValuePair {
searchParams := entity.MapKvPairs(map[string]string{
"anns_field": vectorField,
"topk": fmt.Sprintf("%d", topK),
"params": string(bs),
"metric_type": string(metricType),
"round_decimal": "-1",
ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
offsetKey: fmt.Sprintf("%d", opt.Offset),
})

return searchParams
}

func processSearchResponse(response *milvuspb.SearchResults, outputFields []string) []SearchResult {
results := response.GetResults()

sr := make([]SearchResult, 0, results.GetNumQueries())
offset := 0
fieldDataList := results.GetFieldsData()

for i := 0; i < int(results.GetNumQueries()); i++ {
rc := int(results.GetTopks()[i]) // result entry count for current query
entry := SearchResult{
Expand All @@ -85,14 +190,15 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
offset += rc
continue
}
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
entry.Fields, entry.Err = parseSearchResult(outputFields, fieldDataList, offset, offset+rc)
sr = append(sr, entry)
offset += rc
}
return sr, nil

return sr
}

func (c *GrpcClient) parseSearchResult(_ *entity.Schema, outputFields []string, fieldDataList []*schemapb.FieldData, _, from, to int) ([]entity.Column, error) {
func parseSearchResult(outputFields []string, fieldDataList []*schemapb.FieldData, from, to int) ([]entity.Column, error) {
// duplicated name will have only one column now
outputSet := make(map[string]struct{})
for _, output := range outputFields {
Expand Down Expand Up @@ -208,16 +314,12 @@ func (c *GrpcClient) Query(ctx context.Context, collectionName string, partition
return nil, ErrClientNotReady
}

var sch *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collectionName)
_, ok := MetaCache.getCollectionInfo(collectionName)
if !ok {
coll, err := c.DescribeCollection(ctx, collectionName)
_, err := c.DescribeCollection(ctx, collectionName)
if err != nil {
return nil, err
}
sch = coll.Schema
} else {
sch = collInfo.Schema
}

option, err := makeSearchQueryOption(collectionName, opts...)
Expand Down Expand Up @@ -254,7 +356,7 @@ func (c *GrpcClient) Query(ctx context.Context, collectionName string, partition

fieldsData := resp.GetFieldsData()

columns, err := c.parseSearchResult(sch, outputFields, fieldsData, 0, 0, -1) //entity.FieldDataColumn(fieldData, 0, -1)
columns, err := parseSearchResult(outputFields, fieldsData, 0, -1) //entity.FieldDataColumn(fieldData, 0, -1)
if err != nil {
return nil, err
}
Expand All @@ -271,47 +373,23 @@ func getPKField(schema *entity.Schema) *entity.Field {
return nil
}

func getVectorField(schema *entity.Schema) *entity.Field {
for _, f := range schema.Fields {
if f.DataType == entity.FieldTypeFloatVector || f.DataType == entity.FieldTypeBinaryVector {
return f
}
}
return nil
}
func primaryKeysToPlaceholderGroupBytes(primaryKeys entity.Column) []byte {

func prepareSearchRequest(collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string,
metricType entity.MetricType, topK int, sp entity.SearchParam, opt *SearchQueryOption) (*milvuspb.SearchRequest, error) {
params := sp.Params()
params[forTuningKey] = opt.ForTuning
bs, err := json.Marshal(params)
if err != nil {
return nil, err
}
queryExpr := PKs2Expr("", primaryKeys)
queryExprBytes := []byte(queryExpr)

searchParams := entity.MapKvPairs(map[string]string{
"anns_field": vectorField,
"topk": fmt.Sprintf("%d", topK),
"params": string(bs),
"metric_type": string(metricType),
"round_decimal": "-1",
ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
offsetKey: fmt.Sprintf("%d", opt.Offset),
})
req := &milvuspb.SearchRequest{
DbName: "",
CollectionName: collName,
PartitionNames: partitions,
Dsl: expr,
PlaceholderGroup: vector2PlaceholderGroupBytes(vectors),
DslType: commonpb.DslType_BoolExprV1,
OutputFields: outputFields,
SearchParams: searchParams,
GuaranteeTimestamp: opt.GuaranteeTimestamp,
Nq: int64(len(vectors)),
placeholderGroup := &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_None,
Values: [][]byte{queryExprBytes},
},
},
}
return req, nil

bs, _ := proto.Marshal(placeholderGroup)
return bs
}

// GetPersistentSegmentInfo get persistent segment info
Expand Down
Loading