Skip to content

Commit

Permalink
feat(reasoner): support ProntoQA and ProofWriter.
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjoy committed Sep 11, 2024
1 parent dd44a64 commit a70392b
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ private List<List<Result>> prepareElements(
starts.put(e.getSubject().alias(), e.getSubject());
}
if (CollectionUtils.isEmpty(elements)) {
Triple triple = buildTriple(null, s, e);
Triple triple = bindTriple(null, s, e);
if (triple != null) {
List<List<Result>> singeRst = prepareElement(null, triple, context);
if (CollectionUtils.isNotEmpty(singeRst)) {
Expand All @@ -219,7 +219,7 @@ private List<List<Result>> prepareElements(
} else {
List<List<Result>> tmpElements = new LinkedList<>();
for (List<Result> evidence : elements) {
Triple triple = buildTriple(evidence, s, e);
Triple triple = bindTriple(evidence, s, e);
if (triple != null) {
List<List<Result>> singeRst = prepareElement(evidence, triple, context);
if (CollectionUtils.isNotEmpty(singeRst)) {
Expand All @@ -240,43 +240,23 @@ private Set<String> tripleAlias(Triple triple) {
return new HashSet<>(Arrays.asList(triple.getSubject().alias(), triple.getObject().alias()));
}

private Triple buildTriple(List<Result> evidence, Element s, Triple triple) {
Entity entity = null;
Triple trip = null;
if (CollectionUtils.isEmpty(evidence)) {
entity = (Entity) s;
} else {
private Triple bindTriple(List<Result> evidence, Element s, Triple triple) {
Map<String, Element> aliasToElement = new HashMap<>();
if (CollectionUtils.isNotEmpty(evidence)) {
for (Result r : evidence) {
Element e = r.getData();
if (triple.getSubject() instanceof Predicate) {
if (e instanceof Triple
&& ((Triple) e).getPredicate().alias() == triple.getSubject().alias()) {
trip = (Triple) e;
}
} else {
if (e instanceof Entity && e.alias() == s.alias()) {
entity = (Entity) r.getData();
} else if (e instanceof Triple && ((Triple) e).getSubject().alias() == s.alias()) {
entity = (Entity) ((Triple) e).getSubject();
} else if (e instanceof Triple && ((Triple) e).getObject().alias() == s.alias()) {
entity = (Entity) ((Triple) e).getObject();
}
Element data = r.getData();
aliasToElement.put(data.alias(), data);
if (data instanceof Triple) {
aliasToElement.put(((Triple) data).getSubject().alias(), ((Triple) data).getSubject());
aliasToElement.put(((Triple) data).getObject().alias(), ((Triple) data).getObject());
}
}
}

if (entity == null && trip == null) {
return null;
}
if (triple.getSubject().alias() == s.alias()) {
return new Triple(entity, triple.getPredicate(), triple.getObject());
} else if (triple.getObject().alias() == s.alias()) {
return new Triple(triple.getSubject(), triple.getPredicate(), entity);
} else if (triple.getSubject() instanceof Predicate && trip != null) {
return new Triple(trip, triple.getPredicate(), triple.getObject());
} else {
return null;
}
aliasToElement.put(s.alias(), s);
Element sub = aliasToElement.getOrDefault(triple.getSubject().alias(), triple.getSubject());
Element pre = aliasToElement.getOrDefault(triple.getPredicate().alias(), triple.getPredicate());
Element obj = aliasToElement.getOrDefault(triple.getObject().alias(), triple.getObject());
return new Triple(sub, pre, obj);
}

private Map<String, Element> getStart(Triple pattern, Triple head) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ private Collection<Element> findTriple(Triple tripleMatch) {
}
}
} else {
throw new RuntimeException("Cannot support " + t);
for (Triple tri : sToTriple.getOrDefault(t.getSubject(), new LinkedList<>())) {
if (t.matches(tri)) {
elements.add(tri);
}
}
}
return elements;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public int hashCode() {

@Override
public Element bind(Element pattern) {
if (pattern instanceof Entity || pattern instanceof Node) {
if (pattern instanceof Entity || pattern instanceof Node || pattern instanceof Any) {
return new Entity(this.id, this.type, pattern.alias());
} else {
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ public Element bind(Element pattern) {
} else if (pattern instanceof CombinationEntity) {
Entity entity = ((CombinationEntity) pattern).getEntityList().get(0);
return new Entity(entity.getId(), entity.getType(), entity.getAlias());
} else if (pattern instanceof Node) {
return new Node(this.type, ((Node) pattern).getAlias());
} else if (pattern instanceof Node || pattern instanceof Any) {
return new Node(this.type, pattern.alias());
} else {
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

package com.antgroup.openspg.reasoner.thinker;

import com.antgroup.openspg.reasoner.common.constants.Constants;
import com.antgroup.openspg.reasoner.common.graph.vertex.IVertexId;
import com.antgroup.openspg.reasoner.graphstate.GraphState;
import com.antgroup.openspg.reasoner.graphstate.impl.MemGraphState;
Expand Down Expand Up @@ -67,6 +68,7 @@ public void combinationDrug() {
context.put("目标收缩压上界", 140);
context.put("BMI", 35);
context.put("GFR", 35);
context.put(Constants.SPG_REASONER_THINKER_STRICT, true);
List<Result> triples = thinker.find(null, new Predicate("基本用药方案"), new Node("药品"), context);
Assert.assertTrue(triples.size() == 2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void testCase1() {
Triple t2 = makeTriple("lion", "iss", "round");
Triple t3 = makeTriple("rabbit", "iss", "kind");
Triple t4 = makeTriple("tiger", "iss", "big");
Triple t5 = makeTriple("tiger", "iss", "round");
Triple t5 = makeTriple("tiger", "iss", "kind");

Triple t6 = makeTriple("cow", "needs", "lion");
Triple t7 = makeTriple("cow", "needs", "rabbit");
Expand All @@ -67,6 +67,6 @@ public void testCase1() {
triples.forEach(t -> context.put(t.toString(), t));
List<Result> result =
thinker.find(new Node("tiger"), new Predicate("iss"), new Node("young"), context);
Assert.assertTrue(result.size() == 2);
Assert.assertTrue(result.size() == 1);
}
}

0 comments on commit a70392b

Please sign in to comment.