]> git.decadent.org.uk Git - dak.git/commitdiff
check in old SQLa work
authorMark Hymers <mhy@debian.org>
Sun, 5 Apr 2009 19:14:16 +0000 (20:14 +0100)
committerMark Hymers <mhy@debian.org>
Sun, 9 Aug 2009 15:49:19 +0000 (16:49 +0100)
Signed-off-by: Mark Hymers <mhy@debian.org>
daklib/dbconn.py

index 58dd7fc55520c09bda083d6ca51ec9536f28f971..f64a1fec144ee914a6a2a7de697d333505156736 100755 (executable)
@@ -37,6 +37,9 @@ import os
 import psycopg2
 import traceback
 
+from sqlalchemy import create_engine, Table, MetaData, select
+from sqlalchemy.orm import sessionmaker
+
 from singleton import Singleton
 from config import Config
 
@@ -47,7 +50,7 @@ class Cache(object):
         if hashfunc:
             self.hashfunc = hashfunc
         else:
-            self.hashfunc = lambda x: x['value']
+            self.hashfunc = lambda x: str(x)
 
         self.data = {}
 
@@ -70,25 +73,62 @@ class DBConn(Singleton):
         self.__createconn()
         self.__init_caches()
 
+    def __setuptables(self):
+        self.tbl_architecture = Table('architecture', self.db_meta, autoload=True)
+        self.tbl_archive = Table('archive', self.db_meta, autoload=True)
+        self.tbl_bin_associations = Table('bin_associations', self.db_meta, autoload=True)
+        self.tbl_binaries = Table('binaries', self.db_meta, autoload=True)
+        self.tbl_component = Table('component', self.db_meta, autoload=True)
+        self.tbl_config = Table('config', self.db_meta, autoload=True)
+        self.tbl_content_associations = Table('content_associations', self.db_meta, autoload=True)
+        self.tbl_content_file_names = Table('content_file_names', self.db_meta, autoload=True)
+        self.tbl_content_file_paths = Table('content_file_paths', self.db_meta, autoload=True)
+        self.tbl_dsc_files = Table('dsc_files', self.db_meta, autoload=True)
+        self.tbl_files = Table('files', self.db_meta, autoload=True)
+        self.tbl_fingerprint = Table('fingerprint', self.db_meta, autoload=True)
+        self.tbl_keyrings = Table('keyrings', self.db_meta, autoload=True)
+        self.tbl_location = Table('location', self.db_meta, autoload=True)
+        self.tbl_maintainer = Table('maintainer', self.db_meta, autoload=True)
+        self.tbl_override = Table('override', self.db_meta, autoload=True)
+        self.tbl_override_type = Table('override_type', self.db_meta, autoload=True)
+        self.tbl_pending_content_associations = Table('pending_content_associations', self.db_meta, autoload=True)
+        self.tbl_priority = Table('priority', self.db_meta, autoload=True)
+        self.tbl_queue = Table('queue', self.db_meta, autoload=True)
+        self.tbl_queue_build = Table('queue_build', self.db_meta, autoload=True)
+        self.tbl_section = Table('section', self.db_meta, autoload=True)
+        self.tbl_source = Table('source', self.db_meta, autoload=True)
+        self.tbl_src_associations = Table('src_associations', self.db_meta, autoload=True)
+        self.tbl_src_uploaders = Table('src_uploaders', self.db_meta, autoload=True)
+        self.tbl_suite = Table('suite', self.db_meta, autoload=True)
+        self.tbl_suite_architectures = Table('suite_architectures', self.db_meta, autoload=True)
+        self.tbl_uid = Table('uid', self.db_meta, autoload=True)
+
     ## Connection functions
     def __createconn(self):
         cnf = Config()
-        connstr = "dbname=%s" % cnf["DB::Name"]
         if cnf["DB::Host"]:
-           connstr += " host=%s" % cnf["DB::Host"]
-        if cnf["DB::Port"] and cnf["DB::Port"] != "-1":
-           connstr += " port=%s" % cnf["DB::Port"]
+            # TCP/IP
+            connstr = "postgres://%s" % cnf["DB::Host"]
+            if cnf["DB::Port"] and cnf["DB::Port"] != "-1":
+                connstr += ":%s" % cnf["DB::Port"]
+            connstr += "/%s" % cnf["DB::Name"]
+        else:
+            # Unix Socket
+            connstr = "postgres:///%s" % cnf["DB::Name"]
+            if cnf["DB::Port"] and cnf["DB::Port"] != "-1":
+                connstr += "?port=%s" % cnf["DB::Port"]
 
