import dataclasses
from datetime import datetime
from typing import List, Optional, Type
from nova_api.dao import GenericDAO, camel_to_snake
from nova_api.entity import Entity
from nova_api.exceptions import NoRowsAffectedException
from nova_api.persistence import PersistenceHelper
from nova_api.persistence.mysql_helper import MySQLHelper
[docs]class GenericSQLDAO(GenericDAO):
"""SQL implementation for the GenericDAO interface
"""
# pylint: disable=R0913
[docs] def __init__(self, database_type: Type[PersistenceHelper] = None,
database_instance: PersistenceHelper = None,
table: str = None,
fields: dict = None,
return_class: Type[Entity] = Entity,
prefix: str = None, **kwargs) -> None:
super().__init__(fields, return_class, prefix)
self.database_type = database_type
self.database = database_instance
if self.database_type is None and self.database is None:
self.database_type = MySQLHelper
self.logger.debug("Started %s with database type as %s, table as %s, "
"fields as %s, return_class as %s and prefix as %s",
self.__class__.__name__,
str(database_type),
str(table),
str(fields),
return_class.__name__,
str(prefix))
if self.database is None:
self.logger.debug("Database connection starting. Extra args: %s. ",
str(kwargs))
self.database = self.database_type(**kwargs)
self.logger.debug("Connected to database.")
self.table = table or camel_to_snake(return_class.__name__) + 's'
[docs] def get(self, id_: str) -> Optional[Entity]:
"""Recovers one entity with `id_` from the database.
The `id_` must be the nova_api generated `id_` which is \
a 32-char uuid v4.
:raises InvalidIDTypeException: If the UUID is not a string
:raises InvalidIDException: If the UUID is not a valid UUID v4 \
without '-'.
:param id_: The UUID of the instance to recover
:return: None if no instance is found or a `return_class` instance \
if found
"""
super().get(id_)
self.logger.debug("Get called with valid id %s", id_)
_, results = self.get_all(1, 0, {"id_": id_})
if len(results) == 0:
self.logger.info("No entries with id %s found. Returning None",
id_)
return None
self.logger.debug("Found instance with id %s. Result: %s",
id_,
str(results[0]))
return results[0]
[docs] def get_all(self, length: int = 20, offset: int = 0,
filters: dict = None) -> (int, List[Entity]):
"""Recovers all instances that match the given filters up to the
length specified starting from the offset given.
The filters should be given as a dictionary, available keys are the \
`return_class` attributes. The values may be only the desired value \
or a list with the comparator in the first position and the value in \
the second.
Example:
>>> dao.get_all(length=50, offset=0,
... filters={"birthday":[">", "1/1/1998"],
... "name":"John"})
(2, [ent1, ent2])
:param length: The number of items to select
:param offset: The number of items to skip before starting to select
:param filters: A dict with the filters to use. The key must be a \
valid attribute in the entity and the value may either be an specific \
value or a list with two elements: an operator and a value, respectively.
:return: A tuple with the totol number of entities in the database \
and a list of the matched results.
"""
self.logger.debug("Getting all with filters %s limit %s and offset %s",
str(filters), length, offset)
filters_, query_params = ('', []) \
if not filters \
else self._generate_filters(filters)
query = self.database.SELECT_QUERY.format(
fields=', '.join(self.fields.values()),
table=self.table,
filters=filters_
)
self.logger.debug("Running query in database %s with params %s",
query,
str([*query_params, length, offset]))
self.database.query(query, [*query_params, length, offset])
results = self.database.get_results()
if results is None:
self.logger.info("No results found for query %s, %s in get_all. "
"Returning none", query, str([*query_params,
length, offset]))
return 0, []
return_list = [self.return_class(*result) for result in results]
query_total = self.database.QUERY_TOTAL_COLUMN.format(
table=self.table,
column=self.fields['id_'])
self.database.query(query_total)
total = self.database.get_results()[0][0]
self.logger.debug("Results are %s and the total in the database is %s",
str(return_list),
total)
return total, return_list
[docs] def remove(self, entity: Entity = None,
filters: dict = None) -> int:
"""
Removes entities from database. May be called either with an instance
of return_class or a dict of filters. *If both are passed, the instance
will be removed and the filters won't be considered.*Invalid filters \
won't be considered.
:raises NotEntityException: If `entity` is not a `return_class` \
instance and filters are None.
:raises EntityNotFoundException: If the entity is not found in the \
database.
:raises InvalidFiltersException: If filters is not None and is not \
a dict.
:raises NoRowsAffectedException: If no rows are affected by the \
delete query.
:param entity: `return_class` instance to delete.
:param filters: Filters to apply to delete query in dict format as \
specified by `generate_filters`
:return: Number of affected rows.
"""
super().remove(entity, filters)
filters_ = None
query_params = None
if entity is not None:
filters_, query_params = self._generate_filters(
{"id_": entity.id_})
elif filters is not None:
filters_, query_params = self._generate_filters(filters)
query = self.database.DELETE_QUERY.format(
table=self.table,
column=self.fields['id_'],
filters=filters_)
self.logger.debug("Running remove query in database: %s and params %s",
query,
query_params)
row_count, _ = self.database.query(query, query_params)
if row_count == 0:
self.logger.error("No rows were affected in database during "
"remove!")
raise NoRowsAffectedException()
self.logger.info("%s entities removed from database.",
row_count)
return row_count
[docs] def create(self, entity: Entity) -> str:
"""
Creates a new row in the database with data from `entity`.
:raises NotEntityException: Raised if the entity argument
is not of the return_class of this DAO
:raises DuplicateEntityException: Raised if an entity with
the same ID exists in the database already.
:param entity: The instance to save in the database.
:return: The entity uuid.
"""
super().create(entity)
ent_values = entity.get_db_values()
query = self.database.INSERT_QUERY.format(
table=self.table,
fields=', '.join(self.fields.values()),
values=', '.join(['%s'] * len(ent_values)))
self.logger.debug("Running query in database: %s and params %s",
query,
ent_values)
row_count, _ = self.database.query(query, ent_values)
if row_count == 0:
self.logger.error("No rows were affected in database during "
"create!")
raise NoRowsAffectedException()
self.logger.info("Entity created as %s", entity)
return entity.id_
[docs] def update(self, entity: Entity) -> str:
"""Updates an entity on the database.
:raises NotEntityException: If `entity` is not a `return_class` \
instance.
:raises EntityNotFoundException: If the entity is not found in the \
database.
:param entity: The entity with updated values to update on \
the database.
:return: The `id_` of the updated entity.
"""
super().update(entity)
entity.last_modified_datetime = datetime.now()
ent_values = entity.get_db_values()
query = self.database.UPDATE_QUERY.format(
table=self.table,
fields=', '.join(
[field + '=%s' for field in
self.fields.values()]),
column=self.fields['id_']
)
self.logger.debug("Running query in database: %s and params %s",
query,
ent_values + [entity.id_])
row_count, _ = self.database.query(query,
ent_values + [entity.id_])
if row_count == 0:
self.logger.error("No rows were affected in database during "
"update!")
raise NoRowsAffectedException()
self.logger.info("Entity updated to %s", entity)
return entity.id_
[docs] def create_table_if_not_exists(self) -> None:
"""Creates the table in the database based on the `return_class` \
attributes. The types used in the database will be inferred through \
`predict_db_types` or through the field metadata in the "type" key.
:return: None
"""
fields_ = []
primary_keys = []
self.logger.info("Starting create table processing.")
for field in dataclasses.fields(self.return_class):
self.logger.debug("Processing field %s", field)
if field.metadata.get("database") is False:
self.logger.debug("Field '%s' not included in database table, "
"skipping.", field.name)
continue
type_ = field.metadata.get('type') \
or self.database.predict_db_type(field.type)
self.logger.debug("'%s' type defined as '%s'", field.name, type_)
default = field.metadata.get('default') or "NULL"
field_name = self.fields.get(field.name)
self.logger.debug("'%s' name defined as '%s'",
field.name, field_name)
if field.metadata.get("primary_key"):
self.logger.debug("'%s' added as primary key", field_name)
primary_keys.append(str(field_name))
if default == "NULL":
self.logger.warning("Had to change '%s' default because "
"it is primary key and set to NULL.",
field_name)
default = "NOT NULL"
fields_.append(self.database.COLUMN.format(field=field_name,
type=type_,
default=default))
fields_ = ', '.join(fields_)
primary_keys = ', '.join(primary_keys)
query = self.database.CREATE_QUERY.format(table=self.table,
fields=fields_,
primary_keys=primary_keys)
self.logger.info("Creating table with query: %s", query)
self.database.query(query)
self.logger.info("Table created")
def _generate_filters(self, filters: dict) -> (str, List[str]):
"""
Converts a dict of filters to apply to a query to a SQL query format.
Example:
>>> dao._generate_filters(
... filters={"id_": "12345678901234567890123456789012",
... "creation_datetime": [">", "2020-1-1"]})
("WHERE id_ = %s AND creation_datetime > %s",
["12345678901234567890123456789012", "2020-1-1"])
:raises ValueError: If filters is None.
:raises TypeError: If filters is not a dict
:param filters: dictionary of filters to apply. The key must be a \
property of `return_class` and the value may be only the values, \
if equality is expected or a list with the comparator and the value.
:return: a tupĺe with the where statement and the list of params to use
"""
if filters is None:
raise ValueError("No filters where passed! Filters must be a dict "
"with param names, expected values and "
"comparators in this form: "
"{'param':['comparator', 'value']} "
"or {'param': 'value'} for equality.")
if not isinstance(filters, dict):
raise TypeError("Filters where passed not as dict!"
" Filters must be a dict "
"with param names, expected values and "
"comparators in this form: "
"{'param':['comparator', 'value']} "
"or {'param': 'value'} for equality.")
query_params = [item[1] if isinstance(item, list) else item
for item in filters.values()]
field_keys = self.fields.keys()
for property_, value in filters.items():
if property_ not in field_keys:
self.logger.error("Property %s not available in %s for "
"get_all.",
property_,
self.return_class.__name__)
raise ValueError(
f"Property {property_} not available "
f"in {self.return_class.__name__}."
)
if isinstance(value, list) \
and value[0] not in self.database.ALLOWED_COMPARATORS:
self.logger.error("Comparator %s not available in %s for "
"get_all.",
value[0],
self.return_class.__name__)
raise ValueError(
f"Comparator {value[0]} not allowed "
f"for {self.return_class.__name__}"
)
filters_for_query = [
self.database.FILTER.format(
column=self.fields[filter_],
comparator=(filters[filter_][0]
if isinstance(filters[filter_], list)
else '='))
for filter_ in filters.keys()
]
filters_ = self.database.FILTERS.format(
filters=' AND '.join(filters_for_query))
return filters_, query_params
[docs] def close(self) -> None:
"""Closes the connection to the database
:return: None
"""
self.logger.debug("Closing connection to database.")
self.database.close()