Misapplied Math

Trading, Data Science, CS

Day Eight: LASSO Regression

TL/DR

LASSO regression (least absolute shrinkage and selection operator) is a modified form of least squares regression that penalizes model complexity via a regularization parameter. It does so by including a term proportional to βl1||\beta||_{l_1} in the objective function which shrinks coefficients towards zero, and can even eliminate them entirely. In that light, LASSO is a form of feature selection/dimensionality reduction. Unlike other forms of regularization such as ridge regression, LASSO will actually eliminate predictors. It's a simple, useful technique that performs quite well on many data sets.

Regularization

Regularization refers to the process of adding additional constraints to a problem to avoid over fitting. ML techniques such as neural networks can generate models of arbitrary complexity that will fit in-sample data one-for-one. As we recently saw in the post on Reed-Solomon FEC codes, the same applies to regression. We definitely have a problem anytime there are more regressors than data points, but any excessively complex model will generalize horribly and do you zero good out of sample.

Why LASSO?

There's a litany of regularization techniques for regression, ranging from heuristic, hands-on ones like stepwise regression to full blown dimensionality reduction. They all have their place, but I like LASSO because it works very well, and it's simpler than most dimensionality reduction/ML techniques. And, despite being a non-linear method, as of 2008 it has a relatively efficient solution via coordinate descent. We can solve the optimization in O(np)\O(n\cdot p) time, where nn is the length of the data set and pp is the number of regressors.

An Example

Our objective function has the form:

12i(yixTβ)2+λj=1pβj\frac{1}{2} \sum_i(y_i - \mathbb{x}^T\beta)^2 + \lambda\sum_{j = 1}^p|\beta_j|

where λ0\lambda \geq 0. The first half of the equation is just the standard objective function for least squares regression. The second half penalizes regression coefficients under the l1l_1 norm. The parameter λ\lambda determines how important the penalty on coefficient weights is.

There are two R packages that I know of for LASSO: lars (short for least angle regression – a super set of LASSO) and glmnet. Glmnet includes solvers for more general models (including elastic net – a hybrid of LASSO and ridge that can handle catagorical variables). Lars is simpler to work with but the documentation isn't great. As such, here are a few points worth noting:

  1. The primary lars function generates an object that's subsequently used to generate the fit that you actually want. There's a computational motivation behind this approach. The LARS technique works by solving for a series of "knot points" with associated, monotonically decreasing values of λ\lambda. The knot points are subsequently used to compute the LASSO regression for any value of λ\lambda using only matrix math. This makes procedures such as cross validation where we need to try lots of different values of λ\lambda computationally tractable. Without it, we would have to recompute an expensive non-linear optimization each time λ\lambda changed.
  2. There's a saturation point at which λ\lambda is high enough that the null model is optimal. On the other end of the spectrum, when λ=0\lambda = 0, we're left with least squares. The final value of λ\lambda on the path, right before we end up with least squares, will correspond to the largest coefficient norm. Let's call these coefficients βthresh\beta_\text{thresh}, and denote Δ=βthreshl1\Delta = || \beta_\text{thresh} ||_{l_1}. When the lars package does cross validation, it does so by computing the MSE for models where the second term in the objective function is fixed at xΔ, x[0,1]x \cdot \Delta,\ x \in [0, 1]. This works from a calculation standpoint (and computationally it makes things pretty), but it's counter intuitive if you're interested in the actual value of λ\lambda and not just trying to get the regression coefficients. You could easily write your own cross validation routine to use λ\lambda directly.
  3. The residual sum of squared errors will increase monotonically with λ\lambda. This makes sense as we're trading off between minimizing RSS and the model's complexity. As such, the smallest RSS will always correspond to the smallest value of λ\lambda, and not necessarily the optimal one.

Here's a simple example using data from the lars package. We'll follow a common heuristic that recommends choosing λ\lambda one SD of MSE away from the minimum. Personally I prefer to examine the CV L-curve and pick a value right on the elbow, but this works.

lasso_regression.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
require(lars)
data(diabetes)

# Compute MSEs for a range of coefficient penalties expressed as a fraction 
# of the final L1 norm on the interval [0, 1].
cv.res <- cv.lars(diabetes$x, diabetes$y, type = "lasso", 
	mode = "fraction", plot = FALSE)

