Coverage Report

Created: 2026-04-29 19:21

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/tmp/bitcoin/src/util/sock.cpp
Line
Count
Source
1
// Copyright (c) 2020-present The Bitcoin Core developers
2
// Distributed under the MIT software license, see the accompanying
3
// file COPYING or http://www.opensource.org/licenses/mit-license.php.
4
5
#include <util/sock.h>
6
7
#include <compat/compat.h>
8
#include <span.h>
9
#include <tinyformat.h>
10
#include <util/check.h>
11
#include <util/log.h>
12
#include <util/syserror.h>
13
#include <util/threadinterrupt.h>
14
#include <util/time.h>
15
16
#include <algorithm>
17
#include <compare>
18
#include <exception>
19
#include <memory>
20
#include <stdexcept>
21
#include <string>
22
#include <utility>
23
#include <vector>
24
25
#ifdef USE_POLL
26
#include <poll.h>
27
#endif
28
29
static inline bool IOErrorIsPermanent(int err)
30
21
{
31
21
    return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
32
21
}
33
34
2.66k
Sock::Sock(SOCKET s) : m_socket(s) {}
35
36
Sock::Sock(Sock&& other)
37
3
{
38
3
    m_socket = other.m_socket;
39
3
    other.m_socket = INVALID_SOCKET;
40
3
}
41
42
2.66k
Sock::~Sock() { Close(); }
43
44
Sock& Sock::operator=(Sock&& other)
45
4
{
46
4
    Close();
47
4
    m_socket = other.m_socket;
48
4
    other.m_socket = INVALID_SOCKET;
49
4
    return *this;
50
4
}
51
52
ssize_t Sock::Send(const void* data, size_t len, int flags) const
53
324k
{
54
324k
    return send(m_socket, static_cast<const char*>(data), len, flags);
55
324k
}
56
57
ssize_t Sock::Recv(void* buf, size_t len, int flags) const
58
160k
{
59
160k
    return recv(m_socket, static_cast<char*>(buf), len, flags);
60
160k
}
61
62
int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
63
631
{
64
631
    return connect(m_socket, addr, addr_len);
65
631
}
66
67
int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
68
988
{
69
988
    return bind(m_socket, addr, addr_len);
70
988
}
71
72
int Sock::Listen(int backlog) const
73
987
{
74
987
    return listen(m_socket, backlog);
75
987
}
76
77
std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
78
1.01k
{
79
#ifdef WIN32
80
    static constexpr auto ERR = INVALID_SOCKET;
81
#else
82
1.01k
    static constexpr auto ERR = SOCKET_ERROR;
83
1.01k
#endif
84
85
1.01k
    std::unique_ptr<Sock> sock;
86
87
1.01k
    const auto socket = accept(m_socket, addr, addr_len);
88
1.01k
    if (socket != ERR) {
89
1.01k
        try {
90
1.01k
            sock = std::make_unique<Sock>(socket);
91
1.01k
        } catch (const std::exception&) {
92
#ifdef WIN32
93
            closesocket(socket);
94
#else
95
0
            close(socket);
96
0
#endif
97
0
        }
98
1.01k
    }
99
100
1.01k
    return sock;
101
1.01k
}
102
103
int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
104
623
{
105
623
    return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
106
623
}
107
108
int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
109
3.60k
{
110
3.60k
    return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
111
3.60k
}
112
113
int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
114
1.60k
{
115
1.60k
    return getsockname(m_socket, name, name_len);
116
1.60k
}
117
118
bool Sock::SetNonBlocking() const
119
1.61k
{
120
#ifdef WIN32
121
    u_long on{1};
122
    if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
123
        return false;
124
    }
125
#else
126
1.61k
    const int flags{fcntl(m_socket, F_GETFL, 0)};
127
1.61k
    if (flags == SOCKET_ERROR) {
128
0
        return false;
129
0
    }
130
1.61k
    if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
131
0
        return false;
132
0
    }
133
1.61k
#endif
134
1.61k
    return true;
