Jamal的博客

ssh远程批量执行shell的代码

使用的是

1
2
3
4
5
<dependency>
<groupId>ch.ethz.ganymed</groupId>
<artifactId>ganymed-ssh2</artifactId>
<version>build251beta1</version>
</dependency>

直接上代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
public class ExecutionResult {
private boolean success;
private String errorMsg;
private int returnCode;
private String stdout;
private String stderr;
public static ExecutionResult successResult(int returnCode, String stdout, String stderr) {
ExecutionResult result = new ExecutionResult();
result.setSuccess(true);
result.setReturnCode(returnCode);
result.setStdout(stdout);
result.setStderr(stderr);
return result;
}
public static ExecutionResult failResult(String errorMsg, int returnCode, String stdout, String stderr) {
ExecutionResult result = new ExecutionResult();
result.setSuccess(false);
result.setErrorMsg(errorMsg);
result.setReturnCode(returnCode);
result.setStdout(stdout);
result.setStderr(stderr);
return result;
}
public static ExecutionResult failResult(String errorMsg) {
ExecutionResult result = new ExecutionResult();
result.setSuccess(false);
result.setErrorMsg(errorMsg);
return result;
}
public String getTaskErrorMsg() {
boolean hasErrorMsg = false;
StringBuilder errorMsgBuilder = new StringBuilder();
if (StringUtil.isNotBlank(errorMsg)) {
errorMsgBuilder.append("errorMsg:");
errorMsgBuilder.append(errorMsg);
hasErrorMsg = true;
}
if (StringUtil.isNotBlank(stderr)) {
if (hasErrorMsg) {
errorMsgBuilder.append(";");
}
errorMsgBuilder.append("stderr:");
errorMsgBuilder.append(stderr);
}
return errorMsgBuilder.toString();
}
public boolean isSuccess() {
return success;
}
public void setSuccess(boolean success) {
this.success = success;
}
public String getErrorMsg() {
return errorMsg;
}
public void setErrorMsg(String errorMsg) {
this.errorMsg = errorMsg;
}
public int getReturnCode() {
return returnCode;
}
public void setReturnCode(int returnCode) {
this.returnCode = returnCode;
}
public String getStdout() {
return stdout;
}
public void setStdout(String stdout) {
this.stdout = stdout;
}
public String getStderr() {
return stderr;
}
public void setStderr(String stderr) {
this.stderr = stderr;
}
}
class SshTask implements Callable<ExecutionResult> {
private static final Logger logger = LoggerFactory.getLogger(SshTask.class);
private static final String STD_CHARSET = Charset.defaultCharset().name();
private static final String SSH_USERNAME = "admin";
private static final File PUBLIC_KEY_FILE = new File("key的路径");
// 默认TCP连接建立超时时间(毫秒)
private static final int DEFAULT_TCP_TIMEOUT_MILLISECONDS = 10000;
// 默认整个ssh连接建立超时时间(毫秒)
private static final int DEFAULT_SSH_TIMEOUT_MILLISECONDS = 30000;
private String targetHost;
private String command;
private int timeoutSec;
SshTask(String targetHost, String command, int timeoutSec) {
this.targetHost = targetHost;
this.command = command;
this.timeoutSec = timeoutSec;
}
public ExecutionResult call() throws Exception {
Connection sshConnection = new Connection(targetHost);
try {
sshConnection.connect(null, DEFAULT_TCP_TIMEOUT_MILLISECONDS, DEFAULT_SSH_TIMEOUT_MILLISECONDS);
boolean isAuthenticated = sshConnection.authenticateWithPublicKey(SSH_USERNAME, PUBLIC_KEY_FILE, "");
if (!isAuthenticated) {
// 无法正常通过验证
String canNotAuthenticatedMsg = String.format(
"无法通过SshKey正常登陆目标服务器, targetHost:%s, connectTimeout:%d ms",
targetHost, DEFAULT_SSH_TIMEOUT_MILLISECONDS);
logger.error(canNotAuthenticatedMsg);
return ExecutionResult.failResult(canNotAuthenticatedMsg);
}
return execWithConnection(sshConnection);
} finally {
sshConnection.close();
}
}
private ExecutionResult execWithConnection(Connection sshConnection) throws IOException {
Session sshSession = sshConnection.openSession();
String cmd = new String(command.getBytes(STD_CHARSET), "ISO8859_1");
sshSession.execCommand(cmd);
InputStream stdout = null;
InputStream stderr = null;
try {
stdout = sshSession.getStdout();
stderr = sshSession.getStderr();
return getResultWithTimeout(sshSession, stdout, stderr, timeoutSec);
} finally {
if (stdout != null) {
stdout.close();
}
if (stderr != null) {
stderr.close();
}
}
}
private ExecutionResult getResultWithTimeout(Session sshSession, InputStream stdout, InputStream stderr,
int timeoutSec) throws IOException {
long startMillis = System.currentTimeMillis();
long endMillis = startMillis + timeoutSec * 1000;
ByteArrayOutputStream stdoutByteArrayOutputStream = new ByteArrayOutputStream();
ByteArrayOutputStream stderrByteArrayOutputStream = new ByteArrayOutputStream();
try {
while (true) {
int condition = sshSession.waitForCondition(
ChannelCondition.STDOUT_DATA
| ChannelCondition.STDERR_DATA
| ChannelCondition.EXIT_STATUS
| ChannelCondition.EOF, timeoutSec * 1000);
long currentMills = System.currentTimeMillis();
if ((condition & ChannelCondition.TIMEOUT) != 0 || currentMills > endMillis) {
// 超时 对应超时错误码10002 无奈之举
return ExecutionResult.failResult("执行超时", 10002, readString(stdoutByteArrayOutputStream),
readString(stderrByteArrayOutputStream));
}
if ((condition & ChannelCondition.STDOUT_DATA) != 0) {
copy(stdout, stdoutByteArrayOutputStream);
}
if ((condition & ChannelCondition.STDERR_DATA) != 0) {
copy(stderr, stderrByteArrayOutputStream);
}
if (((condition & ChannelCondition.EOF) != 0)
&& ((condition & ChannelCondition.EXIT_STATUS) != 0)) {
// 收到EOF和EXIT_STATUS
int retCode = sshSession.getExitStatus();
String stdoutStr = readString(stdoutByteArrayOutputStream);
String stderrStr = readString(stderrByteArrayOutputStream);
return ExecutionResult.successResult(retCode, stdoutStr, stderrStr);
}
}
} finally {
stdoutByteArrayOutputStream.close();
stderrByteArrayOutputStream.close();
}
}
private String readString(ByteArrayOutputStream byteArrayOutputStream) throws IOException {
byteArrayOutputStream.flush();
return byteArrayOutputStream.toString();
}
private void copy(InputStream sourceStream, OutputStream targetStream) throws IOException {
byte[] buffer = new byte[8192];
while (sourceStream.available() > 0) {
int bytes = sourceStream.read(buffer);
if (bytes > 0) {
targetStream.write(buffer, 0, bytes);
}
}
}
}
public class SshUtils {
private static final Logger logger = LoggerFactory.getLogger(SshUtils.class);
private static final ThreadPoolTaskExecutor sshThreadPool;
static {
sshThreadPool = new ThreadPoolTaskExecutor();
sshThreadPool.setThreadNamePrefix("SshThreadPool-");
sshThreadPool.setCorePoolSize(20);
sshThreadPool.setMaxPoolSize(100);
sshThreadPool.setQueueCapacity(100);
sshThreadPool.initialize();
}
public static ExecutionResult execute(String targetHost, String command, int timeoutSec) {
if (logger.isInfoEnabled()) {
logger.info(String.format("开始执行命令, targetHost:%s, command:%s, timeoutSec:%d",
targetHost, command, timeoutSec));
}
checkParams(targetHost, command, timeoutSec);
ThreadPoolExecutor executor = sshThreadPool.getThreadPoolExecutor();
SshTask sshTask = new SshTask(targetHost, command, timeoutSec);
Future<ExecutionResult> future = executor.submit(sshTask);
ExecutionResult execResult;
try {
execResult = future.get(timeoutSec + 5, TimeUnit.SECONDS); // 多设置5s超时, 给子任务一些余地
} catch (InterruptedException e) {
String intMsg = String.format("InterruptedException:%s, targetHost:%s, command:%s, timeoutSec:%d.",
e.getMessage(), targetHost, command, timeoutSec);
logger.error(intMsg, e);
execResult = ExecutionResult.failResult(intMsg);
} catch (ExecutionException e) {
String execMsg = String.format("ExecutionException:%s, targetHost:%s, command:%s, timeoutSec:%d.",
e.getMessage(), targetHost, command, timeoutSec);
logger.error(execMsg, e);
execResult = ExecutionResult.failResult(execMsg);
} catch (TimeoutException e) {
String timeoutMsg = String.format("TimeoutException:%s, targetHost:%s, command:%s, timeoutSec:%d.",
e.getMessage(), targetHost, command, timeoutSec);
logger.error(timeoutMsg, e);
execResult = ExecutionResult.failResult(timeoutMsg);
} finally {
if (!future.isDone()) {
String forceKillTaskMsg = String.format(
"任务执行未完成, 强制停止, targetHost:%s, command:%s, timeoutSec:%s",
targetHost, command, timeoutSec);
logger.warn(forceKillTaskMsg);
future.cancel(true);
}
}
if (logger.isInfoEnabled()) {
logger.info(String.format(
"命令执行完成, targetHost:%s, command:%s, timeoutSec:%d, retCode:%d, stdout:%s, stderr:%s",
targetHost, command, timeoutSec,
execResult.getReturnCode(), execResult.getStdout(), execResult.getStderr()));
}
return execResult;
}
private static void checkParams(String targetHost, String command, int timeout) {
Preconditions.checkArgument(StringUtil.isNotBlank(targetHost), "目标主机不应为空");
Preconditions.checkArgument(StringUtil.isNotBlank(command), "待执行命令不应为空");
Preconditions.checkArgument(timeout >= 0, "超时时间不应<0");
}
}