Wrote tests for the RSI function

This commit is contained in:
Luke I. Wilson 2023-05-19 17:03:00 -05:00
parent d851061d1f
commit ba92a650fb
3 changed files with 46 additions and 3 deletions

View File

@ -3,6 +3,10 @@ package autotrader
import "math"
// RSI calculates the Relative Strength Index for a given Series. Typically, the input series is the Close column of a DataFrame. Returns a Series of RSI values of the same length as the input.
//
// Traditionally, an RSI reading of 70 or above indicates an overbought condition, and a reading of 30 or below indicates an oversold condition.
//
// Typically, the RSI is calculated with a period of 14 days.
func RSI(series Series, periods int) Series {
// Calculate the difference between each day's close and the previous day's close.
delta := series.MapReverse(func(i int, v interface{}) interface{} {
@ -19,6 +23,10 @@ func RSI(series Series, periods int) Series {
avgLoss := losses.Rolling(periods).Mean()
// Calculate the RSI.
return avgGain.Map(func(i int, val interface{}) interface{} {
return 100 - (100 / (1 + (val.(float64) / avgLoss.Value(i).(float64))))
loss := avgLoss.Float(i)
if loss == 0 {
return float64(100)
}
return float64(100. - 100./(1.+val.(float64)/loss))
})
}

19
indicators_test.go Normal file
View File

@ -0,0 +1,19 @@
package autotrader
import (
"testing"
)
func TestRSI(t *testing.T) {
prices := NewDataSeriesFloat("Prices", 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6)
rsi := RSI(prices, 14)
if rsi.Len() != 14 {
t.Errorf("RSI length is %d, expected 14", rsi.Len())
}
if !EqualApprox(rsi.Float(0), 100) {
t.Errorf("RSI[0] is %f, expected 0", rsi.Float(0))
}
if !EqualApprox(rsi.Float(-1), 61.02423) {
t.Errorf("RSI[-1] is %f, expected 100", rsi.Float(-1))
}
}

View File

@ -362,6 +362,22 @@ func NewDataSeries(name string, vals ...any) *DataSeries {
return dataSeries
}
func NewDataSeriesFloat(name string, vals ...float64) *DataSeries {
anyVals := make([]any, len(vals))
for i, v := range vals {
anyVals[i] = v
}
return NewDataSeries(name, anyVals...)
}
func NewDataSeriesInt(name string, vals ...int) *DataSeries {
anyVals := make([]any, len(vals))
for i, v := range vals {
anyVals[i] = v
}
return NewDataSeries(name, anyVals...)
}
// Copy returns a new DataSeries with a copy of the original data and Series name. start is an EasyIndex and count is the number of items to copy from start onward. If count is negative then all items from start to the end of the series are copied. If there are not enough items to copy then the maximum amount is returned. If there are no items to copy then an empty DataSeries is returned.
//
// Examples:
@ -528,7 +544,7 @@ func (s *DataSeries) Filter(f func(i int, val any) bool) Series {
func (s *DataSeries) Map(f func(i int, val any) any) Series {
series := s.Copy(0, -1)
for i := 0; i < s.Len(); i++ {
series.SetValue(i, f(i, series.Value(i)))
series.SetValue(i, f(i, s.value(i)))
}
return series
}
@ -537,7 +553,7 @@ func (s *DataSeries) Map(f func(i int, val any) any) Series {
func (s *DataSeries) MapReverse(f func(i int, val any) any) Series {
series := s.Copy(0, -1)
for i := s.Len() - 1; i >= 0; i-- {
series.SetValue(i, f(i, series.Value(i)))
series.SetValue(i, f(i, s.value(i)))
}
return series
}