diff --git a/simplesat/priority_queue.py b/simplesat/priority_queue.py new file mode 100644 index 0000000..15df6d0 --- /dev/null +++ b/simplesat/priority_queue.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from collections import defaultdict +from functools import partial +from heapq import heappush, heappop +from itertools import count + +import six + + +class _REMOVED_TASK(object): + pass +REMOVED_TASK = _REMOVED_TASK() + + +class PriorityQueue(object): + """ A priority queue implementation that supports reprioritizing or + removing tasks, given that tasks are unique. + + Borrowed from: https://docs.python.org/3/library/heapq.html + """ + + def __init__(self): + # list of entries arranged in a heap + self._pq = [] + # mapping of tasks to entries + self._entry_finder = {} + # unique id genrator for tie-breaking + self._next_id = partial(next, count()) + + def __len__(self): + return len(self._entry_finder) + + def __bool__(self): + return bool(len(self)) + + def __contains__(self, task): + return task in self._entry_finder + + def clear(self): + self._pq = [] + self._entry_finder = {} + + def push(self, task, priority=0): + "Add a new task or update the priority of an existing task" + return self._push(priority, self._next_id(), task) + + def peek(self): + """ Return the task with the lowest priority. + + This will pop and repush if a REMOVED task is found. + """ + if not self._pq: + raise KeyError('peek from an empty priority queue') + entry = self._pq[0] + if entry[-1] is REMOVED_TASK: + entry = self._pop() + self._push(*entry) + return entry[-1] + + def pop(self): + 'Remove and return the lowest priority task. Raise KeyError if empty.' + _, _, task = self._pop() + return task + + def pop_many(self, n=None): + """ Return a list of length n of popped elements. If n is not + specified, pop the entire queue. """ + if n is None: + n = len(self) + result = [] + for _ in range(n): + result.append(self.pop()) + return result + + def discard(self, task): + "Remove an existing task if present. If not, do nothing." + try: + self.remove(task) + except KeyError: + pass + + def remove(self, task): + "Remove an existing task. Raise KeyError if not found." + entry = self._entry_finder.pop(task) + entry[-1] = REMOVED_TASK + + def _pop(self): + while self._pq: + entry = heappop(self._pq) + if entry[-1] is not REMOVED_TASK: + del self._entry_finder[entry[-1]] + return entry + raise KeyError('pop from an empty priority queue') + + def _push(self, priority, task_id, task): + if task in self: + o_priority, _, o_task = self._entry_finder[task] + # Still check the task, which might now be REMOVED + if priority == o_priority and task == o_task: + # We're pushing something we already have, do nothing + return + else: + # Make space for the new entry + self.remove(task) + entry = [priority, task_id, task] + self._entry_finder[task] = entry + heappush(self._pq, entry) + + +class GroupPrioritizer(object): + + """ A helper for assigning hierarchical priorities to items + according to priority groups. """ + + def __init__(self, order_key_func=lambda x: x): + """ + Parameters + ---------- + `order_key_func` : callable + used to sort items in each group. + """ + self.key_func = order_key_func + self._priority_groups = defaultdict(set) + self._item_priority = {} + self.known = frozenset() + self.dirty = True + + def __contains__(self, item): + return item in self._item_priority + + def __getitem__(self, item): + "Return the priority of an item." + if self.dirty: + self._prioritize() + return self._item_priority[item] + + def get(self, item, default=None): + if item in self: + return self[item] + return default + + def items(self): + "Return an (item, priority) iterator for all items." + if self.dirty: + self._prioritize() + return six.iteritems(self._item_priority) + + def update(self, items, group): + """Add `items` to the `group`, remove `items` from all other groups, + and update all priority values.""" + self.known = self.known.union(items) + for _group, _items in self._priority_groups.items(): + if _group != group: + _items.difference_update(items) + self._priority_groups[group].update(items) + self.dirty = True + + def group(self, group): + "Return the set of items in `group`." + if group not in self._priority_groups: + raise KeyError(repr(group)) + return self._priority_groups[group] + + def _prioritize(self): + item_priority = {} + + for group, items in six.iteritems(self._priority_groups): + ordered_items = sorted(items, key=self.key_func) + for rank, item in enumerate(ordered_items): + priority = (group, rank) + item_priority[item] = priority + + self._item_priority = item_priority + self.dirty = False diff --git a/simplesat/sat/policy/__init__.py b/simplesat/sat/policy/__init__.py index 2e8da63..8ee0ea5 100644 --- a/simplesat/sat/policy/__init__.py +++ b/simplesat/sat/policy/__init__.py @@ -4,5 +4,8 @@ from .undetermined_clause_policy import ( LoggedUndeterminedClausePolicy, UndeterminedClausePolicy ) +from .priority_queue_policy import ( + LoggedPriorityQueuePolicty, PriorityQueuePolicy +) InstalledFirstPolicy = LoggedUndeterminedClausePolicy diff --git a/simplesat/sat/policy/priority_queue_policy.py b/simplesat/sat/policy/priority_queue_policy.py new file mode 100644 index 0000000..4e0ce47 --- /dev/null +++ b/simplesat/sat/policy/priority_queue_policy.py @@ -0,0 +1,224 @@ +# -*- coding: utf-8 -*- + +from collections import defaultdict + +import six + +from simplesat.constraints.requirement import Requirement +from simplesat.utils import DefaultOrderedDict, toposort, transitive_neighbors +from simplesat.priority_queue import PriorityQueue, GroupPrioritizer +from .policy import IPolicy +from .policy_logger import PolicyLogger + + +class PriorityQueuePolicy(IPolicy): + + """ An IPolicy that uses a priority queue to determine which package id + should be suggested next. + + Packages are split into groups: + + 1. currently installed, + 2. explicitly specified as a requirement, + 3. everything else, + + where each group is arranged in topological order by dependency + relationships and then descending order by version number. + The groups are then searched in order and the first unassigned package id + is suggested. + """ + + def __init__(self, pool, installed_repository, prefer_installed=True): + self._pool = pool + self._installed_ids = set(map(pool.package_id, installed_repository)) + + package_ids = pool._id_to_package.keys() + self._package_id_to_rank = None # set the first time we check + self._all_ids = set(package_ids) + self._required_ids = set() + self._name_to_package_ids = self._group_packages_by_name(package_ids) + + def priority_func(p): + return self._package_id_to_rank[p] + + self._unassigned_pkg_ids = PriorityQueue() + + self.DEFAULT = 0 + if prefer_installed: + self.INSTALLED = -2 + self.REQUIRED = -1 + else: + self.REQUIRED = -1 + self.INSTALLED = self.DEFAULT + + self._prioritizer = GroupPrioritizer(priority_func) + self._add_packages(self._installed_ids.copy(), self.INSTALLED) + + def add_requirements(self, package_ids): + self._required_ids.update(package_ids) + if self.REQUIRED < self.INSTALLED: + self._installed_ids.difference_update(package_ids) + else: + package_ids = set(package_ids).difference(self._installed_ids) + self._add_packages(package_ids, self.REQUIRED) + + def get_next_package_id(self, assignments, clauses): + self._update_cache_from_assignments(assignments) + # Grab the most interesting looking currently unassigned id + p_id = self._unassigned_pkg_ids.peek() + return p_id + + def _add_packages(self, package_ids, group): + prioritizer = self._prioritizer + prioritizer.update(package_ids, group=group) + + # Removing an item from an ordering always maintains the ordering, + # so we only need to update priorities on groups that had items added + for pkg_id in prioritizer.group(group): + if pkg_id in self._unassigned_pkg_ids: + self._unassigned_pkg_ids.push(pkg_id, prioritizer[pkg_id]) + + def pkg_key(self, package_id): + """ Return the key used to compare two packages. """ + package = self._pool._id_to_package[package_id] + try: + installed = package.repository_info.name == 'installed' + except AttributeError: + installed = False + return (package.version, installed) + + def _rank_packages(self, package_ids): + """ Return a dictionary of package_id to priority rank. + + Currently we build a dependency tree of all the relevant packages and + then rank them topologically, starting with those at the top. + + This strategy causes packages which force more assignments via + unit propagation in the solver to be preferred. + """ + pool = self._pool + R = Requirement + + # The direct dependencies of each package + dependencies = defaultdict(set) + for package_id in package_ids: + dependencies[package_id].update( + pool.package_id(package) + for cons in pool._id_to_package[package_id].install_requires + for package in pool.what_provides(R.from_constraints(cons)) + ) + + # This is a flattened version of `dependencies` above + transitive = transitive_neighbors(dependencies) + + packages_by_name = self._group_packages_by_name(package_ids) + + # Some packages have unversioned dependencies, such as simply 'pandas'. + # This can produce cycles in the dependency graph which much be removed + # before topological sorting can be done. + # The strategy is to ignore the dependencies of any package that is + # present in its own transitive dependency list + removed_deps = [] + for package_id in package_ids: + package = pool._id_to_package[package_id] + deps = dependencies[package_id] + package_group = packages_by_name[package.name] + for dep in list(deps): + circular = transitive[dep].intersection(package_group) + if circular: + packages = [pool._id_to_package[p] for p in circular] + depkg = pool._id_to_package[dep] + pkg_strings = [ + "{}-{}".format(pkg.name, pkg.version) + for pkg in packages + ] + msg = "Circular Deps: {}-{} -> {}-{} -> {}".format( + package.name, package.version, + depkg.name, depkg.version, + pkg_strings + ) + removed_deps.append(msg) + deps.remove(dep) + + # Mark packages as depending on older versions of themselves so that + # they will come out first in the toposort + for package_id in package_ids: + package = pool._id_to_package[package_id] + package_group = packages_by_name[package.name] + idx = package_group.index(package_id) + other_older = package_group[:idx + 1] + dependencies[package_id].update(other_older) + + # Finally toposort the packages, preferring higher version and + # already-installed packages to break ties + ordered = [ + package_id + for group in tuple(toposort(dependencies)) + for package_id in sorted(group, key=self.pkg_key, reverse=True) + ] + + package_id_to_rank = { + package_id: rank + for rank, package_id in enumerate(ordered) + } + + return package_id_to_rank + + def _group_packages_by_name(self, package_ids): + """ Return a dictionary from package name to all package ids + corresponding to packages with that name. """ + pool = self._pool + + name_map = DefaultOrderedDict(list) + for package_id in package_ids: + package = pool._id_to_package[package_id] + name_map[package.name].append(package_id) + + name_to_package_ids = {} + + for name, package_ids in name_map.items(): + ordered = sorted(package_ids, key=self.pkg_key, reverse=True) + name_to_package_ids[name] = ordered + + return name_to_package_ids + + def _update_cache_from_assignments(self, assignments): + new_keys = assignments.new_keys.copy() + changelog = assignments.consume_changelog() + + if new_keys: + unknown_ids = new_keys.difference(self._prioritizer.known) + self._all_ids.update(new_keys) + self._package_id_to_rank = self._rank_packages(self._all_ids) + self._prioritizer.update(unknown_ids, group=self.DEFAULT) + + # Newly unassigned + self._unassigned_pkg_ids.clear() + for key in assignments.unassigned_ids: + priority = self._prioritizer[key] + self._unassigned_pkg_ids.push(key, priority=priority) + else: + for key, (old, new) in six.iteritems(changelog): + if new is None: + # Newly unassigned + priority = self._prioritizer[key] + self._unassigned_pkg_ids.push(key, priority=priority) + elif old is None: + # No longer unassigned (because new is not None) + self._unassigned_pkg_ids.remove(key) + + # The remaining case is True -> False, False -> True or + # MISSING -> (True|False) + + # A very cheap sanity check + ours = len(self._unassigned_pkg_ids) + theirs = len(assignments) - assignments.num_assigned + has_new_keys = len(new_keys) + msg = "We failed to track variable assignments {} {} {}" + assert ours == theirs, msg.format(ours, theirs, has_new_keys) + + +def LoggedPriorityQueuePolicty(pool, installed_repository, *args, **kwargs): + policy = PriorityQueuePolicy(pool, installed_repository, *args, **kwargs) + logger = PolicyLogger(policy, extra_args=args, extra_kwargs=kwargs) + return logger diff --git a/simplesat/utils/__init__.py b/simplesat/utils/__init__.py index 052c1fa..bb58d65 100644 --- a/simplesat/utils/__init__.py +++ b/simplesat/utils/__init__.py @@ -7,8 +7,9 @@ import shutil import tempfile -from .timed_context import timed_context from ._collections import DefaultOrderedDict +from .graph import toposort, transitive_neighbors +from .timed_context import timed_context @contextlib.contextmanager diff --git a/simplesat/utils/graph.py b/simplesat/utils/graph.py new file mode 100644 index 0000000..5dcdd1a --- /dev/null +++ b/simplesat/utils/graph.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import division, print_function + +from collections import defaultdict + +import six + + +def toposort(nodes_to_edges): + """Dependencies are expressed as a dictionary whose keys are items and + whose values are a set of dependent items. Output is a list of sets in + topological order. The first set consists of items with no dependences, + each subsequent set consists of items that depend upon items in the + preceeding sets. + + >>> print '\\n'.join(repr(sorted(x)) for x in toposort2({ + ... 2: set([11]), + ... 9: set([11,8]), + ... 10: set([11,3]), + ... 11: set([7,5]), + ... 8: set([7,3]), + ... }) ) + [3, 5, 7] + [8, 11] + [2, 9, 10] + + """ + + data = {k: v.copy() for k, v in six.iteritems(nodes_to_edges)} + + # Ignore self dependencies. + for k, v in six.iteritems(data): + v.discard(k) + + # Find all items that don't depend on anything. + extra_items_in_deps = six.functools.reduce(set.union, six.itervalues(data)) + extra_items_in_deps.difference_update(set(six.iterkeys(data))) + + # Add empty dependences where needed + data.update({item: set() for item in extra_items_in_deps}) + + while True: + ordered = set(item for item, dep in six.iteritems(data) if not dep) + if not ordered: + break + yield ordered + data = {item: (dep - ordered) + for item, dep in six.iteritems(data) + if item not in ordered} + if data: + msg = "Cyclic dependencies exist among these items:\n{}" + cyclic = '\n'.join(repr(x) for x in six.iteritems(data)) + raise ValueError(msg.format(cyclic)) + + +def transitive_neighbors(nodes_to_edges): + """ Return the set of all reachable nodes for each node in the + nodes_to_edges adjacency dict. """ + trans = defaultdict(set) + for node in nodes_to_edges.keys(): + _transitive(node, nodes_to_edges, trans) + return trans + + +def _transitive(node, nodes_to_edges, trans): + trans = trans if trans is not None else defaultdict(set) + if node in trans: + return trans + neighbors = nodes_to_edges[node] + trans[node].update(neighbors) + for neighbor in neighbors: + _transitive(neighbor, nodes_to_edges, trans) + trans[node].update(trans[neighbor]) + return trans