MJaheen commited on
Commit
fb609fe
·
1 Parent(s): 713f69a

Add new features and fixes

Browse files

- add new features like :
1. choose between multiple models
2. add optimized model for CPU using LCM method
3. add raw prompt input
4. add progress bar for generation

- fix some issues :
1. CPU/GPU compatibility and add force CPU mode for testing
2. fix About section and other documentation issues

Files changed (5) hide show
  1. README.md +10 -18
  2. src/app.py +132 -24
  3. src/model/config.py +71 -3
  4. src/model/generator.py +150 -44
  5. src/utils/image_processor.py +81 -0
README.md CHANGED
@@ -23,11 +23,19 @@ AI-powered meme generator using Stable Diffusion and LoRA fine-tuning.
23
 
24
  ## 🌟 Features
25
 
 
 
 
 
 
 
26
  - Generate **custom Pepe memes** from text prompts
27
  - Multiple **style presets** (happy, sad, smug, angry, etc.)
28
- - **Add meme text overlays** and download results
 
29
  - Adjustable generation parameters (CFG, steps, seed, etc.)
30
- - Batch generation and meme gallery system
 
31
 
32
  ---
33
 
@@ -54,22 +62,6 @@ pip install -r requirements.txt
54
  streamlit run src/app.py
55
  ```
56
 
57
- ---
58
-
59
- ## 🚀 Deployment on Hugging Face Spaces
60
-
61
- This app is optimized for deployment on Hugging Face Spaces with the following fixes:
62
-
63
- - **CPU Compatibility**: Uses `torch.float32` on CPU deployments to avoid dtype errors
64
- - **Memory Optimization**: Automatically enables attention and VAE slicing
65
- - **Error Handling**: Proper exception handling for optional dependencies like xformers
66
- - **Docker Support**: Updated Dockerfile with Python 3.11 and necessary system packages
67
-
68
- ### Deployment Fixes Applied:
69
- - Fixed mixed dtype errors when running on CPU-only environments
70
- - Removed autocast context that can cause tensor type mismatches
71
- - Added proper device detection and dtype selection
72
- - Enhanced error handling for optional GPU optimizations
73
 
74
  ---
75
 
 
23
 
24
  ## 🌟 Features
25
 
26
+ - **Multiple Model Support**: Switch between fine-tuned LoRA and base models
27
+ - Pepe Fine-tuned (LoRA) - Custom trained model
28
+ - Base SD 1.5 - Standard Stable Diffusion
29
+ - Dreamlike Photoreal 2.0 - Photorealistic style
30
+ - Openjourney v4 - Artistic Midjourney-style
31
+ - **Raw Prompt Mode**: Use exact prompts without automatic enhancements
32
  - Generate **custom Pepe memes** from text prompts
33
  - Multiple **style presets** (happy, sad, smug, angry, etc.)
34
+ - **Add meme text overlays** with automatic "MJ" signature
35
+ - **Real-time progress tracking** for each generation step
36
  - Adjustable generation parameters (CFG, steps, seed, etc.)
37
+ - Batch generation and meme gallery system
38
+ - **GPU & CPU compatible** with automatic optimization
39
 
40
  ---
41
 
 
62
  streamlit run src/app.py
63
  ```
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  ---
67
 
src/app.py CHANGED
@@ -41,12 +41,33 @@ def init_session_state():
41
  st.session_state.generated_images = []
42
  if 'generation_count' not in st.session_state:
43
  st.session_state.generation_count = 0
 
 
44
 
45
 
46
  @st.cache_resource
47
- def load_generator():
48
- """Load and cache the generator"""
49
- return PepeGenerator()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def get_example_prompts():
@@ -64,14 +85,41 @@ def main():
64
  """Main application"""
65
  init_session_state()
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # Header
68
  st.title("🐸 Pepe the Frog Meme Generator")
69
  st.markdown("Create custom Pepe memes using AI! Powered by Stable Diffusion.")
70
 
71
- # Sidebar
72
- st.sidebar.header("⚙️ Settings")
73
 
74
  # Style selection
 
75
  style_options = {
76
  "Default": "default",
77
  "😊 Happy": "happy",
@@ -88,10 +136,28 @@ def main():
88
  )
89
  style = style_options[selected_style]
90
 
91
- # Advanced settings
 
 
 
 
 
 
 
 
92
  with st.sidebar.expander("🔧 Advanced Settings"):
93
- steps = st.slider("Steps", 20, 100, 50, 5)
94
- guidance = st.slider("Guidance Scale", 1.0, 20.0, 7.5, 0.5)
 
 
 
 
 
 
 
 
 
 
95
  use_seed = st.checkbox("Fixed Seed")
96
  seed = st.number_input("Seed", 0, 999999, 42) if use_seed else None
97
 
@@ -133,7 +199,7 @@ def main():
133
  if st.session_state.generated_images:
134
  placeholder.image(
135
  st.session_state.generated_images[-1],
136
- use_container_width=True
137
  )
138
  else:
139
  placeholder.info("Your meme will appear here...")
@@ -141,40 +207,58 @@ def main():
141
  # Generate
142
  if generate and prompt:
143
  try:
144
- generator = load_generator()
 
 
 
 
 
145
 
146
- progress = st.progress(0)
147
- status = st.empty()
 
148
 
149
  for i in range(num_vars):
150
- status.text(f"Generating {i+1}/{num_vars}...")
151
- progress.progress((i + 1) / num_vars)
152
 
153
- # Generate
 
 
 
 
 
154
  image = generator.generate(
155
  prompt=prompt,
156
  style=style,
157
  num_inference_steps=steps,
158
  guidance_scale=guidance,
159
- seed=seed
 
 
160
  )
161
 
162
  # Add text if requested
163
  if add_text and (top_text or bottom_text):
164
- processor = ImageProcessor()
165
  image = processor.add_meme_text(image, top_text, bottom_text)
166
 
 
 
 
167
  st.session_state.generated_images.append(image)
168
  st.session_state.generation_count += 1
 
 
 
169
 
170
- progress.empty()
171
- status.empty()
172
-
173
- st.success("✅ Meme generated!")
 
174
 
175
  # Show result
176
  if num_vars == 1:
177
- placeholder.image(image, use_container_width=True)
178
 
179
  # Download
180
  buf = io.BytesIO()
@@ -190,7 +274,7 @@ def main():
190
  cols = st.columns(min(num_vars, 2))
191
  for idx, img in enumerate(st.session_state.generated_images[-num_vars:]):
192
  with cols[idx % 2]:
193
- st.image(img, use_container_width=True)
194
 
195
  except Exception as e:
196
  st.error(f"Error: {str(e)}")
@@ -205,7 +289,7 @@ def main():
205
  cols = st.columns(4)
206
  for idx, img in enumerate(reversed(st.session_state.generated_images[-8:])):
207
  with cols[idx % 4]:
208
- st.image(img, use_container_width=True)
209
 
210
  # Footer
211
  st.divider()
@@ -219,6 +303,30 @@ def main():
219
  st.session_state.generated_images = []
220
  st.session_state.generation_count = 0
221
  st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
 
224
  if __name__ == "__main__":
 
41
  st.session_state.generated_images = []
42
  if 'generation_count' not in st.session_state:
43
  st.session_state.generation_count = 0
44
+ if 'current_model' not in st.session_state:
45
+ st.session_state.current_model = None
46
 
47
 
48
  @st.cache_resource
49
+ def load_generator(model_name: str = "Pepe Fine-tuned (LoRA)"):
50
+ """Load and cache the generator based on selected model"""
51
+ config = ModelConfig()
52
+ model_config = config.AVAILABLE_MODELS[model_name]
53
+
54
+ # Update config with selected model settings
55
+ config.BASE_MODEL = model_config['base']
56
+ config.LORA_PATH = model_config.get('lora')
57
+ config.USE_LORA = model_config.get('use_lora', False)
58
+ config.TRIGGER_WORD = model_config.get('trigger_word', 'pepe the frog')
59
+
60
+ # LCM settings
61
+ config.USE_LCM = model_config.get('use_lcm', False)
62
+ config.LCM_LORA_PATH = model_config.get('lcm_lora')
63
+
64
+ # Log which model is being loaded
65
+ import logging
66
+ logger = logging.getLogger(__name__)
67
+ logger.info(f"Loading model: {model_name}")
68
+ logger.info(f"Base: {config.BASE_MODEL}, LoRA: {config.USE_LORA}, LCM: {config.USE_LCM}")
69
+
70
+ return PepeGenerator(config)
71
 
72
 
73
  def get_example_prompts():
 
85
  """Main application"""
86
  init_session_state()
87
 
88
+ # Sidebar (needs to be first to define selected_model)
89
+ st.sidebar.header("⚙️ Settings")
90
+
91
+ # Model selection
92
+ st.sidebar.subheader("🤖 Model Selection")
93
+ config = ModelConfig()
94
+ available_models = list(config.AVAILABLE_MODELS.keys())
95
+ selected_model = st.sidebar.selectbox(
96
+ "Choose Model",
97
+ available_models,
98
+ index=0,
99
+ help="Select which model to use for generation"
100
+ )
101
+
102
+ # Detect model change and auto-clear cache
103
+ if st.session_state.current_model is not None and st.session_state.current_model != selected_model:
104
+ st.cache_resource.clear()
105
+ st.sidebar.success(f"✅ Switched to: {selected_model}")
106
+
107
+ # Update current model in session state
108
+ st.session_state.current_model = selected_model
109
+
110
+ # Show LCM mode indicator if enabled
111
+ model_config = config.AVAILABLE_MODELS[selected_model]
112
+ if model_config.get('use_lcm', False):
113
+ st.sidebar.success("⚡ LCM Mode: 8x Faster! (6-8 steps optimal)")
114
+
115
  # Header
