Skip to content

Commit

Permalink
Allow to pass various models to set_llm.
Browse files Browse the repository at this point in the history
  • Loading branch information
krystian8207 committed Dec 4, 2024
1 parent 0ea163e commit 7e479d8
Show file tree
Hide file tree
Showing 17 changed files with 372 additions and 145 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ License: MIT + file LICENSE
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
Depends:
R (>= 4.1.0)
Imports:
cli (>= 3.4.0),
elmer,
Expand Down
3 changes: 0 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,3 @@ export(set_llm)
export(set_prompt)
export(verbose_off)
export(verbose_on)
importFrom(R6,R6Class)
importFrom(httr2,with_verbosity)
importFrom(lubridate,as_datetime)
8 changes: 4 additions & 4 deletions R/GitAI-package.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @importFrom R6 R6Class
#' @importFrom httr2 with_verbosity
#' @importFrom lubridate as_datetime
NULL
#' Derive knowledge from GitHub or GitLab repositories with the use of AI/LLM
#'
#' @name GitAI-package
"_PACKAGE"
18 changes: 12 additions & 6 deletions R/add_metadata.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,26 @@ add_metadata <- function(result, content) {

get_repo_date <- S7::new_generic("get_repo_date", "repo_api_url")

github_repo <- S7::new_class("github_repo",
properties = list(repo = S7::class_character))
github_repo <- S7::new_class(
"github_repo",
properties = list(repo = S7::class_character)
)

gitlab_repo <- S7::new_class("gitlab_repo",
properties = list(repo = S7::class_character))
gitlab_repo <- S7::new_class(
"gitlab_repo",
properties = list(repo = S7::class_character)
)

S7::method(get_repo_date, github_repo) <- function(repo_api_url) {
repo_data <- get_response(repo_api_url@repo)
lubridate::as_datetime(repo_data$updated_at)
}

S7::method(get_repo_date, gitlab_repo) <- function(repo_api_url) {
repo_data <- get_response(endpoint = repo_api_url@repo,
token = Sys.getenv("GITLAB_PAT"))
repo_data <- get_response(
endpoint = repo_api_url@repo,
token = Sys.getenv("GITLAB_PAT")
)
lubridate::as_datetime(repo_data$last_activity_at)
}

Expand Down
71 changes: 46 additions & 25 deletions R/process_repos.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ process_repos <- function(gitai, verbose = is_verbose()) {

gitstats <- gitai$gitstats

gitai$repos_metadata <-
GitStats::get_repos(gitstats,
add_contributors = FALSE,
verbose = verbose)
gitai$repos_metadata <- GitStats::get_repos(
gitstats,

Check warning on line 13 in R/process_repos.R

View workflow job for this annotation

GitHub Actions / lint

file=R/process_repos.R,line=13,col=6,[indentation_linter] Indentation should be 4 spaces but is 6 spaces.
add_contributors = FALSE,
verbose = verbose
)

GitStats::get_files_structure(
gitstats_object = gitstats,
Expand All @@ -22,31 +23,51 @@ process_repos <- function(gitai, verbose = is_verbose()) {
)
files_content <- GitStats::get_files_content(gitstats, verbose = verbose)
repositories <- unique(files_content$repo_name)
results <-
repositories |>
purrr::map(function(repo_name) {
if (verbose) {
cli::cli_alert_info("Processing repository: {.pkg {repo_name}}")
}

filtered_content <-
files_content |>
dplyr::filter(repo_name == !!repo_name)
content_to_process <-
filtered_content |>
dplyr::pull(file_content) |>
paste(collapse = "\n\n")

result <- process_content(
gitai = gitai,
process_repo_content <- function(repo_name) {
if (verbose) {
cli::cli_alert_info("Processing repository: {.pkg {repo_name}}")
}

filtered_content <- files_content |>
dplyr::filter(repo_name == !!repo_name)
content_to_process <- filtered_content |>
dplyr::pull(file_content) |>
paste(collapse = "\n\n")

result <- gitai |>
process_content(
content = content_to_process
) |>
add_metadata(
content = filtered_content
)
add_metadata(
content = filtered_content
)

}) |>
}

results <- repositories |>
purrr::map(process_repo_content) |>
purrr::set_names(repositories)

results
}

process_repo_content <- function(repo_name) {
if (verbose) {
cli::cli_alert_info("Processing repository: {.pkg {repo_name}}")
}

filtered_content <- files_content |>
dplyr::filter(repo_name == !!repo_name)
content_to_process <- filtered_content |>
dplyr::pull(file_content) |>
paste(collapse = "\n\n")

result <- gitai |>
process_content(
content = content_to_process
) |>
add_metadata(
content = filtered_content
)

}
47 changes: 30 additions & 17 deletions R/set_llm.R
Original file line number Diff line number Diff line change
@@ -1,31 +1,44 @@
#' Set Large Language Model in `GitAI` object.
#'
#' @name set_llm
#' @param gitai A \code{GitAI} object.
#' @param provider A LLM provider.
#' @param model A LLM model.
#' @param seed An integer to make results more reproducible.
#' @param ... Other arguments to pass to `elmer::chat_openai()` function.
#' @param provider Name of LLM provider, a string. Results with setting up LLM using
#' \code{elmer::chat_<provider>} function.
#' @param ... Other arguments to pass to corresponding \code{elmer::chat_<provider>} function.
#' Please use \link{get_llm_defaults} to get default model arguments.
#' @return A \code{GitAI} object.
#' @export
set_llm <- function(gitai,
provider = "openai",
model = "gpt-4o-mini",
seed = NULL,
...) {
set_llm <- function(gitai, provider = "openai", ...) {

if (provider == "openai") {
provider_method <- rlang::env_get(
env = asNamespace("elmer"),
nm = glue::glue("chat_{provider}")
)
provider_args <- purrr::list_modify(
get_llm_defaults(provider),
!!!rlang::dots_list(...)
)

gitai$llm <- elmer::chat_openai(
model = model,
echo = "none",
seed = seed,
...
)
}
gitai$llm <- rlang::exec(provider_method, !!!provider_args)

invisible(gitai)
}

llm_default_args <- list(
openai = list(model = "gpt-4o-mini", seed = NULL, echo = "none"),
ollama = list(model = "llama3.2", seed = NULL),
bedrock = list(model = "anthropic.claude-3-5-sonnet-20240620-v1:0")
)

#' @rdname set_llm
get_llm_defaults <- function(provider) {
llm_defaults <- llm_default_args[[provider]]
if (!is.null(llm_defaults)) {
return(llm_defaults)
}
list()
}

#' Set prompt.
#' @name set_prompt
#' @param gitai A \code{GitAI} object.
Expand Down
12 changes: 12 additions & 0 deletions R/set_repos.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@ set_github_repos <- function(gitai,
GitStats::set_github_host(
host = host,
repos = repos,
token = get_github_pat(),
verbose = verbose
)
invisible(gitai)
}

get_github_pat <- function() {
key_get <- get("key_get", envir = asNamespace("elmer"))
key_get("GITHUB_PAT")
}

#' Set GitLab repositories in `GitAI` object.
#' @name set_gitlab_repos
#' @param gitai A \code{GitAI} object.
Expand All @@ -47,7 +53,13 @@ set_gitlab_repos <- function(gitai,
GitStats::set_gitlab_host(
host = host,
repos = repos,
token = get_gitlab_pat(),
verbose = verbose
)
invisible(gitai)
}

get_gitlab_pat <- function() {
key_get <- get("key_get", envir = asNamespace("elmer"))
key_get("GITLAB_PAT")
}
22 changes: 22 additions & 0 deletions man/GitAI-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions man/set_llm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

80 changes: 80 additions & 0 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
@@ -1 +1,81 @@
test_mocker <- Mocker$new()

# Override other methods when needed in the future
ChatMocked <- elmer:::Chat
ChatMocked$public_methods$chat <- function(..., echo = NULL) {
if (self$get_system_prompt() == "You always return only 'Hi there!'") {
return("Hi there!")
}
}

# This method allows to skip original checks (e.g. for api or other args structure) and returns
# object of class ChatMocked that we can modify for our testing purposes.
mock_chat_method <- function(turns = NULL,
echo = c("none", "text", "all"),
...,
provider_class) {

provider_args <- rlang::dots_list(...)
provider <- rlang::exec(provider_class, !!!provider_args)

ChatMocked$new(provider = provider, turns = turns, echo = echo)
}

chat_openai_mocked <- function(system_prompt = NULL,
turns = NULL,
base_url = "https://api.mocked.com/v1",
api_key = "mocked_key",
model = NULL,
seed = NULL,
api_args = list(),
echo = c("none", "text", "all")) {

turns <- elmer:::normalize_turns(turns, system_prompt)
model <- elmer:::set_default(model, "gpt-4o")
echo <- elmer:::check_echo(echo)

if (is.null(seed)) {
seed <- 1014
}

mock_chat_method(
turns = turns,
echo = echo,
base_url = base_url,
model = model,
seed = seed,
extra_args = api_args,
api_key = api_key,
provider_class = elmer:::ProviderOpenAI
)
}

chat_bedrock_mocked <- function(system_prompt = NULL,
turns = NULL,
model = NULL,
profile = NULL,
echo = NULL) {

credentials <- list(
access_key_id = "access_key_id_mocked",
secret_access_key = "access_key_id_mocked",
session_token = "session_token_mocked",
access_token = "access_token_mocked",
expiration = as.numeric(Sys.time() + 3600),
region = "eu-central-1"
)

turns <- elmer:::normalize_turns(turns, system_prompt)
model <- elmer:::set_default(model, "model_bedrock")
echo <- elmer:::check_echo(echo)

mock_chat_method(
turns = turns,
echo = echo,
base_url = "",
model = model,
profile = profile,
credentials = credentials,
provider_class = elmer:::ProviderBedrock
)
}
Loading

0 comments on commit 7e479d8

Please sign in to comment.