Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement pre-initialized Docker container pool to improve /eval #66

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package org.togetherjava.jshellapi.dto;

import java.io.BufferedReader;
import java.io.BufferedWriter;

/**
* Data record for the state of a container.
*
* @param containerId The id of the container.
* @param containerOutput The output of the container.
* @param containerInput The input of the container.
*/
public record ContainerState(String containerId, BufferedReader containerOutput,
BufferedWriter containerInput) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.async.ResultCallback;
import com.github.dockerjava.api.command.InspectContainerResponse;
import com.github.dockerjava.api.command.PullImageResultCallback;
import com.github.dockerjava.api.model.*;
import com.github.dockerjava.core.DefaultDockerClientConfig;
Expand All @@ -10,16 +11,16 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Service;

import org.togetherjava.jshellapi.Config;
import org.togetherjava.jshellapi.dto.ContainerState;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.*;

@Service
public class DockerService implements DisposableBean {
Expand All @@ -28,10 +29,17 @@ public class DockerService implements DisposableBean {
private static final UUID WORKER_UNIQUE_ID = UUID.randomUUID();

private final DockerClient client;
private final Config config;
private final ExecutorService executor = Executors.newSingleThreadExecutor();
private final ConcurrentHashMap<StartupScriptId, ContainerState> cachedContainers =
new ConcurrentHashMap<>();
private final StartupScriptsService startupScriptsService;

private final String jshellWrapperBaseImageName;

public DockerService(Config config) {
public DockerService(Config config, StartupScriptsService startupScriptsService)
throws InterruptedException {
this.startupScriptsService = startupScriptsService;
DefaultDockerClientConfig clientConfig =
DefaultDockerClientConfig.createDefaultConfigBuilder().build();
ApacheDockerHttpClient httpClient =
Expand All @@ -41,11 +49,16 @@ public DockerService(Config config) {
.connectionTimeout(Duration.ofSeconds(config.dockerConnectionTimeout()))
.build();
this.client = DockerClientImpl.getInstance(clientConfig, httpClient);
this.config = config;

this.jshellWrapperBaseImageName =
config.jshellWrapperImageName().split(Config.JSHELL_WRAPPER_IMAGE_NAME_TAG)[0];

if (!isImagePresentLocally()) {
pullImage();
}
cleanupLeftovers(WORKER_UNIQUE_ID);
executor.submit(() -> initializeCachedContainer(StartupScriptId.EMPTY));
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
}

private void cleanupLeftovers(UUID currentId) {
Expand All @@ -62,48 +75,147 @@ private void cleanupLeftovers(UUID currentId) {
}
}

public String spawnContainer(long maxMemoryMegs, long cpus, @Nullable String cpuSetCpus,
String name, Duration evalTimeout, long sysoutLimit) throws InterruptedException {

boolean presentLocally = client.listImagesCmd()
/**
* Checks if the Docker image with the given name and tag is present locally.
*
* @return true if the image is present, false otherwise.
*/
private boolean isImagePresentLocally() {
return client.listImagesCmd()
.withFilter("reference", List.of(jshellWrapperBaseImageName))
.exec()
.stream()
.flatMap(it -> Arrays.stream(it.getRepoTags()))
.anyMatch(it -> it.endsWith(Config.JSHELL_WRAPPER_IMAGE_NAME_TAG));
}

if (!presentLocally) {
client.pullImageCmd(jshellWrapperBaseImageName)
.withTag("master")
.exec(new PullImageResultCallback())
.awaitCompletion(5, TimeUnit.MINUTES);
}
/**
* Pulls the Docker image.
*/
private void pullImage() throws InterruptedException {
client.pullImageCmd(jshellWrapperBaseImageName)
.withTag("master")
.exec(new PullImageResultCallback())
.awaitCompletion(5, TimeUnit.MINUTES);
}

/**
* Creates a Docker container with the given name.
*
* @param name The name of the container to create.
* @return The ID of the created container.
*/
private String createContainer(String name) {
HostConfig hostConfig = HostConfig.newHostConfig()
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
.withAutoRemove(true)
.withInit(true)
.withCapDrop(Capability.ALL)
.withNetworkMode("none")
.withPidsLimit(2000L)
.withReadonlyRootfs(true)
.withMemory((long) config.dockerMaxRamMegaBytes() * 1024 * 1024)
.withCpuCount((long) Math.ceil(config.dockerCPUsUsage()))
.withCpusetCpus(config.dockerCPUSetCPUs());

return client
.createContainerCmd(jshellWrapperBaseImageName + Config.JSHELL_WRAPPER_IMAGE_NAME_TAG)
.withHostConfig(HostConfig.newHostConfig()
.withAutoRemove(true)
.withInit(true)
.withCapDrop(Capability.ALL)
.withNetworkMode("none")
.withPidsLimit(2000L)
.withReadonlyRootfs(true)
.withMemory(maxMemoryMegs * 1024 * 1024)
.withCpuCount(cpus)
.withCpusetCpus(cpuSetCpus))
.withHostConfig(hostConfig)
.withStdinOpen(true)
.withAttachStdin(true)
.withAttachStderr(true)
.withAttachStdout(true)
.withEnv("evalTimeoutSeconds=" + evalTimeout.toSeconds(),
"sysOutCharLimit=" + sysoutLimit)
.withEnv("evalTimeoutSeconds=" + config.evalTimeoutSeconds(),
"sysOutCharLimit=" + config.sysOutCharLimit())
.withLabels(Map.of(WORKER_LABEL, WORKER_UNIQUE_ID.toString()))
.withName(name)
.exec()
.getId();
}

public InputStream startAndAttachToContainer(String containerId, InputStream stdin)
/**
* Spawns a new Docker container with specified configurations.
*
* @param name Name of the container.
* @param startupScriptId Script to initialize the container with.
* @return The ContainerState of the newly created container.
*/
public ContainerState initializeContainer(String name, StartupScriptId startupScriptId)
throws IOException {
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
if (startupScriptId == null || cachedContainers.isEmpty()
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
|| !cachedContainers.containsKey(startupScriptId)) {
String containerId = createContainer(name);
return setupContainerWithScript(containerId, startupScriptId);
}
ContainerState containerState = cachedContainers.get(startupScriptId);
executor.submit(() -> initializeCachedContainer(startupScriptId));

client.renameContainerCmd(containerState.containerId()).withName(name).exec();
return containerState;
}

/**
* Initializes a new cached docker container with specified configurations.
*
* @param startupScriptId Script to initialize the container with.
*/
private void initializeCachedContainer(StartupScriptId startupScriptId) {
String containerName = cachedContainerName();
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
String id = createContainer(containerName);
startContainer(id);

try {
ContainerState containerState = setupContainerWithScript(id, startupScriptId);
cachedContainers.put(startupScriptId, containerState);
} catch (IOException e) {
killContainerByName(containerName);
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
throw new RuntimeException(e);
}
}

/**
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
* @param containerId The id of the container
* @param startupScriptId The startup script id of the session
* @return ContainerState of the spawned container.
* @throws IOException if an I/O error occurs
*/
private ContainerState setupContainerWithScript(String containerId,
StartupScriptId startupScriptId) throws IOException {
startContainer(containerId);
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
PipedInputStream containerInput = new PipedInputStream();
BufferedWriter writer =
new BufferedWriter(new OutputStreamWriter(new PipedOutputStream(containerInput)));

InputStream containerOutput = attachToContainer(containerId, containerInput);
BufferedReader reader = new BufferedReader(new InputStreamReader(containerOutput));

writer.write(Utils.sanitizeStartupScript(startupScriptsService.get(startupScriptId)));
writer.newLine();
writer.flush();

return new ContainerState(containerId, reader, writer);
}

/**
* Creates a new container
*
* @param containerId the ID of the container to start
*/
private void startContainer(String containerId) {
if (!isContainerRunning(containerId)) {
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
client.startContainerCmd(containerId).exec();
}
}

/**
* Attaches to a running Docker container's input (stdin) and output streams (stdout, stderr).
* Logs any output from stderr and returns an InputStream to read stdout.
*
* @param containerId The ID of the running container to attach to.
* @param containerInput The input stream (containerInput) to send to the container.
* @return InputStream to read the container's stdout
* @throws IOException if an I/O error occurs
*/
private InputStream attachToContainer(String containerId, InputStream containerInput)
throws IOException {
PipedInputStream pipeIn = new PipedInputStream();
PipedOutputStream pipeOut = new PipedOutputStream(pipeIn);
Expand All @@ -113,15 +225,15 @@ public InputStream startAndAttachToContainer(String containerId, InputStream std
.withFollowStream(true)
.withStdOut(true)
.withStdErr(true)
.withStdIn(stdin)
.withStdIn(containerInput)
.exec(new ResultCallback.Adapter<>() {
@Override
public void onNext(Frame object) {
try {
String payloadString =
new String(object.getPayload(), StandardCharsets.UTF_8);
if (object.getStreamType() == StreamType.STDOUT) {
pipeOut.write(object.getPayload());
pipeOut.write(object.getPayload()); // Write stdout data to pipeOut
} else {
LOGGER.warn("Received STDERR from container {}: {}", containerId,
payloadString);
Expand All @@ -131,11 +243,24 @@ public void onNext(Frame object) {
}
}
});

client.startContainerCmd(containerId).exec();
return pipeIn;
}

/**
* Checks if the Docker container with the given ID is currently running.
*
* @param containerId the ID of the container to check
* @return true if the container is running, false otherwise
*/
public boolean isContainerRunning(String containerId) {
InspectContainerResponse containerResponse = client.inspectContainerCmd(containerId).exec();
return Boolean.TRUE.equals(containerResponse.getState().getRunning());
}

private String cachedContainerName() {
EmmanuelStan12 marked this conversation as resolved.
Show resolved Hide resolved
return "cached_session_" + UUID.randomUUID();
}

public void killContainerByName(String name) {
LOGGER.debug("Fetching container to kill {}.", name);
List<Container> containers = client.listContainersCmd().withNameFilter(Set.of(name)).exec();
Expand All @@ -156,6 +281,7 @@ public boolean isDead(String containerName) {
@Override
public void destroy() throws Exception {
LOGGER.info("destroy() called. Destroying all containers...");
executor.shutdown();
cleanupLeftovers(UUID.randomUUID());
client.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.lang.Nullable;

import org.togetherjava.jshellapi.Config;
import org.togetherjava.jshellapi.dto.*;
import org.togetherjava.jshellapi.exceptions.DockerException;

import java.io.*;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -31,33 +30,26 @@ public class JShellService {
private final int startupScriptSize;

public JShellService(DockerService dockerService, JShellSessionService sessionService,
String id, long timeout, boolean renewable, long evalTimeout,
long evalTimeoutValidationLeeway, int sysOutCharLimit, int maxMemory, double cpus,
@Nullable String cpuSetCpus, String startupScript) throws DockerException {
SessionInfo sessionInfo, Config config) throws DockerException {
this.dockerService = dockerService;
this.sessionService = sessionService;
this.id = id;
this.timeout = timeout;
this.renewable = renewable;
this.evalTimeout = evalTimeout;
this.evalTimeoutValidationLeeway = evalTimeoutValidationLeeway;
this.id = sessionInfo.id();
this.timeout = config.dockerConnectionTimeout();
this.renewable = sessionInfo.renewable();
this.evalTimeout = sessionInfo.evalTimeout();
this.evalTimeoutValidationLeeway = sessionInfo.evalTimeoutValidationLeeway();
this.lastTimeoutUpdate = Instant.now();

if (!dockerService.isDead(containerName())) {
LOGGER.warn("Tried to create an existing container {}.", containerName());
throw new DockerException("The session isn't completely destroyed, try again later.");
}

try {
String containerId = dockerService.spawnContainer(maxMemory, (long) Math.ceil(cpus),
cpuSetCpus, containerName(), Duration.ofSeconds(evalTimeout), sysOutCharLimit);
PipedInputStream containerInput = new PipedInputStream();
this.writer = new BufferedWriter(
new OutputStreamWriter(new PipedOutputStream(containerInput)));
InputStream containerOutput =
dockerService.startAndAttachToContainer(containerId, containerInput);
reader = new BufferedReader(new InputStreamReader(containerOutput));
writer.write(sanitize(startupScript));
writer.newLine();
writer.flush();
ContainerState containerState = dockerService.initializeContainer(containerName(),
sessionInfo.startupScriptId());
this.writer = containerState.containerInput();
this.reader = containerState.containerOutput();
checkContainerOK();
startupScriptSize = Integer.parseInt(reader.readLine());
} catch (Exception e) {
Expand Down Expand Up @@ -127,7 +119,7 @@ private JShellResult readResult() throws IOException, NumberFormatException, Doc
int errorCount = Integer.parseInt(reader.readLine());
List<String> errors = new ArrayList<>();
for (int i = 0; i < errorCount; i++) {
errors.add(desanitize(reader.readLine()));
errors.add(Utils.deSanitizeStartupScript((reader.readLine())));
}
yield new JShellEvalAbortionCause.CompileTimeErrorAbortionCause(errors);
}
Expand All @@ -140,7 +132,7 @@ private JShellResult readResult() throws IOException, NumberFormatException, Doc
abortion = new JShellEvalAbortion(causeSource, remainingSource, abortionCause);
}
boolean stdoutOverflow = Boolean.parseBoolean(reader.readLine());
String stdout = desanitize(reader.readLine());
String stdout = Utils.deSanitizeStartupScript(reader.readLine());
return new JShellResult(snippetResults, abortion, stdoutOverflow, stdout);
}

Expand Down Expand Up @@ -282,14 +274,6 @@ private void stopOperation() {
doingOperation = false;
}

private static String sanitize(String s) {
return s.replace("\\", "\\\\").replace("\n", "\\n");
}

private static String desanitize(String text) {
return text.replace("\\n", "\n").replace("\\\\", "\\");
}

private static String cleanCode(String code) {
return code.translateEscapes();
}
Expand Down
Loading