X-Git-Url: https://git.decadent.org.uk/gitweb/?a=blobdiff_plain;f=daklib%2Fdbconn.py;h=ca90ba8882bedaa6dc316d94efd44fb2c09086ae;hb=282d1ce6f56cb8ec1f37b6d9e5c2b9b856141a0e;hp=905bc47d081e5445228f250a13e951353fa943dd;hpb=d5c5510e389b89bddc072009132ed48b9097fd43;p=dak.git diff --git a/daklib/dbconn.py b/daklib/dbconn.py index 905bc47d..ca90ba88 100755 --- a/daklib/dbconn.py +++ b/daklib/dbconn.py @@ -59,21 +59,42 @@ __all__ = ['IntegrityError', 'SQLAlchemyError'] ################################################################################ def session_wrapper(fn): + """ + Wrapper around common ".., session=None):" handling. If the wrapped + function is called without passing 'session', we create a local one + and destroy it when the function ends. + + Also attaches a commit_or_flush method to the session; if we created a + local session, this is a synonym for session.commit(), otherwise it is a + synonym for session.flush(). + """ + def wrapped(*args, **kwargs): private_transaction = False - session = kwargs.get('session') - # No session specified as last argument or in kwargs, create one. - if session is None or len(args) == len(getargspec(fn)[0]) - 1: - private_transaction = True - kwargs['session'] = DBConn().session() + # Find the session object + try: + session = kwargs['session'] + except KeyError: + if len(args) <= len(getargspec(fn)[0]) - 1: + # No session specified as last argument or in kwargs + private_transaction = True + session = kwargs['session'] = DBConn().session() + else: + # Session is last argument in args + session = args[-1] + + if private_transaction: + session.commit_or_flush = session.commit + else: + session.commit_or_flush = session.flush try: return fn(*args, **kwargs) finally: if private_transaction: # We created a session; close it. - kwargs['session'].close() + session.close() wrapped.__doc__ = fn.__doc__ wrapped.func_name = fn.func_name @@ -168,7 +189,7 @@ __all__.append('Archive') @session_wrapper def get_archive(archive, session=None): """ - returns database id for given c{archive}. + returns database id for given C{archive}. @type archive: string @param archive: the name of the arhive @@ -1904,7 +1925,7 @@ def get_suite(suite, session=None): generated if not supplied) @rtype: Suite - @return: Suite object for the requested suite name (None if not presenT) + @return: Suite object for the requested suite name (None if not present) """ q = session.query(Suite).filter_by(suite_name=suite)