From 8c491e536146ef00491001c68d076830ce963ea8 Mon Sep 17 00:00:00 2001 From: Torsten Werner Date: Tue, 1 Feb 2011 07:09:10 +0100 Subject: [PATCH] Implement ORMObject.session() and .clone(). - session() is syntactic sugar - clone() is needed for safe multithreading Signed-off-by: Torsten Werner --- daklib/dbconn.py | 38 +++++++++++++++++++++++++++++++++++++- tests/dbtest_session.py | 37 ++++++++++++++++++++++++++++++++++--- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/daklib/dbconn.py b/daklib/dbconn.py index aa71c180..df38b777 100755 --- a/daklib/dbconn.py +++ b/daklib/dbconn.py @@ -55,7 +55,7 @@ from inspect import getargspec import sqlalchemy from sqlalchemy import create_engine, Table, MetaData, Column, Integer, desc from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, \ - backref, MapperExtension, EXT_CONTINUE + backref, MapperExtension, EXT_CONTINUE, object_mapper from sqlalchemy import types as sqltypes # Don't remove this, we re-export the exceptions to scripts which import us @@ -287,6 +287,42 @@ class ORMObject(object): ''' return session.query(cls).get(primary_key) + def session(self, replace = False): + ''' + Returns the current session that is associated with the object. May + return None is object is in detached state. + ''' + + return object_session(self) + + def clone(self, session = None): + ''' + Clones the current object in a new session and returns the new clone. A + fresh session is created if the optional session parameter is not + provided. + + RATIONALE: SQLAlchemy's session is not thread safe. This method allows + cloning of an existing object to allow several threads to work with + their own instances of an ORMObject. + + WARNING: Only persistent (committed) objects can be cloned. + ''' + + if session is None: + session = DBConn().session() + if self.session() is None: + raise RuntimeError('Method clone() failed for detached object:\n%s' % + self) + self.session().flush() + mapper = object_mapper(self) + primary_key = mapper.primary_key_from_instance(self) + object_class = self.__class__ + new_object = session.query(object_class).get(primary_key) + if new_object is None: + raise RuntimeError( \ + 'Method clone() failed for non-persistent object:\n%s' % self) + return new_object + __all__.append('ORMObject') ################################################################################ diff --git a/tests/dbtest_session.py b/tests/dbtest_session.py index 7c378ce9..72c2aff6 100755 --- a/tests/dbtest_session.py +++ b/tests/dbtest_session.py @@ -2,9 +2,8 @@ from db_test import DBDakTestCase -from daklib.dbconn import Uid +from daklib.dbconn import DBConn, Uid -from sqlalchemy.orm import object_session from sqlalchemy.exc import InvalidRequestError import time @@ -93,7 +92,6 @@ class SessionTestCase(DBDakTestCase): uid = Uid(uid = 'foobar') self.session.add(uid) self.assertTrue(uid in self.session) - self.assertEqual(self.session, object_session(uid)) self.session.expunge(uid) self.assertTrue(uid not in self.session) # test close() @@ -138,6 +136,39 @@ class SessionTestCase(DBDakTestCase): self.session.rollback() self.assertRaises(InvalidRequestError, self.refresh) + def test_session(self): + ''' + Tests the ORMObject.session() method. + ''' + + uid = Uid(uid = 'foobar') + self.session.add(uid) + self.assertEqual(self.session, uid.session()) + + def test_clone(self): + ''' + Tests the ORMObject.clone() method. + ''' + + uid1 = Uid(uid = 'foobar') + # no session yet + self.assertRaises(RuntimeError, uid1.clone) + self.session.add(uid1) + # object not persistent yet + self.assertRaises(RuntimeError, uid1.clone) + self.session.commit() + # test without session parameter + uid2 = uid1.clone() + self.assertTrue(uid1 is not uid2) + self.assertEqual(uid1.uid, uid2.uid) + self.assertTrue(uid2 not in uid1.session()) + self.assertTrue(uid1 not in uid2.session()) + # test with explicit session parameter + new_session = DBConn().session() + uid3 = uid1.clone(session = new_session) + self.assertEqual(uid1.uid, uid3.uid) + self.assertTrue(uid3 in new_session) + def classes_to_clean(self): # We need to clean all Uid objects in case some test fails. return (Uid,) -- 2.39.5