Saturday, July 28, 2007

Re: MIT HACKMEM BITCOUNT

This is a reply to the MIT HACKMEM article that I posted earlier.
I mentioned I did not realize what they were doing in the previous article, but recently just stumbled on what they are actually doing. They are actually doing the same thing I was aiming for, however in a sneaky way.

They are actually performing the hamming algorithm using subtraction. On most architectures, subtraction and addition are close within clock cycle ranges. Though I probably should test the two solutions.

To get an overall sense of what they are doing, let's first examine 3-bit numbers. This will make it easier to extend to a 32-bit number.

Say the 3 lower bits are all 1's, therefore the number is actually 7. But we want to find the count in the 3-bit number, since all are set, the answer should be 3.

There's a technique that you can use to count the number of bits.
x - x>>1 - x>>2 as mentioned in the previous article.
So that's all we have to do:

x = 111 (value of 7 in decimal)
x>>1 = 011 (value of 3 in decimal)
x>>2 = 001 (value of 1 in decimal)
so therefore x - x>>1 - x>>2 = 7 - 3 - 1 = 3, which is the correct number of bits set.

To extend this to a 32-bit number, you have to notice one small thing.
When I did a shift over x>>1 and x>>2, there was a zero that was added.
You need to get the same effect without shifting over the bit from the 32-bit number, meaning something like this:

say x = 101 111 (in binary split into 3-bit segments)
x>>1 = 010 111
x>>2 = 001 010
x - x>>1 - x>>2 is not going to give us the effect we wanted, so we have to mask out the bits that are set that would delude us from getting the correct version.

Therefore, from the previous version the bits that were zero (0) when we used 3-bit numbers follows:
x (no zeroed bits)
x>>1 (the top (3rd bit) was zeroed out)
x>>2 (the top 2 (3rd bit and 2nd bit) was zeroed out)

