Skip to content

Commit

Permalink
Pull request update/240702
Browse files Browse the repository at this point in the history
fb75c5a OS-7661. regenerated live_demo.json
3908db3 OS-7679. Not validate stats length
6e994bb Merge feature/ml_artifacts into integration
  • Loading branch information
stanfra authored Jul 2, 2024
2 parents 4f537bc + fb75c5a commit 15ad836
Show file tree
Hide file tree
Showing 90 changed files with 3,931 additions and 58 deletions.
22 changes: 22 additions & 0 deletions arcee/arcee_receiver/migrations/20240524090000_artifact_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from mongodb_migrations.base import BaseMigration


INDEX_NAME = 'RunIdCreatedAtDt'
INDEX_FIELDS = ['run_id', '_created_at_dt']


class Migration(BaseMigration):
def existing_indexes(self):
return [x['name'] for x in self.db.artifact.list_indexes()]

def upgrade(self):
if INDEX_NAME not in self.existing_indexes():
self.db.artifact.create_index(
[(key, 1) for key in INDEX_FIELDS],
name=INDEX_NAME,
background=True
)

def downgrade(self):
if INDEX_NAME in self.existing_indexes():
self.db.artifact.drop_index(INDEX_NAME)
81 changes: 80 additions & 1 deletion arcee/arcee_receiver/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from enum import Enum
from datetime import datetime, timezone
from pydantic import (
BaseModel, BeforeValidator, ConfigDict, Field, model_validator)
BaseModel, BeforeValidator, ConfigDict, Field, NonNegativeInt,
model_validator)
from typing import List, Optional, Union
from typing_extensions import Annotated

Expand All @@ -23,6 +24,10 @@ class BaseClass(BaseModel):
default_factory=lambda: int(datetime.now(tz=timezone.utc).timestamp()))
now_ms = Field(
default_factory=lambda: datetime.now(tz=timezone.utc).timestamp())
date_start = Field(
default_factory=lambda: int(datetime.now(tz=timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0).timestamp()),
alias='_created_at_dt')


class ConsolePostIn(BaseClass):
Expand Down Expand Up @@ -311,3 +316,77 @@ class MetricPatchIn(MetricPostIn):
class Metric(MetricPatchIn):
id: str = id_
token: str


class ArtifactPatchIn(BaseClass):
path: str = None
name: Optional[str] = None
description: Optional[str] = None
tags: Optional[dict] = {}


class ArtifactPostIn(ArtifactPatchIn):
run_id: str
path: str

@model_validator(mode='after')
def set_name(self):
if not self.name:
self.name = self.path
return self


class Artifact(ArtifactPostIn):
id: str = id_
token: str
created_at: int = now
created_at_dt: int = date_start


timestamp = Field(None, ge=0, le=2**31-1)
max_mongo_int = Field(0, ge=0, le=2**63-1)


class ArtifactSearchParams(BaseModel):
created_at_lt: Optional[NonNegativeInt] = timestamp
created_at_gt: Optional[NonNegativeInt] = timestamp
limit: Optional[NonNegativeInt] = max_mongo_int
start_from: Optional[NonNegativeInt] = max_mongo_int
run_id: Optional[Union[list, str]] = []
text_like: Optional[str] = None

@model_validator(mode='before')
def convert_to_expected_types(self):
"""
Converts a dict of request.args passed as query parameters to a dict
suitable for further model validation.
Example:
request.args: {"limit": ["1"], "text_like": ["test"]}
return: {"limit": "1", "text_like": "test"}
"""
numeric_fields = ['created_at_lt', 'created_at_gt',
'limit', 'start_from']
for k, v in self.items():
if isinstance(v, list) and len(v) == 1:
v = v[0]
self[k] = v
if k in numeric_fields:
try:
self[k] = int(v)
except (TypeError, ValueError):
continue
return self

@model_validator(mode='after')
def validate_run_id(self):
if isinstance(self.run_id, str):
self.run_id = [self.run_id]
return self

@model_validator(mode='after')
def validate_created_at(self):
if (self.created_at_gt is not None and self.created_at_lt is not None
and self.created_at_lt <= self.created_at_gt):
raise ValueError('Invalid created_at filter values')
return self
175 changes: 170 additions & 5 deletions arcee/arcee_receiver/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from collections import OrderedDict, defaultdict
from datetime import datetime, timezone
from datetime import datetime, timezone, timedelta
import asyncio

from etcd import Lock as EtcdLock, Client as EtcdClient
Expand All @@ -10,6 +10,7 @@

