Skip to content

Commit

Permalink
Added swiftdoc documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
srgtuszy committed Oct 14, 2024
1 parent ebff28e commit 055eeb2
Showing 1 changed file with 104 additions and 32 deletions.
136 changes: 104 additions & 32 deletions Sources/llama-cpp-swift/LLama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ import Foundation
import Logging
@preconcurrency import llama

/// An actor that handles inference using the LLama language model.
public actor LLama {
private let logger = Logger.llama
private let model: Model
private let sampling: UnsafeMutablePointer<llama_sampler>
private var tokensList: [llama_token]
private var temporaryInvalidCChars: [CChar]

// MARK: - Init & teardown
// MARK: - Init & Teardown

/// Initializes a new instance of `LLama` with the specified model.
///
/// - Parameter model: The language model to use for inference.
public init(model: Model) {
self.model = model

Expand All @@ -26,52 +30,81 @@ public actor LLama {
self.temporaryInvalidCChars = []
}

deinit {
// llama_sampler_free(sampling)
}

// MARK: - Inference

/// Generates an asynchronous stream of tokens as strings based on the given prompt.
///
/// - Parameters:
/// - prompt: The input text prompt to generate completions for.
/// - maxTokens: The maximum number of tokens to generate. Defaults to 128.
///
/// - Returns: An `AsyncThrowingStream` emitting generated tokens as strings.
public func infer(prompt: String, maxTokens: Int32 = 128) -> AsyncThrowingStream<String, Error> {
return AsyncThrowingStream { continuation in
Task {
var isDone = false
let nLen: Int32 = 1024
var nCur: Int32 = 0
var nDecode: Int32 = 0
var batch = llama_batch_init(512, 0, 1)
defer {
llama_batch_free(batch)
}

do {
try self.completionInit(text: prompt, batch: &batch, nLen: nLen, nCur: &nCur)
try await self.infer(prompt: prompt, maxTokens: maxTokens, continuation: continuation)
} catch {
continuation.finish(throwing: error)
return
}
}
}
}

while !isDone && nCur < nLen && nCur - batch.n_tokens < maxTokens {
guard !Task.isCancelled else {
continuation.finish()
return
}
let newTokenStr = self.completionLoop(
batch: &batch,
isDone: &isDone,
nLen: nLen,
nCur: &nCur,
nDecode: &nDecode
)
continuation.yield(newTokenStr)
}
/// Performs the inference loop and yields generated tokens to the continuation.
///
/// - Parameters:
/// - prompt: The input text prompt to generate completions for.
/// - maxTokens: The maximum number of tokens to generate.
/// - continuation: The stream continuation to yield tokens to.
private func infer(
prompt: String,
maxTokens: Int32,
continuation: AsyncThrowingStream<String, Error>.Continuation
) async throws {
var isDone = false
let nLen: Int32 = 1024
var nCur: Int32 = 0
var nDecode: Int32 = 0
var batch = llama_batch_init(512, 0, 1)
defer {
llama_batch_free(batch)
}

do {
try self.completionInit(text: prompt, batch: &batch, nLen: nLen, nCur: &nCur)
} catch {
throw error
}

while !isDone && nCur < nLen && nCur - batch.n_tokens < maxTokens {
guard !Task.isCancelled else {
continuation.finish()
return
}
let newTokenStr = self.completionLoop(
batch: &batch,
isDone: &isDone,
nLen: nLen,
nCur: &nCur,
nDecode: &nDecode
)
continuation.yield(newTokenStr)
}
continuation.finish()
}

// MARK: - Private helpers

// MARK: - Private Helpers

/// Initializes the completion process by tokenizing the input text and preparing the batch.
///
/// - Parameters:
/// - text: The input text to tokenize.
/// - batch: The batch to initialize.
/// - nLen: The maximum length of the sequence.
/// - nCur: The current position in the sequence.
///
/// - Throws: An `InferError` if the KV cache is too small or decoding fails.
private func completionInit(
text: String,
batch: inout llama_batch,
Expand Down Expand Up @@ -109,6 +142,16 @@ public actor LLama {
nCur = batch.n_tokens
}

/// Performs a single iteration of the completion loop, generating the next token.
///
/// - Parameters:
/// - batch: The batch to use for decoding.
/// - isDone: A flag indicating whether the generation is complete.
/// - nLen: The maximum length of the sequence.
/// - nCur: The current position in the sequence.
/// - nDecode: The number of tokens decoded so far.
///
/// - Returns: The newly generated token as a string.
private func completionLoop(
batch: inout llama_batch,
isDone: inout Bool,
Expand Down Expand Up @@ -154,6 +197,14 @@ public actor LLama {
return newTokenStr
}

/// Adds a token to the batch.
///
/// - Parameters:
/// - batch: The batch to add the token to.
/// - id: The token ID to add.
/// - pos: The position of the token in the sequence.
/// - seq_ids: The sequence IDs associated with the token.
/// - logits: A flag indicating whether to compute logits for this token.
private func llamaBatchAdd(
_ batch: inout llama_batch,
_ id: llama_token,
Expand All @@ -172,6 +223,13 @@ public actor LLama {
batch.n_tokens += 1
}

/// Tokenizes the given text using the model's tokenizer.
///
/// - Parameters:
/// - text: The text to tokenize.
/// - add_bos: A flag indicating whether to add a beginning-of-sequence token.
///
/// - Returns: An array of token IDs.
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Data = text.utf8CString
let nTokens = Int32(utf8Data.count) + (add_bos ? 1 : 0)
Expand All @@ -187,6 +245,11 @@ public actor LLama {
return Array(UnsafeBufferPointer(start: tokens, count: Int(tokenCount)))
}

/// Converts a token ID to an array of CChars representing the token piece.
///
/// - Parameter token: The token ID to convert.
///
/// - Returns: An array of CChars representing the token piece.
private func tokenToPieceArray(token: llama_token) -> [CChar] {
var buffer = [CChar](repeating: 0, count: 8)
var nTokens = llama_token_to_piece(model.model, token, &buffer, 8, 0, false)
Expand All @@ -200,6 +263,11 @@ public actor LLama {
return Array(buffer.prefix(Int(nTokens)))
}

/// Attempts to create a partial string from an array of CChars if the full string is invalid.
///
/// - Parameter cchars: The array of CChars to attempt to convert.
///
/// - Returns: A valid string if possible; otherwise, `nil`.
private func attemptPartialString(from cchars: [CChar]) -> String? {
for i in (1..<cchars.count).reversed() {
let subArray = Array(cchars.prefix(i))
Expand All @@ -212,12 +280,16 @@ public actor LLama {
}

extension llama_batch {
/// Clears the batch by resetting the token count.
fileprivate mutating func clear() {
n_tokens = 0
}
}

extension String {
/// Initializes a string from a sequence of CChars, validating UTF8 encoding.
///
/// - Parameter validatingUTF8: The array of CChars to initialize the string from.
fileprivate init?(validatingUTF8 cchars: [CChar]) {
if #available(macOS 15.0, *) {
self.init(validating: cchars.map { UInt8(bitPattern: $0) }, as: UTF8.self)
Expand Down

0 comments on commit 055eeb2

Please sign in to comment.