mirror of
https://github.com/lukewilson2002/autotrader.git
synced 2025-06-15 00:13:51 +00:00
Wrote tests for the RSI function
This commit is contained in:
parent
d851061d1f
commit
ba92a650fb
@ -3,6 +3,10 @@ package autotrader
|
||||
import "math"
|
||||
|
||||
// RSI calculates the Relative Strength Index for a given Series. Typically, the input series is the Close column of a DataFrame. Returns a Series of RSI values of the same length as the input.
|
||||
//
|
||||
// Traditionally, an RSI reading of 70 or above indicates an overbought condition, and a reading of 30 or below indicates an oversold condition.
|
||||
//
|
||||
// Typically, the RSI is calculated with a period of 14 days.
|
||||
func RSI(series Series, periods int) Series {
|
||||
// Calculate the difference between each day's close and the previous day's close.
|
||||
delta := series.MapReverse(func(i int, v interface{}) interface{} {
|
||||
@ -19,6 +23,10 @@ func RSI(series Series, periods int) Series {
|
||||
avgLoss := losses.Rolling(periods).Mean()
|
||||
// Calculate the RSI.
|
||||
return avgGain.Map(func(i int, val interface{}) interface{} {
|
||||
return 100 - (100 / (1 + (val.(float64) / avgLoss.Value(i).(float64))))
|
||||
loss := avgLoss.Float(i)
|
||||
if loss == 0 {
|
||||
return float64(100)
|
||||
}
|
||||
return float64(100. - 100./(1.+val.(float64)/loss))
|
||||
})
|
||||
}
|
||||
|
19
indicators_test.go
Normal file
19
indicators_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package autotrader
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRSI(t *testing.T) {
|
||||
prices := NewDataSeriesFloat("Prices", 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6)
|
||||
rsi := RSI(prices, 14)
|
||||
if rsi.Len() != 14 {
|
||||
t.Errorf("RSI length is %d, expected 14", rsi.Len())
|
||||
}
|
||||
if !EqualApprox(rsi.Float(0), 100) {
|
||||
t.Errorf("RSI[0] is %f, expected 0", rsi.Float(0))
|
||||
}
|
||||
if !EqualApprox(rsi.Float(-1), 61.02423) {
|
||||
t.Errorf("RSI[-1] is %f, expected 100", rsi.Float(-1))
|
||||
}
|
||||
}
|
20
series.go
20
series.go
@ -362,6 +362,22 @@ func NewDataSeries(name string, vals ...any) *DataSeries {
|
||||
return dataSeries
|
||||
}
|
||||
|
||||
func NewDataSeriesFloat(name string, vals ...float64) *DataSeries {
|
||||
anyVals := make([]any, len(vals))
|
||||
for i, v := range vals {
|
||||
anyVals[i] = v
|
||||
}
|
||||
return NewDataSeries(name, anyVals...)
|
||||
}
|
||||
|
||||
func NewDataSeriesInt(name string, vals ...int) *DataSeries {
|
||||
anyVals := make([]any, len(vals))
|
||||
for i, v := range vals {
|
||||
anyVals[i] = v
|
||||
}
|
||||
return NewDataSeries(name, anyVals...)
|
||||
}
|
||||
|
||||
// Copy returns a new DataSeries with a copy of the original data and Series name. start is an EasyIndex and count is the number of items to copy from start onward. If count is negative then all items from start to the end of the series are copied. If there are not enough items to copy then the maximum amount is returned. If there are no items to copy then an empty DataSeries is returned.
|
||||
//
|
||||
// Examples:
|
||||
@ -528,7 +544,7 @@ func (s *DataSeries) Filter(f func(i int, val any) bool) Series {
|
||||
func (s *DataSeries) Map(f func(i int, val any) any) Series {
|
||||
series := s.Copy(0, -1)
|
||||
for i := 0; i < s.Len(); i++ {
|
||||
series.SetValue(i, f(i, series.Value(i)))
|
||||
series.SetValue(i, f(i, s.value(i)))
|
||||
}
|
||||
return series
|
||||
}
|
||||
@ -537,7 +553,7 @@ func (s *DataSeries) Map(f func(i int, val any) any) Series {
|
||||
func (s *DataSeries) MapReverse(f func(i int, val any) any) Series {
|
||||
series := s.Copy(0, -1)
|
||||
for i := s.Len() - 1; i >= 0; i-- {
|
||||
series.SetValue(i, f(i, series.Value(i)))
|
||||
series.SetValue(i, f(i, s.value(i)))
|
||||
}
|
||||
return series
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user