"""
Sanity checks for the humongous ball of dragonspit that is
ubuntuone.syncdaemon.states
"""
import re
import unittest
from ubuntuone.syncdaemon import states, event_queue


all_states = dict((k, v) for (k, v) in vars(states).items()
                  if isinstance(v, states.SyncDaemonState))
all_trns = set(sum((state.transitions.values()
                    for state in all_states.values()), []))

has_q = re.compile(r'.*_WITH_([^_]+Q)$').search

class TestBasic(unittest.TestCase):
    """
    Checks that apply to (nearly) all states
    """
    def test_name(self):
        """
        Check that the object's name is the same as the name of the object
        """
        self.assertEqual(self.name, self.state.name)

    def _test_handle_external_events(self, external_event):
        """
        Generic check, that the given external event is handled in some way
        """
        self.assertTrue(external_event in self.state.transitions,
                        "unhandled external event: %s" % external_event)

    def test_handle_sys_net_disconnected(self):
        """
        Check SYS_NET_DISCONNECTED is handled
        """
        self._test_handle_external_events('SYS_NET_DISCONNECTED')

    def test_handle_sys_net_connected(self):
        """
        Check SYS_NET_CONNECTED is handled
        """
        self._test_handle_external_events('SYS_NET_CONNECTED')

    def test_handle_sys_connect(self):
        """
        Check SYS_CONNECT is handled
        """
        self._test_handle_external_events('SYS_CONNECT')

    def test_handle_sys_disconnect(self):
        """
        Check SYS_DISCONNECT is handled
        """
        self._test_handle_external_events('SYS_DISCONNECT')

    def test_handle_sys_connection_lost(self):
        """
        Check SYS_CONNECTION_LOST is handled
        """
        self._test_handle_external_events('SYS_CONNECTION_LOST')

    def test_reachable(self):
        """
        Check the state is reachable via a transition
        """
        self.assertTrue(self.name in all_trns,
                        "%s is not the target of any transition!" % self.name)

    def test_start_connecting(self):
        """
        Check START_CONNECTING is only reached via SYS_CONNECTION_LOST
        or one of the READY states
        """
        for evt, trn in self.state.transitions.items():
            if ((trn.startswith('START_CONNECTING')
                 and not (evt == 'SYS_CONNECTION_LOST'
                          or 'WITH_CONNECTION_LOST' in self.name))
                and not
                ((self.name.startswith('READY_WAITING')
                  and evt == 'SYS_NET_CONNECTED')
                 or (self.name.startswith('READY_WITH_NETWORK')
                     and evt == 'SYS_CONNECT')
                 or (self.name.startswith('READING_WAITING_WITH_NETWORK')
                     and evt == 'SYS_LOCAL_RESCAN_DONE'))):
                raise AssertionError('%s --[%s]--> %s, not STANDOFFish'
                                     % (self.name, evt, trn))

    def test_events_exist(self):
        """
        Check the events are known to EventQueue
        """
        diff = set(self.state.transitions) - set(event_queue.EVENTS)
        self.assertFalse(diff, "unknown events: %s" % ", ".join(diff))

    def test_transitions_exist(self):
        """
        Check the target states are known
        """
        diff = set(self.state.transitions.values()) - set(all_states)
        self.assertFalse(diff, "unknown transitions: %s" % ", ".join(diff))

    def test_enter_only_on_start(self):
        """
        Only START_ states (and UNKNOWN_ERROR) should have an enter function
        """
        if self.name != 'UNKNOWN_ERROR':
            if not self.name.startswith('START_'):
                if self.state.enter is not None:
                    raise AssertionError("%s has enter" % self.name)

    def _test_withq_evt_ok(self, evt, exceptions):
        """
        Generic check that the different WITH_xxxQ states are sane

        'sane' here means that they follow the general pattern (seen
        in ok_trns), i.e. that the different queue events move the
        state correctly. There are exceptions, which can be specified.
        """
        if evt not in self.state.transitions:
            return

        trn = self.state.transitions[evt]

        if has_q(self.name):
            this_q = filter(None, has_q(self.name).groups())[0].lower()
        else:
            this_q = ''

        ok_trns = {'metaq': {'SYS_CONTENT_QUEUE_DONE': 'metaq',
                             'SYS_CONTENT_QUEUE_WAITING': 'bothq',
                             'SYS_META_QUEUE_DONE': '',
                             'SYS_META_QUEUE_WAITING': 'metaq'},
                   'contq': {'SYS_CONTENT_QUEUE_DONE': '',
                             'SYS_CONTENT_QUEUE_WAITING': 'contq',
                             'SYS_META_QUEUE_DONE': 'contq',
                             'SYS_META_QUEUE_WAITING': 'bothq'},
                   'bothq': {'SYS_CONTENT_QUEUE_DONE': 'metaq',
                             'SYS_CONTENT_QUEUE_WAITING': 'bothq',
                             'SYS_META_QUEUE_DONE': 'contq',
                             'SYS_META_QUEUE_WAITING': 'bothq'},
                   '': {'SYS_CONTENT_QUEUE_DONE': '',
                        'SYS_CONTENT_QUEUE_WAITING': 'contq',
                        'SYS_META_QUEUE_DONE': '',
                        'SYS_META_QUEUE_WAITING': 'metaq'},
                   }

        if has_q(trn):
            that_q = filter(None, has_q(trn).groups())[0].lower()
        else:
            that_q = ''
        if ok_trns[this_q][evt] != that_q:
            # a few exceptions exist
            if (self.name, trn) not in exceptions:
                raise AssertionError('%s --[%s]--> %s instead of %s'
                                     % (self.name, evt, trn,
                                        ok_trns[this_q][evt]))
    def test_metaq_waiting_evt_ok(self):
        """
        Check for correct handling of SYS_META_QUEUE_WAITING
        """
        self._test_withq_evt_ok(
            'SYS_META_QUEUE_WAITING',
            [('IDLE', 'START_WORKING_ON_METADATA'),
             ('START_WORKING_ON_BOTH', 'WORKING_ON_BOTH'),
             ('START_WORKING_ON_CONTENT', 'START_WORKING_ON_BOTH'),
             ('START_WORKING_ON_METADATA_WITH_CONTQ',
                                        'WORKING_ON_METADATA_WITH_CONTQ'),
             ('START_WORKING_ON_METADATA', 'WORKING_ON_METADATA'),
             ('WORKING_ON_BOTH', 'WORKING_ON_BOTH'),
             ('WORKING_ON_CONTENT', 'START_WORKING_ON_BOTH'),
             ('WORKING_ON_METADATA_WITH_CONTQ',
                                        'WORKING_ON_METADATA_WITH_CONTQ'),
             ('WORKING_ON_METADATA', 'WORKING_ON_METADATA'),
             ])

    def test_metaq_done_evt_ok(self):
        """
        Check for correct handling of SYS_META_QUEUE_DONE
        """
        self._test_withq_evt_ok(
            'SYS_META_QUEUE_DONE',
            [('START_WORKING_ON_METADATA_WITH_CONTQ',
                                        'START_WORKING_ON_CONTENT'),
             ('WORKING_ON_METADATA_WITH_CONTQ', 'START_WORKING_ON_CONTENT'),
             ])

    def test_contq_waiting_ok(self):
        """
        Check for correct handling of SYS_CONTENT_QUEUE_WAITING
        """
        self._test_withq_evt_ok(
            'SYS_CONTENT_QUEUE_WAITING',
            [('IDLE', 'START_WORKING_ON_CONTENT'),
             ('START_WORKING_ON_BOTH', 'WORKING_ON_BOTH'),
             ('START_WORKING_ON_CONTENT', 'WORKING_ON_CONTENT'),
             ('WORKING_ON_BOTH', 'WORKING_ON_BOTH'),
             ('WORKING_ON_CONTENT', 'WORKING_ON_CONTENT'),
             ])

    def test_contq_done_ok(self):
        """
        Check for correct handling of SYS_CONTENT_QUEUE_DONE
        """
        self._test_withq_evt_ok('SYS_CONTENT_QUEUE_DONE', [])

    def test_non_network_transition_leaves_network_alone(self):
        """
        Check that the non-network events affect the network state

        network events are SYS_NET_CONNECTED and SYS_NET_DISCONNECTED
        """
        bad = []
        tpl = "%s.has_network is %s but -[%s]->%s.has_network is %s"
        for evt, trn in self.state.transitions.items():
            if evt in ('SYS_NET_CONNECTED', 'SYS_NET_DISCONNECTED'):
                continue
            trn = all_states[trn]
            if self.state.has_network != trn.has_network and not trn.is_error:
                bad.append(tpl % (self.state.name,
                                  self.state.has_network,
                                  evt, trn.name,
                                  trn.has_network))
        self.assertFalse(bad, ";\n".join(bad))

    def test_network_transition_changes_network_state(self):
        """
        Check that network events affect the network state
        """
        connected = self.state.transitions['SYS_NET_CONNECTED']
        disconnected = self.state.transitions['SYS_NET_DISCONNECTED']
        if not self.state.is_error:
            self.assertTrue(all_states[connected].has_network)
        else:
            self.assertFalse(all_states[connected].has_network)
        self.assertFalse(all_states[disconnected].has_network)

    def test_non_volitional_transition_leaves_volition_alone(self):
        """
        Check that non-volitional events leave the volition alone

        Volitional events are SYS_CONNECT and SYS_DISCONNECT
        """
        bad = []
        tpl = "%s.wants_to_connect is %s but -[%s]->%s.wants_to_connect is %s"
        for evt, trn in self.state.transitions.items():
            if evt in ('SYS_CONNECT', 'SYS_DISCONNECT'):
                continue
            trn = all_states[trn]
            if self.state.wants_to_connect != trn.wants_to_connect and not trn.is_error:
                bad.append(tpl % (self.state.name,
                                  self.state.wants_to_connect,
                                  evt, trn.name,
                                  trn.wants_to_connect))
        self.assertFalse(bad, ";\n".join(bad))


