Skip to content

Commit

Permalink
feat(android): reconnect lost devices
Browse files Browse the repository at this point in the history
  • Loading branch information
Malinskiy committed Jan 29, 2025
1 parent 39fb827 commit 7953b11
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
package com.malinskiy.marathon.device

import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel

interface DeviceProvider {
sealed class DeviceEvent {
class DeviceConnected(val device: Device) : DeviceEvent()
class DeviceDisconnected(val device: Device) : DeviceEvent()
class DeviceConnected(
val device: Device,
) : DeviceEvent()

class DeviceDisconnected(
val device: Device,
) : DeviceEvent()
}

suspend fun initialize()

/**
* Remote test parsers require a temp device
* This method should be called before reading from the [subscribe()] channel
* Remote test parsers require a temp device This method should be called before reading from
* the [subscribe()] channel
*/
suspend fun borrow() : Device
suspend fun borrow(): Device

suspend fun terminate()
fun subscribe(): Channel<DeviceEvent>

fun subscribe(scheduler: ReceiveChannel<DeviceEvent>): ReceiveChannel<DeviceEvent>
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import com.malinskiy.marathon.config.Configuration
import com.malinskiy.marathon.device.Device
import com.malinskiy.marathon.device.DeviceInfo
import com.malinskiy.marathon.device.DevicePoolId
import com.malinskiy.marathon.device.DeviceProvider
import com.malinskiy.marathon.device.toDeviceInfo
import com.malinskiy.marathon.execution.bundle.TestBundleIdentifier
import com.malinskiy.marathon.execution.device.DeviceActor
Expand All @@ -27,6 +28,7 @@ class DevicePoolActor(
private val poolId: DevicePoolId,
private val configuration: Configuration,
private val poolProgressAccumulator: PoolProgressAccumulator,
private val deviceProviderChannel: SendChannel<DeviceProvider.DeviceEvent>,
analytics: Analytics,
shard: TestShard,
timer: Timer,
Expand All @@ -47,9 +49,14 @@ class DevicePoolActor(
is DevicePoolMessage.FromQueue.Notify -> notifyDevices()
is DevicePoolMessage.FromQueue.Terminated -> onQueueTerminated()
is DevicePoolMessage.FromQueue.ExecuteBatch -> executeBatch(msg.device, msg.batch)
is DevicePoolMessage.FromDevice.DeviceLost -> deviceLost(msg.device)
}
}

private suspend fun deviceLost(device: Device) {
deviceProviderChannel.send(DeviceProvider.DeviceEvent.DeviceDisconnected(device))
}

/**
* Any problem with a device should not propagate a cancellation upstream
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ sealed class DevicePoolMessage {
data class IsReady(override val device: Device) : FromDevice(device)
data class CompletedTestBatch(override val device: Device, val results: TestBatchResults) : FromDevice(device)
data class ReturnTestBatch(override val device: Device, val batch: TestBatch, val reason: String) : FromDevice(device)
data class DeviceLost(override val device: Device) : FromDevice(device)
}

sealed class FromQueue : DevicePoolMessage() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.malinskiy.marathon.execution

import com.malinskiy.marathon.actor.unboundedChannel
import com.malinskiy.marathon.analytics.external.Analytics
import com.malinskiy.marathon.analytics.internal.pub.Track
import com.malinskiy.marathon.config.Configuration
Expand Down Expand Up @@ -45,6 +46,7 @@ class Scheduler(
private val pools = ConcurrentHashMap<DevicePoolId, DevicePoolActor>()
private val results = ConcurrentHashMap<DevicePoolId, PoolProgressAccumulator>()
private val poolingStrategy = configuration.poolingStrategy.toPoolingStrategy()
private val deviceProviderChannel = unboundedChannel<DeviceProvider.DeviceEvent>()

private val logger = MarathonLogging.logger("Scheduler")

Expand Down Expand Up @@ -72,7 +74,7 @@ class Scheduler(

private fun subscribeOnDevices(job: Job): Job {
return launch {
for (msg in deviceProvider.subscribe()) {
for (msg in deviceProvider.subscribe(deviceProviderChannel)) {
when (msg) {
is DeviceProvider.DeviceEvent.DeviceConnected -> {
onDeviceConnected(msg, job, coroutineContext)
Expand Down Expand Up @@ -117,7 +119,7 @@ class Scheduler(

pools.computeIfAbsent(poolId) { id ->
logger.debug { "pool actor ${id.name} is being created" }
DevicePoolActor(id, configuration, accumulator, analytics, shard, timer, parent, context, testBundleIdentifier)
DevicePoolActor(id, configuration, accumulator, deviceProviderChannel, analytics, shard, timer, parent, context, testBundleIdentifier)
}
pools[poolId]?.send(AddDevice(device)) ?: logger.debug {
"not sending the AddDevice event " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ class DeviceActor(
initializeJob?.cancelAndJoin()
executeJob?.cancelAndJoin()
close()
withContext(NonCancellable) {
pool.send(DevicePoolMessage.FromDevice.DeviceLost(device))
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import com.malinskiy.marathon.config.Configuration
import com.malinskiy.marathon.config.vendor.VendorConfiguration
import com.malinskiy.marathon.coroutines.newCoroutineExceptionHandler
import com.malinskiy.marathon.device.DeviceProvider
import com.malinskiy.marathon.device.DeviceProvider.DeviceEvent
import com.malinskiy.marathon.exceptions.NoDevicesException
import com.malinskiy.marathon.log.MarathonLogging
import com.malinskiy.marathon.time.Timer
Expand Down Expand Up @@ -52,6 +53,7 @@ class AdamDeviceProvider(
private val logger = MarathonLogging.logger("AdamDeviceProvider")

private val channel: Channel<DeviceProvider.DeviceEvent> = unboundedChannel()
private var scheduler: ReceiveChannel<DeviceProvider.DeviceEvent>? = null

private val dispatcher = newFixedThreadPoolContext(vendorConfiguration.threadingConfiguration.bootWaitingThreads, "DeviceMonitor")
private val installDispatcher = Dispatchers.IO.limitedParallelism(vendorConfiguration.threadingConfiguration.installThreads)
Expand Down Expand Up @@ -210,7 +212,40 @@ class AdamDeviceProvider(
socketFactory.close()
}

override fun subscribe() = channel
override fun subscribe(scheduler: ReceiveChannel<DeviceEvent>): ReceiveChannel<DeviceEvent> {
this.scheduler = scheduler
launch {
for (event in scheduler) {
processSchedulerEvents(event)
}
}
return channel
}

private suspend fun processSchedulerEvents(event: DeviceEvent) {
when (event) {
is DeviceEvent.DeviceConnected -> Unit
is DeviceEvent.DeviceDisconnected -> {
if (event.device is AdamAndroidDevice) {
val device = event.device as AdamAndroidDevice

val fullFormSerial = device.adbSerial.contains(":")
var host: String
var port: Int?
if (fullFormSerial) {
host = device.adbSerial.substringBefore(":")
port = device.adbSerial.substringAfter(":").toIntOrNull()
} else {
host = device.adbSerial
port = null
}
device.client.execute(DisconnectDeviceRequest(host, port))
} else {
logger.warn { "Received $event from scheduler, expected device to be AdamAndroidDevice, but was ${event.device::class.java.simpleName}" }
}
}
}
}
}

data class ProvidedDevice(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.malinskiy.marathon.android.adam

import com.malinskiy.adam.extension.readProtocolString
import com.malinskiy.adam.request.ComplexRequest
import com.malinskiy.adam.request.HostTarget
import com.malinskiy.adam.transport.Socket

class DisconnectDeviceRequest(
private val host: String? = null,
private val port: Int? = 5555
) : ComplexRequest<String>(target = HostTarget) {

override fun serialize() = createBaseRequest(
"disconnect:${
if (host == null) {
""
} else if (port != null) {
"$host:$port"
} else {
"$host"
}
}"
)

override suspend fun readElement(socket: Socket) = socket.readProtocolString()
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class AppleSimulatorProvider(
private val simulatorFactory = SimulatorFactory(configuration, vendorConfiguration, testBundleIdentifier, gson, track, timer)
private val deviceTracker = DeviceTracker()

override fun subscribe() = channel
override fun subscribe(scheduler: ReceiveChannel<DeviceProvider.DeviceEvent>): ReceiveChannel<DeviceProvider.DeviceEvent> = channel

private val sourceMutex = Mutex()
private lateinit var sourceChannel: ReceiveChannel<Marathondevices>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import kotlinx.coroutines.NonCancellable
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.coroutines.newFixedThreadPoolContext
Expand Down Expand Up @@ -72,7 +73,7 @@ class AppleMacosProvider(
configuration.outputDir
)

override fun subscribe() = channel
override fun subscribe(scheduler: ReceiveChannel<DeviceProvider.DeviceEvent>): ReceiveChannel<DeviceProvider.DeviceEvent> = channel

override suspend fun initialize() = withContext(coroutineContext) {
logger.debug("Initializing AppleMacosProvider")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ package com.malinskiy.marathon.test
import com.malinskiy.marathon.actor.unboundedChannel
import com.malinskiy.marathon.device.Device
import com.malinskiy.marathon.device.DeviceProvider
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.launch
import kotlin.coroutines.CoroutineContext

class StubDeviceProvider : DeviceProvider, CoroutineScope {
lateinit var context: CoroutineContext
lateinit var borrowingDevice: Device

override val coroutineContext: kotlin.coroutines.CoroutineContext
override val coroutineContext: CoroutineContext
get() = context

private val channel: Channel<DeviceProvider.DeviceEvent> = unboundedChannel()
Expand All @@ -22,11 +23,9 @@ class StubDeviceProvider : DeviceProvider, CoroutineScope {

override suspend fun borrow() = borrowingDevice

override fun subscribe(): Channel<DeviceProvider.DeviceEvent> {
override fun subscribe(scheduler: ReceiveChannel<DeviceProvider.DeviceEvent>): ReceiveChannel<DeviceProvider.DeviceEvent> {
providingLogic?.let {
launch(context = coroutineContext) {
providingLogic?.invoke(channel)
}
launch(context = coroutineContext) { providingLogic?.invoke(channel) }
}

return channel
Expand Down

0 comments on commit 7953b11

Please sign in to comment.