------------------------------------------------------------------------------
--The files in this directory are based on the programs described in:
--
--    A Modular fully-lazy lambda lifter in Haskell
--    Simon L. Peyton Jones and David Lester
--    Software -- Practice and Experience
--    Vol 21(5), pp.479-506
--    MAY 1991
--
--See the Readme file for more details.
------------------------------------------------------------------------------

-- 5.4 A fully lazy lambda lifter

fullyLazyLift :: Expression -> [SCDefn]
fullyLazyLift  = lambdaLift . float . rename
                   . identifyMFEs . addLevels . separateLams

-- 5.5 Separating the lambdas

separateLams                 :: Expression -> Expression
separateLams (EVar v)         = EVar v
separateLams (EConst k)       = EConst k
separateLams (EAp e1 e2)      = EAp (separateLams e1) (separateLams e2)
separateLams (ELam args body) = foldr mkSingleLam body args
                                where mkSingleLam arg bod = ELam [arg] body
separateLams (ELet isRec defns body)
                              = ELet isRec
                                     [(n,separateLams rhs)|(n,rhs)<-defns]
                                     (separateLams body)

-- 5.6 Adding level numbers

type Level     = Int

addLevels     :: Expression -> AnnExpr (Name,Level) Level
addLevels      = freeToLevel . freeVars

freeToLevel   :: AnnExpr Name (Set Name) -> AnnExpr (Name,Level) Level
freeToLevel e  = freeToLevel_e 0 [] e

freeSetToLevel         :: Set Name -> Assn Name Level -> Level
freeSetToLevel free env = maximum (0:map (assLookup env) (setToList free))

freeToLevel_e :: Level
                   -> Assn Name Level
                        -> AnnExpr Name (Set Name)
                             -> AnnExpr (Name,Level) Level

freeToLevel_e lev env (_, AConst k)  = (0, AConst k)
freeToLevel_e lev env (_, AVar v)    = (assLookup env v, AVar v)
freeToLevel_e lev env (_, AAp e1 e2) = (max (levelOf e1') (levelOf e2'),
                                        AAp e1' e2')
                                       where e1' = freeToLevel_e lev env e1
                                             e2' = freeToLevel_e lev env e2

freeToLevel_e lev env (free, ALam args body)
 = (freeSetToLevel free env, ALam args' body')
 where body' = freeToLevel_e (lev+1) (args'++env) body
       args' = zip args (repeat (lev+1))

freeToLevel_e lev env (free, ALet isRec defns body)
 = (levelOf body', ALet isRec defns' body')
 where binders            = bindersOf defns
       freeRhsVars        = setUnionList [free | (free,_) <- rhssOf defns]
       maxRhsLevel        = freeSetToLevel freeRhsVars
                                           ([(name,0) | name<-binders] ++ env)
       defns'             = map freeToLevel_d defns
       body'              = freeToLevel_e lev (bindersOf defns' ++ env) body
       freeToLevel_d (name,rhs)
                          = ((name,levelOf rhs'),rhs')
                            where rhs' = freeToLevel_e lev envRhs rhs
       envRhs | isRec     = [(name,maxRhsLevel) | name<-binders] ++ env
              | not isRec = env

levelOf           :: AnnExpr a Level -> Level
levelOf (level, _) = level

-- 5.7 Identifying MFEs

identifyMFEs :: AnnExpr (Name,Level) Level -> Expr (Name,Level)
identifyMFEs = identifyMFEs_e 0

notMFECandidate (AConst k) = True
notMFECandidate (AVar v)   = True
notMFECandidate _          = False   -- everything else is a candidate

identifyMFEs_e :: Level -> AnnExpr (Name,Level) Level -> Expr (Name,Level)
identifyMFEs_e cxt (level,e)
      | level==cxt || notMFECandidate e  =  e'
      | otherwise                        = transformMFE level e'
      where e' = identifyMFEs_e1 level e

transformMFE level e = ELet nonRecursive [(("v",level),e)] (EVar "v")

identifyMFEs_e1 level (AConst k)  = EConst k
identifyMFEs_e1 level (AVar v)    = EVar v
identifyMFEs_e1 level (AAp e1 e2) = EAp (identifyMFEs_e level e1)
                                        (identifyMFEs_e level e2)
identifyMFEs_e1 level (ALam args body)
                                  = ELam args (identifyMFEs_e argLevel body)
                                    where ((_,argLevel):_) = args
identifyMFEs_e1 level (ALet isRec defns body)
 = ELet isRec defns' body'
   where body'  = identifyMFEs_e level body
         defns' = [(binder,identifyMFEs_e level rhs) | (binder,rhs) <- defns]

-- 5.8 Renaming

rename  :: Expr (Name,a) -> Expr (Name,a)
rename e = e' where (_,e') = rename_e [] initialNameSupply e

