Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement #4442 , goctl generate unit test files for api handler and logic #4443

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions tools/goctl/api/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ func init() {
goCmdFlags.StringVar(&gogen.VarStringHome, "home")
goCmdFlags.StringVar(&gogen.VarStringRemote, "remote")
goCmdFlags.StringVar(&gogen.VarStringBranch, "branch")
goCmdFlags.BoolVar(&gogen.VarBoolWithTest, "test")
goCmdFlags.StringVarWithDefaultValue(&gogen.VarStringStyle, "style", config.DefaultFormat)

javaCmdFlags.StringVar(&javagen.VarStringDir, "dir")
Expand Down
12 changes: 9 additions & 3 deletions tools/goctl/api/gogen/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ var (
// VarStringBranch describes the branch.
VarStringBranch string
// VarStringStyle describes the style of output files.
VarStringStyle string
VarStringStyle string
VarBoolWithTest bool
)

// GoCommand gen go project files from command line
Expand All @@ -49,6 +50,7 @@ func GoCommand(_ *cobra.Command, _ []string) error {
home := VarStringHome
remote := VarStringRemote
branch := VarStringBranch
withTest := VarBoolWithTest
if len(remote) > 0 {
repo, _ := util.CloneIntoGitHome(remote, branch)
if len(repo) > 0 {
Expand All @@ -66,11 +68,11 @@ func GoCommand(_ *cobra.Command, _ []string) error {
return errors.New("missing -dir")
}

return DoGenProject(apiFile, dir, namingStyle)
return DoGenProject(apiFile, dir, namingStyle, withTest)
}

// DoGenProject gen go project files with api file
func DoGenProject(apiFile, dir, style string) error {
func DoGenProject(apiFile, dir, style string, withTest bool) error {
api, err := parser.Parse(apiFile)
if err != nil {
return err
Expand Down Expand Up @@ -100,6 +102,10 @@ func DoGenProject(apiFile, dir, style string) error {
logx.Must(genHandlers(dir, rootPkg, cfg, api))
logx.Must(genLogic(dir, rootPkg, cfg, api))
logx.Must(genMiddleware(dir, cfg, api))
if withTest {
logx.Must(genHandlersTest(dir, rootPkg, cfg, api))
logx.Must(genLogicTest(dir, rootPkg, cfg, api))
}

if err := backupAndSweep(apiFile); err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion tools/goctl/api/gogen/gen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ func validateWithCamel(t *testing.T, api, camel string) {
assert.Nil(t, err)
err = initMod(dir)
assert.Nil(t, err)
err = DoGenProject(api, dir, camel)
err = DoGenProject(api, dir, camel, true)
assert.Nil(t, err)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if strings.HasSuffix(path, ".go") {
Expand Down
80 changes: 80 additions & 0 deletions tools/goctl/api/gogen/genhandlerstest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package gogen

import (
_ "embed"
"fmt"
"strings"

"github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/util"
"github.com/zeromicro/go-zero/tools/goctl/util/format"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)

//go:embed handler_test.tpl
var handlerTestTemplate string

func genHandlerTest(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
handler := getHandlerName(route)
handlerPath := getHandlerFolderPath(group, route)
pkgName := handlerPath[strings.LastIndex(handlerPath, "/")+1:]
logicName := defaultLogicPackage
if handlerPath != handlerDir {
handler = strings.Title(handler)
logicName = pkgName
}
filename, err := format.FileNamingFormat(cfg.NamingFormat, handler)
if err != nil {
return err
}

return genFile(fileGenConfig{
dir: dir,
subdir: getHandlerFolderPath(group, route),
filename: filename + "_test.go",
templateName: "handlerTestTemplate",
category: category,
templateFile: handlerTestTemplateFile,
builtinTemplate: handlerTestTemplate,
data: map[string]any{
"PkgName": pkgName,
"ImportPackages": genHandlerTestImports(group, route, rootPkg),
"HandlerName": handler,
"RequestType": util.Title(route.RequestTypeName()),
"ResponseType": util.Title(route.ResponseTypeName()),
"LogicName": logicName,
"LogicType": strings.Title(getLogicName(route)),
"Call": strings.Title(strings.TrimSuffix(handler, "Handler")),
"HasResp": len(route.ResponseTypeName()) > 0,
"HasRequest": len(route.RequestTypeName()) > 0,
"HasDoc": len(route.JoinedDoc()) > 0,
"Doc": getDoc(route.JoinedDoc()),
},
})
}

func genHandlersTest(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
for _, group := range api.Service.Groups {
for _, route := range group.Routes {
if err := genHandlerTest(dir, rootPkg, cfg, group, route); err != nil {
return err
}
}
}

return nil
}

func genHandlerTestImports(group spec.Group, route spec.Route, parentPkg string) string {
imports := []string{
//fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, getLogicFolderPath(group, route))),
fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir)),
fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, configDir)),
}
if len(route.RequestTypeName()) > 0 {
imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, typesDir)))
}

return strings.Join(imports, "\n\t")
}
90 changes: 90 additions & 0 deletions tools/goctl/api/gogen/genlogictest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package gogen

import (
_ "embed"
"fmt"
"strings"

"github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/util/format"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)

//go:embed logic_test.tpl
var logicTestTemplate string

func genLogicTest(dir, rootPkg string, cfg *config.Config, api *spec.ApiSpec) error {
for _, g := range api.Service.Groups {
for _, r := range g.Routes {
err := genLogicTestByRoute(dir, rootPkg, cfg, g, r)
if err != nil {
return err
}
}
}
return nil
}

func genLogicTestByRoute(dir, rootPkg string, cfg *config.Config, group spec.Group, route spec.Route) error {
logic := getLogicName(route)
goFile, err := format.FileNamingFormat(cfg.NamingFormat, logic)
if err != nil {
return err
}

imports := genLogicTestImports(route, rootPkg)
var responseString string
var returnString string
var requestString string
var requestType string
if len(route.ResponseTypeName()) > 0 {
resp := responseGoTypeName(route, typesPacket)
responseString = "(resp " + resp + ", err error)"
returnString = "return"
} else {
responseString = "error"
returnString = "return nil"
}
if len(route.RequestTypeName()) > 0 {
requestString = "req *" + requestGoTypeName(route, typesPacket)
requestType = requestGoTypeName(route, typesPacket)
}

subDir := getLogicFolderPath(group, route)
return genFile(fileGenConfig{
dir: dir,
subdir: subDir,
filename: goFile + "_test.go",
templateName: "logicTestTemplate",
category: category,
templateFile: logicTestTemplateFile,
builtinTemplate: logicTestTemplate,
data: map[string]any{
"pkgName": subDir[strings.LastIndex(subDir, "/")+1:],
"imports": imports,
"logic": strings.Title(logic),
"function": strings.Title(strings.TrimSuffix(logic, "Logic")),
"responseType": responseString,
"returnString": returnString,
"request": requestString,
"hasRequest": len(requestType) > 0,
"hasResponse": len(route.ResponseTypeName()) > 0,
"requestType": requestType,
"hasDoc": len(route.JoinedDoc()) > 0,
"doc": getDoc(route.JoinedDoc()),
},
})
}

func genLogicTestImports(route spec.Route, parentPkg string) string {
var imports []string
//imports = append(imports, `"context"`+"\n")
imports = append(imports, fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, contextDir)))
imports = append(imports, fmt.Sprintf("\"%s\"", pathx.JoinPackages(parentPkg, configDir)))
if shallImportTypesPackage(route) {
imports = append(imports, fmt.Sprintf("\"%s\"\n", pathx.JoinPackages(parentPkg, typesDir)))
}
//imports = append(imports, fmt.Sprintf("\"%s/core/logx\"", vars.ProjectOpenSourceURL))
return strings.Join(imports, "\n\t")
}
81 changes: 81 additions & 0 deletions tools/goctl/api/gogen/handler_test.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package {{.PkgName}}

