Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#302] Advice using the SPECIALIZE pragma #363

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions src/Stan/Analysis/Analyser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ import Extensions (ExtensionsResult)
import GHC.LanguageExtensions.Type (Extension (Strict, StrictData))
import Slist (Slist)

import Stan.Analysis.Visitor (Visitor (..), VisitorState (..), addFixity, addObservation,
addObservations, addOpDecl, getFinalObservations)
import Stan.Analysis.Visitor (Visitor (..), VisitorState (..), addFixity, addFunToSpecialize,
addObservation, addObservations, addOpDecl, addSpecializePragma,
getFinalObservations)
import Stan.Core.Id (Id)
import Stan.Core.List (nonRepeatingPairs)
import Stan.FileInfo (isExtensionDisabled)
Expand All @@ -25,12 +26,12 @@ import Stan.Hie (eqAst)
import Stan.Hie.Compat (HieAST (..), HieFile (..), Identifier, NodeInfo (..), TypeIndex)
import Stan.Hie.MatchAst (hieMatchPatternAst)
import Stan.Inspection (Inspection (..), InspectionAnalysis (..))
import Stan.NameMeta (NameMeta, ghcPrimNameFrom)
import Stan.NameMeta (NameMeta, baseNameFrom, ghcPrimNameFrom, nameFromIdentifier, namesFromAst)
import Stan.Observation (Observations, mkObservation)
import Stan.Pattern.Ast (Literal (..), PatternAst (..), anyNamesToPatternAst, case', constructor,
constructorNameIdentifier, dataDecl, fixity, fun, guardBranch, lambdaCase,
lazyField, literalPat, opApp, patternMatchArrow, patternMatchBranch,
patternMatch_, rhs, tuple, typeSig)
patternMatch_, rhs, specializePragma, tuple, typeSig)
import Stan.Pattern.Edsl (PatternBool (..))

import qualified Data.Map.Strict as Map
Expand Down Expand Up @@ -59,6 +60,7 @@ createVisitor hie exts inspections = Visitor $ \node ->
forM_ inspections $ \Inspection{..} -> case inspectionAnalysis of
FindAst patAst -> matchAst inspectionId patAst hie node
Infix -> analyseInfix hie node
SpecializePragma -> analyseSpecializePragma hie node
LazyField -> when
(isExtensionDisabled StrictData exts && isExtensionDisabled Strict exts)
(analyseLazyFields inspectionId hie node)
Expand Down Expand Up @@ -301,7 +303,7 @@ analyseInfix hie curNode = do
matchInfix :: HieAST TypeIndex -> State VisitorState ()
matchInfix node@Node{..} = when
(hieMatchPatternAst hie node fixity)
(traverse_ addFixity $ concatMap nodeIds nodeChildren)
(traverse_ addFixity $ concatMap namesFromAst nodeChildren)

-- add to state a singleton or empty list with the top-level
-- operator definition:
Expand All @@ -317,19 +319,6 @@ analyseInfix hie curNode = do
(traverse_ (uncurry addOpDecl))
)

-- return AST node identifier names as a sized list of texts
nodeIds :: HieAST TypeIndex -> [Text]
nodeIds =
concatMap fixityName
. Map.keys
. nodeIdentifiers
. nodeInfo

fixityName :: Identifier -> [Text]
fixityName = \case
Left _ -> []
Right name -> [toText $ occNameString $ nameOccName name]

extractOperatorName :: HieAST TypeIndex -> [(Text, RealSrcSpan)]
extractOperatorName Node{..} =
concatMap (topLevelOperatorName nodeSpan)
Expand All @@ -344,6 +333,57 @@ analyseInfix hie curNode = do
-- return empty list if identifier name is not operator name
in [(toText $ occNameString occName, srcSpan) | isSymOcc occName]

