]> git.decadent.org.uk Git - dak.git/blobdiff - daklib/dbconn.py
Move get_queue => get_or_set_queue
[dak.git] / daklib / dbconn.py
index 2f5fef30f9866d2ee158b010fe89ed9cae5d9bb2..6d5497fc2d5f4b096d972631637e284eb2ce00aa 100755 (executable)
@@ -39,7 +39,7 @@ import traceback
 
 from inspect import getargspec
 
-from sqlalchemy import create_engine, Table, MetaData, select
+from sqlalchemy import create_engine, Table, MetaData
 from sqlalchemy.orm import sessionmaker, mapper, relation
 
 # Don't remove this, we re-export the exceptions to scripts which import us
@@ -73,9 +73,9 @@ def session_wrapper(fn):
         private_transaction = False
 
         # Find the session object
-        try:
-            session = kwargs['session']
-        except KeyError:
+        session = kwargs.get('session')
+
+        if session is None:
             if len(args) <= len(getargspec(fn)[0]) - 1:
                 # No session specified as last argument or in kwargs
                 private_transaction = True
@@ -83,6 +83,10 @@ def session_wrapper(fn):
             else:
                 # Session is last argument in args
                 session = args[-1]
+                if session is None:
+                    args = list(args)
+                    session = args[-1] = DBConn().session()
+                    private_transaction = True
 
         if private_transaction:
             session.commit_or_flush = session.commit
@@ -1414,6 +1418,7 @@ class Queue(object):
             # them (if one doesn't already exist)
             for dsc_file in changes.dsc_files.keys():
                 # Skip all files except orig tarballs
+                from daklib.regexes import re_is_orig_source
                 if not re_is_orig_source.match(dsc_file):
                     continue
                 # Skip orig files not identified in the pool
@@ -1459,9 +1464,10 @@ class Queue(object):
 __all__.append('Queue')
 
 @session_wrapper
-def get_queue(queuename, session=None):
+def get_or_set_queue(queuename, session=None):
     """
-    Returns Queue object for given C{queue name}.
+    Returns Queue object for given C{queue name}, creating it if it does not
+    exist.
 
     @type queuename: string
     @param queuename: The name of the queue
@@ -1477,11 +1483,17 @@ def get_queue(queuename, session=None):
     q = session.query(Queue).filter_by(queue_name=queuename)
 
     try:
-        return q.one()
+        ret = q.one()
     except NoResultFound:
-        return None
+        queue = Queue()
+        queue.queue_name = queuename
+        session.add(queue)
+        session.commit_or_flush()
+        ret = queue
+
+    return ret
 
-__all__.append('get_queue')
+__all__.append('get_or_set_queue')
 
 ################################################################################