class TestWithQ:
    """
    Checks for states with a given with_q
    """
    def get_thisq(self):
        """
        Get the queue of the current state (from the state's name)
        """
        m = has_q(self.state.name)
        if m is None:
            return None
        return filter(None, m.groups())[0]

    def test_withq(self):
        """
        Test the queue specified via with_q and the state's name match
        """
        if self.state.with_q is None and has_q(self.name):
            raise AssertionError('%s should have with_q %s, has None'
                                 % (self.name, self.get_thisq()))
        elif self.state.with_q != self.get_thisq():
            raise AssertionError("%s's name imples with_q %s, has %s"
                                 % (self.name, self.get_thisq(),
                                    self.state.with_q))

    def test_another_withq(self):
        """
        Test that states with one of the with_q's have the others.

        Some exceptions exist.
        """
        this_q = self.get_thisq() or ''
        is_ok = ((this_q == '' and self.state.with_q is None)
                 or (this_q.upper() == self.state.with_q))
        if not is_ok:
            if (self.name, this_q, self.state.with_q) not in [
                ('START_WORKING_ON_CONTENT', 'cont', None),
                ('START_WORKING_ON_METADATA', 'meta', None),
                ('WORKING_ON_CONTENT', 'cont', None),
                ('WORKING_ON_METADATA', 'meta', None),
                ]:
                raise AssertionError(
                    "%s's name implies %s, but attribute is %r"
                    % (self.name, this_q or 'noQ', self.state.with_q))


