Skip to content

Commit

Permalink
Detect termination for nested local definitions (#3169)
Browse files Browse the repository at this point in the history
* Closes #3147 

When we call a function that is currently being defined (there may be
several such due to nested local definitions), we add a reflexive edge
in the call map instead of adding an edge from the most nested
definition. For example, for

```juvix
go {A B} (f : A -> B) : List A -> List B
  | nil := nil
  | (elem :: next) :=
    let var1 := f elem;
        var2 := go f next;
    in var1 :: var2;
```

we add an edge from `go` to the recursive call `go f next`, instead of
adding an edge from `var2` to `go f next` as before.

This makes the above type-check.

The following still doesn't type-check, because `next'` is not a
subpattern of the clause pattern of `go`. But this is a less pressing
problem.

```juvix
go {A B} (f : A -> B) : List A -> List B
  | nil := nil
  | (elem :: next) :=
    let var1 := f elem;
        var2 (next' : List A) : List B := go f next';
    in myCons var1 (var2 next);
```
  • Loading branch information
lukaszcz authored Nov 15, 2024
1 parent 1d7bf1f commit 9f25ffd
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data
import Juvix.Prelude
import Safe (headMay)

viewCall ::
forall r.
(Members '[Reader SizeInfo] r) =>
(Members '[Reader SizeInfoMap] r) =>
Expression ->
Sem r (Maybe FunCall)
viewCall = \case
Expand All @@ -19,12 +20,15 @@ viewCall = \case
ExpressionApplication (Application f x impl)
| isImplicitOrInstance impl -> viewCall f -- implicit arguments are ignored
| otherwise -> do
c <- viewCall f
x' <- callArg
return $ over callArgs (`snoc` x') <$> c
mc <- viewCall f
case mc of
Just c -> do
x' <- callArg (c ^. callRef)
return $ Just $ over callArgs (`snoc` x') c
Nothing -> return Nothing
where
callArg :: Sem r (CallRow, Expression)
callArg = do
callArg :: FunctionRef -> Sem r (CallRow, Expression)
callArg fref = do
lt <- (^. callRow) <$> lessThan
eq <- (^. callRow) <$> equalTo
return (CallRow (lt `mplus` eq), x)
Expand All @@ -33,19 +37,45 @@ viewCall = \case
lessThan = case viewExpressionAsPattern x of
Nothing -> return (CallRow Nothing)
Just x' -> do
s <- asks (findIndex (elem x') . (^. sizeSmaller))
s <- asks (findIndex (elem x') . (^. sizeSmaller) . findSizeInfo)
return $ case s of
Nothing -> CallRow Nothing
Just s' -> CallRow (Just (s', RLe))
equalTo :: Sem r CallRow
equalTo =
case viewExpressionAsPattern x of
Just x' -> do
s <- asks (elemIndex x' . (^. sizeEqual))
s <- asks (elemIndex x' . (^. sizeEqual) . findSizeInfo)
return $ case s of
Nothing -> CallRow Nothing
Just s' -> CallRow (Just (s', REq))
Nothing -> return (CallRow Nothing)
findSizeInfo :: SizeInfoMap -> SizeInfo
findSizeInfo infos =
{-
If the call is not to any nested function being defined, then we
associate it with the most nested function. Without this,
termination for mutually recursive functions doesn't work.
Consider:
```
isEven (x : Nat) : Bool :=
let
isEven' : Nat -> Bool
| zero := true
| (suc n) := isOdd' n;
isOdd' : Nat -> Bool
| zero := false
| (suc n) := isEven' n;
in isEven' x;
```
The call `isEven' n` inside `isOdd'` needs to be associated with
`isOdd'`, not with `isEven`, and not just forgotten.
-}
fromMaybe (maybe emptySizeInfo snd . headMay $ infos ^. sizeInfoMap)
. (lookup fref)
. (^. sizeInfoMap)
$ infos
_ -> return Nothing
where
singletonCall :: FunctionRef -> FunCall
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ instance Scannable Expression where
buildCallMap =
run
. execState emptyCallMap
. runReader emptySizeInfoMap
. scanTopExpression

runTerminationState :: TerminationState -> Sem (Termination ': r) a -> Sem r (TerminationState, a)
Expand Down Expand Up @@ -122,21 +123,21 @@ scanInductive i = do
scanMutualStatement :: (Members '[State CallMap] r) => MutualStatement -> Sem r ()
scanMutualStatement = \case
StatementInductive i -> scanInductive i
StatementFunction i -> scanFunctionDef i
StatementFunction i -> runReader emptySizeInfoMap $ scanFunctionDef i
StatementAxiom a -> scanAxiom a

scanAxiom :: (Members '[State CallMap] r) => AxiomDef -> Sem r ()
scanAxiom = scanTopExpression . (^. axiomType)

scanFunctionDef ::
(Members '[State CallMap] r) =>
(Members '[State CallMap, Reader SizeInfoMap] r) =>
FunctionDef ->
Sem r ()
scanFunctionDef f@FunctionDef {..} = do
registerFunctionDef f
runReader (Just _funDefName) $ do
scanTypeSignature _funDefType
scanFunctionBody _funDefBody
scanFunctionBody _funDefName _funDefBody
scanDefaultArgs _funDefArgsInfo

scanDefaultArgs ::
Expand All @@ -153,57 +154,68 @@ scanTypeSignature ::
(Members '[State CallMap, Reader (Maybe FunctionRef)] r) =>
Expression ->
Sem r ()
scanTypeSignature = runReader emptySizeInfo . scanExpression
scanTypeSignature = runReader emptySizeInfoMap . scanExpression

scanFunctionBody ::
forall r.
(Members '[State CallMap, Reader (Maybe FunctionRef)] r) =>
(Members '[State CallMap, Reader SizeInfoMap, Reader (Maybe FunctionRef)] r) =>
FunctionName ->
Expression ->
Sem r ()
scanFunctionBody topbody = go [] topbody
scanFunctionBody funName topbody = go [] topbody
where
go :: [PatternArg] -> Expression -> Sem r ()
go revArgs body = case body of
ExpressionLambda Lambda {..} -> mapM_ goClause _lambdaClauses
_ -> runReader (mkSizeInfo (reverse revArgs)) (scanExpression body)
_ ->
local
(over sizeInfoMap ((funName, mkSizeInfo (reverse revArgs)) :))
(scanExpression body)
where
goClause :: LambdaClause -> Sem r ()
goClause (LambdaClause pats clBody) = go (reverse (toList pats) ++ revArgs) clBody

scanLet ::
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfo] r) =>
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfoMap] r) =>
Let ->
Sem r ()
scanLet l = do
mapM_ scanLetClause (l ^. letClauses)
scanExpression (l ^. letExpression)

-- NOTE that we forget about the arguments of the hosting function
scanLetClause :: (Members '[State CallMap] r) => LetClause -> Sem r ()
scanLetClause :: (Members '[State CallMap, Reader SizeInfoMap] r) => LetClause -> Sem r ()
scanLetClause = \case
LetFunDef d -> scanFunctionDef d
LetMutualBlock m -> scanMutualBlockLet m

scanMutualBlockLet :: (Members '[State CallMap] r) => MutualBlockLet -> Sem r ()
scanMutualBlockLet :: (Members '[State CallMap, Reader SizeInfoMap] r) => MutualBlockLet -> Sem r ()
scanMutualBlockLet MutualBlockLet {..} = mapM_ scanFunctionDef _mutualLet

scanTopExpression ::
(Members '[State CallMap] r) =>
Expression ->
Sem r ()
scanTopExpression =
runReader (Nothing @FunctionRef)
. runReader emptySizeInfo
runReader emptySizeInfoMap
. runReader (Nothing @FunctionRef)
. scanExpression

scanExpression ::
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfo] r) =>
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfoMap] r) =>
Expression ->
Sem r ()
scanExpression e =
viewCall e >>= \case
Just c -> do
whenJustM (ask @(Maybe FunctionRef)) (\caller -> runReader caller (registerCall c))
-- Are we recursively calling a function being defined?
recCall <- asks (elem (c ^. callRef) . map fst . (^. sizeInfoMap))
if
| recCall ->
runReader (c ^. callRef) (registerCall c)
| otherwise ->
whenJustM
(ask @(Maybe FunctionRef))
(\caller -> runReader caller (registerCall c))
mapM_ (scanExpression . snd) (c ^. callArgs)
Nothing -> case e of
ExpressionApplication a -> directExpressions_ scanExpression a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,22 @@ data SizeInfo = SizeInfo

makeLenses ''SizeInfo

newtype SizeInfoMap = SizeInfoMap
{ _sizeInfoMap :: [(FunctionName, SizeInfo)]
}

makeLenses ''SizeInfoMap

emptySizeInfo :: SizeInfo
emptySizeInfo =
SizeInfo
{ _sizeEqual = mempty,
_sizeSmaller = mempty
}

emptySizeInfoMap :: SizeInfoMap
emptySizeInfoMap = SizeInfoMap []

mkSizeInfo :: [PatternArg] -> SizeInfo
mkSizeInfo args = SizeInfo {..}
where
Expand Down
10 changes: 9 additions & 1 deletion test/Termination/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,15 @@ tests =
PosTest
"Ignore instance arguments"
$(mkRelDir ".")
$(mkRelFile "issue2414.juvix")
$(mkRelFile "issue2414.juvix"),
PosTest
"Nested local definitions"
$(mkRelDir ".")
$(mkRelFile "Nested1.juvix"),
PosTest
"Named arguments"
$(mkRelDir ".")
$(mkRelFile "Nested2.juvix")
]

testsWithKeyword :: [PosTest]
Expand Down
11 changes: 11 additions & 0 deletions tests/positive/Termination/Nested1.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module Nested1;

import Stdlib.Data.List open;

go {A B} (f : A -> B) : List A -> List B
| nil := nil
| (elem :: next) :=
let
var1 := f elem;
var2 := go f next;
in var1 :: var2;
16 changes: 16 additions & 0 deletions tests/positive/Termination/Nested2.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module Nested2;

type MyList A :=
| myNil
| myCons@{
elem : A;
next : MyList A;
};

go {A B} (f : A -> B) : MyList A -> MyList B
| myNil := myNil
| myCons@{elem; next} :=
myCons@{
elem := f elem;
next := go f next;
};

0 comments on commit 9f25ffd

Please sign in to comment.