master
/ miniconda3 / lib / python3.11 / site-packages / conda / gateways / repodata / jlap / fetch.py

fetch.py @74036c5 raw · history · blame

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
# Copyright (C) 2012 Anaconda, Inc
# SPDX-License-Identifier: BSD-3-Clause
# Lappin' up the jlap
from __future__ import annotations

import io
import json
import logging
import pathlib
import pprint
import re
import time
from contextlib import contextmanager
from hashlib import blake2b
from typing import Iterator

import jsonpatch
import zstandard
from requests import HTTPError

from conda.base.context import context
from conda.gateways.connection import Response, Session
from conda.gateways.repodata import (
    ETAG_KEY,
    LAST_MODIFIED_KEY,
    RepodataCache,
    RepodataState,
)

from .core import JLAP

log = logging.getLogger(__name__)


DIGEST_SIZE = 32  # 256 bits

JLAP_KEY = "jlap"
HEADERS = "headers"
NOMINAL_HASH = "blake2_256_nominal"
ON_DISK_HASH = "blake2_256"
LATEST = "latest"
JLAP_UNAVAILABLE = "jlap_unavailable"
ZSTD_UNAVAILABLE = "zstd_unavailable"

# save these headers. at least etag, last-modified, cache-control plus a few
# useful extras.
STORE_HEADERS = {
    "etag",
    "last-modified",
    "cache-control",
    "content-range",
    "content-length",
    "date",
    "content-type",
    "content-encoding",
}


def hash():
    """Ordinary hash."""
    return blake2b(digest_size=DIGEST_SIZE)


class Jlap304NotModified(Exception):
    pass


class JlapSkipZst(Exception):
    pass


class JlapPatchNotFound(LookupError):
    pass


def process_jlap_response(response: Response, pos=0, iv=b""):
    # if response is 304 Not Modified, could return a buffer with only the
    # cached footer...
    if response.status_code == 304:
        raise Jlap304NotModified()

    def lines() -> Iterator[bytes]:
        yield from response.iter_lines(delimiter=b"\n")  # type: ignore

    buffer = JLAP.from_lines(lines(), iv, pos)

    # new iv == initial iv if nothing changed
    pos, footer, _ = buffer[-2]
    footer = json.loads(footer)

    new_state = {
        # we need to save etag, last-modified, cache-control
        "headers": {
            k.lower(): v
            for k, v in response.headers.items()
            if k.lower() in STORE_HEADERS
        },
        "iv": buffer[-3][-1],
        "pos": pos,
        "footer": footer,
    }

    return buffer, new_state


def fetch_jlap(url, pos=0, etag=None, iv=b"", ignore_etag=True, session=None):
    response = request_jlap(
        url, pos=pos, etag=etag, ignore_etag=ignore_etag, session=session
    )
    return process_jlap_response(response, pos=pos, iv=iv)


def request_jlap(
    url, pos=0, etag=None, ignore_etag=True, session: Session | None = None
):
    """Return the part of the remote .jlap file we are interested in."""
    headers = {}
    if pos:
        headers["range"] = f"bytes={pos}-"
    if etag and not ignore_etag:
        headers["if-none-match"] = etag

    log.debug("%s %s", url, headers)

    assert session is not None

    timeout = context.remote_connect_timeout_secs, context.remote_read_timeout_secs
    response = session.get(url, stream=True, headers=headers, timeout=timeout)
    response.raise_for_status()

    log.debug("request headers: %s", pprint.pformat(response.request.headers))
    log.debug(
        "response headers: %s",
        pprint.pformat(
            {k: v for k, v in response.headers.items() if k.lower() in STORE_HEADERS}
        ),
    )
    log.debug("status: %d", response.status_code)
    if "range" in headers:
        # 200 is also a possibility that we'd rather not deal with; if the
        # server can't do range requests, also mark jlap as unavailable. Which
        # status codes mean 'try again' instead of 'it will never work'?
        if response.status_code not in (206, 304, 404, 416):
            raise HTTPError(
                f"Unexpected response code for range request {response.status_code}",
                response=response,
            )

    log.info("%s", response)

    return response


def format_hash(hash):
    """Abbreviate hash for formatting."""
    return hash[:16] + "\N{HORIZONTAL ELLIPSIS}"


def find_patches(patches, have, want):
    apply = []
    for patch in reversed(patches):
        if have == want:
            break
        if patch["to"] == want:
            apply.append(patch)
            want = patch["from"]

    if have != want:
        log.debug(f"No patch from local revision {format_hash(have)}")
        raise JlapPatchNotFound(f"No patch from local revision {format_hash(have)}")

    return apply


