""" ItemStore.py

Data store abstraction module.
"""
__copyright__ = "Copyright (c) 2002-2005 Free Software Foundation, Inc."
__license__ = """
Straw is free software; you can redistribute it and/or modify it under the
terms of the GNU General Public License as published by the Free Software
Foundation; either version 2 of the License, or (at your option) any later
version.

Straw is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
A PARTICULAR PURPOSE. See the GNU General Public License for more details.

You should have received a copy of the GNU General Public License along with
this program; if not, write to the Free Software Foundation, Inc., 59 Temple
Place - Suite 330, Boston, MA 02111-1307, USA. """

import cPickle as pickle
import os, sys
import time
from error import *
import tempfile
import traceback

try:
    from bsddb.db import *
    import bsddb
except ImportError:
    from bsddb3.db import *
    import bsddb3 as bsddb

import Event
import SummaryItem
import FeedList
import ImageCache
from MainloopManager import MainloopManager

DATABASE_FILE_NAME = "itemstore.db"

class ConvertException(Exception):
    def __init__(self, version1, version2, reason):
        self.version1 = version1
        self.version2 = version2
        self.reason = reason

class MyDB:
    CURRENT_VERSION = 3

    def __init__(self, filename, dbhome, create=0, truncate=0, mode=0600,
                 recover=0, dbflags=0):
        self._db = None
        self._env = None
        recoverenv = DB_CREATE | DB_RECOVER
        # DB_INIT_TXN automatically enables logging
        flagsforenv = DB_INIT_TXN | DB_INIT_MPOOL | DB_INIT_LOCK | DB_PRIVATE

        self._env = DBEnv()
        self._env.set_data_dir(dbhome)
        self._env.set_lk_detect(DB_LOCK_DEFAULT)  # enable auto deadlock avoidance
        self._env.set_lg_max(2**20)
        self._env.set_lk_max_locks   (10000)
        self._env.set_lk_max_objects (10000)

        try:
            self._env.open(dbhome, recoverenv | flagsforenv, mode)
        except bsddb._db.DBRunRecoveryError, err:
            self._env.remove(dbhome)
            self._env.close()
            log("%s" % err[1])
            sys.exit("Recovery Error: See README for details on how to recover data. ")

        flags = 0
        if truncate:
            flags |= DB_TRUNCATE

        try:
            flags |= DB_AUTO_COMMIT
        except NameError:
            pass

        try:
            self._db = DB(self._env)
            self._db.open(filename, DB_BTREE, flags, mode)
        except bsddb._db.DBNoSuchFileError:
            if create:
                self._db = DB(self._env)
                self._db.open(filename, DB_BTREE, flags | DB_CREATE, mode)
                self.set_db_version(self.CURRENT_VERSION)
            else:
                raise
        try:
            self.convert_old_versions()
        except Exception, ex:
            try:
                filename = tempfile.mktemp(prefix="straw-")
                fh = open(filename, "w")
                traceback.print_exc(None, fh)
                raise ConvertException(self.get_db_version(),
                                       MyDB.CURRENT_VERSION, "%s" % filename)
            finally:
                fh.close()

    def close(self):
        if self._db is not None:
            self._db.close()
            self._db = None
        if self._env is not None:
            self._env.close()
            self._env = None

    def checkpoint(self):
        # set flags to 0 or DB_FORCE, else will raise EINVAL (InvalidArgError)
        cpflags = 0
        self._env.txn_checkpoint(cpflags | DB_FORCE )
        deletees = self._env.log_archive(DB_ARCH_ABS)
        for d in deletees:
            os.remove(d)

    def begin_transaction(self):
        return self._env.txn_begin()

    def get_item_ids(self, iid, txn):
        key = "fids:%d" % iid
        dids = self._db.get(key, txn=txn)
        ids = []
        if dids:
            ids = pickle.loads(dids)
        return ids

    def save_feed_item_ids(self, feed, ids, txn=None):
        rowid = "fids:%d" % feed.id
        commit = 0
        if not txn:
            txn = self.begin_transaction()
            commit = 1
        try:
            try:
                self._db.delete(rowid, txn=txn)
            except DBNotFoundError:
                pass
            self._db.put(rowid, pickle.dumps(ids), txn=txn)
        except Exception, ex:
            if commit:
                txn.abort()
            logtb(str(ex))
        else:
            if commit:
                txn.commit()

    def get_item(self, feed_id, item_id, txn=None):
        item = self._db.get("%d:%d" % (feed_id, item_id), txn=txn)
        return unstringify_item(item)

    def add_items(self, feed, items):
        txn = self.begin_transaction()
        try:
            feed_item_ids = self.get_item_ids(feed.id, txn=txn)
            for item in items:
                self._db.put("%d:%d" % (item.feed.id, item.id), stringify_item(item), txn=txn)
                # TODO: it might be a good idea to check here that we don't add
                # duplicate items. It doesn't happen normally, but there can be
                # bugs that trigger that. Throwing an exception would be the
                # the right thing: it wouldn't hide the breakage.
                feed_item_ids.append(item.id)
            self.save_feed_item_ids(feed, feed_item_ids, txn)
        except Exception, ex:
            txn.abort()
            logtb(str(ex))
        else:
            txn.commit()

    def delete_items(self, feed, items):
        """ Deletes a list of items.

        Useful for cutting old items based on number of items stored.
        """
        txn = self.begin_transaction()
        try:
            feed_item_ids = self.get_item_ids(feed.id, txn=txn)
            # because of bugs, we sometime get here duplicate ids. instead of dying,
            # warn the user but continue
            item_ids = []
            for item in items:
                item.clean_up()
                if item.id in item_ids:
                    log("WARNING: skipping duplicate ids in delete items request %s and %s" % (item.title, item.id))
                    # filter out any duplicates
                    feed_item_ids = filter(lambda x: x != item.id, feed_item_ids)
                    continue
                item_ids.append(item.id)
                #log("deleting item %d:%d" % (feed.id, item.id))
                if item.id in feed_item_ids:
                    feed_item_ids.remove(item.id)
                    self._db.delete("%d:%d" % (feed.id, item.id), txn=txn)
            self.save_feed_item_ids(feed, feed_item_ids, txn)
        except Exception, ex:
            txn.abort()
            log_exc("error while deleting items")
        else:
            txn.commit()

    def modify_items(self, items):
        txn = self.begin_transaction()
        try:
            for item in items:
                self._db.put("%d:%d" % (item.feed.id, item.id),
                             stringify_item(item), txn=txn)
        except Exception, ex:
            txn.abort()
            logtb(str(ex))
        else:
            txn.commit()

    def get_feed_items(self, feed):
        txn = self.begin_transaction()
        items = []
        try:
            ids = self.get_item_ids(feed.id, txn=txn)
            for id in ids:
                item = self.get_item(feed.id, id, txn=txn)
                if item is not None:
                    items.append(item)
        except Exception, ex:
            txn.abort()
            log(str(ex))
        else:
            txn.commit()
            return items

    def get_number_of_unread(self, fid, cutoff):
        # Used by config conversion
        # NOTE: this is the number of unread items in 'number of items stored'
        # preference. Since straw stores the most recent items down the list,
        # we only count the unread items from the most recent N items,
        # where N = cutoff.
        txn = self.begin_transaction()
        num_unread = 0
        try:
            ids = self.get_item_ids(fid, txn=txn)
            for id in ids[len(ids)-cutoff:]:
                item = self.get_item(fid, id, txn=txn)
                if item is not None and item.seen == 0:
                    num_unread += 1
                else: continue
        except Exception, ex:
            txn.abort()
            logtb(str(ex))
        else:
            txn.commit()
            return num_unread

    def get_image_urls(self, txn=None):
        dkeys = self._db.get("images", txn=txn)
        keys = []
        if dkeys is not None:
            keys = pickle.loads(dkeys)
        return keys

    def save_image_urls(self, urls, txn=None):
        self._db.put("images", pickle.dumps(urls), txn=txn)

    def get_image_counts(self, txn=None):
        images = self.get_image_urls(txn)
        counts = []
        for image in images:
            key = ("imagecount:" + image).encode('utf-8')
            value = self._db.get(str(key))
            try:
                counts.append((image, int(value)))
            except:
                log("exception for ", key, ", type of value ", value, ": ", type(value))
        return counts

    def update_image_count(self, url, count):
        #logparam(locals(), "url", "count")
        key = ("imagecount:" + url).encode('utf-8')
        txn = self.begin_transaction()
        try:
            if count < 1:
                self._db.delete(key, txn=txn)
            else:
                self._db.put(key, str(count), txn=txn)
        except:
            txn.abort()
            raise
        else:
            txn.commit()

    def update_image(self, url, image):
        key = "image:%s" % str(url)
        txn = self.begin_transaction()
        try:
            image_urls = self.get_image_urls(txn)
            if image is not None:
                self._db.put(key.encode('utf-8'), image, txn=txn)
                if url not in image_urls:
                    image_urls.append(url)
                    self.save_image_urls(image_urls, txn)
            else:
                if url in image_urls:
                    try:
                        self._db.delete(key, txn=txn)
                    except DBNotFoundError:
                        log("Key not found", key)
                    image_urls.remove(url)
                    self.save_image_urls(image_urls, txn=txn)
        except:
            txn.abort()
            raise
        else:
            txn.commit()

    def get_image_data(self, url, txn=None):
        return self._db.get(
            "image:%s" % url.encode('utf-8'), default = None, txn=txn)

    def _image_print(self, key, data):
        if key[:6] == "image:":
            print key

    def _data_print(self, key, data):
        data = pickle.loads(data)
        pprint ({key: data})

    def _db_print(self, helper):
        """Print the database to stdout for debugging"""
        print "******** Printing raw database for debugging ********"
        print "database version: %s" % self.get_db_version()
        cur = self._db.cursor()
        try:
            key, data = cur.first()
            while 1 :
                helper(key, data)
                next = cur.next()
                if next:
                    key, data = next
        finally:
            cur.close()

    def get_db_version(self, txn=None):
        version = self._db.get("straw_db_version", default = "1", txn=txn)
        return int(version)

    def set_db_version(self, version, txn=None):
        try:
            if txn is None:
                txn = self.begin_transaction()
                self._db.put("straw_db_version", str(version), txn=txn)
        except:
            txn.abort()
            raise
        else:
            txn.commit()

    def convert_old_versions(self):
        version = self.get_db_version()
        while version < self.CURRENT_VERSION:
            next = version + 1
            mname = "convert_%d_%d" % (version, next)
            try:
                method = getattr(self, mname)
            except AttributeError:
                raise ConvertException(version, next, "No conversion function specified")
            method()
            self.set_db_version(next)
            version = next

    def convert_1_2(self):
        def is_item(key):
            parts = key.split(':')
            if len(parts) != 2:
                return False
            return parts[0].isdigit() and parts[1].isdigit()

        def round_second(ttuple):
            l = list(ttuple)
            l[5] = int(round(l[5]))
            return tuple(l)

        try:
            import mx.DateTime as mxd
        except ImportError:
            raise ConvertException(1, 2, _("Couldn't import mx.DateTime"))
        txn = self.begin_transaction()
        try:
            cur = self._db.cursor(txn=txn)
            try:
                next = cur.first()
                key = None
                if next:
                    key, data = cur.first()
                while key is not None:
                    if is_item(key):
                        dict = pickle.loads(data)
                        if isinstance(dict['pub_date'], mxd.DateTimeType):
                            p = dict['pub_date']
                            t = time.gmtime(time.mktime(round_second(p.tuple())))
                            dict['pub_date'] = t
                            data = pickle.dumps(dict)
                            cur.put(key, data, DB_CURRENT)
                    next = cur.next()
                    if next:
                        key, data = next
                    else:
                        break
            finally:
                cur.close()
        except Exception, ex:
            txn.abort()
            raise
        else:
            txn.commit()

    def convert_2_3(self):
        def is_item(key):
            parts = key.split(':')
            if len(parts) != 2:
                return False
            return parts[0].isdigit() and parts[1].isdigit()

        imagelistcursor = None
        images = {}
        txn = self.begin_transaction()
        try:
            cur = self._db.cursor(txn=txn)
            try:
                next = cur.first()
                key = None
                if next:
                    key, data = cur.first()
                while key is not None:
                    if is_item(key):
                        dic = pickle.loads(data)
                        for image in dic['images']:
                            images[image] = images.get(image, 0) + 1
                    elif key == "images":
                        imagelistcursor = cur.dup(DB_POSITION)
                    next = cur.next()
                    if next:
                        key, data = next
                    else:
                        break
                for image, count in images.items():
                    key = ("imagecount:" + image).encode('utf-8')
                    cur.put(key, str(count), DB_KEYFIRST)
                imagelistcursor.put("images", pickle.dumps(images.keys()), DB_CURRENT)
            finally:
                cur.close()
                if imagelistcursor != None:
                    imagelistcursor.close()
        except Exception, ex:
            txn.abort()
            raise
        else:
            txn.commit()

