]> git.decadent.org.uk Git - dak.git/blobdiff - daklib/dbconn.py
Test and Improve source_exists().
[dak.git] / daklib / dbconn.py
index 6ea3aa398beaa1745795321e3d3d7bc3af2759f5..9ace924414717d4f6cafb4b2e68fc0fc70263725 100755 (executable)
@@ -1070,8 +1070,12 @@ __all__.append('get_dscfiles')
 ################################################################################
 
 class PoolFile(object):
-    def __init__(self, *args, **kwargs):
-        pass
+    def __init__(self, filename = None, location = None, filesize = -1, \
+        md5sum = None):
+        self.filename = filename
+        self.location = location
+        self.filesize = filesize
+        self.md5sum = md5sum
 
     def __repr__(self):
         return '<PoolFile %s>' % self.filename
@@ -1080,13 +1084,16 @@ class PoolFile(object):
     def fullpath(self):
         return os.path.join(self.location.path, self.filename)
 
+    def is_valid(self, filesize = -1, md5sum = None):\
+        return self.filesize == filesize and self.md5sum == md5sum
+
 __all__.append('PoolFile')
 
 @session_wrapper
 def check_poolfile(filename, filesize, md5sum, location_id, session=None):
     """
     Returns a tuple:
-    (ValidFileFound [boolean or None], PoolFile object or None)
+    (ValidFileFound [boolean], PoolFile object or None)
 
     @type filename: string
     @param filename: the filename of the file to check against the DB
@@ -1102,34 +1109,24 @@ def check_poolfile(filename, filesize, md5sum, location_id, session=None):
 
     @rtype: tuple
     @return: Tuple of length 2.
-                 - If more than one file found with that name: (C{None},  C{None})
                  - If valid pool file found: (C{True}, C{PoolFile object})
                  - If valid pool file not found:
                      - (C{False}, C{None}) if no file found
                      - (C{False}, C{PoolFile object}) if file found with size/md5sum mismatch
     """
 
-    q = session.query(PoolFile).filter_by(filename=filename)
-    q = q.join(Location).filter_by(location_id=location_id)
-
-    ret = None
+    poolfile = session.query(Location).get(location_id). \
+        files.filter_by(filename=filename).first()
+    valid = False
+    if poolfile and poolfile.is_valid(filesize = filesize, md5sum = md5sum):
+        valid = True
 
-    if q.count() > 1:
-        ret = (None, None)
-    elif q.count() < 1:
-        ret = (False, None)
-    else:
-        obj = q.one()
-        if obj.md5sum != md5sum or obj.filesize != int(filesize):
-            ret = (False, obj)
-
-    if ret is None:
-        ret = (True, obj)
-
-    return ret
+    return (valid, poolfile)
 
 __all__.append('check_poolfile')
 
+# TODO: the implementation can trivially be inlined at the place where the
+# function is called
 @session_wrapper
 def get_poolfile_by_id(file_id, session=None):
     """
@@ -1142,41 +1139,10 @@ def get_poolfile_by_id(file_id, session=None):
     @return: either the PoolFile object or None
     """
 
-    q = session.query(PoolFile).filter_by(file_id=file_id)
-
-    try:
-        return q.one()
-    except NoResultFound:
-        return None
+    return session.query(PoolFile).get(file_id)
 
 __all__.append('get_poolfile_by_id')
 
