From c34fe5b0f404a29ea2dc8cd0b2775073613f2740 Mon Sep 17 00:00:00 2001
From: Gaston Thea <gaston.thea@harness.io>
Date: Fri, 25 Oct 2024 14:23:23 -0300
Subject: [PATCH] Fix

---
 .../MySegments/MyLargeSegmentsStorage.swift   | 46 ++++++++++++++-----
 .../streaming/MySegmentUpdateTest.swift       |  9 ++--
 2 files changed, 37 insertions(+), 18 deletions(-)

diff --git a/Split/Storage/MySegments/MyLargeSegmentsStorage.swift b/Split/Storage/MySegments/MyLargeSegmentsStorage.swift
index e421cafd..00c3e8e3 100644
--- a/Split/Storage/MySegments/MyLargeSegmentsStorage.swift
+++ b/Split/Storage/MySegments/MyLargeSegmentsStorage.swift
@@ -13,6 +13,8 @@ class MyLargeSegmentsStorage: MySegmentsStorage {
     private var inMemorySegments: SynchronizedDictionary<String, SegmentChange> = SynchronizedDictionary()
     private let persistentStorage: PersistentMySegmentsStorage
     private let defaultChangeNumber = ServiceConstants.defaultSegmentsChangeNumber
+    private let syncQueue: DispatchQueue
+    private let syncQueueKey = DispatchSpecificKey<Void>()
 
     var keys: Set<String> {
         return inMemorySegments.keys
@@ -20,11 +22,15 @@ class MyLargeSegmentsStorage: MySegmentsStorage {
 
     init(persistentStorage: PersistentMySegmentsStorage) {
         self.persistentStorage = persistentStorage
+        self.syncQueue = DispatchQueue(label: "split-large-segments-storage")
+        syncQueue.setSpecific(key: syncQueueKey, value: ())
     }
 
     func loadLocal(forKey key: String) {
-        let change = persistentStorage.getSnapshot(forKey: key) ?? SegmentChange.empty()
-        inMemorySegments.setValue(change, forKey: key)
+        safeSync {
+            let change = persistentStorage.getSnapshot(forKey: key) ?? SegmentChange.empty()
+            inMemorySegments.setValue(change, forKey: key)
+        }
     }
 
     func changeNumber(forKey key: String) -> Int64? {
@@ -34,19 +40,24 @@ class MyLargeSegmentsStorage: MySegmentsStorage {
     func lowerChangeNumber() -> Int64 {
         return inMemorySegments.all.values.compactMap { $0.changeNumber }.min() ?? -1
     }
+
     func getAll(forKey key: String) -> Set<String> {
         return inMemorySegments.value(forKey: key)?.segments.compactMap { $0.name }.asSet() ?? Set<String>()
     }
 
     func set(_ change: SegmentChange, forKey key: String) {
-        inMemorySegments.setValue(change, forKey: key)
-        persistentStorage.set(change, forKey: key)
+        safeSync {
+            inMemorySegments.setValue(change, forKey: key)
+            persistentStorage.set(change, forKey: key)
+        }
     }
 
     func clear(forKey key: String) {
-        let clearChange = SegmentChange(segments: [])
-        inMemorySegments.setValue(clearChange, forKey: key)
-        persistentStorage.set(clearChange, forKey: key)
+        safeSync {
+            let clearChange = SegmentChange(segments: [])
+            inMemorySegments.setValue(clearChange, forKey: key)
+            persistentStorage.set(clearChange, forKey: key)
+        }
     }
 
     func destroy() {
@@ -58,11 +69,22 @@ class MyLargeSegmentsStorage: MySegmentsStorage {
     }
 
     func getCount() -> Int {
-        let keys = inMemorySegments.keys
-        var count = 0
-        for key in keys {
-            count+=(inMemorySegments.value(forKey: key)?.segments.count ?? 0)
+        safeSync {
+            let keys = inMemorySegments.keys
+            var count = 0
+            for key in keys {
+                count+=(inMemorySegments.value(forKey: key)?.segments.count ?? 0)
+            }
+            return count
+        }
+    }
+
+    // if already being executed in the queue, do not dispatch to it
+    private func safeSync<T>(_ block: () -> T) -> T {
+        if DispatchQueue.getSpecific(key: syncQueueKey) != nil {
+            return block()
+        } else {
+            return syncQueue.sync(execute: block)
         }
-        return count
     }
 }
diff --git a/SplitTests/Integration/streaming/MySegmentUpdateTest.swift b/SplitTests/Integration/streaming/MySegmentUpdateTest.swift
index 160b07e5..3b1d501a 100644
--- a/SplitTests/Integration/streaming/MySegmentUpdateTest.swift
+++ b/SplitTests/Integration/streaming/MySegmentUpdateTest.swift
@@ -76,24 +76,21 @@ class MySegmentUpdateTest: XCTestCase {
         syncSpy.forceMySegmentsCalledCount = 0
         sdkUpdExp = XCTestExpectation()
         pushMessage(TestingData.unboundedNotification(type: type, cn: mySegmentsCns[cnIndex()]))
-        wait(for: [sdkUpdExp], timeout: 50)
-        Thread.sleep(forTimeInterval: 1.0)
+        wait(for: [sdkUpdExp], timeout: 5.0)
 
         // Should not trigger any fetch to my segments because
         // this payload doesn't have "key1" enabled
 
         pushMessage(TestingData.escapedBoundedNotificationZlib(type: type, cn: mySegmentsCns[cnIndex()]))
 
-        Thread.sleep(forTimeInterval: 1.0)
         // Pushed key list message. Key 1 should add a segment
         sdkUpdExp = XCTestExpectation()
         pushMessage(TestingData.escapedKeyListNotificationGzip(type: type, cn: mySegmentsCns[cnIndex()]))
-        wait(for: [sdkUpdExp], timeout: 50)
-        Thread.sleep(forTimeInterval: 1.0)
+        wait(for: [sdkUpdExp], timeout: 5.0)
 
         sdkUpdExp = XCTestExpectation()
         pushMessage(TestingData.segmentRemovalNotification(type: type, cn: mySegmentsCns[cnIndex()]))
-        wait(for: [sdkUpdExp], timeout: 50)
+        wait(for: [sdkUpdExp], timeout: 5.0)
 
         var segmentEntity: [String]!
         if type == .mySegmentsUpdate {