class ModifyItemAction:
    def __init__(self, item):
        self._item = item

    def doit(self, db):
        db.modify_items([self._item])

class ModifyItemsAction:
    def __init__(self, items):
        self._items = items

    def doit(self, db):
        db.modify_items(self._items)

class ItemsAddedAction:
    def __init__(self, feed, items):
        self._feed = feed
        self._items = items

    def doit(self, db):
        db.add_items(self._feed, self._items)

class DeleteItemAction:
    def __init__(self, feed, items):
        self._feed = feed
        self._items = items

    def doit(self, db):
        db.delete_items(self._feed, self._items)

class ImageUpdateAction:
    def __init__(self, url, image):
        self._url = url
        self._image = image

    def doit(self, db):
        db.update_image(self._url, self._image)

class ImageCountChangedAction:
    def __init__(self, url, count):
        self._url = url
        self._count = count

    def doit(self, db):
        db.update_image_count(self._url, self._count)

class ItemStore:
    def __init__(self, dbhome):
        feedlist = FeedList.get_instance()
        self._db = MyDB(DATABASE_FILE_NAME, dbhome, create = 1)
        self.connect_signals()
        feedlist.signal_connect(Event.FeedCreatedSignal,
                                self._feed_created_cb)
        feedlist.signal_connect(Event.FeedDeletedSignal,
                                self._feed_deleted_cb)
        ImageCache.cache.signal_connect(Event.ImageUpdatedSignal,
                                        self.image_updated)
        self._stop = False
        self._action_queue = []

    def _feed_created_cb(self, signal):
        self._connect_feed_signals(signal.feed)

    def _feed_deleted_cb(self, signal):
        self._disconnect_feed_signals(signal.feed)

    def connect_signals(self):
        flist = FeedList.get_instance().flatten_list()
        for f in flist:
            self._connect_feed_signals(f)

    def _connect_feed_signals(self, feed):
        feed.signal_connect(Event.NewItemsSignal, self.items_added)
        feed.signal_connect(Event.ItemReadSignal, self.item_modified)
        feed.signal_connect(Event.ItemStickySignal, self.item_modified)
        feed.signal_connect(Event.AllItemsReadSignal, self.all_items_read)
        feed.signal_connect(Event.ItemDeletedSignal, self.item_deleted)

    def _disconnect_feed_signals(self, feed):
        feed.signal_disconnect(Event.NewItemsSignal, self.items_added)
        feed.signal_disconnect(Event.ItemReadSignal, self.item_modified)
        feed.signal_disconnect(Event.ItemStickySignal, self.item_modified)
        feed.signal_disconnect(Event.AllItemsReadSignal, self.all_items_read)
        feed.signal_disconnect(Event.ItemDeletedSignal, self.item_deleted)


    def modify_item(self, item):
        self._action_queue.append(ModifyItemAction(item))
        return

    def image_updated(self, signal):
        self._action_queue.append(
            ImageUpdateAction(signal.url, signal.data))

    def read_image(self, url):
        return self._db.get_image_data(url)

    def item_deleted(self, signal):
        self._action_queue.append(DeleteItemAction(signal.sender, signal.item))

    def item_modified(self, signal):
        self.modify_item(signal.item)

    def all_items_read(self, signal):
        self._action_queue.append(ModifyItemsAction(
            [item for index, item in signal.changed]))

    def items_added(self, signal):
        self._action_queue.append(
            ItemsAddedAction(signal.sender, signal.items))

    def read_feed_items(self, feed):
        return self._db.get_feed_items(feed)

    def get_number_of_unread(self, feed_id, cutoff):
        return self._db.get_number_of_unread(feed_id, cutoff)

    def get_image_counts(self):
        return self._db.get_image_counts()

    def set_image_count(self, image, count):
        self._action_queue.append(
            ImageCountChangedAction(image, count))

    def start(self):
        mlmgr = MainloopManager.get_instance()
        mlmgr.set_repeating_timer(5000, self._run)

    def stop(self):
        mlmgr = MainloopManager.get_instance()
        mlmgr.end_repeating_timer(self._run)
        self._db.checkpoint()
        self._db.close()
        self._stop = True

    def _run(self):
        self._db.checkpoint()
        freq = 5
        timer = freq
        cpfreq = 60
        cptimer = cpfreq
        prevtime = time.time()
        if not self._stop:
            tmptime = time.time()
            timer += tmptime - prevtime
            cptimer += tmptime - prevtime
            prevtime = tmptime
            if timer > freq:
                try:
                    while len(self._action_queue):
                        action = self._action_queue.pop(0)
                        if action is None:
                            break
                        action.doit(self._db)
                except IndexError, e:
                    pass
                timer = 0
            if cptimer > cpfreq:
                self._db.checkpoint()
                cptimer = 0

