-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add StatsBase.predict to the interface #81
base: main
Are you sure you want to change the base?
Conversation
Bump, maybe @devmotion or @torfjelde? |
Bump again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cheers for the bump; had missed this!
It's worth noting that DPPL is still not compatible wtih [email protected] so we might also want to add this to [email protected].
Furthermore, I'm slightly worried about the state of AbstractPPL atm; it's not clear if anyone has any ownership of the package atm, and IMO it's objectives are a bit all over the place.
I'd personally be happy to go against what was originally suggested in TuringLang/DynamicPPL.jl#466 (comment) and just putting this directly in DPPL.
Or we need to start giving AbstractPPL some love 😕
@sunxd3 can help backport this to It would be great to update |
I can try and help bring |
@sunxd3 It is related to changing behavior of the colon syntax. You can follow this issue TuringLang/DynamicPPL.jl#440 and the issues it linked. We can discuss this more in our next meeting. |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #81 +/- ##
==========================================
- Coverage 84.82% 80.39% -4.44%
==========================================
Files 3 3
Lines 145 153 +8
==========================================
Hits 123 123
- Misses 22 30 +8
☔ View full report in Codecov by Sentry. |
@torfjelde @sunxd3 @penelopeysm, anything missing here? If not, can we push to merge this? |
as far as I can tell, we can introduce On a higher level, we can also add (I need to finish TuringLang/DynamicPPL.jl#651) |
I think @sethaxen might be preoccupied, so I am taking over. Let me know if this is bad. |
That's fine @sunxd3 , this has shifted way down on my priority list and I won't finish it anytime soon. |
apologies for the ping, this might not be ready yet, but maybe time to take a look and start some new discussions |
I guess the one question is whether we should perform |
just saw the comment, sorry. I thought about the same thing, but unsure what is the right thing to do. The issue for me is that the dimension of the prediction might not match the dimension of the data. How about we don't give a default implementation right now? To clarify, the default implementations for the optional arguments should be included, but not function StatsBase.predict(rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params)
return rand(rng, T, fix(model, params))
end |
function StatsBase.predict(rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) | ||
return rand(rng, T, fix(model, params)) | ||
end | ||
function StatsBase.predict(T::Type, model::AbstractProbabilisticProgram, params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@torfjelde @yebai @penelopeysm do you think type T
still a good idea?
I think it's probably okay for the function to be a bit under-speced now, so DynamicPPL and JuliaBUGS and others can decide what type to return.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure we need a concrete implementation of predict
here; usually, an interface function is a generic function with docstrings explaining the interface (input arguments + returned value).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think when Seth added this, it was modeled after the rand
interface
AbstractPPL.jl/src/abstractprobprog.jl
Lines 108 to 124 in b72a963
""" | |
rand([rng=Random.default_rng()], [T=NamedTuple], model::AbstractProbabilisticProgram) -> T | |
Draw a sample from the joint distribution of the model specified by the probabilistic program. | |
The sample will be returned as format specified by `T`. | |
""" | |
Base.rand(rng::Random.AbstractRNG, ::Type, model::AbstractProbabilisticProgram) | |
function Base.rand(rng::Random.AbstractRNG, model::AbstractProbabilisticProgram) | |
return rand(rng, NamedTuple, model) | |
end | |
function Base.rand(::Type{T}, model::AbstractProbabilisticProgram) where {T} | |
return rand(Random.default_rng(), T, model) | |
end | |
function Base.rand(model::AbstractProbabilisticProgram) | |
return rand(Random.default_rng(), NamedTuple, model) | |
end |
I am for just having a simple predict
now, or at most with rng
, but not output type T
.
Moreover, should we slim down the rand
interface also, as this is going to be a breaking release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want DynamicPPL function to have a type argument T
, it makes sense for this one to have T
as well. Otherwise I don't really see the point of having some interface here and then giving it a different signature in DynamicPPL, which completely ignores the interface.
(Likewise for JuliaBUGS or any other package that inherits this interface)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with this, my comments only reflects that we don't have T argument right now if I recall correctly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohhh, I see! For a general interface, though, if we don't specify T
then we have to choose a privileged output type (like NamedTuple) right? Otherwise if it can return anything then it's not super useful either.
What do you think of something like this:
# default_return_type(model) specifies the default type returned by
# rand([rng, ]model) and predict([rng, ]model, params)
function default_return_type end
# Then we can have rand like this
function Base.rand(
rng::Random.AbstractRNG = Random.default_rng(),
::Type{T} = default_return_type(model),
model::AbstractProbabilisticProgram)
)
AbstractPPL._rand(rng, T, model) # User has to implement this
end
# And predict like this
function StatsBase.predict(
rng::Random.AbstractRNG = Random.default_rng(),
::Type{T} = default_return_type(model),
model::AbstractProbabilisticProgram),
params
)
AbstractPPL._predict(rng, T, model, params) # User has to implement this
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically, I don't feel super comfortable only being able to return NamedTuple. I think the user should be allowed to choose what return type they want (in our case sometimes we might want varinfo). Enforcing a specific return type at this level might be too limiting. I also know I'm possibly overcomplicating things, sorry 😄
Also, I don't know how this would interact with different params types as well. Because the output type would surely depend on whether we pass in one set of params (e.g. a NamedTuple) or multiple sets of params (e.g. a chain). 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I'm happy to merge as is.
This PR doesn't bump version. However, the version was previously bumped in #109 to 0.10.0 and we didn't make a release, so I think it's fine to merge this and then release 0.10.0.
As suggested in TuringLang/DynamicPPL.jl#466 (comment), this PR adds
StatsBase.predict
to the API with a default implementation.