From 3ac3c9de53ad81f43d70bebfb642a89c2c69a4c1 Mon Sep 17 00:00:00 2001 From: Guilherme Soares <48023091+guilhas07@users.noreply.github.com> Date: Mon, 2 Sep 2024 08:14:40 -0700 Subject: [PATCH] fix: send SIGTERM signal to --cmd instead of SIGKILL (#687) Co-authored-by: Adrian Hesketh Co-authored-by: Adrian Hesketh --- cmd/templ/generatecmd/run/run_test.go | 108 ++++++++++++++++++ cmd/templ/generatecmd/run/run_unix.go | 46 ++++++-- .../generatecmd/run/testprogram/go.mod.embed | 3 + cmd/templ/generatecmd/run/testprogram/main.go | 63 ++++++++++ 4 files changed, 209 insertions(+), 11 deletions(-) create mode 100644 cmd/templ/generatecmd/run/run_test.go create mode 100644 cmd/templ/generatecmd/run/testprogram/go.mod.embed create mode 100644 cmd/templ/generatecmd/run/testprogram/main.go diff --git a/cmd/templ/generatecmd/run/run_test.go b/cmd/templ/generatecmd/run/run_test.go new file mode 100644 index 000000000..a1a0cff1f --- /dev/null +++ b/cmd/templ/generatecmd/run/run_test.go @@ -0,0 +1,108 @@ +package run_test + +import ( + "context" + "embed" + "io" + "net/http" + "os" + "path/filepath" + "syscall" + "testing" + "time" + + "github.com/a-h/templ/cmd/templ/generatecmd/run" +) + +//go:embed testprogram/* +var testprogram embed.FS + +func TestGoRun(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test in short mode.") + } + + // Copy testprogram to a temporary directory. + dir, err := os.MkdirTemp("", "testprogram") + if err != nil { + t.Fatalf("failed to make test dir: %v", err) + } + files, err := testprogram.ReadDir("testprogram") + if err != nil { + t.Fatalf("failed to read embedded dir: %v", err) + } + for _, file := range files { + srcFileName := "testprogram/" + file.Name() + srcData, err := testprogram.ReadFile(srcFileName) + if err != nil { + t.Fatalf("failed to read src file %q: %v", srcFileName, err) + } + tgtFileName := filepath.Join(dir, file.Name()) + tgtFile, err := os.Create(tgtFileName) + if err != nil { + t.Fatalf("failed to create tgt file %q: %v", tgtFileName, err) + } + defer tgtFile.Close() + if _, err := tgtFile.Write(srcData); err != nil { + t.Fatalf("failed to write to tgt file %q: %v", tgtFileName, err) + } + } + // Rename the go.mod.embed file to go.mod. + if err := os.Rename(filepath.Join(dir, "go.mod.embed"), filepath.Join(dir, "go.mod")); err != nil { + t.Fatalf("failed to rename go.mod.embed: %v", err) + } + + tests := []struct { + name string + cmd string + }{ + { + name: "Well behaved programs get shut down", + cmd: "go run .", + }, + { + name: "Badly behaved programs get shut down", + cmd: "go run . -badly-behaved", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + cmd, err := run.Run(ctx, dir, tt.cmd) + if err != nil { + t.Fatalf("failed to run program: %v", err) + } + + time.Sleep(1 * time.Second) + + pid := cmd.Process.Pid + + if err := run.KillAll(); err != nil { + t.Fatalf("failed to kill all: %v", err) + } + + // Check the parent process is no longer running. + if err := cmd.Process.Signal(os.Signal(syscall.Signal(0))); err == nil { + t.Fatalf("process %d is still running", pid) + } + // Check that the child was stopped. + body, err := readResponse("http://localhost:7777") + if err == nil { + t.Fatalf("child process is still running: %s", body) + } + }) + } +} + +func readResponse(url string) (body string, err error) { + resp, err := http.Get(url) + if err != nil { + return body, err + } + defer resp.Body.Close() + b, err := io.ReadAll(resp.Body) + if err != nil { + return body, err + } + return string(b), nil +} diff --git a/cmd/templ/generatecmd/run/run_unix.go b/cmd/templ/generatecmd/run/run_unix.go index b989768b5..68f8d6724 100644 --- a/cmd/templ/generatecmd/run/run_unix.go +++ b/cmd/templ/generatecmd/run/run_unix.go @@ -4,31 +4,52 @@ package run import ( "context" + "errors" + "fmt" "os" "os/exec" "strings" "sync" "syscall" + "time" ) -var m = &sync.Mutex{} -var running = map[string]*exec.Cmd{} +var ( + m = &sync.Mutex{} + running = map[string]*exec.Cmd{} +) func KillAll() (err error) { m.Lock() defer m.Unlock() + var errs []error for _, cmd := range running { - err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) - if err != nil { - return err + if err := kill(cmd); err != nil { + errs = append(errs, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err)) } } running = map[string]*exec.Cmd{} - return + return errors.Join(errs...) +} + +func kill(cmd *exec.Cmd) (err error) { + errs := make([]error, 4) + errs[0] = ignoreExited(cmd.Process.Signal(syscall.SIGINT)) + errs[1] = ignoreExited(cmd.Process.Signal(syscall.SIGTERM)) + errs[2] = ignoreExited(cmd.Wait()) + errs[3] = ignoreExited(syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)) + return errors.Join(errs...) } -func Stop(cmd *exec.Cmd) (err error) { - return syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) +func ignoreExited(err error) error { + if errors.Is(err, syscall.ESRCH) { + return nil + } + // Ignore *exec.ExitError + if _, ok := err.(*exec.ExitError); ok { + return nil + } + return err } func Run(ctx context.Context, workingDir, input string) (cmd *exec.Cmd, err error) { @@ -36,9 +57,10 @@ func Run(ctx context.Context, workingDir, input string) (cmd *exec.Cmd, err erro defer m.Unlock() cmd, ok := running[input] if ok { - if err = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL); err != nil { - return cmd, err + if err := kill(cmd); err != nil { + return cmd, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err) } + delete(running, input) } parts := strings.Fields(input) @@ -48,7 +70,9 @@ func Run(ctx context.Context, workingDir, input string) (cmd *exec.Cmd, err erro args = append(args, parts[1:]...) } - cmd = exec.Command(executable, args...) + cmd = exec.CommandContext(ctx, executable, args...) + // Wait for the process to finish gracefully before termination. + cmd.WaitDelay = time.Second * 3 cmd.Env = os.Environ() cmd.Dir = workingDir cmd.Stdout = os.Stdout diff --git a/cmd/templ/generatecmd/run/testprogram/go.mod.embed b/cmd/templ/generatecmd/run/testprogram/go.mod.embed new file mode 100644 index 000000000..719ef94b8 --- /dev/null +++ b/cmd/templ/generatecmd/run/testprogram/go.mod.embed @@ -0,0 +1,3 @@ +module testprogram + +go 1.22.6 diff --git a/cmd/templ/generatecmd/run/testprogram/main.go b/cmd/templ/generatecmd/run/testprogram/main.go new file mode 100644 index 000000000..04ef8810a --- /dev/null +++ b/cmd/templ/generatecmd/run/testprogram/main.go @@ -0,0 +1,63 @@ +package main + +import ( + "flag" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" +) + +// This is a test program. It is used only to test the behaviour of the run package. +// The run package is supposed to be able to run and stop programs. Those programs may start +// child processes, which should also be stopped when the parent program is stopped. + +// For example, running `go run .` will compile an executable and run it. + +// So, this program does nothing. It just waits for a signal to stop. + +// In "Well behaved" mode, the program will stop when it receives a signal. +// In "Badly behaved" mode, the program will ignore the signal and continue running. + +// The run package should be able to stop the program in both cases. + +var badlyBehavedFlag = flag.Bool("badly-behaved", false, "If set, the program will ignore the stop signal and continue running.") + +func main() { + flag.Parse() + + mode := "Well behaved" + if *badlyBehavedFlag { + mode = "Badly behaved" + } + fmt.Printf("%s process %d started.\n", mode, os.Getpid()) + + // Start a web server on a known port so that we can check that this process is + // not running, when it's been started as a child process, and we don't know + // its pid. + go func() { + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%d", os.Getpid()) + }) + err := http.ListenAndServe("127.0.0.1:7777", nil) + if err != nil { + fmt.Printf("Error running web server: %v\n", err) + } + }() + + sigs := make(chan os.Signal, 1) + if !*badlyBehavedFlag { + signal.Notify(sigs, os.Interrupt, syscall.SIGTERM) + } + for { + select { + case <-sigs: + fmt.Printf("Process %d received signal. Stopping.\n", os.Getpid()) + return + case <-time.After(1 * time.Second): + fmt.Printf("Process %d still running...\n", os.Getpid()) + } + } +}