Use type safety for RSI indicator

This commit is contained in:
Luke I. Wilson 2023-05-20 13:01:12 -05:00
parent 1516604889
commit ed08bd275d
2 changed files with 18 additions and 18 deletions

View File

@ -7,28 +7,28 @@ import "math"
// Traditionally, an RSI reading of 70 or above indicates an overbought condition, and a reading of 30 or below indicates an oversold condition. // Traditionally, an RSI reading of 70 or above indicates an overbought condition, and a reading of 30 or below indicates an oversold condition.
// //
// Typically, the RSI is calculated with a period of 14 days. // Typically, the RSI is calculated with a period of 14 days.
func RSI(series *Series, periods int) *Series { func RSI(series *FloatSeries, periods int) *FloatSeries {
// Calculate the difference between each day's close and the previous day's close. // Calculate the difference between each day's close and the previous day's close.
delta := series.Copy().Map(func(i int, v interface{}) interface{} { delta := series.Copy().MapReverse(func(i int, v float64) float64 {
if i == 0 { if i == 0 {
return float64(0) return 0
} }
return v.(float64) - series.Value(i-1).(float64) return v - series.Value(i-1)
}) })
// Calculate the average gain and average loss. // Calculate the average gain and average loss.
avgGain := delta.Copy(). avgGain := &FloatSeries{delta.Copy().
Map(func(i int, val interface{}) interface{} { return math.Max(val.(float64), 0) }). Map(func(i int, val float64) float64 { return math.Max(val, 0) }).
Rolling(periods).Average() Rolling(periods).Average()}
avgLoss := delta.Copy(). avgLoss := &FloatSeries{delta.Copy().
Map(func(i int, val interface{}) interface{} { return math.Abs(math.Min(val.(float64), 0)) }). Map(func(i int, val float64) float64 { return math.Abs(math.Min(val, 0)) }).
Rolling(periods).Average() Rolling(periods).Average()}
// Calculate the RSI. // Calculate the RSI.
return avgGain.Map(func(i int, val interface{}) interface{} { return avgGain.Map(func(i int, val float64) float64 {
loss := avgLoss.Float(i) loss := avgLoss.Float(i)
if loss == 0 { if loss == 0 {
return float64(100) return float64(100)
} }
return float64(100. - 100./(1.+val.(float64)/loss)) return float64(100. - 100./(1.+val/loss))
}) })
} }

View File

@ -5,16 +5,16 @@ import (
) )
func TestRSI(t *testing.T) { func TestRSI(t *testing.T) {
prices := NewSeries("Prices", 1., 0., 2., 1., 3., 2., 4., 3., 5., 4., 6., 5., 7., 6.) prices := NewFloatSeries("Prices", 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6)
rsi := RSI(prices, 14) rsi := RSI(prices, 14)
if rsi.Len() != 14 { if rsi.Len() != 14 {
t.Errorf("RSI length is %d, expected 14", rsi.Len()) t.Errorf("RSI length is %d, expected 14", rsi.Len())
} }
if !EqualApprox(rsi.Float(0), 100) { if !EqualApprox(rsi.Value(0), 100) {
t.Errorf("RSI[0] is %f, expected 0", rsi.Float(0)) t.Errorf("RSI[0] is %f, expected 100", rsi.Value(0))
} }
// TODO: check the expected RSI // TODO: check the expected RSI
// if !EqualApprox(rsi.Float(-1), 61.02423) { if !EqualApprox(rsi.Value(-1), 63.157895) {
// t.Errorf("RSI[-1] is %f, expected 100", rsi.Float(-1)) t.Errorf("RSI[-1] is %f, expected 63.157895", rsi.Value(-1))
// } }
} }