]> git.decadent.org.uk Git - dak.git/blobdiff - daklib/dbconn.py
JSON: support python 2.5, too.
[dak.git] / daklib / dbconn.py
index 6d8d3bd69b230a1a61ceec456cdfbb18a6311d31..1364539422961acf89e3a422e9cbf66a5f864e18 100755 (executable)
@@ -38,6 +38,14 @@ import re
 import psycopg2
 import traceback
 import commands
+
+try:
+    # python >= 2.6
+    import json
+except:
+    # python <= 2.5
+    import simplejson as json
+
 from datetime import datetime, timedelta
 from errno import ENOENT
 from tempfile import mkstemp, mkdtemp
@@ -156,7 +164,86 @@ __all__.append('session_wrapper')
 
 ################################################################################
 
-class Architecture(object):
+class ORMObject(object):
+    """
+    ORMObject is a base class for all ORM classes mapped by SQLalchemy. All
+    derived classes must implement the summary() method.
+    """
+
+    def properties(self):
+        '''
+        This method should be implemented by all derived classes and returns a
+        list of the important properties. The properties 'created' and
+        'modified' will be added automatically. A suffix '_count' should be
+        added to properties that are lists or query objects. The most important
+        property name should be returned as the first element in the list
+        because it is used by repr().
+        '''
+        return []
+
+    def json(self):
+        '''
+        Returns a JSON representation of the object based on the properties
+        returned from the properties() method.
+        '''
+        data = {}
+        # add created and modified
+        all_properties = self.properties() + ['created', 'modified']
+        for property in all_properties:
+            # check for list or query
+            if property[-6:] == '_count':
+                value = getattr(self, property[:-6])
+                if hasattr(value, '__len__'):
+                    # list
+                    value = len(value)
+                elif hasattr(value, 'count'):
+                    # query
+                    value = value.count()
+                else:
+                    raise KeyError('Do not understand property %s.' % property)
+            else:
+                # plain object
+                value = getattr(self, property)
+                if value is None:
+                    # skip None
+                    pass
+                elif isinstance(value, ORMObject):
+                    # use repr() for ORMObject types
+                    value = repr(value)
+                else:
+                    # we want a string for all other types because json cannot
+                    # everything
+                    value = str(value)
+            data[property] = value
+        return json.dumps(data)
+
+    def classname(self):
+        '''
+        Returns the name of the class.
+        '''
+        return type(self).__name__
+
+    def __repr__(self):
+        '''
+        Returns a short string representation of the object using the first
+        element from the properties() method.
+        '''
+        primary_property = self.properties()[0]
+        value = getattr(self, primary_property)
+        return '<%s %s>' % (self.classname(), str(value))
+
+    def __str__(self):
+        '''
+        Returns a human readable form of the object using the properties()
+        method.
+        '''
+        return '<%s %s>' % (self.classname(), self.json())
+
+__all__.append('ORMObject')
+
+################################################################################
+
+class Architecture(ORMObject):
     def __init__(self, arch_string = None, description = None):
         self.arch_string = arch_string
         self.description = description
@@ -173,8 +260,8 @@ class Architecture(object):
         # This signals to use the normal comparison operator
         return NotImplemented
 
-    def __repr__(self):
-        return '<Architecture %s>' % self.arch_string
+    def properties(self):
+        return ['arch_string', 'arch_id', 'suites_count']
 
 __all__.append('Architecture')
 
@@ -1069,7 +1156,7 @@ __all__.append('get_dscfiles')
 
 ################################################################################
 
-class PoolFile(object):
+class PoolFile(ORMObject):
     def __init__(self, filename = None, location = None, filesize = -1, \
         md5sum = None):
         self.filename = filename
@@ -1077,9 +1164,6 @@ class PoolFile(object):
         self.filesize = filesize
         self.md5sum = md5sum
 
-    def __repr__(self):
-        return '<PoolFile %s>' % self.filename
-
     @property
     def fullpath(self):
         return os.path.join(self.location.path, self.filename)
@@ -1087,6 +1171,10 @@ class PoolFile(object):
     def is_valid(self, filesize = -1, md5sum = None):\
         return self.filesize == filesize and self.md5sum == md5sum
 
+    def properties(self):
+        return ['filename', 'file_id', 'filesize', 'md5sum', 'sha1sum', \
+            'sha256sum', 'location', 'source', 'last_used']
+
 __all__.append('PoolFile')
 
 @session_wrapper
@@ -2115,10 +2203,14 @@ def source_exists(source, source_version, suites = ["any"], session=None):
     """
 
     cnf = Config()
-    ret = 1
+    ret = True
+
+    from daklib.regexes import re_bin_only_nmu
+    orig_source_version = re_bin_only_nmu.sub('', source_version)
 
     for suite in suites:
-        q = session.query(DBSource).filter_by(source=source)
+        q = session.query(DBSource).filter_by(source=source). \
+            filter(DBSource.version.in_([source_version, orig_source_version]))
         if suite != "any":
             # source must exist in suite X, or in some other suite that's
             # mapped to X, recursively... silent-maps are counted too,
@@ -2133,24 +2225,13 @@ def source_exists(source, source_version, suites = ["any"], session=None):
                 if x[1] in s and x[0] not in s:
                     s.append(x[0])
 
-            q = q.join(SrcAssociation).join(Suite)
-            q = q.filter(Suite.suite_name.in_(s))
+            q = q.filter(DBSource.suites.any(Suite.suite_name.in_(s)))
 
-        # Reduce the query results to a list of version numbers
-        ql = [ j.version for j in q.all() ]
-
-        # Try (1)
-        if source_version in ql:
-            continue
-
-        # Try (2)
-        from daklib.regexes import re_bin_only_nmu
-        orig_source_version = re_bin_only_nmu.sub('', source_version)
-        if orig_source_version in ql:
+        if q.count() > 0:
             continue
 
         # No source found so return not ok
-        ret = 0
+        ret = False
 
     return ret
 
@@ -2168,7 +2249,7 @@ def get_suites_source_in(source, session=None):
     @return: list of Suite objects for the given source
     """
 
