Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect termination for nested local definitions #3169

Merged
merged 7 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data
import Juvix.Prelude
import Safe (headMay)

viewCall ::
forall r.
(Members '[Reader SizeInfo] r) =>
(Members '[Reader SizeInfoMap] r) =>
Expression ->
Sem r (Maybe FunCall)
viewCall = \case
Expand All @@ -19,12 +20,15 @@ viewCall = \case
ExpressionApplication (Application f x impl)
| isImplicitOrInstance impl -> viewCall f -- implicit arguments are ignored
| otherwise -> do
c <- viewCall f
x' <- callArg
return $ over callArgs (`snoc` x') <$> c
mc <- viewCall f
case mc of
Just c -> do
x' <- callArg (c ^. callRef)
return $ Just $ over callArgs (`snoc` x') c
Nothing -> return Nothing
where
callArg :: Sem r (CallRow, Expression)
callArg = do
callArg :: FunctionRef -> Sem r (CallRow, Expression)
callArg fref = do
lt <- (^. callRow) <$> lessThan
eq <- (^. callRow) <$> equalTo
return (CallRow (lt `mplus` eq), x)
Expand All @@ -33,19 +37,45 @@ viewCall = \case
lessThan = case viewExpressionAsPattern x of
Nothing -> return (CallRow Nothing)
Just x' -> do
s <- asks (findIndex (elem x') . (^. sizeSmaller))
s <- asks (findIndex (elem x') . (^. sizeSmaller) . findSizeInfo)
return $ case s of
Nothing -> CallRow Nothing
Just s' -> CallRow (Just (s', RLe))
equalTo :: Sem r CallRow
equalTo =
case viewExpressionAsPattern x of
Just x' -> do
s <- asks (elemIndex x' . (^. sizeEqual))
s <- asks (elemIndex x' . (^. sizeEqual) . findSizeInfo)
return $ case s of
Nothing -> CallRow Nothing
Just s' -> CallRow (Just (s', REq))
Nothing -> return (CallRow Nothing)
findSizeInfo :: SizeInfoMap -> SizeInfo
findSizeInfo infos =
{-
If the call is not to any nested function being defined, then we
associate it with the most nested function. Without this,
termination for mutually recursive functions doesn't work.

Consider:
```
isEven (x : Nat) : Bool :=
let
isEven' : Nat -> Bool
| zero := true
| (suc n) := isOdd' n;
isOdd' : Nat -> Bool
| zero := false
| (suc n) := isEven' n;
in isEven' x;
```
The call `isEven' n` inside `isOdd'` needs to be associated with
`isOdd'`, not with `isEven`, and not just forgotten.
-}
fromMaybe (maybe emptySizeInfo snd . headMay $ infos ^. sizeInfoMap)
janmasrovira marked this conversation as resolved.
Show resolved Hide resolved
. (lookup fref)
. (^. sizeInfoMap)
$ infos
_ -> return Nothing
where
singletonCall :: FunctionRef -> FunCall
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ instance Scannable Expression where
buildCallMap =
run
. execState emptyCallMap
. runReader emptySizeInfoMap
. scanTopExpression

runTerminationState :: TerminationState -> Sem (Termination ': r) a -> Sem r (TerminationState, a)
Expand Down Expand Up @@ -122,21 +123,21 @@ scanInductive i = do
scanMutualStatement :: (Members '[State CallMap] r) => MutualStatement -> Sem r ()
scanMutualStatement = \case
StatementInductive i -> scanInductive i
StatementFunction i -> scanFunctionDef i
StatementFunction i -> runReader emptySizeInfoMap $ scanFunctionDef i
StatementAxiom a -> scanAxiom a

scanAxiom :: (Members '[State CallMap] r) => AxiomDef -> Sem r ()
scanAxiom = scanTopExpression . (^. axiomType)

scanFunctionDef ::
(Members '[State CallMap] r) =>
(Members '[State CallMap, Reader SizeInfoMap] r) =>
FunctionDef ->
Sem r ()
scanFunctionDef f@FunctionDef {..} = do
registerFunctionDef f
runReader (Just _funDefName) $ do
scanTypeSignature _funDefType
scanFunctionBody _funDefBody
scanFunctionBody _funDefName _funDefBody
scanDefaultArgs _funDefArgsInfo

scanDefaultArgs ::
Expand All @@ -153,57 +154,68 @@ scanTypeSignature ::
(Members '[State CallMap, Reader (Maybe FunctionRef)] r) =>
Expression ->
Sem r ()
scanTypeSignature = runReader emptySizeInfo . scanExpression
scanTypeSignature = runReader emptySizeInfoMap . scanExpression

scanFunctionBody ::
forall r.
(Members '[State CallMap, Reader (Maybe FunctionRef)] r) =>
(Members '[State CallMap, Reader SizeInfoMap, Reader (Maybe FunctionRef)] r) =>
FunctionName ->
Expression ->
Sem r ()
scanFunctionBody topbody = go [] topbody
scanFunctionBody funName topbody = go [] topbody
where
go :: [PatternArg] -> Expression -> Sem r ()
go revArgs body = case body of
ExpressionLambda Lambda {..} -> mapM_ goClause _lambdaClauses
_ -> runReader (mkSizeInfo (reverse revArgs)) (scanExpression body)
_ ->
local
(over sizeInfoMap ((funName, mkSizeInfo (reverse revArgs)) :))
(scanExpression body)
where
goClause :: LambdaClause -> Sem r ()
goClause (LambdaClause pats clBody) = go (reverse (toList pats) ++ revArgs) clBody

scanLet ::
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfo] r) =>
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfoMap] r) =>
Let ->
Sem r ()
scanLet l = do
mapM_ scanLetClause (l ^. letClauses)
scanExpression (l ^. letExpression)

-- NOTE that we forget about the arguments of the hosting function
scanLetClause :: (Members '[State CallMap] r) => LetClause -> Sem r ()
scanLetClause :: (Members '[State CallMap, Reader SizeInfoMap] r) => LetClause -> Sem r ()
scanLetClause = \case
LetFunDef d -> scanFunctionDef d
LetMutualBlock m -> scanMutualBlockLet m

scanMutualBlockLet :: (Members '[State CallMap] r) => MutualBlockLet -> Sem r ()
scanMutualBlockLet :: (Members '[State CallMap, Reader SizeInfoMap] r) => MutualBlockLet -> Sem r ()
scanMutualBlockLet MutualBlockLet {..} = mapM_ scanFunctionDef _mutualLet

scanTopExpression ::
(Members '[State CallMap] r) =>
Expression ->
Sem r ()
scanTopExpression =
runReader (Nothing @FunctionRef)
. runReader emptySizeInfo
runReader emptySizeInfoMap
. runReader (Nothing @FunctionRef)
. scanExpression

scanExpression ::
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfo] r) =>
(Members '[State CallMap, Reader (Maybe FunctionRef), Reader SizeInfoMap] r) =>
Expression ->
Sem r ()
scanExpression e =
viewCall e >>= \case
Just c -> do
whenJustM (ask @(Maybe FunctionRef)) (\caller -> runReader caller (registerCall c))
-- Are we recursively calling a function being defined?
recCall <- asks (elem (c ^. callRef) . map fst . (^. sizeInfoMap))
if
| recCall ->
runReader (c ^. callRef) (registerCall c)
| otherwise ->
whenJustM
(ask @(Maybe FunctionRef))
(\caller -> runReader caller (registerCall c))
mapM_ (scanExpression . snd) (c ^. callArgs)
Nothing -> case e of
ExpressionApplication a -> directExpressions_ scanExpression a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,22 @@ data SizeInfo = SizeInfo

makeLenses ''SizeInfo

newtype SizeInfoMap = SizeInfoMap
{ _sizeInfoMap :: [(FunctionName, SizeInfo)]
}

makeLenses ''SizeInfoMap

emptySizeInfo :: SizeInfo
emptySizeInfo =
SizeInfo
{ _sizeEqual = mempty,
_sizeSmaller = mempty
}

emptySizeInfoMap :: SizeInfoMap
emptySizeInfoMap = SizeInfoMap []

mkSizeInfo :: [PatternArg] -> SizeInfo
mkSizeInfo args = SizeInfo {..}
where
Expand Down
10 changes: 9 additions & 1 deletion test/Termination/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,15 @@ tests =
PosTest
"Ignore instance arguments"
$(mkRelDir ".")
$(mkRelFile "issue2414.juvix")
$(mkRelFile "issue2414.juvix"),
PosTest
"Nested local definitions"
$(mkRelDir ".")
$(mkRelFile "Nested1.juvix"),
PosTest
"Named arguments"
$(mkRelDir ".")
$(mkRelFile "Nested2.juvix")
]

testsWithKeyword :: [PosTest]
Expand Down
11 changes: 11 additions & 0 deletions tests/positive/Termination/Nested1.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module Nested1;

import Stdlib.Data.List open;

go {A B} (f : A -> B) : List A -> List B
| nil := nil
| (elem :: next) :=
let
var1 := f elem;
var2 := go f next;
in var1 :: var2;
16 changes: 16 additions & 0 deletions tests/positive/Termination/Nested2.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module Nested2;

type MyList A :=
| myNil
| myCons@{
elem : A;
next : MyList A;
};

go {A B} (f : A -> B) : MyList A -> MyList B
| myNil := myNil
| myCons@{elem; next} :=
myCons@{
elem := f elem;
next := go f next;
};
Loading