Trouter-Library commited on
Commit
bbad13f
·
verified ·
1 Parent(s): 84f5cf9

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +517 -79
inference.py CHANGED
@@ -1,142 +1,580 @@
1
  """
2
  Helion-OSC Inference Script
3
  DeepXR/Helion-OSC - Mathematical Coding Language Model
 
 
 
4
  """
5
 
6
  import torch
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
- from typing import Optional, Dict, Any
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  class HelionOSCInference:
12
- """Inference wrapper for Helion-OSC model"""
 
 
 
 
 
 
 
 
 
13
 
14
  def __init__(
15
  self,
16
  model_name: str = "DeepXR/Helion-OSC",
17
  device: Optional[str] = None,
18
- load_in_8bit: bool = False
 
 
 
19
  ):
20
  """
21
  Initialize the Helion-OSC model
22
 
23
  Args:
24
  model_name: HuggingFace model identifier
25
- device: Device to load model on (cuda/cpu)
26
- load_in_8bit: Whether to load model in 8-bit precision
 
 
 
27
  """
28
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- print(f"Loading Helion-OSC on {self.device}...")
 
31
 
32
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- model_kwargs = {"device_map": "auto"} if self.device == "cuda" else {}
35
- if load_in_8bit:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  model_kwargs["load_in_8bit"] = True
37
-
38
- self.model = AutoModelForCausalLM.from_pretrained(
39
- model_name,
40
- torch_dtype=torch.bfloat16 if self.device == "cuda" else torch.float32,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  **model_kwargs
42
  )
43
 
44
- if self.device == "cpu":
45
- self.model = self.model.to(self.device)
 
 
 
 
46
 
47
- self.model.eval()
48
- print("Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def generate(
51
  self,
52
- prompt: str,
53
- max_length: int = 512,
54
- temperature: float = 0.7,
55
- top_p: float = 0.95,
56
- top_k: int = 50,
57
- num_return_sequences: int = 1,
58
- do_sample: bool = True,
59
  **kwargs
60
- ) -> str:
61
  """
62
- Generate code or text based on prompt
63
 
64
  Args:
65
- prompt: Input prompt
66
- max_length: Maximum length of generated text
67
- temperature: Sampling temperature
68
- top_p: Nucleus sampling parameter
69
- top_k: Top-k sampling parameter
70
- num_return_sequences: Number of sequences to generate
71
- do_sample: Whether to use sampling
72
 
73
  Returns:
74
- Generated text
75
  """
76
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  with torch.no_grad():
79
  outputs = self.model.generate(
80
  **inputs,
81
- max_length=max_length,
82
- temperature=temperature,
83
- top_p=top_p,
84
- top_k=top_k,
85
- num_return_sequences=num_return_sequences,
86
- do_sample=do_sample,
87
- pad_token_id=self.tokenizer.eos_token_id,
88
- **kwargs
 
 
 
 
89
  )
90
 
91
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
92
- return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- def code_generation(self, prompt: str, max_length: int = 1024) -> str:
95
- """Optimized for code generation tasks"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  return self.generate(
97
  prompt,
 
98
  max_length=max_length,
99
- temperature=0.7,
100
- top_p=0.95,
101
- do_sample=True
102
  )
103
 
104
- def mathematical_reasoning(self, prompt: str, max_length: int = 512) -> str:
105
- """Optimized for mathematical reasoning tasks"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  return self.generate(
107
  prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  max_length=max_length,
109
- temperature=0.3,
110
- top_p=0.9,
111
- do_sample=False
112
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  def main():
116
- """Example usage"""
 
 
 
 
117
  # Initialize model
