diff --git a/arcee/arcee_receiver/migrations/20240524090000_artifact_index.py b/arcee/arcee_receiver/migrations/20240524090000_artifact_index.py new file mode 100644 index 00000000..32beb66f --- /dev/null +++ b/arcee/arcee_receiver/migrations/20240524090000_artifact_index.py @@ -0,0 +1,22 @@ +from mongodb_migrations.base import BaseMigration + + +INDEX_NAME = 'RunIdCreatedAtDt' +INDEX_FIELDS = ['run_id', '_created_at_dt'] + + +class Migration(BaseMigration): + def existing_indexes(self): + return [x['name'] for x in self.db.artifact.list_indexes()] + + def upgrade(self): + if INDEX_NAME not in self.existing_indexes(): + self.db.artifact.create_index( + [(key, 1) for key in INDEX_FIELDS], + name=INDEX_NAME, + background=True + ) + + def downgrade(self): + if INDEX_NAME in self.existing_indexes(): + self.db.artifact.drop_index(INDEX_NAME) diff --git a/arcee/arcee_receiver/models.py b/arcee/arcee_receiver/models.py index 189909f8..207e3616 100644 --- a/arcee/arcee_receiver/models.py +++ b/arcee/arcee_receiver/models.py @@ -2,7 +2,8 @@ from enum import Enum from datetime import datetime, timezone from pydantic import ( - BaseModel, BeforeValidator, ConfigDict, Field, model_validator) + BaseModel, BeforeValidator, ConfigDict, Field, NonNegativeInt, + model_validator) from typing import List, Optional, Union from typing_extensions import Annotated @@ -23,6 +24,10 @@ class BaseClass(BaseModel): default_factory=lambda: int(datetime.now(tz=timezone.utc).timestamp())) now_ms = Field( default_factory=lambda: datetime.now(tz=timezone.utc).timestamp()) +date_start = Field( + default_factory=lambda: int(datetime.now(tz=timezone.utc).replace( + hour=0, minute=0, second=0, microsecond=0).timestamp()), + alias='_created_at_dt') class ConsolePostIn(BaseClass): @@ -311,3 +316,77 @@ class MetricPatchIn(MetricPostIn): class Metric(MetricPatchIn): id: str = id_ token: str + + +class ArtifactPatchIn(BaseClass): + path: str = None + name: Optional[str] = None + description: Optional[str] = None + tags: Optional[dict] = {} + + +class ArtifactPostIn(ArtifactPatchIn): + run_id: str + path: str + + @model_validator(mode='after') + def set_name(self): + if not self.name: + self.name = self.path + return self + + +class Artifact(ArtifactPostIn): + id: str = id_ + token: str + created_at: int = now + created_at_dt: int = date_start + + +timestamp = Field(None, ge=0, le=2**31-1) +max_mongo_int = Field(0, ge=0, le=2**63-1) + + +class ArtifactSearchParams(BaseModel): + created_at_lt: Optional[NonNegativeInt] = timestamp + created_at_gt: Optional[NonNegativeInt] = timestamp + limit: Optional[NonNegativeInt] = max_mongo_int + start_from: Optional[NonNegativeInt] = max_mongo_int + run_id: Optional[Union[list, str]] = [] + text_like: Optional[str] = None + + @model_validator(mode='before') + def convert_to_expected_types(self): + """ + Converts a dict of request.args passed as query parameters to a dict + suitable for further model validation. + + Example: + request.args: {"limit": ["1"], "text_like": ["test"]} + return: {"limit": "1", "text_like": "test"} + """ + numeric_fields = ['created_at_lt', 'created_at_gt', + 'limit', 'start_from'] + for k, v in self.items(): + if isinstance(v, list) and len(v) == 1: + v = v[0] + self[k] = v + if k in numeric_fields: + try: + self[k] = int(v) + except (TypeError, ValueError): + continue + return self + + @model_validator(mode='after') + def validate_run_id(self): + if isinstance(self.run_id, str): + self.run_id = [self.run_id] + return self + + @model_validator(mode='after') + def validate_created_at(self): + if (self.created_at_gt is not None and self.created_at_lt is not None + and self.created_at_lt <= self.created_at_gt): + raise ValueError('Invalid created_at filter values') + return self diff --git a/arcee/arcee_receiver/server.py b/arcee/arcee_receiver/server.py index d928ddeb..3a49d72b 100644 --- a/arcee/arcee_receiver/server.py +++ b/arcee/arcee_receiver/server.py @@ -1,6 +1,6 @@ import time from collections import OrderedDict, defaultdict -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta import asyncio from etcd import Lock as EtcdLock, Client as EtcdClient @@ -10,6 +10,7 @@ from mongodb_migrations.cli import MigrationManager from mongodb_migrations.config import Configuration +from pydantic import ValidationError from sanic import Sanic from sanic.log import logger from sanic.response import json @@ -23,7 +24,8 @@ LeaderboardDatasetPatchIn, LeaderboardDatasetPostIn, Leaderboard, LeaderboardPostIn, LeaderboardPatchIn, Log, Platform, StatsPostIn, ModelPatchIn, ModelPostIn, Model, ModelVersionIn, - ModelVersion, Metric, MetricPostIn, MetricPatchIn + ModelVersion, Metric, MetricPostIn, MetricPatchIn, + ArtifactPostIn, ArtifactPatchIn, Artifact, ArtifactSearchParams, ) from arcee.arcee_receiver.modules.leader_board import ( get_calculated_leaderboard, Tendencies) @@ -360,15 +362,17 @@ async def delete_task(request, id_: str): db.proc_data.delete_many({'run_id': {'$in': runs}}), db.log.delete_many({'run_id': {'$in': runs}}), db.run.delete_many({"task_id": id_}), - db.console.delete_many({'run_id': {'$in': runs}}) + db.console.delete_many({'run_id': {'$in': runs}}), + db.artifact.delete_many({'run_id': {'$in': runs}}), ) - dm, ds, dpd, dl, dr, dc = results + dm, ds, dpd, dl, dr, dc, da = results deleted_milestones = dm.deleted_count deleted_stages = ds.deleted_count deleted_logs = dl.deleted_count deleted_runs = dr.deleted_count deleted_proc_data = dpd.deleted_count deleted_consoles = dc.deleted_count + deleted_artifacts = da.deleted_count leaderboard = await db.leaderboard.find_one( {"token": token, "task_id": id_, "deleted_at": 0}) now = int(datetime.now(tz=timezone.utc).timestamp()) @@ -396,7 +400,8 @@ async def delete_task(request, id_: str): "deleted_runs": deleted_runs, "deleted_stages": deleted_stages, "deleted_proc_data": deleted_proc_data, - "deleted_console_output": deleted_consoles + "deleted_console_output": deleted_consoles, + "deleted_artifacts": deleted_artifacts }) @@ -903,6 +908,7 @@ async def delete_run(request, run_id: str): await db.stage.delete_many({'run_id': run['_id']}) await db.milestone.delete_many({'run_id': run['_id']}) await db.proc_data.delete_many({'run_id': run['_id']}) + await db.artifact.delete_many({'run_id': run['_id']}) await db.model_version.update_many({'run_id': run['_id'], 'deleted_at': 0}, {'$set': {'deleted_at': now}}) await db.run.delete_one({'_id': run_id}) @@ -2287,6 +2293,165 @@ async def get_model_versions_for_task(request, task_id: str): return json(versions) +def _format_artifact(artifact: dict, run: dict) -> dict: + artifact.pop('run_id', None) + artifact['run'] = { + '_id': run['_id'], + 'task_id': run['task_id'], + 'name': run['name'], + 'number': run['number'] + } + return artifact + + +async def _create_artifact(**kwargs) -> dict: + artifact = Artifact(**kwargs).model_dump(by_alias=True) + await db.artifact.insert_one(artifact) + return artifact + + +@app.route('/arcee/v2/artifacts', methods=["POST", ], ctx_label='token') +@validate(json=ArtifactPostIn) +async def create_artifact(request, body: ArtifactPostIn): + token = request.ctx.token + run = await db.run.find_one({"_id": body.run_id, 'deleted_at': 0}) + if not run: + raise SanicException("Run not found", status_code=404) + artifact = await _create_artifact( + token=token, **body.model_dump(exclude_unset=True)) + artifact = _format_artifact(artifact, run) + return json(artifact, status=201) + + +def _build_artifact_filter_pipeline(run_ids: list, + query: ArtifactSearchParams): + filters = defaultdict(dict) + filters['run_id'] = {'$in': run_ids} + if query.created_at_lt: + created_at_dt = int((datetime.fromtimestamp( + query.created_at_lt) + timedelta(days=1)).replace( + hour=0, minute=0, second=0, microsecond=0).timestamp()) + filters['_created_at_dt'].update({'$lt': created_at_dt}) + filters['created_at'].update({'$lt': query.created_at_lt}) + if query.created_at_gt: + created_at_dt = int(datetime.fromtimestamp( + query.created_at_gt).replace(hour=0, minute=0, second=0, + microsecond=0).timestamp()) + filters['_created_at_dt'].update({'$gte': created_at_dt}) + filters['created_at'].update({'$gt': query.created_at_gt}) + pipeline = [{'$match': filters}] + if query.text_like: + pipeline += [ + {'$addFields': {'tags_array': {'$objectToArray': '$tags'}}}, + {'$match': {'$or': [ + {'name': {'$regex': f'(.*){query.text_like}(.*)'}}, + {'description': {'$regex': f'(.*){query.text_like}(.*)'}}, + {'path': {'$regex': f'(.*){query.text_like}(.*)'}}, + {'tags_array.k': {'$regex': f'(.*){query.text_like}(.*)'}}, + {'tags_array.v': {'$regex': f'(.*){query.text_like}(.*)'}}, + ]}} + ] + return pipeline + + +@app.route('/arcee/v2/artifacts', methods=["GET", ], ctx_label='token') +async def list_artifacts(request): + token = request.ctx.token + try: + query = ArtifactSearchParams(**request.args) + except ValidationError as e: + raise SanicException(f'Invalid query params: {str(e)}', + status_code=400) + result = { + 'artifacts': [], + 'limit': query.limit, + 'start_from': query.start_from, + 'total_count': 0 + } + tasks = [x['_id'] async for x in db.task.find({'token': token}, + {'_id': 1})] + run_query = {'task_id': {'$in': tasks}, 'deleted_at': 0} + if query.run_id: + run_query['_id'] = {'$in': query.run_id} + runs_map = {run['_id']: run async for run in db.run.find(run_query)} + runs_ids = list(runs_map.keys()) + + pipeline = _build_artifact_filter_pipeline(runs_ids, query) + pipeline.append({'$sort': {'created_at': -1, '_id': 1}}) + + paginate_pipeline = [{'$skip': query.start_from}] + if query.limit: + paginate_pipeline.append({'$limit': query.limit}) + pipeline.extend(paginate_pipeline) + async for artifact in db.artifact.aggregate(pipeline, allowDiskUse=True): + artifact.pop('tags_array', None) + res = Artifact(**artifact).model_dump(by_alias=True) + res.pop('run_id', None) + res = _format_artifact(artifact, runs_map[artifact['run_id']]) + result['artifacts'].append(res) + if len(result['artifacts']) != 0 and not query.limit: + result['total_count'] = len(result['artifacts']) + query.start_from + else: + pipeline = _build_artifact_filter_pipeline(runs_ids, query) + pipeline.append({'$count': 'count'}) + res = db.artifact.aggregate(pipeline) + try: + count = await res.next() + result['total_count'] = count['count'] + except StopAsyncIteration: + pass + return json(result) + + +async def _get_artifact(token, artifact_id): + artifact = await db.artifact.find_one( + {"$and": [ + {"token": token}, + {"_id": artifact_id} + ]}) + if not artifact: + raise SanicException("Artifact not found", status_code=404) + return artifact + + +@app.route('/arcee/v2/artifacts/', methods=["GET", ], ctx_label='token') +async def get_artifact(request, id_: str): + artifact = await _get_artifact(request.ctx.token, id_) + artifact_dict = Artifact(**artifact).model_dump(by_alias=True) + run = await db.run.find_one({"_id": artifact['run_id'], 'deleted_at': 0}) + if not run: + raise SanicException("Run not found", status_code=404) + artifact_dict = _format_artifact(artifact_dict, run) + return json(artifact_dict) + + +@app.route('/arcee/v2/artifacts/', methods=["PATCH", ], ctx_label='token') +@validate(json=ArtifactPatchIn) +async def update_artifact(request, body: ArtifactPatchIn, id_: str): + token = request.ctx.token + artifact = await _get_artifact(token, id_) + run = await db.run.find_one( + {"_id": artifact['run_id'], 'deleted_at': 0}) + if not run: + raise SanicException("Run not found", status_code=404) + updates = body.model_dump(exclude_unset=True) + if updates: + await db.artifact.update_one( + {"_id": id_}, {'$set': updates}) + obj = await db.artifact.find_one({"_id": id_}) + artifact = Artifact(**obj).model_dump(by_alias=True) + artifact = _format_artifact(artifact, run) + return json(artifact) + + +@app.route('/arcee/v2/artifacts/', methods=["DELETE", ], + ctx_label='token') +async def delete_artifact(request, id_: str): + await _get_artifact(request.ctx.token, id_) + await db.artifact.delete_one({"_id": id_}) + return json({'deleted': True, '_id': id_}, status=204) + + if __name__ == '__main__': logger.info('Waiting for migration lock') # trick to lock migrations diff --git a/arcee/arcee_receiver/tests/base.py b/arcee/arcee_receiver/tests/base.py index 9334656d..0487764e 100644 --- a/arcee/arcee_receiver/tests/base.py +++ b/arcee/arcee_receiver/tests/base.py @@ -34,6 +34,8 @@ class Urls: metrics = '/arcee/v2/metrics' metric = '/arcee/v2/metrics/{}' collect = '/arcee/v2/collect' + artifacts = '/arcee/v2/artifacts' + artifact = '/arcee/v2/artifacts/{}' async def prepare_token(): @@ -198,3 +200,32 @@ async def prepare_model_version(model_id, run_id, version='1', aliases=None, await DB_MOCK['model_version'].insert_one(model_version) return await DB_MOCK['model_version'].find_one({ "run_id": run_id, "model_id": model_id}) + + +async def prepare_artifact(run_id, name=None, description=None, + path=None, tags=None, created_at=None): + now = datetime.now(tz=timezone.utc) + if not created_at: + created_at = int(now.timestamp()) + if not name: + name = "my artifact" + if not description: + description = "my artifact" + if not path: + path = "/my/path" + if not tags: + tags = {"key": "value"} + artifact = { + "_id": str(uuid.uuid4()), + "path": path, + "name": name, + "description": description, + "tags": tags, + "run_id": run_id, + "created_at": created_at, + "token": TOKEN1, + '_created_at_dt': int(now.replace(hour=0, minute=0, second=0, + microsecond=0).timestamp()) + } + await DB_MOCK['artifact'].insert_one(artifact) + return await DB_MOCK['artifact'].find_one({'_id': artifact['_id']}) diff --git a/arcee/arcee_receiver/tests/conftest.py b/arcee/arcee_receiver/tests/conftest.py index 58af8348..1e1a6b64 100644 --- a/arcee/arcee_receiver/tests/conftest.py +++ b/arcee/arcee_receiver/tests/conftest.py @@ -29,6 +29,7 @@ async def clean_env(): await DB_MOCK['log'].drop() await DB_MOCK['platform'].drop() await DB_MOCK['dataset'].drop() + await DB_MOCK['artifact'].drop() @pytest.fixture(autouse=True) diff --git a/arcee/arcee_receiver/tests/test_artifact.py b/arcee/arcee_receiver/tests/test_artifact.py new file mode 100644 index 00000000..a80bd4f6 --- /dev/null +++ b/arcee/arcee_receiver/tests/test_artifact.py @@ -0,0 +1,552 @@ +import json +import uuid +import pytest +from datetime import datetime, timezone + +from arcee.arcee_receiver.tests.base import ( + Urls, TOKEN1, prepare_token, prepare_run, prepare_artifact, + prepare_tasks +) + + +@pytest.mark.asyncio +async def test_invalid_token(app): + client = app.asgi_client + for path, method in [ + (Urls.artifacts, client.post), + (Urls.artifacts, client.get), + (Urls.artifact.format(str(uuid.uuid4())), client.get), + (Urls.artifact.format(str(uuid.uuid4())), client.patch), + (Urls.artifact.format(str(uuid.uuid4())), client.delete), + ]: + _, response = await method(path, headers={"x-api-key": "wrong"}) + assert response.status == 401 + assert "Token not found" in response.text + + _, response = await method(path) + assert response.status == 401 + assert "API key is required" in response.text + + +@pytest.mark.asyncio +async def test_create_artifact(app): + client = app.asgi_client + await prepare_token() + run = await prepare_run('task_id', 99, 1, 1, {}) + artifact = { + "path": "/my/path", + "name": "my artifact", + "description": "my artifact", + "tags": {"key": "value"}, + "run_id": run['_id'] + } + _, response = await client.post(Urls.artifacts, + data=json.dumps(artifact), + headers={"x-api-key": TOKEN1}) + assert response.status == 201 + assert response.json['run']['_id'] == artifact.pop('run_id', None) + for key, value in artifact.items(): + assert response.json[key] == value + + +@pytest.mark.asyncio +async def test_create_invalid_run(app): + client = app.asgi_client + await prepare_token() + artifact = { + "path": "/my/path", + "name": "my artifact", + "description": "my artifact", + "tags": {"key": "value"}, + "run_id": str(uuid.uuid4()) + } + _, response = await client.post(Urls.artifacts, + data=json.dumps(artifact), + headers={"x-api-key": TOKEN1}) + assert response.status == 404 + assert 'Run not found' in response.text + + +@pytest.mark.asyncio +async def test_create_required_params(app): + client = app.asgi_client + await prepare_token() + run = await prepare_run('task_id', 99, 1, 1, {}) + artifact = { + "run_id": run['_id'], + "path": "path/test" + } + _, response = await client.post(Urls.artifacts, + data=json.dumps(artifact), + headers={"x-api-key": TOKEN1}) + assert response.status == 201 + assert response.json['token'] == TOKEN1 + + client = app.asgi_client + await prepare_token() + for param in ['run_id', 'path']: + params = artifact.copy() + params.pop(param, None) + _, response = await client.post(Urls.artifacts, + data=json.dumps(params), + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert 'Field required' in response.text + + +@pytest.mark.asyncio +async def test_create_invalid_params_types(app): + client = app.asgi_client + await prepare_token() + run = await prepare_run('task_id', 99, 1, 1, {}) + artifact = { + "run_id": run['_id'], + "path": "path/test" + } + for param in ["description", "name", "path", "run_id"]: + for value in [1, {"test": 1}, ['test']]: + params = artifact.copy() + params[param] = value + _, response = await client.post( + Urls.artifacts, data=json.dumps(params), + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert "Input should be a valid string" in response.text + + for value in [1, "test", ['test']]: + params = artifact.copy() + params['tags'] = value + _, response = await client.post( + Urls.artifacts, data=json.dumps(params), + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert "Input should be a valid dictionary" in response.text + + +@pytest.mark.asyncio +async def test_create_unexpected(app): + client = app.asgi_client + await prepare_token() + run = await prepare_run('task_id', 99, 1, 1, {}) + artifact = { + "run_id": run['_id'], + "path": "path/test" + } + + for param in ['_id', 'created_at', 'token', 'test']: + data = artifact.copy() + data[param] = 'test' + _, response = await client.post(Urls.artifacts, + data=json.dumps(data), + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert "Extra inputs are not permitted" in response.text + + +@pytest.mark.asyncio +async def test_list_artifacts_empty(app): + client = app.asgi_client + await prepare_token() + _, response = await client.get(Urls.artifacts, + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 0 + assert response.json['total_count'] == 0 + assert response.json['limit'] == 0 + assert response.json['start_from'] == 0 + + +@pytest.mark.asyncio +async def test_list_invalid_query_params(app): + client = app.asgi_client + await prepare_token() + for param in ['created_at_lt', 'created_at_gt', 'limit', 'start_from']: + query_url = f'?{param}=test' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert 'Input should be a valid integer' in response.text + + query_url = f'?{param}=-10' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert 'Input should be greater than or equal to 0' in response.text + + query_url = '?created_at_lt=0&created_at_gt=2' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert 'Invalid created_at filter values' in response.text + + +@pytest.mark.asyncio +async def test_list_created_at_filter(app): + client = app.asgi_client + await prepare_token() + for param in ['created_at_gt', 'created_at_lt']: + query_url = f'?{param}={2**32}' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert 'Input should be less than or equal ' \ + 'to 2147483647' in response.text + + query_url = f'?{param}={2**30}' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + + +@pytest.mark.asyncio +async def test_list_limit_filter(app): + client = app.asgi_client + await prepare_token() + query_url = f'?limit={2**64}' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert 'Input should be less than or equal ' \ + 'to 9223372036854775807' in response.text + + query_url = f'?limit={2**62}' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + + +@pytest.mark.asyncio +async def test_list_unexpected_query_params(app): + client = app.asgi_client + await prepare_token() + query_url = '?unexpected=1' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + + +@pytest.mark.asyncio +async def test_list_run_id(app): + client = app.asgi_client + await prepare_token() + task = await prepare_tasks() + run = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + run_id = run['_id'] + now_ts = int(datetime.now(tz=timezone.utc).timestamp()) + await prepare_artifact(run['_id'], created_at=now_ts) + + query_url = f'?run_id={run_id}' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 1 + + query_url = f'?run_id={run_id}&run_id=run_id' + _, response = await client.get(Urls.artifacts + query_url, + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 1 + + +@pytest.mark.asyncio +async def test_list_artifacts_created_at(app): + client = app.asgi_client + await prepare_token() + task = await prepare_tasks() + run = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + now_ts = int(datetime.now(tz=timezone.utc).timestamp()) + artifact = await prepare_artifact(run['_id'], created_at=now_ts) + date1_ts = now_ts - 10 + await prepare_artifact(run['_id'], created_at=date1_ts) + date2_ts = now_ts + 10 + await prepare_artifact(run['_id'], created_at=date2_ts) + _, response = await client.get( + Urls.artifacts + f'?created_at_lt={date2_ts}&created_at_gt={date1_ts}', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 1 + assert response.json['artifacts'][0]['_id'] == artifact['_id'] + assert response.json['total_count'] == 1 + assert response.json['limit'] == 0 + assert response.json['start_from'] == 0 + + +@pytest.mark.asyncio +async def test_list_artifacts_run_id(app): + client = app.asgi_client + await prepare_token() + task = await prepare_tasks() + run1 = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + run2 = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + run3 = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + artifact1 = await prepare_artifact(run1['_id']) + artifact2 = await prepare_artifact(run2['_id']) + await prepare_artifact(run3['_id']) + _, response = await client.get( + Urls.artifacts + f'?run_id={run1["_id"]}&run_id={run2["_id"]}', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 2 + assert response.json['total_count'] == 2 + assert response.json['limit'] == 0 + assert response.json['start_from'] == 0 + for artifact in response.json['artifacts']: + assert artifact['_id'] in [artifact1['_id'], artifact2['_id']] + + +@pytest.mark.asyncio +async def test_list_artifacts_text_like(app): + client = app.asgi_client + await prepare_token() + task = await prepare_tasks() + run = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + artifact1 = await prepare_artifact(run['_id'], name='test1') + artifact2 = await prepare_artifact(run['_id'], description='test2') + artifact3 = await prepare_artifact(run['_id'], tags={'test3': 1}) + artifact4 = await prepare_artifact(run['_id'], tags={'artifact': 'test4'}) + _, response = await client.get(Urls.artifacts + f'?text_like=test1', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 1 + assert response.json['total_count'] == 1 + assert response.json['limit'] == 0 + assert response.json['start_from'] == 0 + assert response.json['artifacts'][0]['_id'] == artifact1['_id'] + + _, response = await client.get(Urls.artifacts + f'?text_like=test2', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 1 + assert response.json['total_count'] == 1 + assert response.json['limit'] == 0 + assert response.json['start_from'] == 0 + assert response.json['artifacts'][0]['_id'] == artifact2['_id'] + + _, response = await client.get(Urls.artifacts + f'?text_like=test3', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 1 + assert response.json['total_count'] == 1 + assert response.json['limit'] == 0 + assert response.json['start_from'] == 0 + assert response.json['artifacts'][0]['_id'] == artifact3['_id'] + + _, response = await client.get(Urls.artifacts + f'?text_like=test4', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 1 + assert response.json['total_count'] == 1 + assert response.json['limit'] == 0 + assert response.json['start_from'] == 0 + assert response.json['artifacts'][0]['_id'] == artifact4['_id'] + + +@pytest.mark.asyncio +async def test_list_artifacts_limit(app): + client = app.asgi_client + await prepare_token() + task = await prepare_tasks() + run = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + await prepare_artifact(run['_id'], created_at=1) + artifact2 = await prepare_artifact(run['_id'], created_at=2) + artifact3 = await prepare_artifact(run['_id'], created_at=3) + _, response = await client.get(Urls.artifacts + f'?limit=1', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 1 + assert response.json['artifacts'][0]['_id'] == artifact3['_id'] + assert response.json['total_count'] == 3 + assert response.json['limit'] == 1 + assert response.json['start_from'] == 0 + + _, response = await client.get(Urls.artifacts + f'?limit=1&start_from=1', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert len(response.json['artifacts']) == 1 + assert response.json['artifacts'][0]['_id'] == artifact2['_id'] + assert response.json['total_count'] == 3 + assert response.json['limit'] == 1 + assert response.json['start_from'] == 1 + + +@pytest.mark.asyncio +async def test_list_artifacts_total_count(app): + client = app.asgi_client + await prepare_token() + task = await prepare_tasks() + run = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + await prepare_artifact(run['_id'], created_at=1) + await prepare_artifact(run['_id'], created_at=2) + await prepare_artifact(run['_id'], created_at=3) + + _, response = await client.get(Urls.artifacts, + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert response.json['total_count'] == 3 + + _, response = await client.get(Urls.artifacts + f'?limit=1', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert response.json['total_count'] == 3 + + _, response = await client.get(Urls.artifacts + f'?start_from=1', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert response.json['total_count'] == 3 + + _, response = await client.get(Urls.artifacts + f'?start_from=10', + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert response.json['total_count'] == 3 + + +@pytest.mark.asyncio +async def test_patch_empty(app): + client = app.asgi_client + await prepare_token() + run = await prepare_run('task_id', 99, 1, 1, {}) + artifact = await prepare_artifact(run['_id']) + _, response = await client.patch( + Urls.artifact.format(artifact['_id']), + data=json.dumps({}), + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + + +@pytest.mark.asyncio +async def test_patch_invalid_params_types(app): + client = app.asgi_client + await prepare_token() + artifact = await prepare_artifact(TOKEN1) + for param in ["description", "name", "path"]: + for value in [1, {"test": 1}, ['test']]: + updates = { + param: value + } + _, response = await client.patch( + Urls.artifact.format(artifact['_id']), + data=json.dumps(updates), + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert "Input should be a valid string" in response.text + + for value in [1, "test", ['test']]: + updates = { + "tags": value + } + _, response = await client.patch( + Urls.artifact.format(artifact['_id']), + data=json.dumps(updates), + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert "Input should be a valid dictionary" in response.text + + +@pytest.mark.asyncio +async def test_patch_unexpected(app): + client = app.asgi_client + await prepare_token() + artifact = await prepare_artifact(TOKEN1) + for param in ['_id', 'created_at', 'run_id', 'token', 'test']: + updates = {param: 'test'} + _, response = await client.patch( + Urls.artifact.format(artifact['_id']), + data=json.dumps(updates), + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert "Extra inputs are not permitted" in response.text + + +@pytest.mark.asyncio +async def test_patch_artifact(app): + client = app.asgi_client + await prepare_token() + task = await prepare_tasks() + run = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + artifact = await prepare_artifact(run_id=run['_id']) + updates = { + "name": "new", + "description": "new", + "tags": {"new": "new"}, + } + _, response = await client.patch(Urls.artifact.format(artifact['_id']), + data=json.dumps(updates), + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert response.json['run']['_id'] == artifact.pop('run_id', None) + for key, value in updates.items(): + assert response.json[key] == value + + +@pytest.mark.asyncio +async def test_patch_artifact_empty_path(app): + client = app.asgi_client + await prepare_token() + task = await prepare_tasks() + run = await prepare_run(task[0]['_id'], 99, 1, 1, {}) + artifact = await prepare_artifact(run_id=run['_id']) + updates = { + "path": None + } + _, response = await client.patch(Urls.artifact.format(artifact['_id']), + data=json.dumps(updates), + headers={"x-api-key": TOKEN1}) + assert response.status == 400 + assert 'Input should be a valid string' in response.text + + +@pytest.mark.asyncio +async def test_patch_not_existing(app): + client = app.asgi_client + await prepare_token() + updates = { + "name": "my artifact" + } + _, response = await client.patch(Urls.artifact.format('artifact_id'), + data=json.dumps(updates), + headers={"x-api-key": TOKEN1}) + assert response.status == 404 + assert "Artifact not found" in response.text + + +@pytest.mark.asyncio +async def test_get_artifact(app): + client = app.asgi_client + await prepare_token() + run = await prepare_run('task_id', 99, 1, 1, {}) + artifact = await prepare_artifact(run['_id']) + _, response = await client.get(Urls.artifact.format(artifact['_id']), + headers={"x-api-key": TOKEN1}) + assert response.status == 200 + assert response.json['token'] == TOKEN1 + + +@pytest.mark.asyncio +async def test_get_not_existing(app): + client = app.asgi_client + await prepare_token() + _, response = await client.get(Urls.artifact.format('artifact_id'), + headers={"x-api-key": TOKEN1}) + assert response.status == 404 + assert "Artifact not found" in response.text + + +@pytest.mark.asyncio +async def test_delete_artifact(app): + client = app.asgi_client + await prepare_token() + run = await prepare_run('task_id', 99, 1, 1, {}) + artifact = await prepare_artifact(run['_id']) + _, response = await client.delete(Urls.artifact.format(artifact['_id']), + headers={"x-api-key": TOKEN1}) + assert response.status == 204 + + +@pytest.mark.asyncio +async def test_delete_not_existing(app): + client = app.asgi_client + await prepare_token() + _, response = await client.delete(Urls.artifact.format('artifact_id'), + headers={"x-api-key": TOKEN1}) + assert response.status == 404 + assert "Artifact not found" in response.text diff --git a/docker_images/cleanmongodb/clean-mongo-db.py b/docker_images/cleanmongodb/clean-mongo-db.py index b130ec15..0e2f9ef7 100644 --- a/docker_images/cleanmongodb/clean-mongo-db.py +++ b/docker_images/cleanmongodb/clean-mongo-db.py @@ -36,6 +36,7 @@ def __init__(self): self.mongo_client.arcee.proc_data: ROWS_LIMIT, self.mongo_client.arcee.stage: ROWS_LIMIT, self.mongo_client.arcee.model_version: ROWS_LIMIT, + self.mongo_client.arcee.artifact: ROWS_LIMIT, # linked to task_id self.mongo_client.arcee.run: ROWS_LIMIT, # linked to profiling_token.token @@ -178,7 +179,8 @@ def _delete_runs(self, runs_ids_chunk): self.mongo_client.arcee.milestone, self.mongo_client.arcee.stage, self.mongo_client.arcee.proc_data, - self.mongo_client.arcee.model_version] + self.mongo_client.arcee.model_version, + self.mongo_client.arcee.artifact] if all(self.limits.get(x) == 0 for x in run_collections): # maximum number of entities related to runs have already # been deleted @@ -353,6 +355,7 @@ def organization_limits(self): self.mongo_client.arcee.proc_data, self.mongo_client.arcee.model, self.mongo_client.arcee.model_version, + self.mongo_client.arcee.artifact, self.mongo_client.bulldozer.template, self.mongo_client.bulldozer.runset, self.mongo_client.bulldozer.runner] diff --git a/ngui/ui/src/api/index.ts b/ngui/ui/src/api/index.ts index 70834e28..d2c83d8d 100644 --- a/ngui/ui/src/api/index.ts +++ b/ngui/ui/src/api/index.ts @@ -220,7 +220,12 @@ import { updateMlModel, deleteMlModel, getMlTaskModelVersions, - updateMlModelVersion + updateMlModelVersion, + getMlArtifacts, + getMlArtifact, + updateMlArtifact, + createMlArtifact, + deleteMlArtifact } from "./restapi"; import { RESTAPI } from "./restapi/reducer"; @@ -445,7 +450,12 @@ export { updateMlModel, deleteMlModel, getMlTaskModelVersions, - updateMlModelVersion + updateMlModelVersion, + getMlArtifacts, + getMlArtifact, + updateMlArtifact, + createMlArtifact, + deleteMlArtifact }; export { RESTAPI, AUTH, JIRA_BUS }; diff --git a/ngui/ui/src/api/restapi/actionCreators.ts b/ngui/ui/src/api/restapi/actionCreators.ts index ad8692ce..15a83c15 100644 --- a/ngui/ui/src/api/restapi/actionCreators.ts +++ b/ngui/ui/src/api/restapi/actionCreators.ts @@ -315,7 +315,14 @@ import { DELETE_ML_MODEL, GET_ML_TASK_MODEL_VERSIONS, UPDATE_ML_MODEL_VERSION, - SET_ML_TASK_MODEL_VERSIONS + SET_ML_TASK_MODEL_VERSIONS, + GET_ML_ARTIFACTS, + SET_ML_ARTIFACTS, + SET_ML_ARTIFACT, + GET_ML_ARTIFACT, + UPDATE_ML_ARTIFACT, + CREATE_ML_ARTIFACT, + DELETE_ML_ARTIFACT } from "./actionTypes"; import { onUpdateOrganizationOption, @@ -353,7 +360,8 @@ import { onUpdateS3DuplicatesOrganizationSettings, onUpdateMlLeaderboardDataset, onUpdatePowerSchedule, - onUpdateMlModel + onUpdateMlModel, + onUpdateMlArtifact } from "./handlers"; export const API_URL = getApiUrl("restapi"); @@ -2374,6 +2382,67 @@ export const getMlExecutorsBreakdown = (organizationId) => hash: hashParams(organizationId) }); +export const getMlArtifacts = (organizationId, params = {}) => + apiAction({ + url: `${API_URL}/organizations/${organizationId}/artifacts`, + method: "GET", + ttl: 5 * MINUTE, + onSuccess: handleSuccess(SET_ML_ARTIFACTS), + hash: hashParams({ organizationId, ...params }), + label: GET_ML_ARTIFACTS, + params: { + limit: params.limit, + run_id: params.runId, + start_from: params.startFrom, + text_like: params.textLike, + created_at_gt: params.createdAtGt, + created_at_lt: params.createdAtLt + } + }); + +export const getMlArtifact = (organizationId, artifactId) => + apiAction({ + url: `${API_URL}/organizations/${organizationId}/artifacts/${artifactId}`, + method: "GET", + ttl: 5 * MINUTE, + onSuccess: handleSuccess(SET_ML_ARTIFACT), + hash: hashParams({ organizationId, artifactId }), + label: GET_ML_ARTIFACT + }); + +export const updateMlArtifact = (organizationId, artifactId, params) => + apiAction({ + url: `${API_URL}/organizations/${organizationId}/artifacts/${artifactId}`, + method: "PATCH", + label: UPDATE_ML_ARTIFACT, + onSuccess: onUpdateMlArtifact, + params, + affectedRequests: [GET_ML_ARTIFACTS] + }); + +export const createMlArtifact = (organizationId, params) => + apiAction({ + url: `${API_URL}/organizations/${organizationId}/artifacts`, + method: "POST", + label: CREATE_ML_ARTIFACT, + params: { + name: params.name, + path: params.path, + description: params.description, + tags: params.tags, + run_id: params.runId + }, + affectedRequests: [GET_ML_ARTIFACTS] + }); + +export const deleteMlArtifact = (organizationId, artifactId) => + apiAction({ + url: `${API_URL}/organizations/${organizationId}/artifacts/${artifactId}`, + method: "DELETE", + label: DELETE_ML_ARTIFACT, + affectedRequests: [GET_ML_ARTIFACTS] + }); + export const getReservedInstancesBreakdown = (organizationId, params) => apiAction({ url: `${API_URL}/organizations/${organizationId}/ri_breakdown`, diff --git a/ngui/ui/src/api/restapi/actionTypes.ts b/ngui/ui/src/api/restapi/actionTypes.ts index 94d16133..cb215b25 100644 --- a/ngui/ui/src/api/restapi/actionTypes.ts +++ b/ngui/ui/src/api/restapi/actionTypes.ts @@ -398,6 +398,14 @@ export const STOP_ML_RUNSET = "STOP_ML_RUNSET"; export const GET_ML_RUNSET_EXECUTORS = "GET_ML_RUNSET_EXECUTORS"; export const SET_ML_RUNSET_EXECUTORS = "SET_ML_RUNSET_EXECUTORS"; +export const GET_ML_ARTIFACTS = "GET_ML_ARTIFACTS"; +export const SET_ML_ARTIFACTS = "SET_ML_ARTIFACTS"; +export const GET_ML_ARTIFACT = "GET_ML_ARTIFACT"; +export const SET_ML_ARTIFACT = "SET_ML_ARTIFACT"; +export const CREATE_ML_ARTIFACT = "CREATE_ML_ARTIFACT"; +export const UPDATE_ML_ARTIFACT = "UPDATE_ML_ARTIFACT"; +export const DELETE_ML_ARTIFACT = "DELETE_ML_ARTIFACT"; + export const GET_ORGANIZATION_BI_EXPORT = "GET_ORGANIZATION_BI_EXPORT"; export const SET_ORGANIZATION_BI_EXPORTS = "SET_ORGANIZATION_BI_EXPORTS"; export const CREATE_ORGANIZATION_BI_EXPORT = "CREATE_ORGANIZATION_BI_EXPORT"; diff --git a/ngui/ui/src/api/restapi/handlers.ts b/ngui/ui/src/api/restapi/handlers.ts index dd52d56f..ab042b59 100644 --- a/ngui/ui/src/api/restapi/handlers.ts +++ b/ngui/ui/src/api/restapi/handlers.ts @@ -60,7 +60,9 @@ import { SET_POWER_SCHEDULE, GET_POWER_SCHEDULE, SET_ML_MODEL, - GET_ML_MODEL + GET_ML_MODEL, + SET_ML_ARTIFACT, + GET_ML_ARTIFACT } from "./actionTypes"; export const onUpdateOrganizationOption = (data) => ({ @@ -246,6 +248,12 @@ export const onUpdateMlTask = (data) => ({ label: GET_ML_TASK }); +export const onUpdateMlArtifact = (data) => ({ + type: SET_ML_ARTIFACT, + payload: data, + label: GET_ML_ARTIFACT +}); + export const onUpdateMlModel = (data) => ({ type: SET_ML_MODEL, payload: data, diff --git a/ngui/ui/src/api/restapi/index.ts b/ngui/ui/src/api/restapi/index.ts index 608c7385..3f336fd7 100644 --- a/ngui/ui/src/api/restapi/index.ts +++ b/ngui/ui/src/api/restapi/index.ts @@ -207,7 +207,12 @@ import { updateMlModel, deleteMlModel, getMlTaskModelVersions, - updateMlModelVersion + updateMlModelVersion, + getMlArtifacts, + getMlArtifact, + updateMlArtifact, + createMlArtifact, + deleteMlArtifact } from "./actionCreators"; export { @@ -419,5 +424,10 @@ export { updateMlModel, deleteMlModel, getMlTaskModelVersions, - updateMlModelVersion + updateMlModelVersion, + getMlArtifacts, + getMlArtifact, + updateMlArtifact, + createMlArtifact, + deleteMlArtifact }; diff --git a/ngui/ui/src/api/restapi/reducer.ts b/ngui/ui/src/api/restapi/reducer.ts index 4b9f46bf..ba721212 100644 --- a/ngui/ui/src/api/restapi/reducer.ts +++ b/ngui/ui/src/api/restapi/reducer.ts @@ -124,7 +124,9 @@ import { SET_RESERVED_INSTANCES_BREAKDOWN, SET_SAVING_PLANS_BREAKDOWN, SET_ML_MODEL, - SET_ML_TASK_MODEL_VERSIONS + SET_ML_TASK_MODEL_VERSIONS, + SET_ML_ARTIFACTS, + SET_ML_ARTIFACT } from "./actionTypes"; export const RESTAPI = "restapi"; @@ -839,6 +841,18 @@ const reducer = (state = {}, action) => { [action.label]: action.payload }; } + case SET_ML_ARTIFACTS: { + return { + ...state, + [action.label]: action.payload + }; + } + case SET_ML_ARTIFACT: { + return { + ...state, + [action.label]: action.payload + }; + } case SET_ORGANIZATION_BI_EXPORTS: { return { ...state, diff --git a/ngui/ui/src/components/ArtifactsTable/ArtifactsTable.tsx b/ngui/ui/src/components/ArtifactsTable/ArtifactsTable.tsx new file mode 100644 index 00000000..1c1c80a9 --- /dev/null +++ b/ngui/ui/src/components/ArtifactsTable/ArtifactsTable.tsx @@ -0,0 +1,161 @@ +import { useMemo } from "react"; +import DeleteOutlinedIcon from "@mui/icons-material/DeleteOutlined"; +import EditOutlinedIcon from "@mui/icons-material/EditOutlined"; +import { FormattedMessage } from "react-intl"; +import { useNavigate } from "react-router-dom"; +import { TABS } from "components/MlTaskRun"; +import { MlDeleteArtifactModal } from "components/SideModalManager/SideModals"; +import Table from "components/Table"; +import TableCellActions from "components/TableCellActions"; +import TextWithDataTestId from "components/TextWithDataTestId"; +import { Pagination, RangeFilter, Search } from "containers/MlArtifactsContainer/MlArtifactsContainer"; +import { useIsAllowed } from "hooks/useAllowedActions"; +import { useOpenSideModal } from "hooks/useOpenSideModal"; +import { Artifact } from "services/MlArtifactsService"; +import { getEditMlArtifactUrl } from "urls"; +import { markdown, run, slicedText, tags, utcTime } from "utils/columns"; +import { TAB_QUERY_PARAM_NAME } from "utils/constants"; + +type ArtifactsTableProps = { + artifacts: Artifact[]; + pagination: Pagination; + search?: Search; + rangeFilter?: RangeFilter; +}; + +const ArtifactsTable = ({ artifacts, pagination, search, rangeFilter }: ArtifactsTableProps) => { + const openSideModal = useOpenSideModal(); + + const tableData = useMemo(() => artifacts, [artifacts]); + + const isManageArtifactsAllowed = useIsAllowed({ + requiredActions: ["EDIT_PARTNER"] + }); + + const navigate = useNavigate(); + + const columns = useMemo(() => { + const getActionsColumn = () => ({ + header: ( + + + + ), + enableSorting: false, + id: "actions", + cell: ({ + row: { + original: { id: artifactId, name, index } + } + }) => ( + , + requiredActions: ["EDIT_PARTNER"], + dataTestId: `btn_edit_${index}`, + action: () => navigate(getEditMlArtifactUrl(artifactId)) + }, + { + key: "delete", + messageId: "delete", + icon: , + color: "error", + requiredActions: ["EDIT_PARTNER"], + dataTestId: `btn_delete_${index}`, + action: () => + openSideModal(MlDeleteArtifactModal, { + id: artifactId, + name, + onSuccess: () => { + const isLastArtifactOnPage = artifacts.length === 1; + + if (isLastArtifactOnPage) { + pagination.onPageIndexChange(Math.max(pagination.pageIndex - 1, 0)); + } + } + }) + } + ]} + /> + ) + }); + + return [ + slicedText({ + headerMessageId: "name", + headerDataTestId: "lbl_name", + accessorKey: "name", + maxTextLength: 70, + enableSorting: false + }), + slicedText({ + headerMessageId: "path", + headerDataTestId: "lbl_path", + accessorKey: "path", + maxTextLength: 70, + copy: true, + enableSorting: false + }), + markdown({ + id: "description", + accessorFn: (originalRow) => originalRow.description, + headerMessageId: "description", + headerDataTestId: "lbl_description", + enableSorting: false + }), + run({ + id: "run", + getRunNumber: ({ run: { number } }) => number, + getRunName: ({ run: { name } }) => name, + getRunId: ({ run: { id } }) => id, + getTaskId: ({ run: { task_id: taskId } }) => taskId, + headerMessageId: "run", + headerDataTestId: "lbl_run", + enableSorting: false, + runDetailsUrlOptions: { + [TAB_QUERY_PARAM_NAME]: TABS.ARTIFACTS + } + }), + utcTime({ + id: "createdAt", + accessorFn: (originalRow) => originalRow.created_at, + headerMessageId: "createdAt", + headerDataTestId: "lbl_created_at", + enableSorting: false + }), + tags({ + id: "tags", + accessorFn: (originalRow) => + Object.entries(originalRow.tags ?? {}) + .map(([key, val]) => `${key}: ${val}`) + .join(" "), + getTags: (originalRow) => originalRow.tags, + enableSorting: false + }), + ...(isManageArtifactsAllowed ? [getActionsColumn()] : []) + ]; + }, [artifacts.length, isManageArtifactsAllowed, navigate, openSideModal, pagination]); + + return ( + + ); +}; + +export default ArtifactsTable; diff --git a/ngui/ui/src/components/ArtifactsTable/index.ts b/ngui/ui/src/components/ArtifactsTable/index.ts new file mode 100644 index 00000000..8f116145 --- /dev/null +++ b/ngui/ui/src/components/ArtifactsTable/index.ts @@ -0,0 +1,3 @@ +import ArtifactsTable from "./ArtifactsTable"; + +export default ArtifactsTable; diff --git a/ngui/ui/src/components/MlArtifacts/MlArtifacts.tsx b/ngui/ui/src/components/MlArtifacts/MlArtifacts.tsx new file mode 100644 index 00000000..a2fc788a --- /dev/null +++ b/ngui/ui/src/components/MlArtifacts/MlArtifacts.tsx @@ -0,0 +1,43 @@ +import RefreshOutlinedIcon from "@mui/icons-material/RefreshOutlined"; +import { FormattedMessage } from "react-intl"; +import { GET_ML_ARTIFACTS } from "api/restapi/actionTypes"; +import ActionBar from "components/ActionBar"; +import ArtifactsTable from "components/ArtifactsTable"; +import PageContentWrapper from "components/PageContentWrapper"; +import MlArtifactsContainer from "containers/MlArtifactsContainer"; +import { useRefetchApis } from "hooks/useRefetchApis"; + +const MlArtifacts = () => { + const refetch = useRefetchApis(); + + const actionBarDefinition = { + title: { + text: , + dataTestId: "lbl_artifacts" + }, + items: [ + { + key: "btn-refresh", + icon: , + messageId: "refresh", + dataTestId: "btn_refresh", + type: "button", + action: () => refetch([GET_ML_ARTIFACTS]) + } + ] + }; + return ( + <> + + + ( + + )} + /> + + + ); +}; + +export default MlArtifacts; diff --git a/ngui/ui/src/components/MlArtifacts/index.ts b/ngui/ui/src/components/MlArtifacts/index.ts new file mode 100644 index 00000000..1feefd0c --- /dev/null +++ b/ngui/ui/src/components/MlArtifacts/index.ts @@ -0,0 +1,3 @@ +import MlArtifacts from "./MlArtifacts"; + +export default MlArtifacts; diff --git a/ngui/ui/src/components/MlTaskRun/Components/Overview.tsx b/ngui/ui/src/components/MlTaskRun/Components/Overview.tsx index a9f83eec..885cf0bf 100644 --- a/ngui/ui/src/components/MlTaskRun/Components/Overview.tsx +++ b/ngui/ui/src/components/MlTaskRun/Components/Overview.tsx @@ -227,7 +227,7 @@ const Overview = ({ reachedGoals, dataset, git, tags, hyperparameters, command, - + diff --git a/ngui/ui/src/components/MlTaskRun/MlTaskRun.tsx b/ngui/ui/src/components/MlTaskRun/MlTaskRun.tsx index 5f5a14d2..9a17e091 100644 --- a/ngui/ui/src/components/MlTaskRun/MlTaskRun.tsx +++ b/ngui/ui/src/components/MlTaskRun/MlTaskRun.tsx @@ -3,11 +3,12 @@ import RefreshOutlinedIcon from "@mui/icons-material/RefreshOutlined"; import { Link, Stack, Typography } from "@mui/material"; import { FormattedMessage } from "react-intl"; import { Link as RouterLink } from "react-router-dom"; -import { GET_ML_EXECUTORS, GET_ML_RUN_DETAILS, GET_ML_RUN_DETAILS_BREAKDOWN } from "api/restapi/actionTypes"; +import { GET_ML_ARTIFACTS, GET_ML_EXECUTORS, GET_ML_RUN_DETAILS, GET_ML_RUN_DETAILS_BREAKDOWN } from "api/restapi/actionTypes"; import ActionBar from "components/ActionBar"; import PageContentWrapper from "components/PageContentWrapper"; import TabsWrapper from "components/TabsWrapper"; import ExecutionBreakdownContainer from "containers/ExecutionBreakdownContainer"; +import RunArtifactsContainer from "containers/RunArtifactsContainer"; import { useRefetchApis } from "hooks/useRefetchApis"; import { ML_TASKS, getMlTaskDetailsUrl } from "urls"; import { SPACING_2 } from "utils/layouts"; @@ -15,8 +16,9 @@ import { formatRunFullName } from "utils/ml"; import { Executors, Overview } from "./Components"; import Status from "./Components/Status"; -const TABS = Object.freeze({ +export const TABS = Object.freeze({ OVERVIEW: "overview", + ARTIFACTS: "artifacts", CHARTS: "charts", EXECUTORS: "executors" }); @@ -49,6 +51,11 @@ const Tabs = ({ run, isLoading = false }) => { dataTestId: "tab_charts", node: }, + { + title: TABS.ARTIFACTS, + dataTestId: "tab_artifact", + node: + }, { title: TABS.EXECUTORS, dataTestId: "tab_executors", @@ -97,7 +104,7 @@ const MlTaskRun = ({ run, isLoading = false }) => { messageId: "refresh", dataTestId: "btn_refresh", type: "button", - action: () => refetch([GET_ML_RUN_DETAILS, GET_ML_EXECUTORS, GET_ML_RUN_DETAILS_BREAKDOWN]) + action: () => refetch([GET_ML_RUN_DETAILS, GET_ML_EXECUTORS, GET_ML_RUN_DETAILS_BREAKDOWN, GET_ML_ARTIFACTS]) } ] }; diff --git a/ngui/ui/src/components/MlTaskRun/index.ts b/ngui/ui/src/components/MlTaskRun/index.ts index f5510c69..d47d7e1b 100644 --- a/ngui/ui/src/components/MlTaskRun/index.ts +++ b/ngui/ui/src/components/MlTaskRun/index.ts @@ -1,3 +1,4 @@ -import MlTaskRun from "./MlTaskRun"; +import MlTaskRun, { TABS } from "./MlTaskRun"; +export { TABS }; export default MlTaskRun; diff --git a/ngui/ui/src/components/ProfilingIntegration/ProfilingIntegration.tsx b/ngui/ui/src/components/ProfilingIntegration/ProfilingIntegration.tsx index c1ed7813..20116cfc 100644 --- a/ngui/ui/src/components/ProfilingIntegration/ProfilingIntegration.tsx +++ b/ngui/ui/src/components/ProfilingIntegration/ProfilingIntegration.tsx @@ -468,6 +468,116 @@ const SettingModelVersionTag = () => ( ); +const CreatingArtifact = () => ( + <> + + + + + + + +
    +
  • + + {chunks} + }} + /> + +
  • +
  • + + {chunks} + }} + /> + +
  • +
  • + + {chunks} + }} + /> + +
  • +
  • + + {chunks}, + i: (chunks) => {chunks} + }} + /> + +
  • +
