Skip to content

Commit 33cb56b

Browse files
authored
[AN-146] Emit VM cost for GCP Batch (#7582)
1 parent e917e52 commit 33cb56b

File tree

7 files changed

+419
-15
lines changed

7 files changed

+419
-15
lines changed

project/Dependencies.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ object Dependencies {
8989
private val metrics3StatsdV = "4.2.0"
9090
private val mockFtpServerV = "3.0.0"
9191
private val mockitoV = "3.12.4"
92+
private val mockitoInlineV = "2.8.9"
9293
private val mockserverNettyV = "5.14.0"
9394
private val mouseV = "1.0.11"
9495

@@ -627,7 +628,8 @@ object Dependencies {
627628
"org.scalatest" %% "scalatest" % scalatestV,
628629
// Use mockito Java DSL directly instead of the numerous and often hard to keep updated Scala DSLs.
629630
// See also scaladoc in common.mock.MockSugar and that trait's various usages.
630-
"org.mockito" % "mockito-core" % mockitoV
631+
"org.mockito" % "mockito-core" % mockitoV,
632+
"org.mockito" % "mockito-inline" % mockitoInlineV
631633
) ++ slf4jBindingDependencies // During testing, add an slf4j binding for _all_ libraries.
632634

633635
val kindProjectorPlugin = "org.typelevel" % "kind-projector" % kindProjectorV cross CrossVersion.full

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/BatchPollResultMonitorActor.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)
4747
case event if event.name == CallMetadataKeys.VmEndTime => event.offsetDateTime
4848
}
4949

50+
override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] =
51+
pollStatus.instantiatedVmInfo
52+
5053
override def handleVmCostLookup(vmInfo: InstantiatedVmInfo) = {
5154
val request = GcpCostLookupRequest(vmInfo, self)
5255
params.serviceRegistry ! request
@@ -69,6 +72,7 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)
6972
}
7073

7174
override def receive: Receive = {
75+
case costResponse: GcpCostLookupResponse => handleCostResponse(costResponse)
7276
case message: PollResultMessage =>
7377
message match {
7478
case ProcessThisPollResult(pollResult: RunStatus) => processPollResult(pollResult)
@@ -93,5 +97,4 @@ class BatchPollResultMonitorActor(pollMonitorParameters: PollMonitorParameters)
9397

9498
override def params: PollMonitorParameters = pollMonitorParameters
9599

96-
override def extractVmInfoFromRunState(pollStatus: RunStatus): Option[InstantiatedVmInfo] = Option.empty // TODO
97100
}

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/GcpBatchAsyncBackendJobExecutionActor.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,18 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
10251025
} yield status
10261026
}
10271027

