X-Git-Url: https://git.decadent.org.uk/gitweb/?a=blobdiff_plain;f=daklib%2Fdbconn.py;h=9b37af700f631b08fb7990dd4434d440cd6d8d54;hb=63a936065dc2979df325eb34a205c3c97e0cd4ce;hp=5aa40210fc5b62684101b5bf995e4b14aec013cc;hpb=d5c3522f3b94dd7a1377bc6cd5848a3a733044fb;p=dak.git diff --git a/daklib/dbconn.py b/daklib/dbconn.py index 5aa40210..9b37af70 100755 --- a/daklib/dbconn.py +++ b/daklib/dbconn.py @@ -34,14 +34,18 @@ ################################################################################ import os +import re import psycopg2 import traceback -from sqlalchemy import create_engine, Table, MetaData, select +from inspect import getargspec + +from sqlalchemy import create_engine, Table, MetaData from sqlalchemy.orm import sessionmaker, mapper, relation # Don't remove this, we re-export the exceptions to scripts which import us from sqlalchemy.exc import * +from sqlalchemy.orm.exc import NoResultFound # Only import Config until Queue stuff is changed to store its config # in the database @@ -55,6 +59,55 @@ __all__ = ['IntegrityError', 'SQLAlchemyError'] ################################################################################ +def session_wrapper(fn): + """ + Wrapper around common ".., session=None):" handling. If the wrapped + function is called without passing 'session', we create a local one + and destroy it when the function ends. + + Also attaches a commit_or_flush method to the session; if we created a + local session, this is a synonym for session.commit(), otherwise it is a + synonym for session.flush(). + """ + + def wrapped(*args, **kwargs): + private_transaction = False + + # Find the session object + session = kwargs.get('session') + + if session is None: + if len(args) <= len(getargspec(fn)[0]) - 1: + # No session specified as last argument or in kwargs + private_transaction = True + session = kwargs['session'] = DBConn().session() + else: + # Session is last argument in args + session = args[-1] + if session is None: + args = list(args) + session = args[-1] = DBConn().session() + private_transaction = True + + if private_transaction: + session.commit_or_flush = session.commit + else: + session.commit_or_flush = session.flush + + try: + return fn(*args, **kwargs) + finally: + if private_transaction: + # We created a session; close it. + session.close() + + wrapped.__doc__ = fn.__doc__ + wrapped.func_name = fn.func_name + + return wrapped + +################################################################################ + class Architecture(object): def __init__(self, *args, **kwargs): pass @@ -76,6 +129,7 @@ class Architecture(object): __all__.append('Architecture') +@session_wrapper def get_architecture(architecture, session=None): """ Returns database id for given C{architecture}. @@ -89,17 +143,18 @@ def get_architecture(architecture, session=None): @rtype: Architecture @return: Architecture object for the given arch (None if not present) - """ - if session is None: - session = DBConn().session() + q = session.query(Architecture).filter_by(arch_string=architecture) - if q.count() == 0: + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_architecture') +@session_wrapper def get_architecture_suites(architecture, session=None): """ Returns list of Suite objects for given C{architecture} name @@ -115,13 +170,13 @@ def get_architecture_suites(architecture, session=None): @return: list of Suite objects for the given name (may be empty) """ - if session is None: - session = DBConn().session() - q = session.query(Suite) q = q.join(SuiteArchitecture) q = q.join(Architecture).filter_by(arch_string=architecture).order_by('suite_name') - return q.all() + + ret = q.all() + + return ret __all__.append('get_architecture_suites') @@ -132,13 +187,14 @@ class Archive(object): pass def __repr__(self): - return '' % self.name + return '' % self.archive_name __all__.append('Archive') +@session_wrapper def get_archive(archive, session=None): """ - returns database id for given c{archive}. + returns database id for given C{archive}. @type archive: string @param archive: the name of the arhive @@ -152,12 +208,13 @@ def get_archive(archive, session=None): """ archive = archive.lower() - if session is None: - session = DBConn().session() + q = session.query(Archive).filter_by(archive_name=archive) - if q.count() == 0: + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_archive') @@ -183,6 +240,7 @@ class DBBinary(object): __all__.append('DBBinary') +@session_wrapper def get_suites_binary_in(package, session=None): """ Returns list of Suite objects which given C{package} name is in @@ -194,13 +252,11 @@ def get_suites_binary_in(package, session=None): @return: list of Suite objects for the given package """ - if session is None: - session = DBConn().session() - return session.query(Suite).join(BinAssociation).join(DBBinary).filter_by(package=package).all() __all__.append('get_suites_binary_in') +@session_wrapper def get_binary_from_id(id, session=None): """ Returns DBBinary object for given C{id} @@ -215,15 +271,17 @@ def get_binary_from_id(id, session=None): @rtype: DBBinary @return: DBBinary object for the given binary (None if not present) """ - if session is None: - session = DBConn().session() + q = session.query(DBBinary).filter_by(binary_id=id) - if q.count() == 0: + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_binary_from_id') +@session_wrapper def get_binaries_from_name(package, version=None, architecture=None, session=None): """ Returns list of DBBinary objects for given C{package} name @@ -244,8 +302,6 @@ def get_binaries_from_name(package, version=None, architecture=None, session=Non @rtype: list @return: list of DBBinary objects for the given name (may be empty) """ - if session is None: - session = DBConn().session() q = session.query(DBBinary).filter_by(package=package) @@ -257,10 +313,13 @@ def get_binaries_from_name(package, version=None, architecture=None, session=Non architecture = [architecture] q = q.join(Architecture).filter(Architecture.arch_string.in_(architecture)) - return q.all() + ret = q.all() + + return ret __all__.append('get_binaries_from_name') +@session_wrapper def get_binaries_from_source_id(source_id, session=None): """ Returns list of DBBinary objects for given C{source_id} @@ -275,18 +334,15 @@ def get_binaries_from_source_id(source_id, session=None): @rtype: list @return: list of DBBinary objects for the given name (may be empty) """ - if session is None: - session = DBConn().session() + return session.query(DBBinary).filter_by(source_id=source_id).all() __all__.append('get_binaries_from_source_id') - +@session_wrapper def get_binary_from_name_suite(package, suitename, session=None): ### For dak examine-package ### XXX: Doesn't use object API yet - if session is None: - session = DBConn().session() sql = """SELECT DISTINCT(b.package), b.version, c.name, su.suite_name FROM binaries b, files fi, location l, component c, bin_associations ba, suite su @@ -303,8 +359,9 @@ def get_binary_from_name_suite(package, suitename, session=None): __all__.append('get_binary_from_name_suite') +@session_wrapper def get_binary_components(package, suitename, arch, session=None): -# Check for packages that have moved from one component to another + # Check for packages that have moved from one component to another query = """SELECT c.name FROM binaries b, bin_associations ba, suite s, location l, component c, architecture a, files f WHERE b.package=:package AND s.suite_name=:suitename AND (a.arch_string = :arch OR a.arch_string = 'all') @@ -315,14 +372,34 @@ def get_binary_components(package, suitename, arch, session=None): vals = {'package': package, 'suitename': suitename, 'arch': arch} - if session is None: - session = DBConn().session() return session.execute(query, vals) __all__.append('get_binary_components') ################################################################################ +class BinaryACL(object): + def __init__(self, *args, **kwargs): + pass + + def __repr__(self): + return '' % self.binary_acl_id + +__all__.append('BinaryACL') + +################################################################################ + +class BinaryACLMap(object): + def __init__(self, *args, **kwargs): + pass + + def __repr__(self): + return '' % self.binary_acl_map_id + +__all__.append('BinaryACLMap') + +################################################################################ + class Component(object): def __init__(self, *args, **kwargs): pass @@ -345,6 +422,7 @@ class Component(object): __all__.append('Component') +@session_wrapper def get_component(component, session=None): """ Returns database id for given C{component}. @@ -357,12 +435,13 @@ def get_component(component, session=None): """ component = component.lower() - if session is None: - session = DBConn().session() + q = session.query(Component).filter_by(component_name=component) - if q.count() == 0: + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_component') @@ -388,6 +467,7 @@ class ContentFilename(object): __all__.append('ContentFilename') +@session_wrapper def get_or_set_contents_file_id(filename, session=None): """ Returns database id for given filename. @@ -404,29 +484,23 @@ def get_or_set_contents_file_id(filename, session=None): @rtype: int @return: the database id for the given component """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True + + q = session.query(ContentFilename).filter_by(filename=filename) try: - q = session.query(ContentFilename).filter_by(filename=filename) - if q.count() < 1: - cf = ContentFilename() - cf.filename = filename - session.add(cf) - if privatetrans: - session.commit() - return cf.cafilename_id - else: - return q.one().cafilename_id + ret = q.one().cafilename_id + except NoResultFound: + cf = ContentFilename() + cf.filename = filename + session.add(cf) + session.commit_or_flush() + ret = cf.cafilename_id - except: - traceback.print_exc() - raise + return ret __all__.append('get_or_set_contents_file_id') +@session_wrapper def get_contents(suite, overridetype, section=None, session=None): """ Returns contents for a suite / overridetype combination, limiting @@ -450,9 +524,6 @@ def get_contents(suite, overridetype, section=None, session=None): package, arch_id) """ - if session is None: - session = DBConn().session() - # find me all of the contents for a given suite contents_q = """SELECT (p.path||'/'||n.file) AS fn, s.section, @@ -491,7 +562,8 @@ class ContentFilepath(object): __all__.append('ContentFilepath') -def get_or_set_contents_path_id(filepath, session): +@session_wrapper +def get_or_set_contents_path_id(filepath, session=None): """ Returns database id for given path. @@ -507,26 +579,19 @@ def get_or_set_contents_path_id(filepath, session): @rtype: int @return: the database id for the given path """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True + + q = session.query(ContentFilepath).filter_by(filepath=filepath) try: - q = session.query(ContentFilepath).filter_by(filepath=filepath) - if q.count() < 1: - cf = ContentFilepath() - cf.filepath = filepath - session.add(cf) - if privatetrans: - session.commit() - return cf.cafilepath_id - else: - return q.one().cafilepath_id + ret = q.one().cafilepath_id + except NoResultFound: + cf = ContentFilepath() + cf.filepath = filepath + session.add(cf) + session.commit_or_flush() + ret = cf.cafilepath_id - except: - traceback.print_exc() - raise + return ret __all__.append('get_or_set_contents_path_id') @@ -560,33 +625,45 @@ def insert_content_paths(binary_id, fullpaths, session=None): """ privatetrans = False - if session is None: session = DBConn().session() privatetrans = True try: + # Insert paths + pathcache = {} for fullpath in fullpaths: + # Get the necessary IDs ... (path, file) = os.path.split(fullpath) - # Get the necessary IDs ... + filepath_id = get_or_set_contents_path_id(path, session) + filename_id = get_or_set_contents_file_id(file, session) + + pathcache[fullpath] = (filepath_id, filename_id) + + for fullpath, dat in pathcache.items(): ca = ContentAssociation() ca.binary_id = binary_id - ca.filename_id = get_or_set_contents_file_id(file) - ca.filepath_id = get_or_set_contents_path_id(path) + ca.filepath_id = dat[0] + ca.filename_id = dat[1] session.add(ca) # Only commit if we set up the session ourself if privatetrans: session.commit() + session.close() + else: + session.flush() return True + except: traceback.print_exc() # Only rollback if we set up the session ourself if privatetrans: session.rollback() + session.close() return False @@ -603,6 +680,39 @@ class DSCFile(object): __all__.append('DSCFile') +@session_wrapper +def get_dscfiles(dscfile_id=None, source_id=None, poolfile_id=None, session=None): + """ + Returns a list of DSCFiles which may be empty + + @type dscfile_id: int (optional) + @param dscfile_id: the dscfile_id of the DSCFiles to find + + @type source_id: int (optional) + @param source_id: the source id related to the DSCFiles to find + + @type poolfile_id: int (optional) + @param poolfile_id: the poolfile id related to the DSCFiles to find + + @rtype: list + @return: Possibly empty list of DSCFiles + """ + + q = session.query(DSCFile) + + if dscfile_id is not None: + q = q.filter_by(dscfile_id=dscfile_id) + + if source_id is not None: + q = q.filter_by(source_id=source_id) + + if poolfile_id is not None: + q = q.filter_by(poolfile_id=poolfile_id) + + return q.all() + +__all__.append('get_dscfiles') + ################################################################################ class PoolFile(object): @@ -614,6 +724,7 @@ class PoolFile(object): __all__.append('PoolFile') +@session_wrapper def check_poolfile(filename, filesize, md5sum, location_id, session=None): """ Returns a tuple: @@ -641,26 +752,50 @@ def check_poolfile(filename, filesize, md5sum, location_id, session=None): (False, PoolFile object) if file found with size/md5sum mismatch """ - if session is None: - session = DBConn().session() - q = session.query(PoolFile).filter_by(filename=filename) q = q.join(Location).filter_by(location_id=location_id) + ret = None + if q.count() > 1: - return (None, None) - if q.count() < 1: - return (False, None) + ret = (None, None) + elif q.count() < 1: + ret = (False, None) + else: + obj = q.one() + if obj.md5sum != md5sum or obj.filesize != filesize: + ret = (False, obj) - obj = q.one() - if obj.md5sum != md5sum or obj.filesize != filesize: - return (False, obj) + if ret is None: + ret = (True, obj) - return (True, obj) + return ret __all__.append('check_poolfile') +@session_wrapper +def get_poolfile_by_id(file_id, session=None): + """ + Returns a PoolFile objects or None for the given id + + @type file_id: int + @param file_id: the id of the file to look for + + @rtype: PoolFile or None + @return: either the PoolFile object or None + """ + + q = session.query(PoolFile).filter_by(file_id=file_id) + try: + return q.one() + except NoResultFound: + return None + +__all__.append('get_poolfile_by_id') + + +@session_wrapper def get_poolfile_by_name(filename, location_id=None, session=None): """ Returns an array of PoolFile objects for the given filename and @@ -676,9 +811,6 @@ def get_poolfile_by_name(filename, location_id=None, session=None): @return: array of PoolFile objects """ - if session is None: - session = DBConn().session() - q = session.query(PoolFile).filter_by(filename=filename) if location_id is not None: @@ -688,6 +820,7 @@ def get_poolfile_by_name(filename, location_id=None, session=None): __all__.append('get_poolfile_by_name') +@session_wrapper def get_poolfile_like_name(filename, session=None): """ Returns an array of PoolFile objects which are like the given name @@ -699,9 +832,6 @@ def get_poolfile_like_name(filename, session=None): @return: array of PoolFile objects """ - if session is None: - session = DBConn().session() - # TODO: There must be a way of properly using bind parameters with %FOO% q = session.query(PoolFile).filter(PoolFile.filename.like('%%%s%%' % filename)) @@ -720,6 +850,34 @@ class Fingerprint(object): __all__.append('Fingerprint') +@session_wrapper +def get_fingerprint(fpr, session=None): + """ + Returns Fingerprint object for given fpr. + + @type fpr: string + @param fpr: The fpr to find / add + + @type session: SQLAlchemy + @param session: Optional SQL session object (a temporary one will be + generated if not supplied). + + @rtype: Fingerprint + @return: the Fingerprint object for the given fpr or None + """ + + q = session.query(Fingerprint).filter_by(fingerprint=fpr) + + try: + ret = q.one() + except NoResultFound: + ret = None + + return ret + +__all__.append('get_fingerprint') + +@session_wrapper def get_or_set_fingerprint(fpr, session=None): """ Returns Fingerprint object for given fpr. @@ -738,42 +896,186 @@ def get_or_set_fingerprint(fpr, session=None): @rtype: Fingerprint @return: the Fingerprint object for the given fpr """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True + + q = session.query(Fingerprint).filter_by(fingerprint=fpr) try: - q = session.query(Fingerprint).filter_by(fingerprint=fpr) - if q.count() < 1: - fingerprint = Fingerprint() - fingerprint.fingerprint = fpr - session.add(fingerprint) - if privatetrans: - session.commit() - else: - session.flush() - return fingerprint - else: - return q.one() + ret = q.one() + except NoResultFound: + fingerprint = Fingerprint() + fingerprint.fingerprint = fpr + session.add(fingerprint) + session.commit_or_flush() + ret = fingerprint - except: - traceback.print_exc() - raise + return ret __all__.append('get_or_set_fingerprint') ################################################################################ +# Helper routine for Keyring class +def get_ldap_name(entry): + name = [] + for k in ["cn", "mn", "sn"]: + ret = entry.get(k) + if ret and ret[0] != "" and ret[0] != "-": + name.append(ret[0]) + return " ".join(name) + +################################################################################ + class Keyring(object): + gpg_invocation = "gpg --no-default-keyring --keyring %s" +\ + " --with-colons --fingerprint --fingerprint" + + keys = {} + fpr_lookup = {} + def __init__(self, *args, **kwargs): pass def __repr__(self): return '' % self.keyring_name + def de_escape_gpg_str(self, str): + esclist = re.split(r'(\\x..)', str) + for x in range(1,len(esclist),2): + esclist[x] = "%c" % (int(esclist[x][2:],16)) + return "".join(esclist) + + def load_keys(self, keyring): + import email.Utils + + if not self.keyring_id: + raise Exception('Must be initialized with database information') + + k = os.popen(self.gpg_invocation % keyring, "r") + key = None + signingkey = False + + for line in k.xreadlines(): + field = line.split(":") + if field[0] == "pub": + key = field[4] + (name, addr) = email.Utils.parseaddr(field[9]) + name = re.sub(r"\s*[(].*[)]", "", name) + if name == "" or addr == "" or "@" not in addr: + name = field[9] + addr = "invalid-uid" + name = self.de_escape_gpg_str(name) + self.keys[key] = {"email": addr} + if name != "": + self.keys[key]["name"] = name + self.keys[key]["aliases"] = [name] + self.keys[key]["fingerprints"] = [] + signingkey = True + elif key and field[0] == "sub" and len(field) >= 12: + signingkey = ("s" in field[11]) + elif key and field[0] == "uid": + (name, addr) = email.Utils.parseaddr(field[9]) + if name and name not in self.keys[key]["aliases"]: + self.keys[key]["aliases"].append(name) + elif signingkey and field[0] == "fpr": + self.keys[key]["fingerprints"].append(field[9]) + self.fpr_lookup[field[9]] = key + + def import_users_from_ldap(self, session): + import ldap + cnf = Config() + + LDAPDn = cnf["Import-LDAP-Fingerprints::LDAPDn"] + LDAPServer = cnf["Import-LDAP-Fingerprints::LDAPServer"] + + l = ldap.open(LDAPServer) + l.simple_bind_s("","") + Attrs = l.search_s(LDAPDn, ldap.SCOPE_ONELEVEL, + "(&(keyfingerprint=*)(gidnumber=%s))" % (cnf["Import-Users-From-Passwd::ValidGID"]), + ["uid", "keyfingerprint", "cn", "mn", "sn"]) + + ldap_fin_uid_id = {} + + byuid = {} + byname = {} + + for i in Attrs: + entry = i[1] + uid = entry["uid"][0] + name = get_ldap_name(entry) + fingerprints = entry["keyFingerPrint"] + keyid = None + for f in fingerprints: + key = self.fpr_lookup.get(f, None) + if key not in self.keys: + continue + self.keys[key]["uid"] = uid + + if keyid != None: + continue + keyid = get_or_set_uid(uid, session).uid_id + byuid[keyid] = (uid, name) + byname[uid] = (keyid, name) + + return (byname, byuid) + + def generate_users_from_keyring(self, format, session): + byuid = {} + byname = {} + any_invalid = False + for x in self.keys.keys(): + if self.keys[x]["email"] == "invalid-uid": + any_invalid = True + self.keys[x]["uid"] = format % "invalid-uid" + else: + uid = format % self.keys[x]["email"] + keyid = get_or_set_uid(uid, session).uid_id + byuid[keyid] = (uid, self.keys[x]["name"]) + byname[uid] = (keyid, self.keys[x]["name"]) + self.keys[x]["uid"] = uid + + if any_invalid: + uid = format % "invalid-uid" + keyid = get_or_set_uid(uid, session).uid_id + byuid[keyid] = (uid, "ungeneratable user id") + byname[uid] = (keyid, "ungeneratable user id") + + return (byname, byuid) + __all__.append('Keyring') +@session_wrapper +def get_keyring(keyring, session=None): + """ + If C{keyring} does not have an entry in the C{keyrings} table yet, return None + If C{keyring} already has an entry, simply return the existing Keyring + + @type keyring: string + @param keyring: the keyring name + + @rtype: Keyring + @return: the Keyring object for this keyring + """ + + q = session.query(Keyring).filter_by(keyring_name=keyring) + + try: + return q.one() + except NoResultFound: + return None + +__all__.append('get_keyring') + +################################################################################ + +class KeyringACLMap(object): + def __init__(self, *args, **kwargs): + pass + + def __repr__(self): + return '' % self.keyring_acl_map_id + +__all__.append('KeyringACLMap') + ################################################################################ class Location(object): @@ -785,6 +1087,7 @@ class Location(object): __all__.append('Location') +@session_wrapper def get_location(location, component=None, archive=None, session=None): """ Returns Location object for the given combination of location, component @@ -803,9 +1106,6 @@ def get_location(location, component=None, archive=None, session=None): @return: Either a Location object or None if one can't be found """ - if session is None: - session = DBConn().session() - q = session.query(Location).filter_by(path=location) if archive is not None: @@ -814,10 +1114,10 @@ def get_location(location, component=None, archive=None, session=None): if component is not None: q = q.join(Component).filter_by(component_name=component) - if q.count() < 1: - return None - else: + try: return q.one() + except NoResultFound: + return None __all__.append('get_location') @@ -838,6 +1138,7 @@ class Maintainer(object): __all__.append('Maintainer') +@session_wrapper def get_or_set_maintainer(name, session=None): """ Returns Maintainer object for given maintainer name. @@ -856,31 +1157,38 @@ def get_or_set_maintainer(name, session=None): @rtype: Maintainer @return: the Maintainer object for the given maintainer """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True + q = session.query(Maintainer).filter_by(name=name) try: - q = session.query(Maintainer).filter_by(name=name) - if q.count() < 1: - maintainer = Maintainer() - maintainer.name = name - session.add(maintainer) - if privatetrans: - session.commit() - else: - session.flush() - return maintainer - else: - return q.one() + ret = q.one() + except NoResultFound: + maintainer = Maintainer() + maintainer.name = name + session.add(maintainer) + session.commit_or_flush() + ret = maintainer - except: - traceback.print_exc() - raise + return ret __all__.append('get_or_set_maintainer') +@session_wrapper +def get_maintainer(maintainer_id, session=None): + """ + Return the name of the maintainer behind C{maintainer_id} or None if that + maintainer_id is invalid. + + @type maintainer_id: int + @param maintainer_id: the id of the maintainer + + @rtype: Maintainer + @return: the Maintainer with this C{maintainer_id} + """ + + return session.query(Maintainer).get(maintainer_id) + +__all__.append('get_maintainer') + ################################################################################ class NewComment(object): @@ -892,6 +1200,7 @@ class NewComment(object): __all__.append('NewComment') +@session_wrapper def has_new_comment(package, version, session=None): """ Returns true if the given combination of C{package}, C{version} has a comment. @@ -910,16 +1219,15 @@ def has_new_comment(package, version, session=None): @return: true/false """ - if session is None: - session = DBConn().session() - q = session.query(NewComment) q = q.filter_by(package=package) q = q.filter_by(version=version) - return q.count() > 0 + + return bool(q.count() > 0) __all__.append('has_new_comment') +@session_wrapper def get_new_comments(package=None, version=None, comment_id=None, session=None): """ Returns (possibly empty) list of NewComment objects for the given @@ -940,12 +1248,8 @@ def get_new_comments(package=None, version=None, comment_id=None, session=None): @rtype: list @return: A (possibly empty) list of NewComment objects will be returned - """ - if session is None: - session = DBConn().session() - q = session.query(NewComment) if package is not None: q = q.filter_by(package=package) if version is not None: q = q.filter_by(version=version) @@ -966,6 +1270,7 @@ class Override(object): __all__.append('Override') +@session_wrapper def get_override(package, suite=None, component=None, overridetype=None, session=None): """ Returns Override object for the given parameters @@ -991,10 +1296,7 @@ def get_override(package, suite=None, component=None, overridetype=None, session @rtype: list @return: A (possibly empty) list of Override objects will be returned - """ - if session is None: - session = DBConn().session() q = session.query(Override) q = q.filter_by(package=package) @@ -1027,6 +1329,7 @@ class OverrideType(object): __all__.append('OverrideType') +@session_wrapper def get_override_type(override_type, session=None): """ Returns OverrideType object for given C{override type}. @@ -1040,14 +1343,14 @@ def get_override_type(override_type, session=None): @rtype: int @return: the database id for the given override type - """ - if session is None: - session = DBConn().session() + q = session.query(OverrideType).filter_by(overridetype=override_type) - if q.count() == 0: + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_override_type') @@ -1098,31 +1401,42 @@ def insert_pending_content_paths(package, fullpaths, session=None): q.delete() # Insert paths + pathcache = {} for fullpath in fullpaths: (path, file) = os.path.split(fullpath) if path.startswith( "./" ): path = path[2:] + filepath_id = get_or_set_contents_path_id(path, session) + filename_id = get_or_set_contents_file_id(file, session) + + pathcache[fullpath] = (filepath_id, filename_id) + + for fullpath, dat in pathcache.items(): pca = PendingContentAssociation() pca.package = package['Package'] pca.version = package['Version'] - pca.filename_id = get_or_set_contents_file_id(file, session) - pca.filepath_id = get_or_set_contents_path_id(path, session) + pca.filepath_id = dat[0] + pca.filename_id = dat[1] pca.architecture = arch_id session.add(pca) # Only commit if we set up the session ourself if privatetrans: session.commit() + session.close() + else: + session.flush() return True - except: + except Exception, e: traceback.print_exc() # Only rollback if we set up the session ourself if privatetrans: session.rollback() + session.close() return False @@ -1151,6 +1465,7 @@ class Priority(object): __all__.append('Priority') +@session_wrapper def get_priority(priority, session=None): """ Returns Priority object for given C{priority name}. @@ -1164,17 +1479,39 @@ def get_priority(priority, session=None): @rtype: Priority @return: Priority object for the given priority - """ - if session is None: - session = DBConn().session() + q = session.query(Priority).filter_by(priority=priority) - if q.count() == 0: + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_priority') +@session_wrapper +def get_priorities(session=None): + """ + Returns dictionary of priority names -> id mappings + + @type session: Session + @param session: Optional SQL session object (a temporary one will be + generated if not supplied) + + @rtype: dictionary + @return: dictionary of priority names -> id mappings + """ + + ret = {} + q = session.query(Priority) + for x in q.all(): + ret[x.priority] = x.priority_id + + return ret + +__all__.append('get_priorities') + ################################################################################ class Queue(object): @@ -1205,10 +1542,10 @@ class Queue(object): @return: None if the operation failed, a string describing the error if not """ - localcommit = False + privatetrans = False if session is None: session = DBConn().session() - localcommit = True + privatetrans = True # TODO: Remove by moving queue config into the database conf = Config() @@ -1238,8 +1575,9 @@ class Queue(object): dest = os.path.join(dest_dir, file_entry) # TODO: Move into database as above - if Cnf.FindB("Dinstall::SecurityQueueBuild"): + if conf.FindB("Dinstall::SecurityQueueBuild"): # Copy it since the original won't be readable by www-data + import utils utils.copy(src, dest) else: # Create a symlink to it @@ -1253,23 +1591,27 @@ class Queue(object): session.add(qb) - # If the .orig.tar.gz is in the pool, create a symlink to - # it (if one doesn't already exist) - if changes.orig_tar_id: - # Determine the .orig.tar.gz file name - for dsc_file in changes.dsc_files.keys(): - if dsc_file.endswith(".orig.tar.gz"): - filename = dsc_file - - dest = os.path.join(dest_dir, filename) + # If the .orig tarballs are in the pool, create a symlink to + # them (if one doesn't already exist) + for dsc_file in changes.dsc_files.keys(): + # Skip all files except orig tarballs + from daklib.regexes import re_is_orig_source + if not re_is_orig_source.match(dsc_file): + continue + # Skip orig files not identified in the pool + if not (changes.orig_files.has_key(dsc_file) and + changes.orig_files[dsc_file].has_key("id")): + continue + orig_file_id = changes.orig_files[dsc_file]["id"] + dest = os.path.join(dest_dir, dsc_file) # If it doesn't exist, create a symlink if not os.path.exists(dest): q = session.execute("SELECT l.path, f.filename FROM location l, files f WHERE f.id = :id and f.location = l.id", - {'id': changes.orig_tar_id}) + {'id': orig_file_id}) res = q.fetchone() if not res: - return "[INTERNAL ERROR] Couldn't find id %s in files table." % (changes.orig_tar_id) + return "[INTERNAL ERROR] Couldn't find id %s in files table." % (orig_file_id) src = os.path.join(res[0], res[1]) os.symlink(src, dest) @@ -1284,22 +1626,25 @@ class Queue(object): # If it does, update things to ensure it's not removed prematurely else: - qb = get_queue_build(dest, suite_id, session) + qb = get_queue_build(dest, s.suite_id, session) if qb is None: qb.in_queue = True qb.last_used = None session.add(qb) - if localcommit: + if privatetrans: session.commit() + session.close() return None __all__.append('Queue') -def get_queue(queuename, session=None): +@session_wrapper +def get_or_set_queue(queuename, session=None): """ - Returns Queue object for given C{queue name}. + Returns Queue object for given C{queue name}, creating it if it does not + exist. @type queuename: string @param queuename: The name of the queue @@ -1310,16 +1655,22 @@ def get_queue(queuename, session=None): @rtype: Queue @return: Queue object for the given queue - """ - if session is None: - session = DBConn().session() + q = session.query(Queue).filter_by(queue_name=queuename) - if q.count() == 0: - return None - return q.one() -__all__.append('get_queue') + try: + ret = q.one() + except NoResultFound: + queue = Queue() + queue.queue_name = queuename + session.add(queue) + session.commit_or_flush() + ret = queue + + return ret + +__all__.append('get_or_set_queue') ################################################################################ @@ -1332,15 +1683,16 @@ class QueueBuild(object): __all__.append('QueueBuild') -def get_queue_build(filename, suite_id, session=None): +@session_wrapper +def get_queue_build(filename, suite, session=None): """ - Returns QueueBuild object for given C{filename} and C{suite id}. + Returns QueueBuild object for given C{filename} and C{suite}. @type filename: string @param filename: The name of the file - @type suiteid: int - @param suiteid: Suite ID + @type suiteid: int or str + @param suiteid: Suite name or ID @type session: Session @param session: Optional SQLA session object (a temporary one will be @@ -1348,14 +1700,18 @@ def get_queue_build(filename, suite_id, session=None): @rtype: Queue @return: Queue object for the given queue - """ - if session is None: - session = DBConn().session() - q = session.query(QueueBuild).filter_by(filename=filename).filter_by(suite_id=suite_id) - if q.count() == 0: + + if isinstance(suite, int): + q = session.query(QueueBuild).filter_by(filename=filename).filter_by(suite_id=suite) + else: + q = session.query(QueueBuild).filter_by(filename=filename) + q = q.join(Suite).filter_by(suite_name=suite) + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_queue_build') @@ -1382,6 +1738,7 @@ class Section(object): __all__.append('Section') +@session_wrapper def get_section(section, session=None): """ Returns Section object for given C{section name}. @@ -1395,17 +1752,39 @@ def get_section(section, session=None): @rtype: Section @return: Section object for the given section name - """ - if session is None: - session = DBConn().session() + q = session.query(Section).filter_by(section=section) - if q.count() == 0: + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_section') +@session_wrapper +def get_sections(session=None): + """ + Returns dictionary of section names -> id mappings + + @type session: Session + @param session: Optional SQL session object (a temporary one will be + generated if not supplied) + + @rtype: dictionary + @return: dictionary of section names -> id mappings + """ + + ret = {} + q = session.query(Section) + for x in q.all(): + ret[x.section] = x.section_id + + return ret + +__all__.append('get_sections') + ################################################################################ class DBSource(object): @@ -1417,6 +1796,7 @@ class DBSource(object): __all__.append('DBSource') +@session_wrapper def source_exists(source, source_version, suites = ["any"], session=None): """ Ensure that source exists somewhere in the archive for the binary @@ -1442,10 +1822,8 @@ def source_exists(source, source_version, suites = ["any"], session=None): """ - if session is None: - session = DBConn().session() - cnf = Config() + ret = 1 for suite in suites: q = session.query(DBSource).filter_by(source=source) @@ -1480,13 +1858,13 @@ def source_exists(source, source_version, suites = ["any"], session=None): continue # No source found so return not ok - return 0 + ret = 0 - # We're good - return 1 + return ret __all__.append('source_exists') +@session_wrapper def get_suites_source_in(source, session=None): """ Returns list of Suite objects which given C{source} name is in @@ -1498,13 +1876,11 @@ def get_suites_source_in(source, session=None): @return: list of Suite objects for the given source """ - if session is None: - session = DBConn().session() - return session.query(Suite).join(SrcAssociation).join(DBSource).filter_by(source=source).all() __all__.append('get_suites_source_in') +@session_wrapper def get_sources_from_name(source, version=None, dm_upload_allowed=None, session=None): """ Returns list of DBSource objects for given C{source} name and other parameters @@ -1526,8 +1902,6 @@ def get_sources_from_name(source, version=None, dm_upload_allowed=None, session= @rtype: list @return: list of DBSource objects for the given name (may be empty) """ - if session is None: - session = DBConn().session() q = session.query(DBSource).filter_by(source=source) @@ -1541,6 +1915,7 @@ def get_sources_from_name(source, version=None, dm_upload_allowed=None, session= __all__.append('get_sources_from_name') +@session_wrapper def get_source_in_suite(source, suite, session=None): """ Returns list of DBSource objects for a combination of C{source} and C{suite}. @@ -1558,20 +1933,31 @@ def get_source_in_suite(source, suite, session=None): @return: the version for I{source} in I{suite} """ - if session is None: - session = DBConn().session() + q = session.query(SrcAssociation) q = q.join('source').filter_by(source=source) q = q.join('suite').filter_by(suite_name=suite) - if q.count() == 0: + + try: + return q.one().source + except NoResultFound: return None - # ???: Maybe we should just return the SrcAssociation object instead - return q.one().source __all__.append('get_source_in_suite') ################################################################################ +class SourceACL(object): + def __init__(self, *args, **kwargs): + pass + + def __repr__(self): + return '' % self.source_acl_id + +__all__.append('SourceACL') + +################################################################################ + class SrcAssociation(object): def __init__(self, *args, **kwargs): pass @@ -1583,6 +1969,17 @@ __all__.append('SrcAssociation') ################################################################################ +class SrcFormat(object): + def __init__(self, *args, **kwargs): + pass + + def __repr__(self): + return '' % (self.format_name) + +__all__.append('SrcFormat') + +################################################################################ + class SrcUploader(object): def __init__(self, *args, **kwargs): pass @@ -1644,6 +2041,7 @@ class Suite(object): __all__.append('Suite') +@session_wrapper def get_suite_architecture(suite, architecture, session=None): """ Returns a SuiteArchitecture object given C{suite} and ${arch} or None if it @@ -1663,18 +2061,18 @@ def get_suite_architecture(suite, architecture, session=None): @return: the SuiteArchitecture object or None """ - if session is None: - session = DBConn().session() - q = session.query(SuiteArchitecture) q = q.join(Architecture).filter_by(arch_string=architecture) q = q.join(Suite).filter_by(suite_name=suite) - if q.count() == 0: + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_suite_architecture') +@session_wrapper def get_suite(suite, session=None): """ Returns Suite object for given C{suite name}. @@ -1687,15 +2085,15 @@ def get_suite(suite, session=None): generated if not supplied) @rtype: Suite - @return: Suite object for the requested suite name (None if not presenT) - + @return: Suite object for the requested suite name (None if not present) """ - if session is None: - session = DBConn().session() + q = session.query(Suite).filter_by(suite_name=suite) - if q.count() == 0: + + try: + return q.one() + except NoResultFound: return None - return q.one() __all__.append('get_suite') @@ -1710,6 +2108,7 @@ class SuiteArchitecture(object): __all__.append('SuiteArchitecture') +@session_wrapper def get_suite_architectures(suite, skipsrc=False, skipall=False, session=None): """ Returns list of Architecture objects for given C{suite} name @@ -1733,23 +2132,60 @@ def get_suite_architectures(suite, skipsrc=False, skipall=False, session=None): @return: list of Architecture objects for the given name (may be empty) """ - if session is None: - session = DBConn().session() - q = session.query(Architecture) q = q.join(SuiteArchitecture) q = q.join(Suite).filter_by(suite_name=suite) + if skipsrc: q = q.filter(Architecture.arch_string != 'source') + if skipall: q = q.filter(Architecture.arch_string != 'all') + q = q.order_by('arch_string') + return q.all() __all__.append('get_suite_architectures') ################################################################################ +class SuiteSrcFormat(object): + def __init__(self, *args, **kwargs): + pass + + def __repr__(self): + return '' % (self.suite_id, self.src_format_id) + +__all__.append('SuiteSrcFormat') + +@session_wrapper +def get_suite_src_formats(suite, session=None): + """ + Returns list of allowed SrcFormat for C{suite}. + + @type suite: str + @param suite: Suite name to search for + + @type session: Session + @param session: Optional SQL session object (a temporary one will be + generated if not supplied) + + @rtype: list + @return: the list of allowed source formats for I{suite} + """ + + q = session.query(SrcFormat) + q = q.join(SuiteSrcFormat) + q = q.join(Suite).filter_by(suite_name=suite) + q = q.order_by('format_name') + + return q.all() + +__all__.append('get_suite_src_formats') + +################################################################################ + class Uid(object): def __init__(self, *args, **kwargs): pass @@ -1771,6 +2207,7 @@ class Uid(object): __all__.append('Uid') +@session_wrapper def add_database_user(uidname, session=None): """ Adds a database user @@ -1786,21 +2223,13 @@ def add_database_user(uidname, session=None): @rtype: Uid @return: the uid object for the given uidname """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True - try: - session.execute("CREATE USER :uid", {'uid': uidname}) - if privatetrans: - session.commit() - except: - traceback.print_exc() - raise + session.execute("CREATE USER :uid", {'uid': uidname}) + session.commit_or_flush() __all__.append('add_database_user') +@session_wrapper def get_or_set_uid(uidname, session=None): """ Returns uid object for given uidname. @@ -1818,48 +2247,47 @@ def get_or_set_uid(uidname, session=None): @rtype: Uid @return: the uid object for the given uidname """ - privatetrans = False - if session is None: - session = DBConn().session() - privatetrans = True + + q = session.query(Uid).filter_by(uid=uidname) try: - q = session.query(Uid).filter_by(uid=uidname) - if q.count() < 1: - uid = Uid() - uid.uid = uidname - session.add(uid) - if privatetrans: - session.commit() - else: - session.flush() - return uid - else: - return q.one() + ret = q.one() + except NoResultFound: + uid = Uid() + uid.uid = uidname + session.add(uid) + session.commit_or_flush() + ret = uid - except: - traceback.print_exc() - raise + return ret __all__.append('get_or_set_uid') - +@session_wrapper def get_uid_from_fingerprint(fpr, session=None): - if session is None: - session = DBConn().session() - q = session.query(Uid) q = q.join(Fingerprint).filter_by(fingerprint=fpr) - if q.count() != 1: - return None - else: + try: return q.one() + except NoResultFound: + return None __all__.append('get_uid_from_fingerprint') ################################################################################ +class UploadBlock(object): + def __init__(self, *args, **kwargs): + pass + + def __repr__(self): + return '' % (self.source, self.upload_block_id) + +__all__.append('UploadBlock') + +################################################################################ + class DBConn(Singleton): """ database module init. @@ -1878,6 +2306,8 @@ class DBConn(Singleton): self.tbl_archive = Table('archive', self.db_meta, autoload=True) self.tbl_bin_associations = Table('bin_associations', self.db_meta, autoload=True) self.tbl_binaries = Table('binaries', self.db_meta, autoload=True) + self.tbl_binary_acl = Table('binary_acl', self.db_meta, autoload=True) + self.tbl_binary_acl_map = Table('binary_acl_map', self.db_meta, autoload=True) self.tbl_component = Table('component', self.db_meta, autoload=True) self.tbl_config = Table('config', self.db_meta, autoload=True) self.tbl_content_associations = Table('content_associations', self.db_meta, autoload=True) @@ -1887,6 +2317,7 @@ class DBConn(Singleton): self.tbl_files = Table('files', self.db_meta, autoload=True) self.tbl_fingerprint = Table('fingerprint', self.db_meta, autoload=True) self.tbl_keyrings = Table('keyrings', self.db_meta, autoload=True) + self.tbl_keyring_acl_map = Table('keyring_acl_map', self.db_meta, autoload=True) self.tbl_location = Table('location', self.db_meta, autoload=True) self.tbl_maintainer = Table('maintainer', self.db_meta, autoload=True) self.tbl_new_comments = Table('new_comments', self.db_meta, autoload=True) @@ -1898,11 +2329,15 @@ class DBConn(Singleton): self.tbl_queue_build = Table('queue_build', self.db_meta, autoload=True) self.tbl_section = Table('section', self.db_meta, autoload=True) self.tbl_source = Table('source', self.db_meta, autoload=True) + self.tbl_source_acl = Table('source_acl', self.db_meta, autoload=True) self.tbl_src_associations = Table('src_associations', self.db_meta, autoload=True) + self.tbl_src_format = Table('src_format', self.db_meta, autoload=True) self.tbl_src_uploaders = Table('src_uploaders', self.db_meta, autoload=True) self.tbl_suite = Table('suite', self.db_meta, autoload=True) self.tbl_suite_architectures = Table('suite_architectures', self.db_meta, autoload=True) + self.tbl_suite_src_formats = Table('suite_src_formats', self.db_meta, autoload=True) self.tbl_uid = Table('uid', self.db_meta, autoload=True) + self.tbl_upload_blocks = Table('upload_blocks', self.db_meta, autoload=True) def __setupmappers(self): mapper(Architecture, self.tbl_architecture, @@ -1938,6 +2373,14 @@ class DBConn(Singleton): binassociations = relation(BinAssociation, primaryjoin=(self.tbl_binaries.c.id==self.tbl_bin_associations.c.bin)))) + mapper(BinaryACL, self.tbl_binary_acl, + properties = dict(binary_acl_id = self.tbl_binary_acl.c.id)) + + mapper(BinaryACLMap, self.tbl_binary_acl_map, + properties = dict(binary_acl_map_id = self.tbl_binary_acl_map.c.id, + fingerprint = relation(Fingerprint, backref="binary_acl_map"), + architecture = relation(Architecture))) + mapper(Component, self.tbl_component, properties = dict(component_id = self.tbl_component.c.id, component_name = self.tbl_component.c.name)) @@ -1981,12 +2424,19 @@ class DBConn(Singleton): uid_id = self.tbl_fingerprint.c.uid, uid = relation(Uid), keyring_id = self.tbl_fingerprint.c.keyring, - keyring = relation(Keyring))) + keyring = relation(Keyring), + source_acl = relation(SourceACL), + binary_acl = relation(BinaryACL))) mapper(Keyring, self.tbl_keyrings, properties = dict(keyring_name = self.tbl_keyrings.c.name, keyring_id = self.tbl_keyrings.c.id)) + mapper(KeyringACLMap, self.tbl_keyring_acl_map, + properties = dict(keyring_acl_map_id = self.tbl_keyring_acl_map.c.id, + keyring = relation(Keyring, backref="keyring_acl_map"), + architecture = relation(Architecture))) + mapper(Location, self.tbl_location, properties = dict(location_id = self.tbl_location.c.id, component_id = self.tbl_location.c.component, @@ -2054,7 +2504,11 @@ class DBConn(Singleton): srcfiles = relation(DSCFile, primaryjoin=(self.tbl_source.c.id==self.tbl_dsc_files.c.source)), srcassociations = relation(SrcAssociation, - primaryjoin=(self.tbl_source.c.id==self.tbl_src_associations.c.source)))) + primaryjoin=(self.tbl_source.c.id==self.tbl_src_associations.c.source)), + srcuploaders = relation(SrcUploader))) + + mapper(SourceACL, self.tbl_source_acl, + properties = dict(source_acl_id = self.tbl_source_acl.c.id)) mapper(SrcAssociation, self.tbl_src_associations, properties = dict(sa_id = self.tbl_src_associations.c.id, @@ -2063,6 +2517,10 @@ class DBConn(Singleton): source_id = self.tbl_src_associations.c.source, source = relation(DBSource))) + mapper(SrcFormat, self.tbl_src_format, + properties = dict(src_format_id = self.tbl_src_format.c.id, + format_name = self.tbl_src_format.c.format_name)) + mapper(SrcUploader, self.tbl_src_uploaders, properties = dict(uploader_id = self.tbl_src_uploaders.c.id, source_id = self.tbl_src_uploaders.c.source, @@ -2081,10 +2539,21 @@ class DBConn(Singleton): arch_id = self.tbl_suite_architectures.c.architecture, architecture = relation(Architecture))) + mapper(SuiteSrcFormat, self.tbl_suite_src_formats, + properties = dict(suite_id = self.tbl_suite_src_formats.c.suite, + suite = relation(Suite, backref='suitesrcformats'), + src_format_id = self.tbl_suite_src_formats.c.src_format, + src_format = relation(SrcFormat))) + mapper(Uid, self.tbl_uid, properties = dict(uid_id = self.tbl_uid.c.id, fingerprint = relation(Fingerprint))) + mapper(UploadBlock, self.tbl_upload_blocks, + properties = dict(upload_block_id = self.tbl_upload_blocks.c.id, + fingerprint = relation(Fingerprint, backref="uploadblocks"), + uid = relation(Uid, backref="uploadblocks"))) + ## Connection functions def __createconn(self): from config import Config