Skip to content

Commit

Permalink
Merge pull request #602 from HigherOrderCO/f64-u64-lb
Browse files Browse the repository at this point in the history
Check return type on ops + refactor
  • Loading branch information
Lorenzobattistela authored Oct 24, 2024
2 parents a25cb45 + 3cf4269 commit 6910c64
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 11 deletions.
18 changes: 8 additions & 10 deletions src/Kind/Check.hs
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,10 @@ infer sus src term dep = debug ("infer:" ++ (if sus then "* " else " ") ++ showT
go (Op2 opr fst snd) = do
fstT <- infer sus src fst dep
sndT <- infer sus src snd dep

let validTypes = [F64, U64]
let checkValidType typ = do
isValid <- foldr (\t acc -> do
isEqual <- equal typ t dep
if isEqual then return True else acc
) (return False) validTypes
return isValid

isValidType <- checkValidType (getType fstT)
isValidType <- checkValidType (getType fstT) validTypes dep

if not isValidType then do
envLog (Error src (Ref "Valid numeric type") (getType fstT) (Op2 opr fst snd) dep)
envFail
Expand All @@ -112,7 +106,11 @@ infer sus src term dep = debug ("infer:" ++ (if sus then "* " else " ") ++ showT
envLog (Error src (getType fstT) (getType sndT) (Op2 opr fst snd) dep)
envFail
else do
return $ Ann False (Op2 opr fstT sndT) (getType fstT)
book <- envGetBook
fill <- envGetFill
let reducedFst = reduce book fill 1 (getType fstT)
let returnType = getOpReturnType opr reducedFst
return $ Ann False (Op2 opr fstT sndT) returnType

go (Swi zer suc) = do
envLog (Error src (Ref "annotation") (Ref "switch") (Swi zer suc) dep)
Expand Down
3 changes: 2 additions & 1 deletion src/Kind/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ data Loc = Loc String Int Int
data Cod = Cod Loc Loc

-- Numeric Operators
data Oper
data Oper
= ADD | SUB | MUL | DIV
| MOD | EQ | NE | LT
| GT | LTE | GTE | AND
| OR | XOR | LSH | RSH
deriving Show

-- Telescope
data Tele
Expand Down
33 changes: 33 additions & 0 deletions src/Kind/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ module Kind.Util where

import Kind.Show
import Kind.Type
import Kind.Equal

import Prelude hiding (LT, GT, EQ)

import Debug.Trace
import qualified Data.Map.Strict as M
Expand Down Expand Up @@ -127,3 +130,33 @@ getADTCts :: Term -> [(String,Ctr)]
getADTCts (ADT _ cts _) = map (\ ctr -> (getCtrName ctr, ctr)) cts
getADTCts (Src loc val) = getADTCts val
getADTCts term = error ("not-an-adt:" ++ showTerm term)

getOpReturnType :: Oper -> Term -> Term
getOpReturnType ADD U64 = U64
getOpReturnType ADD F64 = F64
getOpReturnType SUB U64 = U64
getOpReturnType SUB F64 = F64
getOpReturnType MUL U64 = U64
getOpReturnType MUL F64 = F64
getOpReturnType DIV U64 = U64
getOpReturnType DIV F64 = F64
getOpReturnType MOD U64 = U64
getOpReturnType EQ _ = U64
getOpReturnType NE _ = U64
getOpReturnType LT _ = U64
getOpReturnType GT _ = U64
getOpReturnType LTE _ = U64
getOpReturnType GTE _ = U64
getOpReturnType AND U64 = U64
getOpReturnType OR U64 = U64
getOpReturnType XOR U64 = U64
getOpReturnType LSH U64 = U64
getOpReturnType RSH U64 = U64
getOpReturnType opr trm = error ("Invalid opertor: " ++ (show opr) ++ " Invalid operand type: " ++ (showTerm trm))

checkValidType :: Term -> [Term] -> Int -> Env Bool
checkValidType typ validTypes dep = foldr (\t acc -> do
isEqual <- equal typ t dep
if isEqual then return True else acc
) (return False) validTypes

0 comments on commit 6910c64

Please sign in to comment.