diff --git a/lib/CodeGen/Llvm/Ir2LlvmIr.hs b/lib/CodeGen/Llvm/Ir2LlvmIr.hs index f03e668..5172001 100644 --- a/lib/CodeGen/Llvm/Ir2LlvmIr.hs +++ b/lib/CodeGen/Llvm/Ir2LlvmIr.hs @@ -1,15 +1,21 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecursiveDo #-} -{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} +{-# LANGUAGE TupleSections #-} -module CodeGen.Llvm.Ir2LlvmIr (ppLlvm, ir2LlvmIr) where +module CodeGen.Llvm.Ir2LlvmIr (ppLlvmModule, genLlvmIrModule) where -import CodeGen.Module (Module (..)) -import Control.Monad.State (MonadFix) +import CodeGen.Module (Module (Module)) +import Control.Monad.State (MonadState, State, evalState, gets, modify) +import Data.Map (Map) +import qualified Data.Map as Map import Data.String.Transform (toShortByteString) +import qualified Data.Text as Text import Data.Text.Lazy (Text) import Foreign (fromBool) import qualified LLVM.AST as LLVM hiding (function) +import qualified LLVM.AST.Constant as C import qualified LLVM.AST.IntegerPredicate as LLVM import qualified LLVM.AST.Type as LLVM import qualified LLVM.IRBuilder.Constant as LLVM @@ -20,34 +26,86 @@ import LLVM.Pretty (ppllvm) import Transformations.Anf.Anf import Trees.Common -ppLlvm :: LLVM.Module -> Text -ppLlvm = ppllvm +ppLlvmModule :: LLVM.Module -> Text +ppLlvmModule = ppllvm -ir2LlvmIr :: Module -> LLVM.Module -ir2LlvmIr = genModule +genLlvmIrModule :: Module -> LLVM.Module +genLlvmIrModule = genModule -- Implementation -genModule :: Module -> LLVM.Module -genModule (Module name code) = LLVM.buildModule (toShortByteString name) undefined +type CodeGenM = LLVM.IRBuilderT Llvm + +type Llvm = LLVM.ModuleBuilderT (State Env) + +data Env = Env + { vars :: Map Identifier' LLVM.Operand, + funs :: Map Identifier' LLVM.Operand + } -genExpr :: - (LLVM.MonadIRBuilder m, LLVM.MonadModuleBuilder m, MonadFix m) => - Expression -> - m LLVM.Operand -genExpr expr = case expr of +genModule :: Module -> LLVM.Module +genModule (Module name (Program decls)) = flip evalState (Env Map.empty Map.empty) $ + LLVM.buildModuleT (toShortByteString name) $ do + notF <- LLVM.extern "not" [LLVM.i64] LLVM.i64 + printBoolF <- LLVM.extern "print_bool" [LLVM.i64] LLVM.i64 + printIntF <- LLVM.extern "print_int" [LLVM.i64] LLVM.i64 + + let stdFuns = + [ (Txt "not", notF), + (Txt "print_bool", printBoolF), + (Txt "print_int", printIntF) + ] + + mapM_ (uncurry regFun) stdFuns + + mapM_ genGlobDecl decls + + -- In the `main` we define our global variables. + LLVM.function "main" [] LLVM.i64 $ \_ -> do + mapM_ gVarDef decls + LLVM.ret (LLVM.int64 0) + where + gVarDef :: GlobalDeclaration -> CodeGenM () + gVarDef = \case + GlobVarDecl ident value -> do + operand <- findVar ident + value' <- genExpr value + store' operand value' + _ -> return () + +genGlobDecl :: GlobalDeclaration -> Llvm () +genGlobDecl decl = case decl of + GlobVarDecl ident _ -> do + var <- LLVM.global (LLVM.mkName $ genId ident) LLVM.i64 (C.Int 64 0) + regVar ident var + GlobFunDecl ident params body -> do + fun <- LLVM.function + (LLVM.mkName $ genId ident) + ((LLVM.i64,) . LLVM.ParameterName . toShortByteString . genId <$> params) + LLVM.i64 + $ \args -> do + mapM_ (uncurry regVar) (params `zip` args) + body' <- genExpr body + LLVM.ret body' + regFun ident fun + +genId :: Identifier' -> String +genId = \case + Txt txt -> Text.unpack txt + Gen n txt -> Text.unpack txt <> "'" <> show n + +genExpr :: Expression -> CodeGenM LLVM.Operand +genExpr = \case ExprAtom atom -> genAtom atom ExprComp ce -> genComp ce - ExprLetIn (ident, value) expr -> undefined - -genAtom :: - (LLVM.MonadIRBuilder m, LLVM.MonadModuleBuilder m, MonadFix m) => - AtomicExpression -> - m LLVM.Operand -genAtom atom = case atom of - AtomId ident -> undefined {- mdo - CompilerState map _ <- get - load' $ map ! name -} + ExprLetIn (ident, val) expr -> mdo + val' <- genExpr val `LLVM.named` toShortByteString (genId ident) + regVar ident val' + genExpr expr + +genAtom :: AtomicExpression -> CodeGenM LLVM.Operand +genAtom = \case + AtomId ident -> findVar ident AtomUnit -> return $ LLVM.int64 0 AtomBool bool -> return $ LLVM.int64 $ fromBool bool AtomInt int -> return $ LLVM.int64 $ toInteger int @@ -77,58 +135,64 @@ genAtom atom = case atom of UnMinusOp -> LLVM.mul (LLVM.int64 (-1)) opF x' -genComp :: - (LLVM.MonadIRBuilder m, LLVM.MonadModuleBuilder m, MonadFix m) => - ComplexExpression -> - m LLVM.Operand -genComp comp = case comp of - CompApp f arg -> undefined - CompIte c t e -> genIte c t e +genComp :: ComplexExpression -> CodeGenM LLVM.Operand +genComp = \case + CompApp f arg -> mdo + -- _ <- LLVM.call printInt [(ourExpression, [])] + undefined + CompIte c t e -> mdo + rv <- allocate' + + c' <- genAtom c >>= intToBool + LLVM.condBr c' tBlock eBlock -genIte :: - (LLVM.MonadIRBuilder m, LLVM.MonadModuleBuilder m, MonadFix m) => - AtomicExpression -> - Expression -> - Expression -> - m LLVM.Operand -genIte c t e = mdo - rv <- LLVM.alloca LLVM.i64 Nothing 0 + tBlock <- LLVM.block `LLVM.named` "if.then" + store' rv =<< genExpr t + LLVM.br end - c' <- genAtom c >>= intToBool - LLVM.condBr c' tBlock eBlock + eBlock <- LLVM.block `LLVM.named` "if.else" + store' rv =<< genExpr e + LLVM.br end - tBlock <- LLVM.block `LLVM.named` "if.then" - store' rv =<< genExpr t - LLVM.br end + end <- LLVM.block `LLVM.named` "if.end" - eBlock <- LLVM.block `LLVM.named` "if.else" - store' rv =<< genExpr e - LLVM.br end + load' rv - end <- LLVM.block `LLVM.named` "if.end" +-- Stack - load' rv +findVar :: MonadState Env m => Identifier' -> m LLVM.Operand +findVar k = gets ((Map.! k) . vars) -allocate :: (LLVM.MonadIRBuilder m) => LLVM.Operand -> m LLVM.Operand +regVar :: MonadState Env m => Identifier' -> LLVM.Operand -> m () +regVar k v = modify $ \env -> env {vars = Map.insert k v (vars env)} + +findFun :: Identifier' -> CodeGenM LLVM.Operand +findFun k = gets ((Map.! k) . funs) + +regFun :: MonadState Env m => Identifier' -> LLVM.Operand -> m () +regFun k v = modify $ \env -> env {funs = Map.insert k v (funs env)} + +-- Allocation utils + +allocate :: LLVM.Operand -> CodeGenM LLVM.Operand allocate value = do - addr <- LLVM.alloca LLVM.i64 (Just (LLVM.int64 0)) 0 + addr <- LLVM.alloca LLVM.i64 Nothing 0 store' addr value - pure addr + return addr + +allocate' :: CodeGenM LLVM.Operand +allocate' = LLVM.alloca LLVM.i64 Nothing 0 -load' :: (LLVM.MonadIRBuilder m) => LLVM.Operand -> m LLVM.Operand +load' :: LLVM.Operand -> CodeGenM LLVM.Operand load' addr = LLVM.load addr 0 -store' :: (LLVM.MonadIRBuilder m) => LLVM.Operand -> LLVM.Operand -> m () +store' :: LLVM.Operand -> LLVM.Operand -> CodeGenM () store' addr = LLVM.store addr 0 --- Utils +-- Conversion utils -boolToInt :: - (LLVM.MonadIRBuilder m, LLVM.MonadModuleBuilder m, MonadFix m) => - (LLVM.Operand -> m LLVM.Operand) +boolToInt :: LLVM.Operand -> CodeGenM LLVM.Operand boolToInt = flip LLVM.zext LLVM.i64 -intToBool :: - (LLVM.MonadIRBuilder m, LLVM.MonadModuleBuilder m, MonadFix m) => - (LLVM.Operand -> m LLVM.Operand) +intToBool :: LLVM.Operand -> CodeGenM LLVM.Operand intToBool = flip LLVM.trunc LLVM.i64