]> git.decadent.org.uk Git - dak.git/blobdiff - daklib/dbconn.py
Test the DBSource and Suite relation.
[dak.git] / daklib / dbconn.py
index b46eede790c80edbd8759cd8bbdf64b8709a22d9..513abb5805f1ba5f67eee0d8effee131e05a318c 100755 (executable)
@@ -1084,13 +1084,16 @@ class PoolFile(object):
     def fullpath(self):
         return os.path.join(self.location.path, self.filename)
 
+    def is_valid(self, filesize = -1, md5sum = None):\
+        return self.filesize == filesize and self.md5sum == md5sum
+
 __all__.append('PoolFile')
 
 @session_wrapper
 def check_poolfile(filename, filesize, md5sum, location_id, session=None):
     """
     Returns a tuple:
-    (ValidFileFound [boolean or None], PoolFile object or None)
+    (ValidFileFound [boolean], PoolFile object or None)
 
     @type filename: string
     @param filename: the filename of the file to check against the DB
@@ -1106,34 +1109,24 @@ def check_poolfile(filename, filesize, md5sum, location_id, session=None):
 
     @rtype: tuple
     @return: Tuple of length 2.
-                 - If more than one file found with that name: (C{None},  C{None})
                  - If valid pool file found: (C{True}, C{PoolFile object})
                  - If valid pool file not found:
                      - (C{False}, C{None}) if no file found
                      - (C{False}, C{PoolFile object}) if file found with size/md5sum mismatch
     """
 
-    q = session.query(PoolFile).filter_by(filename=filename)
-    q = q.join(Location).filter_by(location_id=location_id)
-
-    ret = None
-
-    if q.count() > 1:
-        ret = (None, None)
-    elif q.count() < 1:
-        ret = (False, None)
-    else:
-        obj = q.one()
-        if obj.md5sum != md5sum or obj.filesize != int(filesize):
-            ret = (False, obj)
-
-    if ret is None:
-        ret = (True, obj)
+    poolfile = session.query(Location).get(location_id). \
+        files.filter_by(filename=filename).first()
+    valid = False
+    if poolfile and poolfile.is_valid(filesize = filesize, md5sum = md5sum):
+        valid = True
 
-    return ret
+    return (valid, poolfile)
 
 __all__.append('check_poolfile')
 
+# TODO: the implementation can trivially be inlined at the place where the
+# function is called
 @session_wrapper
 def get_poolfile_by_id(file_id, session=None):
     """
@@ -1146,41 +1139,10 @@ def get_poolfile_by_id(file_id, session=None):
     @return: either the PoolFile object or None
     """
 
-    q = session.query(PoolFile).filter_by(file_id=file_id)
-
-    try:
-        return q.one()
-    except NoResultFound:
-        return None
+    return session.query(PoolFile).get(file_id)
 
 __all__.append('get_poolfile_by_id')
 
-
-@session_wrapper
-def get_poolfile_by_name(filename, location_id=None, session=None):
-    """
-    Returns an array of PoolFile objects for the given filename and
-    (optionally) location_id
-
-    @type filename: string
-    @param filename: the filename of the file to check against the DB
-
-    @type location_id: int
-    @param location_id: the id of the location to look in (optional)
-
-    @rtype: array
-    @return: array of PoolFile objects
-    """
-
-    q = session.query(PoolFile).filter_by(filename=filename)
-
-    if location_id is not None:
-        q = q.join(Location).filter_by(location_id=location_id)
-
-    return q.all()
-
-__all__.append('get_poolfile_by_name')
-
 @session_wrapper
 def get_poolfile_like_name(filename, session=None):
     """
@@ -2112,9 +2074,14 @@ __all__.append('get_sections')
 ################################################################################
 
 class DBSource(object):
-    def __init__(self, maintainer = None, changedby = None):
+    def __init__(self, source = None, version = None, maintainer = None, \
+        changedby = None, poolfile = None, install_date = None):
+        self.source = source
+        self.version = version
         self.maintainer = maintainer
         self.changedby = changedby
+        self.poolfile = poolfile
+        self.install_date = install_date
 
     def __repr__(self):
         return '<DBSource %s (%s)>' % (self.source, self.version)
@@ -2240,10 +2207,11 @@ def get_sources_from_name(source, version=None, dm_upload_allowed=None, session=
 
 __all__.append('get_sources_from_name')
 
+# FIXME: This function fails badly if it finds more than 1 source package.
 @session_wrapper
 def get_source_in_suite(source, suite, session=None):
     """
-    Returns list of DBSource objects for a combination of C{source} and C{suite}.
+    Returns a DBSource object for a combination of C{source} and C{suite}.
 
       - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc}
       - B{suite} - a suite name, eg. I{unstable}
@@ -2259,12 +2227,11 @@ def get_source_in_suite(source, suite, session=None):
 
     """
 
-    q = session.query(SrcAssociation)
-    q = q.join('source').filter_by(source=source)
-    q = q.join('suite').filter_by(suite_name=suite)
+    q = session.query(DBSource).filter_by(source = source). \
+        filter(DBSource.suites.any(Suite.suite_name == suite))
 
     try:
-        return q.one().source
+        return q.one()
     except NoResultFound:
         return None
 
@@ -3074,7 +3041,7 @@ class DBConn(object):
                                  version = self.tbl_source.c.version,
                                  maintainer_id = self.tbl_source.c.maintainer,
                                  poolfile_id = self.tbl_source.c.file,
-                                 poolfile = relation(PoolFile),
+                                 poolfile = relation(PoolFile, backref=backref('source', uselist = False)),
                                  fingerprint_id = self.tbl_source.c.sig_fpr,
                                  fingerprint = relation(Fingerprint),
                                  changedby_id = self.tbl_source.c.changedby,