diff --git a/pkg/podmounter/mountoptions/mount_options.go b/pkg/podmounter/mountoptions/mount_options.go index aa555872..73575e34 100644 --- a/pkg/podmounter/mountoptions/mount_options.go +++ b/pkg/podmounter/mountoptions/mount_options.go @@ -32,7 +32,8 @@ func Send(ctx context.Context, sockPath string, options Options) error { return fmt.Errorf("failed to marshal message to send %s: %w", sockPath, err) } - conn, err := net.Dial("unix", sockPath) + var d net.Dialer + conn, err := d.DialContext(ctx, "unix", sockPath) if err != nil { return fmt.Errorf("failed to dial to unix socket %s: %w", sockPath, err) } @@ -40,6 +41,14 @@ func Send(ctx context.Context, sockPath string, options Options) error { unixConn := conn.(*net.UnixConn) + // `unixConn.WriteMsgUnix` does not respect `ctx`'s deadline, we need to call `unixConn.SetDeadline` to ensure `unixConn.WriteMsgUnix` has a deadline. + if deadline, ok := ctx.Deadline(); ok { + err := unixConn.SetDeadline(deadline) + if err != nil { + return fmt.Errorf("failed to set deadline on unix socket %s: %w", sockPath, err) + } + } + unixRights := syscall.UnixRights(options.Fd) messageN, unixRightsN, err := unixConn.WriteMsgUnix(message, unixRights, nil) if err != nil { @@ -55,7 +64,7 @@ func Send(ctx context.Context, sockPath string, options Options) error { var ( messageRecvSize = 1024 - // We only pass one file descriptor and its 32 bits + // We only pass one file descriptor and it's 32 bits unixRightsRecvSize = syscall.CmsgSpace(4) ) @@ -63,12 +72,22 @@ var ( func Recv(ctx context.Context, sockPath string) (Options, error) { warnAboutLongUnixSocketPath(sockPath) - l, err := net.Listen("unix", sockPath) + var lc net.ListenConfig + l, err := lc.Listen(ctx, "unix", sockPath) if err != nil { return Options{}, fmt.Errorf("failed to listen unix socket %s: %w", sockPath, err) } defer l.Close() + // `l.Accept` does not respect `ctx`'s deadline, we need to call `ul.SetDeadline` to ensure `l.Accept` has a deadline. + if deadline, ok := ctx.Deadline(); ok { + ul := l.(*net.UnixListener) + err := ul.SetDeadline(deadline) + if err != nil { + return Options{}, fmt.Errorf("failed to set deadline on unix socket %s: %w", sockPath, err) + } + } + conn, err := l.Accept() if err != nil { return Options{}, fmt.Errorf("failed to accept connection from unix socket %s: %w", sockPath, err)