Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add semaphore and illegal state exception to chat #21

Merged
merged 9 commits into from
Dec 19, 2023
1 change: 1 addition & 0 deletions .changes/calculator-bag-baby-chair.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"MINOR","changes":["An instance of Chat will now throw an InvalidStateException if multiple requests are made simultaneously."]}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.google.ai.client.generativeai.type.InvalidStateException
import com.google.ai.client.generativeai.type.TextPart
import com.google.ai.client.generativeai.type.content
import java.util.LinkedList
import java.util.concurrent.Semaphore
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
Expand All @@ -35,33 +36,41 @@ import kotlinx.coroutines.flow.onEach
* Handles the capturing and storage of the communication with the model, providing methods for
* further interaction.
*
* Note: This object is not thread-safe, and calling [sendMessage] multiple times without waiting
* for a response will throw an [InvalidStateException].
*
* @param model the model to use for the interaction
* @property history the previous interactions with the model
*/
class Chat(private val model: GenerativeModel, val history: MutableList<Content> = ArrayList()) {
private var lock = Semaphore(1)

/**
* Generates a response from the backend with the provided [Content], and any previous ones
davidmotson marked this conversation as resolved.
Show resolved Hide resolved
* sent/returned from this chat.
*
* @param prompt A [Content] to send to the model.
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Content): GenerateContentResponse {
prompt.assertComesFromUser()

val response = model.generateContent(*history.toTypedArray(), prompt)

history.add(prompt)
history.add(response.candidates.first().content)

return response
attemptLock()
try {
val response = model.generateContent(*history.toTypedArray(), prompt)
history.add(prompt)
history.add(response.candidates.first().content)
return response
} finally {
lock.release()
}
}

/**
* Generates a response from the backend with the provided text represented [Content].
*
* @param prompt The text to be converted into a single piece of [Content] to send to the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: String): GenerateContentResponse {
val content = content("user") { text(prompt) }
Expand All @@ -72,6 +81,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* Generates a response from the backend with the provided image represented [Content].
*
* @param prompt The image to be converted into a single piece of [Content] to send to the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
suspend fun sendMessage(prompt: Bitmap): GenerateContentResponse {
val content = content("user") { image(prompt) }
Expand All @@ -84,9 +94,11 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
* @param prompt A [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the prompt is not coming from the 'user' role
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
davidmotson marked this conversation as resolved.
Show resolved Hide resolved
fun sendMessageStream(prompt: Content): Flow<GenerateContentResponse> {
prompt.assertComesFromUser()
attemptLock()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val bitmaps = LinkedList<Bitmap>()
Expand All @@ -109,6 +121,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
}
}
.onCompletion {
lock.release()
if (it == null) {
val content =
content("model") {
Expand All @@ -134,6 +147,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
*
* @param prompt A [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: String): Flow<GenerateContentResponse> {
val content = content("user") { text(prompt) }
Expand All @@ -145,6 +159,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
*
* @param prompt A [Content] to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
* @throws InvalidStateException if the [Chat] instance has an active request.
*/
fun sendMessageStream(prompt: Bitmap): Flow<GenerateContentResponse> {
val content = content("user") { image(prompt) }
Expand All @@ -156,4 +171,13 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
throw InvalidStateException("Chat prompts should come from the 'user' role.")
}
}

private fun attemptLock() {
if (!lock.tryAcquire()) {
throw InvalidStateException(
"This chat instance currently has an ongoing request, please wait for it to complete " +
"before sending more messages"
)
}
}
}