rename_e :: Assn Name Name -> NameSupply -> Expr (Name,a)
                -> (NameSupply, Expr (Name, a))
rename_e env ns (EVar v)    = (ns,EVar (assLookup env v))
rename_e env ns (EConst k)  = (ns, EConst k)
rename_e env ns (EAp e1 e2) = (ns'', EAp e1' e2')
                              where (ns', e1') = rename_e env ns  e1
                                    (ns'',e2') = rename_e env ns' e2
rename_e env ns (ELam args body)
  = (ns'', ELam args' body')         -- BUG????
  where (ns', args') = mapAccuml newBinder ns args
        (ns'',body') = rename_e (assocBinders args args' ++ env) ns' body
rename_e env ns (ELet isRec defns body)
  = (ns''', ELet isRec (zip binders' values') body')
  where (ns',  body')      = rename_e env' ns body
        binders            = bindersOf defns
        (ns'', binders')   = mapAccuml newBinder ns' binders
        env'               = assocBinders binders binders' ++ env
        (ns''',values')    = mapAccuml (rename_e rhsEnv) ns'' (rhssOf defns)
        rhsEnv | isRec     = env'
               | not isRec = env

newBinder ns (name,info) = (ns',(name',info))
                           where (ns',name') = newName ns name

assocBinders                 :: [(Name,a)] -> [(Name,a)] -> Assn Name Name
assocBinders binders binders' = zip (map fst binders) (map fst binders')

-- 5.9 Floating

float  :: Expr (Name,Level) -> Expression
float e = install floatedDefns e' where (floatedDefns,e') = float_e e

type FloatedDefns = [(Level, IsRec, [Defn Name])]

install             :: FloatedDefns -> Expression -> Expression
install defnGroups e = foldr installGroup e defnGroups
 where installGroup (level,isRec,defns) e = ELet isRec defns e

float_e            :: Expr (Name,Level) -> (FloatedDefns, Expression)
float_e (EConst k)  = ([], EConst k)
float_e (EVar v)    = ([], EVar v)
float_e (EAp e1 e2) = (fd1++fd2, EAp e1' e2')
                      where (fd1, e1') = float_e e1
                            (fd2, e2') = float_e e2

float_e (ELam args body)
 = (outerLevelDefns, ELam args' (install thisLevelDefns body'))
 where args'                 = [ arg | (arg,level) <- args ]
       (_, thisLevel)        = head args
       (floatedDefns, body') = float_e body
       thisLevelDefns        = filter groupIsThisLevel       floatedDefns
       outerLevelDefns       = filter (not.groupIsThisLevel) floatedDefns
       groupIsThisLevel (level,_,_) = level >= thisLevel

float_e (ELet isRec defns body)
 = (rhsFloatDefns ++ [thisGroup] ++ bodyFloatDefns, body')
 where (bodyFloatDefns, body') = float_e body
       (rhsFloatDefns, defns') = mapAccuml float_defn [] defns
       thisGroup               = (thisLevel, isRec, defns')
       (_, thisLevel)          = head (bindersOf defns)

float_defn floatedDefns ((name,level),rhs)
 = (rhsFloatDefns ++ floatedDefns, (name, rhs'))
 where (rhsFloatDefns, rhs') = float_e rhs
