diff --git a/examples/eratosthenes.c b/examples/eratosthenes.c index 97abe191e973dc903825849b2718935cc58d2282..287aaa350667643fa28f771d1b26c2734f3bd32c 100644 --- a/examples/eratosthenes.c +++ b/examples/eratosthenes.c @@ -28,6 +28,7 @@ # include "config.h" #endif +#include <assert.h> #include <limits.h> #include <stdio.h> #include <stdlib.h> @@ -53,7 +54,9 @@ usage(void) fprintf(stderr, "Usage: erathostenes [OPTIONS] [LIMIT]\n\n" "Options:\n" " -? Display this message.\n" - " -b SIZE Block size.\n"); + " -b SIZE Block size.\n" + " -v Verbose output.\n" + " -s No output.\n"); } static unsigned @@ -83,16 +86,24 @@ static unsigned long * vector_alloc(unsigned long size) { unsigned long end = (size + BITS_PER_LONG - 1) / BITS_PER_LONG; - unsigned long i; unsigned long *vector = malloc (end * sizeof(long)); if (!vector) - return NULL; + { + fprintf(stderr, "Insufficient memory.\n"); + exit(EXIT_FAILURE); + } + return vector; +} + +static void +vector_init(unsigned long *vector, unsigned long size) +{ + unsigned long end = (size + BITS_PER_LONG - 1) / BITS_PER_LONG; + unsigned long i; for (i = 0; i < end; i++) vector[i] = ~0; - - return vector; } static void @@ -183,39 +194,78 @@ vector_find_next (const unsigned long *vector, unsigned long bit, unsigned long will be spent converting the output to decimal). */ #define OUTPUT(n) printf("%lu\n", (n)) +static long +atosize(const char *s) +{ + char *end; + long value = strtol(s, &end, 10); + + if (value <= 0) + return 0; + + /* FIXME: Doesn't check for overflow. */ + switch(*end) + { + default: + return 0; + case '\0': + break; + case 'k': case 'K': + value <<= 10; + break; + case 'M': + value <<= 20; + break; + } + return value; +} + int main (int argc, char **argv) { - unsigned long *vector; /* Generate all primes p <= limit */ unsigned long limit; - unsigned long size; + unsigned long root; + + unsigned long limit_nbits; + + /* Represents numbers up to sqrt(limit) */ + unsigned long sieve_nbits; + unsigned long *sieve; + /* Block for the rest of the sieving. Size should match the cache, + the default value corresponds to 64 KB. */ + unsigned long block_nbits = 64L << 13; + unsigned long block_start_bit; + unsigned long *block; + unsigned long bit; - unsigned long sieve_limit; - unsigned long block_size; - unsigned long block; - + int silent = 0; + int verbose = 0; int c; - block_size = 0; - - while ( (c = getopt(argc, argv, "?b:")) != -1) + while ( (c = getopt(argc, argv, "?svb:")) != -1) switch (c) { case '?': usage(); return EXIT_FAILURE; case 'b': - { - long arg = atol(optarg); - if (arg <= 10) - { - usage(); - return EXIT_FAILURE; - } - block_size = (arg - 3) / 2; - break; - } + block_nbits = CHAR_BIT * atosize(optarg); + if (!block_nbits) + { + usage(); + return EXIT_FAILURE; + } + break; + + case 's': + silent = 1; + break; + + case 'v': + verbose++; + break; + default: abort(); } @@ -237,83 +287,102 @@ main (int argc, char **argv) return EXIT_FAILURE; } - size = (limit - 1) / 2; - - if (!block_size || block_size > size) - block_size = size; + root = isqrt(limit); + /* Round down to odd */ + root = (root - 1) | 1; + /* Represents odd numbers from 3 up. */ + sieve_nbits = (root - 1) / 2; + sieve = vector_alloc(sieve_nbits ); + vector_init(sieve, sieve_nbits); - /* FIXME: Allocates a bit for all odd numbers up to limit. Could - save a lot of memory by allocating space only for numbers up to - the square root, plus one additional block. */ - vector = vector_alloc(size); - if (!vector) - { - fprintf(stderr, "Insufficient memory.\n"); - return EXIT_FAILURE; - } - - OUTPUT(2UL); - - bit = 0; + if (verbose) + fprintf(stderr, "Initial sieve using %lu bits.\n", sieve_nbits); + + if (!silent) + printf("2\n"); if (limit == 2) return EXIT_SUCCESS; - sieve_limit = (isqrt(2*block_size + 1) - 1) / 2; - - while (bit < sieve_limit) + for (bit = 0; + bit < sieve_nbits; + bit = vector_find_next(sieve, bit + 1, sieve_nbits)) { unsigned long n = 3 + 2 * bit; - - OUTPUT(n); - /* First bit to clear corresponds to n^2, which is bit - (n^2 - 3) / 2 = n * bit + 3 (bit + 1) - */ - vector_clear_bits (vector, n, n*bit + 3*(bit + 1), block_size); + (n^2 - 3) / 2 = (n + 3) * bit + 3 + */ + unsigned long n2_bit = (n+3)*bit + 3; + + if (!silent) + printf("%lu\n", n); - bit = vector_find_next (vector, bit + 1, block_size); + vector_clear_bits (sieve, n, n2_bit, sieve_nbits); } - /* No more marking, just output the remaining primes. */ - for (; bit < block_size ; - bit = vector_find_next (vector, bit + 1, block_size)) + limit_nbits = (limit - 1) / 2; - OUTPUT(3 + 2 * bit); + if (sieve_nbits + block_nbits > limit_nbits) + block_nbits = limit_nbits - sieve_nbits; - for (block = block_size; block < size; block += block_size) + if (verbose) + { + double storage = block_nbits / 8.0; + unsigned shift = 0; + const char prefix[] = " KMG"; + + while (storage > 1024 && shift < 3) + { + storage /= 1024; + shift++; + } + fprintf(stderr, "Blockwise sieving using blocks of %lu bits (%.3g %cByte)\n", + block_nbits, storage, prefix[shift]); + } + + block = vector_alloc(block_nbits); + + for (block_start_bit = bit; block_start_bit < limit_nbits; block_start_bit += block_nbits) { unsigned long block_start; - unsigned long block_end; + + if (block_start_bit + block_nbits > limit_nbits) + block_nbits = limit_nbits - block_start_bit; - if (block + block_size > size) - /* For the final block */ - block_size = size - block; + vector_init(block, block_nbits); - block_start = 2*block + 3; - block_end = 2*(block + block_size) + 3; + block_start = 3 + 2*block_start_bit; - sieve_limit = (isqrt(block_end) - 1) / 2; - for (bit = 0; bit < sieve_limit ;) - { + if (verbose > 1) + fprintf(stderr, "Next block, n = %lu\n", block_start); + + /* Sieve */ + for (bit = 0; bit < sieve_nbits; + bit = vector_find_next(sieve, bit + 1, sieve_nbits)) + { unsigned long n = 3 + 2 * bit; + unsigned long sieve_start_bit = (n + 3) * bit + 3; - unsigned long start = n*bit + 3*(bit + 1); - if (start < block) + if (sieve_start_bit < block_start_bit) { - unsigned long k = (block + 1) / n; - start = bit + k*n; + unsigned long k = (block_start + n - 1) / (2*n); + sieve_start_bit = n * k + bit; + + assert(sieve_start_bit < block_start_bit + n); } - vector_clear_bits (vector, n, start, block + block_size); + assert(sieve_start_bit >= block_start_bit); - bit = vector_find_next (vector, bit + 1, block + block_size); + vector_clear_bits(block, n, sieve_start_bit - block_start_bit, block_nbits); + } + for (bit = vector_find_next(block, 0, block_nbits); + bit < block_nbits; + bit = vector_find_next(block, bit + 1, block_nbits)) + { + unsigned long n = block_start + 2 * bit; + if (!silent) + printf("%lu\n", n); } - for (bit = vector_find_next (vector, block, block + block_size); - bit < block + block_size; - bit = vector_find_next (vector, bit + 1, block + block_size)) - OUTPUT(3 + 2 * bit); } - return EXIT_SUCCESS; }