135
1.61k
}
136
137
bool Sock::IsSelectable() const
138
2.62k
{
139
2.62k
#if defined(USE_POLL) || defined(WIN32)
140
2.62k
    return true;
141
#else
142
    return m_socket < FD_SETSIZE;
143
#endif
144
2.62k
}
145
146
bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
147
852
{
148
    // We need a `shared_ptr` holding `this` for `WaitMany()`, but don't want
149
    // `this` to be destroyed when the `shared_ptr` goes out of scope at the
150
    // end of this function.
151
    // Create it with an aliasing shared_ptr that points to `this` without
152
    // owning it.
153
852
    std::shared_ptr<const Sock> shared{std::shared_ptr<const Sock>{}, this};
154
155
852
    EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
156
157
852
    if (!WaitMany(timeout, events_per_sock)) {
158
0
        return false;
159
0
    }
160
161
852
    if (occurred != nullptr) {
162
730
        *occurred = events_per_sock.begin()->second.occurred;
163
730
    }
164
165
852
    return true;
166
852
}
167
168
bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
169
315k
{
170
315k
#ifdef USE_POLL
171
315k
    std::vector<pollfd> pfds;
172
815k
    for (const auto& [sock, events] : events_per_sock) {
173
815k
        pfds.emplace_back();
174
815k
        auto& pfd = pfds.back();
175
815k
        pfd.fd = sock->m_socket;
176
815k
        if (events.requested & RECV) {
177
815k
            pfd.events |= POLLIN;
178
815k
        }
179
815k
        if (events.requested & SEND) {
180
889
            pfd.events |= POLLOUT;
181
889
        }
182
815k
    }
183
184
315k
    if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
185
0
        return false;
186
0
    }
187
188
315k
    assert(pfds.size() == events_per_sock.size());
189
315k
    size_t i{0};
190
815k
    for (auto& [sock, events] : events_per_sock) {
191
815k
        assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
192
815k
        events.occurred = 0;
193
815k
        if (pfds[i].revents & POLLIN) {
194
160k
            events.occurred |= RECV;
195
160k
        }
196
815k
        if (pfds[i].revents & POLLOUT) {
197
881
            events.occurred |= SEND;
198
881
        }
199
815k
        if (pfds[i].revents & (POLLERR | POLLHUP)) {
200
47
            events.occurred |= ERR;
201
47
        }
202
815k
        ++i;
203
815k
    }
204
205
315k
    return true;
206
#else
207
    fd_set recv;
208
    fd_set send;
209
    fd_set err;
210
    FD_ZERO(&recv);
211
    FD_ZERO(&send);
212
    FD_ZERO(&err);
213
    SOCKET socket_max{0};
214
215
    for (const auto& [sock, events] : events_per_sock) {
216
        if (!sock->IsSelectable()) {
217
            return false;
218
        }
219
        const auto& s = sock->m_socket;
220
        if (events.requested & RECV) {
221
            FD_SET(s, &recv);
222
        }
223
        if (events.requested & SEND) {
224
            FD_SET(s, &send);
225
        }
226
        FD_SET(s, &err);
227
        socket_max = std::max(socket_max, s);
228
    }
229
230
    timeval tv = MillisToTimeval(timeout);
231
232
    if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
233
        return false;
234
    }
235
236
    for (auto& [sock, events] : events_per_sock) {
237
        const auto& s = sock->m_socket;
238
        events.occurred = 0;
239
        if (FD_ISSET(s, &recv)) {
240
            events.occurred |= RECV;
241
        }
242
        if (FD_ISSET(s, &send)) {
243
            events.occurred |= SEND;
244
        }
245
        if (FD_ISSET(s, &err)) {
246
            events.occurred |= ERR;
247
        }
248
    }
249
250
    return true;
251
#endif /* USE_POLL */
252
315k
}
253
254
void Sock::SendComplete(std::span<const unsigned char> data,
255
                        std::chrono::milliseconds timeout,
256
                        CThreadInterrupt& interrupt) const
