Skip to content

Commit

Permalink
feat: fp32 vector to fp16/bf16 vector conversion for RESTful API
Browse files Browse the repository at this point in the history
issue: #37448

Signed-off-by: Yinzuo Jiang <[email protected]>
Signed-off-by: Yinzuo Jiang <[email protected]>
  • Loading branch information
jiangyinzuo committed Nov 8, 2024
1 parent ebc3c82 commit 033d02f
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 71 deletions.
23 changes: 16 additions & 7 deletions internal/distributed/proxy/httpserver/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1764,7 +1764,7 @@ func TestSearchV2(t *testing.T) {
Ids: generateIDs(schemapb.DataType_Int64, 3),
Scores: DefaultScores,
}}, nil).Once()
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().HybridSearch(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(4)
collSchema := generateCollectionSchema(schemapb.DataType_Int64, false)
binaryVectorField := generateVectorFieldSchema(schemapb.DataType_BinaryVector)
binaryVectorField.Name = "binaryVector"
Expand All @@ -1783,7 +1783,7 @@ func TestSearchV2(t *testing.T) {
Schema: collSchema,
ShardsNum: ShardNumDefault,
Status: &StatusSuccess,
}, nil).Times(10)
}, nil).Times(11)
mp.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{Status: commonSuccessStatus, Results: &schemapb.SearchResultData{TopK: int64(0)}}, nil).Times(3)
mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: &commonpb.Status{
Expand Down Expand Up @@ -1882,6 +1882,15 @@ func TestSearchV2(t *testing.T) {
`{"data": ["AQIDBA=="], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
`{"data": [[0.1, 0.2]], "annsField": "book_intro", "metricType": "L2", "limit": 3},` +
`{"data": ["AQ=="], "annsField": "binaryVector", "metricType": "L2", "limit": 3},` +
`{"data": [[0.1, 0.23]], "annsField": "float16Vector", "metricType": "L2", "limit": 3},` +
`{"data": [[0.1, 0.43]], "annsField": "bfloat16Vector", "metricType": "L2", "limit": 3}` +
`], "rerank": {"strategy": "weighted", "params": {"weights": [0.9, 0.8]}}}`),
})
queryTestCases = append(queryTestCases, requestBodyTestCase{
path: AdvancedSearchAction,
requestBody: []byte(`{"collectionName": "hello_milvus", "search": [` +
Expand Down Expand Up @@ -1960,19 +1969,19 @@ func TestSearchV2(t *testing.T) {
errCode: 1100, // ErrParameterInvalid
})

for _, testcase := range queryTestCases {
for i, testcase := range queryTestCases {
t.Run(testcase.path, func(t *testing.T) {
bodyReader := bytes.NewReader(testcase.requestBody)
req := httptest.NewRequest(http.MethodPost, versionalV2(EntityCategory, testcase.path), bodyReader)
w := httptest.NewRecorder()
testEngine.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, http.StatusOK, w.Code, "case %d", i, string(testcase.requestBody))
returnBody := &ReturnErrMsg{}
err := json.Unmarshal(w.Body.Bytes(), returnBody)
assert.Nil(t, err)
assert.Equal(t, testcase.errCode, returnBody.Code)
assert.Nil(t, err, "case %d", i)
assert.Equal(t, testcase.errCode, returnBody.Code, "case %d", i, string(testcase.requestBody))
if testcase.errCode != 0 {
assert.Equal(t, testcase.errMsg, returnBody.Message)
assert.Equal(t, testcase.errMsg, returnBody.Message, "case %d", i, string(testcase.requestBody))
}
fmt.Println(w.Body.String())
})
Expand Down
100 changes: 84 additions & 16 deletions internal/distributed/proxy/httpserver/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,27 +324,25 @@ func checkAndSetData(body string, collSchema *schemapb.CollectionSchema) (error,
}
reallyData[fieldName] = sparseVec
case schemapb.DataType_Float16Vector:
if dataString == "" {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], "", "missing vector field: "+fieldName), reallyDataArray, validDataMap
}
vectorStr := gjson.Get(data.Raw, fieldName).Raw
var vectorArray []byte
err := json.Unmarshal([]byte(vectorStr), &vectorArray)
if err != nil {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray, validDataMap
}
reallyData[fieldName] = vectorArray
fallthrough
case schemapb.DataType_BFloat16Vector:
if dataString == "" {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], "", "missing vector field: "+fieldName), reallyDataArray, validDataMap
}
vectorStr := gjson.Get(data.Raw, fieldName).Raw
// []float32 or []byte
var vectorArray []byte
err := json.Unmarshal([]byte(vectorStr), &vectorArray)
if err != nil {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray, validDataMap
var float32Array []float32
err = json.Unmarshal([]byte(vectorStr), &float32Array)
if err != nil {
return merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(fieldType)], dataString, err.Error()), reallyDataArray, validDataMap
}
reallyData[fieldName] = serialize(float32Array)
} else {
reallyData[fieldName] = vectorArray
}
reallyData[fieldName] = vectorArray
case schemapb.DataType_Bool:
result, err := cast.ToBoolE(dataString)
if err != nil {
Expand Down Expand Up @@ -584,17 +582,39 @@ func convertFloatVectorToArray(vector [][]float32, dim int64) ([]float32, error)
}

func convertBinaryVectorToArray(vector [][]byte, dim int64, dataType schemapb.DataType) ([]byte, error) {
binaryArray := make([]byte, 0)
var bytesLen int64
var float32BytesLen int64 = -1
switch dataType {
case schemapb.DataType_BinaryVector:
bytesLen = dim / 8
case schemapb.DataType_Float16Vector:
bytesLen = dim * 2
float32BytesLen = dim * 4
case schemapb.DataType_BFloat16Vector:
bytesLen = dim * 2
float32BytesLen = dim * 4
}
binaryArray := make([]byte, 0, len(vector)*int(bytesLen))
for _, arr := range vector {
if int64(len(arr)) == float32BytesLen {
switch dataType {
// convert float32 to float16
case schemapb.DataType_Float16Vector:
for i := int64(0); i < float32BytesLen; i += 4 {
f32 := typeutil.BytesToFloat32(arr[i : i+4])
f16Bytes := typeutil.Float32ToFloat16Bytes(f32)
binaryArray = append(binaryArray, f16Bytes...)
}
// convert float32 to bfloat16
case schemapb.DataType_BFloat16Vector:
for i := int64(0); i < float32BytesLen; i += 4 {
f32 := typeutil.BytesToFloat32(arr[i : i+4])
f16Bytes := typeutil.Float32ToBFloat16Bytes(f32)
binaryArray = append(binaryArray, f16Bytes...)
}
}
continue
}
if int64(len(arr)) != bytesLen {
return nil, fmt.Errorf("[]byte size %d doesn't equal to vector dimension %d of %s",
len(arr), dim, schemapb.DataType_name[int32(dataType)])
Expand Down Expand Up @@ -1034,6 +1054,24 @@ func serialize(fv []float32) []byte {
return data
}

// seriaizeToFloat16 converts float32 vector `fv` to float16 vector
func seriaizeToFloat16(fv []float32) []byte {
data := make([]byte, 0, 2*len(fv)) // float16 occupies 2 bytes
for _, f := range fv {
data = append(data, typeutil.Float32ToFloat16Bytes(f)...)
}
return data
}

// seriaizeToBFloat16 converts float32 vector `fv` to bfloat16 vector
func seriaizeToBFloat16(fv []float32) []byte {
data := make([]byte, 0, 2*len(fv)) // bfloat16 occupies 2 bytes
for _, f := range fv {
data = append(data, typeutil.Float32ToBFloat16Bytes(f)...)
}
return data
}

func serializeFloatVectors(vectors []gjson.Result, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
values := make([][]byte, 0)
for _, vector := range vectors {
Expand All @@ -1052,9 +1090,39 @@ func serializeFloatVectors(vectors []gjson.Result, dataType schemapb.DataType, d
return values, nil
}

func serializeByteVectorsMaybeFloat32(vectorStr string, dataType schemapb.DataType, dimension, bytesLen int64, serializeFunc func([]float32) []byte) ([][]byte, error) {
values := make([][]byte, 0)
err := json.Unmarshal([]byte(vectorStr), &values)
if err != nil {
fp32Values := make([][]float32, 0)
err := json.Unmarshal([]byte(vectorStr), &fp32Values)
if err != nil {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vectorStr, err.Error())
}
values = make([][]byte, 0)
for _, vectorArray := range fp32Values {
if int64(len(vectorArray)) != dimension {
vecStr, _ := json.MarshalToString(vectorArray)
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vecStr,
fmt.Sprintf("dimension: %d, but length of []float: %d", dimension, len(vectorArray)))
}
vectorBytes := serializeFunc(vectorArray)
values = append(values, vectorBytes)
}
} else {
for _, vectorArray := range values {
if int64(len(vectorArray)) != bytesLen {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], string(vectorArray),
fmt.Sprintf("dimension: %d, bytesLen: %d, but length of []byte: %d", dimension, bytesLen, len(vectorArray)))
}
}
}
return values, nil
}

func serializeByteVectors(vectorStr string, dataType schemapb.DataType, dimension, bytesLen int64) ([][]byte, error) {
values := make([][]byte, 0)
err := json.Unmarshal([]byte(vectorStr), &values) // todo check len == dimension * 1/2/2
err := json.Unmarshal([]byte(vectorStr), &values)
if err != nil {
return nil, merr.WrapErrParameterInvalid(schemapb.DataType_name[int32(dataType)], vectorStr, err.Error())
}
Expand Down Expand Up @@ -1093,10 +1161,10 @@ func convertQueries2Placeholder(body string, dataType schemapb.DataType, dimensi
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension/8)
case schemapb.DataType_Float16Vector:
valueType = commonpb.PlaceholderType_Float16Vector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
values, err = serializeByteVectorsMaybeFloat32(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2, seriaizeToFloat16)
case schemapb.DataType_BFloat16Vector:
valueType = commonpb.PlaceholderType_BFloat16Vector
values, err = serializeByteVectors(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2)
values, err = serializeByteVectorsMaybeFloat32(gjson.Get(body, HTTPRequestData).Raw, dataType, dimension, dimension*2, seriaizeToBFloat16)
case schemapb.DataType_SparseFloatVector:
valueType = commonpb.PlaceholderType_SparseFloatVector
values, err = serializeSparseFloatVectors(gjson.Get(body, HTTPRequestData).Array(), dataType)
Expand Down
108 changes: 78 additions & 30 deletions internal/distributed/proxy/httpserver/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,20 @@ func TestInsertWithDefaultValueField(t *testing.T) {
func TestSerialize(t *testing.T) {
parameters := []float32{0.11111, 0.22222}
assert.Equal(t, "\xa4\x8d\xe3=\xa4\x8dc>", string(serialize(parameters)))

f16vec := seriaizeToFloat16(parameters)
assert.Equal(t, 4, len(f16vec))
// \x1c/ is 0.1111, \x1c3 is 0.2222
assert.Equal(t, "\x1c/\x1c3", string(f16vec))
assert.Equal(t, "\x1c/", string(typeutil.Float32ToFloat16Bytes(0.11111)))
assert.Equal(t, "\x1c3", string(typeutil.Float32ToFloat16Bytes(0.22222)))

bf16vec := seriaizeToBFloat16(parameters)
assert.Equal(t, 4, len(bf16vec))
assert.Equal(t, "\xe3=c>", string(bf16vec))
assert.Equal(t, "\xe3=", string(typeutil.Float32ToBFloat16Bytes(0.11111)))
assert.Equal(t, "c>", string(typeutil.Float32ToBFloat16Bytes(0.22222)))

assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(vectors2PlaceholderGroupBytes([][]float32{parameters}))) // todo
requestBody := "{\"data\": [[0.11111, 0.22222]]}"
vectors := gjson.Get(requestBody, HTTPRequestData)
Expand Down Expand Up @@ -1613,31 +1627,50 @@ func TestVector(t *testing.T) {
float16Vector := "vector-float16"
bfloat16Vector := "vector-bfloat16"
sparseFloatVector := "vector-sparse-float"
row1 := map[string]interface{}{
FieldBookID: int64(1),
floatVector: []float32{0.1, 0.11},
binaryVector: []byte{1},
float16Vector: []byte{1, 1, 11, 11},
bfloat16Vector: []byte{1, 1, 11, 11},
sparseFloatVector: map[uint32]float32{0: 0.1, 1: 0.11},
}
row2 := map[string]interface{}{
FieldBookID: int64(2),
floatVector: []float32{0.2, 0.22},
binaryVector: []byte{2},
float16Vector: []byte{2, 2, 22, 22},
bfloat16Vector: []byte{2, 2, 22, 22},
sparseFloatVector: map[uint32]float32{1000: 0.3, 200: 0.44},
testcaseRows := []map[string]interface{}{
map[string]interface{}{
FieldBookID: int64(1),
floatVector: []float32{0.1, 0.11},
binaryVector: []byte{1},
float16Vector: []byte{1, 1, 11, 11},
bfloat16Vector: []byte{1, 1, 11, 11},
sparseFloatVector: map[uint32]float32{0: 0.1, 1: 0.11},
},
map[string]interface{}{
FieldBookID: int64(2),
floatVector: []float32{0.2, 0.22},
binaryVector: []byte{2},
float16Vector: []byte{2, 2, 22, 22},
bfloat16Vector: []byte{2, 2, 22, 22},
sparseFloatVector: map[uint32]float32{1000: 0.3, 200: 0.44},
},
map[string]interface{}{
FieldBookID: int64(3),
floatVector: []float32{0.3, 0.33},
binaryVector: []byte{3},
float16Vector: []byte{3, 3, 33, 33},
bfloat16Vector: []byte{3, 3, 33, 33},
sparseFloatVector: map[uint32]float32{987621: 32190.31, 32189: 0.0001},
},
map[string]interface{}{
FieldBookID: int64(4),
floatVector: []float32{0.4, 0.44},
binaryVector: []byte{4},
float16Vector: []float32{0.4, 0.44},
bfloat16Vector: []float32{0.4, 0.44},
sparseFloatVector: map[uint32]float32{25: 0.1, 1: 0.11},
},
map[string]interface{}{
FieldBookID: int64(5),
floatVector: []float32{-0.4, -0.44},
binaryVector: []byte{5},
float16Vector: []int64{99999999, -99999999},
bfloat16Vector: []int64{99999999, -99999999},
sparseFloatVector: map[uint32]float32{1121: 0.1, 3: 0.11},
},
}
row3 := map[string]interface{}{
FieldBookID: int64(3),
floatVector: []float32{0.3, 0.33},
binaryVector: []byte{3},
float16Vector: []byte{3, 3, 33, 33},
bfloat16Vector: []byte{3, 3, 33, 33},
sparseFloatVector: map[uint32]float32{987621: 32190.31, 32189: 0.0001},
}
body, _ := wrapRequestBody([]map[string]interface{}{row1, row2, row3})
body, err := wrapRequestBody(testcaseRows)
assert.Nil(t, err)
primaryField := generatePrimaryField(schemapb.DataType_Int64, false)
floatVectorField := generateVectorFieldSchema(schemapb.DataType_FloatVector)
floatVectorField.Name = floatVector
Expand All @@ -1660,10 +1693,27 @@ func TestVector(t *testing.T) {
}
err, rows, validRows := checkAndSetData(string(body), collectionSchema)
assert.Equal(t, nil, err)
for _, row := range rows {
for i, row := range rows {
assert.Equal(t, 2, len(row[floatVector].([]float32)))
assert.Equal(t, 1, len(row[binaryVector].([]byte)))
assert.Equal(t, 4, len(row[float16Vector].([]byte)))
assert.Equal(t, 4, len(row[bfloat16Vector].([]byte)))
if fv, ok := testcaseRows[i][float16Vector].([]float32); ok {
assert.Equal(t, 8, len(row[float16Vector].([]byte)))
assert.Equal(t, serialize(fv), row[float16Vector].([]byte))
} else if _, ok := testcaseRows[i][float16Vector].([]int64); ok {
assert.Equal(t, 8, len(row[float16Vector].([]byte)))
} else {
assert.Equal(t, 4, len(row[float16Vector].([]byte)))
assert.Equal(t, testcaseRows[i][float16Vector].([]byte), row[float16Vector].([]byte))
}
if fv, ok := testcaseRows[i][bfloat16Vector].([]float32); ok {
assert.Equal(t, 8, len(row[bfloat16Vector].([]byte)))
assert.Equal(t, serialize(fv), row[bfloat16Vector].([]byte))
} else if _, ok := testcaseRows[i][bfloat16Vector].([]int64); ok {
assert.Equal(t, 8, len(row[bfloat16Vector].([]byte)))
} else {
assert.Equal(t, 4, len(row[bfloat16Vector].([]byte)))
assert.Equal(t, testcaseRows[i][bfloat16Vector].([]byte), row[bfloat16Vector].([]byte))
}
// all test sparse rows have 2 elements, each should be of 8 bytes
assert.Equal(t, 16, len(row[sparseFloatVector].([]byte)))
}
Expand All @@ -1674,7 +1724,7 @@ func TestVector(t *testing.T) {

assertError := func(field string, value interface{}) {
row := make(map[string]interface{})
for k, v := range row1 {
for k, v := range testcaseRows[0] {
row[k] = v
}
row[field] = value
Expand All @@ -1683,8 +1733,6 @@ func TestVector(t *testing.T) {
assert.Error(t, err)
}

assertError(bfloat16Vector, []int64{99999999, -99999999})
assertError(float16Vector, []int64{99999999, -99999999})
assertError(binaryVector, []int64{99999999, -99999999})
assertError(floatVector, []float64{math.MaxFloat64, 0})
assertError(sparseFloatVector, map[uint32]float32{0: -0.1, 1: 0.11, 2: 0.12})
Expand Down
2 changes: 2 additions & 0 deletions internal/json/sonic.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (

var (
json = sonic.ConfigStd
// MarshalToString is exported by gin/json package.
MarshalToString = json.MarshalToString
// Marshal is exported by gin/json package.
Marshal = json.Marshal
// Unmarshal is exported by gin/json package.
Expand Down
Loading

0 comments on commit 033d02f

Please sign in to comment.