From 88f490b19d1ed991b5154d49f009dcf925497fdc Mon Sep 17 00:00:00 2001
From: Praveen M <m.praveen@ibm.com>
Date: Thu, 23 May 2024 19:34:28 +0530
Subject: [PATCH] rbd: add volume locks for reclaimspace operations

This commit adds locks on reclaimspace operations to
prevent multiple process executing rbd sparsify/fstrim
on same volume.

Signed-off-by: Praveen M <m.praveen@ibm.com>
---
 internal/csi-addons/rbd/reclaimspace.go      | 26 +++++++++++++++++---
 internal/csi-addons/rbd/reclaimspace_test.go |  6 +++--
 internal/rbd/driver/driver.go                |  4 +--
 3 files changed, 28 insertions(+), 8 deletions(-)

diff --git a/internal/csi-addons/rbd/reclaimspace.go b/internal/csi-addons/rbd/reclaimspace.go
index 97a5ea02523f..7aa4962d73c2 100644
--- a/internal/csi-addons/rbd/reclaimspace.go
+++ b/internal/csi-addons/rbd/reclaimspace.go
@@ -37,12 +37,14 @@ import (
 // of CSI-addons reclaimspace controller service spec.
 type ReclaimSpaceControllerServer struct {
 	*rs.UnimplementedReclaimSpaceControllerServer
+	// Embed ControllerServer as it implements helper functions
+	*rbdutil.ControllerServer
 }
 
 // NewReclaimSpaceControllerServer creates a new ReclaimSpaceControllerServer which handles
 // the ReclaimSpace Service requests from the CSI-Addons specification.
-func NewReclaimSpaceControllerServer() *ReclaimSpaceControllerServer {
-	return &ReclaimSpaceControllerServer{}
+func NewReclaimSpaceControllerServer(c *rbdutil.ControllerServer) *ReclaimSpaceControllerServer {
+	return &ReclaimSpaceControllerServer{ControllerServer: c}
 }
 
 func (rscs *ReclaimSpaceControllerServer) RegisterService(server grpc.ServiceRegistrar) {
@@ -64,6 +66,13 @@ func (rscs *ReclaimSpaceControllerServer) ControllerReclaimSpace(
 	}
 	defer cr.DeleteCredentials()
 
+	if acquired := rscs.VolumeLocks.TryAcquire(volumeID); !acquired {
+		log.ErrorLog(ctx, util.VolumeOperationAlreadyExistsFmt, volumeID)
+
+		return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsFmt, volumeID)
+	}
+	defer rscs.VolumeLocks.Release(volumeID)
+
 	rbdVol, err := rbdutil.GenVolFromVolID(ctx, volumeID, cr, req.GetSecrets())
 	if err != nil {
 		return nil, status.Errorf(codes.Aborted, "failed to find volume with ID %q: %s", volumeID, err.Error())
@@ -90,12 +99,14 @@ func (rscs *ReclaimSpaceControllerServer) ControllerReclaimSpace(
 // of CSI-addons reclaimspace controller service spec.
 type ReclaimSpaceNodeServer struct {
 	*rs.UnimplementedReclaimSpaceNodeServer
+	// Embed ControllerServer as it implements helper functions
+	*rbdutil.ControllerServer
 }
 
 // NewReclaimSpaceNodeServer creates a new IdentityServer which handles the
 // Identity Service requests from the CSI-Addons specification.
-func NewReclaimSpaceNodeServer() *ReclaimSpaceNodeServer {
-	return &ReclaimSpaceNodeServer{}
+func NewReclaimSpaceNodeServer(c *rbdutil.ControllerServer) *ReclaimSpaceNodeServer {
+	return &ReclaimSpaceNodeServer{ControllerServer: c}
 }
 
 func (rsns *ReclaimSpaceNodeServer) RegisterService(server grpc.ServiceRegistrar) {
@@ -116,6 +127,13 @@ func (rsns *ReclaimSpaceNodeServer) NodeReclaimSpace(
 		return nil, status.Error(codes.InvalidArgument, "empty volume ID in request")
 	}
 
+	if acquired := rsns.VolumeLocks.TryAcquire(volumeID); !acquired {
+		log.ErrorLog(ctx, util.VolumeOperationAlreadyExistsFmt, volumeID)
+
+		return nil, status.Errorf(codes.Aborted, util.VolumeOperationAlreadyExistsFmt, volumeID)
+	}
+	defer rsns.VolumeLocks.Release(volumeID)
+
 	// path can either be the staging path on the node, or the volume path
 	// inside an application container
 	path := req.GetStagingTargetPath()
diff --git a/internal/csi-addons/rbd/reclaimspace_test.go b/internal/csi-addons/rbd/reclaimspace_test.go
index 6effbc0fa0f4..bfc35038e6bb 100644
--- a/internal/csi-addons/rbd/reclaimspace_test.go
+++ b/internal/csi-addons/rbd/reclaimspace_test.go
@@ -22,6 +22,8 @@ import (
 
 	rs "github.com/csi-addons/spec/lib/go/reclaimspace"
 	"github.com/stretchr/testify/require"
+
+	rbdutil "github.com/ceph/ceph-csi/internal/rbd"
 )
 
 // TestControllerReclaimSpace is a minimal test for the
@@ -30,7 +32,7 @@ import (
 func TestControllerReclaimSpace(t *testing.T) {
 	t.Parallel()
 
-	controller := NewReclaimSpaceControllerServer()
+	controller := NewReclaimSpaceControllerServer(&rbdutil.ControllerServer{})
 
 	req := &rs.ControllerReclaimSpaceRequest{
 		VolumeId: "",
@@ -47,7 +49,7 @@ func TestControllerReclaimSpace(t *testing.T) {
 func TestNodeReclaimSpace(t *testing.T) {
 	t.Parallel()
 
-	node := NewReclaimSpaceNodeServer()
+	node := NewReclaimSpaceNodeServer(&rbdutil.ControllerServer{})
 
 	req := &rs.NodeReclaimSpaceRequest{
 		VolumeId:         "",
diff --git a/internal/rbd/driver/driver.go b/internal/rbd/driver/driver.go
index c4c736c3447d..087c70a7e26c 100644
--- a/internal/rbd/driver/driver.go
+++ b/internal/rbd/driver/driver.go
@@ -213,7 +213,7 @@ func (r *Driver) setupCSIAddonsServer(conf *util.Config) error {
 	r.cas.RegisterService(is)
 
 	if conf.IsControllerServer {
-		rs := casrbd.NewReclaimSpaceControllerServer()
+		rs := casrbd.NewReclaimSpaceControllerServer(NewControllerServer(r.cd))
 		r.cas.RegisterService(rs)
 
 		fcs := casrbd.NewFenceControllerServer()
@@ -224,7 +224,7 @@ func (r *Driver) setupCSIAddonsServer(conf *util.Config) error {
 	}
 
 	if conf.IsNodeServer {
-		rs := casrbd.NewReclaimSpaceNodeServer()
+		rs := casrbd.NewReclaimSpaceNodeServer(NewControllerServer(r.cd))
 		r.cas.RegisterService(rs)
 	}