# Choose an "optimal" value one standard deviation away from the 
# minimum MSE.
opt.frac <- min(cv.res$cv) + sd(cv.res$cv)
opt.frac <- cv.res$index[which(cv.res$cv < opt.frac)[1]]

# Compute the LARS path
lasso.path <- lars(diabetes$x, diabetes$y, type = "lasso")

# Compute a fit given the LARS path that we precomputed, and our optimal 
# fraction of the final L1 norm
lasso.fit <- predict.lars(lasso.path, type = "coefficients", 
	mode = "fraction", s = opt.frac)

# Extract the final vector of regression coefficients
coef(lasso.fit)

Final Notes

LASSO is a biased, linear estimator whose bias increases with λ\lambda. It's not meant to provide the "best" fit as Gauss-Markov defines it – LASSO aims to find models that generalize well. Feature selection is hard problem and the best that we can do is a combination of common sense and model inference. However, no technique will save you from the worst case scenario: two very highly correlated variables, one of which is a good predictor, the other of which is spurious. It's a crap shoot as to which predictor a feature selection algorithm would penalize in that case. LASSO has a few technical issues as well. Omitted variable bias is still an issue as it is in other forms of regression, and because of its non-linear solution, LASSO isn't invariant under transformations of original data matrix.


Day Seven: Sensor Fusion

TL/DR

Sensor fusion is a generic term for techniques that address the issue of combining multiple noisy estimates of state in an optimal fashion. There's a straight forward view of it as the gain on a Kalman–Bucy filter, and an even simpler interpretation under the central limit theorem.

A Primer on Stochastic Control

Control theory is one of my favorite fields with a ton of applications. As the saying goes, "if all you have is a hammer, everything looks like a nail," and for me I'm always looking for ways to pose a problem as a state space and use the tools of control theory. Control theory gets you everything from cruise control and auto pilot to the optimal means of executing an order under some set of volatility and market impact assumptions. The word "sensor" is general and can mean anything that produces a time series of values – it need not be a physical one like a GPS or LIDAR, but it certainly can be.

Estimating state is a pillar of control theory; before you can apply any sort of control feedback you need to know both what your system is currently doing and what you want it to be doing. What you want it to do is a hard problem in and of itself as the what requires you to figure out an optimal action given your current state, the cost of applying the control, and some (potentially infinite) time horizon. The currently doing part isn't a picnic either as you'll usually have to figure out "where you are" given a set of noisy measurements past and present; that's the problem of state estimation.

The Kalman filter is one of many approaches to state estimation, and the optimal one under some pretty strict and (usually) unrealistic assumptions (the model matches the system perfectly, all noise is stationary IID Gaussian, and that the noise covariance matrix known a priori). That said, the Kalman filter still performs well enough to enjoy widespread use, and alternatives such as particle filters are computationally intensive and have their own issues.

Sensor Fusion

Awhile back I discussed the geometric interpretation of signal extraction in which we addressed a similar problem. Assume that we have two processes generating normally distributed IID random values, X=(μ1,σ1)X = (\mu_1, \sigma_1) and Y=(μ2,σ2)Y = (\mu_2, \sigma_2). We can only observe Z=X+YZ = X + Y, but what we want XX, so the best that we can do is E[XZ=c]\E[X | Z = c]. As it turns out the solution has a pretty slick interpretation under the geometry of linear regression. Sensor fusion addresses a more general problem: given a set of measurements from multiple sensors, each one of them noisy, what's the best way to produce a unified estimate of state? The sensor noise might be correlated and/or time varying, and each sensor might provide a biased estimate of the true state. Good times.

Viewing each sensor independently brings us back to the conditional expectation that we found before (assuming that the sensor has normally distributed noise of constant variance). If we know the sensor noise a priori (the manufacturer tells us that σ=1m\sigma = 1m on a GPS, for example) it's easy to compute E[XZ=c]\E[X | Z = c], where XX is our true state, YY is the sensor noise, and ZZ is what we get to observe. In this context it's easy to see that we could probably just appeal to the central limit theorem, average across the state estimates using an inverse variance weighting, and call it a day. Given that we have a more detailed knowledge of the process and measurement model, can we do better?

A Simple Example

