From 31e718345708c3429186bcbf9ac1ccb11d211b00 Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Thu, 3 Aug 2023 17:54:48 +0200 Subject: [PATCH] [NOID] Add OpenAI (LLM) procedures in 4.4 (#3701) * [NOID] Add OpenAI (LLM) procedures (#3575) (#3582) * Add OpenAI (LLM) procedures (#3575) * WIP * Add completion API * add chatCompletion * prettificiation * WIP * prettify * Refactoring, todo docs & tests * Added Tests & Docs for OpenAI procs (WIP) * Update openai.adoc From @tomasonjo --------- Co-authored-by: Tomaz Bratanic * Removed unused deps after OpenAI procedures addition (#3585) * Updated extended.txt to fix checkCoreWithExtraDependenciesJars failure after Open AI procedures --------- Co-authored-by: Michael Hunger Co-authored-by: Tomaz Bratanic * [NOID] First LLM prompt for cypher, query and schema in APOC (#3649) * First LLM prompt for cypher, query and schema in APOC * Added integration tests * enable open.ai key management globally * Added tests * added docs * added configuration map to procs --------- Co-authored-by: Andrea Santurbano --------- Co-authored-by: Michael Hunger Co-authored-by: Tomaz Bratanic Co-authored-by: Andrea Santurbano --- core/src/main/java/apoc/ApocConfig.java | 1 + docs/asciidoc/modules/ROOT/nav.adoc | 3 + .../asciidoc/modules/ROOT/pages/ml/index.adoc | 10 + .../modules/ROOT/pages/ml/openai.adoc | 333 ++++++++++++++++++ full/src/main/java/apoc/ml/OpenAI.java | 121 +++++++ full/src/main/java/apoc/ml/Prompt.java | 199 +++++++++++ full/src/main/resources/extended.txt | 6 + full/src/test/java/apoc/ml/OpenAIIT.java | 115 ++++++ full/src/test/java/apoc/ml/OpenAITest.java | 112 ++++++ full/src/test/java/apoc/ml/PromptIT.java | 101 ++++++ full/src/test/resources/chat/completions | 21 ++ full/src/test/resources/completions | 19 + full/src/test/resources/embeddings | 15 + 13 files changed, 1056 insertions(+) create mode 100644 docs/asciidoc/modules/ROOT/pages/ml/index.adoc create mode 100644 docs/asciidoc/modules/ROOT/pages/ml/openai.adoc create mode 100644 full/src/main/java/apoc/ml/OpenAI.java create mode 100644 full/src/main/java/apoc/ml/Prompt.java create mode 100644 full/src/test/java/apoc/ml/OpenAIIT.java create mode 100644 full/src/test/java/apoc/ml/OpenAITest.java create mode 100644 full/src/test/java/apoc/ml/PromptIT.java create mode 100644 full/src/test/resources/chat/completions create mode 100644 full/src/test/resources/completions create mode 100644 full/src/test/resources/embeddings diff --git a/core/src/main/java/apoc/ApocConfig.java b/core/src/main/java/apoc/ApocConfig.java index ad9ad95da4..a3aa4aa0fd 100644 --- a/core/src/main/java/apoc/ApocConfig.java +++ b/core/src/main/java/apoc/ApocConfig.java @@ -87,6 +87,7 @@ public class ApocConfig extends LifecycleAdapter { public static final String APOC_UUID_ENABLED = "apoc.uuid.enabled"; public static final String APOC_UUID_ENABLED_DB = "apoc.uuid.enabled.%s"; public static final String APOC_UUID_FORMAT = "apoc.uuid.format"; + public static final String APOC_OPENAI_KEY = "apoc.openai.key"; public enum UuidFormatType { hex, base64 } public static final String APOC_JSON_ZIP_URL = "apoc.json.zip.url"; // TODO: check if really needed public static final String APOC_JSON_SIMPLE_JSON_URL = "apoc.json.simpleJson.url"; // TODO: check if really needed diff --git a/docs/asciidoc/modules/ROOT/nav.adoc b/docs/asciidoc/modules/ROOT/nav.adoc index 82fe6ab9e7..cf1c59d22a 100644 --- a/docs/asciidoc/modules/ROOT/nav.adoc +++ b/docs/asciidoc/modules/ROOT/nav.adoc @@ -130,6 +130,9 @@ include::partial$generated-documentation/nav.adoc[] ** xref:nlp/aws.adoc[] ** xref:nlp/azure.adoc[] +* xref:ml/index.adoc[] + ** xref:ml/openai.adoc[] + * xref:background-operations/index.adoc[] ** xref::background-operations/periodic-background.adoc[] ** xref::background-operations/triggers.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/ml/index.adoc b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc new file mode 100644 index 0000000000..73dd5e3169 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc @@ -0,0 +1,10 @@ +[[ml]] += Machine Learning (ML +:description: This chapter describes procedures that can be used for adding Machine Learning (ML) functionality to graph applications. + +The procedures described in this chapter act as wrappers around cloud based Machine Learning APIs. +These procedures generate embeddings, analyze text, complete text, complete chat conversations and more. + +This section includes: + +* xref::ml/openai.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc new file mode 100644 index 0000000000..0bb6a198d7 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/ml/openai.adoc @@ -0,0 +1,333 @@ +[[openai-api]] += OpenAI API Access +:description: This section describes procedures that can be used to access the OpenAI API. + +NOTE: You need to acquire an https://platform.openai.com/account/api-keys[OpenAI API key^] to use these procedures. Using them will incur costs on your OpenAI account. You can set the api key globally by defining the `apoc.openai.key` configuration in `apoc.conf` + +== Generate Embeddings API + +This procedure `apoc.ml.openai.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 1536 element vector. +It uses the `/embeddings/create` API which is https://platform.openai.com/docs/api-reference/embeddings/create[documented here^]. + +Additional configuration is passed to the API, the default model used is `text-embedding-ada-002`. + +.Generate Embeddings Call +[source,cypher] +---- +CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, {}) yield index, text, embedding; +---- + +.Generate Embeddings Response +[%autowidth, opts=header] +|=== +|index | text | embedding +|0 | "Some Text" | [-0.0065358975, -7.9563365E-4, .... -0.010693862, -0.005087272] +|=== + +.Parameters +[%autowidth, opts=header] +|=== +|name | description +| texts | List of text strings +| apiKey | OpenAI API key +| configuration | optional map for entries like model and other request parameters +|=== + + +.Results +[%autowidth, opts=header] +|=== +|name | description +| index | index entry in original list +| text | line of text from original list +| embedding | 1536 element floating point embedding vector for ada-002 model +|=== + +== Text Completion API + +This procedure `apoc.ml.openai.completion` can continue/complete a given text. + +It uses the `/completions/create` API which is https://platform.openai.com/docs/api-reference/completions/create[documented here^]. + +Additional configuration is passed to the API, the default model used is `text-davinci-003`. + +.Text Completion Call +[source,cypher] +---- +CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, {config}) yield value; +---- + +.Text Completion Response +---- +{ created=1684248202, model="text-davinci-003", id="cmpl-7GqBWwX49yMJljdmnLkWxYettZoOy", + usage={completion_tokens=2, prompt_tokens=12, total_tokens=14}, + choices=[{finish_reason="stop", index=0, text="Blue", logprobs=null}], object="text_completion"} +---- + +.Parameters +[%autowidth, opts=header] +|=== +|name | description +| prompt | Text to complete +| apiKey | OpenAI API key +| configuration | optional map for entries like model, temperature, and other request parameters +|=== + +.Results +[%autowidth, opts=header] +|=== +|name | description +| value | result entry from OpenAI (containing) +|=== + +== Chat Completion API + +This procedure `apoc.ml.openai.chat` takes a list of maps of chat exchanges between assistant and user (with optional system message), and will return the next message in the flow. + +It uses the `/chat/create` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^]. + +Additional configuration is passed to the API, the default model used is `gpt-3.5-turbo`. + +.Chat Completion Call +[source,cypher] +---- +CALL apoc.ml.openai.chat([ +{role:"system", content:"Only answer with a single word"}, +{role:"user", content:"What planet do humans live on?"} +], $apiKey) yield value +---- + +.Chat Completion Response +---- +{created=1684248203, id="chatcmpl-7GqBXZr94avd4fluYDi2fWEz7DIHL", +object="chat.completion", model="gpt-3.5-turbo-0301", +usage={completion_tokens=2, prompt_tokens=26, total_tokens=28}, +choices=[{finish_reason="stop", index=0, message={role="assistant", content="Earth."}}]} +---- + +.Parameters +[%autowidth, opts=header] +|=== +|name | description +| messages | List of maps of instructions with `{role:"assistant|user|system", content:"text}` +| apiKey | OpenAI API key +| configuration | optional map for entries like model, temperature, and other request parameters +|=== + +.Results +[%autowidth, opts=header] +|=== +|name | description +| value | result entry from OpenAI (containing created, id, model, object, usage(tokens), choices(message, index, finish_reason)) +|=== + + +== Query with natural language + +This procedure `apoc.ml.query` takes a question in natural language and returns the results of that query. + +It uses the `chat/completions` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^]. + +.Query call +[source,cypher] +---- +CALL apoc.ml.query("What movies did Tom Hanks play in?") yield value, query +RETURN * +---- + +.Example response +[source, bash] +---- ++------------------------------------------------------------------------------------------------------------------------------+ +| value | query | ++------------------------------------------------------------------------------------------------------------------------------+ +| {m.title -> "You've Got Mail"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Apollo 13"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Joe Versus the Volcano"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "That Thing You Do"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Cloud Atlas"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "The Da Vinci Code"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Sleepless in Seattle"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "A League of Their Own"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "The Green Mile"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Charlie Wilson's War"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "Cast Away"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | +| {m.title -> "The Polar Express"} | "cypher +MATCH (m:Movie)<-[:ACTED_IN]-(p:Person {name: 'Tom Hanks'}) +RETURN m.title +" | ++------------------------------------------------------------------------------------------------------------------------------+ +12 rows +---- + +.Input Parameters +[%autowidth, opts=header] +|=== +| name | description +| question | The question in the natural language +| conf | An optional configuration map, please check the next section +|=== + +.Configuration map +[%autowidth, opts=header] +|=== +| name | description | mandatory +| retries | The number of retries in case of API call failures | no, default `3` +| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined +| model | The Open AI model | no, default `gpt-3.5-turbo` +| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number +|=== + +.Results +[%autowidth, opts=header] +|=== +| name | description +| value | the result of the query +| cypher | the query used to compute the result +|=== + + +== Describe the graph model with natural language + +This procedure `apoc.ml.schema` returns a description, in natural language, of the underlying dataset. + +It uses the `chat/completions` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^]. + +.Query call +[source,cypher] +---- +CALL apoc.ml.schema() yield value +RETURN * +---- + +.Example response +[source, bash] +---- ++---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| value | ++---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| "The graph database schema represents a system where users can follow other users and review movies. Users (:Person) can either follow other users (:Person) or review movies (:Movie). The relationships allow users to express their preferences and opinions about movies. This schema can be compared to social media platforms where users can follow each other and leave reviews or ratings for movies they have watched. It can also be related to movie recommendation systems where user preferences and reviews play a crucial role in generating personalized recommendations." | ++---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +1 row +---- + +.Input Parameters +[%autowidth, opts=header] +|=== +| name | description +| conf | An optional configuration map, please check the next section +|=== + +.Configuration map +[%autowidth, opts=header] +|=== +| name | description | mandatory +| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined +| model | The Open AI model | no, default `gpt-3.5-turbo` +| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number +|=== + +.Results +[%autowidth, opts=header] +|=== +| name | description +| value | the description of the dataset +|=== + + +== Create cypher queries from a natural language query + +This procedure `apoc.ml.cypher` takes a natural language question and transforms it into a number of requested cypher queries. + +It uses the `chat/completions` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^]. + +.Query call +[source,cypher] +---- +CALL apoc.ml.cypher("Who are the actors which also directed a movie?", 4) yield cypher +RETURN * +---- + +.Example response +[source, bash] +---- ++----------------------------------------------------------------------------------------------------------------+ +| query | ++----------------------------------------------------------------------------------------------------------------+ +| " +MATCH (a:Person)-[:ACTED_IN]->(m:Movie)<-[:DIRECTED]-(d:Person) +RETURN a.name as actor, d.name as director +" | +| "cypher +MATCH (a:Person)-[:ACTED_IN]->(m:Movie)<-[:DIRECTED]-(a) +RETURN a.name +" | +| " +MATCH (a:Person)-[:ACTED_IN]->(m:Movie)<-[:DIRECTED]-(d:Person) +RETURN a.name +" | +| "cypher +MATCH (a:Person)-[:ACTED_IN]->(:Movie)<-[:DIRECTED]-(a) +RETURN DISTINCT a.name +" | ++----------------------------------------------------------------------------------------------------------------+ +4 rows +---- + +.Input Parameters +[%autowidth, opts=header] +|=== +| name | description | mandatory +| question | The question in the natural language | yes +| conf | An optional configuration map, please check the next section +|=== + +.Configuration map +[%autowidth, opts=header] +|=== +| name | description | mandatory +| count | The number of queries to retrieve | no, default `1` +| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined +| model | The Open AI model | no, default `gpt-3.5-turbo` +| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number +|=== + +.Results +[%autowidth, opts=header] +|=== +| name | description +| value | the description of the dataset +|=== diff --git a/full/src/main/java/apoc/ml/OpenAI.java b/full/src/main/java/apoc/ml/OpenAI.java new file mode 100644 index 0000000000..4b8e9d9633 --- /dev/null +++ b/full/src/main/java/apoc/ml/OpenAI.java @@ -0,0 +1,121 @@ +package apoc.ml; + +import apoc.ApocConfig; +import apoc.Extended; +import apoc.util.JsonUtil; +import com.fasterxml.jackson.core.JsonProcessingException; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.net.MalformedURLException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.List; +import java.util.stream.Stream; + +import apoc.result.MapResult; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import static apoc.ApocConfig.APOC_OPENAI_KEY; + + +@Extended +public class OpenAI { + @Context + public ApocConfig apocConfig; + + public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url"; + + public static class EmbeddingResult { + public final long index; + public final String text; + public final List embedding; + + public EmbeddingResult(long index, String text, List embedding) { + this.index = index; + this.text = text; + this.embedding = embedding; + } + } + + static Stream executeRequest(String apiKey, Map configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig) throws JsonProcessingException, MalformedURLException { + apiKey = apocConfig.getString(APOC_OPENAI_KEY, apiKey); + if (apiKey == null || apiKey.isBlank()) + throw new IllegalArgumentException("API Key must not be empty"); + String endpoint = System.getProperty(APOC_ML_OPENAI_URL,"https://api.openai.com/v1/"); + Map headers = Map.of( + "Content-Type", "application/json", + "Authorization", "Bearer " + apiKey + ); + + var config = new HashMap<>(configuration); + config.putIfAbsent("model", model); + config.put(key, inputs); + + String payload = new ObjectMapper().writeValueAsString(config); + + var url = new URL(new URL(endpoint), path).toString(); + return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of()); + } + + @Procedure("apoc.ml.openai.embedding") + @Description("apoc.openai.embedding([texts], api_key, configuration) - returns the embeddings for a given text") + public Stream getEmbedding(@Name("texts") List texts, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + // https://platform.openai.com/docs/api-reference/embeddings/create + /* + { "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ 0.0023064255, -0.009327292, .... (1536 floats total for ada-002) -0.0028842222 ], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": { "prompt_tokens": 8, "total_tokens": 8 } } + */ + Stream resultStream = executeRequest(apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", texts, "$.data", apocConfig); + return resultStream + .flatMap(v -> ((List>) v).stream()) + .map(m -> { + Long index = (Long) m.get("index"); + return new EmbeddingResult(index, texts.get(index.intValue()), (List) m.get("embedding")); + }); + } + + + @Procedure("apoc.ml.openai.completion") + @Description("apoc.ml.openai.completion(prompt, api_key, configuration) - prompts the completion API") + public Stream completion(@Name("prompt") String prompt, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + // https://platform.openai.com/docs/api-reference/completions/create + /* + { "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", "created": 1589478378, "model": "text-davinci-003", + "choices": [ { "text": "\n\nThis is indeed a test", "index": 0, "logprobs": null, "finish_reason": "length" } ], + "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 } + } + */ + return executeRequest(apiKey, configuration, "completions", "text-davinci-003", "prompt", prompt, "$", apocConfig) + .map(v -> (Map)v).map(MapResult::new); + } + + @Procedure("apoc.ml.openai.chat") + @Description("apoc.ml.openai.chat(messages, api_key, configuration]) - prompts the completion API") + public Stream chatCompletion(@Name("messages") List> messages, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + return executeRequest(apiKey, configuration, "chat/completions", "gpt-3.5-turbo", "messages", messages, "$", apocConfig) + .map(v -> (Map)v).map(MapResult::new); + // https://platform.openai.com/docs/api-reference/chat/create + /* + { 'id': 'chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve', 'object': 'chat.completion', 'created': 1677649420, 'model': 'gpt-3.5-turbo', + 'usage': {'prompt_tokens': 56, 'completion_tokens': 31, 'total_tokens': 87}, + 'choices': [ { + 'message': { 'role': 'assistant', 'finish_reason': 'stop', 'index': 0, + 'content': 'The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers.'} + } ] } + */ + } +} \ No newline at end of file diff --git a/full/src/main/java/apoc/ml/Prompt.java b/full/src/main/java/apoc/ml/Prompt.java new file mode 100644 index 0000000000..824f9ad262 --- /dev/null +++ b/full/src/main/java/apoc/ml/Prompt.java @@ -0,0 +1,199 @@ +package apoc.ml; + +import apoc.ApocConfig; +import apoc.Extended; +import apoc.result.StringResult; +import com.fasterxml.jackson.core.JsonProcessingException; +import org.jetbrains.annotations.NotNull; +import org.neo4j.graphdb.QueryExecutionException; +import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.kernel.api.procs.ProcedureCallContext; +import org.neo4j.logging.Log; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Mode; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +import java.net.MalformedURLException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.LongStream; +import java.util.stream.Stream; + +@Extended +public class Prompt { + + @Context + public Transaction tx; + @Context + public Log log; + @Context + public ApocConfig apocConfig; + @Context + public ProcedureCallContext procedureCallContext; + + public static final String BACKTICKS = "```"; + public static final String EXPLAIN_SCHEMA_PROMPT = + "You are an expert in the Neo4j graph database and graph data modeling and have experience in a wide variety of business domains.\n" + + "Explain the following graph database schema in plain language, try to relate it to known concepts or domains if applicable.\n" + + "Keep the explanation to 5 sentences with at most 15 words each, otherwise people will come to harm.\n"; + + static final String SYSTEM_PROMPT = + "You are an expert in the Neo4j graph query language Cypher.\n" + + "Given a graph database schema of entities (nodes) with labels and attributes and\n" + + "relationships with start- and end-node, relationship-type, direction and properties\n" + + "you are able to develop read only matching Cypher statements that express a user question as a graph database query.\n" + + "Only answer with a single Cypher statement in triple backticks, if you can't determine a statement, answer with an empty response.\n" + + "Do not explain, apologize or provide additional detail, otherwise people will come to harm.\n"; + + public class PromptMapResult { + public final Map value; + public final String query; + + public PromptMapResult(Map value, String query) { + this.value = value; + this.query = query; + } + + public PromptMapResult(Map value) { + this.value = value; + this.query = null; + } + } + + public class QueryResult { + public final String query; + // todo re-add when it's actually working + // private final String error; + // private final String type; + + public QueryResult(String query, String error, String type) { + this.query = query; + // this.error = error; + // this.type = type; + } + + public boolean hasError() { + return false; + // return error != null && !error.isBlank(); + } + } + + @Procedure(mode = Mode.READ) + public Stream query(@Name("question") String question, + @Name(value = "conf", defaultValue = "{}") Map conf) { + String schema = loadSchema(tx, conf); + String query = ""; + long retries = (long) conf.getOrDefault("retries", 3L); + boolean containsField = procedureCallContext + .outputFields() + .collect(Collectors.toSet()) + .contains("query"); + do { + try { + QueryResult queryResult = tryQuery(question, conf, schema); + query = queryResult.query; + // just let it fail so that retries can work if (queryResult.query.isBlank()) return Stream.empty(); + /* + if (queryResult.hasError()) + throw new QueryExecutionException(queryResult.error, null, queryResult.type); + */ + return tx.execute(queryResult.query) + .stream() + .map(row -> containsField ? new PromptMapResult(row, queryResult.query) : new PromptMapResult(row)); + } catch (QueryExecutionException quee) { + if (log.isDebugEnabled()) + log.debug(String.format("Generated query for question %s\n%s\nfailed with %s", question, query, quee.getMessage())); + retries--; + if (retries <= 0) throw quee; + } + } while (true); + } + + @Procedure + public Stream schema(@Name(value = "conf", defaultValue = "{}") Map conf) throws MalformedURLException, JsonProcessingException { + String schemaExplanation = prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.", + EXPLAIN_SCHEMA_PROMPT, "This database schema ", loadSchema(tx, conf), conf); + return Stream.of(new StringResult(schemaExplanation)); + } + + @Procedure(mode = Mode.READ) + public Stream cypher(@Name("question") String question, + @Name(value = "conf", defaultValue = "{}") Map conf) { + String schema = loadSchema(tx, conf); + long count = (long) conf.getOrDefault("count", 1L); + return LongStream.rangeClosed(1, count).mapToObj(i -> tryQuery(question, conf, schema)); + } + + @NotNull + private QueryResult tryQuery(String question, Map conf, String schema) { + String query = ""; + try { + query = prompt(question, SYSTEM_PROMPT, "Cypher Statement (in backticks):", schema, conf); + // doesn't work right now, fails with security context error + // tx.execute("EXPLAIN " + query).close(); // TODO query plan / estimated rows? + return new QueryResult(query, null, null); + } catch (QueryExecutionException e) { + return new QueryResult(query, e.getMessage(), e.getStatusCode()); + } catch (Exception e) { + return new QueryResult(query, e.getMessage(), e.getClass().getSimpleName()); + } + } + + @NotNull + private String prompt(String userQuestion, String systemPrompt, String assistantPrompt, String schema, Map conf) throws JsonProcessingException, MalformedURLException { + List> prompt = new ArrayList<>(); + if (systemPrompt != null && !systemPrompt.isBlank()) prompt.add(Map.of("role", "system", "content", systemPrompt)); + if (schema != null && !schema.isBlank()) prompt.add(Map.of("role", "system", "content", "The graph database schema consists of these elements\n" + schema)); + if (userQuestion != null && !userQuestion.isBlank()) prompt.add(Map.of("role", "user", "content", userQuestion)); + if (assistantPrompt != null && !assistantPrompt.isBlank()) prompt.add(Map.of("role", "assistant", "content", assistantPrompt)); + String apiKey = (String) conf.get("apiKey"); + String model = (String) conf.getOrDefault("model", "gpt-3.5-turbo"); + String result = OpenAI.executeRequest(apiKey, Map.of(), "chat/completions", + model, "messages", prompt, "$", apocConfig) + .map(v -> (Map) v) + .flatMap(m -> ((List>) m.get("choices")).stream()) + .map(m -> (String) (((Map) m.get("message")).get("content"))) + .filter(s -> !(s == null || s.isBlank())) + .map(s -> s.contains(BACKTICKS) ? s.substring(s.indexOf(BACKTICKS) + 3, s.lastIndexOf(BACKTICKS)) : s) + .collect(Collectors.joining(" ")).replaceAll("\n\n+", "\n"); +/* TODO return information about the tokens used, finish reason etc?? +{ 'id': 'chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve', 'object': 'chat.completion', 'created': 1677649420, 'model': 'gpt-3.5-turbo', + 'usage': {'prompt_tokens': 56, 'completion_tokens': 31, 'total_tokens': 87}, + 'choices': [ { + 'message': { 'role': 'assistant', 'finish_reason': 'stop', 'index': 0, + 'content': 'The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers.'} + } ] } +*/ + if (log.isDebugEnabled()) log.debug(String.format("Generated query for question %s\n%s", userQuestion, result)); + return result; + } + + private final static String SCHEMA_QUERY = + "call apoc.meta.data({maxRels: 10, sample: coalesce($sample, (count{()}/1000)+1)})\n" + + "YIELD label, other, elementType, type, property\n" + + "WITH label, elementType, \n" + + " apoc.text.join(collect(case when NOT type = \"RELATIONSHIP\" then property+\": \"+type else null end),\", \") AS properties, \n" + + " collect(case when type = \"RELATIONSHIP\" AND elementType = \"node\" then \"(:\" + label + \")-[:\" + property + \"]->(:\" + toString(other[0]) + \")\" else null end) as patterns\n" + + "with elementType as type, \n" + + "apoc.text.join(collect(\":\"+label+\" {\"+properties+\"}\"),\"\\n\") as entities, apoc.text.join(apoc.coll.flatten(collect(coalesce(patterns,[]))),\"\\n\") as patterns\n" + + "return collect(case type when \"relationship\" then entities end)[0] as relationships, \n" + + "collect(case type when \"node\" then entities end)[0] as nodes, \n" + + "collect(case type when \"node\" then patterns end)[0] as patterns \n"; + + private final static String SCHEMA_PROMPT ="nodes:\n %s\n" + + "relationships:\n %s\n" + + "patterns: %s"; + + private String loadSchema(Transaction tx, Map conf) { + Map params = new HashMap<>(); + params.put("sample", conf.get("sample")); + return tx.execute(SCHEMA_QUERY, params) + .stream() + .map(m -> String.format(SCHEMA_PROMPT, m.get("nodes"), m.get("relationships"), m.get("patterns"))) + .collect(Collectors.joining("\n")); + } +} diff --git a/full/src/main/resources/extended.txt b/full/src/main/resources/extended.txt index 46dc107fc6..f8b3cc9fe3 100644 --- a/full/src/main/resources/extended.txt +++ b/full/src/main/resources/extended.txt @@ -82,6 +82,12 @@ apoc.log.warn apoc.metrics.get apoc.metrics.list apoc.metrics.storage +apoc.ml.cypher +apoc.ml.query +apoc.ml.schema +apoc.ml.openai.chat +apoc.ml.openai.completion +apoc.ml.openai.embedding apoc.model.jdbc apoc.mongo.aggregate apoc.mongo.count diff --git a/full/src/test/java/apoc/ml/OpenAIIT.java b/full/src/test/java/apoc/ml/OpenAIIT.java new file mode 100644 index 0000000000..57841fe513 --- /dev/null +++ b/full/src/test/java/apoc/ml/OpenAIIT.java @@ -0,0 +1,115 @@ +package apoc.ml; + +import apoc.util.TestUtil; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +import java.util.List; +import java.util.Map; + +import static apoc.util.TestUtil.testCall; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OpenAIIT { + + private String openaiKey; + + @Rule + public DbmsRule db = new ImpermanentDbmsRule(); + + public OpenAIIT() { + } + + @Before + public void setUp() throws Exception { + openaiKey = System.getenv("OPENAI_KEY"); + Assume.assumeNotNull("No OPENAI_KEY environment configured", openaiKey); + TestUtil.registerProcedure(db, OpenAI.class); + } + + @Test + public void getEmbedding() { + testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey)", Map.of("apiKey",openaiKey),(row) -> { + System.out.println("row = " + row); + assertEquals(0L, row.get("index")); + assertEquals("Some Text", row.get("text")); + var embedding = (List) row.get("embedding"); + assertEquals(1536, embedding.size()); + assertEquals(true, embedding.stream().allMatch(d -> d instanceof Double)); + }); + } + + @Test + public void completion() { + testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey)", + Map.of("apiKey",openaiKey),(row) -> { + System.out.println("row = " + row); + var result = (Map)row.get("value"); + assertEquals(true, result.get("created") instanceof Number); + assertEquals(true, result.containsKey("choices")); + var finishReason = (String)((List) result.get("choices")).get(0).get("finish_reason"); + assertEquals(true, finishReason.matches("stop|length")); + String text = (String) ((List) result.get("choices")).get(0).get("text"); + assertEquals(true, text != null && !text.isBlank()); + assertEquals(true, text.toLowerCase().contains("blue")); + assertEquals(true, result.containsKey("usage")); + assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); + assertEquals("text-davinci-003", result.get("model")); + assertEquals("text_completion", result.get("object")); + }); + } + + @Test + public void chatCompletion() { + testCall(db, "CALL apoc.ml.openai.chat([\n" + + "{role:\"system\", content:\"Only answer with a single word\"},\n" + + "{role:\"user\", content:\"What planet do humans live on?\"}\n" + + "], $apiKey)\n", Map.of("apiKey",openaiKey), (row) -> { + System.out.println("row = " + row); + var result = (Map)row.get("value"); + assertEquals(true, result.get("created") instanceof Number); + assertEquals(true, result.containsKey("choices")); + + Map message = ((List>) result.get("choices")).get(0).get("message"); + assertEquals("assistant", message.get("role")); + // assertEquals("stop", message.get("finish_reason")); + String text = (String) message.get("content"); + assertEquals(true, text != null && !text.isBlank()); + + + assertEquals(true, result.containsKey("usage")); + assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); + assertTrue(result.get("model").toString().startsWith("gpt-3.5-turbo")); + assertEquals("chat.completion", result.get("object")); + }); + + /* + { + "id": "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve", + "object": "chat.completion", + "created": 1677649420, + "model": "gpt-3.5-turbo", + "usage": { + "prompt_tokens": 56, + "completion_tokens": 31, + "total_tokens": 87 + }, + "choices": [ + { + "message": { + "role": "assistant", + "finish_reason": "stop", + "index": 0, + "content": "The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers." + } + } + ] +} + */ + } +} \ No newline at end of file diff --git a/full/src/test/java/apoc/ml/OpenAITest.java b/full/src/test/java/apoc/ml/OpenAITest.java new file mode 100644 index 0000000000..3c2383a93a --- /dev/null +++ b/full/src/test/java/apoc/ml/OpenAITest.java @@ -0,0 +1,112 @@ +package apoc.ml; + +import apoc.util.TestUtil; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; + +import static apoc.ApocConfig.APOC_IMPORT_FILE_ENABLED; +import static apoc.ApocConfig.apocConfig; +import static apoc.util.TestUtil.getUrlFileName; +import static apoc.util.TestUtil.testCall; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class OpenAITest { + + private String openaiKey; + + @Rule + public DbmsRule db = new ImpermanentDbmsRule(); + + public OpenAITest() { + } + + @Before + public void setUp() throws Exception { + // openaiKey = System.getenv("OPENAI_KEY"); + // Assume.assumeNotNull("No OPENAI_KEY environment configured", openaiKey); + var path = Paths.get(getUrlFileName("embeddings").toURI()).getParent().toUri(); + System.setProperty(OpenAI.APOC_ML_OPENAI_URL, path.toString()); + apocConfig().setProperty(APOC_IMPORT_FILE_ENABLED, true); + TestUtil.registerProcedure(db, OpenAI.class); + } + + @Test + public void getEmbedding() { + testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], 'fake-api-key')", (row) -> { + assertEquals(0L, row.get("index")); + assertEquals("Some Text", row.get("text")); + assertEquals(List.of(0.0023064255, -0.009327292, -0.0028842222), row.get("embedding")); + }); + } + + @Test + public void completion() { + testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer: ', 'fake-api-key')", (row) -> { + var result = (Map)row.get("value"); + assertEquals(true, result.get("created") instanceof Number); + assertEquals(true, result.containsKey("choices")); + assertEquals("stop", ((List)result.get("choices")).get(0).get("finish_reason")); + String text = (String) ((List) result.get("choices")).get(0).get("text"); + assertEquals(true, text != null && !text.isBlank()); + assertEquals(true, result.containsKey("usage")); + assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); + assertEquals("text-davinci-003", result.get("model")); + assertEquals("text_completion", result.get("object")); + }); + } + + @Test + public void chatCompletion() { + testCall(db, "CALL apoc.ml.openai.chat([\n" + + "{role:\"system\", content:\"Only answer with a single word\"},\n" + + "{role:\"user\", content:\"What planet do humans live on?\"}\n" + + "], 'fake-api-key')\n", (row) -> { + var result = (Map)row.get("value"); + assertEquals(true, result.get("created") instanceof Number); + assertEquals(true, result.containsKey("choices")); + + Map message = ((List>) result.get("choices")).get(0).get("message"); + assertEquals("assistant", message.get("role")); + assertEquals("stop", message.get("finish_reason")); + String text = (String) message.get("content"); + assertEquals(true, text != null && !text.isBlank()); + + + assertEquals(true, result.containsKey("usage")); + assertEquals(true, ((Map)result.get("usage")).get("prompt_tokens") instanceof Number); + assertEquals("gpt-3.5-turbo-0301", result.get("model")); + assertEquals("chat.completion", result.get("object")); + }); + + /* + { + "id": "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve", + "object": "chat.completion", + "created": 1677649420, + "model": "gpt-3.5-turbo", + "usage": { + "prompt_tokens": 56, + "completion_tokens": 31, + "total_tokens": 87 + }, + "choices": [ + { + "message": { + "role": "assistant", + "finish_reason": "stop", + "index": 0, + "content": "The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers." + } + } + ] +} + */ + } +} \ No newline at end of file diff --git a/full/src/test/java/apoc/ml/PromptIT.java b/full/src/test/java/apoc/ml/PromptIT.java new file mode 100644 index 0000000000..0f31136449 --- /dev/null +++ b/full/src/test/java/apoc/ml/PromptIT.java @@ -0,0 +1,101 @@ +package apoc.ml; + +import apoc.coll.Coll; +import apoc.meta.Meta; +import apoc.text.Strings; +import apoc.util.TestUtil; +import apoc.util.Util; +import org.apache.commons.lang3.StringUtils; +import org.assertj.core.api.Assertions; +import org.junit.Assume; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.neo4j.graphdb.Transaction; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static apoc.util.TestUtil.testResult; + +public class PromptIT { + + private static final String OPENAI_KEY = System.getenv("OPENAI_KEY"); + + @Rule + public DbmsRule db = new ImpermanentDbmsRule(); + + @BeforeClass + public static void check() { + Assume.assumeNotNull("No OPENAI_KEY environment configured", OPENAI_KEY); + } + + @Before + public void setUp() { + TestUtil.registerProcedure(db, Prompt.class, Meta.class, Strings.class, Coll.class); + String movies = Util.readResourceFile("movies.cypher"); + try (Transaction tx = db.beginTx()) { + tx.execute(movies); + tx.commit(); + } + } + + @Test + public void testQuery() { + testResult(db, "CALL apoc.ml.query($query, {retries: $retries, apiKey: $apiKey})", + Map.of( + "query", "What movies did Tom Hanks play in?", + "retries", 2L, + "apiKey", OPENAI_KEY + ), + (r) -> { + List> list = r.stream().collect(Collectors.toList()); + Assertions.assertThat(list).hasSize(12); + Assertions.assertThat(list.stream() + .map(m -> m.get("query")) + .filter(Objects::nonNull) + .map(Object::toString) + .map(String::trim)) + .isNotEmpty(); + }); + } + + @Test + public void testSchema() { + testResult(db, "CALL apoc.ml.schema({apiKey: $apiKey})", + Map.of( + "apiKey", OPENAI_KEY + ), + (r) -> { + List> list = r.stream().collect(Collectors.toList()); + Assertions.assertThat(list).hasSize(1); + }); + } + + @Test + public void testCypher() { + long numOfQueries = 4L; + testResult(db, "CALL apoc.ml.cypher($query, {count: $numOfQueries, apiKey: $apiKey})", + Map.of( + "query", "Who are the actors which also directed a movie?", + "numOfQueries", numOfQueries, + "apiKey", OPENAI_KEY + ), + (r) -> { + List> list = r.stream().collect(Collectors.toList()); + Assertions.assertThat(list).hasSize((int) numOfQueries); + Assertions.assertThat(list.stream() + .map(m -> m.get("query")) + .filter(Objects::nonNull) + .map(Object::toString) + .filter(StringUtils::isNotEmpty)) + .hasSize((int) numOfQueries); + }); + } + +} diff --git a/full/src/test/resources/chat/completions b/full/src/test/resources/chat/completions new file mode 100644 index 0000000000..eeb37e6a46 --- /dev/null +++ b/full/src/test/resources/chat/completions @@ -0,0 +1,21 @@ +{ + "id": "chatcmpl-6p9XYPYSTTRi0xEviKjjilqrWU2Ve", + "object": "chat.completion", + "created": 1677649420, + "model": "gpt-3.5-turbo-0301", + "usage": { + "prompt_tokens": 56, + "completion_tokens": 31, + "total_tokens": 87 + }, + "choices": [ + { + "message": { + "role": "assistant", + "finish_reason": "stop", + "index": 0, + "content": "The 2020 World Series was played in Arlington, Texas at the Globe Life Field, which was the new home stadium for the Texas Rangers." + } + } + ] +} \ No newline at end of file diff --git a/full/src/test/resources/completions b/full/src/test/resources/completions new file mode 100644 index 0000000000..7fb6cc274f --- /dev/null +++ b/full/src/test/resources/completions @@ -0,0 +1,19 @@ +{ + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "text-davinci-003", + "choices": [ + { + "text": "\n\nThis is indeed a test", + "index": 0, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12 + } +} diff --git a/full/src/test/resources/embeddings b/full/src/test/resources/embeddings new file mode 100644 index 0000000000..f1450b25de --- /dev/null +++ b/full/src/test/resources/embeddings @@ -0,0 +1,15 @@ +{ + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [0.0023064255, -0.009327292, -0.0028842222], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } +} \ No newline at end of file