From ba92a650fbc318df6b02f213be2437f7e9e92241 Mon Sep 17 00:00:00 2001 From: "Luke I. Wilson" Date: Fri, 19 May 2023 17:03:00 -0500 Subject: [PATCH] Wrote tests for the RSI function --- indicators.go | 10 +++++++++- indicators_test.go | 19 +++++++++++++++++++ series.go | 20 ++++++++++++++++++-- 3 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 indicators_test.go diff --git a/indicators.go b/indicators.go index a840f79..ecec8e4 100644 --- a/indicators.go +++ b/indicators.go @@ -3,6 +3,10 @@ package autotrader import "math" // RSI calculates the Relative Strength Index for a given Series. Typically, the input series is the Close column of a DataFrame. Returns a Series of RSI values of the same length as the input. +// +// Traditionally, an RSI reading of 70 or above indicates an overbought condition, and a reading of 30 or below indicates an oversold condition. +// +// Typically, the RSI is calculated with a period of 14 days. func RSI(series Series, periods int) Series { // Calculate the difference between each day's close and the previous day's close. delta := series.MapReverse(func(i int, v interface{}) interface{} { @@ -19,6 +23,10 @@ func RSI(series Series, periods int) Series { avgLoss := losses.Rolling(periods).Mean() // Calculate the RSI. return avgGain.Map(func(i int, val interface{}) interface{} { - return 100 - (100 / (1 + (val.(float64) / avgLoss.Value(i).(float64)))) + loss := avgLoss.Float(i) + if loss == 0 { + return float64(100) + } + return float64(100. - 100./(1.+val.(float64)/loss)) }) } diff --git a/indicators_test.go b/indicators_test.go new file mode 100644 index 0000000..b4c6b06 --- /dev/null +++ b/indicators_test.go @@ -0,0 +1,19 @@ +package autotrader + +import ( + "testing" +) + +func TestRSI(t *testing.T) { + prices := NewDataSeriesFloat("Prices", 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6) + rsi := RSI(prices, 14) + if rsi.Len() != 14 { + t.Errorf("RSI length is %d, expected 14", rsi.Len()) + } + if !EqualApprox(rsi.Float(0), 100) { + t.Errorf("RSI[0] is %f, expected 0", rsi.Float(0)) + } + if !EqualApprox(rsi.Float(-1), 61.02423) { + t.Errorf("RSI[-1] is %f, expected 100", rsi.Float(-1)) + } +} diff --git a/series.go b/series.go index 1c2481f..13624f2 100644 --- a/series.go +++ b/series.go @@ -362,6 +362,22 @@ func NewDataSeries(name string, vals ...any) *DataSeries { return dataSeries } +func NewDataSeriesFloat(name string, vals ...float64) *DataSeries { + anyVals := make([]any, len(vals)) + for i, v := range vals { + anyVals[i] = v + } + return NewDataSeries(name, anyVals...) +} + +func NewDataSeriesInt(name string, vals ...int) *DataSeries { + anyVals := make([]any, len(vals)) + for i, v := range vals { + anyVals[i] = v + } + return NewDataSeries(name, anyVals...) +} + // Copy returns a new DataSeries with a copy of the original data and Series name. start is an EasyIndex and count is the number of items to copy from start onward. If count is negative then all items from start to the end of the series are copied. If there are not enough items to copy then the maximum amount is returned. If there are no items to copy then an empty DataSeries is returned. // // Examples: @@ -528,7 +544,7 @@ func (s *DataSeries) Filter(f func(i int, val any) bool) Series { func (s *DataSeries) Map(f func(i int, val any) any) Series { series := s.Copy(0, -1) for i := 0; i < s.Len(); i++ { - series.SetValue(i, f(i, series.Value(i))) + series.SetValue(i, f(i, s.value(i))) } return series } @@ -537,7 +553,7 @@ func (s *DataSeries) Map(f func(i int, val any) any) Series { func (s *DataSeries) MapReverse(f func(i int, val any) any) Series { series := s.Copy(0, -1) for i := s.Len() - 1; i >= 0; i-- { - series.SetValue(i, f(i, series.Value(i))) + series.SetValue(i, f(i, s.value(i))) } return series }