Let's consider the problem of modeling a Gaussian process with μ=100\mu = 100 and σ=2\sigma = 2. We have three sensors with σ1=.6\sigma_1 = .6, σ2=.7\sigma_2 = .7, and σ3=.8\sigma_3 = .8. Sensor one has a correlation of r12=.3r_{12} = .3 with sensor two, a correlation of r13=.1r_{13} = .1 with sensor three, and sensor two has a correlation of r23=.1r_{23} = .1 with sensor three. Assume that they have a bias of .1, -.2, and 0, respectively.

Our process and measurement models are x˙=Ax+Bu+w\dot{x} = Ax + Bu + w with wN(0,Q)w \sim N(0, Q) and y=Cx+vy = Cx + v with vN(0,R)v \sim N(0, R), respectively. For our simple Gaussian process that gives:

A=0B=0C=[111]TQ=4R=[.36.126.048.126.49.112.048.112.64]\begin{aligned} A &= 0 \\ B &= 0 \\ C &= \left[ \begin{array}{ccc} 1 & 1 & 1 \end{array} \right]^T \\ Q &= 4 \\ R &= \left[ \begin{array}{ccc} .36 & .126 & .048 \\ .126 & .49 & .112 \\ .048 & .112 & .64 \end{array} \right] \end{aligned}

From there we can use the dse package in R to compute our Kalman gain state estimate via sensor fusion. In many cases we would need to estimate the parameters of our model. That's a separate problem known as system identification and there are several R packages (dse included) that help with this. Since we're simulating data and working with known parameters we'll skip that step.

sensor_fusion.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
require(MASS)
require(dse)
require(reshape2)
require(ggplot2)

samples <- 1000
q <- 2
z0 <- 100
sigma1 <- .6
sigma2 <- .7
sigma3 <- .8
r12 <- .3
r13 <- .1
r23 <- .2
bias1 <- .1
bias2 <- -.2
bias3 <- 0
R <- matrix(nrow = 3, ncol = 3)
R <- rbind(c(sigma1^2, r12 * sigma1 * sigma2, r13 * sigma1 * sigma3), 
			c(r12 * sigma1 * sigma2, sigma2^2, r23 * sigma2 * sigma3), 
			c(r13 * sigma1 * sigma3, r23 * sigma2 * sigma3, sigma3^2))


path <- z0 + cumsum(rnorm(samples, mean = 0, sd = q))
observed <- path + mvrnorm(n = samples, c(bias1, bias2, bias3), R)

ss.model <- SS(F = as.matrix(1), Q = as.matrix(q^2), 
	H = as.matrix(c(1, 1, 1)), R = R, z0 = z0) 
smoothed.model <- smoother(ss.model, TSdata(output = observed))
state.est <- smoothed.model$smooth$state

rmsd <- function(actual, estimated) {
    sqrt(mean((actual - estimated)^2))
}

est.rmsd <- c(rmsd(path, state.est), 
	apply(observed, 2, function(x) rmsd(path, x)))

Plotting the first 50 data points gives:

Kalman Tracking Errors

It's a little hard to tell what's going on but you can probably squint and see that the fusion sensor is tracking the best, and that sensor three (the highest variance one) is tracking the worst. Computing the RMSD gives:

  • Kalman: .49
  • Sensor 1: .59
  • Sensor 2: .71
  • Sensor 3: .79

Note that the individual sensors have an RMSD almost identical to their measurement error. That's exactly what we would expect. And, as we expected, the sensor fusion estimate does better than any of the individual ones. Because our sensor errors were positively correlated we made things harder on ourselves. Re-running the simulation without correlation consistently gives a Kalman RMSD of .40\approx .40. How did the bias impact our simulation? Calculating bias=yy^¯\text{bias} = \bar{y - \hat{y}} gives:

  • Kalman: -.04
  • Sensor 1: -.08
  • Sensor 2: .22
  • Sensor 3: .003

The Kalman filter was able to significantly overcome the bias in sensors one and two while still reducing variance. I specifically chose the bias-free sensor as the one with the most variance to make things as hard as possible. This helps to illustrate one very cool property of Kalman sensor fusion – the ability to capitalize on the bias-variance trade-off and mix biased estimates with unbiased ones.


Day Six: Persistent Data Structures

TL/DR

Persistent data structures have nothing to do with disks, durable storage, or databases. They're an (externally) immutable collection born out of functional programming, but they have great use cases for any programming paradigm. Immutability helps greatly in multi-threaded environments, and regardless of the threading model used they're a natural means of adding versioning and snapshot functionality to a collection. This proves useful for everything from synchronizing data in distributed systems to implementing an undo operation. Clojure relies on them extensively and they're included in the standard libraries of many other functional languages.

