Skip to content

Commit

Permalink
add semaphore and illegal state exception to chat
Browse files Browse the repository at this point in the history
  • Loading branch information
David Motsonashvili committed Dec 18, 2023
1 parent 440f960 commit 602cfb9
Showing 1 changed file with 15 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.TextPart
import com.google.ai.client.generativeai.type.content
import java.util.LinkedList
import java.util.concurrent.Semaphore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
Expand All @@ -39,6 +40,7 @@ import kotlinx.coroutines.flow.onEach
* @property history the previous interactions with the model
*/
class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) {
private var lock = Semaphore(1)

/**
* Generates a response from the backend with the provided [Content], and any previous ones
Expand All @@ -49,9 +51,11 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()
assertNoOngoingCall()

val response = model.generateContent(*history.toTypedArray(), prompt)

lock.release()
history.add(prompt)
history.add(response.candidates.first().content)

Expand Down Expand Up @@ -87,6 +91,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
*/
fun sendMessageStream(prompt: Content): Flow<GenerateContentResponse> {
prompt.assertComesFromUser()
assertNoOngoingCall()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val bitmaps = LinkedList<Bitmap>()
Expand All @@ -109,6 +114,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}
}
.onCompletion {
lock.release()
if (it == null) {
val content =
content("model") {
Expand Down Expand Up @@ -156,4 +162,13 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
throw InvalidStateException("Chat prompts should come from the 'user' role.")
}
}

private fun assertNoOngoingCall() {
if (!lock.tryAcquire()) {
throw IllegalStateException(
"This chat instance currently has an ongoing request, please wait for it to complete " +
"before sending more messages"
)
}
}
}

0 comments on commit 602cfb9

Please sign in to comment.