116
  st.title("🐸 Pepe the Frog Meme Generator")
117
  st.markdown("Create custom Pepe memes using AI! Powered by Stable Diffusion.")
118
 
119
+ st.sidebar.divider()
 
120
 
121
  # Style selection
122
+ st.sidebar.subheader("🎨 Style & Prompt")
123
  style_options = {
124
  "Default": "default",
125
  "😊 Happy": "happy",
 
136
  )
137
  style = style_options[selected_style]
138
 
139
+ # Raw prompt mode
140
+ use_raw_prompt = st.sidebar.checkbox(
141
+ "Raw Prompt Mode",
142
+ help="Use your exact prompt without trigger words or style modifiers"
143
+ )
144
+
145
+ # Advanced settings - adjust defaults based on LCM mode
146
+ is_lcm_mode = model_config.get('use_lcm', False)
147
+
148
  with st.sidebar.expander("🔧 Advanced Settings"):
149
+ if is_lcm_mode:
150
+ # LCM needs fewer steps and lower guidance
151
+ steps = st.slider("Steps", 4, 12, 6, 1,
152
+ help="⚡ LCM Mode: 4-8 steps optimal. Recommended: 6")
153
+ guidance = st.slider("Guidance Scale", 1.0, 2.5, 1.5, 0.1,
154
+ help="⚡ LCM Mode: Lower guidance (1.0-2.0). Recommended: 1.5")
155
+ else:
156
+ # Normal mode settings
157
+ steps = st.slider("Steps", 15, 50, 25, 5,
158
+ help="Fewer steps = faster generation. 20-25 recommended for CPU")
159
+ guidance = st.slider("Guidance Scale", 1.0, 20.0, 7.5, 0.5)
160
+
161
  use_seed = st.checkbox("Fixed Seed")
162
  seed = st.number_input("Seed", 0, 999999, 42) if use_seed else None
163
 
 
199
  if st.session_state.generated_images:
200
  placeholder.image(
201
  st.session_state.generated_images[-1],
202
+ width='stretch'
203
  )
204
  else:
205
  placeholder.info("Your meme will appear here...")
 
207
  # Generate
208
  if generate and prompt:
209
  try:
210
+ generator = load_generator(selected_model)
211
+ processor = ImageProcessor()
212
+
213
+ # Overall progress for multiple images
214
+ overall_progress = st.progress(0)
215
+ overall_status = st.empty()
216
 
217
+ # Progress for current image generation steps
218
+ step_progress = st.progress(0)
219
+ step_status = st.empty()
220
 
221
  for i in range(num_vars):
222
+ overall_status.text(f"Generating image {i+1}/{num_vars}...")
 
223
 
224
+ # Define callback for step-by-step progress
225
+ def progress_callback(current_step: int, total_steps: int):
226
+ step_progress.progress(current_step / total_steps)
227
+ step_status.text(f"Step {current_step}/{total_steps}")
228
+
229
+ # Generate with progress callback
230
  image = generator.generate(
231
  prompt=prompt,
232
  style=style,
233
  num_inference_steps=steps,
234
  guidance_scale=guidance,
235
+ seed=seed,
236
+ callback=progress_callback,
237
+ raw_prompt=use_raw_prompt
238
  )
239
 
240
  # Add text if requested
241
  if add_text and (top_text or bottom_text):
 
242
  image = processor.add_meme_text(image, top_text, bottom_text)
243
 
244
+ # Always add MJ signature
245
+ image = processor.add_signature(image, signature="MJaheen", font_size=10, opacity=200)
246
+
247
  st.session_state.generated_images.append(image)
248
  st.session_state.generation_count += 1
249
+
250
+ # Update overall progress
251
+ overall_progress.progress((i + 1) / num_vars)
252
 
253
+ # Clear progress indicators
254
+ overall_progress.empty()
255
+ overall_status.empty()
256
+ step_progress.empty()
257
+ step_status.empty()
258
 
259
  # Show result
260
  if num_vars == 1:
261
+ placeholder.image(image, width='stretch')
262
 
263
  # Download
264
  buf = io.BytesIO()
 
274
  cols = st.columns(min(num_vars, 2))
275
  for idx, img in enumerate(st.session_state.generated_images[-num_vars:]):
276
  with cols[idx % 2]:
277
+ st.image(img, width='stretch')
278
 
279
  except Exception as e:
280
  st.error(f"Error: {str(e)}")
 
289
  cols = st.columns(4)
290
  for idx, img in enumerate(reversed(st.session_state.generated_images[-8:])):
291
  with cols[idx % 4]:
292
+ st.image(img, width='stretch')
293
 
294
  # Footer
295
  st.divider()
 
303
  st.session_state.generated_images = []
304
  st.session_state.generation_count = 0
305
  st.rerun()
306
+
307
+ # Personal Information
308
+ st.divider()
309
+ st.markdown("### 👨‍💻 About the Engineer")
310
+ info_col1, info_col2 = st.columns(2)
311
+
312
+ with info_col1:
313
+ st.markdown("""
314
+ **Contact Information:**
315
+ - 📧 Email: [[email protected]](mailto:[email protected])
316
+ - 🔗 LinkedIn: [Mohamed Jaheen](https://www.linkedin.com/in/mohamedjaheen/)
317
+ """)
318
+
319
+ with info_col2:
320
+ st.markdown("""
321
+ **About this App:**
322
+ - supported by worldquant university
323
+ - Built with Streamlit & Stable Diffusion
324
+ - Fine-tuned Pepe model available
325
+ - Open source and customizable
326
+ - MIT licences
327
+ """)
328
+
329
+ st.caption("© 2025 - AI Meme Generator (Pepe the Frog) | Made with ❤️ using Python and MJ")
330
 
331
 
332
  if __name__ == "__main__":
src/model/config.py CHANGED
@@ -8,12 +8,79 @@ from typing import Optional
8
  class ModelConfig:
9
  """Model configuration parameters"""
10
 
11
- # Model paths
12
- BASE_MODEL: str ="runwayml/stable-diffusion-v1-5"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  LORA_PATH: str = "MJaheen/Pepe_The_Frog_model_v1_lora"
 
 
 
 
 
 
14
 
15
  # Default generation parameters
16
- DEFAULT_STEPS: int = 50
17
  DEFAULT_GUIDANCE: float = 7.5
18
  DEFAULT_WIDTH: int = 512
19
  DEFAULT_HEIGHT: int = 512
@@ -27,6 +94,7 @@ class ModelConfig:
27
  # Performance
28
  ENABLE_ATTENTION_SLICING: bool = True
29
  ENABLE_VAE_SLICING: bool = True
 
30
 
31
  # Available styles