from mongodb_migrations.cli import MigrationManager
from mongodb_migrations.config import Configuration
from pydantic import ValidationError
from sanic import Sanic
from sanic.log import logger
from sanic.response import json
Expand All @@ -23,7 +24,8 @@
LeaderboardDatasetPatchIn, LeaderboardDatasetPostIn,
Leaderboard, LeaderboardPostIn, LeaderboardPatchIn, Log, Platform,
StatsPostIn, ModelPatchIn, ModelPostIn, Model, ModelVersionIn,
ModelVersion, Metric, MetricPostIn, MetricPatchIn
ModelVersion, Metric, MetricPostIn, MetricPatchIn,
ArtifactPostIn, ArtifactPatchIn, Artifact, ArtifactSearchParams,
)
from arcee.arcee_receiver.modules.leader_board import (
get_calculated_leaderboard, Tendencies)
Expand Down Expand Up @@ -360,15 +362,17 @@ async def delete_task(request, id_: str):
db.proc_data.delete_many({'run_id': {'$in': runs}}),
db.log.delete_many({'run_id': {'$in': runs}}),
db.run.delete_many({"task_id": id_}),
db.console.delete_many({'run_id': {'$in': runs}})
db.console.delete_many({'run_id': {'$in': runs}}),
db.artifact.delete_many({'run_id': {'$in': runs}}),
)
dm, ds, dpd, dl, dr, dc = results
dm, ds, dpd, dl, dr, dc, da = results
deleted_milestones = dm.deleted_count
deleted_stages = ds.deleted_count
deleted_logs = dl.deleted_count
deleted_runs = dr.deleted_count
deleted_proc_data = dpd.deleted_count
deleted_consoles = dc.deleted_count
deleted_artifacts = da.deleted_count
leaderboard = await db.leaderboard.find_one(
{"token": token, "task_id": id_, "deleted_at": 0})
now = int(datetime.now(tz=timezone.utc).timestamp())
Expand Down Expand Up @@ -396,7 +400,8 @@ async def delete_task(request, id_: str):
"deleted_runs": deleted_runs,
"deleted_stages": deleted_stages,
"deleted_proc_data": deleted_proc_data,
"deleted_console_output": deleted_consoles
"deleted_console_output": deleted_consoles,
"deleted_artifacts": deleted_artifacts
})


Expand Down Expand Up @@ -903,6 +908,7 @@ async def delete_run(request, run_id: str):
await db.stage.delete_many({'run_id': run['_id']})
await db.milestone.delete_many({'run_id': run['_id']})
await db.proc_data.delete_many({'run_id': run['_id']})
await db.artifact.delete_many({'run_id': run['_id']})
await db.model_version.update_many({'run_id': run['_id'], 'deleted_at': 0},
{'$set': {'deleted_at': now}})
await db.run.delete_one({'_id': run_id})
Expand Down Expand Up @@ -2287,6 +2293,165 @@ async def get_model_versions_for_task(request, task_id: str):
return json(versions)


def _format_artifact(artifact: dict, run: dict) -> dict:
artifact.pop('run_id', None)
artifact['run'] = {
'_id': run['_id'],
'task_id': run['task_id'],
'name': run['name'],
'number': run['number']
}
return artifact


async def _create_artifact(**kwargs) -> dict:
artifact = Artifact(**kwargs).model_dump(by_alias=True)
await db.artifact.insert_one(artifact)
return artifact


@app.route('/arcee/v2/artifacts', methods=["POST", ], ctx_label='token')
@validate(json=ArtifactPostIn)
async def create_artifact(request, body: ArtifactPostIn):
token = request.ctx.token
run = await db.run.find_one({"_id": body.run_id, 'deleted_at': 0})
if not run:
raise SanicException("Run not found", status_code=404)
artifact = await _create_artifact(
token=token, **body.model_dump(exclude_unset=True))
artifact = _format_artifact(artifact, run)
return json(artifact, status=201)


def _build_artifact_filter_pipeline(run_ids: list,
query: ArtifactSearchParams):
filters = defaultdict(dict)
filters['run_id'] = {'$in': run_ids}
if query.created_at_lt:
created_at_dt = int((datetime.fromtimestamp(
query.created_at_lt) + timedelta(days=1)).replace(
hour=0, minute=0, second=0, microsecond=0).timestamp())
filters['_created_at_dt'].update({'$lt': created_at_dt})
filters['created_at'].update({'$lt': query.created_at_lt})
if query.created_at_gt:
created_at_dt = int(datetime.fromtimestamp(
query.created_at_gt).replace(hour=0, minute=0, second=0,
microsecond=0).timestamp())
filters['_created_at_dt'].update({'$gte': created_at_dt})
filters['created_at'].update({'$gt': query.created_at_gt})
pipeline = [{'$match': filters}]
if query.text_like:
pipeline += [
{'$addFields': {'tags_array': {'$objectToArray': '$tags'}}},
{'$match': {'$or': [
{'name': {'$regex': f'(.*){query.text_like}(.*)'}},
{'description': {'$regex': f'(.*){query.text_like}(.*)'}},
{'path': {'$regex': f'(.*){query.text_like}(.*)'}},
{'tags_array.k': {'$regex': f'(.*){query.text_like}(.*)'}},
{'tags_array.v': {'$regex': f'(.*){query.text_like}(.*)'}},
]}}
]
return pipeline


