#!/usr/bin/env python
#
# Copyright (c) 2006, 2007 Canonical
#
# Written by Gustavo Niemeyer <gustavo@niemeyer.net>
#
# This file is part of Storm Object Relational Mapper.
#
# Storm is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation; either version 2.1 of
# the License, or (at your option) any later version.
#
# Storm is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import optparse
import unittest
import doctest
import new
import sys
import os

import tests


def disable_conftest():
    """Install an empty module on py.test's tests.conftest

    When other runners are stepping by the suite of tests, they shouldn't
    find the py.test specific conftest.py file.
    """
    conftest = new.module("conftest")
    conftest.__file__ = "tests/conftest.py"
    sys.modules["tests.conftest"] = tests.conftest = conftest

def test_with_trial():
    from twisted.scripts import trial
    disable_conftest()
    if not [x for x in sys.argv[1:] if not x.startswith("-")]:
        for dir, dirs, files in os.walk('tests'):
            for file in files:
                if file.endswith('.py'):
                    sys.argv.append(os.path.join(dir, file))
    trial.run()

def test_with_py_test():
    import py
    dirname = os.path.dirname(__file__)
    if not [x for x in sys.argv[1:] if not x.startswith("-")]:
        tests_dir = os.path.join(dirname, "tests/")
        # For timestamp checking when looping:
        storm_dir = os.path.join(dirname, "storm/")
        sys.argv.extend([tests_dir, storm_dir])
    py.test.cmdline.main()

def test_with_unittest():

    usage = "test.py [options] [<test filename>, ...]"

    parser = optparse.OptionParser(usage=usage)

    parser.add_option('--verbose', action='store_true')
    opts, args = parser.parse_args()
    opts.args = args

    disable_conftest()

    runner = unittest.TextTestRunner()

    if opts.verbose:
        runner.verbosity = 2
        
    loader = unittest.TestLoader()
    topdir = os.path.abspath(os.path.dirname(__file__))
    testdir = os.path.dirname(tests.__file__)
    doctest_flags = doctest.ELLIPSIS
    unittests = []
    doctests = []
    for root, dirnames, filenames in os.walk(testdir):
        for filename in filenames:
            filepath = os.path.join(root, filename)
            relpath = filepath[len(topdir)+1:]
            if (filename == "__init__.py" or filename.endswith(".pyc") or
                opts.args and relpath not in opts.args):
                pass
            elif filename.endswith(".py"):
                unittests.append(relpath)
            elif filename.endswith(".txt"):
                doctests.append(relpath)

    class Summary:
        def __init__(self):
            self.total_failures = 0
            self.total_errors = 0
            self.total_tests = 0
        def __call__(self, tests, failures, errors):
            self.total_tests += tests
            self.total_failures += failures
            self.total_errors += errors
            print "(tests=%d, failures=%d, errors=%d)" % \
                  (tests, failures, errors)

    unittest_summary = Summary()
    doctest_summary = Summary()

    if unittests:
        print "Running unittests..."
        for relpath in unittests:
            print "[%s]" % relpath
            modpath = relpath.replace('/', '.')[:-3]
            module = __import__(modpath, None, None, [""])
            test = loader.loadTestsFromModule(module)
            result = runner.run(test)
            unittest_summary(test.countTestCases(),
                             len(result.failures), len(result.errors))
            print

    if doctests:
        print "Running doctests..."
        for relpath in doctests:
            print "[%s]" % relpath
            failures, total = doctest.testfile(relpath,
                                               optionflags=doctest_flags)
            doctest_summary(total, failures, 0)
            print

    print "Total test cases: %d" % unittest_summary.total_tests
    print "Total doctests: %d" % doctest_summary.total_tests
    print "Total failures: %d" % (unittest_summary.total_failures +
                                  doctest_summary.total_failures)
    print "Total errors: %d" % (unittest_summary.total_errors +
                                doctest_summary.total_errors)

    failed = bool(unittest_summary.total_failures or
                  unittest_summary.total_errors or
                  doctest_summary.total_failures or
                  doctest_summary.total_errors)

    sys.exit(failed)

if __name__ == "__main__":
    runner = os.environ.get("STORM_TEST_RUNNER")
    if not runner:
        runner = "unittest"
    runner_func = globals().get("test_with_%s" % runner.replace(".", "_"))
    if not runner_func:
        sys.exit("Test runner not found: %s" % runner)
    runner_func()

# vim:ts=4:sw=4:et
