Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MultiData = "8cc5100c-b3d1-4f82-90cb-0ea93d317aba"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Query = "1a8c2f83-1ff3-5112-b086-8aa67b057ba1"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
67 changes: 67 additions & 0 deletions ext/SoleDataViz
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
module SoleDataViz

using SoleData
using GLMakie, GeometryBasics



function plot_scalardnf(
io::IO,
formula::DNF;
show_all_variables=false,
palette=[:cyan, :green, :yellow, :magenta, :blue],
body_char='=', # alternatives: ■, ━
gap = 0.0,
scale = Vec3f(1.0, 1.0, 0.5),
)
formula = normalize(formula)
disjs = SoleLogics.disjuncts(formula)
all_intervals = [extract_intervals(d) for d in disjs]
# Gather all variables
all_vars = Set{Any}()
for intervals in all_intervals, v in keys(intervals)
push!(all_vars, v)
end
all_vars = sort(collect(all_vars))
all_thresholds = collect_thresholds(all_intervals)

cube = fill(false, (length(all_vars), length(all_thresholds) - 1, length(disjs)))
# For each disjunct, produce a set of colored bars
for (i, (d, intervals)) in enumerate(zip(disjs, all_intervals))
for (j, var) in enumerate(all_vars)
if !haskey(intervals, var)
continue
end
interval = intervals[var]
segmenttypes, first_idx, last_idx = SoleData.compute_segmenttypes(interval, all_thresholds)
cube[j,:,i] = map(x->x.full, segmenttypes)
end
end

transf = Transformation(; scale) # anisotropic scaling

# Coloring index
color_index = [x for x in 1:size(cube, 1), y in 1:size(cube, 2), z in 1:size(cube, 3)]
color_data = ifelse.(cube, color_index, NaN)

# spacing = (1.5, 1.5, 2.0) # X, Y, Z

fig = Figure()
ax = LScene(fig[1, 1], scenekw = (camera = cam3d!,))

voxels!(ax,
0 .. (size(cube, 1)-1),
0 .. (size(cube, 2)-1),
0 .. (size(cube, 3)-1),
color_data;
colormap = palette,
colorrange = (1, length(palette)),
gap
# transformation = transf
)

fig
end


end
4 changes: 4 additions & 0 deletions src/SoleData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ include("utils/multilogiset.jl")

export @scalarformula

module IntervalSetsWrap
using IntervalSets
end

include("scalar/main.jl")

export initlogiset, ninstances, maxchannelsize, worldtype, dimensionality, allworlds, featvalue
Expand Down
4 changes: 0 additions & 4 deletions src/scalar/conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,6 @@ function _rangescalarcond_to_scalarconds_in_conjunction(cond)
conds
end

module IntervalSetsWrap
using IntervalSets: Interval
end

# function myisless(a::Number, aismin::Bool, b::Number, bismin::Bool)
# return a < b
# end
Expand Down
2 changes: 2 additions & 0 deletions src/scalar/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ include("random.jl")

include("representatives.jl")

include("visualizations.jl")

# # Types for representing common associations between features and operators
# include("canonical-conditions.jl") # TODO remove

Expand Down
4 changes: 3 additions & 1 deletion src/scalar/var-features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import SoleData: AbstractFeature

using MultiData: instance_channel

import Base: show
import Base: show, isless
import SoleLogics: syntaxstring

# Feature parentheses (e.g., for parsing/showing "main[V2]")
Expand Down Expand Up @@ -301,6 +301,8 @@ struct VariableValue{I<:VariableId, N<:Union{VariableName, Nothing}} <: Abstract
end
featurename(f::VariableValue) = !isnothing(f.i_name) ? f.i_name : ""

Base.isless(a::VariableValue, b::VariableValue) = isless(a.i_variable, b.i_variable)

function syntaxstring(f::VariableValue; variable_names_map = nothing, show_colon = false, kwargs...)
if !isnothing(f.i_name)
opening_parenthesis = UVF_OPENING_PARENTHESIS
Expand Down
193 changes: 193 additions & 0 deletions src/scalar/visualizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
export show_scalardnf

