;; mst-unit.lisp -- Unit testing for Common Lisp
;;
;; Author: Mark Triggs <mark@dishevelled.net>

(defpackage "MST-UNIT"
  (:use "COMMON-LISP")
  (:export #:define-test #:disable-test #:undefine-test #:define-tests
           #:remove-test #:run-group-tests #:run-all-tests #:clear-group-tests
           #:clear-all-tests #:remove-test-group #:test-groups
           #:compile-all-tests :=> #:run-groups))

(push :mst-unit *features*)

(in-package "MST-UNIT")

(defvar *tests* (make-hash-table :test #'eq))

(defmacro required (name)
  `(error "~A is a required slot." ',name))

(defclass unit-test ()
  ((identifier :initarg :identifier :initform (required identifier)
               :accessor test-identifier)
   (active :initform t)
   (expected-return :initarg :expected-return)
   (return-predicate :initarg :return-predicate)
   (expected-error :initarg :expected-error)
   (test-fn :initarg :fn)
   (test :initarg :test)
   (description :initarg :description)))

(defclass test-result ()
  ((actual-return :initarg :actual-return)
   (actual-error :initarg :actual-error)))

(defun test-active-p (test)
  "Returns non-nil if TEST has not been deactivated."
  (slot-value test 'active))

(defun remove-test-group (test-group)
  "Remove an entire group of tests."
  (setf (gethash test-group *tests*) nil))

(defun remove-test (test-identifier test-group)
  "Remove a unit test."
  (setf (gethash test-group *tests*)
        (remove-if #'(lambda (test)
                       (equalp (test-identifier test)
                               test-identifier))
                   (gethash test-group *tests*))))

(defun %disable-test (test-identifier test-group)
  "Disable a unit test."
  (setf (slot-value (find-if #'(lambda (test)
                                 (equalp (test-identifier test)
                                         test-identifier))
                             (gethash test-group *tests*))
                    'active)
        nil))

(defun add-test (unit-test test-group)
  (push unit-test (gethash test-group *tests*)))

(defmacro disable-test (test-identifier test-group &body ignored)
  (declare (ignore ignored))
  `(%disable-test ,test-identifier ,test-group))

(defmacro undefine-test (test-identifier test-group &body ignored)
  (declare (ignore ignored))
  `(remove-test ,test-identifier ,test-group))

(defmacro define-test (test-identifier test-group test-fn
                       &rest rest
                       &key returns return-predicate
                       raises description satisfies (test '#'equal)
                       &allow-other-keys)
  "Define a unit test.
TEST-IDENTIFIER and TEST-GROUP should be symbols or strings that uniquely
identify this test.  TEST-FN is a function, symbol naming a function or a form
that should be evaluated for this test.

RETURNS (or =>) is a form that is evaluated to yield the expected return value
of TEST-FN.  RETURN-PREDICATE should be a function that takes the return value
of TEST-FN and returns non-nil if the output is correct.  If both RETURNS and
RETURN-PREDICATE are supplied, RETURN-PREDICATE takes precedence.

RAISES should be a symbol indicating the type of condition that TEST-FN is
expected to raise.

DESCRIPTION is a string giving a brief description of this test case.
This is currently ignored."
  (let ((test-obj (gensym)))
    `(let ((,test-obj
            (make-instance 'unit-test
                           :identifier ',test-identifier
                           :fn ,(if (or (functionp test-fn)
                                        (symbolp test-fn)
                                        (and (consp test-fn)
                                             (or (eq (car test-fn) 'function)
                                                 (eq (car test-fn) 'quote))))
                                    test-fn
                                    `(lambda () ,test-fn))
                           :test ,test
                           :expected-return ,(if (getf rest '=>)
                                                `(multiple-value-list
                                                  ,(getf rest '=>))
                                                `(multiple-value-list ,returns))
                           :return-predicate (or ,return-predicate ,satisfies)
                           :expected-error ,raises
                           :description ,description)))
       (remove-test ',test-identifier ',test-group) ; remove any existing tests
       (add-test ,test-obj ',test-group))))


(defmacro define-tests (group &body test-forms)
  "Define multiple unit tests in GROUP.  Each member of TEST-FORMS is expanded
into the equivalent DEFINE-TEST form.

E.g.
  (define-tests :fact
    (:zero (fact 0) :returns 1)
    (:one  (fact 1) :returns 1)
    (:five (fact 5) :returns 120))"
  `(progn ,@(mapcar #'(lambda (test-form)
                        `(define-test ,(car test-form) ,group
                           ,@(cdr test-form)))
                    test-forms)))


(defun unit-test-passes-p (unit-test)
  "Return T if UNIT-TEST passes.  Also returns a TEST-RESULT."
  (with-slots (expected-return expected-error
                               return-predicate test-fn test) unit-test
    (multiple-value-bind (return-val error)
        (handler-case
            (multiple-value-list (funcall test-fn))
          (condition (c) (values nil c)))
      (let ((return-val (if error (list nil) return-val))
            (error error))
        (let ((result (make-instance 'test-result
                                     :actual-return return-val
                                     :actual-error error)))
          (values (and (or (equalp error expected-error)
                           (typep error expected-error))
                       (if return-predicate
                           (multiple-value-bind (result error)
                               (ignore-errors
                                 (apply return-predicate return-val))
                             (and result (not error)))
                           (every test return-val expected-return)))
                  result))))))

(defun run-groups (test-groups &rest keys)
  (reduce #'(lambda (summary group)
              (let ((result (apply #'run-group-tests group keys)))
                (list :passed (+ (getf result :passed)
                                 (getf summary :passed))
                      :failed (+ (getf result :failed)
                                 (getf summary :failed)))))
          test-groups
          :initial-value '(:passed 0 :failed 0)))


(defun run-group-tests (test-group &key (show-passes nil) (show-fails t))
  "Run all defined unit tests for a TEST-GROUP."
  (let ((passes 0)
        (fails 0)
        (disabled 0)
        (start-time (get-internal-real-time)))
    (dolist (test (reverse (gethash test-group *tests*)))
      (cond  ((test-active-p test)
              (multiple-value-bind (passed results) (unit-test-passes-p test)
                (when (or (and passed show-passes)
                          (and (not passed) show-fails))
                  (test-report test-group passed test results))
                (if passed (incf passes) (incf fails))))
             (t (incf disabled))))
    (list :passed passes :failed fails :disabled disabled
          :time (float (/ (- (get-internal-real-time) start-time)
                                 internal-time-units-per-second)))))

(defun format-multiple-values (values)
  "Print a list of return values."
  (if (atom values)
      (format nil "~A" values)
      (format nil "~{~S~^; ~}" values)))

(defun test-report (group status unit-test test-result)
  "Pretty-print the results of a unit test."
  (with-slots (identifier return-predicate
                          expected-return expected-error test) unit-test
    (with-slots (actual-return actual-error) test-result
      (format t "  Test: ~A:~A~Vt ~A~%"
              group
              identifier 70
              (if status "PASSED" "FAILED"))
      (if return-predicate
          (format t "    Returned: ~A"
                  (format-multiple-values actual-return))
          (unless (funcall test actual-return expected-return)
            (format t "    Returned: ~A (Expected: ~A)~%"
                    (format-multiple-values actual-return)
                    (format-multiple-values expected-return))))
      (unless (equalp actual-error expected-error)
        (format t "    Raised: ~A~&            (Expected: ~A)~%"
                actual-error expected-error)))))

(defun run-all-tests (&key (show-passes nil) (show-fails t))
  "Run all defined unit tests."
  (let ((passes 0)
        (fails 0)
        (disabled 0)
        (start-time (get-internal-real-time)))
    (maphash #'(lambda (test-group tests)
                 (when tests
                   (let ((results
                          (run-group-tests test-group
                                           :show-passes show-passes
                                           :show-fails show-fails)))
                     (incf passes (getf results :passed))
                     (incf fails (getf results :failed))
                     (incf disabled (getf results :disabled)))))
             *tests*)
    (list* :passed passes :failed fails
          :time (float (/ (- (get-internal-real-time) start-time)
                                 internal-time-units-per-second))
           (if (> disabled 0)
                                             `(:disabled ,disabled)
                                             ()))))

(defun clear-group-tests (test-group)
  "Clear the unit tests defined for TEST-GROUP."
  (setf (gethash test-group *tests*) nil))

(defun compile-all-tests ()
  "Compile all defined unit tests."
  (maphash #'(lambda (test-group tests)
               (declare (ignore test-group))
               (dolist (test tests)
                 (with-slots (test-fn) test
                   (setq test-fn (compile nil test-fn)))))
           *tests*))

(defun clear-all-tests ()
  "Clear all defined unit tests."
  (setq *tests* (make-hash-table :test #'eq)))

(defun test-groups ()
  "Return a list of all test groups."
  (let ((groups '()))
    (maphash #'(lambda (test-group tests)
                 (declare (ignore tests))
                 (push test-group groups))
             *tests*)
    groups))
