X-Git-Url: https://git.decadent.org.uk/gitweb/?a=blobdiff_plain;f=daklib%2Fdbconn.py;h=9ace924414717d4f6cafb4b2e68fc0fc70263725;hb=55173b9ccdca13cc375d07932b8bbb14a91e213d;hp=6ea3aa398beaa1745795321e3d3d7bc3af2759f5;hpb=e6d57b8d52035f9ec7f0efcd9cca2c1514c7d055;p=dak.git diff --git a/daklib/dbconn.py b/daklib/dbconn.py index 6ea3aa39..9ace9244 100755 --- a/daklib/dbconn.py +++ b/daklib/dbconn.py @@ -1070,8 +1070,12 @@ __all__.append('get_dscfiles') ################################################################################ class PoolFile(object): - def __init__(self, *args, **kwargs): - pass + def __init__(self, filename = None, location = None, filesize = -1, \ + md5sum = None): + self.filename = filename + self.location = location + self.filesize = filesize + self.md5sum = md5sum def __repr__(self): return '' % self.filename @@ -1080,13 +1084,16 @@ class PoolFile(object): def fullpath(self): return os.path.join(self.location.path, self.filename) + def is_valid(self, filesize = -1, md5sum = None):\ + return self.filesize == filesize and self.md5sum == md5sum + __all__.append('PoolFile') @session_wrapper def check_poolfile(filename, filesize, md5sum, location_id, session=None): """ Returns a tuple: - (ValidFileFound [boolean or None], PoolFile object or None) + (ValidFileFound [boolean], PoolFile object or None) @type filename: string @param filename: the filename of the file to check against the DB @@ -1102,34 +1109,24 @@ def check_poolfile(filename, filesize, md5sum, location_id, session=None): @rtype: tuple @return: Tuple of length 2. - - If more than one file found with that name: (C{None}, C{None}) - If valid pool file found: (C{True}, C{PoolFile object}) - If valid pool file not found: - (C{False}, C{None}) if no file found - (C{False}, C{PoolFile object}) if file found with size/md5sum mismatch """ - q = session.query(PoolFile).filter_by(filename=filename) - q = q.join(Location).filter_by(location_id=location_id) - - ret = None + poolfile = session.query(Location).get(location_id). \ + files.filter_by(filename=filename).first() + valid = False + if poolfile and poolfile.is_valid(filesize = filesize, md5sum = md5sum): + valid = True - if q.count() > 1: - ret = (None, None) - elif q.count() < 1: - ret = (False, None) - else: - obj = q.one() - if obj.md5sum != md5sum or obj.filesize != int(filesize): - ret = (False, obj) - - if ret is None: - ret = (True, obj) - - return ret + return (valid, poolfile) __all__.append('check_poolfile') +# TODO: the implementation can trivially be inlined at the place where the +# function is called @session_wrapper def get_poolfile_by_id(file_id, session=None): """ @@ -1142,41 +1139,10 @@ def get_poolfile_by_id(file_id, session=None): @return: either the PoolFile object or None """ - q = session.query(PoolFile).filter_by(file_id=file_id) - - try: - return q.one() - except NoResultFound: - return None + return session.query(PoolFile).get(file_id) __all__.append('get_poolfile_by_id') - -@session_wrapper -def get_poolfile_by_name(filename, location_id=None, session=None): - """ - Returns an array of PoolFile objects for the given filename and - (optionally) location_id - - @type filename: string - @param filename: the filename of the file to check against the DB - - @type location_id: int - @param location_id: the id of the location to look in (optional) - - @rtype: array - @return: array of PoolFile objects - """ - - q = session.query(PoolFile).filter_by(filename=filename) - - if location_id is not None: - q = q.join(Location).filter_by(location_id=location_id) - - return q.all() - -__all__.append('get_poolfile_by_name') - @session_wrapper def get_poolfile_like_name(filename, session=None): """ @@ -1522,8 +1488,10 @@ __all__.append('get_dbchange') ################################################################################ class Location(object): - def __init__(self, *args, **kwargs): - pass + def __init__(self, path = None): + self.path = path + # the column 'type' should go away, see comment at mapper + self.archive_type = 'pool' def __repr__(self): return '' % (self.path, self.location_id) @@ -1567,8 +1535,8 @@ __all__.append('get_location') ################################################################################ class Maintainer(object): - def __init__(self, *args, **kwargs): - pass + def __init__(self, name = None): + self.name = name def __repr__(self): return '''''' % (self.name, self.maintainer_id) @@ -2106,8 +2074,14 @@ __all__.append('get_sections') ################################################################################ class DBSource(object): - def __init__(self, *args, **kwargs): - pass + def __init__(self, source = None, version = None, maintainer = None, \ + changedby = None, poolfile = None, install_date = None): + self.source = source + self.version = version + self.maintainer = maintainer + self.changedby = changedby + self.poolfile = poolfile + self.install_date = install_date def __repr__(self): return '' % (self.source, self.version) @@ -2141,10 +2115,14 @@ def source_exists(source, source_version, suites = ["any"], session=None): """ cnf = Config() - ret = 1 + ret = True + + from daklib.regexes import re_bin_only_nmu + orig_source_version = re_bin_only_nmu.sub('', source_version) for suite in suites: - q = session.query(DBSource).filter_by(source=source) + q = session.query(DBSource).filter_by(source=source). \ + filter(DBSource.version.in_([source_version, orig_source_version])) if suite != "any": # source must exist in suite X, or in some other suite that's # mapped to X, recursively... silent-maps are counted too, @@ -2159,24 +2137,13 @@ def source_exists(source, source_version, suites = ["any"], session=None): if x[1] in s and x[0] not in s: s.append(x[0]) - q = q.join(SrcAssociation).join(Suite) - q = q.filter(Suite.suite_name.in_(s)) + q = q.filter(DBSource.suites.any(Suite.suite_name.in_(s))) - # Reduce the query results to a list of version numbers - ql = [ j.version for j in q.all() ] - - # Try (1) - if source_version in ql: - continue - - # Try (2) - from daklib.regexes import re_bin_only_nmu - orig_source_version = re_bin_only_nmu.sub('', source_version) - if orig_source_version in ql: + if q.count() > 0: continue # No source found so return not ok - ret = 0 + ret = False return ret @@ -2194,7 +2161,7 @@ def get_suites_source_in(source, session=None): @return: list of Suite objects for the given source """ - return session.query(Suite).join(SrcAssociation).join(DBSource).filter_by(source=source).all() + return session.query(Suite).filter(Suite.sources.any(source=source)).all() __all__.append('get_suites_source_in') @@ -2233,10 +2200,12 @@ def get_sources_from_name(source, version=None, dm_upload_allowed=None, session= __all__.append('get_sources_from_name') +# FIXME: This function fails badly if it finds more than 1 source package and +# its implementation is trivial enough to be inlined. @session_wrapper def get_source_in_suite(source, suite, session=None): """ - Returns list of DBSource objects for a combination of C{source} and C{suite}. + Returns a DBSource object for a combination of C{source} and C{suite}. - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc} - B{suite} - a suite name, eg. I{unstable} @@ -2252,12 +2221,9 @@ def get_source_in_suite(source, suite, session=None): """ - q = session.query(SrcAssociation) - q = q.join('source').filter_by(source=source) - q = q.join('suite').filter_by(suite_name=suite) - + q = get_suite(suite, session).get_sources(source) try: - return q.one().source + return q.one() except NoResultFound: return None @@ -2293,15 +2259,10 @@ def add_dsc_to_db(u, filename, session=None): source.poolfile_id = entry["files id"] session.add(source) - session.flush() - - for suite_name in u.pkg.changes["distribution"].keys(): - sa = SrcAssociation() - sa.source_id = source.source_id - sa.suite_id = get_suite(suite_name).suite_id - session.add(sa) - session.flush() + suite_names = u.pkg.changes["distribution"].keys() + source.suites = session.query(Suite). \ + filter(Suite.suite_name.in_(suite_names)).all() # Add the source files to the DB (files and dsc_files) dscfile = DSCFile() @@ -2351,8 +2312,6 @@ def add_dsc_to_db(u, filename, session=None): df.poolfile_id = files_id session.add(df) - session.flush() - # Add the src_uploaders to the DB uploader_ids = [source.maintainer_id] if u.pkg.dsc.has_key("uploaders"): @@ -2506,6 +2465,8 @@ SUITE_FIELDS = [ ('SuiteName', 'suite_name'), ('CopyChanges', 'copychanges'), ('OverrideSuite', 'overridesuite')] +# Why the heck don't we have any UNIQUE constraints in table suite? +# TODO: Add UNIQUE constraints for appropriate columns. class Suite(object): def __init__(self, suite_name = None, version = None): self.suite_name = suite_name @@ -2559,6 +2520,24 @@ class Suite(object): q = q.filter(Architecture.arch_string != 'all') return q.order_by(Architecture.arch_string).all() + def get_sources(self, source): + """ + Returns a query object representing DBSource that is part of C{suite}. + + - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc} + + @type source: string + @param source: source package name + + @rtype: sqlalchemy.orm.query.Query + @return: a query of DBSource + + """ + + session = object_session(self) + return session.query(DBSource).filter_by(source = source). \ + filter(DBSource.suites.contains(self)) + __all__.append('Suite') @session_wrapper @@ -2944,7 +2923,11 @@ class DBConn(object): properties = dict(file_id = self.tbl_files.c.id, filesize = self.tbl_files.c.size, location_id = self.tbl_files.c.location, - location = relation(Location))) + location = relation(Location, + # using lazy='dynamic' in the back + # reference because we have A LOT of + # files in one location + backref=backref('files', lazy='dynamic')))) mapper(Fingerprint, self.tbl_fingerprint, properties = dict(fingerprint_id = self.tbl_fingerprint.c.id, @@ -3017,6 +3000,8 @@ class DBConn(object): component = relation(Component), archive_id = self.tbl_location.c.archive, archive = relation(Archive), + # FIXME: the 'type' column is old cruft and + # should be removed in the future. archive_type = self.tbl_location.c.type)) mapper(Maintainer, self.tbl_maintainer, @@ -3061,7 +3046,7 @@ class DBConn(object): version = self.tbl_source.c.version, maintainer_id = self.tbl_source.c.maintainer, poolfile_id = self.tbl_source.c.file, - poolfile = relation(PoolFile), + poolfile = relation(PoolFile, backref=backref('source', uselist = False)), fingerprint_id = self.tbl_source.c.sig_fpr, fingerprint = relation(Fingerprint), changedby_id = self.tbl_source.c.changedby,