{- | Analyse HIE AST to find all operators which lack specialize pragmas
declaration (where appropriate).

The algorithm is the following:

1. Traverse AST and discover all top-level functions and @SPECIALIZE@ pragmas
in a single pass.
2. Compare two resulting sets to find out functions without @SPECIALIZE@ pragmas.
-}
analyseSpecializePragma :: HieFile -> HieAST TypeIndex -> State VisitorState ()
analyseSpecializePragma hie curNode = do
matchSpecializePragma curNode
matchFunToSpecialize curNode
where
-- adds to the state function names defined in a specialize pragma
-- @{-# SPECIALIZE foo :: _ #-}@
matchSpecializePragma :: HieAST TypeIndex -> State VisitorState ()
matchSpecializePragma node@Node{..} = when
(hieMatchPatternAst hie node specializePragma)
(traverse_ addSpecializePragma $ concatMap namesFromAst nodeChildren)

-- add to state a singleton or empty list with the function definition:
matchFunToSpecialize :: HieAST TypeIndex -> State VisitorState ()
matchFunToSpecialize node@Node{..} = when (hieMatchPatternAst hie node typeSig) $
case nodeChildren of
[] -> pass
[_] -> pass
name:rest -> when (findConstraint rest) $ whenJust
-- do nothing when cannot extract name
(viaNonEmpty head $ extractFunName name)
-- add each function from a list (should be singleton list)
(uncurry addFunToSpecialize)

extractFunName :: HieAST TypeIndex -> [(Text, RealSrcSpan)]
extractFunName Node{..} =
concatMap (map (, nodeSpan) . nameFromIdentifier)
$ Map.keys
$ nodeIdentifiers nodeInfo

findConstraint :: [HieAST TypeIndex] -> Bool
findConstraint [] = False
findConstraint (node@Node{..}:rest)
| hieMatchPatternAst hie node monadIO = True
| otherwise = findConstraint nodeChildren || findConstraint rest

monadIO :: PatternAst
monadIO = PatternAstNodeExact (one ("HsAppTy", "HsType"))
[ PatternAstName ("MonadIO" `baseNameFrom` "Control.Monad.IO.Class") (?)
, (?)
]

-- | Returns source spans of matched AST nodes.
createMatch
:: PatternAst
Expand Down
53 changes: 43 additions & 10 deletions src/Stan/Analysis/Visitor.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ module Stan.Analysis.Visitor
, addObservations
, addFixity
, addOpDecl
, addSpecializePragma
, addFunToSpecialize

, Visitor (..)
, visitAst
Expand All @@ -22,7 +24,8 @@ import Relude.Extra.Lens (Lens', lens, over)

import Stan.Ghc.Compat (RealSrcSpan)
import Stan.Hie.Compat (HieAST (..), HieASTs (..), HieFile (..), TypeIndex)
import Stan.Inspection (inspectionId)
import Stan.Inspection (Inspection, inspectionId)
import Stan.Inspection.Performance (stan0401)
import Stan.Inspection.Style (stan0301)
import Stan.Observation (Observation, Observations, mkObservation)

Expand All @@ -35,19 +38,25 @@ import qualified Slist as S
single HIE AST traversal.
-}
data VisitorState = VisitorState
{ visitorStateObservations :: !Observations
{ visitorStateObservations :: !Observations

-- Operators for STAN-0301
, visitorStateFixities :: !(HashMap Text ())
, visitorStateOpDecls :: !(HashMap Text RealSrcSpan)
, visitorStateFixities :: !(HashMap Text ())
, visitorStateOpDecls :: !(HashMap Text RealSrcSpan)

-- Operators for STAN-0401
, visitorStateSpecializePragmas :: !(HashMap Text ())
, visitorStateFunsToSpecialize :: !(HashMap Text RealSrcSpan)
}

-- | Initial empty state.
initialVisitorState :: VisitorState
initialVisitorState = VisitorState
{ visitorStateObservations = mempty
, visitorStateFixities = mempty
, visitorStateOpDecls = mempty
{ visitorStateObservations = mempty
, visitorStateFixities = mempty
, visitorStateOpDecls = mempty
, visitorStateSpecializePragmas = mempty
, visitorStateFunsToSpecialize = mempty
}

{- | Transform 'VisitorState' to the final list of observations for
Expand All @@ -62,10 +71,16 @@ finaliseState hie VisitorState{..} =
-- detected by finding a difference between two sets:
-- 1. Top-level defined operators
-- 2. Fixity declarations for operators in module
let operatorsWithoutFixity = HM.difference visitorStateOpDecls visitorStateFixities
stan0301inss = mkObservation (inspectionId stan0301) hie <$> S.slist (toList operatorsWithoutFixity)
let stan0301inss = evalInspections stan0301 visitorStateOpDecls visitorStateFixities
stan0401inss = evalInspections stan0401 visitorStateFunsToSpecialize visitorStateSpecializePragmas
-- combine final observations
in visitorStateObservations <> stan0301inss
in visitorStateObservations
<> stan0301inss
<> stan0401inss
where
evalInspections :: Inspection -> HashMap Text RealSrcSpan -> HashMap Text () -> Observations
evalInspections ins mapOfAll mapExclude = mkObservation (inspectionId ins) hie <$>
S.slist (toList $ HM.difference mapOfAll mapExclude)

-- | Get sized list of all 'Observations' from the given HIE file
-- using the created 'Visitor'.
Expand Down Expand Up @@ -93,6 +108,16 @@ opDeclsL = lens
visitorStateOpDecls
(\vstate new -> vstate { visitorStateOpDecls = new })

specializePragmasL :: Lens' VisitorState (HashMap Text ())
specializePragmasL = lens
visitorStateSpecializePragmas
(\vstate new -> vstate { visitorStateSpecializePragmas = new })

funsToSpecializeL :: Lens' VisitorState (HashMap Text RealSrcSpan)
funsToSpecializeL = lens
visitorStateFunsToSpecialize
(\vstate new -> vstate { visitorStateFunsToSpecialize = new })

-- | Add single 'Observation' to the existing 'VisitorState'.
addObservation :: Observation -> State VisitorState ()
addObservation obs = modify' $ over observationsL (S.one obs <>)
Expand All @@ -111,6 +136,14 @@ addFixity fixity = modify' $ over fixitiesL (HM.insert fixity ())
addOpDecl :: Text -> RealSrcSpan -> State VisitorState ()
addOpDecl opDecl srcSpan = modify' $ over opDeclsL (HM.insert opDecl srcSpan)

-- | Add single specialize pragma declaration declaration.
addSpecializePragma :: Text -> State VisitorState ()
addSpecializePragma pragma = modify' $ over specializePragmasL (HM.insert pragma ())

-- | Add single function that could be specialized top-level defintion with its position.
addFunToSpecialize :: Text -> RealSrcSpan -> State VisitorState ()
addFunToSpecialize fun srcSpan = modify' $ over funsToSpecializeL (HM.insert fun srcSpan)

-- | Object that implements the /Visitor pattern/.
newtype Visitor = Visitor
{ unVisitor :: HieAST TypeIndex -> State VisitorState ()
Expand Down
2 changes: 2 additions & 0 deletions src/Stan/Inspection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ data InspectionAnalysis
= FindAst !PatternAst
-- | Find all operators without matching @infix[r|l]@
| Infix
-- | Find suitable functions without specialize pragma
| SpecializePragma
-- | Check if the data type has lazy fields
| LazyField
-- | Usage of tuples with size >= 4
Expand Down
2 changes: 2 additions & 0 deletions src/Stan/Inspection/All.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Stan.Inspection (Inspection (..), InspectionsMap)
import Stan.Inspection.AntiPattern (antiPatternInspectionsMap)
import Stan.Inspection.Infinite (infiniteInspectionsMap)
import Stan.Inspection.Partial (partialInspectionsMap)
import Stan.Inspection.Performance (performanceInspectionsMap)
import Stan.Inspection.Style (styleInspectionsMap)

import qualified Data.HashMap.Strict as HM
Expand All @@ -33,6 +34,7 @@ inspectionsMap =
<> infiniteInspectionsMap
<> antiPatternInspectionsMap
<> styleInspectionsMap
<> performanceInspectionsMap

{- | List of all inspections.
-}
Expand Down
49 changes: 49 additions & 0 deletions src/Stan/Inspection/Performance.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{- |
Copyright: (c) 2020 Kowainik
SPDX-License-Identifier: MPL-2.0
Maintainer: Kowainik <[email protected]>

Contains all 'Inspection's for known performance improvements.

The __preformance__ inspections are in ranges:

* @STAN-0401 .. STAN-0500@

-}

module Stan.Inspection.Performance
( -- * Performance inspections
-- *** @SPECIALIZE@ pragma
stan0401

-- * All inspections
, performanceInspectionsMap
) where

import Relude.Extra.Tuple (fmapToFst)

import Stan.Core.Id (Id (..))
import Stan.Inspection (Inspection (..), InspectionAnalysis (..), InspectionsMap)
import Stan.Severity (Severity (..))

import qualified Stan.Category as Category


-- | All performance 'Inspection's map from 'Id's.
performanceInspectionsMap :: InspectionsMap
performanceInspectionsMap = fromList $ fmapToFst inspectionId
[ stan0401
]


-- | 'Inspection' — @SPECIALIZE@ @STAN-0401@.
stan0401 :: Inspection
stan0401 = Inspection
{ inspectionId = Id "STAN-0401"
, inspectionName = "Performance: SPECIALIZE pragma"
, inspectionDescription = "Use {-# SPECIALIZE #-} pragma to improve performance"
, inspectionSolution = []
, inspectionCategory = Category.antiPattern :| []
, inspectionSeverity = Performance
, inspectionAnalysis = SpecializePragma
}
15 changes: 15 additions & 0 deletions src/Stan/NameMeta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ module Stan.NameMeta
, compareNames
, hieMatchNameMeta
, hieFindIdentifier
, namesFromAst
, nameFromIdentifier

-- * Smart constructors
, baseNameFrom
Expand Down Expand Up @@ -115,6 +117,19 @@ hieFindIdentifier nameMeta =
. nodeIdentifiers
. nodeInfo

-- | Return AST node identifier names as a sized list of texts
namesFromAst :: HieAST TypeIndex -> [Text]
namesFromAst =
concatMap nameFromIdentifier
. Map.keys
. nodeIdentifiers
. nodeInfo

nameFromIdentifier :: Identifier -> [Text]
nameFromIdentifier = \case
Left _ -> []
Right name -> [toText $ occNameString $ nameOccName name]

{- | Create 'NameMeta' for a function from the @base@ package and
a given 'ModuleName'. module.
-}
Expand Down
10 changes: 10 additions & 0 deletions src/Stan/Pattern/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ module Stan.Pattern.Ast
, constructorNameIdentifier
, dataDecl
, fixity
, specializePragma
, fun
, guardBranch
, lazyField
Expand Down Expand Up @@ -182,6 +183,15 @@ infixr 7 ***, +++, ???
fixity :: PatternAst
fixity = PatternAstNode $ one ("FixitySig", "FixitySig")

{- | Pattern for the top-level specialize pragmas declaration:

@
{-# SPECIALIZE foo :: ... #-}
@
-}
specializePragma :: PatternAst
specializePragma = PatternAstNode $ one ("SpecSig", "Sig")

{- | Pattern for the function type signature declaration:

@
Expand Down
3 changes: 3 additions & 0 deletions stan.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ library
Stan.Inspection.AntiPattern
Stan.Inspection.Infinite
Stan.Inspection.Partial
Stan.Inspection.Performance
Stan.Inspection.Style
Stan.NameMeta
Stan.Observation
Expand Down Expand Up @@ -172,6 +173,7 @@ library target
Target.AntiPattern.Stan0214
Target.Infinite
Target.Partial
Target.Performance
Target.Style

test-suite stan-test
Expand All @@ -185,6 +187,7 @@ test-suite stan-test
Test.Stan.Analysis.Common
Test.Stan.Analysis.Infinite
Test.Stan.Analysis.Partial
Test.Stan.Analysis.Performance
Test.Stan.Analysis.Style
Test.Stan.Cli
Test.Stan.Config
Expand Down
14 changes: 14 additions & 0 deletions target/Target/Performance.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{-# OPTIONS_GHC -fno-warn-missing-export-lists #-}

module Target.Performance where

import Control.Monad.IO.Class (MonadIO)



foo :: (MonadIO m, Functor m) => m ()
foo = undefined

bar :: MonadIO m => Functor m => m ()
bar = undefined
{-# SPECIALIZE bar :: IO () #-}
Loading