-        self.db_con = psycopg2.connect(connstr)
+        self.db_pg   = create_engine(connstr)
+        self.db_meta = MetaData()
+        self.db_meta.bind = self.db_pg
+        self.db_smaker = sessionmaker(bind=self.db_pg,
+                                      autoflush=True,
+                                      transactional=True)
 
-    def reconnect(self):
-        try:
-            self.db_con.close()
-        except psycopg2.InterfaceError:
-            pass
+        self.__setuptables()
 
-        self.db_con = None
-        self.__createconn()
+    def session(self):
+        return self.db_smaker()
 
     ## Cache functions
     def __init_caches(self):
@@ -117,45 +157,37 @@ class DBConn(Singleton):
 
     def prepare(self,name,statement):
         if not self.prepared_statements.has_key(name):
-            c = self.cursor()
-            c.execute(statement)
+            pgc.execute(statement)
             self.prepared_statements[name] = statement
 
     def clear_caches(self):
         self.__init_caches()
 
-    ## Functions to pass through to the database connector
-    def cursor(self):
-        return self.db_con.cursor()
-
-    def commit(self):
-        return self.db_con.commit()
-
     ## Get functions
-    def __get_single_id(self, query, values, cachename=None):
+    def __get_id(self, retfield, selectobj, cachekey, cachename=None):
         # This is a bit of a hack but it's an internal function only
         if cachename is not None:
-            res = self.caches[cachename].GetValue(values)
+            res = self.caches[cachename].GetValue(cachekey)
             if res:
                 return res
 
-        c = self.db_con.cursor()
-        c.execute(query, values)
+        c = selectobj.execute()
 
         if c.rowcount != 1:
             return None
 
-        res = c.fetchone()[0]
+        res = c.fetchone()
+
+        if retfield not in res.keys():
+            return None
+
+        res = res[retfield]
 
         if cachename is not None:
-            self.caches[cachename].SetValue(values, res)
+            self.caches[cachename].SetValue(cachekey, res)
 
         return res
 
-    def __get_id(self, retfield, table, qfield, value):
-        query = "SELECT %s FROM %s WHERE %s = %%(value)s" % (retfield, table, qfield)
-        return self.__get_single_id(query, {'value': value}, cachename=table)
-
     def get_suite_id(self, suite):
         """
         Returns database id for given C{suite}.
@@ -168,11 +200,10 @@ class DBConn(Singleton):
         @return: the database id for the given suite
 
         """
-        suiteid = self.__get_id('id', 'suite', 'suite_name', suite)
-        if suiteid is None:
-            return None
-        else:
-            return int(suiteid)
+        return int(self.__get_id('id',
+                                 self.tbl_suite.select(self.tbl_suite.columns.suite_name == suite),
+                                 suite,
+                                 'suite'))
 
     def get_section_id(self, section):
         """
@@ -186,7 +217,10 @@ class DBConn(Singleton):
         @return: the database id for the given section
 
         """
-        return self.__get_id('id', 'section', 'section', section)
+        return self.__get_id('id',
+                             self.tbl_section.select(self.tbl_section.columns.section == section),
+                             section,
+                             'section')
 
     def get_priority_id(self, priority):
         """
@@ -200,7 +234,10 @@ class DBConn(Singleton):
         @return: the database id for the given priority
 
         """
-        return self.__get_id('id', 'priority', 'priority', priority)
+        return self.__get_id('id',
+                             self.tbl_priority.select(self.tbl_priority.columns.priority == priority),
+                             priority,
+                             'priority')
 
     def get_override_type_id(self, override_type):
         """
@@ -214,7 +251,10 @@ class DBConn(Singleton):
         @return: the database id for the given override type
 
         """
-        return self.__get_id('id', 'override_type', 'type', override_type)
+        return self.__get_id('id',
+                             self.tbl_override_type.select(self.tbl_override_type.columns.type == override_type),
+                             override_type,
+                             'override_type')
 
     def get_architecture_id(self, architecture):
         """
@@ -228,7 +268,10 @@ class DBConn(Singleton):
         @return: the database id for the given architecture
 
         """
-        return self.__get_id('id', 'architecture', 'arch_string', architecture)
+        return self.__get_id('id',
+                             self.tbl_architecture.select(self.tbl_architecture.columns.arch_string == architecture),
+                             architecture,
+                             'architecture')
 
     def get_archive_id(self, archive):
         """
@@ -242,7 +285,11 @@ class DBConn(Singleton):
         @return: the database id for the given archive
 
         """
