Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force the declaration of role at the internal level #77

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ internal enum class HarmCategory {

typealias Base64 = String

@Serializable internal data class Content(val role: String? = "user", val parts: List<Part>)
@Serializable internal data class Content(val role: String, val parts: List<Part>)

@Serializable(PartSerializer::class) internal sealed interface Part

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import java.io.ByteArrayOutputStream
private const val BASE_64_FLAGS = Base64.NO_WRAP

internal fun com.google.ai.client.generativeai.type.Content.toInternal() =
Content(this.role, this.parts.map { it.toInternal() })
Content(this.role ?: "user", this.parts.map { it.toInternal() })

internal fun com.google.ai.client.generativeai.type.Part.toInternal(): Part {
return when (this) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.content.TextContent
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
Expand Down Expand Up @@ -82,13 +83,14 @@ internal class ModelNamingTests(private val modelName: String, private val actua
createGenerativeModel(modelName, "super_cool_test_key", RequestOptions(), mockEngine)

withTimeout(5.seconds) {
model.generateContentStream().collect {
model.generateContentStream("sample content").collect {
it.candidates.isEmpty() shouldBe false
channel.close()
}
}

mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
(mockEngine.requestHistory.first().body as TextContent).text shouldContain "role"
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ConversionsTest {

@Test
fun `test content conversion toPublic (role not mentioned)`() {
val content = Content(parts = listOf(TextPart("test"))).toPublic()
val content = Content(role = "user", parts = listOf(TextPart("test"))).toPublic()
rlazo marked this conversation as resolved.
Show resolved Hide resolved
content.role shouldBe "user"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ internal fun prepareResponse(response: GenerateContentResponse) =
JSON.encodeToString(response).toByteArray()

internal fun createRequest(vararg text: String): GenerateContentRequest {
val contents = text.map { Content(parts = listOf(TextPart(it))) }
val contents = text.map { Content(role = "user", parts = listOf(TextPart(it))) }

return GenerateContentRequest("gemini", contents)
}

internal fun createResponse(text: String) = createResponses(text).single()

internal fun createResponses(vararg text: String): List<GenerateContentResponse> {
val candidates = text.map { Candidate(Content(parts = listOf(TextPart(it)))) }
val candidates = text.map { Candidate(Content(role = "user", parts = listOf(TextPart(it)))) }

return candidates.map { GenerateContentResponse(candidates = listOf(it)) }
}
Expand Down
Loading