-
-@session_wrapper
-def get_poolfile_by_name(filename, location_id=None, session=None):
-    """
-    Returns an array of PoolFile objects for the given filename and
-    (optionally) location_id
-
-    @type filename: string
-    @param filename: the filename of the file to check against the DB
-
-    @type location_id: int
-    @param location_id: the id of the location to look in (optional)
-
-    @rtype: array
-    @return: array of PoolFile objects
-    """
-
-    q = session.query(PoolFile).filter_by(filename=filename)
-
-    if location_id is not None:
-        q = q.join(Location).filter_by(location_id=location_id)
-
-    return q.all()
-
-__all__.append('get_poolfile_by_name')
-
 @session_wrapper
 def get_poolfile_like_name(filename, session=None):
     """
@@ -1522,8 +1488,10 @@ __all__.append('get_dbchange')
 ################################################################################
 
 class Location(object):
-    def __init__(self, *args, **kwargs):
-        pass
+    def __init__(self, path = None):
+        self.path = path
+        # the column 'type' should go away, see comment at mapper
+        self.archive_type = 'pool'
 
     def __repr__(self):
         return '<Location %s (%s)>' % (self.path, self.location_id)
@@ -1567,8 +1535,8 @@ __all__.append('get_location')
 ################################################################################
 
 class Maintainer(object):
-    def __init__(self, *args, **kwargs):
-        pass
+    def __init__(self, name = None):
+        self.name = name
 
     def __repr__(self):
         return '''<Maintainer '%s' (%s)>''' % (self.name, self.maintainer_id)
@@ -2106,8 +2074,14 @@ __all__.append('get_sections')
 ################################################################################
 
 class DBSource(object):
-    def __init__(self, *args, **kwargs):
-        pass
+    def __init__(self, source = None, version = None, maintainer = None, \
+        changedby = None, poolfile = None, install_date = None):
+        self.source = source
+        self.version = version
+        self.maintainer = maintainer
+        self.changedby = changedby
+        self.poolfile = poolfile
+        self.install_date = install_date
 
     def __repr__(self):
         return '<DBSource %s (%s)>' % (self.source, self.version)
@@ -2141,10 +2115,14 @@ def source_exists(source, source_version, suites = ["any"], session=None):
     """
 
     cnf = Config()
-    ret = 1
+    ret = True
+
+    from daklib.regexes import re_bin_only_nmu
+    orig_source_version = re_bin_only_nmu.sub('', source_version)
 
     for suite in suites:
-        q = session.query(DBSource).filter_by(source=source)
+        q = session.query(DBSource).filter_by(source=source). \
+            filter(DBSource.version.in_([source_version, orig_source_version]))
         if suite != "any":
             # source must exist in suite X, or in some other suite that's
             # mapped to X, recursively... silent-maps are counted too,
@@ -2159,24 +2137,13 @@ def source_exists(source, source_version, suites = ["any"], session=None):
                 if x[1] in s and x[0] not in s:
                     s.append(x[0])
 
-            q = q.join(SrcAssociation).join(Suite)
-            q = q.filter(Suite.suite_name.in_(s))
+            q = q.filter(DBSource.suites.any(Suite.suite_name.in_(s)))
 
-        # Reduce the query results to a list of version numbers
-        ql = [ j.version for j in q.all() ]
-
-        # Try (1)
-        if source_version in ql:
-            continue
-
-        # Try (2)
-        from daklib.regexes import re_bin_only_nmu
-        orig_source_version = re_bin_only_nmu.sub('', source_version)
-        if orig_source_version in ql:
+        if q.count() > 0:
             continue
 
         # No source found so return not ok
-        ret = 0
+        ret = False
 
     return ret
 
@@ -2194,7 +2161,7 @@ def get_suites_source_in(source, session=None):
     @return: list of Suite objects for the given source
     """
 
-    return session.query(Suite).join(SrcAssociation).join(DBSource).filter_by(source=source).all()
+    return session.query(Suite).filter(Suite.sources.any(source=source)).all()
 
 __all__.append('get_suites_source_in')
 
@@ -2233,10 +2200,12 @@ def get_sources_from_name(source, version=None, dm_upload_allowed=None, session=
 
 __all__.append('get_sources_from_name')
 
+# FIXME: This function fails badly if it finds more than 1 source package and
+# its implementation is trivial enough to be inlined.
 @session_wrapper
 def get_source_in_suite(source, suite, session=None):
     """
-    Returns list of DBSource objects for a combination of C{source} and C{suite}.
+    Returns a DBSource object for a combination of C{source} and C{suite}.
 
       - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc}
       - B{suite} - a suite name, eg. I{unstable}
