Skip to content

Commit

Permalink
fixes #113
Browse files Browse the repository at this point in the history
  • Loading branch information
WolfgangFahl committed May 6, 2024
1 parent 5477616 commit 8e620b6
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 32 deletions.
2 changes: 1 addition & 1 deletion lodstorage/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.10.6"
__version__ = "0.11.0"
File renamed without changes.
84 changes: 84 additions & 0 deletions lodstorage/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
Created on 2024-05-06
@author: wf
"""
import argparse
import re
from typing import Dict, Optional


class Params:
"""
parameter handling
"""

def __init__(self, query: str):
"""
constructor
Args:
query(str): the query to analyze for parameters
"""
self.query = query
self.pattern = re.compile(r"{{\s*(\w+)\s*}}")
self.params = self.pattern.findall(query)
self.params_dict = {param: "" for param in self.params}
self.has_params = len(self.params) > 0

def set(self, params_dict: Dict):
"""
set my params
"""
self.params_dict = params_dict

def apply_parameters(self) -> str:
"""
Replace Jinja templates in the query with corresponding parameter values.
Returns:
str: The query with Jinja templates replaced by parameter values.
"""
query = self.query
for param, value in self.params_dict.items():
pattern = re.compile(r"{{\s*" + re.escape(param) + r"\s*\}\}")
query = re.sub(pattern, value, query)
return query


class StoreDictKeyPair(argparse.Action):
"""
Custom argparse action to store key-value pairs as a dictionary.
This class implements an argparse action to parse and store command-line
arguments in the form of key-value pairs. The pairs should be separated by
a comma and each key-value pair should be separated by an equals sign.
Example:
--option key1=value1,key2=value2,key3=value3
Reference:
https://stackoverflow.com/a/42355279/1497139
"""

def __call__(
self,
_parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: str,
_option_string: Optional[str] = None,
) -> None:
"""
Parse key-value pairs and store them as a dictionary in the namespace.
Args:
parser (argparse.ArgumentParser): The argument parser object.
namespace (argparse.Namespace): The namespace to store the parsed values.
values (str): The string containing key-value pairs separated by commas.
option_string (Optional[str]): The option string, if provided.
"""
my_dict = {}
for kv in values.split(","):
k, v = kv.split("=")
my_dict[k] = v
setattr(namespace, self.dest, my_dict)
20 changes: 13 additions & 7 deletions lodstorage/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __str__(self):

class YamlPath:
@staticmethod
def getPaths(yamlFileName: str, yamlPath: str = None, with_default:bool=True):
def getPaths(yamlFileName: str, yamlPath: str = None, with_default: bool = True):
"""
Args:
yamlFileName (str): The name of the YAML file to read from if (any) - legacy way to specify name
Expand Down Expand Up @@ -604,7 +604,9 @@ class QueryManager(object):
manages pre packaged Queries
"""

def __init__(self, lang: str = None, debug=False, queriesPath=None,with_defaults:bool=True):
def __init__(
self, lang: str = None, debug=False, queriesPath=None, with_default: bool = True
):
"""
Constructor
Args:
Expand All @@ -618,7 +620,9 @@ def __init__(self, lang: str = None, debug=False, queriesPath=None,with_defaults
self.queriesByName = {}
self.lang = lang
self.debug = debug
queries = QueryManager.getQueries(queriesPath=queriesPath,with_default=with_default)
queries = QueryManager.getQueries(
queriesPath=queriesPath, with_default=with_default
)
for name, queryDict in queries.items():
if self.lang in queryDict:
queryText = queryDict.pop(self.lang)
Expand All @@ -635,16 +639,18 @@ def __init__(self, lang: str = None, debug=False, queriesPath=None,with_defaults
self.queriesByName[name] = query

@staticmethod
def getQueries(queriesPath=None, with_default:bool=True):
def getQueries(queriesPath=None, with_default: bool = True):
"""
get the queries for the given queries Path
Args:
queriesPath(str): the path of the yaml file to load queries from
with_default(bool): if True also load the default yaml file
"""
queriesPaths = YamlPath.getPaths("queries.yaml", queriesPath, with_default=with_default)
queriesPaths = YamlPath.getPaths(
"queries.yaml", queriesPath, with_default=with_default
)
queries = {}
for queriesPath in queriesPaths:
if os.path.isfile(queriesPath):
Expand Down
28 changes: 24 additions & 4 deletions lodstorage/querymain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
from pathlib import Path

from lodstorage.params import Params, StoreDictKeyPair
from lodstorage.version import Version

__version__ = Version.version
Expand All @@ -22,7 +23,7 @@

import requests

from lodstorage.csv import CSV
from lodstorage.lod_csv import CSV
from lodstorage.query import (
Endpoint,
EndpointManager,
Expand Down Expand Up @@ -86,9 +87,18 @@ def main(cls, args):
queryCode = queryFilePath.read_text()
name = queryFilePath.stem
query = Query(name="?", query=queryCode, lang=args.language)

if queryCode:
params = Params(query.query)
if params.has_params:
if not args.params:
raise Exception(f"{query.name} needs parameters")
else:
params.set(args.params)
query.query = params.apply_parameters()
queryCode=query.query
if debug or args.showQuery:
print(f"{args.language}:\n{queryCode}")
print(f"{args.language}:\n{query.query}")
endpointConf = Endpoint()
endpointConf.method = "POST"
if args.endpointName:
Expand Down Expand Up @@ -148,7 +158,7 @@ def main(cls, args):
raise Exception(f"format {args.format} not supported yet")

@staticmethod
def rawQuery(endpointConf, query, resultFormat, mimeType, timeout:float=10.0):
def rawQuery(endpointConf, query, resultFormat, mimeType, timeout: float = 10.0):
"""
returns raw result of the endpoint
Expand All @@ -171,7 +181,12 @@ def rawQuery(endpointConf, query, resultFormat, mimeType, timeout:float=10.0):
endpoint = endpointConf.endpoint
method = endpointConf.method
response = requests.request(
method, endpoint, headers=headers, data=payload, params=params,timeout=timeout
method,
endpoint,
headers=headers,
data=payload,
params=params,
timeout=timeout,
)
return response.text

Expand Down Expand Up @@ -269,6 +284,11 @@ def main(argv=None, lang=None): # IGNORE:C0111
parser.add_argument(
"--limit", type=int, default=None, help="set limit parameter of query"
)
parser.add_argument(
"--params",
action=StoreDictKeyPair,
help="query parameters as Key-value pairs in the format key1=value1,key2=value2",
)
parser.add_argument(
"-le",
"--listEndpoints",
Expand Down
2 changes: 2 additions & 0 deletions lodstorage/sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def fix_comments(self, query_string: str) -> str:
"""
make sure broken SPARQLWrapper will find comments
"""
if query_string is None:
return None
return "#\n" + query_string

def getValue(self, sparqlQuery: str, attr: str):
Expand Down
26 changes: 13 additions & 13 deletions lodstorage/sql_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from lodstorage.query import QueryManager
from lodstorage.sparql import SPARQL


class SqlDB:
"""
general SQL database access using SQL Alchemy
Expand Down Expand Up @@ -65,11 +66,11 @@ def __init__(
self.debug = debug
self.entities = []
self.errors = []
self.fetched=False
self.fetched = False
# Ensure the table for the class exists
clazz.metadata.create_all(self.sql_db.engine)

def fetch_or_query(self, qm, force_query=False)->List[Dict]:
def fetch_or_query(self, qm, force_query=False) -> List[Dict]:
"""
Fetches data from the local cache if available.
If the data is not in the cache or if force_query is True,
Expand All @@ -79,12 +80,12 @@ def fetch_or_query(self, qm, force_query=False)->List[Dict]:
qm (QueryManager): The query manager object used for making SPARQL queries.
force_query (bool, optional): A flag to force querying via SPARQL even if the data exists in the local cache. Defaults to False.
Returns:
List: list of records from the SQL database
List: list of records from the SQL database
"""
if not force_query and self.check_local_cache():
lod=self.fetch_from_local()
lod = self.fetch_from_local()
else:
lod=self.get_lod(qm)
lod = self.get_lod(qm)
self.store()
return lod

Expand All @@ -98,11 +99,11 @@ def check_local_cache(self) -> bool:
with self.sql_db.get_session() as session:
result = session.exec(select(self.clazz)).first()
return result is not None

def fetch_from_local(self) -> List[Dict]:
"""
Fetches data from the local SQL database as list of dicts and entities.
Returns:
List[Dict]: List of records from the SQL database in dictionary form.
"""
Expand All @@ -115,7 +116,6 @@ def fetch_from_local(self) -> List[Dict]:
profiler.time()
return self.lod


def get_lod(self, qm: QueryManager) -> List[Dict]:
"""
Fetches data using the SPARQL query specified by my query_name.
Expand All @@ -137,7 +137,7 @@ def get_lod(self, qm: QueryManager) -> List[Dict]:
print(f"Found {len(self.lod)} records for {self.query_name}")
return self.lod

def to_entities(self, max_errors: int = None,cached:bool=True) -> List[Any]:
def to_entities(self, max_errors: int = None, cached: bool = True) -> List[Any]:
"""
Converts records fetched from the LOD into entity instances, applying validation.
Expand All @@ -152,7 +152,7 @@ def to_entities(self, max_errors: int = None,cached:bool=True) -> List[Any]:
self.errors = []
elif self.fetched:
return self.entities

error_records = []
if max_errors is None:
max_errors = self.max_errors
Expand All @@ -171,7 +171,7 @@ def to_entities(self, max_errors: int = None,cached:bool=True) -> List[Any]:
for i, e in enumerate(self.errors):
print(f"{i}:{str(e)} for \n{error_records[i]}")
raise Exception(msg)
self.fetched=True
self.fetched = True
return self.entities

def store(self, max_errors: int = None) -> List[Any]:
Expand All @@ -186,11 +186,11 @@ def store(self, max_errors: int = None) -> List[Any]:
"""
profiler = Profiler(f"store {self.query_name}", profile=self.debug)
self.to_entities(max_errors=max_errors,cached=False)
self.to_entities(max_errors=max_errors, cached=False)
with self.sql_db.get_session() as session:
session.add_all(self.entities)
session.commit()
if self.debug:
print(f"Stored {len(self.entities)} records in local cache")
profiler.time()
return self.entities
return self.entities
4 changes: 2 additions & 2 deletions sampledata/queries.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@
SELECT * FROM sample
LIMIT 5
'cities':
title: 'German cities by population'
title: 'Cities of a country by population'
sparql: |
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX wd: <http://www.wikidata.org/entity/>
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
SELECT ?city_id ?name (MAX(?population_claim) AS ?population) WHERE {
?city_id wdt:P31/wdt:P279* wd:Q515 .
?city_id wdt:P17 wd:Q183 .
?city_id wdt:P17 wd:{{country}}.
?city_id wdt:P1082 ?population_claim .
?city_id rdfs:label ?name .
FILTER (LANG(?name) = "en")
Expand Down
2 changes: 1 addition & 1 deletion tests/testSPARQL.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def testStackoverflow71444069(self):
"""
https://stackoverflow.com/questions/71444069/create-csv-from-result-of-a-for-google-colab/71548650#71548650
"""
from lodstorage.csv import CSV
from lodstorage.lod_csv import CSV
from lodstorage.sparql import SPARQL

sparqlQuery = """SELECT ?org ?orgLabel
Expand Down
2 changes: 1 addition & 1 deletion tests/test_csv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tempfile

from lodstorage.csv import CSV
from lodstorage.lod_csv import CSV
from lodstorage.jsonable import JSONAble, JSONAbleList
from lodstorage.lod import LOD
from lodstorage.sample import Sample
Expand Down
34 changes: 34 additions & 0 deletions tests/test_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Created on 2024-05-06
@author: wf
"""
from lodstorage.params import Params
from tests.basetest import Basetest


class TestParams(Basetest):
"""
test the params handling
"""

def setUp(self, debug=False, profile=True):
Basetest.setUp(self, debug=debug, profile=profile)

def test_jinja_params(self):
"""
test jinia_params
"""
for sample, params_dict in [
("PREFIX target: <http://www.wikidata.org/entity/{{ q }}>", {"q": "Q80"})
]:
params = Params(sample)
if self.debug:
print(params.params)
self.assertEqual(["q"], params.params)
self.assertTrue("q" in params.params_dict)
params.params_dict = params_dict
query = params.apply_parameters()
self.assertEqual(
"PREFIX target: <http://www.wikidata.org/entity/Q80>", query
)
Loading

0 comments on commit 8e620b6

Please sign in to comment.