import sqlite3 from contextlib import closing from time import time from typing import Optional, Iterator from pycs import db from pycs.database.Collection import Collection from pycs.database.File import File from pycs.database.LabelProvider import LabelProvider from pycs.database.Model import Model from pycs.database.Project import Project from pycs.database.Result import Result from pycs.database.discovery.LabelProviderDiscovery import discover as discover_label_providers from pycs.database.discovery.ModelDiscovery import discover as discover_models class Database: """ opens an sqlite database and allows to access several objects """ def __init__(self, discovery: bool = True): """ wrapper for some DB-related runctions. TODO: remove it! """ if discovery: discover_models() discover_label_providers() def close(self): """ close database file """ return def copy(self): """ Create a copy of this database object. This can be used to access the database from another thread. Table initialization and model and label provider discovery is disabled to speedup this function. :return: Database """ return self def commit(self): """ commit changes """ db.session.commit() def get_object_by_id(self, table_name: str, identifier: int, cls): """ create an object from cls and a row fetched from table_name and identified by the identifier :param table_name: table name :param identifier: unique identifier :param cls: class that is used to create the object :return: object of type cls """ with closing(self.con.cursor()) as cursor: cursor.execute(f'SELECT * FROM {table_name} WHERE id = ?', [identifier]) row = cursor.fetchone() if row is not None: return cls(self, row) return None def get_objects(self, table_name: str, cls): """ get a list of all available objects in the table :param table_name: table name :param cls: class that is used to create the objects :return: list of object of type cls """ with closing(self.con.cursor()) as cursor: cursor.execute(f'SELECT * FROM {table_name}') for row in cursor: yield cls(self, row) def models(self) -> Iterator[Model]: """ get a list of all available models :return: iterator of models """ return self.get_objects("models", Model) def model(self, identifier: int) -> Optional[Model]: """ get a model using its unique identifier :param identifier: unique identifier :return: model """ return self.get_object_by_id("models", identifier, Model) def label_providers(self) -> Iterator[LabelProvider]: """ get a list of all available label providers :return: iterator over label providers """ return self.get_objects("label_providers", LabelProvider) def label_provider(self, identifier: int) -> Optional[LabelProvider]: """ get a label provider using its unique identifier :param identifier: unique identifier :return: label provider """ return self.get_object_by_id("label_providers", identifier, LabelProvider) def projects(self) -> Iterator[Project]: """ get a list of all available projects :return: iterator over projects """ return self.get_objects("projects", Project) def project(self, identifier: int) -> Optional[Project]: """ get a project using its unique identifier :param identifier: unique identifier :return: project """ return self.get_object_by_id("projects", identifier, Project) def create_project(self, name: str, description: str, model: Model, label_provider: Optional[LabelProvider], root_folder: str, external_data: bool, data_folder: str): """ insert a project into the database :param name: project name :param description: project description :param model: used model :param label_provider: used label provider (optional) :param root_folder: path to project folder :param external_data: whether an external data directory is used :param data_folder: path to data folder :return: created project """ # prepare some values created = int(time()) label_provider_id = label_provider.identifier if label_provider is not None else None # insert statement with closing(self.con.cursor()) as cursor: cursor.execute(''' INSERT INTO projects ( name, description, created, model, label_provider, root_folder, external_data, data_folder ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', (name, description, created, model.identifier, label_provider_id, root_folder, external_data, data_folder)) return self.project(cursor.lastrowid) def collection(self, identifier: int) -> Optional[Collection]: """ get a collection using its unique identifier :param identifier: unique identifier :return: collection """ return self.get_object_by_id("collections", identifier, Collection) def file(self, identifier) -> Optional[File]: """ get a file using its unique identifier :param identifier: unique identifier :return: file """ return self.get_object_by_id("files", identifier, File) def result(self, identifier) -> Optional[Result]: """ get a result using its unique identifier :param identifier: unique identifier :return: result """ return self.get_object_by_id("results", identifier, Result)