diff --git a/interval/delete.go b/interval/delete.go index effd0e7..88019e6 100644 --- a/interval/delete.go +++ b/interval/delete.go @@ -13,8 +13,9 @@ func (st *SearchTree[V, T]) Delete(start, end T) error { } intervl := interval[V, T]{ - start: start, - end: end, + start: start, + end: end, + allowPoint: st.config.allowIntervalPoint, } if intervl.isInvalid(st.cmp) { @@ -142,8 +143,9 @@ func (st *MultiValueSearchTree[V, T]) Delete(start, end T) error { } intervl := interval[V, T]{ - start: start, - end: end, + start: start, + end: end, + allowPoint: st.config.allowIntervalPoint, } if intervl.isInvalid(st.cmp) { diff --git a/interval/delete_test.go b/interval/delete_test.go index 114d80e..9e5f3b2 100644 --- a/interval/delete_test.go +++ b/interval/delete_test.go @@ -59,6 +59,20 @@ func TestSearchTree_Delete(t *testing.T) { } } +func TestSearchTree_Delete_PointInterval(t *testing.T) { + cmpFunc := func(x, y int) int { return x - y } + st := NewSearchTreeWithOptions[int](cmpFunc, TreeWithIntervalPoint()) + + start, end := 17, 17 + if err := st.Insert(start, end, 0); err != nil { + t.Fatalf("st.Insert(%v, %v): got unexpected error: %v", start, end, err) + } + + if err := st.Delete(start, end); err != nil { + t.Errorf("st.Delete(%v, %v): got unexpected error: %v", start, end, err) + } +} + func TestSearchTree_Delete_EmptyTree(t *testing.T) { st := NewSearchTree[any](func(x, y int) int { return x - y }) @@ -83,10 +97,30 @@ func TestSearchTree_Delete_Error(t *testing.T) { st := NewSearchTree[any](func(x, y int) int { return x - y }) st.Insert(5, 10, nil) - start, end := 10, 4 - err := st.Delete(start, end) - if err == nil { - t.Errorf("st.Delete(%v, %v): got nil error", start, end) + testCases := []struct { + name string + start, end int + }{ + { + name: "EndSmallerThenStart", + start: 10, + end: 5, + }, + { + name: "PointInterval", + start: 10, + end: 10, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + err := st.Delete(tc.start, tc.end) + if err == nil { + t.Errorf("st.Delete(%v, %v): got nil error", tc.start, tc.end) + } + }) } }) } @@ -225,6 +259,20 @@ func TestMultiValueSearchTree_Delete(t *testing.T) { } } +func TestMultiSearchSearchTree_Delete_PointInterval(t *testing.T) { + cmpFunc := func(x, y int) int { return x - y } + st := NewMultiValueSearchTreeWithOptions[int](cmpFunc, TreeWithIntervalPoint()) + + start, end := 17, 17 + if err := st.Insert(start, end, 0); err != nil { + t.Fatalf("st.Insert(%v, %v): got unexpected error: %v", start, end, err) + } + + if err := st.Delete(start, end); err != nil { + t.Errorf("st.Delete(%v, %v): got unexpected error: %v", start, end, err) + } +} + func TestMultiValueSearchTree_Delete_EmptyTree(t *testing.T) { st := NewMultiValueSearchTree[any](func(x, y int) int { return x - y }) @@ -249,10 +297,29 @@ func TestMultiValueSearchTree_Delete_Error(t *testing.T) { st := NewMultiValueSearchTree[any](func(x, y int) int { return x - y }) st.Insert(5, 10, nil) - start, end := 10, 4 - err := st.Delete(start, end) - if err == nil { - t.Errorf("st.Delete(%v, %v): got nil error", start, end) + testCases := []struct { + name string + start, end int + }{ + { + name: "EndSmallerThanStart", + start: 10, + end: 4, + }, + { + name: "PointInterval", + start: 10, + end: 10, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := st.Delete(tc.start, tc.end) + if err == nil { + t.Errorf("st.Delete(%v, %v): got nil error", tc.start, tc.end) + } + }) } }) } diff --git a/interval/example_test.go b/interval/example_test.go index 73eff9f..0d30b29 100644 --- a/interval/example_test.go +++ b/interval/example_test.go @@ -123,3 +123,26 @@ func ExampleMultiValueSearchTree_Upsert() { // Output: // [event4 event5] true } + +func ExampleTreeWithIntervalPoint() { + cmpFn := func(start, end time.Time) int { + switch { + case start.After(end): + return 1 + case start.Before(end): + return -1 + default: + return 0 + } + } + + st := interval.NewSearchTreeWithOptions[string](cmpFn, interval.TreeWithIntervalPoint()) + + pointInerval := time.Now() + st.Insert(pointInerval, pointInerval, "event") + + vals, ok := st.Find(pointInerval, pointInerval) + fmt.Println(vals, ok) + // Output: + // event true +} diff --git a/interval/insert.go b/interval/insert.go index 2afbe9b..6aea229 100644 --- a/interval/insert.go +++ b/interval/insert.go @@ -14,9 +14,10 @@ func (st *SearchTree[V, T]) Insert(start, end T, val V) error { defer st.mu.Unlock() intervl := interval[V, T]{ - start: start, - end: end, - val: val, + start: start, + end: end, + val: val, + allowPoint: st.config.allowIntervalPoint, } if intervl.isInvalid(st.cmp) { @@ -73,9 +74,10 @@ func newEmptyValueListError[V, T any](it interval[V, T], action string) error { // or an EmptyValueListError if vals is an empty list. func (st *MultiValueSearchTree[V, T]) Insert(start, end T, vals ...V) error { intervl := interval[V, T]{ - start: start, - end: end, - vals: vals, + start: start, + end: end, + vals: vals, + allowPoint: st.config.allowIntervalPoint, } if intervl.isInvalid(st.cmp) { @@ -123,9 +125,10 @@ func insert[V, T any](n *node[V, T], intervl interval[V, T], cmp CmpFunc[T]) *no // or an EmptyValueListError if vals is an empty list. func (st *MultiValueSearchTree[V, T]) Upsert(start, end T, vals ...V) error { intervl := interval[V, T]{ - start: start, - end: end, - vals: vals, + start: start, + end: end, + vals: vals, + allowPoint: st.config.allowIntervalPoint, } if intervl.isInvalid(st.cmp) { diff --git a/interval/insert_test.go b/interval/insert_test.go index 2bb37f1..696b7a7 100644 --- a/interval/insert_test.go +++ b/interval/insert_test.go @@ -37,14 +37,57 @@ func TestSearchTree_Insert_UpdateValue(t *testing.T) { } } +func TestSearchTree_Insert_PointInterval(t *testing.T) { + cmpFunc := func(x, y int) int { return x - y } + st := NewSearchTreeWithOptions[string, int](cmpFunc, TreeWithIntervalPoint()) + + start, end := 16, 16 + val := "point-interval" + + err := st.Insert(start, end, val) + if err != nil { + t.Fatalf("st.Insert(%v, %v): got unexpected error: %v", start, end, err) + } + + got, ok := st.Find(start, end) + if !ok { + t.Errorf("st.Find(%v, %v): got not interval", start, end) + } + + if want := val; got != want { + t.Errorf("st.Find(%v, %v): got unexpected value %v; want %v", start, end, got, want) + } +} + func TestSearchTree_Insert_Error(t *testing.T) { - t.Run("InvalidRange", func(t *testing.T) { + t.Run("InvalidInterval", func(t *testing.T) { st := NewSearchTree[int](timeCmp) - start, end := time.Now(), time.Now().Add(-(1 * time.Hour)) - err := st.Insert(start, end, 0) - if err == nil { - t.Errorf("st.Insert(%v, %v): got nil error", start, end) + now := time.Now() + testCases := []struct { + name string + start, end time.Time + }{ + { + name: "EndBeforeStart", + start: now, + end: now.Add(-(1 * time.Hour)), + }, + { + name: "PointInterval", + start: now, + end: now, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := st.Insert(tc.start, tc.end, 0) + var wantErr InvalidIntervalError + if !errors.As(err, &wantErr) { + t.Errorf("st.Insert(%v, %v, 0): got error type %T; want it to be %T", tc.start, tc.end, err, wantErr) + } + }) } }) } @@ -80,16 +123,58 @@ func TestMultiValueSearchTree_Insert(t *testing.T) { } } +func TestMultiValueSearchTree_Insert_PointInterval(t *testing.T) { + cmpFunc := func(x, y int) int { return x - y } + st := NewMultiValueSearchTreeWithOptions[string, int](cmpFunc, TreeWithIntervalPoint()) + + vals := []string{"value1", "value2"} + start, end := 17, 17 + + err := st.Insert(start, end, vals...) + if err != nil { + t.Fatalf("MultiValueSearchTree.Insert(%v, %v): got unexpected error: %v", start, end, err) + } + + got, ok := st.Find(start, end) + if !ok { + t.Fatalf("st.Find(%v, %v): got no interval value; want %v", start, end, vals) + } + + if want := vals; !reflect.DeepEqual(got, want) { + t.Errorf("st.Find(%v, %v): got unexpected value %q; want %q", start, end, got, want) + } +} + func TestMultiValueSearchTree_Insert_Error(t *testing.T) { - t.Run("InvalidRange", func(t *testing.T) { + t.Run("InvalidInterval", func(t *testing.T) { st := NewMultiValueSearchTree[int](timeCmp) - start, end := time.Now(), time.Now().Add(-(1 * time.Hour)) - err := st.Insert(start, end, 0) + now := time.Now() + testCases := []struct { + name string + start, end time.Time + }{ + { + name: "EndBeforeStart", + start: now, + end: now.Add(-(1 * time.Hour)), + }, + { + name: "PointInterval", + start: now, + end: now, + }, + } - var wantErr InvalidIntervalError - if !errors.As(err, &wantErr) { - t.Errorf("st.Insert(%v, %v, 0): got error type %T; want it to be %T", start, end, err, wantErr) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := st.Insert(tc.start, tc.end, 0) + + var wantErr InvalidIntervalError + if !errors.As(err, &wantErr) { + t.Errorf("st.Insert(%v, %v, 0): got error type %T; want it to be %T", tc.start, tc.end, err, wantErr) + } + }) } }) @@ -137,16 +222,58 @@ func TestMultiValueSearchTree_Upsert(t *testing.T) { } } +func TestMultiValueSearchTree_Upsert_PointInterval(t *testing.T) { + cmpFunc := func(x, y int) int { return x - y } + st := NewMultiValueSearchTreeWithOptions[string, int](cmpFunc, TreeWithIntervalPoint()) + + vals := []string{"value1", "value2"} + start, end := 17, 17 + + err := st.Upsert(start, end, vals...) + if err != nil { + t.Fatalf("MultiValueSearchTree.Upsert(%v, %v): got unexpected error: %v", start, end, err) + } + + got, ok := st.Find(start, end) + if !ok { + t.Fatalf("st.Find(%v, %v): got no interval value; want %v", start, end, vals) + } + + if want := vals; !reflect.DeepEqual(got, want) { + t.Errorf("st.Find(%v, %v): got unexpected value %q; want %q", start, end, got, want) + } +} + func TestMultiValueSearchTree_Upsert_Error(t *testing.T) { t.Run("InvalidRange", func(t *testing.T) { st := NewMultiValueSearchTree[int](timeCmp) - start, end := time.Now(), time.Now().Add(-(1 * time.Hour)) - err := st.Upsert(start, end, 0) + now := time.Now() + testCases := []struct { + name string + start, end time.Time + }{ + { + name: "EndBeforeStart", + start: now, + end: now.Add(-(1 * time.Hour)), + }, + { + name: "PointInterval", + start: now, + end: now, + }, + } - var wantErr InvalidIntervalError - if !errors.As(err, &wantErr) { - t.Errorf("st.Upsert(%v, %v, 0): got error type %T; want it to be %T", start, end, err, wantErr) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := st.Upsert(tc.start, tc.end, 0) + + var wantErr InvalidIntervalError + if !errors.As(err, &wantErr) { + t.Errorf("st.Upsert(%v, %v, 0): got error type %T; want it to be %T", tc.start, tc.end, err, wantErr) + } + }) } }) diff --git a/interval/interval.go b/interval/interval.go index 1201398..e725741 100644 --- a/interval/interval.go +++ b/interval/interval.go @@ -1,6 +1,9 @@ package interval -import "fmt" +import ( + "fmt" + "strings" +) // InvalidIntervalError is a description of an invalid interval. type InvalidIntervalError string @@ -11,8 +14,13 @@ func (s InvalidIntervalError) Error() string { } func newInvalidIntervalError[V, T any](it interval[V, T]) error { - s := fmt.Sprintf("interval search tree invalid range: start value %v cannot be less than or equal to end value %v", it.start, it.end) - return InvalidIntervalError(s) + var b strings.Builder + fmt.Fprintf(&b, "interval search tree invalid range: start value %v cannot be less than ", it.start) + if !it.allowPoint { + b.WriteString("or equal to ") + } + fmt.Fprintf(&b, "end value %v", it.end) + return InvalidIntervalError(b.String()) } // CmpFunc must return a nagative integer, zero or a positive interger as x is @@ -45,13 +53,17 @@ func (f CmpFunc[T]) gte(x, y T) bool { } type interval[V, T any] struct { - start T - end T - val V - vals []V + start T + end T + val V + vals []V + allowPoint bool } func (it interval[V, T]) isInvalid(cmp CmpFunc[T]) bool { + if it.allowPoint { + return cmp.lt(it.end, it.start) + } return cmp.lte(it.end, it.start) } diff --git a/interval/search_tree.go b/interval/search_tree.go index cfc48b7..27a034c 100644 --- a/interval/search_tree.go +++ b/interval/search_tree.go @@ -8,6 +8,7 @@ // For more on interval trees, see https://en.wikipedia.org/wiki/Interval_tree // // To create a tree with time.Time as interval key type and string as value type: +// // cmpFn := func(t1, t2 time.Time) int { // switch{ // case t1.After(t2): return 1 @@ -15,25 +16,47 @@ // default: return 0 // } // } -// st := interval.NewSearchTree[string](cmpFn) +// st := interval.NewSearchTree[string](cmpFn) package interval import ( "sync" ) +// TreeConfig contains configuration fields that are used to customize the behavior +// of interval trees, specifically SearchTree and MultiValueSearchTree types. +type TreeConfig struct { + allowIntervalPoint bool +} + +// TreeOption is a functional option type used to customize the behavior +// of interval trees, such as the SearchTree and MultiValueSearchTree types. +type TreeOption func(*TreeConfig) + +// TreeWithIntervalPoint returns a TreeOption function that configures an interval tree to accept intervals +// in which the start and end key values are the same, effectively representing a point rather than a range in the tree. +func TreeWithIntervalPoint() TreeOption { + return func(c *TreeConfig) { + c.allowIntervalPoint = true + } +} + // SearchTree is a generic type representing the Interval Search Tree // where V is a generic value type, and T is a generic interval key type. +// For more details on how to use these configuration options, see the TreeOption +// function and their usage in the NewSearchTreeWithOptions and NewMultiValueSearchTreeWithOptions functions. type SearchTree[V, T any] struct { - mu sync.RWMutex // used to serialize read and write operations - root *node[V, T] - cmp CmpFunc[T] + mu sync.RWMutex // used to serialize read and write operations + root *node[V, T] + cmp CmpFunc[T] + config TreeConfig } // NewSearchTree returns an initialized interval search tree. // The cmp parameter is used for comparing total order of the interval key type T // when inserting or looking up an interval in the tree. // For more details on cmp, see the CmpFunc type. +// // NewSearchTree will panic if cmp is nil. func NewSearchTree[V, T any](cmp CmpFunc[T]) *SearchTree[V, T] { if cmp == nil { @@ -44,6 +67,28 @@ func NewSearchTree[V, T any](cmp CmpFunc[T]) *SearchTree[V, T] { } } +// NewSearchTreeWithOptions returns an initialized interval search tree with custom configuration options. +// The cmp parameter is used for comparing total order of the interval key type T when inserting or looking up an interval in the tree. +// The opts parameter is an optional list of TreeOptions that customize the behavior of the tree, +// such as allowing point intervals using TreeWithIntervalPoint. +// +// NewSearchTreeWithOptions will panic if cmp is nil. +func NewSearchTreeWithOptions[V, T any](cmp CmpFunc[T], opts ...TreeOption) *SearchTree[V, T] { + if cmp == nil { + panic("NewSearchTreeWithOptions: comparison function cmp cannot be nil") + } + + st := &SearchTree[V, T]{ + cmp: cmp, + } + + for _, opt := range opts { + opt(&st.config) + } + + return st +} + // Height returns the max depth of the tree. func (st *SearchTree[V, T]) Height() int { st.mu.RLock() @@ -77,6 +122,7 @@ type MultiValueSearchTree[V, T any] SearchTree[V, T] // The cmp parameter is used for comparing total order of the interval key type T // when inserting or looking up an interval in the tree. // For more details on cmp, see the CmpFunc type. +// // NewMultiValueSearchTree will panic if cmp is nil. func NewMultiValueSearchTree[V, T any](cmp CmpFunc[T]) *MultiValueSearchTree[V, T] { if cmp == nil { @@ -87,6 +133,28 @@ func NewMultiValueSearchTree[V, T any](cmp CmpFunc[T]) *MultiValueSearchTree[V, } } +// NewSearchTreeWithOptions returns an initialized multi-value interval search tree with custom configuration options. +// The cmp parameter is used for comparing total order of the interval key type T when inserting or looking up an interval in the tree. +// The opts parameter is an optional list of TreeOptions that customize the behavior of the tree, +// such as allowing point intervals using TreeWithIntervalPoint. +// +// NewMultiValueSearchTreeWithOptions will panic if cmp is nil. +func NewMultiValueSearchTreeWithOptions[V, T any](cmp CmpFunc[T], opts ...TreeOption) *MultiValueSearchTree[V, T] { + if cmp == nil { + panic("NewMultiValueSearchTreeWithOptions: comparison function cmp cannot be nil") + } + + st := &MultiValueSearchTree[V, T]{ + cmp: cmp, + } + + for _, opt := range opts { + opt(&st.config) + } + + return st +} + // Height returns the max depth of the tree. func (st *MultiValueSearchTree[V, T]) Height() int { st.mu.RLock() diff --git a/interval/search_tree_test.go b/interval/search_tree_test.go index 98d1233..7e78760 100644 --- a/interval/search_tree_test.go +++ b/interval/search_tree_test.go @@ -17,6 +17,16 @@ func TestNewSearchTree_EmptyCmp(t *testing.T) { NewSearchTree[string, int](nil) } +func TestNewSearchTreeWithOptions_EmptyCmp(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("NewSearchTreeWithOptions(nil): got execution without panic") + } + }() + + NewSearchTreeWithOptions[string, int](nil) +} + func TestSearchTree_Height(t *testing.T) { st := NewSearchTree[int](func(x, y int) int { return x - y }) @@ -135,6 +145,16 @@ func TestMultiValueSearchTree_NilCmpFunc(t *testing.T) { NewMultiValueSearchTree[string, int](nil) } +func TestMultiValueSearchTreeWithOptions_NilCmpFunc(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("NewMultiValueSearchTreeWithOptions(nil): got execution without panic") + } + }() + + NewMultiValueSearchTreeWithOptions[string, int](nil) +} + func TestMultiValueSearchTree_IsEmpty(t *testing.T) { st := NewMultiValueSearchTree[int](func(x, y int) int { return x - y })