Converted xsnet Read() ctrlStatOp logic to switch

This commit is contained in:
Russ Magee 2022-09-25 11:33:09 -07:00
parent 667328a91c
commit b2e43f4bad
1 changed files with 28 additions and 13 deletions

View File

@ -1195,6 +1195,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
var hmacIn [HMAC_CHK_SZ]uint8 var hmacIn [HMAC_CHK_SZ]uint8
var payloadLen uint32 var payloadLen uint32
//------------- Read ctrl/status opcode --------------------
// Read ctrl/status opcode (CSOHmacInvalid on hmac mismatch) // Read ctrl/status opcode (CSOHmacInvalid on hmac mismatch)
err = binary.Read(*hc.c, binary.BigEndian, &ctrlStatOp) err = binary.Read(*hc.c, binary.BigEndian, &ctrlStatOp)
if err != nil { if err != nil {
@ -1215,7 +1216,9 @@ func (hc Conn) Read(b []byte) (n int, err error) {
hc.Close() hc.Close()
return 0, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **") return 0, errors.New("** ALERT - remote end detected HMAC mismatch - possible channel tampering **")
} }
//----------------------------------------------------------
//------------------ Read HMAC len ------------------------
// Read the hmac and payload len first // Read the hmac and payload len first
err = binary.Read(*hc.c, binary.BigEndian, &hmacIn) err = binary.Read(*hc.c, binary.BigEndian, &hmacIn)
if err != nil { if err != nil {
@ -1230,7 +1233,9 @@ func (hc Conn) Read(b []byte) (n int, err error) {
logger.LogDebug(etxt) logger.LogDebug(etxt)
return 0, errors.New(etxt) return 0, errors.New(etxt)
} }
//----------------------------------------------------------
//------------------ Read Payload len ---------------------
err = binary.Read(*hc.c, binary.BigEndian, &payloadLen) err = binary.Read(*hc.c, binary.BigEndian, &payloadLen)
if err != nil { if err != nil {
if err.Error() == "EOF" { if err.Error() == "EOF" {
@ -1244,6 +1249,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
logger.LogDebug(etxt) logger.LogDebug(etxt)
return 0, errors.New(etxt) return 0, errors.New(etxt)
} }
//----------------------------------------------------------
if payloadLen > MAX_PAYLOAD_LEN { if payloadLen > MAX_PAYLOAD_LEN {
logger.LogDebug(fmt.Sprintf("[Insane payloadLen:%v]\n", payloadLen)) logger.LogDebug(fmt.Sprintf("[Insane payloadLen:%v]\n", payloadLen))
@ -1251,6 +1257,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
return 1, errors.New("Insane payloadLen") return 1, errors.New("Insane payloadLen")
} }
//-------------------- Read Payload ------------------------
var payloadBytes = make([]byte, payloadLen) var payloadBytes = make([]byte, payloadLen)
n, err = io.ReadFull(*hc.c, payloadBytes) n, err = io.ReadFull(*hc.c, payloadBytes)
if err != nil { if err != nil {
@ -1265,12 +1272,14 @@ func (hc Conn) Read(b []byte) (n int, err error) {
logger.LogDebug(etxt) logger.LogDebug(etxt)
return 0, errors.New(etxt) return 0, errors.New(etxt)
} }
//----------------------------------------------------------
if hc.logCipherText { if hc.logCipherText {
log.Printf(" <:ctext:\r\n%s\r\n", hex.Dump(payloadBytes[:n])) log.Printf(" <:ctext:\r\n%s\r\n", hex.Dump(payloadBytes[:n]))
} }
//fmt.Printf(" <:ctext:\r\n%s\r\n", hex.Dump(payloadBytes[:n])) //fmt.Printf(" <:ctext:\r\n%s\r\n", hex.Dump(payloadBytes[:n]))
//---------------- Verify Payload via HMAC -----------------
hc.rm.Write(payloadBytes) // Calc hmac on received data hc.rm.Write(payloadBytes) // Calc hmac on received data
hTmp := hc.rm.Sum(nil)[0:HMAC_CHK_SZ] hTmp := hc.rm.Sum(nil)[0:HMAC_CHK_SZ]
//log.Printf("<%04x) HMAC:(i)%s (c)%02x\r\n", decryptN, hex.EncodeToString([]byte(hmacIn[0:])), hTmp) //log.Printf("<%04x) HMAC:(i)%s (c)%02x\r\n", decryptN, hex.EncodeToString([]byte(hmacIn[0:])), hTmp)
@ -1280,7 +1289,9 @@ func (hc Conn) Read(b []byte) (n int, err error) {
logger.LogDebug(fmt.Sprintln("** ALERT - detected HMAC mismatch, possible channel tampering **")) logger.LogDebug(fmt.Sprintln("** ALERT - detected HMAC mismatch, possible channel tampering **"))
_, _ = (*hc.c).Write([]byte{CSOHmacInvalid}) _, _ = (*hc.c).Write([]byte{CSOHmacInvalid})
} }
//----------------------------------------------------------
//------------------- Decrypt Payload ----------------------
db := bytes.NewBuffer(payloadBytes[:n]) //copying payloadBytes to db db := bytes.NewBuffer(payloadBytes[:n]) //copying payloadBytes to db
// The StreamReader acts like a pipe, decrypting // The StreamReader acts like a pipe, decrypting
// whatever is available and forwarding the result // whatever is available and forwarding the result
@ -1289,6 +1300,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
// The caller isn't necessarily reading the full payload so we need // The caller isn't necessarily reading the full payload so we need
// to decrypt to an intermediate buffer, draining it on demand of caller // to decrypt to an intermediate buffer, draining it on demand of caller
decryptN, err := rs.Read(payloadBytes) decryptN, err := rs.Read(payloadBytes)
//----------------------------------------------------------
if hc.logPlainText { if hc.logPlainText {
log.Printf(" <:ptext:\r\n%s\r\n", hex.Dump(payloadBytes[:n])) log.Printf(" <:ptext:\r\n%s\r\n", hex.Dump(payloadBytes[:n]))
@ -1297,6 +1309,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
log.Println("xsnet.Read():", err) log.Println("xsnet.Read():", err)
//panic(err) //panic(err)
} else { } else {
//------------ Discard Padding ---------------------
// Padding: Read padSide, padLen, (padding | d) or (d | padding) // Padding: Read padSide, padLen, (padding | d) or (d | padding)
padSide := payloadBytes[0] padSide := payloadBytes[0]
padLen := payloadBytes[1] padLen := payloadBytes[1]
@ -1307,15 +1320,17 @@ func (hc Conn) Read(b []byte) (n int, err error) {
} else { } else {
payloadBytes = payloadBytes[0 : len(payloadBytes)-int(padLen)] payloadBytes = payloadBytes[0 : len(payloadBytes)-int(padLen)]
} }
//--------------------------------------------------
// Throw away pkt if it's chaff (ie., caller to Read() won't see this data) switch ctrlStatOp {
if ctrlStatOp == CSOChaff { case CSOChaff:
// Throw away pkt if it's chaff (ie., caller to Read() won't see this data)
log.Printf("[Chaff pkt, discarded (len %d)]\n", decryptN) log.Printf("[Chaff pkt, discarded (len %d)]\n", decryptN)
} else if ctrlStatOp == CSOTermSize { case CSOTermSize:
fmt.Sscanf(string(payloadBytes), "%d %d", &hc.Rows, &hc.Cols) fmt.Sscanf(string(payloadBytes), "%d %d", &hc.Rows, &hc.Cols)
log.Printf("[TermSize pkt: rows %v cols %v]\n", hc.Rows, hc.Cols) log.Printf("[TermSize pkt: rows %v cols %v]\n", hc.Rows, hc.Cols)
hc.WinCh <- WinSize{hc.Rows, hc.Cols} hc.WinCh <- WinSize{hc.Rows, hc.Cols}
} else if ctrlStatOp == CSOExitStatus { case CSOExitStatus:
if len(payloadBytes) > 0 { if len(payloadBytes) > 0 {
hc.SetStatus(CSOType(binary.BigEndian.Uint32(payloadBytes))) hc.SetStatus(CSOType(binary.BigEndian.Uint32(payloadBytes)))
} else { } else {
@ -1323,7 +1338,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
hc.SetStatus(CSETruncCSO) hc.SetStatus(CSETruncCSO)
} }
hc.Close() hc.Close()
} else if ctrlStatOp == CSOTunSetup { case CSOTunSetup:
// server side tunnel setup in response to client // server side tunnel setup in response to client
lport := binary.BigEndian.Uint16(payloadBytes[0:2]) lport := binary.BigEndian.Uint16(payloadBytes[0:2])
rport := binary.BigEndian.Uint16(payloadBytes[2:4]) rport := binary.BigEndian.Uint16(payloadBytes[2:4])
@ -1335,7 +1350,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunSetup [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Server] Got CSOTunSetup [%d:%d]", lport, rport))
} }
(*hc.tuns)[rport].Ctl <- 'd' // Dial() rport (*hc.tuns)[rport].Ctl <- 'd' // Dial() rport
} else if ctrlStatOp == CSOTunSetupAck { case CSOTunSetupAck:
lport := binary.BigEndian.Uint16(payloadBytes[0:2]) lport := binary.BigEndian.Uint16(payloadBytes[0:2])
rport := binary.BigEndian.Uint16(payloadBytes[2:4]) rport := binary.BigEndian.Uint16(payloadBytes[2:4])
if _, ok := (*hc.tuns)[rport]; !ok { if _, ok := (*hc.tuns)[rport]; !ok {
@ -1346,7 +1361,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunSetupAck [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Client] Got CSOTunSetupAck [%d:%d]", lport, rport))
} }
(*hc.tuns)[rport].Ctl <- 'a' // Listen() for lport connection (*hc.tuns)[rport].Ctl <- 'a' // Listen() for lport connection
} else if ctrlStatOp == CSOTunRefused { case CSOTunRefused:
// client side receiving CSOTunRefused means the remote side // client side receiving CSOTunRefused means the remote side
// could not dial() rport. So we cannot yet listen() // could not dial() rport. So we cannot yet listen()
// for client-side on lport. // for client-side on lport.
@ -1358,7 +1373,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
} else { } else {
logger.LogDebug(fmt.Sprintf("[Client] CSOTunRefused on already-closed tun [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Client] CSOTunRefused on already-closed tun [%d:%d]", lport, rport))
} }
} else if ctrlStatOp == CSOTunDisconn { case CSOTunDisconn:
// server side's rport has disconnected (server lost) // server side's rport has disconnected (server lost)
lport := binary.BigEndian.Uint16(payloadBytes[0:2]) lport := binary.BigEndian.Uint16(payloadBytes[0:2])
rport := binary.BigEndian.Uint16(payloadBytes[2:4]) rport := binary.BigEndian.Uint16(payloadBytes[2:4])
@ -1368,7 +1383,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
} else { } else {
logger.LogDebug(fmt.Sprintf("[Client] CSOTunDisconn on already-closed tun [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Client] CSOTunDisconn on already-closed tun [%d:%d]", lport, rport))
} }
} else if ctrlStatOp == CSOTunHangup { case CSOTunHangup:
// client side's lport has hung up // client side's lport has hung up
lport := binary.BigEndian.Uint16(payloadBytes[0:2]) lport := binary.BigEndian.Uint16(payloadBytes[0:2])
rport := binary.BigEndian.Uint16(payloadBytes[2:4]) rport := binary.BigEndian.Uint16(payloadBytes[2:4])
@ -1378,7 +1393,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
} else { } else {
logger.LogDebug(fmt.Sprintf("[Server] CSOTunHangup to already-closed tun [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Server] CSOTunHangup to already-closed tun [%d:%d]", lport, rport))
} }
} else if ctrlStatOp == CSOTunData { case CSOTunData:
lport := binary.BigEndian.Uint16(payloadBytes[0:2]) lport := binary.BigEndian.Uint16(payloadBytes[0:2])
rport := binary.BigEndian.Uint16(payloadBytes[2:4]) rport := binary.BigEndian.Uint16(payloadBytes[2:4])
//fmt.Printf("[Got CSOTunData: [lport %d:rport %d] data:%v\n", lport, rport, payloadBytes[4:]) //fmt.Printf("[Got CSOTunData: [lport %d:rport %d] data:%v\n", lport, rport, payloadBytes[4:])
@ -1391,7 +1406,7 @@ func (hc Conn) Read(b []byte) (n int, err error) {
} else { } else {
logger.LogDebug(fmt.Sprintf("[Attempt to write data to closed tun [%d:%d]", lport, rport)) logger.LogDebug(fmt.Sprintf("[Attempt to write data to closed tun [%d:%d]", lport, rport))
} }
} else if ctrlStatOp == CSOTunKeepAlive { case CSOTunKeepAlive:
// client side has sent keepalive for tunnels -- if client // client side has sent keepalive for tunnels -- if client
// dies or exits unexpectedly the absence of this will // dies or exits unexpectedly the absence of this will
// let the server know to hang up on Dial()ed server rports. // let the server know to hang up on Dial()ed server rports.
@ -1402,9 +1417,9 @@ func (hc Conn) Read(b []byte) (n int, err error) {
t.KeepAlive = 0 t.KeepAlive = 0
hc.Unlock() hc.Unlock()
} }
} else if ctrlStatOp == CSONone { case CSONone:
hc.dBuf.Write(payloadBytes) hc.dBuf.Write(payloadBytes)
} else { default:
logger.LogDebug(fmt.Sprintf("[Unknown CSOType:%d]", ctrlStatOp)) logger.LogDebug(fmt.Sprintf("[Unknown CSOType:%d]", ctrlStatOp))
} }
} }