Skip to content

Commit

Permalink
Merge pull request #8 from scitran/cgc/host-support
Browse files Browse the repository at this point in the history
Add support for hosts.
  • Loading branch information
cgc authored Mar 22, 2017
2 parents 58407ce + 0c25f7f commit 3b40324
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 39 deletions.
9 changes: 5 additions & 4 deletions examples/flywheel_analyzer_afq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def afq_inputs(analyses, **kwargs):
)

if __name__ == '__main__':
fa.run([
fa.define_analysis('dtiinit', dtiinit_inputs),
fa.define_analysis('afq', afq_inputs),
], project=fa.find_project(label='ENGAGE'))
with fa.installed_client():
fa.run([
fa.define_analysis('dtiinit', dtiinit_inputs),
fa.define_analysis('afq', afq_inputs),
], project=fa.find_project(label='ENGAGE'))
30 changes: 16 additions & 14 deletions examples/flywheel_analyzer_engage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import scitran_client.flywheel_analyzer as fa
from scitran_client import ScitranClient


# XXX at least make this be just the first thing without ' 2'?
Expand All @@ -17,7 +18,7 @@

def _find_file(container, glob):
return (
container.find_file(glob) or
container.find_file(glob, default=None) or
# HACK because flywheel does not currently support nested files
# in output folders, we are flattening hierarchy by replacing
# forward slashes with @@
Expand Down Expand Up @@ -80,16 +81,17 @@ def first_level_model_inputs(acquisition_label, analyses, acquisitions):
), dict(task_type=label_to_task_type[acquisition_label])

if __name__ == '__main__':
fa.run([
define_analysis('reactivity-preprocessing', 'go-no-go 2', reactivity_inputs),
define_analysis('connectivity-preprocessing', 'go-no-go 2', connectivity_inputs),
define_analysis('first-level-models', 'go-no-go 2', first_level_model_inputs),

define_analysis('reactivity-preprocessing', 'conscious 2', reactivity_inputs),
define_analysis('connectivity-preprocessing', 'conscious 2', connectivity_inputs),
define_analysis('first-level-models', 'conscious 2', first_level_model_inputs),

define_analysis('reactivity-preprocessing', 'nonconscious 2', reactivity_inputs),
define_analysis('connectivity-preprocessing', 'nonconscious 2', connectivity_inputs),
define_analysis('first-level-models', 'nonconscious 2', first_level_model_inputs),
], project=fa.find_project(label='ENGAGE'), session_limit=1)
with fa.installed_client(ScitranClient('https://flywheel-cni.scitran.stanford.edu')):
fa.run([
define_analysis('reactivity-preprocessing', 'go-no-go 2', reactivity_inputs),
define_analysis('connectivity-preprocessing', 'go-no-go 2', connectivity_inputs),
define_analysis('first-level-models', 'go-no-go 2', first_level_model_inputs),

define_analysis('reactivity-preprocessing', 'conscious 2', reactivity_inputs),
define_analysis('connectivity-preprocessing', 'conscious 2', connectivity_inputs),
define_analysis('first-level-models', 'conscious 2', first_level_model_inputs),

define_analysis('reactivity-preprocessing', 'nonconscious 2', reactivity_inputs),
define_analysis('connectivity-preprocessing', 'nonconscious 2', connectivity_inputs),
define_analysis('first-level-models', 'nonconscious 2', first_level_model_inputs),
], project=fa.find_project(label='ENGAGE'), session_limit=1)
17 changes: 11 additions & 6 deletions examples/flywheel_analyzer_showdes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from scitran_client import ScitranClient
import scitran_client.flywheel_analyzer as fa

D99 = fa.find(
fa.request('sessions/588bd1ac449f9800159305c2/acquisitions'),
label='atlas')
client = ScitranClient('https://flywheel.scitran.stanford.edu')

with fa.installed_client(client):
D99 = fa.find(
client.request('sessions/588bd1ac449f9800159305c2/acquisitions').json(),
label='atlas')


def anatomical_warp_inputs(acquisitions, **kwargs):
Expand All @@ -14,6 +18,7 @@ def anatomical_warp_inputs(acquisitions, **kwargs):
)

if __name__ == '__main__':
fa.run([
fa.define_analysis('afni-brain-warp', anatomical_warp_inputs, label='anatomical warp'),
], project=fa.find_project(label='showdes'), max_workers=2)
with fa.installed_client(client):
fa.run([
fa.define_analysis('afni-brain-warp', anatomical_warp_inputs, label='anatomical warp'),
], project=fa.find_project(label='showdes'), max_workers=2)
52 changes: 45 additions & 7 deletions scitran_client/flywheel_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fnmatch import fnmatch
from collections import namedtuple, Counter
import math
from contextlib import contextmanager


def _sleep(seconds):
Expand Down Expand Up @@ -41,26 +42,46 @@ def define_analysis(gear_name, create_inputs, label=None):


