{-# OPTIONS_GHC -fno-warn-orphans #-}

module Data.GenValidity.Tree (genTreeOf, shrinkTreeOf) where

import Data.GenValidity
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NE
import Data.Tree
import Data.Validity.Tree ()
import Test.QuickCheck

instance (GenValid a) => GenValid (Tree a) where
  genValid :: Gen (Tree a)
genValid = Gen a -> Gen (Tree a)
forall a. Gen a -> Gen (Tree a)
genTreeOf Gen a
forall a. GenValid a => Gen a
genValid
  shrinkValid :: Tree a -> [Tree a]
shrinkValid = (a -> [a]) -> Tree a -> [Tree a]
forall a. (a -> [a]) -> Tree a -> [Tree a]
shrinkTreeOf a -> [a]
forall a. GenValid a => a -> [a]
shrinkValid

shrinkTreeOf :: (a -> [a]) -> Tree a -> [Tree a]
shrinkTreeOf :: forall a. (a -> [a]) -> Tree a -> [Tree a]
shrinkTreeOf a -> [a]
shrinker (Node a
v [Tree a]
ts) =
  [a -> [Tree a] -> Tree a
forall a. a -> [Tree a] -> Tree a
Node a
v' [Tree a]
ts' | (a
v', [Tree a]
ts') <- (a -> [a])
-> ([Tree a] -> [[Tree a]]) -> (a, [Tree a]) -> [(a, [Tree a])]
forall a b. (a -> [a]) -> (b -> [b]) -> (a, b) -> [(a, b)]
shrinkTuple a -> [a]
shrinker ((Tree a -> [Tree a]) -> [Tree a] -> [[Tree a]]
forall a. (a -> [a]) -> [a] -> [[a]]
shrinkList ((a -> [a]) -> Tree a -> [Tree a]
forall a. (a -> [a]) -> Tree a -> [Tree a]
shrinkTreeOf a -> [a]
shrinker)) (a
v, [Tree a]
ts)]

-- | Generate a tree of values that are generated as specified.
--
-- This takes the size parameter much better into account
genTreeOf :: Gen a -> Gen (Tree a)
genTreeOf :: forall a. Gen a -> Gen (Tree a)
genTreeOf Gen a
func = do
  ne <- Gen a -> Gen (NonEmpty a)
forall a. Gen a -> Gen (NonEmpty a)
genNonEmptyOf Gen a
func
  turnIntoTree ne
  where
    turnIntoTree :: NonEmpty a -> Gen (Tree a)
    turnIntoTree :: forall a. NonEmpty a -> Gen (Tree a)
turnIntoTree (a
e :| [a]
es) = do
      groups <- [a] -> Gen [NonEmpty a]
forall a. [a] -> Gen [NonEmpty a]
turnIntoGroups [a]
es
      subtrees <- mapM turnIntoTree groups
      pure (Node e subtrees)

    turnIntoGroups :: [a] -> Gen [NonEmpty a]
    turnIntoGroups :: forall a. [a] -> Gen [NonEmpty a]
turnIntoGroups = [a] -> [a] -> Gen [NonEmpty a]
forall a. [a] -> [a] -> Gen [NonEmpty a]
go []
      where
        go :: [a] -> [a] -> Gen [NonEmpty a]
        go :: forall a. [a] -> [a] -> Gen [NonEmpty a]
go [a]
acc [] =
          case [a] -> Maybe (NonEmpty a)
forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty [a]
acc of
            Maybe (NonEmpty a)
Nothing -> [NonEmpty a] -> Gen [NonEmpty a]
forall a. a -> Gen a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
            Just NonEmpty a
ne -> [NonEmpty a] -> Gen [NonEmpty a]
forall a. a -> Gen a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [NonEmpty a
ne]
        go [a]
acc (a
e : [a]
es) =
          [(Int, Gen [NonEmpty a])] -> Gen [NonEmpty a]
forall a. HasCallStack => [(Int, Gen a)] -> Gen a
frequency
            [ ( Int
1,
                do
                  rest <- [a] -> [a] -> Gen [NonEmpty a]
forall a. [a] -> [a] -> Gen [NonEmpty a]
go [] [a]
es
                  pure ((e :| acc) : rest)
              ),
              (Int
4, [a] -> [a] -> Gen [NonEmpty a]
forall a. [a] -> [a] -> Gen [NonEmpty a]
go (a
e a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc) [a]
es)
            ]