|
1 | 1 | use ast_grep_codemod_dynamic_lang::{DynamicLang, Registration};
|
2 | 2 | use dirs::data_local_dir;
|
| 3 | +use indicatif::{ProgressBar, ProgressState, ProgressStyle}; |
3 | 4 | use reqwest;
|
| 5 | +use reqwest::header::CONTENT_LENGTH; |
4 | 6 | use serde::{Deserialize, Serialize};
|
5 | 7 | use std::{collections::HashSet, env, fmt, path::PathBuf, str::FromStr};
|
| 8 | +use tokio::io::AsyncWriteExt; |
| 9 | +use tokio_stream::StreamExt; |
6 | 10 |
|
7 | 11 | use crate::sandbox::engine::language_data::get_extensions_for_language;
|
8 | 12 |
|
@@ -54,20 +58,59 @@ pub async fn load_tree_sitter(languages: &[SupportedLanguage]) -> Result<Vec<Dyn
|
54 | 58 | "https://tree-sitter-parsers.s3.us-east-1.amazonaws.com".to_string()
|
55 | 59 | });
|
56 | 60 | let url = format!("{base_url}/tree-sitter/parsers/tree-sitter-{language}/latest/{os}-{arch}.{extension}");
|
| 61 | + |
57 | 62 | let client = reqwest::Client::builder()
|
58 | 63 | .timeout(std::time::Duration::from_secs(30))
|
59 | 64 | .build()
|
60 | 65 | .map_err(|e| format!("Failed to build HTTP client: {e}"))?;
|
| 66 | + |
| 67 | + let head_response = client |
| 68 | + .head(&url) |
| 69 | + .send() |
| 70 | + .await |
| 71 | + .map_err(|e| format!("Failed to get header from {url}: {e}"))?; |
| 72 | + |
| 73 | + let total_size = head_response |
| 74 | + .headers() |
| 75 | + .get(CONTENT_LENGTH) |
| 76 | + .and_then(|val| val.to_str().ok()?.parse().ok()) |
| 77 | + .unwrap_or(0); |
| 78 | + |
| 79 | + let progress_bar = ProgressBar::new(total_size); |
| 80 | + progress_bar.set_style( |
| 81 | + ProgressStyle::with_template( |
| 82 | + "{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})" |
| 83 | + ) |
| 84 | + .unwrap() |
| 85 | + .with_key("eta", |state: &ProgressState, w: &mut dyn std::fmt::Write| { |
| 86 | + write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap() |
| 87 | + }) |
| 88 | + .progress_chars("#>-") |
| 89 | + ); |
| 90 | + |
61 | 91 | let response = client
|
62 | 92 | .get(&url)
|
63 | 93 | .send()
|
64 | 94 | .await
|
65 | 95 | .map_err(|e| format!("Failed to download from {url}: {e}"))?;
|
66 |
| - let body = response |
67 |
| - .bytes() |
| 96 | + |
| 97 | + let mut stream = response.bytes_stream(); |
| 98 | + let mut file = tokio::fs::File::create(&lib_path) |
| 99 | + .await |
| 100 | + .map_err(|e| format!("Failed to create file: {e}"))?; |
| 101 | + |
| 102 | + while let Some(chunk) = stream.next().await { |
| 103 | + let chunk = chunk.map_err(|e| format!("Stream error from {url}: {e}"))?; |
| 104 | + file.write_all(&chunk) |
| 105 | + .await |
| 106 | + .map_err(|e| format!("Write error: {e}"))?; |
| 107 | + progress_bar.inc(chunk.len() as u64); |
| 108 | + } |
| 109 | + |
| 110 | + file.flush() |
68 | 111 | .await
|
69 |
| - .map_err(|e| format!("Failed to read response from {url}: {e}"))?; |
70 |
| - std::fs::write(&lib_path, body).map_err(|e| format!("Failed to write file: {e}"))?; |
| 112 | + .map_err(|e| format!("Flush error: {e}"))?; |
| 113 | + progress_bar.finish_with_message("Downloaded successfully"); |
71 | 114 | }
|
72 | 115 | ready_langs.insert(ReadyLang {
|
73 | 116 | language: *language,
|
|
0 commit comments