X-Git-Url: https://git.decadent.org.uk/gitweb/?a=blobdiff_plain;f=daklib%2Fdbconn.py;h=6d5497fc2d5f4b096d972631637e284eb2ce00aa;hb=852b95bdefa52aead80cad3a51381535774dc48a;hp=04afe7c4f352d473faf4c022a8156b4e6e59f9a4;hpb=1f7d2e0d967c3ae513c31c1ce3a1acd7db694909;p=dak.git diff --git a/daklib/dbconn.py b/daklib/dbconn.py index 04afe7c4..6d5497fc 100755 --- a/daklib/dbconn.py +++ b/daklib/dbconn.py @@ -39,7 +39,7 @@ import traceback from inspect import getargspec -from sqlalchemy import create_engine, Table, MetaData, select +from sqlalchemy import create_engine, Table, MetaData from sqlalchemy.orm import sessionmaker, mapper, relation # Don't remove this, we re-export the exceptions to scripts which import us @@ -59,21 +59,49 @@ __all__ = ['IntegrityError', 'SQLAlchemyError'] ################################################################################ def session_wrapper(fn): + """ + Wrapper around common ".., session=None):" handling. If the wrapped + function is called without passing 'session', we create a local one + and destroy it when the function ends. + + Also attaches a commit_or_flush method to the session; if we created a + local session, this is a synonym for session.commit(), otherwise it is a + synonym for session.flush(). + """ + def wrapped(*args, **kwargs): private_transaction = False + + # Find the session object session = kwargs.get('session') - # No session specified as last argument or in kwargs, create one. - if session is None and len(args) == len(getargspec(fn)[0]) - 1: - private_transaction = True - kwargs['session'] = DBConn().session() + if session is None: + if len(args) <= len(getargspec(fn)[0]) - 1: + # No session specified as last argument or in kwargs + private_transaction = True + session = kwargs['session'] = DBConn().session() + else: + # Session is last argument in args + session = args[-1] + if session is None: + args = list(args) + session = args[-1] = DBConn().session() + private_transaction = True + + if private_transaction: + session.commit_or_flush = session.commit + else: + session.commit_or_flush = session.flush try: return fn(*args, **kwargs) finally: if private_transaction: # We created a session; close it. - kwargs['session'].close() + session.close() + + wrapped.__doc__ = fn.__doc__ + wrapped.func_name = fn.func_name return wrapped @@ -165,7 +193,7 @@ __all__.append('Archive') @session_wrapper def get_archive(archive, session=None): """ - returns database id for given c{archive}. + returns database id for given C{archive}. @type archive: string @param archive: the name of the arhive @@ -223,9 +251,7 @@ def get_suites_binary_in(package, session=None): @return: list of Suite objects for the given package """ - ret = session.query(Suite).join(BinAssociation).join(DBBinary).filter_by(package=package).all() - - return ret + return session.query(Suite).join(BinAssociation).join(DBBinary).filter_by(package=package).all() __all__.append('get_suites_binary_in') @@ -308,9 +334,7 @@ def get_binaries_from_source_id(source_id, session=None): @return: list of DBBinary objects for the given name (may be empty) """ - ret = session.query(DBBinary).filter_by(source_id=source_id).all() - - return ret + return session.query(DBBinary).filter_by(source_id=source_id).all() __all__.append('get_binaries_from_source_id') @@ -330,9 +354,7 @@ def get_binary_from_name_suite(package, suitename, session=None): AND su.suite_name=:suitename ORDER BY b.version DESC""" - ret = session.execute(sql, {'package': package, 'suitename': suitename}) - - return ret + return session.execute(sql, {'package': package, 'suitename': suitename}) __all__.append('get_binary_from_name_suite') @@ -349,9 +371,7 @@ def get_binary_components(package, suitename, arch, session=None): vals = {'package': package, 'suitename': suitename, 'arch': arch} - ret = session.execute(query, vals) - - return ret + return session.execute(query, vals) __all__.append('get_binary_components') @@ -424,6 +444,7 @@ class ContentFilename(object): __all__.append('ContentFilename') +@session_wrapper def get_or_set_contents_file_id(filename, session=None): """ Returns database id for given filename. @@ -440,10 +461,6 @@ def get_or_set_contents_file_id(filename, session=None): @rtype: int @return: the database id for the given component """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(ContentFilename).filter_by(filename=filename) @@ -453,15 +470,9 @@ def get_or_set_contents_file_id(filename, session=None): cf = ContentFilename() cf.filename = filename session.add(cf) - if privatetrans: - session.commit() - else: - session.flush() + session.commit_or_flush() ret = cf.cafilename_id - if privatetrans: - session.close() - return ret __all__.append('get_or_set_contents_file_id') @@ -513,9 +524,7 @@ def get_contents(suite, overridetype, section=None, session=None): contents_q += " ORDER BY fn" - ret = session.execute(contents_q, vals) - - return ret + return session.execute(contents_q, vals) __all__.append('get_contents') @@ -530,6 +539,7 @@ class ContentFilepath(object): __all__.append('ContentFilepath') +@session_wrapper def get_or_set_contents_path_id(filepath, session=None): """ Returns database id for given path. @@ -546,10 +556,6 @@ def get_or_set_contents_path_id(filepath, session=None): @rtype: int @return: the database id for the given path """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(ContentFilepath).filter_by(filepath=filepath) @@ -559,15 +565,9 @@ def get_or_set_contents_path_id(filepath, session=None): cf = ContentFilepath() cf.filepath = filepath session.add(cf) - if privatetrans: - session.commit() - else: - session.flush() + session.commit_or_flush() ret = cf.cafilepath_id - if privatetrans: - session.close() - return ret __all__.append('get_or_set_contents_path_id') @@ -686,9 +686,7 @@ def get_dscfiles(dscfile_id=None, source_id=None, poolfile_id=None, session=None if poolfile_id is not None: q = q.filter_by(poolfile_id=poolfile_id) - ret = q.all() - - return ret + return q.all() __all__.append('get_dscfiles') @@ -795,9 +793,7 @@ def get_poolfile_by_name(filename, location_id=None, session=None): if location_id is not None: q = q.join(Location).filter_by(location_id=location_id) - ret = q.all() - - return ret + return q.all() __all__.append('get_poolfile_by_name') @@ -816,9 +812,7 @@ def get_poolfile_like_name(filename, session=None): # TODO: There must be a way of properly using bind parameters with %FOO% q = session.query(PoolFile).filter(PoolFile.filename.like('%%%s%%' % filename)) - ret = q.all() - - return ret + return q.all() __all__.append('get_poolfile_like_name') @@ -833,6 +827,7 @@ class Fingerprint(object): __all__.append('Fingerprint') +@session_wrapper def get_or_set_fingerprint(fpr, session=None): """ Returns Fingerprint object for given fpr. @@ -851,10 +846,6 @@ def get_or_set_fingerprint(fpr, session=None): @rtype: Fingerprint @return: the Fingerprint object for the given fpr """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(Fingerprint).filter_by(fingerprint=fpr) @@ -864,15 +855,9 @@ def get_or_set_fingerprint(fpr, session=None): fingerprint = Fingerprint() fingerprint.fingerprint = fpr session.add(fingerprint) - if privatetrans: - session.commit() - else: - session.flush() + session.commit_or_flush() ret = fingerprint - if privatetrans: - session.close() - return ret __all__.append('get_or_set_fingerprint') @@ -888,6 +873,7 @@ class Keyring(object): __all__.append('Keyring') +@session_wrapper def get_or_set_keyring(keyring, session=None): """ If C{keyring} does not have an entry in the C{keyrings} table yet, create one @@ -899,28 +885,17 @@ def get_or_set_keyring(keyring, session=None): @rtype: Keyring @return: the Keyring object for this keyring - """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - - try: - obj = session.query(Keyring).filter_by(keyring_name=keyring).first() - if obj is None: - obj = Keyring(keyring_name=keyring) - session.add(obj) - if privatetrans: - session.commit() - else: - session.flush() + q = session.query(Keyring).filter_by(keyring_name=keyring) + try: + return q.one() + except NoResultFound: + obj = Keyring(keyring_name=keyring) + session.add(obj) + session.commit_or_flush() return obj - finally: - if privatetrans: - session.close() __all__.append('get_or_set_keyring') @@ -986,6 +961,7 @@ class Maintainer(object): __all__.append('Maintainer') +@session_wrapper def get_or_set_maintainer(name, session=None): """ Returns Maintainer object for given maintainer name. @@ -1004,10 +980,6 @@ def get_or_set_maintainer(name, session=None): @rtype: Maintainer @return: the Maintainer object for the given maintainer """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True q = session.query(Maintainer).filter_by(name=name) try: @@ -1016,19 +988,14 @@ def get_or_set_maintainer(name, session=None): maintainer = Maintainer() maintainer.name = name session.add(maintainer) - if privatetrans: - session.commit() - else: - session.flush() + session.commit_or_flush() ret = maintainer - if privatetrans: - session.close() - return ret __all__.append('get_or_set_maintainer') +@session_wrapper def get_maintainer(maintainer_id, session=None): """ Return the name of the maintainer behind C{maintainer_id} or None if that @@ -1041,16 +1008,7 @@ def get_maintainer(maintainer_id, session=None): @return: the Maintainer with this C{maintainer_id} """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - - try: - return session.query(Maintainer).get(maintainer_id) - finally: - if privatetrans: - session.close() + return session.query(Maintainer).get(maintainer_id) __all__.append('get_maintainer') @@ -1088,9 +1046,7 @@ def has_new_comment(package, version, session=None): q = q.filter_by(package=package) q = q.filter_by(version=version) - ret = q.count() > 0 - - return ret + return bool(q.count() > 0) __all__.append('has_new_comment') @@ -1122,9 +1078,7 @@ def get_new_comments(package=None, version=None, comment_id=None, session=None): if version is not None: q = q.filter_by(version=version) if comment_id is not None: q = q.filter_by(comment_id=comment_id) - ret = q.all() - - return ret + return q.all() __all__.append('get_new_comments') @@ -1182,9 +1136,7 @@ def get_override(package, suite=None, component=None, overridetype=None, session if not isinstance(overridetype, list): overridetype = [overridetype] q = q.join(OverrideType).filter(OverrideType.overridetype.in_(overridetype)) - ret = q.all() - - return ret + return q.all() __all__.append('get_override') @@ -1223,8 +1175,6 @@ def get_override_type(override_type, session=None): except NoResultFound: return None - return ret - __all__.append('get_override_type') ################################################################################ @@ -1450,6 +1400,7 @@ class Queue(object): # TODO: Move into database as above if conf.FindB("Dinstall::SecurityQueueBuild"): # Copy it since the original won't be readable by www-data + import utils utils.copy(src, dest) else: # Create a symlink to it @@ -1463,23 +1414,27 @@ class Queue(object): session.add(qb) - # If the .orig.tar.gz is in the pool, create a symlink to - # it (if one doesn't already exist) - if changes.orig_tar_id: - # Determine the .orig.tar.gz file name - for dsc_file in changes.dsc_files.keys(): - if dsc_file.endswith(".orig.tar.gz"): - filename = dsc_file - - dest = os.path.join(dest_dir, filename) + # If the .orig tarballs are in the pool, create a symlink to + # them (if one doesn't already exist) + for dsc_file in changes.dsc_files.keys(): + # Skip all files except orig tarballs + from daklib.regexes import re_is_orig_source + if not re_is_orig_source.match(dsc_file): + continue + # Skip orig files not identified in the pool + if not (changes.orig_files.has_key(dsc_file) and + changes.orig_files[dsc_file].has_key("id")): + continue + orig_file_id = changes.orig_files[dsc_file]["id"] + dest = os.path.join(dest_dir, dsc_file) # If it doesn't exist, create a symlink if not os.path.exists(dest): q = session.execute("SELECT l.path, f.filename FROM location l, files f WHERE f.id = :id and f.location = l.id", - {'id': changes.orig_tar_id}) + {'id': orig_file_id}) res = q.fetchone() if not res: - return "[INTERNAL ERROR] Couldn't find id %s in files table." % (changes.orig_tar_id) + return "[INTERNAL ERROR] Couldn't find id %s in files table." % (orig_file_id) src = os.path.join(res[0], res[1]) os.symlink(src, dest) @@ -1509,9 +1464,10 @@ class Queue(object): __all__.append('Queue') @session_wrapper -def get_queue(queuename, session=None): +def get_or_set_queue(queuename, session=None): """ - Returns Queue object for given C{queue name}. + Returns Queue object for given C{queue name}, creating it if it does not + exist. @type queuename: string @param queuename: The name of the queue @@ -1527,11 +1483,17 @@ def get_queue(queuename, session=None): q = session.query(Queue).filter_by(queue_name=queuename) try: - return q.one() + ret = q.one() except NoResultFound: - return None + queue = Queue() + queue.queue_name = queuename + session.add(queue) + session.commit_or_flush() + ret = queue -__all__.append('get_queue') + return ret + +__all__.append('get_or_set_queue') ################################################################################ @@ -1737,9 +1699,7 @@ def get_suites_source_in(source, session=None): @return: list of Suite objects for the given source """ - ret = session.query(Suite).join(SrcAssociation).join(DBSource).filter_by(source=source).all() - - return ret + return session.query(Suite).join(SrcAssociation).join(DBSource).filter_by(source=source).all() __all__.append('get_suites_source_in') @@ -1774,9 +1734,7 @@ def get_sources_from_name(source, version=None, dm_upload_allowed=None, session= if dm_upload_allowed is not None: q = q.filter_by(dm_upload_allowed=dm_upload_allowed) - ret = q.all() - - return ret + return q.all() __all__.append('get_sources_from_name') @@ -1823,6 +1781,17 @@ __all__.append('SrcAssociation') ################################################################################ +class SrcFormat(object): + def __init__(self, *args, **kwargs): + pass + + def __repr__(self): + return '' % (self.format_name) + +__all__.append('SrcFormat') + +################################################################################ + class SrcUploader(object): def __init__(self, *args, **kwargs): pass @@ -1928,7 +1897,7 @@ def get_suite(suite, session=None): generated if not supplied) @rtype: Suite - @return: Suite object for the requested suite name (None if not presenT) + @return: Suite object for the requested suite name (None if not present) """ q = session.query(Suite).filter_by(suite_name=suite) @@ -1987,14 +1956,48 @@ def get_suite_architectures(suite, skipsrc=False, skipall=False, session=None): q = q.order_by('arch_string') - ret = q.all() - - return ret + return q.all() __all__.append('get_suite_architectures') ################################################################################ +class SuiteSrcFormat(object): + def __init__(self, *args, **kwargs): + pass + + def __repr__(self): + return '' % (self.suite_id, self.src_format_id) + +__all__.append('SuiteSrcFormat') + +@session_wrapper +def get_suite_src_formats(suite, session=None): + """ + Returns list of allowed SrcFormat for C{suite}. + + @type suite: str + @param suite: Suite name to search for + + @type session: Session + @param session: Optional SQL session object (a temporary one will be + generated if not supplied) + + @rtype: list + @return: the list of allowed source formats for I{suite} + """ + + q = session.query(SrcFormat) + q = q.join(SuiteSrcFormat) + q = q.join(Suite).filter_by(suite_name=suite) + q = q.order_by('format_name') + + return q.all() + +__all__.append('get_suite_src_formats') + +################################################################################ + class Uid(object): def __init__(self, *args, **kwargs): pass @@ -2016,6 +2019,7 @@ class Uid(object): __all__.append('Uid') +@session_wrapper def add_database_user(uidname, session=None): """ Adds a database user @@ -2032,19 +2036,12 @@ def add_database_user(uidname, session=None): @return: the uid object for the given uidname """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - session.execute("CREATE USER :uid", {'uid': uidname}) - - if privatetrans: - session.commit() - session.close() + session.commit_or_flush() __all__.append('add_database_user') +@session_wrapper def get_or_set_uid(uidname, session=None): """ Returns uid object for given uidname. @@ -2063,11 +2060,6 @@ def get_or_set_uid(uidname, session=None): @return: the uid object for the given uidname """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - q = session.query(Uid).filter_by(uid=uidname) try: @@ -2076,15 +2068,9 @@ def get_or_set_uid(uidname, session=None): uid = Uid() uid.uid = uidname session.add(uid) - if privatetrans: - session.commit() - else: - session.flush() + session.commit_or_flush() ret = uid - if privatetrans: - session.close() - return ret __all__.append('get_or_set_uid') @@ -2142,9 +2128,11 @@ class DBConn(Singleton): self.tbl_section = Table('section', self.db_meta, autoload=True) self.tbl_source = Table('source', self.db_meta, autoload=True) self.tbl_src_associations = Table('src_associations', self.db_meta, autoload=True) + self.tbl_src_format = Table('src_format', self.db_meta, autoload=True) self.tbl_src_uploaders = Table('src_uploaders', self.db_meta, autoload=True) self.tbl_suite = Table('suite', self.db_meta, autoload=True) self.tbl_suite_architectures = Table('suite_architectures', self.db_meta, autoload=True) + self.tbl_suite_src_formats = Table('suite_src_formats', self.db_meta, autoload=True) self.tbl_uid = Table('uid', self.db_meta, autoload=True) def __setupmappers(self): @@ -2306,6 +2294,10 @@ class DBConn(Singleton): source_id = self.tbl_src_associations.c.source, source = relation(DBSource))) + mapper(SrcFormat, self.tbl_src_format, + properties = dict(src_format_id = self.tbl_src_format.c.id, + format_name = self.tbl_src_format.c.format_name)) + mapper(SrcUploader, self.tbl_src_uploaders, properties = dict(uploader_id = self.tbl_src_uploaders.c.id, source_id = self.tbl_src_uploaders.c.source, @@ -2324,6 +2316,12 @@ class DBConn(Singleton): arch_id = self.tbl_suite_architectures.c.architecture, architecture = relation(Architecture))) + mapper(SuiteSrcFormat, self.tbl_suite_src_formats, + properties = dict(suite_id = self.tbl_suite_src_formats.c.suite, + suite = relation(Suite, backref='suitesrcformats'), + src_format_id = self.tbl_suite_src_formats.c.src_format, + src_format = relation(SrcFormat))) + mapper(Uid, self.tbl_uid, properties = dict(uid_id = self.tbl_uid.c.id, fingerprint = relation(Fingerprint)))