Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add semaphore and illegal state exception to chat #21

Merged
merged 9 commits into from
Dec 19, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.TextPart
import com.google.ai.client.generativeai.type.content
import java.util.LinkedList
import java.util.concurrent.Semaphore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
Expand All @@ -39,6 +40,7 @@ import kotlinx.coroutines.flow.onEach
* @property history the previous interactions with the model
*/
class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) {
private var lock = Semaphore(1)

/**
* Generates a response from the backend with the provided [Content], and any previous ones
davidmotson marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -49,13 +51,15 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()

val response = model.generateContent(*history.toTypedArray(), prompt)

history.add(prompt)
history.add(response.candidates.first().content)

return response
attemptLock()
try {
val response = model.generateContent(*history.toTypedArray(), prompt)
history.add(prompt)
history.add(response.candidates.first().content)
return response
} finally {
lock.release()
}
}

/**
Expand Down Expand Up @@ -87,6 +91,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
*/
davidmotson marked this conversation as resolved.
Show resolved Hide resolved
fun sendMessageStream(prompt: Content): Flow<GenerateContentResponse> {
prompt.assertComesFromUser()
attemptLock()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val bitmaps = LinkedList<Bitmap>()
Expand All @@ -109,6 +114,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}
}
.onCompletion {
lock.release()
if (it == null) {
val content =
content("model") {
Expand Down Expand Up @@ -156,4 +162,13 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
throw InvalidStateException("Chat prompts should come from the 'user' role.")
}
}

private fun attemptLock() {
if (!lock.tryAcquire()) {
throw IllegalStateException(
davidmotson marked this conversation as resolved.
Show resolved Hide resolved
"This chat instance currently has an ongoing request, please wait for it to complete " +
"before sending more messages"
)
}
}
}