From: Torsten Werner Date: Mon, 21 Mar 2011 21:26:21 +0000 (+0100) Subject: Add method reset() to class DBConn() and test it. X-Git-Url: https://git.decadent.org.uk/gitweb/?a=commitdiff_plain;h=a881423355b3e9455f44bb88d806bc04a7c7e7e5;p=dak.git Add method reset() to class DBConn() and test it. Signed-off-by: Torsten Werner --- diff --git a/daklib/dbconn.py b/daklib/dbconn.py index 1bf5a0f9..6cd84de3 100755 --- a/daklib/dbconn.py +++ b/daklib/dbconn.py @@ -59,7 +59,7 @@ import sqlalchemy from sqlalchemy import create_engine, Table, MetaData, Column, Integer, desc, \ Text, ForeignKey from sqlalchemy.orm import sessionmaker, mapper, relation, object_session, \ - backref, MapperExtension, EXT_CONTINUE, object_mapper + backref, MapperExtension, EXT_CONTINUE, object_mapper, clear_mappers from sqlalchemy import types as sqltypes # Don't remove this, we re-export the exceptions to scripts which import us @@ -3198,6 +3198,15 @@ class DBConn(object): def session(self): return self.db_smaker() + def reset(self): + ''' + Resets the DBConn object. This function must be called by subprocesses + created by the multiprocessing module. See tests/dbtest_multiproc.py + for an example. + ''' + clear_mappers() + self.__createconn() + __all__.append('DBConn') diff --git a/tests/dbtest_multiproc.py b/tests/dbtest_multiproc.py new file mode 100755 index 00000000..f4c2a37a --- /dev/null +++ b/tests/dbtest_multiproc.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python + +from db_test import DBDakTestCase + +from daklib.dbconn import DBConn + +from multiprocessing import Pool +from time import sleep +import unittest + +def read_number(): + DBConn().reset() + session = DBConn().session() + result = session.query('foo').from_statement('select 7 as foo').scalar() + sleep(0.1) + session.close() + return result + +class MultiProcTestCase(DBDakTestCase): + """ + This TestCase checks that DBConn works with multiprocessing. A fresh + subprocess needs to call reset() on DBConn(). See function read_number() + for an example. + """ + + def save_result(self, result): + self.result += result + + def test_seven(self): + ''' + Test apply_async() with a database session. + ''' + self.result = 0 + pool = Pool() + pool.apply_async(read_number, (), callback = self.save_result) + pool.apply_async(read_number, (), callback = self.save_result) + pool.apply_async(read_number, (), callback = self.save_result) + pool.apply_async(read_number, (), callback = self.save_result) + pool.apply_async(read_number, (), callback = self.save_result) + pool.close() + pool.join() + self.assertEqual(5 * 7, self.result) + +if __name__ == '__main__': + unittest.main()