class TestCleanupWithConnectionLost:
    """
    Checks for CLEANUP states that have already seen SYS_CONNECTION_LOST
    """
    def test_cleanup_with_connection_lost(self):
        """
        CLEANUP_WITH_CONNECTION_LOST states have already gotten their
        CONNECTION_LOST, so they don't need to wait for it; they need
        to not wait for it, instead.
        """
        if self.name.startswith('START_'):
            name = self.name[6:]
        else:
            name = self.name
        target = self.state.transitions['SYS_CONNECTION_LOST']
        self.assertEqual(name, target)

class TestCleanupPlain:
    """
    Tests for CLEANUP states that have not yet seen SYS_CONNECTION_LOST
    """
    def test_cleanup_plain_on_connection_lost(self):
        """
        When the CONNECTION_LOST events come in, "plain" cleanup states
        need to go to the sibling CLEANUP_WITH_CONNECTION_LOST
        """
        evt = 'SYS_CONNECTION_LOST'
        trn = self.state.transitions[evt]
        expected = self.name
        # start transitions don't transition to self
        expected = re.sub(r'^START_', '', expected)
        expected = re.sub(r'^CONNECTED_CLEANUP',
                          'CLEANUP_WITH_NETWORK', expected)
        expected = re.sub(r'(.*?)((?:_WITH_(?:CONT|META|BOTH)Q)?)$',
                          r'\1_WITH_CONNECTION_LOST\2',
                          expected)
        self.assertEqual(trn, expected,
                         '%s --[%s]--> %s, should be %s'
                         % (self.name, evt, trn, expected))