itemstore_instance = None
def get_instance(straw_dir=None):
    global itemstore_instance
    if itemstore_instance is None:
        if straw_dir is None:
            straw_dir = os.path.join(os.getenv('HOME'), ".straw")
        itemstore_instance = ItemStore(straw_dir)
    return itemstore_instance

def stringify_item(item):
    itemdict = {
        'title': item.title,
        'link': item.link,
        'description': item.description,
        'guid': item.guid,
        'guidislink': item.guidislink,
        'pub_date': item.pub_date,
        'source': item.source,
        'images': item.image_keys(),
        'seen': item.seen,
        'id': item.id,
        'fm_license': item.fm_license,
        'fm_changes': item.fm_changes,
        'creator': item.creator,
        'contributors': item.contributors,
        'license_urls': item.license_urls,
        'publication_name': item.publication_name,
        'publication_volume': item.publication_volume,
        'publication_number': item.publication_number,
        'publication_section': item.publication_section,
        'publication_starting_page': item.publication_starting_page,
        'sticky': item._sticky,
        'enclosures': item.enclosures}
    return pickle.dumps(itemdict)

def unstringify_item(itemstring):
    if not itemstring:
        return None

    idict = _unpickle(itemstring)
    if not idict:
        return None

    item = SummaryItem.SummaryItem()
    item.title = idict['title']
    item.link = idict['link']
    item.description = idict['description']
    item.guid = idict['guid']
    item.pub_date = idict['pub_date']
    item.source = idict['source']
    for i in idict['images']:
        item.restore_image(i)
    item.seen = idict['seen']
    item.id = idict['id']
    item.guidislink = idict.get('guidislink', True)
    item.fm_license = idict.get('fm_license', None)
    item.fm_changes = idict.get('fm_changes', None)
    item.creator = idict.get('creator', None)
    item.contributors = idict.get('contributors', None)
    item.license_urls = idict.get('license_urls', None)
    item._sticky = idict.get('sticky', 0)
    item.enclosures = idict.get('enclosures', None)
    item.publication_name = idict.get('publication_name', None)
    item.publication_volume = idict.get('publication_volume', None)
    item.publication_number = idict.get('publication_number', None)
    item.publication_section = idict.get('publication_section', None)
    item.publication_starting_page = idict.get('publication_starting_page', None)
    return item

def _unpickle(istring):
    itemdict = None
    try:
        itemdict = pickle.loads(istring)
    except ValueError, ve:
        log("ItemStore.unstringify_item: pickle.loads raised ValueError, argument was %s" % repr(itemstring))
    except Exception, ex:
        logtb(str(ex))
    return itemdict

if __name__ == '__main__':
    from pprint import pprint
    db = MyDB("itemstore.db", "%s/.straw" % os.getenv('HOME'), create = 1)
    db._db_print(db._data_print)


syntax highlighted by Code2HTML, v. 0.9.1