Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for duplicates #48

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,20 @@ def pytest_addoption(parser):
'--skip-xfail', action="store_true", dest="skip_xfail",
help="Do not run the tests known to fail when in compare mode."
)
parser.addoption(
'--check-duplicates',
type=int,
dest="check_duplicates",
help="Activates the check that there is no duplicate in the nth first results"
)


def pytest_configure(config):
CONFIG['API_URL'] = config.getoption('--api-url')
CONFIG['MAX_RUN'] = config.getoption('--max-run')
CONFIG['LOOSE_COMPARE'] = config.getoption('--loose-compare')
CONFIG['GEOJSON'] = config.getoption('--geojson')
CONFIG['CHECK_DUPLICATES'] = config.getoption('--check-duplicates')
CONFIG['SKIP_XFAIL'] = config.getoption('--skip-xfail')
if config.getoption('--compare-report'):
with open(config.getoption('--compare-report')) as f:
Expand Down Expand Up @@ -151,6 +158,13 @@ def __init__(self, name, parent, **kwargs):
self.comment = kwargs.get('comment')
self.skip = kwargs.get('skip')
self.mark = kwargs.get('mark', [])

self.max_matches = kwargs.get('max_matches')
if self.max_matches:
self.max_matches = int(self.max_matches)
else:
self.max_matches = None

for mark in self.mark:
self.add_marker(mark)

Expand All @@ -161,7 +175,8 @@ def runtest(self):
'query': self.query,
'expected': self.expected,
'lang': self.lang,
'comment': self.comment
'comment': self.comment,
'max_matches': self.max_matches
}
if self.lat and self.lon:
kwargs['center'] = [self.lat, self.lon]
Expand Down
135 changes: 111 additions & 24 deletions geocoder_tester/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from geopy import Point
from geopy.distance import distance
from unidecode import unidecode
from collections import defaultdict


POTSDAM = [52.3879, 13.0582]
BERLIN = [52.519854, 13.438596]
Expand All @@ -16,11 +18,38 @@
'MAX_RUN': 0, # means no limit
'GEOJSON': False,
'FAILED': [],
'CHECK_DUPLICATES': None,
}

http = requests.Session()


def get_properties(f):
if 'geocoding' in f['properties']:
return f['properties']['geocoding']
else:
return f['properties']


def get_duplicates_key(feature):
"""
returns a key used to check if a feature has a duplicate.
This is a bit tuned on how the results are displayed to the end user.

For the majority of objects we use the label (or name if there is no label) + the type of the object
The type is used because for example there can be a POI with the same name as a Stop

For the POI it's a bit trickier, we also use the address of the POI
because there can be for example 2 bars with the same name in the same city
"""
obj = get_properties(feature)
label = obj.get('label') or feature['name']
if obj.get('type') == 'poi':
addr = obj.get('address', {}).get('label', '')
return (label, obj['type'], addr)
return (label, obj.get('type'))


class HttpSearchException(Exception):

def __init__(self, **kwargs):
Expand All @@ -31,6 +60,41 @@ def __str__(self):
return self.error


class DuplicatesException(Exception):
""" custom exception for duplicates reporting. """

def __init__(self, duplicates, params):
super().__init__()
self.duplicates = duplicates
self.query = params.get('q')

def __str__(self):
lines = [
'',
'Duplicates found in the response',
"# Search was: {}".format(self.query),
]
for key, features in self.duplicates.items():
lines.append('## Entry {} has been found for:'.format(key))
keys = [
'label', 'id', 'type', 'osm_id', 'housenumber', 'street',
'postcode', 'city', 'country', 'lat', 'lon', 'addr', 'poi_types'
]
def flatten_res(f):
r = get_properties(f)
coords = f.get('geometry', {}).get('coordinates', [None, None])
r['lat'] = coords[1]
r['lon'] = coords[0]
r['addr'] = r.get('address', {}).get('label')
r['poi_types'] = "-".join(t['name'] for t in r.get('poi_types', []))
return r
results = [flatten_res(f) for f in features]
lines.extend(dicts_to_table(results, keys=keys))
lines.append('')