+ + +); + +const SettingArtifactTag = () => ( + <> + + + + + + + +
    +
  • + + {chunks} + }} + /> + +
  • +
  • + + {chunks} + }} + /> + +
  • +
  • + + {chunks} + }} + /> + +
  • +
+ + +); + const FinishTaskRun = () => ( <> @@ -557,6 +667,12 @@ const ProfilingIntegration = ({ profilingToken, taskKey, isLoading }) => (
+
+ +
+
+ +
diff --git a/ngui/ui/src/components/RunArtifactsTable/RunArtifactsTable.tsx b/ngui/ui/src/components/RunArtifactsTable/RunArtifactsTable.tsx new file mode 100644 index 00000000..fd6c173e --- /dev/null +++ b/ngui/ui/src/components/RunArtifactsTable/RunArtifactsTable.tsx @@ -0,0 +1,164 @@ +import { useMemo } from "react"; +import AddOutlinedIcon from "@mui/icons-material/AddOutlined"; +import DeleteOutlinedIcon from "@mui/icons-material/DeleteOutlined"; +import EditOutlinedIcon from "@mui/icons-material/EditOutlined"; +import { FormattedMessage } from "react-intl"; +import { useNavigate, useParams } from "react-router-dom"; +import { MlDeleteArtifactModal } from "components/SideModalManager/SideModals"; +import Table from "components/Table"; +import TableCellActions from "components/TableCellActions"; +import TextWithDataTestId from "components/TextWithDataTestId"; +import { Pagination, Search } from "containers/MlArtifactsContainer/MlArtifactsContainer"; +import { useIsAllowed } from "hooks/useAllowedActions"; +import { useOpenSideModal } from "hooks/useOpenSideModal"; +import { Artifact } from "services/MlArtifactsService"; +import { getCreateMlRunArtifactUrl, getEditMlRunArtifactUrl } from "urls"; +import { markdown, slicedText, tags, utcTime } from "utils/columns"; + +type RunArtifactsTableProps = { + artifacts: Artifact[]; + pagination: Pagination; + search: Search; +}; + +const RunArtifactsTable = ({ artifacts, pagination, search }: RunArtifactsTableProps) => { + const { taskId, runId } = useParams() as { taskId: string; runId: string }; + + const navigate = useNavigate(); + + const openSideModal = useOpenSideModal(); + + const isManageArtifactsAllowed = useIsAllowed({ + requiredActions: ["EDIT_PARTNER"] + }); + + const tableData = useMemo(() => artifacts, [artifacts]); + + const columns = useMemo(() => { + const getActionsColumn = () => ({ + header: ( + + + + ), + enableSorting: false, + id: "actions", + cell: ({ + row: { + original: { id: artifactId, name, run: { task_id: artifactRunTaskId, id: artifactRunId } = {}, index } + } + }) => ( + , + requiredActions: ["EDIT_PARTNER"], + dataTestId: `btn_edit_${index}`, + action: () => navigate(getEditMlRunArtifactUrl(artifactRunTaskId, artifactRunId, artifactId)) + }, + { + key: "delete", + messageId: "delete", + icon: , + color: "error", + requiredActions: ["EDIT_PARTNER"], + dataTestId: `btn_delete_${index}`, + action: () => + openSideModal(MlDeleteArtifactModal, { + id: artifactId, + name, + onSuccess: () => { + const isLastArtifactOnPage = artifacts.length === 1; + + if (isLastArtifactOnPage) { + pagination.onPageIndexChange(Math.max(pagination.pageIndex - 1, 0)); + } + } + }) + } + ]} + /> + ) + }); + + return [ + slicedText({ + headerMessageId: "name", + headerDataTestId: "lbl_name", + accessorKey: "name", + maxTextLength: 70, + enableSorting: false + }), + slicedText({ + headerMessageId: "path", + headerDataTestId: "lbl_path", + accessorKey: "path", + maxTextLength: 70, + copy: true, + enableSorting: false + }), + markdown({ + id: "description", + accessorFn: (originalRow) => originalRow.description, + headerMessageId: "description", + headerDataTestId: "lbl_description", + enableSorting: false + }), + utcTime({ + id: "createdAt", + accessorFn: (originalRow) => originalRow.created_at, + headerMessageId: "createdAt", + headerDataTestId: "lbl_created_at", + enableSorting: false + }), + tags({ + id: "tags", + accessorFn: (originalRow) => + Object.entries(originalRow.tags ?? {}) + .map(([key, val]) => `${key}: ${val}`) + .join(" "), + getTags: (originalRow) => originalRow.tags, + enableSorting: false + }), + ...(isManageArtifactsAllowed ? [getActionsColumn()] : []) + ]; + }, [artifacts.length, isManageArtifactsAllowed, navigate, openSideModal, pagination]); + + return ( +
, + messageId: "add", + color: "success", + variant: "contained", + type: "button", + link: getCreateMlRunArtifactUrl(taskId, runId), + dataTestId: "btn_add", + requiredActions: ["EDIT_PARTNER"] + } + ] + } + }} + localization={{ emptyMessageId: "noArtifacts" }} + manualPagination={pagination} + withSearch + manualGlobalFiltering={{ + search + }} + counters={{ + showCounters: true + }} + /> + ); +}; + +export default RunArtifactsTable; diff --git a/ngui/ui/src/components/RunArtifactsTable/index.ts b/ngui/ui/src/components/RunArtifactsTable/index.ts new file mode 100644 index 00000000..8433e6f4 --- /dev/null +++ b/ngui/ui/src/components/RunArtifactsTable/index.ts @@ -0,0 +1,3 @@ +import RunArtifactsTable from "./RunArtifactsTable"; + +export default RunArtifactsTable; diff --git a/ngui/ui/src/components/SideModalManager/SideModals/MlDeleteArtifactModal.tsx b/ngui/ui/src/components/SideModalManager/SideModals/MlDeleteArtifactModal.tsx new file mode 100644 index 00000000..c9848fa1 --- /dev/null +++ b/ngui/ui/src/components/SideModalManager/SideModals/MlDeleteArtifactModal.tsx @@ -0,0 +1,28 @@ +import MlDeleteArtifactContainer from "containers/MlDeleteArtifactContainer"; +import BaseSideModal from "./BaseSideModal"; + +class MlDeleteArtifactModal extends BaseSideModal { + headerProps = { + messageId: "mlDeleteArtifactTitle", + color: "error", + dataTestIds: { + title: "lbl_delete_ml_artifact", + closeButton: "btn_close" + } + }; + + dataTestId = "smodal_delete"; + + get content() { + return ( + + ); + } +} + +export default MlDeleteArtifactModal; diff --git a/ngui/ui/src/components/SideModalManager/SideModals/index.ts b/ngui/ui/src/components/SideModalManager/SideModals/index.ts index 227916e0..10e24e44 100644 --- a/ngui/ui/src/components/SideModalManager/SideModals/index.ts +++ b/ngui/ui/src/components/SideModalManager/SideModals/index.ts @@ -43,6 +43,7 @@ import EnvironmentCostModelModal from "./EnvironmentCostModelModal"; import ExcludePoolsFromRecommendationModal from "./ExcludePoolsFromRecommendationModal"; import KubernetesIntegrationModal from "./KubernetesIntegrationModal"; import LeaderboardRunGroupDetailsModal from "./LeaderboardRunGroupDetailsModal"; +import MlDeleteArtifactModal from "./MlDeleteArtifactModal"; import MlDeleteDatasetModal from "./MlDeleteDatasetModal"; import MlDeleteTaskModal from "./MlDeleteTaskModal"; import PoolModal from "./PoolModal"; @@ -131,5 +132,6 @@ export { DeleteMlModelModal, EditModelVersionAliasModal, EditModelPathModal, - EditModelVersionTagsModal + EditModelVersionTagsModal, + MlDeleteArtifactModal }; diff --git a/ngui/ui/src/components/Table/Table.tsx b/ngui/ui/src/components/Table/Table.tsx index b0f7898d..ac6736d0 100644 --- a/ngui/ui/src/components/Table/Table.tsx +++ b/ngui/ui/src/components/Table/Table.tsx @@ -87,7 +87,9 @@ const Table = ({ disableBottomBorderForLastRow = false, tableLayout = "auto", enableSearchQueryParam, - enablePaginationQueryParam + enablePaginationQueryParam, + manualPagination, + manualGlobalFiltering }) => { const { showCounters = false, hideTotal = false, hideDisplayed = false } = counters; @@ -164,9 +166,12 @@ const Table = ({ const initialSortingState = useInitialSortingState(columns); + const isManualPagination = !!manualPagination; + const isManualFiltering = !!manualGlobalFiltering; + const tableState = { - ...globalFilterState, - ...paginationState, + ...(isManualFiltering ? {} : globalFilterState), + ...(isManualPagination ? {} : paginationState), ...expandedState, ...columnOrderState, ...rowSelectionState, @@ -174,8 +179,8 @@ const Table = ({ }; const tableOptions = { - ...globalFilterTableOptions, - ...paginationTableOptions, + ...(isManualFiltering ? {} : globalFilterTableOptions), + ...(isManualPagination ? {} : paginationTableOptions), ...expandedTableOptions, ...columnsVisibilityTableOptions, ...columnOrderTableOptions, @@ -208,17 +213,56 @@ const Table = ({ const { rows } = table.getRowModel(); + const getPaginationSettings = () => { + if (isManualPagination) { + return { + pageCount: manualPagination.pageCount, + pageIndex: manualPagination.pageIndex, + pageSize: manualPagination.pageSize, + onPageIndexChange: manualPagination.onPageIndexChange, + totalRows: manualPagination.totalRows + }; + } + + return { + pageCount: table.getPageCount(), + pageIndex: table.getState().pagination.pageIndex, + pageSize: table.getState().pagination.pageSize, + onPageIndexChange: table.setPageIndex, + totalRows: data.length + }; + }; + + const paginationSettings = getPaginationSettings(); + + const getFilterOptions = () => { + if (isManualFiltering) { + return { + withSearch, + onSearchChange: withSearch ? manualGlobalFiltering.search.onChange : undefined, + searchValue: withSearch ? manualGlobalFiltering.search.value : undefined, + rangeFilter, + rangeValue: rangeFilter ? manualGlobalFiltering.rangeFilter.range : undefined, + onRangeChange: rangeFilter ? manualGlobalFiltering.rangeFilter.onChange : undefined + }; + } + + return { + withSearch, + searchValue: withSearch ? globalFilterState.globalFilter.search : undefined, + onSearchChange: withSearch ? (newSearchValue) => onSearchChange(newSearchValue, { tableContext: table }) : null, + rangeFilter, + rangeValue: rangeFilter ? globalFilterState.globalFilter.range : undefined, + onRangeChange: rangeFilter ? (newRangeValue) => onRangeChange(newRangeValue, { tableContext: table }) : null + }; + }; + return (
onSearchChange(newSearchValue, { tableContext: table }) : null} - searchValue={withSearch ? globalFilterState.globalFilter.search : undefined} - onRangeChange={rangeFilter ? (newRangeValue) => onRangeChange(newRangeValue, { tableContext: table }) : null} - rangeValue={rangeFilter ? globalFilterState.globalFilter.range : undefined} - rangeFilter={rangeFilter} + {...getFilterOptions()} tableContext={table} columnsSelectorUID={columnsSelectorUID} columnSetsSelectorId={columnSetsSelectorId} @@ -348,18 +392,18 @@ const Table = ({ showCounters={showCounters} hideTotal={hideTotal} hideDisplayed={hideDisplayed} - totalNumber={counters.total || data.length} + totalNumber={counters.total || paginationSettings.totalRows} rowsCount={rows.length} selectedRowsCount={selectedRowsCounts} dataTestIds={dataTestIds.infoAreaTestIds} showAllLink={showAllLink} - tableContext={table} + pagination={paginationSettings} /> - {table.getPageCount() > 1 && ( + {paginationSettings.pageCount > 1 && ( table.setPageIndex(newPageIndex)} + count={paginationSettings.pageCount} + page={paginationSettings.pageIndex + 1} + paginationHandler={paginationSettings.onPageIndexChange} /> )} diff --git a/ngui/ui/src/components/Table/components/InfoArea/InfoArea.tsx b/ngui/ui/src/components/Table/components/InfoArea/InfoArea.tsx index 9ad8ac53..9ba8b592 100644 --- a/ngui/ui/src/components/Table/components/InfoArea/InfoArea.tsx +++ b/ngui/ui/src/components/Table/components/InfoArea/InfoArea.tsx @@ -5,14 +5,14 @@ import { Link as RouterLink } from "react-router-dom"; import KeyValueLabel from "components/KeyValueLabel/KeyValueLabel"; import useStyles from "./InfoArea.styles"; -const DisplayedLabel = ({ tableContext, rowsCount, totalNumber, dataTestIds }) => { +const DisplayedLabel = ({ rowsCount, totalNumber, pagination, dataTestIds }) => { const getDisplayedValue = () => { - if (tableContext.getPageCount() <= 1) { + const { pageCount, pageIndex, pageSize } = pagination; + + if (pageCount <= 1 || pageIndex + 1 > pageCount) { return rowsCount; } - const { pageIndex, pageSize } = tableContext.getState().pagination; - const from = pageIndex * pageSize + 1; const currentPageLastRowIndex = (pageIndex + 1) * pageSize; @@ -44,7 +44,7 @@ const InfoArea = ({ selectedRowsCount = 0, dataTestIds = {}, showAllLink, - tableContext + pagination }) => { const { classes } = useStyles(); const { showAll: showAllDataTestId = null } = dataTestIds; @@ -66,10 +66,10 @@ const InfoArea = ({ )} {hideDisplayed ? null : ( )} {selectedRowsCount !== 0 && ( diff --git a/ngui/ui/src/components/forms/MlArtifactForm/FormElements/DescriptionField.tsx b/ngui/ui/src/components/forms/MlArtifactForm/FormElements/DescriptionField.tsx new file mode 100644 index 00000000..5fabab5a --- /dev/null +++ b/ngui/ui/src/components/forms/MlArtifactForm/FormElements/DescriptionField.tsx @@ -0,0 +1,24 @@ +import { FormattedMessage, useIntl } from "react-intl"; +import { TextInput } from "components/forms/common/fields"; +import { FIELD_NAMES } from "../constants"; + +const FIELD_NAME = FIELD_NAMES.DESCRIPTION; + +const DescriptionField = ({ isLoading = false }) => { + const intl = useIntl(); + + return ( + } + minRows={6} + maxRows={16} + multiline + placeholder={intl.formatMessage({ id: "markdownIsSupported" })} + isLoading={isLoading} + /> + ); +}; + +export default DescriptionField; diff --git a/ngui/ui/src/components/forms/MlArtifactForm/FormElements/FormButtons.tsx b/ngui/ui/src/components/forms/MlArtifactForm/FormElements/FormButtons.tsx new file mode 100644 index 00000000..8cd788d9 --- /dev/null +++ b/ngui/ui/src/components/forms/MlArtifactForm/FormElements/FormButtons.tsx @@ -0,0 +1,29 @@ +import { Box } from "@mui/material"; +import Button from "components/Button"; +import ButtonLoader from "components/ButtonLoader"; +import FormButtonsWrapper from "components/FormButtonsWrapper"; +import { useOrganizationInfo } from "hooks/useOrganizationInfo"; + +const FormButtons = ({ onCancel, isLoading = false }) => { + const { isDemo } = useOrganizationInfo(); + + return ( + + + +