diff --git a/nbkickoff/nbkickoff.py b/nbkickoff/nbkickoff.py index 9847299..340c134 100644 --- a/nbkickoff/nbkickoff.py +++ b/nbkickoff/nbkickoff.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 -from .__about__ import __summary__ +from .__about__ import (__summary__, __title__) import logging import os from pathlib import Path -import shutil import sys import webbrowser @@ -59,19 +58,38 @@ def open_notebook(notebook_file): launch_detached_process(sys.executable, '-m', 'notebook', '--NotebookApp.open_browser=True', notebook_file) -def kickoff(template_file, target_file): - create_notebook_from_template(template_file, target_file, {}) +def kickoff(template_file, target_file, variables): + create_notebook_from_template(template_file, target_file, variables) open_notebook(target_file) def main(): import argparse - parser = argparse.ArgumentParser(description=__summary__) - parser.add_argument('-t', '--template', required=True, help='Template notebook file') - parser.add_argument('-f', '--file', required=True, help='Target notebook file to create') + + def parse_template_var(s): + if '=' not in s: + raise argparse.ArgumentTypeError('Invalid template variable specification: ' + s) + parts = s.split('=', 1) + return {parts[0]: parts[1]} + + class DictMergeAction(argparse.Action): + def __call__(self, p, namespace, values, option_string=None): + current_dict = dict() + if hasattr(namespace, self.dest) and getattr(namespace, self.dest) is not None: + current_dict = getattr(namespace, self.dest) + for v in values: + current_dict = {**current_dict, **v} + setattr(namespace, self.dest, current_dict) + + parser = argparse.ArgumentParser(prog=__title__, description=__summary__) + parser.add_argument('template', help='Template notebook file') + parser.add_argument('target', help='Target notebook file to create') + parser.add_argument('variable', action=DictMergeAction, type=parse_template_var, nargs='*', + help='Template variable in the format NAME=VALUE') args = parser.parse_args() - kickoff(args.template, args.file) + kickoff(args.template, args.target, args.variable) if __name__ == '__main__': main() +