]> git.decadent.org.uk Git - dak.git/blobdiff - daklib/dbconn.py
Simplify validation of not NULL constraints.
[dak.git] / daklib / dbconn.py
index 1364539422961acf89e3a422e9cbf66a5f864e18..5e047847e0b3d36054b174461228371bd08c48c9 100755 (executable)
@@ -54,7 +54,8 @@ from inspect import getargspec
 
 import sqlalchemy
 from sqlalchemy import create_engine, Table, MetaData, Column, Integer
-from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, backref
+from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, \
+    backref, MapperExtension, EXT_CONTINUE
 from sqlalchemy import types as sqltypes
 
 # Don't remove this, we re-export the exceptions to scripts which import us
@@ -65,7 +66,7 @@ from sqlalchemy.orm.exc import NoResultFound
 # in the database
 from config import Config
 from textutils import fix_maintainer
-from dak_exceptions import NoSourceFieldError
+from dak_exceptions import DBUpdateError, NoSourceFieldError
 
 # suppress some deprecation warnings in squeeze related to sqlalchemy
 import warnings
@@ -167,7 +168,7 @@ __all__.append('session_wrapper')
 class ORMObject(object):
     """
     ORMObject is a base class for all ORM classes mapped by SQLalchemy. All
-    derived classes must implement the summary() method.
+    derived classes must implement the properties() method.
     """
 
     def properties(self):
@@ -192,7 +193,10 @@ class ORMObject(object):
         for property in all_properties:
             # check for list or query
             if property[-6:] == '_count':
-                value = getattr(self, property[:-6])
+                real_property = property[:-6]
+                if not hasattr(self, real_property):
+                    continue
+                value = getattr(self, real_property)
                 if hasattr(value, '__len__'):
                     # list
                     value = len(value)
@@ -202,17 +206,19 @@ class ORMObject(object):
                 else:
                     raise KeyError('Do not understand property %s.' % property)
             else:
+                if not hasattr(self, property):
+                    continue
                 # plain object
                 value = getattr(self, property)
                 if value is None:
                     # skip None
-                    pass
+                    continue
                 elif isinstance(value, ORMObject):
                     # use repr() for ORMObject types
                     value = repr(value)
                 else:
                     # we want a string for all other types because json cannot
-                    # everything
+                    # encode everything
                     value = str(value)
             data[property] = value
         return json.dumps(data)