class TestCleanupNonStarter:
    """
    Test for CLEANUP states that are not START_ states
    """
    def test_cleanup_on_cleanup_done(self):
        """
        Test non-START_ CLEANUP states handle SYS_CLEANUP_FINISHED
        appropriately
        """
        self.assertTrue('SYS_CLEANUP_FINISHED' in self.state.transitions,
                        '%s should handle SYS_CLEANUP_FINISHED' % self.name)
        evt = 'SYS_CLEANUP_FINISHED'
        trn = self.state.transitions[evt]
        self.assertTrue(trn.startswith('START_STANDOFF'),
                        '%s --[%s]--> %s, should be START_STANDOFFish'
                        % (self.name, evt, trn))

class TestStarter:
    """
    Tests for START_ states
    """
    def setUp(self):
        """
        Common set-up code
        """
        self.other_name = re.sub(r'^START_', '', self.name)
        self.other = all_states.get(self.other_name, None)
        if self.other is not None:
            self.this_trn = set(self.state.transitions)
            self.other_trn = set(self.other.transitions)
            self.only_in_this = self.this_trn - self.other_trn
            self.only_in_other = self.other_trn - self.this_trn \
                - set(['SYS_SERVER_RESCAN_DONE', 'SYS_CLEANUP_FINISHED'])

    def test_starters_have_enter(self):
        """
        All START_ states should have an enter function
        """
        self.assertTrue(self.state.enter is not None,
                        '%s should have an enter function' % self.name)

    def test_other_exists(self):
        """
        All START_ states should have a non-START_ state
        """
        self.assertTrue(self.other_name in all_states,
                        '%s has no non-starter' % self.name)

    def test_no_events_only_in_this(self):
        """
        All START_ states need to handle the same events as their
        non-START_ counterparts
        """
        if self.other is None:
            raise AssertionError('missing non-starter, unable to test')
        self.assertFalse(
            self.only_in_this,
            # an event in the START state is not handled by the non-START
            'missing from %s: %s' % (self.other_name,
                                     ", ".join(sorted(self.only_in_this))))

    def test_no_events_only_in_other(self):
        """
        in general, non-START_ states need to handle the same events
        as their START_ counterparts
        """
        if self.other is None:
            raise AssertionError('missing non-starter, unable to test')
        self.assertFalse(
          self.only_in_other,
          'missing from %s: %s' % (self.name,
                                   ", ".join(sorted(self.only_in_other))))

    def test_same_transitions(self):
        """
        in general, the transition dict for START_ and non-START_
        states should be the same
        """
        if self.other is None:
            raise AssertionError('missing non-starter, unable to test')
        bad = ['%s pushes %s --> %s, but %s --> %s'
               % (trn, self.name, self.state.transitions[trn],
                  self.other.name, self.other.transitions[trn])
               for trn in sorted(self.this_trn.intersection(self.other_trn))
               if (self.state.transitions[trn] != self.other.transitions[trn]
                   and not (self.name == self.state.transitions[trn] and
                            self.other.name == self.other.transitions[trn]))]
        self.assertFalse(bad, bad)


