diff --git a/src/Onda.jl b/src/Onda.jl index 81c3c74..6b4f648 100644 --- a/src/Onda.jl +++ b/src/Onda.jl @@ -1,6 +1,6 @@ module Onda -using Compat: @compat +using Compat: @compat, allequal using UUIDs, Dates, Random, Mmap using Compat, Legolas, TimeSpans, Arrow, Tables, TranscodingStreams, CodecZstd using Legolas: @schema, @version, write_full_path diff --git a/src/deprecations.jl b/src/deprecations.jl index b71111f..9c691fc 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -94,3 +94,15 @@ function upgrade(from::SignalV1, ::SignalV2SchemaVersion) from.channels, from.sample_unit, from.sample_resolution_in_unit, from.sample_offset_in_unit, from.sample_type, from.sample_rate) end + +# Not quite a deprecation, but we will backport `record_merge` for our own purposes +if pkgversion(Legolas) < v"0.5.18" + function record_merge(record::Legolas.AbstractRecord; fields_to_merge...) + # Avoid using `typeof(record)` as can cause constructor failures with parameterized + # record types. + R = Legolas.record_type(Legolas.schema_version_from_record(record)) + return R(Tables.rowmerge(record; fields_to_merge...)) + end +else + using Legolas: record_merge +end diff --git a/src/samples.jl b/src/samples.jl index f24e826..00b4e2c 100644 --- a/src/samples.jl +++ b/src/samples.jl @@ -186,6 +186,35 @@ function _column_arguments(samples::Samples, x) return _indices_fallback(_column_arguments, samples, x) end +##### +##### operations +##### + +# Ensure we don't match `vcat()` since that would be piracy +function Base.vcat(first_samples::Samples, more_samples::Samples...) + samples = (first_samples, more_samples...) + for field in setdiff(fieldnames(SamplesInfoV2), [:channels]) + if !allequal(getfield(s.info, field) for s in samples) + throw(ArgumentError("Cannot `vcat` samples objects which do not all have the same `$field`. Got values: $([getfield(s.info, field) for s in samples])")) + end + end + if !allequal(duration(s) for s in samples) + throw(ArgumentError("Cannot `vcat` samples objects which do not all have the same duration. Got values: $([duration(s) for s in samples])")) + end + if !allequal((s.encoded for s in samples)) + throw(ArgumentError("Cannot `vcat` samples objects which are not all encoded or all decoded. Got encoding values: $([s.encoded for s in samples])")) + end + all_channels = collect(Iterators.flatten(s.info.channels for s in samples)) + if !allunique(all_channels) + throw(ArgumentError("Cannot `vcat` samples objects which do not have unique channel names. Got channel names: $(all_channels)")) + end + # We checked all fields match except `channels`, so we can start with the first one and update the channels + # (we also know `samples` is non-empty by the signature) + info = record_merge(first(samples).info; channels=all_channels) + data = vcat((s.data for s in samples)...) + return Samples(data, info, first(samples).encoded) +end + ##### ##### encoding utilities ##### diff --git a/test/samples.jl b/test/samples.jl index 935bf07..b432f3e 100644 --- a/test/samples.jl +++ b/test/samples.jl @@ -247,6 +247,65 @@ end @test hash(samples) == hash(samples2) end +@testset "Base.vcat" begin + info = SamplesInfoV2(sensor_type="eeg", + channels=["a", "b", "c"], + sample_unit="unit", + sample_resolution_in_unit=1.0, + sample_offset_in_unit=0.0, + sample_type=Float32, + sample_rate=100.0) + + for encoded in (true, false) + samples1 = Samples(rand(sample_type(info), 3, 100), info, encoded) + + # Note: `record_merge` is defined in newer Legolas versions, but for the purposes of backwards compatibility + # with old Arrow versions, we've backported for internal use in Onda. + info2 = Onda.record_merge(info; channels = ["d", "e", "f"]) + samples2 = Samples(rand(sample_type(info2), 3, 100), info2, encoded) + + samples12 = vcat(samples1, samples2) + @test samples12.data[1:3, :] == samples1.data + @test samples12.data[4:6, :] == samples2.data + @test samples12.info.channels == map(string, 'a':'f') + end + + samples1 = Samples(rand(sample_type(info), 3, 100), info, true) + info2 = Onda.record_merge(info; channels = ["d", "e", "f"]) + + err = ArgumentError("""Cannot `vcat` samples objects which do not have unique channel names. Got channel names: ["a", "b", "c", "a", "b", "c"]""") + @test_throws err vcat(samples1, samples1) + + samples2 = Samples(rand(sample_type(info), 3, 100), Onda.record_merge(info2; sample_rate = 10), true) + err = ArgumentError("Cannot `vcat` samples objects which do not all have the same `sample_rate`. Got values: [100.0, 10.0]") + @test_throws err vcat(samples1, samples2) + + samples2 = Samples(rand(Float64, 3, 100), Onda.record_merge(info2; sample_type = Float64), true) + err = ArgumentError("""Cannot `vcat` samples objects which do not all have the same `sample_type`. Got values: ["float32", "float64"]""") + @test_throws err vcat(samples1, samples2) + + samples2 = Samples(rand(sample_type(info), 3, 100), Onda.record_merge(info2; sensor_type = "eeg2"), true) + err = ArgumentError("""Cannot `vcat` samples objects which do not all have the same `sensor_type`. Got values: ["eeg", "eeg2"]""") + @test_throws err vcat(samples1, samples2) + + samples2 = Samples(rand(sample_type(info), 3, 100), Onda.record_merge(info2; sample_unit = "unit2"), true) + err = ArgumentError("""Cannot `vcat` samples objects which do not all have the same `sample_unit`. Got values: ["unit", "unit2"]""") + @test_throws err vcat(samples1, samples2) + + samples2 = Samples(rand(sample_type(info), 3, 100), Onda.record_merge(info2; sample_resolution_in_unit = 5), true) + err = ArgumentError("""Cannot `vcat` samples objects which do not all have the same `sample_resolution_in_unit`. Got values: [1.0, 5.0]""") + @test_throws err vcat(samples1, samples2) + + samples2 = Samples(rand(sample_type(info), 3, 100), Onda.record_merge(info2; sample_offset_in_unit = 5), true) + err = ArgumentError("""Cannot `vcat` samples objects which do not all have the same `sample_offset_in_unit`. Got values: [0.0, 5.0]""") + @test_throws err vcat(samples1, samples2) + + samples2 = Samples(rand(sample_type(info), 3, 100), info2, false) + err = ArgumentError("""Cannot `vcat` samples objects which are not all encoded or all decoded. Got encoding values: Bool[1, 0]""") + @test_throws err vcat(samples1, samples2) +end + + @testset "Samples views" begin info = SamplesInfoV2(sensor_type="eeg",