Use type safety for RSI indicator

This commit is contained in:
Luke I. Wilson 2023-05-20 13:01:12 -05:00
parent 1516604889
commit ed08bd275d
2 changed files with 18 additions and 18 deletions

View File

@ -7,28 +7,28 @@ import "math"
// Traditionally, an RSI reading of 70 or above indicates an overbought condition, and a reading of 30 or below indicates an oversold condition.
//
// Typically, the RSI is calculated with a period of 14 days.
func RSI(series *Series, periods int) *Series {
func RSI(series *FloatSeries, periods int) *FloatSeries {
// Calculate the difference between each day's close and the previous day's close.
delta := series.Copy().Map(func(i int, v interface{}) interface{} {
delta := series.Copy().MapReverse(func(i int, v float64) float64 {
if i == 0 {
return float64(0)
return 0
}
return v.(float64) - series.Value(i-1).(float64)
return v - series.Value(i-1)
})
// Calculate the average gain and average loss.
avgGain := delta.Copy().
Map(func(i int, val interface{}) interface{} { return math.Max(val.(float64), 0) }).
Rolling(periods).Average()
avgLoss := delta.Copy().
Map(func(i int, val interface{}) interface{} { return math.Abs(math.Min(val.(float64), 0)) }).
Rolling(periods).Average()
avgGain := &FloatSeries{delta.Copy().
Map(func(i int, val float64) float64 { return math.Max(val, 0) }).
Rolling(periods).Average()}
avgLoss := &FloatSeries{delta.Copy().
Map(func(i int, val float64) float64 { return math.Abs(math.Min(val, 0)) }).
Rolling(periods).Average()}
// Calculate the RSI.
return avgGain.Map(func(i int, val interface{}) interface{} {
return avgGain.Map(func(i int, val float64) float64 {
loss := avgLoss.Float(i)
if loss == 0 {
return float64(100)
}
return float64(100. - 100./(1.+val.(float64)/loss))
return float64(100. - 100./(1.+val/loss))
})
}

View File

@ -5,16 +5,16 @@ import (
)
func TestRSI(t *testing.T) {
prices := NewSeries("Prices", 1., 0., 2., 1., 3., 2., 4., 3., 5., 4., 6., 5., 7., 6.)
prices := NewFloatSeries("Prices", 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6)
rsi := RSI(prices, 14)
if rsi.Len() != 14 {
t.Errorf("RSI length is %d, expected 14", rsi.Len())
}
if !EqualApprox(rsi.Float(0), 100) {
t.Errorf("RSI[0] is %f, expected 0", rsi.Float(0))
if !EqualApprox(rsi.Value(0), 100) {
t.Errorf("RSI[0] is %f, expected 100", rsi.Value(0))
}
// TODO: check the expected RSI
// if !EqualApprox(rsi.Float(-1), 61.02423) {
// t.Errorf("RSI[-1] is %f, expected 100", rsi.Float(-1))
// }
if !EqualApprox(rsi.Value(-1), 63.157895) {
t.Errorf("RSI[-1] is %f, expected 63.157895", rsi.Value(-1))
}
}