mirror of https://github.com/gogits/gogs.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
356 lines
7.8 KiB
356 lines
7.8 KiB
// Copyright 2013 The Go Authors. All rights reserved. |
|
// Use of this source code is governed by a BSD-style |
|
// license that can be found in the LICENSE file. |
|
|
|
package ssh |
|
|
|
import ( |
|
"encoding/binary" |
|
"fmt" |
|
"io" |
|
"log" |
|
"sync" |
|
"sync/atomic" |
|
) |
|
|
|
// debugMux, if set, causes messages in the connection protocol to be |
|
// logged. |
|
const debugMux = false |
|
|
|
// chanList is a thread safe channel list. |
|
type chanList struct { |
|
// protects concurrent access to chans |
|
sync.Mutex |
|
|
|
// chans are indexed by the local id of the channel, which the |
|
// other side should send in the PeersId field. |
|
chans []*channel |
|
|
|
// This is a debugging aid: it offsets all IDs by this |
|
// amount. This helps distinguish otherwise identical |
|
// server/client muxes |
|
offset uint32 |
|
} |
|
|
|
// Assigns a channel ID to the given channel. |
|
func (c *chanList) add(ch *channel) uint32 { |
|
c.Lock() |
|
defer c.Unlock() |
|
for i := range c.chans { |
|
if c.chans[i] == nil { |
|
c.chans[i] = ch |
|
return uint32(i) + c.offset |
|
} |
|
} |
|
c.chans = append(c.chans, ch) |
|
return uint32(len(c.chans)-1) + c.offset |
|
} |
|
|
|
// getChan returns the channel for the given ID. |
|
func (c *chanList) getChan(id uint32) *channel { |
|
id -= c.offset |
|
|
|
c.Lock() |
|
defer c.Unlock() |
|
if id < uint32(len(c.chans)) { |
|
return c.chans[id] |
|
} |
|
return nil |
|
} |
|
|
|
func (c *chanList) remove(id uint32) { |
|
id -= c.offset |
|
c.Lock() |
|
if id < uint32(len(c.chans)) { |
|
c.chans[id] = nil |
|
} |
|
c.Unlock() |
|
} |
|
|
|
// dropAll forgets all channels it knows, returning them in a slice. |
|
func (c *chanList) dropAll() []*channel { |
|
c.Lock() |
|
defer c.Unlock() |
|
var r []*channel |
|
|
|
for _, ch := range c.chans { |
|
if ch == nil { |
|
continue |
|
} |
|
r = append(r, ch) |
|
} |
|
c.chans = nil |
|
return r |
|
} |
|
|
|
// mux represents the state for the SSH connection protocol, which |
|
// multiplexes many channels onto a single packet transport. |
|
type mux struct { |
|
conn packetConn |
|
chanList chanList |
|
|
|
incomingChannels chan NewChannel |
|
|
|
globalSentMu sync.Mutex |
|
globalResponses chan interface{} |
|
incomingRequests chan *Request |
|
|
|
errCond *sync.Cond |
|
err error |
|
} |
|
|
|
// When debugging, each new chanList instantiation has a different |
|
// offset. |
|
var globalOff uint32 |
|
|
|
func (m *mux) Wait() error { |
|
m.errCond.L.Lock() |
|
defer m.errCond.L.Unlock() |
|
for m.err == nil { |
|
m.errCond.Wait() |
|
} |
|
return m.err |
|
} |
|
|
|
// newMux returns a mux that runs over the given connection. |
|
func newMux(p packetConn) *mux { |
|
m := &mux{ |
|
conn: p, |
|
incomingChannels: make(chan NewChannel, 16), |
|
globalResponses: make(chan interface{}, 1), |
|
incomingRequests: make(chan *Request, 16), |
|
errCond: newCond(), |
|
} |
|
if debugMux { |
|
m.chanList.offset = atomic.AddUint32(&globalOff, 1) |
|
} |
|
|
|
go m.loop() |
|
return m |
|
} |
|
|
|
func (m *mux) sendMessage(msg interface{}) error { |
|
p := Marshal(msg) |
|
return m.conn.writePacket(p) |
|
} |
|
|
|
func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { |
|
if wantReply { |
|
m.globalSentMu.Lock() |
|
defer m.globalSentMu.Unlock() |
|
} |
|
|
|
if err := m.sendMessage(globalRequestMsg{ |
|
Type: name, |
|
WantReply: wantReply, |
|
Data: payload, |
|
}); err != nil { |
|
return false, nil, err |
|
} |
|
|
|
if !wantReply { |
|
return false, nil, nil |
|
} |
|
|
|
msg, ok := <-m.globalResponses |
|
if !ok { |
|
return false, nil, io.EOF |
|
} |
|
switch msg := msg.(type) { |
|
case *globalRequestFailureMsg: |
|
return false, msg.Data, nil |
|
case *globalRequestSuccessMsg: |
|
return true, msg.Data, nil |
|
default: |
|
return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) |
|
} |
|
} |
|
|
|
// ackRequest must be called after processing a global request that |
|
// has WantReply set. |
|
func (m *mux) ackRequest(ok bool, data []byte) error { |
|
if ok { |
|
return m.sendMessage(globalRequestSuccessMsg{Data: data}) |
|
} |
|
return m.sendMessage(globalRequestFailureMsg{Data: data}) |
|
} |
|
|
|
// TODO(hanwen): Disconnect is a transport layer message. We should |
|
// probably send and receive Disconnect somewhere in the transport |
|
// code. |
|
|
|
// Disconnect sends a disconnect message. |
|
func (m *mux) Disconnect(reason uint32, message string) error { |
|
return m.sendMessage(disconnectMsg{ |
|
Reason: reason, |
|
Message: message, |
|
}) |
|
} |
|
|
|
func (m *mux) Close() error { |
|
return m.conn.Close() |
|
} |
|
|
|
// loop runs the connection machine. It will process packets until an |
|
// error is encountered. To synchronize on loop exit, use mux.Wait. |
|
func (m *mux) loop() { |
|
var err error |
|
for err == nil { |
|
err = m.onePacket() |
|
} |
|
|
|
for _, ch := range m.chanList.dropAll() { |
|
ch.close() |
|
} |
|
|
|
close(m.incomingChannels) |
|
close(m.incomingRequests) |
|
close(m.globalResponses) |
|
|
|
m.conn.Close() |
|
|
|
m.errCond.L.Lock() |
|
m.err = err |
|
m.errCond.Broadcast() |
|
m.errCond.L.Unlock() |
|
|
|
if debugMux { |
|
log.Println("loop exit", err) |
|
} |
|
} |
|
|
|
// onePacket reads and processes one packet. |
|
func (m *mux) onePacket() error { |
|
packet, err := m.conn.readPacket() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if debugMux { |
|
if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { |
|
log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) |
|
} else { |
|
p, _ := decode(packet) |
|
log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) |
|
} |
|
} |
|
|
|
switch packet[0] { |
|
case msgNewKeys: |
|
// Ignore notification of key change. |
|
return nil |
|
case msgDisconnect: |
|
return m.handleDisconnect(packet) |
|
case msgChannelOpen: |
|
return m.handleChannelOpen(packet) |
|
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: |
|
return m.handleGlobalPacket(packet) |
|
} |
|
|
|
// assume a channel packet. |
|
if len(packet) < 5 { |
|
return parseError(packet[0]) |
|
} |
|
id := binary.BigEndian.Uint32(packet[1:]) |
|
ch := m.chanList.getChan(id) |
|
if ch == nil { |
|
return fmt.Errorf("ssh: invalid channel %d", id) |
|
} |
|
|
|
return ch.handlePacket(packet) |
|
} |
|
|
|
func (m *mux) handleDisconnect(packet []byte) error { |
|
var d disconnectMsg |
|
if err := Unmarshal(packet, &d); err != nil { |
|
return err |
|
} |
|
|
|
if debugMux { |
|
log.Printf("caught disconnect: %v", d) |
|
} |
|
return &d |
|
} |
|
|
|
func (m *mux) handleGlobalPacket(packet []byte) error { |
|
msg, err := decode(packet) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
switch msg := msg.(type) { |
|
case *globalRequestMsg: |
|
m.incomingRequests <- &Request{ |
|
Type: msg.Type, |
|
WantReply: msg.WantReply, |
|
Payload: msg.Data, |
|
mux: m, |
|
} |
|
case *globalRequestSuccessMsg, *globalRequestFailureMsg: |
|
m.globalResponses <- msg |
|
default: |
|
panic(fmt.Sprintf("not a global message %#v", msg)) |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// handleChannelOpen schedules a channel to be Accept()ed. |
|
func (m *mux) handleChannelOpen(packet []byte) error { |
|
var msg channelOpenMsg |
|
if err := Unmarshal(packet, &msg); err != nil { |
|
return err |
|
} |
|
|
|
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
|
failMsg := channelOpenFailureMsg{ |
|
PeersId: msg.PeersId, |
|
Reason: ConnectionFailed, |
|
Message: "invalid request", |
|
Language: "en_US.UTF-8", |
|
} |
|
return m.sendMessage(failMsg) |
|
} |
|
|
|
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) |
|
c.remoteId = msg.PeersId |
|
c.maxRemotePayload = msg.MaxPacketSize |
|
c.remoteWin.add(msg.PeersWindow) |
|
m.incomingChannels <- c |
|
return nil |
|
} |
|
|
|
func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { |
|
ch, err := m.openChannel(chanType, extra) |
|
if err != nil { |
|
return nil, nil, err |
|
} |
|
|
|
return ch, ch.incomingRequests, nil |
|
} |
|
|
|
func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { |
|
ch := m.newChannel(chanType, channelOutbound, extra) |
|
|
|
ch.maxIncomingPayload = channelMaxPacket |
|
|
|
open := channelOpenMsg{ |
|
ChanType: chanType, |
|
PeersWindow: ch.myWindow, |
|
MaxPacketSize: ch.maxIncomingPayload, |
|
TypeSpecificData: extra, |
|
PeersId: ch.localId, |
|
} |
|
if err := m.sendMessage(open); err != nil { |
|
return nil, err |
|
} |
|
|
|
switch msg := (<-ch.msg).(type) { |
|
case *channelOpenConfirmMsg: |
|
return ch, nil |
|
case *channelOpenFailureMsg: |
|
return nil, &OpenChannelError{msg.Reason, msg.Message} |
|
default: |
|
return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) |
|
} |
|
}
|
|
|