shard.c 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. /*
  2. * ZMap Copyright 2013 Regents of the University of Michigan
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License"); you may not
  5. * use this file except in compliance with the License. You may obtain a copy
  6. * of the License at http://www.apache.org/licenses/LICENSE-2.0
  7. */
  8. #include <stdint.h>
  9. #include <assert.h>
  10. #include <gmp.h>
  11. #include "../lib/includes.h"
  12. #include "../lib/blacklist.h"
  13. #include "shard.h"
  14. #include "state.h"
  15. void shard_init(shard_t* shard,
  16. uint8_t shard_id,
  17. uint8_t num_shards,
  18. uint8_t sub_id,
  19. uint8_t num_subshards,
  20. const cycle_t* cycle,
  21. shard_complete_cb cb,
  22. void *arg)
  23. {
  24. // Start out by figuring out the multiplication factor for this shard.
  25. // With one shard, this would just be the generator, but with n shards,
  26. // f = g^n.
  27. // Then on top of that, we want to shard internally (subshards) per
  28. // thread. With t threads, f = g^(nr).
  29. //
  30. // tot_shards = nr
  31. uint32_t tot_shards = (uint32_t) num_shards * (uint32_t) num_subshards;
  32. uint64_t num_elts = cycle->group->prime - 1;
  33. mpz_t start, generator, prime, result, power;
  34. mpz_init_set_ui(start, cycle->offset);
  35. mpz_init_set_ui(generator, cycle->generator);
  36. mpz_init_set_ui(power, tot_shards);
  37. mpz_init_set_ui(prime, cycle->group->prime);
  38. mpz_init(result);
  39. mpz_powm(result, generator, power, prime);
  40. shard->params.factor = (uint64_t) mpz_get_ui(result);
  41. shard->params.modulus = cycle->group->prime;
  42. // e = p - 1 = num_elts
  43. // begin_idx = s + tr
  44. // end_idx = [e - (e % nr) + (s + tr)] % e
  45. // = [e - (e % nr) + begin_idx] % e
  46. uint64_t begin_idx = shard_id + sub_id*num_shards;
  47. uint64_t end_idx = (num_elts - (num_elts % tot_shards) + begin_idx) % num_elts;
  48. if (end_idx >= tot_shards) {
  49. end_idx += tot_shards;
  50. end_idx %= num_elts;
  51. }
  52. mpz_powm_ui(result, generator, begin_idx + 1, prime);
  53. shard->params.first = (uint64_t) mpz_get_ui(result);
  54. shard->params.first *= cycle->offset;
  55. shard->params.first %= shard->params.modulus;
  56. mpz_powm_ui(result, generator, end_idx + 1, prime);
  57. shard->params.last = (uint64_t) mpz_get_ui(result);
  58. shard->params.last *= cycle->offset;
  59. shard->params.last %= shard->params.modulus;
  60. shard->current = shard->params.first;
  61. // Handle scanning a sample
  62. if (zsend.targets != zsend.max_index) {
  63. shard->state.max_targets = zsend.targets / num_subshards;
  64. uint32_t leftover = zsend.targets % num_subshards;
  65. if (leftover > sub_id) {
  66. shard->state.max_targets++;
  67. }
  68. } else {
  69. shard->state.max_targets = zsend.targets;
  70. }
  71. // Set the (thread) id
  72. shard->id = sub_id;
  73. // Set the callbacks
  74. shard->cb = cb;
  75. shard->arg = arg;
  76. if (shard->current - 1 >= zsend.max_index) {
  77. shard_get_next_ip(shard);
  78. }
  79. // Clear everything
  80. mpz_clear(start);
  81. mpz_clear(generator);
  82. mpz_clear(prime);
  83. mpz_clear(power);
  84. mpz_clear(result);
  85. }
  86. uint32_t shard_get_cur_ip(shard_t *shard)
  87. {
  88. return (uint32_t) blacklist_lookup_index(shard->current - 1);
  89. }
  90. static inline uint32_t shard_get_next_elem(shard_t *shard)
  91. {
  92. do {
  93. shard->current *= shard->params.factor;
  94. shard->current %= shard->params.modulus;
  95. } while (shard->current >= (1LL << 32));
  96. return (uint32_t) shard->current;
  97. }
  98. uint32_t shard_get_next_ip(shard_t *shard)
  99. {
  100. while (1) {
  101. uint32_t candidate = shard_get_next_elem(shard);
  102. if (candidate == shard->params.last) {
  103. return 0;
  104. }
  105. if (candidate - 1 < zsend.max_index) {
  106. shard->state.whitelisted++;
  107. return blacklist_lookup_index(candidate - 1);
  108. }
  109. shard->state.blacklisted++;
  110. }
  111. }