Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert check new samples command to cron job #4394

Merged
merged 16 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 68 additions & 41 deletions seqr/management/commands/check_for_new_samples_from_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from django.db.models.functions import JSONObject
import json
import logging
import re

from reference_data.models import GENOME_VERSION_LOOKUP
from seqr.models import Family, Sample, SavedVariant
from seqr.utils.communication_utils import safe_post_to_slack
from seqr.utils.file_utils import file_iter, does_file_exist
from seqr.utils.file_utils import file_iter, list_files
from seqr.utils.search.add_data_utils import notify_search_data_loaded
from seqr.utils.search.utils import parse_valid_variant_id
from seqr.utils.search.hail_search_utils import hail_variant_multi_lookup, search_data_type
Expand All @@ -18,11 +19,13 @@
from seqr.views.utils.permissions_utils import is_internal_anvil_project, project_has_anvil
from seqr.views.utils.variant_utils import reset_cached_search_results, update_projects_saved_variant_json, \
get_saved_variants
from settings import SEQR_SLACK_LOADING_NOTIFICATION_CHANNEL, BASE_URL
from settings import SEQR_SLACK_LOADING_NOTIFICATION_CHANNEL, HAIL_SEARCH_DATA_DIR

logger = logging.getLogger(__name__)

GS_PATH_TEMPLATE = 'gs://seqr-hail-search-data/v3.1/{path}/runs/{version}/'
RUN_SUCCESS_PATH_TEMPLATE = '{data_dir}/{genome_version}/{dataset_type}/runs/{run_version}/_SUCCESS'
RUN_PATH_FIELDS = ['genome_version', 'dataset_type', 'run_version']

DATASET_TYPE_MAP = {'GCNV': Sample.DATASET_TYPE_SV_CALLS}
USER_EMAIL = 'manage_command'
MAX_LOOKUP_VARIANTS = 5000
Expand All @@ -38,33 +41,65 @@ class Command(BaseCommand):
help = 'Check for newly loaded seqr samples'

def add_arguments(self, parser):
parser.add_argument('path')
parser.add_argument('version')
parser.add_argument('--allow-failed', action='store_true')
parser.add_argument('--genome_version')
parser.add_argument('--dataset_type')
parser.add_argument('--run-version')

def handle(self, *args, **options):
path = options['path']
version = options['version']
genome_version, dataset_type = path.split('/')
dataset_type = DATASET_TYPE_MAP.get(dataset_type, dataset_type)

if Sample.objects.filter(data_source=version, is_active=True).exists():
logger.info(f'Data already loaded for {path}: {version}')
path = self._run_success_path(lambda field: options[field] or '*')
path_regex = self._run_success_path(lambda field: f'(?P<{field}>[^/]+)')
success_runs = {path: re.match(path_regex, path).groupdict() for path in list_files(path, user=None)}
if not success_runs:
user_args = [f'{k}={options[k]}' for k in RUN_PATH_FIELDS if options[k]]
raise CommandError(f'No successful runs found for {", ".join(user_args)}')

loaded_runs = set(Sample.objects.filter(data_source__isnull=False).values_list('data_source', flat=True))
new_runs = {path: run for path, run in success_runs.items() if run['run_version'] not in loaded_runs}
if not new_runs:
logger.info(f'Data already loaded for all {len(success_runs)} runs')
return

logger.info(f'Loading new samples from {path}: {version}')
gs_path = GS_PATH_TEMPLATE.format(path=path, version=version)
if not does_file_exist(gs_path + '_SUCCESS'):
if options['allow_failed']:
logger.warning(f'Loading for failed run {path}: {version}')
else:
raise CommandError(f'Run failed for {path}: {version}, unable to load data')
logger.info(f'Loading new samples from {len(success_runs)} run(s)')
updated_families_by_data_type = defaultdict(set)
updated_variants_by_data_type = defaultdict(dict)
for path, run in new_runs.items():
try:
metadata_path = path.replace('_SUCCESS', 'metadata.json')
data_type, updated_families, updated_variants_by_id = self._load_new_samples(metadata_path, **run)
data_type_key = (data_type, run['genome_version'])
updated_families_by_data_type[data_type_key].update(updated_families)
updated_variants_by_data_type[data_type_key].update(updated_variants_by_id)
except Exception as e:
logger.error(f'Error loading {run["run_version"]}: {e}')

# Reset cached results for all projects, as seqr AFs will have changed for all projects when new data is added
reset_cached_search_results(project=None)

for data_type_key, updated_families in updated_families_by_data_type.items():
self._reload_shared_variant_annotations(
*data_type_key, updated_variants_by_data_type[data_type_key], exclude_families=updated_families,
)

logger.info('DONE')

metadata = json.loads(next(line for line in file_iter(gs_path + 'metadata.json')))
@staticmethod
def _run_success_path(get_field_format):
return RUN_SUCCESS_PATH_TEMPLATE.format(
data_dir=HAIL_SEARCH_DATA_DIR,
**{field: get_field_format(field) for field in RUN_PATH_FIELDS}
)

@classmethod
def _load_new_samples(cls, metadata_path, genome_version, dataset_type, run_version):
dataset_type = DATASET_TYPE_MAP.get(dataset_type, dataset_type)

