From 576d1bfe2b3a9576058712b8309dbbc8be2a096b Mon Sep 17 00:00:00 2001 From: tx7do Date: Sun, 28 Apr 2024 09:42:53 +0800 Subject: [PATCH] feat: add math utils. --- math/gaussian.go | 134 ++++++++++++++++++++++++++++++++++++++++++ math/gaussian_test.go | 47 +++++++++++++++ math/math.go | 44 ++++++++++++++ math/math_test.go | 30 ++++++++++ 4 files changed, 255 insertions(+) create mode 100644 math/gaussian.go create mode 100644 math/gaussian_test.go create mode 100644 math/math.go create mode 100644 math/math_test.go diff --git a/math/gaussian.go b/math/gaussian.go new file mode 100644 index 0000000..2fe6f0e --- /dev/null +++ b/math/gaussian.go @@ -0,0 +1,134 @@ +package math + +import ( + "math" +) + +// prop +//mean: the mean (μ) of the distribution +//variance: the variance (σ^2) of the distribution +//standardDeviation: the standard deviation (σ) of the distribution + +// combination + +type Gaussian struct { + mean float64 + variance float64 + standardDeviation float64 +} + +func NewGaussian(mean, variance float64) *Gaussian { + if variance <= 0.0 { + panic("error") + } + + return &Gaussian{ + mean: mean, + variance: variance, + standardDeviation: math.Sqrt(float64(variance)), + } +} + +// Erfc Complementary error function +// From Numerical Recipes in C 2e p221 +func Erfc(x float64) float64 { + z := math.Abs(x) + t := 1 / (1 + z/2) + r := t * math.Exp(-z*z-1.26551223+t*(1.00002368+ + t*(0.37409196+t*(0.09678418+t*(-0.18628806+ + t*(0.27886807+t*(-1.13520398+t*(1.48851587+ + t*(-0.82215223+t*0.17087277))))))))) + if x >= 0 { + return r + } else { + return 2 - r + } +} + +// Ierfc Inverse complementary error function +// From Numerical Recipes 3e p265 +func Ierfc(x float64) float64 { + if x >= 2 { + return -100 + } + if x <= 0 { + return 100 + } + var xx float64 + if x < 1 { + xx = x + } else { + xx = 2 - x + } + t := math.Sqrt(-2 * math.Log(xx/2)) + r := -0.70711 * ((2.30753+t*0.27061)/ + (1+t*(0.99229+t*0.04481)) - t) + + for j := 0; j < 2; j++ { + e := Erfc(r) - xx + r += e / (1.12837916709551257*math.Exp(-(r*r)) - r*e) + } + + if x < 1 { + return r + } else { + return -r + } + +} + +// fromPrecisionMean Construct a new distribution from the precision and precisionmean +func fromPrecisionMean(precision, precisionmean float64) *Gaussian { + return NewGaussian(precisionmean/precision, 1/precision) +} + +/// PROB + +// Pdf pdf(x): the probability density function, which describes the probability +// of a random variable taking on the value x +func (g *Gaussian) Pdf(x float64) float64 { + m := g.standardDeviation * math.Sqrt(2*math.Pi) + e := math.Exp(-math.Pow(x-g.mean, 2) / (2 * g.variance)) + return e / m +} + +// Cdf cdf(x): the cumulative distribution function, +// which describes the probability of a random +// variable falling in the interval (−∞, x] +func (g *Gaussian) Cdf(x float64) float64 { + return 0.5 * Erfc(-(x-g.mean)/(g.standardDeviation*math.Sqrt(2))) +} + +// Ppf ppf(x): the percent point function, the inverse of cdf +func (g *Gaussian) Ppf(x float64) float64 { + return g.mean - g.standardDeviation*math.Sqrt(2)*Ierfc(2*x) +} + +// Add add(d): returns the result of adding this and the given distribution +func (g *Gaussian) Add(d *Gaussian) *Gaussian { + return NewGaussian(g.mean+d.mean, g.variance+d.variance) +} + +// Sub sub(d): returns the result of subtracting this and the given distribution +func (g *Gaussian) Sub(d *Gaussian) *Gaussian { + return NewGaussian(g.mean-d.mean, g.variance+d.variance) +} + +// Scale scale(c): returns the result of scaling this distribution by the given constant +func (g *Gaussian) Scale(c float64) *Gaussian { + return NewGaussian(g.mean*c, g.variance*c*c) +} + +// Mul mul(d): returns the product distribution of this and the given distribution. If a constant is passed in the distribution is scaled. +func (g *Gaussian) Mul(d *Gaussian) *Gaussian { + precision := 1 / g.variance + dprecision := 1 / d.variance + return fromPrecisionMean(precision+dprecision, precision*g.mean+dprecision*d.mean) +} + +// Div div(d): returns the quotient distribution of this and the given distribution. If a constant is passed in the distribution is scaled by 1/d. +func (g *Gaussian) Div(d *Gaussian) *Gaussian { + precision := 1 / g.variance + dprecision := 1 / d.variance + return fromPrecisionMean(precision-dprecision, precision*g.mean-dprecision*d.mean) +} diff --git a/math/gaussian_test.go b/math/gaussian_test.go new file mode 100644 index 0000000..af63fb5 --- /dev/null +++ b/math/gaussian_test.go @@ -0,0 +1,47 @@ +package math + +import ( + "fmt" + "testing" +) + +func TestGaussian(t *testing.T) { + g := NewGaussian(3.0, 1) + + fmt.Printf("g: %#v\n", g) + fmt.Printf("pdf: %f\n", g.Pdf(5)) + fmt.Printf("cdf: %f\n", g.Cdf(2)) + fmt.Printf("ppf: %f\n", g.Ppf(5)) + + d := NewGaussian(0, 1) + fmt.Printf("ppf: %f, %f\n", d.Pdf(-2), 0.053991) + fmt.Printf("ppf: %f, %f\n", d.Pdf(-1), 0.241971) + fmt.Printf("ppf: %f, %f\n", d.Pdf(0), 0.398942) + fmt.Printf("ppf: %f, %f\n", d.Pdf(1), 0.241971) + fmt.Printf("ppf: %f, %f\n", d.Pdf(2), 0.053991) + + fmt.Printf("cdf: %f, %f\n", d.Cdf(-1.28155), 0.1) + fmt.Printf("cdf: %f, %f\n", d.Cdf(-0.67499), 0.25) + fmt.Printf("cdf: %f, %f\n", d.Cdf(0), 0.5) + fmt.Printf("cdf: %f, %f\n", d.Cdf(0.67499), 0.75) + fmt.Printf("cdf: %f, %f\n", d.Cdf(1.28155), 0.9) + + fmt.Printf("ppf: %f, %f\n", d.Ppf(0.1), -1.28155) + fmt.Printf("ppf: %f, %f\n", d.Ppf(0.25), -0.67499) + fmt.Printf("ppf: %f, %f\n", d.Ppf(0.5), 0.0) + fmt.Printf("ppf: %f, %f\n", d.Ppf(0.75), 0.67449) + fmt.Printf("ppf: %f, %f\n", d.Ppf(0.9), 1.28155) + + d = d.Mul(NewGaussian(0, 1)) + fmt.Printf("Mul: %#v\n", d) + fmt.Printf("%#v\n%#v", NewGaussian(1, 1).Scale(2), NewGaussian(2, 4)) + + d = NewGaussian(1, 1).Div(NewGaussian(1, 2)) + fmt.Printf("div\n") + fmt.Printf("%#v\n%#v\n", d, NewGaussian(1, 2)) + fmt.Printf("%#v\n%#v\n", NewGaussian(1, 1).Scale(1/(1.0/2.0)), NewGaussian(2, 4)) + + fmt.Printf("ADD:\n%#v\n%#v\n", NewGaussian(1, 1).Add(NewGaussian(1, 2)), NewGaussian(2, 3)) + fmt.Printf("SUB:\n%#v\n%#v\n", NewGaussian(1, 1).Sub(NewGaussian(1, 2)), NewGaussian(0, 3)) + fmt.Printf("SCALE:\n%#v\n%#v\n", NewGaussian(1, 1).Scale(2), NewGaussian(2, 4)) +} diff --git a/math/math.go b/math/math.go new file mode 100644 index 0000000..c03b55a --- /dev/null +++ b/math/math.go @@ -0,0 +1,44 @@ +package math + +import ( + "math" +) + +// Sign 符号函数(Sign function,简称sgn)是一个逻辑函数,用以判断实数的正负号。为避免和英文读音相似的正弦函数(sine)混淆,它亦称为Signum function。 +func Sign[T int | int8 | int16 | int32 | int64 | float32 | float64](x T) T { + switch { + case x < 0: // x < 0 : -1 + return -1 + case x > 0: // x > 0 : +1 + return +1 + default: // x == 0 : 0 + return 0 + } +} + +// Mean 计算给定数据的平均值 +func Mean(num []float64) float64 { + var count = len(num) + var sum float64 = 0 + for i := 0; i < count; i++ { + sum += num[i] + } + return sum / float64(count) +} + +// Variance 使用平均值计算给定数据的方差 +func Variance(mean float64, num []float64) float64 { + var count = len(num) + var variance float64 = 0 + for i := 0; i < count; i++ { + variance += math.Pow(num[i]-mean, 2) + } + return variance / float64(count) +} + +// StandardDeviation 使用方差计算给定数据的标准偏差 +func StandardDeviation(num []float64) float64 { + var mean = Mean(num) + var variance = Variance(mean, num) + return math.Sqrt(variance) +} diff --git a/math/math_test.go b/math/math_test.go new file mode 100644 index 0000000..8e22372 --- /dev/null +++ b/math/math_test.go @@ -0,0 +1,30 @@ +package math + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSign(t *testing.T) { + assert.True(t, Sign(2) == 1) + assert.True(t, Sign(-2) == -1) + assert.True(t, Sign(0) == 0) + + assert.True(t, Sign(int64(2)) == 1) + assert.True(t, Sign(int64(-2)) == -1) + assert.True(t, Sign(int64(0)) == 0) + + assert.True(t, Sign(float32(2)) == 1) + assert.True(t, Sign(float32(-2)) == -1) + assert.True(t, Sign(float32(0)) == 0) + + assert.True(t, Sign(float64(2)) == 1) + assert.True(t, Sign(float64(-2)) == -1) + assert.True(t, Sign(float64(0)) == 0) +} + +func TestStandardDeviation(t *testing.T) { + assert.Equal(t, StandardDeviation([]float64{3, 5, 9, 1, 8, 6, 58, 9, 4, 10}), 15.8117045254457) + assert.Equal(t, StandardDeviation([]float64{1, 3, 5, 7, 9, 11, 2, 4, 6, 8}), 3.0397368307141326) +}