Source code for nsaph.requests.query

"""
Utility to generate SQL query based on a YAML query specification
"""


#  Copyright (c) 2021. Harvard University
#
#  Developed by Research Software Engineering,
#  Faculty of Arts and Sciences, Research Computing (FAS RC)
#  Author: Michael A Bouzinier
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

import argparse
import json
import os
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Set

import yaml
from nsaph_utils.utils.io_utils import as_dict

from nsaph import init_logging
from nsaph.db import Connection, ResultSetDeprecated


[docs]def fqn(t): return t
#return '"public"."{}"'.format(t)
[docs]class Query: """ Class providing the API to generate SQL from a user request """ def __init__(self, user_request, connection): request = as_dict(user_request) src = Path(__file__).parents[3] registry_path = os.path.join(src, "yml", "gridmet.yaml") with open(registry_path) as rf: self.registry = yaml.safe_load(rf) self.request = request if isinstance(connection, Connection): self.connection = connection else: self.connection = Connection(*connection) self.cursor = None self.sql = None '''Generated SQL''' self.metadata = None
[docs] def execute(self): self.cursor.execute(self.sql) return ResultSetDeprecated(cursor=self.cursor, metadata=self.metadata)
[docs] def prepare(self): """ Generates SQL. Result is stored as a class member sql :return: None """ connection = self.connection.connect() self.metadata = self.connection.get_database_types() self.cursor = connection.cursor() self.sql = generate(self.registry, self.request)
def __enter__(self): self.prepare() return self def __exit__(self, exc_type, exc_val, exc_tb): self.connection.close()
[docs]def find_tables(column: str, tables: Dict) -> List[str]: children = dict() result = [] for t in tables: tdef = tables[t] if column in tdef["columns"]: result.append(t) if "children" in tdef: children = tdef["children"] if result: return result if children: return find_tables(column, children) raise Exception("Column named {} is not found in any of the defined tables" .format(column))
[docs]def generate_select(variables: Dict) -> str: select = ["{}.{}".format(variables[v][0], v) for v in variables] return "SELECT \n\t" + ",\n\t".join(select) + "\n"
[docs]def collect_tables(source: Dict, tables: Dict, result: Dict): join_columns = dict() for t in source: if t not in tables: continue result[t] = [] for c in tables[t]: if c in join_columns: join_columns[c].add(t) else: join_columns[c] = {t} for t in result: for c in tables[t]: if c not in join_columns: continue if len(join_columns[c]) < 2: continue for t2 in join_columns[c]: if t2 != t: result[t].append((c, t2)) for t in result: if "children" in source[t]: child_tables = source[t]["children"] children_result = OrderedDict() collect_tables(child_tables, tables, children_result) for child in children_result: parent = child_tables[child]["parent"] children_result[child].append((parent, t)) result.update(children_result) return
[docs]def generate_from(variables: Dict, aux: Dict, source: Dict) -> str: columns = dict(variables) columns.update(aux) tables = { columns[v][0]: set() for v in columns } for c in columns: for t in columns[c]: if t in tables: tables[t].add(c) from_tables = OrderedDict() collect_tables(source, tables, from_tables) sql = "FROM \n\t" from_elements = [] tt = set() for t in from_tables: element = fqn(t) joins = from_tables[t] if joins: element += " ON " join_clause = [] for join in joins: c, t2 = join if t2 not in tt: continue join_clause.append("{t1}.{c} = {t2}.{c}" .format(t1=t, t2=t2, c=c)) element += " AND ".join(join_clause) tt.add(t) from_elements.append(element) sql += "\n\t JOIN ".join(from_elements) return sql
[docs]def generate_where(variables: Dict, tables: Dict, used_tables: Set) -> str: where = [] for v in variables: tt = find_tables(v, tables) tt = [t for t in tt if t in used_tables] if not tt: raise Exception("System Error") t = tt[0] expr = variables[v] if isinstance(expr, str): condition = "{}.{} = '{}'".format(t, v, expr) where.append(condition) elif isinstance(expr, list): values = ", ".join(["'{}'".format(str(e)) for e in expr]) condition = "{}.{} IN ({})".format(t, v, values) where.append(condition) elif isinstance(expr, dict): for field in expr: e = expr[field] condition = "EXTRACT ({} FROM {}.{}) = {}".format(field, t, v, e) where.append(condition) return "WHERE \n\t" + "\n\tAND ".join(where) + "\n"
[docs]def all_tables(variables: Dict) -> Set: tables = set() for v in variables: for t in variables[v]: tables.add(t) return tables
[docs]def reduce_tables(variables: Dict) -> bool: tables = all_tables(variables) reduced = False for t in tables: required = False for v in variables: if t not in variables[v]: continue if len(variables[v]) == 1: required = True break if not required: for v in variables: if t in variables[v]: variables[v].remove(t) reduced = True if reduced: reduced = reduce_tables(variables) return reduced
[docs]def generate_order_by(request: dict) -> str: group = None if "package" in request: if "group" in request["package"]: group = request["package"]["group"] if not group: return "" if isinstance(group, str): return "\nORDER BY {}".format(group) if isinstance(group, list): return "\nORDER BY {}".format(", ".join(group)) else: raise Exception("Invalid specification for grouping: ". format(str(group)))
[docs]def generate(system, user) -> (str, List): source_name = user["source"] source = system[source_name] variables = dict() for v in user["variables"]: t = find_tables(v, source["tables"]) variables[v] = t reduce_tables(variables) select = generate_select(variables) tables = all_tables(variables) filters = dict() for v in user["restrict"]: if v in variables: continue t = find_tables(v, source["tables"]) if len(set(t) & tables) > 0: continue filters[v] = t reduce_tables(filters) tables.update(all_tables(filters)) source_tables = source["tables"] from_clause = generate_from(variables, filters, source_tables) where = generate_where(user["restrict"], source_tables, tables) order_by = generate_order_by(user) sql = select + "\n" + from_clause + "\n" + where + order_by return sql
if __name__ == '__main__': init_logging() parser = argparse.ArgumentParser (description="Create table and load data") parser.add_argument("--request", "-r", help="Path to a YAML file containing user request", default=None, required=False) parser.add_argument("--db", help="Path to a database connection parameters file", default="database.ini", required=False) parser.add_argument("--section", help="Section in the database connection parameters file", default="postgresql", required=False) args = parser.parse_args() if not args.request: d = os.path.dirname(__file__) args.request = os.path.join(d, "../../../yml/ellen.yaml") with Query(args.request, (args.db, args.section)) as q: print(q.sql) count = 0 rs = q.execute() for r in rs: count += 1 print(r) print("Count rows: {:d}".format(count))