This is a Triton Java API contributed by Alibaba Cloud PAI Team. It's based on Triton's HTTP/REST Protocols and for both easy of use and performance.
This Java API mimics Triton's official Python API. It has similar classes and methods.
triton.client.InferInput
describes each input to model.triton.client.InferRequestedOutput
describes each output from model.triton.client.InferenceServerClient
is the main inference class.
Currently the Java API supports only a subset of the entire Triton protocol. Specifically:
- Only the HTTP protocol is supported, GRPC is not supported.
- Only synchronous inference requests are supported, asynchronous and streaming inference requests are not supported.
- Health, metadata, statistics, model-management, and other extensions are not supported.
A minimal example would be like:
package triton.client.example;
import java.util.Arrays;
import java.util.List;
import com.google.common.collect.Lists;
import triton.client.InferInput;
import triton.client.InferRequestedOutput;
import triton.client.InferResult;
import triton.client.InferenceServerClient;
import triton.client.pojo.DataType;
public class MinExample {
public static void main(String[] args) throws Exception {
boolean isBinary = true;
InferInput inputIds = new InferInput("input_ids", new long[] {1L, 32}, DataType.INT32);
int[] inputIdsData = new int[32];
Arrays.fill(inputIdsData, 1); // fill with some data.
inputIds.setData(inputIdsData, isBinary);
InferInput inputMask = new InferInput("input_mask", new long[] {1, 32}, DataType.INT32);
int[] inputMaskData = new int[32];
Arrays.fill(inputMaskData, 1);
inputMask.setData(inputMaskData, isBinary);
InferInput segmentIds = new InferInput("segment_ids", new long[] {1, 32}, DataType.INT32);
int[] segmentIdsData = new int[32];
Arrays.fill(segmentIdsData, 0);
segmentIds.setData(segmentIdsData, isBinary);
List<InferInput> inputs = Lists.newArrayList(inputIds, inputMask, segmentIds);
List<InferRequestedOutput> outputs = Lists.newArrayList(new InferRequestedOutput("logits", isBinary));
InferenceServerClient client = new InferenceServerClient("0.0.0.0:8000", 5000, 5000);
InferResult result = client.infer("roberta", inputs, outputs);
float[] logits = result.getOutputAsFloat("logits");
System.out.println(Arrays.toString(logits));
}
}
See more examples in examples.