mirror of https://github.com/gogits/gogs.git
Kim Lindhardt Madsen
9 years ago
236 changed files with 23209 additions and 7088 deletions
@ -0,0 +1,8 @@
|
||||
#!/bin/sh |
||||
|
||||
if test -f ./setup; then |
||||
source ./setup |
||||
fi |
||||
|
||||
export USER=git |
||||
exec gosu $USER /app/gogs/gogs web |
@ -0,0 +1,7 @@
|
||||
#!/bin/sh |
||||
|
||||
if test -f ./setup; then |
||||
source ./setup |
||||
fi |
||||
|
||||
exec gosu root /usr/sbin/sshd -D -f /app/gogs/docker/sshd_config |
@ -0,0 +1,17 @@
|
||||
Port 22 |
||||
AddressFamily any |
||||
ListenAddress 0.0.0.0 |
||||
ListenAddress :: |
||||
Protocol 2 |
||||
LogLevel INFO |
||||
HostKey /data/ssh/ssh_host_key |
||||
HostKey /data/ssh/ssh_host_rsa_key |
||||
HostKey /data/ssh/ssh_host_dsa_key |
||||
HostKey /data/ssh/ssh_host_ecdsa_key |
||||
HostKey /data/ssh/ssh_host_ed25519_key |
||||
PermitRootLogin no |
||||
AuthorizedKeysFile .ssh/authorized_keys |
||||
PasswordAuthentication no |
||||
UsePrivilegeSeparation no |
||||
PermitUserEnvironment yes |
||||
AllowUsers git |
@ -1,106 +0,0 @@
|
||||
// Copyright 2014 The Gogs Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package models |
||||
|
||||
import ( |
||||
"errors" |
||||
"time" |
||||
) |
||||
|
||||
type OauthType int |
||||
|
||||
const ( |
||||
GITHUB OauthType = iota + 1 |
||||
GOOGLE |
||||
TWITTER |
||||
QQ |
||||
WEIBO |
||||
BITBUCKET |
||||
FACEBOOK |
||||
) |
||||
|
||||
var ( |
||||
ErrOauth2RecordNotExist = errors.New("OAuth2 record does not exist") |
||||
ErrOauth2NotAssociated = errors.New("OAuth2 is not associated with user") |
||||
) |
||||
|
||||
type Oauth2 struct { |
||||
Id int64 |
||||
Uid int64 `xorm:"unique(s)"` // userId
|
||||
User *User `xorm:"-"` |
||||
Type int `xorm:"unique(s) unique(oauth)"` // twitter,github,google...
|
||||
Identity string `xorm:"unique(s) unique(oauth)"` // id..
|
||||
Token string `xorm:"TEXT not null"` |
||||
Created time.Time `xorm:"CREATED"` |
||||
Updated time.Time |
||||
HasRecentActivity bool `xorm:"-"` |
||||
} |
||||
|
||||
func BindUserOauth2(userId, oauthId int64) error { |
||||
_, err := x.Id(oauthId).Update(&Oauth2{Uid: userId}) |
||||
return err |
||||
} |
||||
|
||||
func AddOauth2(oa *Oauth2) error { |
||||
_, err := x.Insert(oa) |
||||
return err |
||||
} |
||||
|
||||
func GetOauth2(identity string) (oa *Oauth2, err error) { |
||||
oa = &Oauth2{Identity: identity} |
||||
isExist, err := x.Get(oa) |
||||
if err != nil { |
||||
return |
||||
} else if !isExist { |
||||
return nil, ErrOauth2RecordNotExist |
||||
} else if oa.Uid == -1 { |
||||
return oa, ErrOauth2NotAssociated |
||||
} |
||||
oa.User, err = GetUserByID(oa.Uid) |
||||
return oa, err |
||||
} |
||||
|
||||
func GetOauth2ById(id int64) (oa *Oauth2, err error) { |
||||
oa = new(Oauth2) |
||||
has, err := x.Id(id).Get(oa) |
||||
if err != nil { |
||||
return nil, err |
||||
} else if !has { |
||||
return nil, ErrOauth2RecordNotExist |
||||
} |
||||
return oa, nil |
||||
} |
||||
|
||||
// UpdateOauth2 updates given OAuth2.
|
||||
func UpdateOauth2(oa *Oauth2) error { |
||||
_, err := x.Id(oa.Id).AllCols().Update(oa) |
||||
return err |
||||
} |
||||
|
||||
// GetOauthByUserId returns list of oauthes that are related to given user.
|
||||
func GetOauthByUserId(uid int64) ([]*Oauth2, error) { |
||||
socials := make([]*Oauth2, 0, 5) |
||||
err := x.Find(&socials, Oauth2{Uid: uid}) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
for _, social := range socials { |
||||
social.HasRecentActivity = social.Updated.Add(7 * 24 * time.Hour).After(time.Now()) |
||||
} |
||||
return socials, err |
||||
} |
||||
|
||||
// DeleteOauth2ById deletes a oauth2 by ID.
|
||||
func DeleteOauth2ById(id int64) error { |
||||
_, err := x.Delete(&Oauth2{Id: id}) |
||||
return err |
||||
} |
||||
|
||||
// CleanUnbindOauth deletes all unbind OAuthes.
|
||||
func CleanUnbindOauth() error { |
||||
_, err := x.Delete(&Oauth2{Uid: -1}) |
||||
return err |
||||
} |
@ -0,0 +1,262 @@
|
||||
// Copyright 2015 The Gogs Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package models |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"path" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/Unknwon/com" |
||||
"github.com/go-xorm/xorm" |
||||
|
||||
"github.com/gogits/gogs/modules/git" |
||||
"github.com/gogits/gogs/modules/log" |
||||
"github.com/gogits/gogs/modules/process" |
||||
) |
||||
|
||||
type PullRequestType int |
||||
|
||||
const ( |
||||
PULL_REQUEST_GOGS PullRequestType = iota |
||||
PLLL_ERQUEST_GIT |
||||
) |
||||
|
||||
type PullRequestStatus int |
||||
|
||||
const ( |
||||
PULL_REQUEST_STATUS_CONFLICT PullRequestStatus = iota |
||||
PULL_REQUEST_STATUS_CHECKING |
||||
PULL_REQUEST_STATUS_MERGEABLE |
||||
) |
||||
|
||||
// PullRequest represents relation between pull request and repositories.
|
||||
type PullRequest struct { |
||||
ID int64 `xorm:"pk autoincr"` |
||||
Type PullRequestType |
||||
Status PullRequestStatus |
||||
|
||||
IssueID int64 `xorm:"INDEX"` |
||||
Issue *Issue `xorm:"-"` |
||||
Index int64 |
||||
|
||||
HeadRepoID int64 |
||||
HeadRepo *Repository `xorm:"-"` |
||||
BaseRepoID int64 |
||||
HeadUserName string |
||||
HeadBranch string |
||||
BaseBranch string |
||||
MergeBase string `xorm:"VARCHAR(40)"` |
||||
MergedCommitID string `xorm:"VARCHAR(40)"` |
||||
|
||||
HasMerged bool |
||||
Merged time.Time |
||||
MergerID int64 |
||||
Merger *User `xorm:"-"` |
||||
} |
||||
|
||||
// Note: don't try to get Pull because will end up recursive querying.
|
||||
func (pr *PullRequest) AfterSet(colName string, _ xorm.Cell) { |
||||
var err error |
||||
switch colName { |
||||
case "head_repo_id": |
||||
// FIXME: shouldn't show error if it's known that head repository has been removed.
|
||||
pr.HeadRepo, err = GetRepositoryByID(pr.HeadRepoID) |
||||
if err != nil { |
||||
log.Error(3, "GetRepositoryByID[%d]: %v", pr.ID, err) |
||||
} |
||||
case "merger_id": |
||||
if !pr.HasMerged { |
||||
return |
||||
} |
||||
|
||||
pr.Merger, err = GetUserByID(pr.MergerID) |
||||
if err != nil { |
||||
if IsErrUserNotExist(err) { |
||||
pr.MergerID = -1 |
||||
pr.Merger = NewFakeUser() |
||||
} else { |
||||
log.Error(3, "GetUserByID[%d]: %v", pr.ID, err) |
||||
} |
||||
} |
||||
case "merged": |
||||
if !pr.HasMerged { |
||||
return |
||||
} |
||||
|
||||
pr.Merged = regulateTimeZone(pr.Merged) |
||||
} |
||||
} |
||||
|
||||
// CanAutoMerge returns true if this pull request can be merged automatically.
|
||||
func (pr *PullRequest) CanAutoMerge() bool { |
||||
return pr.Status == PULL_REQUEST_STATUS_MERGEABLE |
||||
} |
||||
|
||||
// Merge merges pull request to base repository.
|
||||
func (pr *PullRequest) Merge(doer *User, baseGitRepo *git.Repository) (err error) { |
||||
sess := x.NewSession() |
||||
defer sessionRelease(sess) |
||||
if err = sess.Begin(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if err = pr.Issue.changeStatus(sess, doer, true); err != nil { |
||||
return fmt.Errorf("Pull.changeStatus: %v", err) |
||||
} |
||||
|
||||
headRepoPath := RepoPath(pr.HeadUserName, pr.HeadRepo.Name) |
||||
headGitRepo, err := git.OpenRepository(headRepoPath) |
||||
if err != nil { |
||||
return fmt.Errorf("OpenRepository: %v", err) |
||||
} |
||||
pr.MergedCommitID, err = headGitRepo.GetCommitIdOfBranch(pr.HeadBranch) |
||||
if err != nil { |
||||
return fmt.Errorf("GetCommitIdOfBranch: %v", err) |
||||
} |
||||
|
||||
if err = mergePullRequestAction(sess, doer, pr.Issue.Repo, pr.Issue); err != nil { |
||||
return fmt.Errorf("mergePullRequestAction: %v", err) |
||||
} |
||||
|
||||
pr.HasMerged = true |
||||
pr.Merged = time.Now() |
||||
pr.MergerID = doer.Id |
||||
if _, err = sess.Id(pr.ID).AllCols().Update(pr); err != nil { |
||||
return fmt.Errorf("update pull request: %v", err) |
||||
} |
||||
|
||||
// Clone base repo.
|
||||
tmpBasePath := path.Join("data/tmp/repos", com.ToStr(time.Now().Nanosecond())+".git") |
||||
os.MkdirAll(path.Dir(tmpBasePath), os.ModePerm) |
||||
defer os.RemoveAll(path.Dir(tmpBasePath)) |
||||
|
||||
var stderr string |
||||
if _, stderr, err = process.ExecTimeout(5*time.Minute, |
||||
fmt.Sprintf("PullRequest.Merge(git clone): %s", tmpBasePath), |
||||
"git", "clone", baseGitRepo.Path, tmpBasePath); err != nil { |
||||
return fmt.Errorf("git clone: %s", stderr) |
||||
} |
||||
|
||||
// Check out base branch.
|
||||
if _, stderr, err = process.ExecDir(-1, tmpBasePath, |
||||
fmt.Sprintf("PullRequest.Merge(git checkout): %s", tmpBasePath), |
||||
"git", "checkout", pr.BaseBranch); err != nil { |
||||
return fmt.Errorf("git checkout: %s", stderr) |
||||
} |
||||
|
||||
// Pull commits.
|
||||
if _, stderr, err = process.ExecDir(-1, tmpBasePath, |
||||
fmt.Sprintf("PullRequest.Merge(git pull): %s", tmpBasePath), |
||||
"git", "pull", headRepoPath, pr.HeadBranch); err != nil { |
||||
return fmt.Errorf("git pull[%s / %s -> %s]: %s", headRepoPath, pr.HeadBranch, tmpBasePath, stderr) |
||||
} |
||||
|
||||
// Push back to upstream.
|
||||
if _, stderr, err = process.ExecDir(-1, tmpBasePath, |
||||
fmt.Sprintf("PullRequest.Merge(git push): %s", tmpBasePath), |
||||
"git", "push", baseGitRepo.Path, pr.BaseBranch); err != nil { |
||||
return fmt.Errorf("git push: %s", stderr) |
||||
} |
||||
|
||||
return sess.Commit() |
||||
} |
||||
|
||||
// NewPullRequest creates new pull request with labels for repository.
|
||||
func NewPullRequest(repo *Repository, pull *Issue, labelIDs []int64, uuids []string, pr *PullRequest, patch []byte) (err error) { |
||||
sess := x.NewSession() |
||||
defer sessionRelease(sess) |
||||
if err = sess.Begin(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if err = newIssue(sess, repo, pull, labelIDs, uuids, true); err != nil { |
||||
return fmt.Errorf("newIssue: %v", err) |
||||
} |
||||
|
||||
// Notify watchers.
|
||||
act := &Action{ |
||||
ActUserID: pull.Poster.Id, |
||||
ActUserName: pull.Poster.Name, |
||||
ActEmail: pull.Poster.Email, |
||||
OpType: CREATE_PULL_REQUEST, |
||||
Content: fmt.Sprintf("%d|%s", pull.Index, pull.Name), |
||||
RepoID: repo.ID, |
||||
RepoUserName: repo.Owner.Name, |
||||
RepoName: repo.Name, |
||||
IsPrivate: repo.IsPrivate, |
||||
} |
||||
if err = notifyWatchers(sess, act); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// Test apply patch.
|
||||
if err = repo.UpdateLocalCopy(); err != nil { |
||||
return fmt.Errorf("UpdateLocalCopy: %v", err) |
||||
} |
||||
|
||||
repoPath, err := repo.RepoPath() |
||||
if err != nil { |
||||
return fmt.Errorf("RepoPath: %v", err) |
||||
} |
||||
patchPath := path.Join(repoPath, "pulls", com.ToStr(pull.ID)+".patch") |
||||
|
||||
os.MkdirAll(path.Dir(patchPath), os.ModePerm) |
||||
if err = ioutil.WriteFile(patchPath, patch, 0644); err != nil { |
||||
return fmt.Errorf("save patch: %v", err) |
||||
} |
||||
|
||||
pr.Status = PULL_REQUEST_STATUS_MERGEABLE |
||||
_, stderr, err := process.ExecDir(-1, repo.LocalCopyPath(), |
||||
fmt.Sprintf("NewPullRequest(git apply --check): %d", repo.ID), |
||||
"git", "apply", "--check", patchPath) |
||||
if err != nil { |
||||
if strings.Contains(stderr, "patch does not apply") { |
||||
pr.Status = PULL_REQUEST_STATUS_CONFLICT |
||||
} else { |
||||
return fmt.Errorf("git apply --check: %v - %s", err, stderr) |
||||
} |
||||
} |
||||
|
||||
pr.IssueID = pull.ID |
||||
pr.Index = pull.Index |
||||
if _, err = sess.Insert(pr); err != nil { |
||||
return fmt.Errorf("insert pull repo: %v", err) |
||||
} |
||||
|
||||
return sess.Commit() |
||||
} |
||||
|
||||
// GetUnmergedPullRequest returnss a pull request that is open and has not been merged
|
||||
// by given head/base and repo/branch.
|
||||
func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch string) (*PullRequest, error) { |
||||
pr := new(PullRequest) |
||||
|
||||
has, err := x.Where("head_repo_id=? AND head_branch=? AND base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?", |
||||
headRepoID, headBranch, baseRepoID, baseBranch, false, false). |
||||
Join("INNER", "issue", "issue.id=pull_request.issue_id").Get(pr) |
||||
if err != nil { |
||||
return nil, err |
||||
} else if !has { |
||||
return nil, ErrPullRequestNotExist{0, 0, headRepoID, baseRepoID, headBranch, baseBranch} |
||||
} |
||||
|
||||
return pr, nil |
||||
} |
||||
|
||||
// GetPullRequestByIssueID returns pull request by given issue ID.
|
||||
func GetPullRequestByIssueID(pullID int64) (*PullRequest, error) { |
||||
pr := new(PullRequest) |
||||
has, err := x.Where("pull_id=?", pullID).Get(pr) |
||||
if err != nil { |
||||
return nil, err |
||||
} else if !has { |
||||
return nil, ErrPullRequestNotExist{0, pullID, 0, 0, "", ""} |
||||
} |
||||
return pr, nil |
||||
} |
File diff suppressed because one or more lines are too long
@ -0,0 +1,615 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/* |
||||
Package agent implements a client to an ssh-agent daemon. |
||||
|
||||
References: |
||||
[PROTOCOL.agent]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent?rev=HEAD
|
||||
*/ |
||||
package agent |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/dsa" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/rsa" |
||||
"encoding/base64" |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"math/big" |
||||
"sync" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
// Agent represents the capabilities of an ssh-agent.
|
||||
type Agent interface { |
||||
// List returns the identities known to the agent.
|
||||
List() ([]*Key, error) |
||||
|
||||
// Sign has the agent sign the data using a protocol 2 key as defined
|
||||
// in [PROTOCOL.agent] section 2.6.2.
|
||||
Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) |
||||
|
||||
// Add adds a private key to the agent.
|
||||
Add(key AddedKey) error |
||||
|
||||
// Remove removes all identities with the given public key.
|
||||
Remove(key ssh.PublicKey) error |
||||
|
||||
// RemoveAll removes all identities.
|
||||
RemoveAll() error |
||||
|
||||
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
|
||||
Lock(passphrase []byte) error |
||||
|
||||
// Unlock undoes the effect of Lock
|
||||
Unlock(passphrase []byte) error |
||||
|
||||
// Signers returns signers for all the known keys.
|
||||
Signers() ([]ssh.Signer, error) |
||||
} |
||||
|
||||
// AddedKey describes an SSH key to be added to an Agent.
|
||||
type AddedKey struct { |
||||
// PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or
|
||||
// *ecdsa.PrivateKey, which will be inserted into the agent.
|
||||
PrivateKey interface{} |
||||
// Certificate, if not nil, is communicated to the agent and will be
|
||||
// stored with the key.
|
||||
Certificate *ssh.Certificate |
||||
// Comment is an optional, free-form string.
|
||||
Comment string |
||||
// LifetimeSecs, if not zero, is the number of seconds that the
|
||||
// agent will store the key for.
|
||||
LifetimeSecs uint32 |
||||
// ConfirmBeforeUse, if true, requests that the agent confirm with the
|
||||
// user before each use of this key.
|
||||
ConfirmBeforeUse bool |
||||
} |
||||
|
||||
// See [PROTOCOL.agent], section 3.
|
||||
const ( |
||||
agentRequestV1Identities = 1 |
||||
|
||||
// 3.2 Requests from client to agent for protocol 2 key operations
|
||||
agentAddIdentity = 17 |
||||
agentRemoveIdentity = 18 |
||||
agentRemoveAllIdentities = 19 |
||||
agentAddIdConstrained = 25 |
||||
|
||||
// 3.3 Key-type independent requests from client to agent
|
||||
agentAddSmartcardKey = 20 |
||||
agentRemoveSmartcardKey = 21 |
||||
agentLock = 22 |
||||
agentUnlock = 23 |
||||
agentAddSmartcardKeyConstrained = 26 |
||||
|
||||
// 3.7 Key constraint identifiers
|
||||
agentConstrainLifetime = 1 |
||||
agentConstrainConfirm = 2 |
||||
) |
||||
|
||||
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
|
||||
// is a sanity check, not a limit in the spec.
|
||||
const maxAgentResponseBytes = 16 << 20 |
||||
|
||||
// Agent messages:
|
||||
// These structures mirror the wire format of the corresponding ssh agent
|
||||
// messages found in [PROTOCOL.agent].
|
||||
|
||||
// 3.4 Generic replies from agent to client
|
||||
const agentFailure = 5 |
||||
|
||||
type failureAgentMsg struct{} |
||||
|
||||
const agentSuccess = 6 |
||||
|
||||
type successAgentMsg struct{} |
||||
|
||||
// See [PROTOCOL.agent], section 2.5.2.
|
||||
const agentRequestIdentities = 11 |
||||
|
||||
type requestIdentitiesAgentMsg struct{} |
||||
|
||||
// See [PROTOCOL.agent], section 2.5.2.
|
||||
const agentIdentitiesAnswer = 12 |
||||
|
||||
type identitiesAnswerAgentMsg struct { |
||||
NumKeys uint32 `sshtype:"12"` |
||||
Keys []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// See [PROTOCOL.agent], section 2.6.2.
|
||||
const agentSignRequest = 13 |
||||
|
||||
type signRequestAgentMsg struct { |
||||
KeyBlob []byte `sshtype:"13"` |
||||
Data []byte |
||||
Flags uint32 |
||||
} |
||||
|
||||
// See [PROTOCOL.agent], section 2.6.2.
|
||||
|
||||
// 3.6 Replies from agent to client for protocol 2 key operations
|
||||
const agentSignResponse = 14 |
||||
|
||||
type signResponseAgentMsg struct { |
||||
SigBlob []byte `sshtype:"14"` |
||||
} |
||||
|
||||
type publicKey struct { |
||||
Format string |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// Key represents a protocol 2 public key as defined in
|
||||
// [PROTOCOL.agent], section 2.5.2.
|
||||
type Key struct { |
||||
Format string |
||||
Blob []byte |
||||
Comment string |
||||
} |
||||
|
||||
func clientErr(err error) error { |
||||
return fmt.Errorf("agent: client error: %v", err) |
||||
} |
||||
|
||||
// String returns the storage form of an agent key with the format, base64
|
||||
// encoded serialized key, and the comment if it is not empty.
|
||||
func (k *Key) String() string { |
||||
s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob) |
||||
|
||||
if k.Comment != "" { |
||||
s += " " + k.Comment |
||||
} |
||||
|
||||
return s |
||||
} |
||||
|
||||
// Type returns the public key type.
|
||||
func (k *Key) Type() string { |
||||
return k.Format |
||||
} |
||||
|
||||
// Marshal returns key blob to satisfy the ssh.PublicKey interface.
|
||||
func (k *Key) Marshal() []byte { |
||||
return k.Blob |
||||
} |
||||
|
||||
// Verify satisfies the ssh.PublicKey interface, but is not
|
||||
// implemented for agent keys.
|
||||
func (k *Key) Verify(data []byte, sig *ssh.Signature) error { |
||||
return errors.New("agent: agent key does not know how to verify") |
||||
} |
||||
|
||||
type wireKey struct { |
||||
Format string |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
func parseKey(in []byte) (out *Key, rest []byte, err error) { |
||||
var record struct { |
||||
Blob []byte |
||||
Comment string |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
if err := ssh.Unmarshal(in, &record); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
var wk wireKey |
||||
if err := ssh.Unmarshal(record.Blob, &wk); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return &Key{ |
||||
Format: wk.Format, |
||||
Blob: record.Blob, |
||||
Comment: record.Comment, |
||||
}, record.Rest, nil |
||||
} |
||||
|
||||
// client is a client for an ssh-agent process.
|
||||
type client struct { |
||||
// conn is typically a *net.UnixConn
|
||||
conn io.ReadWriter |
||||
// mu is used to prevent concurrent access to the agent
|
||||
mu sync.Mutex |
||||
} |
||||
|
||||
// NewClient returns an Agent that talks to an ssh-agent process over
|
||||
// the given connection.
|
||||
func NewClient(rw io.ReadWriter) Agent { |
||||
return &client{conn: rw} |
||||
} |
||||
|
||||
// call sends an RPC to the agent. On success, the reply is
|
||||
// unmarshaled into reply and replyType is set to the first byte of
|
||||
// the reply, which contains the type of the message.
|
||||
func (c *client) call(req []byte) (reply interface{}, err error) { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
|
||||
msg := make([]byte, 4+len(req)) |
||||
binary.BigEndian.PutUint32(msg, uint32(len(req))) |
||||
copy(msg[4:], req) |
||||
if _, err = c.conn.Write(msg); err != nil { |
||||
return nil, clientErr(err) |
||||
} |
||||
|
||||
var respSizeBuf [4]byte |
||||
if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil { |
||||
return nil, clientErr(err) |
||||
} |
||||
respSize := binary.BigEndian.Uint32(respSizeBuf[:]) |
||||
if respSize > maxAgentResponseBytes { |
||||
return nil, clientErr(err) |
||||
} |
||||
|
||||
buf := make([]byte, respSize) |
||||
if _, err = io.ReadFull(c.conn, buf); err != nil { |
||||
return nil, clientErr(err) |
||||
} |
||||
reply, err = unmarshal(buf) |
||||
if err != nil { |
||||
return nil, clientErr(err) |
||||
} |
||||
return reply, err |
||||
} |
||||
|
||||
func (c *client) simpleCall(req []byte) error { |
||||
resp, err := c.call(req) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if _, ok := resp.(*successAgentMsg); ok { |
||||
return nil |
||||
} |
||||
return errors.New("agent: failure") |
||||
} |
||||
|
||||
func (c *client) RemoveAll() error { |
||||
return c.simpleCall([]byte{agentRemoveAllIdentities}) |
||||
} |
||||
|
||||
func (c *client) Remove(key ssh.PublicKey) error { |
||||
req := ssh.Marshal(&agentRemoveIdentityMsg{ |
||||
KeyBlob: key.Marshal(), |
||||
}) |
||||
return c.simpleCall(req) |
||||
} |
||||
|
||||
func (c *client) Lock(passphrase []byte) error { |
||||
req := ssh.Marshal(&agentLockMsg{ |
||||
Passphrase: passphrase, |
||||
}) |
||||
return c.simpleCall(req) |
||||
} |
||||
|
||||
func (c *client) Unlock(passphrase []byte) error { |
||||
req := ssh.Marshal(&agentUnlockMsg{ |
||||
Passphrase: passphrase, |
||||
}) |
||||
return c.simpleCall(req) |
||||
} |
||||
|
||||
// List returns the identities known to the agent.
|
||||
func (c *client) List() ([]*Key, error) { |
||||
// see [PROTOCOL.agent] section 2.5.2.
|
||||
req := []byte{agentRequestIdentities} |
||||
|
||||
msg, err := c.call(req) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
switch msg := msg.(type) { |
||||
case *identitiesAnswerAgentMsg: |
||||
if msg.NumKeys > maxAgentResponseBytes/8 { |
||||
return nil, errors.New("agent: too many keys in agent reply") |
||||
} |
||||
keys := make([]*Key, msg.NumKeys) |
||||
data := msg.Keys |
||||
for i := uint32(0); i < msg.NumKeys; i++ { |
||||
var key *Key |
||||
var err error |
||||
if key, data, err = parseKey(data); err != nil { |
||||
return nil, err |
||||
} |
||||
keys[i] = key |
||||
} |
||||
return keys, nil |
||||
case *failureAgentMsg: |
||||
return nil, errors.New("agent: failed to list keys") |
||||
} |
||||
panic("unreachable") |
||||
} |
||||
|
||||
// Sign has the agent sign the data using a protocol 2 key as defined
|
||||
// in [PROTOCOL.agent] section 2.6.2.
|
||||
func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { |
||||
req := ssh.Marshal(signRequestAgentMsg{ |
||||
KeyBlob: key.Marshal(), |
||||
Data: data, |
||||
}) |
||||
|
||||
msg, err := c.call(req) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
switch msg := msg.(type) { |
||||
case *signResponseAgentMsg: |
||||
var sig ssh.Signature |
||||
if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &sig, nil |
||||
case *failureAgentMsg: |
||||
return nil, errors.New("agent: failed to sign challenge") |
||||
} |
||||
panic("unreachable") |
||||
} |
||||
|
||||
// unmarshal parses an agent message in packet, returning the parsed
|
||||
// form and the message type of packet.
|
||||
func unmarshal(packet []byte) (interface{}, error) { |
||||
if len(packet) < 1 { |
||||
return nil, errors.New("agent: empty packet") |
||||
} |
||||
var msg interface{} |
||||
switch packet[0] { |
||||
case agentFailure: |
||||
return new(failureAgentMsg), nil |
||||
case agentSuccess: |
||||
return new(successAgentMsg), nil |
||||
case agentIdentitiesAnswer: |
||||
msg = new(identitiesAnswerAgentMsg) |
||||
case agentSignResponse: |
||||
msg = new(signResponseAgentMsg) |
||||
default: |
||||
return nil, fmt.Errorf("agent: unknown type tag %d", packet[0]) |
||||
} |
||||
if err := ssh.Unmarshal(packet, msg); err != nil { |
||||
return nil, err |
||||
} |
||||
return msg, nil |
||||
} |
||||
|
||||
type rsaKeyMsg struct { |
||||
Type string `sshtype:"17"` |
||||
N *big.Int |
||||
E *big.Int |
||||
D *big.Int |
||||
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
||||
P *big.Int |
||||
Q *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
type dsaKeyMsg struct { |
||||
Type string `sshtype:"17"` |
||||
P *big.Int |
||||
Q *big.Int |
||||
G *big.Int |
||||
Y *big.Int |
||||
X *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
type ecdsaKeyMsg struct { |
||||
Type string `sshtype:"17"` |
||||
Curve string |
||||
KeyBytes []byte |
||||
D *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// Insert adds a private key to the agent.
|
||||
func (c *client) insertKey(s interface{}, comment string, constraints []byte) error { |
||||
var req []byte |
||||
switch k := s.(type) { |
||||
case *rsa.PrivateKey: |
||||
if len(k.Primes) != 2 { |
||||
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) |
||||
} |
||||
k.Precompute() |
||||
req = ssh.Marshal(rsaKeyMsg{ |
||||
Type: ssh.KeyAlgoRSA, |
||||
N: k.N, |
||||
E: big.NewInt(int64(k.E)), |
||||
D: k.D, |
||||
Iqmp: k.Precomputed.Qinv, |
||||
P: k.Primes[0], |
||||
Q: k.Primes[1], |
||||
Comments: comment, |
||||
Constraints: constraints, |
||||
}) |
||||
case *dsa.PrivateKey: |
||||
req = ssh.Marshal(dsaKeyMsg{ |
||||
Type: ssh.KeyAlgoDSA, |
||||
P: k.P, |
||||
Q: k.Q, |
||||
G: k.G, |
||||
Y: k.Y, |
||||
X: k.X, |
||||
Comments: comment, |
||||
Constraints: constraints, |
||||
}) |
||||
case *ecdsa.PrivateKey: |
||||
nistID := fmt.Sprintf("nistp%d", k.Params().BitSize) |
||||
req = ssh.Marshal(ecdsaKeyMsg{ |
||||
Type: "ecdsa-sha2-" + nistID, |
||||
Curve: nistID, |
||||
KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y), |
||||
D: k.D, |
||||
Comments: comment, |
||||
Constraints: constraints, |
||||
}) |
||||
default: |
||||
return fmt.Errorf("agent: unsupported key type %T", s) |
||||
} |
||||
|
||||
// if constraints are present then the message type needs to be changed.
|
||||
if len(constraints) != 0 { |
||||
req[0] = agentAddIdConstrained |
||||
} |
||||
|
||||
resp, err := c.call(req) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if _, ok := resp.(*successAgentMsg); ok { |
||||
return nil |
||||
} |
||||
return errors.New("agent: failure") |
||||
} |
||||
|
||||
type rsaCertMsg struct { |
||||
Type string `sshtype:"17"` |
||||
CertBytes []byte |
||||
D *big.Int |
||||
Iqmp *big.Int // IQMP = Inverse Q Mod P
|
||||
P *big.Int |
||||
Q *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
type dsaCertMsg struct { |
||||
Type string `sshtype:"17"` |
||||
CertBytes []byte |
||||
X *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
type ecdsaCertMsg struct { |
||||
Type string `sshtype:"17"` |
||||
CertBytes []byte |
||||
D *big.Int |
||||
Comments string |
||||
Constraints []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// Insert adds a private key to the agent. If a certificate is given,
|
||||
// that certificate is added instead as public key.
|
||||
func (c *client) Add(key AddedKey) error { |
||||
var constraints []byte |
||||
|
||||
if secs := key.LifetimeSecs; secs != 0 { |
||||
constraints = append(constraints, agentConstrainLifetime) |
||||
|
||||
var secsBytes [4]byte |
||||
binary.BigEndian.PutUint32(secsBytes[:], secs) |
||||
constraints = append(constraints, secsBytes[:]...) |
||||
} |
||||
|
||||
if key.ConfirmBeforeUse { |
||||
constraints = append(constraints, agentConstrainConfirm) |
||||
} |
||||
|
||||
if cert := key.Certificate; cert == nil { |
||||
return c.insertKey(key.PrivateKey, key.Comment, constraints) |
||||
} else { |
||||
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints) |
||||
} |
||||
} |
||||
|
||||
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error { |
||||
var req []byte |
||||
switch k := s.(type) { |
||||
case *rsa.PrivateKey: |
||||
if len(k.Primes) != 2 { |
||||
return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) |
||||
} |
||||
k.Precompute() |
||||
req = ssh.Marshal(rsaCertMsg{ |
||||
Type: cert.Type(), |
||||
CertBytes: cert.Marshal(), |
||||
D: k.D, |
||||
Iqmp: k.Precomputed.Qinv, |
||||
P: k.Primes[0], |
||||
Q: k.Primes[1], |
||||
Comments: comment, |
||||
Constraints: constraints, |
||||
}) |
||||
case *dsa.PrivateKey: |
||||
req = ssh.Marshal(dsaCertMsg{ |
||||
Type: cert.Type(), |
||||
CertBytes: cert.Marshal(), |
||||
X: k.X, |
||||
Comments: comment, |
||||
}) |
||||
case *ecdsa.PrivateKey: |
||||
req = ssh.Marshal(ecdsaCertMsg{ |
||||
Type: cert.Type(), |
||||
CertBytes: cert.Marshal(), |
||||
D: k.D, |
||||
Comments: comment, |
||||
}) |
||||
default: |
||||
return fmt.Errorf("agent: unsupported key type %T", s) |
||||
} |
||||
|
||||
// if constraints are present then the message type needs to be changed.
|
||||
if len(constraints) != 0 { |
||||
req[0] = agentAddIdConstrained |
||||
} |
||||
|
||||
signer, err := ssh.NewSignerFromKey(s) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { |
||||
return errors.New("agent: signer and cert have different public key") |
||||
} |
||||
|
||||
resp, err := c.call(req) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if _, ok := resp.(*successAgentMsg); ok { |
||||
return nil |
||||
} |
||||
return errors.New("agent: failure") |
||||
} |
||||
|
||||
// Signers provides a callback for client authentication.
|
||||
func (c *client) Signers() ([]ssh.Signer, error) { |
||||
keys, err := c.List() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var result []ssh.Signer |
||||
for _, k := range keys { |
||||
result = append(result, &agentKeyringSigner{c, k}) |
||||
} |
||||
return result, nil |
||||
} |
||||
|
||||
type agentKeyringSigner struct { |
||||
agent *client |
||||
pub ssh.PublicKey |
||||
} |
||||
|
||||
func (s *agentKeyringSigner) PublicKey() ssh.PublicKey { |
||||
return s.pub |
||||
} |
||||
|
||||
func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { |
||||
// The agent has its own entropy source, so the rand argument is ignored.
|
||||
return s.agent.Sign(s.pub, data) |
||||
} |
@ -0,0 +1,287 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"errors" |
||||
"net" |
||||
"os" |
||||
"os/exec" |
||||
"path/filepath" |
||||
"strconv" |
||||
"testing" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
// startAgent executes ssh-agent, and returns a Agent interface to it.
|
||||
func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) { |
||||
if testing.Short() { |
||||
// ssh-agent is not always available, and the key
|
||||
// types supported vary by platform.
|
||||
t.Skip("skipping test due to -short") |
||||
} |
||||
|
||||
bin, err := exec.LookPath("ssh-agent") |
||||
if err != nil { |
||||
t.Skip("could not find ssh-agent") |
||||
} |
||||
|
||||
cmd := exec.Command(bin, "-s") |
||||
out, err := cmd.Output() |
||||
if err != nil { |
||||
t.Fatalf("cmd.Output: %v", err) |
||||
} |
||||
|
||||
/* Output looks like: |
||||
|
||||
SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; |
||||
SSH_AGENT_PID=15542; export SSH_AGENT_PID; |
||||
echo Agent pid 15542; |
||||
*/ |
||||
fields := bytes.Split(out, []byte(";")) |
||||
line := bytes.SplitN(fields[0], []byte("="), 2) |
||||
line[0] = bytes.TrimLeft(line[0], "\n") |
||||
if string(line[0]) != "SSH_AUTH_SOCK" { |
||||
t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) |
||||
} |
||||
socket = string(line[1]) |
||||
|
||||
line = bytes.SplitN(fields[2], []byte("="), 2) |
||||
line[0] = bytes.TrimLeft(line[0], "\n") |
||||
if string(line[0]) != "SSH_AGENT_PID" { |
||||
t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) |
||||
} |
||||
pidStr := line[1] |
||||
pid, err := strconv.Atoi(string(pidStr)) |
||||
if err != nil { |
||||
t.Fatalf("Atoi(%q): %v", pidStr, err) |
||||
} |
||||
|
||||
conn, err := net.Dial("unix", string(socket)) |
||||
if err != nil { |
||||
t.Fatalf("net.Dial: %v", err) |
||||
} |
||||
|
||||
ac := NewClient(conn) |
||||
return ac, socket, func() { |
||||
proc, _ := os.FindProcess(pid) |
||||
if proc != nil { |
||||
proc.Kill() |
||||
} |
||||
conn.Close() |
||||
os.RemoveAll(filepath.Dir(socket)) |
||||
} |
||||
} |
||||
|
||||
func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { |
||||
agent, _, cleanup := startAgent(t) |
||||
defer cleanup() |
||||
|
||||
testAgentInterface(t, agent, key, cert, lifetimeSecs) |
||||
} |
||||
|
||||
func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { |
||||
signer, err := ssh.NewSignerFromKey(key) |
||||
if err != nil { |
||||
t.Fatalf("NewSignerFromKey(%T): %v", key, err) |
||||
} |
||||
// The agent should start up empty.
|
||||
if keys, err := agent.List(); err != nil { |
||||
t.Fatalf("RequestIdentities: %v", err) |
||||
} else if len(keys) > 0 { |
||||
t.Fatalf("got %d keys, want 0: %v", len(keys), keys) |
||||
} |
||||
|
||||
// Attempt to insert the key, with certificate if specified.
|
||||
var pubKey ssh.PublicKey |
||||
if cert != nil { |
||||
err = agent.Add(AddedKey{ |
||||
PrivateKey: key, |
||||
Certificate: cert, |
||||
Comment: "comment", |
||||
LifetimeSecs: lifetimeSecs, |
||||
}) |
||||
pubKey = cert |
||||
} else { |
||||
err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs}) |
||||
pubKey = signer.PublicKey() |
||||
} |
||||
if err != nil { |
||||
t.Fatalf("insert(%T): %v", key, err) |
||||
} |
||||
|
||||
// Did the key get inserted successfully?
|
||||
if keys, err := agent.List(); err != nil { |
||||
t.Fatalf("List: %v", err) |
||||
} else if len(keys) != 1 { |
||||
t.Fatalf("got %v, want 1 key", keys) |
||||
} else if keys[0].Comment != "comment" { |
||||
t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment") |
||||
} else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) { |
||||
t.Fatalf("key mismatch") |
||||
} |
||||
|
||||
// Can the agent make a valid signature?
|
||||
data := []byte("hello") |
||||
sig, err := agent.Sign(pubKey, data) |
||||
if err != nil { |
||||
t.Fatalf("Sign(%s): %v", pubKey.Type(), err) |
||||
} |
||||
|
||||
if err := pubKey.Verify(data, sig); err != nil { |
||||
t.Fatalf("Verify(%s): %v", pubKey.Type(), err) |
||||
} |
||||
} |
||||
|
||||
func TestAgent(t *testing.T) { |
||||
for _, keyType := range []string{"rsa", "dsa", "ecdsa"} { |
||||
testAgent(t, testPrivateKeys[keyType], nil, 0) |
||||
} |
||||
} |
||||
|
||||
func TestCert(t *testing.T) { |
||||
cert := &ssh.Certificate{ |
||||
Key: testPublicKeys["rsa"], |
||||
ValidBefore: ssh.CertTimeInfinity, |
||||
CertType: ssh.UserCert, |
||||
} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
|
||||
testAgent(t, testPrivateKeys["rsa"], cert, 0) |
||||
} |
||||
|
||||
func TestConstraints(t *testing.T) { |
||||
testAgent(t, testPrivateKeys["rsa"], nil, 3600 /* lifetime in seconds */) |
||||
} |
||||
|
||||
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
||||
// therefore is buffered (net.Pipe deadlocks if both sides start with
|
||||
// a write.)
|
||||
func netPipe() (net.Conn, net.Conn, error) { |
||||
listener, err := net.Listen("tcp", "127.0.0.1:0") |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
defer listener.Close() |
||||
c1, err := net.Dial("tcp", listener.Addr().String()) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
c2, err := listener.Accept() |
||||
if err != nil { |
||||
c1.Close() |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return c1, c2, nil |
||||
} |
||||
|
||||
func TestAuth(t *testing.T) { |
||||
a, b, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
|
||||
defer a.Close() |
||||
defer b.Close() |
||||
|
||||
agent, _, cleanup := startAgent(t) |
||||
defer cleanup() |
||||
|
||||
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { |
||||
t.Errorf("Add: %v", err) |
||||
} |
||||
|
||||
serverConf := ssh.ServerConfig{} |
||||
serverConf.AddHostKey(testSigners["rsa"]) |
||||
serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { |
||||
if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { |
||||
return nil, nil |
||||
} |
||||
|
||||
return nil, errors.New("pubkey rejected") |
||||
} |
||||
|
||||
go func() { |
||||
conn, _, _, err := ssh.NewServerConn(a, &serverConf) |
||||
if err != nil { |
||||
t.Fatalf("Server: %v", err) |
||||
} |
||||
conn.Close() |
||||
}() |
||||
|
||||
conf := ssh.ClientConfig{} |
||||
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) |
||||
conn, _, _, err := ssh.NewClientConn(b, "", &conf) |
||||
if err != nil { |
||||
t.Fatalf("NewClientConn: %v", err) |
||||
} |
||||
conn.Close() |
||||
} |
||||
|
||||
func TestLockClient(t *testing.T) { |
||||
agent, _, cleanup := startAgent(t) |
||||
defer cleanup() |
||||
testLockAgent(agent, t) |
||||
} |
||||
|
||||
func testLockAgent(agent Agent, t *testing.T) { |
||||
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil { |
||||
t.Errorf("Add: %v", err) |
||||
} |
||||
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil { |
||||
t.Errorf("Add: %v", err) |
||||
} |
||||
if keys, err := agent.List(); err != nil { |
||||
t.Errorf("List: %v", err) |
||||
} else if len(keys) != 2 { |
||||
t.Errorf("Want 2 keys, got %v", keys) |
||||
} |
||||
|
||||
passphrase := []byte("secret") |
||||
if err := agent.Lock(passphrase); err != nil { |
||||
t.Errorf("Lock: %v", err) |
||||
} |
||||
|
||||
if keys, err := agent.List(); err != nil { |
||||
t.Errorf("List: %v", err) |
||||
} else if len(keys) != 0 { |
||||
t.Errorf("Want 0 keys, got %v", keys) |
||||
} |
||||
|
||||
signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"]) |
||||
if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil { |
||||
t.Fatalf("Sign did not fail") |
||||
} |
||||
|
||||
if err := agent.Remove(signer.PublicKey()); err == nil { |
||||
t.Fatalf("Remove did not fail") |
||||
} |
||||
|
||||
if err := agent.RemoveAll(); err == nil { |
||||
t.Fatalf("RemoveAll did not fail") |
||||
} |
||||
|
||||
if err := agent.Unlock(nil); err == nil { |
||||
t.Errorf("Unlock with wrong passphrase succeeded") |
||||
} |
||||
if err := agent.Unlock(passphrase); err != nil { |
||||
t.Errorf("Unlock: %v", err) |
||||
} |
||||
|
||||
if err := agent.Remove(signer.PublicKey()); err != nil { |
||||
t.Fatalf("Remove: %v", err) |
||||
} |
||||
|
||||
if keys, err := agent.List(); err != nil { |
||||
t.Errorf("List: %v", err) |
||||
} else if len(keys) != 1 { |
||||
t.Errorf("Want 1 keys, got %v", keys) |
||||
} |
||||
} |
@ -0,0 +1,103 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"errors" |
||||
"io" |
||||
"net" |
||||
"sync" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
// RequestAgentForwarding sets up agent forwarding for the session.
|
||||
// ForwardToAgent or ForwardToRemote should be called to route
|
||||
// the authentication requests.
|
||||
func RequestAgentForwarding(session *ssh.Session) error { |
||||
ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !ok { |
||||
return errors.New("forwarding request denied") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// ForwardToAgent routes authentication requests to the given keyring.
|
||||
func ForwardToAgent(client *ssh.Client, keyring Agent) error { |
||||
channels := client.HandleChannelOpen(channelType) |
||||
if channels == nil { |
||||
return errors.New("agent: already have handler for " + channelType) |
||||
} |
||||
|
||||
go func() { |
||||
for ch := range channels { |
||||
channel, reqs, err := ch.Accept() |
||||
if err != nil { |
||||
continue |
||||
} |
||||
go ssh.DiscardRequests(reqs) |
||||
go func() { |
||||
ServeAgent(keyring, channel) |
||||
channel.Close() |
||||
}() |
||||
} |
||||
}() |
||||
return nil |
||||
} |
||||
|
||||
const channelType = "auth-agent@openssh.com" |
||||
|
||||
// ForwardToRemote routes authentication requests to the ssh-agent
|
||||
// process serving on the given unix socket.
|
||||
func ForwardToRemote(client *ssh.Client, addr string) error { |
||||
channels := client.HandleChannelOpen(channelType) |
||||
if channels == nil { |
||||
return errors.New("agent: already have handler for " + channelType) |
||||
} |
||||
conn, err := net.Dial("unix", addr) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
conn.Close() |
||||
|
||||
go func() { |
||||
for ch := range channels { |
||||
channel, reqs, err := ch.Accept() |
||||
if err != nil { |
||||
continue |
||||
} |
||||
go ssh.DiscardRequests(reqs) |
||||
go forwardUnixSocket(channel, addr) |
||||
} |
||||
}() |
||||
return nil |
||||
} |
||||
|
||||
func forwardUnixSocket(channel ssh.Channel, addr string) { |
||||
conn, err := net.Dial("unix", addr) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
var wg sync.WaitGroup |
||||
wg.Add(2) |
||||
go func() { |
||||
io.Copy(conn, channel) |
||||
conn.(*net.UnixConn).CloseWrite() |
||||
wg.Done() |
||||
}() |
||||
go func() { |
||||
io.Copy(channel, conn) |
||||
channel.CloseWrite() |
||||
wg.Done() |
||||
}() |
||||
|
||||
wg.Wait() |
||||
conn.Close() |
||||
channel.Close() |
||||
} |
@ -0,0 +1,184 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"crypto/subtle" |
||||
"errors" |
||||
"fmt" |
||||
"sync" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
type privKey struct { |
||||
signer ssh.Signer |
||||
comment string |
||||
} |
||||
|
||||
type keyring struct { |
||||
mu sync.Mutex |
||||
keys []privKey |
||||
|
||||
locked bool |
||||
passphrase []byte |
||||
} |
||||
|
||||
var errLocked = errors.New("agent: locked") |
||||
|
||||
// NewKeyring returns an Agent that holds keys in memory. It is safe
|
||||
// for concurrent use by multiple goroutines.
|
||||
func NewKeyring() Agent { |
||||
return &keyring{} |
||||
} |
||||
|
||||
// RemoveAll removes all identities.
|
||||
func (r *keyring) RemoveAll() error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return errLocked |
||||
} |
||||
|
||||
r.keys = nil |
||||
return nil |
||||
} |
||||
|
||||
// Remove removes all identities with the given public key.
|
||||
func (r *keyring) Remove(key ssh.PublicKey) error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return errLocked |
||||
} |
||||
|
||||
want := key.Marshal() |
||||
found := false |
||||
for i := 0; i < len(r.keys); { |
||||
if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) { |
||||
found = true |
||||
r.keys[i] = r.keys[len(r.keys)-1] |
||||
r.keys = r.keys[len(r.keys)-1:] |
||||
continue |
||||
} else { |
||||
i++ |
||||
} |
||||
} |
||||
|
||||
if !found { |
||||
return errors.New("agent: key not found") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
|
||||
func (r *keyring) Lock(passphrase []byte) error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return errLocked |
||||
} |
||||
|
||||
r.locked = true |
||||
r.passphrase = passphrase |
||||
return nil |
||||
} |
||||
|
||||
// Unlock undoes the effect of Lock
|
||||
func (r *keyring) Unlock(passphrase []byte) error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if !r.locked { |
||||
return errors.New("agent: not locked") |
||||
} |
||||
if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) { |
||||
return fmt.Errorf("agent: incorrect passphrase") |
||||
} |
||||
|
||||
r.locked = false |
||||
r.passphrase = nil |
||||
return nil |
||||
} |
||||
|
||||
// List returns the identities known to the agent.
|
||||
func (r *keyring) List() ([]*Key, error) { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
// section 2.7: locked agents return empty.
|
||||
return nil, nil |
||||
} |
||||
|
||||
var ids []*Key |
||||
for _, k := range r.keys { |
||||
pub := k.signer.PublicKey() |
||||
ids = append(ids, &Key{ |
||||
Format: pub.Type(), |
||||
Blob: pub.Marshal(), |
||||
Comment: k.comment}) |
||||
} |
||||
return ids, nil |
||||
} |
||||
|
||||
// Insert adds a private key to the keyring. If a certificate
|
||||
// is given, that certificate is added as public key. Note that
|
||||
// any constraints given are ignored.
|
||||
func (r *keyring) Add(key AddedKey) error { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return errLocked |
||||
} |
||||
signer, err := ssh.NewSignerFromKey(key.PrivateKey) |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if cert := key.Certificate; cert != nil { |
||||
signer, err = ssh.NewCertSigner(cert, signer) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
r.keys = append(r.keys, privKey{signer, key.Comment}) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// Sign returns a signature for the data.
|
||||
func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return nil, errLocked |
||||
} |
||||
|
||||
wanted := key.Marshal() |
||||
for _, k := range r.keys { |
||||
if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) { |
||||
return k.signer.Sign(rand.Reader, data) |
||||
} |
||||
} |
||||
return nil, errors.New("not found") |
||||
} |
||||
|
||||
// Signers returns signers for all the known keys.
|
||||
func (r *keyring) Signers() ([]ssh.Signer, error) { |
||||
r.mu.Lock() |
||||
defer r.mu.Unlock() |
||||
if r.locked { |
||||
return nil, errLocked |
||||
} |
||||
|
||||
s := make([]ssh.Signer, 0, len(r.keys)) |
||||
for _, k := range r.keys { |
||||
s = append(s, k.signer) |
||||
} |
||||
return s, nil |
||||
} |
@ -0,0 +1,209 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"crypto/rsa" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"math/big" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
// Server wraps an Agent and uses it to implement the agent side of
|
||||
// the SSH-agent, wire protocol.
|
||||
type server struct { |
||||
agent Agent |
||||
} |
||||
|
||||
func (s *server) processRequestBytes(reqData []byte) []byte { |
||||
rep, err := s.processRequest(reqData) |
||||
if err != nil { |
||||
if err != errLocked { |
||||
// TODO(hanwen): provide better logging interface?
|
||||
log.Printf("agent %d: %v", reqData[0], err) |
||||
} |
||||
return []byte{agentFailure} |
||||
} |
||||
|
||||
if err == nil && rep == nil { |
||||
return []byte{agentSuccess} |
||||
} |
||||
|
||||
return ssh.Marshal(rep) |
||||
} |
||||
|
||||
func marshalKey(k *Key) []byte { |
||||
var record struct { |
||||
Blob []byte |
||||
Comment string |
||||
} |
||||
record.Blob = k.Marshal() |
||||
record.Comment = k.Comment |
||||
|
||||
return ssh.Marshal(&record) |
||||
} |
||||
|
||||
type agentV1IdentityMsg struct { |
||||
Numkeys uint32 `sshtype:"2"` |
||||
} |
||||
|
||||
type agentRemoveIdentityMsg struct { |
||||
KeyBlob []byte `sshtype:"18"` |
||||
} |
||||
|
||||
type agentLockMsg struct { |
||||
Passphrase []byte `sshtype:"22"` |
||||
} |
||||
|
||||
type agentUnlockMsg struct { |
||||
Passphrase []byte `sshtype:"23"` |
||||
} |
||||
|
||||
func (s *server) processRequest(data []byte) (interface{}, error) { |
||||
switch data[0] { |
||||
case agentRequestV1Identities: |
||||
return &agentV1IdentityMsg{0}, nil |
||||
case agentRemoveIdentity: |
||||
var req agentRemoveIdentityMsg |
||||
if err := ssh.Unmarshal(data, &req); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var wk wireKey |
||||
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) |
||||
|
||||
case agentRemoveAllIdentities: |
||||
return nil, s.agent.RemoveAll() |
||||
|
||||
case agentLock: |
||||
var req agentLockMsg |
||||
if err := ssh.Unmarshal(data, &req); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return nil, s.agent.Lock(req.Passphrase) |
||||
|
||||
case agentUnlock: |
||||
var req agentLockMsg |
||||
if err := ssh.Unmarshal(data, &req); err != nil { |
||||
return nil, err |
||||
} |
||||
return nil, s.agent.Unlock(req.Passphrase) |
||||
|
||||
case agentSignRequest: |
||||
var req signRequestAgentMsg |
||||
if err := ssh.Unmarshal(data, &req); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var wk wireKey |
||||
if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
k := &Key{ |
||||
Format: wk.Format, |
||||
Blob: req.KeyBlob, |
||||
} |
||||
|
||||
sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags.
|
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil |
||||
case agentRequestIdentities: |
||||
keys, err := s.agent.List() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
rep := identitiesAnswerAgentMsg{ |
||||
NumKeys: uint32(len(keys)), |
||||
} |
||||
for _, k := range keys { |
||||
rep.Keys = append(rep.Keys, marshalKey(k)...) |
||||
} |
||||
return rep, nil |
||||
case agentAddIdentity: |
||||
return nil, s.insertIdentity(data) |
||||
} |
||||
|
||||
return nil, fmt.Errorf("unknown opcode %d", data[0]) |
||||
} |
||||
|
||||
func (s *server) insertIdentity(req []byte) error { |
||||
var record struct { |
||||
Type string `sshtype:"17"` |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
if err := ssh.Unmarshal(req, &record); err != nil { |
||||
return err |
||||
} |
||||
|
||||
switch record.Type { |
||||
case ssh.KeyAlgoRSA: |
||||
var k rsaKeyMsg |
||||
if err := ssh.Unmarshal(req, &k); err != nil { |
||||
return err |
||||
} |
||||
|
||||
priv := rsa.PrivateKey{ |
||||
PublicKey: rsa.PublicKey{ |
||||
E: int(k.E.Int64()), |
||||
N: k.N, |
||||
}, |
||||
D: k.D, |
||||
Primes: []*big.Int{k.P, k.Q}, |
||||
} |
||||
priv.Precompute() |
||||
|
||||
return s.agent.Add(AddedKey{PrivateKey: &priv, Comment: k.Comments}) |
||||
} |
||||
return fmt.Errorf("not implemented: %s", record.Type) |
||||
} |
||||
|
||||
// ServeAgent serves the agent protocol on the given connection. It
|
||||
// returns when an I/O error occurs.
|
||||
func ServeAgent(agent Agent, c io.ReadWriter) error { |
||||
s := &server{agent} |
||||
|
||||
var length [4]byte |
||||
for { |
||||
if _, err := io.ReadFull(c, length[:]); err != nil { |
||||
return err |
||||
} |
||||
l := binary.BigEndian.Uint32(length[:]) |
||||
if l > maxAgentResponseBytes { |
||||
// We also cap requests.
|
||||
return fmt.Errorf("agent: request too large: %d", l) |
||||
} |
||||
|
||||
req := make([]byte, l) |
||||
if _, err := io.ReadFull(c, req); err != nil { |
||||
return err |
||||
} |
||||
|
||||
repData := s.processRequestBytes(req) |
||||
if len(repData) > maxAgentResponseBytes { |
||||
return fmt.Errorf("agent: reply too large: %d bytes", len(repData)) |
||||
} |
||||
|
||||
binary.BigEndian.PutUint32(length[:], uint32(len(repData))) |
||||
if _, err := c.Write(length[:]); err != nil { |
||||
return err |
||||
} |
||||
if _, err := c.Write(repData); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,77 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
) |
||||
|
||||
func TestServer(t *testing.T) { |
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
defer c1.Close() |
||||
defer c2.Close() |
||||
client := NewClient(c1) |
||||
|
||||
go ServeAgent(NewKeyring(), c2) |
||||
|
||||
testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0) |
||||
} |
||||
|
||||
func TestLockServer(t *testing.T) { |
||||
testLockAgent(NewKeyring(), t) |
||||
} |
||||
|
||||
func TestSetupForwardAgent(t *testing.T) { |
||||
a, b, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
|
||||
defer a.Close() |
||||
defer b.Close() |
||||
|
||||
_, socket, cleanup := startAgent(t) |
||||
defer cleanup() |
||||
|
||||
serverConf := ssh.ServerConfig{ |
||||
NoClientAuth: true, |
||||
} |
||||
serverConf.AddHostKey(testSigners["rsa"]) |
||||
incoming := make(chan *ssh.ServerConn, 1) |
||||
go func() { |
||||
conn, _, _, err := ssh.NewServerConn(a, &serverConf) |
||||
if err != nil { |
||||
t.Fatalf("Server: %v", err) |
||||
} |
||||
incoming <- conn |
||||
}() |
||||
|
||||
conf := ssh.ClientConfig{} |
||||
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) |
||||
if err != nil { |
||||
t.Fatalf("NewClientConn: %v", err) |
||||
} |
||||
client := ssh.NewClient(conn, chans, reqs) |
||||
|
||||
if err := ForwardToRemote(client, socket); err != nil { |
||||
t.Fatalf("SetupForwardAgent: %v", err) |
||||
} |
||||
|
||||
server := <-incoming |
||||
ch, reqs, err := server.OpenChannel(channelType, nil) |
||||
if err != nil { |
||||
t.Fatalf("OpenChannel(%q): %v", channelType, err) |
||||
} |
||||
go ssh.DiscardRequests(reqs) |
||||
|
||||
agentClient := NewClient(ch) |
||||
testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0) |
||||
conn.Close() |
||||
} |
@ -0,0 +1,64 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
|
||||
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
|
||||
// instances.
|
||||
|
||||
package agent |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"fmt" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
"github.com/gogits/gogs/modules/crypto/ssh/testdata" |
||||
) |
||||
|
||||
var ( |
||||
testPrivateKeys map[string]interface{} |
||||
testSigners map[string]ssh.Signer |
||||
testPublicKeys map[string]ssh.PublicKey |
||||
) |
||||
|
||||
func init() { |
||||
var err error |
||||
|
||||
n := len(testdata.PEMBytes) |
||||
testPrivateKeys = make(map[string]interface{}, n) |
||||
testSigners = make(map[string]ssh.Signer, n) |
||||
testPublicKeys = make(map[string]ssh.PublicKey, n) |
||||
for t, k := range testdata.PEMBytes { |
||||
testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) |
||||
if err != nil { |
||||
panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) |
||||
} |
||||
testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) |
||||
if err != nil { |
||||
panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) |
||||
} |
||||
testPublicKeys[t] = testSigners[t].PublicKey() |
||||
} |
||||
|
||||
// Create a cert and sign it for use in tests.
|
||||
testCert := &ssh.Certificate{ |
||||
Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||
ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage
|
||||
ValidAfter: 0, // unix epoch
|
||||
ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time.
|
||||
Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil
|
||||
Key: testPublicKeys["ecdsa"], |
||||
SignatureKey: testPublicKeys["rsa"], |
||||
Permissions: ssh.Permissions{ |
||||
CriticalOptions: map[string]string{}, |
||||
Extensions: map[string]string{}, |
||||
}, |
||||
} |
||||
testCert.SignCert(rand.Reader, testSigners["rsa"]) |
||||
testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] |
||||
testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) |
||||
if err != nil { |
||||
panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) |
||||
} |
||||
} |
@ -0,0 +1,122 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"errors" |
||||
"io" |
||||
"net" |
||||
"testing" |
||||
) |
||||
|
||||
type server struct { |
||||
*ServerConn |
||||
chans <-chan NewChannel |
||||
} |
||||
|
||||
func newServer(c net.Conn, conf *ServerConfig) (*server, error) { |
||||
sconn, chans, reqs, err := NewServerConn(c, conf) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
go DiscardRequests(reqs) |
||||
return &server{sconn, chans}, nil |
||||
} |
||||
|
||||
func (s *server) Accept() (NewChannel, error) { |
||||
n, ok := <-s.chans |
||||
if !ok { |
||||
return nil, io.EOF |
||||
} |
||||
return n, nil |
||||
} |
||||
|
||||
func sshPipe() (Conn, *server, error) { |
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
clientConf := ClientConfig{ |
||||
User: "user", |
||||
} |
||||
serverConf := ServerConfig{ |
||||
NoClientAuth: true, |
||||
} |
||||
serverConf.AddHostKey(testSigners["ecdsa"]) |
||||
done := make(chan *server, 1) |
||||
go func() { |
||||
server, err := newServer(c2, &serverConf) |
||||
if err != nil { |
||||
done <- nil |
||||
} |
||||
done <- server |
||||
}() |
||||
|
||||
client, _, reqs, err := NewClientConn(c1, "", &clientConf) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
server := <-done |
||||
if server == nil { |
||||
return nil, nil, errors.New("server handshake failed.") |
||||
} |
||||
go DiscardRequests(reqs) |
||||
|
||||
return client, server, nil |
||||
} |
||||
|
||||
func BenchmarkEndToEnd(b *testing.B) { |
||||
b.StopTimer() |
||||
|
||||
client, server, err := sshPipe() |
||||
if err != nil { |
||||
b.Fatalf("sshPipe: %v", err) |
||||
} |
||||
|
||||
defer client.Close() |
||||
defer server.Close() |
||||
|
||||
size := (1 << 20) |
||||
input := make([]byte, size) |
||||
output := make([]byte, size) |
||||
b.SetBytes(int64(size)) |
||||
done := make(chan int, 1) |
||||
|
||||
go func() { |
||||
newCh, err := server.Accept() |
||||
if err != nil { |
||||
b.Fatalf("Client: %v", err) |
||||
} |
||||
ch, incoming, err := newCh.Accept() |
||||
go DiscardRequests(incoming) |
||||
for i := 0; i < b.N; i++ { |
||||
if _, err := io.ReadFull(ch, output); err != nil { |
||||
b.Fatalf("ReadFull: %v", err) |
||||
} |
||||
} |
||||
ch.Close() |
||||
done <- 1 |
||||
}() |
||||
|
||||
ch, in, err := client.OpenChannel("speed", nil) |
||||
if err != nil { |
||||
b.Fatalf("OpenChannel: %v", err) |
||||
} |
||||
go DiscardRequests(in) |
||||
|
||||
b.ResetTimer() |
||||
b.StartTimer() |
||||
for i := 0; i < b.N; i++ { |
||||
if _, err := ch.Write(input); err != nil { |
||||
b.Fatalf("WriteFull: %v", err) |
||||
} |
||||
} |
||||
ch.Close() |
||||
b.StopTimer() |
||||
|
||||
<-done |
||||
} |
@ -0,0 +1,98 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"io" |
||||
"sync" |
||||
) |
||||
|
||||
// buffer provides a linked list buffer for data exchange
|
||||
// between producer and consumer. Theoretically the buffer is
|
||||
// of unlimited capacity as it does no allocation of its own.
|
||||
type buffer struct { |
||||
// protects concurrent access to head, tail and closed
|
||||
*sync.Cond |
||||
|
||||
head *element // the buffer that will be read first
|
||||
tail *element // the buffer that will be read last
|
||||
|
||||
closed bool |
||||
} |
||||
|
||||
// An element represents a single link in a linked list.
|
||||
type element struct { |
||||
buf []byte |
||||
next *element |
||||
} |
||||
|
||||
// newBuffer returns an empty buffer that is not closed.
|
||||
func newBuffer() *buffer { |
||||
e := new(element) |
||||
b := &buffer{ |
||||
Cond: newCond(), |
||||
head: e, |
||||
tail: e, |
||||
} |
||||
return b |
||||
} |
||||
|
||||
// write makes buf available for Read to receive.
|
||||
// buf must not be modified after the call to write.
|
||||
func (b *buffer) write(buf []byte) { |
||||
b.Cond.L.Lock() |
||||
e := &element{buf: buf} |
||||
b.tail.next = e |
||||
b.tail = e |
||||
b.Cond.Signal() |
||||
b.Cond.L.Unlock() |
||||
} |
||||
|
||||
// eof closes the buffer. Reads from the buffer once all
|
||||
// the data has been consumed will receive os.EOF.
|
||||
func (b *buffer) eof() error { |
||||
b.Cond.L.Lock() |
||||
b.closed = true |
||||
b.Cond.Signal() |
||||
b.Cond.L.Unlock() |
||||
return nil |
||||
} |
||||
|
||||
// Read reads data from the internal buffer in buf. Reads will block
|
||||
// if no data is available, or until the buffer is closed.
|
||||
func (b *buffer) Read(buf []byte) (n int, err error) { |
||||
b.Cond.L.Lock() |
||||
defer b.Cond.L.Unlock() |
||||
|
||||
for len(buf) > 0 { |
||||
// if there is data in b.head, copy it
|
||||
if len(b.head.buf) > 0 { |
||||
r := copy(buf, b.head.buf) |
||||
buf, b.head.buf = buf[r:], b.head.buf[r:] |
||||
n += r |
||||
continue |
||||
} |
||||
// if there is a next buffer, make it the head
|
||||
if len(b.head.buf) == 0 && b.head != b.tail { |
||||
b.head = b.head.next |
||||
continue |
||||
} |
||||
|
||||
// if at least one byte has been copied, return
|
||||
if n > 0 { |
||||
break |
||||
} |
||||
|
||||
// if nothing was read, and there is nothing outstanding
|
||||
// check to see if the buffer is closed.
|
||||
if b.closed { |
||||
err = io.EOF |
||||
break |
||||
} |
||||
// out of buffers, wait for producer
|
||||
b.Cond.Wait() |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,87 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"io" |
||||
"testing" |
||||
) |
||||
|
||||
var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") |
||||
|
||||
func TestBufferReadwrite(t *testing.T) { |
||||
b := newBuffer() |
||||
b.write(alphabet[:10]) |
||||
r, _ := b.Read(make([]byte, 10)) |
||||
if r != 10 { |
||||
t.Fatalf("Expected written == read == 10, written: 10, read %d", r) |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(alphabet[:5]) |
||||
r, _ = b.Read(make([]byte, 10)) |
||||
if r != 5 { |
||||
t.Fatalf("Expected written == read == 5, written: 5, read %d", r) |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(alphabet[:10]) |
||||
r, _ = b.Read(make([]byte, 5)) |
||||
if r != 5 { |
||||
t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(alphabet[:5]) |
||||
b.write(alphabet[5:15]) |
||||
r, _ = b.Read(make([]byte, 10)) |
||||
r2, _ := b.Read(make([]byte, 10)) |
||||
if r != 10 || r2 != 5 || 15 != r+r2 { |
||||
t.Fatal("Expected written == read == 15") |
||||
} |
||||
} |
||||
|
||||
func TestBufferClose(t *testing.T) { |
||||
b := newBuffer() |
||||
b.write(alphabet[:10]) |
||||
b.eof() |
||||
_, err := b.Read(make([]byte, 5)) |
||||
if err != nil { |
||||
t.Fatal("expected read of 5 to not return EOF") |
||||
} |
||||
b = newBuffer() |
||||
b.write(alphabet[:10]) |
||||
b.eof() |
||||
r, err := b.Read(make([]byte, 5)) |
||||
r2, err2 := b.Read(make([]byte, 10)) |
||||
if r != 5 || r2 != 5 || err != nil || err2 != nil { |
||||
t.Fatal("expected reads of 5 and 5") |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(alphabet[:10]) |
||||
b.eof() |
||||
r, err = b.Read(make([]byte, 5)) |
||||
r2, err2 = b.Read(make([]byte, 10)) |
||||
r3, err3 := b.Read(make([]byte, 10)) |
||||
if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { |
||||
t.Fatal("expected reads of 5 and 5 and 0, with EOF") |
||||
} |
||||
|
||||
b = newBuffer() |
||||
b.write(make([]byte, 5)) |
||||
b.write(make([]byte, 10)) |
||||
b.eof() |
||||
r, err = b.Read(make([]byte, 9)) |
||||
r2, err2 = b.Read(make([]byte, 3)) |
||||
r3, err3 = b.Read(make([]byte, 3)) |
||||
r4, err4 := b.Read(make([]byte, 10)) |
||||
if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { |
||||
t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) |
||||
} |
||||
if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { |
||||
t.Fatal("Expected written == read == 15", r, r2, r3, r4) |
||||
} |
||||
} |
@ -0,0 +1,501 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"sort" |
||||
"time" |
||||
) |
||||
|
||||
// These constants from [PROTOCOL.certkeys] represent the algorithm names
|
||||
// for certificate types supported by this package.
|
||||
const ( |
||||
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" |
||||
CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" |
||||
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" |
||||
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" |
||||
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" |
||||
) |
||||
|
||||
// Certificate types distinguish between host and user
|
||||
// certificates. The values can be set in the CertType field of
|
||||
// Certificate.
|
||||
const ( |
||||
UserCert = 1 |
||||
HostCert = 2 |
||||
) |
||||
|
||||
// Signature represents a cryptographic signature.
|
||||
type Signature struct { |
||||
Format string |
||||
Blob []byte |
||||
} |
||||
|
||||
// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that
|
||||
// a certificate does not expire.
|
||||
const CertTimeInfinity = 1<<64 - 1 |
||||
|
||||
// An Certificate represents an OpenSSH certificate as defined in
|
||||
// [PROTOCOL.certkeys]?rev=1.8.
|
||||
type Certificate struct { |
||||
Nonce []byte |
||||
Key PublicKey |
||||
Serial uint64 |
||||
CertType uint32 |
||||
KeyId string |
||||
ValidPrincipals []string |
||||
ValidAfter uint64 |
||||
ValidBefore uint64 |
||||
Permissions |
||||
Reserved []byte |
||||
SignatureKey PublicKey |
||||
Signature *Signature |
||||
} |
||||
|
||||
// genericCertData holds the key-independent part of the certificate data.
|
||||
// Overall, certificates contain an nonce, public key fields and
|
||||
// key-independent fields.
|
||||
type genericCertData struct { |
||||
Serial uint64 |
||||
CertType uint32 |
||||
KeyId string |
||||
ValidPrincipals []byte |
||||
ValidAfter uint64 |
||||
ValidBefore uint64 |
||||
CriticalOptions []byte |
||||
Extensions []byte |
||||
Reserved []byte |
||||
SignatureKey []byte |
||||
Signature []byte |
||||
} |
||||
|
||||
func marshalStringList(namelist []string) []byte { |
||||
var to []byte |
||||
for _, name := range namelist { |
||||
s := struct{ N string }{name} |
||||
to = append(to, Marshal(&s)...) |
||||
} |
||||
return to |
||||
} |
||||
|
||||
type optionsTuple struct { |
||||
Key string |
||||
Value []byte |
||||
} |
||||
|
||||
type optionsTupleValue struct { |
||||
Value string |
||||
} |
||||
|
||||
// serialize a map of critical options or extensions
|
||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||
// we need two length prefixes for a non-empty string value
|
||||
func marshalTuples(tups map[string]string) []byte { |
||||
keys := make([]string, 0, len(tups)) |
||||
for key := range tups { |
||||
keys = append(keys, key) |
||||
} |
||||
sort.Strings(keys) |
||||
|
||||
var ret []byte |
||||
for _, key := range keys { |
||||
s := optionsTuple{Key: key} |
||||
if value := tups[key]; len(value) > 0 { |
||||
s.Value = Marshal(&optionsTupleValue{value}) |
||||
} |
||||
ret = append(ret, Marshal(&s)...) |
||||
} |
||||
return ret |
||||
} |
||||
|
||||
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
|
||||
// we need two length prefixes for a non-empty option value
|
||||
func parseTuples(in []byte) (map[string]string, error) { |
||||
tups := map[string]string{} |
||||
var lastKey string |
||||
var haveLastKey bool |
||||
|
||||
for len(in) > 0 { |
||||
var key, val, extra []byte |
||||
var ok bool |
||||
|
||||
if key, in, ok = parseString(in); !ok { |
||||
return nil, errShortRead |
||||
} |
||||
keyStr := string(key) |
||||
// according to [PROTOCOL.certkeys], the names must be in
|
||||
// lexical order.
|
||||
if haveLastKey && keyStr <= lastKey { |
||||
return nil, fmt.Errorf("ssh: certificate options are not in lexical order") |
||||
} |
||||
lastKey, haveLastKey = keyStr, true |
||||
// the next field is a data field, which if non-empty has a string embedded
|
||||
if val, in, ok = parseString(in); !ok { |
||||
return nil, errShortRead |
||||
} |
||||
if len(val) > 0 { |
||||
val, extra, ok = parseString(val) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
if len(extra) > 0 { |
||||
return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") |
||||
} |
||||
tups[keyStr] = string(val) |
||||
} else { |
||||
tups[keyStr] = "" |
||||
} |
||||
} |
||||
return tups, nil |
||||
} |
||||
|
||||
func parseCert(in []byte, privAlgo string) (*Certificate, error) { |
||||
nonce, rest, ok := parseString(in) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
|
||||
key, rest, err := parsePubKey(rest, privAlgo) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var g genericCertData |
||||
if err := Unmarshal(rest, &g); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c := &Certificate{ |
||||
Nonce: nonce, |
||||
Key: key, |
||||
Serial: g.Serial, |
||||
CertType: g.CertType, |
||||
KeyId: g.KeyId, |
||||
ValidAfter: g.ValidAfter, |
||||
ValidBefore: g.ValidBefore, |
||||
} |
||||
|
||||
for principals := g.ValidPrincipals; len(principals) > 0; { |
||||
principal, rest, ok := parseString(principals) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) |
||||
principals = rest |
||||
} |
||||
|
||||
c.CriticalOptions, err = parseTuples(g.CriticalOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.Extensions, err = parseTuples(g.Extensions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.Reserved = g.Reserved |
||||
k, err := ParsePublicKey(g.SignatureKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c.SignatureKey = k |
||||
c.Signature, rest, ok = parseSignatureBody(g.Signature) |
||||
if !ok || len(rest) > 0 { |
||||
return nil, errors.New("ssh: signature parse error") |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
type openSSHCertSigner struct { |
||||
pub *Certificate |
||||
signer Signer |
||||
} |
||||
|
||||
// NewCertSigner returns a Signer that signs with the given Certificate, whose
|
||||
// private key is held by signer. It returns an error if the public key in cert
|
||||
// doesn't match the key used by signer.
|
||||
func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { |
||||
if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { |
||||
return nil, errors.New("ssh: signer and cert have different public key") |
||||
} |
||||
|
||||
return &openSSHCertSigner{cert, signer}, nil |
||||
} |
||||
|
||||
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
return s.signer.Sign(rand, data) |
||||
} |
||||
|
||||
func (s *openSSHCertSigner) PublicKey() PublicKey { |
||||
return s.pub |
||||
} |
||||
|
||||
const sourceAddressCriticalOption = "source-address" |
||||
|
||||
// CertChecker does the work of verifying a certificate. Its methods
|
||||
// can be plugged into ClientConfig.HostKeyCallback and
|
||||
// ServerConfig.PublicKeyCallback. For the CertChecker to work,
|
||||
// minimally, the IsAuthority callback should be set.
|
||||
type CertChecker struct { |
||||
// SupportedCriticalOptions lists the CriticalOptions that the
|
||||
// server application layer understands. These are only used
|
||||
// for user certificates.
|
||||
SupportedCriticalOptions []string |
||||
|
||||
// IsAuthority should return true if the key is recognized as
|
||||
// an authority. This allows for certificates to be signed by other
|
||||
// certificates.
|
||||
IsAuthority func(auth PublicKey) bool |
||||
|
||||
// Clock is used for verifying time stamps. If nil, time.Now
|
||||
// is used.
|
||||
Clock func() time.Time |
||||
|
||||
// UserKeyFallback is called when CertChecker.Authenticate encounters a
|
||||
// public key that is not a certificate. It must implement validation
|
||||
// of user keys or else, if nil, all such keys are rejected.
|
||||
UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) |
||||
|
||||
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
|
||||
// public key that is not a certificate. It must implement host key
|
||||
// validation or else, if nil, all such keys are rejected.
|
||||
HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error |
||||
|
||||
// IsRevoked is called for each certificate so that revocation checking
|
||||
// can be implemented. It should return true if the given certificate
|
||||
// is revoked and false otherwise. If nil, no certificates are
|
||||
// considered to have been revoked.
|
||||
IsRevoked func(cert *Certificate) bool |
||||
} |
||||
|
||||
// CheckHostKey checks a host key certificate. This method can be
|
||||
// plugged into ClientConfig.HostKeyCallback.
|
||||
func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { |
||||
cert, ok := key.(*Certificate) |
||||
if !ok { |
||||
if c.HostKeyFallback != nil { |
||||
return c.HostKeyFallback(addr, remote, key) |
||||
} |
||||
return errors.New("ssh: non-certificate host key") |
||||
} |
||||
if cert.CertType != HostCert { |
||||
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) |
||||
} |
||||
|
||||
return c.CheckCert(addr, cert) |
||||
} |
||||
|
||||
// Authenticate checks a user certificate. Authenticate can be used as
|
||||
// a value for ServerConfig.PublicKeyCallback.
|
||||
func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { |
||||
cert, ok := pubKey.(*Certificate) |
||||
if !ok { |
||||
if c.UserKeyFallback != nil { |
||||
return c.UserKeyFallback(conn, pubKey) |
||||
} |
||||
return nil, errors.New("ssh: normal key pairs not accepted") |
||||
} |
||||
|
||||
if cert.CertType != UserCert { |
||||
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) |
||||
} |
||||
|
||||
if err := c.CheckCert(conn.User(), cert); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &cert.Permissions, nil |
||||
} |
||||
|
||||
// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and
|
||||
// the signature of the certificate.
|
||||
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { |
||||
if c.IsRevoked != nil && c.IsRevoked(cert) { |
||||
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) |
||||
} |
||||
|
||||
for opt, _ := range cert.CriticalOptions { |
||||
// sourceAddressCriticalOption will be enforced by
|
||||
// serverAuthenticate
|
||||
if opt == sourceAddressCriticalOption { |
||||
continue |
||||
} |
||||
|
||||
found := false |
||||
for _, supp := range c.SupportedCriticalOptions { |
||||
if supp == opt { |
||||
found = true |
||||
break |
||||
} |
||||
} |
||||
if !found { |
||||
return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) |
||||
} |
||||
} |
||||
|
||||
if len(cert.ValidPrincipals) > 0 { |
||||
// By default, certs are valid for all users/hosts.
|
||||
found := false |
||||
for _, p := range cert.ValidPrincipals { |
||||
if p == principal { |
||||
found = true |
||||
break |
||||
} |
||||
} |
||||
if !found { |
||||
return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) |
||||
} |
||||
} |
||||
|
||||
if !c.IsAuthority(cert.SignatureKey) { |
||||
return fmt.Errorf("ssh: certificate signed by unrecognized authority") |
||||
} |
||||
|
||||
clock := c.Clock |
||||
if clock == nil { |
||||
clock = time.Now |
||||
} |
||||
|
||||
unixNow := clock().Unix() |
||||
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { |
||||
return fmt.Errorf("ssh: cert is not yet valid") |
||||
} |
||||
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { |
||||
return fmt.Errorf("ssh: cert has expired") |
||||
} |
||||
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { |
||||
return fmt.Errorf("ssh: certificate signature does not verify") |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// SignCert sets c.SignatureKey to the authority's public key and stores a
|
||||
// Signature, by authority, in the certificate.
|
||||
func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { |
||||
c.Nonce = make([]byte, 32) |
||||
if _, err := io.ReadFull(rand, c.Nonce); err != nil { |
||||
return err |
||||
} |
||||
c.SignatureKey = authority.PublicKey() |
||||
|
||||
sig, err := authority.Sign(rand, c.bytesForSigning()) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
c.Signature = sig |
||||
return nil |
||||
} |
||||
|
||||
var certAlgoNames = map[string]string{ |
||||
KeyAlgoRSA: CertAlgoRSAv01, |
||||
KeyAlgoDSA: CertAlgoDSAv01, |
||||
KeyAlgoECDSA256: CertAlgoECDSA256v01, |
||||
KeyAlgoECDSA384: CertAlgoECDSA384v01, |
||||
KeyAlgoECDSA521: CertAlgoECDSA521v01, |
||||
} |
||||
|
||||
// certToPrivAlgo returns the underlying algorithm for a certificate algorithm.
|
||||
// Panics if a non-certificate algorithm is passed.
|
||||
func certToPrivAlgo(algo string) string { |
||||
for privAlgo, pubAlgo := range certAlgoNames { |
||||
if pubAlgo == algo { |
||||
return privAlgo |
||||
} |
||||
} |
||||
panic("unknown cert algorithm") |
||||
} |
||||
|
||||
func (cert *Certificate) bytesForSigning() []byte { |
||||
c2 := *cert |
||||
c2.Signature = nil |
||||
out := c2.Marshal() |
||||
// Drop trailing signature length.
|
||||
return out[:len(out)-4] |
||||
} |
||||
|
||||
// Marshal serializes c into OpenSSH's wire format. It is part of the
|
||||
// PublicKey interface.
|
||||
func (c *Certificate) Marshal() []byte { |
||||
generic := genericCertData{ |
||||
Serial: c.Serial, |
||||
CertType: c.CertType, |
||||
KeyId: c.KeyId, |
||||
ValidPrincipals: marshalStringList(c.ValidPrincipals), |
||||
ValidAfter: uint64(c.ValidAfter), |
||||
ValidBefore: uint64(c.ValidBefore), |
||||
CriticalOptions: marshalTuples(c.CriticalOptions), |
||||
Extensions: marshalTuples(c.Extensions), |
||||
Reserved: c.Reserved, |
||||
SignatureKey: c.SignatureKey.Marshal(), |
||||
} |
||||
if c.Signature != nil { |
||||
generic.Signature = Marshal(c.Signature) |
||||
} |
||||
genericBytes := Marshal(&generic) |
||||
keyBytes := c.Key.Marshal() |
||||
_, keyBytes, _ = parseString(keyBytes) |
||||
prefix := Marshal(&struct { |
||||
Name string |
||||
Nonce []byte |
||||
Key []byte `ssh:"rest"` |
||||
}{c.Type(), c.Nonce, keyBytes}) |
||||
|
||||
result := make([]byte, 0, len(prefix)+len(genericBytes)) |
||||
result = append(result, prefix...) |
||||
result = append(result, genericBytes...) |
||||
return result |
||||
} |
||||
|
||||
// Type returns the key name. It is part of the PublicKey interface.
|
||||
func (c *Certificate) Type() string { |
||||
algo, ok := certAlgoNames[c.Key.Type()] |
||||
if !ok { |
||||
panic("unknown cert key type") |
||||
} |
||||
return algo |
||||
} |
||||
|
||||
// Verify verifies a signature against the certificate's public
|
||||
// key. It is part of the PublicKey interface.
|
||||
func (c *Certificate) Verify(data []byte, sig *Signature) error { |
||||
return c.Key.Verify(data, sig) |
||||
} |
||||
|
||||
func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { |
||||
format, in, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
|
||||
out = &Signature{ |
||||
Format: string(format), |
||||
} |
||||
|
||||
if out.Blob, in, ok = parseString(in); !ok { |
||||
return |
||||
} |
||||
|
||||
return out, in, ok |
||||
} |
||||
|
||||
func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { |
||||
sigBytes, rest, ok := parseString(in) |
||||
if !ok { |
||||
return |
||||
} |
||||
|
||||
out, trailing, ok := parseSignatureBody(sigBytes) |
||||
if !ok || len(trailing) > 0 { |
||||
return nil, nil, false |
||||
} |
||||
return |
||||
} |
@ -0,0 +1,216 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
// Cert generated by ssh-keygen 6.0p1 Debian-4.
|
||||
// % ssh-keygen -s ca-key -I test user-key
|
||||
const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=` |
||||
|
||||
func TestParseCert(t *testing.T) { |
||||
authKeyBytes := []byte(exampleSSHCert) |
||||
|
||||
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) |
||||
if err != nil { |
||||
t.Fatalf("ParseAuthorizedKey: %v", err) |
||||
} |
||||
if len(rest) > 0 { |
||||
t.Errorf("rest: got %q, want empty", rest) |
||||
} |
||||
|
||||
if _, ok := key.(*Certificate); !ok { |
||||
t.Fatalf("got %v (%T), want *Certificate", key, key) |
||||
} |
||||
|
||||
marshaled := MarshalAuthorizedKey(key) |
||||
// Before comparison, remove the trailing newline that
|
||||
// MarshalAuthorizedKey adds.
|
||||
marshaled = marshaled[:len(marshaled)-1] |
||||
if !bytes.Equal(authKeyBytes, marshaled) { |
||||
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) |
||||
} |
||||
} |
||||
|
||||
// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3
|
||||
// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub
|
||||
// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN
|
||||
// Critical Options:
|
||||
// force-command /bin/sleep
|
||||
// source-address 192.168.1.0/24
|
||||
// Extensions:
|
||||
// permit-X11-forwarding
|
||||
// permit-agent-forwarding
|
||||
// permit-port-forwarding
|
||||
// permit-pty
|
||||
// permit-user-rc
|
||||
const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` |
||||
|
||||
func TestParseCertWithOptions(t *testing.T) { |
||||
opts := map[string]string{ |
||||
"source-address": "192.168.1.0/24", |
||||
"force-command": "/bin/sleep", |
||||
} |
||||
exts := map[string]string{ |
||||
"permit-X11-forwarding": "", |
||||
"permit-agent-forwarding": "", |
||||
"permit-port-forwarding": "", |
||||
"permit-pty": "", |
||||
"permit-user-rc": "", |
||||
} |
||||
authKeyBytes := []byte(exampleSSHCertWithOptions) |
||||
|
||||
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) |
||||
if err != nil { |
||||
t.Fatalf("ParseAuthorizedKey: %v", err) |
||||
} |
||||
if len(rest) > 0 { |
||||
t.Errorf("rest: got %q, want empty", rest) |
||||
} |
||||
cert, ok := key.(*Certificate) |
||||
if !ok { |
||||
t.Fatalf("got %v (%T), want *Certificate", key, key) |
||||
} |
||||
if !reflect.DeepEqual(cert.CriticalOptions, opts) { |
||||
t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) |
||||
} |
||||
if !reflect.DeepEqual(cert.Extensions, exts) { |
||||
t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) |
||||
} |
||||
marshaled := MarshalAuthorizedKey(key) |
||||
// Before comparison, remove the trailing newline that
|
||||
// MarshalAuthorizedKey adds.
|
||||
marshaled = marshaled[:len(marshaled)-1] |
||||
if !bytes.Equal(authKeyBytes, marshaled) { |
||||
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) |
||||
} |
||||
} |
||||
|
||||
func TestValidateCert(t *testing.T) { |
||||
key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert)) |
||||
if err != nil { |
||||
t.Fatalf("ParseAuthorizedKey: %v", err) |
||||
} |
||||
validCert, ok := key.(*Certificate) |
||||
if !ok { |
||||
t.Fatalf("got %v (%T), want *Certificate", key, key) |
||||
} |
||||
checker := CertChecker{} |
||||
checker.IsAuthority = func(k PublicKey) bool { |
||||
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) |
||||
} |
||||
|
||||
if err := checker.CheckCert("user", validCert); err != nil { |
||||
t.Errorf("Unable to validate certificate: %v", err) |
||||
} |
||||
invalidCert := &Certificate{ |
||||
Key: testPublicKeys["rsa"], |
||||
SignatureKey: testPublicKeys["ecdsa"], |
||||
ValidBefore: CertTimeInfinity, |
||||
Signature: &Signature{}, |
||||
} |
||||
if err := checker.CheckCert("user", invalidCert); err == nil { |
||||
t.Error("Invalid cert signature passed validation") |
||||
} |
||||
} |
||||
|
||||
func TestValidateCertTime(t *testing.T) { |
||||
cert := Certificate{ |
||||
ValidPrincipals: []string{"user"}, |
||||
Key: testPublicKeys["rsa"], |
||||
ValidAfter: 50, |
||||
ValidBefore: 100, |
||||
} |
||||
|
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
|
||||
for ts, ok := range map[int64]bool{ |
||||
25: false, |
||||
50: true, |
||||
99: true, |
||||
100: false, |
||||
125: false, |
||||
} { |
||||
checker := CertChecker{ |
||||
Clock: func() time.Time { return time.Unix(ts, 0) }, |
||||
} |
||||
checker.IsAuthority = func(k PublicKey) bool { |
||||
return bytes.Equal(k.Marshal(), |
||||
testPublicKeys["ecdsa"].Marshal()) |
||||
} |
||||
|
||||
if v := checker.CheckCert("user", &cert); (v == nil) != ok { |
||||
t.Errorf("Authenticate(%d): %v", ts, v) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// TODO(hanwen): tests for
|
||||
//
|
||||
// host keys:
|
||||
// * fallbacks
|
||||
|
||||
func TestHostKeyCert(t *testing.T) { |
||||
cert := &Certificate{ |
||||
ValidPrincipals: []string{"hostname", "hostname.domain"}, |
||||
Key: testPublicKeys["rsa"], |
||||
ValidBefore: CertTimeInfinity, |
||||
CertType: HostCert, |
||||
} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
|
||||
checker := &CertChecker{ |
||||
IsAuthority: func(p PublicKey) bool { |
||||
return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) |
||||
}, |
||||
} |
||||
|
||||
certSigner, err := NewCertSigner(cert, testSigners["rsa"]) |
||||
if err != nil { |
||||
t.Errorf("NewCertSigner: %v", err) |
||||
} |
||||
|
||||
for _, name := range []string{"hostname", "otherhost"} { |
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
defer c1.Close() |
||||
defer c2.Close() |
||||
|
||||
errc := make(chan error) |
||||
|
||||
go func() { |
||||
conf := ServerConfig{ |
||||
NoClientAuth: true, |
||||
} |
||||
conf.AddHostKey(certSigner) |
||||
_, _, _, err := NewServerConn(c1, &conf) |
||||
errc <- err |
||||
}() |
||||
|
||||
config := &ClientConfig{ |
||||
User: "user", |
||||
HostKeyCallback: checker.CheckHostKey, |
||||
} |
||||
_, _, _, err = NewClientConn(c2, name, config) |
||||
|
||||
succeed := name == "hostname" |
||||
if (err == nil) != succeed { |
||||
t.Fatalf("NewClientConn(%q): %v", name, err) |
||||
} |
||||
|
||||
err = <-errc |
||||
if (err == nil) != succeed { |
||||
t.Fatalf("NewServerConn(%q): %v", name, err) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,631 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"sync" |
||||
) |
||||
|
||||
const ( |
||||
minPacketLength = 9 |
||||
// channelMaxPacket contains the maximum number of bytes that will be
|
||||
// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
|
||||
// the minimum.
|
||||
channelMaxPacket = 1 << 15 |
||||
// We follow OpenSSH here.
|
||||
channelWindowSize = 64 * channelMaxPacket |
||||
) |
||||
|
||||
// NewChannel represents an incoming request to a channel. It must either be
|
||||
// accepted for use by calling Accept, or rejected by calling Reject.
|
||||
type NewChannel interface { |
||||
// Accept accepts the channel creation request. It returns the Channel
|
||||
// and a Go channel containing SSH requests. The Go channel must be
|
||||
// serviced otherwise the Channel will hang.
|
||||
Accept() (Channel, <-chan *Request, error) |
||||
|
||||
// Reject rejects the channel creation request. After calling
|
||||
// this, no other methods on the Channel may be called.
|
||||
Reject(reason RejectionReason, message string) error |
||||
|
||||
// ChannelType returns the type of the channel, as supplied by the
|
||||
// client.
|
||||
ChannelType() string |
||||
|
||||
// ExtraData returns the arbitrary payload for this channel, as supplied
|
||||
// by the client. This data is specific to the channel type.
|
||||
ExtraData() []byte |
||||
} |
||||
|
||||
// A Channel is an ordered, reliable, flow-controlled, duplex stream
|
||||
// that is multiplexed over an SSH connection.
|
||||
type Channel interface { |
||||
// Read reads up to len(data) bytes from the channel.
|
||||
Read(data []byte) (int, error) |
||||
|
||||
// Write writes len(data) bytes to the channel.
|
||||
Write(data []byte) (int, error) |
||||
|
||||
// Close signals end of channel use. No data may be sent after this
|
||||
// call.
|
||||
Close() error |
||||
|
||||
// CloseWrite signals the end of sending in-band
|
||||
// data. Requests may still be sent, and the other side may
|
||||
// still send data
|
||||
CloseWrite() error |
||||
|
||||
// SendRequest sends a channel request. If wantReply is true,
|
||||
// it will wait for a reply and return the result as a
|
||||
// boolean, otherwise the return value will be false. Channel
|
||||
// requests are out-of-band messages so they may be sent even
|
||||
// if the data stream is closed or blocked by flow control.
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, error) |
||||
|
||||
// Stderr returns an io.ReadWriter that writes to this channel
|
||||
// with the extended data type set to stderr. Stderr may
|
||||
// safely be read and written from a different goroutine than
|
||||
// Read and Write respectively.
|
||||
Stderr() io.ReadWriter |
||||
} |
||||
|
||||
// Request is a request sent outside of the normal stream of
|
||||
// data. Requests can either be specific to an SSH channel, or they
|
||||
// can be global.
|
||||
type Request struct { |
||||
Type string |
||||
WantReply bool |
||||
Payload []byte |
||||
|
||||
ch *channel |
||||
mux *mux |
||||
} |
||||
|
||||
// Reply sends a response to a request. It must be called for all requests
|
||||
// where WantReply is true and is a no-op otherwise. The payload argument is
|
||||
// ignored for replies to channel-specific requests.
|
||||
func (r *Request) Reply(ok bool, payload []byte) error { |
||||
if !r.WantReply { |
||||
return nil |
||||
} |
||||
|
||||
if r.ch == nil { |
||||
return r.mux.ackRequest(ok, payload) |
||||
} |
||||
|
||||
return r.ch.ackRequest(ok) |
||||
} |
||||
|
||||
// RejectionReason is an enumeration used when rejecting channel creation
|
||||
// requests. See RFC 4254, section 5.1.
|
||||
type RejectionReason uint32 |
||||
|
||||
const ( |
||||
Prohibited RejectionReason = iota + 1 |
||||
ConnectionFailed |
||||
UnknownChannelType |
||||
ResourceShortage |
||||
) |
||||
|
||||
// String converts the rejection reason to human readable form.
|
||||
func (r RejectionReason) String() string { |
||||
switch r { |
||||
case Prohibited: |
||||
return "administratively prohibited" |
||||
case ConnectionFailed: |
||||
return "connect failed" |
||||
case UnknownChannelType: |
||||
return "unknown channel type" |
||||
case ResourceShortage: |
||||
return "resource shortage" |
||||
} |
||||
return fmt.Sprintf("unknown reason %d", int(r)) |
||||
} |
||||
|
||||
func min(a uint32, b int) uint32 { |
||||
if a < uint32(b) { |
||||
return a |
||||
} |
||||
return uint32(b) |
||||
} |
||||
|
||||
type channelDirection uint8 |
||||
|
||||
const ( |
||||
channelInbound channelDirection = iota |
||||
channelOutbound |
||||
) |
||||
|
||||
// channel is an implementation of the Channel interface that works
|
||||
// with the mux class.
|
||||
type channel struct { |
||||
// R/O after creation
|
||||
chanType string |
||||
extraData []byte |
||||
localId, remoteId uint32 |
||||
|
||||
// maxIncomingPayload and maxRemotePayload are the maximum
|
||||
// payload sizes of normal and extended data packets for
|
||||
// receiving and sending, respectively. The wire packet will
|
||||
// be 9 or 13 bytes larger (excluding encryption overhead).
|
||||
maxIncomingPayload uint32 |
||||
maxRemotePayload uint32 |
||||
|
||||
mux *mux |
||||
|
||||
// decided is set to true if an accept or reject message has been sent
|
||||
// (for outbound channels) or received (for inbound channels).
|
||||
decided bool |
||||
|
||||
// direction contains either channelOutbound, for channels created
|
||||
// locally, or channelInbound, for channels created by the peer.
|
||||
direction channelDirection |
||||
|
||||
// Pending internal channel messages.
|
||||
msg chan interface{} |
||||
|
||||
// Since requests have no ID, there can be only one request
|
||||
// with WantReply=true outstanding. This lock is held by a
|
||||
// goroutine that has such an outgoing request pending.
|
||||
sentRequestMu sync.Mutex |
||||
|
||||
incomingRequests chan *Request |
||||
|
||||
sentEOF bool |
||||
|
||||
// thread-safe data
|
||||
remoteWin window |
||||
pending *buffer |
||||
extPending *buffer |
||||
|
||||
// windowMu protects myWindow, the flow-control window.
|
||||
windowMu sync.Mutex |
||||
myWindow uint32 |
||||
|
||||
// writeMu serializes calls to mux.conn.writePacket() and
|
||||
// protects sentClose and packetPool. This mutex must be
|
||||
// different from windowMu, as writePacket can block if there
|
||||
// is a key exchange pending.
|
||||
writeMu sync.Mutex |
||||
sentClose bool |
||||
|
||||
// packetPool has a buffer for each extended channel ID to
|
||||
// save allocations during writes.
|
||||
packetPool map[uint32][]byte |
||||
} |
||||
|
||||
// writePacket sends a packet. If the packet is a channel close, it updates
|
||||
// sentClose. This method takes the lock c.writeMu.
|
||||
func (c *channel) writePacket(packet []byte) error { |
||||
c.writeMu.Lock() |
||||
if c.sentClose { |
||||
c.writeMu.Unlock() |
||||
return io.EOF |
||||
} |
||||
c.sentClose = (packet[0] == msgChannelClose) |
||||
err := c.mux.conn.writePacket(packet) |
||||
c.writeMu.Unlock() |
||||
return err |
||||
} |
||||
|
||||
func (c *channel) sendMessage(msg interface{}) error { |
||||
if debugMux { |
||||
log.Printf("send %d: %#v", c.mux.chanList.offset, msg) |
||||
} |
||||
|
||||
p := Marshal(msg) |
||||
binary.BigEndian.PutUint32(p[1:], c.remoteId) |
||||
return c.writePacket(p) |
||||
} |
||||
|
||||
// WriteExtended writes data to a specific extended stream. These streams are
|
||||
// used, for example, for stderr.
|
||||
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { |
||||
if c.sentEOF { |
||||
return 0, io.EOF |
||||
} |
||||
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
|
||||
opCode := byte(msgChannelData) |
||||
headerLength := uint32(9) |
||||
if extendedCode > 0 { |
||||
headerLength += 4 |
||||
opCode = msgChannelExtendedData |
||||
} |
||||
|
||||
c.writeMu.Lock() |
||||
packet := c.packetPool[extendedCode] |
||||
// We don't remove the buffer from packetPool, so
|
||||
// WriteExtended calls from different goroutines will be
|
||||
// flagged as errors by the race detector.
|
||||
c.writeMu.Unlock() |
||||
|
||||
for len(data) > 0 { |
||||
space := min(c.maxRemotePayload, len(data)) |
||||
if space, err = c.remoteWin.reserve(space); err != nil { |
||||
return n, err |
||||
} |
||||
if want := headerLength + space; uint32(cap(packet)) < want { |
||||
packet = make([]byte, want) |
||||
} else { |
||||
packet = packet[:want] |
||||
} |
||||
|
||||
todo := data[:space] |
||||
|
||||
packet[0] = opCode |
||||
binary.BigEndian.PutUint32(packet[1:], c.remoteId) |
||||
if extendedCode > 0 { |
||||
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) |
||||
} |
||||
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) |
||||
copy(packet[headerLength:], todo) |
||||
if err = c.writePacket(packet); err != nil { |
||||
return n, err |
||||
} |
||||
|
||||
n += len(todo) |
||||
data = data[len(todo):] |
||||
} |
||||
|
||||
c.writeMu.Lock() |
||||
c.packetPool[extendedCode] = packet |
||||
c.writeMu.Unlock() |
||||
|
||||
return n, err |
||||
} |
||||
|
||||
func (c *channel) handleData(packet []byte) error { |
||||
headerLen := 9 |
||||
isExtendedData := packet[0] == msgChannelExtendedData |
||||
if isExtendedData { |
||||
headerLen = 13 |
||||
} |
||||
if len(packet) < headerLen { |
||||
// malformed data packet
|
||||
return parseError(packet[0]) |
||||
} |
||||
|
||||
var extended uint32 |
||||
if isExtendedData { |
||||
extended = binary.BigEndian.Uint32(packet[5:]) |
||||
} |
||||
|
||||
length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) |
||||
if length == 0 { |
||||
return nil |
||||
} |
||||
if length > c.maxIncomingPayload { |
||||
// TODO(hanwen): should send Disconnect?
|
||||
return errors.New("ssh: incoming packet exceeds maximum payload size") |
||||
} |
||||
|
||||
data := packet[headerLen:] |
||||
if length != uint32(len(data)) { |
||||
return errors.New("ssh: wrong packet length") |
||||
} |
||||
|
||||
c.windowMu.Lock() |
||||
if c.myWindow < length { |
||||
c.windowMu.Unlock() |
||||
// TODO(hanwen): should send Disconnect with reason?
|
||||
return errors.New("ssh: remote side wrote too much") |
||||
} |
||||
c.myWindow -= length |
||||
c.windowMu.Unlock() |
||||
|
||||
if extended == 1 { |
||||
c.extPending.write(data) |
||||
} else if extended > 0 { |
||||
// discard other extended data.
|
||||
} else { |
||||
c.pending.write(data) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *channel) adjustWindow(n uint32) error { |
||||
c.windowMu.Lock() |
||||
// Since myWindow is managed on our side, and can never exceed
|
||||
// the initial window setting, we don't worry about overflow.
|
||||
c.myWindow += uint32(n) |
||||
c.windowMu.Unlock() |
||||
return c.sendMessage(windowAdjustMsg{ |
||||
AdditionalBytes: uint32(n), |
||||
}) |
||||
} |
||||
|
||||
func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { |
||||
switch extended { |
||||
case 1: |
||||
n, err = c.extPending.Read(data) |
||||
case 0: |
||||
n, err = c.pending.Read(data) |
||||
default: |
||||
return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) |
||||
} |
||||
|
||||
if n > 0 { |
||||
err = c.adjustWindow(uint32(n)) |
||||
// sendWindowAdjust can return io.EOF if the remote
|
||||
// peer has closed the connection, however we want to
|
||||
// defer forwarding io.EOF to the caller of Read until
|
||||
// the buffer has been drained.
|
||||
if n > 0 && err == io.EOF { |
||||
err = nil |
||||
} |
||||
} |
||||
|
||||
return n, err |
||||
} |
||||
|
||||
func (c *channel) close() { |
||||
c.pending.eof() |
||||
c.extPending.eof() |
||||
close(c.msg) |
||||
close(c.incomingRequests) |
||||
c.writeMu.Lock() |
||||
// This is not necesary for a normal channel teardown, but if
|
||||
// there was another error, it is.
|
||||
c.sentClose = true |
||||
c.writeMu.Unlock() |
||||
// Unblock writers.
|
||||
c.remoteWin.close() |
||||
} |
||||
|
||||
// responseMessageReceived is called when a success or failure message is
|
||||
// received on a channel to check that such a message is reasonable for the
|
||||
// given channel.
|
||||
func (c *channel) responseMessageReceived() error { |
||||
if c.direction == channelInbound { |
||||
return errors.New("ssh: channel response message received on inbound channel") |
||||
} |
||||
if c.decided { |
||||
return errors.New("ssh: duplicate response received for channel") |
||||
} |
||||
c.decided = true |
||||
return nil |
||||
} |
||||
|
||||
func (c *channel) handlePacket(packet []byte) error { |
||||
switch packet[0] { |
||||
case msgChannelData, msgChannelExtendedData: |
||||
return c.handleData(packet) |
||||
case msgChannelClose: |
||||
c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) |
||||
c.mux.chanList.remove(c.localId) |
||||
c.close() |
||||
return nil |
||||
case msgChannelEOF: |
||||
// RFC 4254 is mute on how EOF affects dataExt messages but
|
||||
// it is logical to signal EOF at the same time.
|
||||
c.extPending.eof() |
||||
c.pending.eof() |
||||
return nil |
||||
} |
||||
|
||||
decoded, err := decode(packet) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
switch msg := decoded.(type) { |
||||
case *channelOpenFailureMsg: |
||||
if err := c.responseMessageReceived(); err != nil { |
||||
return err |
||||
} |
||||
c.mux.chanList.remove(msg.PeersId) |
||||
c.msg <- msg |
||||
case *channelOpenConfirmMsg: |
||||
if err := c.responseMessageReceived(); err != nil { |
||||
return err |
||||
} |
||||
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
||||
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) |
||||
} |
||||
c.remoteId = msg.MyId |
||||
c.maxRemotePayload = msg.MaxPacketSize |
||||
c.remoteWin.add(msg.MyWindow) |
||||
c.msg <- msg |
||||
case *windowAdjustMsg: |
||||
if !c.remoteWin.add(msg.AdditionalBytes) { |
||||
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) |
||||
} |
||||
case *channelRequestMsg: |
||||
req := Request{ |
||||
Type: msg.Request, |
||||
WantReply: msg.WantReply, |
||||
Payload: msg.RequestSpecificData, |
||||
ch: c, |
||||
} |
||||
|
||||
c.incomingRequests <- &req |
||||
default: |
||||
c.msg <- msg |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { |
||||
ch := &channel{ |
||||
remoteWin: window{Cond: newCond()}, |
||||
myWindow: channelWindowSize, |
||||
pending: newBuffer(), |
||||
extPending: newBuffer(), |
||||
direction: direction, |
||||
incomingRequests: make(chan *Request, 16), |
||||
msg: make(chan interface{}, 16), |
||||
chanType: chanType, |
||||
extraData: extraData, |
||||
mux: m, |
||||
packetPool: make(map[uint32][]byte), |
||||
} |
||||
ch.localId = m.chanList.add(ch) |
||||
return ch |
||||
} |
||||
|
||||
var errUndecided = errors.New("ssh: must Accept or Reject channel") |
||||
var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") |
||||
|
||||
type extChannel struct { |
||||
code uint32 |
||||
ch *channel |
||||
} |
||||
|
||||
func (e *extChannel) Write(data []byte) (n int, err error) { |
||||
return e.ch.WriteExtended(data, e.code) |
||||
} |
||||
|
||||
func (e *extChannel) Read(data []byte) (n int, err error) { |
||||
return e.ch.ReadExtended(data, e.code) |
||||
} |
||||
|
||||
func (c *channel) Accept() (Channel, <-chan *Request, error) { |
||||
if c.decided { |
||||
return nil, nil, errDecidedAlready |
||||
} |
||||
c.maxIncomingPayload = channelMaxPacket |
||||
confirm := channelOpenConfirmMsg{ |
||||
PeersId: c.remoteId, |
||||
MyId: c.localId, |
||||
MyWindow: c.myWindow, |
||||
MaxPacketSize: c.maxIncomingPayload, |
||||
} |
||||
c.decided = true |
||||
if err := c.sendMessage(confirm); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return c, c.incomingRequests, nil |
||||
} |
||||
|
||||
func (ch *channel) Reject(reason RejectionReason, message string) error { |
||||
if ch.decided { |
||||
return errDecidedAlready |
||||
} |
||||
reject := channelOpenFailureMsg{ |
||||
PeersId: ch.remoteId, |
||||
Reason: reason, |
||||
Message: message, |
||||
Language: "en", |
||||
} |
||||
ch.decided = true |
||||
return ch.sendMessage(reject) |
||||
} |
||||
|
||||
func (ch *channel) Read(data []byte) (int, error) { |
||||
if !ch.decided { |
||||
return 0, errUndecided |
||||
} |
||||
return ch.ReadExtended(data, 0) |
||||
} |
||||
|
||||
func (ch *channel) Write(data []byte) (int, error) { |
||||
if !ch.decided { |
||||
return 0, errUndecided |
||||
} |
||||
return ch.WriteExtended(data, 0) |
||||
} |
||||
|
||||
func (ch *channel) CloseWrite() error { |
||||
if !ch.decided { |
||||
return errUndecided |
||||
} |
||||
ch.sentEOF = true |
||||
return ch.sendMessage(channelEOFMsg{ |
||||
PeersId: ch.remoteId}) |
||||
} |
||||
|
||||
func (ch *channel) Close() error { |
||||
if !ch.decided { |
||||
return errUndecided |
||||
} |
||||
|
||||
return ch.sendMessage(channelCloseMsg{ |
||||
PeersId: ch.remoteId}) |
||||
} |
||||
|
||||
// Extended returns an io.ReadWriter that sends and receives data on the given,
|
||||
// SSH extended stream. Such streams are used, for example, for stderr.
|
||||
func (ch *channel) Extended(code uint32) io.ReadWriter { |
||||
if !ch.decided { |
||||
return nil |
||||
} |
||||
return &extChannel{code, ch} |
||||
} |
||||
|
||||
func (ch *channel) Stderr() io.ReadWriter { |
||||
return ch.Extended(1) |
||||
} |
||||
|
||||
func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { |
||||
if !ch.decided { |
||||
return false, errUndecided |
||||
} |
||||
|
||||
if wantReply { |
||||
ch.sentRequestMu.Lock() |
||||
defer ch.sentRequestMu.Unlock() |
||||
} |
||||
|
||||
msg := channelRequestMsg{ |
||||
PeersId: ch.remoteId, |
||||
Request: name, |
||||
WantReply: wantReply, |
||||
RequestSpecificData: payload, |
||||
} |
||||
|
||||
if err := ch.sendMessage(msg); err != nil { |
||||
return false, err |
||||
} |
||||
|
||||
if wantReply { |
||||
m, ok := (<-ch.msg) |
||||
if !ok { |
||||
return false, io.EOF |
||||
} |
||||
switch m.(type) { |
||||
case *channelRequestFailureMsg: |
||||
return false, nil |
||||
case *channelRequestSuccessMsg: |
||||
return true, nil |
||||
default: |
||||
return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) |
||||
} |
||||
} |
||||
|
||||
return false, nil |
||||
} |
||||
|
||||
// ackRequest either sends an ack or nack to the channel request.
|
||||
func (ch *channel) ackRequest(ok bool) error { |
||||
if !ch.decided { |
||||
return errUndecided |
||||
} |
||||
|
||||
var msg interface{} |
||||
if !ok { |
||||
msg = channelRequestFailureMsg{ |
||||
PeersId: ch.remoteId, |
||||
} |
||||
} else { |
||||
msg = channelRequestSuccessMsg{ |
||||
PeersId: ch.remoteId, |
||||
} |
||||
} |
||||
return ch.sendMessage(msg) |
||||
} |
||||
|
||||
func (ch *channel) ChannelType() string { |
||||
return ch.chanType |
||||
} |
||||
|
||||
func (ch *channel) ExtraData() []byte { |
||||
return ch.extraData |
||||
} |
@ -0,0 +1,549 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto/aes" |
||||
"crypto/cipher" |
||||
"crypto/rc4" |
||||
"crypto/subtle" |
||||
"encoding/binary" |
||||
"errors" |
||||
"fmt" |
||||
"hash" |
||||
"io" |
||||
"io/ioutil" |
||||
) |
||||
|
||||
const ( |
||||
packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher.
|
||||
|
||||
// RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations
|
||||
// MUST be able to process (plus a few more kilobytes for padding and mac). The RFC
|
||||
// indicates implementations SHOULD be able to handle larger packet sizes, but then
|
||||
// waffles on about reasonable limits.
|
||||
//
|
||||
// OpenSSH caps their maxPacket at 256kB so we choose to do
|
||||
// the same. maxPacket is also used to ensure that uint32
|
||||
// length fields do not overflow, so it should remain well
|
||||
// below 4G.
|
||||
maxPacket = 256 * 1024 |
||||
) |
||||
|
||||
// noneCipher implements cipher.Stream and provides no encryption. It is used
|
||||
// by the transport before the first key-exchange.
|
||||
type noneCipher struct{} |
||||
|
||||
func (c noneCipher) XORKeyStream(dst, src []byte) { |
||||
copy(dst, src) |
||||
} |
||||
|
||||
func newAESCTR(key, iv []byte) (cipher.Stream, error) { |
||||
c, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return cipher.NewCTR(c, iv), nil |
||||
} |
||||
|
||||
func newRC4(key, iv []byte) (cipher.Stream, error) { |
||||
return rc4.NewCipher(key) |
||||
} |
||||
|
||||
type streamCipherMode struct { |
||||
keySize int |
||||
ivSize int |
||||
skip int |
||||
createFunc func(key, iv []byte) (cipher.Stream, error) |
||||
} |
||||
|
||||
func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) { |
||||
if len(key) < c.keySize { |
||||
panic("ssh: key length too small for cipher") |
||||
} |
||||
if len(iv) < c.ivSize { |
||||
panic("ssh: iv too small for cipher") |
||||
} |
||||
|
||||
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var streamDump []byte |
||||
if c.skip > 0 { |
||||
streamDump = make([]byte, 512) |
||||
} |
||||
|
||||
for remainingToDump := c.skip; remainingToDump > 0; { |
||||
dumpThisTime := remainingToDump |
||||
if dumpThisTime > len(streamDump) { |
||||
dumpThisTime = len(streamDump) |
||||
} |
||||
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) |
||||
remainingToDump -= dumpThisTime |
||||
} |
||||
|
||||
return stream, nil |
||||
} |
||||
|
||||
// cipherModes documents properties of supported ciphers. Ciphers not included
|
||||
// are not supported and will not be negotiated, even if explicitly requested in
|
||||
// ClientConfig.Crypto.Ciphers.
|
||||
var cipherModes = map[string]*streamCipherMode{ |
||||
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
|
||||
// are defined in the order specified in the RFC.
|
||||
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR}, |
||||
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR}, |
||||
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR}, |
||||
|
||||
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
|
||||
// They are defined in the order specified in the RFC.
|
||||
"arcfour128": {16, 0, 1536, newRC4}, |
||||
"arcfour256": {32, 0, 1536, newRC4}, |
||||
|
||||
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
|
||||
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
|
||||
// RC4) has problems with weak keys, and should be used with caution."
|
||||
// RFC4345 introduces improved versions of Arcfour.
|
||||
"arcfour": {16, 0, 0, newRC4}, |
||||
|
||||
// AES-GCM is not a stream cipher, so it is constructed with a
|
||||
// special case. If we add any more non-stream ciphers, we
|
||||
// should invest a cleaner way to do this.
|
||||
gcmCipherID: {16, 12, 0, nil}, |
||||
|
||||
// insecure cipher, see http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf
|
||||
// uncomment below to enable it.
|
||||
// aes128cbcID: {16, aes.BlockSize, 0, nil},
|
||||
} |
||||
|
||||
// prefixLen is the length of the packet prefix that contains the packet length
|
||||
// and number of padding bytes.
|
||||
const prefixLen = 5 |
||||
|
||||
// streamPacketCipher is a packetCipher using a stream cipher.
|
||||
type streamPacketCipher struct { |
||||
mac hash.Hash |
||||
cipher cipher.Stream |
||||
|
||||
// The following members are to avoid per-packet allocations.
|
||||
prefix [prefixLen]byte |
||||
seqNumBytes [4]byte |
||||
padding [2 * packetSizeMultiple]byte |
||||
packetData []byte |
||||
macResult []byte |
||||
} |
||||
|
||||
// readPacket reads and decrypt a single packet from the reader argument.
|
||||
func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
if _, err := io.ReadFull(r, s.prefix[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) |
||||
length := binary.BigEndian.Uint32(s.prefix[0:4]) |
||||
paddingLength := uint32(s.prefix[4]) |
||||
|
||||
var macSize uint32 |
||||
if s.mac != nil { |
||||
s.mac.Reset() |
||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) |
||||
s.mac.Write(s.seqNumBytes[:]) |
||||
s.mac.Write(s.prefix[:]) |
||||
macSize = uint32(s.mac.Size()) |
||||
} |
||||
|
||||
if length <= paddingLength+1 { |
||||
return nil, errors.New("ssh: invalid packet length, packet too small") |
||||
} |
||||
|
||||
if length > maxPacket { |
||||
return nil, errors.New("ssh: invalid packet length, packet too large") |
||||
} |
||||
|
||||
// the maxPacket check above ensures that length-1+macSize
|
||||
// does not overflow.
|
||||
if uint32(cap(s.packetData)) < length-1+macSize { |
||||
s.packetData = make([]byte, length-1+macSize) |
||||
} else { |
||||
s.packetData = s.packetData[:length-1+macSize] |
||||
} |
||||
|
||||
if _, err := io.ReadFull(r, s.packetData); err != nil { |
||||
return nil, err |
||||
} |
||||
mac := s.packetData[length-1:] |
||||
data := s.packetData[:length-1] |
||||
s.cipher.XORKeyStream(data, data) |
||||
|
||||
if s.mac != nil { |
||||
s.mac.Write(data) |
||||
s.macResult = s.mac.Sum(s.macResult[:0]) |
||||
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { |
||||
return nil, errors.New("ssh: MAC failure") |
||||
} |
||||
} |
||||
|
||||
return s.packetData[:length-paddingLength-1], nil |
||||
} |
||||
|
||||
// writePacket encrypts and sends a packet of data to the writer argument
|
||||
func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||
if len(packet) > maxPacket { |
||||
return errors.New("ssh: packet too large") |
||||
} |
||||
|
||||
paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple |
||||
if paddingLength < 4 { |
||||
paddingLength += packetSizeMultiple |
||||
} |
||||
|
||||
length := len(packet) + 1 + paddingLength |
||||
binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) |
||||
s.prefix[4] = byte(paddingLength) |
||||
padding := s.padding[:paddingLength] |
||||
if _, err := io.ReadFull(rand, padding); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if s.mac != nil { |
||||
s.mac.Reset() |
||||
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) |
||||
s.mac.Write(s.seqNumBytes[:]) |
||||
s.mac.Write(s.prefix[:]) |
||||
s.mac.Write(packet) |
||||
s.mac.Write(padding) |
||||
} |
||||
|
||||
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) |
||||
s.cipher.XORKeyStream(packet, packet) |
||||
s.cipher.XORKeyStream(padding, padding) |
||||
|
||||
if _, err := w.Write(s.prefix[:]); err != nil { |
||||
return err |
||||
} |
||||
if _, err := w.Write(packet); err != nil { |
||||
return err |
||||
} |
||||
if _, err := w.Write(padding); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if s.mac != nil { |
||||
s.macResult = s.mac.Sum(s.macResult[:0]) |
||||
if _, err := w.Write(s.macResult); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
type gcmCipher struct { |
||||
aead cipher.AEAD |
||||
prefix [4]byte |
||||
iv []byte |
||||
buf []byte |
||||
} |
||||
|
||||
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { |
||||
c, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
aead, err := cipher.NewGCM(c) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &gcmCipher{ |
||||
aead: aead, |
||||
iv: iv, |
||||
}, nil |
||||
} |
||||
|
||||
const gcmTagSize = 16 |
||||
|
||||
func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||
// Pad out to multiple of 16 bytes. This is different from the
|
||||
// stream cipher because that encrypts the length too.
|
||||
padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) |
||||
if padding < 4 { |
||||
padding += packetSizeMultiple |
||||
} |
||||
|
||||
length := uint32(len(packet) + int(padding) + 1) |
||||
binary.BigEndian.PutUint32(c.prefix[:], length) |
||||
if _, err := w.Write(c.prefix[:]); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if cap(c.buf) < int(length) { |
||||
c.buf = make([]byte, length) |
||||
} else { |
||||
c.buf = c.buf[:length] |
||||
} |
||||
|
||||
c.buf[0] = padding |
||||
copy(c.buf[1:], packet) |
||||
if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { |
||||
return err |
||||
} |
||||
c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) |
||||
if _, err := w.Write(c.buf); err != nil { |
||||
return err |
||||
} |
||||
c.incIV() |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (c *gcmCipher) incIV() { |
||||
for i := 4 + 7; i >= 4; i-- { |
||||
c.iv[i]++ |
||||
if c.iv[i] != 0 { |
||||
break |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
if _, err := io.ReadFull(r, c.prefix[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
length := binary.BigEndian.Uint32(c.prefix[:]) |
||||
if length > maxPacket { |
||||
return nil, errors.New("ssh: max packet length exceeded.") |
||||
} |
||||
|
||||
if cap(c.buf) < int(length+gcmTagSize) { |
||||
c.buf = make([]byte, length+gcmTagSize) |
||||
} else { |
||||
c.buf = c.buf[:length+gcmTagSize] |
||||
} |
||||
|
||||
if _, err := io.ReadFull(r, c.buf); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c.incIV() |
||||
|
||||
padding := plain[0] |
||||
if padding < 4 || padding >= 20 { |
||||
return nil, fmt.Errorf("ssh: illegal padding %d", padding) |
||||
} |
||||
|
||||
if int(padding+1) >= len(plain) { |
||||
return nil, fmt.Errorf("ssh: padding %d too large", padding) |
||||
} |
||||
plain = plain[1 : length-uint32(padding)] |
||||
return plain, nil |
||||
} |
||||
|
||||
// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
|
||||
type cbcCipher struct { |
||||
mac hash.Hash |
||||
macSize uint32 |
||||
decrypter cipher.BlockMode |
||||
encrypter cipher.BlockMode |
||||
|
||||
// The following members are to avoid per-packet allocations.
|
||||
seqNumBytes [4]byte |
||||
packetData []byte |
||||
macResult []byte |
||||
|
||||
// Amount of data we should still read to hide which
|
||||
// verification error triggered.
|
||||
oracleCamouflage uint32 |
||||
} |
||||
|
||||
func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { |
||||
c, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
cbc := &cbcCipher{ |
||||
mac: macModes[algs.MAC].new(macKey), |
||||
decrypter: cipher.NewCBCDecrypter(c, iv), |
||||
encrypter: cipher.NewCBCEncrypter(c, iv), |
||||
packetData: make([]byte, 1024), |
||||
} |
||||
if cbc.mac != nil { |
||||
cbc.macSize = uint32(cbc.mac.Size()) |
||||
} |
||||
|
||||
return cbc, nil |
||||
} |
||||
|
||||
func maxUInt32(a, b int) uint32 { |
||||
if a > b { |
||||
return uint32(a) |
||||
} |
||||
return uint32(b) |
||||
} |
||||
|
||||
const ( |
||||
cbcMinPacketSizeMultiple = 8 |
||||
cbcMinPacketSize = 16 |
||||
cbcMinPaddingSize = 4 |
||||
) |
||||
|
||||
// cbcError represents a verification error that may leak information.
|
||||
type cbcError string |
||||
|
||||
func (e cbcError) Error() string { return string(e) } |
||||
|
||||
func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
p, err := c.readPacketLeaky(seqNum, r) |
||||
if err != nil { |
||||
if _, ok := err.(cbcError); ok { |
||||
// Verification error: read a fixed amount of
|
||||
// data, to make distinguishing between
|
||||
// failing MAC and failing length check more
|
||||
// difficult.
|
||||
io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) |
||||
} |
||||
} |
||||
return p, err |
||||
} |
||||
|
||||
func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { |
||||
blockSize := c.decrypter.BlockSize() |
||||
|
||||
// Read the header, which will include some of the subsequent data in the
|
||||
// case of block ciphers - this is copied back to the payload later.
|
||||
// How many bytes of payload/padding will be read with this first read.
|
||||
firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) |
||||
firstBlock := c.packetData[:firstBlockLength] |
||||
if _, err := io.ReadFull(r, firstBlock); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength |
||||
|
||||
c.decrypter.CryptBlocks(firstBlock, firstBlock) |
||||
length := binary.BigEndian.Uint32(firstBlock[:4]) |
||||
if length > maxPacket { |
||||
return nil, cbcError("ssh: packet too large") |
||||
} |
||||
if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { |
||||
// The minimum size of a packet is 16 (or the cipher block size, whichever
|
||||
// is larger) bytes.
|
||||
return nil, cbcError("ssh: packet too small") |
||||
} |
||||
// The length of the packet (including the length field but not the MAC) must
|
||||
// be a multiple of the block size or 8, whichever is larger.
|
||||
if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { |
||||
return nil, cbcError("ssh: invalid packet length multiple") |
||||
} |
||||
|
||||
paddingLength := uint32(firstBlock[4]) |
||||
if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { |
||||
return nil, cbcError("ssh: invalid packet length") |
||||
} |
||||
|
||||
// Positions within the c.packetData buffer:
|
||||
macStart := 4 + length |
||||
paddingStart := macStart - paddingLength |
||||
|
||||
// Entire packet size, starting before length, ending at end of mac.
|
||||
entirePacketSize := macStart + c.macSize |
||||
|
||||
// Ensure c.packetData is large enough for the entire packet data.
|
||||
if uint32(cap(c.packetData)) < entirePacketSize { |
||||
// Still need to upsize and copy, but this should be rare at runtime, only
|
||||
// on upsizing the packetData buffer.
|
||||
c.packetData = make([]byte, entirePacketSize) |
||||
copy(c.packetData, firstBlock) |
||||
} else { |
||||
c.packetData = c.packetData[:entirePacketSize] |
||||
} |
||||
|
||||
if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { |
||||
return nil, err |
||||
} else { |
||||
c.oracleCamouflage -= uint32(n) |
||||
} |
||||
|
||||
remainingCrypted := c.packetData[firstBlockLength:macStart] |
||||
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) |
||||
|
||||
mac := c.packetData[macStart:] |
||||
if c.mac != nil { |
||||
c.mac.Reset() |
||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) |
||||
c.mac.Write(c.seqNumBytes[:]) |
||||
c.mac.Write(c.packetData[:macStart]) |
||||
c.macResult = c.mac.Sum(c.macResult[:0]) |
||||
if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { |
||||
return nil, cbcError("ssh: MAC failure") |
||||
} |
||||
} |
||||
|
||||
return c.packetData[prefixLen:paddingStart], nil |
||||
} |
||||
|
||||
func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { |
||||
effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) |
||||
|
||||
// Length of encrypted portion of the packet (header, payload, padding).
|
||||
// Enforce minimum padding and packet size.
|
||||
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) |
||||
// Enforce block size.
|
||||
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize |
||||
|
||||
length := encLength - 4 |
||||
paddingLength := int(length) - (1 + len(packet)) |
||||
|
||||
// Overall buffer contains: header, payload, padding, mac.
|
||||
// Space for the MAC is reserved in the capacity but not the slice length.
|
||||
bufferSize := encLength + c.macSize |
||||
if uint32(cap(c.packetData)) < bufferSize { |
||||
c.packetData = make([]byte, encLength, bufferSize) |
||||
} else { |
||||
c.packetData = c.packetData[:encLength] |
||||
} |
||||
|
||||
p := c.packetData |
||||
|
||||
// Packet header.
|
||||
binary.BigEndian.PutUint32(p, length) |
||||
p = p[4:] |
||||
p[0] = byte(paddingLength) |
||||
|
||||
// Payload.
|
||||
p = p[1:] |
||||
copy(p, packet) |
||||
|
||||
// Padding.
|
||||
p = p[len(packet):] |
||||
if _, err := io.ReadFull(rand, p); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if c.mac != nil { |
||||
c.mac.Reset() |
||||
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) |
||||
c.mac.Write(c.seqNumBytes[:]) |
||||
c.mac.Write(c.packetData) |
||||
// The MAC is now appended into the capacity reserved for it earlier.
|
||||
c.packetData = c.mac.Sum(c.packetData) |
||||
} |
||||
|
||||
c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) |
||||
|
||||
if _, err := w.Write(c.packetData); err != nil { |
||||
return err |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,127 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto" |
||||
"crypto/aes" |
||||
"crypto/rand" |
||||
"testing" |
||||
) |
||||
|
||||
func TestDefaultCiphersExist(t *testing.T) { |
||||
for _, cipherAlgo := range supportedCiphers { |
||||
if _, ok := cipherModes[cipherAlgo]; !ok { |
||||
t.Errorf("default cipher %q is unknown", cipherAlgo) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestPacketCiphers(t *testing.T) { |
||||
// Still test aes128cbc cipher althought it's commented out.
|
||||
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} |
||||
defer delete(cipherModes, aes128cbcID) |
||||
|
||||
for cipher := range cipherModes { |
||||
kr := &kexResult{Hash: crypto.SHA1} |
||||
algs := directionAlgorithms{ |
||||
Cipher: cipher, |
||||
MAC: "hmac-sha1", |
||||
Compression: "none", |
||||
} |
||||
client, err := newPacketCipher(clientKeys, algs, kr) |
||||
if err != nil { |
||||
t.Errorf("newPacketCipher(client, %q): %v", cipher, err) |
||||
continue |
||||
} |
||||
server, err := newPacketCipher(clientKeys, algs, kr) |
||||
if err != nil { |
||||
t.Errorf("newPacketCipher(client, %q): %v", cipher, err) |
||||
continue |
||||
} |
||||
|
||||
want := "bla bla" |
||||
input := []byte(want) |
||||
buf := &bytes.Buffer{} |
||||
if err := client.writePacket(0, buf, rand.Reader, input); err != nil { |
||||
t.Errorf("writePacket(%q): %v", cipher, err) |
||||
continue |
||||
} |
||||
|
||||
packet, err := server.readPacket(0, buf) |
||||
if err != nil { |
||||
t.Errorf("readPacket(%q): %v", cipher, err) |
||||
continue |
||||
} |
||||
|
||||
if string(packet) != want { |
||||
t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestCBCOracleCounterMeasure(t *testing.T) { |
||||
cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} |
||||
defer delete(cipherModes, aes128cbcID) |
||||
|
||||
kr := &kexResult{Hash: crypto.SHA1} |
||||
algs := directionAlgorithms{ |
||||
Cipher: aes128cbcID, |
||||
MAC: "hmac-sha1", |
||||
Compression: "none", |
||||
} |
||||
client, err := newPacketCipher(clientKeys, algs, kr) |
||||
if err != nil { |
||||
t.Fatalf("newPacketCipher(client): %v", err) |
||||
} |
||||
|
||||
want := "bla bla" |
||||
input := []byte(want) |
||||
buf := &bytes.Buffer{} |
||||
if err := client.writePacket(0, buf, rand.Reader, input); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
packetSize := buf.Len() |
||||
buf.Write(make([]byte, 2*maxPacket)) |
||||
|
||||
// We corrupt each byte, but this usually will only test the
|
||||
// 'packet too large' or 'MAC failure' cases.
|
||||
lastRead := -1 |
||||
for i := 0; i < packetSize; i++ { |
||||
server, err := newPacketCipher(clientKeys, algs, kr) |
||||
if err != nil { |
||||
t.Fatalf("newPacketCipher(client): %v", err) |
||||
} |
||||
|
||||
fresh := &bytes.Buffer{} |
||||
fresh.Write(buf.Bytes()) |
||||
fresh.Bytes()[i] ^= 0x01 |
||||
|
||||
before := fresh.Len() |
||||
_, err = server.readPacket(0, fresh) |
||||
if err == nil { |
||||
t.Errorf("corrupt byte %d: readPacket succeeded ", i) |
||||
continue |
||||
} |
||||
if _, ok := err.(cbcError); !ok { |
||||
t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) |
||||
continue |
||||
} |
||||
|
||||
after := fresh.Len() |
||||
bytesRead := before - after |
||||
if bytesRead < maxPacket { |
||||
t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) |
||||
continue |
||||
} |
||||
|
||||
if i > 0 && bytesRead != lastRead { |
||||
t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) |
||||
} |
||||
lastRead = bytesRead |
||||
} |
||||
} |
@ -0,0 +1,213 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"net" |
||||
"sync" |
||||
) |
||||
|
||||
// Client implements a traditional SSH client that supports shells,
|
||||
// subprocesses, port forwarding and tunneled dialing.
|
||||
type Client struct { |
||||
Conn |
||||
|
||||
forwards forwardList // forwarded tcpip connections from the remote side
|
||||
mu sync.Mutex |
||||
channelHandlers map[string]chan NewChannel |
||||
} |
||||
|
||||
// HandleChannelOpen returns a channel on which NewChannel requests
|
||||
// for the given type are sent. If the type already is being handled,
|
||||
// nil is returned. The channel is closed when the connection is closed.
|
||||
func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
if c.channelHandlers == nil { |
||||
// The SSH channel has been closed.
|
||||
c := make(chan NewChannel) |
||||
close(c) |
||||
return c |
||||
} |
||||
|
||||
ch := c.channelHandlers[channelType] |
||||
if ch != nil { |
||||
return nil |
||||
} |
||||
|
||||
ch = make(chan NewChannel, 16) |
||||
c.channelHandlers[channelType] = ch |
||||
return ch |
||||
} |
||||
|
||||
// NewClient creates a Client on top of the given connection.
|
||||
func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { |
||||
conn := &Client{ |
||||
Conn: c, |
||||
channelHandlers: make(map[string]chan NewChannel, 1), |
||||
} |
||||
|
||||
go conn.handleGlobalRequests(reqs) |
||||
go conn.handleChannelOpens(chans) |
||||
go func() { |
||||
conn.Wait() |
||||
conn.forwards.closeAll() |
||||
}() |
||||
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) |
||||
return conn |
||||
} |
||||
|
||||
// NewClientConn establishes an authenticated SSH connection using c
|
||||
// as the underlying transport. The Request and NewChannel channels
|
||||
// must be serviced or the connection will hang.
|
||||
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { |
||||
fullConf := *config |
||||
fullConf.SetDefaults() |
||||
conn := &connection{ |
||||
sshConn: sshConn{conn: c}, |
||||
} |
||||
|
||||
if err := conn.clientHandshake(addr, &fullConf); err != nil { |
||||
c.Close() |
||||
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) |
||||
} |
||||
conn.mux = newMux(conn.transport) |
||||
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil |
||||
} |
||||
|
||||
// clientHandshake performs the client side key exchange. See RFC 4253 Section
|
||||
// 7.
|
||||
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { |
||||
if config.ClientVersion != "" { |
||||
c.clientVersion = []byte(config.ClientVersion) |
||||
} else { |
||||
c.clientVersion = []byte(packageVersion) |
||||
} |
||||
var err error |
||||
c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
c.transport = newClientTransport( |
||||
newTransport(c.sshConn.conn, config.Rand, true /* is client */), |
||||
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) |
||||
if err := c.transport.requestKeyChange(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if packet, err := c.transport.readPacket(); err != nil { |
||||
return err |
||||
} else if packet[0] != msgNewKeys { |
||||
return unexpectedMessageError(msgNewKeys, packet[0]) |
||||
} |
||||
|
||||
// We just did the key change, so the session ID is established.
|
||||
c.sessionID = c.transport.getSessionID() |
||||
|
||||
return c.clientAuthenticate(config) |
||||
} |
||||
|
||||
// verifyHostKeySignature verifies the host key obtained in the key
|
||||
// exchange.
|
||||
func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error { |
||||
sig, rest, ok := parseSignatureBody(result.Signature) |
||||
if len(rest) > 0 || !ok { |
||||
return errors.New("ssh: signature parse error") |
||||
} |
||||
|
||||
return hostKey.Verify(result.H, sig) |
||||
} |
||||
|
||||
// NewSession opens a new Session for this client. (A session is a remote
|
||||
// execution of a program.)
|
||||
func (c *Client) NewSession() (*Session, error) { |
||||
ch, in, err := c.OpenChannel("session", nil) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return newSession(ch, in) |
||||
} |
||||
|
||||
func (c *Client) handleGlobalRequests(incoming <-chan *Request) { |
||||
for r := range incoming { |
||||
// This handles keepalive messages and matches
|
||||
// the behaviour of OpenSSH.
|
||||
r.Reply(false, nil) |
||||
} |
||||
} |
||||
|
||||
// handleChannelOpens channel open messages from the remote side.
|
||||
func (c *Client) handleChannelOpens(in <-chan NewChannel) { |
||||
for ch := range in { |
||||
c.mu.Lock() |
||||
handler := c.channelHandlers[ch.ChannelType()] |
||||
c.mu.Unlock() |
||||
|
||||
if handler != nil { |
||||
handler <- ch |
||||
} else { |
||||
ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) |
||||
} |
||||
} |
||||
|
||||
c.mu.Lock() |
||||
for _, ch := range c.channelHandlers { |
||||
close(ch) |
||||
} |
||||
c.channelHandlers = nil |
||||
c.mu.Unlock() |
||||
} |
||||
|
||||
// Dial starts a client connection to the given SSH server. It is a
|
||||
// convenience function that connects to the given network address,
|
||||
// initiates the SSH handshake, and then sets up a Client. For access
|
||||
// to incoming channels and requests, use net.Dial with NewClientConn
|
||||
// instead.
|
||||
func Dial(network, addr string, config *ClientConfig) (*Client, error) { |
||||
conn, err := net.Dial(network, addr) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
c, chans, reqs, err := NewClientConn(conn, addr, config) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return NewClient(c, chans, reqs), nil |
||||
} |
||||
|
||||
// A ClientConfig structure is used to configure a Client. It must not be
|
||||
// modified after having been passed to an SSH function.
|
||||
type ClientConfig struct { |
||||
// Config contains configuration that is shared between clients and
|
||||
// servers.
|
||||
Config |
||||
|
||||
// User contains the username to authenticate as.
|
||||
User string |
||||
|
||||
// Auth contains possible authentication methods to use with the
|
||||
// server. Only the first instance of a particular RFC 4252 method will
|
||||
// be used during authentication.
|
||||
Auth []AuthMethod |
||||
|
||||
// HostKeyCallback, if not nil, is called during the cryptographic
|
||||
// handshake to validate the server's host key. A nil HostKeyCallback
|
||||
// implies that all host keys are accepted.
|
||||
HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error |
||||
|
||||
// ClientVersion contains the version identification string that will
|
||||
// be used for the connection. If empty, a reasonable default is used.
|
||||
ClientVersion string |
||||
|
||||
// HostKeyAlgorithms lists the key types that the client will
|
||||
// accept from the server as host key, in order of
|
||||
// preference. If empty, a reasonable default is used. Any
|
||||
// string returned from PublicKey.Type method may be used, or
|
||||
// any of the CertAlgoXxxx and KeyAlgoXxxx constants.
|
||||
HostKeyAlgorithms []string |
||||
} |
@ -0,0 +1,441 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
) |
||||
|
||||
// clientAuthenticate authenticates with the remote server. See RFC 4252.
|
||||
func (c *connection) clientAuthenticate(config *ClientConfig) error { |
||||
// initiate user auth session
|
||||
if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { |
||||
return err |
||||
} |
||||
packet, err := c.transport.readPacket() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var serviceAccept serviceAcceptMsg |
||||
if err := Unmarshal(packet, &serviceAccept); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// during the authentication phase the client first attempts the "none" method
|
||||
// then any untried methods suggested by the server.
|
||||
tried := make(map[string]bool) |
||||
var lastMethods []string |
||||
for auth := AuthMethod(new(noneAuth)); auth != nil; { |
||||
ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if ok { |
||||
// success
|
||||
return nil |
||||
} |
||||
tried[auth.method()] = true |
||||
if methods == nil { |
||||
methods = lastMethods |
||||
} |
||||
lastMethods = methods |
||||
|
||||
auth = nil |
||||
|
||||
findNext: |
||||
for _, a := range config.Auth { |
||||
candidateMethod := a.method() |
||||
if tried[candidateMethod] { |
||||
continue |
||||
} |
||||
for _, meth := range methods { |
||||
if meth == candidateMethod { |
||||
auth = a |
||||
break findNext |
||||
} |
||||
} |
||||
} |
||||
} |
||||
return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried)) |
||||
} |
||||
|
||||
func keys(m map[string]bool) []string { |
||||
s := make([]string, 0, len(m)) |
||||
|
||||
for key := range m { |
||||
s = append(s, key) |
||||
} |
||||
return s |
||||
} |
||||
|
||||
// An AuthMethod represents an instance of an RFC 4252 authentication method.
|
||||
type AuthMethod interface { |
||||
// auth authenticates user over transport t.
|
||||
// Returns true if authentication is successful.
|
||||
// If authentication is not successful, a []string of alternative
|
||||
// method names is returned. If the slice is nil, it will be ignored
|
||||
// and the previous set of possible methods will be reused.
|
||||
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error) |
||||
|
||||
// method returns the RFC 4252 method name.
|
||||
method() string |
||||
} |
||||
|
||||
// "none" authentication, RFC 4252 section 5.2.
|
||||
type noneAuth int |
||||
|
||||
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
if err := c.writePacket(Marshal(&userAuthRequestMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: "none", |
||||
})); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
return handleAuthResponse(c) |
||||
} |
||||
|
||||
func (n *noneAuth) method() string { |
||||
return "none" |
||||
} |
||||
|
||||
// passwordCallback is an AuthMethod that fetches the password through
|
||||
// a function call, e.g. by prompting the user.
|
||||
type passwordCallback func() (password string, err error) |
||||
|
||||
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
type passwordAuthMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
Reply bool |
||||
Password string |
||||
} |
||||
|
||||
pw, err := cb() |
||||
// REVIEW NOTE: is there a need to support skipping a password attempt?
|
||||
// The program may only find out that the user doesn't have a password
|
||||
// when prompting.
|
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
if err := c.writePacket(Marshal(&passwordAuthMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: cb.method(), |
||||
Reply: false, |
||||
Password: pw, |
||||
})); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
return handleAuthResponse(c) |
||||
} |
||||
|
||||
func (cb passwordCallback) method() string { |
||||
return "password" |
||||
} |
||||
|
||||
// Password returns an AuthMethod using the given password.
|
||||
func Password(secret string) AuthMethod { |
||||
return passwordCallback(func() (string, error) { return secret, nil }) |
||||
} |
||||
|
||||
// PasswordCallback returns an AuthMethod that uses a callback for
|
||||
// fetching a password.
|
||||
func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { |
||||
return passwordCallback(prompt) |
||||
} |
||||
|
||||
type publickeyAuthMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
// HasSig indicates to the receiver packet that the auth request is signed and
|
||||
// should be used for authentication of the request.
|
||||
HasSig bool |
||||
Algoname string |
||||
PubKey []byte |
||||
// Sig is tagged with "rest" so Marshal will exclude it during
|
||||
// validateKey
|
||||
Sig []byte `ssh:"rest"` |
||||
} |
||||
|
||||
// publicKeyCallback is an AuthMethod that uses a set of key
|
||||
// pairs for authentication.
|
||||
type publicKeyCallback func() ([]Signer, error) |
||||
|
||||
func (cb publicKeyCallback) method() string { |
||||
return "publickey" |
||||
} |
||||
|
||||
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
// Authentication is performed in two stages. The first stage sends an
|
||||
// enquiry to test if each key is acceptable to the remote. The second
|
||||
// stage attempts to authenticate with the valid keys obtained in the
|
||||
// first stage.
|
||||
|
||||
signers, err := cb() |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
var validKeys []Signer |
||||
for _, signer := range signers { |
||||
if ok, err := validateKey(signer.PublicKey(), user, c); ok { |
||||
validKeys = append(validKeys, signer) |
||||
} else { |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
} |
||||
} |
||||
|
||||
// methods that may continue if this auth is not successful.
|
||||
var methods []string |
||||
for _, signer := range validKeys { |
||||
pub := signer.PublicKey() |
||||
|
||||
pubKey := pub.Marshal() |
||||
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: cb.method(), |
||||
}, []byte(pub.Type()), pubKey)) |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
// manually wrap the serialized signature in a string
|
||||
s := Marshal(sign) |
||||
sig := make([]byte, stringLength(len(s))) |
||||
marshalString(sig, s) |
||||
msg := publickeyAuthMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: cb.method(), |
||||
HasSig: true, |
||||
Algoname: pub.Type(), |
||||
PubKey: pubKey, |
||||
Sig: sig, |
||||
} |
||||
p := Marshal(&msg) |
||||
if err := c.writePacket(p); err != nil { |
||||
return false, nil, err |
||||
} |
||||
var success bool |
||||
success, methods, err = handleAuthResponse(c) |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
if success { |
||||
return success, methods, err |
||||
} |
||||
} |
||||
return false, methods, nil |
||||
} |
||||
|
||||
// validateKey validates the key provided is acceptable to the server.
|
||||
func validateKey(key PublicKey, user string, c packetConn) (bool, error) { |
||||
pubKey := key.Marshal() |
||||
msg := publickeyAuthMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: "publickey", |
||||
HasSig: false, |
||||
Algoname: key.Type(), |
||||
PubKey: pubKey, |
||||
} |
||||
if err := c.writePacket(Marshal(&msg)); err != nil { |
||||
return false, err |
||||
} |
||||
|
||||
return confirmKeyAck(key, c) |
||||
} |
||||
|
||||
func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { |
||||
pubKey := key.Marshal() |
||||
algoname := key.Type() |
||||
|
||||
for { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
switch packet[0] { |
||||
case msgUserAuthBanner: |
||||
// TODO(gpaul): add callback to present the banner to the user
|
||||
case msgUserAuthPubKeyOk: |
||||
var msg userAuthPubKeyOkMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, err |
||||
} |
||||
if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { |
||||
return false, nil |
||||
} |
||||
return true, nil |
||||
case msgUserAuthFailure: |
||||
return false, nil |
||||
default: |
||||
return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// PublicKeys returns an AuthMethod that uses the given key
|
||||
// pairs.
|
||||
func PublicKeys(signers ...Signer) AuthMethod { |
||||
return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) |
||||
} |
||||
|
||||
// PublicKeysCallback returns an AuthMethod that runs the given
|
||||
// function to obtain a list of key pairs.
|
||||
func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { |
||||
return publicKeyCallback(getSigners) |
||||
} |
||||
|
||||
// handleAuthResponse returns whether the preceding authentication request succeeded
|
||||
// along with a list of remaining authentication methods to try next and
|
||||
// an error if an unexpected response was received.
|
||||
func handleAuthResponse(c packetConn) (bool, []string, error) { |
||||
for { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
switch packet[0] { |
||||
case msgUserAuthBanner: |
||||
// TODO: add callback to present the banner to the user
|
||||
case msgUserAuthFailure: |
||||
var msg userAuthFailureMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, nil, err |
||||
} |
||||
return false, msg.Methods, nil |
||||
case msgUserAuthSuccess: |
||||
return true, nil, nil |
||||
case msgDisconnect: |
||||
return false, nil, io.EOF |
||||
default: |
||||
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// KeyboardInteractiveChallenge should print questions, optionally
|
||||
// disabling echoing (e.g. for passwords), and return all the answers.
|
||||
// Challenge may be called multiple times in a single session. After
|
||||
// successful authentication, the server may send a challenge with no
|
||||
// questions, for which the user and instruction messages should be
|
||||
// printed. RFC 4256 section 3.3 details how the UI should behave for
|
||||
// both CLI and GUI environments.
|
||||
type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) |
||||
|
||||
// KeyboardInteractive returns a AuthMethod using a prompt/response
|
||||
// sequence controlled by the server.
|
||||
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { |
||||
return challenge |
||||
} |
||||
|
||||
func (cb KeyboardInteractiveChallenge) method() string { |
||||
return "keyboard-interactive" |
||||
} |
||||
|
||||
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { |
||||
type initiateMsg struct { |
||||
User string `sshtype:"50"` |
||||
Service string |
||||
Method string |
||||
Language string |
||||
Submethods string |
||||
} |
||||
|
||||
if err := c.writePacket(Marshal(&initiateMsg{ |
||||
User: user, |
||||
Service: serviceSSH, |
||||
Method: "keyboard-interactive", |
||||
})); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
for { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
// like handleAuthResponse, but with less options.
|
||||
switch packet[0] { |
||||
case msgUserAuthBanner: |
||||
// TODO: Print banners during userauth.
|
||||
continue |
||||
case msgUserAuthInfoRequest: |
||||
// OK
|
||||
case msgUserAuthFailure: |
||||
var msg userAuthFailureMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, nil, err |
||||
} |
||||
return false, msg.Methods, nil |
||||
case msgUserAuthSuccess: |
||||
return true, nil, nil |
||||
default: |
||||
return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) |
||||
} |
||||
|
||||
var msg userAuthInfoRequestMsg |
||||
if err := Unmarshal(packet, &msg); err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
// Manually unpack the prompt/echo pairs.
|
||||
rest := msg.Prompts |
||||
var prompts []string |
||||
var echos []bool |
||||
for i := 0; i < int(msg.NumPrompts); i++ { |
||||
prompt, r, ok := parseString(rest) |
||||
if !ok || len(r) == 0 { |
||||
return false, nil, errors.New("ssh: prompt format error") |
||||
} |
||||
prompts = append(prompts, string(prompt)) |
||||
echos = append(echos, r[0] != 0) |
||||
rest = r[1:] |
||||
} |
||||
|
||||
if len(rest) != 0 { |
||||
return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs") |
||||
} |
||||
|
||||
answers, err := cb(msg.User, msg.Instruction, prompts, echos) |
||||
if err != nil { |
||||
return false, nil, err |
||||
} |
||||
|
||||
if len(answers) != len(prompts) { |
||||
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback") |
||||
} |
||||
responseLength := 1 + 4 |
||||
for _, a := range answers { |
||||
responseLength += stringLength(len(a)) |
||||
} |
||||
serialized := make([]byte, responseLength) |
||||
p := serialized |
||||
p[0] = msgUserAuthInfoResponse |
||||
p = p[1:] |
||||
p = marshalUint32(p, uint32(len(answers))) |
||||
for _, a := range answers { |
||||
p = marshalString(p, []byte(a)) |
||||
} |
||||
|
||||
if err := c.writePacket(serialized); err != nil { |
||||
return false, nil, err |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,393 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"errors" |
||||
"fmt" |
||||
"strings" |
||||
"testing" |
||||
) |
||||
|
||||
type keyboardInteractive map[string]string |
||||
|
||||
func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { |
||||
var answers []string |
||||
for _, q := range questions { |
||||
answers = append(answers, cr[q]) |
||||
} |
||||
return answers, nil |
||||
} |
||||
|
||||
// reused internally by tests
|
||||
var clientPassword = "tiger" |
||||
|
||||
// tryAuth runs a handshake with a given config against an SSH server
|
||||
// with config serverConfig
|
||||
func tryAuth(t *testing.T, config *ClientConfig) error { |
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
defer c1.Close() |
||||
defer c2.Close() |
||||
|
||||
certChecker := CertChecker{ |
||||
IsAuthority: func(k PublicKey) bool { |
||||
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) |
||||
}, |
||||
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { |
||||
if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { |
||||
return nil, nil |
||||
} |
||||
|
||||
return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) |
||||
}, |
||||
IsRevoked: func(c *Certificate) bool { |
||||
return c.Serial == 666 |
||||
}, |
||||
} |
||||
|
||||
serverConfig := &ServerConfig{ |
||||
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { |
||||
if conn.User() == "testuser" && string(pass) == clientPassword { |
||||
return nil, nil |
||||
} |
||||
return nil, errors.New("password auth failed") |
||||
}, |
||||
PublicKeyCallback: certChecker.Authenticate, |
||||
KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { |
||||
ans, err := challenge("user", |
||||
"instruction", |
||||
[]string{"question1", "question2"}, |
||||
[]bool{true, true}) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" |
||||
if ok { |
||||
challenge("user", "motd", nil, nil) |
||||
return nil, nil |
||||
} |
||||
return nil, errors.New("keyboard-interactive failed") |
||||
}, |
||||
AuthLogCallback: func(conn ConnMetadata, method string, err error) { |
||||
t.Logf("user %q, method %q: %v", conn.User(), method, err) |
||||
}, |
||||
} |
||||
serverConfig.AddHostKey(testSigners["rsa"]) |
||||
|
||||
go newServer(c1, serverConfig) |
||||
_, _, _, err = NewClientConn(c2, "", config) |
||||
return err |
||||
} |
||||
|
||||
func TestClientAuthPublicKey(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["rsa"]), |
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodPassword(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
Password(clientPassword), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodFallback(t *testing.T) { |
||||
var passwordCalled bool |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["rsa"]), |
||||
PasswordCallback( |
||||
func() (string, error) { |
||||
passwordCalled = true |
||||
return "WRONG", nil |
||||
}), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
|
||||
if passwordCalled { |
||||
t.Errorf("password auth tried before public-key auth.") |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodWrongPassword(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
Password("wrong"), |
||||
PublicKeys(testSigners["rsa"]), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodKeyboardInteractive(t *testing.T) { |
||||
answers := keyboardInteractive(map[string]string{ |
||||
"question1": "answer1", |
||||
"question2": "answer2", |
||||
}) |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
KeyboardInteractive(answers.Challenge), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("unable to dial remote side: %s", err) |
||||
} |
||||
} |
||||
|
||||
func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { |
||||
answers := keyboardInteractive(map[string]string{ |
||||
"question1": "answer1", |
||||
"question2": "WRONG", |
||||
}) |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
KeyboardInteractive(answers.Challenge), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err == nil { |
||||
t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") |
||||
} |
||||
} |
||||
|
||||
// the mock server will only authenticate ssh-rsa keys
|
||||
func TestAuthMethodInvalidPublicKey(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["dsa"]), |
||||
}, |
||||
} |
||||
|
||||
if err := tryAuth(t, config); err == nil { |
||||
t.Fatalf("dsa private key should not have authenticated with rsa public key") |
||||
} |
||||
} |
||||
|
||||
// the client should authenticate with the second key
|
||||
func TestAuthMethodRSAandDSA(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["dsa"], testSigners["rsa"]), |
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("client could not authenticate with rsa key: %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestClientHMAC(t *testing.T) { |
||||
for _, mac := range supportedMACs { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["rsa"]), |
||||
}, |
||||
Config: Config{ |
||||
MACs: []string{mac}, |
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err != nil { |
||||
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// issue 4285.
|
||||
func TestClientUnsupportedCipher(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(), |
||||
}, |
||||
Config: Config{ |
||||
Ciphers: []string{"aes128-cbc"}, // not currently supported
|
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err == nil { |
||||
t.Errorf("expected no ciphers in common") |
||||
} |
||||
} |
||||
|
||||
func TestClientUnsupportedKex(t *testing.T) { |
||||
config := &ClientConfig{ |
||||
User: "testuser", |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(), |
||||
}, |
||||
Config: Config{ |
||||
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
|
||||
}, |
||||
} |
||||
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { |
||||
t.Errorf("got %v, expected 'common algorithm'", err) |
||||
} |
||||
} |
||||
|
||||
func TestClientLoginCert(t *testing.T) { |
||||
cert := &Certificate{ |
||||
Key: testPublicKeys["rsa"], |
||||
ValidBefore: CertTimeInfinity, |
||||
CertType: UserCert, |
||||
} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
certSigner, err := NewCertSigner(cert, testSigners["rsa"]) |
||||
if err != nil { |
||||
t.Fatalf("NewCertSigner: %v", err) |
||||
} |
||||
|
||||
clientConfig := &ClientConfig{ |
||||
User: "user", |
||||
} |
||||
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) |
||||
|
||||
t.Log("should succeed") |
||||
if err := tryAuth(t, clientConfig); err != nil { |
||||
t.Errorf("cert login failed: %v", err) |
||||
} |
||||
|
||||
t.Log("corrupted signature") |
||||
cert.Signature.Blob[0]++ |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with corrupted sig") |
||||
} |
||||
|
||||
t.Log("revoked") |
||||
cert.Serial = 666 |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("revoked cert login succeeded") |
||||
} |
||||
cert.Serial = 1 |
||||
|
||||
t.Log("sign with wrong key") |
||||
cert.SignCert(rand.Reader, testSigners["dsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with non-authoritive key") |
||||
} |
||||
|
||||
t.Log("host cert") |
||||
cert.CertType = HostCert |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with wrong type") |
||||
} |
||||
cert.CertType = UserCert |
||||
|
||||
t.Log("principal specified") |
||||
cert.ValidPrincipals = []string{"user"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err != nil { |
||||
t.Errorf("cert login failed: %v", err) |
||||
} |
||||
|
||||
t.Log("wrong principal specified") |
||||
cert.ValidPrincipals = []string{"fred"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with wrong principal") |
||||
} |
||||
cert.ValidPrincipals = nil |
||||
|
||||
t.Log("added critical option") |
||||
cert.CriticalOptions = map[string]string{"root-access": "yes"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login passed with unrecognized critical option") |
||||
} |
||||
|
||||
t.Log("allowed source address") |
||||
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err != nil { |
||||
t.Errorf("cert login with source-address failed: %v", err) |
||||
} |
||||
|
||||
t.Log("disallowed source address") |
||||
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"} |
||||
cert.SignCert(rand.Reader, testSigners["ecdsa"]) |
||||
if err := tryAuth(t, clientConfig); err == nil { |
||||
t.Errorf("cert login with source-address succeeded") |
||||
} |
||||
} |
||||
|
||||
func testPermissionsPassing(withPermissions bool, t *testing.T) { |
||||
serverConfig := &ServerConfig{ |
||||
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { |
||||
if conn.User() == "nopermissions" { |
||||
return nil, nil |
||||
} else { |
||||
return &Permissions{}, nil |
||||
} |
||||
}, |
||||
} |
||||
serverConfig.AddHostKey(testSigners["rsa"]) |
||||
|
||||
clientConfig := &ClientConfig{ |
||||
Auth: []AuthMethod{ |
||||
PublicKeys(testSigners["rsa"]), |
||||
}, |
||||
} |
||||
if withPermissions { |
||||
clientConfig.User = "permissions" |
||||
} else { |
||||
clientConfig.User = "nopermissions" |
||||
} |
||||
|
||||
c1, c2, err := netPipe() |
||||
if err != nil { |
||||
t.Fatalf("netPipe: %v", err) |
||||
} |
||||
defer c1.Close() |
||||
defer c2.Close() |
||||
|
||||
go NewClientConn(c2, "", clientConfig) |
||||
serverConn, err := newServer(c1, serverConfig) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if p := serverConn.Permissions; (p != nil) != withPermissions { |
||||
t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) |
||||
} |
||||
} |
||||
|
||||
func TestPermissionsPassing(t *testing.T) { |
||||
testPermissionsPassing(true, t) |
||||
} |
||||
|
||||
func TestNoPermissionsPassing(t *testing.T) { |
||||
testPermissionsPassing(false, t) |
||||
} |
@ -0,0 +1,39 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"net" |
||||
"testing" |
||||
) |
||||
|
||||
func testClientVersion(t *testing.T, config *ClientConfig, expected string) { |
||||
clientConn, serverConn := net.Pipe() |
||||
defer clientConn.Close() |
||||
receivedVersion := make(chan string, 1) |
||||
go func() { |
||||
version, err := readVersion(serverConn) |
||||
if err != nil { |
||||
receivedVersion <- "" |
||||
} else { |
||||
receivedVersion <- string(version) |
||||
} |
||||
serverConn.Close() |
||||
}() |
||||
NewClientConn(clientConn, "", config) |
||||
actual := <-receivedVersion |
||||
if actual != expected { |
||||
t.Fatalf("got %s; want %s", actual, expected) |
||||
} |
||||
} |
||||
|
||||
func TestCustomClientVersion(t *testing.T) { |
||||
version := "Test-Client-Version-0.0" |
||||
testClientVersion(t, &ClientConfig{ClientVersion: version}, version) |
||||
} |
||||
|
||||
func TestDefaultClientVersion(t *testing.T) { |
||||
testClientVersion(t, &ClientConfig{}, packageVersion) |
||||
} |
@ -0,0 +1,354 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto" |
||||
"crypto/rand" |
||||
"fmt" |
||||
"io" |
||||
"sync" |
||||
|
||||
_ "crypto/sha1" |
||||
_ "crypto/sha256" |
||||
_ "crypto/sha512" |
||||
) |
||||
|
||||
// These are string constants in the SSH protocol.
|
||||
const ( |
||||
compressionNone = "none" |
||||
serviceUserAuth = "ssh-userauth" |
||||
serviceSSH = "ssh-connection" |
||||
) |
||||
|
||||
// supportedCiphers specifies the supported ciphers in preference order.
|
||||
var supportedCiphers = []string{ |
||||
"aes128-ctr", "aes192-ctr", "aes256-ctr", |
||||
"aes128-gcm@openssh.com", |
||||
"arcfour256", "arcfour128", |
||||
} |
||||
|
||||
// supportedKexAlgos specifies the supported key-exchange algorithms in
|
||||
// preference order.
|
||||
var supportedKexAlgos = []string{ |
||||
kexAlgoCurve25519SHA256, |
||||
// P384 and P521 are not constant-time yet, but since we don't
|
||||
// reuse ephemeral keys, using them for ECDH should be OK.
|
||||
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, |
||||
kexAlgoDH14SHA1, kexAlgoDH1SHA1, |
||||
} |
||||
|
||||
// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
|
||||
// of authenticating servers) in preference order.
|
||||
var supportedHostKeyAlgos = []string{ |
||||
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, |
||||
CertAlgoECDSA384v01, CertAlgoECDSA521v01, |
||||
|
||||
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, |
||||
KeyAlgoRSA, KeyAlgoDSA, |
||||
} |
||||
|
||||
// supportedMACs specifies a default set of MAC algorithms in preference order.
|
||||
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
|
||||
// because they have reached the end of their useful life.
|
||||
var supportedMACs = []string{ |
||||
"hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", |
||||
} |
||||
|
||||
var supportedCompressions = []string{compressionNone} |
||||
|
||||
// hashFuncs keeps the mapping of supported algorithms to their respective
|
||||
// hashes needed for signature verification.
|
||||
var hashFuncs = map[string]crypto.Hash{ |
||||
KeyAlgoRSA: crypto.SHA1, |
||||
KeyAlgoDSA: crypto.SHA1, |
||||
KeyAlgoECDSA256: crypto.SHA256, |
||||
KeyAlgoECDSA384: crypto.SHA384, |
||||
KeyAlgoECDSA521: crypto.SHA512, |
||||
CertAlgoRSAv01: crypto.SHA1, |
||||
CertAlgoDSAv01: crypto.SHA1, |
||||
CertAlgoECDSA256v01: crypto.SHA256, |
||||
CertAlgoECDSA384v01: crypto.SHA384, |
||||
CertAlgoECDSA521v01: crypto.SHA512, |
||||
} |
||||
|
||||
// unexpectedMessageError results when the SSH message that we received didn't
|
||||
// match what we wanted.
|
||||
func unexpectedMessageError(expected, got uint8) error { |
||||
return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) |
||||
} |
||||
|
||||
// parseError results from a malformed SSH message.
|
||||
func parseError(tag uint8) error { |
||||
return fmt.Errorf("ssh: parse error in message type %d", tag) |
||||
} |
||||
|
||||
func findCommon(what string, client []string, server []string) (common string, err error) { |
||||
for _, c := range client { |
||||
for _, s := range server { |
||||
if c == s { |
||||
return c, nil |
||||
} |
||||
} |
||||
} |
||||
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) |
||||
} |
||||
|
||||
type directionAlgorithms struct { |
||||
Cipher string |
||||
MAC string |
||||
Compression string |
||||
} |
||||
|
||||
type algorithms struct { |
||||
kex string |
||||
hostKey string |
||||
w directionAlgorithms |
||||
r directionAlgorithms |
||||
} |
||||
|
||||
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { |
||||
result := &algorithms{} |
||||
|
||||
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
return result, nil |
||||
} |
||||
|
||||
// If rekeythreshold is too small, we can't make any progress sending
|
||||
// stuff.
|
||||
const minRekeyThreshold uint64 = 256 |
||||
|
||||
// Config contains configuration data common to both ServerConfig and
|
||||
// ClientConfig.
|
||||
type Config struct { |
||||
// Rand provides the source of entropy for cryptographic
|
||||
// primitives. If Rand is nil, the cryptographic random reader
|
||||
// in package crypto/rand will be used.
|
||||
Rand io.Reader |
||||
|
||||
// The maximum number of bytes sent or received after which a
|
||||
// new key is negotiated. It must be at least 256. If
|
||||
// unspecified, 1 gigabyte is used.
|
||||
RekeyThreshold uint64 |
||||
|
||||
// The allowed key exchanges algorithms. If unspecified then a
|
||||
// default set of algorithms is used.
|
||||
KeyExchanges []string |
||||
|
||||
// The allowed cipher algorithms. If unspecified then a sensible
|
||||
// default is used.
|
||||
Ciphers []string |
||||
|
||||
// The allowed MAC algorithms. If unspecified then a sensible default
|
||||
// is used.
|
||||
MACs []string |
||||
} |
||||
|
||||
// SetDefaults sets sensible values for unset fields in config. This is
|
||||
// exported for testing: Configs passed to SSH functions are copied and have
|
||||
// default values set automatically.
|
||||
func (c *Config) SetDefaults() { |
||||
if c.Rand == nil { |
||||
c.Rand = rand.Reader |
||||
} |
||||
if c.Ciphers == nil { |
||||
c.Ciphers = supportedCiphers |
||||
} |
||||
var ciphers []string |
||||
for _, c := range c.Ciphers { |
||||
if cipherModes[c] != nil { |
||||
// reject the cipher if we have no cipherModes definition
|
||||
ciphers = append(ciphers, c) |
||||
} |
||||
} |
||||
c.Ciphers = ciphers |
||||
|
||||
if c.KeyExchanges == nil { |
||||
c.KeyExchanges = supportedKexAlgos |
||||
} |
||||
|
||||
if c.MACs == nil { |
||||
c.MACs = supportedMACs |
||||
} |
||||
|
||||
if c.RekeyThreshold == 0 { |
||||
// RFC 4253, section 9 suggests rekeying after 1G.
|
||||
c.RekeyThreshold = 1 << 30 |
||||
} |
||||
if c.RekeyThreshold < minRekeyThreshold { |
||||
c.RekeyThreshold = minRekeyThreshold |
||||
} |
||||
} |
||||
|
||||
// buildDataSignedForAuth returns the data that is signed in order to prove
|
||||
// possession of a private key. See RFC 4252, section 7.
|
||||
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { |
||||
data := struct { |
||||
Session []byte |
||||
Type byte |
||||
User string |
||||
Service string |
||||
Method string |
||||
Sign bool |
||||
Algo []byte |
||||
PubKey []byte |
||||
}{ |
||||
sessionId, |
||||
msgUserAuthRequest, |
||||
req.User, |
||||
req.Service, |
||||
req.Method, |
||||
true, |
||||
algo, |
||||
pubKey, |
||||
} |
||||
return Marshal(data) |
||||
} |
||||
|
||||
func appendU16(buf []byte, n uint16) []byte { |
||||
return append(buf, byte(n>>8), byte(n)) |
||||
} |
||||
|
||||
func appendU32(buf []byte, n uint32) []byte { |
||||
return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) |
||||
} |
||||
|
||||
func appendU64(buf []byte, n uint64) []byte { |
||||
return append(buf, |
||||
byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), |
||||
byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) |
||||
} |
||||
|
||||
func appendInt(buf []byte, n int) []byte { |
||||
return appendU32(buf, uint32(n)) |
||||
} |
||||
|
||||
func appendString(buf []byte, s string) []byte { |
||||
buf = appendU32(buf, uint32(len(s))) |
||||
buf = append(buf, s...) |
||||
return buf |
||||
} |
||||
|
||||
func appendBool(buf []byte, b bool) []byte { |
||||
if b { |
||||
return append(buf, 1) |
||||
} |
||||
return append(buf, 0) |
||||
} |
||||
|
||||
// newCond is a helper to hide the fact that there is no usable zero
|
||||
// value for sync.Cond.
|
||||
func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } |
||||
|
||||
// window represents the buffer available to clients
|
||||
// wishing to write to a channel.
|
||||
type window struct { |
||||
*sync.Cond |
||||
win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1
|
||||
writeWaiters int |
||||
closed bool |
||||
} |
||||
|
||||
// add adds win to the amount of window available
|
||||
// for consumers.
|
||||
func (w *window) add(win uint32) bool { |
||||
// a zero sized window adjust is a noop.
|
||||
if win == 0 { |
||||
return true |
||||
} |
||||
w.L.Lock() |
||||
if w.win+win < win { |
||||
w.L.Unlock() |
||||
return false |
||||
} |
||||
w.win += win |
||||
// It is unusual that multiple goroutines would be attempting to reserve
|
||||
// window space, but not guaranteed. Use broadcast to notify all waiters
|
||||
// that additional window is available.
|
||||
w.Broadcast() |
||||
w.L.Unlock() |
||||
return true |
||||
} |
||||
|
||||
// close sets the window to closed, so all reservations fail
|
||||
// immediately.
|
||||
func (w *window) close() { |
||||
w.L.Lock() |
||||
w.closed = true |
||||
w.Broadcast() |
||||
w.L.Unlock() |
||||
} |
||||
|
||||
// reserve reserves win from the available window capacity.
|
||||
// If no capacity remains, reserve will block. reserve may
|
||||
// return less than requested.
|
||||
func (w *window) reserve(win uint32) (uint32, error) { |
||||
var err error |
||||
w.L.Lock() |
||||
w.writeWaiters++ |
||||
w.Broadcast() |
||||
for w.win == 0 && !w.closed { |
||||
w.Wait() |
||||
} |
||||
w.writeWaiters-- |
||||
if w.win < win { |
||||
win = w.win |
||||
} |
||||
w.win -= win |
||||
if w.closed { |
||||
err = io.EOF |
||||
} |
||||
w.L.Unlock() |
||||
return win, err |
||||
} |
||||
|
||||
// waitWriterBlocked waits until some goroutine is blocked for further
|
||||
// writes. It is used in tests only.
|
||||
func (w *window) waitWriterBlocked() { |
||||
w.Cond.L.Lock() |
||||
for w.writeWaiters == 0 { |
||||
w.Cond.Wait() |
||||
} |
||||
w.Cond.L.Unlock() |
||||
} |
@ -0,0 +1,144 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net" |
||||
) |
||||
|
||||
// OpenChannelError is returned if the other side rejects an
|
||||
// OpenChannel request.
|
||||
type OpenChannelError struct { |
||||
Reason RejectionReason |
||||
Message string |
||||
} |
||||
|
||||
func (e *OpenChannelError) Error() string { |
||||
return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) |
||||
} |
||||
|
||||
// ConnMetadata holds metadata for the connection.
|
||||
type ConnMetadata interface { |
||||
// User returns the user ID for this connection.
|
||||
// It is empty if no authentication is used.
|
||||
User() string |
||||
|
||||
// SessionID returns the sesson hash, also denoted by H.
|
||||
SessionID() []byte |
||||
|
||||
// ClientVersion returns the client's version string as hashed
|
||||
// into the session ID.
|
||||
ClientVersion() []byte |
||||
|
||||
// ServerVersion returns the server's version string as hashed
|
||||
// into the session ID.
|
||||
ServerVersion() []byte |
||||
|
||||
// RemoteAddr returns the remote address for this connection.
|
||||
RemoteAddr() net.Addr |
||||
|
||||
// LocalAddr returns the local address for this connection.
|
||||
LocalAddr() net.Addr |
||||
} |
||||
|
||||
// Conn represents an SSH connection for both server and client roles.
|
||||
// Conn is the basis for implementing an application layer, such
|
||||
// as ClientConn, which implements the traditional shell access for
|
||||
// clients.
|
||||
type Conn interface { |
||||
ConnMetadata |
||||
|
||||
// SendRequest sends a global request, and returns the
|
||||
// reply. If wantReply is true, it returns the response status
|
||||
// and payload. See also RFC4254, section 4.
|
||||
SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) |
||||
|
||||
// OpenChannel tries to open an channel. If the request is
|
||||
// rejected, it returns *OpenChannelError. On success it returns
|
||||
// the SSH Channel and a Go channel for incoming, out-of-band
|
||||
// requests. The Go channel must be serviced, or the
|
||||
// connection will hang.
|
||||
OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) |
||||
|
||||
// Close closes the underlying network connection
|
||||
Close() error |
||||
|
||||
// Wait blocks until the connection has shut down, and returns the
|
||||
// error causing the shutdown.
|
||||
Wait() error |
||||
|
||||
// TODO(hanwen): consider exposing:
|
||||
// RequestKeyChange
|
||||
// Disconnect
|
||||
} |
||||
|
||||
// DiscardRequests consumes and rejects all requests from the
|
||||
// passed-in channel.
|
||||
func DiscardRequests(in <-chan *Request) { |
||||
for req := range in { |
||||
if req.WantReply { |
||||
req.Reply(false, nil) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// A connection represents an incoming connection.
|
||||
type connection struct { |
||||
transport *handshakeTransport |
||||
sshConn |
||||
|
||||
// The connection protocol.
|
||||
*mux |
||||
} |
||||
|
||||
func (c *connection) Close() error { |
||||
return c.sshConn.conn.Close() |
||||
} |
||||
|
||||
// sshconn provides net.Conn metadata, but disallows direct reads and
|
||||
// writes.
|
||||
type sshConn struct { |
||||
conn net.Conn |
||||
|
||||
user string |
||||
sessionID []byte |
||||
clientVersion []byte |
||||
serverVersion []byte |
||||
} |
||||
|
||||
func dup(src []byte) []byte { |
||||
dst := make([]byte, len(src)) |
||||
copy(dst, src) |
||||
return dst |
||||
} |
||||
|
||||
func (c *sshConn) User() string { |
||||
return c.user |
||||
} |
||||
|
||||
func (c *sshConn) RemoteAddr() net.Addr { |
||||
return c.conn.RemoteAddr() |
||||
} |
||||
|
||||
func (c *sshConn) Close() error { |
||||
return c.conn.Close() |
||||
} |
||||
|
||||
func (c *sshConn) LocalAddr() net.Addr { |
||||
return c.conn.LocalAddr() |
||||
} |
||||
|
||||
func (c *sshConn) SessionID() []byte { |
||||
return dup(c.sessionID) |
||||
} |
||||
|
||||
func (c *sshConn) ClientVersion() []byte { |
||||
return dup(c.clientVersion) |
||||
} |
||||
|
||||
func (c *sshConn) ServerVersion() []byte { |
||||
return dup(c.serverVersion) |
||||
} |
@ -0,0 +1,18 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/* |
||||
Package ssh implements an SSH client and server. |
||||
|
||||
SSH is a transport security protocol, an authentication protocol and a |
||||
family of application protocols. The most typical application level |
||||
protocol is a remote shell and this is specifically implemented. However, |
||||
the multiplexed nature of SSH is exposed to users that wish to support |
||||
others. |
||||
|
||||
References: |
||||
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
|
||||
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
|
||||
*/ |
||||
package ssh |
@ -0,0 +1,211 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh_test |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"log" |
||||
"net" |
||||
"net/http" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh" |
||||
"github.com/gogits/gogs/modules/crypto/ssh/terminal" |
||||
) |
||||
|
||||
func ExampleNewServerConn() { |
||||
// An SSH server is represented by a ServerConfig, which holds
|
||||
// certificate details and handles authentication of ServerConns.
|
||||
config := &ssh.ServerConfig{ |
||||
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { |
||||
// Should use constant-time compare (or better, salt+hash) in
|
||||
// a production setting.
|
||||
if c.User() == "testuser" && string(pass) == "tiger" { |
||||
return nil, nil |
||||
} |
||||
return nil, fmt.Errorf("password rejected for %q", c.User()) |
||||
}, |
||||
} |
||||
|
||||
privateBytes, err := ioutil.ReadFile("id_rsa") |
||||
if err != nil { |
||||
panic("Failed to load private key") |
||||
} |
||||
|
||||
private, err := ssh.ParsePrivateKey(privateBytes) |
||||
if err != nil { |
||||
panic("Failed to parse private key") |
||||
} |
||||
|
||||
config.AddHostKey(private) |
||||
|
||||
// Once a ServerConfig has been configured, connections can be
|
||||
// accepted.
|
||||
listener, err := net.Listen("tcp", "0.0.0.0:2022") |
||||
if err != nil { |
||||
panic("failed to listen for connection") |
||||
} |
||||
nConn, err := listener.Accept() |
||||
if err != nil { |
||||
panic("failed to accept incoming connection") |
||||
} |
||||
|
||||
// Before use, a handshake must be performed on the incoming
|
||||
// net.Conn.
|
||||
_, chans, reqs, err := ssh.NewServerConn(nConn, config) |
||||
if err != nil { |
||||
panic("failed to handshake") |
||||
} |
||||
// The incoming Request channel must be serviced.
|
||||
go ssh.DiscardRequests(reqs) |
||||
|
||||
// Service the incoming Channel channel.
|
||||
for newChannel := range chans { |
||||
// Channels have a type, depending on the application level
|
||||
// protocol intended. In the case of a shell, the type is
|
||||
// "session" and ServerShell may be used to present a simple
|
||||
// terminal interface.
|
||||
if newChannel.ChannelType() != "session" { |
||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") |
||||
continue |
||||
} |
||||
channel, requests, err := newChannel.Accept() |
||||
if err != nil { |
||||
panic("could not accept channel.") |
||||
} |
||||
|
||||
// Sessions have out-of-band requests such as "shell",
|
||||
// "pty-req" and "env". Here we handle only the
|
||||
// "shell" request.
|
||||
go func(in <-chan *ssh.Request) { |
||||
for req := range in { |
||||
ok := false |
||||
switch req.Type { |
||||
case "shell": |
||||
ok = true |
||||
if len(req.Payload) > 0 { |
||||
// We don't accept any
|
||||
// commands, only the
|
||||
// default shell.
|
||||
ok = false |
||||
} |
||||
} |
||||
req.Reply(ok, nil) |
||||
} |
||||
}(requests) |
||||
|
||||
term := terminal.NewTerminal(channel, "> ") |
||||
|
||||
go func() { |
||||
defer channel.Close() |
||||
for { |
||||
line, err := term.ReadLine() |
||||
if err != nil { |
||||
break |
||||
} |
||||
fmt.Println(line) |
||||
} |
||||
}() |
||||
} |
||||
} |
||||
|
||||
func ExampleDial() { |
||||
// An SSH client is represented with a ClientConn. Currently only
|
||||
// the "password" authentication method is supported.
|
||||
//
|
||||
// To authenticate with the remote server you must pass at least one
|
||||
// implementation of AuthMethod via the Auth field in ClientConfig.
|
||||
config := &ssh.ClientConfig{ |
||||
User: "username", |
||||
Auth: []ssh.AuthMethod{ |
||||
ssh.Password("yourpassword"), |
||||
}, |
||||
} |
||||
client, err := ssh.Dial("tcp", "yourserver.com:22", config) |
||||
if err != nil { |
||||
panic("Failed to dial: " + err.Error()) |
||||
} |
||||
|
||||
// Each ClientConn can support multiple interactive sessions,
|
||||
// represented by a Session.
|
||||
session, err := client.NewSession() |
||||
if err != nil { |
||||
panic("Failed to create session: " + err.Error()) |
||||
} |
||||
defer session.Close() |
||||
|
||||
// Once a Session is created, you can execute a single command on
|
||||
// the remote side using the Run method.
|
||||
var b bytes.Buffer |
||||
session.Stdout = &b |
||||
if err := session.Run("/usr/bin/whoami"); err != nil { |
||||
panic("Failed to run: " + err.Error()) |
||||
} |
||||
fmt.Println(b.String()) |
||||
} |
||||
|
||||
func ExampleClient_Listen() { |
||||
config := &ssh.ClientConfig{ |
||||
User: "username", |
||||
Auth: []ssh.AuthMethod{ |
||||
ssh.Password("password"), |
||||
}, |
||||
} |
||||
// Dial your ssh server.
|
||||
conn, err := ssh.Dial("tcp", "localhost:22", config) |
||||
if err != nil { |
||||
log.Fatalf("unable to connect: %s", err) |
||||
} |
||||
defer conn.Close() |
||||
|
||||
// Request the remote side to open port 8080 on all interfaces.
|
||||
l, err := conn.Listen("tcp", "0.0.0.0:8080") |
||||
if err != nil { |
||||
log.Fatalf("unable to register tcp forward: %v", err) |
||||
} |
||||
defer l.Close() |
||||
|
||||
// Serve HTTP with your SSH server acting as a reverse proxy.
|
||||
http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { |
||||
fmt.Fprintf(resp, "Hello world!\n") |
||||
})) |
||||
} |
||||
|
||||
func ExampleSession_RequestPty() { |
||||
// Create client config
|
||||
config := &ssh.ClientConfig{ |
||||
User: "username", |
||||
Auth: []ssh.AuthMethod{ |
||||
ssh.Password("password"), |
||||
}, |
||||
} |
||||
// Connect to ssh server
|
||||
conn, err := ssh.Dial("tcp", "localhost:22", config) |
||||
if err != nil { |
||||
log.Fatalf("unable to connect: %s", err) |
||||
} |
||||
defer conn.Close() |
||||
// Create a session
|
||||
session, err := conn.NewSession() |
||||
if err != nil { |
||||
log.Fatalf("unable to create session: %s", err) |
||||
} |
||||
defer session.Close() |
||||
// Set up terminal modes
|
||||
modes := ssh.TerminalModes{ |
||||
ssh.ECHO: 0, // disable echoing
|
||||
ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
|
||||
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
|
||||
} |
||||
// Request pseudo terminal
|
||||
if err := session.RequestPty("xterm", 80, 40, modes); err != nil { |
||||
log.Fatalf("request for pseudo terminal failed: %s", err) |
||||
} |
||||
// Start remote shell
|
||||
if err := session.Shell(); err != nil { |
||||
log.Fatalf("failed to start shell: %s", err) |
||||
} |
||||
} |
@ -0,0 +1,412 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"net" |
||||
"sync" |
||||
) |
||||
|
||||
// debugHandshake, if set, prints messages sent and received. Key
|
||||
// exchange messages are printed as if DH were used, so the debug
|
||||
// messages are wrong when using ECDH.
|
||||
const debugHandshake = false |
||||
|
||||
// keyingTransport is a packet based transport that supports key
|
||||
// changes. It need not be thread-safe. It should pass through
|
||||
// msgNewKeys in both directions.
|
||||
type keyingTransport interface { |
||||
packetConn |
||||
|
||||
// prepareKeyChange sets up a key change. The key change for a
|
||||
// direction will be effected if a msgNewKeys message is sent
|
||||
// or received.
|
||||
prepareKeyChange(*algorithms, *kexResult) error |
||||
|
||||
// getSessionID returns the session ID. prepareKeyChange must
|
||||
// have been called once.
|
||||
getSessionID() []byte |
||||
} |
||||
|
||||
// rekeyingTransport is the interface of handshakeTransport that we
|
||||
// (internally) expose to ClientConn and ServerConn.
|
||||
type rekeyingTransport interface { |
||||
packetConn |
||||
|
||||
// requestKeyChange asks the remote side to change keys. All
|
||||
// writes are blocked until the key change succeeds, which is
|
||||
// signaled by reading a msgNewKeys.
|
||||
requestKeyChange() error |
||||
|
||||
// getSessionID returns the session ID. This is only valid
|
||||
// after the first key change has completed.
|
||||
getSessionID() []byte |
||||
} |
||||
|
||||
// handshakeTransport implements rekeying on top of a keyingTransport
|
||||
// and offers a thread-safe writePacket() interface.
|
||||
type handshakeTransport struct { |
||||
conn keyingTransport |
||||
config *Config |
||||
|
||||
serverVersion []byte |
||||
clientVersion []byte |
||||
|
||||
// hostKeys is non-empty if we are the server. In that case,
|
||||
// it contains all host keys that can be used to sign the
|
||||
// connection.
|
||||
hostKeys []Signer |
||||
|
||||
// hostKeyAlgorithms is non-empty if we are the client. In that case,
|
||||
// we accept these key types from the server as host key.
|
||||
hostKeyAlgorithms []string |
||||
|
||||
// On read error, incoming is closed, and readError is set.
|
||||
incoming chan []byte |
||||
readError error |
||||
|
||||
// data for host key checking
|
||||
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error |
||||
dialAddress string |
||||
remoteAddr net.Addr |
||||
|
||||
readSinceKex uint64 |
||||
|
||||
// Protects the writing side of the connection
|
||||
mu sync.Mutex |
||||
cond *sync.Cond |
||||
sentInitPacket []byte |
||||
sentInitMsg *kexInitMsg |
||||
writtenSinceKex uint64 |
||||
writeError error |
||||
} |
||||
|
||||
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { |
||||
t := &handshakeTransport{ |
||||
conn: conn, |
||||
serverVersion: serverVersion, |
||||
clientVersion: clientVersion, |
||||
incoming: make(chan []byte, 16), |
||||
config: config, |
||||
} |
||||
t.cond = sync.NewCond(&t.mu) |
||||
return t |
||||
} |
||||
|
||||
func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { |
||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) |
||||
t.dialAddress = dialAddr |
||||
t.remoteAddr = addr |
||||
t.hostKeyCallback = config.HostKeyCallback |
||||
if config.HostKeyAlgorithms != nil { |
||||
t.hostKeyAlgorithms = config.HostKeyAlgorithms |
||||
} else { |
||||
t.hostKeyAlgorithms = supportedHostKeyAlgos |
||||
} |
||||
go t.readLoop() |
||||
return t |
||||
} |
||||
|
||||
func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { |
||||
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) |
||||
t.hostKeys = config.hostKeys |
||||
go t.readLoop() |
||||
return t |
||||
} |
||||
|
||||
func (t *handshakeTransport) getSessionID() []byte { |
||||
return t.conn.getSessionID() |
||||
} |
||||
|
||||
func (t *handshakeTransport) id() string { |
||||
if len(t.hostKeys) > 0 { |
||||
return "server" |
||||
} |
||||
return "client" |
||||
} |
||||
|
||||
func (t *handshakeTransport) readPacket() ([]byte, error) { |
||||
p, ok := <-t.incoming |
||||
if !ok { |
||||
return nil, t.readError |
||||
} |
||||
return p, nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) readLoop() { |
||||
for { |
||||
p, err := t.readOnePacket() |
||||
if err != nil { |
||||
t.readError = err |
||||
close(t.incoming) |
||||
break |
||||
} |
||||
if p[0] == msgIgnore || p[0] == msgDebug { |
||||
continue |
||||
} |
||||
t.incoming <- p |
||||
} |
||||
|
||||
// If we can't read, declare the writing part dead too.
|
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
if t.writeError == nil { |
||||
t.writeError = t.readError |
||||
} |
||||
t.cond.Broadcast() |
||||
} |
||||
|
||||
func (t *handshakeTransport) readOnePacket() ([]byte, error) { |
||||
if t.readSinceKex > t.config.RekeyThreshold { |
||||
if err := t.requestKeyChange(); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
p, err := t.conn.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
t.readSinceKex += uint64(len(p)) |
||||
if debugHandshake { |
||||
msg, err := decode(p) |
||||
log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) |
||||
} |
||||
if p[0] != msgKexInit { |
||||
return p, nil |
||||
} |
||||
err = t.enterKeyExchange(p) |
||||
|
||||
t.mu.Lock() |
||||
if err != nil { |
||||
// drop connection
|
||||
t.conn.Close() |
||||
t.writeError = err |
||||
} |
||||
|
||||
if debugHandshake { |
||||
log.Printf("%s exited key exchange, err %v", t.id(), err) |
||||
} |
||||
|
||||
// Unblock writers.
|
||||
t.sentInitMsg = nil |
||||
t.sentInitPacket = nil |
||||
t.cond.Broadcast() |
||||
t.writtenSinceKex = 0 |
||||
t.mu.Unlock() |
||||
|
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
t.readSinceKex = 0 |
||||
return []byte{msgNewKeys}, nil |
||||
} |
||||
|
||||
// sendKexInit sends a key change message, and returns the message
|
||||
// that was sent. After initiating the key change, all writes will be
|
||||
// blocked until the change is done, and a failed key change will
|
||||
// close the underlying transport. This function is safe for
|
||||
// concurrent use by multiple goroutines.
|
||||
func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) { |
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
return t.sendKexInitLocked() |
||||
} |
||||
|
||||
func (t *handshakeTransport) requestKeyChange() error { |
||||
_, _, err := t.sendKexInit() |
||||
return err |
||||
} |
||||
|
||||
// sendKexInitLocked sends a key change message. t.mu must be locked
|
||||
// while this happens.
|
||||
func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) { |
||||
// kexInits may be sent either in response to the other side,
|
||||
// or because our side wants to initiate a key change, so we
|
||||
// may have already sent a kexInit. In that case, don't send a
|
||||
// second kexInit.
|
||||
if t.sentInitMsg != nil { |
||||
return t.sentInitMsg, t.sentInitPacket, nil |
||||
} |
||||
msg := &kexInitMsg{ |
||||
KexAlgos: t.config.KeyExchanges, |
||||
CiphersClientServer: t.config.Ciphers, |
||||
CiphersServerClient: t.config.Ciphers, |
||||
MACsClientServer: t.config.MACs, |
||||
MACsServerClient: t.config.MACs, |
||||
CompressionClientServer: supportedCompressions, |
||||
CompressionServerClient: supportedCompressions, |
||||
} |
||||
io.ReadFull(rand.Reader, msg.Cookie[:]) |
||||
|
||||
if len(t.hostKeys) > 0 { |
||||
for _, k := range t.hostKeys { |
||||
msg.ServerHostKeyAlgos = append( |
||||
msg.ServerHostKeyAlgos, k.PublicKey().Type()) |
||||
} |
||||
} else { |
||||
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms |
||||
} |
||||
packet := Marshal(msg) |
||||
|
||||
// writePacket destroys the contents, so save a copy.
|
||||
packetCopy := make([]byte, len(packet)) |
||||
copy(packetCopy, packet) |
||||
|
||||
if err := t.conn.writePacket(packetCopy); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
t.sentInitMsg = msg |
||||
t.sentInitPacket = packet |
||||
return msg, packet, nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) writePacket(p []byte) error { |
||||
t.mu.Lock() |
||||
defer t.mu.Unlock() |
||||
|
||||
if t.writtenSinceKex > t.config.RekeyThreshold { |
||||
t.sendKexInitLocked() |
||||
} |
||||
for t.sentInitMsg != nil && t.writeError == nil { |
||||
t.cond.Wait() |
||||
} |
||||
if t.writeError != nil { |
||||
return t.writeError |
||||
} |
||||
t.writtenSinceKex += uint64(len(p)) |
||||
|
||||
switch p[0] { |
||||
case msgKexInit: |
||||
return errors.New("ssh: only handshakeTransport can send kexInit") |
||||
case msgNewKeys: |
||||
return errors.New("ssh: only handshakeTransport can send newKeys") |
||||
default: |
||||
return t.conn.writePacket(p) |
||||
} |
||||
} |
||||
|
||||
func (t *handshakeTransport) Close() error { |
||||
return t.conn.Close() |
||||
} |
||||
|
||||
// enterKeyExchange runs the key exchange.
|
||||
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { |
||||
if debugHandshake { |
||||
log.Printf("%s entered key exchange", t.id()) |
||||
} |
||||
myInit, myInitPacket, err := t.sendKexInit() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
otherInit := &kexInitMsg{} |
||||
if err := Unmarshal(otherInitPacket, otherInit); err != nil { |
||||
return err |
||||
} |
||||
|
||||
magics := handshakeMagics{ |
||||
clientVersion: t.clientVersion, |
||||
serverVersion: t.serverVersion, |
||||
clientKexInit: otherInitPacket, |
||||
serverKexInit: myInitPacket, |
||||
} |
||||
|
||||
clientInit := otherInit |
||||
serverInit := myInit |
||||
if len(t.hostKeys) == 0 { |
||||
clientInit = myInit |
||||
serverInit = otherInit |
||||
|
||||
magics.clientKexInit = myInitPacket |
||||
magics.serverKexInit = otherInitPacket |
||||
} |
||||
|
||||
algs, err := findAgreedAlgorithms(clientInit, serverInit) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// We don't send FirstKexFollows, but we handle receiving it.
|
||||
if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] { |
||||
// other side sent a kex message for the wrong algorithm,
|
||||
// which we have to ignore.
|
||||
if _, err := t.conn.readPacket(); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
kex, ok := kexAlgoMap[algs.kex] |
||||
if !ok { |
||||
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) |
||||
} |
||||
|
||||
var result *kexResult |
||||
if len(t.hostKeys) > 0 { |
||||
result, err = t.server(kex, algs, &magics) |
||||
} else { |
||||
result, err = t.client(kex, algs, &magics) |
||||
} |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
t.conn.prepareKeyChange(algs, result) |
||||
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { |
||||
return err |
||||
} |
||||
if packet, err := t.conn.readPacket(); err != nil { |
||||
return err |
||||
} else if packet[0] != msgNewKeys { |
||||
return unexpectedMessageError(msgNewKeys, packet[0]) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { |
||||
var hostKey Signer |
||||
for _, k := range t.hostKeys { |
||||
if algs.hostKey == k.PublicKey().Type() { |
||||
hostKey = k |
||||
} |
||||
} |
||||
|
||||
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) |
||||
return r, err |
||||
} |
||||
|
||||
func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { |
||||
result, err := kex.Client(t.conn, t.config.Rand, magics) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hostKey, err := ParsePublicKey(result.HostKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if err := verifyHostKeySignature(hostKey, result); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if t.hostKeyCallback != nil { |
||||
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
return result, nil |
||||
} |
@ -0,0 +1,415 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"errors" |
||||
"fmt" |
||||
"net" |
||||
"runtime" |
||||
"strings" |
||||
"sync" |
||||
"testing" |
||||
) |
||||
|
||||
type testChecker struct { |
||||
calls []string |
||||
} |
||||
|
||||
func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { |
||||
if dialAddr == "bad" { |
||||
return fmt.Errorf("dialAddr is bad") |
||||
} |
||||
|
||||
if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { |
||||
return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) |
||||
} |
||||
|
||||
t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
|
||||
// therefore is buffered (net.Pipe deadlocks if both sides start with
|
||||
// a write.)
|
||||
func netPipe() (net.Conn, net.Conn, error) { |
||||
listener, err := net.Listen("tcp", "127.0.0.1:0") |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
defer listener.Close() |
||||
c1, err := net.Dial("tcp", listener.Addr().String()) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
c2, err := listener.Accept() |
||||
if err != nil { |
||||
c1.Close() |
||||
return nil, nil, err |
||||
} |
||||
|
||||
return c1, c2, nil |
||||
} |
||||
|
||||
func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { |
||||
a, b, err := netPipe() |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
trC := newTransport(a, rand.Reader, true) |
||||
trS := newTransport(b, rand.Reader, false) |
||||
clientConf.SetDefaults() |
||||
|
||||
v := []byte("version") |
||||
client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) |
||||
|
||||
serverConf := &ServerConfig{} |
||||
serverConf.AddHostKey(testSigners["ecdsa"]) |
||||
serverConf.AddHostKey(testSigners["rsa"]) |
||||
serverConf.SetDefaults() |
||||
server = newServerTransport(trS, v, v, serverConf) |
||||
|
||||
return client, server, nil |
||||
} |
||||
|
||||
func TestHandshakeBasic(t *testing.T) { |
||||
if runtime.GOOS == "plan9" { |
||||
t.Skip("see golang.org/issue/7237") |
||||
} |
||||
checker := &testChecker{} |
||||
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
|
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
go func() { |
||||
// Client writes a bunch of stuff, and does a key
|
||||
// change in the middle. This should not confuse the
|
||||
// handshake in progress
|
||||
for i := 0; i < 10; i++ { |
||||
p := []byte{msgRequestSuccess, byte(i)} |
||||
if err := trC.writePacket(p); err != nil { |
||||
t.Fatalf("sendPacket: %v", err) |
||||
} |
||||
if i == 5 { |
||||
// halfway through, we request a key change.
|
||||
_, _, err := trC.sendKexInit() |
||||
if err != nil { |
||||
t.Fatalf("sendKexInit: %v", err) |
||||
} |
||||
} |
||||
} |
||||
trC.Close() |
||||
}() |
||||
|
||||
// Server checks that client messages come in cleanly
|
||||
i := 0 |
||||
for { |
||||
p, err := trS.readPacket() |
||||
if err != nil { |
||||
break |
||||
} |
||||
if p[0] == msgNewKeys { |
||||
continue |
||||
} |
||||
want := []byte{msgRequestSuccess, byte(i)} |
||||
if bytes.Compare(p, want) != 0 { |
||||
t.Errorf("message %d: got %q, want %q", i, p, want) |
||||
} |
||||
i++ |
||||
} |
||||
if i != 10 { |
||||
t.Errorf("received %d messages, want 10.", i) |
||||
} |
||||
|
||||
// If all went well, we registered exactly 1 key change.
|
||||
if len(checker.calls) != 1 { |
||||
t.Fatalf("got %d host key checks, want 1", len(checker.calls)) |
||||
} |
||||
|
||||
pub := testSigners["ecdsa"].PublicKey() |
||||
want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) |
||||
if want != checker.calls[0] { |
||||
t.Errorf("got %q want %q for host key check", checker.calls[0], want) |
||||
} |
||||
} |
||||
|
||||
func TestHandshakeError(t *testing.T) { |
||||
checker := &testChecker{} |
||||
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
// send a packet
|
||||
packet := []byte{msgRequestSuccess, 42} |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
// Now request a key change.
|
||||
_, _, err = trC.sendKexInit() |
||||
if err != nil { |
||||
t.Errorf("sendKexInit: %v", err) |
||||
} |
||||
|
||||
// the key change will fail, and afterwards we can't write.
|
||||
if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { |
||||
t.Errorf("writePacket after botched rekey succeeded.") |
||||
} |
||||
|
||||
readback, err := trS.readPacket() |
||||
if err != nil { |
||||
t.Fatalf("server closed too soon: %v", err) |
||||
} |
||||
if bytes.Compare(readback, packet) != 0 { |
||||
t.Errorf("got %q want %q", readback, packet) |
||||
} |
||||
readback, err = trS.readPacket() |
||||
if err == nil { |
||||
t.Errorf("got a message %q after failed key change", readback) |
||||
} |
||||
} |
||||
|
||||
func TestHandshakeTwice(t *testing.T) { |
||||
checker := &testChecker{} |
||||
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
|
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
// send a packet
|
||||
packet := make([]byte, 5) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
// Now request a key change.
|
||||
_, _, err = trC.sendKexInit() |
||||
if err != nil { |
||||
t.Errorf("sendKexInit: %v", err) |
||||
} |
||||
|
||||
// Send another packet. Use a fresh one, since writePacket destroys.
|
||||
packet = make([]byte, 5) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
// 2nd key change.
|
||||
_, _, err = trC.sendKexInit() |
||||
if err != nil { |
||||
t.Errorf("sendKexInit: %v", err) |
||||
} |
||||
|
||||
packet = make([]byte, 5) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
|
||||
packet = make([]byte, 5) |
||||
packet[0] = msgRequestSuccess |
||||
for i := 0; i < 5; i++ { |
||||
msg, err := trS.readPacket() |
||||
if err != nil { |
||||
t.Fatalf("server closed too soon: %v", err) |
||||
} |
||||
if msg[0] == msgNewKeys { |
||||
continue |
||||
} |
||||
|
||||
if bytes.Compare(msg, packet) != 0 { |
||||
t.Errorf("packet %d: got %q want %q", i, msg, packet) |
||||
} |
||||
} |
||||
if len(checker.calls) != 2 { |
||||
t.Errorf("got %d key changes, want 2", len(checker.calls)) |
||||
} |
||||
} |
||||
|
||||
func TestHandshakeAutoRekeyWrite(t *testing.T) { |
||||
checker := &testChecker{} |
||||
clientConf := &ClientConfig{HostKeyCallback: checker.Check} |
||||
clientConf.RekeyThreshold = 500 |
||||
trC, trS, err := handshakePair(clientConf, "addr") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
for i := 0; i < 5; i++ { |
||||
packet := make([]byte, 251) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trC.writePacket(packet); err != nil { |
||||
t.Errorf("writePacket: %v", err) |
||||
} |
||||
} |
||||
|
||||
j := 0 |
||||
for ; j < 5; j++ { |
||||
_, err := trS.readPacket() |
||||
if err != nil { |
||||
break |
||||
} |
||||
} |
||||
|
||||
if j != 5 { |
||||
t.Errorf("got %d, want 5 messages", j) |
||||
} |
||||
|
||||
if len(checker.calls) != 2 { |
||||
t.Errorf("got %d key changes, wanted 2", len(checker.calls)) |
||||
} |
||||
} |
||||
|
||||
type syncChecker struct { |
||||
called chan int |
||||
} |
||||
|
||||
func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { |
||||
t.called <- 1 |
||||
return nil |
||||
} |
||||
|
||||
func TestHandshakeAutoRekeyRead(t *testing.T) { |
||||
sync := &syncChecker{make(chan int, 2)} |
||||
clientConf := &ClientConfig{ |
||||
HostKeyCallback: sync.Check, |
||||
} |
||||
clientConf.RekeyThreshold = 500 |
||||
|
||||
trC, trS, err := handshakePair(clientConf, "addr") |
||||
if err != nil { |
||||
t.Fatalf("handshakePair: %v", err) |
||||
} |
||||
defer trC.Close() |
||||
defer trS.Close() |
||||
|
||||
packet := make([]byte, 501) |
||||
packet[0] = msgRequestSuccess |
||||
if err := trS.writePacket(packet); err != nil { |
||||
t.Fatalf("writePacket: %v", err) |
||||
} |
||||
// While we read out the packet, a key change will be
|
||||
// initiated.
|
||||
if _, err := trC.readPacket(); err != nil { |
||||
t.Fatalf("readPacket(client): %v", err) |
||||
} |
||||
|
||||
<-sync.called |
||||
} |
||||
|
||||
// errorKeyingTransport generates errors after a given number of
|
||||
// read/write operations.
|
||||
type errorKeyingTransport struct { |
||||
packetConn |
||||
readLeft, writeLeft int |
||||
} |
||||
|
||||
func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { |
||||
return nil |
||||
} |
||||
func (n *errorKeyingTransport) getSessionID() []byte { |
||||
return nil |
||||
} |
||||
|
||||
func (n *errorKeyingTransport) writePacket(packet []byte) error { |
||||
if n.writeLeft == 0 { |
||||
n.Close() |
||||
return errors.New("barf") |
||||
} |
||||
|
||||
n.writeLeft-- |
||||
return n.packetConn.writePacket(packet) |
||||
} |
||||
|
||||
func (n *errorKeyingTransport) readPacket() ([]byte, error) { |
||||
if n.readLeft == 0 { |
||||
n.Close() |
||||
return nil, errors.New("barf") |
||||
} |
||||
|
||||
n.readLeft-- |
||||
return n.packetConn.readPacket() |
||||
} |
||||
|
||||
func TestHandshakeErrorHandlingRead(t *testing.T) { |
||||
for i := 0; i < 20; i++ { |
||||
testHandshakeErrorHandlingN(t, i, -1) |
||||
} |
||||
} |
||||
|
||||
func TestHandshakeErrorHandlingWrite(t *testing.T) { |
||||
for i := 0; i < 20; i++ { |
||||
testHandshakeErrorHandlingN(t, -1, i) |
||||
} |
||||
} |
||||
|
||||
// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
|
||||
// handshakeTransport deadlocks, the go runtime will detect it and
|
||||
// panic.
|
||||
func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { |
||||
msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) |
||||
|
||||
a, b := memPipe() |
||||
defer a.Close() |
||||
defer b.Close() |
||||
|
||||
key := testSigners["ecdsa"] |
||||
serverConf := Config{RekeyThreshold: minRekeyThreshold} |
||||
serverConf.SetDefaults() |
||||
serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) |
||||
serverConn.hostKeys = []Signer{key} |
||||
go serverConn.readLoop() |
||||
|
||||
clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} |
||||
clientConf.SetDefaults() |
||||
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) |
||||
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} |
||||
go clientConn.readLoop() |
||||
|
||||
var wg sync.WaitGroup |
||||
wg.Add(4) |
||||
|
||||
for _, hs := range []packetConn{serverConn, clientConn} { |
||||
go func(c packetConn) { |
||||
for { |
||||
err := c.writePacket(msg) |
||||
if err != nil { |
||||
break |
||||
} |
||||
} |
||||
wg.Done() |
||||
}(hs) |
||||
go func(c packetConn) { |
||||
for { |
||||
_, err := c.readPacket() |
||||
if err != nil { |
||||
break |
||||
} |
||||
} |
||||
wg.Done() |
||||
}(hs) |
||||
} |
||||
|
||||
wg.Wait() |
||||
} |
@ -0,0 +1,526 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"crypto" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/subtle" |
||||
"crypto/rand" |
||||
"errors" |
||||
"io" |
||||
"math/big" |
||||
|
||||
"golang.org/x/crypto/curve25519" |
||||
) |
||||
|
||||
const ( |
||||
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" |
||||
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" |
||||
kexAlgoECDH256 = "ecdh-sha2-nistp256" |
||||
kexAlgoECDH384 = "ecdh-sha2-nistp384" |
||||
kexAlgoECDH521 = "ecdh-sha2-nistp521" |
||||
kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org" |
||||
) |
||||
|
||||
// kexResult captures the outcome of a key exchange.
|
||||
type kexResult struct { |
||||
// Session hash. See also RFC 4253, section 8.
|
||||
H []byte |
||||
|
||||
// Shared secret. See also RFC 4253, section 8.
|
||||
K []byte |
||||
|
||||
// Host key as hashed into H.
|
||||
HostKey []byte |
||||
|
||||
// Signature of H.
|
||||
Signature []byte |
||||
|
||||
// A cryptographic hash function that matches the security
|
||||
// level of the key exchange algorithm. It is used for
|
||||
// calculating H, and for deriving keys from H and K.
|
||||
Hash crypto.Hash |
||||
|
||||
// The session ID, which is the first H computed. This is used
|
||||
// to signal data inside transport.
|
||||
SessionID []byte |
||||
} |
||||
|
||||
// handshakeMagics contains data that is always included in the
|
||||
// session hash.
|
||||
type handshakeMagics struct { |
||||
clientVersion, serverVersion []byte |
||||
clientKexInit, serverKexInit []byte |
||||
} |
||||
|
||||
func (m *handshakeMagics) write(w io.Writer) { |
||||
writeString(w, m.clientVersion) |
||||
writeString(w, m.serverVersion) |
||||
writeString(w, m.clientKexInit) |
||||
writeString(w, m.serverKexInit) |
||||
} |
||||
|
||||
// kexAlgorithm abstracts different key exchange algorithms.
|
||||
type kexAlgorithm interface { |
||||
// Server runs server-side key agreement, signing the result
|
||||
// with a hostkey.
|
||||
Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error) |
||||
|
||||
// Client runs the client-side key agreement. Caller is
|
||||
// responsible for verifying the host key signature.
|
||||
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) |
||||
} |
||||
|
||||
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
|
||||
type dhGroup struct { |
||||
g, p *big.Int |
||||
} |
||||
|
||||
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { |
||||
if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 { |
||||
return nil, errors.New("ssh: DH parameter out of bounds") |
||||
} |
||||
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil |
||||
} |
||||
|
||||
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||
hashFunc := crypto.SHA1 |
||||
|
||||
x, err := rand.Int(randSource, group.p) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
X := new(big.Int).Exp(group.g, x, group.p) |
||||
kexDHInit := kexDHInitMsg{ |
||||
X: X, |
||||
} |
||||
if err := c.writePacket(Marshal(&kexDHInit)); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var kexDHReply kexDHReplyMsg |
||||
if err = Unmarshal(packet, &kexDHReply); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
kInt, err := group.diffieHellman(kexDHReply.Y, x) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
h := hashFunc.New() |
||||
magics.write(h) |
||||
writeString(h, kexDHReply.HostKey) |
||||
writeInt(h, X) |
||||
writeInt(h, kexDHReply.Y) |
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
return &kexResult{ |
||||
H: h.Sum(nil), |
||||
K: K, |
||||
HostKey: kexDHReply.HostKey, |
||||
Signature: kexDHReply.Signature, |
||||
Hash: crypto.SHA1, |
||||
}, nil |
||||
} |
||||
|
||||
func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||
hashFunc := crypto.SHA1 |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return |
||||
} |
||||
var kexDHInit kexDHInitMsg |
||||
if err = Unmarshal(packet, &kexDHInit); err != nil { |
||||
return |
||||
} |
||||
|
||||
y, err := rand.Int(randSource, group.p) |
||||
if err != nil { |
||||
return |
||||
} |
||||
|
||||
Y := new(big.Int).Exp(group.g, y, group.p) |
||||
kInt, err := group.diffieHellman(kexDHInit.X, y) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal() |
||||
|
||||
h := hashFunc.New() |
||||
magics.write(h) |
||||
writeString(h, hostKeyBytes) |
||||
writeInt(h, kexDHInit.X) |
||||
writeInt(h, Y) |
||||
|
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
H := h.Sum(nil) |
||||
|
||||
// H is already a hash, but the hostkey signing will apply its
|
||||
// own key-specific hash algorithm.
|
||||
sig, err := signAndMarshal(priv, randSource, H) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
kexDHReply := kexDHReplyMsg{ |
||||
HostKey: hostKeyBytes, |
||||
Y: Y, |
||||
Signature: sig, |
||||
} |
||||
packet = Marshal(&kexDHReply) |
||||
|
||||
err = c.writePacket(packet) |
||||
return &kexResult{ |
||||
H: H, |
||||
K: K, |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
Hash: crypto.SHA1, |
||||
}, nil |
||||
} |
||||
|
||||
// ecdh performs Elliptic Curve Diffie-Hellman key exchange as
|
||||
// described in RFC 5656, section 4.
|
||||
type ecdh struct { |
||||
curve elliptic.Curve |
||||
} |
||||
|
||||
func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
kexInit := kexECDHInitMsg{ |
||||
ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), |
||||
} |
||||
|
||||
serialized := Marshal(&kexInit) |
||||
if err := c.writePacket(serialized); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var reply kexECDHReplyMsg |
||||
if err = Unmarshal(packet, &reply); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// generate shared secret
|
||||
secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) |
||||
|
||||
h := ecHash(kex.curve).New() |
||||
magics.write(h) |
||||
writeString(h, reply.HostKey) |
||||
writeString(h, kexInit.ClientPubKey) |
||||
writeString(h, reply.EphemeralPubKey) |
||||
K := make([]byte, intLength(secret)) |
||||
marshalInt(K, secret) |
||||
h.Write(K) |
||||
|
||||
return &kexResult{ |
||||
H: h.Sum(nil), |
||||
K: K, |
||||
HostKey: reply.HostKey, |
||||
Signature: reply.Signature, |
||||
Hash: ecHash(kex.curve), |
||||
}, nil |
||||
} |
||||
|
||||
// unmarshalECKey parses and checks an EC key.
|
||||
func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { |
||||
x, y = elliptic.Unmarshal(curve, pubkey) |
||||
if x == nil { |
||||
return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") |
||||
} |
||||
if !validateECPublicKey(curve, x, y) { |
||||
return nil, nil, errors.New("ssh: public key not on curve") |
||||
} |
||||
return x, y, nil |
||||
} |
||||
|
||||
// validateECPublicKey checks that the point is a valid public key for
|
||||
// the given curve. See [SEC1], 3.2.2
|
||||
func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { |
||||
if x.Sign() == 0 && y.Sign() == 0 { |
||||
return false |
||||
} |
||||
|
||||
if x.Cmp(curve.Params().P) >= 0 { |
||||
return false |
||||
} |
||||
|
||||
if y.Cmp(curve.Params().P) >= 0 { |
||||
return false |
||||
} |
||||
|
||||
if !curve.IsOnCurve(x, y) { |
||||
return false |
||||
} |
||||
|
||||
// We don't check if N * PubKey == 0, since
|
||||
//
|
||||
// - the NIST curves have cofactor = 1, so this is implicit.
|
||||
// (We don't foresee an implementation that supports non NIST
|
||||
// curves)
|
||||
//
|
||||
// - for ephemeral keys, we don't need to worry about small
|
||||
// subgroup attacks.
|
||||
return true |
||||
} |
||||
|
||||
func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var kexECDHInit kexECDHInitMsg |
||||
if err = Unmarshal(packet, &kexECDHInit); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// We could cache this key across multiple users/multiple
|
||||
// connection attempts, but the benefit is small. OpenSSH
|
||||
// generates a new key for each incoming connection.
|
||||
ephKey, err := ecdsa.GenerateKey(kex.curve, rand) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal() |
||||
|
||||
serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) |
||||
|
||||
// generate shared secret
|
||||
secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) |
||||
|
||||
h := ecHash(kex.curve).New() |
||||
magics.write(h) |
||||
writeString(h, hostKeyBytes) |
||||
writeString(h, kexECDHInit.ClientPubKey) |
||||
writeString(h, serializedEphKey) |
||||
|
||||
K := make([]byte, intLength(secret)) |
||||
marshalInt(K, secret) |
||||
h.Write(K) |
||||
|
||||
H := h.Sum(nil) |
||||
|
||||
// H is already a hash, but the hostkey signing will apply its
|
||||
// own key-specific hash algorithm.
|
||||
sig, err := signAndMarshal(priv, rand, H) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
reply := kexECDHReplyMsg{ |
||||
EphemeralPubKey: serializedEphKey, |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
} |
||||
|
||||
serialized := Marshal(&reply) |
||||
if err := c.writePacket(serialized); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &kexResult{ |
||||
H: H, |
||||
K: K, |
||||
HostKey: reply.HostKey, |
||||
Signature: sig, |
||||
Hash: ecHash(kex.curve), |
||||
}, nil |
||||
} |
||||
|
||||
var kexAlgoMap = map[string]kexAlgorithm{} |
||||
|
||||
func init() { |
||||
// This is the group called diffie-hellman-group1-sha1 in RFC
|
||||
// 4253 and Oakley Group 2 in RFC 2409.
|
||||
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) |
||||
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ |
||||
g: new(big.Int).SetInt64(2), |
||||
p: p, |
||||
} |
||||
|
||||
// This is the group called diffie-hellman-group14-sha1 in RFC
|
||||
// 4253 and Oakley Group 14 in RFC 3526.
|
||||
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) |
||||
|
||||
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ |
||||
g: new(big.Int).SetInt64(2), |
||||
p: p, |
||||
} |
||||
|
||||
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} |
||||
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} |
||||
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} |
||||
kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} |
||||
} |
||||
|
||||
// curve25519sha256 implements the curve25519-sha256@libssh.org key
|
||||
// agreement protocol, as described in
|
||||
// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
|
||||
type curve25519sha256 struct{} |
||||
|
||||
type curve25519KeyPair struct { |
||||
priv [32]byte |
||||
pub [32]byte |
||||
} |
||||
|
||||
func (kp *curve25519KeyPair) generate(rand io.Reader) error { |
||||
if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { |
||||
return err |
||||
} |
||||
curve25519.ScalarBaseMult(&kp.pub, &kp.priv) |
||||
return nil |
||||
} |
||||
|
||||
// curve25519Zeros is just an array of 32 zero bytes so that we have something
|
||||
// convenient to compare against in order to reject curve25519 points with the
|
||||
// wrong order.
|
||||
var curve25519Zeros [32]byte |
||||
|
||||
func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { |
||||
var kp curve25519KeyPair |
||||
if err := kp.generate(rand); err != nil { |
||||
return nil, err |
||||
} |
||||
if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var reply kexECDHReplyMsg |
||||
if err = Unmarshal(packet, &reply); err != nil { |
||||
return nil, err |
||||
} |
||||
if len(reply.EphemeralPubKey) != 32 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length") |
||||
} |
||||
|
||||
var servPub, secret [32]byte |
||||
copy(servPub[:], reply.EphemeralPubKey) |
||||
curve25519.ScalarMult(&secret, &kp.priv, &servPub) |
||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order") |
||||
} |
||||
|
||||
h := crypto.SHA256.New() |
||||
magics.write(h) |
||||
writeString(h, reply.HostKey) |
||||
writeString(h, kp.pub[:]) |
||||
writeString(h, reply.EphemeralPubKey) |
||||
|
||||
kInt := new(big.Int).SetBytes(secret[:]) |
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
return &kexResult{ |
||||
H: h.Sum(nil), |
||||
K: K, |
||||
HostKey: reply.HostKey, |
||||
Signature: reply.Signature, |
||||
Hash: crypto.SHA256, |
||||
}, nil |
||||
} |
||||
|
||||
func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { |
||||
packet, err := c.readPacket() |
||||
if err != nil { |
||||
return |
||||
} |
||||
var kexInit kexECDHInitMsg |
||||
if err = Unmarshal(packet, &kexInit); err != nil { |
||||
return |
||||
} |
||||
|
||||
if len(kexInit.ClientPubKey) != 32 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong length") |
||||
} |
||||
|
||||
var kp curve25519KeyPair |
||||
if err := kp.generate(rand); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var clientPub, secret [32]byte |
||||
copy(clientPub[:], kexInit.ClientPubKey) |
||||
curve25519.ScalarMult(&secret, &kp.priv, &clientPub) |
||||
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { |
||||
return nil, errors.New("ssh: peer's curve25519 public value has wrong order") |
||||
} |
||||
|
||||
hostKeyBytes := priv.PublicKey().Marshal() |
||||
|
||||
h := crypto.SHA256.New() |
||||
magics.write(h) |
||||
writeString(h, hostKeyBytes) |
||||
writeString(h, kexInit.ClientPubKey) |
||||
writeString(h, kp.pub[:]) |
||||
|
||||
kInt := new(big.Int).SetBytes(secret[:]) |
||||
K := make([]byte, intLength(kInt)) |
||||
marshalInt(K, kInt) |
||||
h.Write(K) |
||||
|
||||
H := h.Sum(nil) |
||||
|
||||
sig, err := signAndMarshal(priv, rand, H) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
reply := kexECDHReplyMsg{ |
||||
EphemeralPubKey: kp.pub[:], |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
} |
||||
if err := c.writePacket(Marshal(&reply)); err != nil { |
||||
return nil, err |
||||
} |
||||
return &kexResult{ |
||||
H: H, |
||||
K: K, |
||||
HostKey: hostKeyBytes, |
||||
Signature: sig, |
||||
Hash: crypto.SHA256, |
||||
}, nil |
||||
} |
@ -0,0 +1,50 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
// Key exchange tests.
|
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"reflect" |
||||
"testing" |
||||
) |
||||
|
||||
func TestKexes(t *testing.T) { |
||||
type kexResultErr struct { |
||||
result *kexResult |
||||
err error |
||||
} |
||||
|
||||
for name, kex := range kexAlgoMap { |
||||
a, b := memPipe() |
||||
|
||||
s := make(chan kexResultErr, 1) |
||||
c := make(chan kexResultErr, 1) |
||||
var magics handshakeMagics |
||||
go func() { |
||||
r, e := kex.Client(a, rand.Reader, &magics) |
||||
a.Close() |
||||
c <- kexResultErr{r, e} |
||||
}() |
||||
go func() { |
||||
r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"]) |
||||
b.Close() |
||||
s <- kexResultErr{r, e} |
||||
}() |
||||
|
||||
clientRes := <-c |
||||
serverRes := <-s |
||||
if clientRes.err != nil { |
||||
t.Errorf("client: %v", clientRes.err) |
||||
} |
||||
if serverRes.err != nil { |
||||
t.Errorf("server: %v", serverRes.err) |
||||
} |
||||
if !reflect.DeepEqual(clientRes.result, serverRes.result) { |
||||
t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) |
||||
} |
||||
} |
||||
} |
@ -0,0 +1,628 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto" |
||||
"crypto/dsa" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/rsa" |
||||
"crypto/x509" |
||||
"encoding/asn1" |
||||
"encoding/base64" |
||||
"encoding/pem" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"math/big" |
||||
) |
||||
|
||||
// These constants represent the algorithm names for key types supported by this
|
||||
// package.
|
||||
const ( |
||||
KeyAlgoRSA = "ssh-rsa" |
||||
KeyAlgoDSA = "ssh-dss" |
||||
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" |
||||
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" |
||||
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" |
||||
) |
||||
|
||||
// parsePubKey parses a public key of the given algorithm.
|
||||
// Use ParsePublicKey for keys with prepended algorithm.
|
||||
func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { |
||||
switch algo { |
||||
case KeyAlgoRSA: |
||||
return parseRSA(in) |
||||
case KeyAlgoDSA: |
||||
return parseDSA(in) |
||||
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: |
||||
return parseECDSA(in) |
||||
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: |
||||
cert, err := parseCert(in, certToPrivAlgo(algo)) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
return cert, nil, nil |
||||
} |
||||
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err) |
||||
} |
||||
|
||||
// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
|
||||
// (see sshd(8) manual page) once the options and key type fields have been
|
||||
// removed.
|
||||
func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { |
||||
in = bytes.TrimSpace(in) |
||||
|
||||
i := bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
i = len(in) |
||||
} |
||||
base64Key := in[:i] |
||||
|
||||
key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) |
||||
n, err := base64.StdEncoding.Decode(key, base64Key) |
||||
if err != nil { |
||||
return nil, "", err |
||||
} |
||||
key = key[:n] |
||||
out, err = ParsePublicKey(key) |
||||
if err != nil { |
||||
return nil, "", err |
||||
} |
||||
comment = string(bytes.TrimSpace(in[i:])) |
||||
return out, comment, nil |
||||
} |
||||
|
||||
// ParseAuthorizedKeys parses a public key from an authorized_keys
|
||||
// file used in OpenSSH according to the sshd(8) manual page.
|
||||
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { |
||||
for len(in) > 0 { |
||||
end := bytes.IndexByte(in, '\n') |
||||
if end != -1 { |
||||
rest = in[end+1:] |
||||
in = in[:end] |
||||
} else { |
||||
rest = nil |
||||
} |
||||
|
||||
end = bytes.IndexByte(in, '\r') |
||||
if end != -1 { |
||||
in = in[:end] |
||||
} |
||||
|
||||
in = bytes.TrimSpace(in) |
||||
if len(in) == 0 || in[0] == '#' { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
i := bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { |
||||
return out, comment, options, rest, nil |
||||
} |
||||
|
||||
// No key type recognised. Maybe there's an options field at
|
||||
// the beginning.
|
||||
var b byte |
||||
inQuote := false |
||||
var candidateOptions []string |
||||
optionStart := 0 |
||||
for i, b = range in { |
||||
isEnd := !inQuote && (b == ' ' || b == '\t') |
||||
if (b == ',' && !inQuote) || isEnd { |
||||
if i-optionStart > 0 { |
||||
candidateOptions = append(candidateOptions, string(in[optionStart:i])) |
||||
} |
||||
optionStart = i + 1 |
||||
} |
||||
if isEnd { |
||||
break |
||||
} |
||||
if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { |
||||
inQuote = !inQuote |
||||
} |
||||
} |
||||
for i < len(in) && (in[i] == ' ' || in[i] == '\t') { |
||||
i++ |
||||
} |
||||
if i == len(in) { |
||||
// Invalid line: unmatched quote
|
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
in = in[i:] |
||||
i = bytes.IndexAny(in, " \t") |
||||
if i == -1 { |
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { |
||||
options = candidateOptions |
||||
return out, comment, options, rest, nil |
||||
} |
||||
|
||||
in = rest |
||||
continue |
||||
} |
||||
|
||||
return nil, "", nil, nil, errors.New("ssh: no key found") |
||||
} |
||||
|
||||
// ParsePublicKey parses an SSH public key formatted for use in
|
||||
// the SSH wire protocol according to RFC 4253, section 6.6.
|
||||
func ParsePublicKey(in []byte) (out PublicKey, err error) { |
||||
algo, in, ok := parseString(in) |
||||
if !ok { |
||||
return nil, errShortRead |
||||
} |
||||
var rest []byte |
||||
out, rest, err = parsePubKey(in, string(algo)) |
||||
if len(rest) > 0 { |
||||
return nil, errors.New("ssh: trailing junk in public key") |
||||
} |
||||
|
||||
return out, err |
||||
} |
||||
|
||||
// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH
|
||||
// authorized_keys file. The return value ends with newline.
|
||||
func MarshalAuthorizedKey(key PublicKey) []byte { |
||||
b := &bytes.Buffer{} |
||||
b.WriteString(key.Type()) |
||||
b.WriteByte(' ') |
||||
e := base64.NewEncoder(base64.StdEncoding, b) |
||||
e.Write(key.Marshal()) |
||||
e.Close() |
||||
b.WriteByte('\n') |
||||
return b.Bytes() |
||||
} |
||||
|
||||
// PublicKey is an abstraction of different types of public keys.
|
||||
type PublicKey interface { |
||||
// Type returns the key's type, e.g. "ssh-rsa".
|
||||
Type() string |
||||
|
||||
// Marshal returns the serialized key data in SSH wire format,
|
||||
// with the name prefix.
|
||||
Marshal() []byte |
||||
|
||||
// Verify that sig is a signature on the given data using this
|
||||
// key. This function will hash the data appropriately first.
|
||||
Verify(data []byte, sig *Signature) error |
||||
} |
||||
|
||||
// A Signer can create signatures that verify against a public key.
|
||||
type Signer interface { |
||||
// PublicKey returns an associated PublicKey instance.
|
||||
PublicKey() PublicKey |
||||
|
||||
// Sign returns raw signature for the given data. This method
|
||||
// will apply the hash specified for the keytype to the data.
|
||||
Sign(rand io.Reader, data []byte) (*Signature, error) |
||||
} |
||||
|
||||
type rsaPublicKey rsa.PublicKey |
||||
|
||||
func (r *rsaPublicKey) Type() string { |
||||
return "ssh-rsa" |
||||
} |
||||
|
||||
// parseRSA parses an RSA key according to RFC 4253, section 6.6.
|
||||
func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
E *big.Int |
||||
N *big.Int |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
if w.E.BitLen() > 24 { |
||||
return nil, nil, errors.New("ssh: exponent too large") |
||||
} |
||||
e := w.E.Int64() |
||||
if e < 3 || e&1 == 0 { |
||||
return nil, nil, errors.New("ssh: incorrect exponent") |
||||
} |
||||
|
||||
var key rsa.PublicKey |
||||
key.E = int(e) |
||||
key.N = w.N |
||||
return (*rsaPublicKey)(&key), w.Rest, nil |
||||
} |
||||
|
||||
func (r *rsaPublicKey) Marshal() []byte { |
||||
e := new(big.Int).SetInt64(int64(r.E)) |
||||
wirekey := struct { |
||||
Name string |
||||
E *big.Int |
||||
N *big.Int |
||||
}{ |
||||
KeyAlgoRSA, |
||||
e, |
||||
r.N, |
||||
} |
||||
return Marshal(&wirekey) |
||||
} |
||||
|
||||
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||
if sig.Format != r.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) |
||||
} |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) |
||||
} |
||||
|
||||
type rsaPrivateKey struct { |
||||
*rsa.PrivateKey |
||||
} |
||||
|
||||
func (r *rsaPrivateKey) PublicKey() PublicKey { |
||||
return (*rsaPublicKey)(&r.PrivateKey.PublicKey) |
||||
} |
||||
|
||||
func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return &Signature{ |
||||
Format: r.PublicKey().Type(), |
||||
Blob: blob, |
||||
}, nil |
||||
} |
||||
|
||||
type dsaPublicKey dsa.PublicKey |
||||
|
||||
func (r *dsaPublicKey) Type() string { |
||||
return "ssh-dss" |
||||
} |
||||
|
||||
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
|
||||
func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
P, Q, G, Y *big.Int |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
key := &dsaPublicKey{ |
||||
Parameters: dsa.Parameters{ |
||||
P: w.P, |
||||
Q: w.Q, |
||||
G: w.G, |
||||
}, |
||||
Y: w.Y, |
||||
} |
||||
return key, w.Rest, nil |
||||
} |
||||
|
||||
func (k *dsaPublicKey) Marshal() []byte { |
||||
w := struct { |
||||
Name string |
||||
P, Q, G, Y *big.Int |
||||
}{ |
||||
k.Type(), |
||||
k.P, |
||||
k.Q, |
||||
k.G, |
||||
k.Y, |
||||
} |
||||
|
||||
return Marshal(&w) |
||||
} |
||||
|
||||
func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||
if sig.Format != k.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) |
||||
} |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
|
||||
// Per RFC 4253, section 6.6,
|
||||
// The value for 'dss_signature_blob' is encoded as a string containing
|
||||
// r, followed by s (which are 160-bit integers, without lengths or
|
||||
// padding, unsigned, and in network byte order).
|
||||
// For DSS purposes, sig.Blob should be exactly 40 bytes in length.
|
||||
if len(sig.Blob) != 40 { |
||||
return errors.New("ssh: DSA signature parse error") |
||||
} |
||||
r := new(big.Int).SetBytes(sig.Blob[:20]) |
||||
s := new(big.Int).SetBytes(sig.Blob[20:]) |
||||
if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { |
||||
return nil |
||||
} |
||||
return errors.New("ssh: signature did not verify") |
||||
} |
||||
|
||||
type dsaPrivateKey struct { |
||||
*dsa.PrivateKey |
||||
} |
||||
|
||||
func (k *dsaPrivateKey) PublicKey() PublicKey { |
||||
return (*dsaPublicKey)(&k.PrivateKey.PublicKey) |
||||
} |
||||
|
||||
func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
h := crypto.SHA1.New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
r, s, err := dsa.Sign(rand, k.PrivateKey, digest) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
sig := make([]byte, 40) |
||||
rb := r.Bytes() |
||||
sb := s.Bytes() |
||||
|
||||
copy(sig[20-len(rb):20], rb) |
||||
copy(sig[40-len(sb):], sb) |
||||
|
||||
return &Signature{ |
||||
Format: k.PublicKey().Type(), |
||||
Blob: sig, |
||||
}, nil |
||||
} |
||||
|
||||
type ecdsaPublicKey ecdsa.PublicKey |
||||
|
||||
func (key *ecdsaPublicKey) Type() string { |
||||
return "ecdsa-sha2-" + key.nistID() |
||||
} |
||||
|
||||
func (key *ecdsaPublicKey) nistID() string { |
||||
switch key.Params().BitSize { |
||||
case 256: |
||||
return "nistp256" |
||||
case 384: |
||||
return "nistp384" |
||||
case 521: |
||||
return "nistp521" |
||||
} |
||||
panic("ssh: unsupported ecdsa key size") |
||||
} |
||||
|
||||
func supportedEllipticCurve(curve elliptic.Curve) bool { |
||||
return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() |
||||
} |
||||
|
||||
// ecHash returns the hash to match the given elliptic curve, see RFC
|
||||
// 5656, section 6.2.1
|
||||
func ecHash(curve elliptic.Curve) crypto.Hash { |
||||
bitSize := curve.Params().BitSize |
||||
switch { |
||||
case bitSize <= 256: |
||||
return crypto.SHA256 |
||||
case bitSize <= 384: |
||||
return crypto.SHA384 |
||||
} |
||||
return crypto.SHA512 |
||||
} |
||||
|
||||
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
|
||||
func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { |
||||
var w struct { |
||||
Curve string |
||||
KeyBytes []byte |
||||
Rest []byte `ssh:"rest"` |
||||
} |
||||
|
||||
if err := Unmarshal(in, &w); err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
key := new(ecdsa.PublicKey) |
||||
|
||||
switch w.Curve { |
||||
case "nistp256": |
||||
key.Curve = elliptic.P256() |
||||
case "nistp384": |
||||
key.Curve = elliptic.P384() |
||||
case "nistp521": |
||||
key.Curve = elliptic.P521() |
||||
default: |
||||
return nil, nil, errors.New("ssh: unsupported curve") |
||||
} |
||||
|
||||
key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) |
||||
if key.X == nil || key.Y == nil { |
||||
return nil, nil, errors.New("ssh: invalid curve point") |
||||
} |
||||
return (*ecdsaPublicKey)(key), w.Rest, nil |
||||
} |
||||
|
||||
func (key *ecdsaPublicKey) Marshal() []byte { |
||||
// See RFC 5656, section 3.1.
|
||||
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) |
||||
w := struct { |
||||
Name string |
||||
ID string |
||||
Key []byte |
||||
}{ |
||||
key.Type(), |
||||
key.nistID(), |
||||
keyBytes, |
||||
} |
||||
|
||||
return Marshal(&w) |
||||
} |
||||
|
||||
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { |
||||
if sig.Format != key.Type() { |
||||
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) |
||||
} |
||||
|
||||
h := ecHash(key.Curve).New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
|
||||
// Per RFC 5656, section 3.1.2,
|
||||
// The ecdsa_signature_blob value has the following specific encoding:
|
||||
// mpint r
|
||||
// mpint s
|
||||
var ecSig struct { |
||||
R *big.Int |
||||
S *big.Int |
||||
} |
||||
|
||||
if err := Unmarshal(sig.Blob, &ecSig); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { |
||||
return nil |
||||
} |
||||
return errors.New("ssh: signature did not verify") |
||||
} |
||||
|
||||
type ecdsaPrivateKey struct { |
||||
*ecdsa.PrivateKey |
||||
} |
||||
|
||||
func (k *ecdsaPrivateKey) PublicKey() PublicKey { |
||||
return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey) |
||||
} |
||||
|
||||
func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { |
||||
h := ecHash(k.PrivateKey.PublicKey.Curve).New() |
||||
h.Write(data) |
||||
digest := h.Sum(nil) |
||||
r, s, err := ecdsa.Sign(rand, k.PrivateKey, digest) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
sig := make([]byte, intLength(r)+intLength(s)) |
||||
rest := marshalInt(sig, r) |
||||
marshalInt(rest, s) |
||||
return &Signature{ |
||||
Format: k.PublicKey().Type(), |
||||
Blob: sig, |
||||
}, nil |
||||
} |
||||
|
||||
// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey
|
||||
// returns a corresponding Signer instance. EC keys should use P256,
|
||||
// P384 or P521.
|
||||
func NewSignerFromKey(k interface{}) (Signer, error) { |
||||
var sshKey Signer |
||||
switch t := k.(type) { |
||||
case *rsa.PrivateKey: |
||||
sshKey = &rsaPrivateKey{t} |
||||
case *dsa.PrivateKey: |
||||
sshKey = &dsaPrivateKey{t} |
||||
case *ecdsa.PrivateKey: |
||||
if !supportedEllipticCurve(t.Curve) { |
||||
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") |
||||
} |
||||
|
||||
sshKey = &ecdsaPrivateKey{t} |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %T", k) |
||||
} |
||||
return sshKey, nil |
||||
} |
||||
|
||||
// NewPublicKey takes a pointer to rsa, dsa or ecdsa PublicKey
|
||||
// and returns a corresponding ssh PublicKey instance. EC keys should use P256, P384 or P521.
|
||||
func NewPublicKey(k interface{}) (PublicKey, error) { |
||||
var sshKey PublicKey |
||||
switch t := k.(type) { |
||||
case *rsa.PublicKey: |
||||
sshKey = (*rsaPublicKey)(t) |
||||
case *ecdsa.PublicKey: |
||||
if !supportedEllipticCurve(t.Curve) { |
||||
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") |
||||
} |
||||
sshKey = (*ecdsaPublicKey)(t) |
||||
case *dsa.PublicKey: |
||||
sshKey = (*dsaPublicKey)(t) |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %T", k) |
||||
} |
||||
return sshKey, nil |
||||
} |
||||
|
||||
// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
|
||||
// the same keys as ParseRawPrivateKey.
|
||||
func ParsePrivateKey(pemBytes []byte) (Signer, error) { |
||||
key, err := ParseRawPrivateKey(pemBytes) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return NewSignerFromKey(key) |
||||
} |
||||
|
||||
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
|
||||
// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
|
||||
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { |
||||
block, _ := pem.Decode(pemBytes) |
||||
if block == nil { |
||||
return nil, errors.New("ssh: no key found") |
||||
} |
||||
|
||||
switch block.Type { |
||||
case "RSA PRIVATE KEY": |
||||
return x509.ParsePKCS1PrivateKey(block.Bytes) |
||||
case "EC PRIVATE KEY": |
||||
return x509.ParseECPrivateKey(block.Bytes) |
||||
case "DSA PRIVATE KEY": |
||||
return ParseDSAPrivateKey(block.Bytes) |
||||
default: |
||||
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) |
||||
} |
||||
} |
||||
|
||||
// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as
|
||||
// specified by the OpenSSL DSA man page.
|
||||
func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { |
||||
var k struct { |
||||
Version int |
||||
P *big.Int |
||||
Q *big.Int |
||||
G *big.Int |
||||
Priv *big.Int |
||||
Pub *big.Int |
||||
} |
||||
rest, err := asn1.Unmarshal(der, &k) |
||||
if err != nil { |
||||
return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) |
||||
} |
||||
if len(rest) > 0 { |
||||
return nil, errors.New("ssh: garbage after DSA key") |
||||
} |
||||
|
||||
return &dsa.PrivateKey{ |
||||
PublicKey: dsa.PublicKey{ |
||||
Parameters: dsa.Parameters{ |
||||
P: k.P, |
||||
Q: k.Q, |
||||
G: k.G, |
||||
}, |
||||
Y: k.Priv, |
||||
}, |
||||
X: k.Pub, |
||||
}, nil |
||||
} |
@ -0,0 +1,306 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/dsa" |
||||
"crypto/ecdsa" |
||||
"crypto/elliptic" |
||||
"crypto/rand" |
||||
"crypto/rsa" |
||||
"encoding/base64" |
||||
"fmt" |
||||
"reflect" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/gogits/gogs/modules/crypto/ssh/testdata" |
||||
) |
||||
|
||||
func rawKey(pub PublicKey) interface{} { |
||||
switch k := pub.(type) { |
||||
case *rsaPublicKey: |
||||
return (*rsa.PublicKey)(k) |
||||
case *dsaPublicKey: |
||||
return (*dsa.PublicKey)(k) |
||||
case *ecdsaPublicKey: |
||||
return (*ecdsa.PublicKey)(k) |
||||
case *Certificate: |
||||
return k |
||||
} |
||||
panic("unknown key type") |
||||
} |
||||
|
||||
func TestKeyMarshalParse(t *testing.T) { |
||||
for _, priv := range testSigners { |
||||
pub := priv.PublicKey() |
||||
roundtrip, err := ParsePublicKey(pub.Marshal()) |
||||
if err != nil { |
||||
t.Errorf("ParsePublicKey(%T): %v", pub, err) |
||||
} |
||||
|
||||
k1 := rawKey(pub) |
||||
k2 := rawKey(roundtrip) |
||||
|
||||
if !reflect.DeepEqual(k1, k2) { |
||||
t.Errorf("got %#v in roundtrip, want %#v", k2, k1) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestUnsupportedCurves(t *testing.T) { |
||||
raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) |
||||
if err != nil { |
||||
t.Fatalf("GenerateKey: %v", err) |
||||
} |
||||
|
||||
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") { |
||||
t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err) |
||||
} |
||||
|
||||
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") { |
||||
t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestNewPublicKey(t *testing.T) { |
||||
for _, k := range testSigners { |
||||
raw := rawKey(k.PublicKey()) |
||||
// Skip certificates, as NewPublicKey does not support them.
|
||||
if _, ok := raw.(*Certificate); ok { |
||||
continue |
||||
} |
||||
pub, err := NewPublicKey(raw) |
||||
if err != nil { |
||||
t.Errorf("NewPublicKey(%#v): %v", raw, err) |
||||
} |
||||
if !reflect.DeepEqual(k.PublicKey(), pub) { |
||||
t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestKeySignVerify(t *testing.T) { |
||||
for _, priv := range testSigners { |
||||
pub := priv.PublicKey() |
||||
|
||||
data := []byte("sign me") |
||||
sig, err := priv.Sign(rand.Reader, data) |
||||
if err != nil { |
||||
t.Fatalf("Sign(%T): %v", priv, err) |
||||
} |
||||
|
||||
if err := pub.Verify(data, sig); err != nil { |
||||
t.Errorf("publicKey.Verify(%T): %v", priv, err) |
||||
} |
||||
sig.Blob[5]++ |
||||
if err := pub.Verify(data, sig); err == nil { |
||||
t.Errorf("publicKey.Verify on broken sig did not fail") |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestParseRSAPrivateKey(t *testing.T) { |
||||
key := testPrivateKeys["rsa"] |
||||
|
||||
rsa, ok := key.(*rsa.PrivateKey) |
||||
if !ok { |
||||
t.Fatalf("got %T, want *rsa.PrivateKey", rsa) |
||||
} |
||||
|
||||
if err := rsa.Validate(); err != nil { |
||||
t.Errorf("Validate: %v", err) |
||||
} |
||||
} |
||||
|
||||
func TestParseECPrivateKey(t *testing.T) { |
||||
key := testPrivateKeys["ecdsa"] |
||||
|
||||
ecKey, ok := key.(*ecdsa.PrivateKey) |
||||
if !ok { |
||||
t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) |
||||
} |
||||
|
||||
if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { |
||||
t.Fatalf("public key does not validate.") |
||||
} |
||||
} |
||||
|
||||
func TestParseDSA(t *testing.T) { |
||||
// We actually exercise the ParsePrivateKey codepath here, as opposed to
|
||||
// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
|
||||
// uses.
|
||||
s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) |
||||
if err != nil { |
||||
t.Fatalf("ParsePrivateKey returned error: %s", err) |
||||
} |
||||
|
||||
data := []byte("sign me") |
||||
sig, err := s.Sign(rand.Reader, data) |
||||
if err != nil { |
||||
t.Fatalf("dsa.Sign: %v", err) |
||||
} |
||||
|
||||
if err := s.PublicKey().Verify(data, sig); err != nil { |
||||
t.Errorf("Verify failed: %v", err) |
||||
} |
||||
} |
||||
|
||||
// Tests for authorized_keys parsing.
|
||||
|
||||
// getTestKey returns a public key, and its base64 encoding.
|
||||
func getTestKey() (PublicKey, string) { |
||||
k := testPublicKeys["rsa"] |
||||
|
||||
b := &bytes.Buffer{} |
||||
e := base64.NewEncoder(base64.StdEncoding, b) |
||||
e.Write(k.Marshal()) |
||||
e.Close() |
||||
|
||||
return k, b.String() |
||||
} |
||||
|
||||
func TestMarshalParsePublicKey(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) |
||||
|
||||
authKeys := MarshalAuthorizedKey(pub) |
||||
actualFields := strings.Fields(string(authKeys)) |
||||
if len(actualFields) == 0 { |
||||
t.Fatalf("failed authKeys: %v", authKeys) |
||||
} |
||||
|
||||
// drop the comment
|
||||
expectedFields := strings.Fields(line)[0:2] |
||||
|
||||
if !reflect.DeepEqual(actualFields, expectedFields) { |
||||
t.Errorf("got %v, expected %v", actualFields, expectedFields) |
||||
} |
||||
|
||||
actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) |
||||
if err != nil { |
||||
t.Fatalf("cannot parse %v: %v", line, err) |
||||
} |
||||
if !reflect.DeepEqual(actPub, pub) { |
||||
t.Errorf("got %v, expected %v", actPub, pub) |
||||
} |
||||
} |
||||
|
||||
type authResult struct { |
||||
pubKey PublicKey |
||||
options []string |
||||
comments string |
||||
rest string |
||||
ok bool |
||||
} |
||||
|
||||
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) { |
||||
rest := authKeys |
||||
var values []authResult |
||||
for len(rest) > 0 { |
||||
var r authResult |
||||
var err error |
||||
r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) |
||||
r.ok = (err == nil) |
||||
t.Log(err) |
||||
r.rest = string(rest) |
||||
values = append(values, r) |
||||
} |
||||
|
||||
if !reflect.DeepEqual(values, expected) { |
||||
t.Errorf("got %#v, expected %#v", values, expected) |
||||
} |
||||
} |
||||
|
||||
func TestAuthorizedKeyBasic(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
line := "ssh-rsa " + pubSerialized + " user@host" |
||||
testAuthorizedKeys(t, []byte(line), |
||||
[]authResult{ |
||||
{pub, nil, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuth(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithOptions := []string{ |
||||
`# comments to ignore before any keys...`, |
||||
``, |
||||
`env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, |
||||
`# comments to ignore, along with a blank line`, |
||||
``, |
||||
`env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, |
||||
``, |
||||
`# more comments, plus a invalid entry`, |
||||
`ssh-rsa data-that-will-not-parse user@host3`, |
||||
} |
||||
for _, eol := range []string{"\n", "\r\n"} { |
||||
authOptions := strings.Join(authWithOptions, eol) |
||||
rest2 := strings.Join(authWithOptions[3:], eol) |
||||
rest3 := strings.Join(authWithOptions[6:], eol) |
||||
testAuthorizedKeys(t, []byte(authOptions), []authResult{ |
||||
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, |
||||
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, |
||||
{nil, nil, "", "", false}, |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestAuthWithQuotedSpaceInEnv(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) |
||||
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{ |
||||
{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuthWithQuotedCommaInEnv(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) |
||||
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{ |
||||
{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuthWithQuotedQuoteInEnv(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) |
||||
authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) |
||||
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{ |
||||
{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, |
||||
}) |
||||
|
||||
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{ |
||||
{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuthWithInvalidSpace(t *testing.T) { |
||||
_, pubSerialized := getTestKey() |
||||
authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host |
||||
#more to follow but still no valid keys`) |
||||
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{ |
||||
{nil, nil, "", "", false}, |
||||
}) |
||||
} |
||||
|
||||
func TestAuthWithMissingQuote(t *testing.T) { |
||||
pub, pubSerialized := getTestKey() |
||||
authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host |
||||
env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) |
||||
|
||||
testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{ |
||||
{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, |
||||
}) |
||||
} |
||||
|
||||
func TestInvalidEntry(t *testing.T) { |
||||
authInvalid := []byte(`ssh-rsa`) |
||||
_, _, _, _, err := ParseAuthorizedKey(authInvalid) |
||||
if err == nil { |
||||
t.Errorf("got valid entry for %q", authInvalid) |
||||
} |
||||
} |
@ -0,0 +1,57 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
// Message authentication support
|
||||
|
||||
import ( |
||||
"crypto/hmac" |
||||
"crypto/sha1" |
||||
"crypto/sha256" |
||||
"hash" |
||||
) |
||||
|
||||
type macMode struct { |
||||
keySize int |
||||
new func(key []byte) hash.Hash |
||||
} |
||||
|
||||
// truncatingMAC wraps around a hash.Hash and truncates the output digest to
|
||||
// a given size.
|
||||
type truncatingMAC struct { |
||||
length int |
||||
hmac hash.Hash |
||||
} |
||||
|
||||
func (t truncatingMAC) Write(data []byte) (int, error) { |
||||
return t.hmac.Write(data) |
||||
} |
||||
|
||||
func (t truncatingMAC) Sum(in []byte) []byte { |
||||
out := t.hmac.Sum(in) |
||||
return out[:len(in)+t.length] |
||||
} |
||||
|
||||
func (t truncatingMAC) Reset() { |
||||
t.hmac.Reset() |
||||
} |
||||
|
||||
func (t truncatingMAC) Size() int { |
||||
return t.length |
||||
} |
||||
|
||||
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } |
||||
|
||||
var macModes = map[string]*macMode{ |
||||
"hmac-sha2-256": {32, func(key []byte) hash.Hash { |
||||
return hmac.New(sha256.New, key) |
||||
}}, |
||||
"hmac-sha1": {20, func(key []byte) hash.Hash { |
||||
return hmac.New(sha1.New, key) |
||||
}}, |
||||
"hmac-sha1-96": {20, func(key []byte) hash.Hash { |
||||
return truncatingMAC{12, hmac.New(sha1.New, key)} |
||||
}}, |
||||
} |
@ -0,0 +1,110 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ssh |
||||
|
||||
import ( |
||||
"io" |
||||
"sync" |
||||
"testing" |
||||
) |
||||
|
||||
// An in-memory packetConn. It is safe to call Close and writePacket
|
||||
// from different goroutines.
|
||||
type memTransport struct { |
||||
eof bool |
||||
pending [][]byte |
||||
write *memTransport |
||||
sync.Mutex |
||||
*sync.Cond |
||||
} |
||||
|
||||
func (t *memTransport) readPacket() ([]byte, error) { |
||||
t.Lock() |
||||
defer t.Unlock() |
||||
for { |
||||
if len(t.pending) > 0 { |
||||
r := t.pending[0] |
||||
t.pending = t.pending[1:] |
||||
return r, nil |
||||
} |
||||
if t.eof { |
||||
return nil, io.EOF |
||||
} |
||||
t.Cond.Wait() |
||||
} |
||||
} |
||||
|
||||
func (t *memTransport) closeSelf() error { |
||||
t.Lock() |
||||
defer t.Unlock() |
||||
if t.eof { |
||||
return io.EOF |
||||
} |
||||
t.eof = true |
||||
t.Cond.Broadcast() |
||||
return nil |
||||
} |
||||
|
||||
func (t *memTransport) Close() error { |
||||
err := t.write.closeSelf() |
||||
t.closeSelf() |
||||
return err |
||||
} |
||||
|
||||
func (t *memTransport) writePacket(p []byte) error { |
||||
t.write.Lock() |
||||
defer t.write.Unlock() |
||||
if t.write.eof { |
||||
return io.EOF |
||||
} |
||||
c := make([]byte, len(p)) |
||||
copy(c, p) |
||||
t.write.pending = append(t.write.pending, c) |
||||
t.write.Cond.Signal() |
||||
return nil |
||||
} |
||||
|
||||
func memPipe() (a, b packetConn) { |
||||
t1 := memTransport{} |
||||
t2 := memTransport{} |
||||
t1.write = &t2 |
||||
t2.write = &t1 |
||||
t1.Cond = sync.NewCond(&t1.Mutex) |
||||
t2.Cond = sync.NewCond(&t2.Mutex) |
||||
return &t1, &t2 |
||||
} |
||||
|
||||
func TestMemPipe(t *testing.T) { |
||||
a, b := memPipe() |
||||
if err := a.writePacket([]byte{42}); err != nil { |
||||
t.Fatalf("writePacket: %v", err) |
||||
} |
||||
if err := a.Close(); err != nil { |
||||
t.Fatal("Close: ", err) |
||||
} |
||||
p, err := b.readPacket() |
||||
if err != nil { |
||||
t.Fatal("readPacket: ", err) |
||||
} |
||||
if len(p) != 1 || p[0] != 42 { |
||||
t.Fatalf("got %v, want {42}", p) |
||||
} |
||||
p, err = b.readPacket() |
||||
if err != io.EOF { |
||||
t.Fatalf("got %v, %v, want EOF", p, err) |
||||
} |
||||
} |
||||
|
||||
func TestDoubleClose(t *testing.T) { |
||||
a, _ := memPipe() |
||||
err := a.Close() |
||||
if err != nil { |
||||
t.Errorf("Close: %v", err) |
||||
} |
||||
err = a.Close() |
||||
if err != io.EOF { |
||||
t.Errorf("expect EOF on double close.") |
||||
} |
||||
} |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue