Skip to content

Commit

Permalink
Move the logic for retrieving the access point name to the view layer…
Browse files Browse the repository at this point in the history
… and add a configurable setting to control the maximum number of principals per access point policy.
  • Loading branch information
Chrystinne committed Jan 15, 2025
1 parent b42d709 commit ff28a70
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 40 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ S3_OPEN_ACCESS_BUCKET=
S3_SERVER_ACCESS_LOG_BUCKET=
# The default bucket name to store projects with a 'RESTRICTED/CREDENTIALED' access policy.
S3_CONTROLLED_ACCESS_BUCKET=
# Maximum number of principals allowed per access point policy
MAX_PRINCIPALS_PER_AP_POLICY=

# Datacite
# Used to assign the DOIs
Expand Down
3 changes: 3 additions & 0 deletions physionet-django/physionet/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@
# Bucket name for the S3 bucket containing the controlled access data
S3_CONTROLLED_ACCESS_BUCKET = config('S3_CONTROLLED_ACCESS_BUCKET', default=None)

# Maximum number of principals allowed per access point policy
MAX_PRINCIPALS_PER_AP_POLICY = config('MAX_PRINCIPALS_PER_AP_POLICY', default=500, cast=int)

# Header tags for the AWS lambda function that grants access to S3 storage
AWS_HEADER_KEY = config('AWS_KEY', default=False)
AWS_HEADER_VALUE = config('AWS_VALUE', default=False)
Expand Down
18 changes: 7 additions & 11 deletions physionet-django/project/cloud/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,12 +814,10 @@ def get_access_point_name(project):
access_points = project.aws.access_points.all()

if access_points is None:
print("No access points found for project.")
access_point_name = f"{project.slug}-v{project.version.replace('.', '-')}-01"

elif access_points.count() == 1:
access_point_name = project.aws.access_points.first().name
print("Only one access point found for project: ", access_point_name)

else:
access_point_name = get_latest_access_point(project)
Expand Down Expand Up @@ -1065,14 +1063,13 @@ def create_first_data_access_point_policy(project):
)


def add_user_to_access_point_policy(project, user_aws_id, max_users=500):
def add_user_to_access_point_policy(project, user_aws_id):
"""
Add a user to an existing access point or create a new one if no access point has capacity.
Args:
project (PublishedProject): The project associated with the access points.
user_aws_id (str): The AWS ID of the user to be added.
max_users (int): The maximum number of users an access point can have.
Returns:
dict: A dictionary containing the access point information where the user was added.
Expand All @@ -1081,7 +1078,7 @@ def add_user_to_access_point_policy(project, user_aws_id, max_users=500):
try:
aws_acount = get_aws_account_by_id(user_aws_id)
# Check if there is an access point with capacity
access_point_data = get_access_point_with_capacity(project, max_users=max_users)
access_point_data = get_access_point_with_capacity(project)
if access_point_data:
# If an access point with capacity exists, add the user
access_point = access_point_data['access_point']
Expand Down Expand Up @@ -1134,13 +1131,12 @@ def add_user_to_access_point_policy(project, user_aws_id, max_users=500):
return None


def get_access_point_with_capacity(project, max_users=500):
def get_access_point_with_capacity(project):
"""
Finds an access point associated with the project that can add a new user.
Args:
project (PublishedProject): The project to check.
max_users (int): The maximum number of users allowed for an access point.
Returns:
dict: A dictionary containing:
Expand All @@ -1158,7 +1154,7 @@ def get_access_point_with_capacity(project, max_users=500):
# Count the number of users associated with the access point
user_count = access_point.users.count()