@@ -239,10 +245,50 @@ class ORMObject(object):
         '''
         return '<%s %s>' % (self.classname(), self.json())
 
+    def not_null_constraints(self):
+        '''
+        Returns a list of properties that must be not NULL. Derived classes
+        should override this method if needed.
+        '''
+        return []
+
+    validation_message = \
+        "Validation failed because property '%s' must not be empty in object\n%s"
+
+    def validate(self):
+        '''
+        This function validates the not NULL constraints as returned by
+        not_null_constraints(). It raises the DBUpdateError exception if
+        validation fails.
+        '''
+        for property in self.not_null_constraints():
+            if not hasattr(self, property) or getattr(self, property) is None:
+                raise DBUpdateError(self.validation_message % \
+                    (property, str(self)))
+
 __all__.append('ORMObject')
 
 ################################################################################
 
+class Validator(MapperExtension):
+    '''
+    This class calls the validate() method for each instance for the
+    'before_update' and 'before_insert' events. A global object validator is
+    used for configuring the individual mappers.
+    '''
+
+    def before_update(self, mapper, connection, instance):
+        instance.validate()
+        return EXT_CONTINUE
+
+    def before_insert(self, mapper, connection, instance):
+        instance.validate()
+        return EXT_CONTINUE
+
+validator = Validator()
+
+################################################################################
+
 class Architecture(ORMObject):
     def __init__(self, arch_string = None, description = None):
         self.arch_string = arch_string
@@ -263,6 +309,9 @@ class Architecture(ORMObject):
     def properties(self):
         return ['arch_string', 'arch_id', 'suites_count']
 
+    def not_null_constraints(self):
+        return ['arch_string']
+
 __all__.append('Architecture')
 
 @session_wrapper
@@ -1175,6 +1224,9 @@ class PoolFile(ORMObject):
         return ['filename', 'file_id', 'filesize', 'md5sum', 'sha1sum', \
             'sha256sum', 'location', 'source', 'last_used']
 
+    def not_null_constraints(self):
+        return ['filename', 'md5sum', 'location']
+
 __all__.append('PoolFile')
 
 @session_wrapper
@@ -1285,12 +1337,16 @@ __all__.append('add_poolfile')
 
 ################################################################################
 
-class Fingerprint(object):
+class Fingerprint(ORMObject):
     def __init__(self, fingerprint = None):
         self.fingerprint = fingerprint
 
-    def __repr__(self):
-        return '<Fingerprint %s>' % self.fingerprint
+    def properties(self):
+        return ['fingerprint', 'fingerprint_id', 'keyring', 'uid', \
+            'binary_reject']
+
+    def not_null_constraints(self):
+        return ['fingerprint']
 
 __all__.append('Fingerprint')
 
@@ -1575,14 +1631,17 @@ __all__.append('get_dbchange')
 
 ################################################################################
 
-class Location(object):
+class Location(ORMObject):
     def __init__(self, path = None):
         self.path = path
         # the column 'type' should go away, see comment at mapper
         self.archive_type = 'pool'
 
-    def __repr__(self):
-        return '<Location %s (%s)>' % (self.path, self.location_id)
+    def properties(self):
+        return ['path', 'archive_type', 'component', 'files_count']
+
+    def not_null_constraints(self):
+        return ['path', 'archive_type']
 
 __all__.append('Location')
 
@@ -1622,12 +1681,15 @@ __all__.append('get_location')
 
 ################################################################################
 
-class Maintainer(object):
+class Maintainer(ORMObject):
     def __init__(self, name = None):
         self.name = name
 
-    def __repr__(self):
-        return '''<Maintainer '%s' (%s)>''' % (self.name, self.maintainer_id)
+    def properties(self):
+        return ['name', 'maintainer_id']
+
+    def not_null_constraints(self):
+        return ['name']
 
     def get_split_maintainer(self):
         if not hasattr(self, 'name') or self.name is None:
@@ -2161,7 +2223,7 @@ __all__.append('get_sections')
 
 ################################################################################
 
-class DBSource(object):
+class DBSource(ORMObject):
     def __init__(self, source = None, version = None, maintainer = None, \
         changedby = None, poolfile = None, install_date = None):
         self.source = source
@@ -2171,8 +2233,14 @@ class DBSource(object):
         self.poolfile = poolfile
         self.install_date = install_date
 
-    def __repr__(self):
-        return '<DBSource %s (%s)>' % (self.source, self.version)
+    def properties(self):
+        return ['source', 'source_id', 'maintainer', 'changedby', \
+            'fingerprint', 'poolfile', 'version', 'suites_count', \
+            'install_date']
+
+    def not_null_constraints(self):
+        return ['source', 'version', 'maintainer', 'changedby', \
+            'poolfile', 'install_date']
 
 __all__.append('DBSource')
 
@@ -2544,13 +2612,16 @@ SUITE_FIELDS = [ ('SuiteName', 'suite_name'),
 
 # Why the heck don't we have any UNIQUE constraints in table suite?
 # TODO: Add UNIQUE constraints for appropriate columns.
-class Suite(object):
+class Suite(ORMObject):
     def __init__(self, suite_name = None, version = None):
         self.suite_name = suite_name
         self.version = version
 
-    def __repr__(self):
-        return '<Suite %s>' % self.suite_name
+    def properties(self):
+        return ['suite_name', 'version']
+
+    def not_null_constraints(self):
+        return ['suite_name', 'version']
 
     def __eq__(self, val):
         if isinstance(val, str):
@@ -2907,10 +2978,11 @@ class DBConn(object):
 
     def __setupmappers(self):
         mapper(Architecture, self.tbl_architecture,
-           properties = dict(arch_id = self.tbl_architecture.c.id,
+            properties = dict(arch_id = self.tbl_architecture.c.id,
                suites = relation(Suite, secondary=self.tbl_suite_architectures,
                    order_by='suite_name',
-                   backref=backref('architectures', order_by='arch_string'))))
+                   backref=backref('architectures', order_by='arch_string'))),
+            extension = validator)
 
         mapper(Archive, self.tbl_archive,
                properties = dict(archive_id = self.tbl_archive.c.id,
@@ -3003,7 +3075,8 @@ class DBConn(object):
                                      # using lazy='dynamic' in the back
                                      # reference because we have A LOT of
                                      # files in one location
-                                     backref=backref('files', lazy='dynamic'))))
+                                     backref=backref('files', lazy='dynamic'))),
+                extension = validator)
 
         mapper(Fingerprint, self.tbl_fingerprint,
                properties = dict(fingerprint_id = self.tbl_fingerprint.c.id,
@@ -3012,7 +3085,8 @@ class DBConn(object):
                                  keyring_id = self.tbl_fingerprint.c.keyring,
                                  keyring = relation(Keyring),
                                  source_acl = relation(SourceACL),
-                                 binary_acl = relation(BinaryACL)))
+                                 binary_acl = relation(BinaryACL)),
+               extension = validator)
 
         mapper(Keyring, self.tbl_keyrings,
                properties = dict(keyring_name = self.tbl_keyrings.c.name,
@@ -3078,14 +3152,16 @@ class DBConn(object):
                                  archive = relation(Archive),
                                  # FIXME: the 'type' column is old cruft and
                                  # should be removed in the future.
-                                 archive_type = self.tbl_location.c.type))
+                                 archive_type = self.tbl_location.c.type),
+               extension = validator)
 
         mapper(Maintainer, self.tbl_maintainer,
                properties = dict(maintainer_id = self.tbl_maintainer.c.id,
                    maintains_sources = relation(DBSource, backref='maintainer',
                        primaryjoin=(self.tbl_maintainer.c.id==self.tbl_source.c.maintainer)),
                    changed_sources = relation(DBSource, backref='changedby',
-                       primaryjoin=(self.tbl_maintainer.c.id==self.tbl_source.c.changedby))))
+                       primaryjoin=(self.tbl_maintainer.c.id==self.tbl_source.c.changedby))),
+                extension = validator)
 
         mapper(NewComment, self.tbl_new_comments,
                properties = dict(comment_id = self.tbl_new_comments.c.id))
@@ -3130,7 +3206,8 @@ class DBConn(object):
                                                      primaryjoin=(self.tbl_source.c.id==self.tbl_dsc_files.c.source)),
                                  suites = relation(Suite, secondary=self.tbl_src_associations,
                                      backref='sources'),
-                                 srcuploaders = relation(SrcUploader)))
+                                 srcuploaders = relation(SrcUploader)),
+               extension = validator)
 
         mapper(SourceACL, self.tbl_source_acl,
                properties = dict(source_acl_id = self.tbl_source_acl.c.id))
@@ -3151,7 +3228,9 @@ class DBConn(object):
         mapper(Suite, self.tbl_suite,
                properties = dict(suite_id = self.tbl_suite.c.id,
                                  policy_queue = relation(PolicyQueue),
-                                 copy_queues = relation(BuildQueue, secondary=self.tbl_suite_build_queue_copy)))
+                                 copy_queues = relation(BuildQueue,
+                                     secondary=self.tbl_suite_build_queue_copy)),
+                extension = validator)
 
         mapper(SuiteSrcFormat, self.tbl_suite_src_formats,
                properties = dict(suite_id = self.tbl_suite_src_formats.c.suite,