diff --git a/spoon-runner/src/main/java/com/squareup/spoon/CliArgs.kt b/spoon-runner/src/main/java/com/squareup/spoon/CliArgs.kt index e7ca600e..5b1161d0 100644 --- a/spoon-runner/src/main/java/com/squareup/spoon/CliArgs.kt +++ b/spoon-runner/src/main/java/com/squareup/spoon/CliArgs.kt @@ -75,7 +75,8 @@ internal class CliArgs(parser: ArgParser) { val coverage by parser.flagging("Enable code coverage") val singleInstrumentationCall by parser.flagging("--single-instrumentation-call", - help = "Run all tests in a single instrumentation call") + help = "Run tests in a single instrumentation call. If sharding is enabled each shard " + + "will be in a single instrumentation call otherwise all tests will be run.") private fun validateInstrumentationArgs() { val isTestRunPackageLimited = instrumentationArgs?.contains("package") ?: false diff --git a/spoon-runner/src/main/java/com/squareup/spoon/SpoonDeviceRunner.java b/spoon-runner/src/main/java/com/squareup/spoon/SpoonDeviceRunner.java index fb88b302..da352501 100644 --- a/spoon-runner/src/main/java/com/squareup/spoon/SpoonDeviceRunner.java +++ b/spoon-runner/src/main/java/com/squareup/spoon/SpoonDeviceRunner.java @@ -234,16 +234,20 @@ public DeviceResult run(AndroidDebugBridge adb) { result.startTests(); multiRunListener.multiRunStarted(recorder.runName(), recorder.testCount()); if (singleInstrumentationCall) { - logDebug(debug, "Running all tests in a single instrumentation call [%s]", serial); + logDebug(debug, "Running tests in a single instrumentation call [%s]", serial); try { - runAllTestOnDevice(testPackage, testRunner, device, listeners); + if (numShards != 0) { + runTestShardOnDevice(testPackage, testRunner, device, listeners); + } else { + runAllTestsOnDevice(testPackage, testRunner, device, listeners); + } } catch (Exception e) { result.addException(e); } } else { for (TestIdentifier test : activeTests) { try { - runTestOnDevice(testPackage, testRunner, device, listeners, test); + runSingleTestOnDevice(testPackage, testRunner, device, listeners, test); } catch (Exception e) { result.addException(e); } @@ -316,29 +320,34 @@ private LogRecordingTestRunListener queryTestSet(final String testPackage, return recorder; } - private void runAllTestOnDevice(final String testPackage, final String testRunner, + private void runAllTestsOnDevice(final String testPackage, final String testRunner, + final IDevice device, final List listeners) throws Exception { + + logDebug(debug, "Running tests [%s]", serial); + + RemoteAndroidTestRunner runner = createConfiguredRunner(testPackage, testRunner, device); + runner.run(listeners); + } + + private void runTestShardOnDevice(final String testPackage, final String testRunner, final IDevice device, final List listeners) throws Exception { - runTestOnDevice(testPackage, testRunner, device, listeners, null); + logDebug(debug, "Running tests for shardIndex [%d] out of numShards [%s] on [%s]", + shardIndex, numShards, serial); + RemoteAndroidTestRunner runner = createConfiguredRunner(testPackage, testRunner, device); + addShardingInstrumentationArgs(runner); + runner.run(listeners); } - private void runTestOnDevice(final String testPackage, final String testRunner, + private void runSingleTestOnDevice(final String testPackage, final String testRunner, final IDevice device, final List listeners, @Nullable final TestIdentifier test) throws Exception { - if (test != null) { - logDebug(debug, "Running %s [%s]", test, serial); - } else { - logDebug(debug, "Running tests [%s]", serial); - } + logDebug(debug, "Running %s on [%s]", test, serial); + RemoteAndroidTestRunner runner = createConfiguredRunner(testPackage, testRunner, device); runner.removeInstrumentationArg("package"); - if (codeCoverage) { - addCodeCoverageInstrumentationArgs(runner, device); - } - if (test != null) { - runner.setMethodName(test.getClassName(), test.getTestName()); - } + runner.setMethodName(test.getClassName(), test.getTestName()); runner.run(listeners); } @@ -356,6 +365,9 @@ private RemoteAndroidTestRunner createConfiguredRunner(String testPackage, Strin } runner.addInstrumentationArg(entry.getKey(), entry.getValue()); } + if (codeCoverage) { + addCodeCoverageInstrumentationArgs(runner, device); + } return runner; }