package ssh import ( "bytes" "fmt" "io" "os" "os/user" "path/filepath" "strings" "sync" "time" "be.ems/src/framework/logger" "be.ems/src/framework/utils/cmd" gosftp "github.com/pkg/sftp" gossh "golang.org/x/crypto/ssh" ) // ConnSSH 连接SSH对象 type ConnSSH struct { User string `json:"user"` // 主机用户名 Addr string `json:"addr"` // 主机地址 Port int64 `json:"port"` // SSH端口 AuthMode string `json:"authMode"` // 认证模式(0密码 1主机私钥) Password string `json:"password"` // 认证密码 PrivateKey string `json:"privateKey"` // 认证私钥 PassPhrase string `json:"passPhrase"` // 认证私钥密码 DialTimeOut time.Duration `json:"dialTimeOut"` // 连接超时断开 Client *gossh.Client `json:"client"` LastResult string `json:"lastResult"` // 记最后一次执行命令的结果 } // NewClient 创建SSH客户端 func (c *ConnSSH) NewClient() (*ConnSSH, error) { // IPV6地址协议 proto := "tcp" if strings.Contains(c.Addr, ":") { proto = "tcp6" c.Addr = fmt.Sprintf("[%s]", c.Addr) } addr := fmt.Sprintf("%s:%d", c.Addr, c.Port) // ssh客户端配置 config := &gossh.ClientConfig{} config.SetDefaults() config.HostKeyCallback = gossh.InsecureIgnoreHostKey() config.User = c.User // 默认等待5s if c.DialTimeOut == 0 { c.DialTimeOut = 5 * time.Second } config.Timeout = c.DialTimeOut // 认证模式-0密码 1私钥 if c.AuthMode == "1" { var signer gossh.Signer var err error if len(c.PassPhrase) != 0 { signer, err = gossh.ParsePrivateKeyWithPassphrase([]byte(c.PrivateKey), []byte(c.PassPhrase)) } else { signer, err = gossh.ParsePrivateKey([]byte(c.PrivateKey)) } if err != nil { logger.Errorf("NewClient parse private key => %s", err.Error()) return nil, err } config.Auth = []gossh.AuthMethod{gossh.PublicKeys(signer)} } else { config.Auth = []gossh.AuthMethod{gossh.Password(c.Password)} } client, err := gossh.Dial(proto, addr, config) if nil != err { logger.Errorf("NewClient dial => %s", err.Error()) return c, err } c.Client = client return c, nil } // Close 关闭当前SSH客户端 func (c *ConnSSH) Close() { if c.Client != nil { c.Client.Close() } } // NewClientByLocalPrivate 创建SSH客户端-本地私钥(~/.ssh/id_rsa)直连 // // ssh.ConnSSH{ // User: "user", // Addr: "192.168.x.x", // Port: body.Port, // } func (c *ConnSSH) NewClientByLocalPrivate() (*ConnSSH, error) { c.Port = 22 c.AuthMode = "1" privateKey, err := c.CurrentUserRsaKey(false) if err != nil { return nil, err } c.PrivateKey = privateKey return c.NewClient() } // RunCMD 执行单次命令 func (c *ConnSSH) RunCMD(cmd string) (string, error) { if c.Client == nil { if _, err := c.NewClient(); err != nil { return "", err } } session, err := c.Client.NewSession() if err != nil { logger.Errorf("RunCMD failed to create session: => %s", err.Error()) return "", err } defer session.Close() buf, err := session.CombinedOutput(cmd) if err != nil { logger.Errorf("RunCMD failed run command: => %s", err.Error()) } c.LastResult = string(buf) return c.LastResult, err } // SendToAuthorizedKeys 发送当前用户私钥到远程服务器进行授权密钥 func (c *ConnSSH) SendToAuthorizedKeys() error { publicKey, err := c.CurrentUserRsaKey(true) if err != nil { return err } authorizedKeysEntry := fmt.Sprintln(strings.TrimSpace(publicKey)) cmdStrArr := []string{ fmt.Sprintf("sudo chown %s:%s /home/%s/.ssh && sudo chmod 700 /home/%s/.ssh", c.User, c.User, c.User, c.User), fmt.Sprintf("sudo chown %s:%s /home/%s/.ssh/authorized_keys && sudo chmod 600 /home/%s/.ssh/authorized_keys", c.User, c.User, c.User, c.User), fmt.Sprintf("sudo echo '%s' >> ~/.ssh/authorized_keys", authorizedKeysEntry), } _, err = c.RunCMD(strings.Join(cmdStrArr, " && ")) if err != nil { logger.Errorf("SendAuthorizedKeys echo err %s", err.Error()) return err } return nil } // CurrentUserRsaKey 当前用户OMC使用的RSA私钥 // 默认读取私钥 // ssh-keygen -t rsa -P "" -f ~/.ssh/id_rsa // ssh-keygen -y -f ~/.ssh/id_rsa > ~/.ssh/id_rsa.pub func (c *ConnSSH) CurrentUserRsaKey(publicKey bool) (string, error) { usr, err := user.Current() if err != nil { logger.Errorf("CurrentUserRsaKey get => %s", err.Error()) return "", err } // 是否存在私钥并创建 keyPath := fmt.Sprintf("%s/.ssh/id_rsa", usr.HomeDir) if _, err := os.Stat(keyPath); err != nil { _, err2 := cmd.ExecWithCheck("ssh-keygen", "-t", "rsa", "-P", "", "-f", keyPath) if err2 != nil { logger.Errorf("CurrentUserPrivateKey ssh-keygen [%s] rsa => %s", usr.Username, err2.Error()) } } // 读取用户默认的文件 if publicKey { keyPath = keyPath + ".pub" } key, err := os.ReadFile(keyPath) if err != nil { logger.Errorf("CurrentUserRsaKey [%s] read => %s", usr.Username, err.Error()) return "", fmt.Errorf("read file %s fail", keyPath) } return string(key), nil } // NewClientSession 创建SSH客户端会话对象 func (c *ConnSSH) NewClientSession(cols, rows int) (*SSHClientSession, error) { sshSession, err := c.Client.NewSession() if err != nil { return nil, err } stdin, err := sshSession.StdinPipe() if err != nil { return nil, err } comboWriter := new(singleWriter) sshSession.Stdout = comboWriter sshSession.Stderr = comboWriter modes := gossh.TerminalModes{ gossh.ECHO: 1, gossh.TTY_OP_ISPEED: 14400, gossh.TTY_OP_OSPEED: 14400, } if err := sshSession.RequestPty("xterm", rows, cols, modes); err != nil { return nil, err } if err := sshSession.Shell(); err != nil { return nil, err } return &SSHClientSession{ Stdin: stdin, Stdout: comboWriter, Session: sshSession, }, nil } // SSHClientSession SSH客户端会话对象 type SSHClientSession struct { Stdin io.WriteCloser Stdout *singleWriter Session *gossh.Session } // Close 关闭会话 func (s *SSHClientSession) Close() { if s.Stdin != nil { s.Stdin.Close() } if s.Stdout != nil { s.Stdout = nil } if s.Session != nil { s.Session.Close() } } // Write 写入命令 回车(\n)才会执行 func (s *SSHClientSession) Write(cmd string) (int, error) { if s.Stdin == nil { return 0, fmt.Errorf("ssh client session is nil to content write failed") } return s.Stdin.Write([]byte(cmd)) } // Read 读取结果 func (s *SSHClientSession) Read() []byte { if s.Stdout == nil { return []byte{} } bs := s.Stdout.Bytes() if len(bs) > 0 { s.Stdout.Reset() return bs } return []byte{} } // singleWriter SSH客户端会话消息 type singleWriter struct { b bytes.Buffer mu sync.Mutex } func (w *singleWriter) Write(p []byte) (int, error) { w.mu.Lock() defer w.mu.Unlock() return w.b.Write(p) } func (w *singleWriter) Bytes() []byte { w.mu.Lock() defer w.mu.Unlock() return w.b.Bytes() } func (w *singleWriter) Reset() { w.mu.Lock() defer w.mu.Unlock() w.b.Reset() } // NewClientSFTP 创建SSH客户端SFTP对象 func (c *ConnSSH) NewClientSFTP() (*SSHClientSFTP, error) { sftpClient, err := gosftp.NewClient(c.Client) if err != nil { logger.Errorf("NewClientSFTP failed to create sftp: => %s", err.Error()) return nil, err } return &SSHClientSFTP{ Client: sftpClient, }, nil } // SSHClientSFTP SSH客户端SFTP对象 type SSHClientSFTP struct { Client *gosftp.Client } // Close 关闭会话 func (s *SSHClientSFTP) Close() { if s.Client != nil { s.Client.Close() } } // CopyDirRemoteToLocal 复制目录-远程到本地 func (s *SSHClientSFTP) CopyDirRemoteToLocal(remoteDir, localDir string) error { // 列出远程目录中的文件和子目录 remoteFiles, err := s.Client.ReadDir(remoteDir) if err != nil { logger.Errorf("CopyDirRemoteToLocal failed to reading remote directory %s: => %s", remoteDir, err.Error()) return err } // 创建本地目录 err = os.MkdirAll(localDir, 0755) if err != nil { logger.Errorf("CopyDirRemoteToLocal failed to creating local directory %s: => %s", localDir, err.Error()) return err } // 遍历远程文件和子目录并复制到本地 for _, remoteFile := range remoteFiles { remotePath := filepath.Join(remoteDir, remoteFile.Name()) localPath := filepath.Join(localDir, remoteFile.Name()) if remoteFile.IsDir() { // 如果是子目录,则递归复制子目录 err = s.CopyDirRemoteToLocal(remotePath, localPath) if err != nil { logger.Errorf("CopyDirRemoteToLocal failed to copying remote directory %s: => %s", remotePath, err.Error()) continue } } else { // 如果是文件,则复制文件内容 remoteFile, err := s.Client.Open(remotePath) if err != nil { logger.Errorf("CopyDirRemoteToLocal failed to opening remote file %s: => %s", remotePath, err.Error()) continue } defer remoteFile.Close() localFile, err := os.Create(localPath) if err != nil { logger.Errorf("CopyDirRemoteToLocal failed to creating local file %s: => %s", localPath, err.Error()) continue } defer localFile.Close() _, err = io.Copy(localFile, remoteFile) if err != nil { logger.Errorf("CopyDirRemoteToLocal failed to copying file contents from %s to %s: => %s", remotePath, localPath, err.Error()) continue } } } return nil } // CopyDirRemoteToLocal 复制目录-本地到远程 func (s *SSHClientSFTP) CopyDirLocalToRemote(localDir, remoteDir string) error { // 创建远程目录 err := s.Client.MkdirAll(remoteDir) if err != nil { logger.Errorf("CopyDirLocalToRemote failed to creating remote directory %s: => %s", remoteDir, err.Error()) return err } // 遍历本地目录中的文件和子目录并复制到远程 err = filepath.Walk(localDir, func(localPath string, info os.FileInfo, err error) error { if err != nil { return err } // 生成远程路径 remotePath := filepath.Join(remoteDir, localPath[len(localDir):]) if info.IsDir() { // 如果是子目录,则创建远程目录 err := s.Client.MkdirAll(remotePath) if err != nil { logger.Errorf("CopyDirLocalToRemote failed to creating remote directory %s: => %s", remotePath, err.Error()) return nil } } else { // 如果是文件,则复制文件内容 localFile, err := os.Open(localPath) if err != nil { logger.Errorf("CopyDirLocalToRemote failed to opening local file %s: => %s", localPath, err.Error()) return nil } defer localFile.Close() remoteFile, err := s.Client.Create(remotePath) if err != nil { logger.Errorf("CopyDirLocalToRemote failed to creating remote file %s: => %s", remotePath, err.Error()) return nil } defer remoteFile.Close() _, err = io.Copy(remoteFile, localFile) if err != nil { logger.Errorf("CopyDirLocalToRemote failed to copying file contents from %s to %s: => %s", localPath, remotePath, err.Error()) return nil } } return nil }) if err != nil { logger.Errorf("CopyDirLocalToRemote failed to walking local directory: => %s", err.Error()) return err } return nil } // CopyDirRemoteToLocal 复制文件-远程到本地 func (s *SSHClientSFTP) CopyFileRemoteToLocal(remotePath, localPath string) error { // 打开远程文件 remoteFile, err := s.Client.Open(remotePath) if err != nil { logger.Errorf("CopyFileRemoteToLocal failed to opening remote file: => %s", err.Error()) return err } defer remoteFile.Close() if err := os.MkdirAll(filepath.Dir(localPath), 0750); err != nil { return err } // 如果目标文件已经存在,先将目标文件重命名 // if info, err := os.Stat(localPath); err == nil && !info.IsDir() { // ext := filepath.Ext(localPath) // name := localPath[0 : len(localPath)-len(ext)] // newName := fmt.Sprintf("%s-%s%s", name, time.Now().Format("20060102_150405"), ext) // err := os.Rename(localPath, newName) // if err != nil { // return err // } // } // 创建本地文件 localFile, err := os.Create(localPath) if err != nil { logger.Errorf("CopyFileRemoteToLocal failed to creating local file: => %s", err.Error()) return err } defer localFile.Close() // 复制文件内容 _, err = io.Copy(localFile, remoteFile) if err != nil { logger.Errorf("CopyFileRemoteToLocal failed to copying contents: => %s", err.Error()) return err } return nil } // CopyDirRemoteToLocal 复制文件-本地到远程 func (s *SSHClientSFTP) CopyFileLocalToRemote(localPath, remotePath string) error { // 打开本地文件 localFile, err := os.Open(localPath) if err != nil { logger.Errorf("CopyFileLocalToRemote failed to opening local file: => %s", err.Error()) return err } defer localFile.Close() // 创建远程文件 remoteFile, err := s.Client.Create(remotePath) if err != nil { logger.Errorf("CopyFileLocalToRemote failed to creating remote file: => %s", err.Error()) return err } defer remoteFile.Close() // 复制文件内容 _, err = io.Copy(remoteFile, localFile) if err != nil { logger.Errorf("CopyFileLocalToRemote failed to copying contents: => %s", err.Error()) return err } return nil }