|
diff --git a/dns.go b/dns.go |
|
index 9421b5a..f007f3f 100644 |
|
--- a/dns.go |
|
+++ b/dns.go |
|
@@ -237,6 +237,10 @@ type Msg struct { |
|
// over the wire. Note that this data is a snapshot of the Msg when it was packed or unpacked. |
|
Data []byte |
|
hijacked atomic.Bool // pool's allocation has been hijacked by caller |
|
+ |
|
+ // state of the unpacking, see Msg.Unpack. |
|
+ unpackedOffset int |
|
+ nextSection int |
|
} |
|
|
|
// Option is an option on how to handle a message. Options can be combined, but that have to be "in order", if |
|
diff --git a/msg.go b/msg.go |
|
index aa24461..46035ed 100644 |
|
--- a/msg.go |
|
+++ b/msg.go |
|
@@ -279,6 +279,8 @@ func (m *Msg) Pack() error { |
|
} |
|
} |
|
m.Data = m.Data[:off] |
|
+ m.unpackedOffset = off |
|
+ m.nextSection = 5 // Additional section done |
|
return nil |
|
} |
|
|
|
@@ -343,34 +345,49 @@ func unpackRRs(cnt uint16, msg *cryptobyte.String, msgBuf []byte) ([]RR, error) |
|
return dst, nil |
|
} |
|
|
|
-func (m *Msg) unpack(dh header, msg, msgBuf []byte) error { |
|
- s := cryptobyte.String(msg) |
|
+func (m *Msg) unpack(dh header, s cryptobyte.String, msgBuf []byte) error { |
|
var err error |
|
|
|
- m.Question, err = m.unpackQuestions(dh.Qdcount, &s, msgBuf) |
|
- if err != nil { |
|
- return err |
|
+ if m.nextSection <= 1 { |
|
+ m.Question, err = m.unpackQuestions(dh.Qdcount, &s, msgBuf) |
|
+ if err != nil { |
|
+ return err |
|
+ } |
|
+ m.unpackedOffset = len(msgBuf) - len(s) |
|
+ m.nextSection = 2 |
|
} |
|
if m.Options > 0 && m.Options <= MsgOptionUnpackQuestion { |
|
return nil |
|
} |
|
|
|
- m.Answer, err = unpackRRs(dh.Ancount, &s, msgBuf) |
|
- if err != nil { |
|
- return err |
|
+ if m.nextSection <= 2 { |
|
+ m.Answer, err = unpackRRs(dh.Ancount, &s, msgBuf) |
|
+ if err != nil { |
|
+ return err |
|
+ } |
|
+ m.unpackedOffset = len(msgBuf) - len(s) |
|
+ m.nextSection = 3 |
|
} |
|
if m.Options > 0 && m.Options <= MsgOptionUnpackAnswer { |
|
return nil |
|
} |
|
|
|
- m.Ns, err = unpackRRs(dh.Nscount, &s, msgBuf) |
|
- if err != nil { |
|
- return err |
|
+ if m.nextSection <= 3 { |
|
+ m.Ns, err = unpackRRs(dh.Nscount, &s, msgBuf) |
|
+ if err != nil { |
|
+ return err |
|
+ } |
|
+ m.unpackedOffset = len(msgBuf) - len(s) |
|
+ m.nextSection = 4 |
|
} |
|
|
|
- m.Extra, err = unpackRRs(dh.Arcount, &s, msgBuf) |
|
- if err != nil { |
|
- return err |
|
+ if m.nextSection <= 4 { |
|
+ m.Extra, err = unpackRRs(dh.Arcount, &s, msgBuf) |
|
+ if err != nil { |
|
+ return err |
|
+ } |
|
+ m.unpackedOffset = len(msgBuf) - len(s) |
|
+ m.nextSection = 5 |
|
} |
|
|
|
// Check for the OPT RR and remove it entirely, unpack the OPT for option codes and put those in the Pseudo |
|
@@ -418,12 +435,24 @@ func (m *Msg) unpack(dh header, msg, msgBuf []byte) error { |
|
|
|
// Unpack unpacks a binary message that sits in m.Data to a Msg structure. |
|
func (m *Msg) Unpack() error { |
|
- s := cryptobyte.String(m.Data) |
|
var dh header |
|
+ s := cryptobyte.String(m.Data) |
|
if !dh.unpack(&s) { |
|
return unpack.Errorf("overflow %s", "MsgHeader") |
|
} |
|
m.setMsgHeader(dh) |
|
+ |
|
+ // If we have unpacked something, we can resume. |
|
+ if m.unpackedOffset > MsgHeaderSize && len(m.Data) >= m.unpackedOffset { |
|
+ if !s.Skip(m.unpackedOffset - MsgHeaderSize) { |
|
+ // This should not happen if unpackedOffset is valid and verified against len(Data) |
|
+ return unpack.ErrTruncatedMessage |
|
+ } |
|
+ } else { |
|
+ m.unpackedOffset = MsgHeaderSize |
|
+ m.nextSection = 1 |
|
+ } |
|
+ |
|
if m.Options > 0 && m.Options <= MsgOptionUnpackHeader { |
|
return nil |
|
} |
|
@@ -673,6 +702,8 @@ func (m *Msg) Write(p []byte) (n int, err error) { |
|
} |
|
|
|
n = copy(m.Data, p) |
|
+ m.unpackedOffset = 0 |
|
+ m.nextSection = 0 |
|
return n, nil |
|
} |
|
|
|
@@ -758,6 +789,8 @@ func (m *Msg) ReadFrom(r io.Reader) (int64, error) { |
|
return 0, err |
|
} |
|
m.Data = m.Data[:n] |
|
+ m.unpackedOffset = 0 |
|
+ m.nextSection = 0 |
|
return int64(n), nil |
|
} |
|
|
|
@@ -784,6 +817,8 @@ func (m *Msg) ReadFrom(r io.Reader) (int64, error) { |
|
if err != nil { |
|
m.Data = m.Data[:n] |
|
} |
|
+ m.unpackedOffset = 0 |
|
+ m.nextSection = 0 |
|
return int64(n), err |
|
} |
|
|
|
@@ -834,8 +869,10 @@ func (m *Msg) Copy() *Msg { |
|
Ns: m.Ns, |
|
Extra: m.Extra, |
|
Pseudo: m.Pseudo, |
|
- Data: m.Data, |
|
- msgPool: m.msgPool, |
|
+ Data: m.Data, |
|
+ msgPool: m.msgPool, |
|
+ unpackedOffset: m.unpackedOffset, |
|
+ nextSection: m.nextSection, |
|
} |
|
} |
|
|
|
diff --git a/server.go b/server.go |
|
index cfd8f66..58aded5 100644 |
|
--- a/server.go |
|
+++ b/server.go |
|
@@ -311,7 +311,7 @@ func (srv *Server) serveTCP(wg *sync.WaitGroup, conn net.Conn) { |
|
// serveDNS serves the message it skip the message handling if the received message has the response bit set. |
|
func (srv *Server) serveDNS(w *response, r *Msg) { |
|
r.msgPool = srv.MsgPool |
|
- r.Options = MsgOptionUnpackQuestion | MsgOptionUnpackHeader |
|
+ r.Options = MsgOptionUnpackQuestion |
|
|
|
if err := r.Unpack(); err != nil { |
|
srv.MsgInvalidFunc(r, err) |