118
- helion = HelionOSCInference()
119
-
120
- # Example 1: Code generation
121
- code_prompt = "Write a Python function to calculate the factorial of a number using recursion:"
122
- print("\n=== Code Generation ===")
123
- print(f"Prompt: {code_prompt}")
124
- result = helion.code_generation(code_prompt)
125
- print(f"Output:\n{result}\n")
126
-
127
- # Example 2: Mathematical reasoning
128
- math_prompt = "Prove that the sum of first n natural numbers is n(n+1)/2:"
129
- print("\n=== Mathematical Reasoning ===")
130
- print(f"Prompt: {math_prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  result = helion.mathematical_reasoning(math_prompt)
132
- print(f"Output:\n{result}\n")
133
-
134
- # Example 3: Algorithm design
135
- algo_prompt = "Design an efficient algorithm to find the longest palindromic substring:"
136
- print("\n=== Algorithm Design ===")
137
- print(f"Prompt: {algo_prompt}")
138
- result = helion.generate(algo_prompt, max_length=1024)
139
- print(f"Output:\n{result}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
 
142
  if __name__ == "__main__":
 
1
  """
2
  Helion-OSC Inference Script
3
  DeepXR/Helion-OSC - Mathematical Coding Language Model
4
+
5
+ This module provides comprehensive inference capabilities for the Helion-OSC model,
6
+ including specialized methods for different programming and mathematical tasks.
7
  """
8
 
9
  import torch
10
+ import json
11
+ import logging
12
+ from typing import Optional, Dict, Any, List, Union
13
+ from transformers import (
14
+ AutoTokenizer,
15
+ AutoModelForCausalLM,
16
+ GenerationConfig,
17
+ StoppingCriteria,
18
+ StoppingCriteriaList
19
+ )
20
+ from dataclasses import dataclass
21
+ import warnings
22
+
23
+ # Configure logging
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class GenerationParameters:
33
+ """Parameters for text generation"""
34
+ max_length: int = 2048
35
+ temperature: float = 0.7
36
+ top_p: float = 0.95
37
+ top_k: int = 50
38
+ repetition_penalty: float = 1.05
39
+ length_penalty: float = 1.0
40
+ do_sample: bool = True
41
+ num_return_sequences: int = 1
42
+ early_stopping: bool = False
43
+
44
+
45
+ class CodeStoppingCriteria(StoppingCriteria):
46
+ """Custom stopping criteria for code generation"""
47
+
48
+ def __init__(self, stop_sequences: List[str], tokenizer):
49
+ self.stop_sequences = stop_sequences
50
+ self.tokenizer = tokenizer
51
+
52
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
53
+ decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
54
+ return any(seq in decoded for seq in self.stop_sequences)
55
 
56
 
57
  class HelionOSCInference:
58
+ """
59
+ Comprehensive inference wrapper for Helion-OSC model
60
+
61
+ Supports multiple generation modes:
62
+ - Code generation
63
+ - Mathematical reasoning
64
+ - Algorithm design
65
+ - Code debugging
66
+ - Documentation generation
67
+ """
68
 
69
  def __init__(
70
  self,
71
  model_name: str = "DeepXR/Helion-OSC",
72
  device: Optional[str] = None,
73
+ load_in_8bit: bool = False,
74
+ load_in_4bit: bool = False,
75
+ use_flash_attention: bool = True,
76
+ trust_remote_code: bool = True
77
  ):
78
  """
79
  Initialize the Helion-OSC model
80
 
81
  Args:
82
  model_name: HuggingFace model identifier
83
+ device: Device to load model on (cuda/cpu/mps)
84
+ load_in_8bit: Load model in 8-bit precision
85
+ load_in_4bit: Load model in 4-bit precision
86
+ use_flash_attention: Use flash attention for faster inference
87
+ trust_remote_code: Trust remote code from model repository
88
  """
89
+ self.model_name = model_name
90
+ self.device = self._get_device(device)
91
+ self.load_in_8bit = load_in_8bit
92
+ self.load_in_4bit = load_in_4bit
93
+
94
+ logger.info(f"Initializing Helion-OSC on {self.device}...")
95
+
96
+ # Load tokenizer
97
+ self.tokenizer = self._load_tokenizer(trust_remote_code)
98
+
99
+ # Load model
100
+ self.model = self._load_model(
101
+ use_flash_attention=use_flash_attention,
102
+ trust_remote_code=trust_remote_code
103
+ )
104
 
105
+ # Load generation configs
106
+ self.generation_configs = self._load_generation_configs()
107
 
108
+ logger.info("Model loaded successfully!")
109
+ self._print_model_info()
110
+
111
+ def _get_device(self, device: Optional[str]) -> str:
112
+ """Determine the best available device"""
113
+ if device:
114
+ return device
115
+ if torch.cuda.is_available():
116
+ return "cuda"
117
+ elif torch.backends.mps.is_available():
118
+ return "mps"
119
+ return "cpu"
120
+
121
+ def _load_tokenizer(self, trust_remote_code: bool):
122
+ """Load and configure tokenizer"""
123
+ logger.info("Loading tokenizer...")
124
+ tokenizer = AutoTokenizer.from_pretrained(
125
+ self.model_name,
126
+ trust_remote_code=trust_remote_code,
127
+ padding_side="left"
128
+ )
129
 
130
+ # Ensure pad token is set
131
+ if tokenizer.pad_token is None:
132
+ tokenizer.pad_token = tokenizer.eos_token
133
+
134
+ return tokenizer
135
+
136
+ def _load_model(self, use_flash_attention: bool, trust_remote_code: bool):
137
+ """Load and configure model"""
138
+ logger.info("Loading model...")
139
+
140
+ model_kwargs = {
141
+ "trust_remote_code": trust_remote_code,
142
+ "low_cpu_mem_usage": True
143
+ }
144
+
145
+ # Configure precision and quantization
146
+ if self.load_in_8bit:
147
  model_kwargs["load_in_8bit"] = True
148
+ logger.info("Loading in 8-bit precision")
149
+ elif self.load_in_4bit:
150
+ model_kwargs["load_in_4bit"] = True
151
+ model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16
152
+ model_kwargs["bnb_4bit_use_double_quant"] = True
153
+ model_kwargs["bnb_4bit_quant_type"] = "nf4"
154
+ logger.info("Loading in 4-bit precision")
155
+ else:
156
+ if self.device == "cuda":
157
+ model_kwargs["torch_dtype"] = torch.bfloat16
158
+ else:
159
+ model_kwargs["torch_dtype"] = torch.float32
160
+
161
+ # Configure device mapping
162
+ if self.device == "cuda" and not (self.load_in_8bit or self.load_in_4bit):
163
+ model_kwargs["device_map"] = "auto"
164
+
165
+ # Load model
166
+ model = AutoModelForCausalLM.from_pretrained(
167
+ self.model_name,
168
  **model_kwargs
169
  )
170
 
171
+ # Move to device if needed
172
+ if self.device != "cuda" or (self.load_in_8bit or self.load_in_4bit):
173
+ if not (self.load_in_8bit or self.load_in_4bit):
174
+ model = model.to(self.device)
175
+
176
+ model.eval()
177
 
178
+ # Enable gradient checkpointing for memory efficiency if needed
179
+ if hasattr(model, 'gradient_checkpointing_enable'):
180
+ model.gradient_checkpointing_enable()
181
+
182
+ return model
183
+
184
+ def _load_generation_configs(self) -> Dict[str, GenerationParameters]:
185
+ """Load task-specific generation configurations"""
186
+ return {
187
+ "code_generation": GenerationParameters(
188
+ max_length=4096,
189
+ temperature=0.7,
190
+ top_p=0.95,
191
+ top_k=50,
192
+ repetition_penalty=1.05,
193
+ do_sample=True
194
+ ),
195
+ "mathematical_reasoning": GenerationParameters(
196
+ max_length=2048,
197
+ temperature=0.3,
198
+ top_p=0.9,
199
+ top_k=40,
200
+ repetition_penalty=1.0,
201
+ do_sample=False
202
+ ),
203
+ "code_completion": GenerationParameters(
204
+ max_length=1024,
205
+ temperature=0.6,
206
+ top_p=0.92,
207
+ top_k=45,
208
+ repetition_penalty=1.03,
209
+ do_sample=True
210
+ ),
211
+ "algorithm_design": GenerationParameters(
212
+ max_length=3072,
213
+ temperature=0.5,
214
+ top_p=0.93,
215
+ top_k=50,
216
+ repetition_penalty=1.08,
217
+ do_sample=True
218
+ ),
219
+ "debugging": GenerationParameters(
220
+ max_length=2048,
221
+ temperature=0.4,
222
+ top_p=0.88,
223
+ repetition_penalty=1.0,
224
+ do_sample=False
225
+ )
226
+ }
227
+
228
+ def _print_model_info(self):
229
+ """Print model information"""
230
+ try:
231
+ num_params = sum(p.numel() for p in self.model.parameters())
232
+ logger.info(f"Model parameters: {num_params:,}")
233
+ logger.info(f"Model dtype: {next(self.model.parameters()).dtype}")
234
+ logger.info(f"Device: {self.device}")
235
+ except Exception as e:
236
+ logger.warning(f"Could not get model info: {e}")
237
 
238
  def generate(
239
  self,
240
+ prompt: Union[str, List[str]],
241
+ task_type: str = "code_generation",
242
+ custom_params: Optional[GenerationParameters] = None,
243
+ stop_sequences: Optional[List[str]] = None,
244
+ return_full_text: bool = False,
 
 
245
  **kwargs
246
+ ) -> Union[str, List[str]]:
247
  """
248
+ Generate text based on prompt
249
 
250
  Args:
251
+ prompt: Input prompt or list of prompts
252
+ task_type: Type of task (code_generation, mathematical_reasoning, etc.)
253
+ custom_params: Custom generation parameters
254
+ stop_sequences: List of sequences to stop generation
255
+ return_full_text: Whether to return full text including prompt
256
+ **kwargs: Additional generation parameters
 
257
 
258
  Returns:
259
+ Generated text or list of generated texts
260
  """
261
+ # Get generation parameters
262
+ if custom_params:
263
+ params = custom_params
264
+ elif task_type in self.generation_configs:
265
+ params = self.generation_configs[task_type]
266
+ else:
267
+ logger.warning(f"Unknown task type '{task_type}', using default parameters")
268
+ params = GenerationParameters()
269
+
270
+ # Override with kwargs
271
+ for key, value in kwargs.items():
272
+ if hasattr(params, key):
273
+ setattr(params, key, value)
274
 
275
+ # Tokenize input
276
+ is_batch = isinstance(prompt, list)
277
+ inputs = self.tokenizer(
278
+ prompt,
279
+ return_tensors="pt",
280
+ padding=True,
281
+ truncation=True,
282
+ max_length=self.model.config.max_position_embeddings
283
+ ).to(self.device)
284
+
285
+ # Setup stopping criteria
286
+ stopping_criteria = None
287
+ if stop_sequences:
288
+ stopping_criteria = StoppingCriteriaList([
289
+ CodeStoppingCriteria(stop_sequences, self.tokenizer)
290
+ ])
291
+
292
+ # Generate
293
  with torch.no_grad():
294
  outputs = self.model.generate(
295
  **inputs,
296
+ max_length=params.max_length,
297
+ temperature=params.temperature,
298
+ top_p=params.top_p,
299
+ top_k=params.top_k,
300
+ repetition_penalty=params.repetition_penalty,
301
+ length_penalty=params.length_penalty,
302
+ do_sample=params.do_sample,
303
+ num_return_sequences=params.num_return_sequences,
304
+ early_stopping=params.early_stopping,
305
+ pad_token_id=self.tokenizer.pad_token_id,
306
+ eos_token_id=self.tokenizer.eos_token_id,
307
+ stopping_criteria=stopping_criteria
308
  )
309
 
310
+ # Decode outputs
311
+ generated_texts = []
312
+ for output in outputs:
313
+ text = self.tokenizer.decode(output, skip_special_tokens=True)
314
+ if not return_full_text and not is_batch:
315
+ # Remove prompt from single generation
316
+ if isinstance(prompt, str):
317
+ text = text[len(prompt):].strip()
318
+ generated_texts.append(text)
319
+
320
+ return generated_texts if is_batch or params.num_return_sequences > 1 else generated_texts[0]
321
+
322
+ def code_generation(
323
+ self,
324
+ prompt: str,
325
+ language: Optional[str] = None,
326
+ max_length: int = 4096,
327
+ **kwargs
328
+ ) -> str:
329
+ """
330
+ Generate code for a given prompt
331
+
332
+ Args:
333
+ prompt: Code generation prompt
334
+ language: Programming language (optional)
335
+ max_length: Maximum length of generated code
336
+ **kwargs: Additional generation parameters
337
+
338
+ Returns:
339
+ Generated code
340
+ """
341
+ if language:
342
+ prompt = f"Language: {language}\n{prompt}"
343
+
344
+ return self.generate(
345
+ prompt,
346
+ task_type="code_generation",
347
+ max_length=max_length,
348
+ **kwargs
349
+ )
350
+
351
+ def mathematical_reasoning(
352
+ self,
353
+ prompt: str,
354
+ max_length: int = 2048,
355
+ **kwargs
356
+ ) -> str:
357
+ """
358
+ Solve mathematical problems with step-by-step reasoning
359
+
360
+ Args:
361
+ prompt: Mathematical problem
362
+ max_length: Maximum length of solution
363
+ **kwargs: Additional generation parameters
364
+
365
+ Returns:
366
+ Mathematical solution with reasoning
367
+ """
368
+ return self.generate(
369
+ prompt,
370
+ task_type="mathematical_reasoning",
371
+ max_length=max_length,
372
+ **kwargs
373
+ )
374
 
375
+ def algorithm_design(
376
+ self,
377
+ prompt: str,
378
+ include_complexity: bool = True,
379
+ max_length: int = 3072,
380
+ **kwargs
381
+ ) -> str:
382
+ """
383
+ Design algorithms with complexity analysis
384
+
385
+ Args:
386
+ prompt: Algorithm design prompt
387
+ include_complexity: Whether to include complexity analysis
388
+ max_length: Maximum length of output
389
+ **kwargs: Additional generation parameters
390
+
391
+ Returns:
392
+ Algorithm design with analysis
393
+ """
394
+ if include_complexity:
395
+ prompt += "\n\nPlease include time and space complexity analysis."
396
+
397
  return self.generate(
398
  prompt,
399
+ task_type="algorithm_design",
400
  max_length=max_length,
401
+ **kwargs
 
 
402
  )
403
 
404
+ def debug_code(
405
+ self,
406
+ code: str,
407
+ error_message: Optional[str] = None,
408
+ max_length: int = 2048,
409
+ **kwargs
410
+ ) -> str:
411
+ """
412
+ Debug code and provide fixes
413
+
414
+ Args:
415
+ code: Code to debug
416
+ error_message: Optional error message
417
+ max_length: Maximum length of output
418
+ **kwargs: Additional generation parameters
419
+
420
+ Returns:
421
+ Debugging analysis and fixes
422
+ """
423
+ prompt = f"Debug the following code:\n\n```\n{code}\n```"
424
+ if error_message:
425
+ prompt += f"\n\nError message: {error_message}"
426
+ prompt += "\n\nProvide a detailed explanation and fixed code."
427
+
428
  return self.generate(
429
  prompt,
430
+ task_type="debugging",
431
+ max_length=max_length,
432
+ **kwargs
433
+ )
434
+
435
+ def complete_code(
436
+ self,
437
+ code_context: str,
438
+ max_length: int = 1024,
439
+ **kwargs
440
+ ) -> str:
441
+ """
442
+ Complete partial code
443
+
444
+ Args:
445
+ code_context: Partial code to complete
446
+ max_length: Maximum length of completion
447
+ **kwargs: Additional generation parameters
448
+
449
+ Returns:
450
+ Code completion
451
+ """
452
+ return self.generate(
453
+ code_context,
454
+ task_type="code_completion",
455
  max_length=max_length,
456
+ stop_sequences=["\n\n", "```", "###"],
457
+ **kwargs
 
458
  )
459
+
460
+ def batch_generate(
461
+ self,
462
+ prompts: List[str],
463
+ task_type: str = "code_generation",
464
+ batch_size: int = 4,
465
+ **kwargs
466
+ ) -> List[str]:
467
+ """
468
+ Generate responses for multiple prompts in batches
469
+
470
+ Args:
471
+ prompts: List of prompts
472
+ task_type: Type of task
473
+ batch_size: Batch size for processing
474
+ **kwargs: Additional generation parameters
475
+
476
+ Returns:
477
+ List of generated responses
478
+ """
479
+ results = []
480
+ for i in range(0, len(prompts), batch_size):
481
+ batch = prompts[i:i + batch_size]
482
+ batch_results = self.generate(batch, task_type=task_type, **kwargs)
483
+ if isinstance(batch_results, str):
484
+ batch_results = [batch_results]
485
+ results.extend(batch_results)
486
+ return results
487
 
488
 
489
  def main():
490
+ """Example usage and demonstrations"""
491
+ print("=" * 80)
492
+ print("Helion-OSC Inference Examples")
493
+ print("=" * 80)
494
+
495
  # Initialize model
496
+ helion = HelionOSCInference(
497
+ load_in_8bit=False, # Set to True for lower memory usage
498
+ load_in_4bit=False # Set to True for even lower memory usage
499
+ )
500
+
501
+ # Example 1: Code Generation
502
+ print("\n" + "=" * 80)
503
+ print("Example 1: Code Generation")
504
+ print("=" * 80)
505
+ code_prompt = """Write a Python function to implement a binary search tree with the following methods:
506
+ - insert(value): Insert a new value
507
+ - search(value): Search for a value
508
+ - delete(value): Delete a value
509
+ - inorder_traversal(): Return inorder traversal
510
+
511
+ Include proper documentation and type hints."""
512
+
513
+ print(f"\nPrompt:\n{code_prompt}")
514
+ print("\nGenerating...")
515
+ result = helion.code_generation(code_prompt, language="python")
516
+ print(f"\nGenerated Code:\n{result}")
517
+
518
+ # Example 2: Mathematical Reasoning
519
+ print("\n" + "=" * 80)
520
+ print("Example 2: Mathematical Reasoning")
521
+ print("=" * 80)
522
+ math_prompt = """Prove that the sum of the first n natural numbers equals n(n+1)/2 using mathematical induction."""
523
+
524
+ print(f"\nPrompt:\n{math_prompt}")
525
+ print("\nGenerating...")
526
  result = helion.mathematical_reasoning(math_prompt)
527
+ print(f"\nSolution:\n{result}")
528
+
529
+ # Example 3: Algorithm Design
530
+ print("\n" + "=" * 80)
531
+ print("Example 3: Algorithm Design")
532
+ print("=" * 80)
533
+ algo_prompt = """Design an efficient algorithm to find the longest palindromic substring in a given string."""
534
+
535
+ print(f"\nPrompt:\n{algo_prompt}")
536
+ print("\nGenerating...")
537
+ result = helion.algorithm_design(algo_prompt, include_complexity=True)
538
+ print(f"\nAlgorithm:\n{result}")
539
+
540
+ # Example 4: Code Debugging
541
+ print("\n" + "=" * 80)
542
+ print("Example 4: Code Debugging")
543
+ print("=" * 80)
544
+ buggy_code = """
545
+ def fibonacci(n):
546
+ if n <= 1:
547
+ return n
548
+ return fibonacci(n-1) + fibonacci(n-2)
549
+
550
+ # This is too slow for large n
551
+ result = fibonacci(100)
552
+ """
553
+
554
+ print(f"\nBuggy Code:\n{buggy_code}")
555
+ print("\nGenerating debugging analysis...")
556
+ result = helion.debug_code(buggy_code, error_message="Takes too long to compute")
557
+ print(f"\nDebug Analysis:\n{result}")
558
+
559
+ # Example 5: Batch Processing
560
+ print("\n" + "=" * 80)
561
+ print("Example 5: Batch Code Generation")
562
+ print("=" * 80)
563
+ batch_prompts = [
564
+ "Write a Python function to reverse a linked list",
565
+ "Write a JavaScript function to debounce API calls",
566
+ "Write a Rust function to parse JSON safely"
567
+ ]
568
+
569
+ print("\nProcessing batch prompts...")
570
+ results = helion.batch_generate(batch_prompts, batch_size=2)
571
+ for i, (prompt, result) in enumerate(zip(batch_prompts, results), 1):
572
+ print(f"\nPrompt {i}: {prompt}")
573
+ print(f"Result {i}:\n{result}\n")
574
+
575
+ print("=" * 80)
576
+ print("Examples completed!")
577
+ print("=" * 80)
578
 
579
 
580
  if __name__ == "__main__":