jeudi, décembre 08, 2011

Simple linear regression in Scala

Here is how to compute Simple linear regression in Scala.
Class LinearRegression takes in n measurements from a List[(x: Double, y: Double)] and computes the line that best fits the data according to the least squares metric.
This Scala program is the scala translation of the java program available at http://introcs.cs.princeton.edu/java/97data/LinearRegression.java.html .
class LinearRegression(val pairs: List[(Double,Double)]) { 
 val size = pairs.size
 println("pairs = " + pairs)

 // first pass: read in data, compute xbar and ybar
 val sums = pairs.foldLeft(new X_X2_Y(0D,0D,0D))(_ + new X_X2_Y(_))
 val bars = (sums.x / size, sums.y / size)

 // second pass: compute summary statistics
 val sumstats = pairs.foldLeft(new X2_Y2_XY(0D,0D,0D))(_ + new X2_Y2_XY(_, bars))

 val beta1 = sumstats.xy / sumstats.x2
 val beta0 = bars._2 - (beta1 * bars._1)
 val betas = (beta0, beta1)

 println("y = " + ("%4.3f" format beta1) + " * x + " + ("%4.3f" format beta0))

 // analyze results
 val correlation = pairs.foldLeft(new RSS_SSR(0D,0D))(_ + RSS_SSR.build(_, bars, betas))
 val R2 = correlation.ssr / sumstats.y2
 val svar = correlation.rss / (size - 2)
 val svar1 = svar / sumstats.x2
 val svar0 = ( svar / size ) + ( bars._1 * bars._1 * svar1)
 val svar0bis = svar * sums.x2 / (size * sumstats.x2)
 println("R^2                 = " + R2)
 println("std error of beta_1 = " + Math.sqrt(svar1))
 println("std error of beta_0 = " + Math.sqrt(svar0))
 println("std error of beta_0 = " + Math.sqrt(svar0bis))
 println("SSTO = " + sumstats.y2)
 println("SSE  = " + correlation.rss)
 println("SSR  = " + correlation.ssr)
}

object RSS_SSR {
 def build(p: (Double,Double), bars: (Double,Double), betas: (Double,Double)): RSS_SSR = {
  val fit = (betas._2 * p._1) + betas._1
  val rss = (fit-p._2) * (fit-p._2)
  val ssr = (fit-bars._2) * (fit-bars._2)
  new RSS_SSR(rss, ssr)
 }
}

class RSS_SSR(val rss: Double, val ssr: Double) {
 def +(p: RSS_SSR): RSS_SSR = new RSS_SSR(rss+p.rss, ssr+p.ssr)
}

class X_X2_Y(val x: Double, val x2: Double, val y: Double) {
 def this(p: (Double,Double)) = this(p._1, p._1*p._1, p._2)
 def +(p: X_X2_Y): X_X2_Y = new X_X2_Y(x+p.x,x2+p.x2,y+p.y)
}

class X2_Y2_XY(val x2: Double, val y2: Double, val xy: Double) {
 def this(p: (Double,Double), bars: (Double,Double)) = this((p._1-bars._1)*(p._1-bars._1), (p._2-bars._2)*(p._2-bars._2),(p._1-bars._1)*(p._2-bars._2))
 def +(p: X2_Y2_XY): X2_Y2_XY = new X2_Y2_XY(x2+p.x2,y2+p.y2,xy+p.xy)
}