]> git.decadent.org.uk Git - dak.git/blobdiff - daklib/dbconn.py
Add classes Validator and ValidatorTestCase.
[dak.git] / daklib / dbconn.py
index 9c1ed2b39b677d0f1d85f21d19c89a1d57b95c96..877c5a39fc4032e670c9f6295e65e1276c50fd7c 100755 (executable)
@@ -38,6 +38,14 @@ import re
 import psycopg2
 import traceback
 import commands
+
+try:
+    # python >= 2.6
+    import json
+except:
+    # python <= 2.5
+    import simplejson as json
+
 from datetime import datetime, timedelta
 from errno import ENOENT
 from tempfile import mkstemp, mkdtemp
@@ -46,7 +54,8 @@ from inspect import getargspec
 
 import sqlalchemy
 from sqlalchemy import create_engine, Table, MetaData, Column, Integer
-from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, backref
+from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, \
+    backref, MapperExtension, EXT_CONTINUE
 from sqlalchemy import types as sqltypes
 
 # Don't remove this, we re-export the exceptions to scripts which import us
@@ -57,7 +66,7 @@ from sqlalchemy.orm.exc import NoResultFound
 # in the database
 from config import Config
 from textutils import fix_maintainer
-from dak_exceptions import NoSourceFieldError
+from dak_exceptions import DBUpdateError, NoSourceFieldError
 
 # suppress some deprecation warnings in squeeze related to sqlalchemy
 import warnings
@@ -156,7 +165,117 @@ __all__.append('session_wrapper')
 
 ################################################################################
 
-class Architecture(object):
+class ORMObject(object):
+    """
+    ORMObject is a base class for all ORM classes mapped by SQLalchemy. All
+    derived classes must implement the properties() method.
+    """
+
+    def properties(self):
+        '''
+        This method should be implemented by all derived classes and returns a
+        list of the important properties. The properties 'created' and
+        'modified' will be added automatically. A suffix '_count' should be
+        added to properties that are lists or query objects. The most important
+        property name should be returned as the first element in the list
+        because it is used by repr().
+        '''
+        return []
+
+    def json(self):
+        '''
+        Returns a JSON representation of the object based on the properties
+        returned from the properties() method.
+        '''
+        data = {}
+        # add created and modified
+        all_properties = self.properties() + ['created', 'modified']
+        for property in all_properties:
+            # check for list or query
+            if property[-6:] == '_count':
+                real_property = property[:-6]
+                if not hasattr(self, real_property):
+                    continue
+                value = getattr(self, real_property)
+                if hasattr(value, '__len__'):
+                    # list
+                    value = len(value)
+                elif hasattr(value, 'count'):
+                    # query
+                    value = value.count()
+                else:
+                    raise KeyError('Do not understand property %s.' % property)
+            else:
+                if not hasattr(self, property):
+                    continue
+                # plain object
+                value = getattr(self, property)
+                if value is None:
+                    # skip None
+                    pass
+                elif isinstance(value, ORMObject):
+                    # use repr() for ORMObject types
+                    value = repr(value)
+                else:
+                    # we want a string for all other types because json cannot
+                    # encode everything
+                    value = str(value)
+            data[property] = value
+        return json.dumps(data)
+
+    def classname(self):
+        '''
+        Returns the name of the class.
+        '''
+        return type(self).__name__
+
+    def __repr__(self):
+        '''
+        Returns a short string representation of the object using the first
+        element from the properties() method.
+        '''
+        primary_property = self.properties()[0]
+        value = getattr(self, primary_property)
+        return '<%s %s>' % (self.classname(), str(value))
+
+    def __str__(self):
+        '''
+        Returns a human readable form of the object using the properties()
+        method.
+        '''
+        return '<%s %s>' % (self.classname(), self.json())
+
+    def validate(self):
+        '''
+        This function should be implemented by derived classes to validate self.
+        It may raise the DBUpdateError exception if needed.
+        '''
+        pass
+
+__all__.append('ORMObject')
+
+################################################################################
+
+class Validator(MapperExtension):
+    '''
+    This class calls the validate() method for each instance for the
+    'before_update' and 'before_insert' events. A global object validator is
+    used for configuring the individual mappers.
+    '''
+
+    def before_update(self, mapper, connection, instance):
+        instance.validate()
+        return EXT_CONTINUE
+
+    def before_insert(self, mapper, connection, instance):
+        instance.validate()
+        return EXT_CONTINUE
+
+validator = Validator()
+
+################################################################################
+
+class Architecture(ORMObject):
     def __init__(self, arch_string = None, description = None):
         self.arch_string = arch_string
         self.description = description
@@ -173,8 +292,14 @@ class Architecture(object):
         # This signals to use the normal comparison operator
         return NotImplemented
 
-    def __repr__(self):
-        return '<Architecture %s>' % self.arch_string
+    def properties(self):
+        return ['arch_string', 'arch_id', 'suites_count']
+
+    def validate(self):
+        if self.arch_string is None or len(self.arch_string) == 0:
+            raise DBUpdateError( \
+                "Validation failed because 'arch_string' must not be empty in object\n%s" % \
+                str(self))
 
 __all__.append('Architecture')
 
