]> git.decadent.org.uk Git - dak.git/blob - daklib/dbconn.py
check in old SQLa work
[dak.git] / daklib / dbconn.py
1 #!/usr/bin/python
2
3 """ DB access class
4
5 @contact: Debian FTPMaster <ftpmaster@debian.org>
6 @copyright: 2000, 2001, 2002, 2003, 2004, 2006  James Troup <james@nocrew.org>
7 @copyright: 2008-2009  Mark Hymers <mhy@debian.org>
8 @copyright: 2009  Joerg Jaspert <joerg@debian.org>
9 @copyright: 2009  Mike O'Connor <stew@debian.org>
10 @license: GNU General Public License version 2 or later
11 """
12
13 # This program is free software; you can redistribute it and/or modify
14 # it under the terms of the GNU General Public License as published by
15 # the Free Software Foundation; either version 2 of the License, or
16 # (at your option) any later version.
17
18 # This program is distributed in the hope that it will be useful,
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
21 # GNU General Public License for more details.
22
23 # You should have received a copy of the GNU General Public License
24 # along with this program; if not, write to the Free Software
25 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
26
27 ################################################################################
28
29 # < mhy> I need a funny comment
30 # < sgran> two peanuts were walking down a dark street
31 # < sgran> one was a-salted
32 #  * mhy looks up the definition of "funny"
33
34 ################################################################################
35
36 import os
37 import psycopg2
38 import traceback
39
40 from sqlalchemy import create_engine, Table, MetaData, select
41 from sqlalchemy.orm import sessionmaker
42
43 from singleton import Singleton
44 from config import Config
45
46 ################################################################################
47
48 class Cache(object):
49     def __init__(self, hashfunc=None):
50         if hashfunc:
51             self.hashfunc = hashfunc
52         else:
53             self.hashfunc = lambda x: str(x)
54
55         self.data = {}
56
57     def SetValue(self, keys, value):
58         self.data[self.hashfunc(keys)] = value
59
60     def GetValue(self, keys):
61         return self.data.get(self.hashfunc(keys))
62
63 ################################################################################
64
65 class DBConn(Singleton):
66     """
67     database module init.
68     """
69     def __init__(self, *args, **kwargs):
70         super(DBConn, self).__init__(*args, **kwargs)
71
72     def _startup(self, *args, **kwargs):
73         self.__createconn()
74         self.__init_caches()
75
76     def __setuptables(self):
77         self.tbl_architecture = Table('architecture', self.db_meta, autoload=True)
78         self.tbl_archive = Table('archive', self.db_meta, autoload=True)
79         self.tbl_bin_associations = Table('bin_associations', self.db_meta, autoload=True)
80         self.tbl_binaries = Table('binaries', self.db_meta, autoload=True)
81         self.tbl_component = Table('component', self.db_meta, autoload=True)
82         self.tbl_config = Table('config', self.db_meta, autoload=True)
83         self.tbl_content_associations = Table('content_associations', self.db_meta, autoload=True)
84         self.tbl_content_file_names = Table('content_file_names', self.db_meta, autoload=True)
85         self.tbl_content_file_paths = Table('content_file_paths', self.db_meta, autoload=True)
86         self.tbl_dsc_files = Table('dsc_files', self.db_meta, autoload=True)
87         self.tbl_files = Table('files', self.db_meta, autoload=True)
88         self.tbl_fingerprint = Table('fingerprint', self.db_meta, autoload=True)
89         self.tbl_keyrings = Table('keyrings', self.db_meta, autoload=True)
90         self.tbl_location = Table('location', self.db_meta, autoload=True)
91         self.tbl_maintainer = Table('maintainer', self.db_meta, autoload=True)
92         self.tbl_override = Table('override', self.db_meta, autoload=True)
93         self.tbl_override_type = Table('override_type', self.db_meta, autoload=True)
94         self.tbl_pending_content_associations = Table('pending_content_associations', self.db_meta, autoload=True)
95         self.tbl_priority = Table('priority', self.db_meta, autoload=True)
96         self.tbl_queue = Table('queue', self.db_meta, autoload=True)
97         self.tbl_queue_build = Table('queue_build', self.db_meta, autoload=True)
98         self.tbl_section = Table('section', self.db_meta, autoload=True)
99         self.tbl_source = Table('source', self.db_meta, autoload=True)
100         self.tbl_src_associations = Table('src_associations', self.db_meta, autoload=True)
101         self.tbl_src_uploaders = Table('src_uploaders', self.db_meta, autoload=True)
102         self.tbl_suite = Table('suite', self.db_meta, autoload=True)
103         self.tbl_suite_architectures = Table('suite_architectures', self.db_meta, autoload=True)
104         self.tbl_uid = Table('uid', self.db_meta, autoload=True)
105
106     ## Connection functions
107     def __createconn(self):
108         cnf = Config()
109         if cnf["DB::Host"]:
110             # TCP/IP
111             connstr = "postgres://%s" % cnf["DB::Host"]
112             if cnf["DB::Port"] and cnf["DB::Port"] != "-1":
113                 connstr += ":%s" % cnf["DB::Port"]
114             connstr += "/%s" % cnf["DB::Name"]
115         else:
116             # Unix Socket
117             connstr = "postgres:///%s" % cnf["DB::Name"]
118             if cnf["DB::Port"] and cnf["DB::Port"] != "-1":
119                 connstr += "?port=%s" % cnf["DB::Port"]
120
121         self.db_pg   = create_engine(connstr)
122         self.db_meta = MetaData()
123         self.db_meta.bind = self.db_pg
124         self.db_smaker = sessionmaker(bind=self.db_pg,
125                                       autoflush=True,
126                                       transactional=True)
127
128         self.__setuptables()
129
130     def session(self):
131         return self.db_smaker()
132
133     ## Cache functions
134     def __init_caches(self):
135         self.caches = {'suite':         Cache(),
136                        'section':       Cache(),
137                        'priority':      Cache(),
138                        'override_type': Cache(),
139                        'architecture':  Cache(),
140                        'archive':       Cache(),
141                        'component':     Cache(),
142                        'content_path_names':     Cache(),
143                        'content_file_names':     Cache(),
144                        'location':      Cache(lambda x: '%s_%s_%s' % (x['location'], x['component'], x['location'])),
145                        'maintainer':    {}, # TODO
146                        'keyring':       {}, # TODO
147                        'source':        Cache(lambda x: '%s_%s_' % (x['source'], x['version'])),
148                        'files':         Cache(lambda x: '%s_%s_' % (x['filename'], x['location'])),
149                        'maintainer':    {}, # TODO
150                        'fingerprint':   {}, # TODO
151                        'queue':         {}, # TODO
152                        'uid':           {}, # TODO
153                        'suite_version': Cache(lambda x: '%s_%s' % (x['source'], x['suite'])),
154                       }
155
156         self.prepared_statements = {}
157
158     def prepare(self,name,statement):
159         if not self.prepared_statements.has_key(name):
160             pgc.execute(statement)
161             self.prepared_statements[name] = statement
162
163     def clear_caches(self):
164         self.__init_caches()
165
166     ## Get functions
167     def __get_id(self, retfield, selectobj, cachekey, cachename=None):
168         # This is a bit of a hack but it's an internal function only
169         if cachename is not None:
170             res = self.caches[cachename].GetValue(cachekey)
171             if res:
172                 return res
173
174         c = selectobj.execute()
175
176         if c.rowcount != 1:
177             return None
178
179         res = c.fetchone()
180
181         if retfield not in res.keys():
182             return None
183
184         res = res[retfield]
185
186         if cachename is not None:
187             self.caches[cachename].SetValue(cachekey, res)
188
189         return res
190
191     def get_suite_id(self, suite):
192         """
193         Returns database id for given C{suite}.
194         Results are kept in a cache during runtime to minimize database queries.
195
196         @type suite: string
197         @param suite: The name of the suite
198
199         @rtype: int
200         @return: the database id for the given suite
201
202         """
203         return int(self.__get_id('id',
204                                  self.tbl_suite.select(self.tbl_suite.columns.suite_name == suite),
205                                  suite,
206                                  'suite'))
207
208     def get_section_id(self, section):
209         """
210         Returns database id for given C{section}.
211         Results are kept in a cache during runtime to minimize database queries.
212
213         @type section: string
214         @param section: The name of the section
215
216         @rtype: int
217         @return: the database id for the given section
218
219         """
220         return self.__get_id('id',
221                              self.tbl_section.select(self.tbl_section.columns.section == section),
222                              section,
223                              'section')
224
225     def get_priority_id(self, priority):
226         """
227         Returns database id for given C{priority}.
228         Results are kept in a cache during runtime to minimize database queries.
229
230         @type priority: string
231         @param priority: The name of the priority
232
233         @rtype: int
234         @return: the database id for the given priority
235
236         """
237         return self.__get_id('id',
238                              self.tbl_priority.select(self.tbl_priority.columns.priority == priority),
239                              priority,
240                              'priority')
241
242     def get_override_type_id(self, override_type):
243         """
244         Returns database id for given override C{type}.
245         Results are kept in a cache during runtime to minimize database queries.
246
247         @type override_type: string
248         @param override_type: The name of the override type
249
250         @rtype: int
251         @return: the database id for the given override type
252
253         """
254         return self.__get_id('id',
255                              self.tbl_override_type.select(self.tbl_override_type.columns.type == override_type),
256                              override_type,
257                              'override_type')
258
259     def get_architecture_id(self, architecture):
260         """
261         Returns database id for given C{architecture}.
262         Results are kept in a cache during runtime to minimize database queries.
263
264         @type architecture: string
265         @param architecture: The name of the override type
266
267         @rtype: int
268         @return: the database id for the given architecture
269
270         """
271         return self.__get_id('id',
272                              self.tbl_architecture.select(self.tbl_architecture.columns.arch_string == architecture),
273                              architecture,
274                              'architecture')
275
276     def get_archive_id(self, archive):
277         """
278         returns database id for given c{archive}.
279         results are kept in a cache during runtime to minimize database queries.
280
281         @type archive: string
282         @param archive: the name of the override type
283
284         @rtype: int
285         @return: the database id for the given archive
286
287         """
288         archive = archive.lower()
289         return self.__get_id('id',
290                              self.tbl_archive.select(self.tbl_archive.columns.name == archive),
291                              archive,
292                              'archive')
293
294     def get_component_id(self, component):
295         """
296         Returns database id for given C{component}.
297         Results are kept in a cache during runtime to minimize database queries.
298
299         @type component: string
300         @param component: The name of the override type
301
302         @rtype: int
303         @return: the database id for the given component
304
305         """
306         component = component.lower()
307         return self.__get_id('id',
308                              self.tbl_component.select(self.tbl_component.columns.name == component),
309                              component.lower(),
310                              'component')
311
312     def get_location_id(self, location, component, archive):
313         """
314         Returns database id for the location behind the given combination of
315           - B{location} - the path of the location, eg. I{/srv/ftp.debian.org/ftp/pool/}
316           - B{component} - the id of the component as returned by L{get_component_id}
317           - B{archive} - the id of the archive as returned by L{get_archive_id}
318         Results are kept in a cache during runtime to minimize database queries.
319
320         @type location: string
321         @param location: the path of the location
322
323         @type component: string
324         @param component: the name of the component
325
326         @type archive: string
327         @param archive: the name of the archive
328
329         @rtype: int
330         @return: the database id for the location
331
332         """
333
334         archive = archive.lower()
335         component = component.lower()
336
337         values = {'archive': archive, 'location': location, 'component': component}
338
339         s = self.tbl_location.join(self.tbl_archive).join(self.tbl_component)
340
341         s = s.select(self.tbl_location.columns.path == location)
342         s = s.where(self.tbl_archive.columns.name == archive)
343         s = s.where(self.tbl_component.columns.name == component)
344
345         return self.__get_id('location.id', s, values, 'location')
346
347     def get_source_id(self, source, version):
348         """
349         Returns database id for the combination of C{source} and C{version}
350           - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc}
351           - B{version}
352         Results are kept in a cache during runtime to minimize database queries.
353
354         @type source: string
355         @param source: source package name
356
357         @type version: string
358         @param version: the source version
359
360         @rtype: int
361         @return: the database id for the source
362
363         """
364         s = self.tbl_source.select()
365         s = s.where(self.tbl_source.columns.source  == source)
366         s = s.where(self.tbl_source.columns.version == version)
367
368         return self.__get_id('id', s, {'source': source, 'version': version}, 'source')
369
370     def get_suite(self, suite):
371         if isinstance(suite, str):
372             suite_id = self.get_suite_id(suite.lower())
373         elif type(suite) == int:
374             suite_id = suite
375
376         s = self.tbl_suite.select(self.tbl_suite.columns.id == suite_id)
377         c = s.execute()
378         if c.rowcount < 1:
379             return None
380         else:
381             return c.fetchone()
382
383     def get_suite_version(self, source, suite):
384         """
385         Returns database id for a combination of C{source} and C{suite}.
386
387           - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc}
388           - B{suite} - a suite name, eg. I{unstable}
389
390         Results are kept in a cache during runtime to minimize database queries.
391
392         @type source: string
393         @param source: source package name
394
395         @type suite: string
396         @param suite: the suite name
397
398         @rtype: string
399         @return: the version for I{source} in I{suite}
400
401         """
402         s = select([self.tbl_source.columns.source, self.tbl_source.columns.version])
403 #        s = self.tbl_source.join(self.tbl_src_associations).join(self.tbl_suite)
404
405         s = s.select(self.tbl_suite.columns.suite_name == suite, use_labels=True)
406         s = s.select(self.tbl_source.columns.source == source)
407
408         return self.__get_id('source.version', s, {'suite': suite, 'source': source}, 'suite_version')
409
410
411     def get_files_id (self, filename, size, md5sum, location_id):
412         """
413         Returns -1, -2 or the file_id for filename, if its C{size} and C{md5sum} match an
414         existing copy.
415
416         The database is queried using the C{filename} and C{location_id}. If a file does exist
417         at that location, the existing size and md5sum are checked against the provided
418         parameters. A size or checksum mismatch returns -2. If more than one entry is
419         found within the database, a -1 is returned, no result returns None, otherwise
420         the file id.
421
422         Results are kept in a cache during runtime to minimize database queries.
423
424         @type filename: string
425         @param filename: the filename of the file to check against the DB
426
427         @type size: int
428         @param size: the size of the file to check against the DB
429
430         @type md5sum: string
431         @param md5sum: the md5sum of the file to check against the DB
432
433         @type location_id: int
434         @param location_id: the id of the location as returned by L{get_location_id}
435
436         @rtype: int / None
437         @return: Various return values are possible:
438                    - -2: size/checksum error
439                    - -1: more than one file found in database
440                    - None: no file found in database
441                    - int: file id
442
443         """
444         values = {'filename' : filename,
445                   'location' : location_id}
446
447         res = self.caches['files'].GetValue( values )
448
449         if not res:
450             query = """SELECT id, size, md5sum
451                        FROM files
452                        WHERE filename = %(filename)s AND location = %(location)s"""
453
454             cursor = self.db_con.cursor()
455             cursor.execute( query, values )
456
457             if cursor.rowcount == 0:
458                 res = None
459
460             elif cursor.rowcount != 1:
461                 res = -1
462
463             else:
464                 row = cursor.fetchone()
465
466                 if row[1] != int(size) or row[2] != md5sum:
467                     res =  -2
468
469                 else:
470                     self.caches['files'].SetValue(values, row[0])
471                     res = row[0]
472
473         return res
474
475
476     def get_or_set_contents_file_id(self, filename):
477         """
478         Returns database id for given filename.
479
480         Results are kept in a cache during runtime to minimize database queries.
481         If no matching file is found, a row is inserted.
482
483         @type filename: string
484         @param filename: The filename
485
486         @rtype: int
487         @return: the database id for the given component
488         """
489         try:
490             values={'value': filename}
491             query = "SELECT id FROM content_file_names WHERE file = %(value)s"
492             id = self.__get_single_id(query, values, cachename='content_file_names')
493             if not id:
494                 c = self.db_con.cursor()
495                 c.execute( "INSERT INTO content_file_names VALUES (DEFAULT, %(value)s) RETURNING id",
496                            values )
497
498                 id = c.fetchone()[0]
499                 self.caches['content_file_names'].SetValue(values, id)
500
501             return id
502         except:
503             traceback.print_exc()
504             raise
505
506     def get_or_set_contents_path_id(self, path):
507         """
508         Returns database id for given path.
509
510         Results are kept in a cache during runtime to minimize database queries.
511         If no matching file is found, a row is inserted.
512
513         @type path: string
514         @param path: The filename
515
516         @rtype: int
517         @return: the database id for the given component
518         """
519         try:
520             values={'value': path}
521             query = "SELECT id FROM content_file_paths WHERE path = %(value)s"
522             id = self.__get_single_id(query, values, cachename='content_path_names')
523             if not id:
524                 c = self.db_con.cursor()
525                 c.execute( "INSERT INTO content_file_paths VALUES (DEFAULT, %(value)s) RETURNING id",
526                            values )
527
528                 id = c.fetchone()[0]
529                 self.caches['content_path_names'].SetValue(values, id)
530
531             return id
532         except:
533             traceback.print_exc()
534             raise
535
536     def get_suite_architectures(self, suite):
537         """
538         Returns list of architectures for C{suite}.
539
540         @type suite: string, int
541         @param suite: the suite name or the suite_id
542
543         @rtype: list
544         @return: the list of architectures for I{suite}
545         """
546
547         suite_id = None
548         if type(suite) == str:
549             suite_id = self.get_suite_id(suite)
550         elif type(suite) == int:
551             suite_id = suite
552         else:
553             return None
554
555         c = self.db_con.cursor()
556         c.execute( """SELECT a.arch_string FROM suite_architectures sa
557                       JOIN architecture a ON (a.id = sa.architecture)
558                       WHERE suite='%s'""" % suite_id )
559
560         return map(lambda x: x[0], c.fetchall())
561
562     def insert_content_paths(self, bin_id, fullpaths):
563         """
564         Make sure given path is associated with given binary id
565
566         @type bin_id: int
567         @param bin_id: the id of the binary
568         @type fullpaths: list
569         @param fullpaths: the list of paths of the file being associated with the binary
570
571         @return: True upon success
572         """
573
574         c = self.db_con.cursor()
575
576         c.execute("BEGIN WORK")
577         try:
578
579             for fullpath in fullpaths:
580                 (path, file) = os.path.split(fullpath)
581
582                 if path.startswith( "./" ):
583                     path = path[2:]
584                 # Get the necessary IDs ...
585                 file_id = self.get_or_set_contents_file_id(file)
586                 path_id = self.get_or_set_contents_path_id(path)
587
588                 c.execute("""INSERT INTO content_associations
589                                (binary_pkg, filepath, filename)
590                            VALUES ( '%d', '%d', '%d')""" % (bin_id, path_id, file_id) )
591
592             c.execute("COMMIT")
593             return True
594         except:
595             traceback.print_exc()
596             c.execute("ROLLBACK")
597             return False
598
599     def insert_pending_content_paths(self, package, fullpaths):
600         """
601         Make sure given paths are temporarily associated with given
602         package
603
604         @type package: dict
605         @param package: the package to associate with should have been read in from the binary control file
606         @type fullpaths: list
607         @param fullpaths: the list of paths of the file being associated with the binary
608
609         @return: True upon success
610         """
611
612         c = self.db_con.cursor()
613
614         c.execute("BEGIN WORK")
615         try:
616             arch_id = self.get_architecture_id(package['Architecture'])
617
618             # Remove any already existing recorded files for this package
619             c.execute("""DELETE FROM pending_content_associations
620                          WHERE package=%(Package)s
621                          AND version=%(Version)s
622                          AND architecture=%(ArchID)s""", {'Package': package['Package'],
623                                                           'Version': package['Version'],
624                                                           'ArchID':  arch_id})
625
626             for fullpath in fullpaths:
627                 (path, file) = os.path.split(fullpath)
628
629                 if path.startswith( "./" ):
630                     path = path[2:]
631                 # Get the necessary IDs ...
632                 file_id = self.get_or_set_contents_file_id(file)
633                 path_id = self.get_or_set_contents_path_id(path)
634
635                 c.execute("""INSERT INTO pending_content_associations
636                                (package, version, architecture, filepath, filename)
637                             VALUES (%%(Package)s, %%(Version)s, '%d', '%d', '%d')"""
638                     % (arch_id, path_id, file_id), package )
639
640             c.execute("COMMIT")
641             return True
642         except:
643             traceback.print_exc()
644             c.execute("ROLLBACK")
645             return False
646