diff --git a/machine_test.go b/machine_test.go index b8668dc2..cf89ed2b 100644 --- a/machine_test.go +++ b/machine_test.go @@ -23,6 +23,7 @@ import ( "net" "os" "os/exec" + "os/signal" "path/filepath" "strconv" "strings" @@ -1199,3 +1200,106 @@ func createValidConfig(t *testing.T, socketPath string) Config { }, } } + +func TestSignalForwarding(t *testing.T) { + forwardedSignals := []os.Signal{ + syscall.SIGUSR1, + syscall.SIGUSR2, + syscall.SIGINT, + syscall.SIGTERM, + } + ignoredSignals := []os.Signal{ + syscall.SIGHUP, + syscall.SIGQUIT, + } + + cfg := Config{ + Debug: true, + KernelImagePath: filepath.Join(testDataPath, "vmlinux"), + SocketPath: "/tmp/TestSignalForwarding.sock", + Drives: []models.Drive{ + { + DriveID: String("0"), + IsRootDevice: Bool(true), + IsReadOnly: Bool(false), + PathOnHost: String(testRootfs), + }, + }, + DisableValidation: true, + ForwardSignals: forwardedSignals, + } + defer os.RemoveAll("/tmp/TestSignalForwarding.sock") + + opClient := fctesting.MockClient{} + + ctx := context.Background() + client := NewClient(cfg.SocketPath, fctesting.NewLogEntry(t), true, WithOpsClient(&opClient)) + + fd, err := net.Listen("unix", cfg.SocketPath) + if err != nil { + t.Fatalf("unexpected error during creation of unix socket: %v", err) + } + defer fd.Close() + + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + cmd := exec.Command(filepath.Join(testDataPath, "sigprint.sh")) + cmd.Stdout = stdout + cmd.Stderr = stderr + stdin, err := cmd.StdinPipe() + assert.NoError(t, err) + + m, err := NewMachine( + ctx, + cfg, + WithClient(client), + WithProcessRunner(cmd), + WithLogger(fctesting.NewLogEntry(t)), + ) + if err != nil { + t.Fatalf("failed to create new machine: %v", err) + } + + if err := m.startVMM(ctx); err != nil { + t.Fatalf("error startVMM: %v", err) + } + defer m.StopVMM() + + sigChan := make(chan os.Signal) + signal.Notify(sigChan, ignoredSignals...) + defer func() { + signal.Stop(sigChan) + close(sigChan) + }() + + go func() { + for sig := range sigChan { + t.Logf("received signal %v, ignoring", sig) + } + }() + + go func() { + for _, sig := range append(forwardedSignals, ignoredSignals...) { + t.Logf("sending signal %v to self", sig) + syscall.Kill(syscall.Getpid(), sig.(syscall.Signal)) + } + + // give the child process time to receive signals and flush pipes + time.Sleep(1 * time.Second) + + // terminate the signal printing process + stdin.Write([]byte("q")) + }() + + err = m.Wait(ctx) + require.NoError(t, err, "wait returned an error") + + receivedSignals := []os.Signal{} + for _, sigStr := range strings.Split(strings.TrimSpace(stdout.String()), "\n") { + i, err := strconv.Atoi(sigStr) + require.NoError(t, err, "expected numeric output") + receivedSignals = append(receivedSignals, syscall.Signal(i)) + } + + assert.ElementsMatch(t, forwardedSignals, receivedSignals) +} diff --git a/testdata/sigprint.sh b/testdata/sigprint.sh new file mode 100755 index 00000000..4595c7b7 --- /dev/null +++ b/testdata/sigprint.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +typeset -i sig=1 +while (( sig < 65 )); do + trap "echo '$sig'" $sig 2>/dev/null + let sig=sig+1 +done + +>&2 echo "Send signals to PID $$ and type [q] when done." + +while : +do + read -n1 input + [ "$input" == "q" ] && break + sleep .1 +done