Skip to content

Commit

Permalink
Merge pull request #205 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 1.11.0 release
  • Loading branch information
ablaom committed Jun 27, 2024
2 parents 08414af + e3d2571 commit 8107f2c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 3 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.10.0"
version = "1.11.0"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -19,7 +19,7 @@ OrderedCollections = "1"
Random = "<0.0.1, 1"
ScientificTypes = "3"
ScientificTypesBase = "3"
StatisticalTraits = "3.3"
StatisticalTraits = "3.4"
Tables = "1"
Test = "<0.0.1, 1"
julia = "1.6"
Expand Down
1 change: 1 addition & 0 deletions src/MLJModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const MODEL_TRAITS = [
:predict_scitype,
:transform_scitype,
:inverse_transform_scitype,
:target_in_fit,
:is_pure_julia,
:package_name,
:package_license,
Expand Down
13 changes: 12 additions & 1 deletion src/model_traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ end
StatTraits.is_supervised(::Type{<:Supervised}) = true
StatTraits.is_supervised(::Type{<:SupervisedAnnotator}) = true

StatTraits.target_in_fit(::Type{<:Supervised}) = true
StatTraits.target_in_fit(::Type{<:Unsupervised}) = false

StatTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
StatTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
StatTraits.prediction_type(::Type{<:Interval}) = :interval
Expand Down Expand Up @@ -73,7 +76,15 @@ function supervised_fit_data_scitype(M)
return ret
end

StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) = Tuple{input_scitype(M)}
# helper to determine the scitype of unsupervised models
function unsupervised_fit_data_scitype(M)
I = input_scitype(M)
T = target_scitype(M)
target_in_fit(M) && return Tuple{I, T}
return Tuple{I}
end

StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) = unsupervised_fit_data_scitype(M)
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
StatTraits.fit_data_scitype(M::Type{<:Supervised}) = supervised_fit_data_scitype(M)

Expand Down
14 changes: 14 additions & 0 deletions test/model_traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ end
@mlj_model mutable struct UA <: UnsupervisedAnnotator
end

@mlj_model mutable struct SupervisedTransformer <: Unsupervised
end


foo(::P1) = 0
bar(::P1) = nothing

Expand All @@ -34,6 +38,10 @@ M.package_name(::Type{<:U1}) = "Bach"
M.package_url(::Type{<:U1}) = "www.did_he_write_565.com"
M.human_name(::Type{<:U1}) = "funky model"

M.target_in_fit(::Type{<:SupervisedTransformer}) = true
M.target_scitype(::Type{<:SupervisedTransformer}) = Continuous
M.input_scitype(::Type{<:SupervisedTransformer}) = Finite

@testset "traits" begin
ms = S1()
mu = U1(a=42, b=sin)
Expand All @@ -42,6 +50,7 @@ M.human_name(::Type{<:U1}) = "funky model"
mi = I1()
sa = SA()
ua = UA()
supervised_transformer = SupervisedTransformer()

@test input_scitype(ms) == Unknown
@test output_scitype(ms) == Unknown
Expand Down Expand Up @@ -115,6 +124,11 @@ M.human_name(::Type{<:U1}) = "funky model"
setfull()

@test Set(implemented_methods(mp)) == Set([:clean!,:bar,:foo])

@test fit_data_scitype(mu) == Tuple{Unknown};;;
@test fit_data_scitype(mu) == Tuple{Unknown}
@test fit_data_scitype(supervised_transformer) == Tuple{Finite,Continuous}

end

@testset "`_density` - helper for predict_scitype fallback" begin
Expand Down

0 comments on commit 8107f2c

Please sign in to comment.