Skip to content

Commit

Permalink
fix data race sending mail (#82)
Browse files Browse the repository at this point in the history
* fix data race sending mail when timeout exceed

* change localhost to 127.0.0.1 in test

* minimal fixes in test

* reduce loop logic
  • Loading branch information
xhit authored Jul 6, 2023
1 parent 6250c42 commit 6b425f7
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 13 deletions.
37 changes: 24 additions & 13 deletions email.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/textproto"
"strconv"
"strings"
"sync"
"time"

"github.com/toorop/go-dkim"
Expand Down Expand Up @@ -55,6 +56,7 @@ type SMTPServer struct {

// SMTPClient represents a SMTP Client for send email
type SMTPClient struct {
mu sync.Mutex
Client *smtpClient
KeepAlive bool
SendTimeout time.Duration
Expand Down Expand Up @@ -865,21 +867,29 @@ func (server *SMTPServer) Connect() (*SMTPClient, error) {

// Reset send RSET command to smtp client
func (smtpClient *SMTPClient) Reset() error {
smtpClient.mu.Lock()
defer smtpClient.mu.Unlock()
return smtpClient.Client.reset()
}

// Noop send NOOP command to smtp client
func (smtpClient *SMTPClient) Noop() error {
smtpClient.mu.Lock()
defer smtpClient.mu.Unlock()
return smtpClient.Client.noop()
}

// Quit send QUIT command to smtp client
func (smtpClient *SMTPClient) Quit() error {
smtpClient.mu.Lock()
defer smtpClient.mu.Unlock()
return smtpClient.Client.quit()
}

// Close closes the connection
func (smtpClient *SMTPClient) Close() error {
smtpClient.mu.Lock()
defer smtpClient.mu.Unlock()
return smtpClient.Client.close()
}

Expand Down Expand Up @@ -909,14 +919,14 @@ func send(from string, to []string, msg string, client *SMTPClient) error {
if client.SendTimeout != 0 {
smtpSendChannel = make(chan error, 1)

go func(from string, to []string, msg string, c *smtpClient) {
smtpSendChannel <- sendMailProcess(from, to, msg, c)
}(from, to, msg, client.Client)
go func(from string, to []string, msg string, client *SMTPClient) {
smtpSendChannel <- sendMailProcess(from, to, msg, client)
}(from, to, msg, client)
}

if client.SendTimeout == 0 {
// no SendTimeout, just fire the sendMailProcess
return sendMailProcess(from, to, msg, client.Client)
return sendMailProcess(from, to, msg, client)
}

// get the send result or timeout result, which ever happens first
Expand All @@ -928,35 +938,36 @@ func send(from string, to []string, msg string, client *SMTPClient) error {
checkKeepAlive(client)
return errors.New("Mail Error: SMTP Send timed out")
}

}
}

return errors.New("Mail Error: No SMTP Client Provided")
}

func sendMailProcess(from string, to []string, msg string, c *smtpClient) error {
func sendMailProcess(from string, to []string, msg string, c *SMTPClient) error {
c.mu.Lock()
defer c.mu.Unlock()

cmdArgs := make(map[string]string)

if _, ok := c.ext["SIZE"]; ok {
if _, ok := c.Client.ext["SIZE"]; ok {
cmdArgs["SIZE"] = strconv.Itoa(len(msg))
}

// Set the sender
if err := c.mail(from, cmdArgs); err != nil {
if err := c.Client.mail(from, cmdArgs); err != nil {
return err
}

// Set the recipients
for _, address := range to {
if err := c.rcpt(address); err != nil {
if err := c.Client.rcpt(address); err != nil {
return err
}
}

// Send the data command
w, err := c.data()
w, err := c.Client.data()
if err != nil {
return err
}
Expand All @@ -978,9 +989,9 @@ func sendMailProcess(from string, to []string, msg string, c *smtpClient) error
// check if keepAlive for close or reset
func checkKeepAlive(client *SMTPClient) {
if client.KeepAlive {
client.Client.reset()
client.Reset()
} else {
client.Client.quit()
client.Client.close()
client.Quit()
client.Close()
}
}
112 changes: 112 additions & 0 deletions email_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package mail

import (
"fmt"
"log"
"net"
"testing"
"time"
)

func TestSendRace(t *testing.T) {
port := 56666
port2 := 56667
timeout := 1 * time.Second

responses := []string{
`220 test connected`,
`250 after helo`,
`250 after mail from`,
`250 after rcpt to`,
`354 after data`,
}

startService(port, responses, 5*time.Second)
startService(port2, responses, 0)

server := NewSMTPClient()
server.ConnectTimeout = timeout
server.SendTimeout = timeout
server.KeepAlive = false
server.Host = `127.0.0.1`
server.Port = port

smtpClient, err := server.Connect()
if err != nil {
log.Fatalf("couldn't connect: %s", err.Error())
}
defer smtpClient.Close()

// create another server in other port to test timeouts
server.Port = port2
smtpClient2, err := server.Connect()
if err != nil {
log.Fatalf("couldn't connect: %s", err.Error())
}
defer smtpClient2.Close()

msg := NewMSG().
SetFrom(`foo@bar`).
AddTo(`rcpt@bar`).
SetSubject("subject").
SetBody(TextPlain, "body")

// the smtpClient2 has not timeout
err = msg.Send(smtpClient2)
if err != nil {
log.Fatalf("couldn't send: %s", err.Error())
}

// the smtpClient send to listener with the last response is after SendTimeout, so when this error is returned the test succeed.
err = msg.Send(smtpClient)
if err != nil && err.Error() != "Mail Error: SMTP Send timed out" {
log.Fatalf("couldn't send: %s", err.Error())
}
}

func startService(port int, responses []string, timeout time.Duration) {
log.Printf("starting service at %d...\n", port)
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
log.Fatalf("couldn't listen to port %d: %s", port, err)
}

go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Fatalf("couldn't listen accept the request in port %d", port)
}
go respond(conn, responses, timeout)
}
}()
}

func respond(conn net.Conn, responses []string, timeout time.Duration) {
buf := make([]byte, 1024)
for _, resp := range responses {
write(conn, resp)
n, err := conn.Read(buf)
if err != nil {
log.Println("couldn't read data")
return
}
readStr := string(buf[:n])
log.Printf("READ:%s", string(readStr))
}

// if timeout, sleep for that time, otherwise sent a 250 OK
if timeout > 0 {
time.Sleep(timeout)
} else {
write(conn, "250 OK")
}

conn.Close()
fmt.Print("\n\n")
}

func write(conn net.Conn, command string) {
log.Printf("WRITE:%s", command)
conn.Write([]byte(command + "\n"))
}

0 comments on commit 6b425f7

Please sign in to comment.