Therefore, to get the 32-bit number that we want we need a correct mask, which would be to mask out the top bit for every 3-bit chunk for x>>1 and the top 2 bits for every 3-bit chunk for x>>2.
Suprisingly, if we use the octal numbering system, this is quite easy.
(The octal numbering system for zeroing out the top bit is just 03 and to zeroing out the top 2 bits is just 01;
This is because there's only 3 bits in an octal number:
111 = 7
011 = 3
001 = 1

Therefore you can think of the algorithm as actually doing something like:
x = (x&037777777777) - ((x>>1)&033333333333) - ((x>>2)&011111111111)

which is equivalent to x - ((x>>1)&033333333333) - ((x>>2)&011111111111)

And then it hammers down how many bits are set in each 3-bit chunk, and so we just do a 3-right bit shift addition and mod by 63 to get the final answer.

Here is some code in C,

#define y 0x1111111111111111ull
#define n 0x0f0f0f0f0f0f0f0full

#define MIT(x) \
( \
((HAMMING(x) + (HAMMING(x)>>4)) & (n) ) % 255 \
)
#define HAMMING(x) \
( \
( (x)&(y)) + \
(((x)>>1)&(y)) + \
(((x)>>2)&(y)) + \
(((x)>>3)&(y)) \
)

The code MIT(x) will yield the solution where x is a 64-bit number. I had to change the implementation a little for 32-bit numbers and 64-bit numbers. For 64-bit numbers, you need to hold the sums in 4-bit chunks instead, and therefore mod'ing by 63 is no longer correct, but 255 is. And the mask has also changed accordingly. (This was my way of im plementing it from the last post, but the code looks nearly equivalent:

x = x - ((x>>1)&0x7777777777777777) -
((x>>2)&0x3333333333333333) -
((x>>3)&0x1111111111111111);
x = ((x+(x>>4))&0x0f0f0f0f0f0f0f0f) % 255;

q.e.d.

Monday, July 2, 2007

fls implementation

fls stands for find last bit set for an unsigned int.
For example, the function signature looks like unsigned fls( unsigned ); meaning that it returns the index of some unsigned integer that you pass to it.

There are a lot of approaches to this. (I am not going to point out the binary search/scan across the unsigned integer, though it is the classical solution)

First and foremost, I will explain what x & ~(x-1) actually means.
When x is a power of 2, the check if x & (x-1) == 0 is a trick that is employed to verify that x is indeed a power of 2. But why? What is actually taking place...

The simple fact is that x-1 is actually taking the right most (LSB) that is a '1' and converting it to a '0' and setting all bits after (to the right) it to a '1'.

This invariably explains why if x is a power of 2 that it sets all values behind that bit to a '1' and the bit that is initially set to a '0'. Therefore, if that is the case, then what does x & ~(x-1) mean?

Well, that's simple. If x-1 yields the lsb that is set to a '0' that means that is the first value that is not the same in x and x-1. That also explains that every bit above (msb) it are the same and everything below (lsb) it are the opposite. Therefore, if you flip the bits of x-1, which is ~(x-1), and AND (&) it with x, then the bits above the lsb that is set are different, and the bits that were set to '1's below the lsb that was set get set to '0's. But the lsb that was set in x & the lsb that was a '0' in x-1 when flipped gets turned to a '1' again, and therefore is the only bit that is the same.

So there we have it, x & ~(x-1) actually yields only one bit set, namely that of lsb that was a '1'.
Therefore, if we know this, there is a simple algorithm to perform:


unsigned fls( unsigned x ) {
return (unsigned)log2( (double) (x & ~(x-1) );
}


This will give us the index of the last bit that is set.

However, this case is slow because it uses the SIMD on Intel's architecture. Context switches and first-use faults yields this operation to be very slow, compared to the other solutions that exist that obviously try to speed the fls() operation up.

However, there is another solution. A solution that is quite unique. Something that is quite a hack, but the speed is incredibly astounding. Let's consider a char, call it c.
If c is split into 3-bit chunks, and only one bit is set in c, then it's rather easy to find the index by dividing by 2. For example,

100 = 4
4/2 = 2 which happens to be the index of the bit set

010 = 2
2/2 = 1 which happens to be the index of the bit set

001 = 1
1/2 = 0 which happens to be the index of the bit set

Therefore, we proved it for small numbers, shifting with correct values will give the correct result by adding the shift offset. This is index quite fast; however dividing by 2 is the same as right shifting by 1. So instead of right shifting by multiples of 3 and then dividing by 2, we can right shift by the multiple of 3 + 1. The implementation is below:

#define BITS 8
unsigned fls( unsigned x ) {
int y = 0;
x &= ~(x-1);
for( int i = 0; i < y; ++i ) {
y = (x>>(BITS*i)) & 0xff;

// make sure it's not zero (0)
if( !y ) continue;

int ret = (y&07) ? (y&amp;amp;amp;07)>>1 :
(y&070) ? ((y&070)>>4) + 3 :
(y&0700) ? ((y&0700)>>7) + 6;

return ret + (BITS*i);
}
return 0;
}


Therefore, (BITS*i) shifts the starting index to start on the valid byte assignment.

MIT HAKMEM BITCOUNT re-do:

Ok, so there's this famous question of how to count the bits in a number.
This is easily done by looping through the number & 0x1 and check if it's nonzero and increment the counter, but let's do it with constant memory and in constant time.

I've read through a lot of tutorials that tried to explain this, and it's overly complicated, imho.
I can give a simpler solution that makes sense to the dumbest of people. It uses almost the same concepts as the MIT HAKMEM solution.

So I guess I should give the HAKMEM MIT AI Labs Bit Count solution first:
(Let's assume the number you want to count is always the variable x)
x = x - ((x>>1)&033333333333) - ((x>>2)&011111111111);
x = (x+(x>>3)) & 030707070707) % 63;
x now holds the bit count.

This technique uses the fact that if you have a 3-bit number then all you have to do is:
x - (x>>1) - (x>>2)
so if you have a 32 bit number it's just:
x - (x>>1) - ... - (x>>31), or more generally if you have k-bit number:
x - (x>>1) - ... - (x>>(k-1))

Somehow, x>>1 & 033333333333 is supposed to give you the values x>>1, x>>3, ...
and x>>2 & 01111111111 is supposed to give you the values x>>2, x>>4, ...
therefore, it will yield x - (x>>1) - ... - (x>>31) and yield the solution, almost.

There is one last trick that they do. The solution is actually stored in 3-bit chunks.
However, we want to extract them. One way is to store them in 6-bit chunks where the first 3-bit chunks are all zero (0), then we can just take the number and mod it by 63 (2^6 -1) and it will give us the result, which is exactly what this algorithm does. However, to get it in 6-bit chunks where you have the LSBits as the 3-bit chunks you have to mask the sum(s) with something to the effect of 00707070707... however, the high end order actually only has 2-bit left, so that's why the mask is 03070707...

(But I never really got how they get (x>>1) + (x>>3) + (x>>5) by doing x>>1 & 03333....)
But that's the moral of the story.

That's why I came up with another implementation:

My implementation is a little easier if you are familiar with the hamming weight algorithm.
What you do is just add the weights of every other bit in the 3-bit sequence, sum them up, mask it with the 030707... and then mod by 63 again.

Therefore, here is my solution with the utilization of the hamming weight algorithm:
(again assuming the value is in x):

x = (x&amp;011111111111) + ((x>>1)&011111111111) + ((x>>2)&011111111111);
x = ((x+(x>>3)) & 030707070707) % 63;


unsigned count( unsigned x ) {
x = (x&011111111111) +
((x>>1) & 011111111111) +
((x>>2)&011111111111);
return ((x+(x>>3))&030707070707) % 63;
}



The hamming weight algorithm can be looked like this:
0101 0111 1110 0001 = x
If you & 011111.... with x
0101 0111 1110 0001
1001 0010 0100 1001
Which will set every 3rd bit (starting from last) to a one or zero (everything else is zero)
If you do the same with x>>1, this will set every 3rd bit to one or zero (everything is zero) starting with the 2nd to last bit, and if you do the same with x>>2 it will respectively start with the 3rd to last bit.

Therefore if you add them, you will get the sum of the 3-bit chunks. However, you have shift to the right by 3, so that you can get the 3-bit + 3-bit chunks into a 6-bit chunk and then with the mask zero out the upper 3-bit chunk of this new 6-bit chunk. Because each of the previous 3-bit chunk held the correct value:

001 011 (looking at it like 3-bit chunks)
010 000
011 011 (2+1=3 and 3+0=3, but we want to merge them, so shift the result by 3 and add it to itself)
000 011
011 110 is the result in the right 3-bit set, so let's mask out the top 3-bit set since that won't give us anything we want. (since we are doing (0 3-bit + 1 3-bit ) + (2 3-bit + 3 3-bit)...)
therefore we mask it with 30707070707.... and when we mod it by 63, it takes the 6 bit chunks and adds the values, essentially. (since the top (3-bit) order is always 0)