Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • oss/libraries/go/services/job-queues
1 result
Select Git revision
Show changes
Showing
with 635 additions and 466 deletions
//go:build darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || aix || js
// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris aix js
//go:build darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || aix || js || zos
// +build darwin dragonfly freebsd !android,linux netbsd openbsd solaris aix js zos
package sftp
......
......@@ -2,6 +2,7 @@ package sftp
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
......@@ -256,7 +257,7 @@ func NewClientPipe(rd io.Reader, wr io.WriteCloser, opts ...ClientOption) (*Clie
// read/write at the same time. For those services you will need to use
// `client.OpenFile(os.O_WRONLY|os.O_CREATE|os.O_TRUNC)`.
func (c *Client) Create(path string) (*File, error) {
return c.open(path, flags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
return c.open(path, toPflags(os.O_RDWR|os.O_CREATE|os.O_TRUNC))
}
const sftpProtocolVersion = 3 // https://filezilla-project.org/specs/draft-ietf-secsh-filexfer-02.txt
......@@ -321,19 +322,27 @@ func (c *Client) Walk(root string) *fs.Walker {
return fs.WalkFS(root, c)
}
// ReadDir reads the directory named by dirname and returns a list of
// directory entries.
// ReadDir reads the directory named by p
// and returns a list of directory entries.
func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
handle, err := c.opendir(p)
return c.ReadDirContext(context.Background(), p)
}
// ReadDirContext reads the directory named by p
// and returns a list of directory entries.
// The passed context can be used to cancel the operation
// returning all entries listed up to the cancellation.
func (c *Client) ReadDirContext(ctx context.Context, p string) ([]os.FileInfo, error) {
handle, err := c.opendir(ctx, p)
if err != nil {
return nil, err
}
defer c.close(handle) // this has to defer earlier than the lock below
var attrs []os.FileInfo
var entries []os.FileInfo
var done = false
for !done {
id := c.nextID()
typ, data, err1 := c.sendPacket(nil, &sshFxpReaddirPacket{
typ, data, err1 := c.sendPacket(ctx, nil, &sshFxpReaddirPacket{
ID: id,
Handle: handle,
})
......@@ -354,11 +363,14 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
filename, data = unmarshalString(data)
_, data = unmarshalString(data) // discard longname
var attr *FileStat
attr, data = unmarshalAttrs(data)
attr, data, err = unmarshalAttrs(data)
if err != nil {
return nil, err
}
if filename == "." || filename == ".." {
continue
}
attrs = append(attrs, fileInfoFromStat(attr, path.Base(filename)))
entries = append(entries, fileInfoFromStat(attr, path.Base(filename)))
}
case sshFxpStatus:
// TODO(dfc) scope warning!
......@@ -371,12 +383,12 @@ func (c *Client) ReadDir(p string) ([]os.FileInfo, error) {
if err == io.EOF {
err = nil
}
return attrs, err
return entries, err
}
func (c *Client) opendir(path string) (string, error) {
func (c *Client) opendir(ctx context.Context, path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpOpendirPacket{
typ, data, err := c.sendPacket(ctx, nil, &sshFxpOpendirPacket{
ID: id,
Path: path,
})
......@@ -412,7 +424,7 @@ func (c *Client) Stat(p string) (os.FileInfo, error) {
// If 'p' is a symbolic link, the returned FileInfo structure describes the symbolic link.
func (c *Client) Lstat(p string) (os.FileInfo, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpLstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpLstatPacket{
ID: id,
Path: p,
})
......@@ -425,7 +437,11 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
attr, _, err := unmarshalAttrs(data)
if err != nil {
// avoid returning a valid value from fileInfoFromStats if err != nil.
return nil, err
}
return fileInfoFromStat(attr, path.Base(p)), nil
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
......@@ -437,7 +453,7 @@ func (c *Client) Lstat(p string) (os.FileInfo, error) {
// ReadLink reads the target of a symbolic link.
func (c *Client) ReadLink(p string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpReadlinkPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpReadlinkPacket{
ID: id,
Path: p,
})
......@@ -466,7 +482,7 @@ func (c *Client) ReadLink(p string) (string, error) {
// Link creates a hard link at 'newname', pointing at the same inode as 'oldname'
func (c *Client) Link(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpHardlinkPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpHardlinkPacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
......@@ -485,7 +501,7 @@ func (c *Client) Link(oldname, newname string) error {
// Symlink creates a symbolic link at 'newname', pointing at target 'oldname'
func (c *Client) Symlink(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpSymlinkPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSymlinkPacket{
ID: id,
Linkpath: newname,
Targetpath: oldname,
......@@ -501,9 +517,9 @@ func (c *Client) Symlink(oldname, newname string) error {
}
}
func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error {
func (c *Client) fsetstat(handle string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFsetstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFsetstatPacket{
ID: id,
Handle: handle,
Flags: flags,
......@@ -523,7 +539,7 @@ func (c *Client) setfstat(handle string, flags uint32, attrs interface{}) error
// setstat is a convience wrapper to allow for changing of various parts of the file descriptor.
func (c *Client) setstat(path string, flags uint32, attrs interface{}) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpSetstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpSetstatPacket{
ID: id,
Path: path,
Flags: flags,
......@@ -577,23 +593,37 @@ func (c *Client) Truncate(path string, size int64) error {
return c.setstat(path, sshFileXferAttrSize, uint64(size))
}
// SetExtendedData sets extended attributes of the named file. It uses the
// SSH_FILEXFER_ATTR_EXTENDED flag in the setstat request.
//
// This flag provides a general extension mechanism for vendor-specific extensions.
// Names of the attributes should be a string of the format "name@domain", where "domain"
// is a valid, registered domain name and "name" identifies the method. Server
// implementations SHOULD ignore extended data fields that they do not understand.
func (c *Client) SetExtendedData(path string, extended []StatExtended) error {
attrs := &FileStat{
Extended: extended,
}
return c.setstat(path, sshFileXferAttrExtended, attrs)
}
// Open opens the named file for reading. If successful, methods on the
// returned file can be used for reading; the associated file descriptor
// has mode O_RDONLY.
func (c *Client) Open(path string) (*File, error) {
return c.open(path, flags(os.O_RDONLY))
return c.open(path, toPflags(os.O_RDONLY))
}
// OpenFile is the generalized open call; most users will use Open or
// Create instead. It opens the named file with specified flag (O_RDONLY
// etc.). If successful, methods on the returned File can be used for I/O.
func (c *Client) OpenFile(path string, f int) (*File, error) {
return c.open(path, flags(f))
return c.open(path, toPflags(f))
}
func (c *Client) open(path string, pflags uint32) (*File, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpOpenPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpOpenPacket{
ID: id,
Path: path,
Pflags: pflags,
......@@ -621,7 +651,7 @@ func (c *Client) open(path string, pflags uint32) (*File, error) {
// immediately after this request has been sent.
func (c *Client) close(handle string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpClosePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpClosePacket{
ID: id,
Handle: handle,
})
......@@ -638,7 +668,7 @@ func (c *Client) close(handle string) error {
func (c *Client) stat(path string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatPacket{
ID: id,
Path: path,
})
......@@ -651,8 +681,8 @@ func (c *Client) stat(path string) (*FileStat, error) {
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return attr, nil
attr, _, err := unmarshalAttrs(data)
return attr, err
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
......@@ -662,7 +692,7 @@ func (c *Client) stat(path string) (*FileStat, error) {
func (c *Client) fstat(handle string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpFstatPacket{
ID: id,
Handle: handle,
})
......@@ -675,8 +705,8 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
if sid != id {
return nil, &unexpectedIDErr{id, sid}
}
attr, _ := unmarshalAttrs(data)
return attr, nil
attr, _, err := unmarshalAttrs(data)
return attr, err
case sshFxpStatus:
return nil, normaliseError(unmarshalStatus(id, data))
default:
......@@ -691,7 +721,7 @@ func (c *Client) fstat(handle string) (*FileStat, error) {
func (c *Client) StatVFS(path string) (*StatVFS, error) {
// send the StatVFS packet to the server
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpStatvfsPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpStatvfsPacket{
ID: id,
Path: path,
})
......@@ -746,7 +776,7 @@ func (c *Client) Remove(path string) error {
func (c *Client) removeFile(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRemovePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRemovePacket{
ID: id,
Filename: path,
})
......@@ -764,7 +794,7 @@ func (c *Client) removeFile(path string) error {
// RemoveDirectory removes a directory path.
func (c *Client) RemoveDirectory(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRmdirPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRmdirPacket{
ID: id,
Path: path,
})
......@@ -782,7 +812,7 @@ func (c *Client) RemoveDirectory(path string) error {
// Rename renames a file.
func (c *Client) Rename(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRenamePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRenamePacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
......@@ -802,7 +832,7 @@ func (c *Client) Rename(oldname, newname string) error {
// which will replace newname if it already exists.
func (c *Client) PosixRename(oldname, newname string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpPosixRenamePacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpPosixRenamePacket{
ID: id,
Oldpath: oldname,
Newpath: newname,
......@@ -824,7 +854,7 @@ func (c *Client) PosixRename(oldname, newname string) error {
// or relative pathnames without a leading slash into absolute paths.
func (c *Client) RealPath(path string) (string, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpRealpathPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpRealpathPacket{
ID: id,
Path: path,
})
......@@ -861,7 +891,7 @@ func (c *Client) Getwd() (string, error) {
// parent folder does not exist (the method cannot create complete paths).
func (c *Client) Mkdir(path string) error {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpMkdirPacket{
typ, data, err := c.sendPacket(context.Background(), nil, &sshFxpMkdirPacket{
ID: id,
Path: path,
})
......@@ -965,18 +995,34 @@ func (c *Client) RemoveAll(path string) error {
// File represents a remote file.
type File struct {
c *Client
path string
handle string
c *Client
path string
mu sync.Mutex
mu sync.RWMutex
handle string
offset int64 // current offset within remote file
}
// Close closes the File, rendering it unusable for I/O. It returns an
// error, if any.
func (f *File) Close() error {
return f.c.close(f.handle)
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return os.ErrClosed
}
// The design principle here is that when `openssh-portable/sftp-server.c` is doing `handle_close`,
// it will unconditionally mark the handle as unused,
// so we need to also unconditionally mark this handle as invalid.
// By invalidating our local copy of the handle,
// we ensure that there cannot be any erroneous use-after-close requests sent after Close.
handle := f.handle
f.handle = ""
return f.c.close(handle)
}
// Name returns the name of the file as presented to Open or Create.
......@@ -997,7 +1043,7 @@ func (f *File) Read(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()
n, err := f.ReadAt(b, f.offset)
n, err := f.readAt(b, f.offset)
f.offset += int64(n)
return n, err
}
......@@ -1007,7 +1053,7 @@ func (f *File) Read(b []byte) (int, error) {
func (f *File) readChunkAt(ch chan result, b []byte, off int64) (n int, err error) {
for err == nil && n < len(b) {
id := f.c.nextID()
typ, data, err := f.c.sendPacket(ch, &sshFxpReadPacket{
typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpReadPacket{
ID: id,
Handle: f.handle,
Offset: uint64(off) + uint64(n),
......@@ -1062,6 +1108,19 @@ func (f *File) readAtSequential(b []byte, off int64) (read int, err error) {
// the number of bytes read and an error, if any. ReadAt follows io.ReaderAt semantics,
// so the file offset is not altered during the read.
func (f *File) ReadAt(b []byte, off int64) (int, error) {
f.mu.RLock()
defer f.mu.RUnlock()
return f.readAt(b, off)
}
// readAt must be called while holding either the Read or Write mutex in File.
// This code is concurrent safe with itself, but not with Close.
func (f *File) readAt(b []byte, off int64) (int, error) {
if f.handle == "" {
return 0, os.ErrClosed
}
if len(b) <= f.c.maxPacket {
// This should be able to be serviced with 1/2 requests.
// So, just do it directly.
......@@ -1179,7 +1238,9 @@ func (f *File) ReadAt(b []byte, off int64) (int, error) {
if err != nil {
// return the offset as the start + how much we read before the error.
errCh <- rErr{packet.off + int64(n), err}
return
// DO NOT return.
// We want to ensure that workCh is drained before wg.Wait returns.
}
}
}()
......@@ -1258,6 +1319,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
if f.c.disableConcurrentReads {
return f.writeToSequential(w)
}
......@@ -1405,12 +1470,10 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
select {
case readWork.cur <- writeWork:
case <-cancel:
return
}
if err != nil {
return
}
// DO NOT return.
// We want to ensure that readCh is drained before wg.Wait returns.
}
}()
}
......@@ -1450,6 +1513,17 @@ func (f *File) WriteTo(w io.Writer) (written int64, err error) {
// Stat returns the FileInfo structure describing file. If there is an
// error.
func (f *File) Stat() (os.FileInfo, error) {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return nil, os.ErrClosed
}
return f.stat()
}
func (f *File) stat() (os.FileInfo, error) {
fs, err := f.c.fstat(f.handle)
if err != nil {
return nil, err
......@@ -1469,13 +1543,17 @@ func (f *File) Write(b []byte) (int, error) {
f.mu.Lock()
defer f.mu.Unlock()
n, err := f.WriteAt(b, f.offset)
if f.handle == "" {
return 0, os.ErrClosed
}
n, err := f.writeAt(b, f.offset)
f.offset += int64(n)
return n, err
}
func (f *File) writeChunkAt(ch chan result, b []byte, off int64) (int, error) {
typ, data, err := f.c.sendPacket(ch, &sshFxpWritePacket{
typ, data, err := f.c.sendPacket(context.Background(), ch, &sshFxpWritePacket{
ID: f.c.nextID(),
Handle: f.handle,
Offset: uint64(off),
......@@ -1627,6 +1705,19 @@ func (f *File) writeAtConcurrent(b []byte, off int64) (int, error) {
// the number of bytes written and an error, if any. WriteAt follows io.WriterAt semantics,
// so the file offset is not altered during the write.
func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return 0, os.ErrClosed
}
return f.writeAt(b, off)
}
// writeAt must be called while holding either the Read or Write mutex in File.
// This code is concurrent safe with itself, but not with Close.
func (f *File) writeAt(b []byte, off int64) (written int, err error) {
if len(b) <= f.c.maxPacket {
// We can do this in one write.
return f.writeChunkAt(nil, b, off)
......@@ -1665,7 +1756,21 @@ func (f *File) WriteAt(b []byte, off int64) (written int, err error) {
// Giving a concurrency of less than one will default to the Client’s max concurrency.
//
// Otherwise, the given concurrency will be capped by the Client's max concurrency.
//
// When one needs to guarantee concurrent reads/writes, this method is preferred
// over ReadFrom.
func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
f.mu.Lock()
defer f.mu.Unlock()
return f.readFromWithConcurrency(r, concurrency)
}
func (f *File) readFromWithConcurrency(r io.Reader, concurrency int) (read int64, err error) {
if f.handle == "" {
return 0, os.ErrClosed
}
// Split the write into multiple maxPacket sized concurrent writes.
// This allows writes with a suitably large reader
// to transfer data at a much faster rate due to overlapping round trip times.
......@@ -1757,6 +1862,9 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64
if err != nil {
errCh <- rwErr{work.off, err}
// DO NOT return.
// We want to ensure that workCh is drained before wg.Wait returns.
}
}
}()
......@@ -1811,10 +1919,26 @@ func (f *File) ReadFromWithConcurrency(r io.Reader, concurrency int) (read int64
// This method is preferred over calling Write multiple times
// to maximise throughput for transferring the entire file,
// especially over high-latency links.
//
// To ensure concurrent writes, the given r needs to implement one of
// the following receiver methods:
//
// Len() int
// Size() int64
// Stat() (os.FileInfo, error)
//
// or be an instance of [io.LimitedReader] to determine the number of possible
// concurrent requests. Otherwise, reads/writes are performed sequentially.
// ReadFromWithConcurrency can be used explicitly to guarantee concurrent
// processing of the reader.
func (f *File) ReadFrom(r io.Reader) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
if f.c.useConcurrentWrites {
var remain int64
switch r := r.(type) {
......@@ -1836,7 +1960,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
if remain < 0 {
// We can strongly assert that we want default max concurrency here.
return f.ReadFromWithConcurrency(r, f.c.maxConcurrentRequests)
return f.readFromWithConcurrency(r, f.c.maxConcurrentRequests)
}
if remain > int64(f.c.maxPacket) {
......@@ -1851,7 +1975,7 @@ func (f *File) ReadFrom(r io.Reader) (int64, error) {
concurrency64 = int64(f.c.maxConcurrentRequests)
}
return f.ReadFromWithConcurrency(r, int(concurrency64))
return f.readFromWithConcurrency(r, int(concurrency64))
}
}
......@@ -1894,12 +2018,16 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return 0, os.ErrClosed
}
switch whence {
case io.SeekStart:
case io.SeekCurrent:
offset += f.offset
case io.SeekEnd:
fi, err := f.Stat()
fi, err := f.stat()
if err != nil {
return f.offset, err
}
......@@ -1918,22 +2046,84 @@ func (f *File) Seek(offset int64, whence int) (int64, error) {
// Chown changes the uid/gid of the current file.
func (f *File) Chown(uid, gid int) error {
return f.c.Chown(f.path, uid, gid)
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrUIDGID, &FileStat{
UID: uint32(uid),
GID: uint32(gid),
})
}
// Chmod changes the permissions of the current file.
//
// See Client.Chmod for details.
func (f *File) Chmod(mode os.FileMode) error {
return f.c.setfstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrPermissions, toChmodPerm(mode))
}
// SetExtendedData sets extended attributes of the current file. It uses the
// SSH_FILEXFER_ATTR_EXTENDED flag in the setstat request.
//
// This flag provides a general extension mechanism for vendor-specific extensions.
// Names of the attributes should be a string of the format "name@domain", where "domain"
// is a valid, registered domain name and "name" identifies the method. Server
// implementations SHOULD ignore extended data fields that they do not understand.
func (f *File) SetExtendedData(path string, extended []StatExtended) error {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
attrs := &FileStat{
Extended: extended,
}
return f.c.fsetstat(f.handle, sshFileXferAttrExtended, attrs)
}
// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
f.mu.RLock()
defer f.mu.RUnlock()
if f.handle == "" {
return os.ErrClosed
}
return f.c.fsetstat(f.handle, sshFileXferAttrSize, uint64(size))
}
// Sync requests a flush of the contents of a File to stable storage.
//
// Sync requires the server to support the fsync@openssh.com extension.
func (f *File) Sync() error {
f.mu.Lock()
defer f.mu.Unlock()
if f.handle == "" {
return os.ErrClosed
}
id := f.c.nextID()
typ, data, err := f.c.sendPacket(nil, &sshFxpFsyncPacket{
typ, data, err := f.c.sendPacket(context.Background(), nil, &sshFxpFsyncPacket{
ID: id,
Handle: f.handle,
})
......@@ -1948,15 +2138,6 @@ func (f *File) Sync() error {
}
}
// Truncate sets the size of the current file. Although it may be safely assumed
// that if the size is less than its current size it will be truncated to fit,
// the SFTP protocol does not specify what behavior the server should do when setting
// size greater than the current size.
// We send a SSH_FXP_FSETSTAT here since we have a file handle
func (f *File) Truncate(size int64) error {
return f.c.setfstat(f.handle, sshFileXferAttrSize, uint64(size))
}
// normaliseError normalises an error into a more standard form that can be
// checked against stdlib errors like io.EOF or os.ErrNotExist.
func normaliseError(err error) error {
......@@ -1981,15 +2162,14 @@ func normaliseError(err error) error {
// flags converts the flags passed to OpenFile into ssh flags.
// Unsupported flags are ignored.
func flags(f int) uint32 {
func toPflags(f int) uint32 {
var out uint32
switch f & os.O_WRONLY {
case os.O_WRONLY:
out |= sshFxfWrite
switch f & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) {
case os.O_RDONLY:
out |= sshFxfRead
}
if f&os.O_RDWR == os.O_RDWR {
case os.O_WRONLY:
out |= sshFxfWrite
case os.O_RDWR:
out |= sshFxfRead | sshFxfWrite
}
if f&os.O_APPEND == os.O_APPEND {
......@@ -2013,7 +2193,7 @@ func flags(f int) uint32 {
// setuid, setgid and sticky in m, because we've historically supported those
// bits, and we mask off any non-permission bits.
func toChmodPerm(m os.FileMode) (perm uint32) {
const mask = os.ModePerm | s_ISUID | s_ISGID | s_ISVTX
const mask = os.ModePerm | os.FileMode(s_ISUID|s_ISGID|s_ISVTX)
perm = uint32(m & mask)
if m&os.ModeSetuid != 0 {
......
package sftp
import (
"context"
"encoding"
"fmt"
"io"
......@@ -128,14 +129,19 @@ type idmarshaler interface {
encoding.BinaryMarshaler
}
func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) {
func (c *clientConn) sendPacket(ctx context.Context, ch chan result, p idmarshaler) (byte, []byte, error) {
if cap(ch) < 1 {
ch = make(chan result, 1)
}
c.dispatchRequest(ch, p)
s := <-ch
return s.typ, s.data, s.err
select {
case <-ctx.Done():
return 0, nil, ctx.Err()
case s := <-ch:
return s.typ, s.data, s.err
}
}
// dispatchRequest should ideally only be called by race-detection tests outside of this file,
......
//go:build aix || darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || js
// +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris js
//go:build aix || darwin || dragonfly || freebsd || (!android && linux) || netbsd || openbsd || solaris || js || zos
// +build aix darwin dragonfly freebsd !android,linux netbsd openbsd solaris js zos
package sftp
......
......@@ -56,6 +56,11 @@ func marshalFileInfo(b []byte, fi os.FileInfo) []byte {
flags, fileStat := fileStatFromInfo(fi)
b = marshalUint32(b, flags)
return marshalFileStat(b, flags, fileStat)
}
func marshalFileStat(b []byte, flags uint32, fileStat *FileStat) []byte {
if flags&sshFileXferAttrSize != 0 {
b = marshalUint64(b, fileStat.Size)
}
......@@ -91,10 +96,9 @@ func marshalStatus(b []byte, err StatusError) []byte {
}
func marshal(b []byte, v interface{}) []byte {
if v == nil {
return b
}
switch v := v.(type) {
case nil:
return b
case uint8:
return append(b, v)
case uint32:
......@@ -103,6 +107,8 @@ func marshal(b []byte, v interface{}) []byte {
return marshalUint64(b, v)
case string:
return marshalString(b, v)
case []byte:
return append(b, v...)
case os.FileInfo:
return marshalFileInfo(b, v)
default:
......@@ -168,38 +174,69 @@ func unmarshalStringSafe(b []byte) (string, []byte, error) {
return string(b[:n]), b[n:], nil
}
func unmarshalAttrs(b []byte) (*FileStat, []byte) {
flags, b := unmarshalUint32(b)
func unmarshalAttrs(b []byte) (*FileStat, []byte, error) {
flags, b, err := unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
return unmarshalFileStat(flags, b)
}
func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte, error) {
var fs FileStat
var err error
if flags&sshFileXferAttrSize == sshFileXferAttrSize {
fs.Size, b, _ = unmarshalUint64Safe(b)
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.UID, b, _ = unmarshalUint32Safe(b)
fs.Size, b, err = unmarshalUint64Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrUIDGID == sshFileXferAttrUIDGID {
fs.GID, b, _ = unmarshalUint32Safe(b)
fs.UID, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
fs.GID, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrPermissions == sshFileXferAttrPermissions {
fs.Mode, b, _ = unmarshalUint32Safe(b)
fs.Mode, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrACmodTime == sshFileXferAttrACmodTime {
fs.Atime, b, _ = unmarshalUint32Safe(b)
fs.Mtime, b, _ = unmarshalUint32Safe(b)
fs.Atime, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
fs.Mtime, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
}
if flags&sshFileXferAttrExtended == sshFileXferAttrExtended {
var count uint32
count, b, _ = unmarshalUint32Safe(b)
count, b, err = unmarshalUint32Safe(b)
if err != nil {
return nil, b, err
}
ext := make([]StatExtended, count)
for i := uint32(0); i < count; i++ {
var typ string
var data string
typ, b, _ = unmarshalStringSafe(b)
data, b, _ = unmarshalStringSafe(b)
typ, b, err = unmarshalStringSafe(b)
if err != nil {
return nil, b, err
}
data, b, err = unmarshalStringSafe(b)
if err != nil {
return nil, b, err
}
ext[i] = StatExtended{
ExtType: typ,
ExtData: data,
......@@ -207,7 +244,7 @@ func unmarshalFileStat(flags uint32, b []byte) (*FileStat, []byte) {
}
fs.Extended = ext
}
return &fs, b
return &fs, b, nil
}
func unmarshalStatus(id uint32, data []byte) error {
......@@ -681,12 +718,13 @@ type sshFxpOpenPacket struct {
ID uint32
Path string
Pflags uint32
Flags uint32 // ignored
Flags uint32
Attrs interface{}
}
func (p *sshFxpOpenPacket) id() uint32 { return p.ID }
func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
func (p *sshFxpOpenPacket) marshalPacket() ([]byte, []byte, error) {
l := 4 + 1 + 4 + // uint32(length) + byte(type) + uint32(id)
4 + len(p.Path) +
4 + 4
......@@ -698,7 +736,22 @@ func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
b = marshalUint32(b, p.Pflags)
b = marshalUint32(b, p.Flags)
return b, nil
switch attrs := p.Attrs.(type) {
case []byte:
return b, attrs, nil // may as well short-ciruit this case.
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpOpenPacket) MarshalBinary() ([]byte, error) {
header, payload, err := p.marshalPacket()
return append(header, payload...), err
}
func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
......@@ -709,12 +762,25 @@ func (p *sshFxpOpenPacket) UnmarshalBinary(b []byte) error {
return err
} else if p.Pflags, b, err = unmarshalUint32Safe(b); err != nil {
return err
} else if p.Flags, _, err = unmarshalUint32Safe(b); err != nil {
} else if p.Flags, b, err = unmarshalUint32Safe(b); err != nil {
return err
}
p.Attrs = b
return nil
}
func (p *sshFxpOpenPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs, nil
case []byte:
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}
type sshFxpReadPacket struct {
ID uint32
Len uint32
......@@ -757,7 +823,7 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error {
// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length)
const dataHeaderLen = 4 + 1 + 4 + 4
func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte {
func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32, maxTxPacket uint32) []byte {
dataLen := p.Len
if dataLen > maxTxPacket {
dataLen = maxTxPacket
......@@ -943,9 +1009,17 @@ func (p *sshFxpSetstatPacket) marshalPacket() ([]byte, []byte, error) {
b = marshalString(b, p.Path)
b = marshalUint32(b, p.Flags)
payload := marshal(nil, p.Attrs)
switch attrs := p.Attrs.(type) {
case []byte:
return b, attrs, nil // may as well short-ciruit this case.
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, payload, nil
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpSetstatPacket) MarshalBinary() ([]byte, error) {
......@@ -964,9 +1038,17 @@ func (p *sshFxpFsetstatPacket) marshalPacket() ([]byte, []byte, error) {
b = marshalString(b, p.Handle)
b = marshalUint32(b, p.Flags)
payload := marshal(nil, p.Attrs)
switch attrs := p.Attrs.(type) {
case []byte:
return b, attrs, nil // may as well short-ciruit this case.
case os.FileInfo:
_, fs := fileStatFromInfo(attrs) // we throw away the flags, and override with those in packet.
return b, marshalFileStat(nil, p.Flags, fs), nil
case *FileStat:
return b, marshalFileStat(nil, p.Flags, attrs), nil
}
return b, payload, nil
return b, marshal(nil, p.Attrs), nil
}
func (p *sshFxpFsetstatPacket) MarshalBinary() ([]byte, error) {
......@@ -987,6 +1069,18 @@ func (p *sshFxpSetstatPacket) UnmarshalBinary(b []byte) error {
return nil
}
func (p *sshFxpSetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs, nil
case []byte:
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}
func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
var err error
if p.ID, b, err = unmarshalUint32Safe(b); err != nil {
......@@ -1000,6 +1094,18 @@ func (p *sshFxpFsetstatPacket) UnmarshalBinary(b []byte) error {
return nil
}
func (p *sshFxpFsetstatPacket) unmarshalFileStat(flags uint32) (*FileStat, error) {
switch attrs := p.Attrs.(type) {
case *FileStat:
return attrs, nil
case []byte:
fs, _, err := unmarshalFileStat(flags, attrs)
return fs, err
default:
return nil, fmt.Errorf("invalid type in unmarshalFileStat: %T", attrs)
}
}
type sshFxpHandlePacket struct {
ID uint32
Handle string
......
......@@ -3,7 +3,6 @@ package sftp
// Methods on the Request object to make working with the Flags bitmasks and
// Attr(ibutes) byte blob easier. Use Pflags() when working with an Open/Write
// request and AttrFlags() and Attributes() when working with SetStat requests.
import "os"
// FileOpenFlags defines Open and Write Flags. Correlate directly with with os.OpenFile flags
// (https://golang.org/pkg/os/#pkg-constants).
......@@ -50,14 +49,9 @@ func (r *Request) AttrFlags() FileAttrFlags {
return newFileAttrFlags(r.Flags)
}
// FileMode returns the Mode SFTP file attributes wrapped as os.FileMode
func (a FileStat) FileMode() os.FileMode {
return os.FileMode(a.Mode)
}
// Attributes parses file attributes byte blob and return them in a
// FileStat object.
func (r *Request) Attributes() *FileStat {
fs, _ := unmarshalFileStat(r.Flags, r.Attrs)
fs, _, _ := unmarshalFileStat(r.Flags, r.Attrs)
return fs
}
......@@ -30,7 +30,7 @@ type FileReader interface {
// FileWriter should return an io.WriterAt for the filepath.
//
// The request server code will call Close() on the returned io.WriterAt
// ojbect if an io.Closer type assertion succeeds.
// object if an io.Closer type assertion succeeds.
// Note in cases of an error, the error text will be sent to the client.
// Note when receiving an Append flag it is important to not open files using
// O_APPEND if you plan to use WriteAt, as they conflict.
......@@ -144,6 +144,8 @@ type NameLookupFileLister interface {
//
// If a populated entry implements [FileInfoExtendedData], extended attributes will also be returned to the client.
//
// The request server code will call Close() on ListerAt if an io.Closer type assertion succeeds.
//
// Note in cases of an error, the error text will be sent to the client.
type ListerAt interface {
ListAt([]os.FileInfo, int64) (int, error)
......
......@@ -10,7 +10,7 @@ import (
"sync"
)
var maxTxPacket uint32 = 1 << 15
const defaultMaxTxPacket uint32 = 1 << 15
// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
......@@ -28,6 +28,7 @@ type RequestServer struct {
pktMgr *packetManager
startDirectory string
maxTxPacket uint32
mu sync.RWMutex
handleCount int
......@@ -57,6 +58,22 @@ func WithStartDirectory(startDirectory string) RequestServerOption {
}
}
// WithRSMaxTxPacket sets the maximum size of the payload returned to the client,
// measured in bytes. The default value is 32768 bytes, and this option
// can only be used to increase it. Setting this option to a larger value
// should be safe, because the client decides the size of the requested payload.
//
// The default maximum packet size is 32768 bytes.
func WithRSMaxTxPacket(size uint32) RequestServerOption {
return func(rs *RequestServer) {
if size < defaultMaxTxPacket {
return
}
rs.maxTxPacket = size
}
}
// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
......@@ -73,6 +90,7 @@ func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServ
pktMgr: newPktMgr(svrConn),
startDirectory: "/",
maxTxPacket: defaultMaxTxPacket,
openRequests: make(map[string]*Request),
}
......@@ -260,7 +278,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Method: "Stat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
......@@ -272,7 +290,7 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Method: "Setstat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case *sshFxpExtendedPacketPosixRename:
request := &Request{
......@@ -280,24 +298,24 @@ func (rs *RequestServer) packetWorker(ctx context.Context, pktChan chan orderedR
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case *sshFxpExtendedPacketStatVFS:
request := &Request{
Method: "StatVFS",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
}
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.id(), EBADF)
} else {
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case hasPath:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
request.close()
default:
rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
......
......@@ -121,6 +121,22 @@ func (s *state) getListerAt() ListerAt {
return s.listerAt
}
func (s *state) closeListerAt() error {
s.mu.Lock()
defer s.mu.Unlock()
var err error
if s.listerAt != nil {
if c, ok := s.listerAt.(io.Closer); ok {
err = c.Close()
}
s.listerAt = nil
}
return err
}
// Request contains the data and state for the incoming service request.
type Request struct {
// Get, Put, Setstat, Stat, Rename, Remove
......@@ -178,6 +194,7 @@ func requestFromPacket(ctx context.Context, pkt hasPath, baseDir string) *Reques
switch p := pkt.(type) {
case *sshFxpOpenPacket:
request.Flags = p.Pflags
request.Attrs = p.Attrs.([]byte)
case *sshFxpSetstatPacket:
request.Flags = p.Flags
request.Attrs = p.Attrs.([]byte)
......@@ -229,9 +246,9 @@ func (r *Request) close() error {
}
}()
rd, wr, rw := r.getAllReaderWriters()
err := r.state.closeListerAt()
var err error
rd, wr, rw := r.getAllReaderWriters()
// Close errors on a Writer are far more likely to be the important one.
// As they can be information that there was a loss of data.
......@@ -283,14 +300,14 @@ func (r *Request) transferError(err error) {
}
// called from worker to handle packet/request
func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
switch r.Method {
case "Get":
return fileget(handlers.FileGet, r, pkt, alloc, orderID)
return fileget(handlers.FileGet, r, pkt, alloc, orderID, maxTxPacket)
case "Put":
return fileput(handlers.FilePut, r, pkt, alloc, orderID)
return fileput(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Open":
return fileputget(handlers.FilePut, r, pkt, alloc, orderID)
return fileputget(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
return filecmd(handlers.FileCmd, r, pkt)
case "List":
......@@ -375,13 +392,13 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
}
// wrap FileReader handler
func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rd := r.getReaderAt()
if rd == nil {
return statusFromError(pkt.id(), errors.New("unexpected read packet"))
}
data, offset, _ := packetData(pkt, alloc, orderID)
data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)
n, err := rd.ReadAt(data, offset)
// only return EOF error if no data left to read
......@@ -397,20 +414,20 @@ func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orde
}
// wrap FileWriter handler
func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
wr := r.getWriterAt()
if wr == nil {
return statusFromError(pkt.id(), errors.New("unexpected write packet"))
}
data, offset, _ := packetData(pkt, alloc, orderID)
data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)
_, err := wr.WriteAt(data, offset)
return statusFromError(pkt.id(), err)
}
// wrap OpenFileWriter handler
func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rw := r.getWriterAtReaderAt()
if rw == nil {
return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
......@@ -418,7 +435,7 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o
switch p := pkt.(type) {
case *sshFxpReadPacket:
data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
data, offset := p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset)
n, err := rw.ReadAt(data, offset)
// only return EOF error if no data left to read
......@@ -444,10 +461,10 @@ func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, o
}
// file data for additional read/write packets
func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) {
func packetData(p requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) (data []byte, offset int64, length uint32) {
switch p := p.(type) {
case *sshFxpReadPacket:
return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len
return p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset), p.Len
case *sshFxpWritePacket:
return p.Data, int64(p.Offset), p.Length
}
......
......@@ -34,6 +34,7 @@ type Server struct {
openFilesLock sync.RWMutex
handleCount int
workDir string
maxTxPacket uint32
}
func (svr *Server) nextHandle(f *os.File) string {
......@@ -86,6 +87,7 @@ func NewServer(rwc io.ReadWriteCloser, options ...ServerOption) (*Server, error)
debugStream: ioutil.Discard,
pktMgr: newPktMgr(svrConn),
openFiles: make(map[string]*os.File),
maxTxPacket: defaultMaxTxPacket,
}
for _, o := range options {
......@@ -139,6 +141,24 @@ func WithServerWorkingDirectory(workDir string) ServerOption {
}
}
// WithMaxTxPacket sets the maximum size of the payload returned to the client,
// measured in bytes. The default value is 32768 bytes, and this option
// can only be used to increase it. Setting this option to a larger value
// should be safe, because the client decides the size of the requested payload.
//
// The default maximum packet size is 32768 bytes.
func WithMaxTxPacket(size uint32) ServerOption {
return func(s *Server) error {
if size < defaultMaxTxPacket {
return errors.New("size must be greater than or equal to 32768")
}
s.maxTxPacket = size
return nil
}
}
type rxPacket struct {
pktType fxp
pktBytes []byte
......@@ -287,7 +307,7 @@ func handlePacket(s *Server, p orderedRequest) error {
f, ok := s.getHandle(p.Handle)
if ok {
err = nil
data := p.getDataSlice(s.pktMgr.alloc, orderID)
data := p.getDataSlice(s.pktMgr.alloc, orderID, s.maxTxPacket)
n, _err := f.ReadAt(data, int64(p.Offset))
if _err != nil && (_err != io.EOF || n == 0) {
err = _err
......@@ -462,7 +482,18 @@ func (p *sshFxpOpenPacket) respond(svr *Server) responsePacket {
osFlags |= os.O_EXCL
}
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, 0o644)
mode := os.FileMode(0o644)
// Like OpenSSH, we only handle permissions here, and only when the file is being created.
// Otherwise, the permissions are ignored.
if p.Flags&sshFileXferAttrPermissions != 0 {
fs, err := p.unmarshalFileStat(p.Flags)
if err != nil {
return statusFromError(p.ID, err)
}
mode = fs.FileMode() & os.ModePerm
}
f, err := os.OpenFile(svr.toLocalPath(p.Path), osFlags, mode)
if err != nil {
return statusFromError(p.ID, err)
}
......@@ -496,44 +527,23 @@ func (p *sshFxpReaddirPacket) respond(svr *Server) responsePacket {
}
func (p *sshFxpSetstatPacket) respond(svr *Server) responsePacket {
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
var err error
path := svr.toLocalPath(p.Path)
p.Path = svr.toLocalPath(p.Path)
debug("setstat name %q", path)
debug("setstat name \"%s\"", p.Path)
if (p.Flags & sshFileXferAttrSize) != 0 {
var size uint64
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = os.Truncate(p.Path, int64(size))
}
fs, err := p.unmarshalFileStat(p.Flags)
if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
err = os.Truncate(path, int64(fs.Size))
}
if (p.Flags & sshFileXferAttrPermissions) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = os.Chmod(p.Path, os.FileMode(mode))
}
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = os.Chmod(path, fs.FileMode())
}
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(p.Path, atimeT, mtimeT)
}
if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = os.Chown(path, int(fs.UID), int(fs.GID))
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
} else {
err = os.Chown(p.Path, int(uid), int(gid))
}
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}
return statusFromError(p.ID, err)
......@@ -545,41 +555,32 @@ func (p *sshFxpFsetstatPacket) respond(svr *Server) responsePacket {
return statusFromError(p.ID, EBADF)
}
// additional unmarshalling is required for each possibility here
b := p.Attrs.([]byte)
var err error
path := f.Name()
debug("fsetstat name \"%s\"", f.Name())
if (p.Flags & sshFileXferAttrSize) != 0 {
var size uint64
if size, b, err = unmarshalUint64Safe(b); err == nil {
err = f.Truncate(int64(size))
}
debug("fsetstat name %q", path)
fs, err := p.unmarshalFileStat(p.Flags)
if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
err = f.Truncate(int64(fs.Size))
}
if (p.Flags & sshFileXferAttrPermissions) != 0 {
var mode uint32
if mode, b, err = unmarshalUint32Safe(b); err == nil {
err = f.Chmod(os.FileMode(mode))
}
if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = f.Chmod(fs.FileMode())
}
if (p.Flags & sshFileXferAttrACmodTime) != 0 {
var atime uint32
var mtime uint32
if atime, b, err = unmarshalUint32Safe(b); err != nil {
} else if mtime, b, err = unmarshalUint32Safe(b); err != nil {
} else {
atimeT := time.Unix(int64(atime), 0)
mtimeT := time.Unix(int64(mtime), 0)
err = os.Chtimes(f.Name(), atimeT, mtimeT)
}
if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = f.Chown(int(fs.UID), int(fs.GID))
}
if (p.Flags & sshFileXferAttrUIDGID) != 0 {
var uid uint32
var gid uint32
if uid, b, err = unmarshalUint32Safe(b); err != nil {
} else if gid, _, err = unmarshalUint32Safe(b); err != nil {
} else {
err = f.Chown(int(uid), int(gid))
if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
type chtimer interface {
Chtimes(atime, mtime time.Time) error
}
switch f := interface{}(f).(type) {
case chtimer:
// future-compatible, for when/if *os.File supports Chtimes.
err = f.Chtimes(fs.AccessTime(), fs.ModTime())
default:
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}
}
......
package sftp
import (
"os"
"syscall"
)
var EBADF = syscall.NewError("fd out of range or not open")
func wrapPathError(filepath string, err error) error {
if errno, ok := err.(syscall.ErrorString); ok {
return &os.PathError{Path: filepath, Err: errno}
}
return err
}
// translateErrno translates a syscall error number to a SFTP error code.
func translateErrno(errno syscall.ErrorString) uint32 {
switch errno {
case "":
return sshFxOk
case syscall.ENOENT:
return sshFxNoSuchFile
case syscall.EPERM:
return sshFxPermissionDenied
}
return sshFxFailure
}
func translateSyscallError(err error) (uint32, bool) {
switch e := err.(type) {
case syscall.ErrorString:
return translateErrno(e), true
case *os.PathError:
debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err)
if errno, ok := e.Err.(syscall.ErrorString); ok {
return translateErrno(errno), true
}
}
return 0, false
}
// isRegular returns true if the mode describes a regular file.
func isRegular(mode uint32) bool {
return mode&S_IFMT == syscall.S_IFREG
}
// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)
switch mode & S_IFMT {
case syscall.S_IFBLK:
fm |= os.ModeDevice
case syscall.S_IFCHR:
fm |= os.ModeDevice | os.ModeCharDevice
case syscall.S_IFDIR:
fm |= os.ModeDir
case syscall.S_IFIFO:
fm |= os.ModeNamedPipe
case syscall.S_IFLNK:
fm |= os.ModeSymlink
case syscall.S_IFREG:
// nothing to do
case syscall.S_IFSOCK:
fm |= os.ModeSocket
}
return fm
}
// fromFileMode converts from the os.FileMode specification to sftp filemode bits
func fromFileMode(mode os.FileMode) uint32 {
ret := uint32(mode & os.ModePerm)
switch mode & os.ModeType {
case os.ModeDevice | os.ModeCharDevice:
ret |= syscall.S_IFCHR
case os.ModeDevice:
ret |= syscall.S_IFBLK
case os.ModeDir:
ret |= syscall.S_IFDIR
case os.ModeNamedPipe:
ret |= syscall.S_IFIFO
case os.ModeSymlink:
ret |= syscall.S_IFLNK
case 0:
ret |= syscall.S_IFREG
case os.ModeSocket:
ret |= syscall.S_IFSOCK
}
return ret
}
// Plan 9 doesn't have setuid, setgid or sticky, but a Plan 9 client should
// be able to send these bits to a POSIX server.
const (
s_ISUID = 04000
s_ISGID = 02000
s_ISVTX = 01000
)
//go:build !plan9
// +build !plan9
package sftp
import (
"os"
"syscall"
)
const EBADF = syscall.EBADF
func wrapPathError(filepath string, err error) error {
if errno, ok := err.(syscall.Errno); ok {
return &os.PathError{Path: filepath, Err: errno}
}
return err
}
// translateErrno translates a syscall error number to a SFTP error code.
func translateErrno(errno syscall.Errno) uint32 {
switch errno {
case 0:
return sshFxOk
case syscall.ENOENT:
return sshFxNoSuchFile
case syscall.EACCES, syscall.EPERM:
return sshFxPermissionDenied
}
return sshFxFailure
}
func translateSyscallError(err error) (uint32, bool) {
switch e := err.(type) {
case syscall.Errno:
return translateErrno(e), true
case *os.PathError:
debug("statusFromError,pathError: error is %T %#v", e.Err, e.Err)
if errno, ok := e.Err.(syscall.Errno); ok {
return translateErrno(errno), true
}
}
return 0, false
}
// isRegular returns true if the mode describes a regular file.
func isRegular(mode uint32) bool {
return mode&S_IFMT == syscall.S_IFREG
}
// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)
switch mode & S_IFMT {
case syscall.S_IFBLK:
fm |= os.ModeDevice
case syscall.S_IFCHR:
fm |= os.ModeDevice | os.ModeCharDevice
case syscall.S_IFDIR:
fm |= os.ModeDir
case syscall.S_IFIFO:
fm |= os.ModeNamedPipe
case syscall.S_IFLNK:
fm |= os.ModeSymlink
case syscall.S_IFREG:
// nothing to do
case syscall.S_IFSOCK:
fm |= os.ModeSocket
}
if mode&syscall.S_ISUID != 0 {
fm |= os.ModeSetuid
}
if mode&syscall.S_ISGID != 0 {
fm |= os.ModeSetgid
}
if mode&syscall.S_ISVTX != 0 {
fm |= os.ModeSticky
}
return fm
}
// fromFileMode converts from the os.FileMode specification to sftp filemode bits
func fromFileMode(mode os.FileMode) uint32 {
ret := uint32(mode & os.ModePerm)
switch mode & os.ModeType {
case os.ModeDevice | os.ModeCharDevice:
ret |= syscall.S_IFCHR
case os.ModeDevice:
ret |= syscall.S_IFBLK
case os.ModeDir:
ret |= syscall.S_IFDIR
case os.ModeNamedPipe:
ret |= syscall.S_IFIFO
case os.ModeSymlink:
ret |= syscall.S_IFLNK
case 0:
ret |= syscall.S_IFREG
case os.ModeSocket:
ret |= syscall.S_IFSOCK
}
if mode&os.ModeSetuid != 0 {
ret |= syscall.S_ISUID
}
if mode&os.ModeSetgid != 0 {
ret |= syscall.S_ISGID
}
if mode&os.ModeSticky != 0 {
ret |= syscall.S_ISVTX
}
return ret
}
const (
s_ISUID = syscall.S_ISUID
s_ISGID = syscall.S_ISGID
s_ISVTX = syscall.S_ISVTX
)
//go:build plan9 || windows || (js && wasm)
// +build plan9 windows js,wasm
// Go defines S_IFMT on windows, plan9 and js/wasm as 0x1f000 instead of
// 0xf000. None of the the other S_IFxyz values include the "1" (in 0x1f000)
// which prevents them from matching the bitmask.
package sftp
const S_IFMT = 0xf000
//go:build !plan9 && !windows && (!js || !wasm)
// +build !plan9
// +build !windows
// +build !js !wasm
package sftp
import "syscall"
const S_IFMT = syscall.S_IFMT
env:
CIRRUS_CLONE_DEPTH: 1
GO_VERSION: go1.22.2
GO_VERSION: go1.23.0
freebsd_13_task:
freebsd_instance:
image_family: freebsd-13-2
image_family: freebsd-13-3
install_script: |
pkg install -y go
GOBIN=$PWD/bin go install golang.org/dl/${GO_VERSION}@latest
......
......@@ -73,3 +73,26 @@ func GetPossible() (int, error) {
func GetPresent() (int, error) {
return getPresent()
}
// ListOffline returns the list of offline CPUs. See [GetOffline] for details on
// when a CPU is considered offline.
func ListOffline() ([]int, error) {
return listOffline()
}
// ListOnline returns the list of CPUs that are online and being scheduled.
func ListOnline() ([]int, error) {
return listOnline()
}
// ListPossible returns the list of possible CPUs. See [GetPossible] for
// details on when a CPU is considered possible.
func ListPossible() ([]int, error) {
return listPossible()
}
// ListPresent returns the list of present CPUs. See [GetPresent] for
// details on when a CPU is considered present.
func ListPresent() ([]int, error) {
return listPresent()
}
......@@ -15,6 +15,7 @@
package numcpus
import (
"fmt"
"os"
"path/filepath"
"strconv"
......@@ -23,7 +24,14 @@ import (
"golang.org/x/sys/unix"
)
const sysfsCPUBasePath = "/sys/devices/system/cpu"
const (
sysfsCPUBasePath = "/sys/devices/system/cpu"
offline = "offline"
online = "online"
possible = "possible"
present = "present"
)
func getFromCPUAffinity() (int, error) {
var cpuSet unix.CPUSet
......@@ -33,19 +41,26 @@ func getFromCPUAffinity() (int, error) {
return cpuSet.Count(), nil
}
func readCPURange(file string) (int, error) {
func readCPURangeWith[T any](file string, f func(cpus string) (T, error)) (T, error) {
var zero T
buf, err := os.ReadFile(filepath.Join(sysfsCPUBasePath, file))
if err != nil {
return 0, err
return zero, err
}
return parseCPURange(strings.Trim(string(buf), "\n "))
return f(strings.Trim(string(buf), "\n "))
}
func parseCPURange(cpus string) (int, error) {
func countCPURange(cpus string) (int, error) {
// Treat empty file as valid. This might be the case if there are no offline CPUs in which
// case /sys/devices/system/cpu/offline is empty.
if cpus == "" {
return 0, nil
}
n := int(0)
for _, cpuRange := range strings.Split(cpus, ",") {
if len(cpuRange) == 0 {
continue
if cpuRange == "" {
return 0, fmt.Errorf("empty CPU range in CPU string %q", cpus)
}
from, to, found := strings.Cut(cpuRange, "-")
first, err := strconv.ParseUint(from, 10, 32)
......@@ -60,11 +75,49 @@ func parseCPURange(cpus string) (int, error) {
if err != nil {
return 0, err
}
if last < first {
return 0, fmt.Errorf("last CPU in range (%d) less than first (%d)", last, first)
}
n += int(last - first + 1)
}
return n, nil
}
func listCPURange(cpus string) ([]int, error) {
// See comment in countCPURange.
if cpus == "" {
return []int{}, nil
}
list := []int{}
for _, cpuRange := range strings.Split(cpus, ",") {
if cpuRange == "" {
return nil, fmt.Errorf("empty CPU range in CPU string %q", cpus)
}
from, to, found := strings.Cut(cpuRange, "-")
first, err := strconv.ParseUint(from, 10, 32)
if err != nil {
return nil, err
}
if !found {
// range containing a single element
list = append(list, int(first))
continue
}
last, err := strconv.ParseUint(to, 10, 32)
if err != nil {
return nil, err
}
if last < first {
return nil, fmt.Errorf("last CPU in range (%d) less than first (%d)", last, first)
}
for cpu := int(first); cpu <= int(last); cpu++ {
list = append(list, cpu)
}
}
return list, nil
}
func getConfigured() (int, error) {
d, err := os.Open(sysfsCPUBasePath)
if err != nil {
......@@ -100,20 +153,36 @@ func getKernelMax() (int, error) {
}
func getOffline() (int, error) {
return readCPURange("offline")
return readCPURangeWith(offline, countCPURange)
}
func getOnline() (int, error) {
if n, err := getFromCPUAffinity(); err == nil {
return n, nil
}
return readCPURange("online")
return readCPURangeWith(online, countCPURange)
}
func getPossible() (int, error) {
return readCPURange("possible")
return readCPURangeWith(possible, countCPURange)
}
func getPresent() (int, error) {
return readCPURange("present")
return readCPURangeWith(present, countCPURange)
}
func listOffline() ([]int, error) {
return readCPURangeWith(offline, listCPURange)
}
func listOnline() ([]int, error) {
return readCPURangeWith(online, listCPURange)
}
func listPossible() ([]int, error) {
return readCPURangeWith(possible, listCPURange)
}
func listPresent() ([]int, error) {
return readCPURangeWith(present, listCPURange)
}
......@@ -510,8 +510,8 @@ userAuthLoop:
if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
return nil, err
}
return nil, discMsg
authErrs = append(authErrs, discMsg)
return nil, &ServerAuthError{Errors: authErrs}
}
var userAuthReq userAuthRequestMsg
......
......@@ -156,7 +156,7 @@ from the generated architecture-specific files listed below, and merge these
into a common file for each OS.
The merge is performed in the following steps:
1. Construct the set of common code that is idential in all architecture-specific files.
1. Construct the set of common code that is identical in all architecture-specific files.
2. Write this common code to the merged file.
3. Remove the common code from all architecture-specific files.
......
......@@ -656,7 +656,7 @@ errors=$(
signals=$(
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print $2 }' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
grep -E -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
sort
)
......@@ -666,7 +666,7 @@ echo '#include <errno.h>' | $CC -x c - -E -dM $ccflags |
sort >_error.grep
echo '#include <signal.h>' | $CC -x c - -E -dM $ccflags |
awk '$1=="#define" && $2 ~ /^SIG[A-Z0-9]+$/ { print "^\t" $2 "[ \t]*=" }' |
grep -v 'SIGSTKSIZE\|SIGSTKSZ\|SIGRT\|SIGMAX64' |
grep -E -v '(SIGSTKSIZE|SIGSTKSZ|SIGRT|SIGMAX64)' |
sort >_signal.grep
echo '// mkerrors.sh' "$@"
......