Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Refactor union case patterns #3932

Merged
merged 2 commits into from
Oct 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 66 additions & 87 deletions src/Fable.Transforms/Rust/Fable2Rust.fs
Original file line number Diff line number Diff line change
Expand Up @@ -1603,12 +1603,14 @@ module Util =
transformLeaveContext com ctx argType arg
)

let prepareRefForPatternMatch (com: IRustCompiler) ctx typ (name: string option) fableExpr =
let makeRefForPatternMatch (com: IRustCompiler) ctx typ (nameOpt: string option) fableExpr =
let expr = com.TransformExpr(ctx, fableExpr)

if isThisArgumentIdentExpr ctx fableExpr then
expr
elif (name.IsSome && isRefScoped ctx name.Value) || (isInRefType com typ) then
elif isInRefType com typ then
expr
elif nameOpt.IsSome && isRefScoped ctx nameOpt.Value then
expr
elif shouldBeRefCountWrapped com ctx typ |> Option.isSome then
expr |> makeAsRef
Expand Down Expand Up @@ -1856,13 +1858,13 @@ module Util =
entName + "::" + unionCase.Name
)

let getUnionCaseFields com ctx name tag (unionCase: Fable.UnionCase) =
unionCase.UnionCaseFields
|> List.mapi (fun i field ->
let fieldName = $"{name}_{tag}_{i}"
let fieldType = FableTransforms.uncurryType field.FieldType
makeTypedIdent fieldType fieldName
)
// let getUnionCaseFields com ctx name caseIndex (unionCase: Fable.UnionCase) =
// unionCase.UnionCaseFields
// |> List.mapi (fun i _field ->
// let fieldName = $"{name}_{caseIndex}_{i}"
// let fieldType = FableTransforms.uncurryType field.FieldType
// makeTypedIdent fieldType fieldName
// )

let makeUnion (com: IRustCompiler) ctx r values tag entRef genArgs =
let ent = com.GetEntity(entRef)
Expand Down Expand Up @@ -1991,7 +1993,7 @@ module Util =

let sourceIsRef =
match e with
| Fable.Get(Fable.IdentExpr ident, _, _, _)
| Fable.Get(Fable.IdentExpr ident, _, _, _) -> isArmScoped ctx ident.Name
| MaybeCasted(Fable.IdentExpr ident) -> isRefScoped ctx ident.Name
| _ -> false

Expand All @@ -2002,7 +2004,6 @@ module Util =
let mustClone =
match e with
| MaybeCasted(Fable.IdentExpr ident) ->
// isArmScoped ctx ident.Name ||
// clone non-mutable idents if used more than once
not (ident.IsMutable) && not (isUsedOnce ctx ident.Name) //&& not (isByRefType com ident.Type)
| Fable.Get(_, Fable.FieldGet _, _, _) -> true // always clone field get exprs
Expand Down Expand Up @@ -2558,7 +2559,7 @@ module Util =

let unionCaseName = getUnionCaseName com ctx info.Entity unionCase
let pat = makeUnionCasePat unionCaseName fields
let expr = fableExpr |> prepareRefForPatternMatch com ctx fableExpr.Type None
let expr = makeRefForPatternMatch com ctx fableExpr.Type None fableExpr
let thenExpr = mkGenericPathExpr [ fieldName ] None |> makeClone

let arms = [ mkArm [] pat None thenExpr ]
Expand Down Expand Up @@ -2774,7 +2775,7 @@ module Util =
| Fable.Test(Fable.IdentExpr ident, Fable.UnionCaseTest _, _) ->
// add scoped ident to ctx for thenBody
let usages = calcIdentUsages [ ident ] [ thenBody ]
getScopedIdentCtx com ctx ident true true false false usages
getScopedIdentCtx com ctx ident true false false false usages
| _ -> ctx

transformLeaveContext com ctx None thenBody
Expand Down Expand Up @@ -2888,32 +2889,60 @@ module Util =
mkLetExpr pat downcastExpr
| _ -> makeLibCall com ctx genArgsOpt "Native" "type_test" [ expr ]

let makeUnionCaseTest (com: IRustCompiler) ctx range tag (fableExpr: Fable.Expr) =
match fableExpr.Type with
| Fable.DeclaredType(entRef, genArgs) ->
let ent = com.GetEntity(entRef)
assert (ent.IsFSharpUnion)
// let genArgsOpt = transformGenArgs com ctx genArgs // TODO:
let unionCase = ent.UnionCases |> List.item tag
let makeUnionCasePatOpt (com: IRustCompiler) ctx typ nameOpt caseIndex =
match typ with
| Fable.Option(genArg, _) ->
// let genArgsOpt = transformGenArgs com ctx [genArg]
let unionCaseFullName = [ "Some"; "None" ] |> List.item caseIndex |> rawIdent

let fields =
match fableExpr with
| Fable.IdentExpr ident ->
let fieldIdents = getUnionCaseFields com ctx ident.Name tag unionCase
fieldIdents |> List.map (fun fi -> makeFullNameIdentPat fi.Name)
| _ ->
if List.isEmpty unionCase.UnionCaseFields then
[]
else
[ WILD_PAT ]
match caseIndex with
| 0 ->
match nameOpt with
| Some identName ->
let fieldName = $"{identName}_{caseIndex}_{0}"
[ makeFullNameIdentPat fieldName ]
| _ -> [ WILD_PAT ]
| _ -> []

