Here is what I came up with. It uses a general shift function, so you can shift by negative numbers. src is the attacker square, occ is all occupied squares, dir is the attacking direction, and n is the number of iterations. So for king/knight n=1, bishop/rook/queen n=7, two-square pawn push n=2. Can it be simplified?
Code: Select all
uint64 attacks (uint64 src, uint64 occ, int dir, size_t n = 7)
{
uint64 a = uint64(0);
uint64 keep = map_lookup(table, dir, all_squares);
for (int i = 0; i < n; i++) {
src = shift(src, dir);
src &= keep;
a |= src;
if (src & occ)
break;
}
return a;
}
Code: Select all
#include <iostream>
#include <string>
#include <map>
typedef unsigned long long uint64;
template <typename MAP_T, typename SEARCH_T, typename RETURN_T>
RETURN_T map_lookup (MAP_T the_map, SEARCH_T find_me, RETURN_T default_val) {
auto found = the_map.find( find_me );
if ( found != the_map.end() )
return found->second;
return default_val;
}
uint64 mask (int i) { return uint64(1) << i; }
std::string bbstr ( uint64 bits, uint64 bits2 = uint64(0) )
{
std::string s;
for (int y = 7; y >= 0; --y)
{
for (int x = 0; x <= 7; ++x)
{
int square = (y * 8) + x;
if ( bits & mask(square) )
s += '#';
else if (bits2 & mask(square))
s += 'X';
else
s += '.';
}
s += '\n';
}
return s;
}
uint64 shift (uint64 b, int n) {
return (n > 0) ? (b << n) : (b >> -n);
}
uint64 file_a = 0x0101010101010101;
uint64 file_b = 0x0202020202020202;
uint64 file_g = 0x4040404040404040;
uint64 file_h = 0x8080808080808080;
uint64 all_squares = 0xffffffffffffffff;
std::map< int, uint64 > table = {
{ 17, ~file_a },
{ 15, ~file_h },
{ 10, ~(file_a | file_b) },
{ 9, ~file_a },
{ 7, ~file_h },
{ 6, ~(file_g | file_h) },
{ 1, ~file_a },
{ -1, ~file_h },
{ -6, ~(file_a | file_b) },
{ -7, ~file_a },
{ -9, ~file_h },
{ -10, ~(file_g | file_h) },
{ -15, ~file_a },
{ -17, ~file_h },
};
uint64 attacks (uint64 src, uint64 occ, int dir, size_t n = 7)
{
uint64 a = uint64(0);
uint64 keep = map_lookup(table, dir, all_squares);
for (int i = 0; i < n; i++) {
src = shift(src, dir);
src &= keep;
a |= src;
if (src & occ)
break;
}
return a;
}
uint64 bishop_attacks (uint64 src, uint64 occ) {
uint64 a = uint64(0);
for (int dir : {-9, -7, 7, 9})
a |= attacks(src,occ,dir,7);
return a;
}
int main (int argc, char * argv[]) {
// 28=e5, 9=b2, etc.
uint64 src = mask(28);
uint64 occupied = src | mask(9) | mask (14) | mask(54) | mask(49);
std::cout << "OCCUPIED\n";
std::cout << bbstr(occupied) << std::endl;
std::cout << "ATTACKS\n";
uint64 a = bishop_attacks(src,occupied);
std::cout << bbstr(a,src) << std::endl;
return 0;
}
Code: Select all
OCCUPIED
........
.#....#.
........
........
....#...
........
.#....#.
........
ATTACKS
........
.#.....#
..#...#.
...#.#..
....X...
...#.#..
..#...#.
.#......