blob: 99254fedb25955c6975e42051b0324cee4d72954 [file] [log] [blame]
yuezonghe824eb0c2024-06-27 02:32:26 -07001// -*- mode: c++ -*-
2#ifndef ARES_TEST_H
3#define ARES_TEST_H
4
5#include "ares.h"
6
7#include "dns-proto.h"
8
9// Include ares internal file for DNS protocol constants
10#include "nameser.h"
11
12#include "gtest/gtest.h"
13#include "gmock/gmock.h"
14
15#ifdef HAVE_CONFIG_H
16#include "config.h"
17#endif
18#if defined(HAVE_USER_NAMESPACE) && defined(HAVE_UTS_NAMESPACE)
19#define HAVE_CONTAINER
20#endif
21
22#include <functional>
23#include <list>
24#include <map>
25#include <memory>
26#include <set>
27#include <string>
28#include <utility>
29#include <vector>
30
31namespace ares {
32
33typedef unsigned char byte;
34
35namespace test {
36
37extern bool verbose;
38extern int mock_port;
39
40// Process all pending work on ares-owned file descriptors, plus
41// optionally the given set-of-FDs + work function.
42void ProcessWork(ares_channel channel,
43 std::function<std::set<int>()> get_extrafds,
44 std::function<void(int)> process_extra);
45std::set<int> NoExtraFDs();
46
47// Test fixture that ensures library initialization, and allows
48// memory allocations to be failed.
49class LibraryTest : public ::testing::Test {
50 public:
51 LibraryTest() {
52 EXPECT_EQ(ARES_SUCCESS,
53 ares_library_init_mem(ARES_LIB_INIT_ALL,
54 &LibraryTest::amalloc,
55 &LibraryTest::afree,
56 &LibraryTest::arealloc));
57 }
58 ~LibraryTest() {
59 ares_library_cleanup();
60 ClearFails();
61 }
62 // Set the n-th malloc call (of any size) from the library to fail.
63 // (nth == 1 means the next call)
64 static void SetAllocFail(int nth);
65 // Set the next malloc call for the given size to fail.
66 static void SetAllocSizeFail(size_t size);
67 // Remove any pending alloc failures.
68 static void ClearFails();
69
70 static void *amalloc(size_t size);
71 static void* arealloc(void *ptr, size_t size);
72 static void afree(void *ptr);
73 private:
74 static bool ShouldAllocFail(size_t size);
75 static unsigned long long fails_;
76 static std::map<size_t, int> size_fails_;
77};
78
79// Test fixture that uses a default channel.
80class DefaultChannelTest : public LibraryTest {
81 public:
82 DefaultChannelTest() : channel_(nullptr) {
83 EXPECT_EQ(ARES_SUCCESS, ares_init(&channel_));
84 EXPECT_NE(nullptr, channel_);
85 }
86
87 ~DefaultChannelTest() {
88 ares_destroy(channel_);
89 channel_ = nullptr;
90 }
91
92 // Process all pending work on ares-owned file descriptors.
93 void Process();
94
95 protected:
96 ares_channel channel_;
97};
98
99// Test fixture that uses a default channel with the specified lookup mode.
100class DefaultChannelModeTest
101 : public LibraryTest,
102 public ::testing::WithParamInterface<std::string> {
103 public:
104 DefaultChannelModeTest() : channel_(nullptr) {
105 struct ares_options opts = {0};
106 opts.lookups = strdup(GetParam().c_str());
107 int optmask = ARES_OPT_LOOKUPS;
108 EXPECT_EQ(ARES_SUCCESS, ares_init_options(&channel_, &opts, optmask, NULL));
109 EXPECT_NE(nullptr, channel_);
110 free(opts.lookups);
111 }
112
113 ~DefaultChannelModeTest() {
114 ares_destroy(channel_);
115 channel_ = nullptr;
116 }
117
118 // Process all pending work on ares-owned file descriptors.
119 void Process();
120
121 protected:
122 ares_channel channel_;
123};
124
125// Mock DNS server to allow responses to be scripted by tests.
126class MockServer {
127 public:
128 MockServer(int family, int port, int tcpport = 0);
129 ~MockServer();
130
131 // Mock method indicating the processing of a particular <name, RRtype>
132 // request.
133 MOCK_METHOD2(OnRequest, void(const std::string& name, int rrtype));
134
135 // Set the reply to be sent next; the query ID field will be overwritten
136 // with the value from the request.
137 void SetReplyData(const std::vector<byte>& reply) { reply_ = reply; }
138 void SetReply(const DNSPacket* reply) { SetReplyData(reply->data()); }
139 void SetReplyQID(int qid) { qid_ = qid; }
140
141 // The set of file descriptors that the server handles.
142 std::set<int> fds() const;
143
144 // Process activity on a file descriptor.
145 void ProcessFD(int fd);
146
147 // Ports the server is responding to
148 int udpport() const { return udpport_; }
149 int tcpport() const { return tcpport_; }
150
151 private:
152 void ProcessRequest(int fd, struct sockaddr_storage* addr, int addrlen,
153 int qid, const std::string& name, int rrtype);
154
155 int udpport_;
156 int tcpport_;
157 int udpfd_;
158 int tcpfd_;
159 std::set<int> connfds_;
160 std::vector<byte> reply_;
161 int qid_;
162};
163
164// Test fixture that uses a mock DNS server.
165class MockChannelOptsTest : public LibraryTest {
166 public:
167 MockChannelOptsTest(int count, int family, bool force_tcp, struct ares_options* givenopts, int optmask);
168 ~MockChannelOptsTest();
169
170 // Process all pending work on ares-owned and mock-server-owned file descriptors.
171 void Process();
172
173 protected:
174 // NiceMockServer doesn't complain about uninteresting calls.
175 typedef testing::NiceMock<MockServer> NiceMockServer;
176 typedef std::vector< std::unique_ptr<NiceMockServer> > NiceMockServers;
177
178 std::set<int> fds() const;
179 void ProcessFD(int fd);
180
181 static NiceMockServers BuildServers(int count, int family, int base_port);
182
183 NiceMockServers servers_;
184 // Convenience reference to first server.
185 NiceMockServer& server_;
186 ares_channel channel_;
187};
188
189class MockChannelTest
190 : public MockChannelOptsTest,
191 public ::testing::WithParamInterface< std::pair<int, bool> > {
192 public:
193 MockChannelTest() : MockChannelOptsTest(1, GetParam().first, GetParam().second, nullptr, 0) {}
194};
195
196class MockUDPChannelTest
197 : public MockChannelOptsTest,
198 public ::testing::WithParamInterface<int> {
199 public:
200 MockUDPChannelTest() : MockChannelOptsTest(1, GetParam(), false, nullptr, 0) {}
201};
202
203class MockTCPChannelTest
204 : public MockChannelOptsTest,
205 public ::testing::WithParamInterface<int> {
206 public:
207 MockTCPChannelTest() : MockChannelOptsTest(1, GetParam(), true, nullptr, 0) {}
208};
209
210// gMock action to set the reply for a mock server.
211ACTION_P2(SetReplyData, mockserver, data) {
212 mockserver->SetReplyData(data);
213}
214ACTION_P2(SetReply, mockserver, reply) {
215 mockserver->SetReply(reply);
216}
217ACTION_P2(SetReplyQID, mockserver, qid) {
218 mockserver->SetReplyQID(qid);
219}
220// gMock action to cancel a channel.
221ACTION_P2(CancelChannel, mockserver, channel) {
222 ares_cancel(channel);
223}
224
225// C++ wrapper for struct hostent.
226struct HostEnt {
227 HostEnt() : addrtype_(-1) {}
228 HostEnt(const struct hostent* hostent);
229 std::string name_;
230 std::vector<std::string> aliases_;
231 int addrtype_; // AF_INET or AF_INET6
232 std::vector<std::string> addrs_;
233};
234std::ostream& operator<<(std::ostream& os, const HostEnt& result);
235
236// Structure that describes the result of an ares_host_callback invocation.
237struct HostResult {
238 // Whether the callback has been invoked.
239 bool done_;
240 // Explicitly provided result information.
241 int status_;
242 int timeouts_;
243 // Contents of the hostent structure, if provided.
244 HostEnt host_;
245};
246std::ostream& operator<<(std::ostream& os, const HostResult& result);
247
248// Structure that describes the result of an ares_callback invocation.
249struct SearchResult {
250 // Whether the callback has been invoked.
251 bool done_;
252 // Explicitly provided result information.
253 int status_;
254 int timeouts_;
255 std::vector<byte> data_;
256};
257std::ostream& operator<<(std::ostream& os, const SearchResult& result);
258
259// Structure that describes the result of an ares_nameinfo_callback invocation.
260struct NameInfoResult {
261 // Whether the callback has been invoked.
262 bool done_;
263 // Explicitly provided result information.
264 int status_;
265 int timeouts_;
266 std::string node_;
267 std::string service_;
268};
269std::ostream& operator<<(std::ostream& os, const NameInfoResult& result);
270
271// Standard implementation of ares callbacks that fill out the corresponding
272// structures.
273void HostCallback(void *data, int status, int timeouts,
274 struct hostent *hostent);
275void SearchCallback(void *data, int status, int timeouts,
276 unsigned char *abuf, int alen);
277void NameInfoCallback(void *data, int status, int timeouts,
278 char *node, char *service);
279
280// Retrieve the name servers used by a channel.
281std::vector<std::string> GetNameServers(ares_channel channel);
282
283
284// RAII class to temporarily create a directory of a given name.
285class TransientDir {
286 public:
287 TransientDir(const std::string& dirname);
288 ~TransientDir();
289
290 private:
291 std::string dirname_;
292};
293
294// C++ wrapper around tempnam()
295std::string TempNam(const char *dir, const char *prefix);
296
297// RAII class to temporarily create file of a given name and contents.
298class TransientFile {
299 public:
300 TransientFile(const std::string &filename, const std::string &contents);
301 ~TransientFile();
302
303 protected:
304 std::string filename_;
305};
306
307// RAII class for a temporary file with the given contents.
308class TempFile : public TransientFile {
309 public:
310 TempFile(const std::string& contents);
311 const char* filename() const { return filename_.c_str(); }
312};
313
314#ifndef WIN32
315// RAII class for a temporary environment variable value.
316class EnvValue {
317 public:
318 EnvValue(const char *name, const char *value) : name_(name), restore_(false) {
319 char *original = getenv(name);
320 if (original) {
321 restore_ = true;
322 original_ = original;
323 }
324 setenv(name_.c_str(), value, 1);
325 }
326 ~EnvValue() {
327 if (restore_) {
328 setenv(name_.c_str(), original_.c_str(), 1);
329 } else {
330 unsetenv(name_.c_str());
331 }
332 }
333 private:
334 std::string name_;
335 bool restore_;
336 std::string original_;
337};
338#endif
339
340
341#ifdef HAVE_CONTAINER
342// Linux-specific functionality for running code in a container, implemented
343// in ares-test-ns.cc
344typedef std::function<int(void)> VoidToIntFn;
345typedef std::vector<std::pair<std::string, std::string>> NameContentList;
346
347class ContainerFilesystem {
348 public:
349 ContainerFilesystem(NameContentList files, const std::string& mountpt);
350 ~ContainerFilesystem();
351 std::string root() const { return rootdir_; };
352 std::string mountpt() const { return mountpt_; };
353 private:
354 void EnsureDirExists(const std::string& dir);
355 std::string rootdir_;
356 std::string mountpt_;
357 std::list<std::string> dirs_;
358 std::vector<std::unique_ptr<TransientFile>> files_;
359};
360
361int RunInContainer(ContainerFilesystem* fs, const std::string& hostname,
362 const std::string& domainname, VoidToIntFn fn);
363
364#define ICLASS_NAME(casename, testname) Contained##casename##_##testname
365#define CONTAINED_TEST_F(casename, testname, hostname, domainname, files) \
366 class ICLASS_NAME(casename, testname) : public casename { \
367 public: \
368 ICLASS_NAME(casename, testname)() {} \
369 static int InnerTestBody(); \
370 }; \
371 TEST_F(ICLASS_NAME(casename, testname), _) { \
372 ContainerFilesystem chroot(files, ".."); \
373 VoidToIntFn fn(ICLASS_NAME(casename, testname)::InnerTestBody); \
374 EXPECT_EQ(0, RunInContainer(&chroot, hostname, domainname, fn)); \
375 } \
376 int ICLASS_NAME(casename, testname)::InnerTestBody()
377
378#endif
379
380} // namespace test
381} // namespace ares
382
383#endif