def apply_patches(data, apply):
    while apply:
        patch = apply.pop()
        log.debug(
            f"{format_hash(patch['from'])} \N{RIGHTWARDS ARROW} {format_hash(patch['to'])}, "
            f"{len(patch['patch'])} steps"
        )
        data = jsonpatch.JsonPatch(patch["patch"]).apply(data, in_place=True)


def withext(url, ext):
    return re.sub(r"(\.\w+)$", ext, url)


@contextmanager
def timeme(message):
    begin = time.monotonic()
    yield
    end = time.monotonic()
    log.debug("%sTook %0.02fs", message, end - begin)


def build_headers(json_path: pathlib.Path, state: RepodataState):
    """Caching headers for a path and state."""
    headers = {}
    # simplify if we require state to be empty when json_path is missing.
    if json_path.exists():
        etag = state.get("_etag")
        if etag:
            headers["if-none-match"] = etag
    return headers


class HashWriter(io.RawIOBase):
    def __init__(self, backing, hasher):
        self.backing = backing
        self.hasher = hasher

    def write(self, b: bytes):
        self.hasher.update(b)
        return self.backing.write(b)

    def close(self):
        self.backing.close()


def download_and_hash(
    hasher,
    url,
    json_path,
    session: Session,
    state: RepodataState | None,
    is_zst=False,
    dest_path: pathlib.Path | None = None,
):
    """Download url if it doesn't exist, passing bytes through hasher.update().

    json_path: Path of old cached data (ignore etag if not exists).
    dest_path: Path to write new data.
    """
    if dest_path is None:
        dest_path = json_path
    state = state or RepodataState()
    headers = build_headers(json_path, state)
    timeout = context.remote_connect_timeout_secs, context.remote_read_timeout_secs
    response = session.get(url, stream=True, timeout=timeout, headers=headers)
    log.debug("%s %s", url, response.headers)
    response.raise_for_status()
    length = 0
    # is there a status code for which we must clear the file?
    if response.status_code == 200:
        if is_zst:
            decompressor = zstandard.ZstdDecompressor()
            writer = decompressor.stream_writer(
                HashWriter(dest_path.open("wb"), hasher), closefd=True  # type: ignore
            )
        else:
            writer = HashWriter(dest_path.open("wb"), hasher)
        with writer as repodata:
            for block in response.iter_content(chunk_size=1 << 14):
                repodata.write(block)
    if response.request:
        log.info("Download %d bytes %r", length, response.request.headers)
    return response  # can be 304 not modified


