From ed08bd275d1607e6d8a4f6e327f98cf08c7a8ab5 Mon Sep 17 00:00:00 2001 From: "Luke I. Wilson" Date: Sat, 20 May 2023 13:01:12 -0500 Subject: [PATCH] Use type safety for RSI indicator --- indicators.go | 24 ++++++++++++------------ indicators_test.go | 12 ++++++------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/indicators.go b/indicators.go index a535311..32d11af 100644 --- a/indicators.go +++ b/indicators.go @@ -7,28 +7,28 @@ import "math" // Traditionally, an RSI reading of 70 or above indicates an overbought condition, and a reading of 30 or below indicates an oversold condition. // // Typically, the RSI is calculated with a period of 14 days. -func RSI(series *Series, periods int) *Series { +func RSI(series *FloatSeries, periods int) *FloatSeries { // Calculate the difference between each day's close and the previous day's close. - delta := series.Copy().Map(func(i int, v interface{}) interface{} { + delta := series.Copy().MapReverse(func(i int, v float64) float64 { if i == 0 { - return float64(0) + return 0 } - return v.(float64) - series.Value(i-1).(float64) + return v - series.Value(i-1) }) // Calculate the average gain and average loss. - avgGain := delta.Copy(). - Map(func(i int, val interface{}) interface{} { return math.Max(val.(float64), 0) }). - Rolling(periods).Average() - avgLoss := delta.Copy(). - Map(func(i int, val interface{}) interface{} { return math.Abs(math.Min(val.(float64), 0)) }). - Rolling(periods).Average() + avgGain := &FloatSeries{delta.Copy(). + Map(func(i int, val float64) float64 { return math.Max(val, 0) }). + Rolling(periods).Average()} + avgLoss := &FloatSeries{delta.Copy(). + Map(func(i int, val float64) float64 { return math.Abs(math.Min(val, 0)) }). + Rolling(periods).Average()} // Calculate the RSI. - return avgGain.Map(func(i int, val interface{}) interface{} { + return avgGain.Map(func(i int, val float64) float64 { loss := avgLoss.Float(i) if loss == 0 { return float64(100) } - return float64(100. - 100./(1.+val.(float64)/loss)) + return float64(100. - 100./(1.+val/loss)) }) } diff --git a/indicators_test.go b/indicators_test.go index a53b2d7..deead57 100644 --- a/indicators_test.go +++ b/indicators_test.go @@ -5,16 +5,16 @@ import ( ) func TestRSI(t *testing.T) { - prices := NewSeries("Prices", 1., 0., 2., 1., 3., 2., 4., 3., 5., 4., 6., 5., 7., 6.) + prices := NewFloatSeries("Prices", 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6) rsi := RSI(prices, 14) if rsi.Len() != 14 { t.Errorf("RSI length is %d, expected 14", rsi.Len()) } - if !EqualApprox(rsi.Float(0), 100) { - t.Errorf("RSI[0] is %f, expected 0", rsi.Float(0)) + if !EqualApprox(rsi.Value(0), 100) { + t.Errorf("RSI[0] is %f, expected 100", rsi.Value(0)) } // TODO: check the expected RSI - // if !EqualApprox(rsi.Float(-1), 61.02423) { - // t.Errorf("RSI[-1] is %f, expected 100", rsi.Float(-1)) - // } + if !EqualApprox(rsi.Value(-1), 63.157895) { + t.Errorf("RSI[-1] is %f, expected 63.157895", rsi.Value(-1)) + } }