From 7e3e8bec5cd421cabdb8897d3209a1707a0d161c Mon Sep 17 00:00:00 2001 From: chyezh Date: Thu, 14 Nov 2024 11:07:28 +0800 Subject: [PATCH] enhance: move segcore codes of segment into one package - move most cgo opeartion related to search/query into segcore package for reusing for streamingnode. - add go unittest for segcore operations. Signed-off-by: chyezh --- Makefile | 1 - internal/.mockery.yaml | 6 + internal/core/src/segcore/segcore_init_c.h | 3 + .../mocks/util/mock_segcore/mock_CSegment.go | 708 ++++++++++++++++++ .../util/mock_segcore}/mock_data.go | 146 ++-- .../mock_optimizers/mock_QueryHook.go} | 2 +- .../delegator/delegator_data_test.go | 3 +- internal/querynodev2/local_worker_test.go | 7 +- .../querynodev2/pipeline/insert_node_test.go | 9 +- .../querynodev2/pipeline/pipeline_test.go | 5 +- internal/querynodev2/segments/cgo_util.go | 38 - internal/querynodev2/segments/collection.go | 53 +- .../querynodev2/segments/count_reducer.go | 3 +- internal/querynodev2/segments/manager_test.go | 5 +- internal/querynodev2/segments/mock_segment.go | 71 +- internal/querynodev2/segments/reducer.go | 3 +- internal/querynodev2/segments/result.go | 10 +- internal/querynodev2/segments/result_test.go | 90 +-- internal/querynodev2/segments/retrieve.go | 2 +- .../querynodev2/segments/retrieve_test.go | 30 +- internal/querynodev2/segments/search.go | 6 +- .../segments/search_reduce_test.go | 25 +- internal/querynodev2/segments/search_test.go | 13 +- internal/querynodev2/segments/segcore.go | 21 + internal/querynodev2/segments/segment.go | 448 ++++------- .../querynodev2/segments/segment_interface.go | 7 +- internal/querynodev2/segments/segment_l0.go | 7 +- .../querynodev2/segments/segment_loader.go | 5 +- .../segments/segment_loader_test.go | 57 +- internal/querynodev2/segments/segment_test.go | 13 +- internal/querynodev2/server.go | 3 +- internal/querynodev2/server_test.go | 7 +- internal/querynodev2/services_test.go | 63 +- .../querynodev2/tasks/query_stream_task.go | 6 +- internal/querynodev2/tasks/query_task.go | 6 +- internal/querynodev2/tasks/search_task.go | 33 +- .../searchutil/optimizers/query_hook_test.go | 11 +- internal/util/segcore/cgo_util.go | 78 ++ internal/util/segcore/cgo_util_test.go | 19 + internal/util/segcore/collection.go | 83 ++ internal/util/segcore/collection_test.go | 28 + .../segments => util/segcore}/plan.go | 115 ++- .../segments => util/segcore}/plan_test.go | 43 +- .../segments => util/segcore}/reduce.go | 47 +- .../segments => util/segcore}/reduce_test.go | 133 ++-- internal/util/segcore/requests.go | 99 +++ internal/util/segcore/requests_test.go | 33 + internal/util/segcore/responses.go | 47 ++ internal/util/segcore/segcore_init.go | 24 + internal/util/segcore/segcore_init_test.go | 13 + internal/util/segcore/segment.go | 294 ++++++++ internal/util/segcore/segment_interface.go | 74 ++ internal/util/segcore/segment_test.go | 136 ++++ internal/util/segcore/trace.go | 56 ++ 54 files changed, 2391 insertions(+), 857 deletions(-) create mode 100644 internal/mocks/util/mock_segcore/mock_CSegment.go rename internal/{querynodev2/segments => mocks/util/mock_segcore}/mock_data.go (92%) rename internal/{util/searchutil/optimizers/mock_query_hook.go => mocks/util/searchutil/mock_optimizers/mock_QueryHook.go} (99%) create mode 100644 internal/querynodev2/segments/segcore.go create mode 100644 internal/util/segcore/cgo_util.go create mode 100644 internal/util/segcore/cgo_util_test.go create mode 100644 internal/util/segcore/collection.go create mode 100644 internal/util/segcore/collection_test.go rename internal/{querynodev2/segments => util/segcore}/plan.go (60%) rename internal/{querynodev2/segments => util/segcore}/plan_test.go (66%) rename internal/{querynodev2/segments => util/segcore}/reduce.go (82%) rename internal/{querynodev2/segments => util/segcore}/reduce_test.go (57%) create mode 100644 internal/util/segcore/requests.go create mode 100644 internal/util/segcore/requests_test.go create mode 100644 internal/util/segcore/responses.go create mode 100644 internal/util/segcore/segcore_init.go create mode 100644 internal/util/segcore/segcore_init_test.go create mode 100644 internal/util/segcore/segment.go create mode 100644 internal/util/segcore/segment_interface.go create mode 100644 internal/util/segcore/segment_test.go create mode 100644 internal/util/segcore/trace.go diff --git a/Makefile b/Makefile index ad028ec3d3e6a..d1634e01ce0f0 100644 --- a/Makefile +++ b/Makefile @@ -473,7 +473,6 @@ generate-mockery-querycoord: getdeps generate-mockery-querynode: getdeps build-cpp @source $(PWD)/scripts/setenv.sh # setup PKG_CONFIG_PATH - $(INSTALL_PATH)/mockery --name=QueryHook --dir=$(PWD)/internal/querynodev2/optimizers --output=$(PWD)/internal/querynodev2/optimizers --filename=mock_query_hook.go --with-expecter --outpkg=optimizers --structname=MockQueryHook --inpackage $(INSTALL_PATH)/mockery --name=Manager --dir=$(PWD)/internal/querynodev2/cluster --output=$(PWD)/internal/querynodev2/cluster --filename=mock_manager.go --with-expecter --outpkg=cluster --structname=MockManager --inpackage $(INSTALL_PATH)/mockery --name=SegmentManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_segment_manager.go --with-expecter --outpkg=segments --structname=MockSegmentManager --inpackage $(INSTALL_PATH)/mockery --name=CollectionManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_collection_manager.go --with-expecter --outpkg=segments --structname=MockCollectionManager --inpackage diff --git a/internal/.mockery.yaml b/internal/.mockery.yaml index ea0cf349562f9..2179959e87d3a 100644 --- a/internal/.mockery.yaml +++ b/internal/.mockery.yaml @@ -61,6 +61,9 @@ packages: interfaces: StreamingCoordCataLog: StreamingNodeCataLog: + github.com/milvus-io/milvus/internal/util/segcore: + interfaces: + CSegment: github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer: interfaces: Discoverer: @@ -72,6 +75,9 @@ packages: interfaces: Resolver: Builder: + github.com/milvus-io/milvus/internal/util/searchutil/optimizers: + interfaces: + QueryHook: google.golang.org/grpc/resolver: interfaces: ClientConn: diff --git a/internal/core/src/segcore/segcore_init_c.h b/internal/core/src/segcore/segcore_init_c.h index d617d796a8406..56092fe364bb1 100644 --- a/internal/core/src/segcore/segcore_init_c.h +++ b/internal/core/src/segcore/segcore_init_c.h @@ -11,6 +11,9 @@ #pragma once +#include +#include + #ifdef __cplusplus extern "C" { #endif diff --git a/internal/mocks/util/mock_segcore/mock_CSegment.go b/internal/mocks/util/mock_segcore/mock_CSegment.go new file mode 100644 index 0000000000000..dc77513b2c610 --- /dev/null +++ b/internal/mocks/util/mock_segcore/mock_CSegment.go @@ -0,0 +1,708 @@ +// Code generated by mockery v2.46.0. DO NOT EDIT. + +package mock_segcore + +import ( + context "context" + + segcore "github.com/milvus-io/milvus/internal/util/segcore" + mock "github.com/stretchr/testify/mock" +) + +// MockCSegment is an autogenerated mock type for the CSegment type +type MockCSegment struct { + mock.Mock +} + +type MockCSegment_Expecter struct { + mock *mock.Mock +} + +func (_m *MockCSegment) EXPECT() *MockCSegment_Expecter { + return &MockCSegment_Expecter{mock: &_m.Mock} +} + +// AddFieldDataInfo provides a mock function with given fields: ctx, request +func (_m *MockCSegment) AddFieldDataInfo(ctx context.Context, request *segcore.LoadFieldDataRequest) (*segcore.AddFieldDataInfoResult, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for AddFieldDataInfo") + } + + var r0 *segcore.AddFieldDataInfoResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.AddFieldDataInfoResult, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) *segcore.AddFieldDataInfoResult); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*segcore.AddFieldDataInfoResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *segcore.LoadFieldDataRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSegment_AddFieldDataInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddFieldDataInfo' +type MockCSegment_AddFieldDataInfo_Call struct { + *mock.Call +} + +// AddFieldDataInfo is a helper method to define mock.On call +// - ctx context.Context +// - request *segcore.LoadFieldDataRequest +func (_e *MockCSegment_Expecter) AddFieldDataInfo(ctx interface{}, request interface{}) *MockCSegment_AddFieldDataInfo_Call { + return &MockCSegment_AddFieldDataInfo_Call{Call: _e.mock.On("AddFieldDataInfo", ctx, request)} +} + +func (_c *MockCSegment_AddFieldDataInfo_Call) Run(run func(ctx context.Context, request *segcore.LoadFieldDataRequest)) *MockCSegment_AddFieldDataInfo_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*segcore.LoadFieldDataRequest)) + }) + return _c +} + +func (_c *MockCSegment_AddFieldDataInfo_Call) Return(_a0 *segcore.AddFieldDataInfoResult, _a1 error) *MockCSegment_AddFieldDataInfo_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSegment_AddFieldDataInfo_Call) RunAndReturn(run func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.AddFieldDataInfoResult, error)) *MockCSegment_AddFieldDataInfo_Call { + _c.Call.Return(run) + return _c +} + +// Delete provides a mock function with given fields: ctx, request +func (_m *MockCSegment) Delete(ctx context.Context, request *segcore.DeleteRequest) (*segcore.DeleteResult, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 *segcore.DeleteResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *segcore.DeleteRequest) (*segcore.DeleteResult, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *segcore.DeleteRequest) *segcore.DeleteResult); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*segcore.DeleteResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *segcore.DeleteRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSegment_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type MockCSegment_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - request *segcore.DeleteRequest +func (_e *MockCSegment_Expecter) Delete(ctx interface{}, request interface{}) *MockCSegment_Delete_Call { + return &MockCSegment_Delete_Call{Call: _e.mock.On("Delete", ctx, request)} +} + +func (_c *MockCSegment_Delete_Call) Run(run func(ctx context.Context, request *segcore.DeleteRequest)) *MockCSegment_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*segcore.DeleteRequest)) + }) + return _c +} + +func (_c *MockCSegment_Delete_Call) Return(_a0 *segcore.DeleteResult, _a1 error) *MockCSegment_Delete_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSegment_Delete_Call) RunAndReturn(run func(context.Context, *segcore.DeleteRequest) (*segcore.DeleteResult, error)) *MockCSegment_Delete_Call { + _c.Call.Return(run) + return _c +} + +// HasRawData provides a mock function with given fields: fieldID +func (_m *MockCSegment) HasRawData(fieldID int64) bool { + ret := _m.Called(fieldID) + + if len(ret) == 0 { + panic("no return value specified for HasRawData") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(int64) bool); ok { + r0 = rf(fieldID) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockCSegment_HasRawData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasRawData' +type MockCSegment_HasRawData_Call struct { + *mock.Call +} + +// HasRawData is a helper method to define mock.On call +// - fieldID int64 +func (_e *MockCSegment_Expecter) HasRawData(fieldID interface{}) *MockCSegment_HasRawData_Call { + return &MockCSegment_HasRawData_Call{Call: _e.mock.On("HasRawData", fieldID)} +} + +func (_c *MockCSegment_HasRawData_Call) Run(run func(fieldID int64)) *MockCSegment_HasRawData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *MockCSegment_HasRawData_Call) Return(_a0 bool) *MockCSegment_HasRawData_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCSegment_HasRawData_Call) RunAndReturn(run func(int64) bool) *MockCSegment_HasRawData_Call { + _c.Call.Return(run) + return _c +} + +// ID provides a mock function with given fields: +func (_m *MockCSegment) ID() int64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ID") + } + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// MockCSegment_ID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ID' +type MockCSegment_ID_Call struct { + *mock.Call +} + +// ID is a helper method to define mock.On call +func (_e *MockCSegment_Expecter) ID() *MockCSegment_ID_Call { + return &MockCSegment_ID_Call{Call: _e.mock.On("ID")} +} + +func (_c *MockCSegment_ID_Call) Run(run func()) *MockCSegment_ID_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCSegment_ID_Call) Return(_a0 int64) *MockCSegment_ID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCSegment_ID_Call) RunAndReturn(run func() int64) *MockCSegment_ID_Call { + _c.Call.Return(run) + return _c +} + +// Insert provides a mock function with given fields: ctx, request +func (_m *MockCSegment) Insert(ctx context.Context, request *segcore.InsertRequest) (*segcore.InsertResult, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for Insert") + } + + var r0 *segcore.InsertResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *segcore.InsertRequest) (*segcore.InsertResult, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *segcore.InsertRequest) *segcore.InsertResult); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*segcore.InsertResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *segcore.InsertRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSegment_Insert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Insert' +type MockCSegment_Insert_Call struct { + *mock.Call +} + +// Insert is a helper method to define mock.On call +// - ctx context.Context +// - request *segcore.InsertRequest +func (_e *MockCSegment_Expecter) Insert(ctx interface{}, request interface{}) *MockCSegment_Insert_Call { + return &MockCSegment_Insert_Call{Call: _e.mock.On("Insert", ctx, request)} +} + +func (_c *MockCSegment_Insert_Call) Run(run func(ctx context.Context, request *segcore.InsertRequest)) *MockCSegment_Insert_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*segcore.InsertRequest)) + }) + return _c +} + +func (_c *MockCSegment_Insert_Call) Return(_a0 *segcore.InsertResult, _a1 error) *MockCSegment_Insert_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSegment_Insert_Call) RunAndReturn(run func(context.Context, *segcore.InsertRequest) (*segcore.InsertResult, error)) *MockCSegment_Insert_Call { + _c.Call.Return(run) + return _c +} + +// LoadFieldData provides a mock function with given fields: ctx, request +func (_m *MockCSegment) LoadFieldData(ctx context.Context, request *segcore.LoadFieldDataRequest) (*segcore.LoadFieldDataResult, error) { + ret := _m.Called(ctx, request) + + if len(ret) == 0 { + panic("no return value specified for LoadFieldData") + } + + var r0 *segcore.LoadFieldDataResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.LoadFieldDataResult, error)); ok { + return rf(ctx, request) + } + if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) *segcore.LoadFieldDataResult); ok { + r0 = rf(ctx, request) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*segcore.LoadFieldDataResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *segcore.LoadFieldDataRequest) error); ok { + r1 = rf(ctx, request) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSegment_LoadFieldData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadFieldData' +type MockCSegment_LoadFieldData_Call struct { + *mock.Call +} + +// LoadFieldData is a helper method to define mock.On call +// - ctx context.Context +// - request *segcore.LoadFieldDataRequest +func (_e *MockCSegment_Expecter) LoadFieldData(ctx interface{}, request interface{}) *MockCSegment_LoadFieldData_Call { + return &MockCSegment_LoadFieldData_Call{Call: _e.mock.On("LoadFieldData", ctx, request)} +} + +func (_c *MockCSegment_LoadFieldData_Call) Run(run func(ctx context.Context, request *segcore.LoadFieldDataRequest)) *MockCSegment_LoadFieldData_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*segcore.LoadFieldDataRequest)) + }) + return _c +} + +func (_c *MockCSegment_LoadFieldData_Call) Return(_a0 *segcore.LoadFieldDataResult, _a1 error) *MockCSegment_LoadFieldData_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSegment_LoadFieldData_Call) RunAndReturn(run func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.LoadFieldDataResult, error)) *MockCSegment_LoadFieldData_Call { + _c.Call.Return(run) + return _c +} + +// MemSize provides a mock function with given fields: +func (_m *MockCSegment) MemSize() int64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for MemSize") + } + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// MockCSegment_MemSize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MemSize' +type MockCSegment_MemSize_Call struct { + *mock.Call +} + +// MemSize is a helper method to define mock.On call +func (_e *MockCSegment_Expecter) MemSize() *MockCSegment_MemSize_Call { + return &MockCSegment_MemSize_Call{Call: _e.mock.On("MemSize")} +} + +func (_c *MockCSegment_MemSize_Call) Run(run func()) *MockCSegment_MemSize_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCSegment_MemSize_Call) Return(_a0 int64) *MockCSegment_MemSize_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCSegment_MemSize_Call) RunAndReturn(run func() int64) *MockCSegment_MemSize_Call { + _c.Call.Return(run) + return _c +} + +// RawPointer provides a mock function with given fields: +func (_m *MockCSegment) RawPointer() segcore.CSegmentInterface { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for RawPointer") + } + + var r0 segcore.CSegmentInterface + if rf, ok := ret.Get(0).(func() segcore.CSegmentInterface); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(segcore.CSegmentInterface) + } + + return r0 +} + +// MockCSegment_RawPointer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RawPointer' +type MockCSegment_RawPointer_Call struct { + *mock.Call +} + +// RawPointer is a helper method to define mock.On call +func (_e *MockCSegment_Expecter) RawPointer() *MockCSegment_RawPointer_Call { + return &MockCSegment_RawPointer_Call{Call: _e.mock.On("RawPointer")} +} + +func (_c *MockCSegment_RawPointer_Call) Run(run func()) *MockCSegment_RawPointer_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCSegment_RawPointer_Call) Return(_a0 segcore.CSegmentInterface) *MockCSegment_RawPointer_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCSegment_RawPointer_Call) RunAndReturn(run func() segcore.CSegmentInterface) *MockCSegment_RawPointer_Call { + _c.Call.Return(run) + return _c +} + +// Release provides a mock function with given fields: +func (_m *MockCSegment) Release() { + _m.Called() +} + +// MockCSegment_Release_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Release' +type MockCSegment_Release_Call struct { + *mock.Call +} + +// Release is a helper method to define mock.On call +func (_e *MockCSegment_Expecter) Release() *MockCSegment_Release_Call { + return &MockCSegment_Release_Call{Call: _e.mock.On("Release")} +} + +func (_c *MockCSegment_Release_Call) Run(run func()) *MockCSegment_Release_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCSegment_Release_Call) Return() *MockCSegment_Release_Call { + _c.Call.Return() + return _c +} + +func (_c *MockCSegment_Release_Call) RunAndReturn(run func()) *MockCSegment_Release_Call { + _c.Call.Return(run) + return _c +} + +// Retrieve provides a mock function with given fields: ctx, plan +func (_m *MockCSegment) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcore.RetrieveResult, error) { + ret := _m.Called(ctx, plan) + + if len(ret) == 0 { + panic("no return value specified for Retrieve") + } + + var r0 *segcore.RetrieveResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlan) (*segcore.RetrieveResult, error)); ok { + return rf(ctx, plan) + } + if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlan) *segcore.RetrieveResult); ok { + r0 = rf(ctx, plan) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*segcore.RetrieveResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *segcore.RetrievePlan) error); ok { + r1 = rf(ctx, plan) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSegment_Retrieve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Retrieve' +type MockCSegment_Retrieve_Call struct { + *mock.Call +} + +// Retrieve is a helper method to define mock.On call +// - ctx context.Context +// - plan *segcore.RetrievePlan +func (_e *MockCSegment_Expecter) Retrieve(ctx interface{}, plan interface{}) *MockCSegment_Retrieve_Call { + return &MockCSegment_Retrieve_Call{Call: _e.mock.On("Retrieve", ctx, plan)} +} + +func (_c *MockCSegment_Retrieve_Call) Run(run func(ctx context.Context, plan *segcore.RetrievePlan)) *MockCSegment_Retrieve_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*segcore.RetrievePlan)) + }) + return _c +} + +func (_c *MockCSegment_Retrieve_Call) Return(_a0 *segcore.RetrieveResult, _a1 error) *MockCSegment_Retrieve_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSegment_Retrieve_Call) RunAndReturn(run func(context.Context, *segcore.RetrievePlan) (*segcore.RetrieveResult, error)) *MockCSegment_Retrieve_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveByOffsets provides a mock function with given fields: ctx, plan +func (_m *MockCSegment) RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcore.RetrieveResult, error) { + ret := _m.Called(ctx, plan) + + if len(ret) == 0 { + panic("no return value specified for RetrieveByOffsets") + } + + var r0 *segcore.RetrieveResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlanWithOffsets) (*segcore.RetrieveResult, error)); ok { + return rf(ctx, plan) + } + if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlanWithOffsets) *segcore.RetrieveResult); ok { + r0 = rf(ctx, plan) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*segcore.RetrieveResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *segcore.RetrievePlanWithOffsets) error); ok { + r1 = rf(ctx, plan) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSegment_RetrieveByOffsets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByOffsets' +type MockCSegment_RetrieveByOffsets_Call struct { + *mock.Call +} + +// RetrieveByOffsets is a helper method to define mock.On call +// - ctx context.Context +// - plan *segcore.RetrievePlanWithOffsets +func (_e *MockCSegment_Expecter) RetrieveByOffsets(ctx interface{}, plan interface{}) *MockCSegment_RetrieveByOffsets_Call { + return &MockCSegment_RetrieveByOffsets_Call{Call: _e.mock.On("RetrieveByOffsets", ctx, plan)} +} + +func (_c *MockCSegment_RetrieveByOffsets_Call) Run(run func(ctx context.Context, plan *segcore.RetrievePlanWithOffsets)) *MockCSegment_RetrieveByOffsets_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*segcore.RetrievePlanWithOffsets)) + }) + return _c +} + +func (_c *MockCSegment_RetrieveByOffsets_Call) Return(_a0 *segcore.RetrieveResult, _a1 error) *MockCSegment_RetrieveByOffsets_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSegment_RetrieveByOffsets_Call) RunAndReturn(run func(context.Context, *segcore.RetrievePlanWithOffsets) (*segcore.RetrieveResult, error)) *MockCSegment_RetrieveByOffsets_Call { + _c.Call.Return(run) + return _c +} + +// RowNum provides a mock function with given fields: +func (_m *MockCSegment) RowNum() int64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for RowNum") + } + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// MockCSegment_RowNum_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RowNum' +type MockCSegment_RowNum_Call struct { + *mock.Call +} + +// RowNum is a helper method to define mock.On call +func (_e *MockCSegment_Expecter) RowNum() *MockCSegment_RowNum_Call { + return &MockCSegment_RowNum_Call{Call: _e.mock.On("RowNum")} +} + +func (_c *MockCSegment_RowNum_Call) Run(run func()) *MockCSegment_RowNum_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockCSegment_RowNum_Call) Return(_a0 int64) *MockCSegment_RowNum_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockCSegment_RowNum_Call) RunAndReturn(run func() int64) *MockCSegment_RowNum_Call { + _c.Call.Return(run) + return _c +} + +// Search provides a mock function with given fields: ctx, searchReq +func (_m *MockCSegment) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) { + ret := _m.Called(ctx, searchReq) + + if len(ret) == 0 { + panic("no return value specified for Search") + } + + var r0 *segcore.SearchResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *segcore.SearchRequest) (*segcore.SearchResult, error)); ok { + return rf(ctx, searchReq) + } + if rf, ok := ret.Get(0).(func(context.Context, *segcore.SearchRequest) *segcore.SearchResult); ok { + r0 = rf(ctx, searchReq) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*segcore.SearchResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *segcore.SearchRequest) error); ok { + r1 = rf(ctx, searchReq) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockCSegment_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search' +type MockCSegment_Search_Call struct { + *mock.Call +} + +// Search is a helper method to define mock.On call +// - ctx context.Context +// - searchReq *segcore.SearchRequest +func (_e *MockCSegment_Expecter) Search(ctx interface{}, searchReq interface{}) *MockCSegment_Search_Call { + return &MockCSegment_Search_Call{Call: _e.mock.On("Search", ctx, searchReq)} +} + +func (_c *MockCSegment_Search_Call) Run(run func(ctx context.Context, searchReq *segcore.SearchRequest)) *MockCSegment_Search_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*segcore.SearchRequest)) + }) + return _c +} + +func (_c *MockCSegment_Search_Call) Return(_a0 *segcore.SearchResult, _a1 error) *MockCSegment_Search_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockCSegment_Search_Call) RunAndReturn(run func(context.Context, *segcore.SearchRequest) (*segcore.SearchResult, error)) *MockCSegment_Search_Call { + _c.Call.Return(run) + return _c +} + +// NewMockCSegment creates a new instance of MockCSegment. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCSegment(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCSegment { + mock := &MockCSegment{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/querynodev2/segments/mock_data.go b/internal/mocks/util/mock_segcore/mock_data.go similarity index 92% rename from internal/querynodev2/segments/mock_data.go rename to internal/mocks/util/mock_segcore/mock_data.go index 1f734c6fccb35..95257bc987cca 100644 --- a/internal/querynodev2/segments/mock_data.go +++ b/internal/mocks/util/mock_segcore/mock_data.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package segments +package mock_segcore import ( "context" @@ -43,6 +43,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/segcorepb" storage "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/indexcgowrapper" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -78,7 +79,7 @@ const ( rowIDFieldID = 0 timestampFieldID = 1 metricTypeKey = common.MetricTypeKey - defaultDim = 128 + DefaultDim = 128 defaultMetricType = metric.L2 dimKey = common.DimKey @@ -89,7 +90,7 @@ const ( // ---------- unittest util functions ---------- // gen collection schema for type vecFieldParam struct { - id int64 + ID int64 dim int metricType string vecType schemapb.DataType @@ -97,125 +98,125 @@ type vecFieldParam struct { } type constFieldParam struct { - id int64 + ID int64 dataType schemapb.DataType fieldName string } -var simpleFloatVecField = vecFieldParam{ - id: 100, - dim: defaultDim, +var SimpleFloatVecField = vecFieldParam{ + ID: 100, + dim: DefaultDim, metricType: defaultMetricType, vecType: schemapb.DataType_FloatVector, fieldName: "floatVectorField", } var simpleBinVecField = vecFieldParam{ - id: 101, - dim: defaultDim, + ID: 101, + dim: DefaultDim, metricType: metric.JACCARD, vecType: schemapb.DataType_BinaryVector, fieldName: "binVectorField", } var simpleFloat16VecField = vecFieldParam{ - id: 112, - dim: defaultDim, + ID: 112, + dim: DefaultDim, metricType: defaultMetricType, vecType: schemapb.DataType_Float16Vector, fieldName: "float16VectorField", } var simpleBFloat16VecField = vecFieldParam{ - id: 113, - dim: defaultDim, + ID: 113, + dim: DefaultDim, metricType: defaultMetricType, vecType: schemapb.DataType_BFloat16Vector, fieldName: "bfloat16VectorField", } -var simpleSparseFloatVectorField = vecFieldParam{ - id: 114, +var SimpleSparseFloatVectorField = vecFieldParam{ + ID: 114, metricType: metric.IP, vecType: schemapb.DataType_SparseFloatVector, fieldName: "sparseFloatVectorField", } var simpleBoolField = constFieldParam{ - id: 102, + ID: 102, dataType: schemapb.DataType_Bool, fieldName: "boolField", } var simpleInt8Field = constFieldParam{ - id: 103, + ID: 103, dataType: schemapb.DataType_Int8, fieldName: "int8Field", } var simpleInt16Field = constFieldParam{ - id: 104, + ID: 104, dataType: schemapb.DataType_Int16, fieldName: "int16Field", } var simpleInt32Field = constFieldParam{ - id: 105, + ID: 105, dataType: schemapb.DataType_Int32, fieldName: "int32Field", } var simpleInt64Field = constFieldParam{ - id: 106, + ID: 106, dataType: schemapb.DataType_Int64, fieldName: "int64Field", } var simpleFloatField = constFieldParam{ - id: 107, + ID: 107, dataType: schemapb.DataType_Float, fieldName: "floatField", } var simpleDoubleField = constFieldParam{ - id: 108, + ID: 108, dataType: schemapb.DataType_Double, fieldName: "doubleField", } var simpleJSONField = constFieldParam{ - id: 109, + ID: 109, dataType: schemapb.DataType_JSON, fieldName: "jsonField", } var simpleArrayField = constFieldParam{ - id: 110, + ID: 110, dataType: schemapb.DataType_Array, fieldName: "arrayField", } var simpleVarCharField = constFieldParam{ - id: 111, + ID: 111, dataType: schemapb.DataType_VarChar, fieldName: "varCharField", } -var rowIDField = constFieldParam{ - id: rowIDFieldID, +var RowIDField = constFieldParam{ + ID: rowIDFieldID, dataType: schemapb.DataType_Int64, fieldName: "RowID", } var timestampField = constFieldParam{ - id: timestampFieldID, + ID: timestampFieldID, dataType: schemapb.DataType_Int64, fieldName: "Timestamp", } func genConstantFieldSchema(param constFieldParam) *schemapb.FieldSchema { field := &schemapb.FieldSchema{ - FieldID: param.id, + FieldID: param.ID, Name: param.fieldName, IsPrimaryKey: false, DataType: param.dataType, @@ -231,7 +232,7 @@ func genConstantFieldSchema(param constFieldParam) *schemapb.FieldSchema { func genPKFieldSchema(param constFieldParam) *schemapb.FieldSchema { field := &schemapb.FieldSchema{ - FieldID: param.id, + FieldID: param.ID, Name: param.fieldName, IsPrimaryKey: true, DataType: param.dataType, @@ -247,7 +248,7 @@ func genPKFieldSchema(param constFieldParam) *schemapb.FieldSchema { func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema { fieldVec := &schemapb.FieldSchema{ - FieldID: param.id, + FieldID: param.ID, Name: param.fieldName, IsPrimaryKey: false, DataType: param.vecType, @@ -270,11 +271,11 @@ func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema { } func GenTestBM25CollectionSchema(collectionName string) *schemapb.CollectionSchema { - fieldRowID := genConstantFieldSchema(rowIDField) + fieldRowID := genConstantFieldSchema(RowIDField) fieldTimestamp := genConstantFieldSchema(timestampField) pkFieldSchema := genPKFieldSchema(simpleInt64Field) textFieldSchema := genConstantFieldSchema(simpleVarCharField) - sparseFieldSchema := genVectorFieldSchema(simpleSparseFloatVectorField) + sparseFieldSchema := genVectorFieldSchema(SimpleSparseFloatVectorField) sparseFieldSchema.IsFunctionOutput = true schema := &schemapb.CollectionSchema{ @@ -301,7 +302,7 @@ func GenTestBM25CollectionSchema(collectionName string) *schemapb.CollectionSche // some tests do not yet support sparse float vector, see comments of // GenSparseFloatVecDataset in indexcgowrapper/dataset.go func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, withSparse bool) *schemapb.CollectionSchema { - fieldRowID := genConstantFieldSchema(rowIDField) + fieldRowID := genConstantFieldSchema(RowIDField) fieldTimestamp := genConstantFieldSchema(timestampField) fieldBool := genConstantFieldSchema(simpleBoolField) fieldInt8 := genConstantFieldSchema(simpleInt8Field) @@ -312,7 +313,7 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi // fieldArray := genConstantFieldSchema(simpleArrayField) fieldJSON := genConstantFieldSchema(simpleJSONField) fieldArray := genConstantFieldSchema(simpleArrayField) - floatVecFieldSchema := genVectorFieldSchema(simpleFloatVecField) + floatVecFieldSchema := genVectorFieldSchema(SimpleFloatVecField) binVecFieldSchema := genVectorFieldSchema(simpleBinVecField) float16VecFieldSchema := genVectorFieldSchema(simpleFloat16VecField) bfloat16VecFieldSchema := genVectorFieldSchema(simpleBFloat16VecField) @@ -346,7 +347,7 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi } if withSparse { - schema.Fields = append(schema.Fields, genVectorFieldSchema(simpleSparseFloatVectorField)) + schema.Fields = append(schema.Fields, genVectorFieldSchema(SimpleSparseFloatVectorField)) } for i, field := range schema.GetFields() { @@ -477,7 +478,7 @@ func SaveBinLog(ctx context.Context, return nil, nil, err } - k := JoinIDPath(collectionID, partitionID, segmentID, fieldID) + k := metautil.JoinIDPath(collectionID, partitionID, segmentID, fieldID) key := path.Join(chunkManager.RootPath(), "insert-log", k) kvs[key] = blob.Value fieldBinlog = append(fieldBinlog, &datapb.FieldBinlog{ @@ -499,7 +500,7 @@ func SaveBinLog(ctx context.Context, return nil, nil, err } - k := JoinIDPath(collectionID, partitionID, segmentID, fieldID) + k := metautil.JoinIDPath(collectionID, partitionID, segmentID, fieldID) key := path.Join(chunkManager.RootPath(), "stats-log", k) kvs[key] = blob.Value statsBinlog = append(statsBinlog, &datapb.FieldBinlog{ @@ -597,7 +598,7 @@ func genInsertData(msgLength int, schema *schemapb.CollectionSchema) (*storage.I Data: testutils.GenerateJSONArray(msgLength), } case schemapb.DataType_FloatVector: - dim := simpleFloatVecField.dim // if no dim specified, use simpleFloatVecField's dim + dim := SimpleFloatVecField.dim // if no dim specified, use simpleFloatVecField's dim insertData.Data[f.FieldID] = &storage.FloatVectorFieldData{ Data: testutils.GenerateFloatVectors(msgLength, dim), Dim: dim, @@ -689,7 +690,7 @@ func SaveDeltaLog(collectionID int64, pkFieldID := int64(106) fieldBinlog := make([]*datapb.FieldBinlog, 0) log.Debug("[query node unittest] save delta log", zap.Int64("fieldID", pkFieldID)) - key := JoinIDPath(collectionID, partitionID, segmentID, pkFieldID) + key := metautil.JoinIDPath(collectionID, partitionID, segmentID, pkFieldID) // keyPath := path.Join(defaultLocalStorage, "delta-log", key) keyPath := path.Join(cm.RootPath(), "delta-log", key) kvs[keyPath] = blob.Value @@ -750,13 +751,13 @@ func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64, var dataset *indexcgowrapper.Dataset switch fieldSchema.DataType { case schemapb.DataType_BinaryVector: - dataset = indexcgowrapper.GenBinaryVecDataset(testutils.GenerateBinaryVectors(msgLength, defaultDim)) + dataset = indexcgowrapper.GenBinaryVecDataset(testutils.GenerateBinaryVectors(msgLength, DefaultDim)) case schemapb.DataType_FloatVector: - dataset = indexcgowrapper.GenFloatVecDataset(testutils.GenerateFloatVectors(msgLength, defaultDim)) + dataset = indexcgowrapper.GenFloatVecDataset(testutils.GenerateFloatVectors(msgLength, DefaultDim)) case schemapb.DataType_Float16Vector: - dataset = indexcgowrapper.GenFloat16VecDataset(testutils.GenerateFloat16Vectors(msgLength, defaultDim)) + dataset = indexcgowrapper.GenFloat16VecDataset(testutils.GenerateFloat16Vectors(msgLength, DefaultDim)) case schemapb.DataType_BFloat16Vector: - dataset = indexcgowrapper.GenBFloat16VecDataset(testutils.GenerateBFloat16Vectors(msgLength, defaultDim)) + dataset = indexcgowrapper.GenBFloat16VecDataset(testutils.GenerateBFloat16Vectors(msgLength, DefaultDim)) case schemapb.DataType_SparseFloatVector: contents, dim := testutils.GenerateSparseFloatVectorsData(msgLength) dataset = indexcgowrapper.GenSparseFloatVecDataset(&storage.SparseFloatVectorFieldData{ @@ -806,14 +807,14 @@ func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64, return nil, err } } - _, cCurrentIndexVersion := getIndexEngineVersion() + indexVersion := segcore.GetIndexEngineInfo() return &querypb.FieldIndexInfo{ FieldID: fieldSchema.GetFieldID(), IndexName: indexInfo.GetIndexName(), IndexParams: indexInfo.GetIndexParams(), IndexFilePaths: indexPaths, - CurrentIndexVersion: cCurrentIndexVersion, + CurrentIndexVersion: indexVersion.CurrentIndexVersion, }, nil } @@ -826,7 +827,7 @@ func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLen } defer index.Delete() - err = index.Build(indexcgowrapper.GenFloatVecDataset(testutils.GenerateFloatVectors(msgLength, defaultDim))) + err = index.Build(indexcgowrapper.GenFloatVecDataset(testutils.GenerateFloatVectors(msgLength, DefaultDim))) if err != nil { return nil, err } @@ -845,7 +846,7 @@ func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLen collectionID, partitionID, segmentID, - simpleFloatVecField.id, + SimpleFloatVecField.ID, indexParams, "querynode-test", 0, @@ -866,20 +867,20 @@ func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLen return nil, err } } - _, cCurrentIndexVersion := getIndexEngineVersion() + indexEngineInfo := segcore.GetIndexEngineInfo() return &querypb.FieldIndexInfo{ FieldID: fieldID, IndexName: "querynode-test", IndexParams: funcutil.Map2KeyValuePair(indexParams), IndexFilePaths: indexPaths, - CurrentIndexVersion: cCurrentIndexVersion, + CurrentIndexVersion: indexEngineInfo.CurrentIndexVersion, }, nil } func genIndexParams(indexType, metricType string) (map[string]string, map[string]string) { typeParams := make(map[string]string) - typeParams[common.DimKey] = strconv.Itoa(defaultDim) + typeParams[common.DimKey] = strconv.Itoa(DefaultDim) indexParams := make(map[string]string) indexParams[common.IndexTypeKey] = indexType @@ -927,7 +928,7 @@ func genStorageConfig() *indexpb.StorageConfig { } } -func genSearchRequest(nq int64, indexType string, collection *Collection) (*internalpb.SearchRequest, error) { +func genSearchRequest(nq int64, indexType string, collection *segcore.CCollection) (*internalpb.SearchRequest, error) { placeHolder, err := genPlaceHolderGroup(nq) if err != nil { return nil, err @@ -946,7 +947,6 @@ func genSearchRequest(nq int64, indexType string, collection *Collection) (*inte return &internalpb.SearchRequest{ Base: genCommonMsgBase(commonpb.MsgType_Search, 0), CollectionID: collection.ID(), - PartitionIDs: collection.GetPartitions(), PlaceholderGroup: placeHolder, SerializedExprPlan: serializedPlan, DslType: commonpb.DslType_BoolExprV1, @@ -969,8 +969,8 @@ func genPlaceHolderGroup(nq int64) ([]byte, error) { Values: make([][]byte, 0), } for i := int64(0); i < nq; i++ { - vec := make([]float32, defaultDim) - for j := 0; j < defaultDim; j++ { + vec := make([]float32, DefaultDim) + for j := 0; j < DefaultDim; j++ { vec[j] = rand.Float32() } var rawData []byte @@ -1070,22 +1070,22 @@ func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDeci >`, nil } -func checkSearchResult(ctx context.Context, nq int64, plan *SearchPlan, searchResult *SearchResult) error { - searchResults := make([]*SearchResult, 0) +func CheckSearchResult(ctx context.Context, nq int64, plan *segcore.SearchPlan, searchResult *segcore.SearchResult) error { + searchResults := make([]*segcore.SearchResult, 0) searchResults = append(searchResults, searchResult) - topK := plan.getTopK() + topK := plan.GetTopK() sliceNQs := []int64{nq / 5, nq / 5, nq / 5, nq / 5, nq / 5} sliceTopKs := []int64{topK, topK / 2, topK, topK, topK / 2} - sInfo := ParseSliceInfo(sliceNQs, sliceTopKs, nq) + sInfo := segcore.ParseSliceInfo(sliceNQs, sliceTopKs, nq) - res, err := ReduceSearchResultsAndFillData(ctx, plan, searchResults, 1, sInfo.SliceNQs, sInfo.SliceTopKs) + res, err := segcore.ReduceSearchResultsAndFillData(ctx, plan, searchResults, 1, sInfo.SliceNQs, sInfo.SliceTopKs) if err != nil { return err } for i := 0; i < len(sInfo.SliceNQs); i++ { - blob, err := GetSearchResultDataBlob(ctx, res, i) + blob, err := segcore.GetSearchResultDataBlob(ctx, res, i) if err != nil { return err } @@ -1114,12 +1114,14 @@ func checkSearchResult(ctx context.Context, nq int64, plan *SearchPlan, searchRe } } - DeleteSearchResults(searchResults) - DeleteSearchResultDataBlobs(res) + for _, searchResult := range searchResults { + searchResult.Release() + } + segcore.DeleteSearchResultDataBlobs(res) return nil } -func genSearchPlanAndRequests(collection *Collection, segments []int64, indexType string, nq int64) (*SearchRequest, error) { +func GenSearchPlanAndRequests(collection *segcore.CCollection, segments []int64, indexType string, nq int64) (*segcore.SearchRequest, error) { iReq, _ := genSearchRequest(nq, indexType, collection) queryReq := &querypb.SearchRequest{ Req: iReq, @@ -1127,10 +1129,10 @@ func genSearchPlanAndRequests(collection *Collection, segments []int64, indexTyp SegmentIDs: segments, Scope: querypb.DataScope_Historical, } - return NewSearchRequest(context.Background(), collection, queryReq, queryReq.Req.GetPlaceholderGroup()) + return segcore.NewSearchRequest(collection, queryReq, queryReq.Req.GetPlaceholderGroup()) } -func genInsertMsg(collection *Collection, partitionID, segment int64, numRows int) (*msgstream.InsertMsg, error) { +func GenInsertMsg(collection *segcore.CCollection, partitionID, segment int64, numRows int) (*msgstream.InsertMsg, error) { fieldsData := make([]*schemapb.FieldData, 0) for _, f := range collection.Schema().Fields { @@ -1156,7 +1158,7 @@ func genInsertMsg(collection *Collection, partitionID, segment int64, numRows in case schemapb.DataType_JSON: fieldsData = append(fieldsData, testutils.GenerateScalarFieldDataWithID(f.DataType, simpleJSONField.fieldName, f.GetFieldID(), numRows)) case schemapb.DataType_FloatVector: - dim := simpleFloatVecField.dim // if no dim specified, use simpleFloatVecField's dim + dim := SimpleFloatVecField.dim // if no dim specified, use simpleFloatVecField's dim fieldsData = append(fieldsData, testutils.GenerateVectorFieldDataWithID(f.DataType, f.Name, f.FieldID, numRows, dim)) case schemapb.DataType_BinaryVector: dim := simpleBinVecField.dim // if no dim specified, use simpleFloatVecField's dim @@ -1227,14 +1229,14 @@ func genSimpleRowIDField(numRows int) []int64 { return ids } -func genSimpleRetrievePlan(collection *Collection) (*RetrievePlan, error) { +func GenSimpleRetrievePlan(collection *segcore.CCollection) (*segcore.RetrievePlan, error) { timestamp := storage.Timestamp(1000) - planBytes, err := genSimpleRetrievePlanExpr(collection.schema.Load()) + planBytes, err := genSimpleRetrievePlanExpr(collection.Schema()) if err != nil { return nil, err } - plan, err2 := NewRetrievePlan(context.Background(), collection, planBytes, timestamp, 100) + plan, err2 := segcore.NewRetrievePlan(collection, planBytes, timestamp, 100) return plan, err2 } @@ -1279,14 +1281,14 @@ func genSimpleRetrievePlanExpr(schema *schemapb.CollectionSchema) ([]byte, error return planExpr, err } -func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, fieldValue interface{}, dim int64) *schemapb.FieldData { +func GenFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, fieldValue interface{}, dim int64) *schemapb.FieldData { if fieldType < 100 { return testutils.GenerateScalarFieldDataWithValue(fieldType, fieldName, fieldID, fieldValue) } return testutils.GenerateVectorFieldDataWithValue(fieldType, fieldName, fieldID, fieldValue, int(dim)) } -func genSearchResultData(nq int64, topk int64, ids []int64, scores []float32, topks []int64) *schemapb.SearchResultData { +func GenSearchResultData(nq int64, topk int64, ids []int64, scores []float32, topks []int64) *schemapb.SearchResultData { return &schemapb.SearchResultData{ NumQueries: 1, TopK: topk, diff --git a/internal/util/searchutil/optimizers/mock_query_hook.go b/internal/mocks/util/searchutil/mock_optimizers/mock_QueryHook.go similarity index 99% rename from internal/util/searchutil/optimizers/mock_query_hook.go rename to internal/mocks/util/searchutil/mock_optimizers/mock_QueryHook.go index 6b084d7098fe4..5521418b1141b 100644 --- a/internal/util/searchutil/optimizers/mock_query_hook.go +++ b/internal/mocks/util/searchutil/mock_optimizers/mock_QueryHook.go @@ -1,6 +1,6 @@ // Code generated by mockery v2.46.0. DO NOT EDIT. -package optimizers +package mock_optimizers import mock "github.com/stretchr/testify/mock" diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index 7463c64da11a0..7fc98c8c7e944 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -36,6 +36,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/planpb" @@ -1519,7 +1520,7 @@ func (s *DelegatorDataSuite) TestLevel0Deletions() { err = allPartitionDeleteData.Append(storage.NewInt64PrimaryKey(2), 101) s.Require().NoError(err) - schema := segments.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true) + schema := mock_segcore.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true) collection := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) diff --git a/internal/querynodev2/local_worker_test.go b/internal/querynodev2/local_worker_test.go index 724387e137fcb..50ed8b4c1d965 100644 --- a/internal/querynodev2/local_worker_test.go +++ b/internal/querynodev2/local_worker_test.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" @@ -93,8 +94,8 @@ func (suite *LocalWorkerTestSuite) BeforeTest(suiteName, testName string) { err = suite.node.Start() suite.NoError(err) - suite.schema = segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) - suite.indexMeta = segments.GenTestIndexMeta(suite.collectionID, suite.schema) + suite.schema = mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) + suite.indexMeta = mock_segcore.GenTestIndexMeta(suite.collectionID, suite.schema) collection := segments.NewCollection(suite.collectionID, suite.schema, suite.indexMeta, &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) @@ -114,7 +115,7 @@ func (suite *LocalWorkerTestSuite) AfterTest(suiteName, testName string) { func (suite *LocalWorkerTestSuite) TestLoadSegment() { // load empty - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ TargetID: suite.node.session.GetServerID(), diff --git a/internal/querynodev2/pipeline/insert_node_test.go b/internal/querynodev2/pipeline/insert_node_test.go index 65bf17240f339..c1380db2c28d7 100644 --- a/internal/querynodev2/pipeline/insert_node_test.go +++ b/internal/querynodev2/pipeline/insert_node_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" @@ -58,10 +59,10 @@ func (suite *InsertNodeSuite) SetupSuite() { func (suite *InsertNodeSuite) TestBasic() { // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) in := suite.buildInsertNodeMsg(schema) - collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ + collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) collection.AddPartition(suite.partitionID) @@ -94,10 +95,10 @@ func (suite *InsertNodeSuite) TestBasic() { } func (suite *InsertNodeSuite) TestDataTypeNotSupported() { - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) in := suite.buildInsertNodeMsg(schema) - collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ + collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) collection.AddPartition(suite.partitionID) diff --git a/internal/querynodev2/pipeline/pipeline_test.go b/internal/querynodev2/pipeline/pipeline_test.go index 8e2edde82016b..2539e690beb2a 100644 --- a/internal/querynodev2/pipeline/pipeline_test.go +++ b/internal/querynodev2/pipeline/pipeline_test.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/segments" @@ -108,8 +109,8 @@ func (suite *PipelineTestSuite) SetupTest() { func (suite *PipelineTestSuite) TestBasic() { // init mock // mock collection manager - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) - collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) + collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection) diff --git a/internal/querynodev2/segments/cgo_util.go b/internal/querynodev2/segments/cgo_util.go index b127a19909888..222e8ca95a4fa 100644 --- a/internal/querynodev2/segments/cgo_util.go +++ b/internal/querynodev2/segments/cgo_util.go @@ -28,13 +28,10 @@ import "C" import ( "context" - "math" "unsafe" "go.uber.org/zap" - "google.golang.org/protobuf/proto" - "github.com/milvus-io/milvus/internal/util/cgoconverter" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -55,38 +52,3 @@ func HandleCStatus(ctx context.Context, status *C.CStatus, extraInfo string, fie log.Warn("CStatus returns err", zap.Error(err), zap.String("extra", extraInfo)) return err } - -// UnmarshalCProto unmarshal the proto from C memory -func UnmarshalCProto(cRes *C.CProto, msg proto.Message) error { - blob := (*(*[math.MaxInt32]byte)(cRes.proto_blob))[:int(cRes.proto_size):int(cRes.proto_size)] - return proto.Unmarshal(blob, msg) -} - -// CopyCProtoBlob returns the copy of C memory -func CopyCProtoBlob(cProto *C.CProto) []byte { - blob := C.GoBytes(cProto.proto_blob, C.int32_t(cProto.proto_size)) - C.free(cProto.proto_blob) - return blob -} - -// GetCProtoBlob returns the raw C memory, invoker should release it itself -func GetCProtoBlob(cProto *C.CProto) []byte { - lease, blob := cgoconverter.UnsafeGoBytes(&cProto.proto_blob, int(cProto.proto_size)) - cgoconverter.Extract(lease) - return blob -} - -func GetLocalUsedSize(ctx context.Context, path string) (int64, error) { - var availableSize int64 - cSize := (*C.int64_t)(&availableSize) - cPath := C.CString(path) - defer C.free(unsafe.Pointer(cPath)) - - status := C.GetLocalUsedSize(cPath, cSize) - err := HandleCStatus(ctx, &status, "get local used size failed") - if err != nil { - return 0, err - } - - return availableSize, nil -} diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 561507baa91e9..c80572db00751 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -16,27 +16,18 @@ package segments -/* -#cgo pkg-config: milvus_core - -#include "segcore/collection_c.h" -#include "segcore/segment_c.h" -*/ -import "C" - import ( "sync" - "unsafe" "github.com/samber/lo" "go.uber.org/atomic" "go.uber.org/zap" - "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -145,7 +136,7 @@ func (m *collectionManager) Unref(collectionID int64, count uint32) bool { // In a query node, `Collection` is a replica info of a collection in these query node. type Collection struct { mu sync.RWMutex // protects colllectionPtr - collectionPtr C.CCollection + ccollection *segcore.CCollection id int64 partitions *typeutil.ConcurrentSet[int64] loadType querypb.LoadType @@ -178,6 +169,11 @@ func (c *Collection) ID() int64 { return c.id } +// GetCCollection returns the CCollection of collection +func (c *Collection) GetCCollection() *segcore.CCollection { + return c.ccollection +} + // Schema returns the schema of collection func (c *Collection) Schema() *schemapb.CollectionSchema { return c.schema.Load() @@ -254,23 +250,12 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM loadFieldIDs = typeutil.NewSet(lo.Map(loadSchema.GetFields(), func(field *schemapb.FieldSchema, _ int) int64 { return field.GetFieldID() })...) } - schemaBlob, err := proto.Marshal(loadSchema) - if err != nil { - log.Warn("marshal schema failed", zap.Error(err)) - return nil - } - - collection := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob))) - isGpuIndex := false + req := &segcore.CreateCCollectionRequest{ + Schema: loadSchema, + } if indexMeta != nil && len(indexMeta.GetIndexMetas()) > 0 && indexMeta.GetMaxIndexRowCount() > 0 { - indexMetaBlob, err := proto.Marshal(indexMeta) - if err != nil { - log.Warn("marshal index meta failed", zap.Error(err)) - return nil - } - C.SetIndexMeta(collection, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob))) - + req.IndexMeta = indexMeta for _, indexMeta := range indexMeta.GetIndexMetas() { isGpuIndex = lo.ContainsBy(indexMeta.GetIndexParams(), func(param *commonpb.KeyValuePair) bool { return param.Key == common.IndexTypeKey && vecindexmgr.GetVecIndexMgrInstance().IsGPUVecIndex(param.Value) @@ -281,8 +266,13 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM } } + ccollection, err := segcore.CreateCCollection(req) + if err != nil { + log.Warn("create collection failed", zap.Error(err)) + return nil + } coll := &Collection{ - collectionPtr: collection, + ccollection: ccollection, id: collectionID, partitions: typeutil.NewConcurrentSet[int64](), loadType: loadMetaInfo.GetLoadType(), @@ -330,10 +320,9 @@ func DeleteCollection(collection *Collection) { collection.mu.Lock() defer collection.mu.Unlock() - cPtr := collection.collectionPtr - if cPtr != nil { - C.DeleteCollection(cPtr) + if collection.ccollection == nil { + return } - - collection.collectionPtr = nil + collection.ccollection.Release() + collection.ccollection = nil } diff --git a/internal/querynodev2/segments/count_reducer.go b/internal/querynodev2/segments/count_reducer.go index 3cf5367ea4acc..ae484dd54fe14 100644 --- a/internal/querynodev2/segments/count_reducer.go +++ b/internal/querynodev2/segments/count_reducer.go @@ -6,6 +6,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/segcore" ) type cntReducer struct{} @@ -33,7 +34,7 @@ func (r *cntReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveR type cntReducerSegCore struct{} -func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, _ []Segment, _ *RetrievePlan) (*segcorepb.RetrieveResults, error) { +func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, _ []Segment, _ *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) { cnt := int64(0) allRetrieveCount := int64(0) for _, res := range results { diff --git a/internal/querynodev2/segments/manager_test.go b/internal/querynodev2/segments/manager_test.go index a5f4bd668a304..7fc0851e16dfb 100644 --- a/internal/querynodev2/segments/manager_test.go +++ b/internal/querynodev2/segments/manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/initcore" @@ -50,10 +51,10 @@ func (s *ManagerSuite) SetupTest() { s.segments = nil for i, id := range s.segmentIDs { - schema := GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64, true) + schema := mock_segcore.GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64, true) segment, err := NewSegment( context.Background(), - NewCollection(s.collectionIDs[i], schema, GenTestIndexMeta(s.collectionIDs[i], schema), &querypb.LoadMetaInfo{ + NewCollection(s.collectionIDs[i], schema, mock_segcore.GenTestIndexMeta(s.collectionIDs[i], schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }), s.types[i], diff --git a/internal/querynodev2/segments/mock_segment.go b/internal/querynodev2/segments/mock_segment.go index ca19cdb897d80..1af3012ed3038 100644 --- a/internal/querynodev2/segments/mock_segment.go +++ b/internal/querynodev2/segments/mock_segment.go @@ -17,6 +17,8 @@ import ( querypb "github.com/milvus-io/milvus/internal/proto/querypb" + segcore "github.com/milvus-io/milvus/internal/util/segcore" + segcorepb "github.com/milvus-io/milvus/internal/proto/segcorepb" storage "github.com/milvus-io/milvus/internal/storage" @@ -1358,7 +1360,7 @@ func (_c *MockSegment_ResourceUsageEstimate_Call) RunAndReturn(run func() Resour } // Retrieve provides a mock function with given fields: ctx, plan -func (_m *MockSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { +func (_m *MockSegment) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) { ret := _m.Called(ctx, plan) if len(ret) == 0 { @@ -1367,10 +1369,10 @@ func (_m *MockSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco var r0 *segcorepb.RetrieveResults var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan) (*segcorepb.RetrieveResults, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error)); ok { return rf(ctx, plan) } - if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan) *segcorepb.RetrieveResults); ok { + if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlan) *segcorepb.RetrieveResults); ok { r0 = rf(ctx, plan) } else { if ret.Get(0) != nil { @@ -1378,7 +1380,7 @@ func (_m *MockSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco } } - if rf, ok := ret.Get(1).(func(context.Context, *RetrievePlan) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, *segcore.RetrievePlan) error); ok { r1 = rf(ctx, plan) } else { r1 = ret.Error(1) @@ -1394,14 +1396,14 @@ type MockSegment_Retrieve_Call struct { // Retrieve is a helper method to define mock.On call // - ctx context.Context -// - plan *RetrievePlan +// - plan *segcore.RetrievePlan func (_e *MockSegment_Expecter) Retrieve(ctx interface{}, plan interface{}) *MockSegment_Retrieve_Call { return &MockSegment_Retrieve_Call{Call: _e.mock.On("Retrieve", ctx, plan)} } -func (_c *MockSegment_Retrieve_Call) Run(run func(ctx context.Context, plan *RetrievePlan)) *MockSegment_Retrieve_Call { +func (_c *MockSegment_Retrieve_Call) Run(run func(ctx context.Context, plan *segcore.RetrievePlan)) *MockSegment_Retrieve_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*RetrievePlan)) + run(args[0].(context.Context), args[1].(*segcore.RetrievePlan)) }) return _c } @@ -1411,14 +1413,14 @@ func (_c *MockSegment_Retrieve_Call) Return(_a0 *segcorepb.RetrieveResults, _a1 return _c } -func (_c *MockSegment_Retrieve_Call) RunAndReturn(run func(context.Context, *RetrievePlan) (*segcorepb.RetrieveResults, error)) *MockSegment_Retrieve_Call { +func (_c *MockSegment_Retrieve_Call) RunAndReturn(run func(context.Context, *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error)) *MockSegment_Retrieve_Call { _c.Call.Return(run) return _c } -// RetrieveByOffsets provides a mock function with given fields: ctx, plan, offsets -func (_m *MockSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) { - ret := _m.Called(ctx, plan, offsets) +// RetrieveByOffsets provides a mock function with given fields: ctx, plan +func (_m *MockSegment) RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error) { + ret := _m.Called(ctx, plan) if len(ret) == 0 { panic("no return value specified for RetrieveByOffsets") @@ -1426,19 +1428,19 @@ func (_m *MockSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan var r0 *segcorepb.RetrieveResults var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan, []int64) (*segcorepb.RetrieveResults, error)); ok { - return rf(ctx, plan, offsets) + if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error)); ok { + return rf(ctx, plan) } - if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan, []int64) *segcorepb.RetrieveResults); ok { - r0 = rf(ctx, plan, offsets) + if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlanWithOffsets) *segcorepb.RetrieveResults); ok { + r0 = rf(ctx, plan) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*segcorepb.RetrieveResults) } } - if rf, ok := ret.Get(1).(func(context.Context, *RetrievePlan, []int64) error); ok { - r1 = rf(ctx, plan, offsets) + if rf, ok := ret.Get(1).(func(context.Context, *segcore.RetrievePlanWithOffsets) error); ok { + r1 = rf(ctx, plan) } else { r1 = ret.Error(1) } @@ -1453,15 +1455,14 @@ type MockSegment_RetrieveByOffsets_Call struct { // RetrieveByOffsets is a helper method to define mock.On call // - ctx context.Context -// - plan *RetrievePlan -// - offsets []int64 -func (_e *MockSegment_Expecter) RetrieveByOffsets(ctx interface{}, plan interface{}, offsets interface{}) *MockSegment_RetrieveByOffsets_Call { - return &MockSegment_RetrieveByOffsets_Call{Call: _e.mock.On("RetrieveByOffsets", ctx, plan, offsets)} +// - plan *segcore.RetrievePlanWithOffsets +func (_e *MockSegment_Expecter) RetrieveByOffsets(ctx interface{}, plan interface{}) *MockSegment_RetrieveByOffsets_Call { + return &MockSegment_RetrieveByOffsets_Call{Call: _e.mock.On("RetrieveByOffsets", ctx, plan)} } -func (_c *MockSegment_RetrieveByOffsets_Call) Run(run func(ctx context.Context, plan *RetrievePlan, offsets []int64)) *MockSegment_RetrieveByOffsets_Call { +func (_c *MockSegment_RetrieveByOffsets_Call) Run(run func(ctx context.Context, plan *segcore.RetrievePlanWithOffsets)) *MockSegment_RetrieveByOffsets_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*RetrievePlan), args[2].([]int64)) + run(args[0].(context.Context), args[1].(*segcore.RetrievePlanWithOffsets)) }) return _c } @@ -1471,7 +1472,7 @@ func (_c *MockSegment_RetrieveByOffsets_Call) Return(_a0 *segcorepb.RetrieveResu return _c } -func (_c *MockSegment_RetrieveByOffsets_Call) RunAndReturn(run func(context.Context, *RetrievePlan, []int64) (*segcorepb.RetrieveResults, error)) *MockSegment_RetrieveByOffsets_Call { +func (_c *MockSegment_RetrieveByOffsets_Call) RunAndReturn(run func(context.Context, *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error)) *MockSegment_RetrieveByOffsets_Call { _c.Call.Return(run) return _c } @@ -1522,27 +1523,27 @@ func (_c *MockSegment_RowNum_Call) RunAndReturn(run func() int64) *MockSegment_R } // Search provides a mock function with given fields: ctx, searchReq -func (_m *MockSegment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { +func (_m *MockSegment) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) { ret := _m.Called(ctx, searchReq) if len(ret) == 0 { panic("no return value specified for Search") } - var r0 *SearchResult + var r0 *segcore.SearchResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *SearchRequest) (*SearchResult, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, *segcore.SearchRequest) (*segcore.SearchResult, error)); ok { return rf(ctx, searchReq) } - if rf, ok := ret.Get(0).(func(context.Context, *SearchRequest) *SearchResult); ok { + if rf, ok := ret.Get(0).(func(context.Context, *segcore.SearchRequest) *segcore.SearchResult); ok { r0 = rf(ctx, searchReq) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*SearchResult) + r0 = ret.Get(0).(*segcore.SearchResult) } } - if rf, ok := ret.Get(1).(func(context.Context, *SearchRequest) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, *segcore.SearchRequest) error); ok { r1 = rf(ctx, searchReq) } else { r1 = ret.Error(1) @@ -1558,24 +1559,24 @@ type MockSegment_Search_Call struct { // Search is a helper method to define mock.On call // - ctx context.Context -// - searchReq *SearchRequest +// - searchReq *segcore.SearchRequest func (_e *MockSegment_Expecter) Search(ctx interface{}, searchReq interface{}) *MockSegment_Search_Call { return &MockSegment_Search_Call{Call: _e.mock.On("Search", ctx, searchReq)} } -func (_c *MockSegment_Search_Call) Run(run func(ctx context.Context, searchReq *SearchRequest)) *MockSegment_Search_Call { +func (_c *MockSegment_Search_Call) Run(run func(ctx context.Context, searchReq *segcore.SearchRequest)) *MockSegment_Search_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*SearchRequest)) + run(args[0].(context.Context), args[1].(*segcore.SearchRequest)) }) return _c } -func (_c *MockSegment_Search_Call) Return(_a0 *SearchResult, _a1 error) *MockSegment_Search_Call { +func (_c *MockSegment_Search_Call) Return(_a0 *segcore.SearchResult, _a1 error) *MockSegment_Search_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockSegment_Search_Call) RunAndReturn(run func(context.Context, *SearchRequest) (*SearchResult, error)) *MockSegment_Search_Call { +func (_c *MockSegment_Search_Call) RunAndReturn(run func(context.Context, *segcore.SearchRequest) (*segcore.SearchResult, error)) *MockSegment_Search_Call { _c.Call.Return(run) return _c } diff --git a/internal/querynodev2/segments/reducer.go b/internal/querynodev2/segments/reducer.go index d5ef51a7df7b3..8ef7a4102e800 100644 --- a/internal/querynodev2/segments/reducer.go +++ b/internal/querynodev2/segments/reducer.go @@ -9,6 +9,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -26,7 +27,7 @@ func CreateInternalReducer(req *querypb.QueryRequest, schema *schemapb.Collectio } type segCoreReducer interface { - Reduce(context.Context, []*segcorepb.RetrieveResults, []Segment, *RetrievePlan) (*segcorepb.RetrieveResults, error) + Reduce(context.Context, []*segcorepb.RetrieveResults, []Segment, *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) } func CreateSegCoreReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema, manager *Manager) segCoreReducer { diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index c3eaa032cc9fb..6eaaab1ad6715 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/util/reduce" + "github.com/milvus-io/milvus/internal/util/segcore" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -413,7 +414,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore return nil, err } validRetrieveResults = append(validRetrieveResults, tr) - if plan.ignoreNonPk { + if plan.IsIgnoreNonPk() { validSegments = append(validSegments, segments[i]) } loopEnd += size @@ -493,7 +494,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore log.Debug("skip duplicated query result while reducing segcore.RetrieveResults", zap.Int64("dupCount", skipDupCnt)) } - if !plan.ignoreNonPk { + if !plan.IsIgnoreNonPk() { // target entry already retrieved, don't do this after AppendPKs for better performance. Save the cost everytime // judge the `!plan.ignoreNonPk` condition. _, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData") @@ -524,7 +525,10 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore var r *segcorepb.RetrieveResults var err error if err := doOnSegment(ctx, manager, validSegments[idx], func(ctx context.Context, segment Segment) error { - r, err = segment.RetrieveByOffsets(ctx, plan, theOffsets) + r, err = segment.RetrieveByOffsets(ctx, &segcore.RetrievePlanWithOffsets{ + RetrievePlan: plan, + Offsets: theOffsets, + }) return err }); err != nil { return nil, err diff --git a/internal/querynodev2/segments/result_test.go b/internal/querynodev2/segments/result_test.go index 7bd6dc42c7706..60c5332cd9aec 100644 --- a/internal/querynodev2/segments/result_test.go +++ b/internal/querynodev2/segments/result_test.go @@ -27,9 +27,11 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/util/reduce" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -50,7 +52,7 @@ type ResultSuite struct { } func MergeSegcoreRetrieveResultsV1(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, param *mergeParam) (*segcorepb.RetrieveResults, error) { - plan := &RetrievePlan{ignoreNonPk: false} + plan := &segcore.RetrievePlan{} return MergeSegcoreRetrieveResults(ctx, retrieveResults, param, nil, plan, nil) } @@ -66,14 +68,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} var fieldDataArray1 []*schemapb.FieldData - fieldDataArray1 = append(fieldDataArray1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1)) - fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) - fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) + fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1)) + fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) + fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) var fieldDataArray2 []*schemapb.FieldData - fieldDataArray2 = append(fieldDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) - fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) - fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) + fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) + fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) + fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) suite.Run("test skip dupPK 2", func() { result1 := &segcorepb.RetrieveResults{ @@ -114,14 +116,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { suite.Run("test_duppk_multipke_segment", func() { var fieldsData1 []*schemapb.FieldData - fieldsData1 = append(fieldsData1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) - fieldsData1 = append(fieldsData1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{1, 1}, 1)) - fieldsData1 = append(fieldsData1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) + fieldsData1 = append(fieldsData1, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) + fieldsData1 = append(fieldsData1, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{1, 1}, 1)) + fieldsData1 = append(fieldsData1, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) var fieldsData2 []*schemapb.FieldData - fieldsData2 = append(fieldsData2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2500}, 1)) - fieldsData2 = append(fieldsData2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{1}, 1)) - fieldsData2 = append(fieldsData2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:8], Dim)) + fieldsData2 = append(fieldsData2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2500}, 1)) + fieldsData2 = append(fieldsData2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{1}, 1)) + fieldsData2 = append(fieldsData2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:8], Dim)) result1 := &segcorepb.RetrieveResults{ Ids: &schemapb.IDs{ @@ -254,7 +256,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() { ids[i] = int64(i) offsets[i] = int64(i) } - fieldData := genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1) + fieldData := mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1) result := &segcorepb.RetrieveResults{ Ids: &schemapb.IDs{ @@ -333,14 +335,14 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} var fieldDataArray1 []*schemapb.FieldData - fieldDataArray1 = append(fieldDataArray1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1)) - fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) - fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) + fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1)) + fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) + fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) var fieldDataArray2 []*schemapb.FieldData - fieldDataArray2 = append(fieldDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) - fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) - fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) + fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) + fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) + fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) suite.Run("test skip dupPK 2", func() { result1 := &internalpb.RetrieveResults{ @@ -395,9 +397,9 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { }, }, FieldsData: []*schemapb.FieldData{ - genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, + mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1, 2}, 1), - genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, + mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{3, 4}, 1), }, } @@ -410,9 +412,9 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { }, }, FieldsData: []*schemapb.FieldData{ - genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, + mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{5, 6}, 1), - genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, + mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{7, 8}, 1), }, } @@ -493,7 +495,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() { ids[i] = int64(i) offsets[i] = int64(i) } - fieldData := genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1) + fieldData := mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1) result := &internalpb.RetrieveResults{ Ids: &schemapb.IDs{ @@ -572,17 +574,17 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0} var fieldDataArray1 []*schemapb.FieldData - fieldDataArray1 = append(fieldDataArray1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000, 3000}, 1)) - fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, + fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000, 3000}, 1)) + fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:3], 1)) - fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, + fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:12], Dim)) var fieldDataArray2 []*schemapb.FieldData - fieldDataArray2 = append(fieldDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000, 4000}, 1)) - fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, + fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000, 4000}, 1)) + fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:3], 1)) - fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, + fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:12], Dim)) suite.Run("test stop seg core merge for best", func() { @@ -712,10 +714,10 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() { FieldsData: fieldDataArray1, } var drainDataArray2 []*schemapb.FieldData - drainDataArray2 = append(drainDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000}, 1)) - drainDataArray2 = append(drainDataArray2, genFieldData(Int64FieldName, Int64FieldID, + drainDataArray2 = append(drainDataArray2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000}, 1)) + drainDataArray2 = append(drainDataArray2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:1], 1)) - drainDataArray2 = append(drainDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, + drainDataArray2 = append(drainDataArray2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:4], Dim)) result2 := &internalpb.RetrieveResults{ Ids: &schemapb.IDs{ @@ -878,28 +880,28 @@ func (suite *ResultSuite) TestSort() { }, Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6}, FieldsData: []*schemapb.FieldData{ - genFieldData("int64 field", 100, schemapb.DataType_Int64, + mock_segcore.GenFieldData("int64 field", 100, schemapb.DataType_Int64, []int64{5, 4, 3, 2, 9, 8, 7, 6}, 1), - genFieldData("double field", 101, schemapb.DataType_Double, + mock_segcore.GenFieldData("double field", 101, schemapb.DataType_Double, []float64{5, 4, 3, 2, 9, 8, 7, 6}, 1), - genFieldData("string field", 102, schemapb.DataType_VarChar, + mock_segcore.GenFieldData("string field", 102, schemapb.DataType_VarChar, []string{"5", "4", "3", "2", "9", "8", "7", "6"}, 1), - genFieldData("bool field", 103, schemapb.DataType_Bool, + mock_segcore.GenFieldData("bool field", 103, schemapb.DataType_Bool, []bool{false, true, false, true, false, true, false, true}, 1), - genFieldData("float field", 104, schemapb.DataType_Float, + mock_segcore.GenFieldData("float field", 104, schemapb.DataType_Float, []float32{5, 4, 3, 2, 9, 8, 7, 6}, 1), - genFieldData("int field", 105, schemapb.DataType_Int32, + mock_segcore.GenFieldData("int field", 105, schemapb.DataType_Int32, []int32{5, 4, 3, 2, 9, 8, 7, 6}, 1), - genFieldData("float vector field", 106, schemapb.DataType_FloatVector, + mock_segcore.GenFieldData("float vector field", 106, schemapb.DataType_FloatVector, []float32{5, 4, 3, 2, 9, 8, 7, 6}, 1), - genFieldData("binary vector field", 107, schemapb.DataType_BinaryVector, + mock_segcore.GenFieldData("binary vector field", 107, schemapb.DataType_BinaryVector, []byte{5, 4, 3, 2, 9, 8, 7, 6}, 8), - genFieldData("json field", 108, schemapb.DataType_JSON, + mock_segcore.GenFieldData("json field", 108, schemapb.DataType_JSON, [][]byte{ []byte("{\"5\": 5}"), []byte("{\"4\": 4}"), []byte("{\"3\": 3}"), []byte("{\"2\": 2}"), []byte("{\"9\": 9}"), []byte("{\"8\": 8}"), []byte("{\"7\": 7}"), []byte("{\"6\": 6}"), }, 1), - genFieldData("json field", 108, schemapb.DataType_Array, + mock_segcore.GenFieldData("json field", 108, schemapb.DataType_Array, []*schemapb.ScalarField{ {Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{5, 6, 7}}}}, {Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{4, 5, 6}}}}, diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index b8fa826b8caea..f7c30bfa4ed62 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -54,7 +54,7 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s } return false }() - plan.ignoreNonPk = !anySegIsLazyLoad && len(segments) > 1 && req.GetReq().GetLimit() != typeutil.Unlimited && plan.ShouldIgnoreNonPk() + plan.SetIgnoreNonPk(!anySegIsLazyLoad && len(segments) > 1 && req.GetReq().GetLimit() != typeutil.Unlimited && plan.ShouldIgnoreNonPk()) label := metrics.SealedSegmentLabel if segType == commonpb.SegmentState_Growing { diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index 737c697dd4a5f..3e12eb334b68a 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -25,11 +25,13 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/initcore" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -71,8 +73,8 @@ func (suite *RetrieveSuite) SetupTest() { suite.segmentID = 1 suite.manager = NewManager() - schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) - indexMeta := GenTestIndexMeta(suite.collectionID, schema) + schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) + indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, schema) suite.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, @@ -99,7 +101,7 @@ func (suite *RetrieveSuite) SetupTest() { ) suite.Require().NoError(err) - binlogs, _, err := SaveBinLog(ctx, + binlogs, _, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID, @@ -127,7 +129,7 @@ func (suite *RetrieveSuite) SetupTest() { ) suite.Require().NoError(err) - insertMsg, err := genInsertMsg(suite.collection, suite.partitionID, suite.growing.ID(), msgLength) + insertMsg, err := mock_segcore.GenInsertMsg(suite.collection.GetCCollection(), suite.partitionID, suite.growing.ID(), msgLength) suite.Require().NoError(err) insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg) suite.Require().NoError(err) @@ -147,7 +149,7 @@ func (suite *RetrieveSuite) TearDownTest() { } func (suite *RetrieveSuite) TestRetrieveSealed() { - plan, err := genSimpleRetrievePlan(suite.collection) + plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection()) suite.NoError(err) req := &querypb.QueryRequest{ @@ -164,13 +166,16 @@ func (suite *RetrieveSuite) TestRetrieveSealed() { suite.Len(res[0].Result.Offset, 3) suite.manager.Segment.Unpin(segments) - resultByOffsets, err := suite.sealed.RetrieveByOffsets(context.Background(), plan, []int64{0, 1}) + resultByOffsets, err := suite.sealed.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{ + RetrievePlan: plan, + Offsets: []int64{0, 1}, + }) suite.NoError(err) suite.Len(resultByOffsets.Offset, 0) } func (suite *RetrieveSuite) TestRetrieveGrowing() { - plan, err := genSimpleRetrievePlan(suite.collection) + plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection()) suite.NoError(err) req := &querypb.QueryRequest{ @@ -187,13 +192,16 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() { suite.Len(res[0].Result.Offset, 3) suite.manager.Segment.Unpin(segments) - resultByOffsets, err := suite.growing.RetrieveByOffsets(context.Background(), plan, []int64{0, 1}) + resultByOffsets, err := suite.growing.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{ + RetrievePlan: plan, + Offsets: []int64{0, 1}, + }) suite.NoError(err) suite.Len(resultByOffsets.Offset, 0) } func (suite *RetrieveSuite) TestRetrieveStreamSealed() { - plan, err := genSimpleRetrievePlan(suite.collection) + plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection()) suite.NoError(err) req := &querypb.QueryRequest{ @@ -237,7 +245,7 @@ func (suite *RetrieveSuite) TestRetrieveStreamSealed() { } func (suite *RetrieveSuite) TestRetrieveNonExistSegment() { - plan, err := genSimpleRetrievePlan(suite.collection) + plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection()) suite.NoError(err) req := &querypb.QueryRequest{ @@ -256,7 +264,7 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() { } func (suite *RetrieveSuite) TestRetrieveNilSegment() { - plan, err := genSimpleRetrievePlan(suite.collection) + plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection()) suite.NoError(err) suite.sealed.Release(context.Background()) diff --git a/internal/querynodev2/segments/search.go b/internal/querynodev2/segments/search.go index cc7916c751585..ca5aa76a1ebe7 100644 --- a/internal/querynodev2/segments/search.go +++ b/internal/querynodev2/segments/search.go @@ -55,7 +55,7 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, searchLabel).Observe(float64(elapsed)) metrics.QueryNodeSegmentSearchLatencyPerVector.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), - metrics.SearchLabel, searchLabel).Observe(float64(elapsed) / float64(searchReq.getNumOfQuery())) + metrics.SearchLabel, searchLabel).Observe(float64(elapsed) / float64(searchReq.GetNumOfQuery())) return nil } @@ -64,7 +64,7 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy segmentsWithoutIndex := make([]int64, 0) for _, segment := range segments { seg := segment - if !seg.ExistIndex(searchReq.searchFieldID) { + if !seg.ExistIndex(searchReq.SearchFieldID()) { segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID()) } errGroup.Go(func() error { @@ -148,7 +148,7 @@ func searchSegmentsStreamly(ctx context.Context, metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, searchLabel).Observe(float64(searchDuration)) metrics.QueryNodeSegmentSearchLatencyPerVector.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), - metrics.SearchLabel, searchLabel).Observe(float64(searchDuration) / float64(searchReq.getNumOfQuery())) + metrics.SearchLabel, searchLabel).Observe(float64(searchDuration) / float64(searchReq.GetNumOfQuery())) return nil } diff --git a/internal/querynodev2/segments/search_reduce_test.go b/internal/querynodev2/segments/search_reduce_test.go index 22ff091a0e1a3..2409158dd22da 100644 --- a/internal/querynodev2/segments/search_reduce_test.go +++ b/internal/querynodev2/segments/search_reduce_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -24,8 +25,8 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() { ids := []int64{1, 2, 3, 4} scores := []float32{-1.0, -2.0, -3.0, -4.0} topks := []int64{int64(len(ids))} - data1 := genSearchResultData(nq, topk, ids, scores, topks) - data2 := genSearchResultData(nq, topk, ids, scores, topks) + data1 := mock_segcore.GenSearchResultData(nq, topk, ids, scores, topks) + data2 := mock_segcore.GenSearchResultData(nq, topk, ids, scores, topks) dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) @@ -43,8 +44,8 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() { ids2 := []int64{5, 1, 3, 4} scores2 := []float32{-1.0, -1.0, -3.0, -4.0} topks2 := []int64{int64(len(ids2))} - data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) - data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2) dataArray := make([]*schemapb.SearchResultData, 0) dataArray = append(dataArray, data1) dataArray = append(dataArray, data2) @@ -68,8 +69,8 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { ids2 := []int64{5, 1, 3, 4} scores2 := []float32{-1.0, -1.0, -3.0, -4.0} topks2 := []int64{int64(len(ids2))} - data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) - data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2) data1.GroupByFieldValue = &schemapb.FieldData{ Type: schemapb.DataType_Int8, Field: &schemapb.FieldData_Scalars{ @@ -112,8 +113,8 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { ids2 := []int64{3, 4} scores2 := []float32{-1.0, -1.0} topks2 := []int64{int64(len(ids2))} - data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) - data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2) data1.GroupByFieldValue = &schemapb.FieldData{ Type: schemapb.DataType_Bool, Field: &schemapb.FieldData_Scalars{ @@ -156,8 +157,8 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { ids2 := []int64{5, 1, 3, 4} scores2 := []float32{-1.0, -1.0, -3.0, -4.0} topks2 := []int64{int64(len(ids2))} - data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) - data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2) data1.GroupByFieldValue = &schemapb.FieldData{ Type: schemapb.DataType_VarChar, Field: &schemapb.FieldData_Scalars{ @@ -200,8 +201,8 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() { ids2 := []int64{4, 5, 6, 7} scores2 := []float32{-1.0, -1.0, -3.0, -4.0} topks2 := []int64{int64(len(ids2))} - data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) - data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) + data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1) + data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2) data1.GroupByFieldValue = &schemapb.FieldData{ Type: schemapb.DataType_VarChar, Field: &schemapb.FieldData_Scalars{ diff --git a/internal/querynodev2/segments/search_test.go b/internal/querynodev2/segments/search_test.go index 81475b14c27db..11d003769a87e 100644 --- a/internal/querynodev2/segments/search_test.go +++ b/internal/querynodev2/segments/search_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" storage "github.com/milvus-io/milvus/internal/storage" @@ -62,8 +63,8 @@ func (suite *SearchSuite) SetupTest() { suite.segmentID = 1 suite.manager = NewManager() - schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) - indexMeta := GenTestIndexMeta(suite.collectionID, schema) + schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) + indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, schema) suite.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, @@ -90,7 +91,7 @@ func (suite *SearchSuite) SetupTest() { ) suite.Require().NoError(err) - binlogs, _, err := SaveBinLog(ctx, + binlogs, _, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID, @@ -118,7 +119,7 @@ func (suite *SearchSuite) SetupTest() { ) suite.Require().NoError(err) - insertMsg, err := genInsertMsg(suite.collection, suite.partitionID, suite.growing.ID(), msgLength) + insertMsg, err := mock_segcore.GenInsertMsg(suite.collection.GetCCollection(), suite.partitionID, suite.growing.ID(), msgLength) suite.Require().NoError(err) insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg) suite.Require().NoError(err) @@ -141,7 +142,7 @@ func (suite *SearchSuite) TestSearchSealed() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - searchReq, err := genSearchPlanAndRequests(suite.collection, []int64{suite.sealed.ID()}, IndexFaissIDMap, nq) + searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection.GetCCollection(), []int64{suite.sealed.ID()}, mock_segcore.IndexFaissIDMap, nq) suite.NoError(err) _, segments, err := SearchHistorical(ctx, suite.manager, searchReq, suite.collectionID, nil, []int64{suite.sealed.ID()}) @@ -150,7 +151,7 @@ func (suite *SearchSuite) TestSearchSealed() { } func (suite *SearchSuite) TestSearchGrowing() { - searchReq, err := genSearchPlanAndRequests(suite.collection, []int64{suite.growing.ID()}, IndexFaissIDMap, 1) + searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection.GetCCollection(), []int64{suite.growing.ID()}, mock_segcore.IndexFaissIDMap, 1) suite.NoError(err) res, segments, err := SearchStreaming(context.TODO(), suite.manager, searchReq, diff --git a/internal/querynodev2/segments/segcore.go b/internal/querynodev2/segments/segcore.go new file mode 100644 index 0000000000000..75ef6bbcb358f --- /dev/null +++ b/internal/querynodev2/segments/segcore.go @@ -0,0 +1,21 @@ +package segments + +import "github.com/milvus-io/milvus/internal/util/segcore" + +type ( + SearchRequest = segcore.SearchRequest + SearchResult = segcore.SearchResult + SearchPlan = segcore.SearchPlan + RetrievePlan = segcore.RetrievePlan +) + +func DeleteSearchResults(results []*SearchResult) { + if len(results) == 0 { + return + } + for _, result := range results { + if result != nil { + result.Release() + } + } +} diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 41f92557a5f7d..d873a2ef51d43 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -29,7 +29,6 @@ import "C" import ( "context" "fmt" - "runtime" "strings" "time" "unsafe" @@ -52,8 +51,8 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/querynodev2/segments/state" "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/internal/util/cgo" "github.com/milvus-io/milvus/internal/util/indexparamcheck" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -268,7 +267,9 @@ var _ Segment = (*LocalSegment)(nil) type LocalSegment struct { baseSegment ptrLock *state.LoadStateLock - ptr C.CSegmentInterface + ptr C.CSegmentInterface // TODO: Remove in future, after move load index into segcore package. + // always keep same with csegment.RawPtr(), for eaiser to access, + csegment segcore.CSegment // cached results, to avoid too many CGO calls memSize *atomic.Int64 @@ -300,39 +301,17 @@ func NewSegment(ctx context.Context, return nil, err } - multipleChunkEnable := paramtable.Get().QueryNodeCfg.MultipleChunkedEnable.GetAsBool() - var cSegType C.SegmentType var locker *state.LoadStateLock switch segmentType { case SegmentTypeSealed: - if multipleChunkEnable { - cSegType = C.ChunkedSealed - } else { - cSegType = C.Sealed - } locker = state.NewLoadStateLock(state.LoadStateOnlyMeta) case SegmentTypeGrowing: locker = state.NewLoadStateLock(state.LoadStateDataLoaded) - cSegType = C.Growing default: return nil, fmt.Errorf("illegal segment type %d when create segment %d", segmentType, loadInfo.GetSegmentID()) } - var newPtr C.CSegmentInterface - _, err = GetDynamicPool().Submit(func() (any, error) { - status := C.NewSegment(collection.collectionPtr, cSegType, C.int64_t(loadInfo.GetSegmentID()), &newPtr, C.bool(loadInfo.GetIsSorted())) - err := HandleCStatus(ctx, &status, "NewSegmentFailed", - zap.Int64("collectionID", loadInfo.GetCollectionID()), - zap.Int64("partitionID", loadInfo.GetPartitionID()), - zap.Int64("segmentID", loadInfo.GetSegmentID()), - zap.String("segmentType", segmentType.String())) - return nil, err - }).Await() - if err != nil { - return nil, err - } - - log.Info("create segment", + logger := log.With( zap.Int64("collectionID", loadInfo.GetCollectionID()), zap.Int64("partitionID", loadInfo.GetPartitionID()), zap.Int64("segmentID", loadInfo.GetSegmentID()), @@ -340,10 +319,28 @@ func NewSegment(ctx context.Context, zap.String("level", loadInfo.GetLevel().String()), ) + var csegment segcore.CSegment + if _, err := GetDynamicPool().Submit(func() (any, error) { + var err error + csegment, err = segcore.CreateCSegment(&segcore.CreateCSegmentRequest{ + Collection: collection.ccollection, + SegmentID: loadInfo.GetSegmentID(), + SegmentType: segmentType, + IsSorted: loadInfo.GetIsSorted(), + EnableChunked: paramtable.Get().QueryNodeCfg.MultipleChunkedEnable.GetAsBool(), + }) + return nil, err + }).Await(); err != nil { + logger.Warn("create segment failed", zap.Error(err)) + return nil, err + } + log.Info("create segment done") + segment := &LocalSegment{ baseSegment: base, ptrLock: locker, - ptr: newPtr, + ptr: C.CSegmentInterface(csegment.RawPointer()), + csegment: csegment, lastDeltaTimestamp: atomic.NewUint64(0), fields: typeutil.NewConcurrentMap[int64, *FieldInfo](), fieldIndexes: typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](), @@ -354,6 +351,7 @@ func NewSegment(ctx context.Context, } if err := segment.initializeSegment(); err != nil { + csegment.Release() return nil, err } return segment, nil @@ -424,15 +422,12 @@ func (s *LocalSegment) RowNum() int64 { rowNum := s.rowNum.Load() if rowNum < 0 { - var rowCount C.int64_t GetDynamicPool().Submit(func() (any, error) { - rowCount = C.GetRealCount(s.ptr) - s.rowNum.Store(int64(rowCount)) + rowNum = s.csegment.RowNum() + s.rowNum.Store(rowNum) return nil, nil }).Await() - rowNum = int64(rowCount) } - return rowNum } @@ -444,14 +439,11 @@ func (s *LocalSegment) MemSize() int64 { memSize := s.memSize.Load() if memSize < 0 { - var cMemSize C.int64_t GetDynamicPool().Submit(func() (any, error) { - cMemSize = C.GetMemoryUsageInBytes(s.ptr) - s.memSize.Store(int64(cMemSize)) + memSize = s.csegment.MemSize() + s.memSize.Store(memSize) return nil, nil }).Await() - - memSize = int64(cMemSize) } return memSize } @@ -479,8 +471,7 @@ func (s *LocalSegment) HasRawData(fieldID int64) bool { } defer s.ptrLock.RUnlock() - ret := C.HasRawData(s.ptr, C.int64_t(fieldID)) - return bool(ret) + return s.csegment.HasRawData(fieldID) } func (s *LocalSegment) Indexes() []*IndexedFieldInfo { @@ -498,192 +489,124 @@ func (s *LocalSegment) ResetIndexesLazyLoad(lazyState bool) { } } -func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { - /* - CStatus - Search(void* plan, - void* placeholder_groups, - uint64_t* timestamps, - int num_groups, - long int* result_ids, - float* result_distances); - */ - log := log.Ctx(ctx).With( +func (s *LocalSegment) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) { + log := log.Ctx(ctx).WithLazy( zap.Int64("collectionID", s.Collection()), zap.Int64("segmentID", s.ID()), zap.String("segmentType", s.segmentType.String()), ) + if !s.ptrLock.RLockIf(state.IsNotReleased) { // TODO: check if the segment is readable but not released. too many related logic need to be refactor. return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } defer s.ptrLock.RUnlock() - traceCtx := ParseCTraceContext(ctx) - defer runtime.KeepAlive(traceCtx) - defer runtime.KeepAlive(searchReq) - - hasIndex := s.ExistIndex(searchReq.searchFieldID) + hasIndex := s.ExistIndex(searchReq.SearchFieldID()) log = log.With(zap.Bool("withIndex", hasIndex)) log.Debug("search segment...") tr := timerecord.NewTimeRecorder("cgoSearch") - - future := cgo.Async( - ctx, - func() cgo.CFuturePtr { - return (cgo.CFuturePtr)(C.AsyncSearch( - traceCtx.ctx, - s.ptr, - searchReq.plan.cSearchPlan, - searchReq.cPlaceholderGroup, - C.uint64_t(searchReq.mvccTimestamp), - )) - }, - cgo.WithName("search"), - ) - defer future.Release() - result, err := future.BlockAndLeakyGet() + result, err := s.csegment.Search(ctx, searchReq) if err != nil { log.Warn("Search failed") return nil, err } metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) log.Debug("search segment done") - return &SearchResult{ - cSearchResult: (C.CSearchResult)(result), - }, nil + return result, nil } -func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { +func (s *LocalSegment) retrieve(ctx context.Context, plan *segcore.RetrievePlan, log *zap.Logger) (*segcore.RetrieveResult, error) { if !s.ptrLock.RLockIf(state.IsNotReleased) { // TODO: check if the segment is readable but not released. too many related logic need to be refactor. return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } defer s.ptrLock.RUnlock() - log := log.Ctx(ctx).With( - zap.Int64("collectionID", s.Collection()), - zap.Int64("partitionID", s.Partition()), - zap.Int64("segmentID", s.ID()), - zap.Int64("msgID", plan.msgID), - zap.String("segmentType", s.segmentType.String()), - ) log.Debug("begin to retrieve") - traceCtx := ParseCTraceContext(ctx) - defer runtime.KeepAlive(traceCtx) - defer runtime.KeepAlive(plan) - - maxLimitSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() tr := timerecord.NewTimeRecorder("cgoRetrieve") - - future := cgo.Async( - ctx, - func() cgo.CFuturePtr { - return (cgo.CFuturePtr)(C.AsyncRetrieve( - traceCtx.ctx, - s.ptr, - plan.cRetrievePlan, - C.uint64_t(plan.Timestamp), - C.int64_t(maxLimitSize), - C.bool(plan.ignoreNonPk), - )) - }, - cgo.WithName("retrieve"), - ) - defer future.Release() - result, err := future.BlockAndLeakyGet() + result, err := s.csegment.Retrieve(ctx, plan) if err != nil { log.Warn("Retrieve failed") return nil, err } - defer C.DeleteRetrieveResult((*C.CRetrieveResult)(result)) - metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) + return result, nil +} + +func (s *LocalSegment) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) { + log := log.Ctx(ctx).WithLazy( + zap.Int64("collectionID", s.Collection()), + zap.Int64("partitionID", s.Partition()), + zap.Int64("segmentID", s.ID()), + zap.Int64("msgID", plan.MsgID()), + zap.String("segmentType", s.segmentType.String()), + ) + + result, err := s.retrieve(ctx, plan, log) + if err != nil { + return nil, err + } + defer result.Release() _, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "partial-segcore-results-deserialization") defer span.End() - retrieveResult := new(segcorepb.RetrieveResults) - if err := UnmarshalCProto((*C.CRetrieveResult)(result), retrieveResult); err != nil { + retrieveResult, err := result.GetResult() + if err != nil { log.Warn("unmarshal retrieve result failed", zap.Error(err)) return nil, err } - - log.Debug("retrieve segment done", - zap.Int("resultNum", len(retrieveResult.Offset)), - ) - // Sort was done by the segcore. - // sort.Sort(&byPK{result}) + log.Debug("retrieve segment done", zap.Int("resultNum", len(retrieveResult.Offset))) return retrieveResult, nil } -func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) { - if len(offsets) == 0 { - return nil, merr.WrapErrParameterInvalid("segment offsets", "empty offsets") - } - +func (s *LocalSegment) retrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets, log *zap.Logger) (*segcore.RetrieveResult, error) { if !s.ptrLock.RLockIf(state.IsNotReleased) { // TODO: check if the segment is readable but not released. too many related logic need to be refactor. return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") } defer s.ptrLock.RUnlock() - fields := []zap.Field{ - zap.Int64("collectionID", s.Collection()), - zap.Int64("partitionID", s.Partition()), - zap.Int64("segmentID", s.ID()), - zap.Int64("msgID", plan.msgID), - zap.String("segmentType", s.segmentType.String()), - zap.Int("resultNum", len(offsets)), - } - - log := log.Ctx(ctx).With(fields...) log.Debug("begin to retrieve by offsets") tr := timerecord.NewTimeRecorder("cgoRetrieveByOffsets") - traceCtx := ParseCTraceContext(ctx) - defer runtime.KeepAlive(traceCtx) - defer runtime.KeepAlive(plan) - defer runtime.KeepAlive(offsets) - - future := cgo.Async( - ctx, - func() cgo.CFuturePtr { - return (cgo.CFuturePtr)(C.AsyncRetrieveByOffsets( - traceCtx.ctx, - s.ptr, - plan.cRetrievePlan, - (*C.int64_t)(unsafe.Pointer(&offsets[0])), - C.int64_t(len(offsets)), - )) - }, - cgo.WithName("retrieve-by-offsets"), - ) - defer future.Release() - result, err := future.BlockAndLeakyGet() + result, err := s.csegment.RetrieveByOffsets(ctx, plan) if err != nil { log.Warn("RetrieveByOffsets failed") return nil, err } - defer C.DeleteRetrieveResult((*C.CRetrieveResult)(result)) - metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) + return result, nil +} + +func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error) { + log := log.Ctx(ctx).WithLazy(zap.Int64("collectionID", s.Collection()), + zap.Int64("partitionID", s.Partition()), + zap.Int64("segmentID", s.ID()), + zap.Int64("msgID", plan.MsgID()), + zap.String("segmentType", s.segmentType.String()), + zap.Int("resultNum", len(plan.Offsets)), + ) + + result, err := s.retrieveByOffsets(ctx, plan, log) + if err != nil { + return nil, err + } + defer result.Release() _, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "reduced-segcore-results-deserialization") defer span.End() - retrieveResult := new(segcorepb.RetrieveResults) - if err := UnmarshalCProto((*C.CRetrieveResult)(result), retrieveResult); err != nil { + retrieveResult, err := result.GetResult() + if err != nil { log.Warn("unmarshal retrieve by offsets result failed", zap.Error(err)) return nil, err } - - log.Debug("retrieve by segment offsets done", - zap.Int("resultNum", len(retrieveResult.Offset)), - ) + log.Debug("retrieve by segment offsets done", zap.Int("resultNum", len(retrieveResult.Offset))) return retrieveResult, nil } @@ -700,26 +623,6 @@ func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) ( return dataPath, offsetInBinlog } -// -------------------------------------------------------------------------------------- interfaces for growing segment -func (s *LocalSegment) preInsert(ctx context.Context, numOfRecords int) (int64, error) { - /* - long int - PreInsert(CSegmentInterface c_segment, long int size); - */ - var offset int64 - cOffset := (*C.int64_t)(&offset) - - var status C.CStatus - GetDynamicPool().Submit(func() (any, error) { - status = C.PreInsert(s.ptr, C.int64_t(int64(numOfRecords)), cOffset) - return nil, nil - }).Await() - if err := HandleCStatus(ctx, &status, "PreInsert failed"); err != nil { - return 0, err - } - return offset, nil -} - func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []typeutil.Timestamp, record *segcorepb.InsertRecord) error { if s.Type() != SegmentTypeGrowing { return fmt.Errorf("unexpected segmentType when segmentInsert, segmentType = %s", s.segmentType.String()) @@ -729,24 +632,8 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps [] } defer s.ptrLock.RUnlock() - offset, err := s.preInsert(ctx, len(rowIDs)) - if err != nil { - return err - } - - insertRecordBlob, err := proto.Marshal(record) - if err != nil { - return fmt.Errorf("failed to marshal insert record: %s", err) - } - - numOfRow := len(rowIDs) - cOffset := C.int64_t(offset) - cNumOfRows := C.int64_t(numOfRow) - cEntityIDsPtr := (*C.int64_t)(&(rowIDs)[0]) - cTimestampsPtr := (*C.uint64_t)(&(timestamps)[0]) - - var status C.CStatus - + var result *segcore.InsertResult + var err error GetDynamicPool().Submit(func() (any, error) { start := time.Now() defer func() { @@ -756,21 +643,19 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps [] "Sync", ).Observe(float64(time.Since(start).Milliseconds())) }() - status = C.Insert(s.ptr, - cOffset, - cNumOfRows, - cEntityIDsPtr, - cTimestampsPtr, - (*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])), - (C.uint64_t)(len(insertRecordBlob)), - ) + + result, err = s.csegment.Insert(ctx, &segcore.InsertRequest{ + RowIDs: rowIDs, + Timestamps: timestamps, + Record: record, + }) return nil, nil }).Await() - if err := HandleCStatus(ctx, &status, "Insert failed"); err != nil { + + if err != nil { return err } - - s.insertCount.Add(int64(numOfRow)) + s.insertCount.Add(int64(result.InsertedRows)) s.rowNum.Store(-1) s.memSize.Store(-1) return nil @@ -794,20 +679,7 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys storage.PrimaryKe } defer s.ptrLock.RUnlock() - cOffset := C.int64_t(0) // depre - cSize := C.int64_t(primaryKeys.Len()) - cTimestampsPtr := (*C.uint64_t)(&(timestamps)[0]) - - ids, err := storage.ParsePrimaryKeysBatch2IDs(primaryKeys) - if err != nil { - return err - } - - dataBlob, err := proto.Marshal(ids) - if err != nil { - return fmt.Errorf("failed to marshal ids: %s", err) - } - var status C.CStatus + var err error GetDynamicPool().Submit(func() (any, error) { start := time.Now() defer func() { @@ -817,23 +689,19 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys storage.PrimaryKe "Sync", ).Observe(float64(time.Since(start).Milliseconds())) }() - status = C.Delete(s.ptr, - cOffset, - cSize, - (*C.uint8_t)(unsafe.Pointer(&dataBlob[0])), - (C.uint64_t)(len(dataBlob)), - cTimestampsPtr, - ) + _, err = s.csegment.Delete(ctx, &segcore.DeleteRequest{ + PrimaryKeys: primaryKeys, + Timestamps: timestamps, + }) return nil, nil }).Await() - if err := HandleCStatus(ctx, &status, "Delete failed"); err != nil { + if err != nil { return err } s.rowNum.Store(-1) s.lastDeltaTimestamp.Store(timestamps[len(timestamps)-1]) - return nil } @@ -854,30 +722,17 @@ func (s *LocalSegment) LoadMultiFieldData(ctx context.Context) error { zap.Int64("segmentID", s.ID()), ) - loadFieldDataInfo, err := newLoadFieldDataInfo(ctx) - defer deleteFieldDataInfo(loadFieldDataInfo) - if err != nil { - return err + req := &segcore.LoadFieldDataRequest{ + MMapDir: paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue(), + RowCount: rowCount, } - for _, field := range fields { - fieldID := field.FieldID - err = loadFieldDataInfo.appendLoadFieldInfo(ctx, fieldID, rowCount) - if err != nil { - return err - } - - for _, binlog := range field.Binlogs { - err = loadFieldDataInfo.appendLoadFieldDataPath(ctx, fieldID, binlog) - if err != nil { - return err - } - } - - loadFieldDataInfo.appendMMapDirPath(paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue()) + req.Fields = append(req.Fields, segcore.LoadFieldDataInfo{ + Field: field, + }) } - var status C.CStatus + var err error GetLoadPool().Submit(func() (any, error) { start := time.Now() defer func() { @@ -887,20 +742,15 @@ func (s *LocalSegment) LoadMultiFieldData(ctx context.Context) error { "Sync", ).Observe(float64(time.Since(start).Milliseconds())) }() - status = C.LoadFieldData(s.ptr, loadFieldDataInfo.cLoadFieldDataInfo) + _, err = s.csegment.LoadFieldData(ctx, req) return nil, nil }).Await() - if err := HandleCStatus(ctx, &status, "LoadMultiFieldData failed", - zap.Int64("collectionID", s.Collection()), - zap.Int64("partitionID", s.Partition()), - zap.Int64("segmentID", s.ID())); err != nil { + if err != nil { + log.Warn("LoadMultiFieldData failed", zap.Error(err)) return err } - log.Info("load mutil field done", - zap.Int64("row count", rowCount), - zap.Int64("segmentID", s.ID())) - + log.Info("load mutil field done", zap.Int64("row count", rowCount), zap.Int64("segmentID", s.ID())) return nil } @@ -922,26 +772,6 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun ) log.Info("start loading field data for field") - loadFieldDataInfo, err := newLoadFieldDataInfo(ctx) - if err != nil { - return err - } - defer deleteFieldDataInfo(loadFieldDataInfo) - - err = loadFieldDataInfo.appendLoadFieldInfo(ctx, fieldID, rowCount) - if err != nil { - return err - } - - if field != nil { - for _, binlog := range field.Binlogs { - err = loadFieldDataInfo.appendLoadFieldDataPath(ctx, fieldID, binlog) - if err != nil { - return err - } - } - } - // TODO retrieve_enable should be considered collection := s.collection fieldSchema, err := getFieldSchema(collection.Schema(), fieldID) @@ -949,10 +779,15 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun return err } mmapEnabled := isDataMmapEnable(fieldSchema) - loadFieldDataInfo.appendMMapDirPath(paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue()) - loadFieldDataInfo.enableMmap(fieldID, mmapEnabled) + req := &segcore.LoadFieldDataRequest{ + MMapDir: paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue(), + Fields: []segcore.LoadFieldDataInfo{{ + Field: field, + EnableMMap: mmapEnabled, + }}, + RowCount: rowCount, + } - var status C.CStatus GetLoadPool().Submit(func() (any, error) { start := time.Now() defer func() { @@ -962,20 +797,16 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun "Sync", ).Observe(float64(time.Since(start).Milliseconds())) }() + _, err = s.csegment.LoadFieldData(ctx, req) log.Info("submitted loadFieldData task to load pool") - status = C.LoadFieldData(s.ptr, loadFieldDataInfo.cLoadFieldDataInfo) return nil, nil }).Await() - if err := HandleCStatus(ctx, &status, "LoadFieldData failed", - zap.Int64("collectionID", s.Collection()), - zap.Int64("partitionID", s.Partition()), - zap.Int64("segmentID", s.ID()), - zap.Int64("fieldID", fieldID)); err != nil { + + if err != nil { + log.Warn("LoadFieldData failed", zap.Error(err)) return err } - log.Info("load field done") - return nil } @@ -985,46 +816,33 @@ func (s *LocalSegment) AddFieldDataInfo(ctx context.Context, rowCount int64, fie } defer s.ptrLock.RUnlock() - log := log.Ctx(ctx).With( + log := log.Ctx(ctx).WithLazy( zap.Int64("collectionID", s.Collection()), zap.Int64("partitionID", s.Partition()), zap.Int64("segmentID", s.ID()), zap.Int64("row count", rowCount), ) - loadFieldDataInfo, err := newLoadFieldDataInfo(ctx) - if err != nil { - return err + req := &segcore.AddFieldDataInfoRequest{ + Fields: make([]segcore.LoadFieldDataInfo, 0, len(fields)), + RowCount: rowCount, } - defer deleteFieldDataInfo(loadFieldDataInfo) - for _, field := range fields { - fieldID := field.FieldID - err = loadFieldDataInfo.appendLoadFieldInfo(ctx, fieldID, rowCount) - if err != nil { - return err - } - - for _, binlog := range field.Binlogs { - err = loadFieldDataInfo.appendLoadFieldDataPath(ctx, fieldID, binlog) - if err != nil { - return err - } - } + req.Fields = append(req.Fields, segcore.LoadFieldDataInfo{ + Field: field, + }) } - var status C.CStatus + var err error GetLoadPool().Submit(func() (any, error) { - status = C.AddFieldDataInfoForSealed(s.ptr, loadFieldDataInfo.cLoadFieldDataInfo) + _, err = s.csegment.AddFieldDataInfo(ctx, req) return nil, nil }).Await() - if err := HandleCStatus(ctx, &status, "AddFieldDataInfo failed", - zap.Int64("collectionID", s.Collection()), - zap.Int64("partitionID", s.Partition()), - zap.Int64("segmentID", s.ID())); err != nil { + + if err != nil { + log.Warn("AddFieldDataInfo failed", zap.Error(err)) return err } - log.Info("add field data info done") return nil } @@ -1456,7 +1274,7 @@ func (s *LocalSegment) Release(ctx context.Context, opts ...releaseOption) { C.DeleteSegment(ptr) - localDiskUsage, err := GetLocalUsedSize(context.Background(), paramtable.Get().LocalStorageCfg.Path.GetValue()) + localDiskUsage, err := segcore.GetLocalUsedSize(context.Background(), paramtable.Get().LocalStorageCfg.Path.GetValue()) // ignore error here, shall not block releasing if err == nil { metrics.QueryNodeDiskUsedSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(localDiskUsage) / 1024 / 1024) // in MB diff --git a/internal/querynodev2/segments/segment_interface.go b/internal/querynodev2/segments/segment_interface.go index 7df85a85cf302..400886ccd5edf 100644 --- a/internal/querynodev2/segments/segment_interface.go +++ b/internal/querynodev2/segments/segment_interface.go @@ -24,6 +24,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -92,9 +93,9 @@ type Segment interface { GetBM25Stats() map[int64]*storage.BM25Stats // Read operations - Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) - Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) - RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) + Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) + Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) + RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error) IsLazyLoad() bool ResetIndexesLazyLoad(lazyState bool) diff --git a/internal/querynodev2/segments/segment_l0.go b/internal/querynodev2/segments/segment_l0.go index da8af518bf452..cab1f64b7645a 100644 --- a/internal/querynodev2/segments/segment_l0.go +++ b/internal/querynodev2/segments/segment_l0.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/segcorepb" storage "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -131,15 +132,15 @@ func (s *L0Segment) Level() datapb.SegmentLevel { return datapb.SegmentLevel_L0 } -func (s *L0Segment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { +func (s *L0Segment) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) { return nil, nil } -func (s *L0Segment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { +func (s *L0Segment) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) { return nil, nil } -func (s *L0Segment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) { +func (s *L0Segment) RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error) { return nil, nil } diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index eb9b513e8adec..72d0f72911095 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -46,6 +46,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -452,7 +453,7 @@ func (loader *segmentLoader) requestResource(ctx context.Context, infos ...*quer memoryUsage := hardware.GetUsedMemoryCount() totalMemory := hardware.GetMemoryCount() - diskUsage, err := GetLocalUsedSize(ctx, paramtable.Get().LocalStorageCfg.Path.GetValue()) + diskUsage, err := segcore.GetLocalUsedSize(ctx, paramtable.Get().LocalStorageCfg.Path.GetValue()) if err != nil { return result, errors.Wrap(err, "get local used size failed") } @@ -1365,7 +1366,7 @@ func (loader *segmentLoader) checkSegmentSize(ctx context.Context, segmentLoadIn return 0, 0, errors.New("get memory failed when checkSegmentSize") } - localDiskUsage, err := GetLocalUsedSize(ctx, paramtable.Get().LocalStorageCfg.Path.GetValue()) + localDiskUsage, err := segcore.GetLocalUsedSize(ctx, paramtable.Get().LocalStorageCfg.Path.GetValue()) if err != nil { return 0, 0, errors.Wrap(err, "get local used size failed") } diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index 79f8206ae1ba3..4b87d4a182afc 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" @@ -84,8 +85,8 @@ func (suite *SegmentLoaderSuite) SetupTest() { initcore.InitRemoteChunkManager(paramtable.Get()) // Data - suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64, false) - indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema) + suite.schema = mock_segcore.GenTestCollectionSchema("test", schemapb.DataType_Int64, false) + indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, suite.schema) loadMeta := &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, CollectionID: suite.collectionID, @@ -100,8 +101,8 @@ func (suite *SegmentLoaderSuite) SetupBM25() { suite.loader = NewLoader(suite.manager, suite.chunkManager) initcore.InitRemoteChunkManager(paramtable.Get()) - suite.schema = GenTestBM25CollectionSchema("test") - indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema) + suite.schema = mock_segcore.GenTestBM25CollectionSchema("test") + indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, suite.schema) loadMeta := &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, CollectionID: suite.collectionID, @@ -124,7 +125,7 @@ func (suite *SegmentLoaderSuite) TestLoad() { msgLength := 4 // Load sealed - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID, @@ -146,7 +147,7 @@ func (suite *SegmentLoaderSuite) TestLoad() { suite.NoError(err) // Load growing - binlogs, statsLogs, err = SaveBinLog(ctx, + binlogs, statsLogs, err = mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID+1, @@ -174,7 +175,7 @@ func (suite *SegmentLoaderSuite) TestLoadFail() { msgLength := 4 // Load sealed - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID, @@ -211,7 +212,7 @@ func (suite *SegmentLoaderSuite) TestLoadMultipleSegments() { // Load sealed for i := 0; i < suite.segmentNum; i++ { segmentID := suite.segmentID + int64(i) - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, segmentID, @@ -247,7 +248,7 @@ func (suite *SegmentLoaderSuite) TestLoadMultipleSegments() { loadInfos = loadInfos[:0] for i := 0; i < suite.segmentNum; i++ { segmentID := suite.segmentID + int64(suite.segmentNum) + int64(i) - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, segmentID, @@ -287,7 +288,7 @@ func (suite *SegmentLoaderSuite) TestLoadWithIndex() { // Load sealed for i := 0; i < suite.segmentNum; i++ { segmentID := suite.segmentID + int64(i) - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, segmentID, @@ -298,13 +299,13 @@ func (suite *SegmentLoaderSuite) TestLoadWithIndex() { suite.NoError(err) vecFields := funcutil.GetVecFieldIDs(suite.schema) - indexInfo, err := GenAndSaveIndex( + indexInfo, err := mock_segcore.GenAndSaveIndex( suite.collectionID, suite.partitionID, segmentID, vecFields[0], msgLength, - IndexFaissIVFFlat, + mock_segcore.IndexFaissIVFFlat, metric.L2, suite.chunkManager, ) @@ -338,7 +339,7 @@ func (suite *SegmentLoaderSuite) TestLoadBloomFilter() { // Load sealed for i := 0; i < suite.segmentNum; i++ { segmentID := suite.segmentID + int64(i) - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, segmentID, @@ -379,7 +380,7 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() { // Load sealed for i := 0; i < suite.segmentNum; i++ { segmentID := suite.segmentID + int64(i) - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, segmentID, @@ -390,7 +391,7 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() { suite.NoError(err) // Delete PKs 1, 2 - deltaLogs, err := SaveDeltaLog(suite.collectionID, + deltaLogs, err := mock_segcore.SaveDeltaLog(suite.collectionID, suite.partitionID, segmentID, suite.chunkManager, @@ -428,13 +429,13 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() { func (suite *SegmentLoaderSuite) TestLoadBm25Stats() { suite.SetupBM25() msgLength := 1 - sparseFieldID := simpleSparseFloatVectorField.id + sparseFieldID := mock_segcore.SimpleSparseFloatVectorField.ID loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum) for i := 0; i < suite.segmentNum; i++ { segmentID := suite.segmentID + int64(i) - bm25logs, err := SaveBM25Log(suite.collectionID, suite.partitionID, segmentID, sparseFieldID, msgLength, suite.chunkManager) + bm25logs, err := mock_segcore.SaveBM25Log(suite.collectionID, suite.partitionID, segmentID, sparseFieldID, msgLength, suite.chunkManager) suite.NoError(err) loadInfos = append(loadInfos, &querypb.SegmentLoadInfo{ @@ -468,7 +469,7 @@ func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() { // Load sealed for i := 0; i < suite.segmentNum; i++ { segmentID := suite.segmentID + int64(i) - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, segmentID, @@ -479,7 +480,7 @@ func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() { suite.NoError(err) // Delete PKs 1, 2 - deltaLogs, err := SaveDeltaLog(suite.collectionID, + deltaLogs, err := mock_segcore.SaveDeltaLog(suite.collectionID, suite.partitionID, segmentID, suite.chunkManager, @@ -602,7 +603,7 @@ func (suite *SegmentLoaderSuite) TestLoadWithMmap() { msgLength := 100 // Load sealed - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID, @@ -629,7 +630,7 @@ func (suite *SegmentLoaderSuite) TestPatchEntryNum() { msgLength := 100 segmentID := suite.segmentID - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, segmentID, @@ -640,13 +641,13 @@ func (suite *SegmentLoaderSuite) TestPatchEntryNum() { suite.NoError(err) vecFields := funcutil.GetVecFieldIDs(suite.schema) - indexInfo, err := GenAndSaveIndex( + indexInfo, err := mock_segcore.GenAndSaveIndex( suite.collectionID, suite.partitionID, segmentID, vecFields[0], msgLength, - IndexFaissIVFFlat, + mock_segcore.IndexFaissIVFFlat, metric.L2, suite.chunkManager, ) @@ -690,7 +691,7 @@ func (suite *SegmentLoaderSuite) TestRunOutMemory() { msgLength := 4 // Load sealed - binlogs, statsLogs, err := SaveBinLog(ctx, + binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID, @@ -712,7 +713,7 @@ func (suite *SegmentLoaderSuite) TestRunOutMemory() { suite.Error(err) // Load growing - binlogs, statsLogs, err = SaveBinLog(ctx, + binlogs, statsLogs, err = mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID+1, @@ -782,7 +783,7 @@ func (suite *SegmentLoaderDetailSuite) SetupSuite() { suite.partitionID = rand.Int63() suite.segmentID = rand.Int63() suite.segmentNum = 5 - suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64, false) + suite.schema = mock_segcore.GenTestCollectionSchema("test", schemapb.DataType_Int64, false) } func (suite *SegmentLoaderDetailSuite) SetupTest() { @@ -801,9 +802,9 @@ func (suite *SegmentLoaderDetailSuite) SetupTest() { initcore.InitRemoteChunkManager(paramtable.Get()) // Data - schema := GenTestCollectionSchema("test", schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema("test", schemapb.DataType_Int64, false) - indexMeta := GenTestIndexMeta(suite.collectionID, schema) + indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, schema) loadMeta := &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, CollectionID: suite.collectionID, diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index fe489d9bf6416..f35a2d513a0f5 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" storage "github.com/milvus-io/milvus/internal/storage" @@ -54,8 +55,8 @@ func (suite *SegmentSuite) SetupTest() { suite.segmentID = 1 suite.manager = NewManager() - schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) - indexMeta := GenTestIndexMeta(suite.collectionID, schema) + schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) + indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, schema) suite.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, @@ -93,7 +94,7 @@ func (suite *SegmentSuite) SetupTest() { ) suite.Require().NoError(err) - binlogs, _, err := SaveBinLog(ctx, + binlogs, _, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID, @@ -124,7 +125,7 @@ func (suite *SegmentSuite) SetupTest() { ) suite.Require().NoError(err) - insertMsg, err := genInsertMsg(suite.collection, suite.partitionID, suite.growing.ID(), msgLength) + insertMsg, err := mock_segcore.GenInsertMsg(suite.collection.GetCCollection(), suite.partitionID, suite.growing.ID(), msgLength) suite.Require().NoError(err) insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg) suite.Require().NoError(err) @@ -187,9 +188,9 @@ func (suite *SegmentSuite) TestDelete() { } func (suite *SegmentSuite) TestHasRawData() { - has := suite.growing.HasRawData(simpleFloatVecField.id) + has := suite.growing.HasRawData(mock_segcore.SimpleFloatVecField.ID) suite.True(has) - has = suite.sealed.HasRawData(simpleFloatVecField.id) + has = suite.sealed.HasRawData(mock_segcore.SimpleFloatVecField.ID) suite.True(has) } diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 17fb05cba416c..9bc9bf3de2a06 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -60,6 +60,7 @@ import ( "github.com/milvus-io/milvus/internal/util/initcore" "github.com/milvus-io/milvus/internal/util/searchutil/optimizers" "github.com/milvus-io/milvus/internal/util/searchutil/scheduler" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/log" @@ -323,7 +324,7 @@ func (node *QueryNode) Init() error { node.factory.Init(paramtable.Get()) localRootPath := paramtable.Get().LocalStorageCfg.Path.GetValue() - localUsedSize, err := segments.GetLocalUsedSize(node.ctx, localRootPath) + localUsedSize, err := segcore.GetLocalUsedSize(node.ctx, localRootPath) if err != nil { log.Warn("get local used size failed", zap.Error(err)) initError = err diff --git a/internal/querynodev2/server_test.go b/internal/querynodev2/server_test.go index 37070bbae9786..6800e9c4c8c0f 100644 --- a/internal/querynodev2/server_test.go +++ b/internal/querynodev2/server_test.go @@ -32,12 +32,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" + "github.com/milvus-io/milvus/internal/mocks/util/searchutil/mock_optimizers" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" - "github.com/milvus-io/milvus/internal/util/searchutil/optimizers" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -159,7 +160,7 @@ func (suite *QueryNodeSuite) TestInit_QueryHook() { err = suite.node.Init() suite.NoError(err) - mockHook := optimizers.NewMockQueryHook(suite.T()) + mockHook := mock_optimizers.NewMockQueryHook(suite.T()) suite.node.queryHook = mockHook suite.node.handleQueryHookEvent() @@ -219,7 +220,7 @@ func (suite *QueryNodeSuite) TestStop() { suite.node.manager = segments.NewManager() - schema := segments.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true) + schema := mock_segcore.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true) collection := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 3f6364a217c62..075290bf0543c 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" @@ -262,8 +263,8 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() { ctx := context.Background() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) - deltaLogs, err := segments.SaveDeltaLog(suite.collectionID, + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + deltaLogs, err := mock_segcore.SaveDeltaLog(suite.collectionID, suite.partitionIDs[0], suite.flushedSegmentIDs[0], suite.node.chunkManager, @@ -306,7 +307,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() { PartitionIDs: suite.partitionIDs, MetricType: defaultMetricType, }, - IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema), + IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema), } // mocks @@ -331,7 +332,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() { ctx := context.Background() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar, false) req := &querypb.WatchDmChannelsRequest{ Base: &commonpb.MsgBase{ @@ -358,7 +359,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() { PartitionIDs: suite.partitionIDs, MetricType: defaultMetricType, }, - IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema), + IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema), } // mocks @@ -383,9 +384,9 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() { ctx := context.Background() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) - indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) + indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema) infos := suite.genSegmentLoadInfos(schema, indexInfos) segmentInfos := lo.SliceToMap(infos, func(info *querypb.SegmentLoadInfo) (int64, *datapb.SegmentInfo) { @@ -544,7 +545,7 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema partNum := len(suite.partitionIDs) infos := make([]*querypb.SegmentLoadInfo, 0) for i := 0; i < segNum; i++ { - binlogs, statslogs, err := segments.SaveBinLog(ctx, + binlogs, statslogs, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionIDs[i%partNum], suite.validSegmentIDs[i], @@ -559,7 +560,7 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema for offset, field := range vectorFieldSchemas { indexInfo := lo.FindOrElse(indexInfos, nil, func(info *indexpb.IndexInfo) bool { return info.FieldID == field.GetFieldID() }) if indexInfo != nil { - index, err := segments.GenAndSaveIndexV2( + index, err := mock_segcore.GenAndSaveIndexV2( suite.collectionID, suite.partitionIDs[i%partNum], suite.validSegmentIDs[i], @@ -595,8 +596,8 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() { ctx := context.Background() suite.TestWatchDmChannelsInt64() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) - indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema) infos := suite.genSegmentLoadInfos(schema, indexInfos) for _, info := range infos { req := &querypb.LoadSegmentsRequest{ @@ -624,7 +625,7 @@ func (suite *ServiceSuite) TestLoadSegments_VarChar() { ctx := context.Background() suite.TestWatchDmChannelsVarchar() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar, false) loadMeta := &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, CollectionID: suite.collectionID, @@ -661,7 +662,7 @@ func (suite *ServiceSuite) TestLoadDeltaInt64() { ctx := context.Background() suite.TestLoadSegments_Int64() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgID: rand.Int63(), @@ -686,7 +687,7 @@ func (suite *ServiceSuite) TestLoadDeltaVarchar() { ctx := context.Background() suite.TestLoadSegments_VarChar() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgID: rand.Int63(), @@ -711,9 +712,9 @@ func (suite *ServiceSuite) TestLoadIndex_Success() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) - indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) + indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema) infos := suite.genSegmentLoadInfos(schema, indexInfos) infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo { info.SegmentID = info.SegmentID + 1000 @@ -782,10 +783,10 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) suite.Run("load_non_exist_segment", func() { - indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) + indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema) infos := suite.genSegmentLoadInfos(schema, indexInfos) infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo { info.SegmentID = info.SegmentID + 1000 @@ -828,7 +829,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() { mockLoader.EXPECT().LoadIndex(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mocked error")) - indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) + indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema) infos := suite.genSegmentLoadInfos(schema, indexInfos) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ @@ -854,7 +855,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() { func (suite *ServiceSuite) TestLoadSegments_Failed() { ctx := context.Background() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgID: rand.Int63(), @@ -901,7 +902,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() { delegator.EXPECT().TryCleanExcludedSegments(mock.Anything).Maybe() delegator.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).Return(nil) // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgID: rand.Int63(), @@ -923,7 +924,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() { suite.Run("delegator_not_found", func() { // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgID: rand.Int63(), @@ -953,7 +954,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() { delegator.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")). Return(errors.New("mocked error")) // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ MsgID: rand.Int63(), @@ -1245,7 +1246,7 @@ func (suite *ServiceSuite) TestSearch_Failed() { ctx := context.Background() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType", false) req := &querypb.SearchRequest{ Req: creq, @@ -1267,7 +1268,7 @@ func (suite *ServiceSuite) TestSearch_Failed() { CollectionID: suite.collectionID, PartitionIDs: suite.partitionIDs, } - indexMeta := suite.node.composeIndexMeta(segments.GenTestIndexInfoList(suite.collectionID, schema), schema) + indexMeta := suite.node.composeIndexMeta(mock_segcore.GenTestIndexInfoList(suite.collectionID, schema), schema) suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, LoadMeta) // Delegator not found @@ -1459,7 +1460,7 @@ func (suite *ServiceSuite) TestQuery_Normal() { suite.TestLoadSegments_Int64() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) suite.NoError(err) req := &querypb.QueryRequest{ @@ -1478,7 +1479,7 @@ func (suite *ServiceSuite) TestQuery_Failed() { defer cancel() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) suite.NoError(err) req := &querypb.QueryRequest{ @@ -1540,7 +1541,7 @@ func (suite *ServiceSuite) TestQueryStream_Normal() { suite.TestLoadSegments_Int64() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) suite.NoError(err) req := &querypb.QueryRequest{ @@ -1575,7 +1576,7 @@ func (suite *ServiceSuite) TestQueryStream_Failed() { defer cancel() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) suite.NoError(err) req := &querypb.QueryRequest{ @@ -1653,7 +1654,7 @@ func (suite *ServiceSuite) TestQuerySegments_Normal() { suite.TestLoadSegments_Int64() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) suite.NoError(err) req := &querypb.QueryRequest{ @@ -1675,7 +1676,7 @@ func (suite *ServiceSuite) TestQueryStreamSegments_Normal() { suite.TestLoadSegments_Int64() // data - schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) suite.NoError(err) req := &querypb.QueryRequest{ diff --git a/internal/querynodev2/tasks/query_stream_task.go b/internal/querynodev2/tasks/query_stream_task.go index e24c755373b11..cdb3c3c2e99ba 100644 --- a/internal/querynodev2/tasks/query_stream_task.go +++ b/internal/querynodev2/tasks/query_stream_task.go @@ -7,6 +7,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/util/searchutil/scheduler" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/internal/util/streamrpc" ) @@ -59,9 +60,8 @@ func (t *QueryStreamTask) PreExecute() error { } func (t *QueryStreamTask) Execute() error { - retrievePlan, err := segments.NewRetrievePlan( - t.ctx, - t.collection, + retrievePlan, err := segcore.NewRetrievePlan( + t.collection.GetCCollection(), t.req.Req.GetSerializedExprPlan(), t.req.Req.GetMvccTimestamp(), t.req.Req.Base.GetMsgID(), diff --git a/internal/querynodev2/tasks/query_task.go b/internal/querynodev2/tasks/query_task.go index da4f18f72001e..7099e83defc22 100644 --- a/internal/querynodev2/tasks/query_task.go +++ b/internal/querynodev2/tasks/query_task.go @@ -16,6 +16,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/util/searchutil/scheduler" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -100,9 +101,8 @@ func (t *QueryTask) Execute() error { } tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "QueryTask") - retrievePlan, err := segments.NewRetrievePlan( - t.ctx, - t.collection, + retrievePlan, err := segcore.NewRetrievePlan( + t.collection.GetCCollection(), t.req.Req.GetSerializedExprPlan(), t.req.Req.GetMvccTimestamp(), t.req.Req.Base.GetMsgID(), diff --git a/internal/querynodev2/tasks/search_task.go b/internal/querynodev2/tasks/search_task.go index 39f25f542f0d6..9c16c13253dfd 100644 --- a/internal/querynodev2/tasks/search_task.go +++ b/internal/querynodev2/tasks/search_task.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/util/searchutil/scheduler" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -145,7 +146,7 @@ func (t *SearchTask) Execute() error { if err != nil { return err } - searchReq, err := segments.NewSearchRequest(t.ctx, t.collection, req, t.placeholderGroup) + searchReq, err := segcore.NewSearchRequest(t.collection.GetCCollection(), req, t.placeholderGroup) if err != nil { return err } @@ -215,7 +216,7 @@ func (t *SearchTask) Execute() error { }, 0) tr.RecordSpan() - blobs, err := segments.ReduceSearchResultsAndFillData( + blobs, err := segcore.ReduceSearchResultsAndFillData( t.ctx, searchReq.Plan(), results, @@ -227,7 +228,7 @@ func (t *SearchTask) Execute() error { log.Warn("failed to reduce search results", zap.Error(err)) return err } - defer segments.DeleteSearchResultDataBlobs(blobs) + defer segcore.DeleteSearchResultDataBlobs(blobs) metrics.QueryNodeReduceLatency.WithLabelValues( fmt.Sprint(t.GetNodeID()), metrics.SearchLabel, @@ -235,7 +236,7 @@ func (t *SearchTask) Execute() error { metrics.BatchReduce). Observe(float64(tr.RecordSpan().Milliseconds())) for i := range t.originNqs { - blob, err := segments.GetSearchResultDataBlob(t.ctx, blobs, i) + blob, err := segcore.GetSearchResultDataBlob(t.ctx, blobs, i) if err != nil { return err } @@ -385,8 +386,8 @@ func (t *SearchTask) combinePlaceHolderGroups() error { type StreamingSearchTask struct { SearchTask others []*StreamingSearchTask - resultBlobs segments.SearchResultDataBlobs - streamReducer segments.StreamSearchReducer + resultBlobs segcore.SearchResultDataBlobs + streamReducer segcore.StreamSearchReducer } func NewStreamingSearchTask(ctx context.Context, @@ -433,7 +434,7 @@ func (t *StreamingSearchTask) Execute() error { tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "SearchTask") req := t.req t.combinePlaceHolderGroups() - searchReq, err := segments.NewSearchRequest(t.ctx, t.collection, req, t.placeholderGroup) + searchReq, err := segcore.NewSearchRequest(t.collection.GetCCollection(), req, t.placeholderGroup) if err != nil { return err } @@ -455,14 +456,14 @@ func (t *StreamingSearchTask) Execute() error { nil, req.GetSegmentIDs(), streamReduceFunc) - defer segments.DeleteStreamReduceHelper(t.streamReducer) + defer segcore.DeleteStreamReduceHelper(t.streamReducer) defer t.segmentManager.Segment.Unpin(pinnedSegments) if err != nil { log.Error("Failed to search sealed segments streamly", zap.Error(err)) return err } - t.resultBlobs, err = segments.GetStreamReduceResult(t.ctx, t.streamReducer) - defer segments.DeleteSearchResultDataBlobs(t.resultBlobs) + t.resultBlobs, err = segcore.GetStreamReduceResult(t.ctx, t.streamReducer) + defer segcore.DeleteSearchResultDataBlobs(t.resultBlobs) if err != nil { log.Error("Failed to get stream-reduced search result") return err @@ -488,7 +489,7 @@ func (t *StreamingSearchTask) Execute() error { return nil } tr.RecordSpan() - t.resultBlobs, err = segments.ReduceSearchResultsAndFillData( + t.resultBlobs, err = segcore.ReduceSearchResultsAndFillData( t.ctx, searchReq.Plan(), results, @@ -500,7 +501,7 @@ func (t *StreamingSearchTask) Execute() error { log.Warn("failed to reduce search results", zap.Error(err)) return err } - defer segments.DeleteSearchResultDataBlobs(t.resultBlobs) + defer segcore.DeleteSearchResultDataBlobs(t.resultBlobs) metrics.QueryNodeReduceLatency.WithLabelValues( fmt.Sprint(t.GetNodeID()), metrics.SearchLabel, @@ -514,7 +515,7 @@ func (t *StreamingSearchTask) Execute() error { // 2. reorganize blobs to original search request for i := range t.originNqs { - blob, err := segments.GetSearchResultDataBlob(t.ctx, t.resultBlobs, i) + blob, err := segcore.GetSearchResultDataBlob(t.ctx, t.resultBlobs, i) if err != nil { return err } @@ -584,19 +585,19 @@ func (t *StreamingSearchTask) maybeReturnForEmptyResults(results []*segments.Sea } func (t *StreamingSearchTask) streamReduce(ctx context.Context, - plan *segments.SearchPlan, + plan *segcore.SearchPlan, newResult *segments.SearchResult, sliceNQs []int64, sliceTopKs []int64, ) error { if t.streamReducer == nil { var err error - t.streamReducer, err = segments.NewStreamReducer(ctx, plan, sliceNQs, sliceTopKs) + t.streamReducer, err = segcore.NewStreamReducer(ctx, plan, sliceNQs, sliceTopKs) if err != nil { log.Error("Fail to init stream reducer, return") return err } } - return segments.StreamReduceSearchResult(ctx, newResult, t.streamReducer) + return segcore.StreamReduceSearchResult(ctx, newResult, t.streamReducer) } diff --git a/internal/util/searchutil/optimizers/query_hook_test.go b/internal/util/searchutil/optimizers/query_hook_test.go index ac5f8ed505d6e..6dc78bfeac403 100644 --- a/internal/util/searchutil/optimizers/query_hook_test.go +++ b/internal/util/searchutil/optimizers/query_hook_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/suite" "google.golang.org/protobuf/proto" + "github.com/milvus-io/milvus/internal/mocks/util/searchutil/mock_optimizers" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -36,7 +37,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { suite.Run("normal_run", func() { paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") - mockHook := NewMockQueryHook(suite.T()) + mockHook := mock_optimizers.NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` @@ -87,7 +88,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { }) suite.Run("disable optimization", func() { - mockHook := NewMockQueryHook(suite.T()) + mockHook := mock_optimizers.NewMockQueryHook(suite.T()) suite.queryHook = mockHook defer func() { suite.queryHook = nil }() @@ -144,7 +145,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { suite.Run("other_plannode", func() { paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") - mockHook := NewMockQueryHook(suite.T()) + mockHook := mock_optimizers.NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` @@ -174,7 +175,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { suite.Run("no_serialized_plan", func() { paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") defer paramtable.Get().Reset(paramtable.Get().AutoIndexConfig.Enable.Key) - mockHook := NewMockQueryHook(suite.T()) + mockHook := mock_optimizers.NewMockQueryHook(suite.T()) suite.queryHook = mockHook defer func() { suite.queryHook = nil }() @@ -187,7 +188,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { suite.Run("hook_run_error", func() { paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") - mockHook := NewMockQueryHook(suite.T()) + mockHook := mock_optimizers.NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` diff --git a/internal/util/segcore/cgo_util.go b/internal/util/segcore/cgo_util.go new file mode 100644 index 0000000000000..a34c08657df13 --- /dev/null +++ b/internal/util/segcore/cgo_util.go @@ -0,0 +1,78 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package segcore + +/* +#cgo pkg-config: milvus_core + +#include "segcore/collection_c.h" +#include "common/type_c.h" +#include "segcore/segment_c.h" +#include "storage/storage_c.h" +*/ +import "C" + +import ( + "context" + "math" + "unsafe" + + "google.golang.org/protobuf/proto" + + "github.com/milvus-io/milvus/internal/util/cgoconverter" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +type CStatus = C.CStatus + +// ConsumeCStatusIntoError consumes the CStatus and returns the error +func ConsumeCStatusIntoError(status *C.CStatus) error { + if status == nil || status.error_code == 0 { + return nil + } + errorCode := status.error_code + errorMsg := C.GoString(status.error_msg) + C.free(unsafe.Pointer(status.error_msg)) + return merr.SegcoreError(int32(errorCode), errorMsg) +} + +// unmarshalCProto unmarshal the proto from C memory +func unmarshalCProto(cRes *C.CProto, msg proto.Message) error { + blob := (*(*[math.MaxInt32]byte)(cRes.proto_blob))[:int(cRes.proto_size):int(cRes.proto_size)] + return proto.Unmarshal(blob, msg) +} + +// getCProtoBlob returns the raw C memory, invoker should release it itself +func getCProtoBlob(cProto *C.CProto) []byte { + lease, blob := cgoconverter.UnsafeGoBytes(&cProto.proto_blob, int(cProto.proto_size)) + cgoconverter.Extract(lease) + return blob +} + +// GetLocalUsedSize returns the used size of the local path +func GetLocalUsedSize(ctx context.Context, path string) (int64, error) { + var availableSize int64 + cSize := (*C.int64_t)(&availableSize) + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + status := C.GetLocalUsedSize(cPath, cSize) + if err := ConsumeCStatusIntoError(&status); err != nil { + return 0, err + } + return availableSize, nil +} diff --git a/internal/util/segcore/cgo_util_test.go b/internal/util/segcore/cgo_util_test.go new file mode 100644 index 0000000000000..3a517ede69d56 --- /dev/null +++ b/internal/util/segcore/cgo_util_test.go @@ -0,0 +1,19 @@ +package segcore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConsumeCStatusIntoError(t *testing.T) { + err := ConsumeCStatusIntoError(nil) + assert.NoError(t, err) +} + +func TestGetLocalUsedSize(t *testing.T) { + size, err := GetLocalUsedSize(context.Background(), "") + assert.NoError(t, err) + assert.NotNil(t, size) +} diff --git a/internal/util/segcore/collection.go b/internal/util/segcore/collection.go new file mode 100644 index 0000000000000..2292e9822fb03 --- /dev/null +++ b/internal/util/segcore/collection.go @@ -0,0 +1,83 @@ +package segcore + +/* +#cgo pkg-config: milvus_core + +#include "segcore/collection_c.h" +#include "segcore/segment_c.h" +*/ +import "C" + +import ( + "errors" + "unsafe" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" + "google.golang.org/protobuf/proto" +) + +// CreateCCollectionRequest is a request to create a CCollection. +type CreateCCollectionRequest struct { + CollectionID int64 + Schema *schemapb.CollectionSchema + IndexMeta *segcorepb.CollectionIndexMeta +} + +// CreateCCollection creates a CCollection from a CreateCCollectionRequest. +func CreateCCollection(req *CreateCCollectionRequest) (*CCollection, error) { + schemaBlob, err := proto.Marshal(req.Schema) + if err != nil { + return nil, errors.New("marshal schema failed") + } + var indexMetaBlob []byte + if req.IndexMeta != nil { + indexMetaBlob, err = proto.Marshal(req.IndexMeta) + if err != nil { + return nil, errors.New("marshal index meta failed") + } + } + ptr := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob))) + if indexMetaBlob != nil { + C.SetIndexMeta(ptr, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob))) + } + return &CCollection{ + collectionID: req.CollectionID, + ptr: ptr, + schema: req.Schema, + indexMeta: req.IndexMeta, + }, nil +} + +// CCollection is just a wrapper of the underlying C-structure CCollection. +// Contains some additional immutable properties of collection. +type CCollection struct { + ptr C.CCollection + collectionID int64 + schema *schemapb.CollectionSchema + indexMeta *segcorepb.CollectionIndexMeta +} + +// ID returns the collection ID. +func (c *CCollection) ID() int64 { + return c.collectionID +} + +// rawPointer returns the underlying C-structure pointer. +func (c *CCollection) rawPointer() C.CCollection { + return c.ptr +} + +func (c *CCollection) Schema() *schemapb.CollectionSchema { + return c.schema +} + +func (c *CCollection) IndexMeta() *segcorepb.CollectionIndexMeta { + return c.indexMeta +} + +// Release releases the underlying collection +func (c *CCollection) Release() { + C.DeleteCollection(c.ptr) + c.ptr = nil +} diff --git a/internal/util/segcore/collection_test.go b/internal/util/segcore/collection_test.go new file mode 100644 index 0000000000000..3afda510a8268 --- /dev/null +++ b/internal/util/segcore/collection_test.go @@ -0,0 +1,28 @@ +package segcore_test + +import ( + "testing" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" + "github.com/milvus-io/milvus/internal/util/segcore" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/stretchr/testify/assert" +) + +func TestCollection(t *testing.T) { + paramtable.Init() + schema := mock_segcore.GenTestCollectionSchema("test", schemapb.DataType_Int64, false) + indexMeta := mock_segcore.GenTestIndexMeta(1, schema) + ccollection, err := segcore.CreateCCollection(&segcore.CreateCCollectionRequest{ + CollectionID: 1, + Schema: schema, + IndexMeta: indexMeta, + }) + assert.NoError(t, err) + assert.NotNil(t, ccollection) + assert.NotNil(t, ccollection.Schema()) + assert.NotNil(t, ccollection.IndexMeta()) + assert.Equal(t, int64(1), ccollection.ID()) + defer ccollection.Release() +} diff --git a/internal/querynodev2/segments/plan.go b/internal/util/segcore/plan.go similarity index 60% rename from internal/querynodev2/segments/plan.go rename to internal/util/segcore/plan.go index ff94ac63c9d93..54779caf58213 100644 --- a/internal/querynodev2/segments/plan.go +++ b/internal/util/segcore/plan.go @@ -14,11 +14,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package segments +package segcore /* #cgo pkg-config: milvus_core +#include "common/type_c.h" #include "segcore/collection_c.h" #include "segcore/segment_c.h" #include "segcore/plan_c.h" @@ -26,15 +27,14 @@ package segments import "C" import ( - "context" - "fmt" "unsafe" "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/util/merr" - . "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) // SearchPlan is a wrapper of the underlying C-structure C.CSearchPlan @@ -42,22 +42,16 @@ type SearchPlan struct { cSearchPlan C.CSearchPlan } -func createSearchPlanByExpr(ctx context.Context, col *Collection, expr []byte) (*SearchPlan, error) { - if col.collectionPtr == nil { - return nil, errors.New("nil collection ptr, collectionID = " + fmt.Sprintln(col.id)) - } +func createSearchPlanByExpr(col *CCollection, expr []byte) (*SearchPlan, error) { var cPlan C.CSearchPlan - status := C.CreateSearchPlanByExpr(col.collectionPtr, unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan) - - err1 := HandleCStatus(ctx, &status, "Create Plan by expr failed") - if err1 != nil { - return nil, err1 + status := C.CreateSearchPlanByExpr(col.rawPointer(), unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan) + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrap(err, "Create Plan by expr failed") } - return &SearchPlan{cSearchPlan: cPlan}, nil } -func (plan *SearchPlan) getTopK() int64 { +func (plan *SearchPlan) GetTopK() int64 { topK := C.GetTopK(plan.cSearchPlan) return int64(topK) } @@ -82,15 +76,15 @@ func (plan *SearchPlan) delete() { type SearchRequest struct { plan *SearchPlan cPlaceholderGroup C.CPlaceholderGroup - msgID UniqueID - searchFieldID UniqueID - mvccTimestamp Timestamp + msgID int64 + searchFieldID int64 + mvccTimestamp typeutil.Timestamp } -func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb.SearchRequest, placeholderGrp []byte) (*SearchRequest, error) { +func NewSearchRequest(collection *CCollection, req *querypb.SearchRequest, placeholderGrp []byte) (*SearchRequest, error) { metricType := req.GetReq().GetMetricType() expr := req.Req.SerializedExprPlan - plan, err := createSearchPlanByExpr(ctx, collection, expr) + plan, err := createSearchPlanByExpr(collection, expr) if err != nil { return nil, err } @@ -104,10 +98,9 @@ func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb. blobSize := C.int64_t(len(placeholderGrp)) var cPlaceholderGroup C.CPlaceholderGroup status := C.ParsePlaceholderGroup(plan.cSearchPlan, blobPtr, blobSize, &cPlaceholderGroup) - - if err := HandleCStatus(ctx, &status, "parser searchRequest failed"); err != nil { + if err := ConsumeCStatusIntoError(&status); err != nil { plan.delete() - return nil, err + return nil, errors.Wrap(err, "parser searchRequest failed") } metricTypeInPlan := plan.GetMetricType() @@ -118,23 +111,21 @@ func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb. var fieldID C.int64_t status = C.GetFieldID(plan.cSearchPlan, &fieldID) - if err = HandleCStatus(ctx, &status, "get fieldID from plan failed"); err != nil { + if err := ConsumeCStatusIntoError(&status); err != nil { plan.delete() - return nil, err + return nil, errors.Wrap(err, "get fieldID from plan failed") } - ret := &SearchRequest{ + return &SearchRequest{ plan: plan, cPlaceholderGroup: cPlaceholderGroup, msgID: req.GetReq().GetBase().GetMsgID(), searchFieldID: int64(fieldID), mvccTimestamp: req.GetReq().GetMvccTimestamp(), - } - - return ret, nil + }, nil } -func (req *SearchRequest) getNumOfQuery() int64 { +func (req *SearchRequest) GetNumOfQuery() int64 { numQueries := C.GetNumOfQueries(req.cPlaceholderGroup) return int64(numQueries) } @@ -143,6 +134,10 @@ func (req *SearchRequest) Plan() *SearchPlan { return req.plan } +func (req *SearchRequest) SearchFieldID() int64 { + return req.searchFieldID +} + func (req *SearchRequest) Delete() { if req.plan != nil { req.plan.delete() @@ -150,59 +145,49 @@ func (req *SearchRequest) Delete() { C.DeletePlaceholderGroup(req.cPlaceholderGroup) } -func parseSearchRequest(ctx context.Context, plan *SearchPlan, searchRequestBlob []byte) (*SearchRequest, error) { - if len(searchRequestBlob) == 0 { - return nil, fmt.Errorf("empty search request") - } - blobPtr := unsafe.Pointer(&searchRequestBlob[0]) - blobSize := C.int64_t(len(searchRequestBlob)) - var cPlaceholderGroup C.CPlaceholderGroup - status := C.ParsePlaceholderGroup(plan.cSearchPlan, blobPtr, blobSize, &cPlaceholderGroup) - - if err := HandleCStatus(ctx, &status, "parser searchRequest failed"); err != nil { - return nil, err - } - - ret := &SearchRequest{cPlaceholderGroup: cPlaceholderGroup, plan: plan} - return ret, nil -} - // RetrievePlan is a wrapper of the underlying C-structure C.CRetrievePlan type RetrievePlan struct { cRetrievePlan C.CRetrievePlan - Timestamp Timestamp - msgID UniqueID // only used to debug. + Timestamp typeutil.Timestamp + msgID int64 // only used to debug. + maxLimitSize int64 ignoreNonPk bool } -func NewRetrievePlan(ctx context.Context, col *Collection, expr []byte, timestamp Timestamp, msgID UniqueID) (*RetrievePlan, error) { - col.mu.RLock() - defer col.mu.RUnlock() - - if col.collectionPtr == nil { - return nil, merr.WrapErrCollectionNotFound(col.id, "collection released") +func NewRetrievePlan(col *CCollection, expr []byte, timestamp typeutil.Timestamp, msgID int64) (*RetrievePlan, error) { + if col.rawPointer() == nil { + return nil, errors.New("collection is released") } - var cPlan C.CRetrievePlan - status := C.CreateRetrievePlanByExpr(col.collectionPtr, unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan) - - err := HandleCStatus(ctx, &status, "Create retrieve plan by expr failed") - if err != nil { - return nil, err + status := C.CreateRetrievePlanByExpr(col.rawPointer(), unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan) + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrap(err, "Create retrieve plan by expr failed") } - - newPlan := &RetrievePlan{ + maxLimitSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64() + return &RetrievePlan{ cRetrievePlan: cPlan, Timestamp: timestamp, msgID: msgID, - } - return newPlan, nil + maxLimitSize: maxLimitSize, + }, nil } func (plan *RetrievePlan) ShouldIgnoreNonPk() bool { return bool(C.ShouldIgnoreNonPk(plan.cRetrievePlan)) } +func (plan *RetrievePlan) SetIgnoreNonPk(ignore bool) { + plan.ignoreNonPk = ignore +} + +func (plan *RetrievePlan) IsIgnoreNonPk() bool { + return plan.ignoreNonPk +} + +func (plan *RetrievePlan) MsgID() int64 { + return plan.msgID +} + func (plan *RetrievePlan) Delete() { C.DeleteRetrievePlan(plan.cRetrievePlan) } diff --git a/internal/querynodev2/segments/plan_test.go b/internal/util/segcore/plan_test.go similarity index 66% rename from internal/querynodev2/segments/plan_test.go rename to internal/util/segcore/plan_test.go index cc4c000113e40..4fad9c83ae3df 100644 --- a/internal/querynodev2/segments/plan_test.go +++ b/internal/util/segcore/plan_test.go @@ -14,18 +14,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package segments +package segcore_test import ( - "context" "testing" "github.com/stretchr/testify/suite" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -36,47 +38,46 @@ type PlanSuite struct { collectionID int64 partitionID int64 segmentID int64 - collection *Collection + collection *segcore.CCollection } func (suite *PlanSuite) SetupTest() { suite.collectionID = 100 suite.partitionID = 10 suite.segmentID = 1 - schema := GenTestCollectionSchema("plan-suite", schemapb.DataType_Int64, true) - suite.collection = NewCollection(suite.collectionID, schema, GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ - LoadType: querypb.LoadType_LoadCollection, + schema := mock_segcore.GenTestCollectionSchema("plan-suite", schemapb.DataType_Int64, true) + var err error + suite.collection, err = segcore.CreateCCollection(&segcore.CreateCCollectionRequest{ + Schema: schema, + IndexMeta: mock_segcore.GenTestIndexMeta(suite.collectionID, schema), }) - suite.collection.AddPartition(suite.partitionID) + if err != nil { + panic(err) + } } func (suite *PlanSuite) TearDownTest() { - DeleteCollection(suite.collection) + suite.collection.Release() } func (suite *PlanSuite) TestPlanCreateByExpr() { planNode := &planpb.PlanNode{ - OutputFieldIds: []int64{rowIDFieldID}, + OutputFieldIds: []int64{0}, } expr, err := proto.Marshal(planNode) suite.NoError(err) - _, err = createSearchPlanByExpr(context.Background(), suite.collection, expr) - suite.Error(err) -} - -func (suite *PlanSuite) TestPlanFail() { - collection := &Collection{ - id: -1, - } - - _, err := createSearchPlanByExpr(context.Background(), collection, nil) + _, err = segcore.NewSearchRequest(suite.collection, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + SerializedExprPlan: expr, + }, + }, nil) suite.Error(err) } func (suite *PlanSuite) TestQueryPlanCollectionReleased() { - collection := &Collection{id: suite.collectionID} - _, err := NewRetrievePlan(context.Background(), collection, nil, 0, 0) + suite.collection.Release() + _, err := segcore.NewRetrievePlan(suite.collection, nil, 0, 0) suite.Error(err) } diff --git a/internal/querynodev2/segments/reduce.go b/internal/util/segcore/reduce.go similarity index 82% rename from internal/querynodev2/segments/reduce.go rename to internal/util/segcore/reduce.go index 6fcee6995601b..e943206aace47 100644 --- a/internal/querynodev2/segments/reduce.go +++ b/internal/util/segcore/reduce.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package segments +package segcore /* #cgo pkg-config: milvus_core @@ -27,6 +27,8 @@ import "C" import ( "context" "fmt" + + "github.com/cockroachdb/errors" ) type SliceInfo struct { @@ -34,22 +36,12 @@ type SliceInfo struct { SliceTopKs []int64 } -// SearchResult contains a pointer to the search result in C++ memory -type SearchResult struct { - cSearchResult C.CSearchResult -} - // SearchResultDataBlobs is the CSearchResultsDataBlobs in C++ type ( SearchResultDataBlobs = C.CSearchResultDataBlobs StreamSearchReducer = C.CSearchStreamReducer ) -// RetrieveResult contains a pointer to the retrieve result in C++ memory -type RetrieveResult struct { - cRetrieveResult C.CRetrieveResult -} - func ParseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *SliceInfo { sInfo := &SliceInfo{ SliceNQs: make([]int64, 0), @@ -94,8 +86,8 @@ func NewStreamReducer(ctx context.Context, var streamReducer StreamSearchReducer status := C.NewStreamReducer(plan.cSearchPlan, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices, &streamReducer) - if err := HandleCStatus(ctx, &status, "MergeSearchResultsWithOutputFields failed"); err != nil { - return nil, err + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrap(err, "MergeSearchResultsWithOutputFields failed") } return streamReducer, nil } @@ -108,8 +100,8 @@ func StreamReduceSearchResult(ctx context.Context, cSearchResultPtr := &cSearchResults[0] status := C.StreamReduce(streamReducer, cSearchResultPtr, 1) - if err := HandleCStatus(ctx, &status, "StreamReduceSearchResult failed"); err != nil { - return err + if err := ConsumeCStatusIntoError(&status); err != nil { + return errors.Wrap(err, "StreamReduceSearchResult failed") } return nil } @@ -117,8 +109,8 @@ func StreamReduceSearchResult(ctx context.Context, func GetStreamReduceResult(ctx context.Context, streamReducer StreamSearchReducer) (SearchResultDataBlobs, error) { var cSearchResultDataBlobs SearchResultDataBlobs status := C.GetStreamReduceResult(streamReducer, &cSearchResultDataBlobs) - if err := HandleCStatus(ctx, &status, "ReduceSearchResultsAndFillData failed"); err != nil { - return nil, err + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrap(err, "ReduceSearchResultsAndFillData failed") } return cSearchResultDataBlobs, nil } @@ -154,8 +146,8 @@ func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searc traceCtx := ParseCTraceContext(ctx) status := C.ReduceSearchResultsAndFillData(traceCtx.ctx, &cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr, cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices) - if err := HandleCStatus(ctx, &status, "ReduceSearchResultsAndFillData failed"); err != nil { - return nil, err + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrap(err, "ReduceSearchResultsAndFillData failed") } return cSearchResultDataBlobs, nil } @@ -163,10 +155,10 @@ func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searc func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs SearchResultDataBlobs, blobIndex int) ([]byte, error) { var blob C.CProto status := C.GetSearchResultDataBlob(&blob, cSearchResultDataBlobs, C.int32_t(blobIndex)) - if err := HandleCStatus(ctx, &status, "marshal failed"); err != nil { - return nil, err + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrap(err, "marshal failed") } - return GetCProtoBlob(&blob), nil + return getCProtoBlob(&blob), nil } func DeleteSearchResultDataBlobs(cSearchResultDataBlobs SearchResultDataBlobs) { @@ -176,14 +168,3 @@ func DeleteSearchResultDataBlobs(cSearchResultDataBlobs SearchResultDataBlobs) { func DeleteStreamReduceHelper(cStreamReduceHelper StreamSearchReducer) { C.DeleteStreamSearchReducer(cStreamReduceHelper) } - -func DeleteSearchResults(results []*SearchResult) { - if len(results) == 0 { - return - } - for _, result := range results { - if result != nil { - C.DeleteSearchResult(result.cSearchResult) - } - } -} diff --git a/internal/querynodev2/segments/reduce_test.go b/internal/util/segcore/reduce_test.go similarity index 57% rename from internal/querynodev2/segments/reduce_test.go rename to internal/util/segcore/reduce_test.go index 9c3a5d4f179f1..8d7e21e6c2358 100644 --- a/internal/querynodev2/segments/reduce_test.go +++ b/internal/util/segcore/reduce_test.go @@ -14,13 +14,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -package segments +package segcore_test import ( "context" "fmt" "log" "math" + "path/filepath" "testing" "github.com/stretchr/testify/suite" @@ -29,11 +30,13 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" + "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" - storage "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/initcore" + "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -49,8 +52,8 @@ type ReduceSuite struct { collectionID int64 partitionID int64 segmentID int64 - collection *Collection - segment Segment + collection *segcore.CCollection + segment segcore.CSegment } func (suite *ReduceSuite) SetupSuite() { @@ -58,7 +61,10 @@ func (suite *ReduceSuite) SetupSuite() { } func (suite *ReduceSuite) SetupTest() { - var err error + localDataRootPath := filepath.Join(paramtable.Get().LocalStorageCfg.Path.GetValue(), typeutil.QueryNodeRole) + initcore.InitLocalChunkManager(localDataRootPath) + err := initcore.InitMmapManager(paramtable.Get()) + suite.NoError(err) ctx := context.Background() msgLength := 100 @@ -70,29 +76,22 @@ func (suite *ReduceSuite) SetupTest() { suite.collectionID = 100 suite.partitionID = 10 suite.segmentID = 1 - schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) - suite.collection = NewCollection(suite.collectionID, - schema, - GenTestIndexMeta(suite.collectionID, schema), - &querypb.LoadMetaInfo{ - LoadType: querypb.LoadType_LoadCollection, - }) - suite.segment, err = NewSegment(ctx, - suite.collection, - SegmentTypeSealed, - 0, - &querypb.SegmentLoadInfo{ - SegmentID: suite.segmentID, - CollectionID: suite.collectionID, - PartitionID: suite.partitionID, - NumOfRows: int64(msgLength), - InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID), - Level: datapb.SegmentLevel_Legacy, - }, - ) + schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) + suite.collection, err = segcore.CreateCCollection(&segcore.CreateCCollectionRequest{ + CollectionID: suite.collectionID, + Schema: schema, + IndexMeta: mock_segcore.GenTestIndexMeta(suite.collectionID, schema), + }) + suite.NoError(err) + suite.segment, err = segcore.CreateCSegment(&segcore.CreateCSegmentRequest{ + Collection: suite.collection, + SegmentID: suite.segmentID, + SegmentType: segcore.SegmentTypeSealed, + IsSorted: false, + }) suite.Require().NoError(err) - binlogs, _, err := SaveBinLog(ctx, + binlogs, _, err := mock_segcore.SaveBinLog(ctx, suite.collectionID, suite.partitionID, suite.segmentID, @@ -101,15 +100,19 @@ func (suite *ReduceSuite) SetupTest() { suite.chunkManager, ) suite.Require().NoError(err) + req := &segcore.LoadFieldDataRequest{ + RowCount: int64(msgLength), + } for _, binlog := range binlogs { - err = suite.segment.(*LocalSegment).LoadFieldData(ctx, binlog.FieldID, int64(msgLength), binlog) - suite.Require().NoError(err) + req.Fields = append(req.Fields, segcore.LoadFieldDataInfo{Field: binlog}) } + _, err = suite.segment.LoadFieldData(ctx, req) + suite.Require().NoError(err) } func (suite *ReduceSuite) TearDownTest() { - suite.segment.Release(context.Background()) - DeleteCollection(suite.collection) + suite.segment.Release() + suite.collection.Release() ctx := context.Background() suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath) } @@ -118,7 +121,7 @@ func (suite *ReduceSuite) TestReduceParseSliceInfo() { originNQs := []int64{2, 3, 2} originTopKs := []int64{10, 5, 20} nqPerSlice := int64(2) - sInfo := ParseSliceInfo(originNQs, originTopKs, nqPerSlice) + sInfo := segcore.ParseSliceInfo(originNQs, originTopKs, nqPerSlice) expectedSliceNQs := []int64{2, 2, 1, 2} expectedSliceTopKs := []int64{10, 5, 5, 20} @@ -130,7 +133,7 @@ func (suite *ReduceSuite) TestReduceAllFunc() { nq := int64(10) // TODO: replace below by genPlaceholderGroup(nq) - vec := testutils.GenerateFloatVectors(1, defaultDim) + vec := testutils.GenerateFloatVectors(1, mock_segcore.DefaultDim) var searchRawData []byte for i, ele := range vec { buf := make([]byte, 4) @@ -167,35 +170,73 @@ func (suite *ReduceSuite) TestReduceAllFunc() { > placeholder_tag: "$0" >` - var planpb planpb.PlanNode + var planNode planpb.PlanNode // proto.UnmarshalText(planStr, &planpb) - prototext.Unmarshal([]byte(planStr), &planpb) - serializedPlan, err := proto.Marshal(&planpb) - suite.NoError(err) - plan, err := createSearchPlanByExpr(context.Background(), suite.collection, serializedPlan) + prototext.Unmarshal([]byte(planStr), &planNode) + serializedPlan, err := proto.Marshal(&planNode) suite.NoError(err) - searchReq, err := parseSearchRequest(context.Background(), plan, placeGroupByte) - searchReq.mvccTimestamp = typeutil.MaxTimestamp + searchReq, err := segcore.NewSearchRequest(suite.collection, &querypb.SearchRequest{ + Req: &internalpb.SearchRequest{ + SerializedExprPlan: serializedPlan, + MvccTimestamp: typeutil.MaxTimestamp, + }, + }, placeGroupByte) + suite.NoError(err) defer searchReq.Delete() searchResult, err := suite.segment.Search(context.Background(), searchReq) suite.NoError(err) - err = checkSearchResult(context.Background(), nq, plan, searchResult) + err = mock_segcore.CheckSearchResult(context.Background(), nq, searchReq.Plan(), searchResult) + suite.NoError(err) + + // Test Illegal Query + retrievePlan, err := segcore.NewRetrievePlan( + suite.collection, + []byte(fmt.Sprintf("%d > 100", mock_segcore.RowIDField.ID)), + typeutil.MaxTimestamp, + 0) + suite.Error(err) + suite.Nil(retrievePlan) + + plan := planpb.PlanNode{ + Node: &planpb.PlanNode_Query{ + Query: &planpb.QueryPlanNode{ + IsCount: true, + }, + }, + } + expr, err := proto.Marshal(&plan) + suite.NoError(err) + retrievePlan, err = segcore.NewRetrievePlan( + suite.collection, + expr, + typeutil.MaxTimestamp, + 0) + suite.NotNil(retrievePlan) + suite.NoError(err) + + retrieveResult, err := suite.segment.Retrieve(context.Background(), retrievePlan) + suite.NotNil(retrieveResult) + suite.NoError(err) + result, err := retrieveResult.GetResult() suite.NoError(err) + suite.NotNil(result) + suite.Equal(int64(100), result.AllRetrieveCount) + retrieveResult.Release() } func (suite *ReduceSuite) TestReduceInvalid() { - plan := &SearchPlan{} - _, err := ReduceSearchResultsAndFillData(context.Background(), plan, nil, 1, nil, nil) + plan := &segcore.SearchPlan{} + _, err := segcore.ReduceSearchResultsAndFillData(context.Background(), plan, nil, 1, nil, nil) suite.Error(err) - searchReq, err := genSearchPlanAndRequests(suite.collection, []int64{suite.segmentID}, IndexHNSW, 10) + searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection, []int64{suite.segmentID}, mock_segcore.IndexHNSW, 10) suite.NoError(err) - searchResults := make([]*SearchResult, 0) + searchResults := make([]*segcore.SearchResult, 0) searchResults = append(searchResults, nil) - _, err = ReduceSearchResultsAndFillData(context.Background(), searchReq.plan, searchResults, 1, []int64{10}, []int64{10}) + _, err = segcore.ReduceSearchResultsAndFillData(context.Background(), searchReq.Plan(), searchResults, 1, []int64{10}, []int64{10}) suite.Error(err) } diff --git a/internal/util/segcore/requests.go b/internal/util/segcore/requests.go new file mode 100644 index 0000000000000..5b42aa8394275 --- /dev/null +++ b/internal/util/segcore/requests.go @@ -0,0 +1,99 @@ +package segcore + +/* +#cgo pkg-config: milvus_core +#include "segcore/load_field_data_c.h" +*/ +import "C" + +import ( + "unsafe" + + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/pkg/errors" +) + +type RetrievePlanWithOffsets struct { + *RetrievePlan + Offsets []int64 +} + +type InsertRequest struct { + RowIDs []int64 + Timestamps []typeutil.Timestamp + Record *segcorepb.InsertRecord +} + +type DeleteRequest struct { + PrimaryKeys storage.PrimaryKeys + Timestamps []typeutil.Timestamp +} + +type LoadFieldDataRequest struct { + Fields []LoadFieldDataInfo + MMapDir string + RowCount int64 +} + +type LoadFieldDataInfo struct { + Field *datapb.FieldBinlog + EnableMMap bool +} + +func (req *LoadFieldDataRequest) getCLoadFieldDataRequest() (result *cLoadFieldDataRequest, err error) { + var cLoadFieldDataInfo C.CLoadFieldDataInfo + status := C.NewLoadFieldDataInfo(&cLoadFieldDataInfo) + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrap(err, "NewLoadFieldDataInfo failed") + } + defer func() { + if err != nil { + C.DeleteLoadFieldDataInfo(cLoadFieldDataInfo) + } + }() + rowCount := C.int64_t(req.RowCount) + + for _, field := range req.Fields { + cFieldID := C.int64_t(field.Field.GetFieldID()) + + status = C.AppendLoadFieldInfo(cLoadFieldDataInfo, cFieldID, rowCount) + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrapf(err, "AppendLoadFieldInfo failed at fieldID, %d", field.Field.GetFieldID()) + } + for _, binlog := range field.Field.Binlogs { + cEntriesNum := C.int64_t(binlog.GetEntriesNum()) + cFile := C.CString(binlog.GetLogPath()) + defer C.free(unsafe.Pointer(cFile)) + + status = C.AppendLoadFieldDataPath(cLoadFieldDataInfo, cFieldID, cEntriesNum, cFile) + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrapf(err, "AppendLoadFieldDataPath failed at binlog, %d, %s", field.Field.GetFieldID(), binlog.GetLogPath()) + } + } + + C.EnableMmap(cLoadFieldDataInfo, cFieldID, C.bool(field.EnableMMap)) + } + if len(req.MMapDir) > 0 { + mmapDir := C.CString(req.MMapDir) + defer C.free(unsafe.Pointer(mmapDir)) + C.AppendMMapDirPath(cLoadFieldDataInfo, mmapDir) + } + return &cLoadFieldDataRequest{ + cLoadFieldDataInfo: cLoadFieldDataInfo, + }, nil +} + +type cLoadFieldDataRequest struct { + cLoadFieldDataInfo C.CLoadFieldDataInfo +} + +func (req *cLoadFieldDataRequest) Release() { + C.DeleteLoadFieldDataInfo(req.cLoadFieldDataInfo) +} + +type AddFieldDataInfoRequest = LoadFieldDataRequest + +type AddFieldDataInfoResult struct{} diff --git a/internal/util/segcore/requests_test.go b/internal/util/segcore/requests_test.go new file mode 100644 index 0000000000000..7658087291176 --- /dev/null +++ b/internal/util/segcore/requests_test.go @@ -0,0 +1,33 @@ +package segcore + +import ( + "testing" + + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/stretchr/testify/assert" +) + +func TestLoadFieldDataRequest(t *testing.T) { + req := &LoadFieldDataRequest{ + Fields: []LoadFieldDataInfo{{ + Field: &datapb.FieldBinlog{ + FieldID: 1, + Binlogs: []*datapb.Binlog{ + { + EntriesNum: 100, + LogPath: "1", + }, { + EntriesNum: 101, + LogPath: "2", + }, + }, + }, + }}, + RowCount: 100, + MMapDir: "1234567890", + } + creq, err := req.getCLoadFieldDataRequest() + assert.NoError(t, err) + assert.NotNil(t, creq) + creq.Release() +} diff --git a/internal/util/segcore/responses.go b/internal/util/segcore/responses.go new file mode 100644 index 0000000000000..f2405cbcf2569 --- /dev/null +++ b/internal/util/segcore/responses.go @@ -0,0 +1,47 @@ +package segcore + +/* +#cgo pkg-config: milvus_core + +#include "segcore/plan_c.h" +#include "segcore/reduce_c.h" +*/ +import "C" + +import ( + "github.com/milvus-io/milvus/internal/proto/segcorepb" +) + +type SearchResult struct { + cSearchResult C.CSearchResult +} + +func (r *SearchResult) Release() { + C.DeleteSearchResult(r.cSearchResult) + r.cSearchResult = nil +} + +type RetrieveResult struct { + cRetrieveResult *C.CRetrieveResult +} + +func (r *RetrieveResult) GetResult() (*segcorepb.RetrieveResults, error) { + retrieveResult := new(segcorepb.RetrieveResults) + if err := unmarshalCProto(r.cRetrieveResult, retrieveResult); err != nil { + return nil, err + } + return retrieveResult, nil +} + +func (r *RetrieveResult) Release() { + C.DeleteRetrieveResult(r.cRetrieveResult) + r.cRetrieveResult = nil +} + +type InsertResult struct { + InsertedRows int64 +} + +type DeleteResult struct{} + +type LoadFieldDataResult struct{} diff --git a/internal/util/segcore/segcore_init.go b/internal/util/segcore/segcore_init.go new file mode 100644 index 0000000000000..b05e056a692f9 --- /dev/null +++ b/internal/util/segcore/segcore_init.go @@ -0,0 +1,24 @@ +package segcore + +/* +#cgo pkg-config: milvus_core + +#include "segcore/segcore_init_c.h" + +*/ +import "C" + +// IndexEngineInfo contains all the information about the index engine. +type IndexEngineInfo struct { + MinIndexVersion int32 + CurrentIndexVersion int32 +} + +// GetIndexEngineInfo returns the minimal and current version of the index engine. +func GetIndexEngineInfo() IndexEngineInfo { + cMinimal, cCurrent := C.GetMinimalIndexVersion(), C.GetCurrentIndexVersion() + return IndexEngineInfo{ + MinIndexVersion: int32(cMinimal), + CurrentIndexVersion: int32(cCurrent), + } +} diff --git a/internal/util/segcore/segcore_init_test.go b/internal/util/segcore/segcore_init_test.go new file mode 100644 index 0000000000000..b6e8724be61d1 --- /dev/null +++ b/internal/util/segcore/segcore_init_test.go @@ -0,0 +1,13 @@ +package segcore + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetIndexEngineInfo(t *testing.T) { + r := GetIndexEngineInfo() + assert.NotZero(t, r.CurrentIndexVersion) + assert.Zero(t, r.MinIndexVersion) +} diff --git a/internal/util/segcore/segment.go b/internal/util/segcore/segment.go new file mode 100644 index 0000000000000..87524870794cf --- /dev/null +++ b/internal/util/segcore/segment.go @@ -0,0 +1,294 @@ +package segcore + +/* +#cgo pkg-config: milvus_core + +#include "common/type_c.h" +#include "futures/future_c.h" +#include "segcore/collection_c.h" +#include "segcore/plan_c.h" +#include "segcore/reduce_c.h" +*/ +import "C" + +import ( + "context" + "fmt" + "runtime" + "unsafe" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/cgo" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/pkg/errors" + "google.golang.org/protobuf/proto" +) + +const ( + SegmentTypeGrowing SegmentType = commonpb.SegmentState_Growing + SegmentTypeSealed SegmentType = commonpb.SegmentState_Sealed +) + +type ( + SegmentType = commonpb.SegmentState + CSegmentInterface C.CSegmentInterface +) + +// CreateCSegmentRequest is a request to create a segment. +type CreateCSegmentRequest struct { + Collection *CCollection + SegmentID int64 + SegmentType SegmentType + IsSorted bool + EnableChunked bool +} + +func (req *CreateCSegmentRequest) getCSegmentType() C.SegmentType { + var segmentType C.SegmentType + switch req.SegmentType { + case SegmentTypeGrowing: + segmentType = C.Growing + case SegmentTypeSealed: + if req.EnableChunked { + segmentType = C.ChunkedSealed + break + } + segmentType = C.Sealed + default: + panic(fmt.Sprintf("invalid segment type: %d", req.SegmentType)) + } + return segmentType +} + +// CreateCSegment creates a segment from a CreateCSegmentRequest. +func CreateCSegment(req *CreateCSegmentRequest) (CSegment, error) { + var ptr C.CSegmentInterface + status := C.NewSegment(req.Collection.rawPointer(), req.getCSegmentType(), C.int64_t(req.SegmentID), &ptr, C.bool(req.IsSorted)) + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, err + } + return &cSegmentImpl{id: req.SegmentID, ptr: ptr}, nil +} + +// cSegmentImpl is a wrapper for cSegmentImplInterface. +type cSegmentImpl struct { + id int64 + ptr C.CSegmentInterface +} + +// ID returns the ID of the segment. +func (s *cSegmentImpl) ID() int64 { + return s.id +} + +// RawPointer returns the raw pointer of the segment. +func (s *cSegmentImpl) RawPointer() CSegmentInterface { + return CSegmentInterface(s.ptr) +} + +// RowNum returns the number of rows in the segment. +func (s *cSegmentImpl) RowNum() int64 { + rowCount := C.GetRealCount(s.ptr) + return int64(rowCount) +} + +// MemSize returns the memory size of the segment. +func (s *cSegmentImpl) MemSize() int64 { + cMemSize := C.GetMemoryUsageInBytes(s.ptr) + return int64(cMemSize) +} + +// HasRawData checks if the segment has raw data. +func (s *cSegmentImpl) HasRawData(fieldID int64) bool { + ret := C.HasRawData(s.ptr, C.int64_t(fieldID)) + return bool(ret) +} + +// Search requests a search on the segment. +func (s *cSegmentImpl) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { + traceCtx := ParseCTraceContext(ctx) + defer runtime.KeepAlive(traceCtx) + defer runtime.KeepAlive(searchReq) + + future := cgo.Async(ctx, + func() cgo.CFuturePtr { + return (cgo.CFuturePtr)(C.AsyncSearch( + traceCtx.ctx, + s.ptr, + searchReq.plan.cSearchPlan, + searchReq.cPlaceholderGroup, + C.uint64_t(searchReq.mvccTimestamp), + )) + }, + cgo.WithName("search"), + ) + defer future.Release() + result, err := future.BlockAndLeakyGet() + if err != nil { + return nil, err + } + return &SearchResult{cSearchResult: (C.CSearchResult)(result)}, nil +} + +// Retrieve retrieves entities from the segment. +func (s *cSegmentImpl) Retrieve(ctx context.Context, plan *RetrievePlan) (*RetrieveResult, error) { + traceCtx := ParseCTraceContext(ctx) + defer runtime.KeepAlive(traceCtx) + defer runtime.KeepAlive(plan) + + future := cgo.Async( + ctx, + func() cgo.CFuturePtr { + return (cgo.CFuturePtr)(C.AsyncRetrieve( + traceCtx.ctx, + s.ptr, + plan.cRetrievePlan, + C.uint64_t(plan.Timestamp), + C.int64_t(plan.maxLimitSize), + C.bool(plan.ignoreNonPk), + )) + }, + cgo.WithName("retrieve"), + ) + defer future.Release() + result, err := future.BlockAndLeakyGet() + if err != nil { + return nil, err + } + return &RetrieveResult{cRetrieveResult: (*C.CRetrieveResult)(result)}, nil +} + +// RetrieveByOffsets retrieves entities from the segment by offsets. +func (s *cSegmentImpl) RetrieveByOffsets(ctx context.Context, plan *RetrievePlanWithOffsets) (*RetrieveResult, error) { + if len(plan.Offsets) == 0 { + return nil, merr.WrapErrParameterInvalid("segment offsets", "empty offsets") + } + + traceCtx := ParseCTraceContext(ctx) + defer runtime.KeepAlive(traceCtx) + defer runtime.KeepAlive(plan) + defer runtime.KeepAlive(plan.Offsets) + + future := cgo.Async( + ctx, + func() cgo.CFuturePtr { + return (cgo.CFuturePtr)(C.AsyncRetrieveByOffsets( + traceCtx.ctx, + s.ptr, + plan.cRetrievePlan, + (*C.int64_t)(unsafe.Pointer(&plan.Offsets[0])), + C.int64_t(len(plan.Offsets)), + )) + }, + cgo.WithName("retrieve-by-offsets"), + ) + defer future.Release() + result, err := future.BlockAndLeakyGet() + if err != nil { + return nil, err + } + return &RetrieveResult{cRetrieveResult: (*C.CRetrieveResult)(result)}, nil +} + +// Insert inserts entities into the segment. +func (s *cSegmentImpl) Insert(ctx context.Context, request *InsertRequest) (*InsertResult, error) { + offset, err := s.preInsert(len(request.RowIDs)) + if err != nil { + return nil, err + } + + insertRecordBlob, err := proto.Marshal(request.Record) + if err != nil { + return nil, fmt.Errorf("failed to marshal insert record: %s", err) + } + + numOfRow := len(request.RowIDs) + cOffset := C.int64_t(offset) + cNumOfRows := C.int64_t(numOfRow) + cEntityIDsPtr := (*C.int64_t)(&(request.RowIDs)[0]) + cTimestampsPtr := (*C.uint64_t)(&(request.Timestamps)[0]) + + status := C.Insert(s.ptr, + cOffset, + cNumOfRows, + cEntityIDsPtr, + cTimestampsPtr, + (*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])), + (C.uint64_t)(len(insertRecordBlob)), + ) + return &InsertResult{InsertedRows: int64(numOfRow)}, ConsumeCStatusIntoError(&status) +} + +func (s *cSegmentImpl) preInsert(numOfRecords int) (int64, error) { + var offset int64 + cOffset := (*C.int64_t)(&offset) + status := C.PreInsert(s.ptr, C.int64_t(int64(numOfRecords)), cOffset) + if err := ConsumeCStatusIntoError(&status); err != nil { + return 0, err + } + return offset, nil +} + +// Delete deletes entities from the segment. +func (s *cSegmentImpl) Delete(ctx context.Context, request *DeleteRequest) (*DeleteResult, error) { + if request.PrimaryKeys.Len() == 0 { + return &DeleteResult{}, nil + } + + cOffset := C.int64_t(0) // depre + cSize := C.int64_t(len(request.PrimaryKeys.Len())) + cTimestampsPtr := (*C.uint64_t)(&(request.Timestamps)[0]) + + ids, err := storage.ParsePrimaryKeysBatch2IDs(request.PrimaryKeys) + if err != nil { + return nil, errors.Wrap(err, "failed to parse primary keys") + } + dataBlob, err := proto.Marshal(ids) + if err != nil { + return nil, fmt.Errorf("failed to marshal ids: %s", err) + } + status := C.Delete(s.ptr, + cOffset, + cSize, + (*C.uint8_t)(unsafe.Pointer(&dataBlob[0])), + (C.uint64_t)(len(dataBlob)), + cTimestampsPtr, + ) + return &DeleteResult{}, ConsumeCStatusIntoError(&status) +} + +// LoadFieldData loads field data into the segment. +func (s *cSegmentImpl) LoadFieldData(ctx context.Context, request *LoadFieldDataRequest) (*LoadFieldDataResult, error) { + creq, err := request.getCLoadFieldDataRequest() + if err != nil { + return nil, err + } + defer creq.Release() + + status := C.LoadFieldData(s.ptr, creq.cLoadFieldDataInfo) + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrap(err, "failed to load field data") + } + return &LoadFieldDataResult{}, nil +} + +// AddFieldDataInfo adds field data info into the segment. +func (s *cSegmentImpl) AddFieldDataInfo(ctx context.Context, request *AddFieldDataInfoRequest) (*AddFieldDataInfoResult, error) { + creq, err := request.getCLoadFieldDataRequest() + if err != nil { + return nil, err + } + defer creq.Release() + + status := C.AddFieldDataInfoForSealed(s.ptr, creq.cLoadFieldDataInfo) + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, errors.Wrap(err, "failed to add field data info") + } + return &AddFieldDataInfoResult{}, nil +} + +// Release releases the segment. +func (s *cSegmentImpl) Release() { + C.DeleteSegment(s.ptr) +} diff --git a/internal/util/segcore/segment_interface.go b/internal/util/segcore/segment_interface.go new file mode 100644 index 0000000000000..303db2be36b45 --- /dev/null +++ b/internal/util/segcore/segment_interface.go @@ -0,0 +1,74 @@ +package segcore + +/* +#cgo pkg-config: milvus_core + +#include "common/type_c.h" + +*/ +import "C" + +import "context" + +// CSegment is the interface of a segcore segment. +// TODO: We should separate the interface of CGrowingSegment and CSealedSegment, +// Because they have different implementations, GrowingSegment will only be used at streamingnode, SealedSegment will only be used at querynode. +// But currently, we just use the same interface to represent them to keep compatible with querynode LocalSegment. +type CSegment interface { + GrowingSegment + + SealedSegment +} + +// GrowingSegment is the interface of a growing segment. +type GrowingSegment interface { + basicSegmentMethodSet + + // Insert inserts data into the segment. + Insert(ctx context.Context, request *InsertRequest) (*InsertResult, error) +} + +// SealedSegment is the interface of a sealed segment. +type SealedSegment interface { + basicSegmentMethodSet + + // LoadFieldData loads field data into the segment. + LoadFieldData(ctx context.Context, request *LoadFieldDataRequest) (*LoadFieldDataResult, error) + + // AddFieldDataInfo adds field data info into the segment. + AddFieldDataInfo(ctx context.Context, request *AddFieldDataInfoRequest) (*AddFieldDataInfoResult, error) +} + +// basicSegmentMethodSet is the basic method set of a segment. +type basicSegmentMethodSet interface { + // ID returns the ID of the segment. + ID() int64 + + // RawPointer returns the raw pointer of the segment. + // TODO: should be removed in future. + RawPointer() CSegmentInterface + + // RawPointer returns the raw pointer of the segment. + RowNum() int64 + + // MemSize returns the memory size of the segment. + MemSize() int64 + + // HasRawData checks if the segment has raw data. + HasRawData(fieldID int64) bool + + // Search requests a search on the segment. + Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) + + // Retrieve retrieves entities from the segment. + Retrieve(ctx context.Context, plan *RetrievePlan) (*RetrieveResult, error) + + // RetrieveByOffsets retrieves entities from the segment by offsets. + RetrieveByOffsets(ctx context.Context, plan *RetrievePlanWithOffsets) (*RetrieveResult, error) + + // Delete deletes data from the segment. + Delete(ctx context.Context, request *DeleteRequest) (*DeleteResult, error) + + // Release releases the segment. + Release() +} diff --git a/internal/util/segcore/segment_test.go b/internal/util/segcore/segment_test.go new file mode 100644 index 0000000000000..930683e2a8976 --- /dev/null +++ b/internal/util/segcore/segment_test.go @@ -0,0 +1,136 @@ +package segcore_test + +import ( + "context" + "path/filepath" + "testing" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore" + "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/proto/segcorepb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/initcore" + "github.com/milvus-io/milvus/internal/util/segcore" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" +) + +func TestGrowingSegment(t *testing.T) { + paramtable.Init() + localDataRootPath := filepath.Join(paramtable.Get().LocalStorageCfg.Path.GetValue(), typeutil.QueryNodeRole) + initcore.InitLocalChunkManager(localDataRootPath) + err := initcore.InitMmapManager(paramtable.Get()) + assert.NoError(t, err) + + collectionID := int64(100) + segmentID := int64(100) + + schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) + collection, err := segcore.CreateCCollection(&segcore.CreateCCollectionRequest{ + CollectionID: collectionID, + Schema: schema, + IndexMeta: mock_segcore.GenTestIndexMeta(collectionID, schema), + }) + assert.NoError(t, err) + assert.NotNil(t, collection) + defer collection.Release() + + segment, err := segcore.CreateCSegment(&segcore.CreateCSegmentRequest{ + Collection: collection, + SegmentID: segmentID, + SegmentType: segcore.SegmentTypeGrowing, + IsSorted: false, + }) + + assert.NoError(t, err) + assert.NotNil(t, segment) + defer segment.Release() + + assert.Equal(t, segmentID, segment.ID()) + assert.Equal(t, int64(0), segment.RowNum()) + assert.Zero(t, segment.MemSize()) + assert.True(t, segment.HasRawData(0)) + assertEqualCount(t, collection, segment, 0) + + insertMsg, err := mock_segcore.GenInsertMsg(collection, 1, segmentID, 100) + assert.NoError(t, err) + insertResult, err := segment.Insert(context.Background(), &segcore.InsertRequest{ + RowIDs: insertMsg.RowIDs, + Timestamps: insertMsg.Timestamps, + Record: &segcorepb.InsertRecord{ + FieldsData: insertMsg.FieldsData, + NumRows: int64(len(insertMsg.RowIDs)), + }, + }) + assert.NoError(t, err) + assert.NotNil(t, insertResult) + assert.Equal(t, int64(100), insertResult.InsertedRows) + + assert.Equal(t, int64(100), segment.RowNum()) + assertEqualCount(t, collection, segment, 100) + + deleteResult, err := segment.Delete(context.Background(), &segcore.DeleteRequest{ + PrimaryKeys: []storage.PrimaryKey{ + storage.NewInt64PrimaryKey(10), + }, + Timestamps: []typeutil.Timestamp{ + 1000, + }, + }) + assert.NoError(t, err) + assert.NotNil(t, deleteResult) + + assert.Equal(t, int64(99), segment.RowNum()) +} + +func assertEqualCount( + t *testing.T, + collection *segcore.CCollection, + segment segcore.CSegment, + count int64, +) { + plan := planpb.PlanNode{ + Node: &planpb.PlanNode_Query{ + Query: &planpb.QueryPlanNode{ + IsCount: true, + }, + }, + } + expr, err := proto.Marshal(&plan) + assert.NoError(t, err) + retrievePlan, err := segcore.NewRetrievePlan( + collection, + expr, + typeutil.MaxTimestamp, + 100) + defer retrievePlan.Delete() + + assert.True(t, retrievePlan.ShouldIgnoreNonPk()) + assert.False(t, retrievePlan.IsIgnoreNonPk()) + retrievePlan.SetIgnoreNonPk(true) + assert.True(t, retrievePlan.IsIgnoreNonPk()) + assert.NotZero(t, retrievePlan.MsgID()) + + assert.NotNil(t, retrievePlan) + assert.NoError(t, err) + + retrieveResult, err := segment.Retrieve(context.Background(), retrievePlan) + assert.NotNil(t, retrieveResult) + assert.NoError(t, err) + result, err := retrieveResult.GetResult() + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, count, result.AllRetrieveCount) + retrieveResult.Release() + + retrieveResult2, err := segment.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{ + RetrievePlan: retrievePlan, + Offsets: []int64{0, 1, 2, 3, 4}, + }) + assert.NoError(t, err) + assert.NotNil(t, retrieveResult2) + retrieveResult2.Release() +} diff --git a/internal/util/segcore/trace.go b/internal/util/segcore/trace.go new file mode 100644 index 0000000000000..523e8ef4892d0 --- /dev/null +++ b/internal/util/segcore/trace.go @@ -0,0 +1,56 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package segcore + +/* +#cgo pkg-config: milvus_core + +#include "segcore/segment_c.h" +*/ +import "C" + +import ( + "context" + "unsafe" + + "go.opentelemetry.io/otel/trace" +) + +// CTraceContext is the wrapper for `C.CTraceContext` +// it stores the internal C.CTraceContext and +type CTraceContext struct { + traceID trace.TraceID + spanID trace.SpanID + ctx C.CTraceContext +} + +// ParseCTraceContext parses tracing span and convert it into `C.CTraceContext`. +func ParseCTraceContext(ctx context.Context) *CTraceContext { + span := trace.SpanFromContext(ctx) + + cctx := &CTraceContext{ + traceID: span.SpanContext().TraceID(), + spanID: span.SpanContext().SpanID(), + } + cctx.ctx = C.CTraceContext{ + traceID: (*C.uint8_t)(unsafe.Pointer(&cctx.traceID[0])), + spanID: (*C.uint8_t)(unsafe.Pointer(&cctx.spanID[0])), + traceFlags: (C.uint8_t)(span.SpanContext().TraceFlags()), + } + + return cctx +}