let unionCaseName =
tryUseKnownUnionCaseNames unionCaseFullName
|> Option.defaultValue unionCaseFullName

let unionCaseName = getUnionCaseName com ctx entRef unionCase
let pat = makeUnionCasePat unionCaseName fields
Some(pat)
| Fable.DeclaredType(entRef, genArgs) ->
let ent = com.GetEntity(entRef)

let expr =
fableExpr
|> prepareRefForPatternMatch com ctx fableExpr.Type (tryGetIdentName fableExpr)
if ent.IsFSharpUnion then
// let genArgsOpt = transformGenArgs com ctx genArgs // TODO:
let unionCase = ent.UnionCases |> List.item caseIndex

let fields =
match nameOpt with
| Some identName ->
unionCase.UnionCaseFields
|> List.mapi (fun i _field ->
let fieldName = $"{identName}_{caseIndex}_{i}"
makeFullNameIdentPat fieldName
)
| _ -> unionCase.UnionCaseFields |> List.map (fun _field -> WILD_PAT)

let unionCaseName = getUnionCaseName com ctx entRef unionCase
let pat = makeUnionCasePat unionCaseName fields
Some(pat)
else
None
| _ -> None

let makeUnionCaseTest (com: IRustCompiler) ctx range tag (fableExpr: Fable.Expr) =
let typ = fableExpr.Type
let nameOpt = tryGetIdentName fableExpr
let patOpt = makeUnionCasePatOpt com ctx typ nameOpt tag

match patOpt with
| Some pat ->
let expr = makeRefForPatternMatch com ctx typ nameOpt fableExpr
let letExpr = mkLetExpr pat expr
letExpr
| _ -> failwith "unreachable"
Expand Down Expand Up @@ -2997,55 +3026,6 @@ module Util =

mkArm attrs pat guard body

let makeUnionCasePatOpt evalType evalName caseIndex =
match evalType with
| Fable.Option(genArg, _) ->
// let genArgsOpt = transformGenArgs com ctx [genArg]
let unionCaseFullName = [ "Some"; "None" ] |> List.item caseIndex |> rawIdent

let fields =
match evalName with
| Some idName ->
match caseIndex with
| 0 ->
let fieldName = $"{idName}_{caseIndex}_{0}"
[ makeFullNameIdentPat fieldName ]
| _ -> []
| _ -> [ WILD_PAT ]

let unionCaseName =
tryUseKnownUnionCaseNames unionCaseFullName
|> Option.defaultValue unionCaseFullName

Some(makeUnionCasePat unionCaseName fields)
| Fable.DeclaredType(entRef, genArgs) ->
let ent = com.GetEntity(entRef)

if ent.IsFSharpUnion then
// let genArgsOpt = transformGenArgs com ctx genArgs
let unionCase = ent.UnionCases |> List.item caseIndex

let fields =
match evalName with
| Some idName ->
unionCase.UnionCaseFields
|> List.mapi (fun i _field ->
let fieldName = $"{idName}_{caseIndex}_{i}"
makeFullNameIdentPat fieldName
)
| _ ->
if List.isEmpty unionCase.UnionCaseFields then
[]
else
[ WILD_PAT ]

let unionCaseName = getUnionCaseName com ctx entRef unionCase

Some(makeUnionCasePat unionCaseName fields)
else
None
| _ -> None

let evalType, evalName =
match evalExpr with
| Fable.Get(Fable.IdentExpr ident, Fable.UnionTag, _, _) -> ident.Type, Some ident.Name
Expand All @@ -3057,7 +3037,7 @@ module Util =
let patOpt =
match caseExpr with
| Fable.Value(Fable.NumberConstant(Fable.NumberValue.Int32 tag, Fable.NumberInfo.Empty), r) ->
makeUnionCasePatOpt evalType evalName tag
makeUnionCasePatOpt com ctx evalType evalName tag
| _ -> None

let pat =
Expand All @@ -3082,11 +3062,11 @@ module Util =
| Fable.Get(Fable.IdentExpr ident, Fable.OptionValue, _, _) when
Some ident.Name = evalName && ident.Type = evalType
->
makeUnionCasePatOpt evalType evalName 0
makeUnionCasePatOpt com ctx evalType evalName 0
| Fable.Get(Fable.IdentExpr ident, Fable.UnionField info, _, _) when
Some ident.Name = evalName && ident.Type = evalType
->
makeUnionCasePatOpt evalType evalName info.CaseIndex
makeUnionCasePatOpt com ctx evalType evalName info.CaseIndex
| _ ->
//need to recurse or this only works for trivial expressions
let subExprs = getSubExpressions expr
Expand All @@ -3098,8 +3078,7 @@ module Util =
let extraVals = namesForIndex evalType evalName targetIndex
makeArm pat targetIndex boundValues extraVals

let expr = evalExpr |> prepareRefForPatternMatch com ctx evalType evalName

let expr = makeRefForPatternMatch com ctx evalType evalName evalExpr
mkMatchExpr expr (arms @ [ defaultArm ])

let matchTargetIdentAndValues idents values =
Expand Down
Loading