def request_url_jlap_state(
    url,
    state: RepodataState,
    full_download=False,
    *,
    session: Session,
    cache: RepodataCache,
    temp_path: pathlib.Path,
) -> dict | None:
    jlap_state = state.get(JLAP_KEY, {})
    headers = jlap_state.get(HEADERS, {})
    json_path = cache.cache_path_json

    buffer = JLAP()  # type checks

    if (
        full_download
        or not (NOMINAL_HASH in state and json_path.exists())
        or not state.should_check_format("jlap")
    ):
        hasher = hash()
        with timeme(f"Download complete {url} "):
            # Don't deal with 304 Not Modified if hash unavailable e.g. if
            # cached without jlap
            if NOMINAL_HASH not in state:
                state.pop(ETAG_KEY, None)
                state.pop(LAST_MODIFIED_KEY, None)

            try:
                if state.should_check_format("zst"):
                    response = download_and_hash(
                        hasher,
                        withext(url, ".json.zst"),
                        json_path,  # makes conditional request if exists
                        dest_path=temp_path,  # writes to
                        session=session,
                        state=state,
                        is_zst=True,
                    )
                else:
                    raise JlapSkipZst()
            except (JlapSkipZst, HTTPError) as e:
                if isinstance(e, HTTPError) and e.response.status_code != 404:
                    raise
                if not isinstance(e, JlapSkipZst):
                    # don't update last-checked timestamp on skip
                    state.set_has_format("zst", False)
                    state[ZSTD_UNAVAILABLE] = time.time_ns()  # alternate method
                response = download_and_hash(
                    hasher,
                    withext(url, ".json"),
                    json_path,
                    dest_path=temp_path,
                    session=session,
                    state=state,
                )

            # will we use state['headers'] for caching against
            state["_mod"] = response.headers.get("last-modified")
            state["_etag"] = response.headers.get("etag")
            state["_cache_control"] = response.headers.get("cache-control")

        # was not re-hashed if 304 not modified
        if response.status_code == 200:
            state[NOMINAL_HASH] = state[ON_DISK_HASH] = hasher.hexdigest()

        have = state[NOMINAL_HASH]

        # a jlap buffer with zero patches. the buffer format is (position,
        # payload, checksum) where position is the offset from the beginning of
        # the file; payload is the leading or trailing checksum or other data;
        # and checksum is the running checksum for the file up to that point.
        buffer = JLAP([[-1, "", ""], [0, json.dumps({LATEST: have}), ""], [1, "", ""]])

    else:
        have = state[NOMINAL_HASH]
        # have_hash = state.get(ON_DISK_HASH)

        need_jlap = True
        try:
            iv_hex = jlap_state.get("iv", "")
            pos = jlap_state.get("pos", 0)
            etag = headers.get(ETAG_KEY, None)
            jlap_url = withext(url, ".jlap")
            log.debug("Fetch %s from iv=%s, pos=%s", jlap_url, iv_hex, pos)
            # wrong to read state outside of function, and totally rebuild inside
            buffer, jlap_state = fetch_jlap(
                jlap_url,
                pos=pos,
                etag=etag,
                iv=bytes.fromhex(iv_hex),
                session=session,
                ignore_etag=False,
            )
            state.set_has_format("jlap", True)
            need_jlap = False
        except ValueError:
            log.info("Checksum not OK")
        except IndexError as e:
            log.info("Incomplete file?", exc_info=e)
        except HTTPError as e:
            # If we get a 416 Requested range not satisfiable, the server-side
            # file may have been truncated and we need to fetch from 0
            if e.response.status_code == 404:
                state.set_has_format("jlap", False)
                return request_url_jlap_state(
                    url,
                    state,
                    full_download=True,
                    session=session,
                    cache=cache,
                    temp_path=temp_path,
                )
            log.exception("Requests error")

        if need_jlap:  # retry whole file, if range failed
            try:
                buffer, jlap_state = fetch_jlap(withext(url, ".jlap"), session=session)
            except (ValueError, IndexError) as e:
                log.exception("Error parsing jlap", exc_info=e)
                # a 'latest' hash that we can't achieve, triggering later error handling
                buffer = JLAP(
                    [[-1, "", ""], [0, json.dumps({LATEST: "0" * 32}), ""], [1, "", ""]]
                )
                state.set_has_format("jlap", False)

        state[JLAP_KEY] = jlap_state

    with timeme("Apply Patches "):
        # buffer[0] == previous iv
        # buffer[1:-2] == patches
        # buffer[-2] == footer = new_state["footer"]
        # buffer[-1] == trailing checksum

        patches = list(json.loads(patch) for _, patch, _ in buffer.body)
        _, footer, _ = buffer.penultimate
        want = json.loads(footer)["latest"]

        try:
            apply = find_patches(patches, have, want)
            log.info(
                f"Apply {len(apply)} patches "
                f"{format_hash(have)} \N{RIGHTWARDS ARROW} {format_hash(want)}"
            )

            if apply:
                with timeme("Load "):
                    # we haven't loaded repodata yet; it could fail to parse, or
                    # have the wrong hash.
                    # if this fails, then we also need to fetch again from 0
                    repodata_json = json.loads(cache.load())
                    # XXX cache.state must equal what we started with, otherwise
                    # bail with 'repodata on disk' (indicating another process
                    # downloaded repodata.json in parallel with us)
                    if have != cache.state.get(NOMINAL_HASH):  # or check mtime_ns?
                        log.warn("repodata cache changed during jlap fetch.")
                        return None

                apply_patches(repodata_json, apply)

                with timeme("Write changed "), temp_path.open("wb") as repodata:
                    hasher = hash()
                    HashWriter(repodata, hasher).write(
                        json.dumps(repodata_json).encode("utf-8")
                    )

                    # actual hash of serialized json
                    state[ON_DISK_HASH] = hasher.hexdigest()

                    # hash of equivalent upstream json
                    state[NOMINAL_HASH] = want

                    # avoid duplicate parsing
                    return repodata_json
            else:
                assert state[NOMINAL_HASH] == want

        except (JlapPatchNotFound, json.JSONDecodeError) as e:
            if isinstance(e, JlapPatchNotFound):
                # 'have' hash not mentioned in patchset
                #
                # XXX or skip jlap at top of fn; make sure it is not
                # possible to download the complete json twice
                log.info(
                    "Current repodata.json %s not found in patchset. Re-download repodata.json"
                )

            assert not full_download, "Recursion error"  # pragma: no cover

            return request_url_jlap_state(
                url,
                state,
                full_download=True,
                session=session,
                cache=cache,
                temp_path=temp_path,
            )