Skip to content

Commit

Permalink
change to applyConfiguration
Browse files Browse the repository at this point in the history
Signed-off-by: owenowenisme <[email protected]>
  • Loading branch information
owenowenisme committed Jan 9, 2025
1 parent d705cef commit 4a952cb
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 15 deletions.
96 changes: 81 additions & 15 deletions ray-operator/test/e2e/raycluster_gcs_ft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@ import (
"testing"

. "github.com/onsi/gomega"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
"github.com/ray-project/kuberay/ray-operator/controllers/ray/common"
"github.com/ray-project/kuberay/ray-operator/controllers/ray/utils"
rayv1ac "github.com/ray-project/kuberay/ray-operator/pkg/client/applyconfiguration/ray/v1"
. "github.com/ray-project/kuberay/ray-operator/test/support"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
appsv1ac "k8s.io/client-go/applyconfigurations/apps/v1"
corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
metav1ac "k8s.io/client-go/applyconfigurations/meta/v1"
)

func TestRayClusterGCSFaultTolerence(t *testing.T) {
Expand All @@ -18,22 +24,74 @@ func TestRayClusterGCSFaultTolerence(t *testing.T) {

// Create a namespace
namespace := test.NewTestNamespace()
testScriptAC := newConfigMap(namespace.Name, files(test, "test_detached_actor_1.py", "test_detached_actor_2.py"))

testScript, err := test.Client().Core().CoreV1().ConfigMaps(namespace.Name).Apply(test.Ctx(), testScriptAC, TestApplyOptions)
g.Expect(err).NotTo(HaveOccurred())

_, err = test.Client().Core().AppsV1().Deployments(namespace.Name).Apply(
test.Ctx(),
appsv1ac.Deployment("redis", namespace.Name).
WithSpec(appsv1ac.DeploymentSpec().
WithReplicas(1).
WithSelector(metav1ac.LabelSelector().WithMatchLabels(map[string]string{"app": "redis"})).
WithTemplate(corev1ac.PodTemplateSpec().
WithLabels(map[string]string{"app": "redis"}).
WithSpec(corev1ac.PodSpec().
WithContainers(corev1ac.Container().
WithName("redis").
WithImage("redis:7.4").
WithPorts(corev1ac.ContainerPort().WithContainerPort(6379)),
),
),
),
),
TestApplyOptions,
)
g.Expect(err).NotTo(HaveOccurred())

test.T().Log("Creating Cluster for GCSFaultTolerence testing.")
yamlFilePath := "testdata/ray-cluster.ray-ft.yaml"
rayClusterFromYaml := DeserializeRayClusterYAML(test, yamlFilePath)
KubectlApplyYAML(test, yamlFilePath, namespace.Name)
_, err = test.Client().Core().CoreV1().Services(namespace.Name).Apply(
test.Ctx(),
corev1ac.Service("redis", namespace.Name).
WithSpec(corev1ac.ServiceSpec().
WithSelector(map[string]string{"app": "redis"}).
WithPorts(corev1ac.ServicePort().
WithPort(6379),
),
),
TestApplyOptions,
)
g.Expect(err).NotTo(HaveOccurred())

rayClusterAC := rayv1ac.RayCluster("raycluster-gcsft", namespace.Name).WithSpec(
newRayClusterSpec((mountConfigMap[rayv1ac.RayClusterSpecApplyConfiguration](testScript, "/home/ray/samples"))).WithGcsFaultToleranceOptions(
rayv1ac.GcsFaultToleranceOptions().
WithRedisAddress("redis:6379"),
),
)

rayCluster, err := GetRayCluster(test, namespace.Name, rayClusterFromYaml.Name)
rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions)
g.Expect(err).NotTo(HaveOccurred())
g.Expect(rayCluster).NotTo(BeNil())
test.T().Logf("Created RayCluster %s/%s successfully", rayCluster.Namespace, rayCluster.Name)

// Make sure the RAY_REDIS_ADDRESS env is set on the Head Pod.
g.Eventually(func(g Gomega) bool {
rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Apply(test.Ctx(), rayClusterAC, TestApplyOptions)
g.Expect(err).NotTo(HaveOccurred())
if rayCluster.Status.Head.PodName != "" {
headPod, err := test.Client().Core().CoreV1().Pods(namespace.Name).Get(test.Ctx(), rayCluster.Status.Head.PodName, metav1.GetOptions{})
g.Expect(err).NotTo(HaveOccurred())
return utils.EnvVarExists(utils.RAY_REDIS_ADDRESS, headPod.Spec.Containers[utils.RayContainerIndex].Env)
}
return false
}, TestTimeoutMedium).Should(BeTrue())

test.T().Logf("Waiting for RayCluster %s/%s to become ready", rayCluster.Namespace, rayCluster.Name)
g.Eventually(RayCluster(test, rayCluster.Namespace, rayCluster.Name), TestTimeoutMedium).
Should(WithTransform(RayClusterState, Equal(rayv1.Ready)))
g.Eventually(RayCluster(test, namespace.Name, rayCluster.Name), TestTimeoutMedium).
Should(WithTransform(StatusCondition(rayv1.RayClusterProvisioned), MatchCondition(metav1.ConditionTrue, rayv1.AllPodRunningAndReadyFirstTime)))

test.T().Run("Test Detached Actor", func(_ *testing.T) {
headPod, err := GetHeadPod(test, rayClusterFromYaml)
headPod, err := GetHeadPod(test, rayCluster)
g.Expect(err).NotTo(HaveOccurred())

test.T().Logf("HeadPod Name: %s", headPod.Name)
Expand All @@ -52,15 +110,19 @@ func TestRayClusterGCSFaultTolerence(t *testing.T) {

// Restart count should eventually become 1, not creating a new pod
HeadPodRestartCount := func(p *corev1.Pod) int32 { return p.Status.ContainerStatuses[0].RestartCount }
g.Eventually(HeadPod(test, rayCluster)).
HeadPodContainerReady := func(p *corev1.Pod) bool { return p.Status.ContainerStatuses[0].Ready }

g.Eventually(HeadPod(test, rayCluster), TestTimeoutMedium).
Should(WithTransform(HeadPodRestartCount, Equal(int32(1))))
g.Eventually(HeadPod(test, rayCluster), TestTimeoutMedium).
Should(WithTransform(HeadPodContainerReady, Equal(true)))

// Pos Status should eventually become Running
PodState := func(p *corev1.Pod) string { return string(p.Status.Phase) }
g.Eventually(HeadPod(test, rayCluster)).
Should(WithTransform(PodState, Equal("Running")))

headPod, err = GetHeadPod(test, rayClusterFromYaml)
headPod, err = GetHeadPod(test, rayCluster)
g.Expect(err).NotTo(HaveOccurred())

expectedOutput := "3"
Expand All @@ -74,11 +136,15 @@ func TestRayClusterGCSFaultTolerence(t *testing.T) {
g.Expect(err).NotTo(HaveOccurred())
// Will get 2 head pods while one is terminating and another is creating, so wait until one is left
g.Eventually(func() error {
_, err := GetHeadPod(test, rayClusterFromYaml)
_, err := GetHeadPod(test, rayCluster)
return err
}, TestTimeoutMedium).ShouldNot(HaveOccurred())

headPod, err = GetHeadPod(test, rayClusterFromYaml)
g.Eventually(HeadPod(test, rayCluster)).
Should(WithTransform(PodState, Equal("Running")))

// Then get new head pod and run verification
headPod, err = GetHeadPod(test, rayCluster)
g.Expect(err).NotTo(HaveOccurred())
expectedOutput = "4"
ExecPodCmd(test, headPod, common.RayHeadContainer, []string{"python", "samples/test_detached_actor_2.py", rayNamespace, expectedOutput})
Expand Down
21 changes: 21 additions & 0 deletions ray-operator/test/e2e/test_detached_actor_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import ray
import sys

ray.init(namespace=sys.argv[1])

@ray.remote
class TestCounter:
def __init__(self):
self.value = 0

def increment(self):
self.value += 1
return self.value

tc = TestCounter.options(name="testCounter", lifetime="detached", max_restarts=-1).remote()
val1 = ray.get(tc.increment.remote())
val2 = ray.get(tc.increment.remote())
print(f"val1: {val1}, val2: {val2}")

assert(val1 == 1)
assert(val2 == 2)
35 changes: 35 additions & 0 deletions ray-operator/test/e2e/test_detached_actor_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import ray
import time
import sys

def retry_with_timeout(func, timeout=90):
err = None
start = time.time()
i = 0
while time.time() - start <= timeout:
try:
print(f"retry iter: {i}", flush=True)
i += 1
return func()
except BaseException as e:
err = e
finally:
time.sleep(1)
raise err

def get_detached_actor():
return ray.get_actor("testCounter")

# Try to connect to Ray cluster.
print("Try to connect to Ray cluster.", flush=True)
retry_with_timeout(lambda: ray.init(address='ray://127.0.0.1:10001', namespace=sys.argv[1]), timeout = 180)

# Get TestCounter actor
print("Get TestCounter actor.", flush=True)
tc = retry_with_timeout(get_detached_actor)

print("Try to call remote function \'increment\'.", flush=True)
val = retry_with_timeout(lambda: ray.get(tc.increment.remote()))
print(f"val: {val}", flush=True)

assert(val == int(sys.argv[2]))

0 comments on commit 4a952cb

Please sign in to comment.