logger.info(f'Loading new samples from {genome_version}/{dataset_type}: {run_version}')

metadata = json.loads(next(line for line in file_iter(metadata_path)))
families = Family.objects.filter(guid__in=metadata['family_samples'].keys())
if len(families) < len(metadata['family_samples']):
invalid = metadata['family_samples'].keys() - set(families.values_list('guid', flat=True))
raise CommandError(f'Invalid families in run metadata {path}: {version} - {", ".join(invalid)}')
raise CommandError(f'Invalid families in run metadata {genome_version}/{dataset_type}: {run_version} - {", ".join(invalid)}')

family_project_map = {f.guid: f.project for f in families.select_related('project')}
samples_by_project = defaultdict(list)
Expand All @@ -90,15 +125,12 @@ def handle(self, *args, **options):
updated_samples, inactivated_sample_guids, *args = match_and_update_search_samples(
projects=samples_by_project.keys(),
sample_project_tuples=sample_project_tuples,
sample_data={'data_source': version, 'elasticsearch_index': ';'.join(metadata['callsets'])},
sample_data={'data_source': run_version, 'elasticsearch_index': ';'.join(metadata['callsets'])},
sample_type=sample_type,
dataset_type=dataset_type,
user=None,
)

# Reset cached results for all projects, as seqr AFs will have changed for all projects when new data is added
reset_cached_search_results(project=None)

# Send loading notifications and update Airtable PDOs
update_sample_data_by_project = {
s['individual__family__project']: s for s in updated_samples.values('individual__family__project').annotate(
Expand All @@ -121,7 +153,7 @@ def handle(self, *args, **options):
updated_families.update(project_families)
updated_project_families.append((project.id, project.name, project.genome_version, project_families))
if is_internal and dataset_type == Sample.DATASET_TYPE_VARIANT_CALLS:
split_project_pdos[project.name] = self._update_pdos(session, project.guid, sample_ids)
split_project_pdos[project.name] = cls._update_pdos(session, project.guid, sample_ids)

# Send failure notifications
failed_family_samples = metadata.get('failed_family_samples', {})
Expand Down Expand Up @@ -149,25 +181,19 @@ def handle(self, *args, **options):
updated_variants_by_id = update_projects_saved_variant_json(
updated_project_families, user_email=USER_EMAIL, dataset_type=dataset_type)

self._reload_shared_variant_annotations(
search_data_type(dataset_type, sample_type), genome_version, updated_variants_by_id, exclude_families=updated_families)

logger.info('DONE')
return search_data_type(dataset_type, sample_type), updated_families, updated_variants_by_id

@staticmethod
def _update_pdos(session, project_guid, sample_ids):
airtable_samples = session.fetch_records(
'Samples', fields=['CollaboratorSampleID', 'SeqrCollaboratorSampleID', 'PDOID'],
or_filters={'PDOStatus': LOADABLE_PDO_STATUSES},
and_filters={'SeqrProject': f'{BASE_URL}project/{project_guid}/project_page'}
airtable_samples = session.get_samples_for_matched_pdos(
LOADABLE_PDO_STATUSES, pdo_fields=['PDOID'], project_guid=project_guid,
)

pdo_ids = set()
skipped_pdo_samples = defaultdict(list)
for record_id, sample in airtable_samples.items():
pdo_id = sample['PDOID'][0]
sample_id = sample.get('SeqrCollaboratorSampleID') or sample['CollaboratorSampleID']
if sample_id in sample_ids:
pdo_id = sample['pdos'][0]['PDOID']
if sample['sample_id'] in sample_ids:
pdo_ids.add(pdo_id)
else:
skipped_pdo_samples[pdo_id].append(record_id)
Expand Down Expand Up @@ -220,15 +246,16 @@ def _reload_shared_variant_annotations(data_type, genome_version, updated_varian
family_guids=updated_annotation_samples.values_list('individual__family__guid', flat=True).distinct(),
)

variant_type_summary = f'{data_type} {genome_version} saved variants'
if not variant_models:
logger.info('No additional saved variants to update')
logger.info(f'No additional {variant_type_summary} to update')
return

variants_by_id = defaultdict(list)
for v in variant_models:
variants_by_id[v.variant_id].append(v)

logger.info(f'Reloading shared annotations for {len(variant_models)} {data_type} {genome_version} saved variants ({len(variants_by_id)} unique)')
logger.info(f'Reloading shared annotations for {len(variant_models)} {variant_type_summary} ({len(variants_by_id)} unique)')

updated_variants_by_id = {
variant_id: {k: v for k, v in variant.items() if k not in {'familyGuids', 'genotypes'}}
Expand All @@ -250,7 +277,7 @@ def _reload_shared_variant_annotations(data_type, genome_version, updated_varian
updated_variant_models.append(variant_model)

SavedVariant.objects.bulk_update(updated_variant_models, ['saved_variant_json'], batch_size=10000)
logger.info(f'Updated {len(updated_variant_models)} saved variants')
logger.info(f'Updated {len(updated_variant_models)} {variant_type_summary}')


reload_shared_variant_annotations = Command._reload_shared_variant_annotations
Loading
Loading