diff --git a/data.go b/data.go index 4220051..20dea8b 100644 --- a/data.go +++ b/data.go @@ -18,14 +18,6 @@ import ( "golang.org/x/exp/slices" ) -// EasyIndex returns an index to the `n` -length object that allows for negative indexing. For example, EasyIndex(-1, 5) returns 4. This is similar to Python's negative indexing. The return value may be less than zero if (-i) > n. -func EasyIndex(i, n int) int { - if i < 0 { - return n + i - } - return i -} - type Series interface { Signaler @@ -47,6 +39,9 @@ type Series interface { // Statistical functions. Rolling(period int) *RollingSeries + + // WithValueFunc is used to implement other types of Series that may modify the values by applying a function before returning them, for example. This returns a Series that is a copy of the original with the new value function used whenever a value is requested outside of the Value() method, which will still return the original value. + WithValueFunc(value func(i int) interface{}) Series } type Frame interface { @@ -95,11 +90,14 @@ func (s *AppliedSeries) Value(i int) interface{} { return s.apply(EasyIndex(i, s.Len()), s.Series.Value(i)) } +func (s *AppliedSeries) WithValueFunc(value func(i int) interface{}) Series { + return &AppliedSeries{Series: s.Series.WithValueFunc(value), apply: s.apply} +} + func NewAppliedSeries(s Series, apply func(i int, val interface{}) interface{}) *AppliedSeries { - return &AppliedSeries{ - Series: s, - apply: apply, - } + appliedSeries := &AppliedSeries{apply: apply} + appliedSeries.Series = s.WithValueFunc(appliedSeries.Value) + return appliedSeries } type RollingSeries struct { @@ -107,146 +105,139 @@ type RollingSeries struct { period int } +// Average is an alias for Mean. +func (s *RollingSeries) Average() *AppliedSeries { + return s.Mean() +} + func (s *RollingSeries) Mean() *AppliedSeries { - return &AppliedSeries{ - Series: s, - apply: func(_ int, v interface{}) interface{} { - switch v := v.(type) { - case []interface{}: - if len(v) == 0 { - return nil - } - switch v[0].(type) { - case float64: - var sum float64 - for _, v := range v { - sum += v.(float64) - } - return sum / float64(len(v)) - case int64: - var sum int64 - for _, v := range v { - sum += v.(int64) - } - return sum / int64(len(v)) - default: - return v[len(v)-1] // Do nothing - } - default: - panic(fmt.Sprintf("expected a slice of values, got %t", v)) + return NewAppliedSeries(s, func(_ int, v interface{}) interface{} { + switch v := v.(type) { + case []interface{}: + if len(v) == 0 { + return nil } - }, - } + switch v[0].(type) { + case float64: + var sum float64 + for _, v := range v { + sum += v.(float64) + } + return sum / float64(len(v)) + case int64: + var sum int64 + for _, v := range v { + sum += v.(int64) + } + return sum / int64(len(v)) + default: + return v[len(v)-1] // Do nothing + } + default: + panic(fmt.Sprintf("expected a slice of values, got %t", v)) + } + }) } func (s *RollingSeries) EMA() *AppliedSeries { - return &AppliedSeries{ - Series: s, - apply: func(i int, v interface{}) interface{} { - switch v := v.(type) { - case []interface{}: - if len(v) == 0 { - return nil - } - switch v[0].(type) { - case float64: - ema := v[0].(float64) - for _, v := range v[1:] { - ema += (v.(float64) - ema) * 2 / (float64(s.period) + 1) - } - return ema - case int64: - ema := v[0].(int64) - for _, v := range v[1:] { - ema += (v.(int64) - ema) * 2 / (int64(s.period) + 1) - } - return ema - default: // string, time.Time - return v[len(v)-1] // Do nothing - } - default: - panic(fmt.Sprintf("expected a slice of values, got %t", v)) + return NewAppliedSeries(s, func(i int, v interface{}) interface{} { + switch v := v.(type) { + case []interface{}: + if len(v) == 0 { + return nil } - }, - } + switch v[0].(type) { + case float64: + ema := v[0].(float64) + for _, v := range v[1:] { + ema += (v.(float64) - ema) * 2 / (float64(s.period) + 1) + } + return ema + case int64: + ema := v[0].(int64) + for _, v := range v[1:] { + ema += (v.(int64) - ema) * 2 / (int64(s.period) + 1) + } + return ema + default: // string, time.Time + return v[len(v)-1] // Do nothing + } + default: + panic(fmt.Sprintf("expected a slice of values, got %t", v)) + } + }) } func (s *RollingSeries) Median() *AppliedSeries { - return &AppliedSeries{ - Series: s, - apply: func(_ int, v interface{}) interface{} { - switch v := v.(type) { - case []interface{}: - if len(v) == 0 { - return nil - } - switch v[0].(type) { - case float64: - if len(v) == 0 { - return float64(0) - } - slices.SortFunc(v, func(a, b interface{}) bool { - x, y := a.(float64), b.(float64) - return x < y || (math.IsNaN(x) && !math.IsNaN(y)) - }) - if len(v)%2 == 0 { - return (v[len(v)/2-1].(float64) + v[len(v)/2].(float64)) / 2 - } - return v[len(v)/2] - case int64: - if len(v) == 0 { - return int64(0) - } - slices.SortFunc(v, func(a, b interface{}) bool { - x, y := a.(int64), b.(int64) - return x < y - }) - if len(v)%2 == 0 { - return (v[len(v)/2-1].(int64) + v[len(v)/2].(int64)) / 2 - } - return v[len(v)/2] - default: // string, time.Time - return v[len(v)-1] // Do nothing - } - default: - panic(fmt.Sprintf("expected a slice of values, got %t", v)) + return NewAppliedSeries(s, func(_ int, v interface{}) interface{} { + switch v := v.(type) { + case []interface{}: + if len(v) == 0 { + return nil } - }, - } + switch v[0].(type) { + case float64: + if len(v) == 0 { + return float64(0) + } + slices.SortFunc(v, func(a, b interface{}) bool { + x, y := a.(float64), b.(float64) + return x < y || (math.IsNaN(x) && !math.IsNaN(y)) + }) + if len(v)%2 == 0 { + return (v[len(v)/2-1].(float64) + v[len(v)/2].(float64)) / 2 + } + return v[len(v)/2] + case int64: + if len(v) == 0 { + return int64(0) + } + slices.SortFunc(v, func(a, b interface{}) bool { + x, y := a.(int64), b.(int64) + return x < y + }) + if len(v)%2 == 0 { + return (v[len(v)/2-1].(int64) + v[len(v)/2].(int64)) / 2 + } + return v[len(v)/2] + default: // string, time.Time + return v[len(v)-1] // Do nothing + } + default: + panic(fmt.Sprintf("expected a slice of values, got %t", v)) + } + }) } func (s *RollingSeries) StdDev() *AppliedSeries { - return &AppliedSeries{ - Series: s, - apply: func(i int, v interface{}) interface{} { - switch v := v.(type) { - case []interface{}: - if len(v) == 0 { - return nil - } - switch v[0].(type) { - case float64: - mean := s.Mean().Value(i).(float64) // Take the mean of the last period values for the current index - var sum float64 - for _, v := range v { - sum += (v.(float64) - mean) * (v.(float64) - mean) - } - return math.Sqrt(sum / float64(len(v))) - case int64: - mean := s.Mean().Value(i).(int64) - var sum int64 - for _, v := range v { - sum += (v.(int64) - mean) * (v.(int64) - mean) - } - return int64(math.Sqrt(float64(sum) / float64(len(v)))) - default: // A slice of something else, just return the last value - return v[len(v)-1] // Do nothing - } - default: - panic(fmt.Sprintf("expected a slice of values, got %t", v)) + return NewAppliedSeries(s, func(i int, v interface{}) interface{} { + switch v := v.(type) { + case []interface{}: + if len(v) == 0 { + return nil } - }, - } + switch v[0].(type) { + case float64: + mean := s.Mean().Value(i).(float64) // Take the mean of the last period values for the current index + var sum float64 + for _, v := range v { + sum += (v.(float64) - mean) * (v.(float64) - mean) + } + return math.Sqrt(sum / float64(len(v))) + case int64: + mean := s.Mean().Value(i).(int64) + var sum int64 + for _, v := range v { + sum += (v.(int64) - mean) * (v.(int64) - mean) + } + return int64(math.Sqrt(float64(sum) / float64(len(v)))) + default: // A slice of something else, just return the last value + return v[len(v)-1] // Do nothing + } + default: + panic(fmt.Sprintf("expected a slice of values, got %t", v)) + } + }) } // Value returns []interface{} up to `period` long. The last item in the slice is the item at i. If i is out of bounds, nil is returned. @@ -263,6 +254,10 @@ func (s *RollingSeries) Value(i int) interface{} { return items } +func (s *RollingSeries) WithValueFunc(value func(i int) interface{}) Series { + return &RollingSeries{Series: s.Series.WithValueFunc(value), period: s.period} +} + // DataSeries is a Series that wraps a column of data. The data can be of the following types: float64, int64, string, or time.Time. // // Signals: @@ -270,7 +265,8 @@ func (s *RollingSeries) Value(i int) interface{} { // - NameChanged(string) - when the name is changed. type DataSeries struct { SignalManager - data df.Series + data df.Series + value func(i int) interface{} } // Copy copies the Series from start to end (inclusive). If end is -1, it will copy to the end of the Series. If start is out of bounds, nil is returned. @@ -284,7 +280,11 @@ func (s *DataSeries) Copy(start, end int) Series { } _end = &end } - return &DataSeries{SignalManager{}, s.data.Copy(df.Range{Start: &start, End: _end})} + return &DataSeries{ + SignalManager: SignalManager{}, + data: s.data.Copy(df.Range{Start: &start, End: _end}), + value: s.value, + } } func (s *DataSeries) Name() string { @@ -308,7 +308,9 @@ func (s *DataSeries) Len() int { } func (s *DataSeries) Rolling(period int) *RollingSeries { - return &RollingSeries{s, period} + rollingSeries := &RollingSeries{period: period} + rollingSeries.Series = s.WithValueFunc(rollingSeries.Value) + return rollingSeries } func (s *DataSeries) Push(value interface{}) Series { @@ -341,7 +343,7 @@ func (s *DataSeries) ValueRange(start, end int) []interface{} { items := make([]interface{}, end-start+1) for i := start; i <= end; i++ { - items[i-start] = s.Value(i) + items[i-start] = s.value(i) } return items } @@ -354,7 +356,7 @@ func (s *DataSeries) Values() []interface{} { } func (s *DataSeries) Float(i int) float64 { - val := s.Value(i) + val := s.value(i) if val == nil { return 0 } @@ -367,7 +369,7 @@ func (s *DataSeries) Float(i int) float64 { } func (s *DataSeries) Int(i int) int64 { - val := s.Value(i) + val := s.value(i) if val == nil { return 0 } @@ -380,7 +382,7 @@ func (s *DataSeries) Int(i int) int64 { } func (s *DataSeries) Str(i int) string { - val := s.Value(i) + val := s.value(i) if val == nil { return "" } @@ -393,7 +395,7 @@ func (s *DataSeries) Str(i int) string { } func (s *DataSeries) Time(i int) time.Time { - val := s.Value(i) + val := s.value(i) if val == nil { return time.Time{} } @@ -405,8 +407,21 @@ func (s *DataSeries) Time(i int) time.Time { } } +func (s *DataSeries) WithValueFunc(value func(i int) interface{}) Series { + return &DataSeries{ + SignalManager: s.SignalManager, + data: s.data, + value: value, + } +} + func NewDataSeries(data df.Series) *DataSeries { - return &DataSeries{SignalManager{}, data} + dataSeries := &DataSeries{ + SignalManager: SignalManager{}, + data: data, + } + dataSeries.value = dataSeries.Value + return dataSeries } type DataFrame struct { diff --git a/data_test.go b/data_test.go index 0a96fce..d02499a 100644 --- a/data_test.go +++ b/data_test.go @@ -3,6 +3,8 @@ package autotrader import ( "testing" "time" + + "github.com/rocketlaunchr/dataframe-go" ) func newTestingDataFrame() *DataFrame { @@ -13,6 +15,36 @@ func newTestingDataFrame() *DataFrame { return data } +func TestAppliedSeries(t *testing.T) { + // Test rolling average. + series := NewDataSeries(dataframe.NewSeriesFloat64("test", nil, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + + sma5Expected := []float64{1, 1.5, 2, 2.5, 3, 4, 5, 6, 7, 8} + sma5 := (Series)(series.Rolling(5).Average()) // Take the 5 period moving average and cast it to Series. + if sma5.Len() != 10 { + t.Fatalf("Expected 10 rows, got %d", sma5.Len()) + } + for i := 0; i < 10; i++ { + // Calling Float instead of Value is very important. Value will call the AppliedSeries.Value method + // while Float calls Series.Float which is what most people will use and is the most likely to be + // problematic as it is supposed to route through the DataSeries.value method. + if val := sma5.Float(i); !EqualApprox(val, sma5Expected[i]) { + t.Errorf("(%d)\tExpected %f, got %v", i, sma5Expected[i], val) + } + } + + ema5Expected := []float64{1, 1.3333333333333333, 1.8888888888888888, 2.5925925925925926, 3.3950617283950617, 4.395061728395062, 5.395061728395062, 6.395061728395062, 7.395061728395062, 8.395061728395062} + ema5 := (Series)(series.Rolling(5).EMA()) // Take the 5 period exponential moving average. + if ema5.Len() != 10 { + t.Fatalf("Expected 10 rows, got %d", ema5.Len()) + } + for i := 0; i < 10; i++ { + if val := ema5.Float(i); !EqualApprox(val, ema5Expected[i]) { + t.Errorf("(%d)\tExpected %f, got %v", i, ema5Expected[i], val) + } + } +} + func TestDataSeries(t *testing.T) { data := newTestingDataFrame() diff --git a/utils.go b/utils.go index e47214e..6330f83 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,29 @@ package autotrader -import "golang.org/x/exp/constraints" +import ( + "golang.org/x/exp/constraints" +) + +const floatComparisonTolerance = float64(1e-6) + +// EasyIndex returns an index to the `n` -length object that allows for negative indexing. For example, EasyIndex(-1, 5) returns 4. This is similar to Python's negative indexing. The return value may be less than zero if (-i) > n. +func EasyIndex(i, n int) int { + if i < 0 { + return n + i + } + return i +} + +func EqualApprox(a, b float64) bool { + return Abs(a-b) < floatComparisonTolerance +} + +func Abs[T constraints.Integer | constraints.Float](a T) T { + if a < T(0) { + return -a + } + return a +} func Min[T constraints.Ordered](a, b T) T { if a < b {