X-Git-Url: https://git.decadent.org.uk/gitweb/?a=blobdiff_plain;f=daklib%2Fdbconn.py;h=fca8e1e0ec3ed347bf17cbb5cb23f4207bd07b14;hb=28d6d3160080ea84655f6e44834aa04567746e2e;hp=66921b851a8efef6d5f530561d5376bc0400ac09;hpb=98a9167cf4ca8ce6344152505c05537626f208f6;p=dak.git diff --git a/daklib/dbconn.py b/daklib/dbconn.py index 66921b85..fca8e1e0 100755 --- a/daklib/dbconn.py +++ b/daklib/dbconn.py @@ -59,6 +59,18 @@ class Architecture(object): def __init__(self, *args, **kwargs): pass + def __eq__(self, val): + if isinstance(val, str): + return (self.arch_string== val) + # This signals to use the normal comparison operator + return NotImplemented + + def __ne__(self, val): + if isinstance(val, str): + return (self.arch_string != val) + # This signals to use the normal comparison operator + return NotImplemented + def __repr__(self): return '' % self.arch_string @@ -171,6 +183,24 @@ class DBBinary(object): __all__.append('DBBinary') +def get_suites_binary_in(package, session=None): + """ + Returns list of Suite objects which given C{package} name is in + + @type source: str + @param source: DBBinary package name to search for + + @rtype: list + @return: list of Suite objects for the given package + """ + + if session is None: + session = DBConn().session() + + return session.query(Suite).join(BinAssociation).join(DBBinary).filter_by(package=package).all() + +__all__.append('get_suites_binary_in') + def get_binary_from_id(id, session=None): """ Returns DBBinary object for given C{id} @@ -297,6 +327,18 @@ class Component(object): def __init__(self, *args, **kwargs): pass + def __eq__(self, val): + if isinstance(val, str): + return (self.component_name == val) + # This signals to use the normal comparison operator + return NotImplemented + + def __ne__(self, val): + if isinstance(val, str): + return (self.component_name != val) + # This signals to use the normal comparison operator + return NotImplemented + def __repr__(self): return '' % self.component_name @@ -572,6 +614,53 @@ class PoolFile(object): __all__.append('PoolFile') +def check_poolfile(filename, filesize, md5sum, location_id, session=None): + """ + Returns a tuple: + (ValidFileFound [boolean or None], PoolFile object or None) + + @type filename: string + @param filename: the filename of the file to check against the DB + + @type filesize: int + @param filesize: the size of the file to check against the DB + + @type md5sum: string + @param md5sum: the md5sum of the file to check against the DB + + @type location_id: int + @param location_id: the id of the location to look in + + @rtype: tuple + @return: Tuple of length 2. + If more than one file found with that name: + (None, None) + If valid pool file found: (True, PoolFile object) + If valid pool file not found: + (False, None) if no file found + (False, PoolFile object) if file found with size/md5sum mismatch + """ + + if session is None: + session = DBConn().session() + + q = session.query(PoolFile).filter_by(filename=filename) + q = q.join(Location).filter_by(location_id=location_id) + + if q.count() > 1: + return (None, None) + if q.count() < 1: + return (False, None) + + obj = q.one() + if obj.md5sum != md5sum or obj.filesize != filesize: + return (False, obj) + + return (True, obj) + +__all__.append('check_poolfile') + + def get_poolfile_by_name(filename, location_id=None, session=None): """ Returns an array of PoolFile objects for the given filename and @@ -587,7 +676,7 @@ def get_poolfile_by_name(filename, location_id=None, session=None): @return: array of PoolFile objects """ - if session is not None: + if session is None: session = DBConn().session() q = session.query(PoolFile).filter_by(filename=filename) @@ -610,7 +699,7 @@ def get_poolfile_like_name(filename, session=None): @return: array of PoolFile objects """ - if session is not None: + if session is None: session = DBConn().session() # TODO: There must be a way of properly using bind parameters with %FOO% @@ -959,6 +1048,18 @@ class Priority(object): def __init__(self, *args, **kwargs): pass + def __eq__(self, val): + if isinstance(val, str): + return (self.priority == val) + # This signals to use the normal comparison operator + return NotImplemented + + def __ne__(self, val): + if isinstance(val, str): + return (self.priority != val) + # This signals to use the normal comparison operator + return NotImplemented + def __repr__(self): return '' % (self.priority, self.priority_id) @@ -1178,6 +1279,18 @@ class Section(object): def __init__(self, *args, **kwargs): pass + def __eq__(self, val): + if isinstance(val, str): + return (self.section == val) + # This signals to use the normal comparison operator + return NotImplemented + + def __ne__(self, val): + if isinstance(val, str): + return (self.section != val) + # This signals to use the normal comparison operator + return NotImplemented + def __repr__(self): return '
' % self.section @@ -1288,6 +1401,24 @@ def source_exists(source, source_version, suites = ["any"], session=None): __all__.append('source_exists') +def get_suites_source_in(source, session=None): + """ + Returns list of Suite objects which given C{source} name is in + + @type source: str + @param source: DBSource package name to search for + + @rtype: list + @return: list of Suite objects for the given source + """ + + if session is None: + session = DBConn().session() + + return session.query(Suite).join(SrcAssociation).join(DBSource).filter_by(source=source).all() + +__all__.append('get_suites_source_in') + def get_sources_from_name(source, version=None, dm_upload_allowed=None, session=None): """ Returns list of DBSource objects for given C{source} name and other parameters @@ -1404,6 +1535,18 @@ class Suite(object): def __repr__(self): return '' % self.suite_name + def __eq__(self, val): + if isinstance(val, str): + return (self.suite_name == val) + # This signals to use the normal comparison operator + return NotImplemented + + def __ne__(self, val): + if isinstance(val, str): + return (self.suite_name != val) + # This signals to use the normal comparison operator + return NotImplemented + def details(self): ret = [] for disp, field in SUITE_FIELDS: @@ -1525,6 +1668,18 @@ class Uid(object): def __init__(self, *args, **kwargs): pass + def __eq__(self, val): + if isinstance(val, str): + return (self.uid == val) + # This signals to use the normal comparison operator + return NotImplemented + + def __ne__(self, val): + if isinstance(val, str): + return (self.uid != val) + # This signals to use the normal comparison operator + return NotImplemented + def __repr__(self): return '' % (self.uid, self.name)