Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • oss/libraries/go/services/job-queues
1 result
Select Git revision
Show changes
Showing
with 635 additions and 466 deletions
//go:build darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || aix || js //go:build darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || aix || js || zos
// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris aix js // +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris aix js zos
package sftp package sftp
......
...@@ -2,6 +2,7 @@ package sftp ...@@ -2,6 +2,7 @@ package sftp
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
...@@ -256,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie ...@@ -256,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
// read/write at the same time. For those services you will need to use // read/write at the same time. For those services you will need to use
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`. // `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`.
func (c *Client) Create(path string) (*File, error) { func (c *Client) Create(path string) (*File, error) {
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC)) return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
} }
const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
...@@ -321,19 +322,27 @@ func (c *Client) Walk(root string) *fs.Walker { ...@@ -321,19 +322,27 @@ func (c *Client) Walk(root string) *fs.Walker {
return fs.WalkFS(root, c) return fs.WalkFS(root, c)
} }
// ReadDir reads the directory named by dirname and returns a list of // ReadDir reads the directory named by p
// directory entries. // and returns a list of directory entries.
func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
handle, err := c.opendir(p) return c.ReadDirContext(context.Background(), p)
}
// ReadDirContext reads the directory named by p
// and returns a list of directory entries.
// The passed context can be used to cancel the operation
// returning all entries listed up to the cancellation.
func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, error) {
handle, err := c.opendir(ctx, p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer c.close(handle) // this has to defer earlier than the lock below defer c.close(handle) // this has to defer earlier than the lock below
var attrs []os.FileInfo var entries []os.FileInfo
var done = false var done = false
for !done { for !done {
id := c.nextID() id := c.nextID()
typ, data, err1 := c.sendPacket(nil, &sshFxpReaddirPacket{ typ, data, err1 := c.sendPacket(ctx, nil, &sshFxpReaddirPacket{
ID: id, ID: id,
Handle: handle, Handle: handle,
}) })
...@@ -354,11 +363,14 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { ...@@ -354,11 +363,14 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
filename, data = unmarshalString(data) filename, data = unmarshalString(data)
_, data = unmarshalString(data) // discard longname _, data = unmarshalString(data) // discard longname
var attr *FileStat var attr *FileStat
attr, data = unmarshalAttrs(data) attr, data, err = unmarshalAttrs(data)
if err != nil {
return nil, err
}
if filename == "." || filename == ".." { if filename == "." || filename == ".." {
continue continue
} }
attrs = append(attrs, fileInfoFromStat(attr, path.Base(filename))) entries = append(entries, fileInfoFromStat(attr, path.Base(filename)))
} }
case sshFxpStatus: case sshFxpStatus:
// TODO(dfc) scope warning! // TODO(dfc) scope warning!
...@@ -371,12 +383,12 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) { ...@@ -371,12 +383,12 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
if err == io.EOF { if err == io.EOF {
err = nil err = nil
} }
return attrs, err return entries, err
} }
func (c *Client) opendir(path string) (string, error) { func (c *Client) opendir(ctx context.Context, path string) (string, error) {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpOpendirPacket{ typ, data, err := c.sendPacket(ctx, nil, &sshFxpOpendirPacket{
ID: id, ID: id,
Path: path, Path: path,
}) })
...@@ -412,7 +424,7 @@ func (c *Client) Stat(p string) (os.FileInfo, error) { ...@@ -412,7 +424,7 @@ func (c *Client) Stat(p string) (os.FileInfo, error) {
// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link. // If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link.
func (c *Client) Lstat(p string) (os.FileInfo, error) { func (c *Client) Lstat(p string) (os.FileInfo, error) {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpLstatPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpLstatPacket{
ID: id, ID: id,
Path: p, Path: p,
}) })
...@@ -425,7 +437,11 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { ...@@ -425,7 +437,11 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
if sid != id { if sid != id {
return nil, &unexpectedIDErr{id, sid} return nil, &unexpectedIDErr{id, sid}
} }
attr, _ := unmarshalAttrs(data) attr, _, err := unmarshalAttrs(data)
if err != nil {
// avoid returning a valid value from fileInfoFromStats if err != nil.
return nil, err
}
return fileInfoFromStat(attr, path.Base(p)), nil return fileInfoFromStat(attr, path.Base(p)), nil
case sshFxpStatus: case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data)) return nil, normaliseError(unmarshalStatus(id, data))
...@@ -437,7 +453,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) { ...@@ -437,7 +453,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
// ReadLink reads the target of a symbolic link. // ReadLink reads the target of a symbolic link.
func (c *Client) ReadLink(p string) (string, error) { func (c *Client) ReadLink(p string) (string, error) {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpReadlinkPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpReadlinkPacket{
ID: id, ID: id,
Path: p, Path: p,
}) })
...@@ -466,7 +482,7 @@ func (c *Client) ReadLink(p string) (string, error) { ...@@ -466,7 +482,7 @@ func (c *Client) ReadLink(p string) (string, error) {
// Link creates a hard link at 'newname', pointing at the same inode as 'oldname' // Link creates a hard link at 'newname', pointing at the same inode as 'oldname'
func (c *Client) Link(oldname, newname string) error { func (c *Client) Link(oldname, newname string) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpHardlinkPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpHardlinkPacket{
ID: id, ID: id,
Oldpath: oldname, Oldpath: oldname,
Newpath: newname, Newpath: newname,
...@@ -485,7 +501,7 @@ func (c *Client) Link(oldname, newname string) error { ...@@ -485,7 +501,7 @@ func (c *Client) Link(oldname, newname string) error {
// Symlink creates a symbolic link at 'newname', pointing at target 'oldname' // Symlink creates a symbolic link at 'newname', pointing at target 'oldname'
func (c *Client) Symlink(oldname, newname string) error { func (c *Client) Symlink(oldname, newname string) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpSymlinkPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSymlinkPacket{
ID: id, ID: id,
Linkpath: newname, Linkpath: newname,
Targetpath: oldname, Targetpath: oldname,
...@@ -501,9 +517,9 @@ func (c *Client) Symlink(oldname, newname string) error { ...@@ -501,9 +517,9 @@ func (c *Client) Symlink(oldname, newname string) error {
} }
} }
func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error { func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFsetstatPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
ID: id, ID: id,
Handle: handle, Handle: handle,
Flags: flags, Flags: flags,
...@@ -523,7 +539,7 @@ func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error ...@@ -523,7 +539,7 @@ func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error
// setstat is a convience wrapper to allow for changing of various parts of the file descriptor. // setstat is a convience wrapper to allow for changing of various parts of the file descriptor.
func (c *Client) setstat(path string, flags uint32, attrs interface{}) error { func (c *Client) setstat(path string, flags uint32, attrs interface{}) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpSetstatPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSetstatPacket{
ID: id, ID: id,
Path: path, Path: path,
Flags: flags, Flags: flags,
...@@ -577,23 +593,37 @@ func (c *Client) Truncate(path string, size int64) error { ...@@ -577,23 +593,37 @@ func (c *Client) Truncate(path string, size int64) error {
return c.setstat(path, sshFileXferAttrSize, uint64(size)) return c.setstat(path, sshFileXferAttrSize, uint64(size))
} }
// SetExtendedData sets extended attributes of the named file. It uses the
// SSH_FILEXFER_ATTR_EXTENDED flag in the setstat request.
//
// This flag provides a general extension mechanism for vendor-specific extensions.
// Names of the attributes should be a string of the format "name@domain", where "domain"
// is a valid, registered domain name and "name" identifies the method. Server
// implementations SHOULD ignore extended data fields that they do not understand.
func (c *Client) SetExtendedData(path string, extended []StatExtended) error {
attrs := &FileStat{
Extended: extended,
}
return c.setstat(path, sshFileXferAttrExtended, attrs)
}
// Open opens the named file for reading. If successful, methods on the // Open opens the named file for reading. If successful, methods on the
// returned file can be used for reading; the associated file descriptor // returned file can be used for reading; the associated file descriptor
// has mode O_RDONLY. // has mode O_RDONLY.
func (c *Client) Open(path string) (*File, error) { func (c *Client) Open(path string) (*File, error) {
return c.open(path, flags(os.O_RDONLY)) return c.open(path, toPflags(os.O_RDONLY))
} }
// OpenFile is the generalized open call; most users will use Open or // OpenFile is the generalized open call; most users will use Open or
// Create instead. It opens the named file with specified flag (O_RDONLY // Create instead. It opens the named file with specified flag (O_RDONLY
// etc.). If successful, methods on the returned File can be used for I/O. // etc.). If successful, methods on the returned File can be used for I/O.
func (c *Client) OpenFile(path string, f int) (*File, error) { func (c *Client) OpenFile(path string, f int) (*File, error) {
return c.open(path, flags(f)) return c.open(path, toPflags(f))
} }
func (c *Client) open(path string, pflags uint32) (*File, error) { func (c *Client) open(path string, pflags uint32) (*File, error) {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpOpenPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpOpenPacket{
ID: id, ID: id,
Path: path, Path: path,
Pflags: pflags, Pflags: pflags,
...@@ -621,7 +651,7 @@ func (c *Client) open(path string, pflags uint32) (*File, error) { ...@@ -621,7 +651,7 @@ func (c *Client) open(path string, pflags uint32) (*File, error) {
// immediately after this request has been sent. // immediately after this request has been sent.
func (c *Client) close(handle string) error { func (c *Client) close(handle string) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpClosePacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpClosePacket{
ID: id, ID: id,
Handle: handle, Handle: handle,
}) })
...@@ -638,7 +668,7 @@ func (c *Client) close(handle string) error { ...@@ -638,7 +668,7 @@ func (c *Client) close(handle string) error {
func (c *Client) stat(path string) (*FileStat, error) { func (c *Client) stat(path string) (*FileStat, error) {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatPacket{
ID: id, ID: id,
Path: path, Path: path,
}) })
...@@ -651,8 +681,8 @@ func (c *Client) stat(path string) (*FileStat, error) { ...@@ -651,8 +681,8 @@ func (c *Client) stat(path string) (*FileStat, error) {
if sid != id { if sid != id {
return nil, &unexpectedIDErr{id, sid} return nil, &unexpectedIDErr{id, sid}
} }
attr, _ := unmarshalAttrs(data) attr, _, err := unmarshalAttrs(data)
return attr, nil return attr, err
case sshFxpStatus: case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data)) return nil, normaliseError(unmarshalStatus(id, data))
default: default:
...@@ -662,7 +692,7 @@ func (c *Client) stat(path string) (*FileStat, error) { ...@@ -662,7 +692,7 @@ func (c *Client) stat(path string) (*FileStat, error) {
func (c *Client) fstat(handle string) (*FileStat, error) { func (c *Client) fstat(handle string) (*FileStat, error) {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFstatPacket{
ID: id, ID: id,
Handle: handle, Handle: handle,
}) })
...@@ -675,8 +705,8 @@ func (c *Client) fstat(handle string) (*FileStat, error) { ...@@ -675,8 +705,8 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
if sid != id { if sid != id {
return nil, &unexpectedIDErr{id, sid} return nil, &unexpectedIDErr{id, sid}
} }
attr, _ := unmarshalAttrs(data) attr, _, err := unmarshalAttrs(data)
return attr, nil return attr, err
case sshFxpStatus: case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data)) return nil, normaliseError(unmarshalStatus(id, data))
default: default:
...@@ -691,7 +721,7 @@ func (c *Client) fstat(handle string) (*FileStat, error) { ...@@ -691,7 +721,7 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
func (c *Client) StatVFS(path string) (*StatVFS, error) { func (c *Client) StatVFS(path string) (*StatVFS, error) {
// send the StatVFS packet to the server // send the StatVFS packet to the server
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatvfsPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatvfsPacket{
ID: id, ID: id,
Path: path, Path: path,
}) })
...@@ -746,7 +776,7 @@ func (c *Client) Remove(path string) error { ...@@ -746,7 +776,7 @@ func (c *Client) Remove(path string) error {
func (c *Client) removeFile(path string) error { func (c *Client) removeFile(path string) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRemovePacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRemovePacket{
ID: id, ID: id,
Filename: path, Filename: path,
}) })
...@@ -764,7 +794,7 @@ func (c *Client) removeFile(path string) error { ...@@ -764,7 +794,7 @@ func (c *Client) removeFile(path string) error {
// RemoveDirectory removes a directory path. // RemoveDirectory removes a directory path.
func (c *Client) RemoveDirectory(path string) error { func (c *Client) RemoveDirectory(path string) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRmdirPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRmdirPacket{
ID: id, ID: id,
Path: path, Path: path,
}) })
...@@ -782,7 +812,7 @@ func (c *Client) RemoveDirectory(path string) error { ...@@ -782,7 +812,7 @@ func (c *Client) RemoveDirectory(path string) error {
// Rename renames a file. // Rename renames a file.
func (c *Client) Rename(oldname, newname string) error { func (c *Client) Rename(oldname, newname string) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRenamePacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRenamePacket{
ID: id, ID: id,
Oldpath: oldname, Oldpath: oldname,
Newpath: newname, Newpath: newname,
...@@ -802,7 +832,7 @@ func (c *Client) Rename(oldname, newname string) error { ...@@ -802,7 +832,7 @@ func (c *Client) Rename(oldname, newname string) error {
// which will replace newname if it already exists. // which will replace newname if it already exists.
func (c *Client) PosixRename(oldname, newname string) error { func (c *Client) PosixRename(oldname, newname string) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpPosixRenamePacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpPosixRenamePacket{
ID: id, ID: id,
Oldpath: oldname, Oldpath: oldname,
Newpath: newname, Newpath: newname,
...@@ -824,7 +854,7 @@ func (c *Client) PosixRename(oldname, newname string) error { ...@@ -824,7 +854,7 @@ func (c *Client) PosixRename(oldname, newname string) error {
// or relative pathnames without a leading slash into absolute paths. // or relative pathnames without a leading slash into absolute paths.
func (c *Client) RealPath(path string) (string, error) { func (c *Client) RealPath(path string) (string, error) {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRealpathPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRealpathPacket{
ID: id, ID: id,
Path: path, Path: path,
}) })
...@@ -861,7 +891,7 @@ func (c *Client) Getwd() (string, error) { ...@@ -861,7 +891,7 @@ func (c *Client) Getwd() (string, error) {
// parent folder does not exist (the method cannot create complete paths). // parent folder does not exist (the method cannot create complete paths).
func (c *Client) Mkdir(path string) error { func (c *Client) Mkdir(path string) error {
id := c.nextID() id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpMkdirPacket{ typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpMkdirPacket{
ID: id, ID: id,
Path: path, Path: path,
}) })
...@@ -967,16 +997,32 @@ func (c *Client) RemoveAll(path string) error { ...@@ -967,16 +997,32 @@ func (c *Client) RemoveAll(path string) error {
type File struct { type File struct {
c *Client c *Client
path string path string
handle string
mu sync.Mutex mu sync.RWMutex
handle string
offset int64 // current offset within remote file offset int64 // current offset within remote file
} }
// Close closes the File, rendering it unusable for I/O. It returns an // Close closes the File, rendering it unusable for I/O. It returns an
// error, if any. // error, if any.
func (f *File) Close() error { func (f *File) Close() error {
return f.c.close(f.handle) f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return os.ErrClosed
}
// The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`,
// it will unconditionally mark the handle as unused,
// so we need to also unconditionally mark this handle as invalid.
// By invalidating our local copy of the handle,
// we ensure that there cannot be any erroneous use-after-close requests sent after Close.
handle := f.handle
f.handle = ""
return f.c.close(handle)
} }
// Name returns the name of the file as presented to Open or Create. // Name returns the name of the file as presented to Open or Create.
...@@ -997,7 +1043,7 @@ func (f *File) Read(b []byte) (int, error) { ...@@ -997,7 +1043,7 @@ func (f *File) Read(b []byte) (int, error) {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
n, err := f.ReadAt(b, f.offset) n, err := f.readAt(b, f.offset)
f.offset += int64(n) f.offset += int64(n)
return n, err return n, err
} }
...@@ -1007,7 +1053,7 @@ func (f *File) Read(b []byte) (int, error) { ...@@ -1007,7 +1053,7 @@ func (f *File) Read(b []byte) (int, error) {
func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) { func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) {
for err == nil && n < len(b) { for err == nil && n < len(b) {
id := f.c.nextID() id := f.c.nextID()
typ, data, err := f.c.sendPacket(ch, &sshFxpReadPacket{ typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpReadPacket{
ID: id, ID: id,
Handle: f.handle, Handle: f.handle,
Offset: uint64(off) + uint64(n), Offset: uint64(off) + uint64(n),
...@@ -1062,6 +1108,19 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) { ...@@ -1062,6 +1108,19 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) {
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics, // the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// so the file offset is not altered during the read. // so the file offset is not altered during the read.
func (f *File) ReadAt(b []byte, off int64) (int, error) { func (f *File) ReadAt(b []byte, off int64) (int, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.readAt(b, off)
}
// readAt must be called while holding either the Read or Write mutex in File.
// This code is concurrent safe with itself, but not with Close.
func (f *File) readAt(b []byte, off int64) (int, error) {
if f.handle == "" {
return 0, os.ErrClosed
}
if len(b) <= f.c.maxPacket { if len(b) <= f.c.maxPacket {
// This should be able to be serviced with 1/2 requests. // This should be able to be serviced with 1/2 requests.
// So, just do it directly. // So, just do it directly.
...@@ -1179,7 +1238,9 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) { ...@@ -1179,7 +1238,9 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
if err != nil { if err != nil {
// return the offset as the start + how much we read before the error. // return the offset as the start + how much we read before the error.
errCh <- rErr{packet.off + int64(n), err} errCh <- rErr{packet.off + int64(n), err}
return
// DO NOT return.
// We want to ensure that workCh is drained before wg.Wait returns.
} }
} }
}() }()
...@@ -1258,6 +1319,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { ...@@ -1258,6 +1319,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
if f.c.disableConcurrentReads { if f.c.disableConcurrentReads {
return f.writeToSequential(w) return f.writeToSequential(w)
} }
...@@ -1405,12 +1470,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { ...@@ -1405,12 +1470,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
select { select {
case readWork.cur <- writeWork: case readWork.cur <- writeWork:
case <-cancel: case <-cancel:
return
} }
if err != nil { // DO NOT return.
return // We want to ensure that readCh is drained before wg.Wait returns.
}
} }
}() }()
} }
...@@ -1450,6 +1513,17 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) { ...@@ -1450,6 +1513,17 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
// Stat returns the FileInfo structure describing file. If there is an // Stat returns the FileInfo structure describing file. If there is an
// error. // error.
func (f *File) Stat() (os.FileInfo, error) { func (f *File) Stat() (os.FileInfo, error) {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return nil, os.ErrClosed
}
return f.stat()
}
func (f *File) stat() (os.FileInfo, error) {
fs, err := f.c.fstat(f.handle) fs, err := f.c.fstat(f.handle)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -1469,13 +1543,17 @@ func (f *File) Write(b []byte) (int, error) { ...@@ -1469,13 +1543,17 @@ func (f *File) Write(b []byte) (int, error) {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
n, err := f.WriteAt(b, f.offset) if f.handle == "" {
return 0, os.ErrClosed
}
n, err := f.writeAt(b, f.offset)
f.offset += int64(n) f.offset += int64(n)
return n, err return n, err
} }
func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) { func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) {
typ, data, err := f.c.sendPacket(ch, &sshFxpWritePacket{ typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpWritePacket{
ID: f.c.nextID(), ID: f.c.nextID(),
Handle: f.handle, Handle: f.handle,
Offset: uint64(off), Offset: uint64(off),
...@@ -1627,6 +1705,19 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) { ...@@ -1627,6 +1705,19 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) {
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics, // the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics,
// so the file offset is not altered during the write. // so the file offset is not altered during the write.
func (f *File) WriteAt(b []byte, off int64) (written int, err error) { func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return 0, os.ErrClosed
}
return f.writeAt(b, off)
}
// writeAt must be called while holding either the Read or Write mutex in File.
// This code is concurrent safe with itself, but not with Close.
func (f *File) writeAt(b []byte, off int64) (written int, err error) {
if len(b) <= f.c.maxPacket { if len(b) <= f.c.maxPacket {
// We can do this in one write. // We can do this in one write.
return f.writeChunkAt(nil, b, off) return f.writeChunkAt(nil, b, off)
...@@ -1665,7 +1756,21 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) { ...@@ -1665,7 +1756,21 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
// Giving a concurrency of less than one will default to the Client’s max concurrency. // Giving a concurrency of less than one will default to the Client’s max concurrency.
// //
// Otherwise, the given concurrency will be capped by the Client's max concurrency. // Otherwise, the given concurrency will be capped by the Client's max concurrency.
//
// When one needs to guarantee concurrent reads/writes, this method is preferred
// over ReadFrom.
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) { func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.readFromWithConcurrency(r, concurrency)
}
func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
if f.handle == "" {
return 0, os.ErrClosed
}
// Split the write into multiple maxPacket sized concurrent writes. // Split the write into multiple maxPacket sized concurrent writes.
// This allows writes with a suitably large reader // This allows writes with a suitably large reader
// to transfer data at a much faster rate due to overlapping round trip times. // to transfer data at a much faster rate due to overlapping round trip times.
...@@ -1757,6 +1862,9 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64 ...@@ -1757,6 +1862,9 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64
if err != nil { if err != nil {
errCh <- rwErr{work.off, err} errCh <- rwErr{work.off, err}
// DO NOT return.
// We want to ensure that workCh is drained before wg.Wait returns.
} }
} }
}() }()
...@@ -1811,10 +1919,26 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64 ...@@ -1811,10 +1919,26 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64
// This method is preferred over calling Write multiple times // This method is preferred over calling Write multiple times
// to maximise throughput for transferring the entire file, // to maximise throughput for transferring the entire file,
// especially over high-latency links. // especially over high-latency links.
//
// To ensure concurrent writes, the given r needs to implement one of
// the following receiver methods:
//
// Len() int
// Size() int64
// Stat() (os.FileInfo, error)
//
// or be an instance of [io.LimitedReader] to determine the number of possible
// concurrent requests. Otherwise, reads/writes are performed sequentially.
// ReadFromWithConcurrency can be used explicitly to guarantee concurrent
// processing of the reader.
func (f *File) ReadFrom(r io.Reader) (int64, error) { func (f *File) ReadFrom(r io.Reader) (int64, error) {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
if f.c.useConcurrentWrites { if f.c.useConcurrentWrites {
var remain int64 var remain int64
switch r := r.(type) { switch r := r.(type) {
...@@ -1836,7 +1960,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { ...@@ -1836,7 +1960,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
if remain < 0 { if remain < 0 {
// We can strongly assert that we want default max concurrency here. // We can strongly assert that we want default max concurrency here.
return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests) return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests)
} }
if remain > int64(f.c.maxPacket) { if remain > int64(f.c.maxPacket) {
...@@ -1851,7 +1975,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) { ...@@ -1851,7 +1975,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
concurrency64 = int64(f.c.maxConcurrentRequests) concurrency64 = int64(f.c.maxConcurrentRequests)
} }
return f.ReadFromWithConcurrency(r, int(concurrency64)) return f.readFromWithConcurrency(r, int(concurrency64))
} }
} }
...@@ -1894,12 +2018,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { ...@@ -1894,12 +2018,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
switch whence { switch whence {
case io.SeekStart: case io.SeekStart:
case io.SeekCurrent: case io.SeekCurrent:
offset += f.offset offset += f.offset
case io.SeekEnd: case io.SeekEnd:
fi, err := f.Stat() fi, err := f.stat()
if err != nil { if err != nil {
return f.offset, err return f.offset, err
} }
...@@ -1918,22 +2046,84 @@ func (f *File) Seek(offset int64, whence int) (int64, error) { ...@@ -1918,22 +2046,84 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
// Chown changes the uid/gid of the current file. // Chown changes the uid/gid of the current file.
func (f *File) Chown(uid, gid int) error { func (f *File) Chown(uid, gid int) error {
return f.c.Chown(f.path, uid, gid) f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{
UID: uint32(uid),
GID: uint32(gid),
})
} }
// Chmod changes the permissions of the current file. // Chmod changes the permissions of the current file.
// //
// See Client.Chmod for details. // See Client.Chmod for details.
func (f *File) Chmod(mode os.FileMode) error { func (f *File) Chmod(mode os.FileMode) error {
return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode)) f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
}
// SetExtendedData sets extended attributes of the current file. It uses the
// SSH_FILEXFER_ATTR_EXTENDED flag in the setstat request.
//
// This flag provides a general extension mechanism for vendor-specific extensions.
// Names of the attributes should be a string of the format "name@domain", where "domain"
// is a valid, registered domain name and "name" identifies the method. Server
// implementations SHOULD ignore extended data fields that they do not understand.
func (f *File) SetExtendedData(path string, extended []StatExtended) error {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
attrs := &FileStat{
Extended: extended,
}
return f.c.fsetstat(f.handle, sshFileXferAttrExtended, attrs)
}
// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size))
} }
// Sync requests a flush of the contents of a File to stable storage. // Sync requests a flush of the contents of a File to stable storage.
// //
// Sync requires the server to support the fsync@openssh.com extension. // Sync requires the server to support the fsync@openssh.com extension.
func (f *File) Sync() error { func (f *File) Sync() error {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return os.ErrClosed
}
id := f.c.nextID() id := f.c.nextID()
typ, data, err := f.c.sendPacket(nil, &sshFxpFsyncPacket{ typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id, ID: id,
Handle: f.handle, Handle: f.handle,
}) })
...@@ -1948,15 +2138,6 @@ func (f *File) Sync() error { ...@@ -1948,15 +2138,6 @@ func (f *File) Sync() error {
} }
} }
// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
}
// normaliseError normalises an error into a more standard form that can be // normaliseError normalises an error into a more standard form that can be
// checked against stdlib errors like io.EOF or os.ErrNotExist. // checked against stdlib errors like io.EOF or os.ErrNotExist.
func normaliseError(err error) error { func normaliseError(err error) error {
...@@ -1981,15 +2162,14 @@ func normaliseError(err error) error { ...@@ -1981,15 +2162,14 @@ func normaliseError(err error) error {
// flags converts the flags passed to OpenFile into ssh flags. // flags converts the flags passed to OpenFile into ssh flags.
// Unsupported flags are ignored. // Unsupported flags are ignored.
func flags(f int) uint32 { func toPflags(f int) uint32 {
var out uint32 var out uint32
switch f & os.O_WRONLY { switch f & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) {
case os.O_WRONLY:
out |= sshFxfWrite
case os.O_RDONLY: case os.O_RDONLY:
out |= sshFxfRead out |= sshFxfRead
} case os.O_WRONLY:
if f&os.O_RDWR == os.O_RDWR { out |= sshFxfWrite
case os.O_RDWR:
out |= sshFxfRead | sshFxfWrite out |= sshFxfRead | sshFxfWrite
} }
if f&os.O_APPEND == os.O_APPEND { if f&os.O_APPEND == os.O_APPEND {
...@@ -2013,7 +2193,7 @@ func flags(f int) uint32 { ...@@ -2013,7 +2193,7 @@ func flags(f int) uint32 {
// setuid, setgid and sticky in m, because we've historically supported those // setuid, setgid and sticky in m, because we've historically supported those
// bits, and we mask off any non-permission bits. // bits, and we mask off any non-permission bits.
func toChmodPerm(m os.FileMode) (perm uint32) { func toChmodPerm(m os.FileMode) (perm uint32) {
const mask = os.ModePerm | s_ISUID | s_ISGID | s_ISVTX const mask = os.ModePerm | os.FileMode(s_ISUID|s_ISGID|s_ISVTX)
perm = uint32(m & mask) perm = uint32(m & mask)
if m&os.ModeSetuid != 0 { if m&os.ModeSetuid != 0 {
......
package sftp package sftp
import ( import (
"context"
"encoding" "encoding"
"fmt" "fmt"
"io" "io"
...@@ -128,15 +129,20 @@ type idmarshaler interface { ...@@ -128,15 +129,20 @@ type idmarshaler interface {
encoding.BinaryMarshaler encoding.BinaryMarshaler
} }
func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) { func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) {
if cap(ch) < 1 { if cap(ch) < 1 {
ch = make(chan result, 1) ch = make(chan result, 1)
} }
c.dispatchRequest(ch, p) c.dispatchRequest(ch, p)
s := <-ch
select {
case <-ctx.Done():
return 0, nil, ctx.Err()
case s := <-ch:
return s.typ, s.data, s.err return s.typ, s.data, s.err
} }
}
// dispatchRequest should ideally only be called by race-detection tests outside of this file, // dispatchRequest should ideally only be called by race-detection tests outside of this file,
// where you have to ensure two packets are in flight sequentially after each other. // where you have to ensure two packets are in flight sequentially after each other.
......
//go:build aix || darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || js //go:build aix || darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || js || zos
// +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris js // +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris js zos
package sftp package sftp
......
...@@ -56,6 +56,11 @@ func marshalFileInfo(b []byte, fi os.FileInfo) []byte { ...@@ -56,6 +56,11 @@ func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
flags, fileStat := fileStatFromInfo(fi) flags, fileStat := fileStatFromInfo(fi)
b = marshalUint32(b, flags) b = marshalUint32(b, flags)
return marshalFileStat(b, flags, fileStat)
}
func marshalFileStat(b []byte, flags uint32, fileStat *FileStat) []byte {
if flags&sshFileXferAttrSize != 0 { if flags&sshFileXferAttrSize != 0 {
b = marshalUint64(b, fileStat.Size) b = marshalUint64(b, fileStat.Size)
} }
...@@ -91,10 +96,9 @@ func marshalStatus(b []byte, err StatusError) []byte { ...@@ -91,10 +96,9 @@ func marshalStatus(b []byte, err StatusError) []byte {
} }
func marshal(b []byte, v interface{}) []byte { func marshal(b []byte, v interface{}) []byte {
if v == nil {
return b
}
switch v := v.(type) { switch v := v.(type) {
case nil:
return b
case uint8: case uint8:
return append(b, v) return append(b, v)
case uint32: case uint32:
...@@ -103,6 +107,8 @@ func marshal(b []byte, v interface{}) []byte { ...@@ -103,6 +107,8 @@ func marshal(b []byte, v interface{}) []byte {
return marshalUint64(b, v) return marshalUint64(b, v)
case string: case string:
return marshalString(b, v) return marshalString(b, v)
case []byte:
return append(b, v...)
case os.FileInfo: case os.FileInfo:
return marshalFileInfo(b, v) return marshalFileInfo(b, v)
default: default:
...@@ -168,38 +174,69 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) { ...@@ -168,38 +174,69 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) {
return string(b[:n]), b[n:], nil return string(b[:n]), b[n:], nil
} }
func unmarshalAttrs(b []byte) (*FileStat, []byte) { func unmarshalAttrs(b []byte) (*FileStat, []byte, error) {
flags, b := unmarshalUint32(b) flags, b, err := unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
return unmarshalFileStat(flags, b) return unmarshalFileStat(flags, b)
} }
func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte, error) {
var fs FileStat var fs FileStat
var err error
if flags&sshFileXferAttrSize == sshFileXferAttrSize { if flags&sshFileXferAttrSize == sshFileXferAttrSize {
fs.Size, b, _ = unmarshalUint64Safe(b) fs.Size, b, err = unmarshalUint64Safe(b)
if err != nil {
return nil, b, err
} }
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.UID, b, _ = unmarshalUint32Safe(b)
} }
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID { if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.GID, b, _ = unmarshalUint32Safe(b) fs.UID, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
fs.GID, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
} }
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions { if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
fs.Mode, b, _ = unmarshalUint32Safe(b) fs.Mode, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
} }
if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime { if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
fs.Atime, b, _ = unmarshalUint32Safe(b) fs.Atime, b, err = unmarshalUint32Safe(b)
fs.Mtime, b, _ = unmarshalUint32Safe(b) if err != nil {
return nil, b, err
}
fs.Mtime, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
} }
if flags&sshFileXferAttrExtended == sshFileXferAttrExtended { if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
var count uint32 var count uint32
count, b, _ = unmarshalUint32Safe(b) count, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
ext := make([]StatExtended, count) ext := make([]StatExtended, count)
for i := uint32(0); i < count; i++ { for i := uint32(0); i < count; i++ {
var typ string var typ string
var data string var data string
typ, b, _ = unmarshalStringSafe(b) typ, b, err = unmarshalStringSafe(b)
data, b, _ = unmarshalStringSafe(b) if err != nil {
return nil, b, err
}
data, b, err = unmarshalStringSafe(b)
if err != nil {
return nil, b, err
}
ext[i] = StatExtended{ ext[i] = StatExtended{
ExtType: typ, ExtType: typ,
ExtData: data, ExtData: data,
...@@ -207,7 +244,7 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) { ...@@ -207,7 +244,7 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
} }
fs.Extended = ext fs.Extended = ext
} }
return &fs, b return &fs, b, nil
} }
func unmarshalStatus(id uint32, data []byte) error { func unmarshalStatus(id uint32, data []byte) error {
...@@ -681,12 +718,13 @@ type sshFxpOpenPacket struct { ...@@ -681,12 +718,13 @@ type sshFxpOpenPacket struct {
ID uint32 ID uint32
Path string Path string
Pflags uint32 Pflags uint32
Flags uint32 // ignored Flags uint32
Attrs interface{}
} }
func (p *sshFxpOpenPacket) id() uint32 { return p.ID } func (p *sshFxpOpenPacket) id() uint32 { return p.ID }
func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id) l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Path) + 4 + len(p.Path) +
4 + 4 4 + 4
...@@ -698,7 +736,22 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) { ...@@ -698,7 +736,22 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
b = marshalUint32(b, p.Pflags) b = marshalUint32(b, p.Pflags)
b = marshalUint32(b, p.Flags) b = marshalUint32(b, p.Flags)
return b, nil switch attrs := p.Attrs.(type) {
case []byte:
return b, attrs, nil // may as well short-ciruit this case.
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
} }
func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
...@@ -709,12 +762,25 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error { ...@@ -709,12 +762,25 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
return err return err
} else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil { } else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil {
return err return err
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil { } else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
return err return err
} }
p.Attrs = b
return nil return nil
} }
func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs, nil
case []byte:
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}
type sshFxpReadPacket struct { type sshFxpReadPacket struct {
ID uint32 ID uint32
Len uint32 Len uint32
...@@ -757,7 +823,7 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error { ...@@ -757,7 +823,7 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error {
// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length) // So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length)
const dataHeaderLen = 4 + 1 + 4 + 4 const dataHeaderLen = 4 + 1 + 4 + 4
func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte { func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32, maxTxPacket uint32) []byte {
dataLen := p.Len dataLen := p.Len
if dataLen > maxTxPacket { if dataLen > maxTxPacket {
dataLen = maxTxPacket dataLen = maxTxPacket
...@@ -943,9 +1009,17 @@ func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) { ...@@ -943,9 +1009,17 @@ func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) {
b = marshalString(b, p.Path) b = marshalString(b, p.Path)
b = marshalUint32(b, p.Flags) b = marshalUint32(b, p.Flags)
payload := marshal(nil, p.Attrs) switch attrs := p.Attrs.(type) {
case []byte:
return b, attrs, nil // may as well short-ciruit this case.
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, payload, nil return b, marshal(nil, p.Attrs), nil
} }
func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) { func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) {
...@@ -964,9 +1038,17 @@ func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) { ...@@ -964,9 +1038,17 @@ func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) {
b = marshalString(b, p.Handle) b = marshalString(b, p.Handle)
b = marshalUint32(b, p.Flags) b = marshalUint32(b, p.Flags)
payload := marshal(nil, p.Attrs) switch attrs := p.Attrs.(type) {
case []byte:
return b, attrs, nil // may as well short-ciruit this case.
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, payload, nil return b, marshal(nil, p.Attrs), nil
} }
func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) { func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) {
...@@ -987,6 +1069,18 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error { ...@@ -987,6 +1069,18 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error {
return nil return nil
} }
func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs, nil
case []byte:
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}
func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
var err error var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil { if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
...@@ -1000,6 +1094,18 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error { ...@@ -1000,6 +1094,18 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
return nil return nil
} }
func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs, nil
case []byte:
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}
type sshFxpHandlePacket struct { type sshFxpHandlePacket struct {
ID uint32 ID uint32
Handle string Handle string
......
...@@ -3,7 +3,6 @@ package sftp ...@@ -3,7 +3,6 @@ package sftp
// Methods on the Request object to make working with the Flags bitmasks and // Methods on the Request object to make working with the Flags bitmasks and
// Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write // Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write
// request and AttrFlags() and Attributes() when working with SetStat requests. // request and AttrFlags() and Attributes() when working with SetStat requests.
import "os"
// FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags // FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags
// (https://golang.org/pkg/os/#pkg-constants). // (https://golang.org/pkg/os/#pkg-constants).
...@@ -50,14 +49,9 @@ func (r *Request) AttrFlags() FileAttrFlags { ...@@ -50,14 +49,9 @@ func (r *Request) AttrFlags() FileAttrFlags {
return newFileAttrFlags(r.Flags) return newFileAttrFlags(r.Flags)
} }
// FileMode returns the Mode SFTP file attributes wrapped as os.FileMode
func (a FileStat) FileMode() os.FileMode {
return os.FileMode(a.Mode)
}
// Attributes parses file attributes byte blob and return them in a // Attributes parses file attributes byte blob and return them in a
// FileStat object. // FileStat object.
func (r *Request) Attributes() *FileStat { func (r *Request) Attributes() *FileStat {
fs, _ := unmarshalFileStat(r.Flags, r.Attrs) fs, _, _ := unmarshalFileStat(r.Flags, r.Attrs)
return fs return fs
} }
...@@ -30,7 +30,7 @@ type FileReader interface { ...@@ -30,7 +30,7 @@ type FileReader interface {
// FileWriter should return an io.WriterAt for the filepath. // FileWriter should return an io.WriterAt for the filepath.
// //
// The request server code will call Close() on the returned io.WriterAt // The request server code will call Close() on the returned io.WriterAt
// ojbect if an io.Closer type assertion succeeds. // object if an io.Closer type assertion succeeds.
// Note in cases of an error, the error text will be sent to the client. // Note in cases of an error, the error text will be sent to the client.
// Note when receiving an Append flag it is important to not open files using // Note when receiving an Append flag it is important to not open files using
// O_APPEND if you plan to use WriteAt, as they conflict. // O_APPEND if you plan to use WriteAt, as they conflict.
...@@ -144,6 +144,8 @@ type NameLookupFileLister interface { ...@@ -144,6 +144,8 @@ type NameLookupFileLister interface {
// //
// If a populated entry implements [FileInfoExtendedData], extended attributes will also be returned to the client. // If a populated entry implements [FileInfoExtendedData], extended attributes will also be returned to the client.
// //
// The request server code will call Close() on ListerAt if an io.Closer type assertion succeeds.
//
// Note in cases of an error, the error text will be sent to the client. // Note in cases of an error, the error text will be sent to the client.
type ListerAt interface { type ListerAt interface {
ListAt([]os.FileInfo, int64) (int, error) ListAt([]os.FileInfo, int64) (int, error)
......
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
"sync" "sync"
) )
var maxTxPacket uint32 = 1 << 15 const defaultMaxTxPacket uint32 = 1 << 15
// Handlers contains the 4 SFTP server request handlers. // Handlers contains the 4 SFTP server request handlers.
type Handlers struct { type Handlers struct {
...@@ -28,6 +28,7 @@ type RequestServer struct { ...@@ -28,6 +28,7 @@ type RequestServer struct {
pktMgr *packetManager pktMgr *packetManager
startDirectory string startDirectory string
maxTxPacket uint32
mu sync.RWMutex mu sync.RWMutex
handleCount int handleCount int
...@@ -57,6 +58,22 @@ func WithStartDirectory(startDirectory string) RequestServerOption { ...@@ -57,6 +58,22 @@ func WithStartDirectory(startDirectory string) RequestServerOption {
} }
} }
// WithRSMaxTxPacket sets the maximum size of the payload returned to the client,
// measured in bytes. The default value is 32768 bytes, and this option
// can only be used to increase it. Setting this option to a larger value
// should be safe, because the client decides the size of the requested payload.
//
// The default maximum packet size is 32768 bytes.
func WithRSMaxTxPacket(size uint32) RequestServerOption {
return func(rs *RequestServer) {
if size < defaultMaxTxPacket {
return
}
rs.maxTxPacket = size
}
}
// NewRequestServer creates/allocates/returns new RequestServer. // NewRequestServer creates/allocates/returns new RequestServer.
// Normally there will be one server per user-session. // Normally there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer { func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
...@@ -73,6 +90,7 @@ func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServ ...@@ -73,6 +90,7 @@ func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServ
pktMgr: newPktMgr(svrConn), pktMgr: newPktMgr(svrConn),
startDirectory: "/", startDirectory: "/",
maxTxPacket: defaultMaxTxPacket,
openRequests: make(map[string]*Request), openRequests: make(map[string]*Request),
} }
...@@ -260,7 +278,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR ...@@ -260,7 +278,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Method: "Stat", Method: "Stat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath), Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
} }
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
} }
case *sshFxpFsetstatPacket: case *sshFxpFsetstatPacket:
handle := pkt.getHandle() handle := pkt.getHandle()
...@@ -272,7 +290,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR ...@@ -272,7 +290,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Method: "Setstat", Method: "Setstat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath), Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
} }
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
} }
case *sshFxpExtendedPacketPosixRename: case *sshFxpExtendedPacketPosixRename:
request := &Request{ request := &Request{
...@@ -280,24 +298,24 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR ...@@ -280,24 +298,24 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath), Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath), Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath),
} }
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case *sshFxpExtendedPacketStatVFS: case *sshFxpExtendedPacketStatVFS:
request := &Request{ request := &Request{
Method: "StatVFS", Method: "StatVFS",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path), Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
} }
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case hasHandle: case hasHandle:
handle := pkt.getHandle() handle := pkt.getHandle()
request, ok := rs.getRequest(handle) request, ok := rs.getRequest(handle)
if !ok { if !ok {
rpkt = statusFromError(pkt.id(), EBADF) rpkt = statusFromError(pkt.id(), EBADF)
} else { } else {
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
} }
case hasPath: case hasPath:
request := requestFromPacket(ctx, pkt, rs.startDirectory) request := requestFromPacket(ctx, pkt, rs.startDirectory)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
request.close() request.close()
default: default:
rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported) rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
......
...@@ -121,6 +121,22 @@ func (s *state) getListerAt() ListerAt { ...@@ -121,6 +121,22 @@ func (s *state) getListerAt() ListerAt {
return s.listerAt return s.listerAt
} }
func (s *state) closeListerAt() error {
s.mu.Lock()
defer s.mu.Unlock()
var err error
if s.listerAt != nil {
if c, ok := s.listerAt.(io.Closer); ok {
err = c.Close()
}
s.listerAt = nil
}
return err
}
// Request contains the data and state for the incoming service request. // Request contains the data and state for the incoming service request.
type Request struct { type Request struct {
// Get, Put, Setstat, Stat, Rename, Remove // Get, Put, Setstat, Stat, Rename, Remove
...@@ -178,6 +194,7 @@ func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Reques ...@@ -178,6 +194,7 @@ func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Reques
switch p := pkt.(type) { switch p := pkt.(type) {
case *sshFxpOpenPacket: case *sshFxpOpenPacket:
request.Flags = p.Pflags request.Flags = p.Pflags
request.Attrs = p.Attrs.([]byte)
case *sshFxpSetstatPacket: case *sshFxpSetstatPacket:
request.Flags = p.Flags request.Flags = p.Flags
request.Attrs = p.Attrs.([]byte) request.Attrs = p.Attrs.([]byte)
...@@ -229,9 +246,9 @@ func (r *Request) close() error { ...@@ -229,9 +246,9 @@ func (r *Request) close() error {
} }
}() }()
rd, wr, rw := r.getAllReaderWriters() err := r.state.closeListerAt()
var err error rd, wr, rw := r.getAllReaderWriters()
// Close errors on a Writer are far more likely to be the important one. // Close errors on a Writer are far more likely to be the important one.
// As they can be information that there was a loss of data. // As they can be information that there was a loss of data.
...@@ -283,14 +300,14 @@ func (r *Request) transferError(err error) { ...@@ -283,14 +300,14 @@ func (r *Request) transferError(err error) {
} }
// called from worker to handle packet/request // called from worker to handle packet/request
func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
switch r.Method { switch r.Method {
case "Get": case "Get":
return fileget(handlers.FileGet, r, pkt, alloc, orderID) return fileget(handlers.FileGet, r, pkt, alloc, orderID, maxTxPacket)
case "Put": case "Put":
return fileput(handlers.FilePut, r, pkt, alloc, orderID) return fileput(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Open": case "Open":
return fileputget(handlers.FilePut, r, pkt, alloc, orderID) return fileputget(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS": case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
return filecmd(handlers.FileCmd, r, pkt) return filecmd(handlers.FileCmd, r, pkt)
case "List": case "List":
...@@ -375,13 +392,13 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { ...@@ -375,13 +392,13 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
} }
// wrap FileReader handler // wrap FileReader handler
func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rd := r.getReaderAt() rd := r.getReaderAt()
if rd == nil { if rd == nil {
return statusFromError(pkt.id(), errors.New("unexpected read packet")) return statusFromError(pkt.id(), errors.New("unexpected read packet"))
} }
data, offset, _ := packetData(pkt, alloc, orderID) data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)
n, err := rd.ReadAt(data, offset) n, err := rd.ReadAt(data, offset)
// only return EOF error if no data left to read // only return EOF error if no data left to read
...@@ -397,20 +414,20 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde ...@@ -397,20 +414,20 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde
} }
// wrap FileWriter handler // wrap FileWriter handler
func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
wr := r.getWriterAt() wr := r.getWriterAt()
if wr == nil { if wr == nil {
return statusFromError(pkt.id(), errors.New("unexpected write packet")) return statusFromError(pkt.id(), errors.New("unexpected write packet"))
} }
data, offset, _ := packetData(pkt, alloc, orderID) data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)
_, err := wr.WriteAt(data, offset) _, err := wr.WriteAt(data, offset)
return statusFromError(pkt.id(), err) return statusFromError(pkt.id(), err)
} }
// wrap OpenFileWriter handler // wrap OpenFileWriter handler
func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rw := r.getWriterAtReaderAt() rw := r.getWriterAtReaderAt()
if rw == nil { if rw == nil {
return statusFromError(pkt.id(), errors.New("unexpected write and read packet")) return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
...@@ -418,7 +435,7 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o ...@@ -418,7 +435,7 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o
switch p := pkt.(type) { switch p := pkt.(type) {
case *sshFxpReadPacket: case *sshFxpReadPacket:
data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset) data, offset := p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset)
n, err := rw.ReadAt(data, offset) n, err := rw.ReadAt(data, offset)
// only return EOF error if no data left to read // only return EOF error if no data left to read
...@@ -444,10 +461,10 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o ...@@ -444,10 +461,10 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o
} }
// file data for additional read/write packets // file data for additional read/write packets
func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) { func packetData(p requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) (data []byte, offset int64, length uint32) {
switch p := p.(type) { switch p := p.(type) {
case *sshFxpReadPacket: case *sshFxpReadPacket:
return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len return p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset), p.Len
case *sshFxpWritePacket: case *sshFxpWritePacket:
return p.Data, int64(p.Offset), p.Length return p.Data, int64(p.Offset), p.Length
} }
......
...@@ -34,6 +34,7 @@ type Server struct { ...@@ -34,6 +34,7 @@ type Server struct {
openFilesLock sync.RWMutex openFilesLock sync.RWMutex
handleCount int handleCount int
workDir string workDir string
maxTxPacket uint32
} }
func (svr *Server) nextHandle(f *os.File) string { func (svr *Server) nextHandle(f *os.File) string {
...@@ -86,6 +87,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error) ...@@ -86,6 +87,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error)
debugStream: ioutil.Discard, debugStream: ioutil.Discard,
pktMgr: newPktMgr(svrConn), pktMgr: newPktMgr(svrConn),
openFiles: make(map[string]*os.File), openFiles: make(map[string]*os.File),
maxTxPacket: defaultMaxTxPacket,
} }
for _, o := range options { for _, o := range options {
...@@ -139,6 +141,24 @@ func WithServerWorkingDirectory(workDir string) ServerOption { ...@@ -139,6 +141,24 @@ func WithServerWorkingDirectory(workDir string) ServerOption {
} }
} }
// WithMaxTxPacket sets the maximum size of the payload returned to the client,
// measured in bytes. The default value is 32768 bytes, and this option
// can only be used to increase it. Setting this option to a larger value
// should be safe, because the client decides the size of the requested payload.
//
// The default maximum packet size is 32768 bytes.
func WithMaxTxPacket(size uint32) ServerOption {
return func(s *Server) error {
if size < defaultMaxTxPacket {
return errors.New("size must be greater than or equal to 32768")
}
s.maxTxPacket = size
return nil
}
}
type rxPacket struct { type rxPacket struct {
pktType fxp pktType fxp
pktBytes []byte pktBytes []byte
...@@ -287,7 +307,7 @@ func handlePacket(s *Server, p orderedRequest) error { ...@@ -287,7 +307,7 @@ func handlePacket(s *Server, p orderedRequest) error {
f, ok := s.getHandle(p.Handle) f, ok := s.getHandle(p.Handle)
if ok { if ok {
err = nil err = nil
data := p.getDataSlice(s.pktMgr.alloc, orderID) data := p.getDataSlice(s.pktMgr.alloc, orderID, s.maxTxPacket)
n, _err := f.ReadAt(data, int64(p.Offset)) n, _err := f.ReadAt(data, int64(p.Offset))
if _err != nil && (_err != io.EOF || n == 0) { if _err != nil && (_err != io.EOF || n == 0) {
err = _err err = _err
...@@ -462,7 +482,18 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket { ...@@ -462,7 +482,18 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
osFlags |= os.O_EXCL osFlags |= os.O_EXCL
} }
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644) mode := os.FileMode(0o644)
// Like OpenSSH, we only handle permissions here, and only when the file is being created.
// Otherwise, the permissions are ignored.
if p.Flags&sshFileXferAttrPermissions != 0 {
fs, err := p.unmarshalFileStat(p.Flags)
if err != nil {
return statusFromError(p.ID, err)
}
mode = fs.FileMode() & os.ModePerm
}
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode)
if err != nil { if err != nil {
return statusFromError(p.ID, err) return statusFromError(p.ID, err)
} }
...@@ -496,44 +527,23 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket { ...@@ -496,44 +527,23 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket {
} }
func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket { func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {
// additional unmarshalling is required for each possibility here path := svr.toLocalPath(p.Path)
b := p.Attrs.([]byte)
var err error
p.Path = svr.toLocalPath(p.Path) debug("setstat name %q", path)
debug("setstat name \"%s\"", p.Path) fs, err := p.unmarshalFileStat(p.Flags)
if (p.Flags & sshFileXferAttrSize) != 0 {
var size uint64 if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
if size, b, err = unmarshalUint64Safe(b); err == nil { err = os.Truncate(path, int64(fs.Size))
err = os.Truncate(p.Path, int64(size))
}
}
if (p.Flags & sshFileXferAttrPermissions) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = os.Chmod(p.Path, os.FileMode(mode))
}
}
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(p.Path, atimeT, mtimeT)
} }
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = os.Chmod(path, fs.FileMode())
} }
if (p.Flags & sshFileXferAttrUIDGID) != 0 { if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
var uid uint32 err = os.Chown(path, int(fs.UID), int(fs.GID))
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
} else {
err = os.Chown(p.Path, int(uid), int(gid))
} }
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
} }
return statusFromError(p.ID, err) return statusFromError(p.ID, err)
...@@ -545,41 +555,32 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket { ...@@ -545,41 +555,32 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
return statusFromError(p.ID, EBADF) return statusFromError(p.ID, EBADF)
} }
// additional unmarshalling is required for each possibility here path := f.Name()
b := p.Attrs.([]byte)
var err error
debug("fsetstat name \"%s\"", f.Name()) debug("fsetstat name %q", path)
if (p.Flags & sshFileXferAttrSize) != 0 {
var size uint64 fs, err := p.unmarshalFileStat(p.Flags)
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = f.Truncate(int64(size)) if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
} err = f.Truncate(int64(fs.Size))
}
if (p.Flags & sshFileXferAttrPermissions) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = f.Chmod(os.FileMode(mode))
} }
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = f.Chmod(fs.FileMode())
} }
if (p.Flags & sshFileXferAttrACmodTime) != 0 { if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
var atime uint32 err = f.Chown(int(fs.UID), int(fs.GID))
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(f.Name(), atimeT, mtimeT)
} }
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
type chtimer interface {
Chtimes(atime, mtime time.Time) error
} }
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32 switch f := interface{}(f).(type) {
var gid uint32 case chtimer:
if uid, b, err = unmarshalUint32Safe(b); err != nil { // future-compatible, for when/if *os.File supports Chtimes.
} else if gid, _, err = unmarshalUint32Safe(b); err != nil { err = f.Chtimes(fs.AccessTime(), fs.ModTime())
} else { default:
err = f.Chown(int(uid), int(gid)) err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
} }
} }
......
package sftp
import (
"os"
"syscall"
)
var EBADF = syscall.NewError("fd out of range or not open")
func wrapPathError(filepath string, err error) error {
if errno, ok := err.(syscall.ErrorString); ok {
return &os.PathError{Path: filepath, Err: errno}
}
return err
}
// translateErrno translates a syscall error number to a SFTP error code.
func translateErrno(errno syscall.ErrorString) uint32 {
switch errno {
case "":
return sshFxOk
case syscall.ENOENT:
return sshFxNoSuchFile
case syscall.EPERM:
return sshFxPermissionDenied
}
return sshFxFailure
}
func translateSyscallError(err error) (uint32, bool) {
switch e := err.(type) {
case syscall.ErrorString:
return translateErrno(e), true
case *os.PathError:
debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err)
if errno, ok := e.Err.(syscall.ErrorString); ok {
return translateErrno(errno), true
}
}
return 0, false
}
// isRegular returns true if the mode describes a regular file.
func isRegular(mode uint32) bool {
return mode&S_IFMT == syscall.S_IFREG
}
// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)
switch mode & S_IFMT {
case syscall.S_IFBLK:
fm |= os.ModeDevice
case syscall.S_IFCHR:
fm |= os.ModeDevice | os.ModeCharDevice
case syscall.S_IFDIR:
fm |= os.ModeDir
case syscall.S_IFIFO:
fm |= os.ModeNamedPipe
case syscall.S_IFLNK:
fm |= os.ModeSymlink
case syscall.S_IFREG:
// nothing to do
case syscall.S_IFSOCK:
fm |= os.ModeSocket
}
return fm
}
// fromFileMode converts from the os.FileMode specification to sftp filemode bits
func fromFileMode(mode os.FileMode) uint32 {
ret := uint32(mode & os.ModePerm)
switch mode & os.ModeType {
case os.ModeDevice | os.ModeCharDevice:
ret |= syscall.S_IFCHR
case os.ModeDevice:
ret |= syscall.S_IFBLK
case os.ModeDir:
ret |= syscall.S_IFDIR
case os.ModeNamedPipe:
ret |= syscall.S_IFIFO
case os.ModeSymlink:
ret |= syscall.S_IFLNK
case 0:
ret |= syscall.S_IFREG
case os.ModeSocket:
ret |= syscall.S_IFSOCK
}
return ret
}
// Plan 9 doesn't have setuid, setgid or sticky, but a Plan 9 client should
// be able to send these bits to a POSIX server.
const (
s_ISUID = 04000
s_ISGID = 02000
s_ISVTX = 01000
)
//go:build !plan9
// +build !plan9
package sftp
import (
"os"
"syscall"
)
const EBADF = syscall.EBADF
func wrapPathError(filepath string, err error) error {
if errno, ok := err.(syscall.Errno); ok {
return &os.PathError{Path: filepath, Err: errno}
}
return err
}
// translateErrno translates a syscall error number to a SFTP error code.
func translateErrno(errno syscall.Errno) uint32 {
switch errno {
case 0:
return sshFxOk
case syscall.ENOENT:
return sshFxNoSuchFile
case syscall.EACCES, syscall.EPERM:
return sshFxPermissionDenied
}
return sshFxFailure
}
func translateSyscallError(err error) (uint32, bool) {
switch e := err.(type) {
case syscall.Errno:
return translateErrno(e), true
case *os.PathError:
debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err)
if errno, ok := e.Err.(syscall.Errno); ok {
return translateErrno(errno), true
}
}
return 0, false
}
// isRegular returns true if the mode describes a regular file.
func isRegular(mode uint32) bool {
return mode&S_IFMT == syscall.S_IFREG
}
// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)
switch mode & S_IFMT {
case syscall.S_IFBLK:
fm |= os.ModeDevice
case syscall.S_IFCHR:
fm |= os.ModeDevice | os.ModeCharDevice
case syscall.S_IFDIR:
fm |= os.ModeDir
case syscall.S_IFIFO:
fm |= os.ModeNamedPipe
case syscall.S_IFLNK:
fm |= os.ModeSymlink
case syscall.S_IFREG:
// nothing to do
case syscall.S_IFSOCK:
fm |= os.ModeSocket
}
if mode&syscall.S_ISUID != 0 {
fm |= os.ModeSetuid
}
if mode&syscall.S_ISGID != 0 {
fm |= os.ModeSetgid
}
if mode&syscall.S_ISVTX != 0 {
fm |= os.ModeSticky
}
return fm
}
// fromFileMode converts from the os.FileMode specification to sftp filemode bits
func fromFileMode(mode os.FileMode) uint32 {
ret := uint32(mode & os.ModePerm)
switch mode & os.ModeType {
case os.ModeDevice | os.ModeCharDevice:
ret |= syscall.S_IFCHR
case os.ModeDevice:
ret |= syscall.S_IFBLK
case os.ModeDir:
ret |= syscall.S_IFDIR
case os.ModeNamedPipe:
ret |= syscall.S_IFIFO
case os.ModeSymlink:
ret |= syscall.S_IFLNK
case 0:
ret |= syscall.S_IFREG
case os.ModeSocket:
ret |= syscall.S_IFSOCK
}
if mode&os.ModeSetuid != 0 {
ret |= syscall.S_ISUID
}
if mode&os.ModeSetgid != 0 {
ret |= syscall.S_ISGID
}
if mode&os.ModeSticky != 0 {
ret |= syscall.S_ISVTX
}
return ret
}
const (
s_ISUID = syscall.S_ISUID
s_ISGID = syscall.S_ISGID
s_ISVTX = syscall.S_ISVTX
)
//go:build plan9 || windows || (js && wasm)
// +build plan9 windows js,wasm
// Go defines S_IFMT on windows, plan9 and js/wasm as 0x1f000 instead of
// 0xf000. None of the the other S_IFxyz values include the "1" (in 0x1f000)
// which prevents them from matching the bitmask.
package sftp
const S_IFMT = 0xf000
//go:build !plan9 && !windows && (!js || !wasm)
// +build !plan9
// +build !windows
// +build !js !wasm
package sftp
import "syscall"
const S_IFMT = syscall.S_IFMT
env: env:
CIRRUS_CLONE_DEPTH: 1 CIRRUS_CLONE_DEPTH: 1
GO_VERSION: go1.22.2 GO_VERSION: go1.23.0
freebsd_13_task: freebsd_13_task:
freebsd_instance: freebsd_instance:
image_family: freebsd-13-2 image_family: freebsd-13-3
install_script: | install_script: |
pkg install -y go pkg install -y go
GOBIN=$PWD/bin go install golang.org/dl/${GO_VERSION}@latest GOBIN=$PWD/bin go install golang.org/dl/${GO_VERSION}@latest
......
...@@ -73,3 +73,26 @@ func GetPossible() (int, error) { ...@@ -73,3 +73,26 @@ func GetPossible() (int, error) {
func GetPresent() (int, error) { func GetPresent() (int, error) {
return getPresent() return getPresent()
} }
// ListOffline returns the list of offline CPUs. See [GetOffline] for details on
// when a CPU is considered offline.
func ListOffline() ([]int, error) {
return listOffline()
}
// ListOnline returns the list of CPUs that are online and being scheduled.
func ListOnline() ([]int, error) {
return listOnline()
}
// ListPossible returns the list of possible CPUs. See [GetPossible] for
// details on when a CPU is considered possible.
func ListPossible() ([]int, error) {
return listPossible()
}
// ListPresent returns the list of present CPUs. See [GetPresent] for
// details on when a CPU is considered present.
func ListPresent() ([]int, error) {
return listPresent()
}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
package numcpus package numcpus
import ( import (
"fmt"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
...@@ -23,7 +24,14 @@ import ( ...@@ -23,7 +24,14 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
const sysfsCPUBasePath = "/sys/devices/system/cpu" const (
sysfsCPUBasePath = "/sys/devices/system/cpu"
offline = "offline"
online = "online"
possible = "possible"
present = "present"
)
func getFromCPUAffinity() (int, error) { func getFromCPUAffinity() (int, error) {
var cpuSet unix.CPUSet var cpuSet unix.CPUSet
...@@ -33,19 +41,26 @@ func getFromCPUAffinity() (int, error) { ...@@ -33,19 +41,26 @@ func getFromCPUAffinity() (int, error) {
return cpuSet.Count(), nil return cpuSet.Count(), nil
} }
func readCPURange(file string) (int, error) { func readCPURangeWith[T any](file string, f func(cpus string) (T, error)) (T, error) {
var zero T
buf, err := os.ReadFile(filepath.Join(sysfsCPUBasePath, file)) buf, err := os.ReadFile(filepath.Join(sysfsCPUBasePath, file))
if err != nil { if err != nil {
return 0, err return zero, err
} }
return parseCPURange(strings.Trim(string(buf), "\n ")) return f(strings.Trim(string(buf), "\n "))
}
func countCPURange(cpus string) (int, error) {
// Treat empty file as valid. This might be the case if there are no offline CPUs in which
// case /sys/devices/system/cpu/offline is empty.
if cpus == "" {
return 0, nil
} }
func parseCPURange(cpus string) (int, error) {
n := int(0) n := int(0)
for _, cpuRange := range strings.Split(cpus, ",") { for _, cpuRange := range strings.Split(cpus, ",") {
if len(cpuRange) == 0 { if cpuRange == "" {
continue return 0, fmt.Errorf("empty CPU range in CPU string %q", cpus)
} }
from, to, found := strings.Cut(cpuRange, "-") from, to, found := strings.Cut(cpuRange, "-")
first, err := strconv.ParseUint(from, 10, 32) first, err := strconv.ParseUint(from, 10, 32)
...@@ -60,11 +75,49 @@ func parseCPURange(cpus string) (int, error) { ...@@ -60,11 +75,49 @@ func parseCPURange(cpus string) (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
if last < first {
return 0, fmt.Errorf("last CPU in range (%d) less than first (%d)", last, first)
}
n += int(last - first + 1) n += int(last - first + 1)
} }
return n, nil return n, nil
} }
func listCPURange(cpus string) ([]int, error) {
// See comment in countCPURange.
if cpus == "" {
return []int{}, nil
}
list := []int{}
for _, cpuRange := range strings.Split(cpus, ",") {
if cpuRange == "" {
return nil, fmt.Errorf("empty CPU range in CPU string %q", cpus)
}
from, to, found := strings.Cut(cpuRange, "-")
first, err := strconv.ParseUint(from, 10, 32)
if err != nil {
return nil, err
}
if !found {
// range containing a single element
list = append(list, int(first))
continue
}
last, err := strconv.ParseUint(to, 10, 32)
if err != nil {
return nil, err
}
if last < first {
return nil, fmt.Errorf("last CPU in range (%d) less than first (%d)", last, first)
}
for cpu := int(first); cpu <= int(last); cpu++ {
list = append(list, cpu)
}
}
return list, nil
}
func getConfigured() (int, error) { func getConfigured() (int, error) {
d, err := os.Open(sysfsCPUBasePath) d, err := os.Open(sysfsCPUBasePath)
if err != nil { if err != nil {
...@@ -100,20 +153,36 @@ func getKernelMax() (int, error) { ...@@ -100,20 +153,36 @@ func getKernelMax() (int, error) {
} }
func getOffline() (int, error) { func getOffline() (int, error) {
return readCPURange("offline") return readCPURangeWith(offline, countCPURange)
} }
func getOnline() (int, error) { func getOnline() (int, error) {
if n, err := getFromCPUAffinity(); err == nil { if n, err := getFromCPUAffinity(); err == nil {
return n, nil return n, nil
} }
return readCPURange("online") return readCPURangeWith(online, countCPURange)
} }
func getPossible() (int, error) { func getPossible() (int, error) {
return readCPURange("possible") return readCPURangeWith(possible, countCPURange)
} }
func getPresent() (int, error) { func getPresent() (int, error) {
return readCPURange("present") return readCPURangeWith(present, countCPURange)
}
func listOffline() ([]int, error) {
return readCPURangeWith(offline, listCPURange)
}
func listOnline() ([]int, error) {
return readCPURangeWith(online, listCPURange)
}
func listPossible() ([]int, error) {
return readCPURangeWith(possible, listCPURange)
}
func listPresent() ([]int, error) {
return readCPURangeWith(present, listCPURange)
} }
...@@ -510,8 +510,8 @@ userAuthLoop: ...@@ -510,8 +510,8 @@ userAuthLoop:
if err := s.transport.writePacket(Marshal(discMsg)); err != nil { if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
return nil, err return nil, err
} }
authErrs = append(authErrs, discMsg)
return nil, discMsg return nil, &ServerAuthError{Errors: authErrs}
} }
var userAuthReq userAuthRequestMsg var userAuthReq userAuthRequestMsg
......
...@@ -156,7 +156,7 @@ from the generated architecture-specific files listed below, and merge these ...@@ -156,7 +156,7 @@ from the generated architecture-specific files listed below, and merge these
into a common file for each OS. into a common file for each OS.
The merge is performed in the following steps: The merge is performed in the following steps:
1. Construct the set of common code that is idential in all architecture-specific files. 1. Construct the set of common code that is identical in all architecture-specific files.
2. Write this common code to the merged file. 2. Write this common code to the merged file.
3. Remove the common code from all architecture-specific files. 3. Remove the common code from all architecture-specific files.
......
...@@ -656,7 +656,7 @@ errors=$( ...@@ -656,7 +656,7 @@ errors=$(
signals=$( signals=$(
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags | echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' | awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' | grep -E -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
sort sort
) )
...@@ -666,7 +666,7 @@ echo '#include <errno.h>' | $CC -x c - -E -dM $ccflags | ...@@ -666,7 +666,7 @@ echo '#include <errno.h>' | $CC -x c - -E -dM $ccflags |
sort >_error.grep sort >_error.grep
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags | echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' | awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' | grep -E -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
sort >_signal.grep sort >_signal.grep
echo '// mkerrors.sh' "$@" echo '// mkerrors.sh' "$@"
......