"""
A utility that executes test cases generated by `nsaph.dbt.create_test.py`
tool.
"""
# Copyright (c) 2023. Harvard University
#
# Developed by Research Software Engineering,
# Harvard University Research Computing and Data (RCD) Services.
#
# Author: Michael A Bouzinier
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
import logging
import os.path
from typing import List
from nsaph import init_logging
from nsaph.db import Connection
from nsaph.dbt.dbt_config import DBTConfig
[docs]class TestFailedError(Exception):
pass
[docs]class DBTRunner:
def __init__(self, context: DBTConfig = None):
if not context:
context = DBTConfig(None, __doc__).instantiate()
self.context = context
self.scripts = self.context.script
self.test_names = [
os.path.splitext(os.path.basename(t))[0] for t in self.scripts
]
init_logging(name="run-tests-" + "-".join(self.test_names))
self.runs = 0
self.successes = 0
self.failures = 0
[docs] def reset(self):
self.runs = 0
self.successes = 0
self.failures = 0
[docs] def run(self):
with Connection(self.context.db, self.context.connection) as cnxn:
for script_file in self.scripts:
with open(script_file) as script:
self.run_script(script, cnxn)
[docs] def run_script(self, script, cnxn):
lines = [line for line in script]
query = ''.join(lines)
with cnxn.cursor() as cursor:
cursor.execute(query)
columns = [desc[0] for desc in cursor.description]
pi = columns.index("passed")
n = len(columns)
rows = [row for row in cursor]
lengths = [0 for _ in range(n)]
passes = 0
failures = 0
test_cases = []
for row in rows:
values = [row[i] for i in range(n)]
if row[pi]:
passes += 1
values[pi] = "passed"
else:
failures += 1
values[pi] = "failed"
for i in range(n):
if len(values[i]) > lengths[i]:
lengths[i] = len(str(values[i]))
test_cases.append(values)
lengths = [l + 1 for l in lengths]
logging.info(self.report_row(columns, lengths))
for row in test_cases:
s = self.report_row(row, lengths)
if row[pi] == "passed":
logging.info(s)
elif row[pi] == "failed":
logging.error(s)
else:
logging.warning(s)
logging.info("Passed: {:d}; Failed: {:d}".format(passes, failures))
self.runs += len(test_cases)
self.successes += passes
self.failures += failures
return
[docs] @classmethod
def report_row(cls, row: List, lengths: List[int]) -> str:
s = ""
for i in range(len(lengths)):
cell = str(row[i]).ljust(lengths[i]) + '\t'
s += cell
return s
[docs] def test(self):
self.reset()
self.run()
if self.failures > 0:
err = TestFailedError(f"There are {str(self.failures)} failures")
logging.exception("Tests FAILED", err)
raise err
logging.info("All tests succeeded")
if __name__ == '__main__':
runner = DBTRunner()
runner.test()