From 0a4b1f1ce2a96b0b515c9d8778668df3045df56c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jerry=20Lundstr=C3=B6m?= Date: Tue, 26 Jun 2018 16:33:19 +0200 Subject: [PATCH] Timeout - `examples/*`: Fix handling of previous objects, need to explicitly check for not nil on cdata - `examples/replay.lua`: Handle timeouts - `examples/respdiff.lua`: Handle timeouts - `examples/test_throughput.lua`: Add tests for splitting in Lua - `core.object.dns`: Add column names in `print()` for resource records - `core.output.tcpcli`: - Issue #47: Remove DNS query check - Add timeout support, see `timeout()` - Continue parsing if we received more then needed - Add `received()` for the number of payloads received - Documentation - `core.output.udpcli`: - Issue #46: Remove DNS query check - Add timeout support, see `timeout()` - Add `received()` for the number of payloads received - Documentation --- examples/dumpdns-qr.lua | 4 +- examples/dumpdns.lua | 4 +- examples/filter_rcode.lua | 2 +- examples/replay.lua | 12 ++-- examples/respdiff.lua | 56 ++++++++--------- examples/test_throughput.lua | 90 +++++++++++++++++++++++++++ src/core/object/dns.lua | 72 ++++++++++++---------- src/output/tcpcli.c | 115 +++++++++++++++++++++++++++-------- src/output/tcpcli.h | 1 + src/output/tcpcli.hh | 10 ++- src/output/tcpcli.lua | 70 +++++++++++++-------- src/output/udpcli.c | 74 +++++++++++++++++++--- src/output/udpcli.h | 1 + src/output/udpcli.hh | 5 ++ src/output/udpcli.lua | 60 ++++++++++-------- 15 files changed, 422 insertions(+), 154 deletions(-) diff --git a/examples/dumpdns-qr.lua b/examples/dumpdns-qr.lua index e2fe52ec..9c0fd4b6 100755 --- a/examples/dumpdns-qr.lua +++ b/examples/dumpdns-qr.lua @@ -29,14 +29,14 @@ while true do local pl = obj:cast() if obj:type() == "payload" and pl.len > 0 then local transport = obj.obj_prev - while transport do + while transport ~= nil do if transport.obj_type == object.IP or transport.obj_type == object.IP6 then break end transport = transport.obj_prev end local protocol = obj.obj_prev - while protocol do + while protocol ~= nil do if protocol.obj_type == object.UDP or protocol.obj_type == object.TCP then break end diff --git a/examples/dumpdns.lua b/examples/dumpdns.lua index 3092eb9f..47e5c4d7 100755 --- a/examples/dumpdns.lua +++ b/examples/dumpdns.lua @@ -21,14 +21,14 @@ while true do local pl = obj:cast() if obj:type() == "payload" and pl.len > 0 then local transport = obj.obj_prev - while transport do + while transport ~= nil do if transport.obj_type == object.IP or transport.obj_type == object.IP6 then break end transport = transport.obj_prev end local protocol = obj.obj_prev - while protocol do + while protocol ~= nil do if protocol.obj_type == object.UDP or protocol.obj_type == object.TCP then break end diff --git a/examples/filter_rcode.lua b/examples/filter_rcode.lua index eafe725d..c3f02544 100755 --- a/examples/filter_rcode.lua +++ b/examples/filter_rcode.lua @@ -22,7 +22,7 @@ while true do local pl = obj:cast() if obj:type() == "payload" and pl.len > 0 then local transport = obj.obj_prev - while transport do + while transport ~= nil do if transport.obj_type == object.IP or transport.obj_type == object.IP6 then break end diff --git a/examples/replay.lua b/examples/replay.lua index f78a77ef..ef58d9b9 100755 --- a/examples/replay.lua +++ b/examples/replay.lua @@ -67,15 +67,17 @@ if printdns then recv(rctx, obj) - local resp = nil - while resp == nil do - resp = oprod(opctx) + local response = oprod(opctx) + if response == nil then + log.fatal("producer error") end - while resp ~= nil do + local payload = response:cast() + if payload.len == 0 then + print("timed out") + else dns.obj_prev = resp print("response:") dns:print() - resp = oprod(opctx) end end end diff --git a/examples/respdiff.lua b/examples/respdiff.lua index 509730c2..88da8d44 100755 --- a/examples/respdiff.lua +++ b/examples/respdiff.lua @@ -66,14 +66,14 @@ while true do dns.obj_prev = obj if dns:parse_header() == 0 then local transport = obj.obj_prev - while transport do + while transport ~= nil do if transport.obj_type == object.IP or transport.obj_type == object.IP6 then break end transport = transport.obj_prev end local protocol = obj.obj_prev - while protocol do + while protocol ~= nil do if protocol.obj_type == object.UDP or protocol.obj_type == object.TCP then break end @@ -102,7 +102,8 @@ while true do clipayload.payload = q.payload clipayload.len = q.len - local responses, response = {}, nil + local prod, pctx + if q.proto == "udp" then if not udpcli then udpcli = require("dnsjit.output.udpcli").new() @@ -111,13 +112,8 @@ while true do udpprod, _ = udpcli:produce() end udprecv(udpctx, cliobject) - while response == nil do - response = udpprod(udpctx) - end - while response ~= nil do - table.insert(responses, response) - response = udpprod(udpctx) - end + prod = udpprod + pctx = udpctx elseif q.proto == "tcp" then if not tcpcli then tcpcli = require("dnsjit.output.tcpcli").new() @@ -126,27 +122,31 @@ while true do tcpprod, _ = tcpcli:produce() end tcprecv(tcpctx, cliobject) - while response == nil do - response = tcpprod(tcpctx) - end - while response ~= nil do - table.insert(responses, response) - response = tcpprod(tcpctx) - end + prod = tcpprod + pctx = tcpctx end - for _, response in pairs(responses) do - dns.obj_prev = response - if dns:parse_header() == 0 and dns.id == q.id then - query_payload.payload = q.payload - query_payload.len = q.len - original_payload.payload = payload.payload - original_payload.len = payload.len - response = response:cast() - response_payload.payload = response.payload - response_payload.len = response.len + while true do + local response = prod(pctx) + if response == nil then + log.fatal("producer error") + end + local rpl = response:cast() + if rpl.len == 0 then + log.info("timed out") + else + dns.obj_prev = response + if dns:parse_header() == 0 and dns.id == q.id then + query_payload.payload = q.payload + query_payload.len = q.len + original_payload.payload = payload.payload + original_payload.len = payload.len + response_payload.payload = rpl.payload + response_payload.len = rpl.len - resprecv(respctx, query_payload_obj) + resprecv(respctx, query_payload_obj) + break + end end end end diff --git a/examples/test_throughput.lua b/examples/test_throughput.lua index 95e04e9a..b936c415 100755 --- a/examples/test_throughput.lua +++ b/examples/test_throughput.lua @@ -232,6 +232,96 @@ if getopt:val("s") then print(run, "runtime", runtime, num/runtime, "/sec", o1:packets() + o2:packets() + o3:packets() + o4:packets(), o1:packets(), o2:packets(), o3:packets(), o4:packets()) end + + print("zero:receiver() -> lua split table -> null:receive() x4") + local run + for run = 1, runs do + local i = require("dnsjit.input.zero").new() + local o1 = require("dnsjit.output.null").new() + local o2 = require("dnsjit.output.null").new() + local o3 = require("dnsjit.output.null").new() + local o4 = require("dnsjit.output.null").new() + + local prod, pctx = i:produce() + local recv, rctx = {}, {} + + local f, c = o1:receive() + table.insert(recv, f) + table.insert(rctx, c) + f, c = o2:receive() + table.insert(recv, f) + table.insert(rctx, c) + f, c = o3:receive() + table.insert(recv, f) + table.insert(rctx, c) + f, c = o4:receive() + table.insert(recv, f) + table.insert(rctx, c) + + local start_sec, start_nsec = clock:monotonic() + local idx = 1 + for n = 1, num do + local f, c = recv[idx], rctx[idx] + if not f then + idx = 1 + f, c = recv[1], rctx[1] + end + f(c, prod(pctx)) + idx = idx + 1 + end + local end_sec, end_nsec = clock:monotonic() + + local runtime = 0 + if end_sec > start_sec then + runtime = ((end_sec - start_sec) - 1) + ((1000000000 - start_nsec + end_nsec)/1000000000) + elseif end_sec == start_sec and end_nsec > start_nsec then + runtime = (end_nsec - start_nsec) / 1000000000 + end + + print(run, "runtime", runtime, num/runtime, "/sec", o1:packets() + o2:packets() + o3:packets() + o4:packets(), o1:packets(), o2:packets(), o3:packets(), o4:packets()) + end + + print("zero:receiver() -> lua split gen code -> null:receive() x4") + local run + for run = 1, runs do + local i = require("dnsjit.input.zero").new() + local o1 = require("dnsjit.output.null").new() + local o2 = require("dnsjit.output.null").new() + local o3 = require("dnsjit.output.null").new() + local o4 = require("dnsjit.output.null").new() + + local prod, pctx = i:produce() + local f1, c1 = o1:receive() + local f2, c2 = o2:receive() + local f3, c3 = o3:receive() + local f4, c4 = o4:receive() + + local code = "return function (num, prod, pctx, f1, c1, f2, c2, f3, c3, f4, c4)\nlocal n = 0\nwhile n < num do\n" + code = code .. "f1(c1,prod(pctx))\n" + code = code .. "n = n + 1\n" + code = code .. "f2(c2,prod(pctx))\n" + code = code .. "n = n + 1\n" + code = code .. "f3(c3,prod(pctx))\n" + code = code .. "n = n + 1\n" + code = code .. "f4(c4,prod(pctx))\n" + code = code .. "n = n + 1\n" + code = code .. "end\n" + code = code .. "end" + local f = assert(loadstring(code))() + + local start_sec, start_nsec = clock:monotonic() + f(num, prod, pctx, f1, c1, f2, c2, f3, c3, f4, c4) + local end_sec, end_nsec = clock:monotonic() + + local runtime = 0 + if end_sec > start_sec then + runtime = ((end_sec - start_sec) - 1) + ((1000000000 - start_nsec + end_nsec)/1000000000) + elseif end_sec == start_sec and end_nsec > start_nsec then + runtime = (end_nsec - start_nsec) / 1000000000 + end + + print(run, "runtime", runtime, num/runtime, "/sec", o1:packets() + o2:packets() + o3:packets() + o4:packets(), o1:packets(), o2:packets(), o3:packets(), o4:packets()) + end end if getopt:val("t") then diff --git a/src/core/object/dns.lua b/src/core/object/dns.lua index d195e647..d620aae1 100644 --- a/src/core/object/dns.lua +++ b/src/core/object/dns.lua @@ -671,44 +671,52 @@ function Dns:print(num_labels) print("", "nscount:", self.nscount) print("", "arcount:", self.arcount) - print("", "questions:") - for n = 1, self.qdcount do - if C.core_object_dns_parse_q(self, q, labels, num_labels) ~= 0 then - return + if self.qdcount > 0 then + print("questions:", "class", "type", "labels") + for n = 1, self.qdcount do + if C.core_object_dns_parse_q(self, q, labels, num_labels) ~= 0 then + return + end + print("", Dns.class_tostring(q.class), Dns.type_tostring(q.type), label.tooffstr(self, labels, num_labels)) end - print("", "", Dns.class_tostring(q.class), Dns.type_tostring(q.type), label.tooffstr(self, labels, num_labels)) end - print("", "answers:") - for n = 1, self.ancount do - if C.core_object_dns_parse_rr(self, rr, labels, num_labels) ~= 0 then - return - end - if rr.rdata_labels == 0 then - print("", "", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels)) - else - print("", "", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels), label.tooffstr(self, labels, rr.rdata_labels, rr.labels)) + if self.ancount > 0 then + print("answers:", "class", "type", "ttl", "labels", "RR labels") + for n = 1, self.ancount do + if C.core_object_dns_parse_rr(self, rr, labels, num_labels) ~= 0 then + return + end + if rr.rdata_labels == 0 then + print("", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels)) + else + print("", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels), label.tooffstr(self, labels, rr.rdata_labels, rr.labels)) + end end end - print("", "authorities:") - for n = 1, self.nscount do - if C.core_object_dns_parse_rr(self, rr, labels, num_labels) ~= 0 then - return - end - if rr.rdata_labels == 0 then - print("", "", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels)) - else - print("", "", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels), label.tooffstr(self, labels, rr.rdata_labels, rr.labels)) + if self.nscount > 0 then + print("authorities:", "class", "type", "ttl", "labels", "RR labels") + for n = 1, self.nscount do + if C.core_object_dns_parse_rr(self, rr, labels, num_labels) ~= 0 then + return + end + if rr.rdata_labels == 0 then + print("", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels)) + else + print("", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels), label.tooffstr(self, labels, rr.rdata_labels, rr.labels)) + end end end - print("", "additionals:") - for n = 1, self.arcount do - if C.core_object_dns_parse_rr(self, rr, labels, num_labels) ~= 0 then - return - end - if rr.rdata_labels == 0 then - print("", "", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels)) - else - print("", "", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels), label.tooffstr(self, labels, rr.rdata_labels, rr.labels)) + if self.arcount > 0 then + print("additionals:", "class", "type", "ttl", "labels", "RR labels") + for n = 1, self.arcount do + if C.core_object_dns_parse_rr(self, rr, labels, num_labels) ~= 0 then + return + end + if rr.rdata_labels == 0 then + print("", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels)) + else + print("", Dns.class_tostring(rr.class), Dns.type_tostring(rr.type), rr.ttl, label.tooffstr(self, labels, rr.labels), label.tooffstr(self, labels, rr.rdata_labels, rr.labels)) + end end end end diff --git a/src/output/tcpcli.c b/src/output/tcpcli.c index a22090df..f583b4d0 100644 --- a/src/output/tcpcli.c +++ b/src/output/tcpcli.c @@ -32,13 +32,15 @@ #include #include #include +#include static core_log_t _log = LOG_T_INIT("output.tcpcli"); static output_tcpcli_t _defaults = { LOG_T_INIT_OBJ("output.tcpcli"), 0, 0, -1, { 0 }, CORE_OBJECT_PAYLOAD_INIT(0), - 0, 0, 0 + 0, 0, 0, 0, + { 5, 0 }, 1 }; core_log_t* output_tcpcli_log() @@ -134,8 +136,10 @@ int output_tcpcli_set_nonblocking(output_tcpcli_t* self, int nonblocking) if (nonblocking) { flags |= O_NONBLOCK; + self->blocking = 0; } else { flags &= ~O_NONBLOCK; + self->blocking = 1; } if (fcntl(self->fd, F_SETFL, flags | O_NONBLOCK)) { @@ -166,13 +170,7 @@ static void _receive(output_tcpcli_t* self, const core_object_t* obj) return; } - if (len < 3 || payload[2] & 0x80) { - return; - } - - sent = 0; - self->pkts++; - + sent = 0; dnslen = htons(len); for (;;) { @@ -189,6 +187,7 @@ static void _receive(output_tcpcli_t* self, const core_object_t* obj) sent += ret; if (sent < len) continue; + self->pkts++; return; } switch (errno) { @@ -200,10 +199,10 @@ static void _receive(output_tcpcli_t* self, const core_object_t* obj) default: break; } - self->errs++; break; } - break; + self->errs++; + return; } switch (errno) { case EAGAIN: @@ -214,9 +213,9 @@ static void _receive(output_tcpcli_t* self, const core_object_t* obj) default: break; } - self->errs++; break; } + self->errs++; break; } } @@ -234,30 +233,84 @@ core_receiver_t output_tcpcli_receiver(output_tcpcli_t* self) static const core_object_t* _produce(output_tcpcli_t* self) { - ssize_t n, recv; - uint16_t dnslen; + ssize_t n, recv = 0; + uint16_t dnslen; + struct pollfd p; + int to; mlassert_self(); + // Check if last recvfrom() got more then we needed + if (!self->have_dnslen && self->recv > self->dnslen) { + recv = self->recv - self->dnslen; + if (recv < sizeof(dnslen)) { + memcpy(((uint8_t*)&dnslen), self->recvbuf + self->dnslen, recv); + } else { + memcpy(((uint8_t*)&dnslen), self->recvbuf + self->dnslen, sizeof(dnslen)); + + self->dnslen = ntohs(dnslen); + self->have_dnslen = 1; + + if (recv > sizeof(dnslen)) { + self->recv = recv - sizeof(dnslen); + memmove(self->recvbuf, self->recvbuf + self->dnslen + sizeof(dnslen), self->recv); + + if (self->recv > self->dnslen) { + self->pkts_recv++; + self->pkt.len = self->dnslen; + return (core_object_t*)&self->pkt; + } + } else { + self->recv = 0; + } + } + } + + if (self->blocking) { + p.fd = self->fd; + p.events = POLLIN; + p.revents = 0; + to = (self->timeout.sec * 1e3) + (self->timeout.nsec / 1e6); + if (!to) { + to = 1; + } + } + if (!self->have_dnslen) { - recv = 0; for (;;) { + n = poll(&p, 1, to); + if (n < 0 || (p.revents & (POLLERR | POLLHUP | POLLNVAL))) { + self->errs++; + return 0; + } + if (!n || !(p.revents & POLLIN)) { + if (recv) { + self->errs++; + return 0; + } + self->pkt.len = 0; + return (core_object_t*)&self->pkt; + } + n = recvfrom(self->fd, ((uint8_t*)&dnslen) + recv, sizeof(dnslen) - recv, 0, 0, 0); - if (n > -1) { + if (n > 0) { recv += n; if (recv < sizeof(dnslen)) continue; break; } + if (!n) { + break; + } switch (errno) { case EAGAIN: #if EAGAIN != EWOULDBLOCK case EWOULDBLOCK: #endif - n = 0; - break; + continue; default: break; } + self->errs++; break; } @@ -271,23 +324,37 @@ static const core_object_t* _produce(output_tcpcli_t* self) } for (;;) { - n = recvfrom(self->fd, self->recvbuf, sizeof(self->recvbuf), 0, 0, 0); - if (n > -1) { + n = poll(&p, 1, to); + if (n < 0 || (p.revents & (POLLERR | POLLHUP | POLLNVAL))) { + self->errs++; + return 0; + } + if (!n || !(p.revents & POLLIN)) { + self->pkt.len = 0; + return (core_object_t*)&self->pkt; + } + + n = recvfrom(self->fd, self->recvbuf + self->recv, sizeof(self->recvbuf) - self->recv, 0, 0, 0); + if (n > 0) { self->recv += n; if (self->recv < self->dnslen) continue; break; } + if (!n) { + break; + } switch (errno) { case EAGAIN: #if EAGAIN != EWOULDBLOCK case EWOULDBLOCK: #endif - n = 0; - break; + self->pkt.len = 0; + return (core_object_t*)&self->pkt; default: break; } + self->errs++; break; } @@ -295,10 +362,8 @@ static const core_object_t* _produce(output_tcpcli_t* self) return 0; } - // TODO: recv more then dnslen - - self->pkt.len = self->dnslen; - self->have_dnslen = 0; + self->pkts_recv++; + self->pkt.len = self->dnslen; return (core_object_t*)&self->pkt; } diff --git a/src/output/tcpcli.h b/src/output/tcpcli.h index 62103a0d..abdc7aff 100644 --- a/src/output/tcpcli.h +++ b/src/output/tcpcli.h @@ -22,6 +22,7 @@ #include "core/receiver.h" #include "core/producer.h" #include "core/object/payload.h" +#include "core/timespec.h" #ifndef __dnsjit_output_tcpcli_h #define __dnsjit_output_tcpcli_h diff --git a/src/output/tcpcli.hh b/src/output/tcpcli.hh index 750f911b..1ce94087 100644 --- a/src/output/tcpcli.hh +++ b/src/output/tcpcli.hh @@ -22,17 +22,21 @@ //lua:require("dnsjit.core.receiver_h") //lua:require("dnsjit.core.producer_h") //lua:require("dnsjit.core.object.payload_h") +//lua:require("dnsjit.core.timespec_h") typedef struct output_tcpcli { core_log_t _log; size_t pkts, errs; int fd; - uint8_t recvbuf[4 * 1024]; + uint8_t recvbuf[64 * 1024]; core_object_payload_t pkt; uint16_t dnslen; - unsigned short have_dnslen; - size_t recv; + uint8_t have_dnslen; + size_t recv, pkts_recv; + + core_timespec_t timeout; + int8_t blocking; } output_tcpcli_t; core_log_t* output_tcpcli_log(); diff --git a/src/output/tcpcli.lua b/src/output/tcpcli.lua index edf84f34..53dd1b39 100644 --- a/src/output/tcpcli.lua +++ b/src/output/tcpcli.lua @@ -17,12 +17,25 @@ -- along with dnsjit. If not, see . -- dnsjit.output.tcpcli --- Simple TCP DNS client +-- Simple, length aware, TCP client -- local output = require("dnsjit.output.tcpcli").new("127.0.0.1", "53") -- --- Simple DNS client that takes any payload you give it, look for the bit in --- the payload that says it's a DNS query, sends the length of the DNS and --- then sends the full payload over TCP. +-- Simple TCP client that takes any payload you give it, sends the length of +-- the payload as an unsigned 16 bit integer and then sends the payload. +-- When receiving it will first retrieve the length of the payload as an +-- unsigned 16 bit integer and it will stall until it gets, even if +-- nonblocking mode is used. +-- Then it will retrieve at least that amount of bytes, if nonblocking mode +-- is used here then it will return a payload object with length zero if +-- there was nothing to receive or if the full payload have not been received +-- yet. +-- Additional calls will continue retrieving the payload. +-- .SS Attributes +-- .TP +-- timeout +-- A +-- .I core.timespec +-- that is used when producing objects. module(...,package.seeall) require("dnsjit.output.tcpcli_h") @@ -33,25 +46,20 @@ local t_name = "output_tcpcli_t" local output_tcpcli_t = ffi.typeof(t_name) local Tcpcli = {} --- Create a new Tcpcli output. Optinally connect to the --- .I host --- and --- .IR port right away or use --- .BR connect () --- later on. -function Tcpcli.new(host, port) +-- Create a new Tcpcli output. +function Tcpcli.new() local self = { obj = output_tcpcli_t(), } C.output_tcpcli_init(self.obj) ffi.gc(self.obj, C.output_tcpcli_destroy) - self = setmetatable(self, { __index = Tcpcli }) - if host and port then - if self:connect(host, port) ~= 0 then - return - end - end - return self + return setmetatable(self, { __index = Tcpcli }) +end + +-- Set the timeout when producing objects. +function Tcpcli:timeout(seconds, nanoseconds) + self.obj.timeout.sec = seconds + self.obj.timeout.nsec = nanoseconds end -- Connect to the @@ -60,11 +68,7 @@ end -- .I port -- and return 0 if successful. function Tcpcli:connect(host, port) - local ret = C.output_tcpcli_connect(self.obj, host, port) - if ret == 0 then - ret = self:nonblocking(true) - end - return ret + return C.output_tcpcli_connect(self.obj, host, port) end -- Enable (true) or disable (false) nonblocking mode and @@ -92,15 +96,33 @@ end -- Return the C functions and context for producing objects, these objects -- are received. +-- If nonblocking mode is enabled the producer will return a payload object +-- with length zero if there was nothing to receive or if the full payload +-- have not been received yet. +-- If nonblocking mode is disabled the producer will wait for data and if +-- timed out (see +-- .IR timeout ) +-- it will return a payload object with length zero. +-- If a timeout happens during during the first stage, getting the length, it +-- will fail and return nil. +-- Additional calls will continue retrieving the payload. +-- The producer returns nil on error. function Tcpcli:produce() return C.output_tcpcli_producer(self.obj), self.obj end --- Return the number of queries we sent. +-- Return the number of "packets" sent, actually the number of completely sent +-- payloads. function Tcpcli:packets() return tonumber(self.obj.pkts) end +-- Return the number of "packets" received, actually the number of completely +-- received DNS messages. +function Tcpcli:received() + return tonumber(self.obj.pkts_recv) +end + -- Return the number of errors when sending. function Tcpcli:errors() return tonumber(self.obj.errs) diff --git a/src/output/udpcli.c b/src/output/udpcli.c index 3d9e8ac8..c155662e 100644 --- a/src/output/udpcli.c +++ b/src/output/udpcli.c @@ -29,13 +29,15 @@ #include #include #include +#include static core_log_t _log = LOG_T_INIT("output.udpcli"); static output_udpcli_t _defaults = { LOG_T_INIT_OBJ("output.udpcli"), 0, 0, -1, { 0 }, 0, - { 0 }, CORE_OBJECT_PAYLOAD_INIT(0) + { 0 }, CORE_OBJECT_PAYLOAD_INIT(0), 0, + { 5, 0 }, 1 }; core_log_t* output_udpcli_log() @@ -127,8 +129,10 @@ int output_udpcli_set_nonblocking(output_udpcli_t* self, int nonblocking) if (nonblocking) { flags |= O_NONBLOCK; + self->blocking = 0; } else { flags &= ~O_NONBLOCK; + self->blocking = 1; } if (fcntl(self->fd, F_SETFL, flags | O_NONBLOCK)) { @@ -158,18 +162,14 @@ static void _receive(output_udpcli_t* self, const core_object_t* obj) return; } - if (len < 3 || payload[2] & 0x80) { - return; - } - sent = 0; - self->pkts++; for (;;) { ssize_t ret = sendto(self->fd, payload + sent, len - sent, 0, (struct sockaddr*)&self->addr, self->addr_len); if (ret > -1) { sent += ret; if (sent < len) continue; + self->pkts++; return; } switch (errno) { @@ -181,9 +181,9 @@ static void _receive(output_udpcli_t* self, const core_object_t* obj) default: break; } - self->errs++; break; } + self->errs++; break; } } @@ -214,11 +214,65 @@ static const core_object_t* _produce(output_udpcli_t* self) #if EAGAIN != EWOULDBLOCK case EWOULDBLOCK: #endif - n = 0; + self->pkt.len = 0; + return (core_object_t*)&self->pkt; + default: + break; + } + self->errs++; + break; + } + + if (n < 1) { + return 0; + } + + self->pkts_recv++; + self->pkt.len = n; + return (core_object_t*)&self->pkt; +} + +static const core_object_t* _produce_block(output_udpcli_t* self) +{ + ssize_t n; + struct pollfd p; + int to; + mlassert_self(); + + p.fd = self->fd; + p.events = POLLIN; + p.revents = 0; + to = (self->timeout.sec * 1e3) + (self->timeout.nsec / 1e6); + if (!to) { + to = 1; + } + + n = poll(&p, 1, to); + if (n < 0 || (p.revents & (POLLERR | POLLHUP | POLLNVAL))) { + self->errs++; + return 0; + } + if (!n || !(p.revents & POLLIN)) { + self->pkt.len = 0; + return (core_object_t*)&self->pkt; + } + + for (;;) { + n = recvfrom(self->fd, self->recvbuf, sizeof(self->recvbuf), 0, 0, 0); + if (n > -1) { break; + } + switch (errno) { + case EAGAIN: +#if EAGAIN != EWOULDBLOCK + case EWOULDBLOCK: +#endif + self->pkt.len = 0; + return (core_object_t*)&self->pkt; default: break; } + self->errs++; break; } @@ -226,6 +280,7 @@ static const core_object_t* _produce(output_udpcli_t* self) return 0; } + self->pkts_recv++; self->pkt.len = n; return (core_object_t*)&self->pkt; } @@ -238,5 +293,8 @@ core_producer_t output_udpcli_producer(output_udpcli_t* self) lfatal("not connected"); } + if (self->blocking) { + return (core_producer_t)_produce_block; + } return (core_producer_t)_produce; } diff --git a/src/output/udpcli.h b/src/output/udpcli.h index bdc98da0..79febf3d 100644 --- a/src/output/udpcli.h +++ b/src/output/udpcli.h @@ -22,6 +22,7 @@ #include "core/receiver.h" #include "core/producer.h" #include "core/object/payload.h" +#include "core/timespec.h" #ifndef __dnsjit_output_udpcli_h #define __dnsjit_output_udpcli_h diff --git a/src/output/udpcli.hh b/src/output/udpcli.hh index 6cd3f00d..f61e5536 100644 --- a/src/output/udpcli.hh +++ b/src/output/udpcli.hh @@ -23,6 +23,7 @@ //lua:require("dnsjit.core.receiver_h") //lua:require("dnsjit.core.producer_h") //lua:require("dnsjit.core.object.payload_h") +//lua:require("dnsjit.core.timespec_h") typedef struct output_udpcli { core_log_t _log; @@ -34,6 +35,10 @@ typedef struct output_udpcli { uint8_t recvbuf[4 * 1024]; core_object_payload_t pkt; + size_t pkts_recv; + + core_timespec_t timeout; + int8_t blocking; } output_udpcli_t; core_log_t* output_udpcli_log(); diff --git a/src/output/udpcli.lua b/src/output/udpcli.lua index e6b65168..9f0797ca 100644 --- a/src/output/udpcli.lua +++ b/src/output/udpcli.lua @@ -20,9 +20,14 @@ -- Simple and dumb UDP DNS client -- local output = require("dnsjit.output.udpcli").new("127.0.0.1", "53") -- --- Simple and rather dumb DNS client that takes any payload you give it, --- look for the bit in the payload that says it's a DNS query and sends --- the full payload over UDP if it is. +-- Simple and rather dumb DNS client that takes any payload you give it and +-- sends the full payload over UDP. +-- .SS Attributes +-- .TP +-- timeout +-- A +-- .I core.timespec +-- that is used when producing objects. module(...,package.seeall) require("dnsjit.output.udpcli_h") @@ -33,25 +38,20 @@ local t_name = "output_udpcli_t" local output_udpcli_t = ffi.typeof(t_name) local Udpcli = {} --- Create a new Udpcli output. Optinally connect to the --- .I host --- and --- .IR port right away or use --- .BR connect () --- later on. -function Udpcli.new(host, port) +-- Create a new Udpcli output. +function Udpcli.new() local self = { obj = output_udpcli_t(), } C.output_udpcli_init(self.obj) ffi.gc(self.obj, C.output_udpcli_destroy) - self = setmetatable(self, { __index = Udpcli }) - if host and port then - if self:connect(host, port) ~= 0 then - return - end - end - return self + return setmetatable(self, { __index = Udpcli }) +end + +-- Set the timeout when producing objects. +function Udpcli:timeout(seconds, nanoseconds) + self.obj.timeout.sec = seconds + self.obj.timeout.nsec = nanoseconds end -- Connect to the @@ -60,11 +60,7 @@ end -- .I port -- and return 0 if successful. function Udpcli:connect(host, port) - local ret = C.output_udpcli_connect(self.obj, host, port) - if ret == 0 then - ret = self:nonblocking(true) - end - return ret + return C.output_udpcli_connect(self.obj, host, port) end -- Enable (true) or disable (false) nonblocking mode and @@ -92,16 +88,32 @@ end -- Return the C functions and context for producing objects, these objects -- are received. +-- If nonblocking mode is enabled the producer will return a payload object +-- with length zero if there was nothing to receive. +-- If nonblocking mode is disabled the producer will wait for data and if +-- timed out (see +-- .IR timeout ) +-- it will return a payload object with length zero. +-- The producer returns nil on error. function Udpcli:produce() return C.output_udpcli_producer(self.obj), self.obj end --- Return the number of queries we sent. +-- Return the number of "packets" sent, actually the number of completely sent +-- payloads. function Udpcli:packets() return tonumber(self.obj.pkts) end --- Return the number of errors when sending. +-- Return the number of "packets" received, actually the number of successful +-- calls to +-- .IR recvfrom (2) +-- that returned data. +function Udpcli:received() + return tonumber(self.obj.pkts_recv) +end + +-- Return the number of errors when sending or receiving. function Udpcli:errors() return tonumber(self.obj.errs) end