using Printf: @sprintf
using SoleData: AbstractScalarCondition

using .IntervalSetsWrap: infimum, supremum, isleftopen, isrightopen, :(..)

function extract_intervals(conjunction::LeftmostConjunctiveForm)
_extract_intervals(SoleLogics.conjuncts(conjunction), true)
end
function extract_intervals(disjunction::LeftmostDisjunctiveForm)
_extract_intervals(SoleLogics.disjuncts(disjunction), false)
end

# polarity = true computes the intersection
# polarity = false computes the union
function _extract_intervals(atoms::Vector, polarity::Bool)
by_var = Dict{Any, IntervalSetsWrap.Interval}()

for atom in atoms
@assert atom isa Atom
cond = SoleLogics.value(atom)
@assert cond isa AbstractScalarCondition typeof(cond)
feat = SoleData.feature(cond)
interval = tointervalset(cond)
by_var[feat] = if haskey(by_var, feat)
# Update borders
if polarity
by_var[feat] ∩ interval
else
by_var[feat] ∪ interval
end
else
interval
end
end
return by_var
end


function collect_thresholds(all_intervals)
thresholds = Set{Number}()
for intervals in all_intervals
for interval in values(intervals)
push!(thresholds, infimum(interval))
push!(thresholds, supremum(interval))
end
end
return sort(collect(thresholds))
end

mutable struct IntervalType
left_closed::Bool
full::Bool
right_closed::Bool
end

function compute_segmenttypes(interval, thresholds)
(minv, mini, maxv, maxi) = infimum(interval), !isleftopen(interval), supremum(interval), !isrightopen(interval)
nseg = length(thresholds) - 1

isnonempty = [begin
t0 = thresholds[i]
t1 = thresholds[i+1]
if maxv <= t0 || minv >= t1
false
else
true
end
end for i in 1:nseg]

# Default all non-empty segmenttypes to be closed
segmenttypes = map(i-> IntervalType(i, i, i), isnonempty)
first_idx = findfirst(identity, isnonempty)
last_idx = findlast(identity, isnonempty)
if !mini && first_idx !== nothing
segmenttypes[first_idx].left_closed = false
end
if !maxi && last_idx !== nothing
segmenttypes[last_idx].right_closed = false
end
return segmenttypes, first_idx, last_idx
end

function draw_bar(segmenttypes, first_idx, last_idx; colwidth=5, body_char = "=")

segments_str = fill(" " ^ colwidth, length(segmenttypes))

for i in 1:length(segmenttypes)
if segmenttypes[i].full
segments_str[i] = body_char ^ colwidth
end
end

if !isnothing(first_idx)
segments_str[first_idx] = let x = collect(segments_str[first_idx])
x[1] = (segmenttypes[first_idx].left_closed ? '[' : '(')
String(x)
end
end
if !isnothing(last_idx)
segments_str[last_idx] = let x = collect(segments_str[last_idx])
x[colwidth] = (segmenttypes[last_idx].right_closed ? ']' : ')')
String(x)
end
end

return " " ^ colwidth * join(segments_str)
end


show_scalardnf(f::DNF; kwargs...) = show_scalardnf(stdout, f; kwargs...)

"""
Produce a graphical representation for a scalar DNF formula.

- `show_all_variables::Bool = false`: whether to force the printing of always-true variable constraints (e.g., \$-∞ <= V1 <= ∞\$)
- `print_disjunct_nrs::Bool = false`: whether to print the progressive number for each disjunct
- `palette::Vector`
"""
function show_scalardnf(
io::IO,
formula::DNF;
show_all_variables=false,
print_disjunct_nrs=false,
palette=[:cyan, :green, :yellow, :magenta, :blue],
colwidth=5,
body_char='=', # alternatives: ■, ━
)
@assert colwidth >= 5
formula = normalize(formula)
disjs = SoleLogics.disjuncts(formula)
all_intervals = [extract_intervals(d) for d in disjs]
# Gather all variables
all_vars = Set{Any}()
for intervals in all_intervals, v in keys(intervals)
push!(all_vars, v)
end
all_vars = sort(collect(all_vars))
all_thresholds = collect_thresholds(all_intervals)

