Add WithValueFunc to Series to fix bug with accessing Value()

This commit is contained in:
Luke I. Wilson 2023-05-16 18:23:54 -05:00
parent 02799477a5
commit b1960b2b98
3 changed files with 218 additions and 148 deletions

95
data.go
View File

@ -18,14 +18,6 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
// EasyIndex returns an index to the `n` -length object that allows for negative indexing. For example, EasyIndex(-1, 5) returns 4. This is similar to Python's negative indexing. The return value may be less than zero if (-i) > n.
func EasyIndex(i, n int) int {
if i < 0 {
return n + i
}
return i
}
type Series interface { type Series interface {
Signaler Signaler
@ -47,6 +39,9 @@ type Series interface {
// Statistical functions. // Statistical functions.
Rolling(period int) *RollingSeries Rolling(period int) *RollingSeries
// WithValueFunc is used to implement other types of Series that may modify the values by applying a function before returning them, for example. This returns a Series that is a copy of the original with the new value function used whenever a value is requested outside of the Value() method, which will still return the original value.
WithValueFunc(value func(i int) interface{}) Series
} }
type Frame interface { type Frame interface {
@ -95,11 +90,14 @@ func (s *AppliedSeries) Value(i int) interface{} {
return s.apply(EasyIndex(i, s.Len()), s.Series.Value(i)) return s.apply(EasyIndex(i, s.Len()), s.Series.Value(i))
} }
func NewAppliedSeries(s Series, apply func(i int, val interface{}) interface{}) *AppliedSeries { func (s *AppliedSeries) WithValueFunc(value func(i int) interface{}) Series {
return &AppliedSeries{ return &AppliedSeries{Series: s.Series.WithValueFunc(value), apply: s.apply}
Series: s,
apply: apply,
} }
func NewAppliedSeries(s Series, apply func(i int, val interface{}) interface{}) *AppliedSeries {
appliedSeries := &AppliedSeries{apply: apply}
appliedSeries.Series = s.WithValueFunc(appliedSeries.Value)
return appliedSeries
} }
type RollingSeries struct { type RollingSeries struct {
@ -107,10 +105,13 @@ type RollingSeries struct {
period int period int
} }
// Average is an alias for Mean.
func (s *RollingSeries) Average() *AppliedSeries {
return s.Mean()
}
func (s *RollingSeries) Mean() *AppliedSeries { func (s *RollingSeries) Mean() *AppliedSeries {
return &AppliedSeries{ return NewAppliedSeries(s, func(_ int, v interface{}) interface{} {
Series: s,
apply: func(_ int, v interface{}) interface{} {
switch v := v.(type) { switch v := v.(type) {
case []interface{}: case []interface{}:
if len(v) == 0 { if len(v) == 0 {
@ -135,14 +136,11 @@ func (s *RollingSeries) Mean() *AppliedSeries {
default: default:
panic(fmt.Sprintf("expected a slice of values, got %t", v)) panic(fmt.Sprintf("expected a slice of values, got %t", v))
} }
}, })
}
} }
func (s *RollingSeries) EMA() *AppliedSeries { func (s *RollingSeries) EMA() *AppliedSeries {
return &AppliedSeries{ return NewAppliedSeries(s, func(i int, v interface{}) interface{} {
Series: s,
apply: func(i int, v interface{}) interface{} {
switch v := v.(type) { switch v := v.(type) {
case []interface{}: case []interface{}:
if len(v) == 0 { if len(v) == 0 {
@ -167,14 +165,11 @@ func (s *RollingSeries) EMA() *AppliedSeries {
default: default:
panic(fmt.Sprintf("expected a slice of values, got %t", v)) panic(fmt.Sprintf("expected a slice of values, got %t", v))
} }
}, })
}
} }
func (s *RollingSeries) Median() *AppliedSeries { func (s *RollingSeries) Median() *AppliedSeries {
return &AppliedSeries{ return NewAppliedSeries(s, func(_ int, v interface{}) interface{} {
Series: s,
apply: func(_ int, v interface{}) interface{} {
switch v := v.(type) { switch v := v.(type) {
case []interface{}: case []interface{}:
if len(v) == 0 { if len(v) == 0 {
@ -211,14 +206,11 @@ func (s *RollingSeries) Median() *AppliedSeries {
default: default:
panic(fmt.Sprintf("expected a slice of values, got %t", v)) panic(fmt.Sprintf("expected a slice of values, got %t", v))
} }
}, })
}
} }
func (s *RollingSeries) StdDev() *AppliedSeries { func (s *RollingSeries) StdDev() *AppliedSeries {
return &AppliedSeries{ return NewAppliedSeries(s, func(i int, v interface{}) interface{} {
Series: s,
apply: func(i int, v interface{}) interface{} {
switch v := v.(type) { switch v := v.(type) {
case []interface{}: case []interface{}:
if len(v) == 0 { if len(v) == 0 {
@ -245,8 +237,7 @@ func (s *RollingSeries) StdDev() *AppliedSeries {
default: default:
panic(fmt.Sprintf("expected a slice of values, got %t", v)) panic(fmt.Sprintf("expected a slice of values, got %t", v))
} }
}, })
}
} }
// Value returns []interface{} up to `period` long. The last item in the slice is the item at i. If i is out of bounds, nil is returned. // Value returns []interface{} up to `period` long. The last item in the slice is the item at i. If i is out of bounds, nil is returned.
@ -263,6 +254,10 @@ func (s *RollingSeries) Value(i int) interface{} {
return items return items
} }
func (s *RollingSeries) WithValueFunc(value func(i int) interface{}) Series {
return &RollingSeries{Series: s.Series.WithValueFunc(value), period: s.period}
}
// DataSeries is a Series that wraps a column of data. The data can be of the following types: float64, int64, string, or time.Time. // DataSeries is a Series that wraps a column of data. The data can be of the following types: float64, int64, string, or time.Time.
// //
// Signals: // Signals:
@ -271,6 +266,7 @@ func (s *RollingSeries) Value(i int) interface{} {
type DataSeries struct { type DataSeries struct {
SignalManager SignalManager
data df.Series data df.Series
value func(i int) interface{}
} }
// Copy copies the Series from start to end (inclusive). If end is -1, it will copy to the end of the Series. If start is out of bounds, nil is returned. // Copy copies the Series from start to end (inclusive). If end is -1, it will copy to the end of the Series. If start is out of bounds, nil is returned.
@ -284,7 +280,11 @@ func (s *DataSeries) Copy(start, end int) Series {
} }
_end = &end _end = &end
} }
return &DataSeries{SignalManager{}, s.data.Copy(df.Range{Start: &start, End: _end})} return &DataSeries{
SignalManager: SignalManager{},
data: s.data.Copy(df.Range{Start: &start, End: _end}),
value: s.value,
}
} }
func (s *DataSeries) Name() string { func (s *DataSeries) Name() string {
@ -308,7 +308,9 @@ func (s *DataSeries) Len() int {
} }
func (s *DataSeries) Rolling(period int) *RollingSeries { func (s *DataSeries) Rolling(period int) *RollingSeries {
return &RollingSeries{s, period} rollingSeries := &RollingSeries{period: period}
rollingSeries.Series = s.WithValueFunc(rollingSeries.Value)
return rollingSeries
} }
func (s *DataSeries) Push(value interface{}) Series { func (s *DataSeries) Push(value interface{}) Series {
@ -341,7 +343,7 @@ func (s *DataSeries) ValueRange(start, end int) []interface{} {
items := make([]interface{}, end-start+1) items := make([]interface{}, end-start+1)
for i := start; i <= end; i++ { for i := start; i <= end; i++ {
items[i-start] = s.Value(i) items[i-start] = s.value(i)
} }
return items return items
} }
@ -354,7 +356,7 @@ func (s *DataSeries) Values() []interface{} {
} }
func (s *DataSeries) Float(i int) float64 { func (s *DataSeries) Float(i int) float64 {
val := s.Value(i) val := s.value(i)
if val == nil { if val == nil {
return 0 return 0
} }
@ -367,7 +369,7 @@ func (s *DataSeries) Float(i int) float64 {
} }
func (s *DataSeries) Int(i int) int64 { func (s *DataSeries) Int(i int) int64 {
val := s.Value(i) val := s.value(i)
if val == nil { if val == nil {
return 0 return 0
} }
@ -380,7 +382,7 @@ func (s *DataSeries) Int(i int) int64 {
} }
func (s *DataSeries) Str(i int) string { func (s *DataSeries) Str(i int) string {
val := s.Value(i) val := s.value(i)
if val == nil { if val == nil {
return "" return ""
} }
@ -393,7 +395,7 @@ func (s *DataSeries) Str(i int) string {
} }
func (s *DataSeries) Time(i int) time.Time { func (s *DataSeries) Time(i int) time.Time {
val := s.Value(i) val := s.value(i)
if val == nil { if val == nil {
return time.Time{} return time.Time{}
} }
@ -405,8 +407,21 @@ func (s *DataSeries) Time(i int) time.Time {
} }
} }
func (s *DataSeries) WithValueFunc(value func(i int) interface{}) Series {
return &DataSeries{
SignalManager: s.SignalManager,
data: s.data,
value: value,
}
}
func NewDataSeries(data df.Series) *DataSeries { func NewDataSeries(data df.Series) *DataSeries {
return &DataSeries{SignalManager{}, data} dataSeries := &DataSeries{
SignalManager: SignalManager{},
data: data,
}
dataSeries.value = dataSeries.Value
return dataSeries
} }
type DataFrame struct { type DataFrame struct {

View File

@ -3,6 +3,8 @@ package autotrader
import ( import (
"testing" "testing"
"time" "time"
"github.com/rocketlaunchr/dataframe-go"
) )
func newTestingDataFrame() *DataFrame { func newTestingDataFrame() *DataFrame {
@ -13,6 +15,36 @@ func newTestingDataFrame() *DataFrame {
return data return data
} }
func TestAppliedSeries(t *testing.T) {
// Test rolling average.
series := NewDataSeries(dataframe.NewSeriesFloat64("test", nil, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
sma5Expected := []float64{1, 1.5, 2, 2.5, 3, 4, 5, 6, 7, 8}
sma5 := (Series)(series.Rolling(5).Average()) // Take the 5 period moving average and cast it to Series.
if sma5.Len() != 10 {
t.Fatalf("Expected 10 rows, got %d", sma5.Len())
}
for i := 0; i < 10; i++ {
// Calling Float instead of Value is very important. Value will call the AppliedSeries.Value method
// while Float calls Series.Float which is what most people will use and is the most likely to be
// problematic as it is supposed to route through the DataSeries.value method.
if val := sma5.Float(i); !EqualApprox(val, sma5Expected[i]) {
t.Errorf("(%d)\tExpected %f, got %v", i, sma5Expected[i], val)
}
}
ema5Expected := []float64{1, 1.3333333333333333, 1.8888888888888888, 2.5925925925925926, 3.3950617283950617, 4.395061728395062, 5.395061728395062, 6.395061728395062, 7.395061728395062, 8.395061728395062}
ema5 := (Series)(series.Rolling(5).EMA()) // Take the 5 period exponential moving average.
if ema5.Len() != 10 {
t.Fatalf("Expected 10 rows, got %d", ema5.Len())
}
for i := 0; i < 10; i++ {
if val := ema5.Float(i); !EqualApprox(val, ema5Expected[i]) {
t.Errorf("(%d)\tExpected %f, got %v", i, ema5Expected[i], val)
}
}
}
func TestDataSeries(t *testing.T) { func TestDataSeries(t *testing.T) {
data := newTestingDataFrame() data := newTestingDataFrame()

View File

@ -1,6 +1,29 @@
package autotrader package autotrader
import "golang.org/x/exp/constraints" import (
"golang.org/x/exp/constraints"
)
const floatComparisonTolerance = float64(1e-6)
// EasyIndex returns an index to the `n` -length object that allows for negative indexing. For example, EasyIndex(-1, 5) returns 4. This is similar to Python's negative indexing. The return value may be less than zero if (-i) > n.
func EasyIndex(i, n int) int {
if i < 0 {
return n + i
}
return i
}
func EqualApprox(a, b float64) bool {
return Abs(a-b) < floatComparisonTolerance
}
func Abs[T constraints.Integer | constraints.Float](a T) T {
if a < T(0) {
return -a
}
return a
}
func Min[T constraints.Ordered](a, b T) T { func Min[T constraints.Ordered](a, b T) T {
if a < b { if a < b {