Skip to content

Commit

Permalink
feat(controller): add domain names for object storage (#3029)
Browse files Browse the repository at this point in the history
  • Loading branch information
anda-ren authored Dec 6, 2023
1 parent de5e14d commit 145d209
Show file tree
Hide file tree
Showing 22 changed files with 969 additions and 65 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.configuration.security;

import ai.starwhale.mlops.storage.configuration.StorageProperties;
import ai.starwhale.mlops.storage.domain.DomainAwareStorageAccessService;
import ai.starwhale.mlops.storage.s3.S3Config;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.jetbrains.annotations.NotNull;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;

@Component
public class ObjectStoreDomainDetectionFilter extends OncePerRequestFilter {

public static final String HEADER_NAME = "SW_CLIENT_FAVORED_OSS_DOMAIN_ALIAS";

final Map<String, Pattern> domainAliasMap;

public ObjectStoreDomainDetectionFilter(StorageProperties storageProperties) {
S3Config s3Config = storageProperties.getS3Config();
if (null == s3Config) {
domainAliasMap = new HashMap<>();
return;
}
Map<String, String> endpointEquivalentsMap = s3Config.getEndpointEquivalentsMap();
domainAliasMap = endpointEquivalentsMap.entrySet().stream().collect(Collectors.toMap(Entry::getKey, (entry) -> {
URI uri = null;
try {
uri = new URI(entry.getValue());
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
return Pattern.compile(uri.getHost().replace(".", "\\."));
}));

}

@Override
protected void doFilterInternal(
HttpServletRequest request,
@NotNull HttpServletResponse response,
FilterChain filterChain
) throws ServletException, IOException {
Pattern hostPattern = domainAliasMap.get(request.getHeader(HEADER_NAME));
if (null != hostPattern) {
request.setAttribute(DomainAwareStorageAccessService.OSS_DOMAIN_PATTERN_ATTR, hostPattern);
}
filterChain.doFilter(request, response);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ public class SecurityConfiguration extends WebSecurityConfigurerAdapter {
@Resource
private ContentCachingFilter contentCachingFilter;

@Resource
private ObjectStoreDomainDetectionFilter objectStoreDomainDetectionFilter;


public SecurityConfiguration() {
super();
Expand Down Expand Up @@ -139,6 +142,7 @@ protected void configure(HttpSecurity http) throws Exception {
JwtLoginFilter.class)
.addFilterBefore(projectDetectionFilter, JwtTokenFilter.class)
.addFilterBefore(contentCachingFilter, ProjectDetectionFilter.class)
.addFilterAfter(objectStoreDomainDetectionFilter, JwtTokenFilter.class)
;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import ai.starwhale.mlops.storage.LengthAbleInputStream;
import ai.starwhale.mlops.storage.StorageAccessService;
import ai.starwhale.mlops.storage.domain.DomainAwareStorageAccessService;
import ai.starwhale.mlops.storage.memory.StorageAccessServiceMemory;
import java.io.IOException;
import java.util.HashMap;
Expand All @@ -40,9 +41,12 @@ public CachedBlobService(StorageAccessService defaultStorageAccessService,
// for test only
storageAccessService = new StorageAccessServiceMemory();
} else {
storageAccessService = StorageAccessService.getS3LikeStorageAccessService(
cacheConfig.getStorageType(),
cacheConfig);
storageAccessService = new DomainAwareStorageAccessService(
StorageAccessService.getS3LikeStorageAccessService(
cacheConfig.getStorageType(),
cacheConfig
)
);
}
this.caches.put(cacheConfig.getBlobIdPrefix(),
new BlobServiceImpl(storageAccessService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.starwhale.mlops.storage.StorageAccessService;
import ai.starwhale.mlops.storage.StorageConnectionToken;
import ai.starwhale.mlops.storage.StorageUri;
import ai.starwhale.mlops.storage.domain.DomainAwareStorageAccessService;
import ai.starwhale.mlops.storage.fs.FsConfig;
import ai.starwhale.mlops.storage.s3.S3Config;
import java.util.Map;
Expand Down Expand Up @@ -103,9 +104,12 @@ private StorageAccessService buildStorageAccessService(StorageConnectionToken to
token.getTokens().get("sigKey")
));
default:
return StorageAccessService.getS3LikeStorageAccessService(
token.getType(),
new S3Config(token.getTokens()));
return new DomainAwareStorageAccessService(
StorageAccessService.getS3LikeStorageAccessService(
token.getType(),
new S3Config(token.getTokens())
)
);
}
} catch (Exception e) {
log.error("can not build storage access service", e);
Expand Down
1 change: 1 addition & 0 deletions server/controller/src/main/resources/application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ sw:
secret-key: ${SW_STORAGE_SECRETKEY:starwhale}
region: ${SW_STORAGE_REGION:local}
endpoint: ${SW_STORAGE_ENDPOINT:http://localhost:9000}
endpoint-equivalents-raw: ${SW_STORAGE_ENDPOINT_EQS:{"test":"http://127.0.0.1:9000"}}
huge-file-threshold: 10485760 # 10MB
huge-file-part-size: 5242880 # 5MB
controller:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright 2022 Starwhale, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ai.starwhale.mlops.configuration.security;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import ai.starwhale.mlops.storage.configuration.StorageProperties;
import ai.starwhale.mlops.storage.s3.S3Config;
import java.io.IOException;
import java.util.Map;
import java.util.regex.Pattern;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;

class ObjectStoreDomainDetectionFilterTest {

ObjectStoreDomainDetectionFilter filter;

@BeforeEach
void setup() throws IOException {
StorageProperties storageProperties = mock(StorageProperties.class);
when(storageProperties.getS3Config()).thenReturn(new S3Config(
Map.of("endpointEquivalents", "{\"test1\":\"http://120.0.1.2\", \"test2\":\"http://a.b.com\"}"))
);
filter = new ObjectStoreDomainDetectionFilter(storageProperties);
}

@Test
void doFilterInternal() throws ServletException, IOException {
HttpServletRequest req = mock(HttpServletRequest.class);
HttpServletResponse resp = mock(HttpServletResponse.class);
FilterChain filterChain = mock(FilterChain.class);
when(req.getHeader("SW_CLIENT_FAVORED_OSS_DOMAIN_ALIAS")).thenReturn(null);
filter.doFilterInternal(req, resp, filterChain);
verify(req, times(0)).setAttribute(any(), any());
when(req.getHeader("SW_CLIENT_FAVORED_OSS_DOMAIN_ALIAS")).thenReturn("test1");
filter.doFilterInternal(req, resp, filterChain);
ArgumentCaptor<Pattern> ac = ArgumentCaptor.forClass(Pattern.class);
verify(req).setAttribute(eq("SW_OSS_DOMAIN_REG_PATTERN"), ac.capture());
Assertions.assertTrue(ac.getValue().matcher("120.0.1.2").matches());
Assertions.assertFalse(ac.getValue().matcher("120a0.1.2").matches());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import ai.starwhale.mlops.schedule.impl.k8s.K8sJobTemplate;
import ai.starwhale.mlops.schedule.impl.k8s.reporting.ResourceEventHolder;
import ai.starwhale.mlops.storage.StorageAccessService;
import ai.starwhale.mlops.storage.configuration.StorageProperties;
import ai.starwhale.mlops.storage.memory.StorageAccessServiceMemory;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.pagehelper.autoconfigure.PageHelperAutoConfiguration;
Expand Down Expand Up @@ -536,5 +537,10 @@ ModelServingService modelServingService() {
RunExecutor runExecutor() {
return mock(RunExecutor.class);
}

@Bean
StorageProperties storageProperties() {
return mock(StorageProperties.class);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ai.starwhale.mlops.storage.StorageConnectionToken;
import ai.starwhale.mlops.storage.StorageUri;
import ai.starwhale.mlops.storage.aliyun.StorageAccessServiceAliyun;
import ai.starwhale.mlops.storage.domain.DomainAwareStorageAccessService;
import ai.starwhale.mlops.storage.fs.StorageAccessServiceFile;
import ai.starwhale.mlops.storage.memory.StorageAccessServiceMemory;
import ai.starwhale.mlops.storage.s3.StorageAccessServiceS3;
Expand Down Expand Up @@ -84,15 +85,21 @@ public void testCache() throws URISyntaxException {
new StorageUri("s3://10.34.2.1:8080/b1/p2"));
StorageAccessService s3AccessService3 = storageAccessParser.getStorageAccessServiceFromUri(
new StorageUri("s3://10.34.2.1:8080/b2/p2"));
assertThat(s3AccessService1, instanceOf(StorageAccessServiceS3.class));
assertThat(
((DomainAwareStorageAccessService) s3AccessService1).getDelegated(),
instanceOf(StorageAccessServiceS3.class)
);
assertThat(s3AccessService1, is(s3AccessService2));
assertThat(s3AccessService1, not(is(s3AccessService3)));

assertThat(storageAccessParser.getStorageAccessServiceFromUri(new StorageUri("file://b1/c/d")),
instanceOf(StorageAccessServiceFile.class));

assertThat(storageAccessParser.getStorageAccessServiceFromUri(new StorageUri("oss://10.34.2.1:8080/b1/c/d")),
instanceOf(StorageAccessServiceAliyun.class));
assertThat(
((DomainAwareStorageAccessService) storageAccessParser.getStorageAccessServiceFromUri(new StorageUri(
"oss://10.34.2.1:8080/b1/c/d"))).getDelegated(),
instanceOf(StorageAccessServiceAliyun.class)
);
}

@Test
Expand All @@ -110,11 +117,17 @@ public void testSettingUpdate() throws URISyntaxException {
"ak", "ak",
"sk", "sk"))));
storageAccessParser.onUpdate(systemSetting);
assertThat(storageAccessParser.getStorageAccessServiceFromUri(new StorageUri("s3://10.34.2.1:8080/b1/p1")),
instanceOf(StorageAccessServiceS3.class));
assertThat(
((DomainAwareStorageAccessService) storageAccessParser.getStorageAccessServiceFromUri(new StorageUri(
"s3://10.34.2.1:8080/b1/p1"))).getDelegated(),
instanceOf(StorageAccessServiceS3.class)
);
assertThat(storageAccessParser.getStorageAccessServiceFromUri(new StorageUri("oss://10.34.2.1:8080/b1/c/d")),
is(this.defaultStorageAccessService));
assertThat(storageAccessParser.getStorageAccessServiceFromUri(new StorageUri("oss://10.34.2.1:8080/b2/c/d")),
instanceOf(StorageAccessServiceAliyun.class));
assertThat(
((DomainAwareStorageAccessService) storageAccessParser.getStorageAccessServiceFromUri(new StorageUri(
"oss://10.34.2.1:8080/b2/c/d"))).getDelegated(),
instanceOf(StorageAccessServiceAliyun.class)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import ai.starwhale.mlops.schedule.log.RunLogSaver;
import ai.starwhale.mlops.storage.LengthAbleInputStream;
import ai.starwhale.mlops.storage.StorageAccessService;
import ai.starwhale.mlops.storage.configuration.StorageProperties;
import ai.starwhale.mlops.storage.memory.StorageAccessServiceMemory;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.primitives.Ints;
Expand Down Expand Up @@ -216,6 +217,11 @@ ModelServingService modelServingService() {
RunExecutor runExecutor() {
return mock(RunExecutor.class);
}

@Bean
StorageProperties storageProperties() {
return mock(StorageProperties.class);
}
}

@Autowired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,14 @@ static StorageAccessService getS3LikeStorageAccessService(String type, S3Config
*/
String signedUrl(String path, Long expTimeMillis) throws IOException;

default List<String> signedUrlAllDomains(String path, Long expTimeMillis) throws IOException {
return List.of(signedUrl(path, expTimeMillis));
}

String signedPutUrl(String path, String contentType, Long expTimeMillis) throws IOException;

default List<String> signedPutUrlAllDomains(String path, String contentType, Long expTimeMillis)
throws IOException {
return List.of(signedPutUrl(path, contentType, expTimeMillis));
}
}
Loading

0 comments on commit 145d209

Please sign in to comment.