diff --git a/conftest.py b/conftest.py index 2ae6a1e..5235f31 100644 --- a/conftest.py +++ b/conftest.py @@ -68,6 +68,12 @@ def pytest_addoption(parser): '--skip-xfail', action="store_true", dest="skip_xfail", help="Do not run the tests known to fail when in compare mode." ) + parser.addoption( + '--check-duplicates', + type=int, + dest="check_duplicates", + help="Activates the check that there is no duplicate in the nth first results" + ) def pytest_configure(config): @@ -75,6 +81,7 @@ def pytest_configure(config): CONFIG['MAX_RUN'] = config.getoption('--max-run') CONFIG['LOOSE_COMPARE'] = config.getoption('--loose-compare') CONFIG['GEOJSON'] = config.getoption('--geojson') + CONFIG['CHECK_DUPLICATES'] = config.getoption('--check-duplicates') CONFIG['SKIP_XFAIL'] = config.getoption('--skip-xfail') if config.getoption('--compare-report'): with open(config.getoption('--compare-report')) as f: @@ -151,6 +158,13 @@ def __init__(self, name, parent, **kwargs): self.comment = kwargs.get('comment') self.skip = kwargs.get('skip') self.mark = kwargs.get('mark', []) + + self.max_matches = kwargs.get('max_matches') + if self.max_matches: + self.max_matches = int(self.max_matches) + else: + self.max_matches = None + for mark in self.mark: self.add_marker(mark) @@ -161,7 +175,8 @@ def runtest(self): 'query': self.query, 'expected': self.expected, 'lang': self.lang, - 'comment': self.comment + 'comment': self.comment, + 'max_matches': self.max_matches } if self.lat and self.lon: kwargs['center'] = [self.lat, self.lon] diff --git a/geocoder_tester/base.py b/geocoder_tester/base.py index b3bfacb..27d9f69 100644 --- a/geocoder_tester/base.py +++ b/geocoder_tester/base.py @@ -5,6 +5,8 @@ from geopy import Point from geopy.distance import distance from unidecode import unidecode +from collections import defaultdict + POTSDAM = [52.3879, 13.0582] BERLIN = [52.519854, 13.438596] @@ -16,11 +18,38 @@ 'MAX_RUN': 0, # means no limit 'GEOJSON': False, 'FAILED': [], + 'CHECK_DUPLICATES': None, } http = requests.Session() +def get_properties(f): + if 'geocoding' in f['properties']: + return f['properties']['geocoding'] + else: + return f['properties'] + + +def get_duplicates_key(feature): + """ + returns a key used to check if a feature has a duplicate. + This is a bit tuned on how the results are displayed to the end user. + + For the majority of objects we use the label (or name if there is no label) + the type of the object + The type is used because for example there can be a POI with the same name as a Stop + + For the POI it's a bit trickier, we also use the address of the POI + because there can be for example 2 bars with the same name in the same city + """ + obj = get_properties(feature) + label = obj.get('label') or feature['name'] + if obj.get('type') == 'poi': + addr = obj.get('address', {}).get('label', '') + return (label, obj['type'], addr) + return (label, obj.get('type')) + + class HttpSearchException(Exception): def __init__(self, **kwargs): @@ -31,6 +60,41 @@ def __str__(self): return self.error +class DuplicatesException(Exception): + """ custom exception for duplicates reporting. """ + + def __init__(self, duplicates, params): + super().__init__() + self.duplicates = duplicates + self.query = params.get('q') + + def __str__(self): + lines = [ + '', + 'Duplicates found in the response', + "# Search was: {}".format(self.query), + ] + for key, features in self.duplicates.items(): + lines.append('## Entry {} has been found for:'.format(key)) + keys = [ + 'label', 'id', 'type', 'osm_id', 'housenumber', 'street', + 'postcode', 'city', 'country', 'lat', 'lon', 'addr', 'poi_types' + ] + def flatten_res(f): + r = get_properties(f) + coords = f.get('geometry', {}).get('coordinates', [None, None]) + r['lat'] = coords[1] + r['lon'] = coords[0] + r['addr'] = r.get('address', {}).get('label') + r['poi_types'] = "-".join(t['name'] for t in r.get('poi_types', [])) + return r + results = [flatten_res(f) for f in features] + lines.extend(dicts_to_table(results, keys=keys)) + lines.append('') + + return "\n".join(lines) + + class SearchException(Exception): """ custom exception for error reporting. """ @@ -63,7 +127,7 @@ def __str__(self): 'name', 'osm_key', 'osm_value', 'osm_id', 'housenumber', 'street', 'postcode', 'city', 'country', 'lat', 'lon', 'distance' ] - results = [self.flat_result(f) for f in self.results['features']] + results = [self.flat_result(f) for f in self.results] lines.extend(dicts_to_table(results, keys=keys)) lines.append('') if CONFIG['GEOJSON']: @@ -85,19 +149,15 @@ def __str__(self): return "\n".join(lines) def to_geojson(self, coordinates, **properties): - self.results['features'].append({ + self.results.append({ "type": "Feature", "geometry": {"type": "Point", "coordinates": coordinates}, "properties": properties, }) - return json.dumps(self.results) + return json.dumps({'features': self.results}) def flat_result(self, result): - out = None - if 'geocoding' in result['properties']: - out = result['properties']['geocoding'] - else: - out = result['properties'] + out = get_properties(result) if 'geometry' in result: out['lat'] = result['geometry']['coordinates'][1] out['lon'] = result['geometry']['coordinates'][0] @@ -132,24 +192,23 @@ def compare_values(get, expected): def assert_search(query, expected, limit=1, - comment=None, lang=None, center=None): - params = {"q": query, "limit": limit} + comment=None, lang=None, center=None, + max_matches=None): + query_limit = max(CONFIG['CHECK_DUPLICATES'] or 0, int(limit)) + params = {"q": query, "limit": query_limit} if lang: params['lang'] = lang if center: params['lat'] = center[0] params['lon'] = center[1] - results = search(**params) + raw_results = search(**params) + results = raw_results['features'][:int(limit)] def assert_expected(expected): - found = False - for r in results['features']: - passed = True - properties = None - if 'geocoding' in r['properties']: - properties = r['properties']['geocoding'] - else: - properties = r['properties'] + nb_found = 0 + for r in results: + found = True + properties = get_properties(r) failed = properties['failed'] = [] for key, value in expected.items(): value = str(value) @@ -166,22 +225,50 @@ def assert_expected(expected): if int(deviation.meters) <= int(max_deviation): continue # Continue to other properties failed.append('distance') - passed = False + found = False failed.append(key) - if passed: - found = True - if not found: + if found: + nb_found += 1 + if max_matches is None: + break + + if nb_found == 0: + raise SearchException( params=params, expected=expected, results=results ) + elif max_matches is not None and nb_found > max_matches: + message = 'Got {} matching results. Expected at most {}.'.format( + nb_found, max_matches + ) + raise SearchException( + params=params, + expected=expected, + results=results, + message=message + ) if not isinstance(expected, list): expected = [expected] for s in expected: assert_expected(s) + if CONFIG['CHECK_DUPLICATES']: + check_duplicates(raw_results['features'][:CONFIG['CHECK_DUPLICATES']], params) + + +def check_duplicates(features, params): + results = defaultdict(list) + for f in features: + key = get_duplicates_key(f) + results[key].append(f) + + duplicates = {k: dup for k, dup in results.items() if len(dup) != 1} + if duplicates: + raise DuplicatesException(duplicates, params) + def dicts_to_table(dicts, keys): if not dicts: @@ -207,7 +294,7 @@ def dicts_to_table(dicts, keys): l = lengths.copy() for key in keys: value = d.get(key) or '_' - if key in d['failed']: + if key in d.get('failed', {}): l[key] += 10 # Add ANSI chars so python len will turn out. value = "\033[1;4m{}\033[0m".format(value) row[key] = value