return "\n".join(lines)


class SearchException(Exception):
""" custom exception for error reporting. """

Expand Down Expand Up @@ -63,7 +127,7 @@ def __str__(self):
'name', 'osm_key', 'osm_value', 'osm_id', 'housenumber', 'street',
'postcode', 'city', 'country', 'lat', 'lon', 'distance'
]
results = [self.flat_result(f) for f in self.results['features']]
results = [self.flat_result(f) for f in self.results]
lines.extend(dicts_to_table(results, keys=keys))
lines.append('')
if CONFIG['GEOJSON']:
Expand All @@ -85,19 +149,15 @@ def __str__(self):
return "\n".join(lines)

def to_geojson(self, coordinates, **properties):
self.results['features'].append({
self.results.append({
"type": "Feature",
"geometry": {"type": "Point", "coordinates": coordinates},
"properties": properties,
})
return json.dumps(self.results)
return json.dumps({'features': self.results})

def flat_result(self, result):
out = None
if 'geocoding' in result['properties']:
out = result['properties']['geocoding']
else:
out = result['properties']
out = get_properties(result)
if 'geometry' in result:
out['lat'] = result['geometry']['coordinates'][1]
out['lon'] = result['geometry']['coordinates'][0]
Expand Down Expand Up @@ -132,24 +192,23 @@ def compare_values(get, expected):


def assert_search(query, expected, limit=1,
comment=None, lang=None, center=None):
params = {"q": query, "limit": limit}
comment=None, lang=None, center=None,
max_matches=None):
query_limit = max(CONFIG['CHECK_DUPLICATES'] or 0, int(limit))
params = {"q": query, "limit": query_limit}
if lang:
params['lang'] = lang
if center:
params['lat'] = center[0]
params['lon'] = center[1]
results = search(**params)
raw_results = search(**params)
results = raw_results['features'][:int(limit)]

def assert_expected(expected):
found = False
for r in results['features']:
passed = True
properties = None
if 'geocoding' in r['properties']:
properties = r['properties']['geocoding']
else:
properties = r['properties']
nb_found = 0
for r in results:
found = True
properties = get_properties(r)
failed = properties['failed'] = []
for key, value in expected.items():
value = str(value)
Expand All @@ -166,22 +225,50 @@ def assert_expected(expected):
if int(deviation.meters) <= int(max_deviation):
continue # Continue to other properties
failed.append('distance')
passed = False
found = False
failed.append(key)
if passed:
found = True
if not found:
if found:
nb_found += 1
if max_matches is None:
break

if nb_found == 0:

raise SearchException(
params=params,
expected=expected,
results=results
)
elif max_matches is not None and nb_found > max_matches:
message = 'Got {} matching results. Expected at most {}.'.format(
nb_found, max_matches
)
raise SearchException(
params=params,
expected=expected,
results=results,
message=message
)

if not isinstance(expected, list):
expected = [expected]
for s in expected:
assert_expected(s)

if CONFIG['CHECK_DUPLICATES']:
check_duplicates(raw_results['features'][:CONFIG['CHECK_DUPLICATES']], params)


def check_duplicates(features, params):
results = defaultdict(list)
for f in features:
key = get_duplicates_key(f)
results[key].append(f)

duplicates = {k: dup for k, dup in results.items() if len(dup) != 1}
if duplicates:
raise DuplicatesException(duplicates, params)


def dicts_to_table(dicts, keys):
if not dicts:
Expand All @@ -207,7 +294,7 @@ def dicts_to_table(dicts, keys):
l = lengths.copy()
for key in keys:
value = d.get(key) or '_'
if key in d['failed']:
if key in d.get('failed', {}):
l[key] += 10 # Add ANSI chars so python len will turn out.
value = "\033[1;4m{}\033[0m".format(value)
row[key] = value
Expand Down