# Maximum length for variable names
namewidth = maximum(length(syntaxstring(v)) for v in all_vars)

# header
header = " " ^ (3+colwidth+namewidth)
for t in all_thresholds
header *= @sprintf("%-*.*f", colwidth, 2, t)
end
println(io, header)
println(io)

# Variable-color mapping
colors = Dict()
var_colors = Dict{Any,Symbol}()
for (i, v) in enumerate(all_vars)
var_colors[v] = get(colors, v, palette[(i-1) % length(palette) + 1])
end

# For each disjunct, produce a set of colored bars
for (i, (d, intervals)) in enumerate(zip(disjs, all_intervals))
print_disjunct_nrs && println(io, "Disjunct $i: ", syntaxstring(normalize(d)))
for v in all_vars
if !haskey(intervals, v)
if show_all_variables
interval = IntervalSetsWrap.Interval(-Inf .. Inf)
else
continue
end
else
interval = intervals[v]
end

# Variable is in disjunct
if interval == IntervalSetsWrap.Interval(-Inf .. Inf) && !show_all_variables
# Avoid showing all variables
continue
end

segmenttypes, first_idx, last_idx = compute_segmenttypes(interval, all_thresholds)

bar = draw_bar(segmenttypes, first_idx, last_idx; colwidth, body_char)
# colore
color = var_colors[v]
# stampo nome e barre
print(io, " ")
printstyled(io, rpad(syntaxstring(v), namewidth), " : ", bar, color=color)
println(io)
end
println(io)
end
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ test_suites = [
("Conditions", [ "range-scalar-condition.jl", ]),
("Alphabets", [ "scalar-alphabet.jl", "discretization.jl"]),
("Features", [ "patchedfeatures.jl"]),
("Visualizations", [ "visualizations.jl", ]),
#
("MLJ", [ "MLJ.jl", ]),
("PLA", [ "pla.jl", ]),
Expand Down
62 changes: 62 additions & 0 deletions test/visualizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using SoleData

@testset "Visualizations" begin
f = @scalarformula(
((V1 < 5.85) ∧ (V1 ≥ 5.65) ∧ (V2 < 2.85) ∧ (V3 < 4.55) ∧ (V3 ≥ 4.45)) ∨
((V1 < 5.3) ∧ (V2 ≥ 2.85) ∧ (V3 < 5.05) ∧ (V3 ≥ 4.85) ∧ (V4 < 0.35))
) |> dnf

check_same(a, b) = myclean(a) == myclean(b)
function myclean(s)
lines = collect(split(s, '\n'))
join(filter(x->x != "", map(x -> rstrip(x), lines)), "\n")
end

@test myclean(
begin
io = IOBuffer()
show_scalardnf(
io,
f;
colwidth=6,
)
String(take!(io))
end) == myclean("
-Inf 0.35 2.85 4.45 4.55 4.85 5.05 5.30 5.65 5.85 Inf

V1 : [====)
V2 : [==========)
V3 : [====)

V1 : [========================================)
V2 : [==============================================]
V3 : [====)
V4 : [====)
")

@test myclean(
begin
io = IOBuffer()
show_scalardnf(
io,
f;
show_all_variables=true,
colwidth=6,
)
String(take!(io))
end) == myclean("
-Inf 0.35 2.85 4.45 4.55 4.85 5.05 5.30 5.65 5.85 Inf

V1 : [====)
V2 : [==========)
V3 : [====)
V4 : [==========================================================]

V1 : [========================================)
V2 : [==============================================]
V3 : [====)
V4 : [====)

")

end
Loading