@app.route('/arcee/v2/artifacts', methods=["GET", ], ctx_label='token')
async def list_artifacts(request):
token = request.ctx.token
try:
query = ArtifactSearchParams(**request.args)
except ValidationError as e:
raise SanicException(f'Invalid query params: {str(e)}',
status_code=400)
result = {
'artifacts': [],
'limit': query.limit,
'start_from': query.start_from,
'total_count': 0
}
tasks = [x['_id'] async for x in db.task.find({'token': token},
{'_id': 1})]
run_query = {'task_id': {'$in': tasks}, 'deleted_at': 0}
if query.run_id:
run_query['_id'] = {'$in': query.run_id}
runs_map = {run['_id']: run async for run in db.run.find(run_query)}
runs_ids = list(runs_map.keys())

pipeline = _build_artifact_filter_pipeline(runs_ids, query)
pipeline.append({'$sort': {'created_at': -1, '_id': 1}})

paginate_pipeline = [{'$skip': query.start_from}]
if query.limit:
paginate_pipeline.append({'$limit': query.limit})
pipeline.extend(paginate_pipeline)
async for artifact in db.artifact.aggregate(pipeline, allowDiskUse=True):
artifact.pop('tags_array', None)
res = Artifact(**artifact).model_dump(by_alias=True)
res.pop('run_id', None)
res = _format_artifact(artifact, runs_map[artifact['run_id']])
result['artifacts'].append(res)
if len(result['artifacts']) != 0 and not query.limit:
result['total_count'] = len(result['artifacts']) + query.start_from
else:
pipeline = _build_artifact_filter_pipeline(runs_ids, query)
pipeline.append({'$count': 'count'})
res = db.artifact.aggregate(pipeline)
try:
count = await res.next()
result['total_count'] = count['count']
except StopAsyncIteration:
pass
return json(result)


async def _get_artifact(token, artifact_id):
artifact = await db.artifact.find_one(
{"$and": [
{"token": token},
{"_id": artifact_id}
]})
if not artifact:
raise SanicException("Artifact not found", status_code=404)
return artifact


@app.route('/arcee/v2/artifacts/<id_>', methods=["GET", ], ctx_label='token')
async def get_artifact(request, id_: str):
artifact = await _get_artifact(request.ctx.token, id_)
artifact_dict = Artifact(**artifact).model_dump(by_alias=True)
run = await db.run.find_one({"_id": artifact['run_id'], 'deleted_at': 0})
if not run:
raise SanicException("Run not found", status_code=404)
artifact_dict = _format_artifact(artifact_dict, run)
return json(artifact_dict)


@app.route('/arcee/v2/artifacts/<id_>', methods=["PATCH", ], ctx_label='token')
@validate(json=ArtifactPatchIn)
async def update_artifact(request, body: ArtifactPatchIn, id_: str):
token = request.ctx.token
artifact = await _get_artifact(token, id_)
run = await db.run.find_one(
{"_id": artifact['run_id'], 'deleted_at': 0})
if not run:
raise SanicException("Run not found", status_code=404)
updates = body.model_dump(exclude_unset=True)
if updates:
await db.artifact.update_one(
{"_id": id_}, {'$set': updates})
obj = await db.artifact.find_one({"_id": id_})
artifact = Artifact(**obj).model_dump(by_alias=True)
artifact = _format_artifact(artifact, run)
return json(artifact)


@app.route('/arcee/v2/artifacts/<id_>', methods=["DELETE", ],
ctx_label='token')
async def delete_artifact(request, id_: str):
await _get_artifact(request.ctx.token, id_)
await db.artifact.delete_one({"_id": id_})
return json({'deleted': True, '_id': id_}, status=204)


if __name__ == '__main__':
logger.info('Waiting for migration lock')
# trick to lock migrations
Expand Down
31 changes: 31 additions & 0 deletions arcee/arcee_receiver/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Urls:
metrics = '/arcee/v2/metrics'
metric = '/arcee/v2/metrics/{}'
collect = '/arcee/v2/collect'
artifacts = '/arcee/v2/artifacts'
artifact = '/arcee/v2/artifacts/{}'


async def prepare_token():
Expand Down Expand Up @@ -198,3 +200,32 @@ async def prepare_model_version(model_id, run_id, version='1', aliases=None,
await DB_MOCK['model_version'].insert_one(model_version)
return await DB_MOCK['model_version'].find_one({
"run_id": run_id, "model_id": model_id})


async def prepare_artifact(run_id, name=None, description=None,
path=None, tags=None, created_at=None):
now = datetime.now(tz=timezone.utc)
if not created_at:
created_at = int(now.timestamp())
if not name:
name = "my artifact"
if not description:
description = "my artifact"
if not path:
path = "/my/path"
if not tags:
tags = {"key": "value"}
artifact = {
"_id": str(uuid.uuid4()),
"path": path,
"name": name,
"description": description,
"tags": tags,
"run_id": run_id,
"created_at": created_at,
"token": TOKEN1,
'_created_at_dt': int(now.replace(hour=0, minute=0, second=0,
microsecond=0).timestamp())
}
await DB_MOCK['artifact'].insert_one(artifact)
return await DB_MOCK['artifact'].find_one({'_id': artifact['_id']})
Loading

0 comments on commit 15ad836

Please sign in to comment.