Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pinecone support #61

Merged
merged 13 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:

env:
GITHUB_PAT: ${{ secrets.TEST_GITHUB_PAT }}
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}
R_KEEP_PKG_SOURCE: yes

steps:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
runs-on: ubuntu-latest
env:
GITHUB_PAT: ${{ secrets.TEST_GITHUB_PAT }}
PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }}

steps:
- uses: actions/checkout@v4
Expand Down
18 changes: 10 additions & 8 deletions .lintr
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
linters: linters_with_defaults(
line_length_linter = line_length_linter(120L),
object_usage_linter = NULL,
object_length_linter = object_length_linter(45L),
object_name_linter = object_name_linter(
styles = c("snake_case", "CamelCase", "symbols"),
regexes = character()
),
cyclocomp_linter = NULL
trailing_whitespace_linter = NULL,
trailing_blank_lines_linter = NULL,
line_length_linter = NULL,
object_usage_linter = NULL,
object_length_linter = object_length_linter(45L),
object_name_linter = object_name_linter(
styles = c("snake_case", "CamelCase", "symbols"),
regexes = character()
),
cyclocomp_linter = NULL
)
encoding: "UTF-8"
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ Imports:
Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
Config/testthat/parallel: true
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Generated by roxygen2: do not edit by hand

export(add_files)
export(find_records)
export(initialize_project)
export(is_verbose)
export(process_repos)
export(set_database)
export(set_github_repos)
export(set_gitlab_repos)
export(set_llm)
Expand Down
6 changes: 6 additions & 0 deletions R/GitAI.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ GitAI <- R6::R6Class(
private$.llm <- value
},

db = function(value) {
if (missing(value)) return(private$.db)
private$.db <- value
},