@@ -2252,12 +2221,9 @@ def get_source_in_suite(source, suite, session=None):
 
     """
 
-    q = session.query(SrcAssociation)
-    q = q.join('source').filter_by(source=source)
-    q = q.join('suite').filter_by(suite_name=suite)
-
+    q = get_suite(suite, session).get_sources(source)
     try:
-        return q.one().source
+        return q.one()
     except NoResultFound:
         return None
 
@@ -2293,15 +2259,10 @@ def add_dsc_to_db(u, filename, session=None):
 
     source.poolfile_id = entry["files id"]
     session.add(source)
-    session.flush()
-
-    for suite_name in u.pkg.changes["distribution"].keys():
-        sa = SrcAssociation()
-        sa.source_id = source.source_id
-        sa.suite_id = get_suite(suite_name).suite_id
-        session.add(sa)
 
-    session.flush()
+    suite_names = u.pkg.changes["distribution"].keys()
+    source.suites = session.query(Suite). \
+        filter(Suite.suite_name.in_(suite_names)).all()
 
     # Add the source files to the DB (files and dsc_files)
     dscfile = DSCFile()
@@ -2351,8 +2312,6 @@ def add_dsc_to_db(u, filename, session=None):
         df.poolfile_id = files_id
         session.add(df)
 
-    session.flush()
-
     # Add the src_uploaders to the DB
     uploader_ids = [source.maintainer_id]
     if u.pkg.dsc.has_key("uploaders"):
@@ -2506,6 +2465,8 @@ SUITE_FIELDS = [ ('SuiteName', 'suite_name'),
                  ('CopyChanges', 'copychanges'),
                  ('OverrideSuite', 'overridesuite')]
 
+# Why the heck don't we have any UNIQUE constraints in table suite?
+# TODO: Add UNIQUE constraints for appropriate columns.
 class Suite(object):
     def __init__(self, suite_name = None, version = None):
         self.suite_name = suite_name
@@ -2559,6 +2520,24 @@ class Suite(object):
             q = q.filter(Architecture.arch_string != 'all')
         return q.order_by(Architecture.arch_string).all()
 
+    def get_sources(self, source):
+        """
+        Returns a query object representing DBSource that is part of C{suite}.
+
+          - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc}
+
+        @type source: string
+        @param source: source package name
+
+        @rtype: sqlalchemy.orm.query.Query
+        @return: a query of DBSource
+
+        """
+
+        session = object_session(self)
+        return session.query(DBSource).filter_by(source = source). \
+            filter(DBSource.suites.contains(self))
+
 __all__.append('Suite')
 
 @session_wrapper
@@ -2944,7 +2923,11 @@ class DBConn(object):
                properties = dict(file_id = self.tbl_files.c.id,
                                  filesize = self.tbl_files.c.size,
                                  location_id = self.tbl_files.c.location,
-                                 location = relation(Location)))
+                                 location = relation(Location,
+                                     # using lazy='dynamic' in the back
+                                     # reference because we have A LOT of
+                                     # files in one location
+                                     backref=backref('files', lazy='dynamic'))))
 
         mapper(Fingerprint, self.tbl_fingerprint,
                properties = dict(fingerprint_id = self.tbl_fingerprint.c.id,
@@ -3017,6 +3000,8 @@ class DBConn(object):
                                  component = relation(Component),
                                  archive_id = self.tbl_location.c.archive,
                                  archive = relation(Archive),
+                                 # FIXME: the 'type' column is old cruft and
+                                 # should be removed in the future.
                                  archive_type = self.tbl_location.c.type))
 
         mapper(Maintainer, self.tbl_maintainer,
@@ -3061,7 +3046,7 @@ class DBConn(object):
                                  version = self.tbl_source.c.version,
                                  maintainer_id = self.tbl_source.c.maintainer,
                                  poolfile_id = self.tbl_source.c.file,
-                                 poolfile = relation(PoolFile),
+                                 poolfile = relation(PoolFile, backref=backref('source', uselist = False)),
                                  fingerprint_id = self.tbl_source.c.sig_fpr,
                                  fingerprint = relation(Fingerprint),
                                  changedby_id = self.tbl_source.c.changedby,