I was watching a video by Yannic Kilcher about xLSTM: Extended Long Short-Term Memory when something caught my attention. At one point, Yannic describes something called bidirectional associative memory, which sounds pretty abstract and research-y and boring. The description got some gears turning in my head, though, and I arrived at a ridiculous idea. I had to try it out.
The idea
I want to take the basic idea behind bidirectional associative memory and twist it into a probabilistic data structure. The original paper calls it BAM so I’m going to go ahead and use that term instead of typing out the whole name every time.
I want to create a key-value store that can hold more data than it really should be able to.
It will hold vectors that have d
elements by storing them in a d x d
matrix.
That means I should only be able to store d
of these vectors. d
vectors with d
values each is d x d
.
But I think I can use some randomness to store more than d
vectors.
I’ll call it a Fuzzy Store.
I’m going into this assuming it will be terrible, but I want to know exactly how terrible. That means experimenting and benchmarking and researching, which are all things I enjoy.
Building a key-value store from my vague understanding of BAM
My memory starts as an empty 2 x 2
matrix.
0 0 0 0 memory 0
Now, I want to store a key-value pair in this memory. The key is a vector with two elements, and so is the value.
1 0 key
8 2 value
To store them, I need their outer product.
8 1 0 8 0 2 x = 2 0 value x transpose(key)
Now, I add that to the memory matrix.
0 0 8 0 8 0 0 0 + 2 0 = 2 0 memory 1
If I want to get my value back, all I have to do is multiply my memory by my key.
8 0 1 8 2 0 x 0 = 2
If I do the same thing with another key-value pair like
0 3 1 9 key value
my memory will end up looking like
1 ┌------ value associated with key 0 | v 8 3 2 9 ^ | 0 └--- value associated with key 1 memory 2
Retrieving a value using its key still works.
8 3 0 3 2 9 x 1 = 9
Writing this all out using notation from linear algebra helps it make a little more sense. Adding the first value to my memory looks like this multiplication.
M1 = V1 K1TIf I want to get the value back, it would be really nice to have something to cancel out that transposed key.
Luckily, I have just the thing to turn it into a 1
.
Multiplying a vector by its transpose is the same as taking the dot product of the vector with itself. Some quick intuition for the dot product is that it asks, “How much of this vector is going in the same direction as this other vector?” If you ask that about the same vector, the answer is, “All of it,” so you end up with the length (magnitude).
Technically, you get the magnitude squared.
But 1
squared is still 1
.
As long as my key vectors have magnitude 1
, I’m good to go.
I can multiply my memory by the key vector to cancel everything, leaving the value I stored in the memory.
And we can keep going! Here’s what happens when I add a second key-value pair to the memory.
M2 = M1 + V2 K2TJust like before, I’ll multiply by the second key to retrieve the second value.
M2K2 = ( V1 K1T + V2 K2T ) K2Things look a little different now that I have multiple key-value pairs in my memory. I’ll distribute K2 to cancel K2T .
V1 K1T K2 + V2 K2T K2I get my second value back, but there’s some extra stuff at the front there.
V1 K1T K2 + V2But wait, the first example stored and retrieved two values perfectly—there wasn’t any error. What’s going on?
If two vectors are perpendicular to each other (orthogonal), like my keys were in that example, their dot product is 0
.
Plugging in that 0
leaves me with the second value exactly how I stored it.
And there you have it. A key-value store.
It gets better
This probably sounds like a roundabout way to store vectors in an array.
I have my vectors in a group and then I grab the one I want based on the index of the 1
in my key.
But I don’t have to pick nice pretty vectors that are all 0
with one 1
.
If you remember from above, all I need are orthogonal vectors with magnitude 1
.
I can always normalize a vector to give it magnitude 1
.
I just need some way to generate a ton of orthogonal vectors.
The simplest way to generate orthogonal vectors is to do what I did in my first example.
Make every vector all 0
s with a single 1
, and make sure every vector has the 1
at a different index.
That’s all well and good, but there are only d
unique places to put the 1
.
I need more than d
.
Otherwise, what am I doing here?
This is where the randomness comes in.
In high dimensions, random vectors tend to be roughly orthogonal.
That gives me just enough to cobble together my Fuzzy Store.
Everything should cancel out and work perfectly
as long as I pick random high-dimensional key vectors and normalize them.
That will let me store way more than the d
vectors I should be able to.
Approximately. Fuzzily.
The only way to know whether that’s too good to be true is to give it a whirl.
Benchmarking
Here’s what retrieval accuracy looks like when I use random keys and values.
Each point on this plot represents the mean retrieval accuracy for a different Fuzzy Store.
Points on the same line use the same value for d
, meaning they have the same size memory.
Notice the x-axis!
I’m storing more than d
vectors in my Fuzzy Store.
Sort of.
As I store more things in a Fuzzy Store, it gets harder and harder to get my values back. Every time I add a key-value pair, I’m also adding another error term to that retrieval calculation. My random key vectors are only almost orthogonal, not exactly orthogonal. Each error term is pretty small, but they’re not quite zero. They add up.
To counteract that effect, I can use larger keys and values to get into higher dimensional space where my random keys will be more orthogonal. As the keys get more orthogonal, the error terms get smaller and the values I get out of the Fuzzy Store are closer to the values I put in.
It’s not very good, but it’s a little better than I expected.
A cosine similarity of 0
means I’m retrieving none of the information I stored in the Fuzzy Store.
You can see that the accuracy plummets towards 0
right from the start.
Once I get up to the 1000 x 1000
memory, it kind of works a little bit.
I can get a cosine similarity of about 0.56
even when I’m storing twice as many key-value pairs as the 1000
I should be able to.
How bad is it?
Looking at the plot of retrieval accuracy, my instinct is oof.
That intuition is great, but it’s not quantifiable. If I want to understand what’s happening and why, I need to do some analysis. And by “analysis,” I mean guessing and checking every decay function I can think of.
After about a month of experimenting, I landed on a proper fit.
There were some long digressions in there, including a week where I developed a general method to rotate any function of the form y = f(x, params)
around the origin by a single parameter theta
in a way that still works with scipy.optimize.curve_fit()
.
I didn’t end up using that, but that’s what experimenting is all about.
You can do some experimenting yourself on the plot below. It’s a bit like an interactive version of the classic xkcd 2048: Curve-Fitting.
The retrieval accuracy scales like one over the square root of the number of key-value pairs in the Fuzzy Store. In other words, O(1/√n).
Cooking up a theory
What else scales like the inverse square root? Standard error comes to mind. If you’re like me and have a vague recollection of hearing that term, but don’t really remember what makes it different from standard deviation, there’s a really good description here. The short version is that the standard error tells you how bad your guess for a population statistic will be based on the sample size you’re using to make that guess.
Here’s the basic equation for calculating the standard error of your estimated statistic.
standard_error = standard_deviation_of_population / sqrt(size_of_sample)
If you take a sample of 4
items from a population that has a standard deviation of 1
and use the mean of that sample as an estimate for the mean of the population, the standard error of your estimate is 1 / sqrt(4) = 0.5
.
You can use a larger sample size to get a smaller standard error.
Since it scales like 1 / sqrt(n)
, you need to square your sample size to halve your standard error.
To get the standard error down to 0.25
in this example, you would need a sample size of 16
since 1 / sqrt(16) = 0.25
.
I think I can tie this back to the Fuzzy Store.
In general, multiplying a random vector by a random matrix should result in a random vector.
I shouldn’t be able to predict the outcome of that multiplication.
But in the Fuzzy Store, I’m deliberately constructing the matrix in a way that helps me predict the outcome.
That predictability is essentially statistical error.
In the “population” of all matrices and vectors, the mean cosine similarity between my prediction and the result of the multiplication should be 0
.
In my sample (the Fuzzy Store), I have some error because I’m not using every possible vector and matrix. I’m only using the ones I’ve stored in the Fuzzy Store. In the Fuzzy Store, that error isn’t random. That error is the ability to get the value back using the key.
As I store more and more key-value pairs in the Fuzzy Store, my sample size gets larger and the standard error goes down.
With a larger sample size, I get a better estimate for the true population mean of cosine similarity, which should be 0
.
I have less error to work with, so I can’t use the error to store key-value pairs as well anymore.
I slowly approach the true population of completely random results at a rate of 1 / sqrt(n)
as I store more key-value pairs in the Fuzzy Store.
I think this also explains why the performance gets better when I use a larger matrix in the Fuzzy Store.
The standard deviation of the population is in the numerator of the error I have to work with.
If I look at all possible random vectors with length 1000
, they have a higher standard deviation than all possible vectors with length 10
.
Vectors can point in a lot more directions if they have a thousand dimensions to choose from instead of only ten.
This explanation is pretty hand-wavy because I don’t want to think about how to define the “direction” of a vector in a way that makes it easy to calculate the standard deviation. It makes intuitive sense to me. That’s only let me down like 50% of the time. Ship it.
What if I use orthogonal keys?
Let’s see what happens when I pick orthogonal keys instead of hoping random keys will be orthogonal enough.
I can only create d
orthogonal vectors for each Fuzzy Store, so I’ll fall back to random keys once I’ve used up all the orthogonal ones.
Using keys that are all 0
with a single 1
, like in the toy example I used to explain this concept, I can store d
key-value pairs with perfect retrieval accuracy.
Once I’ve saturated the space by using up all d
basis vectors, adding more key-value pairs knocks the retrieval accuracy off a cliff.
The smaller Fuzzy Stores don’t explode as aggressively, and seem to recover more quickly.
I don’t know why it rebounds a bit.
That behavior intrigues me.
I expected it to drop to 0
at d
key-value pairs and stay there, but it comes back up a bit after the tumble.
The smaller Fuzzy Stores do the same thing at a smaller scale.
I’ve run these simulations many times while designing these plots.
They always rebound like this, so it’s not just a fluke.
Really interesting.
Another time, I suppose.
What did I learn?
I learned that I can, in fact, store more than d x d
things in a d x d
matrix.
Approximately.
It was also good to do some simple linear algebra again. Keeping up with the breakneck pace of machine learning research has helped my intuition, but it’s important for me to put things into practice every now and then. That’s the best way for me to remember how things actually work.
In the spirit of treating this like an actual probabilistic data structure, I’ll list some of the benefits and drawbacks to the Fuzzy Store.
Pros | Cons |
---|---|
Can store more than d x d things in a d x d matrix |
Retrieval accuracy decreases like O(1/√n) |
O(1) memory if you ignore the keys that someone else has to store. Not my problem. | That constant memory is pretty large (d2) |
O(1) computational complexity for all operations | That constant computation always involves a large matrix multiplication (d2 again) |
Honestly, that’s not the worst stat sheet I’ve ever seen. Looks pretty webscale to me.