system_prompt = function(value) {

if (is.null(private$.llm))
Expand Down Expand Up @@ -47,6 +52,7 @@ GitAI <- R6::R6Class(
private = list(
.project_id = NULL,
.llm = NULL,
.db = NULL,
.gitstats = NULL,
.files = NULL,
.repos_metadata = NULL,
Expand Down
151 changes: 151 additions & 0 deletions R/Pinecone.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
Pinecone <- R6::R6Class(
classname = "Pinecone",
inherit = VectorDatabase,
public = list(

get_index_metadata = function() {

pinecone_api_key <- Sys.getenv("PINECONE_API_KEY")

url <- paste0("https://api.pinecone.io/indexes/", private$.index)

httr2::request(url) |>
httr2::req_headers("Api-Key" = pinecone_api_key) |>
httr2::req_perform() |>
httr2::resp_body_json()
},

write_record = function(id, text, metadata = list()) {

pinecone_api_key <- Sys.getenv("PINECONE_API_KEY")

url <- paste0("https://", private$.index_host)

embeddings <- private$.get_embeddings(text = text)

metadata$text <- text

body <- list(
namespace = private$.namespace,
vectors = list(
id = id,
values = embeddings,
metadata = metadata
)
)

request <- httr2::request(url) |>
httr2::req_url_path_append("vectors/upsert") |>
httr2::req_headers(
"Api-Key" = pinecone_api_key,
"X-Pinecone-API-Version" = "2024-10"
) |>
httr2::req_body_json(body)

response <- request |>
httr2::req_perform()

response_body <- httr2::resp_body_json(response)
response_body
},

find_records = function(query, top_k = 1) {

embeddings <- private$.get_embeddings(query)

pinecone_api_key <- Sys.getenv("PINECONE_API_KEY")

url <- paste0("https://", private$.index_host)

body <- list(
namespace = private$.namespace,
vector = embeddings,
topK = top_k,
includeValues = FALSE,
includeMetadata = TRUE
)

request <- httr2::request(url) |>
httr2::req_url_path_append("query") |>
httr2::req_headers(
"Api-Key" = pinecone_api_key,
"X-Pinecone-API-Version" = "2024-10"
) |>
httr2::req_body_json(body)

response <- request |>
httr2::req_perform()

response_body <- httr2::resp_body_json(response)
results <- response_body$matches

results |>
purrr::map(function(result) {
result$values <- NULL
result
})
}
),

active = list(

namespace = function(value) {
if (missing(value)) return(private$.namespace)
private$.namespace <- value
},

index = function(value) {
if (missing(value)) return(private$.index)
private$.index <- value
}
),

private = list(

.project_id = NULL,
.index = NULL,
.namespace = NULL,
.index_host = NULL,

.initialize = function(index, namespace) {

private$.index <- index
private$.namespace <- namespace
private$.index_host <- self$get_index_metadata()$host
},

.get_embeddings = function(text) {

pinecone_api_key <- Sys.getenv("PINECONE_API_KEY")

url <- "https://api.pinecone.io"

body <- list(
model = "multilingual-e5-large",
parameters = list(
input_type = "passage",
truncate = "END"
),
inputs = list(
list(text = text)
)
)

request <- httr2::request(url) |>
httr2::req_url_path_append("embed") |>
httr2::req_headers(
"Api-Key" = pinecone_api_key,
"X-Pinecone-API-Version" = "2024-10"
) |>
httr2::req_body_json(body)

response <- request |>
httr2::req_perform()

response_body <- httr2::resp_body_json(response)

response_body$data[[1]]$values |> unlist()

}
)
)
27 changes: 27 additions & 0 deletions R/VectorDatabase.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
VectorDatabase <- R6::R6Class(
classname = "VectorDatabase",
public = list(

initialize = function(...) {

private$.initialize(...)
},

write_record = function(id, embeddings, metadata) {
stop(call. = FALSE, "Not implemented yet.")
},

find_records = function(query, top_k = 1) {
stop(call. = FALSE, "Not implemented yet.")
}
),

private = list(

.initialize = function(...) {},

.get_embeddings = function(text) {
stop(call. = FALSE, "Not implemented yet.")
}
)
)
19 changes: 19 additions & 0 deletions R/find_records.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#' Finding top K records in a vector database.
#'
#' @param query A character, user query.
#' @param top_k A numeric, number of top K records to return.
#' @inheritParams process_repos
#'
#' @export
find_records <- function(
gitai,
query,
top_k = 1,
verbose = is_verbose()
) {

gitai$db$find_records(
query = query,
top_k = top_k
)
}
60 changes: 38 additions & 22 deletions R/process_repos.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#' additional diagnostic messages.
#' @return A list.
#' @export
process_repos <- function(gitai, verbose = is_verbose()) {
process_repos <- function(
gitai,
verbose = is_verbose()
) {

gitstats <- gitai$gitstats

Expand All @@ -23,29 +26,42 @@ process_repos <- function(gitai, verbose = is_verbose()) {
)
files_content <- GitStats::get_files_content(gitstats, verbose = verbose)
repositories <- unique(files_content$repo_name)
process_repo_content <- function(repo_name) {
if (verbose) {
cli::cli_alert_info("Processing repository: {.pkg {repo_name}}")
}

filtered_content <- files_content |>
dplyr::filter(repo_name == !!repo_name)
content_to_process <- filtered_content |>
dplyr::pull(file_content) |>
paste(collapse = "\n\n")

result <- gitai |>
process_content(
results <-
repositories |>
purrr::map(function(repo_name) {
if (verbose) {
cli::cli_alert_info("Processing repository: {.pkg {repo_name}}")
}

filtered_content <- files_content |>
dplyr::filter(repo_name == !!repo_name)

content_to_process <- filtered_content |>
dplyr::pull(file_content) |>
paste(collapse = "\n\n")

if (verbose) {
cli::cli_alert_info("Processing content with LLM...")
}
result <- process_content(
gitai = gitai,
content = content_to_process
) |>
add_metadata(
content = filtered_content
)
}

results <- repositories |>
purrr::map(process_repo_content) |>
add_metadata(content = filtered_content)

if (!is.null(gitai$db)) {
if (verbose) {
cli::cli_alert_info("Writing to database...")
}
gitai$db$write_record(
id = repo_name,
text = result$text,
metadata = result$metadata
)
}
result
}) |>
purrr::set_names(repositories)

results
invisible(results)
}
29 changes: 29 additions & 0 deletions R/set_database.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#' Setting database in `GitAI` object.
#'
#' @inheritParams process_repos
#' @param provider A string. Name of database provider.
#' @param ... Additional arguments to pass to database provider constructor.
#'
#' @export
set_database <- function(
gitai,
provider = "Pinecone",
...
) {

provider_class <- get(provider)

args <- list(...)

if (is.null(args$namespace)) {
args$namespace <- gitai$project_id
}

db <- do.call(
what = provider_class$new,
args = args
)

gitai$db <- db
invisible(gitai)
}
Loading
Loading