Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OTA: chunked download #520

Merged
merged 18 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/ArduinoIoTCloudTCP.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,18 @@ class ArduinoIoTCloudTCP: public ArduinoIoTCloudClass
_get_ota_confirmation = cb;

if(_get_ota_confirmation) {
_ota.setOtaPolicies(OTACloudProcessInterface::ApprovalRequired);
_ota.enableOtaPolicy(OTACloudProcessInterface::ApprovalRequired);
} else {
_ota.setOtaPolicies(OTACloudProcessInterface::None);
_ota.disableOtaPolicy(OTACloudProcessInterface::ApprovalRequired);
}
}

/* Slower but more reliable in some corner cases */
void setOTAChunkMode(bool enable = true) {
if(enable) {
_ota.enableOtaPolicy(OTACloudProcessInterface::ChunkDownload);
} else {
_ota.disableOtaPolicy(OTACloudProcessInterface::ChunkDownload);
}
}
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/ota/interface/OTAInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ OTACloudProcessInterface::State OTACloudProcessInterface::idle(Message* msg) {
OTACloudProcessInterface::State OTACloudProcessInterface::otaAvailable() {
// depending on the policy decided on this device the ota process can start immediately
// or wait for confirmation from the user
if((policies & (ApprovalRequired | Approved)) == ApprovalRequired ) {
if(getOtaPolicy(ApprovalRequired) && !getOtaPolicy(Approved)) {
return OtaAvailable;
} else {
policies &= ~Approved;
disableOtaPolicy(Approved);
return StartOTA;
} // TODO add an abortOTA command? in this case delete the context
}
Expand Down
9 changes: 7 additions & 2 deletions src/ota/interface/OTAInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,22 @@ class OTACloudProcessInterface: public CloudProcess {
enum OtaFlags: uint16_t {
None = 0,
ApprovalRequired = 1,
Approved = 1<<1
Approved = 1<<1,
ChunkDownload = 1<<2
};

virtual void handleMessage(Message*);
// virtual CloudProcess::State getState();
// virtual void hook(State s, void* action);
virtual void update() { handleMessage(nullptr); }

inline void approveOta() { policies |= Approved; }
inline void approveOta() { this->policies |= Approved; }
inline void setOtaPolicies(uint16_t policies) { this->policies = policies; }

inline void enableOtaPolicy(OtaFlags policyFlag) { this->policies |= policyFlag; }
inline void disableOtaPolicy(OtaFlags policyFlag) { this->policies &= ~policyFlag; }
inline bool getOtaPolicy(OtaFlags policyFlag) { return (this->policies & policyFlag) != 0;}

inline State getState() { return state; }

virtual bool isOtaCapable() = 0;
Expand Down
130 changes: 89 additions & 41 deletions src/ota/interface/OTAInterfaceDefault.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,17 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() {
}
);

// make the http get request
// check url
if(strcmp(context->parsed_url.schema(), "https") == 0) {
http_client = new HttpClient(*client, context->parsed_url.host(), context->parsed_url.port());
} else {
return UrlParseErrorFail;
}

http_client->beginRequest();
auto res = http_client->get(context->parsed_url.path());

if(username != nullptr && password != nullptr) {
http_client->sendBasicAuth(username, password);
}

http_client->endRequest();

if(res == HTTP_ERROR_CONNECTION_FAILED) {
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
context->parsed_url.host(), context->parsed_url.port());
return ServerConnectErrorFail;
} else if(res == HTTP_ERROR_TIMED_OUT) {
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
return OtaHeaderTimeoutFail;
} else if(res != HTTP_SUCCESS) {
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", res, OTACloudProcessInterface::context->url);
return OtaDownloadFail;
}

int statusCode = http_client->responseStatusCode();

if(statusCode != 200) {
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
return HttpResponseFail;
// make the http get request
OTACloudProcessInterface::State res = requestOta();
if(res != Fetch) {
return res;
}

// The following call is required to save the header value , keep it
Expand All @@ -82,16 +60,27 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::startOTA() {
return HttpHeaderErrorFail;
}

context->contentLength = http_client->contentLength();
context->lastReportTime = millis();

DEBUG_VERBOSE("OTA file length: %d", context->contentLength);
return Fetch;
}

OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
OTACloudProcessInterface::State res = Fetch;
int http_res = 0;
uint32_t start = millis();

if(getOtaPolicy(ChunkDownload)) {
res = requestOta(ChunkDownload);
}

context->downloadedChunkSize = 0;
context->downloadedChunkStartTime = millis();

if(res != Fetch) {
goto exit;
}

/* download chunked or timed */
do {
if(!http_client->connected()) {
res = OtaDownloadFail;
Expand All @@ -104,7 +93,7 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
continue;
}

http_res = http_client->read(context->buffer, context->buf_len);
int http_res = http_client->read(context->buffer, context->bufLen);

if(http_res < 0) {
DEBUG_VERBOSE("OTA ERROR: Download read error %d", http_res);
Expand All @@ -119,8 +108,10 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
res = ErrorWriteUpdateFileFail;
goto exit;
}
} while((context->downloadState == OtaDownloadFile || context->downloadState == OtaDownloadHeader) &&
millis() - start < downloadTime);

context->downloadedChunkSize += http_res;

} while(context->downloadState < OtaDownloadCompleted && fetchMore());