High Level Overview

Persistent data structures can be "fully persistent" or "partially persistent." Fully persistent data structures allow for changes to any previous version and keep track of those versions via a change graph. Partially persistent ones only allow for changes to the current state, but read-only access to any previous state. We'll focus on partially persistent ones.

One of the simplest examples of a persistent data structure is a collection with copy-on-write semantics. If you've used java's CopyOnWriteArrayList you've already seen this in action. Get operations are only slightly more expensive than ArrayList access (the backing array is volatile so there's an mfence whenever it's accessed, which happens once per get operation, and once for the entirety of an iteration). However, all operations that mutate do so by locking and copying the backing array in entirety before swapping out the old array for the new one. This ensures that any other thread working with the list or iterating over it will see a consistent view of the backing array. There's two problems with this approach: 1) It's prohibitively expensive if the collection changes frequently, and 2) it still relies on locks to mutate state. Before addressing concurrency, let's see what we can do about the first problem.

Persistence via Path Copy

Without too much work we can do a little better than copy-on-write for most use cases via path copying:

Tree path copy

We start by structuring our collection as a d-ary tree (or a trie if ordering matters). When we need to mutate state we copy all nodes along the path containing the mutation. We then work backwards from the point of the change, fixing up references along the way so that everything along the mutated path holds a reference to the newly created node reflecting our insertion/deletion.

You can represent pretty much any standard type of collection as a tree (maps, sets, lists…), albeit with some level of inefficiency. For something like a vector the degenerate case is just a d-ary tree, where d is the length of the vector, leading us back to the copy-on-write semantics described above. As such, the degree of branching trades off between time complexity for access and mutation.

Real World

In practice, there are more complex but way more efficient ways to go. For arrays, an efficient implementation is described in this paper. CTries are efficient and concurrent implementations of hash array mapped tries. For further study, Rich Hickey, the creator of Clojure gave a very nice presentation on persistent data structures at QCon (Clojure uses them for all of its mutable collections). It's worth noting that making lock free persistent collections is really hard without garbage collection, so this is one area where managed runtimes have a big advantage. What about atomic reference counting and shared_ptr? They're pretty expensive and doesn't scale well (especially on x86's strong memory model), which is one of the reasons why java uses a graph tracing garbage collector instead of something more fine grained.

A Motivating Example

In my work with distributed systems, synchronizing state is almost always a core concern. How do you get a "late joiner" caught up with the current state of the universe? Reading from logs is one possibility. Another involves creating a snapshot whenever a client asks for it. Assume that you have a single threaded, event driven architecture in which the server handles one request at a time. Also assume that all messages have a unique sequence number (so gap filling and detecting missing messages for retransmission is handled elsewhere). Here's the problem:

  1. Client joins the network and starts queuing messages.
  2. Client: "Hey, I need the starting state of collection A."
  3. Server: "I'm on message 50 now, and here's your collection!" The server then stops the world to perform an expensive serialization operation.
  4. Client: "Thanks…I'll take collections B-Z while you're at it."
  5. Server: "A little busy right now…"
  6. Client: "Tough luck. Do it."

This isn't going to scale well. As an alternative we could have the client replay all messages that ever mutated the collection, but that might require lots of time and bandwidth. Is there a good compromise that still gives us total consistency?

The solution involves persistent collections. Whenever a client needs a snapshot, the server kicks off a worker thread and uses a memory fence to hand off the reference and ensure that all updates to that version of the collection are visible. Then, the worker thread handles serialization and transmission while the server proceeds with business as usual. If we explicitly add snapshot functionality to our collection we don't even need thread safety as the API should ensure that we're holding an immutable reference to a specific point in time, and not just the current head of the collection.

This is a pretty awesome way to accomplish a lot of tasks, even mundane ones like ensuring that a GUI or web client displays a table that's always in sync with the server. For more advanced use cases such as snapshot/replay based recovery we can extend this concept and have the server take a snapshot every n messages or minutes. Each time it does so the server stores a reference to the collection in a list. Then, when a client needs to get caught up or recover from an outage it can ask for a snapshot as-of a specific time/sequence number. The server would then return the most recent snapshot before the requested version, and the client would replay all subsequent messages to recover state. Datomic applies this model to the database world to greatly simplify many use cases and allow for queries against any point in history.