ssh远程批量执行shell的代码 发表于 2017-02-05 | 分类于 Java | 使用的是12345<dependency> <groupId>ch.ethz.ganymed</groupId> <artifactId>ganymed-ssh2</artifactId> <version>build251beta1</version> </dependency> 直接上代码:123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301public class ExecutionResult { private boolean success; private String errorMsg; private int returnCode; private String stdout; private String stderr; public static ExecutionResult successResult(int returnCode, String stdout, String stderr) { ExecutionResult result = new ExecutionResult(); result.setSuccess(true); result.setReturnCode(returnCode); result.setStdout(stdout); result.setStderr(stderr); return result; } public static ExecutionResult failResult(String errorMsg, int returnCode, String stdout, String stderr) { ExecutionResult result = new ExecutionResult(); result.setSuccess(false); result.setErrorMsg(errorMsg); result.setReturnCode(returnCode); result.setStdout(stdout); result.setStderr(stderr); return result; } public static ExecutionResult failResult(String errorMsg) { ExecutionResult result = new ExecutionResult(); result.setSuccess(false); result.setErrorMsg(errorMsg); return result; } public String getTaskErrorMsg() { boolean hasErrorMsg = false; StringBuilder errorMsgBuilder = new StringBuilder(); if (StringUtil.isNotBlank(errorMsg)) { errorMsgBuilder.append("errorMsg:"); errorMsgBuilder.append(errorMsg); hasErrorMsg = true; } if (StringUtil.isNotBlank(stderr)) { if (hasErrorMsg) { errorMsgBuilder.append(";"); } errorMsgBuilder.append("stderr:"); errorMsgBuilder.append(stderr); } return errorMsgBuilder.toString(); } public boolean isSuccess() { return success; } public void setSuccess(boolean success) { this.success = success; } public String getErrorMsg() { return errorMsg; } public void setErrorMsg(String errorMsg) { this.errorMsg = errorMsg; } public int getReturnCode() { return returnCode; } public void setReturnCode(int returnCode) { this.returnCode = returnCode; } public String getStdout() { return stdout; } public void setStdout(String stdout) { this.stdout = stdout; } public String getStderr() { return stderr; } public void setStderr(String stderr) { this.stderr = stderr; }}class SshTask implements Callable<ExecutionResult> { private static final Logger logger = LoggerFactory.getLogger(SshTask.class); private static final String STD_CHARSET = Charset.defaultCharset().name(); private static final String SSH_USERNAME = "admin"; private static final File PUBLIC_KEY_FILE = new File("key的路径"); // 默认TCP连接建立超时时间(毫秒) private static final int DEFAULT_TCP_TIMEOUT_MILLISECONDS = 10000; // 默认整个ssh连接建立超时时间(毫秒) private static final int DEFAULT_SSH_TIMEOUT_MILLISECONDS = 30000; private String targetHost; private String command; private int timeoutSec; SshTask(String targetHost, String command, int timeoutSec) { this.targetHost = targetHost; this.command = command; this.timeoutSec = timeoutSec; } public ExecutionResult call() throws Exception { Connection sshConnection = new Connection(targetHost); try { sshConnection.connect(null, DEFAULT_TCP_TIMEOUT_MILLISECONDS, DEFAULT_SSH_TIMEOUT_MILLISECONDS); boolean isAuthenticated = sshConnection.authenticateWithPublicKey(SSH_USERNAME, PUBLIC_KEY_FILE, ""); if (!isAuthenticated) { // 无法正常通过验证 String canNotAuthenticatedMsg = String.format( "无法通过SshKey正常登陆目标服务器, targetHost:%s, connectTimeout:%d ms", targetHost, DEFAULT_SSH_TIMEOUT_MILLISECONDS); logger.error(canNotAuthenticatedMsg); return ExecutionResult.failResult(canNotAuthenticatedMsg); } return execWithConnection(sshConnection); } finally { sshConnection.close(); } } private ExecutionResult execWithConnection(Connection sshConnection) throws IOException { Session sshSession = sshConnection.openSession(); String cmd = new String(command.getBytes(STD_CHARSET), "ISO8859_1"); sshSession.execCommand(cmd); InputStream stdout = null; InputStream stderr = null; try { stdout = sshSession.getStdout(); stderr = sshSession.getStderr(); return getResultWithTimeout(sshSession, stdout, stderr, timeoutSec); } finally { if (stdout != null) { stdout.close(); } if (stderr != null) { stderr.close(); } } } private ExecutionResult getResultWithTimeout(Session sshSession, InputStream stdout, InputStream stderr, int timeoutSec) throws IOException { long startMillis = System.currentTimeMillis(); long endMillis = startMillis + timeoutSec * 1000; ByteArrayOutputStream stdoutByteArrayOutputStream = new ByteArrayOutputStream(); ByteArrayOutputStream stderrByteArrayOutputStream = new ByteArrayOutputStream(); try { while (true) { int condition = sshSession.waitForCondition( ChannelCondition.STDOUT_DATA | ChannelCondition.STDERR_DATA | ChannelCondition.EXIT_STATUS | ChannelCondition.EOF, timeoutSec * 1000); long currentMills = System.currentTimeMillis(); if ((condition & ChannelCondition.TIMEOUT) != 0 || currentMills > endMillis) { // 超时 对应超时错误码10002 无奈之举 return ExecutionResult.failResult("执行超时", 10002, readString(stdoutByteArrayOutputStream), readString(stderrByteArrayOutputStream)); } if ((condition & ChannelCondition.STDOUT_DATA) != 0) { copy(stdout, stdoutByteArrayOutputStream); } if ((condition & ChannelCondition.STDERR_DATA) != 0) { copy(stderr, stderrByteArrayOutputStream); } if (((condition & ChannelCondition.EOF) != 0) && ((condition & ChannelCondition.EXIT_STATUS) != 0)) { // 收到EOF和EXIT_STATUS int retCode = sshSession.getExitStatus(); String stdoutStr = readString(stdoutByteArrayOutputStream); String stderrStr = readString(stderrByteArrayOutputStream); return ExecutionResult.successResult(retCode, stdoutStr, stderrStr); } } } finally { stdoutByteArrayOutputStream.close(); stderrByteArrayOutputStream.close(); } } private String readString(ByteArrayOutputStream byteArrayOutputStream) throws IOException { byteArrayOutputStream.flush(); return byteArrayOutputStream.toString(); } private void copy(InputStream sourceStream, OutputStream targetStream) throws IOException { byte[] buffer = new byte[8192]; while (sourceStream.available() > 0) { int bytes = sourceStream.read(buffer); if (bytes > 0) { targetStream.write(buffer, 0, bytes); } } }}public class SshUtils { private static final Logger logger = LoggerFactory.getLogger(SshUtils.class); private static final ThreadPoolTaskExecutor sshThreadPool; static { sshThreadPool = new ThreadPoolTaskExecutor(); sshThreadPool.setThreadNamePrefix("SshThreadPool-"); sshThreadPool.setCorePoolSize(20); sshThreadPool.setMaxPoolSize(100); sshThreadPool.setQueueCapacity(100); sshThreadPool.initialize(); } public static ExecutionResult execute(String targetHost, String command, int timeoutSec) { if (logger.isInfoEnabled()) { logger.info(String.format("开始执行命令, targetHost:%s, command:%s, timeoutSec:%d", targetHost, command, timeoutSec)); } checkParams(targetHost, command, timeoutSec); ThreadPoolExecutor executor = sshThreadPool.getThreadPoolExecutor(); SshTask sshTask = new SshTask(targetHost, command, timeoutSec); Future<ExecutionResult> future = executor.submit(sshTask); ExecutionResult execResult; try { execResult = future.get(timeoutSec + 5, TimeUnit.SECONDS); // 多设置5s超时, 给子任务一些余地 } catch (InterruptedException e) { String intMsg = String.format("InterruptedException:%s, targetHost:%s, command:%s, timeoutSec:%d.", e.getMessage(), targetHost, command, timeoutSec); logger.error(intMsg, e); execResult = ExecutionResult.failResult(intMsg); } catch (ExecutionException e) { String execMsg = String.format("ExecutionException:%s, targetHost:%s, command:%s, timeoutSec:%d.", e.getMessage(), targetHost, command, timeoutSec); logger.error(execMsg, e); execResult = ExecutionResult.failResult(execMsg); } catch (TimeoutException e) { String timeoutMsg = String.format("TimeoutException:%s, targetHost:%s, command:%s, timeoutSec:%d.", e.getMessage(), targetHost, command, timeoutSec); logger.error(timeoutMsg, e); execResult = ExecutionResult.failResult(timeoutMsg); } finally { if (!future.isDone()) { String forceKillTaskMsg = String.format( "任务执行未完成, 强制停止, targetHost:%s, command:%s, timeoutSec:%s", targetHost, command, timeoutSec); logger.warn(forceKillTaskMsg); future.cancel(true); } } if (logger.isInfoEnabled()) { logger.info(String.format( "命令执行完成, targetHost:%s, command:%s, timeoutSec:%d, retCode:%d, stdout:%s, stderr:%s", targetHost, command, timeoutSec, execResult.getReturnCode(), execResult.getStdout(), execResult.getStderr())); } return execResult; } private static void checkParams(String targetHost, String command, int timeout) { Preconditions.checkArgument(StringUtil.isNotBlank(targetHost), "目标主机不应为空"); Preconditions.checkArgument(StringUtil.isNotBlank(command), "待执行命令不应为空"); Preconditions.checkArgument(timeout >= 0, "超时时间不应<0"); }}