if user_count < max_users:
if user_count < settings.MAX_PRINCIPALS_PER_AP_POLICY:
# Get the list of usernames associated with the access point
users = list(access_point.users.values_list('username', flat=True))
return {
Expand Down Expand Up @@ -1188,10 +1184,9 @@ def insert_access_point_policy(access_point, data_access_point_name, project, su


def initialize_access_points(project):
MAX_PRINCIPALS_PER_AP_POLICY = 500
project_name = project.slug + "-" + project.version
aws_ids = get_aws_accounts_for_dataset(project_name)
number_of_access_points_needed = ceil(len(aws_ids) / MAX_PRINCIPALS_PER_AP_POLICY)
number_of_access_points_needed = ceil(len(aws_ids) / settings.MAX_PRINCIPALS_PER_AP_POLICY)
bucket_name = get_bucket_name(project)
for i in range(number_of_access_points_needed):
data_access_point_version = str(i + 1).zfill(2)
Expand All @@ -1201,7 +1196,8 @@ def initialize_access_points(project):
)

subset_aws_ids = aws_ids[
i * MAX_PRINCIPALS_PER_AP_POLICY: (i + 1) * MAX_PRINCIPALS_PER_AP_POLICY
i * settings.MAX_PRINCIPALS_PER_AP_POLICY:
(i + 1) * settings.MAX_PRINCIPALS_PER_AP_POLICY
]
access_point = AWSAccessPoint.objects.filter(
name=data_access_point_name, aws__project=project
Expand Down
36 changes: 14 additions & 22 deletions physionet-django/project/modelcomponents/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,24 @@ class AWS(models.Model):
class Meta:
default_permissions = ()

def s3_uri(self, user=None):
def get_public_s3_uri(self):
"""
Construct the S3 URI for the project.
Parameters:
user (User): The user requesting the S3 URI
Construct the S3 URI for public projects.
"""
from project.cloud.s3 import get_access_point_name_for_user_and_project
if self.is_private:
if not user or not user.is_authenticated:
print("Error: No valid user provided")
return None

# Fetch access point name
access_point_name = get_access_point_name_for_user_and_project(user, self)
if access_point_name and "No " not in access_point_name:
return (
f's3://arn:aws:s3:us-east-1:{settings.AWS_ACCOUNT_ID}:accesspoint/'
f'{access_point_name}/{self.project.slug}/{self.project.version}/'
)
else:
print(f"Error: {access_point_name}")
return None

# For public projects, construct URI using bucket name
return f's3://{self.bucket_name}/{self.project.slug}/{self.project.version}/'

def get_private_s3_uri(self, access_point_name):
"""
Construct the S3 URI for private projects using an access point.
"""
if not access_point_name:
return None

return (
f's3://arn:aws:s3:us-east-1:{settings.AWS_ACCOUNT_ID}:accesspoint/'
f'{access_point_name}/{self.project.slug}/{self.project.version}/'
)

def __str__(self):
return f"AWS instance for project: {self.project.slug}"

Expand Down
19 changes: 12 additions & 7 deletions physionet-django/project/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
SubmissionStatus,
Topic,
UploadedDocument,
AWS,
)
from project.authorization.access import can_view_project_files, can_access_project
from project.projectfiles import ProjectFiles
Expand All @@ -59,9 +60,8 @@
has_s3_credentials,
files_sent_to_S3,
add_user_to_access_point_policy,
s3_bucket_has_access_point,
get_access_point_name_for_user_and_project,
s3_bucket_has_credentialed_users,
initialize_access_points,
)
from django.db.models import F, DateTimeField, ExpressionWrapper

Expand Down Expand Up @@ -1932,14 +1932,19 @@ def published_project(request, project_slug, version, subdir=''):
bulk_url_prefix = notification.get_url_prefix(request, bulk_download=True)
all_project_versions = PublishedProject.objects.filter(slug=project_slug).order_by('version_order')

# Check if AWS instance exists for the project
s3_uri = None
if hasattr(project, 'aws'):
try:
if project.aws.is_private:
if has_signed_dua:
s3_uri = project.aws.s3_uri(user=request.user)
if has_signed_dua and request.user.is_authenticated:
access_point_name = get_access_point_name_for_user_and_project(
request.user,
project.aws
)
s3_uri = project.aws.get_private_s3_uri(access_point_name)
else:
s3_uri = '--no-sign-request ' + project.aws.s3_uri(user=None)
s3_uri = '--no-sign-request ' + project.aws.get_public_s3_uri()
except AWS.DoesNotExist:
s3_uri = None

context = {
'project': project,
Expand Down

0 comments on commit ff28a70

Please sign in to comment.