-    return session.query(Suite).join(SrcAssociation).join(DBSource).filter_by(source=source).all()
+    return session.query(Suite).filter(Suite.sources.any(source=source)).all()
 
 __all__.append('get_suites_source_in')
 
@@ -2207,10 +2288,12 @@ def get_sources_from_name(source, version=None, dm_upload_allowed=None, session=
 
 __all__.append('get_sources_from_name')
 
+# FIXME: This function fails badly if it finds more than 1 source package and
+# its implementation is trivial enough to be inlined.
 @session_wrapper
 def get_source_in_suite(source, suite, session=None):
     """
-    Returns list of DBSource objects for a combination of C{source} and C{suite}.
+    Returns a DBSource object for a combination of C{source} and C{suite}.
 
       - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc}
       - B{suite} - a suite name, eg. I{unstable}
@@ -2226,12 +2309,9 @@ def get_source_in_suite(source, suite, session=None):
 
     """
 
-    q = session.query(SrcAssociation)
-    q = q.join('source').filter_by(source=source)
-    q = q.join('suite').filter_by(suite_name=suite)
-
+    q = get_suite(suite, session).get_sources(source)
     try:
-        return q.one().source
+        return q.one()
     except NoResultFound:
         return None
 
@@ -2267,15 +2347,10 @@ def add_dsc_to_db(u, filename, session=None):
 
     source.poolfile_id = entry["files id"]
     session.add(source)
-    session.flush()
 
-    for suite_name in u.pkg.changes["distribution"].keys():
-        sa = SrcAssociation()
-        sa.source_id = source.source_id
-        sa.suite_id = get_suite(suite_name).suite_id
-        session.add(sa)
-
-    session.flush()
+    suite_names = u.pkg.changes["distribution"].keys()
+    source.suites = session.query(Suite). \
+        filter(Suite.suite_name.in_(suite_names)).all()
 
     # Add the source files to the DB (files and dsc_files)
     dscfile = DSCFile()
@@ -2325,8 +2400,6 @@ def add_dsc_to_db(u, filename, session=None):
         df.poolfile_id = files_id
         session.add(df)
 
-    session.flush()
-
     # Add the src_uploaders to the DB
     uploader_ids = [source.maintainer_id]
     if u.pkg.dsc.has_key("uploaders"):
@@ -2431,17 +2504,6 @@ __all__.append('SourceACL')
 
 ################################################################################
 
-class SrcAssociation(object):
-    def __init__(self, *args, **kwargs):
-        pass
-
-    def __repr__(self):
-        return '<SrcAssociation %s (%s, %s)>' % (self.sa_id, self.source, self.suite)
-
-__all__.append('SrcAssociation')
-
-################################################################################
-
 class SrcFormat(object):
     def __init__(self, *args, **kwargs):
         pass
@@ -2480,6 +2542,8 @@ SUITE_FIELDS = [ ('SuiteName', 'suite_name'),
                  ('CopyChanges', 'copychanges'),
                  ('OverrideSuite', 'overridesuite')]
 
+# Why the heck don't we have any UNIQUE constraints in table suite?
+# TODO: Add UNIQUE constraints for appropriate columns.
 class Suite(object):
     def __init__(self, suite_name = None, version = None):
         self.suite_name = suite_name
@@ -2525,14 +2589,31 @@ class Suite(object):
         @return: list of Architecture objects for the given name (may be empty)
         """
 
-        q = object_session(self).query(Architecture). \
-            filter(Architecture.suites.contains(self))
+        q = object_session(self).query(Architecture).with_parent(self)
         if skipsrc:
             q = q.filter(Architecture.arch_string != 'source')
         if skipall:
             q = q.filter(Architecture.arch_string != 'all')
         return q.order_by(Architecture.arch_string).all()
 
+    def get_sources(self, source):
+        """
+        Returns a query object representing DBSource that is part of C{suite}.
+
+          - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc}
+
+        @type source: string
+        @param source: source package name
+
+        @rtype: sqlalchemy.orm.query.Query
+        @return: a query of DBSource
+
+        """
+
+        session = object_session(self)
+        return session.query(DBSource).filter_by(source = source). \
+            with_parent(self)
+
 __all__.append('Suite')
 
 @session_wrapper
@@ -3054,13 +3135,6 @@ class DBConn(object):
         mapper(SourceACL, self.tbl_source_acl,
                properties = dict(source_acl_id = self.tbl_source_acl.c.id))
 
-        mapper(SrcAssociation, self.tbl_src_associations,
-               properties = dict(sa_id = self.tbl_src_associations.c.id,
-                                 suite_id = self.tbl_src_associations.c.suite,
-                                 suite = relation(Suite),
-                                 source_id = self.tbl_src_associations.c.source,
-                                 source = relation(DBSource)))
-
         mapper(SrcFormat, self.tbl_src_format,
                properties = dict(src_format_id = self.tbl_src_format.c.id,
                                  format_name = self.tbl_src_format.c.format_name))