From 3f565c45bb67d104e9793b90e64bba3d2b1fa0cd Mon Sep 17 00:00:00 2001 From: HawickMason <1914991129@qq.com> Date: Tue, 21 Jan 2025 19:15:57 +0800 Subject: [PATCH 1/4] feat(mcp_memory): add mcp server for memory --- jcommon/mcp/mcp-memory/pom.xml | 57 +++ .../mone/mcp/memory/MemoryMcpBootstrap.java | 13 + .../config/McpStdioTransportConfig.java | 21 + .../mcp/memory/function/MemoryFunctions.java | 417 ++++++++++++++++++ .../run/mone/mcp/memory/graph/Entity.java | 22 + .../mone/mcp/memory/graph/KnowledgeGraph.java | 11 + .../memory/graph/KnowledgeGraphManager.java | 232 ++++++++++ .../run/mone/mcp/memory/graph/Relation.java | 17 + .../mcp/memory/server/MemoryMcpServer.java | 103 +++++ .../src/main/resources/application.properties | 2 + .../mcp/memory/test/MemoryFunctionsTest.java | 180 ++++++++ .../resources/application-test.properties | 6 + jcommon/mcp/pom.xml | 3 +- 13 files changed, 1083 insertions(+), 1 deletion(-) create mode 100644 jcommon/mcp/mcp-memory/pom.xml create mode 100644 jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/MemoryMcpBootstrap.java create mode 100644 jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/config/McpStdioTransportConfig.java create mode 100644 jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/function/MemoryFunctions.java create mode 100644 jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/Entity.java create mode 100644 jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraph.java create mode 100644 jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraphManager.java create mode 100644 jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/Relation.java create mode 100644 jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/server/MemoryMcpServer.java create mode 100644 jcommon/mcp/mcp-memory/src/main/resources/application.properties create mode 100644 jcommon/mcp/mcp-memory/src/test/java/run/mone/mcp/memory/test/MemoryFunctionsTest.java create mode 100644 jcommon/mcp/mcp-memory/src/test/resources/application-test.properties diff --git a/jcommon/mcp/mcp-memory/pom.xml b/jcommon/mcp/mcp-memory/pom.xml new file mode 100644 index 000000000..9234e221b --- /dev/null +++ b/jcommon/mcp/mcp-memory/pom.xml @@ -0,0 +1,57 @@ + + + 4.0.0 + + run.mone + mcp + 1.6.1-jdk21-SNAPSHOT + + + mcp-memory + + + 17 + 17 + + + + + + + maven-compiler-plugin + 3.11.0 + + ${maven.compiler.source} + ${maven.compiler.target} + true + UTF-8 + + ${project.basedir}/src/main/java + + + + + org.springframework.boot + spring-boot-maven-plugin + 2.7.14 + + run.mone.mcp.memory.MemoryMcpBootstrap + app + + + + + repackage + + + + + + + + + + + \ No newline at end of file diff --git a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/MemoryMcpBootstrap.java b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/MemoryMcpBootstrap.java new file mode 100644 index 000000000..0f4c938ec --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/MemoryMcpBootstrap.java @@ -0,0 +1,13 @@ +package run.mone.mcp.memory; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.ComponentScan; + +@SpringBootApplication +@ComponentScan("run.mone.mcp.memory") +public class MemoryMcpBootstrap { + public static void main(String[] args) { + SpringApplication.run(MemoryMcpBootstrap.class, args); + } +} \ No newline at end of file diff --git a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/config/McpStdioTransportConfig.java b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/config/McpStdioTransportConfig.java new file mode 100644 index 000000000..ec3e5ba9f --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/config/McpStdioTransportConfig.java @@ -0,0 +1,21 @@ +package run.mone.mcp.memory.config; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import run.mone.hive.mcp.server.transport.StdioServerTransport; + +@Configuration +@ConditionalOnProperty(name = "stdio.enabled", havingValue = "true") +class McpStdioTransportConfig { + /** + * stdio 通信 + * @param mapper + * @return + */ + @Bean + StdioServerTransport stdioServerTransport(ObjectMapper mapper) { + return new StdioServerTransport(mapper); + } +} diff --git a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/function/MemoryFunctions.java b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/function/MemoryFunctions.java new file mode 100644 index 000000000..13aaca558 --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/function/MemoryFunctions.java @@ -0,0 +1,417 @@ +package run.mone.mcp.memory.function; + +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import run.mone.hive.mcp.spec.McpSchema; +import run.mone.hive.mcp.spec.McpSchema.CallToolResult; +import run.mone.mcp.memory.graph.KnowledgeGraphManager; +import run.mone.mcp.memory.graph.Entity; +import run.mone.mcp.memory.graph.KnowledgeGraph; +import run.mone.mcp.memory.graph.Relation; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +@Data +@Slf4j +public class MemoryFunctions { + + private static final KnowledgeGraphManager graphManager = new KnowledgeGraphManager(); + + private static final ObjectMapper objectMapper = new ObjectMapper(); + + @Data + public static class CreateEntitiesFunction implements Function, McpSchema.CallToolResult> { + private String name = "create_entities"; + + private String desc = "Create multiple new entities in the knowledge graph"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "entities": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name of the entity" + }, + "entityType": { + "type": "string", + "description": "The type of the entity" + }, + "observations": { + "type": "array", + "items": { + "type": "string" + }, + "description": "An array of observation contents associated with the entity" + } + }, + "required": ["name", "entityType", "observations"] + } + } + }, + "required": ["entities"] + } + """; + + @Override + public CallToolResult apply(Map t) { + List entities = parseObject(t.get("entities"), new TypeReference>() {}); + log.info("Creating entities: {}", entities); + try { + List newEntities = graphManager.createEntities(entities); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(objectMapper.writeValueAsString(newEntities))), false); + } catch (Throwable e) { + log.error("Failed to create entities", e); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Failed to create entities")), true); + } + } + } + + @Data + public static class CreateRelationsFunction implements Function, McpSchema.CallToolResult> { + private String name = "create_relations"; + + private String desc = "Create multiple new relations between entities in the knowledge graph. Relations should be in active voice"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "relations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "from": { + "type": "string", + "description": "The name of the entity where the relation starts" + }, + "to": { + "type": "string", + "description": "The name of the entity where the relation ends" + }, + "relationType": { + "type": "string", + "description": "The type of the relation" + } + }, + "required": ["from", "to", "relationType"] + } + } + }, + "required": ["relations"] + } + """; + + @Override + public CallToolResult apply(Map t) { + List relations = parseObject(t.get("relations"), new TypeReference>() {}); + log.info("Creating relations: {}", relations); + try { + List newRelations = graphManager.createRelations(relations); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(objectMapper.writeValueAsString(newRelations))), false); + } catch (Throwable e) { + log.error("Failed to create relations", e); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Failed to create relations")), true); + } + } + } + + @Data + public static class AddObservationsFunction implements Function, McpSchema.CallToolResult> { + private String name = "add_observations"; + + private String desc = "Add new observations to existing entities in the knowledge graph"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "observations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "entityName": { + "type": "string", + "description": "The name of the entity to add the observations to" + }, + "contents": { + "type": "array", + "items": { + "type": "string" + }, + "description": "An array of observation contents to add" + } + }, + "required": ["entityName", "contents"] + } + } + }, + "required": ["observations"] + } + """; + + @Override + public CallToolResult apply(Map t) { + List observations = parseObject(t.get("observations"), new TypeReference>() {}); + log.info("Adding observations: {}", observations); + try { + List results = graphManager.addObservations(observations); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(objectMapper.writeValueAsString(results))), false); + } catch (Throwable e) { + log.error("Failed to add observations", e); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Failed to add observations")), true); + } + } + } + + @Data + public static class DeleteEntitiesFunction implements Function, McpSchema.CallToolResult> { + private String name = "delete_entities"; + + private String desc = "Delete multiple entities and their associated relations from the knowledge graph"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "entityNames": { + "type": "array", + "items": { + "type": "string" + }, + "description": "An array of entity names to delete" + } + }, + "required": ["entityNames"] + } + """; + + @Override + public CallToolResult apply(Map t) { + List entityNames = parseObject(t.get("entityNames"), new TypeReference>() {}); + log.info("Deleting entities: {}", entityNames); + try { + graphManager.deleteEntities(entityNames); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Entities deleted successfully")), false); + } catch (Throwable e) { + log.error("Failed to delete entities", e); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Failed to delete entities")), true); + } + } + } + + @Data + public static class DeleteObservationsFunction implements Function, McpSchema.CallToolResult> { + private String name = "delete_observations"; + + private String desc = "Delete specific observations from entities in the knowledge graph"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "deletions": { + "type": "array", + "items": { + "type": "object", + "properties": { + "entityName": { + "type": "string", + "description": "The name of the entity containing the observations" + }, + "observations": { + "type": "array", + "items": { + "type": "string" + }, + "description": "An array of observations to delete" + } + }, + "required": ["entityName", "observations"] + } + } + }, + "required": ["deletions"] + } + """; + + @Override + public CallToolResult apply(Map t) { + List deletions = parseObject(t.get("deletions"), new TypeReference>() {}); + log.info("Deleting observations: {}", deletions); + try { + graphManager.deleteObservations(deletions); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Observations deleted successfully")), false); + } catch (Throwable e) { + log.error("Failed to delete observations", e); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Failed to delete observations")), true); + } + } + } + + @Data + public static class DeleteRelationsFunction implements Function, McpSchema.CallToolResult> { + private String name = "delete_relations"; + + private String desc = "Delete multiple relations from the knowledge graph"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "relations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "from": { + "type": "string", + "description": "The name of the entity where the relation starts" + }, + "to": { + "type": "string", + "description": "The name of the entity where the relation ends" + }, + "relationType": { + "type": "string", + "description": "The type of the relation" + } + }, + "required": ["from", "to", "relationType"] + } + } + }, + "required": ["relations"] + } + """; + + @Override + public CallToolResult apply(Map t) { + List relations = parseObject(t.get("relations"), new TypeReference>() {}); + log.info("Deleting relations: {}", relations); + try { + graphManager.deleteRelations(relations); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Relations deleted successfully")), false); + } catch (Throwable e) { + log.error("Failed to delete relations", e); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Failed to delete relations")), true); + } + } + } + + @Data + public static class ReadGraphFunction implements Function, McpSchema.CallToolResult> { + private String name = "read_graph"; + + private String desc = "Read the entire knowledge graph"; + + private String toolScheme = """ + { + "type": "object", + "properties": {} + } + """; + + @Override + public CallToolResult apply(Map t) { + log.info("Reading graph"); + try { + KnowledgeGraph graph = graphManager.readGraph(); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(objectMapper.writeValueAsString(graph))), false); + } catch (Throwable e) { + log.error("Failed to read graph", e); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Failed to read graph")), true); + } + } + } + + @Data + public static class SearchNodesFunction implements Function, McpSchema.CallToolResult> { + private String name = "search_nodes"; + + private String desc = "Search for nodes in the knowledge graph based on a query"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to match against entity names, types, and observation content" + } + }, + "required": ["query"] + } + """; + + @Override + public CallToolResult apply(Map t) { + String query = (String) t.get("query"); + log.info("Searching nodes with query: {}", query); + try { + KnowledgeGraph results = graphManager.searchNodes(query); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(objectMapper.writeValueAsString(results))), false); + } catch (Throwable e) { + log.error("Failed to search nodes", e); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Failed to search nodes")), true); + } + } + } + + @Data + public static class OpenNodesFunction implements Function, McpSchema.CallToolResult> { + private String name = "open_nodes"; + + private String desc = "Open specific nodes in the knowledge graph by their names"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "names": { + "type": "array", + "items": { + "type": "string" + }, + "description": "An array of entity names to retrieve" + } + }, + "required": ["names"] + } + """; + + @Override + public CallToolResult apply(Map t) { + List names = parseObject(t.get("names"), new TypeReference>() {}); + log.info("Opening nodes: {}", names); + try { + KnowledgeGraph results = graphManager.openNodes(names); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent(objectMapper.writeValueAsString(results))), false); + } catch (Throwable e) { + log.error("Failed to open nodes", e); + return new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("Failed to open nodes")), true); + } + } + } + + private static T parseObject(Object obj, TypeReference typeReference) { + try { + return objectMapper.readValue(objectMapper.writeValueAsString(obj), typeReference); + } catch (Exception e) { + log.error("Failed to parse JSON: {}", obj, e); + throw new RuntimeException("Failed to parse JSON", e); + } + } +} \ No newline at end of file diff --git a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/Entity.java b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/Entity.java new file mode 100644 index 000000000..d5a310ee7 --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/Entity.java @@ -0,0 +1,22 @@ +package run.mone.mcp.memory.graph; + +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.ArrayList; + +@Data +@NoArgsConstructor +public class Entity { + private String type = "entity"; + private String name; + private String entityType; + private List observations; + + public Entity(String name, String entityType) { + this.name = name; + this.entityType = entityType; + this.observations = new ArrayList<>(); + } +} diff --git a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraph.java b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraph.java new file mode 100644 index 000000000..73d15aea8 --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraph.java @@ -0,0 +1,11 @@ +package run.mone.mcp.memory.graph; + +import lombok.Data; +import java.util.List; +import java.util.ArrayList; + +@Data +public class KnowledgeGraph { + private List entities = new ArrayList<>(); + private List relations = new ArrayList<>(); +} \ No newline at end of file diff --git a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraphManager.java b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraphManager.java new file mode 100644 index 000000000..92b24ebf6 --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraphManager.java @@ -0,0 +1,232 @@ +package run.mone.mcp.memory.graph; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.core.type.TypeReference; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; + +import java.io.*; +import java.nio.file.*; +import java.util.*; +import java.util.stream.Collectors; + +@Slf4j +public class KnowledgeGraphManager { + private static final String MEMORY_FILE_PATH = System.getProperty("java.io.tmpdir") + File.separator + "memory.jsonl"; + private final ObjectMapper objectMapper = new ObjectMapper(); + + private void ensureFileExists() throws IOException { + Path path = Paths.get(MEMORY_FILE_PATH); + if (!Files.exists(path)) { + Files.createDirectories(path.getParent()); + Files.createFile(path); + } + } + + private KnowledgeGraph loadGraph() throws IOException { + ensureFileExists(); + List lines = Files.readAllLines(Paths.get(MEMORY_FILE_PATH)); + + KnowledgeGraph graph = new KnowledgeGraph(); + + for (String line : lines) { + if (line.trim().isEmpty()) continue; + + Map item = objectMapper.readValue(line, new TypeReference>() {}); + String type = (String) item.get("type"); + + if ("entity".equals(type)) { + Entity entity = objectMapper.convertValue(item, Entity.class); + graph.getEntities().add(entity); + } else if ("relation".equals(type)) { + Relation relation = objectMapper.convertValue(item, Relation.class); + graph.getRelations().add(relation); + } + } + return graph; + } + + private void saveGraph(KnowledgeGraph graph) throws IOException { + List lines = new ArrayList<>(); + + for (Entity entity : graph.getEntities()) { + Map item = new HashMap<>(); + item.put("type", "entity"); + item.put("name", entity.getName()); + item.put("entityType", entity.getEntityType()); + item.put("observations", entity.getObservations()); + lines.add(objectMapper.writeValueAsString(item)); + } + + for (Relation relation : graph.getRelations()) { + Map item = new HashMap<>(); + item.put("type", "relation"); + item.put("from", relation.getFrom()); + item.put("to", relation.getTo()); + item.put("relationType", relation.getRelationType()); + lines.add(objectMapper.writeValueAsString(item)); + } + + Files.write(Paths.get(MEMORY_FILE_PATH), lines); + } + + public List createEntities(List entities) throws IOException { + KnowledgeGraph graph = loadGraph(); + List newEntities = entities.stream() + .filter(e -> !graph.getEntities().contains(e.getName())) + .collect(Collectors.toList()); + + graph.getEntities().addAll(newEntities); + saveGraph(graph); + return newEntities; + } + + public List createRelations(List relations) throws IOException { + KnowledgeGraph graph = loadGraph(); + List newRelations = relations.stream() + .filter(r -> !graph.getRelations().stream() + .anyMatch(existing -> + existing.getFrom().equals(r.getFrom()) && + existing.getTo().equals(r.getTo()) && + existing.getRelationType().equals(r.getRelationType()))) + .collect(Collectors.toList()); + + graph.getRelations().addAll(newRelations); + saveGraph(graph); + return newRelations; + } + + public List addObservations(List requests) throws IOException { + KnowledgeGraph graph = loadGraph(); + List results = new ArrayList<>(); + + for (ObservationRequest request : requests) { + Entity entity = graph.getEntities().stream() + .filter(e -> e.getName().equals(request.getEntityName())) + .findFirst() + .orElse(null); + if (entity == null) { + throw new IllegalArgumentException("Entity not found: " + request.getEntityName()); + } + + List newObservations = request.getContents().stream() + .filter(content -> !entity.getObservations().contains(content)) + .collect(Collectors.toList()); + + entity.getObservations().addAll(newObservations); + results.add(new ObservationResult(request.getEntityName(), newObservations)); + } + + saveGraph(graph); + return results; + } + + public void deleteEntities(List entityNames) throws IOException { + KnowledgeGraph graph = loadGraph(); + graph.getEntities().removeIf(e -> entityNames.contains(e.getName())); + saveGraph(graph); + } + + public void deleteObservations(List deletions) throws IOException { + KnowledgeGraph graph = loadGraph(); + + for (ObservationDeletion deletion : deletions) { + Entity entity = graph.getEntities().stream() + .filter(e -> e.getName().equals(deletion.getEntityName())) + .findFirst() + .orElse(null); + if (entity != null) { + entity.getObservations().removeAll(deletion.getObservations()); + } + } + + saveGraph(graph); + } + + public void deleteRelations(List relations) throws IOException { + KnowledgeGraph graph = loadGraph(); + graph.getRelations().removeIf(r -> + relations.stream().anyMatch(delRelation -> + r.getFrom().equals(delRelation.getFrom()) && + r.getTo().equals(delRelation.getTo()) && + r.getRelationType().equals(delRelation.getRelationType()) + ) + ); + saveGraph(graph); + } + + public KnowledgeGraph readGraph() throws IOException { + return loadGraph(); + } + + public KnowledgeGraph searchNodes(String query) throws IOException { + KnowledgeGraph graph = loadGraph(); + String lowercaseQuery = query.toLowerCase(); + + List filteredEntities = graph.getEntities().stream() + .filter(e -> + e.getName().toLowerCase().contains(lowercaseQuery) || + e.getEntityType().toLowerCase().contains(lowercaseQuery) || + e.getObservations().stream() + .anyMatch(o -> o.toLowerCase().contains(lowercaseQuery))) + .collect(Collectors.toList()); + + List filteredRelations = graph.getRelations().stream() + .filter(r -> + filteredEntities.contains(r.getFrom()) && + filteredEntities.contains(r.getTo())) + .collect(Collectors.toList()); + + KnowledgeGraph filteredGraph = new KnowledgeGraph(); + filteredGraph.getEntities().addAll(filteredEntities); + filteredGraph.getRelations().addAll(filteredRelations); + + return filteredGraph; + } + + public KnowledgeGraph openNodes(List names) throws IOException { + KnowledgeGraph graph = loadGraph(); + + List filteredEntities = graph.getEntities().stream() + .filter(e -> names.contains(e.getName())) + .collect(Collectors.toList()); + + List filteredRelations = graph.getRelations().stream() + .filter(r -> + filteredEntities.contains(r.getFrom()) && + filteredEntities.contains(r.getTo())) + .collect(Collectors.toList()); + + KnowledgeGraph filteredGraph = new KnowledgeGraph(); + filteredGraph.getEntities().addAll(filteredEntities); + filteredGraph.getRelations().addAll(filteredRelations); + + return filteredGraph; + } + + // 辅助类 + @Data + public static class ObservationRequest { + private String entityName; + private List contents; + } + + @Data + public static class ObservationResult { + private String entityName; + private List addedObservations; + + public ObservationResult(String entityName, List addedObservations) { + this.entityName = entityName; + this.addedObservations = addedObservations; + } + } + + @Data + public static class ObservationDeletion { + private String entityName; + private List observations; + } +} + + diff --git a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/Relation.java b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/Relation.java new file mode 100644 index 000000000..1fa5c8531 --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/Relation.java @@ -0,0 +1,17 @@ +package run.mone.mcp.memory.graph; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class Relation { + private String type = "relation"; + private String from; + private String to; + private String relationType; +} diff --git a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/server/MemoryMcpServer.java b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/server/MemoryMcpServer.java new file mode 100644 index 000000000..e4a1baf64 --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/server/MemoryMcpServer.java @@ -0,0 +1,103 @@ +package run.mone.mcp.memory.server; + +import org.springframework.stereotype.Component; + +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import run.mone.hive.mcp.server.McpServer; +import run.mone.hive.mcp.server.McpSyncServer; +import run.mone.hive.mcp.spec.McpSchema.ServerCapabilities; +import run.mone.hive.mcp.spec.ServerMcpTransport; +import run.mone.mcp.memory.function.MemoryFunctions; +import run.mone.hive.mcp.server.McpServer.ToolRegistration; +import run.mone.hive.mcp.spec.McpSchema.Tool; + +@Component +public class MemoryMcpServer { + + private ServerMcpTransport transport; + private McpSyncServer syncServer; + + public MemoryMcpServer(ServerMcpTransport transport) { + this.transport = transport; + } + + public McpSyncServer start() { + McpSyncServer syncServer = McpServer.using(transport) + .serverInfo("memory_mcp", "0.0.1") + .capabilities(ServerCapabilities.builder() + .tools(true) + .logging() + .build()) + .sync(); + + MemoryFunctions.CreateEntitiesFunction function1 = new MemoryFunctions.CreateEntitiesFunction(); + var toolRegistration1 = new ToolRegistration( + new Tool(function1.getName(), function1.getDesc(), function1.getToolScheme()), function1 + ); + + MemoryFunctions.CreateRelationsFunction function2 = new MemoryFunctions.CreateRelationsFunction(); + var toolRegistration2 = new ToolRegistration( + new Tool(function2.getName(), function2.getDesc(), function2.getToolScheme()), function2 + ); + + MemoryFunctions.AddObservationsFunction function3 = new MemoryFunctions.AddObservationsFunction(); + var toolRegistration3 = new ToolRegistration( + new Tool(function3.getName(), function3.getDesc(), function3.getToolScheme()), function3 + ); + + MemoryFunctions.DeleteEntitiesFunction function4 = new MemoryFunctions.DeleteEntitiesFunction(); + var toolRegistration4 = new ToolRegistration( + new Tool(function4.getName(), function4.getDesc(), function4.getToolScheme()), function4 + ); + + MemoryFunctions.DeleteRelationsFunction function5 = new MemoryFunctions.DeleteRelationsFunction(); + var toolRegistration5 = new ToolRegistration( + new Tool(function5.getName(), function5.getDesc(), function5.getToolScheme()), function5 + ); + + MemoryFunctions.DeleteObservationsFunction function6 = new MemoryFunctions.DeleteObservationsFunction(); + var toolRegistration6 = new ToolRegistration( + new Tool(function6.getName(), function6.getDesc(), function6.getToolScheme()), function6 + ); + + MemoryFunctions.ReadGraphFunction function7 = new MemoryFunctions.ReadGraphFunction(); + var toolRegistration7 = new ToolRegistration( + new Tool(function7.getName(), function7.getDesc(), function7.getToolScheme()), function7 + ); + + MemoryFunctions.SearchNodesFunction function8 = new MemoryFunctions.SearchNodesFunction(); + var toolRegistration8 = new ToolRegistration( + new Tool(function8.getName(), function8.getDesc(), function8.getToolScheme()), function8 + ); + + MemoryFunctions.OpenNodesFunction function9 = new MemoryFunctions.OpenNodesFunction(); + var toolRegistration9 = new ToolRegistration( + new Tool(function9.getName(), function9.getDesc(), function9.getToolScheme()), function9 + ); + + syncServer.addTool(toolRegistration1); + syncServer.addTool(toolRegistration2); + syncServer.addTool(toolRegistration3); + syncServer.addTool(toolRegistration4); + syncServer.addTool(toolRegistration5); + syncServer.addTool(toolRegistration6); + syncServer.addTool(toolRegistration7); + syncServer.addTool(toolRegistration8); + syncServer.addTool(toolRegistration9); + return syncServer; + } + + @PostConstruct + public void init() { + this.syncServer = start(); + } + + @PreDestroy + public void stop() { + if (this.syncServer != null) { + this.syncServer.closeGracefully(); + } + } + +} diff --git a/jcommon/mcp/mcp-memory/src/main/resources/application.properties b/jcommon/mcp/mcp-memory/src/main/resources/application.properties new file mode 100644 index 000000000..33a709ec8 --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/main/resources/application.properties @@ -0,0 +1,2 @@ +stdio.enabled=true +spring.main.web-application-type=none \ No newline at end of file diff --git a/jcommon/mcp/mcp-memory/src/test/java/run/mone/mcp/memory/test/MemoryFunctionsTest.java b/jcommon/mcp/mcp-memory/src/test/java/run/mone/mcp/memory/test/MemoryFunctionsTest.java new file mode 100644 index 000000000..775ca7a7e --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/test/java/run/mone/mcp/memory/test/MemoryFunctionsTest.java @@ -0,0 +1,180 @@ +package run.mone.mcp.memory.test; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.TestPropertySource; +import run.mone.hive.mcp.spec.McpSchema; +import run.mone.mcp.memory.MemoryMcpBootstrap; +import run.mone.mcp.memory.function.MemoryFunctions; +import run.mone.mcp.memory.graph.Entity; +import run.mone.mcp.memory.graph.Relation; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +@SpringBootTest(classes = MemoryMcpBootstrap.class) +@TestPropertySource(locations = "classpath:application-test.properties") +public class MemoryFunctionsTest { + + private MemoryFunctions.CreateEntitiesFunction createEntitiesFunction; + private MemoryFunctions.CreateRelationsFunction createRelationsFunction; + private MemoryFunctions.AddObservationsFunction addObservationsFunction; + private MemoryFunctions.DeleteEntitiesFunction deleteEntitiesFunction; + private MemoryFunctions.DeleteObservationsFunction deleteObservationsFunction; + private MemoryFunctions.DeleteRelationsFunction deleteRelationsFunction; + private MemoryFunctions.ReadGraphFunction readGraphFunction; + private MemoryFunctions.SearchNodesFunction searchNodesFunction; + private MemoryFunctions.OpenNodesFunction openNodesFunction; + + @BeforeEach + void setUp() { + createEntitiesFunction = new MemoryFunctions.CreateEntitiesFunction(); + createRelationsFunction = new MemoryFunctions.CreateRelationsFunction(); + addObservationsFunction = new MemoryFunctions.AddObservationsFunction(); + deleteEntitiesFunction = new MemoryFunctions.DeleteEntitiesFunction(); + deleteObservationsFunction = new MemoryFunctions.DeleteObservationsFunction(); + deleteRelationsFunction = new MemoryFunctions.DeleteRelationsFunction(); + readGraphFunction = new MemoryFunctions.ReadGraphFunction(); + searchNodesFunction = new MemoryFunctions.SearchNodesFunction(); + openNodesFunction = new MemoryFunctions.OpenNodesFunction(); + } + + @Test + void testCreateEntities() { + List entities = new ArrayList<>(); + Entity entity = new Entity(); + entity.setName("TestEntity"); + entity.setEntityType("TestType"); + entity.setObservations(List.of("Observation1", "Observation2")); + entities.add(entity); + + Map args = new HashMap<>(); + args.put("entities", entities); + + McpSchema.CallToolResult result = createEntitiesFunction.apply(args); + assertFalse(result.isError()); + } + + @Test + void testCreateRelations() { + // First create some entities + testCreateEntities(); + + List relations = new ArrayList<>(); + Relation relation = new Relation(); + relation.setFrom("TestEntity"); + relation.setTo("TestEntity"); + relation.setRelationType("SELF_RELATION"); + relations.add(relation); + + Map args = new HashMap<>(); + args.put("relations", relations); + + McpSchema.CallToolResult result = createRelationsFunction.apply(args); + assertFalse(result.isError()); + } + + @Test + void testAddObservations() { + // First create an entity + testCreateEntities(); + + Map args = new HashMap<>(); + List> observations = new ArrayList<>(); + Map observation = new HashMap<>(); + observation.put("entityName", "TestEntity"); + observation.put("contents", List.of("NewObservation1", "NewObservation2")); + observations.add(observation); + args.put("observations", observations); + + McpSchema.CallToolResult result = addObservationsFunction.apply(args); + assertFalse(result.isError()); + } + + @Test + void testDeleteEntities() { + // First create an entity + testCreateEntities(); + + Map args = new HashMap<>(); + args.put("entityNames", List.of("TestEntity")); + + McpSchema.CallToolResult result = deleteEntitiesFunction.apply(args); + assertFalse(result.isError()); + } + + @Test + void testDeleteObservations() { + // First create an entity with observations + testCreateEntities(); + + Map args = new HashMap<>(); + List> deletions = new ArrayList<>(); + Map deletion = new HashMap<>(); + deletion.put("entityName", "TestEntity"); + deletion.put("observations", List.of("Observation1")); + deletions.add(deletion); + args.put("deletions", deletions); + + McpSchema.CallToolResult result = deleteObservationsFunction.apply(args); + assertFalse(result.isError()); + } + + @Test + void testDeleteRelations() { + // First create some relations + testCreateRelations(); + + List relations = new ArrayList<>(); + Relation relation = new Relation(); + relation.setFrom("TestEntity"); + relation.setTo("TestEntity"); + relation.setRelationType("SELF_RELATION"); + relations.add(relation); + + Map args = new HashMap<>(); + args.put("relations", relations); + + McpSchema.CallToolResult result = deleteRelationsFunction.apply(args); + assertFalse(result.isError()); + } + + @Test + void testReadGraph() { + // First create some data + testCreateRelations(); + + Map args = new HashMap<>(); + McpSchema.CallToolResult result = readGraphFunction.apply(args); + assertFalse(result.isError()); + } + + @Test + void testSearchNodes() { + // First create some data + testCreateEntities(); + + Map args = new HashMap<>(); + args.put("query", "TestEntity"); + + McpSchema.CallToolResult result = searchNodesFunction.apply(args); + assertFalse(result.isError()); + } + + @Test + void testOpenNodes() { + // First create some entities + testCreateEntities(); + + Map args = new HashMap<>(); + args.put("names", List.of("TestEntity")); + + McpSchema.CallToolResult result = openNodesFunction.apply(args); + assertFalse(result.isError()); + } +} diff --git a/jcommon/mcp/mcp-memory/src/test/resources/application-test.properties b/jcommon/mcp/mcp-memory/src/test/resources/application-test.properties new file mode 100644 index 000000000..2fe0226be --- /dev/null +++ b/jcommon/mcp/mcp-memory/src/test/resources/application-test.properties @@ -0,0 +1,6 @@ +# Test-specific configuration +spring.main.banner-mode=off + +# Logging +logging.level.root=WARN +logging.level.run.mone.mcp.playwright=DEBUG \ No newline at end of file diff --git a/jcommon/mcp/pom.xml b/jcommon/mcp/pom.xml index 07d7431ed..05e45dd25 100644 --- a/jcommon/mcp/pom.xml +++ b/jcommon/mcp/pom.xml @@ -22,6 +22,7 @@ mcp-playwright mcp-applescript mcp-multimodal + mcp-memory @@ -149,4 +150,4 @@ - \ No newline at end of file + From 54fcc5e2fb8b17be1d774aeeb7015d962177fb44 Mon Sep 17 00:00:00 2001 From: HawickMason <1914991129@qq.com> Date: Tue, 21 Jan 2025 19:47:15 +0800 Subject: [PATCH 2/4] feat(mcp_memory): add mcp server for memory, change file location --- .../java/run/mone/mcp/memory/graph/KnowledgeGraphManager.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraphManager.java b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraphManager.java index 92b24ebf6..cec14e460 100644 --- a/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraphManager.java +++ b/jcommon/mcp/mcp-memory/src/main/java/run/mone/mcp/memory/graph/KnowledgeGraphManager.java @@ -12,7 +12,7 @@ @Slf4j public class KnowledgeGraphManager { - private static final String MEMORY_FILE_PATH = System.getProperty("java.io.tmpdir") + File.separator + "memory.jsonl"; + private static final String MEMORY_FILE_PATH = System.getProperty("user.dir") + File.separator + ".memory.jsonl"; private final ObjectMapper objectMapper = new ObjectMapper(); private void ensureFileExists() throws IOException { From 662f409f0696f8122de69df4fc7f74c9fd7adf9b Mon Sep 17 00:00:00 2001 From: HawickMason <1914991129@qq.com> Date: Wed, 22 Jan 2025 14:26:12 +0800 Subject: [PATCH 3/4] feat(mcp_playwright): add functions to mcp-playwright --- .../function/PlaywrightFunctions.java | 750 ++++++++++++++++++ .../server/PlaywrightMcpServer.java | 107 ++- .../function/PlaywrightFunctionsTest.java | 192 +++++ 3 files changed, 1044 insertions(+), 5 deletions(-) create mode 100644 jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/function/PlaywrightFunctions.java create mode 100644 jcommon/mcp/mcp-playwright/src/test/java/run/mone/mcp/playwright/function/PlaywrightFunctionsTest.java diff --git a/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/function/PlaywrightFunctions.java b/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/function/PlaywrightFunctions.java new file mode 100644 index 000000000..0aa16ba99 --- /dev/null +++ b/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/function/PlaywrightFunctions.java @@ -0,0 +1,750 @@ +package run.mone.mcp.playwright.function; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.microsoft.playwright.*; +import com.microsoft.playwright.Page.NavigateOptions; +import com.microsoft.playwright.options.RequestOptions; +import com.microsoft.playwright.options.WaitUntilState; + +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import run.mone.hive.mcp.spec.McpSchema; + +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Instant; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +@Slf4j +@Data +public class PlaywrightFunctions { + + private static final ObjectMapper objectMapper = new ObjectMapper(); + private static Playwright playwright; + private static Browser browser; + private static Page page; + private static APIRequestContext apiContext; + + public static Playwright getPlaywright() { + return playwright; + } + + public static Browser getBrowser() { + return browser; + } + + public static Page getPage() { + return page; + } + + // 确保浏览器实例存在 + private static Page ensureBrowser(Integer width, Integer height, Boolean headless) { + try { + if (playwright == null) { + playwright = Playwright.create(); + } + if (browser == null) { + browser = playwright.chromium().launch(new BrowserType.LaunchOptions() + .setHeadless(headless != null ? headless : false)); + } + if (page == null) { + page = browser.newPage(new Browser.NewPageOptions() + .setViewportSize(width != null ? width : 1920, + height != null ? height : 1080)); + } + return page; + } catch (Exception e) { + log.error("Failed to ensure browser", e); + throw new RuntimeException("Failed to ensure browser", e); + } + } + + // 确保API上下文存在 + private static APIRequestContext ensureApiContext(String baseUrl) { + try { + if (playwright == null) { + playwright = Playwright.create(); + } + if (apiContext == null) { + apiContext = playwright.request().newContext(new APIRequest.NewContextOptions() + .setBaseURL(baseUrl)); + } + return apiContext; + } catch (Exception e) { + log.error("Failed to ensure API context", e); + throw new RuntimeException("Failed to ensure API context", e); + } + } + + // 工具方法:解析JSON对象 + private static T parseObject(Object obj, Class clazz) { + try { + return objectMapper.readValue(objectMapper.writeValueAsString(obj), clazz); + } catch (Exception e) { + log.error("Failed to parse JSON: {}", obj, e); + throw new RuntimeException("Failed to parse JSON", e); + } + } + + @Data + public static class NavigateFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_navigate"; + + private String desc = "Navigate to a URL"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "url": { "type": "string" }, + "width": { + "type": "number", + "description": "Viewport width in pixels (default: 1920)" + }, + "height": { + "type": "number", + "description": "Viewport height in pixels (default: 1080)" + }, + "timeout": { + "type": "number", + "description": "Navigation timeout in milliseconds" + }, + "waitUntil": { + "type": "string", + "description": "Navigation wait condition" + }, + "headless": { + "type": "boolean", + "description": "Whether to run in headless mode (default: false)" + } + }, + "required": ["url"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + Integer width = args.get("width") != null ? ((Number)args.get("width")).intValue() : 1024; + Integer height = args.get("height") != null ? ((Number)args.get("height")).intValue() : 768; + Boolean headless = args.get("headless") != null ? (Boolean)args.get("headless") : false; + + Page page = ensureBrowser(width, height, headless); + + NavigateOptions options = new NavigateOptions() + .setTimeout(args.get("timeout") != null ? ((Number)args.get("timeout")).intValue() : 30000) + .setWaitUntil(args.get("waitUntil") != null ? + WaitUntilState.valueOf(args.get("waitUntil").toString().toUpperCase()) : + WaitUntilState.LOAD); + + page.navigate(args.get("url").toString(), options); + + String message = String.format("Navigated to %s with %s wait%s", + args.get("url"), + args.get("waitUntil") != null ? args.get("waitUntil") : "load", + (width != null && height != null) ? String.format(" (viewport: %dx%d)", width, height) : "" + ); + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent(message)), + false + ); + } catch (Exception e) { + log.error("Navigation failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Navigation failed: " + e.getMessage() + " ,cause: " + e.getCause())), + true + ); + } + } + } + + @Data + public static class ScreenshotFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_screenshot"; + + private String desc = "Take a screenshot of the current page or a specific element"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name for the screenshot" + }, + "selector": { + "type": "string", + "description": "CSS selector for element to screenshot" + }, + "width": { + "type": "number", + "description": "Width in pixels (default: 800)" + }, + "height": { + "type": "number", + "description": "Height in pixels (default: 600)" + }, + "storeBase64": { + "type": "boolean", + "description": "Store screenshot in base64 format (default: true)" + }, + "savePng": { + "type": "boolean", + "description": "Save screenshot as PNG file (default: false)" + }, + "downloadsDir": { + "type": "string", + "description": "Custom downloads directory path (default: user's Downloads folder)" + } + }, + "required": ["name"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + Page page = ensureBrowser( + args.get("width") != null ? ((Number)args.get("width")).intValue() : 1024, + args.get("height") != null ? ((Number)args.get("height")).intValue() : 768, + args.get("headless") != null ? (Boolean)args.get("headless") : false + ); + + + Path path = args.get("savePng") != null && (Boolean)args.get("savePng") ? + Paths.get(args.get("downloadsDir") != null ? + args.get("downloadsDir").toString() : + System.getProperty("user.home") + "/Downloads", + args.get("name") + "-" + Instant.now().toString().replace(":", "-") + ".png" + ) : Paths.get(System.getProperty("java.io.tmpdir"), "screenshot.png"); + + Page.ScreenshotOptions pageOptions = new Page.ScreenshotOptions() + .setPath(path); + + ElementHandle.ScreenshotOptions elementOptions = new ElementHandle.ScreenshotOptions() + .setPath(path); + + byte[] screenshot; + if (args.get("selector") != null) { + ElementHandle element = page.querySelector(args.get("selector").toString()); + if (element == null) { + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Element not found: " + args.get("selector"))), + true + ); + } + screenshot = element.screenshot(elementOptions); + } else { + screenshot = page.screenshot(pageOptions); + } + + List content = new java.util.ArrayList<>(); + + if (args.get("savePng") != null && (Boolean)args.get("savePng")) { + content.add(new McpSchema.TextContent( + "Screenshot saved to: " + path.toString() + )); + } + + if (args.get("storeBase64") == null || (Boolean)args.get("storeBase64")) { + content.add(new McpSchema.ImageContent(null, null, "image/png", + Base64.getEncoder().encodeToString(screenshot), + "image/png" + )); + } + + return new McpSchema.CallToolResult(content, false); + } catch (Exception e) { + log.error("Screenshot failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Screenshot failed: " + e.getMessage() + " ,cause: " + e.getCause())), + true + ); + } + } + } + + @Data + public static class ClickFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_click"; + + private String desc = "Click an element on the page"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "selector": { + "type": "string", + "description": "CSS selector for element to click" + } + }, + "required": ["selector"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + Page page = ensureBrowser(null, null, null); + page.click(args.get("selector").toString()); + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Clicked: " + args.get("selector"))), + false + ); + } catch (Exception e) { + log.error("Click failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to click " + args.get("selector") + ": " + e.getMessage() + )), + true + ); + } + } + } + + @Data + public static class FillFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_fill"; + + private String desc = "Fill out an input field"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "selector": { + "type": "string", + "description": "CSS selector for input field" + }, + "value": { + "type": "string", + "description": "Value to fill" + } + }, + "required": ["selector", "value"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + Page page = ensureBrowser(null, null, null); + page.waitForSelector(args.get("selector").toString()); + page.fill(args.get("selector").toString(), args.get("value").toString()); + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Filled " + args.get("selector") + " with: " + args.get("value") + )), + false + ); + } catch (Exception e) { + log.error("Fill failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to fill " + args.get("selector") + ": " + e.getMessage() + )), + true + ); + } + } + } + + @Data + public static class SelectFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_select"; + + private String desc = "Select an element on the page with Select tag"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "selector": { + "type": "string", + "description": "CSS selector for element to select" + }, + "value": { + "type": "string", + "description": "Value to select" + } + }, + "required": ["selector", "value"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + Page page = ensureBrowser(null, null, null); + page.waitForSelector(args.get("selector").toString()); + page.selectOption(args.get("selector").toString(), args.get("value").toString()); + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Selected " + args.get("selector") + " with: " + args.get("value") + )), + false + ); + } catch (Exception e) { + log.error("Select failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to select " + args.get("selector") + ": " + e.getMessage() + )), + true + ); + } + } + } + + @Data + public static class HoverFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_hover"; + + private String desc = "Hover an element on the page"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "selector": { + "type": "string", + "description": "CSS selector for element to hover" + } + }, + "required": ["selector"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + Page page = ensureBrowser(null, null, null); + page.waitForSelector(args.get("selector").toString()); + page.hover(args.get("selector").toString()); + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Hovered " + args.get("selector"))), + false + ); + } catch (Exception e) { + log.error("Hover failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to hover " + args.get("selector") + ": " + e.getMessage() + )), + true + ); + } + } + } + + @Data + public static class EvaluateFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_evaluate"; + + private String desc = "Execute JavaScript in the browser console"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "script": { + "type": "string", + "description": "JavaScript code to execute" + } + }, + "required": ["script"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + Page page = ensureBrowser(null, null, null); + Object result = page.evaluate(args.get("script").toString()); + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Execution result:\n" + objectMapper.writeValueAsString(result) + )), + false + ); + } catch (Exception e) { + log.error("Script execution failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Script execution failed: " + e.getMessage() + )), + true + ); + } + } + } + + @Data + public static class GetFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_get"; + + private String desc = "Perform an HTTP GET request"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "URL to perform GET operation" + } + }, + "required": ["url"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + APIRequestContext context = ensureApiContext(args.get("url").toString()); + APIResponse response = context.get(args.get("url").toString()); + + return new McpSchema.CallToolResult( + List.of( + new McpSchema.TextContent("Performed GET Operation " + args.get("url")), + new McpSchema.TextContent("Response: " + response.text()), + new McpSchema.TextContent("Response code " + response.status()) + ), + false + ); + } catch (Exception e) { + log.error("GET request failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to perform GET operation on " + args.get("url") + ": " + e.getMessage() + )), + true + ); + } + } + } + + @Data + public static class PostFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_post"; + + private String desc = "Perform an HTTP POST request"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "URL to perform POST operation" + }, + "value": { + "type": "string", + "description": "Data to post in the body" + } + }, + "required": ["url", "value"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + APIRequestContext context = ensureApiContext(args.get("url").toString()); + RequestOptions options = RequestOptions.create() + .setData(args.get("value")) + .setHeader("Content-Type", "application/json"); + + APIResponse response = context.post(args.get("url").toString(), options); + + return new McpSchema.CallToolResult( + List.of( + new McpSchema.TextContent( + "Performed POST Operation " + args.get("url") + + " with data " + args.get("value") + ), + new McpSchema.TextContent("Response: " + response.text()), + new McpSchema.TextContent("Response code " + response.status()) + ), + false + ); + } catch (Exception e) { + log.error("POST request failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to perform POST operation on " + args.get("url") + ": " + e.getMessage() + )), + true + ); + } + } + } + + @Data + public static class PutFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_put"; + + private String desc = "Perform an HTTP PUT request"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "URL to perform PUT operation" + }, + "value": { + "type": "string", + "description": "Data to PUT in the body" + } + }, + "required": ["url", "value"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + APIRequestContext context = ensureApiContext(args.get("url").toString()); + RequestOptions options = RequestOptions.create() + .setData(args.get("value")) + .setHeader("Content-Type", "application/json"); + + APIResponse response = context.put(args.get("url").toString(), options); + + return new McpSchema.CallToolResult( + List.of( + new McpSchema.TextContent( + "Performed PUT Operation " + args.get("url") + + " with data " + args.get("value") + ), + new McpSchema.TextContent("Response: " + response.text()), + new McpSchema.TextContent("Response code " + response.status()) + ), + false + ); + } catch (Exception e) { + log.error("PUT request failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to perform PUT operation on " + args.get("url") + ": " + e.getMessage() + )), + true + ); + } + } + } + + @Data + public static class PatchFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_patch"; + + private String desc = "Perform an HTTP PATCH request"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "URL to perform PATCH operation" + }, + "value": { + "type": "string", + "description": "Data to PATCH in the body" + } + }, + "required": ["url", "value"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + APIRequestContext context = ensureApiContext(args.get("url").toString()); + RequestOptions options = RequestOptions.create() + .setData(args.get("value")) + .setHeader("Content-Type", "application/json"); + + APIResponse response = context.patch(args.get("url").toString(), options); + + return new McpSchema.CallToolResult( + List.of( + new McpSchema.TextContent( + "Performed PATCH Operation " + args.get("url") + + " with data " + args.get("value") + ), + new McpSchema.TextContent("Response: " + response.text()), + new McpSchema.TextContent("Response code " + response.status()) + ), + false + ); + } catch (Exception e) { + log.error("PATCH request failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to perform PATCH operation on " + args.get("url") + ": " + e.getMessage() + )), + true + ); + } + } + } + + @Data + public static class DeleteFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_delete"; + + private String desc = "Perform an HTTP DELETE request"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "URL to perform DELETE operation" + } + }, + "required": ["url"] + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + APIRequestContext context = ensureApiContext(args.get("url").toString()); + APIResponse response = context.delete(args.get("url").toString()); + + return new McpSchema.CallToolResult( + List.of( + new McpSchema.TextContent("Performed DELETE Operation " + args.get("url")), + new McpSchema.TextContent("Response code " + response.status()) + ), + false + ); + } catch (Exception e) { + log.error("DELETE request failed", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to perform DELETE operation on " + args.get("url") + ": " + e.getMessage() + )), + true + ); + } + } + } + +} diff --git a/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/server/PlaywrightMcpServer.java b/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/server/PlaywrightMcpServer.java index 200f053a9..58f04ab25 100644 --- a/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/server/PlaywrightMcpServer.java +++ b/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/server/PlaywrightMcpServer.java @@ -10,6 +10,19 @@ import run.mone.hive.mcp.spec.McpSchema.Tool; import run.mone.hive.mcp.spec.ServerMcpTransport; import run.mone.mcp.playwright.function.PlaywrightFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions; +import run.mone.mcp.playwright.function.PlaywrightFunctions.ClickFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.DeleteFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.EvaluateFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.FillFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.GetFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.HoverFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.NavigateFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.PatchFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.PostFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.PutFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.ScreenshotFunction; +import run.mone.mcp.playwright.function.PlaywrightFunctions.SelectFunction; import run.mone.hive.mcp.server.McpServer.ToolRegistration; import run.mone.hive.mcp.spec.McpSchema.ServerCapabilities; @@ -37,13 +50,97 @@ public McpSyncServer start() { log.info("Registering execute_playwright tool..."); try { - playwrightFunction = new PlaywrightFunction(); - var playwrightToolRegistration = new ToolRegistration( - new Tool(playwrightFunction.getName(), playwrightFunction.getDesc(), playwrightFunction.getPlaywrightToolSchema()), playwrightFunction + // playwrightFunction = new PlaywrightFunction(); + // var playwrightToolRegistration = new ToolRegistration( + // new Tool(playwrightFunction.getName(), playwrightFunction.getDesc(), playwrightFunction.getPlaywrightToolSchema()), playwrightFunction + // ); + + // syncServer.addTool(playwrightToolRegistration); + // log.info("Successfully registered execute_playwright tool"); + + NavigateFunction navigateFunction = new PlaywrightFunctions.NavigateFunction(); + var navigateToolRegistration = new ToolRegistration( + new Tool(navigateFunction.getName(), navigateFunction.getDesc(), navigateFunction.getToolScheme()), navigateFunction + ); + syncServer.addTool(navigateToolRegistration); + log.info("Successfully registered navigate tool"); + + ClickFunction clickFunction = new PlaywrightFunctions.ClickFunction(); + var clickToolRegistration = new ToolRegistration( + new Tool(clickFunction.getName(), clickFunction.getDesc(), clickFunction.getToolScheme()), clickFunction + ); + syncServer.addTool(clickToolRegistration); + log.info("Successfully registered click tool"); + + ScreenshotFunction screenshotFunction = new PlaywrightFunctions.ScreenshotFunction(); + var screenshotToolRegistration = new ToolRegistration( + new Tool(screenshotFunction.getName(), screenshotFunction.getDesc(), screenshotFunction.getToolScheme()), screenshotFunction + ); + syncServer.addTool(screenshotToolRegistration); + log.info("Successfully registered screenshot tool"); + + FillFunction fillFunction = new PlaywrightFunctions.FillFunction(); + var fillToolRegistration = new ToolRegistration( + new Tool(fillFunction.getName(), fillFunction.getDesc(), fillFunction.getToolScheme()), fillFunction + ); + syncServer.addTool(fillToolRegistration); + log.info("Successfully registered fill tool"); + + SelectFunction selectFunction = new PlaywrightFunctions.SelectFunction(); + var selectToolRegistration = new ToolRegistration( + new Tool(selectFunction.getName(), selectFunction.getDesc(), selectFunction.getToolScheme()), selectFunction + ); + syncServer.addTool(selectToolRegistration); + log.info("Successfully registered select tool"); + + HoverFunction hoverFunction = new PlaywrightFunctions.HoverFunction(); + var hoverToolRegistration = new ToolRegistration( + new Tool(hoverFunction.getName(), hoverFunction.getDesc(), hoverFunction.getToolScheme()), hoverFunction ); + syncServer.addTool(hoverToolRegistration); + log.info("Successfully registered hover tool"); - syncServer.addTool(playwrightToolRegistration); - log.info("Successfully registered execute_playwright tool"); + EvaluateFunction evalFunction = new PlaywrightFunctions.EvaluateFunction(); + var evalToolRegistration = new ToolRegistration( + new Tool(evalFunction.getName(), evalFunction.getDesc(), evalFunction.getToolScheme()), evalFunction + ); + syncServer.addTool(evalToolRegistration); + log.info("Successfully registered eval tool"); + + GetFunction getFunction = new PlaywrightFunctions.GetFunction(); + var getToolRegistration = new ToolRegistration( + new Tool(getFunction.getName(), getFunction.getDesc(), getFunction.getToolScheme()), getFunction + ); + syncServer.addTool(getToolRegistration); + log.info("Successfully registered get tool"); + + PostFunction postFunction = new PlaywrightFunctions.PostFunction(); + var postToolRegistration = new ToolRegistration( + new Tool(postFunction.getName(), postFunction.getDesc(), postFunction.getToolScheme()), postFunction + ); + syncServer.addTool(postToolRegistration); + log.info("Successfully registered post tool"); + + PutFunction putFunction = new PlaywrightFunctions.PutFunction(); + var putToolRegistration = new ToolRegistration( + new Tool(putFunction.getName(), putFunction.getDesc(), putFunction.getToolScheme()), putFunction + ); + syncServer.addTool(putToolRegistration); + log.info("Successfully registered put tool"); + + DeleteFunction deleteFunction = new PlaywrightFunctions.DeleteFunction(); + var deleteToolRegistration = new ToolRegistration( + new Tool(deleteFunction.getName(), deleteFunction.getDesc(), deleteFunction.getToolScheme()), deleteFunction + ); + syncServer.addTool(deleteToolRegistration); + log.info("Successfully registered delete tool"); + + PatchFunction patchFunction = new PlaywrightFunctions.PatchFunction(); + var patchToolRegistration = new ToolRegistration( + new Tool(patchFunction.getName(), patchFunction.getDesc(), patchFunction.getToolScheme()), patchFunction + ); + syncServer.addTool(patchToolRegistration); + log.info("Successfully registered patch tool"); } catch (Exception e) { log.error("Failed to register execute_playwright tool", e); throw e; diff --git a/jcommon/mcp/mcp-playwright/src/test/java/run/mone/mcp/playwright/function/PlaywrightFunctionsTest.java b/jcommon/mcp/mcp-playwright/src/test/java/run/mone/mcp/playwright/function/PlaywrightFunctionsTest.java new file mode 100644 index 000000000..b547772a7 --- /dev/null +++ b/jcommon/mcp/mcp-playwright/src/test/java/run/mone/mcp/playwright/function/PlaywrightFunctionsTest.java @@ -0,0 +1,192 @@ +package run.mone.mcp.playwright.function; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.web.server.LocalServerPort; +import org.springframework.test.context.TestPropertySource; +import run.mone.hive.mcp.spec.McpSchema; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) +@TestPropertySource(locations = "classpath:application-test.properties") +public class PlaywrightFunctionsTest { + + @LocalServerPort + private int port; + + private String baseUrl; + + private PlaywrightFunctions.NavigateFunction navigateFunction; + private PlaywrightFunctions.ScreenshotFunction screenshotFunction; + private PlaywrightFunctions.ClickFunction clickFunction; + private PlaywrightFunctions.FillFunction fillFunction; + private PlaywrightFunctions.SelectFunction selectFunction; + private PlaywrightFunctions.HoverFunction hoverFunction; + private PlaywrightFunctions.EvaluateFunction evaluateFunction; + private PlaywrightFunctions.GetFunction getFunction; + private PlaywrightFunctions.PostFunction postFunction; + private PlaywrightFunctions.PutFunction putFunction; + private PlaywrightFunctions.PatchFunction patchFunction; + private PlaywrightFunctions.DeleteFunction deleteFunction; + + @BeforeEach + void setUp() { + baseUrl = "https://www.baidu.com"; + + navigateFunction = new PlaywrightFunctions.NavigateFunction(); + screenshotFunction = new PlaywrightFunctions.ScreenshotFunction(); + clickFunction = new PlaywrightFunctions.ClickFunction(); + fillFunction = new PlaywrightFunctions.FillFunction(); + selectFunction = new PlaywrightFunctions.SelectFunction(); + hoverFunction = new PlaywrightFunctions.HoverFunction(); + evaluateFunction = new PlaywrightFunctions.EvaluateFunction(); + getFunction = new PlaywrightFunctions.GetFunction(); + postFunction = new PlaywrightFunctions.PostFunction(); + putFunction = new PlaywrightFunctions.PutFunction(); + patchFunction = new PlaywrightFunctions.PatchFunction(); + deleteFunction = new PlaywrightFunctions.DeleteFunction(); + } + + @AfterEach + void tearDown() { + // 清理浏览器资源 + try { + if (PlaywrightFunctions.getPage() != null) { + PlaywrightFunctions.getPage().close(); + } + if (PlaywrightFunctions.getBrowser() != null) { + PlaywrightFunctions.getBrowser().close(); + } + if (PlaywrightFunctions.getPlaywright() != null) { + PlaywrightFunctions.getPlaywright().close(); + } + } catch (Exception e) { + // 忽略清理时的异常 + } + } + + @Test + void testNavigate() { + Map args = new HashMap<>(); + args.put("url", baseUrl); + args.put("width", 1024); + args.put("height", 768); + args.put("waitUntil", "load"); + + McpSchema.CallToolResult result = navigateFunction.apply(args); + assertFalse(result.isError()); + assertTrue(result.content().get(0).toString().contains("Navigated to")); + } + + @Test + void testScreenshot() { + // 先导航到页面 + testNavigate(); + + Map args = new HashMap<>(); + args.put("name", "test-screenshot"); + args.put("storeBase64", true); + + McpSchema.CallToolResult result = screenshotFunction.apply(args); + assertFalse(result.isError()); + assertTrue(result.content().size() > 0); + } + + @Test + void testClick() { + // 先导航到页面 + testNavigate(); + + Map args = new HashMap<>(); + args.put("selector", "button"); + + McpSchema.CallToolResult result = clickFunction.apply(args); + // 如果页面上没有按钮,这里会失败 + assertTrue(result.isError()); + } + + @Test + void testFill() { + // 先导航到页面 + testNavigate(); + + Map args = new HashMap<>(); + args.put("selector", "input"); + args.put("value", "test value"); + + McpSchema.CallToolResult result = fillFunction.apply(args); + // 如果页面上没有输入框,这里会失败 + assertTrue(result.isError()); + } + + @Test + void testEvaluate() { + // 先导航到页面 + testNavigate(); + + Map args = new HashMap<>(); + args.put("script", "document.title"); + + McpSchema.CallToolResult result = evaluateFunction.apply(args); + assertFalse(result.isError()); + } + + @Test + void testGet() { + Map args = new HashMap<>(); + args.put("url", baseUrl); + + McpSchema.CallToolResult result = getFunction.apply(args); + assertFalse(result.isError()); + assertTrue(result.content().get(0).toString().contains("Performed GET Operation")); + } + + @Test + void testPost() { + Map args = new HashMap<>(); + args.put("url", baseUrl); + args.put("value", "{\"test\":\"value\"}"); + + McpSchema.CallToolResult result = postFunction.apply(args); + // 如果端点不支持POST,这里会失败 + assertTrue(result.isError()); + } + + @Test + void testPut() { + Map args = new HashMap<>(); + args.put("url", baseUrl); + args.put("value", "{\"test\":\"value\"}"); + + McpSchema.CallToolResult result = putFunction.apply(args); + // 如果端点不支持PUT,这里会失败 + assertTrue(result.isError()); + } + + @Test + void testPatch() { + Map args = new HashMap<>(); + args.put("url", baseUrl); + args.put("value", "{\"test\":\"value\"}"); + + McpSchema.CallToolResult result = patchFunction.apply(args); + // 如果端点不支持PATCH,这里会失败 + assertTrue(result.isError()); + } + + @Test + void testDelete() { + Map args = new HashMap<>(); + args.put("url", baseUrl); + + McpSchema.CallToolResult result = deleteFunction.apply(args); + // 如果端点不支持DELETE,这里会失败 + assertTrue(result.isError()); + } +} From 9ec53d2e02341b400e790cba0a0aaaf792197c90 Mon Sep 17 00:00:00 2001 From: HawickMason <1914991129@qq.com> Date: Wed, 22 Jan 2025 16:14:00 +0800 Subject: [PATCH 4/4] feat(mcp_playwright): add functions to mcp-playwright - 1 --- .../function/PlaywrightFunctions.java | 105 +++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) diff --git a/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/function/PlaywrightFunctions.java b/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/function/PlaywrightFunctions.java index 0aa16ba99..e931a698d 100644 --- a/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/function/PlaywrightFunctions.java +++ b/jcommon/mcp/mcp-playwright/src/main/java/run/mone/mcp/playwright/function/PlaywrightFunctions.java @@ -3,6 +3,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.microsoft.playwright.*; import com.microsoft.playwright.Page.NavigateOptions; +import com.microsoft.playwright.options.LoadState; import com.microsoft.playwright.options.RequestOptions; import com.microsoft.playwright.options.WaitUntilState; @@ -50,7 +51,7 @@ private static Page ensureBrowser(Integer width, Integer height, Boolean headles browser = playwright.chromium().launch(new BrowserType.LaunchOptions() .setHeadless(headless != null ? headless : false)); } - if (page == null) { + if (page == null || page.isClosed()) { page = browser.newPage(new Browser.NewPageOptions() .setViewportSize(width != null ? width : 1920, height != null ? height : 1080)); @@ -747,4 +748,106 @@ public McpSchema.CallToolResult apply(Map args) { } } + @Data + public static class GetContentFunction implements Function, McpSchema.CallToolResult> { + private String name = "playwright_get_content"; + + private String desc = "Get content from the current page or a specific element"; + + private String toolScheme = """ + { + "type": "object", + "properties": { + "selector": { + "type": "string", + "description": "CSS selector for target element (optional)" + }, + "contentType": { + "type": "string", + "enum": ["text", "html"], + "description": "Type of content to retrieve (default: text)" + }, + "wait": { + "type": "boolean", + "description": "Whether to wait for element to be present (default: true)" + }, + "timeout": { + "type": "number", + "description": "Maximum time to wait in milliseconds (default: 30000)" + }, + "waitForLoadState": { + "type": "string", + "enum": ["load", "domcontentloaded", "networkidle"], + "description": "Wait for specific load state (default: load)" + }, + "waitForSelector": { + "type": "string", + "description": "Additional selector to wait for before getting content (optional)" + } + } + } + """; + + @Override + public McpSchema.CallToolResult apply(Map args) { + try { + Page page = ensureBrowser(null, null, null); + + String selector = args.get("selector") != null ? args.get("selector").toString() : null; + String contentType = args.get("contentType") != null ? args.get("contentType").toString() : "text"; + boolean wait = args.get("wait") != null ? (Boolean)args.get("wait") : true; + int timeout = args.get("timeout") != null ? ((Number)args.get("timeout")).intValue() : 30000; + String waitForLoadState = args.get("waitForLoadState") != null ? + args.get("waitForLoadState").toString() : "load"; + String waitForSelector = args.get("waitForSelector") != null ? + args.get("waitForSelector").toString() : null; + + // 等待页面加载状态 + page.waitForLoadState(LoadState.valueOf(waitForLoadState.toUpperCase()), + new Page.WaitForLoadStateOptions().setTimeout(timeout)); + + // 如果指定了额外的等待选择器 + if (waitForSelector != null) { + page.waitForSelector(waitForSelector, + new Page.WaitForSelectorOptions().setTimeout(timeout)); + } + + String content; + if (selector != null) { + if (wait) { + page.waitForSelector(selector, + new Page.WaitForSelectorOptions().setTimeout(timeout)); + } + ElementHandle element = page.querySelector(selector); + if (element == null) { + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("Element not found: " + selector)), + true + ); + } + content = contentType.equals("html") ? + (String)element.evaluate("el => el.innerHTML") : + element.textContent(); + } else { + content = contentType.equals("html") ? + page.content() : + (String)page.evaluate("() => document.body.innerText"); + } + + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent(content)), + false + ); + } catch (Exception e) { + log.error("Failed to get content", e); + return new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent( + "Failed to get content: " + e.getMessage() + )), + true + ); + } + } + } + }