diff --git a/hkexnet/hkexnet.go b/hkexnet/hkexnet.go index b58297a..7dc763d 100644 --- a/hkexnet/hkexnet.go +++ b/hkexnet/hkexnet.go @@ -1080,3 +1080,74 @@ func (hc *Conn) chaffHelper() { } }() } + +//////////////////////////////////////////////////// + +// Copy copies from src to dst until either EOF is reached +// on src or an error occurs. It returns the number of bytes +// copied and the first error encountered while copying, if any. +// +// A successful Copy returns err == nil, not err == EOF. +// Because Copy is defined to read from src until EOF, it does +// not treat an EOF from Read as an error to be reported. +// +// If src implements the WriterTo interface, +// the copy is implemented by calling src.WriteTo(dst). +// Otherwise, if dst implements the ReaderFrom interface, +// the copy is implemented by calling dst.ReadFrom(src). +func Copy(dst io.Writer, src io.Reader) (written int64, err error) { + written, err = copyBuffer(dst, src, nil) + return +} + +// copyBuffer is the actual implementation of Copy and CopyBuffer. +// if buf is nil, one is allocated. +func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := src.(io.WriterTo); ok { + return wt.WriteTo(dst) + } + // Similarly, if the writer has a ReadFrom method, use it to do the copy. + if rt, ok := dst.(io.ReaderFrom); ok { + return rt.ReadFrom(src) + } + if buf == nil { + size := 32 * 1024 + if l, ok := src.(*io.LimitedReader); ok && int64(size) > l.N { + if l.N < 1 { + size = 1 + } else { + size = int(l.N) + } + } + buf = make([]byte, size) + } + for { + nr, er := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + written += int64(nw) + } + if ew != nil { + err = ew + break + } + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + //_,_ = dst.Write([]byte{0x2f}) + return written, err +} + +//////////////////////////////////////////////////// diff --git a/hkexsh/hkexsh.go b/hkexsh/hkexsh.go index 62aae20..f32d0be 100755 --- a/hkexsh/hkexsh.go +++ b/hkexsh/hkexsh.go @@ -239,14 +239,17 @@ func doShellMode(isInteractive bool, conn *hkexnet.Conn, oldState *hkexsh.State, // #gv:s/label=\"doShellMode\$1\"/label=\"shellRemoteToStdin\"/ // TODO:.gv:doShellMode:1:shellRemoteToStdin shellRemoteToStdin := func() { - defer wg.Done() + defer func() { + wg.Done() + }() + // By deferring a call to wg.Done(), // each goroutine guarantees that it marks // its direction's stream as finished. - // io.Copy() expects EOF so normally this will + // pkg io/Copy expects EOF so normally this will // exit with inerr == nil - _, inerr := io.Copy(os.Stdout, conn) + _, inerr := hkexnet.Copy(os.Stdout, conn) if inerr != nil { _ = hkexsh.Restore(int(os.Stdin.Fd()), oldState) // #nosec // Copy operations and user logging off will cause @@ -264,6 +267,7 @@ func doShellMode(isInteractive bool, conn *hkexnet.Conn, oldState *hkexsh.State, if isInteractive { log.Println("[* Got EOF *]") _ = hkexsh.Restore(int(os.Stdin.Fd()), oldState) // #nosec + os.Exit(int(rec.Status())) } } go shellRemoteToStdin() @@ -281,10 +285,9 @@ func doShellMode(isInteractive bool, conn *hkexnet.Conn, oldState *hkexsh.State, shellStdinToRemote := func() { defer wg.Done() //!defer wg.Done() - // Copy() expects EOF so this will - // exit with outerr == nil - //!_, outerr := io.Copy(conn, os.Stdin) _, outerr := func(conn *hkexnet.Conn, r io.Reader) (w int64, e error) { + // Copy() expects EOF so this will + // exit with outerr == nil w, e = io.Copy(conn, r) return w, e }(conn, os.Stdin)