Add WithValueFunc to Series to fix bug with accessing Value()

This commit is contained in:
Luke I. Wilson 2023-05-16 18:23:54 -05:00
parent 02799477a5
commit b1960b2b98
3 changed files with 218 additions and 148 deletions

309
data.go
View File

@ -18,14 +18,6 @@ import (
"golang.org/x/exp/slices"
)
// EasyIndex returns an index to the `n` -length object that allows for negative indexing. For example, EasyIndex(-1, 5) returns 4. This is similar to Python's negative indexing. The return value may be less than zero if (-i) > n.
func EasyIndex(i, n int) int {
if i < 0 {
return n + i
}
return i
}
type Series interface {
Signaler
@ -47,6 +39,9 @@ type Series interface {
// Statistical functions.
Rolling(period int) *RollingSeries
// WithValueFunc is used to implement other types of Series that may modify the values by applying a function before returning them, for example. This returns a Series that is a copy of the original with the new value function used whenever a value is requested outside of the Value() method, which will still return the original value.
WithValueFunc(value func(i int) interface{}) Series
}
type Frame interface {
@ -95,11 +90,14 @@ func (s *AppliedSeries) Value(i int) interface{} {
return s.apply(EasyIndex(i, s.Len()), s.Series.Value(i))
}
func (s *AppliedSeries) WithValueFunc(value func(i int) interface{}) Series {
return &AppliedSeries{Series: s.Series.WithValueFunc(value), apply: s.apply}
}
func NewAppliedSeries(s Series, apply func(i int, val interface{}) interface{}) *AppliedSeries {
return &AppliedSeries{
Series: s,
apply: apply,
}
appliedSeries := &AppliedSeries{apply: apply}
appliedSeries.Series = s.WithValueFunc(appliedSeries.Value)
return appliedSeries
}
type RollingSeries struct {
@ -107,146 +105,139 @@ type RollingSeries struct {
period int
}
// Average is an alias for Mean.
func (s *RollingSeries) Average() *AppliedSeries {
return s.Mean()
}
func (s *RollingSeries) Mean() *AppliedSeries {
return &AppliedSeries{
Series: s,
apply: func(_ int, v interface{}) interface{} {
switch v := v.(type) {
case []interface{}:
if len(v) == 0 {
return nil
}
switch v[0].(type) {
case float64:
var sum float64
for _, v := range v {
sum += v.(float64)
}
return sum / float64(len(v))
case int64:
var sum int64
for _, v := range v {
sum += v.(int64)
}
return sum / int64(len(v))
default:
return v[len(v)-1] // Do nothing
}
default:
panic(fmt.Sprintf("expected a slice of values, got %t", v))
return NewAppliedSeries(s, func(_ int, v interface{}) interface{} {
switch v := v.(type) {
case []interface{}:
if len(v) == 0 {
return nil
}
},
}
switch v[0].(type) {
case float64:
var sum float64
for _, v := range v {
sum += v.(float64)
}
return sum / float64(len(v))
case int64:
var sum int64
for _, v := range v {
sum += v.(int64)
}
return sum / int64(len(v))
default:
return v[len(v)-1] // Do nothing
}
default:
panic(fmt.Sprintf("expected a slice of values, got %t", v))
}
})
}
func (s *RollingSeries) EMA() *AppliedSeries {
return &AppliedSeries{
Series: s,
apply: func(i int, v interface{}) interface{} {
switch v := v.(type) {
case []interface{}:
if len(v) == 0 {
return nil
}
switch v[0].(type) {
case float64:
ema := v[0].(float64)
for _, v := range v[1:] {
ema += (v.(float64) - ema) * 2 / (float64(s.period) + 1)
}
return ema
case int64:
ema := v[0].(int64)
for _, v := range v[1:] {
ema += (v.(int64) - ema) * 2 / (int64(s.period) + 1)
}
return ema
default: // string, time.Time
return v[len(v)-1] // Do nothing
}
default:
panic(fmt.Sprintf("expected a slice of values, got %t", v))
return NewAppliedSeries(s, func(i int, v interface{}) interface{} {
switch v := v.(type) {
case []interface{}:
if len(v) == 0 {
return nil
}
},
}
switch v[0].(type) {
case float64:
ema := v[0].(float64)
for _, v := range v[1:] {
ema += (v.(float64) - ema) * 2 / (float64(s.period) + 1)
}
return ema
case int64:
ema := v[0].(int64)
for _, v := range v[1:] {
ema += (v.(int64) - ema) * 2 / (int64(s.period) + 1)
}
return ema
default: // string, time.Time
return v[len(v)-1] // Do nothing
}
default:
panic(fmt.Sprintf("expected a slice of values, got %t", v))
}
})
}
func (s *RollingSeries) Median() *AppliedSeries {
return &AppliedSeries{
Series: s,
apply: func(_ int, v interface{}) interface{} {
switch v := v.(type) {
case []interface{}:
if len(v) == 0 {
return nil
}
switch v[0].(type) {
case float64:
if len(v) == 0 {
return float64(0)
}
slices.SortFunc(v, func(a, b interface{}) bool {
x, y := a.(float64), b.(float64)
return x < y || (math.IsNaN(x) && !math.IsNaN(y))
})
if len(v)%2 == 0 {
return (v[len(v)/2-1].(float64) + v[len(v)/2].(float64)) / 2
}
return v[len(v)/2]
case int64:
if len(v) == 0 {
return int64(0)
}
slices.SortFunc(v, func(a, b interface{}) bool {
x, y := a.(int64), b.(int64)
return x < y
})
if len(v)%2 == 0 {
return (v[len(v)/2-1].(int64) + v[len(v)/2].(int64)) / 2
}
return v[len(v)/2]
default: // string, time.Time
return v[len(v)-1] // Do nothing
}
default:
panic(fmt.Sprintf("expected a slice of values, got %t", v))
return NewAppliedSeries(s, func(_ int, v interface{}) interface{} {
switch v := v.(type) {
case []interface{}:
if len(v) == 0 {
return nil
}
},
}
switch v[0].(type) {
case float64:
if len(v) == 0 {
return float64(0)
}
slices.SortFunc(v, func(a, b interface{}) bool {
x, y := a.(float64), b.(float64)
return x < y || (math.IsNaN(x) && !math.IsNaN(y))
})
if len(v)%2 == 0 {
return (v[len(v)/2-1].(float64) + v[len(v)/2].(float64)) / 2
}
return v[len(v)/2]
case int64:
if len(v) == 0 {
return int64(0)
}
slices.SortFunc(v, func(a, b interface{}) bool {
x, y := a.(int64), b.(int64)
return x < y
})
if len(v)%2 == 0 {
return (v[len(v)/2-1].(int64) + v[len(v)/2].(int64)) / 2
}
return v[len(v)/2]
default: // string, time.Time
return v[len(v)-1] // Do nothing
}
default:
panic(fmt.Sprintf("expected a slice of values, got %t", v))
}
})
}
func (s *RollingSeries) StdDev() *AppliedSeries {
return &AppliedSeries{
Series: s,
apply: func(i int, v interface{}) interface{} {
switch v := v.(type) {
case []interface{}:
if len(v) == 0 {
return nil
}
switch v[0].(type) {
case float64:
mean := s.Mean().Value(i).(float64) // Take the mean of the last period values for the current index
var sum float64
for _, v := range v {
sum += (v.(float64) - mean) * (v.(float64) - mean)
}
return math.Sqrt(sum / float64(len(v)))
case int64:
mean := s.Mean().Value(i).(int64)
var sum int64
for _, v := range v {
sum += (v.(int64) - mean) * (v.(int64) - mean)
}
return int64(math.Sqrt(float64(sum) / float64(len(v))))
default: // A slice of something else, just return the last value
return v[len(v)-1] // Do nothing
}
default:
panic(fmt.Sprintf("expected a slice of values, got %t", v))
return NewAppliedSeries(s, func(i int, v interface{}) interface{} {
switch v := v.(type) {
case []interface{}:
if len(v) == 0 {
return nil
}
},
}
switch v[0].(type) {
case float64:
mean := s.Mean().Value(i).(float64) // Take the mean of the last period values for the current index
var sum float64
for _, v := range v {
sum += (v.(float64) - mean) * (v.(float64) - mean)
}
return math.Sqrt(sum / float64(len(v)))
case int64:
mean := s.Mean().Value(i).(int64)
var sum int64
for _, v := range v {
sum += (v.(int64) - mean) * (v.(int64) - mean)
}
return int64(math.Sqrt(float64(sum) / float64(len(v))))
default: // A slice of something else, just return the last value
return v[len(v)-1] // Do nothing
}
default:
panic(fmt.Sprintf("expected a slice of values, got %t", v))
}
})
}
// Value returns []interface{} up to `period` long. The last item in the slice is the item at i. If i is out of bounds, nil is returned.
@ -263,6 +254,10 @@ func (s *RollingSeries) Value(i int) interface{} {
return items
}
func (s *RollingSeries) WithValueFunc(value func(i int) interface{}) Series {
return &RollingSeries{Series: s.Series.WithValueFunc(value), period: s.period}
}
// DataSeries is a Series that wraps a column of data. The data can be of the following types: float64, int64, string, or time.Time.
//
// Signals:
@ -270,7 +265,8 @@ func (s *RollingSeries) Value(i int) interface{} {
// - NameChanged(string) - when the name is changed.
type DataSeries struct {
SignalManager
data df.Series
data df.Series
value func(i int) interface{}
}
// Copy copies the Series from start to end (inclusive). If end is -1, it will copy to the end of the Series. If start is out of bounds, nil is returned.
@ -284,7 +280,11 @@ func (s *DataSeries) Copy(start, end int) Series {
}
_end = &end
}
return &DataSeries{SignalManager{}, s.data.Copy(df.Range{Start: &start, End: _end})}
return &DataSeries{
SignalManager: SignalManager{},
data: s.data.Copy(df.Range{Start: &start, End: _end}),
value: s.value,
}
}
func (s *DataSeries) Name() string {
@ -308,7 +308,9 @@ func (s *DataSeries) Len() int {
}
func (s *DataSeries) Rolling(period int) *RollingSeries {
return &RollingSeries{s, period}
rollingSeries := &RollingSeries{period: period}
rollingSeries.Series = s.WithValueFunc(rollingSeries.Value)
return rollingSeries
}
func (s *DataSeries) Push(value interface{}) Series {
@ -341,7 +343,7 @@ func (s *DataSeries) ValueRange(start, end int) []interface{} {
items := make([]interface{}, end-start+1)
for i := start; i <= end; i++ {
items[i-start] = s.Value(i)
items[i-start] = s.value(i)
}
return items
}
@ -354,7 +356,7 @@ func (s *DataSeries) Values() []interface{} {
}
func (s *DataSeries) Float(i int) float64 {
val := s.Value(i)
val := s.value(i)
if val == nil {
return 0
}
@ -367,7 +369,7 @@ func (s *DataSeries) Float(i int) float64 {
}
func (s *DataSeries) Int(i int) int64 {
val := s.Value(i)
val := s.value(i)
if val == nil {
return 0
}
@ -380,7 +382,7 @@ func (s *DataSeries) Int(i int) int64 {
}
func (s *DataSeries) Str(i int) string {
val := s.Value(i)
val := s.value(i)
if val == nil {
return ""
}
@ -393,7 +395,7 @@ func (s *DataSeries) Str(i int) string {
}
func (s *DataSeries) Time(i int) time.Time {
val := s.Value(i)
val := s.value(i)
if val == nil {
return time.Time{}
}
@ -405,8 +407,21 @@ func (s *DataSeries) Time(i int) time.Time {
}
}
func (s *DataSeries) WithValueFunc(value func(i int) interface{}) Series {
return &DataSeries{
SignalManager: s.SignalManager,
data: s.data,
value: value,
}
}
func NewDataSeries(data df.Series) *DataSeries {
return &DataSeries{SignalManager{}, data}
dataSeries := &DataSeries{
SignalManager: SignalManager{},
data: data,
}
dataSeries.value = dataSeries.Value
return dataSeries
}
type DataFrame struct {

View File

@ -3,6 +3,8 @@ package autotrader
import (
"testing"
"time"
"github.com/rocketlaunchr/dataframe-go"
)
func newTestingDataFrame() *DataFrame {
@ -13,6 +15,36 @@ func newTestingDataFrame() *DataFrame {
return data
}
func TestAppliedSeries(t *testing.T) {
// Test rolling average.
series := NewDataSeries(dataframe.NewSeriesFloat64("test", nil, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
sma5Expected := []float64{1, 1.5, 2, 2.5, 3, 4, 5, 6, 7, 8}
sma5 := (Series)(series.Rolling(5).Average()) // Take the 5 period moving average and cast it to Series.
if sma5.Len() != 10 {
t.Fatalf("Expected 10 rows, got %d", sma5.Len())
}
for i := 0; i < 10; i++ {
// Calling Float instead of Value is very important. Value will call the AppliedSeries.Value method
// while Float calls Series.Float which is what most people will use and is the most likely to be
// problematic as it is supposed to route through the DataSeries.value method.
if val := sma5.Float(i); !EqualApprox(val, sma5Expected[i]) {
t.Errorf("(%d)\tExpected %f, got %v", i, sma5Expected[i], val)
}
}
ema5Expected := []float64{1, 1.3333333333333333, 1.8888888888888888, 2.5925925925925926, 3.3950617283950617, 4.395061728395062, 5.395061728395062, 6.395061728395062, 7.395061728395062, 8.395061728395062}
ema5 := (Series)(series.Rolling(5).EMA()) // Take the 5 period exponential moving average.
if ema5.Len() != 10 {
t.Fatalf("Expected 10 rows, got %d", ema5.Len())
}
for i := 0; i < 10; i++ {
if val := ema5.Float(i); !EqualApprox(val, ema5Expected[i]) {
t.Errorf("(%d)\tExpected %f, got %v", i, ema5Expected[i], val)
}
}
}
func TestDataSeries(t *testing.T) {
data := newTestingDataFrame()

View File

@ -1,6 +1,29 @@
package autotrader
import "golang.org/x/exp/constraints"
import (
"golang.org/x/exp/constraints"
)
const floatComparisonTolerance = float64(1e-6)
// EasyIndex returns an index to the `n` -length object that allows for negative indexing. For example, EasyIndex(-1, 5) returns 4. This is similar to Python's negative indexing. The return value may be less than zero if (-i) > n.
func EasyIndex(i, n int) int {
if i < 0 {
return n + i
}
return i
}
func EqualApprox(a, b float64) bool {
return Abs(a-b) < floatComparisonTolerance
}
func Abs[T constraints.Integer | constraints.Float](a T) T {
if a < T(0) {
return -a
}
return a
}
func Min[T constraints.Ordered](a, b T) T {
if a < b {