Skip to content

Commit

Permalink
slice: 重构 slice 中使用 equalFunc 的方法
Browse files Browse the repository at this point in the history
1. 引入 matchFunc
2. 简化 slice 中查找类型的 API
3. 引入 Find 和 FindAll 方法
  • Loading branch information
flycash committed Aug 14, 2023
1 parent 5f533c9 commit f290bac
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 38 deletions.
10 changes: 6 additions & 4 deletions slice/contains.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ package slice

// Contains 判断 src 里面是否存在 dst
func Contains[T comparable](src []T, dst T) bool {
return ContainsFunc[T](src, dst, func(src, dst T) bool {
return ContainsFunc[T](src, func(src T) bool {
return src == dst
})
}

// ContainsFunc 判断 src 里面是否存在 dst
// 你应该优先使用 Contains
func ContainsFunc[T any](src []T, dst T, equal equalFunc[T]) bool {
func ContainsFunc[T any](src []T, equal func(src T) bool) bool {
// 遍历调用equal函数进行判断
for _, v := range src {
if equal(v, dst) {
if equal(v) {
return true
}
}
Expand Down Expand Up @@ -72,7 +72,9 @@ func ContainsAll[T comparable](src, dst []T) bool {
// 你应该优先使用 ContainsAll
func ContainsAllFunc[T any](src, dst []T, equal equalFunc[T]) bool {
for _, valDst := range dst {
if !ContainsFunc[T](src, valDst, equal) {
if !ContainsFunc[T](src, func(src T) bool {
return equal(src, valDst)
}) {
return false
}
}
Expand Down
8 changes: 4 additions & 4 deletions slice/contains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ func TestContainsFunc(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.want, ContainsFunc[int](test.src, test.dst, func(src, dst int) bool {
return src == dst
assert.Equal(t, test.want, ContainsFunc[int](test.src, func(src int) bool {
return src == test.dst
}))
})
}
Expand Down Expand Up @@ -287,8 +287,8 @@ func ExampleContains() {
}

func ExampleContainsFunc() {
res := ContainsFunc[int]([]int{1, 2, 3}, 3, func(src, dst int) bool {
return src == dst
res := ContainsFunc[int]([]int{1, 2, 3}, func(src int) bool {
return src == 3
})
fmt.Println(res)
// Output:
Expand Down
5 changes: 3 additions & 2 deletions slice/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ func DiffSet[T comparable](src, dst []T) []T {
// DiffSetFunc 差集,已去重
// 你应该优先使用 DiffSet
func DiffSetFunc[T any](src, dst []T, equal equalFunc[T]) []T {
// TODO 优化容量预估
var ret = make([]T, 0, len(src))
for _, val := range src {
if !ContainsFunc[T](dst, val, equal) {
if !ContainsFunc[T](dst, func(src T) bool {
return equal(src, val)
}) {
ret = append(ret, val)
}
}
Expand Down
43 changes: 43 additions & 0 deletions slice/find.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package slice

// Find 查找元素
// 如果没有找到,第二个返回值返回 false
func Find[T any](src []T, match matchFunc[T]) (T, bool) {
for _, val := range src {
if match(val) {
return val, true
}
}
var t T
return t, false
}

// FindAll 查找所有符合条件的元素
// 永远不会返回 nil
func FindAll[T any](src []T, match matchFunc[T]) []T {
// 我们认为符合条件元素应该是少数
// 所以会除以 8
// 也就是触发扩容的情况下,最多三次就会和原本的容量一样
// +1 是为了保证,至少有一个元素
res := make([]T, 0, len(src)>>3+1)
for _, val := range src {
if match(val) {
res = append(res, val)
}
}
return res
}
149 changes: 149 additions & 0 deletions slice/find_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package slice

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestFind(t *testing.T) {
testCases := []struct {
name string
input []Number
match matchFunc[Number]

wantVal Number
found bool
}{
{
name: "找到了",
input: []Number{
{val: 123},
{val: 234},
},
match: func(src Number) bool {
return src.val == 123
},
wantVal: Number{val: 123},
found: true,
},
{
name: "没找到",
input: []Number{
{val: 123},
{val: 234},
},
match: func(src Number) bool {
return src.val == 456
},
},
{
name: "nil",
match: func(src Number) bool {
return src.val == 123
},
},
{
name: "没有元素",
input: []Number{},
match: func(src Number) bool {
return src.val == 123
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
val, found := Find[Number](tc.input, tc.match)
assert.Equal(t, tc.found, found)
assert.Equal(t, tc.wantVal, val)
})
}
}

func TestFindAll(t *testing.T) {
testCases := []struct {
name string
input []Number
match matchFunc[Number]

wantVals []Number
}{
{
name: "没有符合条件的",
input: []Number{{val: 2}, {val: 4}},
match: func(src Number) bool {
return src.val%2 == 1
},
wantVals: []Number{},
},
{
name: "找到了",
input: []Number{{val: 2}, {val: 3}, {val: 4}},
match: func(src Number) bool {
return src.val%2 == 1
},
wantVals: []Number{{val: 3}},
},
{
name: "nil",
match: func(src Number) bool {
return src.val%2 == 1
},
wantVals: []Number{},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
vals := FindAll[Number](tc.input, tc.match)
assert.Equal(t, tc.wantVals, vals)
})
}
}

func ExampleFind() {
val, ok := Find[int]([]int{1, 2, 3}, func(src int) bool {
return src == 2
})
fmt.Println(val, ok)
val, ok = Find[int]([]int{1, 2, 3}, func(src int) bool {
return src == 4
})
fmt.Println(val, ok)
// Output:
// 2 true
// 0 false
}

func ExampleFindAll() {
vals := FindAll[int]([]int{2, 3, 4}, func(src int) bool {
return src%2 == 1
})
fmt.Println(vals)
vals = FindAll[int]([]int{2, 3, 4}, func(src int) bool {
return src > 5
})
fmt.Println(vals)
// Output:
// [3]
// []
}

type Number struct {
val int
}
22 changes: 11 additions & 11 deletions slice/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ package slice
// Index 返回和 dst 相等的第一个元素下标
// -1 表示没找到
func Index[T comparable](src []T, dst T) int {
return IndexFunc[T](src, dst, func(src, dst T) bool {
return IndexFunc[T](src, func(src T) bool {
return src == dst
})
}

// IndexFunc 返回和 dst 相等的第一个元素下标
// IndexFunc 返回 match 返回 true 的第一个下标
// -1 表示没找到
// 你应该优先使用 Index
func IndexFunc[T any](src []T, dst T, equal equalFunc[T]) int {
func IndexFunc[T any](src []T, match matchFunc[T]) int {
for k, v := range src {
if equal(v, dst) {
if match(v) {
return k
}
}
Expand All @@ -37,17 +37,17 @@ func IndexFunc[T any](src []T, dst T, equal equalFunc[T]) int {
// LastIndex 返回和 dst 相等的最后一个元素下标
// -1 表示没找到
func LastIndex[T comparable](src []T, dst T) int {
return LastIndexFunc[T](src, dst, func(src, dst T) bool {
return LastIndexFunc[T](src, func(src T) bool {
return src == dst
})
}

// LastIndexFunc 返回和 dst 相等的最后一个元素下标
// -1 表示没找到
// 你应该优先使用 LastIndex
func LastIndexFunc[T any](src []T, dst T, equal equalFunc[T]) int {
func LastIndexFunc[T any](src []T, match matchFunc[T]) int {
for i := len(src) - 1; i >= 0; i-- {
if equal(dst, src[i]) {
if match(src[i]) {
return i
}
}
Expand All @@ -56,17 +56,17 @@ func LastIndexFunc[T any](src []T, dst T, equal equalFunc[T]) int {

// IndexAll 返回和 dst 相等的所有元素的下标
func IndexAll[T comparable](src []T, dst T) []int {
return IndexAllFunc[T](src, dst, func(src, dst T) bool {
return IndexAllFunc[T](src, func(src T) bool {
return src == dst
})
}

// IndexAllFunc 返回和 dst 相等的所有元素的下标
// IndexAllFunc 返回和 match 返回 true 的所有元素的下标
// 你应该优先使用 IndexAll
func IndexAllFunc[T any](src []T, dst T, equal equalFunc[T]) []int {
func IndexAllFunc[T any](src []T, match matchFunc[T]) []int {
var indexes = make([]int, 0, len(src))
for k, v := range src {
if equal(v, dst) {
if match(v) {
indexes = append(indexes, k)
}
}
Expand Down
28 changes: 14 additions & 14 deletions slice/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ func TestIndexFunc(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.want, IndexFunc[int](test.src, test.dst, func(src, dst int) bool {
return src == dst
assert.Equal(t, test.want, IndexFunc[int](test.src, func(src int) bool {
return src == test.dst
}))
})
}
Expand Down Expand Up @@ -191,8 +191,8 @@ func TestLastIndexFunc(t *testing.T) {
},
}
for _, test := range tests {
assert.Equal(t, test.want, LastIndexFunc[int](test.src, test.dst, func(src, dst int) bool {
return src == dst
assert.Equal(t, test.want, LastIndexFunc[int](test.src, func(src int) bool {
return src == test.dst
}))
}
}
Expand Down Expand Up @@ -268,8 +268,8 @@ func TestIndexAllFunc(t *testing.T) {
},
}
for _, test := range tests {
res := IndexAllFunc[int](test.src, test.dst, func(src, dst int) bool {
return src == dst
res := IndexAllFunc[int](test.src, func(src int) bool {
return src == test.dst
})
assert.ElementsMatch(t, test.want, res)
}
Expand All @@ -286,12 +286,12 @@ func ExampleIndex() {
}

func ExampleIndexFunc() {
res := IndexFunc[int]([]int{1, 2, 3}, 1, func(src, dst int) bool {
return src == dst
res := IndexFunc[int]([]int{1, 2, 3}, func(src int) bool {
return src == 1
})
fmt.Println(res)
res = IndexFunc[int]([]int{1, 2, 3}, 4, func(src, dst int) bool {
return src == dst
res = IndexFunc[int]([]int{1, 2, 3}, func(src int) bool {
return src == 4
})
fmt.Println(res)
// Output:
Expand All @@ -310,12 +310,12 @@ func ExampleIndexAll() {
}

func ExampleIndexAllFunc() {
res := IndexAllFunc[int]([]int{1, 2, 3, 4, 5, 3, 9}, 3, func(src, dst int) bool {
return src == dst
res := IndexAllFunc[int]([]int{1, 2, 3, 4, 5, 3, 9}, func(src int) bool {
return src == 3
})
fmt.Println(res)
res = IndexAllFunc[int]([]int{1, 2, 3}, 4, func(src, dst int) bool {
return src == dst
res = IndexAllFunc[int]([]int{1, 2, 3}, func(src int) bool {
return src == 4
})
fmt.Println(res)
// Output:
Expand Down
Loading

0 comments on commit f290bac

Please sign in to comment.