# **************************************************************************
# *
# * Authors:     J.M. De la Rosa Trevin (delarosatrevin@scilifelab.se) [1]
# *              Grigory Sharov (gsharov@mrc-lmb.cam.ac.uk) [2]
# *
# * [1] SciLifeLab, Stockholm University
# * [2] MRC Laboratory of Molecular Biology (MRC-LMB)
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'delarosatrevin@scilifelab.se'
# *
# **************************************************************************

import os
import unittest

from emhub.data import DataManager, ImageSessionData, H5SessionData
from emhub.data.imports.testdata import TestData


class TestDataManager(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        sqlitePath = '/tmp/emhub.sqlite'

        if os.path.exists(sqlitePath):
            os.remove(sqlitePath)

        cls.dm = DataManager(sqlitePath)
        # populate db with test data
        TestData(cls.dm)

    def test_users(self):
        print("=" * 80, "\nTesting users...")
        users = self.dm.get_users()

        for u in users:
            pi = u.get_pi()

            if pi is not None:
                self.assertEqual(pi.id, u.get_pi().id)
            pi_str = pi.name if pi else 'None'

    def test_applications(self):
        print("=" * 80, "\nTesting applications...")
        applications = self.dm.get_applications()

        codes = [p.code for p in applications]

        self.assertEqual(codes, ['CEM00297', 'CEM00315', 'CEM00332', 'DBB00001',
                                 'CEM00345', 'CEM00346'])

        users = self.dm.get_users()
        uList = []
        piDict = {}

        for u in users:
            if u.is_pi:
                piDict[u.id] = [u]
                uList.append("PI: %s, applications: %s" % (u.name,
                                                           u.created_applications))
                uList.append("   Lab members:")
                for u2 in u.lab_members:
                    uList.append("     - %s" % u2.name)
            else:
                pi = u.get_pi()
                if pi is not None:
                    piDict[pi.id].append(u)

        for l in uList:
            print(l)

        # Check that all users in the same lab, have the same applications
        for pi_id, members in piDict.items():
            pi = members[0]
            pRef = pi.get_applications()

            # Check that the relationship pi-lab_members is working as expected
            print(">>> Checking PI: ", pi.name)
            members_ids = set(m.id for m in members[1:])
            lab_ids = set(u.id for u in pi.lab_members)
            print("   members: ", members_ids)
            print("   lab_ids: ", lab_ids)
            # self.assertEqual(members_ids, lab_ids)

            # Check applications for lab_members
            for u in members[1:]:
                self.assertEqual(pRef, u.get_applications())

    def test_count_booking_resources(self):
        print("=" * 80, "\nTesting counting booking resources...")

        def print_count(count):
            for a, count_dict in count.items():
                print("Application ID: ", a)
                for k, c in count_dict.items():
                    print("   %s: %s" % (k, c))

        applications = [a.id for a in self.dm.get_applications()]
        count_resources = self.dm.count_booking_resources(applications)
        print_count(count_resources)
        self.assertTrue(len(count_resources) > 0)

        count_tags = self.dm.count_booking_resources(applications,
                                                     resource_tags=['krios'])
        self.assertTrue(len(count_tags))
        print_count(count_tags)


class TestSessionData(unittest.TestCase):
    def test_basic(self):
        setId = 1
        tsd = ImageSessionData()
        mics = tsd.get_items(setId)
        print("=" * 80, "\nTesting hdf5 session...")

        hsd = H5SessionData('/tmp/data.h5', 'w')
        hsd.create_set(setId, label='Test set')

        for mic in mics:
            micData = tsd.get_item(setId, mic.id,
                                   dataAttrs=['micThumbData',
                                              'psdData',
                                              'shiftPlotData'])
            hsd.add_item(setId, itemId=mic.id, **micData._asdict())

        hsd.close()

        hsd = H5SessionData('/tmp/data.h5', 'r')
        for mic in hsd.get_items(setId):
            print(mic)