257
169
{
258
169
    const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
259
169
    size_t sent{0};
260
261
169
    for (;;) {
262
169
        const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
263
264
169
        if (ret > 0) {
265
169
            sent += static_cast<size_t>(ret);
266
169
            if (sent == data.size()) {
267
169
                break;
268
169
            }
269
169
        } else {
270
0
            const int err{WSAGetLastError()};
271
0
            if (IOErrorIsPermanent(err)) {
272
0
                throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
273
0
            }
274
0
        }
275
276
0
        const auto now = GetTime<std::chrono::milliseconds>();
277
278
0
        if (now >= deadline) {
279
0
            throw std::runtime_error(strprintf(
280
0
                "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
281
0
        }
282
283
0
        if (interrupt) {
284
0
            throw std::runtime_error(strprintf(
285
0
                "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
286
0
        }
287
288
        // Wait for a short while (or the socket to become ready for sending) before retrying
289
        // if nothing was sent.
290
0
        const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
291
0
        (void)Wait(wait_time, SEND);
292
0
    }
293
169
}
294
295
void Sock::SendComplete(std::span<const char> data,
296
                        std::chrono::milliseconds timeout,
297
                        CThreadInterrupt& interrupt) const
298
49
{
299
49
    SendComplete(MakeUCharSpan(data), timeout, interrupt);
300
49
}
301
302
std::string Sock::RecvUntilTerminator(uint8_t terminator,
303
                                      std::chrono::milliseconds timeout,
304
                                      CThreadInterrupt& interrupt,
305
                                      size_t max_data) const
306
38
{
307
38
    const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
308
38
    std::string data;
309
38
    bool terminator_found{false};
310
311
    // We must not consume any bytes past the terminator from the socket.
312
    // One option is to read one byte at a time and check if we have read a terminator.
313
    // However that is very slow. Instead, we peek at what is in the socket and only read
314
    // as many bytes as possible without crossing the terminator.
315
    // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
316
    // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
317
    // at a time is about 50 times slower.
318
319
169
    for (;;) {
320
169
        if (data.size() >= max_data) {
321
2
            throw std::runtime_error(
322
2
                strprintf("Received too many bytes without a terminator (%u)", data.size()));
323
2
        }
324
325
167
        char buf[512];
326
327
167
        const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
328
329
167
        switch (peek_ret) {
330
0
        case -1: {
331
0
            const int err{WSAGetLastError()};
332
0
            if (IOErrorIsPermanent(err)) {
333
0
                throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
334
0
            }
335
0
            break;
336
0
        }
337
0
        case 0:
338
0
            throw std::runtime_error("Connection unexpectedly closed by peer");
339
167
        default:
340
167
            auto end = buf + peek_ret;
341
167
            auto terminator_pos = std::find(buf, end, terminator);
342
167
            terminator_found = terminator_pos != end;
343
344
167
            const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
345
167
                                                    static_cast<size_t>(peek_ret)};
346
347
167
            const ssize_t read_ret{Recv(buf, try_len, 0)};
348
349
167
            if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
350
0
                throw std::runtime_error(
351
0
                    strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
352
0
                              "peek claimed %u bytes are available",
353
0
                              read_ret, try_len, peek_ret));
354
0
            }
355
356
            // Don't include the terminator in the output.
357
167
            const size_t append_len{terminator_found ? try_len - 1 : try_len};
358
359
167
            data.append(buf, buf + append_len);
360
361
167
            if (terminator_found) {
362
36
                return data;
363
36
            }
364
167
        }
365
366
131
        const auto now = GetTime<std::chrono::milliseconds>();
367
368
131
        if (now >= deadline) {
369
0
            throw std::runtime_error(strprintf(
370
0
                "Receive timeout (received %u bytes without terminator before that)", data.size()));
371
0
        }
372
373
131
        if (interrupt) {
374
0
            throw std::runtime_error(strprintf(
375
0
                "Receive interrupted (received %u bytes without terminator before that)",
376
0
                data.size()));
377
0
        }
378
379
        // Wait for a short while (or the socket to become ready for reading) before retrying.
380
131
        const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
381
131
        (void)Wait(wait_time, RECV);
382
131
    }
383
38
}
384
385
bool Sock::IsConnected(std::string& errmsg) const
386
105
{
387
105
    if (m_socket == INVALID_SOCKET) {
388
0
        errmsg = "not connected";
389
0
        return false;
390
0
    }
391
392
105
    char c;
393
105
    switch (Recv(&c, sizeof(c), MSG_PEEK)) {
394
21
    case -1: {
395
21
        const int err = WSAGetLastError();
396
21
        if (IOErrorIsPermanent(err)) {
397
0
            errmsg = NetworkErrorString(err);
398
0
            return false;
399
0
        }
400
21
        return true;
401
21
    }
402
0
    case 0:
403
0
        errmsg = "closed";
404
0
        return false;
405
84
    default:
406
84
        return true;
407
105
    }
408
105
}
409
410
void Sock::Close()
411
2.67k
{
412
2.67k
    if (m_socket == INVALID_SOCKET) {
413
34
        return;
414
34
    }
415
#ifdef WIN32
416
    int ret = closesocket(m_socket);
417
#else
418
2.63k
    int ret = close(m_socket);
419
2.63k
#endif
420
2.63k
    if (ret) {
421
0
        LogWarning("Error closing socket %d: %s", m_socket, NetworkErrorString(WSAGetLastError()));
422
0
    }
423
2.63k
    m_socket = INVALID_SOCKET;
424
2.63k
}
425
426
bool Sock::operator==(SOCKET s) const
427
4
{
428
4
    return m_socket == s;
429
4
};
430
431
std::string NetworkErrorString(int err)
432
38
{
433
#if defined(WIN32)
434
    return Win32ErrorString(err);
435
#else
436
    // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
437
38
    return SysErrorString(err);
438
38
#endif
439
38
}