]> git.decadent.org.uk Git - dak.git/blob - daklib/dbconn.py
add suite wrapper class
[dak.git] / daklib / dbconn.py
1 #!/usr/bin/python
2
3 """ DB access class
4
5 @contact: Debian FTPMaster <ftpmaster@debian.org>
6 @copyright: 2000, 2001, 2002, 2003, 2004, 2006  James Troup <james@nocrew.org>
7 @copyright: 2008-2009  Mark Hymers <mhy@debian.org>
8 @copyright: 2009  Joerg Jaspert <joerg@debian.org>
9 @copyright: 2009  Mike O'Connor <stew@debian.org>
10 @license: GNU General Public License version 2 or later
11 """
12
13 # This program is free software; you can redistribute it and/or modify
14 # it under the terms of the GNU General Public License as published by
15 # the Free Software Foundation; either version 2 of the License, or
16 # (at your option) any later version.
17
18 # This program is distributed in the hope that it will be useful,
19 # but WITHOUT ANY WARRANTY; without even the implied warranty of
20 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
21 # GNU General Public License for more details.
22
23 # You should have received a copy of the GNU General Public License
24 # along with this program; if not, write to the Free Software
25 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
26
27 ################################################################################
28
29 # < mhy> I need a funny comment
30 # < sgran> two peanuts were walking down a dark street
31 # < sgran> one was a-salted
32 #  * mhy looks up the definition of "funny"
33
34 ################################################################################
35
36 import os
37 import psycopg2
38 import psycopg2.extras
39 import traceback
40
41 from singleton import Singleton
42 from config import Config
43
44 ################################################################################
45
46 class Cache(object):
47     def __init__(self, hashfunc=None):
48         if hashfunc:
49             self.hashfunc = hashfunc
50         else:
51             self.hashfunc = lambda x: x['value']
52
53         self.data = {}
54
55     def SetValue(self, keys, value):
56         self.data[self.hashfunc(keys)] = value
57
58     def GetValue(self, keys):
59         return self.data.get(self.hashfunc(keys))
60
61 ################################################################################
62
63 class DBConn(Singleton):
64     """
65     database module init.
66     """
67     def __init__(self, *args, **kwargs):
68         super(DBConn, self).__init__(*args, **kwargs)
69
70     def _startup(self, *args, **kwargs):
71         self.__createconn()
72         self.__init_caches()
73
74     ## Connection functions
75     def __createconn(self):
76         cnf = Config()
77         connstr = "dbname=%s" % cnf["DB::Name"]
78         if cnf["DB::Host"]:
79            connstr += " host=%s" % cnf["DB::Host"]
80         if cnf["DB::Port"] and cnf["DB::Port"] != "-1":
81            connstr += " port=%s" % cnf["DB::Port"]
82
83         self.db_con = psycopg2.connect(connstr)
84
85     def reconnect(self):
86         try:
87             self.db_con.close()
88         except psycopg2.InterfaceError:
89             pass
90
91         self.db_con = None
92         self.__createconn()
93
94     ## Cache functions
95     def __init_caches(self):
96         self.caches = {'suite':         Cache(),
97                        'section':       Cache(),
98                        'priority':      Cache(),
99                        'override_type': Cache(),
100                        'architecture':  Cache(),
101                        'archive':       Cache(),
102                        'component':     Cache(),
103                        'content_path_names':     Cache(),
104                        'content_file_names':     Cache(),
105                        'location':      Cache(lambda x: '%s_%s_%s' % (x['location'], x['component'], x['location'])),
106                        'maintainer':    {}, # TODO
107                        'keyring':       {}, # TODO
108                        'source':        Cache(lambda x: '%s_%s_' % (x['source'], x['version'])),
109                        'files':         Cache(lambda x: '%s_%s_' % (x['filename'], x['location'])),
110                        'maintainer':    {}, # TODO
111                        'fingerprint':   {}, # TODO
112                        'queue':         {}, # TODO
113                        'uid':           {}, # TODO
114                        'suite_version': Cache(lambda x: '%s_%s' % (x['source'], x['suite'])),
115                       }
116
117         self.prepared_statements = {}
118
119     def prepare(self,name,statement):
120         if not self.prepared_statements.has_key(name):
121             c = self.cursor()
122             c.execute(statement)
123             self.prepared_statements[name] = statement
124
125     def clear_caches(self):
126         self.__init_caches()
127
128     ## Functions to pass through to the database connector
129     def cursor(self):
130         return self.db_con.cursor()
131
132     def commit(self):
133         return self.db_con.commit()
134
135     ## Get functions
136     def __get_single_row(self, query, values):
137         c = self.db_con.cursor(cursor_factory=psycopg2.extras.DictCursor)
138         c.execute(query, values)
139
140         if c.rowcount < 1:
141             return None
142
143         res = c.fetchone()
144
145         return res
146
147     def __get_single_id(self, query, values, cachename=None):
148         # This is a bit of a hack but it's an internal function only
149         if cachename is not None:
150             res = self.caches[cachename].GetValue(values)
151             if res:
152                 return res
153
154         c = self.db_con.cursor()
155         c.execute(query, values)
156
157         if c.rowcount != 1:
158             return None
159
160         res = c.fetchone()[0]
161
162         if cachename is not None:
163             self.caches[cachename].SetValue(values, res)
164
165         return res
166
167     def __get_id(self, retfield, table, qfield, value):
168         query = "SELECT %s FROM %s WHERE %s = %%(value)s" % (retfield, table, qfield)
169         return self.__get_single_id(query, {'value': value}, cachename=table)
170
171     def get_suite_id(self, suite):
172         """
173         Returns database id for given C{suite}.
174         Results are kept in a cache during runtime to minimize database queries.
175
176         @type suite: string
177         @param suite: The name of the suite
178
179         @rtype: int
180         @return: the database id for the given suite
181
182         """
183         return int(self.__get_id('id', 'suite', 'suite_name', suite))
184
185     def get_section_id(self, section):
186         """
187         Returns database id for given C{section}.
188         Results are kept in a cache during runtime to minimize database queries.
189
190         @type section: string
191         @param section: The name of the section
192
193         @rtype: int
194         @return: the database id for the given section
195
196         """
197         return self.__get_id('id', 'section', 'section', section)
198
199     def get_priority_id(self, priority):
200         """
201         Returns database id for given C{priority}.
202         Results are kept in a cache during runtime to minimize database queries.
203
204         @type priority: string
205         @param priority: The name of the priority
206
207         @rtype: int
208         @return: the database id for the given priority
209
210         """
211         return self.__get_id('id', 'priority', 'priority', priority)
212
213     def get_override_type_id(self, override_type):
214         """
215         Returns database id for given override C{type}.
216         Results are kept in a cache during runtime to minimize database queries.
217
218         @type type: string
219         @param type: The name of the override type
220
221         @rtype: int
222         @return: the database id for the given override type
223
224         """
225         return self.__get_id('id', 'override_type', 'type', override_type)
226
227     def get_architecture_id(self, architecture):
228         """
229         Returns database id for given C{architecture}.
230         Results are kept in a cache during runtime to minimize database queries.
231
232         @type architecture: string
233         @param architecture: The name of the override type
234
235         @rtype: int
236         @return: the database id for the given architecture
237
238         """
239         return self.__get_id('id', 'architecture', 'arch_string', architecture)
240
241     def get_archive_id(self, archive):
242         """
243         returns database id for given c{archive}.
244         results are kept in a cache during runtime to minimize database queries.
245
246         @type archive: string
247         @param archive: the name of the override type
248
249         @rtype: int
250         @return: the database id for the given archive
251
252         """
253         return self.__get_id('id', 'archive', 'lower(name)', archive)
254
255     def get_component_id(self, component):
256         """
257         Returns database id for given C{component}.
258         Results are kept in a cache during runtime to minimize database queries.
259
260         @type component: string
261         @param component: The name of the override type
262
263         @rtype: int
264         @return: the database id for the given component
265
266         """
267         return self.__get_id('id', 'component', 'lower(name)', component)
268
269     def get_location_id(self, location, component, archive):
270         """
271         Returns database id for the location behind the given combination of
272           - B{location} - the path of the location, eg. I{/srv/ftp.debian.org/ftp/pool/}
273           - B{component} - the id of the component as returned by L{get_component_id}
274           - B{archive} - the id of the archive as returned by L{get_archive_id}
275         Results are kept in a cache during runtime to minimize database queries.
276
277         @type location: string
278         @param location: the path of the location
279
280         @type component: int
281         @param component: the id of the component
282
283         @type archive: int
284         @param archive: the id of the archive
285
286         @rtype: int
287         @return: the database id for the location
288
289         """
290
291         archive_id = self.get_archive_id(archive)
292
293         if not archive_id:
294             return None
295
296         res = None
297
298         if component:
299             component_id = self.get_component_id(component)
300             if component_id:
301                 res = self.__get_single_id("SELECT id FROM location WHERE path=%(location)s AND component=%(component)s AND archive=%(archive)s",
302                         {'location': location,
303                          'archive': int(archive_id),
304                          'component': component_id}, cachename='location')
305         else:
306             res = self.__get_single_id("SELECT id FROM location WHERE path=%(location)s AND archive=%(archive)d",
307                     {'location': location, 'archive': archive_id, 'component': ''}, cachename='location')
308
309         return res
310
311     def get_source_id(self, source, version):
312         """
313         Returns database id for the combination of C{source} and C{version}
314           - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc}
315           - B{version}
316         Results are kept in a cache during runtime to minimize database queries.
317
318         @type source: string
319         @param source: source package name
320
321         @type version: string
322         @param version: the source version
323
324         @rtype: int
325         @return: the database id for the source
326
327         """
328         return self.__get_single_id("SELECT id FROM source s WHERE s.source=%(source)s AND s.version=%(version)s",
329                                  {'source': source, 'version': version}, cachename='source')
330
331     def get_suite(self, suite):
332         if isinstance(suite, str):
333             suite_id = self.get_suite_id(suite.lower())
334         elif type(suite) == int:
335             suite_id = suite
336
337         print suite_id
338
339         return self.__get_single_row("SELECT * FROM suite WHERE id = %(id)s",
340                                      {'id': suite_id})
341
342     def get_suite_version(self, source, suite):
343         """
344         Returns database id for a combination of C{source} and C{suite}.
345
346           - B{source} - source package name, eg. I{mailfilter}, I{bbdb}, I{glibc}
347           - B{suite} - a suite name, eg. I{unstable}
348
349         Results are kept in a cache during runtime to minimize database queries.
350
351         @type source: string
352         @param source: source package name
353
354         @type suite: string
355         @param suite: the suite name
356
357         @rtype: string
358         @return: the version for I{source} in I{suite}
359
360         """
361         return self.__get_single_id("""
362         SELECT s.version FROM source s, suite su, src_associations sa
363         WHERE sa.source=s.id
364           AND sa.suite=su.id
365           AND su.suite_name=%(suite)s
366           AND s.source=%(source)""", {'suite': suite, 'source': source}, cachename='suite_version')
367
368
369     def get_files_id (self, filename, size, md5sum, location_id):
370         """
371         Returns -1, -2 or the file_id for filename, if its C{size} and C{md5sum} match an
372         existing copy.
373
374         The database is queried using the C{filename} and C{location_id}. If a file does exist
375         at that location, the existing size and md5sum are checked against the provided
376         parameters. A size or checksum mismatch returns -2. If more than one entry is
377         found within the database, a -1 is returned, no result returns None, otherwise
378         the file id.
379
380         Results are kept in a cache during runtime to minimize database queries.
381
382         @type filename: string
383         @param filename: the filename of the file to check against the DB
384
385         @type size: int
386         @param size: the size of the file to check against the DB
387
388         @type md5sum: string
389         @param md5sum: the md5sum of the file to check against the DB
390
391         @type location_id: int
392         @param location_id: the id of the location as returned by L{get_location_id}
393
394         @rtype: int / None
395         @return: Various return values are possible:
396                    - -2: size/checksum error
397                    - -1: more than one file found in database
398                    - None: no file found in database
399                    - int: file id
400
401         """
402         values = {'filename' : filename,
403                   'location' : location_id}
404
405         res = self.caches['files'].GetValue( values )
406
407         if not res:
408             query = """SELECT id, size, md5sum
409                        FROM files
410                        WHERE filename = %(filename)s AND location = %(location)s"""
411
412             cursor = self.db_con.cursor()
413             cursor.execute( query, values )
414
415             if cursor.rowcount == 0:
416                 res = None
417
418             elif cursor.rowcount != 1:
419                 res = -1
420
421             else:
422                 row = cursor.fetchone()
423
424                 if row[1] != size or row[2] != md5sum:
425                     res =  -2
426
427                 else:
428                     self.caches[cachename].SetValue(values, row[0])
429                     res = row[0]
430
431         return res
432
433
434     def get_or_set_contents_file_id(self, filename):
435         """
436         Returns database id for given filename.
437
438         Results are kept in a cache during runtime to minimize database queries.
439         If no matching file is found, a row is inserted.
440
441         @type filename: string
442         @param filename: The filename
443
444         @rtype: int
445         @return: the database id for the given component
446         """
447         try:
448             values={'value': filename}
449             query = "SELECT id FROM content_file_names WHERE file = %(value)s"
450             id = self.__get_single_id(query, values, cachename='content_file_names')
451             if not id:
452                 c = self.db_con.cursor()
453                 c.execute( "INSERT INTO content_file_names VALUES (DEFAULT, %(value)s) RETURNING id",
454                            values )
455
456                 id = c.fetchone()[0]
457                 self.caches['content_file_names'].SetValue(values, id)
458
459             return id
460         except:
461             traceback.print_exc()
462             raise
463
464     def get_or_set_contents_path_id(self, path):
465         """
466         Returns database id for given path.
467
468         Results are kept in a cache during runtime to minimize database queries.
469         If no matching file is found, a row is inserted.
470
471         @type path: string
472         @param path: The filename
473
474         @rtype: int
475         @return: the database id for the given component
476         """
477         try:
478             values={'value': path}
479             query = "SELECT id FROM content_file_paths WHERE path = %(value)s"
480             id = self.__get_single_id(query, values, cachename='content_path_names')
481             if not id:
482                 c = self.db_con.cursor()
483                 c.execute( "INSERT INTO content_file_paths VALUES (DEFAULT, %(value)s) RETURNING id",
484                            values )
485
486                 id = c.fetchone()[0]
487                 self.caches['content_path_names'].SetValue(values, id)
488
489             return id
490         except:
491             traceback.print_exc()
492             raise
493
494     def get_suite_architectures(self, suite):
495         """
496         Returns list of architectures for C{suite}.
497
498         @type suite: string, int
499         @param suite: the suite name or the suite_id
500
501         @rtype: list
502         @return: the list of architectures for I{suite}
503         """
504
505         suite_id = None
506         if type(suite) == str:
507             suite_id = self.get_suite_id(suite)
508         elif type(suite) == int:
509             suite_id = suite
510         else:
511             return None
512
513         c = self.db_con.cursor()
514         c.execute( """SELECT a.arch_string FROM suite_architectures sa
515                       JOIN architecture a ON (a.id = sa.architecture)
516                       WHERE suite='%s'""" % suite_id )
517
518         return map(lambda x: x[0], c.fetchall())
519
520     def insert_content_paths(self, bin_id, fullpaths):
521         """
522         Make sure given path is associated with given binary id
523
524         @type bin_id: int
525         @param bin_id: the id of the binary
526         @type fullpath: string
527         @param fullpath: the path of the file being associated with the binary
528
529         @return True upon success
530         """
531
532         c = self.db_con.cursor()
533
534         c.execute("BEGIN WORK")
535         try:
536
537             for fullpath in fullpaths:
538                 (path, file) = os.path.split(fullpath)
539
540                 # Get the necessary IDs ...
541                 file_id = self.get_or_set_contents_file_id(file)
542                 path_id = self.get_or_set_contents_path_id(path)
543
544                 c.execute("""INSERT INTO content_associations
545                                (binary_pkg, filepath, filename)
546                            VALUES ( '%d', '%d', '%d')""" % (bin_id, path_id, file_id) )
547
548             c.execute("COMMIT")
549             return True
550         except:
551             traceback.print_exc()
552             c.execute("ROLLBACK")
553             return False
554
555     def insert_pending_content_paths(self, package, fullpaths):
556         """
557         Make sure given paths are temporarily associated with given
558         package
559
560         @type package: dict
561         @param package: the package to associate with should have been read in from the binary control file
562         @type fullpaths: list
563         @param fullpaths: the list of paths of the file being associated with the binary
564
565         @return True upon success
566         """
567
568         c = self.db_con.cursor()
569
570         c.execute("BEGIN WORK")
571         try:
572
573                 # Remove any already existing recorded files for this package
574             c.execute("""DELETE FROM pending_content_associations
575                          WHERE package=%(Package)s
576                          AND version=%(Version)s""", package )
577
578             for fullpath in fullpaths:
579                 (path, file) = os.path.split(fullpath)
580
581                 if path.startswith( "./" ):
582                     path = path[2:]
583                 # Get the necessary IDs ...
584                 file_id = self.get_or_set_contents_file_id(file)
585                 path_id = self.get_or_set_contents_path_id(path)
586
587                 c.execute("""INSERT INTO pending_content_associations
588                                (package, version, filepath, filename)
589                            VALUES (%%(Package)s, %%(Version)s, '%d', '%d')""" % (path_id, file_id),
590                           package )
591             c.execute("COMMIT")
592             return True
593         except:
594             traceback.print_exc()
595             c.execute("ROLLBACK")
596             return False
597
598 ################################################################################
599
600 class Suite(object):
601     # This should be kept in sync with the suites table;
602     # we should probably just do introspection on the table
603     # (or maybe use an ORM)
604     _fieldnames = ['announce', 'changelogbase', 'codename', 'commentsdir',
605                    'copychanges', 'copydotdak', 'description', 'id',
606                    'label', 'notautomatic', 'origin', 'overridecodename',
607                    'overridesuite', 'policy_engine', 'priority', 'suite_name',
608                    'untouchable', 'validtime', 'version']
609
610     def __init_fields(self):
611         for k in self._fieldnames:
612             setattr(self, k, None)
613
614     def __init__(self, suite):
615         self.__init_fields()
616         if suite is not None:
617             db_conn = DBConn()
618             suite_data = db_conn.get_suite(suite)
619             print suite_data
620             if suite_data is not None:
621                 for k in suite_data.keys():
622                     setattr(self, k, suite_data[k])
623
624
625