diff --git a/pkg/netdevice/netResourcePool_test.go b/pkg/netdevice/netResourcePool_test.go index fbbd72fa5..60aa49334 100644 --- a/pkg/netdevice/netResourcePool_test.go +++ b/pkg/netdevice/netResourcePool_test.go @@ -15,6 +15,8 @@ package netdevice_test import ( + "fmt" + nettypes "github.com/k8snetworkplumbingwg/network-attachment-definition-client/pkg/apis/k8s.cni.cncf.io/v1" "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/factory" "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/netdevice" "github.com/k8snetworkplumbingwg/sriov-network-device-plugin/pkg/types" @@ -24,6 +26,7 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + . "github.com/stretchr/testify/mock" ) var _ = Describe("NetResourcePool", func() { @@ -144,4 +147,55 @@ var _ = Describe("NetResourcePool", func() { }) }) }) + Describe("Saving and Cleaning DevInfo files ", func() { + Context("for valid pci devices", func() { + rc := &types.ResourceConfig{ + ResourceName: "fakeResource", + ResourcePrefix: "fakeOrg.io", + SelectorObj: &types.NetDeviceSelectors{ + IsRdma: true, + }, + } + nadutils := &mocks.NadUtils{} + nadutils.On("SaveDeviceInfoFile", "fakeOrg.io/fakeResource", "fake1", Anything). + Return(func(rName, id string, devInfo *nettypes.DeviceInfo) error { + if devInfo.Type != nettypes.DeviceInfoTypePCI || devInfo.Pci == nil || devInfo.Pci.PciAddress != "0000:01:00.1" { + return fmt.Errorf("wrong device info") + } + return nil + }) + nadutils.On("SaveDeviceInfoFile", "fakeOrg.io/fakeResource", "fake2", Anything). + Return(func(rName, id string, devInfo *nettypes.DeviceInfo) error { + if devInfo.Type != nettypes.DeviceInfoTypePCI || devInfo.Pci == nil || devInfo.Pci.PciAddress != "0000:01:00.2" { + return fmt.Errorf("wrong device info") + } + return nil + }) + nadutils.On("CleanDeviceInfoFile", "fakeOrg.io/fakeResource", "fake1").Return(nil) + nadutils.On("CleanDeviceInfoFile", "fakeOrg.io/fakeResource", "fake2").Return(nil) + + devs := map[string]*v1beta1.Device{} + fake1 := &mocks.PciNetDevice{} + fake1.On("GetPciAddr").Return("0000:01:00.1") + fake2 := &mocks.PciNetDevice{} + fake2.On("GetPciAddr").Return("0000:01:00.2") + pcis := map[string]types.PciDevice{"fake1": fake1, "fake2": fake2} + rp := netdevice.NewNetResourcePool(nadutils, rc, devs, pcis) + + err_store := rp.StoreDeviceInfoFile("fakeOrg.io") + err_clean := rp.CleanDeviceInfoFile("fakeOrg.io") + + It("should call nadutils to create a well formatted DeviceInfo object", func() { + nadutils.AssertCalled(GinkgoT(), "SaveDeviceInfoFile", "fakeOrg.io/fakeResource", "fake1", Anything) + nadutils.AssertCalled(GinkgoT(), "SaveDeviceInfoFile", "fakeOrg.io/fakeResource", "fake2", Anything) + Expect(err_store).ToNot(HaveOccurred()) + }) + It("should call nadutils to clean the DeviceInfo objects", func() { + fmt.Printf("%+v", nadutils.Calls) + nadutils.AssertCalled(GinkgoT(), "CleanDeviceInfoFile", "fakeOrg.io/fakeResource", "fake1") + nadutils.AssertCalled(GinkgoT(), "CleanDeviceInfoFile", "fakeOrg.io/fakeResource", "fake2") + Expect(err_clean).ToNot(HaveOccurred()) + }) + }) + }) }) diff --git a/pkg/resources/server_test.go b/pkg/resources/server_test.go index 3f2140768..60c99b52c 100644 --- a/pkg/resources/server_test.go +++ b/pkg/resources/server_test.go @@ -58,6 +58,8 @@ var _ = Describe("Server", func() { rp := mocks.ResourcePool{} rp.On("Probe").Return(false) rp.On("GetResourceName").Return("fakename") + rp.On("StoreDeviceInfoFile", "fakeprefix").Return(nil) + rp.On("CleanDeviceInfoFile", "fakeprefix").Return(nil) // Use faked dir as socket dir types.SockDir = fs.RootDir @@ -71,9 +73,9 @@ var _ = Describe("Server", func() { if shouldRunServer { if shouldEnablePluginWatch { - rp.On("StoreDeviceInfoFile", "fakeprefix").Return(nil) - rp.On("CleanDeviceInfoFile", "fakeprefix").Return(nil) rs.Start() + rp.AssertCalled(GinkgoT(), "CleanDeviceInfoFile", "fakeprefix") + rp.AssertCalled(GinkgoT(), "StoreDeviceInfoFile", "fakeprefix") } else { os.MkdirAll(pluginapi.DevicePluginPath, 0755) registrationServer.start() @@ -171,6 +173,7 @@ var _ = Describe("Server", func() { rp.On("CleanDeviceInfoFile", "fake").Return(nil) err := rs.Stop() Expect(err).NotTo(HaveOccurred()) + rp.AssertCalled(GinkgoT(), "CleanDeviceInfoFile", "fake") }() Eventually(rs.termSignal, time.Second*10).Should(Receive()) Eventually(rs.stopWatcher, time.Second*10).Should(Receive()) @@ -205,6 +208,7 @@ var _ = Describe("Server", func() { rp.On("CleanDeviceInfoFile", "fake").Return(nil) err := rs.Stop() Expect(err).NotTo(HaveOccurred()) + rp.AssertCalled(GinkgoT(), "CleanDeviceInfoFile", "fake") }() Eventually(rs.termSignal, time.Second*10).Should(Receive())