1028+
override val pollingResultMonitorActor: Option[ActorRef] = Option(
1029+
context.actorOf(
1030+
BatchPollResultMonitorActor.props(serviceRegistryActor,
1031+
workflowDescriptor,
1032+
jobDescriptor,
1033+
validatedRuntimeAttributes,
1034+
platform,
1035+
jobLogger
1036+
)
1037+
)
1038+
)
1039+
10281040
override def isTerminal(runStatus: RunStatus): Boolean =
10291041
runStatus match {
10301042
case _: RunStatus.TerminalRunStatus => true
@@ -1070,7 +1082,7 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
10701082
Future.fromTry {
10711083
Try {
10721084
runStatus match {
1073-
case RunStatus.Aborted(_) => AbortedExecutionHandle
1085+
case RunStatus.Aborted(_, _) => AbortedExecutionHandle
10741086
case failedStatus: RunStatus.UnsuccessfulRunStatus => handleFailedRunStatus(failedStatus)
10751087
case unknown =>
10761088
throw new RuntimeException(

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/api/request/BatchRequestExecutor.scala

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cromwell.backend.google.batch.api.request
22

33
import com.google.api.gax.rpc.{ApiException, StatusCode}
4+
import com.google.cloud.batch.v1.AllocationPolicy.ProvisioningModel
45
import com.google.cloud.batch.v1._
56
import com.typesafe.scalalogging.LazyLogging
67
import cromwell.backend.google.batch.actors.BatchApiAbortClient.{
@@ -11,6 +12,8 @@ import cromwell.backend.google.batch.api.BatchApiRequestManager._
1112
import cromwell.backend.google.batch.api.{BatchApiRequestManager, BatchApiResponse}
1213
import cromwell.backend.google.batch.models.{GcpBatchExitCode, RunStatus}
1314
import cromwell.core.ExecutionEvent
15+
import cromwell.services.cost.InstantiatedVmInfo
16+
import cromwell.services.metadata.CallMetadataKeys
1417

1518
import scala.annotation.unused
1619
import scala.concurrent.{ExecutionContext, Future, Promise}
@@ -136,14 +139,32 @@ object BatchRequestExecutor {
136139
)
137140
lazy val exitCode = findBatchExitCode(events)
138141

142+
// Get vm info for this job
143+
val allocationPolicy = job.getAllocationPolicy
144+
145+
// Get instances that can be created with this AllocationPolicy, only instances[0] is supported
146+
val instancePolicy = allocationPolicy.getInstances(0).getPolicy
147+
val machineType = instancePolicy.getMachineType
148+
val preemtible = instancePolicy.getProvisioningModelValue == ProvisioningModel.PREEMPTIBLE.getNumber
149+
150+
// location list = [regions/us-central1, zones/us-central1-b], region is the first element
151+
val location = allocationPolicy.getLocation.getAllowedLocationsList.get(0)
152+
val region =
153+
if (location.isEmpty)
154+
"us-central1"
155+
else
156+
location.split("/").last
157+
158+
val instantiatedVmInfo = Some(InstantiatedVmInfo(region, machineType, preemtible))
159+
139160
if (job.getStatus.getState == JobStatus.State.SUCCEEDED) {
140-
RunStatus.Success(events)
161+
RunStatus.Success(events, instantiatedVmInfo)
141162
} else if (job.getStatus.getState == JobStatus.State.RUNNING) {
142-
RunStatus.Running(events)
163+
RunStatus.Running(events, instantiatedVmInfo)
143164
} else if (job.getStatus.getState == JobStatus.State.FAILED) {
144-
RunStatus.Failed(exitCode, events)
165+
RunStatus.Failed(exitCode, events, instantiatedVmInfo)
145166
} else {
146-
RunStatus.Initializing(events)
167+
RunStatus.Initializing(events, instantiatedVmInfo)
147168
}
148169
}
149170

@@ -152,12 +173,20 @@ object BatchRequestExecutor {
152173
GcpBatchExitCode.fromEventMessage(e.name.toLowerCase)
153174
}.headOption
154175

155-
private def getEventList(events: List[StatusEvent]): List[ExecutionEvent] =
176+
private def getEventList(events: List[StatusEvent]): List[ExecutionEvent] = {
177+
val startedRegex = ".*SCHEDULED to RUNNING.*".r
178+
val endedRegex = ".*RUNNING to.*".r // can be SUCCEEDED or FAILED
156179
events.map { e =>
157180
val time = java.time.Instant
158181
.ofEpochSecond(e.getEventTime.getSeconds, e.getEventTime.getNanos.toLong)
159182
.atOffset(java.time.ZoneOffset.UTC)
160-
ExecutionEvent(name = e.getDescription, offsetDateTime = time)
183+
val eventType = e.getDescription match {
184+
case startedRegex() => CallMetadataKeys.VmStartTime
185+
case endedRegex() => CallMetadataKeys.VmEndTime
186+
case _ => e.getType
187+
}
188+
ExecutionEvent(name = eventType, offsetDateTime = time)
161189
}
190+
}
162191
}
163192
}

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/models/RunStatus.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
11
package cromwell.backend.google.batch.models
22

33
import cromwell.core.ExecutionEvent
4+
import cromwell.services.cost.InstantiatedVmInfo
45

56
sealed trait RunStatus {
67
def eventList: Seq[ExecutionEvent]
78
def toString: String
9+
10+
val instantiatedVmInfo: Option[InstantiatedVmInfo]
811
}
912

1013
object RunStatus {
1114

12-
case class Initializing(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Initializing" }
13-
case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent]) extends RunStatus {
15+
case class Initializing(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty)
16+
extends RunStatus { override def toString = "Initializing" }
17+
case class AwaitingCloudQuota(eventList: Seq[ExecutionEvent],
18+
instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty
19+
) extends RunStatus {
1420
override def toString = "AwaitingCloudQuota"
1521
}
1622

17-
case class Running(eventList: Seq[ExecutionEvent]) extends RunStatus { override def toString = "Running" }
23+
case class Running(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty)
24+
extends RunStatus { override def toString = "Running" }
1825

1926
sealed trait TerminalRunStatus extends RunStatus
2027

21-
case class Success(eventList: Seq[ExecutionEvent]) extends TerminalRunStatus {
28+
case class Success(eventList: Seq[ExecutionEvent], instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty)
29+
extends TerminalRunStatus {
2230
override def toString = "Success"
2331
}
2432

@@ -29,7 +37,8 @@ object RunStatus {
2937

3038
final case class Failed(
3139
exitCode: Option[GcpBatchExitCode],
32-
eventList: Seq[ExecutionEvent]
40+
eventList: Seq[ExecutionEvent],
41+
instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty
3342
) extends UnsuccessfulRunStatus {
3443
override def toString = "Failed"
3544

@@ -58,7 +67,9 @@ object RunStatus {
5867
}
5968
}
6069

61-
final case class Aborted(eventList: Seq[ExecutionEvent]) extends UnsuccessfulRunStatus {
70+
final case class Aborted(eventList: Seq[ExecutionEvent],
71+
instantiatedVmInfo: Option[InstantiatedVmInfo] = Option.empty
72+
) extends UnsuccessfulRunStatus {
6273
override def toString = "Aborted"
6374

6475
override val exitCode: Option[GcpBatchExitCode] = None
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
package cromwell.backend.google.batch.actors
2+
3+
import akka.actor.{ActorRef, ActorSystem, Props}
4+
import akka.testkit.{TestKit, TestProbe}
5+
import cats.data.Validated.Valid
6+
import common.mock.MockSugar
7+
import cromwell.backend.google.batch.models.GcpBatchRuntimeAttributes
8+
import cromwell.backend.{BackendJobDescriptor, BackendJobDescriptorKey, RuntimeAttributeDefinition}
9+
import cromwell.core.callcaching.NoDocker
10+
import cromwell.core.{ExecutionEvent, WorkflowOptions}
11+
import cromwell.core.logging.JobLogger
12+
import cromwell.services.cost.{GcpCostLookupRequest, GcpCostLookupResponse, InstantiatedVmInfo}
13+
import cromwell.services.keyvalue.InMemoryKvServiceActor
14+
import org.scalatest.flatspec.AnyFlatSpecLike
15+
import org.scalatest.matchers.should.Matchers
16+
import cromwell.backend.google.batch.models.GcpBatchTestConfig._
17+
import wom.graph.CommandCallNode
18+
import cromwell.backend._
19+
import cromwell.backend.google.batch.models._
20+
import cromwell.backend.io.TestWorkflows
21+
import cromwell.backend.standard.pollmonitoring.ProcessThisPollResult
22+
import cromwell.services.metadata.CallMetadataKeys
23+
import cromwell.services.metadata.MetadataService.PutMetadataAction
24+
import org.slf4j.helpers.NOPLogger
25+
import wom.values.WomString
26+
27+
import java.time.{Instant, OffsetDateTime}
28+
import java.time.temporal.ChronoUnit
29+
import scala.concurrent.duration.DurationInt
30+
31+
class BatchPollResultMonitorActorSpec
32+
extends TestKit(ActorSystem("BatchPollResultMonitorActorSpec"))
33+
with AnyFlatSpecLike
34+
with BackendSpec
35+
with Matchers
36+
with MockSugar {
37+
38+
var kvService: ActorRef = system.actorOf(Props(new InMemoryKvServiceActor), "kvService")
39+
val runtimeAttributesBuilder = GcpBatchRuntimeAttributes.runtimeAttributesBuilder(gcpBatchConfiguration)
40+
val jobLogger = mock[JobLogger]
41+
val serviceRegistry = TestProbe()
42+
43+
val workflowDescriptor = buildWdlWorkflowDescriptor(TestWorkflows.HelloWorld)
44+
val call: CommandCallNode = workflowDescriptor.callable.taskCallNodes.head
45+
val jobKey = BackendJobDescriptorKey(call, None, 1)
46+
47+
val jobDescriptor = BackendJobDescriptor(workflowDescriptor,
48+
jobKey,
49+
runtimeAttributes = Map.empty,
50+
evaluatedTaskInputs = Map.empty,
51+
NoDocker,
52+
None,
53+
Map.empty
54+
)
55+
56+
val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"))
57+
58+
val staticRuntimeAttributeDefinitions: Set[RuntimeAttributeDefinition] =
59+
GcpBatchRuntimeAttributes.runtimeAttributesBuilder(GcpBatchTestConfig.gcpBatchConfiguration).definitions.toSet
60+
61+
val defaultedAttributes =
62+
RuntimeAttributeDefinition.addDefaultsToAttributes(staticRuntimeAttributeDefinitions,
63+
WorkflowOptions.fromMap(Map.empty).get
64+
)(
65+
runtimeAttributes
66+
)
67+
val validatedRuntimeAttributes = runtimeAttributesBuilder.build(defaultedAttributes, NOPLogger.NOP_LOGGER)
68+
69+
val actor = system.actorOf(
70+
BatchPollResultMonitorActor.props(serviceRegistry.ref,
71+
workflowDescriptor,
72+
jobDescriptor,
73+
validatedRuntimeAttributes,
74+
Some(Gcp),
75+
jobLogger
76+
)
77+
)
78+
val vmInfo = InstantiatedVmInfo("europe-west9", "custom-16-32768", false)
79+
80+
behavior of "BatchPollResultMonitorActor"
81+
82+
it should "send a cost lookup request with the correct vm info after receiving a success pollResult" in {
83+
84+
val terminalPollResult =
85+
RunStatus.Success(Seq(ExecutionEvent("fakeEvent", OffsetDateTime.now().truncatedTo(ChronoUnit.MILLIS))),
86+
Some(vmInfo)
87+
)
88+
val message = ProcessThisPollResult(terminalPollResult)
89+
90+
actor ! message
91+
92+
serviceRegistry.expectMsgPF(1.seconds) { case m: GcpCostLookupRequest =>
93+
m.vmInfo shouldBe vmInfo
94+
}
95+
}
96+
97+
it should "emit the correct cost metadata after receiving a costLookupResponse" in {
98+
99+
val costLookupResponse = GcpCostLookupResponse(vmInfo, Valid(BigDecimal(0.1)))
100+
101+
actor ! costLookupResponse
102+
103+
serviceRegistry.expectMsgPF(1.seconds) { case m: PutMetadataAction =>
104+
val event = m.events.head
105+
m.events.size shouldBe 1
106+
event.key.key shouldBe CallMetadataKeys.VmCostPerHour
107+
event.value.get.value shouldBe "0.1"
108+
}
109+
}
110+
111+
it should "emit the correct start time after receiving a running pollResult" in {
112+
113+
val vmStartTime = OffsetDateTime.now().minus(2, ChronoUnit.HOURS)
114+
val pollResult = RunStatus.Running(
115+
Seq(ExecutionEvent(CallMetadataKeys.VmStartTime, vmStartTime)),
116+
Some(vmInfo)
117+
)
118+
val message = ProcessThisPollResult(pollResult)
119+
120+
actor ! message
121+
122+
serviceRegistry.expectMsgPF(1.seconds) { case m: PutMetadataAction =>
123+
val event = m.events.head
124+
m.events.size shouldBe 1
125+
event.key.key shouldBe CallMetadataKeys.VmStartTime
126+
assert(
127+
Instant
128+
.parse(event.value.get.value)
129+
.equals(vmStartTime.toInstant.truncatedTo(ChronoUnit.MILLIS))
130+
)
131+
}
132+
}
133+
134+
it should "emit the correct end time after receiving a running pollResult" in {
135+
136+
val vmEndTime = OffsetDateTime.now().minus(2, ChronoUnit.HOURS)
137+
val pollResult = RunStatus.Running(
138+
Seq(ExecutionEvent(CallMetadataKeys.VmEndTime, vmEndTime)),
139+
Some(vmInfo)
140+
)
141+
val message = ProcessThisPollResult(pollResult)
142+
143+
actor ! message
144+
145+
serviceRegistry.expectMsgPF(1.seconds) { case m: PutMetadataAction =>
146+
val event = m.events.head
147+
m.events.size shouldBe 1
148+
event.key.key shouldBe CallMetadataKeys.VmEndTime
149+
assert(
150+
Instant
151+
.parse(event.value.get.value)
152+
.equals(vmEndTime.toInstant.truncatedTo(ChronoUnit.MILLIS))
153+
)
154+
}
155+
}
156+
}

0 commit comments

Comments
 (0)