diff --git a/examples/advanced/reuse-allocation.ts b/examples/advanced/reuse-allocation.ts new file mode 100644 index 000000000..5e1e99fa0 --- /dev/null +++ b/examples/advanced/reuse-allocation.ts @@ -0,0 +1,75 @@ +/** + * This advanced example demonstrates create an allocation manually and then reuse + * it across multiple market orders. + */ +import { MarketOrderSpec, GolemNetwork } from "@golem-sdk/golem-js"; +import { pinoPrettyLogger } from "@golem-sdk/pino-logger"; +(async () => { + const glm = new GolemNetwork({ + logger: pinoPrettyLogger({ + level: "info", + }), + }); + + try { + await glm.connect(); + + const allocation = await glm.payment.createAllocation({ + budget: 1, + expirationSec: 3600, + }); + + const firstOrder: MarketOrderSpec = { + demand: { + workload: { imageTag: "golem/alpine:latest" }, + }, + market: { + rentHours: 0.5, + pricing: { + model: "burn-rate", + avgGlmPerHour: 0.5, + }, + }, + payment: { + // You can either pass the allocation object ... + allocation, + }, + }; + const secondOrder: MarketOrderSpec = { + demand: { + workload: { imageTag: "golem/alpine:latest" }, + }, + market: { + rentHours: 0.5, + pricing: { + model: "burn-rate", + avgGlmPerHour: 0.5, + }, + }, + payment: { + // ... or just the allocation ID + allocation: allocation.id, + }, + }; + + const lease1 = await glm.oneOf(firstOrder); + const lease2 = await glm.oneOf(secondOrder); + + await lease1 + .getExeUnit() + .then((exe) => exe.run("echo Running on first lease")) + .then((res) => console.log(res.stdout)); + await lease2 + .getExeUnit() + .then((exe) => exe.run("echo Running on second lease")) + .then((res) => console.log(res.stdout)); + + await lease1.finalize(); + await lease2.finalize(); + await glm.payment.releaseAllocation(allocation); + } catch (err) { + console.error("Failed to run the example", err); + } finally { + await glm.disconnect(); + } +})().catch(console.error); diff --git a/examples/advanced/step-by-step.ts b/examples/advanced/step-by-step.ts index a20192432..c45d6ff16 100644 --- a/examples/advanced/step-by-step.ts +++ b/examples/advanced/step-by-step.ts @@ -104,7 +104,7 @@ import { filter, map, switchMap, take } from "rxjs"; // To keep this example simple, we will not retry and just crash if the signing fails const draftProposal = draftProposals[0]!; const agreement = await glm.market.proposeAgreement(draftProposal); - console.log("Agreement signed with provider", agreement.getProviderInfo().name); + console.log("Agreement signed with provider", agreement.provider.name); // Provider is ready to start the computation // Let's setup payment first diff --git a/package-lock.json b/package-lock.json index b95d6d289..a9460f078 100644 --- a/package-lock.json +++ b/package-lock.json @@ -28,7 +28,7 @@ "tmp": "^0.2.2", "uuid": "^9.0.1", "ws": "^8.16.0", - "ya-ts-client": "^1.1.1-beta.1" + "ya-ts-client": "^1.1.2" }, "devDependencies": { "@commitlint/cli": "^19.0.3", @@ -18973,9 +18973,9 @@ } }, "node_modules/ya-ts-client": { - "version": "1.1.1-beta.1", - "resolved": "https://registry.npmjs.org/ya-ts-client/-/ya-ts-client-1.1.1-beta.1.tgz", - "integrity": "sha512-rvkgvNphGdnm63hLe4hKnw8jvdavxEy1y7beLoAEn1nVhYRSF8unKveUdf0K2FbPhPNbiBT7vMEJVDF8x3fqFQ==", + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/ya-ts-client/-/ya-ts-client-1.1.2.tgz", + "integrity": "sha512-rZ2YMs3ATOzkaRo7NXDuNZ9xhjxufD/FLFpEm88LqnrpNJcRa81ZXy7B0yKQ2PaOedaKWELPjwcF7RoaOp2N1A==", "engines": { "node": ">=18.0.0" } diff --git a/package.json b/package.json index b2585af82..c8e70de7a 100644 --- a/package.json +++ b/package.json @@ -77,7 +77,7 @@ "tmp": "^0.2.2", "uuid": "^9.0.1", "ws": "^8.16.0", - "ya-ts-client": "^1.1.1-beta.1" + "ya-ts-client": "^1.1.2" }, "devDependencies": { "@commitlint/cli": "^19.0.3", diff --git a/src/activity/activity.module.ts b/src/activity/activity.module.ts index a59ec1e35..be982e272 100644 --- a/src/activity/activity.module.ts +++ b/src/activity/activity.module.ts @@ -125,11 +125,28 @@ export class ActivityModuleImpl implements ActivityModule { async executeScript(activity: Activity, script: ExeScriptRequest): Promise { this.logger.debug("Executing script on activity", { activityId: activity.id }); try { + this.events.emit("scriptSent", activity, script); const result = await this.activityApi.executeScript(activity, script); - this.events.emit("scriptExecuted", activity, script, result); + this.events.emit( + "scriptExecuted", + await this.refreshActivity(activity).catch(() => { + this.logger.warn("Failed to refresh activity after script execution", { activityId: activity.id }); + return activity; + }), + script, + result, + ); return result; } catch (error) { - this.events.emit("errorExecutingScript", activity, script, error); + this.events.emit( + "errorExecutingScript", + await this.refreshActivity(activity).catch(() => { + this.logger.warn("Failed to refresh activity after script execution error", { activityId: activity.id }); + return activity; + }), + script, + error, + ); throw error; } } @@ -142,10 +159,26 @@ export class ActivityModuleImpl implements ActivityModule { this.logger.debug("Fetching batch results", { activityId: activity.id, batchId }); try { const results = await this.activityApi.getExecBatchResults(activity, batchId, commandIndex, timeout); - this.events.emit("batchResultsReceived", activity, batchId, results); + this.events.emit( + "batchResultsReceived", + await this.refreshActivity(activity).catch(() => { + this.logger.warn("Failed to refresh activity after batch results received", { activityId: activity.id }); + return activity; + }), + batchId, + results, + ); return results; } catch (error) { - this.events.emit("errorGettingBatchResults", activity, batchId, error); + this.events.emit( + "errorGettingBatchResults", + await this.refreshActivity(activity).catch(() => { + this.logger.warn("Failed to refresh activity after batch results error", { activityId: activity.id }); + return activity; + }), + batchId, + error, + ); throw error; } } @@ -156,11 +189,27 @@ export class ActivityModuleImpl implements ActivityModule { ): Observable { this.logger.debug("Observing streaming batch events", { activityId: activity.id, batchId }); return this.activityApi.getExecBatchEvents(activity, batchId, commandIndex).pipe( - tap((event) => { - this.events.emit("batchEventsReceived", activity, batchId, event); + tap(async (event) => { + this.events.emit( + "batchEventsReceived", + await this.refreshActivity(activity).catch(() => { + this.logger.warn("Failed to refresh activity after batch events received", { activityId: activity.id }); + return activity; + }), + batchId, + event, + ); }), - catchError((error) => { - this.events.emit("errorGettingBatchEvents", activity, batchId, error); + catchError(async (error) => { + this.events.emit( + "errorGettingBatchEvents", + await this.refreshActivity(activity).catch(() => { + this.logger.warn("Failed to refresh activity after batch events error", { activityId: activity.id }); + return activity; + }), + batchId, + error, + ); throw error; }), ); @@ -169,7 +218,7 @@ export class ActivityModuleImpl implements ActivityModule { async createActivity(agreement: Agreement): Promise { this.logger.info("Creating activity", { agreementId: agreement.id, - provider: agreement.getProviderInfo(), + provider: agreement.provider, }); try { const activity = await this.activityApi.createActivity(agreement); @@ -185,7 +234,7 @@ export class ActivityModuleImpl implements ActivityModule { this.logger.info("Destroying activity", { activityId: activity.id, agreementId: activity.agreement.id, - provider: activity.agreement.getProviderInfo(), + provider: activity.agreement.provider, }); try { const updated = await this.activityApi.destroyActivity(activity); @@ -205,13 +254,13 @@ export class ActivityModuleImpl implements ActivityModule { }); try { const freshActivity = await this.activityApi.getActivity(staleActivity.id); - if (freshActivity.getState() !== staleActivity.getState()) { + if (freshActivity.getState() !== freshActivity.getPreviousState()) { this.logger.debug("Activity state changed", { activityId: staleActivity.id, - previousState: staleActivity.getState(), + previousState: freshActivity.getPreviousState(), newState: freshActivity.getState(), }); - this.events.emit("activityStateChanged", freshActivity, staleActivity.getState()); + this.events.emit("activityStateChanged", freshActivity, freshActivity.getPreviousState()); } return freshActivity; } catch (error) { @@ -236,10 +285,25 @@ export class ActivityModuleImpl implements ActivityModule { this.logger.debug("Initializing the exe-unit for activity", { activityId: activity.id }); try { await ctx.before(); - this.events.emit("activityInitialized", activity); + this.events.emit( + "workContextInitialized", + await this.refreshActivity(activity).catch(() => { + this.logger.warn("Failed to refresh activity after work context initialization", { activityId: activity.id }); + return activity; + }), + ); return ctx; } catch (error) { - this.events.emit("errorInitializingActivity", activity, error); + this.events.emit( + "errorInitializingWorkContext", + await this.refreshActivity(activity).catch(() => { + this.logger.warn("Failed to refresh activity after work context initialization error", { + activityId: activity.id, + }); + return activity; + }), + error, + ); throw error; } } diff --git a/src/activity/activity.ts b/src/activity/activity.ts index e89f6d41b..d9bdc601c 100644 --- a/src/activity/activity.ts +++ b/src/activity/activity.ts @@ -31,20 +31,26 @@ export class Activity { * @param id The ID of the activity in Yagna * @param agreement The agreement that's related to this activity * @param currentState The current state as it was obtained from yagna + * @param previousState The previous state (or New if this is the first time we're creating the activity) * @param usage Current resource usage vector information */ constructor( public readonly id: string, public readonly agreement: Agreement, protected readonly currentState: ActivityStateEnum = ActivityStateEnum.New, + protected readonly previousState: ActivityStateEnum = ActivityStateEnum.Unknown, protected readonly usage: ActivityUsageInfo, ) {} - public getProviderInfo(): ProviderInfo { - return this.agreement.getProviderInfo(); + public get provider(): ProviderInfo { + return this.agreement.provider; } public getState() { return this.currentState; } + + public getPreviousState() { + return this.previousState; + } } diff --git a/src/activity/api.ts b/src/activity/api.ts index 5e3174981..904cbc214 100644 --- a/src/activity/api.ts +++ b/src/activity/api.ts @@ -11,12 +11,13 @@ export type ActivityEvents = { activityDestroyed: (activity: Activity) => void; errorDestroyingActivity: (activity: Activity, error: Error) => void; - activityInitialized: (activity: Activity) => void; - errorInitializingActivity: (activity: Activity, error: Error) => void; + workContextInitialized: (activity: Activity) => void; + errorInitializingWorkContext: (activity: Activity, error: Error) => void; activityStateChanged: (activity: Activity, previousState: ActivityStateEnum) => void; errorRefreshingActivity: (activity: Activity, error: Error) => void; + scriptSent: (activity: Activity, script: ExeScriptRequest) => void; scriptExecuted: (activity: Activity, script: ExeScriptRequest, result: string) => void; errorExecutingScript: (activity: Activity, script: ExeScriptRequest, error: Error) => void; diff --git a/src/activity/exe-script-executor.test.ts b/src/activity/exe-script-executor.test.ts index 491159733..7f7d19f2a 100644 --- a/src/activity/exe-script-executor.test.ts +++ b/src/activity/exe-script-executor.test.ts @@ -24,7 +24,7 @@ describe("ExeScriptExecutor", () => { reset(mockStorageProvider); reset(mockActivityModule); resetAllMocks(); - when(mockActivity.getProviderInfo()).thenReturn({ + when(mockActivity.provider).thenReturn({ id: "test-provider-id", name: "test-provider-name", walletAddress: "0xProviderWallet", diff --git a/src/activity/exe-script-executor.ts b/src/activity/exe-script-executor.ts index 7d7f4093e..b8c00bc10 100644 --- a/src/activity/exe-script-executor.ts +++ b/src/activity/exe-script-executor.ts @@ -89,7 +89,7 @@ export class ExeScriptExecutor { WorkErrorCode.ScriptExecutionFailed, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, error, ); } @@ -175,7 +175,7 @@ export class ExeScriptExecutor { WorkErrorCode.ActivityResultsFetchingFailed, agreement, activity, - activity.getProviderInfo(), + activity.provider, error, ), ); @@ -224,7 +224,7 @@ export class ExeScriptExecutor { WorkErrorCode.ActivityResultsFetchingFailed, activity.agreement, activity, - activity.getProviderInfo(), + activity.provider, ); } if (error) { diff --git a/src/activity/work/batch.spec.ts b/src/activity/work/batch.spec.ts index f82db017f..ba9a72f59 100644 --- a/src/activity/work/batch.spec.ts +++ b/src/activity/work/batch.spec.ts @@ -34,8 +34,8 @@ describe("Batch", () => { walletAddress: "0xTestProvider", }; - when(mockAgreement.getProviderInfo()).thenReturn(providerInfo); - when(mockActivity.getProviderInfo()).thenReturn(providerInfo); + when(mockAgreement.provider).thenReturn(providerInfo); + when(mockActivity.provider).thenReturn(providerInfo); when(mockActivity.agreement).thenReturn(instance(mockAgreement)); activity = instance(mockActivity); @@ -161,7 +161,7 @@ describe("Batch", () => { WorkErrorCode.ScriptExecutionFailed, activity.agreement, activity, - activity.getProviderInfo(), + activity.provider, new Error("FAILURE"), ), ); @@ -179,7 +179,7 @@ describe("Batch", () => { WorkErrorCode.ScriptExecutionFailed, activity.agreement, activity, - activity.getProviderInfo(), + activity.provider, new Error("ERROR"), ), ); @@ -197,7 +197,7 @@ describe("Batch", () => { WorkErrorCode.ScriptExecutionFailed, activity.agreement, activity, - activity.getProviderInfo(), + activity.provider, new Error("FAILURE"), ), ); @@ -284,7 +284,7 @@ describe("Batch", () => { WorkErrorCode.ScriptExecutionFailed, activity.agreement, activity, - activity.getProviderInfo(), + activity.provider, new Error("ERROR"), ), ); diff --git a/src/activity/work/batch.ts b/src/activity/work/batch.ts index 0a6e5ff46..d99801ff5 100644 --- a/src/activity/work/batch.ts +++ b/src/activity/work/batch.ts @@ -107,7 +107,7 @@ export class Batch { WorkErrorCode.ScriptExecutionFailed, this.executor.activity.agreement, this.executor.activity, - this.executor.activity.agreement.getProviderInfo(), + this.executor.activity.agreement.provider, error, ); this.logger.debug("Error in batch script execution"); @@ -130,7 +130,7 @@ export class Batch { WorkErrorCode.ScriptExecutionFailed, this.executor.activity.agreement, this.executor.activity, - this.executor.activity.agreement.getProviderInfo(), + this.executor.activity.agreement.provider, error, ); } @@ -153,7 +153,7 @@ export class Batch { WorkErrorCode.ScriptExecutionFailed, this.executor.activity.agreement, this.executor.activity, - this.executor.activity.agreement.getProviderInfo(), + this.executor.activity.agreement.provider, error, ); } @@ -169,7 +169,7 @@ export class Batch { WorkErrorCode.ScriptExecutionFailed, activity.agreement, activity, - activity.getProviderInfo(), + activity.provider, ) : null; if (error) { diff --git a/src/activity/work/process.ts b/src/activity/work/process.ts index b6dec6ef0..7f9577eae 100644 --- a/src/activity/work/process.ts +++ b/src/activity/work/process.ts @@ -53,7 +53,7 @@ export class RemoteProcess { WorkErrorCode.ActivityResultsFetchingFailed, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, new GolemTimeoutError(`The waiting time (${timeoutInMs} ms) for the final result has been exceeded`), ), ); @@ -72,7 +72,7 @@ export class RemoteProcess { WorkErrorCode.ActivityResultsFetchingFailed, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, ), ); this.activityModule diff --git a/src/activity/work/work.ts b/src/activity/work/work.ts index ce47701e7..7214f59de 100644 --- a/src/activity/work/work.ts +++ b/src/activity/work/work.ts @@ -76,7 +76,7 @@ export class WorkContext { this.activityDeployingTimeout = options?.activityDeployingTimeout || DEFAULTS.activityDeployingTimeout; this.logger = options?.logger ?? defaultLogger("work"); - this.provider = activity.getProviderInfo(); + this.provider = activity.provider; this.storageProvider = options?.storageProvider ?? new NullStorageProvider(); this.networkNode = options?.networkNode; @@ -127,7 +127,7 @@ export class WorkContext { WorkErrorCode.ActivityDeploymentFailed, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, ); } await this.setupActivity(); @@ -157,7 +157,7 @@ export class WorkContext { WorkErrorCode.ActivityDeploymentFailed, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, e, ); }); @@ -180,7 +180,7 @@ export class WorkContext { WorkErrorCode.ActivityDeploymentFailed, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, ); } })(), @@ -194,7 +194,7 @@ export class WorkContext { WorkErrorCode.ActivityDeploymentFailed, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, error, ); }) @@ -282,7 +282,7 @@ export class WorkContext { WorkErrorCode.ScriptExecutionFailed, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, e, ); }); @@ -361,7 +361,7 @@ export class WorkContext { WorkErrorCode.NetworkSetupMissing, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, ); return this.networkNode.getWebsocketUri(port); @@ -374,7 +374,7 @@ export class WorkContext { WorkErrorCode.NetworkSetupMissing, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, ); return this.networkNode.ip; } @@ -413,7 +413,7 @@ export class WorkContext { WorkErrorCode.ScriptInitializationFailed, this.activity.agreement, this.activity, - this.activity.getProviderInfo(), + this.activity.provider, e, ); }); diff --git a/src/experimental/deployment/builder.test.ts b/src/experimental/deployment/builder.test.ts index 542709c51..bf1b1d708 100644 --- a/src/experimental/deployment/builder.test.ts +++ b/src/experimental/deployment/builder.test.ts @@ -59,22 +59,14 @@ describe("Deployment builder", () => { it("throws an error when creating a network with the same name", () => { const builder = new GolemDeploymentBuilder(mockGolemNetwork); expect(() => { - builder - .createNetwork("my-network", { - id: "test", - }) - .createNetwork("my-network", { - id: "test", - }); + builder.createNetwork("my-network").createNetwork("my-network"); }).toThrow(new GolemConfigError(`Network with name my-network already exists`)); }); it("throws an error when creating a deployment with an activity pool referencing a non-existing network", () => { const builder = new GolemDeploymentBuilder(mockGolemNetwork); expect(() => { builder - .createNetwork("existing-network", { - id: "test", - }) + .createNetwork("existing-network") .createLeaseProcessPool("my-pool", { demand: { workload: { imageTag: "image", minCpuCores: 1, minMemGib: 1, minStorageGib: 1 }, diff --git a/src/experimental/deployment/builder.ts b/src/experimental/deployment/builder.ts index 10f9f7d42..564b23500 100644 --- a/src/experimental/deployment/builder.ts +++ b/src/experimental/deployment/builder.ts @@ -38,7 +38,7 @@ export class GolemDeploymentBuilder { return this; } - createNetwork(name: string, options: NetworkOptions): this { + createNetwork(name: string, options: NetworkOptions = {}): this { if (this.components.networks.some((network) => network.name === name)) { throw new GolemConfigError(`Network with name ${name} already exists`); } diff --git a/src/golem-network/golem-network.test.ts b/src/golem-network/golem-network.test.ts index f392ad547..2b48916ad 100644 --- a/src/golem-network/golem-network.test.ts +++ b/src/golem-network/golem-network.test.ts @@ -96,6 +96,35 @@ describe("Golem Network", () => { verify(mockLeaseProcess.finalize()).once(); verify(mockPayment.releaseAllocation(allocation)).once(); }); + it("should not release the allocation if it was provided by the user", async () => { + const allocation = instance(mock(Allocation)); + + const mockLeaseProcess = mock(LeaseProcess); + const testProcess = instance(mockLeaseProcess); + when(mockLeaseProcess.finalize()).thenResolve(); + when(mockLease.createLease(_, _, _)).thenReturn(testProcess); + + when(mockMarket.collectDraftOfferProposals(_)).thenReturn(new Subject()); + jest.spyOn(DraftOfferProposalPool.prototype, "acquire").mockResolvedValue({} as OfferProposal); + + const glm = getGolemNetwork(); + await glm.connect(); + + const lease = await glm.oneOf({ + ...order, + payment: { + allocation, + }, + }); + + expect(lease).toBe(testProcess); + + await glm.disconnect(); + + verify(mockLeaseProcess.finalize()).once(); + verify(mockPayment.createAllocation(_)).never(); + verify(mockPayment.releaseAllocation(allocation)).never(); + }); }); describe("manyOf()", () => { @@ -127,5 +156,33 @@ describe("Golem Network", () => { verify(mockLeasePool.drainAndClear()).once(); verify(mockPayment.releaseAllocation(allocation)).once(); }); + it("should not release the allocation if it was provided by the user", async () => { + const allocation = instance(mock(Allocation)); + + when(mockMarket.collectDraftOfferProposals(_)).thenReturn(new Subject()); + const mockLeasePool = mock(LeaseProcessPool); + when(mockLeasePool.drainAndClear()).thenResolve(); + const leasePool = instance(mockLeasePool); + when(mockLease.createLeaseProcessPool(_, _, _)).thenReturn(leasePool); + + const glm = getGolemNetwork(); + await glm.connect(); + + const pool = await glm.manyOf({ + concurrency: 3, + order: { + ...order, + payment: { + allocation, + }, + }, + }); + + expect(pool).toBe(leasePool); + await glm.disconnect(); + verify(mockLeasePool.drainAndClear()).once(); + verify(mockPayment.createAllocation(_)).never(); + verify(mockPayment.releaseAllocation(allocation)).never(); + }); }); }); diff --git a/src/golem-network/golem-network.ts b/src/golem-network/golem-network.ts index c5e35e3e6..25579bf76 100644 --- a/src/golem-network/golem-network.ts +++ b/src/golem-network/golem-network.ts @@ -8,7 +8,7 @@ import { MarketOptions, OfferProposal, } from "../market"; -import { IPaymentApi, PaymentModule, PaymentModuleImpl, PaymentModuleOptions } from "../payment"; +import { Allocation, IPaymentApi, PaymentModule, PaymentModuleImpl, PaymentModuleOptions } from "../payment"; import { ActivityModule, ActivityModuleImpl, IActivityApi, IFileServer } from "../activity"; import { INetworkApi, Network, NetworkModule, NetworkModuleImpl, NetworkOptions } from "../network"; import { EventEmitter } from "eventemitter3"; @@ -111,6 +111,14 @@ export interface GolemNetworkOptions { >; } +type AllocationOptions = { + /** + * Optionally pass an existing allocation to use or an ID of an allocation that already exists in yagna. + * If this is not provided, a new allocation will be created based on an estimated budget. + */ + allocation?: Allocation | string; +}; + /** * Represents the order specifications which will result in access to LeaseProcess. */ @@ -118,7 +126,7 @@ export interface MarketOrderSpec { demand: BuildDemandOptions; market: MarketOptions; activity?: LeaseProcessOptions["activity"]; - payment?: LeaseProcessOptions["payment"]; + payment?: LeaseProcessOptions["payment"] & AllocationOptions; network?: Network; } @@ -249,7 +257,7 @@ export class GolemNetwork { marketApi: this.options.override?.marketApi || new MarketApiAdapter(this.yagna, agreementRepository, proposalRepository, demandRepository, this.logger), - networkApi: this.options.override?.networkApi || new NetworkApiAdapter(this.yagna, this.logger), + networkApi: this.options.override?.networkApi || new NetworkApiAdapter(this.yagna), fileServer: this.options.override?.fileServer || new GftpServerAdapter(this.storageProvider), }; this.network = getFactory(NetworkModuleImpl, this.options.override?.network)(this.services); @@ -314,6 +322,26 @@ export class GolemNetwork { this.hasConnection = false; } + private async getAllocationFromOrder({ + order, + concurrency, + }: { + order: MarketOrderSpec; + concurrency: Concurrency; + }): Promise { + if (!order.payment?.allocation) { + const budget = this.market.estimateBudget({ order, concurrency }); + return this.payment.createAllocation({ + budget, + expirationSec: order.market.rentHours * 60 * 60, + }); + } + if (typeof order.payment.allocation === "string") { + return this.payment.getAllocation(order.payment.allocation); + } + return order.payment.allocation; + } + /** * Define your computational resource demand and access a single instance * @@ -339,12 +367,7 @@ export class GolemNetwork { selectProposal: order.market.proposalSelector, }); - const budget = this.market.estimateBudget({ order, concurrency: 1 }); - const allocation = await this.payment.createAllocation({ - budget, - expirationSec: order.market.rentHours * 60 * 60, - }); - + const allocation = await this.getAllocationFromOrder({ order, concurrency: 1 }); const demandSpecification = await this.market.buildDemandDetails(order.demand, allocation); const draftProposal$ = this.market.collectDraftOfferProposals({ @@ -364,7 +387,7 @@ export class GolemNetwork { ); const networkNode = order.network - ? await this.network.createNetworkNode(order.network, agreement.getProviderInfo().id) + ? await this.network.createNetworkNode(order.network, agreement.provider.id) : undefined; const lease = this.lease.createLease(agreement, allocation, { @@ -385,6 +408,10 @@ export class GolemNetwork { .removeNetworkNode(order.network, networkNode) .catch((err) => this.logger.error("Error while removing network node", err)); } + // Don't release the allocation if it was provided by the user + if (order.payment?.allocation) { + return; + } await this.payment .releaseAllocation(allocation) .catch((err) => this.logger.error("Error while releasing allocation", err)); @@ -435,11 +462,7 @@ export class GolemNetwork { selectProposal: order.market.proposalSelector, }); - const budget = this.market.estimateBudget({ concurrency, order }); - const allocation = await this.payment.createAllocation({ - budget, - expirationSec: order.market.rentHours * 60 * 60, - }); + const allocation = await this.getAllocationFromOrder({ order, concurrency }); const demandSpecification = await this.market.buildDemandDetails(order.demand, allocation); const draftProposal$ = this.market.collectDraftOfferProposals({ @@ -469,7 +492,10 @@ export class GolemNetwork { await leaseProcessPool .drainAndClear() .catch((err) => this.logger.error("Error while draining lease process pool", err)); - + // Don't release the allocation if it was provided by the user + if (order.payment?.allocation) { + return; + } await this.payment .releaseAllocation(allocation) .catch((err) => this.logger.error("Error while releasing allocation", err)); diff --git a/src/lease-process/lease-process-pool.ts b/src/lease-process/lease-process-pool.ts index 348f8f46a..752f112a1 100644 --- a/src/lease-process/lease-process-pool.ts +++ b/src/lease-process/lease-process-pool.ts @@ -115,7 +115,7 @@ export class LeaseProcessPool { signalOrTimeout, ); const networkNode = this.network - ? await this.networkModule.createNetworkNode(this.network, agreement.getProviderInfo().id) + ? await this.networkModule.createNetworkNode(this.network, agreement.provider.id) : undefined; const leaseProcess = this.leaseModule.createLease(agreement, this.allocation, { networkNode, diff --git a/src/market/agreement/agreement.ts b/src/market/agreement/agreement.ts index 2efb53421..8e171ab20 100644 --- a/src/market/agreement/agreement.ts +++ b/src/market/agreement/agreement.ts @@ -51,7 +51,7 @@ export class Agreement { return this.model.state; } - getProviderInfo(): ProviderInfo { + get provider(): ProviderInfo { return { id: this.model.offer.providerId, name: this.model.offer.properties["golem.node.id.name"], diff --git a/src/market/api.ts b/src/market/api.ts index dd73df10b..779ac42cc 100644 --- a/src/market/api.ts +++ b/src/market/api.ts @@ -27,13 +27,16 @@ export type MarketEvents = { /** Emitted when offer proposal from the Provider is received */ offerProposalReceived: (event: OfferProposalReceivedEvent) => void; + offerCounterProposalSent: (offerProposal: OfferProposal, counterProposal: OfferCounterProposal) => void; + errorSendingCounterProposal: (offerProposal: OfferProposal, error: Error) => void; + /** Emitted when the Provider rejects the counter-proposal that the Requestor sent */ offerCounterProposalRejected: (event: OfferCounterProposalRejectedEvent) => void; /** Not implemented */ offerPropertyQueryReceived: (event: OfferPropertyQueryReceivedEvent) => void; - offerProposalRejectedByFilter: (offerProposal: OfferProposal, reason?: string) => void; + offerProposalRejectedByProposalFilter: (offerProposal: OfferProposal, reason?: string) => void; /** Emitted when proposal price does not meet user criteria */ offerProposalRejectedByPriceFilter: (offerProposal: OfferProposal, reason?: string) => void; diff --git a/src/market/market.module.test.ts b/src/market/market.module.test.ts index 875cba6a6..b6e34effa 100644 --- a/src/market/market.module.test.ts +++ b/src/market/market.module.test.ts @@ -1,9 +1,9 @@ import { _, imock, instance, mock, reset, spy, verify, when } from "@johanblumenberg/ts-mockito"; import { Logger, YagnaApi } from "../shared/utils"; import { MarketModuleImpl } from "./market.module"; -import { Demand, DemandSpecification, IDemandRepository } from "./demand"; +import { Demand, DemandSpecification } from "./demand"; import { Subject, take } from "rxjs"; -import { IProposalRepository, MarketProposalEvent, OfferProposal, ProposalProperties } from "./proposal"; +import { MarketProposalEvent, OfferProposal, ProposalProperties } from "./proposal"; import { MarketApiAdapter } from "../shared/yagna/"; import { IActivityApi, IFileServer } from "../activity"; import { StorageProvider } from "../shared/storage"; @@ -37,8 +37,6 @@ beforeEach(() => { activityApi: instance(imock()), paymentApi: instance(imock()), networkApi: instance(imock()), - proposalRepository: instance(imock()), - demandRepository: instance(imock()), yagna: instance(mockYagna), logger: instance(imock()), marketApi: instance(mockMarketApiAdapter), diff --git a/src/market/market.module.ts b/src/market/market.module.ts index cdedf763e..544972801 100644 --- a/src/market/market.module.ts +++ b/src/market/market.module.ts @@ -20,14 +20,13 @@ import { import { Allocation, IPaymentApi } from "../payment"; import { filter, map, Observable, OperatorFunction, switchMap, tap } from "rxjs"; import { - IProposalRepository, OfferCounterProposal, OfferProposal, OfferProposalReceivedEvent, ProposalFilter, ProposalsBatch, } from "./proposal"; -import { BuildDemandOptions, DemandBodyBuilder, DemandSpecification, IDemandRepository } from "./demand"; +import { BuildDemandOptions, DemandBodyBuilder, DemandSpecification } from "./demand"; import { IActivityApi, IFileServer } from "../activity"; import { StorageProvider } from "../shared/storage"; import { WorkloadDemandDirectorConfig } from "./demand/directors/workload-demand-director-config"; @@ -211,16 +210,12 @@ export class MarketModuleImpl implements MarketModule { private readonly logger = defaultLogger("market"); private readonly marketApi: IMarketApi; - private readonly proposalRepo: IProposalRepository; - private readonly demandRepo: IDemandRepository; private fileServer: IFileServer; constructor( private readonly deps: { logger: Logger; yagna: YagnaApi; - proposalRepository: IProposalRepository; - demandRepository: IDemandRepository; paymentApi: IPaymentApi; activityApi: IActivityApi; marketApi: IMarketApi; @@ -232,8 +227,6 @@ export class MarketModuleImpl implements MarketModule { ) { this.logger = deps.logger; this.marketApi = deps.marketApi; - this.proposalRepo = deps.proposalRepository; - this.demandRepo = deps.demandRepository; this.fileServer = deps.fileServer; this.collectAndEmitAgreementEvents(); @@ -377,11 +370,15 @@ export class MarketModuleImpl implements MarketModule { offerProposal: OfferProposal, counterDemand: DemandSpecification, ): Promise { - const counterProposal = await this.deps.marketApi.counterProposal(offerProposal, counterDemand); - - this.logger.debug("Counter proposal sent", counterProposal); - - return counterProposal; + try { + const counterProposal = await this.deps.marketApi.counterProposal(offerProposal, counterDemand); + this.logger.debug("Counter proposal sent", counterProposal); + this.events.emit("offerCounterProposalSent", offerProposal, counterProposal); + return counterProposal; + } catch (error) { + this.events.emit("errorSendingCounterProposal", offerProposal, error); + throw error; + } } async proposeAgreement(proposal: OfferProposal, options?: AgreementOptions): Promise { @@ -389,7 +386,7 @@ export class MarketModuleImpl implements MarketModule { this.logger.info("Proposed and got approval for agreement", { agreementId: agreement.id, - provider: agreement.getProviderInfo(), + provider: agreement.provider, }); return agreement; @@ -400,7 +397,7 @@ export class MarketModuleImpl implements MarketModule { this.logger.info("Terminated agreement", { agreementId: agreement.id, - provider: agreement.getProviderInfo(), + provider: agreement.provider, reason, }); @@ -631,7 +628,7 @@ export class MarketModuleImpl implements MarketModule { const result = filter(proposal); if (!result) { - this.events.emit("offerProposalRejectedByFilter", proposal); + this.events.emit("offerProposalRejectedByProposalFilter", proposal); this.logger.debug("The offer was rejected by the user filter", { id: proposal.id }); } diff --git a/src/market/proposal/offer-proposal.ts b/src/market/proposal/offer-proposal.ts index 9fce9e381..a6abf12c5 100644 --- a/src/market/proposal/offer-proposal.ts +++ b/src/market/proposal/offer-proposal.ts @@ -40,7 +40,6 @@ export type ProposalDTO = Partial<{ */ export class OfferProposal extends MarketProposal { public readonly issuer = "Provider"; - public provider: ProviderInfo; constructor( model: MarketApi.ProposalDTO, @@ -48,8 +47,6 @@ export class OfferProposal extends MarketProposal { ) { super(model); - this.provider = this.getProviderInfo(); - this.validate(); } @@ -107,7 +104,7 @@ export class OfferProposal extends MarketProposal { return this.pricing.start + this.pricing.cpuSec * threadsNo + this.pricing.envSec; } - public getProviderInfo(): ProviderInfo { + public get provider(): ProviderInfo { return { id: this.model.issuerId, name: this.properties["golem.node.id.name"], diff --git a/src/network/network.module.test.ts b/src/network/network.module.test.ts index 7d8890d98..aba27904a 100644 --- a/src/network/network.module.test.ts +++ b/src/network/network.module.test.ts @@ -48,10 +48,9 @@ describe("Network", () => { }); it("should create network with 16 bit mask", async () => { - await networkModule.createNetwork({ id: "1", ip: "192.168.7.0/16" }); + await networkModule.createNetwork({ ip: "192.168.7.0/16" }); expect(capture(mockNetworkApi.createNetwork).last()).toEqual([ { - id: "1", ip: "192.168.0.0", mask: "255.255.0.0", }, @@ -59,10 +58,9 @@ describe("Network", () => { }); it("should create network with 24 bit mask", async () => { - await networkModule.createNetwork({ id: "1", ip: "192.168.7.0/24" }); + await networkModule.createNetwork({ ip: "192.168.7.0/24" }); expect(capture(mockNetworkApi.createNetwork).last()).toEqual([ { - id: "1", ip: "192.168.7.0", mask: "255.255.255.0", }, @@ -70,10 +68,9 @@ describe("Network", () => { }); it("should create network with 8 bit mask", async () => { - await networkModule.createNetwork({ id: "1", ip: "192.168.7.0/8" }); + await networkModule.createNetwork({ ip: "192.168.7.0/8" }); expect(capture(mockNetworkApi.createNetwork).last()).toEqual([ { - id: "1", ip: "192.0.0.0", mask: "255.0.0.0", }, @@ -81,7 +78,7 @@ describe("Network", () => { }); it("should not create network with invalid ip", async () => { - const shouldFail = networkModule.createNetwork({ id: "1", ip: "123.1.2" }); + const shouldFail = networkModule.createNetwork({ ip: "123.1.2" }); await expect(shouldFail).rejects.toMatchError( new GolemNetworkError( "Unable to create network. An IP4 number cannot have less or greater than 4 octets", @@ -95,12 +92,10 @@ describe("Network", () => { it("should create network with custom gateway", async () => { await networkModule.createNetwork({ ip: "192.168.0.1/27", - id: "owner_1", gateway: "192.168.0.2", }); expect(capture(mockNetworkApi.createNetwork).last()).toEqual([ { - id: "owner_1", ip: "192.168.0.0", mask: "255.255.255.224", gateway: "192.168.0.2", diff --git a/src/network/network.module.ts b/src/network/network.module.ts index 532b8bf4f..62b18f4c9 100644 --- a/src/network/network.module.ts +++ b/src/network/network.module.ts @@ -9,16 +9,10 @@ import AsyncLock from "async-lock"; import { getMessageFromApiError } from "../shared/utils/apiErrorMessage"; export interface NetworkOptions { - /** - * The ID of the network. - * This is an optional field that can be used to specify a unique identifier for the network. - * If not provided, it will be generated automatically. - */ - id?: string; - /** * The IP address of the network. May contain netmask, e.g. "192.168.0.0/24". * This field can include the netmask directly in CIDR notation. + * @default "192.168.0.0" */ ip?: string; @@ -99,7 +93,6 @@ export class NetworkModuleImpl implements NetworkModule { const mask = ipRange.getPrefix().toMask(); const gateway = options?.gateway ? new IPv4(options.gateway) : undefined; const network = await this.networkApi.createNetwork({ - id: options?.id, ip: ip.toString(), mask: mask?.toString(), gateway: gateway?.toString(), diff --git a/src/payment/agreement_payment_process.spec.ts b/src/payment/agreement_payment_process.spec.ts index 15b9bcede..9d6f70079 100644 --- a/src/payment/agreement_payment_process.spec.ts +++ b/src/payment/agreement_payment_process.spec.ts @@ -30,7 +30,7 @@ beforeEach(() => { walletAddress: "0x1234", }; - when(agreementMock.getProviderInfo()).thenReturn(testProviderInfo); + when(agreementMock.provider).thenReturn(testProviderInfo); when(invoiceMock.provider).thenReturn(testProviderInfo); when(mockPaymentModule.observeInvoices()).thenReturn(new Subject()); when(mockPaymentModule.observeDebitNotes()).thenReturn(new Subject()); @@ -170,7 +170,7 @@ describe("AgreementPaymentProcess", () => { "Agreement agreement-id is already covered with an invoice: invoice-id", PaymentErrorCode.AgreementAlreadyPaid, allocation, - agreement.getProviderInfo(), + agreement.provider, ), ); expect(process.isFinished()).toEqual(true); diff --git a/src/payment/agreement_payment_process.ts b/src/payment/agreement_payment_process.ts index 8aeef8b30..0af0cce7e 100644 --- a/src/payment/agreement_payment_process.ts +++ b/src/payment/agreement_payment_process.ts @@ -256,7 +256,7 @@ export class AgreementPaymentProcess { invoiceId: invoice.id, agreementId: invoice.agreementId, amount: invoice.amount, - provider: this.agreement.getProviderInfo(), + provider: this.agreement.provider, }); } catch (error) { const message = getMessageFromApiError(error); diff --git a/src/payment/api.ts b/src/payment/api.ts index 806ceafa9..19279385b 100644 --- a/src/payment/api.ts +++ b/src/payment/api.ts @@ -57,7 +57,31 @@ export interface IPaymentApi { } export type CreateAllocationParams = { + /** + * How much to allocate + */ budget: number; - paymentPlatform: string; + /** + * How long the allocation should be valid + */ expirationSec: number; + /** + * Optionally override the payment platform to use for this allocation + */ + paymentPlatform?: string; + /** + * Optionally provide a deposit to be used for the allocation, instead of using funds from the yagna wallet. + * Deposit is a way to pay for the computation using someone else's funds. The other party has to + * call the `createDeposit` method on the `LockPayment` smart contract and provide the deposit ID. + */ + deposit?: { + /** + * Address of the smart contract that holds the deposit. + */ + contract: string; + /** + * ID of the deposit, obtained by calling the `createDeposit` method on the smart contract. + */ + id: string; + }; }; diff --git a/src/payment/payment.module.ts b/src/payment/payment.module.ts index c2c4e7682..aa18f61cf 100644 --- a/src/payment/payment.module.ts +++ b/src/payment/payment.module.ts @@ -47,12 +47,14 @@ export interface PaymentModule { observeInvoices(): Observable; - createAllocation(params: { budget: number; expirationSec: number }): Promise; + createAllocation(params: CreateAllocationParams): Promise; releaseAllocation(allocation: Allocation): Promise; amendAllocation(allocation: Allocation, params: CreateAllocationParams): Promise; + getAllocation(id: string): Promise; + acceptInvoice(invoice: Invoice, allocation: Allocation, amount: string): Promise; rejectInvoice(invoice: Invoice, reason: string): Promise; @@ -134,16 +136,13 @@ export class PaymentModuleImpl implements PaymentModule { return this.paymentApi.receivedInvoices$; } - async createAllocation(params: { budget: number; expirationSec: number }): Promise { - const payer = await this.getPayerDetails(); - - this.logger.info("Creating allocation", { params: params, payer }); + async createAllocation(params: CreateAllocationParams): Promise { + this.logger.info("Creating allocation", { params: params }); try { const allocation = await this.paymentApi.createAllocation({ - budget: params.budget, paymentPlatform: this.getPaymentPlatform(), - expirationSec: params.expirationSec, + ...params, }); this.events.emit("allocationCreated", allocation); return allocation; @@ -156,14 +155,30 @@ export class PaymentModuleImpl implements PaymentModule { async releaseAllocation(allocation: Allocation): Promise { this.logger.info("Releasing allocation", { id: allocation.id }); try { + const lastKnownAllocationState = await this.getAllocation(allocation.id).catch(() => { + this.logger.warn("Failed to fetch allocation before releasing", { id: allocation.id }); + return allocation; + }); await this.paymentApi.releaseAllocation(allocation); - this.events.emit("allocationReleased", allocation); + this.events.emit("allocationReleased", lastKnownAllocationState); } catch (error) { - this.events.emit("errorReleasingAllocation", allocation, error); + this.events.emit( + "errorReleasingAllocation", + await this.paymentApi.getAllocation(allocation.id).catch(() => { + this.logger.warn("Failed to fetch allocation after failed release attempt", { id: allocation.id }); + return allocation; + }), + error, + ); throw error; } } + getAllocation(id: string): Promise { + this.logger.debug("Fetching allocation by id", { id }); + return this.paymentApi.getAllocation(id); + } + // eslint-disable-next-line @typescript-eslint/no-unused-vars amendAllocation(allocation: Allocation, _newOpts: CreateAllocationParams): Promise { this.events.emit("errorAmendingAllocation", allocation, new Error("Amending allocation is not supported yet")); diff --git a/src/shared/yagna/adapters/activity-api-adapter.ts b/src/shared/yagna/adapters/activity-api-adapter.ts index 8c5000876..e0daae6b1 100644 --- a/src/shared/yagna/adapters/activity-api-adapter.ts +++ b/src/shared/yagna/adapters/activity-api-adapter.ts @@ -39,7 +39,7 @@ export class ActivityApiAdapter implements IActivityApi { WorkErrorCode.ActivityCreationFailed, agreement, undefined, - agreement.getProviderInfo(), + agreement.provider, ); } } @@ -55,7 +55,7 @@ export class ActivityApiAdapter implements IActivityApi { WorkErrorCode.ActivityDestroyingFailed, activity.agreement, activity, - activity.agreement.getProviderInfo(), + activity.agreement.provider, ); } } @@ -74,7 +74,7 @@ export class ActivityApiAdapter implements IActivityApi { WorkErrorCode.ScriptExecutionFailed, activity.agreement, activity, - activity.agreement.getProviderInfo(), + activity.agreement.provider, ); } } @@ -95,7 +95,7 @@ export class ActivityApiAdapter implements IActivityApi { WorkErrorCode.ActivityResultsFetchingFailed, activity.agreement, activity, - activity.getProviderInfo(), + activity.provider, error, ); } diff --git a/src/shared/yagna/adapters/market-api-adapter.ts b/src/shared/yagna/adapters/market-api-adapter.ts index 0dd5c8cf4..aa8c07d35 100644 --- a/src/shared/yagna/adapters/market-api-adapter.ts +++ b/src/shared/yagna/adapters/market-api-adapter.ts @@ -164,7 +164,6 @@ export class MarketApiAdapter implements IMarketApi { receivedProposal.id, bodyClone, ); - this.logger.debug("Proposal counter result from yagna", { result: maybeNewId }); if (typeof maybeNewId !== "string") { @@ -277,7 +276,7 @@ export class MarketApiAdapter implements IMarketApi { ); } - this.logger.info("Established agreement", { agreementId: agreement.id, provider: agreement.getProviderInfo() }); + this.logger.info("Established agreement", { agreementId: agreement.id, provider: agreement.provider }); return confirmed; } diff --git a/src/shared/yagna/adapters/network-api-adapter.ts b/src/shared/yagna/adapters/network-api-adapter.ts index c17046697..d4880c04c 100644 --- a/src/shared/yagna/adapters/network-api-adapter.ts +++ b/src/shared/yagna/adapters/network-api-adapter.ts @@ -1,19 +1,13 @@ import { YagnaApi } from "../yagnaApi"; -import { Logger } from "../../utils"; -import { INetworkApi } from "../../../network/api"; -import { GolemNetworkError, Network, NetworkErrorCode, NetworkNode } from "../../../network"; +import { GolemNetworkError, INetworkApi, Network, NetworkErrorCode, NetworkNode } from "../../../network"; import { getMessageFromApiError } from "../../utils/apiErrorMessage"; export class NetworkApiAdapter implements INetworkApi { - constructor( - private readonly yagnaApi: YagnaApi, - private readonly logger: Logger, - ) {} + constructor(private readonly yagnaApi: YagnaApi) {} - async createNetwork(options: { id: string; ip: string; mask?: string; gateway?: string }): Promise { + async createNetwork(options: { ip: string; mask?: string; gateway?: string }): Promise { try { const { id, ip, mask, gateway } = await this.yagnaApi.net.createNetwork(options); - // @ts-expect-error TODO: Remove when this PR is merged: https://github.com/golemfactory/ya-client/pull/179 return new Network(id, ip, mask, gateway); } catch (error) { const message = getMessageFromApiError(error); diff --git a/src/shared/yagna/adapters/payment-api-adapter.ts b/src/shared/yagna/adapters/payment-api-adapter.ts index 869697ed3..33033fd8d 100644 --- a/src/shared/yagna/adapters/payment-api-adapter.ts +++ b/src/shared/yagna/adapters/payment-api-adapter.ts @@ -182,13 +182,14 @@ export class PaymentApiAdapter implements IPaymentApi { const model = await this.yagna.payment.createAllocation({ totalAmount: params.budget.toString(), paymentPlatform: params.paymentPlatform, - address: address, + address, timestamp: now.toISOString(), timeout: new Date(+now + params.expirationSec * 1000).toISOString(), makeDeposit: false, remainingAmount: "", spentAmount: "", allocationId: "", + deposit: params.deposit, }); this.logger.debug( diff --git a/src/shared/yagna/repository/activity-repository.ts b/src/shared/yagna/repository/activity-repository.ts index 7d0279d42..aae7d15e7 100644 --- a/src/shared/yagna/repository/activity-repository.ts +++ b/src/shared/yagna/repository/activity-repository.ts @@ -3,8 +3,11 @@ import { ActivityApi } from "ya-ts-client"; import { IAgreementRepository } from "../../../market/agreement/agreement"; import { getMessageFromApiError } from "../../utils/apiErrorMessage"; import { GolemWorkError, WorkErrorCode } from "../../../activity"; +import { CacheService } from "../../cache/CacheService"; export class ActivityRepository implements IActivityRepository { + private stateCache: CacheService = new CacheService(); + constructor( private readonly state: ActivityApi.RequestorStateService, private readonly agreementRepo: IAgreementRepository, @@ -14,10 +17,11 @@ export class ActivityRepository implements IActivityRepository { try { const agreementId = await this.state.getActivityAgreement(id); const agreement = await this.agreementRepo.getById(agreementId); + const previousState = this.stateCache.get(id) ?? ActivityStateEnum.New; const state = await this.getStateOfActivity(id); const usage = await this.state.getActivityUsage(id); - return new Activity(id, agreement, state ?? ActivityStateEnum.Unknown, usage); + return new Activity(id, agreement, state ?? ActivityStateEnum.Unknown, previousState, usage); } catch (error) { const message = getMessageFromApiError(error); throw new GolemWorkError( @@ -33,12 +37,14 @@ export class ActivityRepository implements IActivityRepository { async getStateOfActivity(id: string): Promise { try { - const state = await this.state.getActivityState(id); - if (!state || state.state[0] === null) { + const yagnaStateResponse = await this.state.getActivityState(id); + if (!yagnaStateResponse || yagnaStateResponse.state[0] === null) { return ActivityStateEnum.Unknown; } - return ActivityStateEnum[state.state[0]]; + const state = ActivityStateEnum[yagnaStateResponse.state[0]]; + this.stateCache.set(id, state); + return state; } catch (error) { const message = getMessageFromApiError(error); throw new GolemWorkError( diff --git a/tests/examples/examples.json b/tests/examples/examples.json index 85036fefd..3aad9f743 100644 --- a/tests/examples/examples.json +++ b/tests/examples/examples.json @@ -10,6 +10,7 @@ { "cmd": "tsx", "path": "examples/advanced/proposal-filter.ts" }, { "cmd": "tsx", "path": "examples/advanced/proposal-predefined-filter.ts" }, { "cmd": "tsx", "path": "examples/advanced/override-module.ts" }, + { "cmd": "tsx", "path": "examples/advanced/reuse-allocation.ts" }, { "cmd": "tsx", "path": "examples/experimental/deployment/new-api.ts" }, { "cmd": "tsx", "path": "examples/experimental/job/getJobById.ts" }, { "cmd": "tsx", "path": "examples/experimental/job/waitForResults.ts" }, diff --git a/tests/unit/activity.test.ts b/tests/unit/activity.test.ts index 0e1586fe4..b5865091b 100644 --- a/tests/unit/activity.test.ts +++ b/tests/unit/activity.test.ts @@ -6,12 +6,20 @@ const mockAgreement = mock(Agreement); describe("Activity", () => { describe("Getting state", () => { it("should get activity state", () => { - const activity = new Activity("activity-id", instance(mockAgreement), ActivityStateEnum.New, { - currentUsage: [0.0, 0.0, 0.0], - timestamp: Date.now(), - }); + const activity = new Activity( + "activity-id", + instance(mockAgreement), + ActivityStateEnum.Initialized, + ActivityStateEnum.New, + { + currentUsage: [0.0, 0.0, 0.0], + timestamp: Date.now(), + }, + ); const state = activity.getState(); - expect(state).toEqual(ActivityStateEnum.New); + const prev = activity.getPreviousState(); + expect(state).toEqual(ActivityStateEnum.Initialized); + expect(prev).toEqual(ActivityStateEnum.New); }); }); }); diff --git a/tests/unit/agreement.test.ts b/tests/unit/agreement.test.ts index ff9736ab0..ba484ad91 100644 --- a/tests/unit/agreement.test.ts +++ b/tests/unit/agreement.test.ts @@ -38,12 +38,12 @@ const demand = new Demand( ); describe("Agreement", () => { - describe("getProviderInfo()", () => { + describe("get provider()", () => { it("should be a instance ProviderInfo with provider details", () => { const agreement = new Agreement(agreementData.agreementId, agreementData, demand); - expect(agreement.getProviderInfo().id).toEqual("provider-id"); - expect(agreement.getProviderInfo().name).toEqual("provider-name"); - expect(agreement.getProviderInfo().walletAddress).toEqual("0xProviderWallet"); + expect(agreement.provider.id).toEqual("provider-id"); + expect(agreement.provider.name).toEqual("provider-name"); + expect(agreement.provider.walletAddress).toEqual("0xProviderWallet"); }); }); diff --git a/tests/unit/work.test.ts b/tests/unit/work.test.ts index 3ce1a1b61..c0d049b6b 100644 --- a/tests/unit/work.test.ts +++ b/tests/unit/work.test.ts @@ -40,7 +40,7 @@ describe("Work Context", () => { reset(mockStorageProvider); reset(mockAgreement); reset(mockActivityModule); - when(mockActivity.getProviderInfo()).thenReturn({ + when(mockActivity.provider).thenReturn({ id: "test-provider-id", name: "test-provider-name", walletAddress: "0xProviderWallet",