diff --git a/examples/flywheel_analyzer_afq.py b/examples/flywheel_analyzer_afq.py index f451419..04da070 100644 --- a/examples/flywheel_analyzer_afq.py +++ b/examples/flywheel_analyzer_afq.py @@ -19,7 +19,8 @@ def afq_inputs(analyses, **kwargs): ) if __name__ == '__main__': - fa.run([ - fa.define_analysis('dtiinit', dtiinit_inputs), - fa.define_analysis('afq', afq_inputs), - ], project=fa.find_project(label='ENGAGE')) + with fa.installed_client(): + fa.run([ + fa.define_analysis('dtiinit', dtiinit_inputs), + fa.define_analysis('afq', afq_inputs), + ], project=fa.find_project(label='ENGAGE')) diff --git a/examples/flywheel_analyzer_engage.py b/examples/flywheel_analyzer_engage.py index efa7d42..c4f223a 100644 --- a/examples/flywheel_analyzer_engage.py +++ b/examples/flywheel_analyzer_engage.py @@ -1,4 +1,5 @@ import scitran_client.flywheel_analyzer as fa +from scitran_client import ScitranClient # XXX at least make this be just the first thing without ' 2'? @@ -17,7 +18,7 @@ def _find_file(container, glob): return ( - container.find_file(glob) or + container.find_file(glob, default=None) or # HACK because flywheel does not currently support nested files # in output folders, we are flattening hierarchy by replacing # forward slashes with @@ @@ -80,16 +81,17 @@ def first_level_model_inputs(acquisition_label, analyses, acquisitions): ), dict(task_type=label_to_task_type[acquisition_label]) if __name__ == '__main__': - fa.run([ - define_analysis('reactivity-preprocessing', 'go-no-go 2', reactivity_inputs), - define_analysis('connectivity-preprocessing', 'go-no-go 2', connectivity_inputs), - define_analysis('first-level-models', 'go-no-go 2', first_level_model_inputs), - - define_analysis('reactivity-preprocessing', 'conscious 2', reactivity_inputs), - define_analysis('connectivity-preprocessing', 'conscious 2', connectivity_inputs), - define_analysis('first-level-models', 'conscious 2', first_level_model_inputs), - - define_analysis('reactivity-preprocessing', 'nonconscious 2', reactivity_inputs), - define_analysis('connectivity-preprocessing', 'nonconscious 2', connectivity_inputs), - define_analysis('first-level-models', 'nonconscious 2', first_level_model_inputs), - ], project=fa.find_project(label='ENGAGE'), session_limit=1) + with fa.installed_client(ScitranClient('https://flywheel-cni.scitran.stanford.edu')): + fa.run([ + define_analysis('reactivity-preprocessing', 'go-no-go 2', reactivity_inputs), + define_analysis('connectivity-preprocessing', 'go-no-go 2', connectivity_inputs), + define_analysis('first-level-models', 'go-no-go 2', first_level_model_inputs), + + define_analysis('reactivity-preprocessing', 'conscious 2', reactivity_inputs), + define_analysis('connectivity-preprocessing', 'conscious 2', connectivity_inputs), + define_analysis('first-level-models', 'conscious 2', first_level_model_inputs), + + define_analysis('reactivity-preprocessing', 'nonconscious 2', reactivity_inputs), + define_analysis('connectivity-preprocessing', 'nonconscious 2', connectivity_inputs), + define_analysis('first-level-models', 'nonconscious 2', first_level_model_inputs), + ], project=fa.find_project(label='ENGAGE'), session_limit=1) diff --git a/examples/flywheel_analyzer_showdes.py b/examples/flywheel_analyzer_showdes.py index 5aefe21..4f38894 100644 --- a/examples/flywheel_analyzer_showdes.py +++ b/examples/flywheel_analyzer_showdes.py @@ -1,8 +1,12 @@ +from scitran_client import ScitranClient import scitran_client.flywheel_analyzer as fa -D99 = fa.find( - fa.request('sessions/588bd1ac449f9800159305c2/acquisitions'), - label='atlas') +client = ScitranClient('https://flywheel.scitran.stanford.edu') + +with fa.installed_client(client): + D99 = fa.find( + client.request('sessions/588bd1ac449f9800159305c2/acquisitions').json(), + label='atlas') def anatomical_warp_inputs(acquisitions, **kwargs): @@ -14,6 +18,7 @@ def anatomical_warp_inputs(acquisitions, **kwargs): ) if __name__ == '__main__': - fa.run([ - fa.define_analysis('afni-brain-warp', anatomical_warp_inputs, label='anatomical warp'), - ], project=fa.find_project(label='showdes'), max_workers=2) + with fa.installed_client(client): + fa.run([ + fa.define_analysis('afni-brain-warp', anatomical_warp_inputs, label='anatomical warp'), + ], project=fa.find_project(label='showdes'), max_workers=2) diff --git a/scitran_client/flywheel_analyzer.py b/scitran_client/flywheel_analyzer.py index 090fa7e..2d009e2 100644 --- a/scitran_client/flywheel_analyzer.py +++ b/scitran_client/flywheel_analyzer.py @@ -7,6 +7,7 @@ from fnmatch import fnmatch from collections import namedtuple, Counter import math +from contextlib import contextmanager def _sleep(seconds): @@ -41,26 +42,46 @@ def define_analysis(gear_name, create_inputs, label=None): class FlywheelFileContainer(dict): - def find_file(self, pattern): + def find_file(self, pattern, **kwargs): '''Find a file in this container with a name that matches pattern. This will look for a file in this container that matches the supplied pattern. Matching uses the fnmatch python library, which does Unix filename pattern matching. + kwargs['default'] - like `next`, when a default is supplied, it will be + returned when there are no matches. When a default is not supplied, an + exception will be thrown. + To find a specific file, you can simply match by name: > acquisition.find_file('anatomical.nii.gz') To find a file with an extension, you can use a Unix-style pattern: > stimulus_onsets.find_file('*.txt') + + When looking for a file that might be missing, supply a default value: + > partial_set_of_files.find_file('*.txt', default=None) ''' - # TODO make sure this throws for missing or multiple files. - is_analysis = 'job' in self and 'state' in self + has_default = 'default' in kwargs + is_analysis = 'job' in self # XXX if is_analysis then we should require the file to be an output?? - f = next( + matches = [ f for f in self['files'] - if fnmatch(f['name'], pattern)) + if fnmatch(f['name'], pattern)] + + assert len(matches) <= 1, ( + 'Multiple matches found for pattern "{}" in container {}. Patterns should uniquely identify a file.' + .format(pattern, self['_id'])) + if not matches: + if has_default: + return kwargs.get('default') + else: + raise Exception( + 'Could not find a match for "{}" in container {}.' + .format(pattern, self['_id'])) + + f = matches[0] return dict( type='analysis' if is_analysis else 'acquisition', @@ -104,8 +125,7 @@ class ShuttingDownException(Exception): def request(*args, **kwargs): # HACK client is a module variable for now. In the future, we should pass client around. - if 'client' not in state: - state['client'] = ScitranClient() + assert 'client' in state, 'client must be installed in state before using request. See `installed_client`.' response = state['client']._request(*args, **kwargs) return json.loads(response.text) @@ -223,6 +243,24 @@ def done(f): raise +@contextmanager +def installed_client(client=None): + ''' + This context manager handles the installation of a scitran client for use + in the flywheel analyzer. Most flywheel analyzer code depends on this being + set up. + + > with installed_client(): + > print fa.find_project(label='ADHD study') # actually works! + ''' + # BIG HACK + state['client'] = client or ScitranClient() + try: + yield state['client'] + finally: + state['client'] = None + + def run(operations, project=None, max_workers=10, session_limit=None): """Run a sequence of FlywheelAnalysisOperations. diff --git a/scitran_client/st_auth.py b/scitran_client/st_auth.py index 37dff87..e6f75a1 100644 --- a/scitran_client/st_auth.py +++ b/scitran_client/st_auth.py @@ -25,17 +25,19 @@ def _prompt_for_valid_api_key(url): return api_key -def create_token(instance_name, config_dir): +def create_token(instance_name_or_host, config_dir): ''' Get an API key for this instance, requesting a new one if no previous one exists. Args: - instance_name (str): The instance to generate a token for. + instance_name (str): The instance to generate a token for. Should only have one of `instance_name` or `host`. + host (str): The host we are are trying to generate a token for. config_dir (str): Path of directory where the tokens live. Returns: Python tuple: (token (str), client_url (str)): (The requested token, the base url for this client) ''' + if not os.path.exists(config_dir): os.mkdir(config_dir) @@ -47,13 +49,21 @@ def create_token(instance_name, config_dir): with open(auth_path, 'r') as f: auth_config = json.load(f) - auth = auth_config.get(instance_name) + matches = [ + a + for name, a in auth_config.iteritems() + if a['url'] == instance_name_or_host or name == instance_name_or_host + ] + assert len(matches) <= 1, \ + 'Too many matches for for {}. found: {}'.format(instance_name_or_host, matches) + + auth = matches and matches[0] example = json.dumps(dict(api_key='', url='https://myflywheel.io'), indent=4) assert isinstance(auth, dict) and set(auth.keys()) == {'api_key', 'url'}, ( 'Missing or invalid entry in {0} for instance {1}. You can fix this issue by ' 'adding an entry for {1} or making it look more like this: {2}' - .format(auth_path, instance_name, example)) + .format(auth_path, instance_name_or_host, example)) # We just wipe out keys that are invalid. if auth['api_key'] and not _is_valid_token(auth['url'], auth['api_key']): diff --git a/scitran_client/st_client.py b/scitran_client/st_client.py index c4fa8ee..9e8e01e 100644 --- a/scitran_client/st_client.py +++ b/scitran_client/st_client.py @@ -55,8 +55,7 @@ class ScitranClient(object): '''Handles api calls to a certain instance. Attributes: - instance (str): instance name. - base_url (str): The base url of that instance, as returned by stAuth.create_token(instance, st_dir) + instance_name (str): instance name or host. token (str): Authentication token. st_dir (str): The path to the directory where token and authentication file are kept for this instance. ''' @@ -70,7 +69,7 @@ def __init__(self, gear_out_dir=DEFAULT_OUTPUT_DIR): self.session = requests.Session() - self.instance = instance_name + self.instance_name_or_host = instance_name self.st_dir = st_dir self._authenticate() self.debug = debug @@ -139,7 +138,7 @@ def _authenticate_request(self, request): return request def _authenticate(self): - self.token, self.base_url = st_auth.create_token(self.instance, self.st_dir) + self.token, self.base_url = st_auth.create_token(self.instance_name_or_host, self.st_dir) self.base_url = urlparse.urljoin(self.base_url, 'api/') def _request(self, *args, **kwargs):