// TODO verify that the information present in the ota header match the info in context
if(context->downloadState == OtaDownloadCompleted) {
Expand Down Expand Up @@ -153,13 +144,69 @@ OTACloudProcessInterface::State OTADefaultCloudProcessInterface::fetch() {
return res;
}

void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len) {
OTACloudProcessInterface::State OTADefaultCloudProcessInterface::requestOta(OtaFlags mode) {
int http_res = 0;

/* stop connected client */
http_client->stop();

/* request chunk */
http_client->beginRequest();
http_res = http_client->get(context->parsed_url.path());

if(username != nullptr && password != nullptr) {
http_client->sendBasicAuth(username, password);
}

if((mode & ChunkDownload) == ChunkDownload) {
char range[128] = {0};
size_t rangeSize = context->downloadedSize + maxChunkSize > context->contentLength ? context->contentLength - context->downloadedSize : maxChunkSize;
sprintf(range, "bytes=%lu-%lu", context->downloadedSize, context->downloadedSize + rangeSize);
DEBUG_VERBOSE("OTA downloading range: %s", range);
http_client->sendHeader("Range", range);
}

http_client->endRequest();

if(http_res == HTTP_ERROR_CONNECTION_FAILED) {
DEBUG_VERBOSE("OTA ERROR: http client error connecting to server \"%s:%d\"",
context->parsed_url.host(), context->parsed_url.port());
return ServerConnectErrorFail;
} else if(http_res == HTTP_ERROR_TIMED_OUT) {
DEBUG_VERBOSE("OTA ERROR: http client timeout \"%s\"", OTACloudProcessInterface::context->url);
return OtaHeaderTimeoutFail;
} else if(http_res != HTTP_SUCCESS) {
DEBUG_VERBOSE("OTA ERROR: http client returned %d on get \"%s\"", http_res, OTACloudProcessInterface::context->url);
return OtaDownloadFail;
}

int statusCode = http_client->responseStatusCode();

if((((mode & ChunkDownload) == ChunkDownload) && (statusCode != 206)) ||
(((mode & ChunkDownload) != ChunkDownload) && (statusCode != 200))) {
DEBUG_VERBOSE("OTA ERROR: get response on \"%s\" returned status %d", OTACloudProcessInterface::context->url, statusCode);
return HttpResponseFail;
}

http_client->skipResponseHeaders();
andreagilardoni marked this conversation as resolved.
Show resolved Hide resolved
return Fetch;
}

bool OTADefaultCloudProcessInterface::fetchMore() {
if (getOtaPolicy(ChunkDownload)) {
return context->downloadedChunkSize < maxChunkSize;
} else {
return (millis() - context->downloadedChunkStartTime) < downloadTime;
}
}

void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t bufLen) {
assert(context != nullptr); // This should never fail

for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+buf_len; ) {
for(uint8_t* cursor=(uint8_t*)buffer; cursor<buffer+bufLen; ) {
switch(context->downloadState) {
case OtaDownloadHeader: {
const uint32_t headerLeft = context->headerCopiedBytes + buf_len <= sizeof(context->header.buf) ? buf_len : sizeof(context->header.buf) - context->headerCopiedBytes;
const uint32_t headerLeft = context->headerCopiedBytes + bufLen <= sizeof(context->header.buf) ? bufLen : sizeof(context->header.buf) - context->headerCopiedBytes;
memcpy(context->header.buf+context->headerCopiedBytes, buffer, headerLeft);
cursor += headerLeft;
context->headerCopiedBytes += headerLeft;
Expand All @@ -184,8 +231,7 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
break;
}
case OtaDownloadFile: {
const uint32_t contentLength = http_client->contentLength();
const uint32_t dataLeft = buf_len - (cursor-buffer);
const uint32_t dataLeft = bufLen - (cursor-buffer);
context->decoder.decompress(cursor, dataLeft); // TODO verify return value

context->calculatedCrc32 = crc_update(
Expand All @@ -198,18 +244,18 @@ void OTADefaultCloudProcessInterface::parseOta(uint8_t* buffer, size_t buf_len)
context->downloadedSize += dataLeft;

if((millis() - context->lastReportTime) > 10000) { // Report the download progress each X millisecond
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, contentLength);
DEBUG_VERBOSE("OTA Download Progress %d/%d", context->downloadedSize, context->contentLength);

reportStatus(context->downloadedSize);
context->lastReportTime = millis();
}

// TODO there should be no more bytes available when the download is completed
if(context->downloadedSize == contentLength) {
if(context->downloadedSize == context->contentLength) {
context->downloadState = OtaDownloadCompleted;
}

if(context->downloadedSize > contentLength) {
if(context->downloadedSize > context->contentLength) {
context->downloadState = OtaDownloadError;
}
// TODO fail if we exceed a timeout? and available is 0 (client is broken)
Expand Down Expand Up @@ -250,7 +296,9 @@ OTADefaultCloudProcessInterface::Context::Context(
, headerCopiedBytes(0)
, downloadedSize(0)
, lastReportTime(0)
, contentLength(0)
, writeError(false)
, downloadedChunkSize(0)
, decoder(putc) { }

static const uint32_t crc_table[256] = {
Expand Down
16 changes: 13 additions & 3 deletions src/ota/interface/OTAInterfaceDefault.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
virtual int writeFlash(uint8_t* const buffer, size_t len) = 0;

private:
void parseOta(uint8_t* buffer, size_t buf_len);
void parseOta(uint8_t* buffer, size_t bufLen);
State requestOta(OtaFlags mode = None);
bool fetchMore();

Client* client;
HttpClient* http_client;
Expand All @@ -53,6 +55,10 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
// This mitigate the issues arising from tasks run in main loop that are using all the computing time
static constexpr uint32_t downloadTime = 2000;

// The amount of data that each iteration of Fetch has to take at least
// This should be enabled setting ChunkDownload OtaFlag to 1 and mitigate some Ota corner cases
static constexpr size_t maxChunkSize = 1024 * 10;

enum OTADownloadState: uint8_t {
OtaDownloadHeader,
OtaDownloadFile,
Expand All @@ -74,13 +80,17 @@ class OTADefaultCloudProcessInterface: public OTACloudProcessInterface {
uint32_t headerCopiedBytes;
uint32_t downloadedSize;
uint32_t lastReportTime;
uint32_t contentLength;
bool writeError;

uint32_t downloadedChunkStartTime;
uint32_t downloadedChunkSize;

// LZSS decoder
LZSSDecoder decoder;

const size_t buf_len = 64;
uint8_t buffer[64];
static constexpr size_t bufLen = 64;
uint8_t buffer[bufLen];
} *context;
};

Expand Down
Loading