class TestOtherQ:
    """
    Test for states that have _WITH_fooQ states
    """
    def _test_other_queues(self, queue, exceptions):
        """
        Check that a WITH_fooQ state exists for the given queue (with
        the given exceptions). Further, check that if it does indeed
        exist, it handles the same transitions.
        """
        other_name = self.name + '_WITH_' + queue
        if other_name not in all_states:
            if self.name not in exceptions:
                raise AssertionError('%s has some queues but not %s'
                                     % (self.name, queue))
        else:
            other = all_states[other_name]
            this_trn = set(self.state.transitions)
            other_trn = set(other.transitions)
            if this_trn != other_trn:
                only_in_this = this_trn - other_trn
                only_in_other = other_trn - this_trn
                if only_in_this:
                    raise AssertionError('missing from %s: %s' %
                                         (other_name,
                                          ", ".join(sorted(only_in_this))))
                if only_in_other:
                    raise AssertionError('missing from %s: %s' %
                                         (self.name,
                                          ", ".join(sorted(only_in_other))))

    def test_other_meta_queue(self):
        """
        Check that a WITH_METAQ state exists for the current state
        (with exceptions). Further, check that if it does indeed
        exist, it handles the same transitions as the base state.
        """
        self._test_other_queues('METAQ', [
                'START_CLEANUP_WAITING',
                'START_CLEANUP_WAITING_WITH_NETWORK_WITH_CONNECTION_LOST',
                'START_CONNECTED_CLEANUP',
                'START_WORKING_ON_METADATA',
                'WORKING_ON_METADATA',
                ])

    def test_other_cont_queue(self):
        """
        Check that a WITH_CONTQ state exists for the current state
        (with exceptions). Further, check that if it does indeed
        exist, it handles the same transitions as the base state.
        """
        self._test_other_queues('CONTQ', [])

    def test_other_both_queue(self):
        """
        Check that a WITH_BOTHQ state exists for the current state
        (with exceptions). Further, check that if it does indeed
        exist, it handles the same transitions as the base state.
        """
        self._test_other_queues('BOTHQ', [
                'START_CLEANUP_WAITING',
                'START_CLEANUP_WAITING_WITH_NETWORK_WITH_CONNECTION_LOST',
                'START_CONNECTED_CLEANUP',
                'START_WORKING_ON_METADATA',
                'WORKING_ON_METADATA',
                ])


class TestHandleQueues:
    """
    Tests for handling of queue events
    """
    def _test_handles_queue_event(self, queue_event):
        """
        Generic test that the current state handles the given queue event
        """
        self.assertTrue(queue_event in self.state.transitions,
                        self.name + ' does not handle ' + queue_event)

    def test_handles_meta_queue_waiting(self):
        """
        Check SYS_META_QUEUE_WAITING is handled
        """
        self._test_handles_queue_event('SYS_META_QUEUE_WAITING')

    def test_handles_content_queue_waiting(self):
        """
        Check SYS_CONTENT_QUEUE_WAITING is handled
        """
        self._test_handles_queue_event('SYS_CONTENT_QUEUE_WAITING')

    def test_handles_meta_queue_done(self):
        """
        Check SYS_META_QUEUE_DONE is handled
        """
        self._test_handles_queue_event('SYS_META_QUEUE_DONE')

    def test_handles_content_queue_done(self):
        """
        Check SYS_CONTENT_QUEUE_DONE is handled
        """
        self._test_handles_queue_event('SYS_CONTENT_QUEUE_DONE')


def test_suite():
    """
    Build the test suite
    """
    suite = unittest.TestSuite()
    loader = unittest.TestLoader()
    flag = object()
    for name, state in all_states.items():
        bases = [TestBasic]
        if getattr(state, 'with_q', flag) is not flag:
            bases.append(TestWithQ)
        if re.search('CLEANUP.*WITH_CONNECTION_LOST', name):
            bases.append(TestCleanupWithConnectionLost)
        elif 'CLEANUP' in name:
            bases.append(TestCleanupPlain)
            if name.startswith('CLEANUP'):
                bases.append(TestCleanupNonStarter)
        if name.startswith('START_'):
            if not name.startswith('START_CONNECTED_CLEANUP'):
                # CONNECTED_CLEANUP doesn't exist;
                # we fake it with CLEANUP_WITH_NETWORK
                bases.append(TestStarter)
        if any((name + '_WITH_' + queue in all_states)
               for queue in ('METAQ', 'CONTQ', 'BOTHQ')):
            bases.append(TestOtherQ)
        if not (isinstance(state, states.AQErrorState)
                or name == 'INIT'
                or name.startswith('INIT_')):
            bases.append(TestHandleQueues)
        bases = tuple(reversed(bases))
        test = type('Test_'+name, bases, {'name': name, 'state': state})
        suite.addTest(loader.loadTestsFromTestCase(test))
    return suite