-        return self.__get_id('id', 'archive', 'lower(name)', archive)
+        archive = archive.lower()
+        return self.__get_id('id',
+                             self.tbl_archive.select(self.tbl_archive.columns.name == archive),
+                             archive,
+                             'archive')
 
     def get_component_id(self, component):
         """
@@ -256,7 +303,11 @@ class DBConn(Singleton):
         @return: the database id for the given component
 
         """
-        return self.__get_id('id', 'component', 'lower(name)', component)
+        component = component.lower()
+        return self.__get_id('id',
+                             self.tbl_component.select(self.tbl_component.columns.name == component),
+                             component.lower(),
+                             'component')
 
     def get_location_id(self, location, component, archive):
         """
@@ -269,36 +320,29 @@ class DBConn(Singleton):
         @type location: string
         @param location: the path of the location
 
-        @type component: int
-        @param component: the id of the component
+        @type component: string
+        @param component: the name of the component
 
-        @type archive: int
-        @param archive: the id of the archive
+        @type archive: string
+        @param archive: the name of the archive
 
         @rtype: int
         @return: the database id for the location
 
         """
 
-        archive_id = self.get_archive_id(archive)
+        archive = archive.lower()
+        component = component.lower()
 
-        if not archive_id:
-            return None
+        values = {'archive': archive, 'location': location, 'component': component}
 
-        res = None
+        s = self.tbl_location.join(self.tbl_archive).join(self.tbl_component)
 
-        if component:
-            component_id = self.get_component_id(component)
-            if component_id:
-                res = self.__get_single_id("SELECT id FROM location WHERE path=%(location)s AND component=%(component)s AND archive=%(archive)s",
-                        {'location': location,
-                         'archive': int(archive_id),
-                         'component': component_id}, cachename='location')
-        else:
-            res = self.__get_single_id("SELECT id FROM location WHERE path=%(location)s AND archive=%(archive)d",
-                    {'location': location, 'archive': archive_id, 'component': ''}, cachename='location')
+        s = s.select(self.tbl_location.columns.path == location)
+        s = s.where(self.tbl_archive.columns.name == archive)
+        s = s.where(self.tbl_component.columns.name == component)
 
-        return res
+        return self.__get_id('location.id', s, values, 'location')
 
     def get_source_id(self, source, version):
         """
@@ -317,8 +361,24 @@ class DBConn(Singleton):
         @return: the database id for the source
 
         """
-        return self.__get_single_id("SELECT id FROM source s WHERE s.source=%(source)s AND s.version=%(version)s",
-                                 {'source': source, 'version': version}, cachename='source')
+        s = self.tbl_source.select()
+        s = s.where(self.tbl_source.columns.source  == source)
+        s = s.where(self.tbl_source.columns.version == version)
+
+        return self.__get_id('id', s, {'source': source, 'version': version}, 'source')
+
+    def get_suite(self, suite):
+        if isinstance(suite, str):
+            suite_id = self.get_suite_id(suite.lower())
+        elif type(suite) == int:
+            suite_id = suite
+
+        s = self.tbl_suite.select(self.tbl_suite.columns.id == suite_id)
+        c = s.execute()
+        if c.rowcount < 1:
+            return None
+        else:
+            return c.fetchone()
 
     def get_suite_version(self, source, suite):
         """
@@ -339,12 +399,13 @@ class DBConn(Singleton):
         @return: the version for I{source} in I{suite}
 
         """
-        return self.__get_single_id("""
-        SELECT s.version FROM source s, suite su, src_associations sa
-        WHERE sa.source=s.id
-          AND sa.suite=su.id
-          AND su.suite_name=%(suite)s
-          AND s.source=%(source)""", {'suite': suite, 'source': source}, cachename='suite_version')
+        s = select([self.tbl_source.columns.source, self.tbl_source.columns.version])
+#        s = self.tbl_source.join(self.tbl_src_associations).join(self.tbl_suite)
+
+        s = s.select(self.tbl_suite.columns.suite_name == suite, use_labels=True)
+        s = s.select(self.tbl_source.columns.source == source)
+
+        return self.__get_id('source.version', s, {'suite': suite, 'source': source}, 'suite_version')
 
 
     def get_files_id (self, filename, size, md5sum, location_id):
@@ -582,3 +643,4 @@ class DBConn(Singleton):
             traceback.print_exc()
             c.execute("ROLLBACK")
             return False
+