class FlywheelFileContainer(dict):
def find_file(self, pattern):
def find_file(self, pattern, **kwargs):
'''Find a file in this container with a name that matches pattern.
This will look for a file in this container that matches the supplied
pattern. Matching uses the fnmatch python library, which does Unix
filename pattern matching.
kwargs['default'] - like `next`, when a default is supplied, it will be
returned when there are no matches. When a default is not supplied, an
exception will be thrown.
To find a specific file, you can simply match by name:
> acquisition.find_file('anatomical.nii.gz')
To find a file with an extension, you can use a Unix-style pattern:
> stimulus_onsets.find_file('*.txt')
When looking for a file that might be missing, supply a default value:
> partial_set_of_files.find_file('*.txt', default=None)
'''
# TODO make sure this throws for missing or multiple files.
is_analysis = 'job' in self and 'state' in self
has_default = 'default' in kwargs
is_analysis = 'job' in self

# XXX if is_analysis then we should require the file to be an output??
f = next(
matches = [
f for f in self['files']
if fnmatch(f['name'], pattern))
if fnmatch(f['name'], pattern)]

assert len(matches) <= 1, (
'Multiple matches found for pattern "{}" in container {}. Patterns should uniquely identify a file.'
.format(pattern, self['_id']))
if not matches:
if has_default:
return kwargs.get('default')
else:
raise Exception(
'Could not find a match for "{}" in container {}.'
.format(pattern, self['_id']))

f = matches[0]

return dict(
type='analysis' if is_analysis else 'acquisition',
Expand Down Expand Up @@ -104,8 +125,7 @@ class ShuttingDownException(Exception):

def request(*args, **kwargs):
# HACK client is a module variable for now. In the future, we should pass client around.
if 'client' not in state:
state['client'] = ScitranClient()
assert 'client' in state, 'client must be installed in state before using request. See `installed_client`.'
response = state['client']._request(*args, **kwargs)
return json.loads(response.text)

Expand Down Expand Up @@ -223,6 +243,24 @@ def done(f):
raise


@contextmanager
def installed_client(client=None):
'''
This context manager handles the installation of a scitran client for use
in the flywheel analyzer. Most flywheel analyzer code depends on this being
set up.
> with installed_client():
> print fa.find_project(label='ADHD study') # actually works!
'''
# BIG HACK
state['client'] = client or ScitranClient()
try:
yield state['client']
finally:
state['client'] = None


def run(operations, project=None, max_workers=10, session_limit=None):
"""Run a sequence of FlywheelAnalysisOperations.
Expand Down
18 changes: 14 additions & 4 deletions scitran_client/st_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@ def _prompt_for_valid_api_key(url):
return api_key


def create_token(instance_name, config_dir):
def create_token(instance_name_or_host, config_dir):
'''
Get an API key for this instance, requesting a new one if no previous one exists.
Args:
instance_name (str): The instance to generate a token for.
instance_name (str): The instance to generate a token for. Should only have one of `instance_name` or `host`.
host (str): The host we are are trying to generate a token for.
config_dir (str): Path of directory where the tokens live.
Returns:
Python tuple: (token (str), client_url (str)): (The requested token, the base url for this client)
'''

if not os.path.exists(config_dir):
os.mkdir(config_dir)

Expand All @@ -47,13 +49,21 @@ def create_token(instance_name, config_dir):
with open(auth_path, 'r') as f:
auth_config = json.load(f)

auth = auth_config.get(instance_name)
matches = [
a
for name, a in auth_config.iteritems()
if a['url'] == instance_name_or_host or name == instance_name_or_host
]
assert len(matches) <= 1, \
'Too many matches for for {}. found: {}'.format(instance_name_or_host, matches)

auth = matches and matches[0]

example = json.dumps(dict(api_key='<secret>', url='https://myflywheel.io'), indent=4)
assert isinstance(auth, dict) and set(auth.keys()) == {'api_key', 'url'}, (
'Missing or invalid entry in {0} for instance {1}. You can fix this issue by '
'adding an entry for {1} or making it look more like this: {2}'
.format(auth_path, instance_name, example))
.format(auth_path, instance_name_or_host, example))

# We just wipe out keys that are invalid.
if auth['api_key'] and not _is_valid_token(auth['url'], auth['api_key']):
Expand Down
7 changes: 3 additions & 4 deletions scitran_client/st_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ class ScitranClient(object):
'''Handles api calls to a certain instance.
Attributes:
instance (str): instance name.
base_url (str): The base url of that instance, as returned by stAuth.create_token(instance, st_dir)
instance_name (str): instance name or host.
token (str): Authentication token.
st_dir (str): The path to the directory where token and authentication file are kept for this instance.
'''
Expand All @@ -70,7 +69,7 @@ def __init__(self,
gear_out_dir=DEFAULT_OUTPUT_DIR):

self.session = requests.Session()
self.instance = instance_name
self.instance_name_or_host = instance_name
self.st_dir = st_dir
self._authenticate()
self.debug = debug
Expand Down Expand Up @@ -139,7 +138,7 @@ def _authenticate_request(self, request):
return request

def _authenticate(self):
self.token, self.base_url = st_auth.create_token(self.instance, self.st_dir)
self.token, self.base_url = st_auth.create_token(self.instance_name_or_host, self.st_dir)
self.base_url = urlparse.urljoin(self.base_url, 'api/')

def _request(self, *args, **kwargs):
Expand Down

0 comments on commit 3b40324

Please sign in to comment.