32
  AVAILABLE_STYLES: tuple = (
 
8
  class ModelConfig:
9
  """Model configuration parameters"""
10
 
11
+ # Available models
12
+ AVAILABLE_MODELS: dict = None
13
+
14
+ def __post_init__(self):
15
+ if self.AVAILABLE_MODELS is None:
16
+ self.AVAILABLE_MODELS = {
17
+ "Pepe Fine-tuned (LoRA)": {
18
+ "base": "runwayml/stable-diffusion-v1-5",
19
+ "lora": "MJaheen/Pepe_The_Frog_model_v1_lora",
20
+ "trigger_word": "pepe_style_frog",
21
+ "use_lora": True,
22
+ "use_lcm": False
23
+ },
24
+ "Pepe + LCM (FAST)": {
25
+ "base": "runwayml/stable-diffusion-v1-5",
26
+ "lora": "MJaheen/Pepe_The_Frog_model_v1_lora",
27
+ "lcm_lora": "latent-consistency/lcm-lora-sdv1-5",
28
+ "trigger_word": "pepe_style_frog",
29
+ "use_lora": True,
30
+ "use_lcm": True
31
+ },
32
+ "Base SD 1.5": {
33
+ "base": "runwayml/stable-diffusion-v1-5",
34
+ "lora": None,
35
+ "trigger_word": "pepe the frog",
36
+ "use_lora": False,
37
+ "use_lcm": False
38
+ },
39
+ "Dreamlike Photoreal 2.0": {
40
+ "base": "dreamlike-art/dreamlike-photoreal-2.0",
41
+ "lora": None,
42
+ "trigger_word": "pepe the frog",
43
+ "use_lora": False,
44
+ "use_lcm": False
45
+ },
46
+ "Openjourney v4": {
47
+ "base": "prompthero/openjourney-v4",
48
+ "lora": None,
49
+ "trigger_word": "pepe the frog",
50
+ "use_lora": False,
51
+ "use_lcm": False
52
+ },
53
+ "Tiny SD (Fast CPU)": {
54
+ "base": "segmind/tiny-sd",
55
+ "lora": None,
56
+ "trigger_word": "pepe the frog",
57
+ "use_lora": False,
58
+ "use_lcm": False
59
+ },
60
+ "Small SD (Balanced CPU)": {
61
+ "base": "segmind/small-sd",
62
+ "lora": None,
63
+ "trigger_word": "pepe the frog",
64
+ "use_lora": False,
65
+ "use_lcm": False
66
+ }
67
+ }
68
+
69
+ # Default model selection
70
+ SELECTED_MODEL: str = "Pepe Fine-tuned (LoRA)"
71
+
72
+ # Model paths (will be set based on selection)
73
+ BASE_MODEL: str = "runwayml/stable-diffusion-v1-5"
74
  LORA_PATH: str = "MJaheen/Pepe_The_Frog_model_v1_lora"
75
+ USE_LORA: bool = True
76
+ TRIGGER_WORD: str = "pepe_style_frog"
77
+
78
+ # LCM settings
79
+ USE_LCM: bool = False
80
+ LCM_LORA_PATH: Optional[str] = None
81
 
82
  # Default generation parameters
83
+ DEFAULT_STEPS: int = 25 # Reduced for faster CPU inference (was 50)
84
  DEFAULT_GUIDANCE: float = 7.5
85
  DEFAULT_WIDTH: int = 512
86
  DEFAULT_HEIGHT: int = 512
 
94
  # Performance
95
  ENABLE_ATTENTION_SLICING: bool = True
96
  ENABLE_VAE_SLICING: bool = True
97
+ FORCE_CPU: bool = True # Set to True to force CPU, False to use GPU if available
98
 
99
  # Available styles
100
  AVAILABLE_STYLES: tuple = (
src/model/generator.py CHANGED
@@ -1,11 +1,12 @@
1
  """Pepe Meme Generator - Core generation logic"""
2
 
3
- from typing import Optional, List
4
  import torch
5
- from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
6
  import streamlit as st
7
  from PIL import Image
8
  import logging
 
9
 
10
  from .config import ModelConfig
11
 
@@ -14,40 +15,118 @@ logger = logging.getLogger(__name__)
14
 
15
  class PepeGenerator:
16
  """Main generator class for creating Pepe memes"""
17
-
18
  def __init__(self, config: Optional[ModelConfig] = None):
19
  """Initialize the generator"""
20
  self.config = config or ModelConfig()
21
- self.device = self._get_device()
22
- self.pipe = self._load_model()
 
 
 
 
 
 
 
23
  logger.info(f"PepeGenerator initialized on {self.device}")
24
-
25
  @staticmethod
26
  @st.cache_resource
27
- def _load_model() -> StableDiffusionPipeline:
28
- """Load and cache the Stable Diffusion model"""
29
- logger.info("Loading Stable Diffusion model...")
30
-
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Determine appropriate dtype based on device
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
- torch_dtype = torch.float16 if device == "cuda" else torch.float32
 
 
 
34
 
 
 
 
35
  pipe = StableDiffusionPipeline.from_pretrained(
36
- ModelConfig.BASE_MODEL,
37
  torch_dtype=torch_dtype,
38
  safety_checker=None, # Disabled for meme generation - users must comply with SD license
39
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Optimize scheduler
42
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(
43
- pipe.scheduler.config
44
- )
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # Enable memory optimizations
47
  pipe.enable_attention_slicing()
48
  pipe.enable_vae_slicing()
49
-
50
- if device == "cuda":
51
  pipe = pipe.to("cuda")
52
  try:
53
  pipe.enable_xformers_memory_efficient_attention()
@@ -56,16 +135,21 @@ class PepeGenerator:
56
  except Exception as e:
57
  logger.warning(f"Could not enable xformers: {e}")
58
  else:
59
- logger.info("Running on CPU - memory optimizations applied")
60
-
 
 
 
61
  logger.info("Model loaded successfully")
62
  return pipe
63
-
64
  @staticmethod
65
- def _get_device() -> str:
66
  """Determine the best available device"""
 
 
67
  return "cuda" if torch.cuda.is_available() else "cpu"
68
-
69
  def generate(
70
  self,
71
  prompt: str,
@@ -76,23 +160,42 @@ class PepeGenerator:
76
  seed: Optional[int] = None,
77
  width: int = 512,
78
  height: int = 512,
 
 
79
  ) -> Image.Image:
80
- """Generate a single Pepe meme image"""
81
-
82
- # Apply style preset
83
- enhanced_prompt = self._apply_style_preset(prompt, style)
84
 
 
 
 
 
 
 
 
 
 
 
 
85
  # Set default negative prompt
86
  if negative_prompt is None:
87
  negative_prompt = self.config.DEFAULT_NEGATIVE_PROMPT
88
-
89
  # Set seed for reproducibility
90
  generator = None
91
  if seed is not None:
92
  generator = torch.Generator(device=self.device).manual_seed(seed)
93
-
94
  logger.info(f"Generating: {enhanced_prompt[:50]}...")
95
-
 
 
 
 
 
 
 
 
 
96
  # Generate image (removed autocast for CPU compatibility)
97
  output = self.pipe(
98
  prompt=enhanced_prompt,
@@ -102,32 +205,32 @@ class PepeGenerator:
102
  generator=generator,
103
  width=width,
104
  height=height,
 
105
  )
106
-
107
  return output.images[0]
108
-
109
  def generate_batch(
110
  self,
111
  prompt: str,
112
  num_images: int = 4,
113
  **kwargs
114
  ) -> List[Image.Image]:
115
- """Generate multiple variations"""
116
  images = []
117
  for i in range(num_images):
118
  if 'seed' not in kwargs:
119
  kwargs['seed'] = torch.randint(0, 100000, (1,)).item()
120
-
121
  image = self.generate(prompt, **kwargs)
122
  images.append(image)
123
-
124
  if 'seed' in kwargs:
125
  del kwargs['seed']
126
-
127
  return images
128
-
129
- @staticmethod
130
- def _apply_style_preset(prompt: str, style: str) -> str:
131
  """Apply style-specific prompt enhancements"""
132
  style_modifiers = {
133
  "happy": "cheerful, smiling, joyful",
@@ -138,11 +241,14 @@ class PepeGenerator:
138
  "surprised": "shocked, amazed, wide eyes",
139
  }
140
 
141
- base = f"pepe the frog, {prompt}"
 
142
 
 
 
143
  if style in style_modifiers:
144
  base = f"{base}, {style_modifiers[style]}"
145
-
146
  base = f"{base}, high quality, detailed, meme art"
147
-
148
  return base
 
1
  """Pepe Meme Generator - Core generation logic"""
2
 
3
+ from typing import Optional, List, Callable
4
  import torch
5
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, LCMScheduler
6
  import streamlit as st
7
  from PIL import Image
8
  import logging
9
+ import os
10
 
11
  from .config import ModelConfig
12
 
 
15
 
16
  class PepeGenerator:
17
  """Main generator class for creating Pepe memes"""
18
+
19
  def __init__(self, config: Optional[ModelConfig] = None):
20
  """Initialize the generator"""
21
  self.config = config or ModelConfig()
22
+ self.device = self._get_device(self.config.FORCE_CPU)
23
+ self.pipe = self._load_model(
24
+ self.config.BASE_MODEL,
25
+ self.config.USE_LORA,
26
+ self.config.LORA_PATH,
27
+ self.config.FORCE_CPU,
28
+ self.config.USE_LCM,
29
+ self.config.LCM_LORA_PATH
30
+ )
31
  logger.info(f"PepeGenerator initialized on {self.device}")
32
+
33
  @staticmethod
34
  @st.cache_resource
35
+ def _load_model(base_model: str, use_lora: bool, lora_path: Optional[str],
36
+ force_cpu: bool = False, use_lcm: bool = False,
37
+ lcm_lora_path: Optional[str] = None) -> StableDiffusionPipeline:
38
+ """Load and cache the Stable Diffusion model with LoRA and LCM support"""
39
+ logger.info("="*60)
40
+ logger.info("LOADING NEW MODEL PIPELINE")
41
+ logger.info(f"Base Model: {base_model}")
42
+ logger.info(f"LoRA Enabled: {use_lora}")
43
+ if use_lora and lora_path:
44
+ logger.info(f"LoRA Path: {lora_path}")
45
+ logger.info(f"LCM Enabled: {use_lcm}")
46
+ if use_lcm and lcm_lora_path:
47
+ logger.info(f"LCM-LoRA Path: {lcm_lora_path}")
48
+ logger.info(f"Force CPU: {force_cpu}")
49
+ logger.info("="*60)
50
+
51
  # Determine appropriate dtype based on device
52
+ if force_cpu:
53
+ device = "cpu"
54
+ logger.info("🔧 FORCED CPU MODE - GPU disabled for testing")
55
+ else:
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
 
58
+ torch_dtype = torch.float16 if (device == "cuda" and not force_cpu) else torch.float32
59
+ logger.info(f"Using device: {device}, dtype: {torch_dtype}")
60
+
61
  pipe = StableDiffusionPipeline.from_pretrained(
62
+ base_model,
63
  torch_dtype=torch_dtype,
64
  safety_checker=None, # Disabled for meme generation - users must comply with SD license
65
  )
66
+
67
+ # Load LoRA weights if configured
68
+ if use_lora and lora_path:
69
+ logger.info(f"Loading LoRA weights from: {lora_path}")
70
+ try:
71
+ # Check if it's a local path or Hugging Face model ID
72
+ # Explicitly name it "pepe" to avoid "default_0" naming
73
+ if os.path.exists(lora_path):
74
+ # Local path
75
+ pipe.load_lora_weights(lora_path, adapter_name="pepe")
76
+ logger.info("LoRA weights loaded successfully from local path")
77
+ elif "/" in lora_path:
78
+ # Hugging Face model ID (format: username/model_name)
79
+ pipe.load_lora_weights(lora_path, adapter_name="pepe")
80
+ logger.info(f"✅ LoRA weights loaded successfully from Hugging Face: {lora_path}")
81
+ else:
82
+ logger.warning(f"Invalid LoRA path format: {lora_path}")
83
+
84
+ # If not using LCM, set Pepe LoRA as the active adapter
85
+ if not use_lcm:
86
+ pipe.set_adapters(["pepe"])
87
+ logger.info("✅ Pepe LoRA active")
88
+ except Exception as e:
89
+ logger.error(f"Failed to load LoRA weights: {e}")
90
+ logger.info("Continuing without LoRA weights...")
91
 
92
+ # Load LCM-LoRA on top if configured (this enables fast inference!)
93
+ if use_lcm and lcm_lora_path:
94
+ logger.info(f"Loading LCM-LoRA from: {lcm_lora_path}")
95
+ try:
96
+ # Load LCM-LoRA as a separate adapter
97
+ pipe.load_lora_weights(lcm_lora_path, adapter_name="lcm")
98
+ logger.info("✅ LCM-LoRA loaded successfully")
99
+
100
+ # If we have both Pepe LoRA and LCM-LoRA, fuse them
101
+ if use_lora:
102
+ logger.info("Fusing Pepe LoRA + LCM-LoRA adapters...")
103
+ # Use the correct adapter names: "pepe" and "lcm"
104
+ pipe.set_adapters(["pepe", "lcm"], adapter_weights=[1.0, 1.0])
105
+ logger.info("✅ Both LoRAs fused successfully (pepe + lcm)")
106
+ else:
107
+ # Only LCM, set it as active
108
+ pipe.set_adapters(["lcm"])
109
+ logger.info("✅ LCM-LoRA active (solo mode)")
110
+ except Exception as e:
111
+ logger.error(f"Failed to load LCM-LoRA: {e}")
112
+ logger.info("Continuing without LCM...")
113
+ use_lcm = False
114
+
115
+ # Set appropriate scheduler based on LCM mode
116
+ if use_lcm:
117
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
118
+ logger.info("⚡ Using LCM Scheduler (few-step mode)")
119
+ else:
120
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
121
+ pipe.scheduler.config
122
+ )
123
+ logger.info("🔧 Using DPM Solver Scheduler (standard mode)")
124
+
125
  # Enable memory optimizations
126
  pipe.enable_attention_slicing()
127
  pipe.enable_vae_slicing()
128
+
129
+ if device == "cuda" and not force_cpu:
130
  pipe = pipe.to("cuda")
131
  try:
132
  pipe.enable_xformers_memory_efficient_attention()
 
135
  except Exception as e:
136
  logger.warning(f"Could not enable xformers: {e}")
137
  else:
138
+ if force_cpu:
139
+ logger.info("Running on CPU - FORCED for testing")
140
+ else:
141
+ logger.info("Running on CPU - memory optimizations applied")
142
+
143
  logger.info("Model loaded successfully")
144
  return pipe
145
+
146
  @staticmethod
147
+ def _get_device(force_cpu: bool = False) -> str:
148
  """Determine the best available device"""
149
+ if force_cpu:
150
+ return "cpu"
151
  return "cuda" if torch.cuda.is_available() else "cpu"
152
+
153
  def generate(
154
  self,
155
  prompt: str,
 
160
  seed: Optional[int] = None,
161
  width: int = 512,
162
  height: int = 512,
163
+ callback: Optional[Callable[[int, int], None]] = None,
164
+ raw_prompt: bool = False,
165
  ) -> Image.Image:
166
+ """Generate a single Pepe meme image
 
 
 
167
 
168
+ Args:
169
+ callback: Optional callback function (current_step, total_steps)
170
+ raw_prompt: If True, use prompt as-is without modifications
171
+ """
172
+
173
+ # Apply style preset or use raw prompt
174
+ if raw_prompt:
175
+ enhanced_prompt = prompt
176
+ else:
177
+ enhanced_prompt = self._apply_style_preset(prompt, style)
178
+
179
  # Set default negative prompt
180
  if negative_prompt is None:
181
  negative_prompt = self.config.DEFAULT_NEGATIVE_PROMPT
182
+
183
  # Set seed for reproducibility
184
  generator = None
185
  if seed is not None:
186
  generator = torch.Generator(device=self.device).manual_seed(seed)
187
+
188
  logger.info(f"Generating: {enhanced_prompt[:50]}...")
189
+ logger.debug(f"Full prompt: {enhanced_prompt}")
190
+ logger.debug(f"Model config - Base: {self.config.BASE_MODEL}, LoRA: {self.config.USE_LORA}")
191
+
192
+ # Create callback wrapper if provided (using new API)
193
+ callback_on_step_end_fn = None
194
+ if callback:
195
+ def callback_on_step_end_fn(pipe, step, timestep, callback_kwargs):
196
+ callback(step + 1, num_inference_steps)
197
+ return callback_kwargs
198
+
199
  # Generate image (removed autocast for CPU compatibility)
200
  output = self.pipe(
201
  prompt=enhanced_prompt,
 
205
  generator=generator,
206
  width=width,
207
  height=height,
208
+ callback_on_step_end=callback_on_step_end_fn,
209
  )
210
+
211
  return output.images[0]
212
+
213
  def generate_batch(
214
  self,
215
  prompt: str,
216
  num_images: int = 4,
217
  **kwargs
218
  ) -> List[Image.Image]:
219
+ """Generate multiple variations with callback support"""
220
  images = []
221
  for i in range(num_images):
222
  if 'seed' not in kwargs:
223
  kwargs['seed'] = torch.randint(0, 100000, (1,)).item()
224
+
225
  image = self.generate(prompt, **kwargs)
226
  images.append(image)
227
+
228
  if 'seed' in kwargs:
229
  del kwargs['seed']
230
+
231
  return images
232
+
233
+ def _apply_style_preset(self, prompt: str, style: str) -> str:
 
234
  """Apply style-specific prompt enhancements"""
235
  style_modifiers = {
236
  "happy": "cheerful, smiling, joyful",
 
241
  "surprised": "shocked, amazed, wide eyes",
242
  }
243
 
244
+ # Use trigger word from config
245
+ trigger_word = self.config.TRIGGER_WORD
246
 
247
+ base = f"{trigger_word}, {prompt}"
248
+
249
  if style in style_modifiers:
250
  base = f"{base}, {style_modifiers[style]}"
251
+
252
  base = f"{base}, high quality, detailed, meme art"
253
+
254
  return base
src/utils/image_processor.py CHANGED
@@ -72,6 +72,87 @@ class ImageProcessor:
72
  # Draw main text
73
  draw.text(position, text, font=font, fill="white", anchor="mm")
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  @staticmethod
76
  def enhance_image(
77
  image: Image.Image,
 
72
  # Draw main text
73
  draw.text(position, text, font=font, fill="white", anchor="mm")
74
 
75
+ @staticmethod
76
+ def add_signature(
77
+ image: Image.Image,
78
+ signature: str = "MJ",
79
+ position: str = "bottom-right",
80
+ font_size: int = 20,
81
+ opacity: int = 180,
82
+ ) -> Image.Image:
83
+ """Add a small signature/watermark to the image
84
+
85
+ Args:
86
+ image: Input image
87
+ signature: Text to add as signature
88
+ position: Position of signature (bottom-right, bottom-left, top-right, top-left)
89
+ font_size: Size of the signature font
90
+ opacity: Opacity of the signature (0-255)
91
+ """
92
+ img = image.copy()
93
+
94
+ # Create a transparent overlay
95
+ overlay = Image.new('RGBA', img.size, (255, 255, 255, 0))
96
+ draw = ImageDraw.Draw(overlay)
97
+
98
+ # Load font
99
+ try:
100
+ font = ImageFont.truetype("arial.ttf", font_size)
101
+ except:
102
+ try:
103
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
104
+ except:
105
+ font = ImageFont.load_default()
106
+ logger.warning("Using default font for signature")
107
+
108
+ # Calculate text size and position
109
+ bbox = draw.textbbox((0, 0), signature, font=font)
110
+ text_width = bbox[2] - bbox[0]
111
+ text_height = bbox[3] - bbox[1]
112
+
113
+ padding = 10
114
+
115
+ if position == "bottom-right":
116
+ x = img.width - text_width - padding
117
+ y = img.height - text_height - padding
118
+ elif position == "bottom-left":
119
+ x = padding
120
+ y = img.height - text_height - padding
121
+ elif position == "top-right":
122
+ x = img.width - text_width - padding
123
+ y = padding
124
+ elif position == "top-left":
125
+ x = padding
126
+ y = padding
127
+ else:
128
+ x = img.width - text_width - padding
129
+ y = img.height - text_height - padding
130
+
131
+ # Draw signature with semi-transparent background
132
+ bg_padding = 5
133
+ draw.rectangle(
134
+ [x - bg_padding, y - bg_padding,
135
+ x + text_width + bg_padding, y + text_height + bg_padding],
136
+ fill=(0, 0, 0, opacity // 2)
137
+ )
138
+
139
+ # Draw text
140
+ draw.text((x, y), signature, font=font, fill=(255, 255, 255, opacity))
141
+
142
+ # Convert to RGB if needed and composite
143
+ if img.mode != 'RGBA':
144
+ img = img.convert('RGBA')
145
+
146
+ img = Image.alpha_composite(img, overlay)
147
+
148
+ # Convert back to RGB
149
+ if img.mode == 'RGBA':
150
+ rgb_img = Image.new('RGB', img.size, (255, 255, 255))
151
+ rgb_img.paste(img, mask=img.split()[3])
152
+ return rgb_img
153
+
154
+ return img
155
+
156
  @staticmethod
157
  def enhance_image(
158
  image: Image.Image,