import (
"bytes"
{{if .HasRequest}}"encoding/json"{{end}}
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
{{.ImportPackages}}
)

{{if .HasDoc}}{{.Doc}}{{end}}
func Test{{.HandlerName}}(t *testing.T) {
// new service context
c := config.Config{}
svcCtx := svc.NewServiceContext(c)
// init mock service context here

tests := []struct {
name string
reqBody interface{}
wantStatus int
wantResp string
setupMocks func()
}{
{
name: "invalid request body",
reqBody: "invalid",
wantStatus: http.StatusBadRequest,
wantResp: "unsupported type", // Adjust based on actual error response
setupMocks: func() {
// No setup needed for this test case
},
},
{
name: "handler error",
{{if .HasRequest}}reqBody: types.{{.RequestType}}{
//TODO: add fields here
},
{{end}}wantStatus: http.StatusBadRequest,
wantResp: "error", // Adjust based on actual error response
setupMocks: func() {
// Mock login logic to return an error
},
},
{
name: "handler successful",
{{if .HasRequest}}reqBody: types.{{.RequestType}}{
//TODO: add fields here
},
{{end}}wantStatus: http.StatusOK,
wantResp: `{"code":0,"msg":"success","data":{}}`, // Adjust based on actual success response
setupMocks: func() {
// Mock login logic to return success
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupMocks()
var reqBody []byte
{{if .HasRequest}}var err error
reqBody, err = json.Marshal(tt.reqBody)
require.NoError(t, err){{end}}
req, err := http.NewRequest("POST", "/ut", bytes.NewBuffer(reqBody))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")

rr := httptest.NewRecorder()
handler := {{.HandlerName}}(svcCtx)
handler.ServeHTTP(rr, req)
t.Log(rr.Body.String())
assert.Equal(t, tt.wantStatus, rr.Code)
assert.Contains(t, rr.Body.String(), tt.wantResp)
})
}
}
69 changes: 69 additions & 0 deletions tools/goctl/api/gogen/logic_test.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package {{.pkgName}}

