diff --git a/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/engine/InfGraph.java b/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/engine/InfGraph.java index 5328e7c53..ab7ec0718 100644 --- a/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/engine/InfGraph.java +++ b/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/engine/InfGraph.java @@ -209,7 +209,7 @@ private List> prepareElements( starts.put(e.getSubject().alias(), e.getSubject()); } if (CollectionUtils.isEmpty(elements)) { - Triple triple = buildTriple(null, s, e); + Triple triple = bindTriple(null, s, e); if (triple != null) { List> singeRst = prepareElement(null, triple, context); if (CollectionUtils.isNotEmpty(singeRst)) { @@ -219,7 +219,7 @@ private List> prepareElements( } else { List> tmpElements = new LinkedList<>(); for (List evidence : elements) { - Triple triple = buildTriple(evidence, s, e); + Triple triple = bindTriple(evidence, s, e); if (triple != null) { List> singeRst = prepareElement(evidence, triple, context); if (CollectionUtils.isNotEmpty(singeRst)) { @@ -240,43 +240,23 @@ private Set tripleAlias(Triple triple) { return new HashSet<>(Arrays.asList(triple.getSubject().alias(), triple.getObject().alias())); } - private Triple buildTriple(List evidence, Element s, Triple triple) { - Entity entity = null; - Triple trip = null; - if (CollectionUtils.isEmpty(evidence)) { - entity = (Entity) s; - } else { + private Triple bindTriple(List evidence, Element s, Triple triple) { + Map aliasToElement = new HashMap<>(); + if (CollectionUtils.isNotEmpty(evidence)) { for (Result r : evidence) { - Element e = r.getData(); - if (triple.getSubject() instanceof Predicate) { - if (e instanceof Triple - && ((Triple) e).getPredicate().alias() == triple.getSubject().alias()) { - trip = (Triple) e; - } - } else { - if (e instanceof Entity && e.alias() == s.alias()) { - entity = (Entity) r.getData(); - } else if (e instanceof Triple && ((Triple) e).getSubject().alias() == s.alias()) { - entity = (Entity) ((Triple) e).getSubject(); - } else if (e instanceof Triple && ((Triple) e).getObject().alias() == s.alias()) { - entity = (Entity) ((Triple) e).getObject(); - } + Element data = r.getData(); + aliasToElement.put(data.alias(), data); + if (data instanceof Triple) { + aliasToElement.put(((Triple) data).getSubject().alias(), ((Triple) data).getSubject()); + aliasToElement.put(((Triple) data).getObject().alias(), ((Triple) data).getObject()); } } } - - if (entity == null && trip == null) { - return null; - } - if (triple.getSubject().alias() == s.alias()) { - return new Triple(entity, triple.getPredicate(), triple.getObject()); - } else if (triple.getObject().alias() == s.alias()) { - return new Triple(triple.getSubject(), triple.getPredicate(), entity); - } else if (triple.getSubject() instanceof Predicate && trip != null) { - return new Triple(trip, triple.getPredicate(), triple.getObject()); - } else { - return null; - } + aliasToElement.put(s.alias(), s); + Element sub = aliasToElement.getOrDefault(triple.getSubject().alias(), triple.getSubject()); + Element pre = aliasToElement.getOrDefault(triple.getPredicate().alias(), triple.getPredicate()); + Element obj = aliasToElement.getOrDefault(triple.getObject().alias(), triple.getObject()); + return new Triple(sub, pre, obj); } private Map getStart(Triple pattern, Triple head) { diff --git a/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/engine/MemTripleStore.java b/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/engine/MemTripleStore.java index 96229ec5d..07768dcce 100644 --- a/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/engine/MemTripleStore.java +++ b/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/engine/MemTripleStore.java @@ -55,7 +55,11 @@ private Collection findTriple(Triple tripleMatch) { } } } else { - throw new RuntimeException("Cannot support " + t); + for (Triple tri : sToTriple.getOrDefault(t.getSubject(), new LinkedList<>())) { + if (t.matches(tri)) { + elements.add(tri); + } + } } return elements; } diff --git a/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/logic/graph/Entity.java b/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/logic/graph/Entity.java index 05212ad53..4cecd9c4c 100644 --- a/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/logic/graph/Entity.java +++ b/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/logic/graph/Entity.java @@ -138,7 +138,7 @@ public int hashCode() { @Override public Element bind(Element pattern) { - if (pattern instanceof Entity || pattern instanceof Node) { + if (pattern instanceof Entity || pattern instanceof Node || pattern instanceof Any) { return new Entity(this.id, this.type, pattern.alias()); } else { return this; diff --git a/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/logic/graph/Node.java b/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/logic/graph/Node.java index f4795dcc1..3261f81de 100644 --- a/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/logic/graph/Node.java +++ b/reasoner/thinker/src/main/java/com/antgroup/openspg/reasoner/thinker/logic/graph/Node.java @@ -71,8 +71,8 @@ public Element bind(Element pattern) { } else if (pattern instanceof CombinationEntity) { Entity entity = ((CombinationEntity) pattern).getEntityList().get(0); return new Entity(entity.getId(), entity.getType(), entity.getAlias()); - } else if (pattern instanceof Node) { - return new Node(this.type, ((Node) pattern).getAlias()); + } else if (pattern instanceof Node || pattern instanceof Any) { + return new Node(this.type, pattern.alias()); } else { return this; } diff --git a/reasoner/thinker/src/test/java/com/antgroup/openspg/reasoner/thinker/HypertensionTest.java b/reasoner/thinker/src/test/java/com/antgroup/openspg/reasoner/thinker/HypertensionTest.java index 0f82cd851..f3b6bb08b 100644 --- a/reasoner/thinker/src/test/java/com/antgroup/openspg/reasoner/thinker/HypertensionTest.java +++ b/reasoner/thinker/src/test/java/com/antgroup/openspg/reasoner/thinker/HypertensionTest.java @@ -13,6 +13,7 @@ package com.antgroup.openspg.reasoner.thinker; +import com.antgroup.openspg.reasoner.common.constants.Constants; import com.antgroup.openspg.reasoner.common.graph.vertex.IVertexId; import com.antgroup.openspg.reasoner.graphstate.GraphState; import com.antgroup.openspg.reasoner.graphstate.impl.MemGraphState; @@ -67,6 +68,7 @@ public void combinationDrug() { context.put("目标收缩压上界", 140); context.put("BMI", 35); context.put("GFR", 35); + context.put(Constants.SPG_REASONER_THINKER_STRICT, true); List triples = thinker.find(null, new Predicate("基本用药方案"), new Node("药品"), context); Assert.assertTrue(triples.size() == 2); } diff --git a/reasoner/thinker/src/test/java/com/antgroup/openspg/reasoner/thinker/ProofWriterTest.java b/reasoner/thinker/src/test/java/com/antgroup/openspg/reasoner/thinker/ProofWriterTest.java index 9155025a6..f55b0d172 100644 --- a/reasoner/thinker/src/test/java/com/antgroup/openspg/reasoner/thinker/ProofWriterTest.java +++ b/reasoner/thinker/src/test/java/com/antgroup/openspg/reasoner/thinker/ProofWriterTest.java @@ -54,7 +54,7 @@ public void testCase1() { Triple t2 = makeTriple("lion", "iss", "round"); Triple t3 = makeTriple("rabbit", "iss", "kind"); Triple t4 = makeTriple("tiger", "iss", "big"); - Triple t5 = makeTriple("tiger", "iss", "round"); + Triple t5 = makeTriple("tiger", "iss", "kind"); Triple t6 = makeTriple("cow", "needs", "lion"); Triple t7 = makeTriple("cow", "needs", "rabbit"); @@ -67,6 +67,6 @@ public void testCase1() { triples.forEach(t -> context.put(t.toString(), t)); List result = thinker.find(new Node("tiger"), new Predicate("iss"), new Node("young"), context); - Assert.assertTrue(result.size() == 2); + Assert.assertTrue(result.size() == 1); } }