@@ -1069,7 +1194,7 @@ __all__.append('get_dscfiles')
 
 ################################################################################
 
-class PoolFile(object):
+class PoolFile(ORMObject):
     def __init__(self, filename = None, location = None, filesize = -1, \
         md5sum = None):
         self.filename = filename
@@ -1077,9 +1202,6 @@ class PoolFile(object):
         self.filesize = filesize
         self.md5sum = md5sum
 
-    def __repr__(self):
-        return '<PoolFile %s>' % self.filename
-
     @property
     def fullpath(self):
         return os.path.join(self.location.path, self.filename)
@@ -1087,6 +1209,10 @@ class PoolFile(object):
     def is_valid(self, filesize = -1, md5sum = None):\
         return self.filesize == filesize and self.md5sum == md5sum
 
+    def properties(self):
+        return ['filename', 'file_id', 'filesize', 'md5sum', 'sha1sum', \
+            'sha256sum', 'location', 'source', 'last_used']
+
 __all__.append('PoolFile')
 
 @session_wrapper
@@ -2115,10 +2241,14 @@ def source_exists(source, source_version, suites = ["any"], session=None):
     """
 
     cnf = Config()
-    ret = 1
+    ret = True
+
+    from daklib.regexes import re_bin_only_nmu
+    orig_source_version = re_bin_only_nmu.sub('', source_version)
 
     for suite in suites:
-        q = session.query(DBSource).filter_by(source=source)
+        q = session.query(DBSource).filter_by(source=source). \
+            filter(DBSource.version.in_([source_version, orig_source_version]))
         if suite != "any":
             # source must exist in suite X, or in some other suite that's
             # mapped to X, recursively... silent-maps are counted too,
@@ -2133,24 +2263,13 @@ def source_exists(source, source_version, suites = ["any"], session=None):
                 if x[1] in s and x[0] not in s:
                     s.append(x[0])
 
-            q = q.join(SrcAssociation).join(Suite)
-            q = q.filter(Suite.suite_name.in_(s))
-
-        # Reduce the query results to a list of version numbers
-        ql = [ j.version for j in q.all() ]
-
-        # Try (1)
-        if source_version in ql:
-            continue
+            q = q.filter(DBSource.suites.any(Suite.suite_name.in_(s)))
 
-        # Try (2)
-        from daklib.regexes import re_bin_only_nmu
-        orig_source_version = re_bin_only_nmu.sub('', source_version)
-        if orig_source_version in ql:
+        if q.count() > 0:
             continue
 
         # No source found so return not ok
-        ret = 0
+        ret = False
 
     return ret
 
@@ -2423,17 +2542,6 @@ __all__.append('SourceACL')
 
 ################################################################################
 
-class SrcAssociation(object):
-    def __init__(self, *args, **kwargs):
-        pass
-
-    def __repr__(self):
-        return '<SrcAssociation %s (%s, %s)>' % (self.sa_id, self.source, self.suite)
-
-__all__.append('SrcAssociation')
-
-################################################################################
-
 class SrcFormat(object):
     def __init__(self, *args, **kwargs):
         pass
@@ -2519,8 +2627,7 @@ class Suite(object):
         @return: list of Architecture objects for the given name (may be empty)
         """
 
-        q = object_session(self).query(Architecture). \
-            filter(Architecture.suites.contains(self))
+        q = object_session(self).query(Architecture).with_parent(self)
         if skipsrc:
             q = q.filter(Architecture.arch_string != 'source')
         if skipall:
@@ -2543,7 +2650,7 @@ class Suite(object):
 
         session = object_session(self)
         return session.query(DBSource).filter_by(source = source). \
-            filter(DBSource.suites.contains(self))
+            with_parent(self)
 
 __all__.append('Suite')
 
@@ -2838,10 +2945,11 @@ class DBConn(object):
 
     def __setupmappers(self):
         mapper(Architecture, self.tbl_architecture,
-           properties = dict(arch_id = self.tbl_architecture.c.id,
+            properties = dict(arch_id = self.tbl_architecture.c.id,
                suites = relation(Suite, secondary=self.tbl_suite_architectures,
                    order_by='suite_name',
-                   backref=backref('architectures', order_by='arch_string'))))
+                   backref=backref('architectures', order_by='arch_string'))),
+            extension = validator)
 
         mapper(Archive, self.tbl_archive,
                properties = dict(archive_id = self.tbl_archive.c.id,
@@ -3066,13 +3174,6 @@ class DBConn(object):
         mapper(SourceACL, self.tbl_source_acl,
                properties = dict(source_acl_id = self.tbl_source_acl.c.id))
 
-        mapper(SrcAssociation, self.tbl_src_associations,
-               properties = dict(sa_id = self.tbl_src_associations.c.id,
-                                 suite_id = self.tbl_src_associations.c.suite,
-                                 suite = relation(Suite),
-                                 source_id = self.tbl_src_associations.c.source,
-                                 source = relation(DBSource)))
-
         mapper(SrcFormat, self.tbl_src_format,
                properties = dict(src_format_id = self.tbl_src_format.c.id,
                                  format_name = self.tbl_src_format.c.format_name))