import (
"context"
"testing"

{{.imports}}
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test{{.logic}}_{{.function}}(t *testing.T) {
c := config.Config{}
mockSvcCtx := svc.NewServiceContext(c)
// init mock service context here

tests := []struct {
name string
ctx context.Context
setupMocks func()
{{if .hasRequest}}req *{{.requestType}}{{end}}
wantErr bool
checkResp func{{if .hasResponse}}{{.responseType}}{{else}}(err error){{end}}
}{
{
name: "response error",
ctx: context.Background(),
setupMocks: func() {
// mock data for this test case
},
{{if .hasRequest}}req: &{{.requestType}}{
// TODO: init your request here
},{{end}}
wantErr: true,
checkResp: func{{if .hasResponse}}{{.responseType}}{{else}}(err error){{end}} {
// TODO: Add your check logic here
},
},
{
name: "successful",
ctx: context.Background(),
setupMocks: func() {
// Mock data for this test case
},
{{if .hasRequest}}req: &{{.requestType}}{
// TODO: init your request here
},{{end}}
wantErr: false,
checkResp: func{{if .hasResponse}}{{.responseType}}{{else}}(err error){{end}} {
// TODO: Add your check logic here
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupMocks()
l := New{{.logic}}(tt.ctx, mockSvcCtx)
{{if .hasResponse}}resp, {{end}}err := l.{{.function}}({{if .hasRequest}}tt.req{{end}})
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
{{if .hasResponse}}assert.NotNil(t, resp){{end}}
}
tt.checkResp({{if .hasResponse}}resp, {{end}}err)
})
}
}
Loading